from typing import Optional, Tuple, Union, List, DefaultDict, TYPE_CHECKING
from collections import defaultdict
import logging
import networkx
from ailment import Block, AILBlockWalkerBase
from ailment.statement import ConditionalJump
from ailment.expression import Expression, BinaryOp, Const, Load
from angr.utils.graph import GraphUtils
from ..utils import first_nonlabel_statement, remove_last_statement
from ..structuring.structurer_nodes import IncompleteSwitchCaseHeadStatement, SequenceNode, MultiNode
from .optimization_pass import OptimizationPass, OptimizationPassStage, MultipleBlocksException
if TYPE_CHECKING:
from ailment.expression import UnaryOp, Convert
_l = logging.getLogger(name=__name__)
[docs]class Case:
"""
Describes a case in a switch-case construct.
"""
__slots__ = (
"original_node",
"node_type",
"variable_hash",
"expr",
"value",
"target",
"next_addr",
)
[docs] def __init__(
self, original_node, node_type: Optional[str], variable_hash, expr, value: Union[int, str], target, next_addr
):
self.original_node = original_node
self.node_type = node_type
self.variable_hash = variable_hash
self.expr = expr
self.value = value
self.target = target
self.next_addr = next_addr
def __repr__(self):
if self.value == "default":
return f"Case default@{self.target:#x}"
return f"Case {repr(self.original_node)}@{self.target:#x}: {self.expr} == {self.value}"
def __eq__(self, other):
if not isinstance(other, Case):
return False
return (
self.original_node == other.original_node
and self.node_type == other.node_type
and self.variable_hash == other.variable_hash
and self.value == other.value
and self.target == other.target
and self.next_addr == other.next_addr
)
def __hash__(self):
return hash(
(Case, self.original_node, self.node_type, self.variable_hash, self.value, self.target, self.next_addr)
)
[docs]class StableVarExprHasher(AILBlockWalkerBase):
"""
Obtain a stable hash of an AIL expression with respect to all variables and all operations applied on variables.
"""
[docs] def __init__(self, expr: Expression):
super().__init__()
self.expr = expr
self._hash_lst = []
self.walk_expression(expr)
self.hash = hash(tuple(self._hash_lst))
def _handle_expr(self, expr_idx: int, expr: Expression, stmt_idx: int, stmt, block: Optional[Block]):
if hasattr(expr, "variable"):
self._hash_lst.append(expr.variable)
else:
super()._handle_expr(expr_idx, expr, stmt_idx, stmt, block)
def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt, block: Optional[Block]):
self._hash_lst.append("Load")
super()._handle_expr(expr_idx, expr, stmt_idx, stmt, block)
def _handle_BinaryOp(self, expr_idx: int, expr: BinaryOp, stmt_idx: int, stmt, block: Optional[Block]):
self._hash_lst.append(expr.op)
super()._handle_BinaryOp(expr_idx, expr, stmt_idx, stmt, block)
def _handle_UnaryOp(self, expr_idx: int, expr: "UnaryOp", stmt_idx: int, stmt, block: Optional[Block]):
self._hash_lst.append(expr.op)
super()._handle_UnaryOp(expr_idx, expr, stmt_idx, stmt, block)
def _handle_Const(self, expr_idx: int, expr: Const, stmt_idx: int, stmt, block: Optional[Block]):
self._hash_lst.append((expr.value, expr.bits))
def _handle_Convert(self, expr_idx: int, expr: "Convert", stmt_idx: int, stmt, block: Optional[Block]):
self._hash_lst.append(expr.to_bits)
super()._handle_Convert(expr_idx, expr, stmt_idx, stmt, block)
[docs]class LoweredSwitchSimplifier(OptimizationPass):
"""
Recognize and simplify lowered switch-case constructs.
"""
ARCHES = [
"AMD64",
]
PLATFORMS = ["linux", "windows"]
STAGE = OptimizationPassStage.BEFORE_REGION_IDENTIFICATION
NAME = "Convert lowered switch-cases (if-else) to switch-cases"
DESCRIPTION = (
"Convert lowered switch-cases (if-else) to switch-cases. Only works when the Phoenix structuring "
"algorithm is in use."
)
STRUCTURING = ["phoenix"]
[docs] def __init__(self, func, blocks_by_addr=None, blocks_by_addr_and_idx=None, graph=None, **kwargs):
super().__init__(
func, blocks_by_addr=blocks_by_addr, blocks_by_addr_and_idx=blocks_by_addr_and_idx, graph=graph, **kwargs
)
self.analyze()
def _check(self):
# TODO: More filtering
return True, None
def _analyze(self, cache=None):
variablehash_to_cases = self._find_cascading_switch_variable_comparisons()
if not variablehash_to_cases:
return
graph_copy = networkx.DiGraph(self._graph)
self.out_graph = graph_copy
node_to_heads = defaultdict(set)
for _, caselists in variablehash_to_cases.items():
for cases in caselists:
original_nodes = [case.original_node for case in cases if case.value != "default"]
original_head: Block = original_nodes[0]
original_nodes = original_nodes[1:]
case_addrs = {(case.original_node, case.value, case.target, case.next_addr) for case in cases}
expr = cases[0].expr
# create a fake switch-case head node
switch_stmt = IncompleteSwitchCaseHeadStatement(
original_head.statements[-1].idx, expr, case_addrs, ins_addr=original_head.statements[-1].ins_addr
)
new_head = original_head.copy()
# replace the last instruction of the head node with switch_node
new_head.statements[-1] = switch_stmt
# update the block
self._update_block(original_head, new_head)
# sanity check that no switch head points to either itself
# or to any if-head that was merged into the new switch head; this
# would result in a successor node no longer being present in the graph
if any(onode not in graph_copy for onode in original_nodes):
self.out_graph = None
return
# add edges between the head and case nodes
for onode in original_nodes:
successors = list(graph_copy.successors(onode))
for succ in successors:
if succ not in original_nodes:
graph_copy.add_edge(new_head, succ)
node_to_heads[succ].add(new_head)
graph_copy.remove_node(onode)
# find shared case nodes and make copies of them
# note that this only solves cases where *one* node is shared between switch-cases. a more general solution
# requires jump threading reverter.
for succ_node, heads in node_to_heads.items():
if len(heads) > 1:
# each head gets a copy of the node!
node_successors = list(graph_copy.successors(succ_node))
next_id = 0 if succ_node.idx is None else succ_node.idx + 1
graph_copy.remove_node(succ_node)
for head in heads:
node_copy = succ_node.copy()
node_copy.idx = next_id
next_id += 1
graph_copy.add_edge(head, node_copy)
for succ in node_successors:
if succ is succ_node:
graph_copy.add_edge(node_copy, node_copy)
else:
graph_copy.add_edge(node_copy, succ)
def _find_cascading_switch_variable_comparisons(self):
sorted_nodes = GraphUtils.quasi_topological_sort_nodes(self._graph)
variable_comparisons = {}
for node in sorted_nodes:
r = self._find_switch_variable_comparison_type_a(node)
if r is not None:
variable_comparisons[node] = ("a",) + r
continue
r = self._find_switch_variable_comparison_type_b(node)
if r is not None:
variable_comparisons[node] = ("b",) + r
continue
varhash_to_caselists: DefaultDict[int, List[List[Case]]] = defaultdict(list)
for head in variable_comparisons:
cases = []
last_comp = None
comp = head
while True:
comp_type, variable_hash, expr, value, target, next_addr = variable_comparisons[comp]
if cases:
last_varhash = cases[-1].variable_hash
else:
last_varhash = None
if last_varhash is None or last_varhash == variable_hash:
if target == comp.addr:
# invalid
break
cases.append(Case(comp, comp_type, variable_hash, expr, value, target, next_addr))
else:
# new variable!
if last_comp is not None:
cases.append(Case(last_comp, None, last_varhash, None, "default", comp.addr, None))
break
if comp is not head:
# non-head node has at most one predecessor
if self._graph.in_degree[comp] > 1:
break
successors = [succ for succ in self._graph.successors(comp) if succ is not comp]
succ_addrs = {succ.addr for succ in successors}
if target in succ_addrs:
next_comp_addr = next(iter(succ_addr for succ_addr in succ_addrs if succ_addr != target), None)
if next_comp_addr is None:
break
try:
next_comp = self._get_block(next_comp_addr)
except MultipleBlocksException:
# multiple blocks :/ it's possible that other optimization passes have duplicated the default
# node. check it.
next_comp_many = list(self._get_blocks(next_comp_addr))
if next_comp_many[0] not in variable_comparisons:
cases.append(Case(comp, None, variable_hash, expr, "default", next_comp_addr, None))
# otherwise we don't support it
break
assert next_comp is not None
if next_comp in variable_comparisons:
last_comp = comp
comp = next_comp
continue
cases.append(Case(comp, None, variable_hash, expr, "default", next_comp_addr, None))
break
if cases:
v = cases[-1].variable_hash
for idx, existing_cases in list(enumerate(varhash_to_caselists[v])):
if self.cases_issubset(existing_cases, cases):
varhash_to_caselists[v][idx] = cases
break
if self.cases_issubset(cases, existing_cases):
break
else:
varhash_to_caselists[v].append(cases)
for v, caselists in list(varhash_to_caselists.items()):
for idx, cases in list(enumerate(caselists)):
# filter: there should be at least two non-default cases
if len([case for case in cases if case.value != "default"]) < 2:
caselists[idx] = None
continue
# filter: no type-a node after the first case node
if any(case for case in cases[1:] if case.value != "default" and case.node_type == "a"):
caselists[idx] = None
continue
# filter: each case is only reachable from a case node
all_case_nodes = {case.original_node for case in cases}
skipped = False
for case in cases:
target_nodes = [
succ for succ in self._graph.successors(case.original_node) if succ.addr == case.target
]
if len(target_nodes) != 1:
caselists[idx] = None
skipped = True
break
target_node = target_nodes[0]
nonself_preds = {pred for pred in self._graph.predecessors(target_node) if pred.addr == case.target}
if not nonself_preds.issubset(all_case_nodes):
caselists[idx] = None
skipped = True
break
if skipped:
continue
varhash_to_caselists[v] = [cl for cl in caselists if cl is not None]
return varhash_to_caselists
@staticmethod
def _find_switch_variable_comparison_type_a(
node,
) -> Optional[Tuple[int, Expression, int, int, int]]:
# the type a is the last statement is a var == constant comparison, but
# there is more than one non-label statement in the block
if isinstance(node, Block) and node.statements:
stmt = node.statements[-1]
if stmt is not None and stmt is not first_nonlabel_statement(node):
if (
isinstance(stmt, ConditionalJump)
and isinstance(stmt.true_target, Const)
and isinstance(stmt.false_target, Const)
):
cond = stmt.condition
if isinstance(cond, BinaryOp):
if isinstance(cond.operands[1], Const):
variable_hash = StableVarExprHasher(cond.operands[0]).hash
value = cond.operands[1].value
if cond.op == "CmpEQ":
target = stmt.true_target.value
next_node_addr = stmt.false_target.value
elif cond.op == "CmpNE":
target = stmt.false_target.value
next_node_addr = stmt.true_target.value
else:
return None
return variable_hash, cond.operands[0], value, target, next_node_addr
return None
@staticmethod
def _find_switch_variable_comparison_type_b(
node,
) -> Optional[Tuple[int, Expression, int, int, int]]:
# the type b is the last statement is a var == constant comparison, and
# there is only one non-label statement
if isinstance(node, Block):
stmt = first_nonlabel_statement(node)
if stmt is not None and stmt is node.statements[-1]:
if (
isinstance(stmt, ConditionalJump)
and isinstance(stmt.true_target, Const)
and isinstance(stmt.false_target, Const)
):
cond = stmt.condition
if isinstance(cond, BinaryOp):
if isinstance(cond.operands[1], Const):
variable_hash = StableVarExprHasher(cond.operands[0]).hash
value = cond.operands[1].value
if cond.op == "CmpEQ":
target = stmt.true_target.value
next_node_addr = stmt.false_target.value
elif cond.op == "CmpNE":
target = stmt.false_target.value
next_node_addr = stmt.true_target.value
else:
return None
return variable_hash, cond.operands[0], value, target, next_node_addr
return None
[docs] @staticmethod
def restore_graph(
node, last_stmt: IncompleteSwitchCaseHeadStatement, graph: networkx.DiGraph, full_graph: networkx.DiGraph
):
last_node = node
ca_default = [
(onode, value, target, a) for onode, value, target, a in last_stmt.case_addrs if value == "default"
]
ca_others = [
(onode, value, target, a) for onode, value, target, a in last_stmt.case_addrs if value != "default"
]
# non-default nodes
ca_others = {ca[0].addr: ca for ca in ca_others}
# extract the AIL block from last_node
last_block = last_node
if isinstance(last_block, SequenceNode):
last_block = last_block.nodes[-1]
if isinstance(last_block, MultiNode):
last_block = last_block.nodes[-1]
assert isinstance(last_block, Block)
next_node_addr = last_block.addr
while next_node_addr is not None and next_node_addr in ca_others:
onode, value, target, next_node_addr = ca_others[next_node_addr]
onode: Block
if first_nonlabel_statement(onode) is not onode.statements[-1]:
onode = onode.copy(statements=[onode.statements[-1]])
graph.add_edge(last_node, onode)
full_graph.add_edge(last_node, onode)
target_node = next(iter(nn for nn in full_graph if nn.addr == target))
graph.add_edge(onode, target_node)
full_graph.add_edge(onode, target_node)
if graph.has_edge(node, target_node):
graph.remove_edge(node, target_node)
if full_graph.has_edge(node, target_node):
full_graph.remove_edge(node, target_node)
# update last_node
last_node = onode
# default nodes
if ca_default:
onode, value, target, _ = ca_default[0]
default_target = next(iter(nn for nn in full_graph if nn.addr == target))
graph.add_edge(last_node, default_target)
full_graph.add_edge(last_node, default_target)
if graph.has_edge(node, default_target):
graph.remove_edge(node, default_target)
if full_graph.has_edge(node, default_target):
full_graph.remove_edge(node, default_target)
# all good - remove the last statement in node
remove_last_statement(node)
[docs] @staticmethod
def cases_issubset(cases_0: List[Case], cases_1: List[Case]) -> bool:
"""
Test if cases_0 is a subset of cases_1.
"""
if len(cases_0) > len(cases_1):
return False
for case in cases_0:
if case not in cases_1:
return False
return True