Source code for angr.analyses.decompiler.optimization_passes.lowered_switch_simplifier

from typing import Optional, Tuple, Union, List, DefaultDict, TYPE_CHECKING
from collections import defaultdict
import logging

import networkx

from ailment import Block
from ailment.statement import ConditionalJump
from ailment.expression import Expression, BinaryOp, Const, Load

from ...cfg.cfg_utils import CFGUtils
from ..utils import first_nonlabel_statement, remove_last_statement
from ..structuring.structurer_nodes import IncompleteSwitchCaseHeadStatement, SequenceNode, MultiNode
from ..ailblock_walker import AILBlockWalkerBase
from .optimization_pass import OptimizationPass, OptimizationPassStage, MultipleBlocksException

if TYPE_CHECKING:
    from ailment.expression import UnaryOp, Convert

_l = logging.getLogger(name=__name__)


[docs]class Case: """ Describes a case in a switch-case construct. """ __slots__ = ( "original_node", "node_type", "variable_hash", "expr", "value", "target", "next_addr", ) def __init__( self, original_node, node_type: Optional[str], variable_hash, expr, value: Union[int, str], target, next_addr ): self.original_node = original_node self.node_type = node_type self.variable_hash = variable_hash self.expr = expr self.value = value self.target = target self.next_addr = next_addr def __repr__(self): if self.value == "default": return f"Case default@{self.target:#x}" return f"Case {repr(self.original_node)}@{self.target:#x}: {self.expr} == {self.value}" def __eq__(self, other): if not isinstance(other, Case): return False return ( self.original_node == other.original_node and self.node_type == other.node_type and self.variable_hash == other.variable_hash and self.value == other.value and self.target == other.target and self.next_addr == other.next_addr ) def __hash__(self): return hash( (Case, self.original_node, self.node_type, self.variable_hash, self.value, self.target, self.next_addr) )
[docs]class StableVarExprHasher(AILBlockWalkerBase): """ Obtain a stable hash of an AIL expression with respect to all variables and all operations applied on variables. """ def __init__(self, expr: Expression): super().__init__() self.expr = expr self._hash_lst = [] self.walk_expression(expr) self.hash = hash(tuple(self._hash_lst)) def _handle_expr(self, expr_idx: int, expr: Expression, stmt_idx: int, stmt, block: Optional[Block]): if hasattr(expr, "variable"): self._hash_lst.append(expr.variable) else: super()._handle_expr(expr_idx, expr, stmt_idx, stmt, block) def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt, block: Optional[Block]): self._hash_lst.append("Load") super()._handle_expr(expr_idx, expr, stmt_idx, stmt, block) def _handle_BinaryOp(self, expr_idx: int, expr: BinaryOp, stmt_idx: int, stmt, block: Optional[Block]): self._hash_lst.append(expr.op) super()._handle_BinaryOp(expr_idx, expr, stmt_idx, stmt, block) def _handle_UnaryOp(self, expr_idx: int, expr: "UnaryOp", stmt_idx: int, stmt, block: Optional[Block]): self._hash_lst.append(expr.op) super()._handle_UnaryOp(expr_idx, expr, stmt_idx, stmt, block) def _handle_Const(self, expr_idx: int, expr: Const, stmt_idx: int, stmt, block: Optional[Block]): self._hash_lst.append((expr.value, expr.bits)) def _handle_Convert(self, expr_idx: int, expr: "Convert", stmt_idx: int, stmt, block: Optional[Block]): self._hash_lst.append(expr.to_bits) super()._handle_Convert(expr_idx, expr, stmt_idx, stmt, block)
[docs]class LoweredSwitchSimplifier(OptimizationPass): """ Recognize and simplify lowered switch-case constructs. """ ARCHES = [ "AMD64", ] PLATFORMS = ["linux", "windows"] STAGE = OptimizationPassStage.BEFORE_REGION_IDENTIFICATION NAME = "Convert lowered switch-cases (if-else) to switch-cases" DESCRIPTION = ( "Convert lowered switch-cases (if-else) to switch-cases. Only works when the Phoenix structuring " "algorithm is in use." ) STRUCTURING = ["phoenix"] def __init__(self, func, blocks_by_addr=None, blocks_by_addr_and_idx=None, graph=None, **kwargs): super().__init__( func, blocks_by_addr=blocks_by_addr, blocks_by_addr_and_idx=blocks_by_addr_and_idx, graph=graph, **kwargs ) self.analyze() def _check(self): # TODO: More filtering return True, None def _analyze(self, cache=None): variablehash_to_cases = self._find_cascading_switch_variable_comparisons() if not variablehash_to_cases: return graph_copy = networkx.DiGraph(self._graph) self.out_graph = graph_copy node_to_heads = defaultdict(set) for _, caselists in variablehash_to_cases.items(): for cases in caselists: original_nodes = [case.original_node for case in cases if case.value != "default"] original_head: Block = original_nodes[0] original_nodes = original_nodes[1:] case_addrs = {(case.original_node, case.value, case.target, case.next_addr) for case in cases} expr = cases[0].expr # create a fake switch-case head node switch_stmt = IncompleteSwitchCaseHeadStatement( original_head.statements[-1].idx, expr, case_addrs, ins_addr=original_head.statements[-1].ins_addr ) new_head = original_head.copy() # replace the last instruction of the head node with switch_node new_head.statements[-1] = switch_stmt # update the block self._update_block(original_head, new_head) # add edges between the head and case nodes for onode in original_nodes: successors = list(graph_copy.successors(onode)) for succ in successors: if succ not in original_nodes: graph_copy.add_edge(new_head, succ) node_to_heads[succ].add(new_head) graph_copy.remove_node(onode) # find shared case nodes and make copies of them # note that this only solves cases where *one* node is shared between switch-cases. a more general solution # requires jump threading reverter. for succ_node, heads in node_to_heads.items(): if len(heads) > 1: # each head gets a copy of the node! node_successors = list(graph_copy.successors(succ_node)) next_id = 0 if succ_node.idx is None else succ_node.idx + 1 graph_copy.remove_node(succ_node) for head in heads: node_copy = succ_node.copy() node_copy.idx = next_id next_id += 1 graph_copy.add_edge(head, node_copy) for succ in node_successors: if succ is succ_node: graph_copy.add_edge(node_copy, node_copy) else: graph_copy.add_edge(node_copy, succ) def _find_cascading_switch_variable_comparisons(self): sorted_nodes = CFGUtils.quasi_topological_sort_nodes(self._graph) variable_comparisons = {} for node in sorted_nodes: r = self._find_switch_variable_comparison_type_a(node) if r is not None: variable_comparisons[node] = ("a",) + r continue r = self._find_switch_variable_comparison_type_b(node) if r is not None: variable_comparisons[node] = ("b",) + r continue varhash_to_caselists: DefaultDict[int, List[List[Case]]] = defaultdict(list) for head in variable_comparisons: cases = [] last_comp = None comp = head while True: comp_type, variable_hash, expr, value, target, next_addr = variable_comparisons[comp] if cases: last_varhash = cases[-1].variable_hash else: last_varhash = None if last_varhash is None or last_varhash == variable_hash: if target == comp.addr: # invalid break cases.append(Case(comp, comp_type, variable_hash, expr, value, target, next_addr)) else: # new variable! if last_comp is not None: cases.append(Case(last_comp, None, last_varhash, None, "default", comp.addr, None)) break if comp is not head: # non-head node has at most one predecessor if self._graph.in_degree[comp] > 1: break successors = [succ for succ in self._graph.successors(comp) if succ is not comp] succ_addrs = {succ.addr for succ in successors} if target in succ_addrs: next_comp_addr = next(iter(succ_addr for succ_addr in succ_addrs if succ_addr != target), None) if next_comp_addr is None: break try: next_comp = self._get_block(next_comp_addr) except MultipleBlocksException: # multiple blocks :/ it's possible that other optimization passes have duplicated the default # node. check it. next_comp_many = list(self._get_blocks(next_comp_addr)) if next_comp_many[0] not in variable_comparisons: cases.append(Case(comp, None, variable_hash, expr, "default", next_comp_addr, None)) # otherwise we don't support it break assert next_comp is not None if next_comp in variable_comparisons: last_comp = comp comp = next_comp continue cases.append(Case(comp, None, variable_hash, expr, "default", next_comp_addr, None)) break if cases: v = cases[-1].variable_hash for idx, existing_cases in list(enumerate(varhash_to_caselists[v])): if self.cases_issubset(existing_cases, cases): varhash_to_caselists[v][idx] = cases break if self.cases_issubset(cases, existing_cases): break else: varhash_to_caselists[v].append(cases) for v, caselists in list(varhash_to_caselists.items()): for idx, cases in list(enumerate(caselists)): # filter: there should be at least two non-default cases if len([case for case in cases if case.value != "default"]) < 2: caselists[idx] = None continue # filter: no type-a node after the first case node if any(case for case in cases[1:] if case.value != "default" and case.node_type == "a"): caselists[idx] = None continue # filter: each case is only reachable from a case node all_case_nodes = {case.original_node for case in cases} skipped = False for case in cases: target_nodes = [ succ for succ in self._graph.successors(case.original_node) if succ.addr == case.target ] if len(target_nodes) != 1: caselists[idx] = None skipped = True break target_node = target_nodes[0] nonself_preds = {pred for pred in self._graph.predecessors(target_node) if pred.addr == case.target} if not nonself_preds.issubset(all_case_nodes): caselists[idx] = None skipped = True break if skipped: continue varhash_to_caselists[v] = [cl for cl in caselists if cl is not None] return varhash_to_caselists @staticmethod def _find_switch_variable_comparison_type_a( node, ) -> Optional[Tuple[int, Expression, int, int, int]]: # the type a is the last statement is a var == constant comparison, but # there is more than one non-label statement in the block if isinstance(node, Block) and node.statements: stmt = node.statements[-1] if stmt is not None and stmt is not first_nonlabel_statement(node): if ( isinstance(stmt, ConditionalJump) and isinstance(stmt.true_target, Const) and isinstance(stmt.false_target, Const) ): cond = stmt.condition if isinstance(cond, BinaryOp): if isinstance(cond.operands[1], Const): variable_hash = StableVarExprHasher(cond.operands[0]).hash value = cond.operands[1].value if cond.op == "CmpEQ": target = stmt.true_target.value next_node_addr = stmt.false_target.value elif cond.op == "CmpNE": target = stmt.false_target.value next_node_addr = stmt.true_target.value else: return None return variable_hash, cond.operands[0], value, target, next_node_addr return None @staticmethod def _find_switch_variable_comparison_type_b( node, ) -> Optional[Tuple[int, Expression, int, int, int]]: # the type b is the last statement is a var == constant comparison, and # there is only one non-label statement if isinstance(node, Block): stmt = first_nonlabel_statement(node) if stmt is not None and stmt is node.statements[-1]: if ( isinstance(stmt, ConditionalJump) and isinstance(stmt.true_target, Const) and isinstance(stmt.false_target, Const) ): cond = stmt.condition if isinstance(cond, BinaryOp): if isinstance(cond.operands[1], Const): variable_hash = StableVarExprHasher(cond.operands[0]).hash value = cond.operands[1].value if cond.op == "CmpEQ": target = stmt.true_target.value next_node_addr = stmt.false_target.value elif cond.op == "CmpNE": target = stmt.false_target.value next_node_addr = stmt.true_target.value else: return None return variable_hash, cond.operands[0], value, target, next_node_addr return None
[docs] @staticmethod def restore_graph( node, last_stmt: IncompleteSwitchCaseHeadStatement, graph: networkx.DiGraph, full_graph: networkx.DiGraph ): last_node = node ca_default = [ (onode, value, target, a) for onode, value, target, a in last_stmt.case_addrs if value == "default" ] ca_others = [ (onode, value, target, a) for onode, value, target, a in last_stmt.case_addrs if value != "default" ] # non-default nodes ca_others = {ca[0].addr: ca for ca in ca_others} # extract the AIL block from last_node last_block = last_node if isinstance(last_block, SequenceNode): last_block = last_block.nodes[-1] if isinstance(last_block, MultiNode): last_block = last_block.nodes[-1] assert isinstance(last_block, Block) next_node_addr = last_block.addr while next_node_addr is not None and next_node_addr in ca_others: onode, value, target, next_node_addr = ca_others[next_node_addr] onode: Block if first_nonlabel_statement(onode) is not onode.statements[-1]: onode = onode.copy(statements=[onode.statements[-1]]) graph.add_edge(last_node, onode) full_graph.add_edge(last_node, onode) target_node = next(iter(nn for nn in full_graph if nn.addr == target)) graph.add_edge(onode, target_node) full_graph.add_edge(onode, target_node) if graph.has_edge(node, target_node): graph.remove_edge(node, target_node) if full_graph.has_edge(node, target_node): full_graph.remove_edge(node, target_node) # update last_node last_node = onode # default nodes if ca_default: onode, value, target, _ = ca_default[0] default_target = next(iter(nn for nn in full_graph if nn.addr == target)) graph.add_edge(last_node, default_target) full_graph.add_edge(last_node, default_target) if graph.has_edge(node, default_target): graph.remove_edge(node, default_target) if full_graph.has_edge(node, default_target): full_graph.remove_edge(node, default_target) # all good - remove the last statement in node remove_last_statement(node)
[docs] @staticmethod def cases_issubset(cases_0: List[Case], cases_1: List[Case]) -> bool: """ Test if cases_0 is a subset of cases_1. """ if len(cases_0) > len(cases_1): return False for case in cases_0: if case not in cases_1: return False return True