Source code for angr.analyses.decompiler.region_simplifiers.switch_cluster_simplifier

# pylint:disable=no-self-use,arguments-renamed
import enum
from typing import DefaultDict, Any, List, Union, Dict, Tuple, Set, Optional
from collections import OrderedDict, defaultdict

import ailment

from ....utils.constants import SWITCH_MISSING_DEFAULT_NODE_ADDR
from ..structuring.structurer_nodes import SwitchCaseNode, ConditionNode, SequenceNode, MultiNode
from ..sequence_walker import SequenceWalker


[docs]class CmpOp(enum.Enum): """ All supported comparison operators. """ LT = 0 GT = 1 EQ = 2 NE = 3
[docs]class ConditionalRegion: """ Describes a conditional region. """ __slots__ = ( "variable", "op", "value", "node", "parent", )
[docs] def __init__(self, variable, op: CmpOp, value: int, node: Union[ConditionNode, ailment.Block], parent=None): self.variable = variable self.op = op self.value = value self.node = node self.parent = parent
def __repr__(self): return f"<{self.variable} {self.op} {self.value} @ {self.node.addr:#x}, parent node {self.parent}>"
[docs]class SwitchCaseRegion: """ Describes an already-recovered switch region. """ __slots__ = ( "variable", "node", "parent", )
[docs] def __init__(self, variable, node: SwitchCaseNode, parent=None): self.variable = variable self.node = node self.parent = parent
def __repr__(self): return f"<{self.variable} ({len(self.node.cases)} cases) @ {self.node.addr:#x}, parent node {self.parent}>"
[docs]class SwitchClusterFinder(SequenceWalker): """ Find comparisons and switches in order to identify switch clusters. """
[docs] def __init__(self, node): handlers = { SwitchCaseNode: self._handle_SwitchCase, ConditionNode: self._handle_Condition, ailment.Block: self._handle_Block, } super().__init__(handlers) self.var2condnodes: DefaultDict[Any, List[ConditionalRegion]] = defaultdict(list) self.var2switches: DefaultDict[Any, List[SwitchCaseRegion]] = defaultdict(list) self.walk(node)
def _handle_Block(self, node: ailment.Block, parent=None, **kwargs): # pylint:disable=unused-argument if node.statements and isinstance(node.statements[-1], ailment.Stmt.ConditionalJump): cond = node.statements[-1].condition self._process_condition(cond, node, parent) def _handle_Condition(self, node, parent=None, **kwargs): cond = node.condition self._process_condition(cond, node, parent) return super()._handle_Condition(node, parent=parent, **kwargs) def _handle_SwitchCase(self, node: SwitchCaseNode, parent=None, **kwargs): cond = node.switch_expr if hasattr(cond, "variable"): variable = cond.variable scr = SwitchCaseRegion(variable, node, parent) self.var2switches[variable].append(scr) return super()._handle_SwitchCase(node, parent=parent, **kwargs) def _process_condition(self, cond: ailment.Expr.Expression, node: Union[ConditionNode, ailment.Block], parent): negated = False if isinstance(cond, ailment.Expr.UnaryOp) and cond.op == "Not": negated = True cond = cond.operand if isinstance(cond, ailment.Expr.BinaryOp) and cond.op in { "CmpLE", "CmpLT", "CmpGE", "CmpGT", "CmpEQ", "CmpNE", }: v = None variable = None if isinstance(cond.operands[1], ailment.Expr.Const): v = cond.operands[1].value if isinstance(cond.operands[0], (ailment.Expr.Register, ailment.Expr.Load)): if hasattr(cond.operands[0], "variable"): # there we go variable = cond.operands[0].variable if v is not None and variable is not None: if negated: real_op = ailment.Expr.BinaryOp.COMPARISON_NEGATION[cond.op] else: real_op = cond.op # eliminate equal if real_op == "CmpLE": real_op = CmpOp.LT v += 1 elif real_op == "CmpGE": real_op = CmpOp.GT v -= 1 elif real_op == "CmpLT": real_op = CmpOp.LT elif real_op == "CmpGT": real_op = CmpOp.GT elif real_op == "CmpEQ": real_op = CmpOp.EQ elif real_op == "CmpNE": real_op = CmpOp.NE else: raise TypeError(f"Unsupported real_op value {real_op}") cr = ConditionalRegion(variable, real_op, v, node, parent=parent) self.var2condnodes[variable].append(cr)
[docs]class SwitchClusterReplacer(SequenceWalker): """ Replace an identified switch cluster with a newly created SwitchCase node. """
[docs] def __init__(self, region, to_replace, replace_with): handlers = { ConditionNode: self._handle_Condition, } super().__init__(handlers) self.to_replace = to_replace self.replace_with = replace_with self.walk(region)
def _handle_Condition(self, node, **kwargs): if node is self.to_replace: return self.replace_with return super()._handle_Condition(node, **kwargs)
[docs]def is_simple_jump_node(node, case_addrs, targets: Optional[Set[int]] = None) -> bool: if isinstance(node, (SequenceNode, MultiNode)): return all(is_simple_jump_node(nn, case_addrs) for nn in node.nodes) if isinstance(node, ailment.Block): for stmt in node.statements: if isinstance(stmt, ailment.Stmt.Label): pass elif ( isinstance(stmt, ailment.Stmt.Jump) and isinstance(stmt.target, ailment.Expr.Const) and stmt.target.value in case_addrs ): if targets is not None: targets.add(stmt.target.value) elif isinstance(stmt, ailment.Stmt.ConditionalJump): ok = False if ( stmt.true_target is None or isinstance(stmt.true_target, ailment.Expr.Const) and stmt.true_target.value in case_addrs ): ok = True if stmt.true_target is not None and targets is not None: targets.add(stmt.true_target.value) if ( stmt.false_target is None or isinstance(stmt.false_target, ailment.Expr.Const) and stmt.false_target.value in case_addrs ): ok = True if stmt.false_target is not None and targets is not None: targets.add(stmt.false_target.value) if not ok: return False else: return False return True return False
[docs]def filter_cond_regions(cond_regions: List[ConditionalRegion], case_addrs: Set[int]) -> List[ConditionalRegion]: """ Remove all conditional regions that cannot be merged into switch(es). """ # TODO: the check is too strict! new_cond_regions = [] for r in cond_regions: ok = True if isinstance(r.node, ConditionNode): true_node, false_node = r.node.true_node, r.node.false_node # it should only contain a jump into one of the cases ok = (true_node is None or is_simple_jump_node(true_node, case_addrs)) and ( false_node is None or is_simple_jump_node(false_node, case_addrs) ) elif isinstance(r.node, ailment.Block): if not is_simple_jump_node(r.node, case_addrs): ok = False else: raise TypeError(f"Unsupported r.node type {type(r.node)}") if ok: new_cond_regions.append(r) return new_cond_regions
[docs]def update_switch_case_list( cases: List[Tuple[Union[int, Tuple[int, ...]], SequenceNode]], old_case_id: Union[int, Tuple[int, ...]], new_case_id: int, ) -> None: """ Update cases in-place. Make new_case_id directly jump to old_case_id. """ for i, (case_ids, case_node) in list(enumerate(cases)): if isinstance(case_ids, int): match = old_case_id == case_ids new_key = (case_ids, new_case_id) elif isinstance(case_ids, tuple): match = old_case_id == case_ids or old_case_id in case_ids new_key = case_ids + (new_case_id,) else: raise TypeError(f"Unsupported case_ids type {type(case_ids)}") if match: cases[i] = (new_key, case_node) break
[docs]def simplify_switch_clusters( region, var2condnodes: Dict[Any, List[ConditionalRegion]], var2switches: Dict[Any, List[SwitchCaseRegion]] ): """ Identify switch clusters and simplify each of them. :param region: The region to simplify. :param var2condnodes: A dict that stores the mapping from (potential) switch variables to conditional regions. :param var2switches: A dict that stores the mapping from switch variables to switch-case regions. :return: None """ for variable in var2switches: switch_regions = var2switches[variable] cond_regions = list(var2condnodes[variable]) if not cond_regions: continue # each switch region belongs to a conditional region switch_region_to_parent_region: Dict[SwitchCaseRegion, Tuple[ConditionalRegion, str]] = {} used_condnodes_and_branch = set() for r in switch_regions: for cr in cond_regions: if isinstance(cr.node, ConditionNode): if cr.node.true_node is r.node: switch_region_to_parent_region[r] = (cr, "true") used_condnodes_and_branch.add((cr, "true")) break if cr.node.false_node is r.node: switch_region_to_parent_region[r] = (cr, "false") used_condnodes_and_branch.add((cr, "false")) break if len(switch_region_to_parent_region) != len(switch_regions): continue # there is at most one default node default_node_addrs = set() for r in switch_regions: if r.node.default_node is not None: dn = r.node.default_node if dn.addr == SWITCH_MISSING_DEFAULT_NODE_ADDR: # parse the jump target out of the manually created default node if ( isinstance(dn, ailment.Block) and isinstance(dn.statements[-1], ailment.Stmt.Jump) and isinstance(dn.statements[-1].target, ailment.Expr.Const) ): default_node_addrs.add(dn.statements[-1].target.value) else: default_node_addrs.add(dn.addr) else: default_node_addrs.add(dn.addr) if len(default_node_addrs) > 1: continue # ensure cases in each switch do not overlap case_ids = set() overlaps = False for r in switch_regions: for case_id in r.node.cases: if case_id in case_ids: overlaps = True break case_ids.add(case_id) if overlaps: break if overlaps: continue used_condnodes = list(condnode for condnode, _ in used_condnodes_and_branch) # collect addresses for all case nodes case_addr_to_case_id: Dict[int, Union[int, Tuple[int, ...], str]] = {} for sr in switch_regions: for case_id, case_node in sr.node.cases.items(): case_addr_to_case_id[case_node.addr] = case_id if sr.node.default_node is not None: case_addr_to_case_id[sr.node.default_node.addr] = "default" case_addrs: Set[int] = set(case_addr_to_case_id) # filter cond_regions mergeable_cond_regions = filter_cond_regions(cond_regions, case_addrs) # list all unmatched conditional nodes standalone_condnodes: List[Tuple[ConditionalRegion, str]] = [] standalone_condjumps: List[ConditionalRegion] = [] for cr in mergeable_cond_regions: if isinstance(cr.node, ConditionNode): if cr.node.true_node is not None and (cr, "true") not in used_condnodes_and_branch: standalone_condnodes.append((cr, "true")) if cr.node.false_node is not None and (cr, "false") not in used_condnodes_and_branch: standalone_condnodes.append((cr, "false")) elif isinstance(cr.node, ailment.Block): # the last statement is a conditional jump standalone_condjumps.append(cr) else: raise TypeError(f"Unsupported cr.node type {type(cr.node)}") # all conditional regions belong to the same parent node parents = {r.parent for r in mergeable_cond_regions + used_condnodes} if len(parents) != 1: continue # the hard part: ensure there are no gaps between value ranges of each conditional region parent = next(iter(parents)) # build the switch-case node cases = [] default_node = None # first let's merge all switch cases # to ensure proper handling of fallthrough into the default node, we add the one with a default node at the end switch_regions_no_default_nodes = [] switch_regions_default_nodes = [] for sr in switch_regions: if sr.node.default_node is not None and not switch_regions_default_nodes: switch_regions_default_nodes.append(sr) else: switch_regions_no_default_nodes.append(sr) for sr in switch_regions_no_default_nodes: for case_idx, node in sr.node.cases.items(): cases.append((case_idx, node)) if switch_regions_default_nodes: for case_idx, node in switch_regions_default_nodes[0].node.cases.items(): cases.append((case_idx, node)) default_node = switch_regions_default_nodes[0].node.default_node # then let's handle conditional nodes and jumps # for now, we only support two types: # # if (var == A) { # goto one_case/default; # } # ... the actual switch # # or # # if (var > A) { # goto default; # } # ... the actual switch ok = False if not standalone_condnodes: if not standalone_condjumps: ok = True elif len(standalone_condjumps) == 1: condjump = standalone_condjumps[0] targets = set() if is_simple_jump_node(condjump.node, case_addrs, targets=targets) and len(targets) == 1: target = next(iter(targets)) if condjump.op == CmpOp.EQ: # which case is it jumping to? existing_case_id = case_addr_to_case_id[target] if existing_case_id == "default": # just add it at the end cases.append((condjump.value, SequenceNode(0, nodes=[]))) else: update_switch_case_list( cases, existing_case_id, condjump.value, ) ok = True elif condjump.op in {CmpOp.GT, CmpOp.LT}: if default_node is not None and target == default_node.addr: # it goes to default anyway - we don't need to do anything ok = True if not ok: continue # build the SwitchCase node and replace old nodes in the parent node cases_dict = OrderedDict(cases) new_switchcase = SwitchCaseNode( switch_regions_default_nodes[0].node.switch_expr, cases_dict, default_node, addr=switch_regions_default_nodes[0].node.addr, ) # what are we trying to replace? if isinstance(parent, SequenceNode): start_idx = None # inclusive end_idx = None # not inclusive for idx, node in enumerate(parent.nodes): if ( any(cr.node is node for cr, _ in standalone_condnodes) or any(cr.node is node for cr in standalone_condjumps) or any(cr.node is node for cr in used_condnodes) ): if start_idx is None: start_idx = idx if end_idx is not None: # found a gap... pass else: if start_idx is not None: if end_idx is None: end_idx = idx if start_idx is None: continue if end_idx is None: end_idx = len(parent.nodes) # we want to replace parent.nodes[start_idx : end_idx] with a SwitchCaseNode parent.nodes[start_idx:end_idx] = [new_switchcase] elif isinstance(parent, ConditionNode): # replace this condition node with the new SwitchCase node SwitchClusterReplacer(region, parent, new_switchcase) else: # unsupported for now continue