Source code for angr.analyses.decompiler.optimization_passes.condition_constprop

from __future__ import annotations

import networkx

from ailment import AILBlockWalker, Block
from ailment.statement import ConditionalJump, Statement
from ailment.expression import Const, BinaryOp, VirtualVariable

from angr.analyses.decompiler.region_identifier import RegionIdentifier
from .optimization_pass import OptimizationPass, OptimizationPassStage


class ConstantCondition:
    """
    Describes an opportunity for replacing a vvar with a constant value.
    """

    def __init__(self, vvar_id: int, value: Const, block_addr: int, block_idx: int | None):
        self.vvar_id = vvar_id
        self.value = value
        self.block_addr = block_addr
        self.block_idx = block_idx

    def __repr__(self):
        return f"<ConstCond vvar_{self.vvar_id} == {self.value} since {self.block_addr:#x}-{self.block_idx}>"


class CCondPropBlockWalker(AILBlockWalker):
    """
    Block walker for ConditionConstantPropagation to replace vvars with constant values.
    """

    def __init__(self, vvar_id: int, const_value: Const):
        super().__init__()
        self._new_block: Block | None = None  # output
        self.vvar_id = vvar_id
        self.const_value = const_value

    def walk(self, block: Block):
        self._new_block = None
        super().walk(block)
        return self._new_block

    def _handle_stmt(self, stmt_idx: int, stmt: Statement, block: Block):  # type: ignore
        r = super()._handle_stmt(stmt_idx, stmt, block)
        if r is not None:
            # replace the original statement
            if self._new_block is None:
                self._new_block = block.copy()
            self._new_block.statements[stmt_idx] = r

    def _handle_VirtualVariable(  # type: ignore
        self, expr_idx: int, expr: VirtualVariable, stmt_idx: int, stmt: Statement, block: Block | None
    ) -> Const | None:
        if expr.varid == self.vvar_id:
            return Const(expr.idx, None, self.const_value.value, self.const_value.bits, **expr.tags)
        return None


[docs] class ConditionConstantPropagation(OptimizationPass): """ Reason about constant propagation opportunities from conditionals and propagate constants in the graph accordingly. """ ARCHES = None PLATFORMS = None STAGE = OptimizationPassStage.AFTER_SINGLE_BLOCK_SIMPLIFICATION NAME = "Propagate constants using information deduced from conditionals." DESCRIPTION = __doc__.strip() # type: ignore
[docs] def __init__(self, func, **kwargs): super().__init__(func, **kwargs) self.analyze()
def _check(self): cconds = self._find_const_conditions() if not cconds: return False, None return True, {"cconds": cconds} def _analyze(self, cache=None): if not cache or cache.get("cconds", None) is None: # noqa: SIM108 cconds = self._find_const_conditions() else: cconds = cache["cconds"] if not cconds: return # group cconds according to their sources cconds_by_src: dict[tuple[int, int | None], list[ConstantCondition]] = {} for ccond in cconds: src = ccond.block_addr, ccond.block_idx if src not in cconds_by_src: cconds_by_src[src] = [] cconds_by_src[src].append(ccond) # calculate a dominance frontier for each block entry_node_addr, entry_node_idx = self.entry_node_addr entry_node = self._get_block(entry_node_addr, idx=entry_node_idx) df = networkx.algorithms.dominance_frontiers(self._graph, entry_node) for src, cconds in cconds_by_src.items(): head_block = self._get_block(src[0], idx=src[1]) if head_block is None: continue frontier = df.get(head_block) if frontier is None: continue graph_slice = RegionIdentifier.slice_graph(self._graph, head_block, frontier, include_frontier=False) for ccond in cconds: walker = CCondPropBlockWalker(ccond.vvar_id, ccond.value) for block in graph_slice: new_block = walker.walk(block) if new_block is not None: self._update_block(block, new_block) def _find_const_conditions(self) -> list[ConstantCondition]: cconds = [] for block in self._graph: if block.statements: last_stmt = block.statements[-1] if ( not isinstance(last_stmt, ConditionalJump) or not isinstance(last_stmt.true_target, Const) or not isinstance(last_stmt.false_target, Const) ): continue if isinstance(last_stmt.condition, BinaryOp): cond = last_stmt.condition op = cond.op op0, op1 = cond.operands if isinstance(op0, Const): op0, op1 = op1, op0 if isinstance(op0, VirtualVariable) and isinstance(op1, Const) and op1.is_int: if op == "CmpEQ": ccond = ConstantCondition( op0.varid, op1, last_stmt.true_target.value, last_stmt.true_target_idx # type: ignore ) cconds.append(ccond) elif op == "CmpNE": ccond = ConstantCondition( op0.varid, op1, last_stmt.false_target.value, last_stmt.false_target_idx # type: ignore ) cconds.append(ccond) return cconds