Source code for angr.analyses.decompiler.redundant_label_remover

# pylint:disable=unused-argument
from typing import Set, Optional, Tuple, Dict

import ailment

from .sequence_walker import SequenceWalker
from .structuring.structurer_nodes import SequenceNode
from .utils import first_nonlabel_statement


[docs]class RedundantLabelRemover: """ Remove redundant labels. This optimization pass contains two separate passes. The first pass (self._walker0) finds all redundant labels (e.g., two or more labels for the same location) and records the replacement label for redundant labels in self._new_jump_target. The second pass (self._walker1) removes all redundant labels that (a) are not referenced anywhere (determined by jump_targets), or (b) are deemed replaceable by the first pass. """
[docs] def __init__(self, node, jump_targets: Set[Tuple[int, Optional[int]]]): self.root = node self._jump_targets = jump_targets self._labels_to_remove: Set[ailment.Stmt.Label] = set() self._new_jump_target: Dict[Tuple[int, Optional[int]], Tuple[int, Optional[int]]] = {} handlers0 = { SequenceNode: self._handle_Sequence, } self._walker0 = SequenceWalker(handlers=handlers0) self._walker0.walk(self.root) handlers1 = { ailment.Block: self._handle_Block, } self._walker1 = SequenceWalker(handlers=handlers1) self._walker1.walk(self.root) self.result = self.root
# # Handlers # def _handle_Sequence(self, node: SequenceNode, **kwargs): # merge consecutive labels last_label_addr: Optional[Tuple[int, Optional[int]]] = None for node_ in node.nodes: if isinstance(node_, ailment.Block): if node_.statements: for stmt in node_.statements: if isinstance(stmt, ailment.Stmt.Label): if last_label_addr is None: # record the label address last_label_addr = stmt.ins_addr, stmt.block_idx else: # this label is useless - we should replace this label with the last label self._labels_to_remove.add(stmt) self._new_jump_target[(stmt.ins_addr, stmt.block_idx)] = last_label_addr else: last_label_addr = None break else: last_label_addr = None return self._walker0._handle_Sequence(node, **kwargs) def _handle_Block(self, block: ailment.Block, **kwargs): if block.statements: # fixed point remove all labels with no edges in while True: for idx, stmt in enumerate(block.statements): if isinstance(stmt, ailment.Stmt.Label): if (stmt.ins_addr, stmt.block_idx) not in self._jump_targets or stmt in self._labels_to_remove: # useless label - update the block in-place block.statements = block.statements[:idx] + block.statements[idx + 1 :] break else: break first_stmt = first_nonlabel_statement(block) if isinstance(first_stmt, ailment.Stmt.ConditionalJump): if isinstance(first_stmt.true_target, ailment.Expr.Const): tpl = first_stmt.true_target.value, None if tpl in self._new_jump_target: first_stmt.true_target = ailment.Expr.Const( first_stmt.true_target.idx, first_stmt.true_target.variable, self._new_jump_target[tpl][0], first_stmt.true_target.bits, **first_stmt.true_target.tags, ) if isinstance(first_stmt.false_target, ailment.Expr.Const): tpl = first_stmt.false_target.value, None if tpl in self._new_jump_target: first_stmt.false_target = ailment.Expr.Const( first_stmt.false_target.idx, first_stmt.false_target.variable, self._new_jump_target[tpl][0], first_stmt.false_target.bits, **first_stmt.false_target.tags, ) if block.statements: last_stmt = block.statements[-1] if isinstance(last_stmt, ailment.Stmt.Jump): if isinstance(last_stmt.target, ailment.Expr.Const): tpl = last_stmt.target.value, last_stmt.target_idx if tpl in self._new_jump_target: last_stmt.target = ailment.Expr.Const( last_stmt.target.idx, last_stmt.target.variable, self._new_jump_target[tpl][0], last_stmt.target.bits, **last_stmt.target.tags, ) last_stmt.target_idx = self._new_jump_target[tpl][1]