Source code for angr.analyses.decompiler.optimization_passes.ret_addr_save_simplifier

from typing import Tuple, Optional, List, Any
import logging

import ailment

from ....calling_conventions import SimRegArg, DEFAULT_CC
from .optimization_pass import OptimizationPass, OptimizationPassStage

_l = logging.getLogger(name=__name__)


[docs]class RetAddrSaveSimplifier(OptimizationPass): """ Removes code in function prologues and epilogues for saving and restoring return address registers (ra, lr, etc.), generally seen in non-leaf functions. """ ARCHES = ["MIPS32", "MIPS64"] PLATFORMS = ["linux"] STAGE = OptimizationPassStage.AFTER_GLOBAL_SIMPLIFICATION NAME = "Simplify return address storage" DESCRIPTION = __doc__.strip()
[docs] def __init__(self, func, **kwargs): super().__init__(func, **kwargs) self.analyze()
def _check(self): if self.project.arch.name not in DEFAULT_CC: return False, {} default_cc = DEFAULT_CC[self.project.arch.name](self.project.arch) if not isinstance(default_cc.return_addr, SimRegArg): return False, {} save_stmt = self._find_retaddr_save_stmt() # Note that restoring statements may not exist since they can be effectively removed by other optimization restore_stmts = self._find_retaddr_restore_stmt() if save_stmt is None: return False, {} save_dst = save_stmt[2] if restore_stmts is not None: restore_srcs = [tpl[2] for tpl in restore_stmts] if all(src == save_dst for src in restore_srcs): return True, { "save_stmt": save_stmt, "restore_stmts": restore_stmts, } return True, { "save_stmt": save_stmt, "restore_stmts": [], } def _analyze(self, cache=None): save_stmt = None restore_stmts = None if cache is not None: save_stmt = cache.get("save_stmt", None) restore_stmts = cache.get("restore_stmts", None) if save_stmt is None: save_stmt = self._find_retaddr_save_stmt() if restore_stmts is None: restore_stmts = self._find_retaddr_restore_stmt() if save_stmt is None: return # update the first block block, stmt_idx, _ = save_stmt block_copy = block.copy() block_copy.statements.pop(stmt_idx) self._update_block(block, block_copy) # update all endpoint blocks if restore_stmts: for block, stmt_idx, _ in restore_stmts: block_copy = block.copy() block_copy.statements.pop(stmt_idx) self._update_block(block, block_copy) def _find_retaddr_save_stmt(self) -> Optional[Tuple[Any, int, ailment.Expr.StackBaseOffset]]: """ Find the AIL statement that saves the return address to a stack slot. :return: A tuple of (block_addr, statement_idx, save_dst) or None if not found. """ first_block = self._get_block(self._func.addr) if first_block is None: return None default_cc = DEFAULT_CC[self.project.arch.name](self.project.arch) retaddr = default_cc.return_addr assert isinstance(retaddr, SimRegArg) retaddr_reg = self.project.arch.registers[retaddr.reg_name][0] for idx, stmt in enumerate(first_block.statements): if ( isinstance(stmt, ailment.Stmt.Store) and isinstance(stmt.addr, ailment.Expr.StackBaseOffset) and isinstance(stmt.data, ailment.Expr.Register) and stmt.data.reg_offset == retaddr_reg and stmt.addr.offset < 0 ): return first_block, idx, stmt.addr if ( isinstance(stmt, ailment.Stmt.Store) and isinstance(stmt.addr, ailment.Expr.StackBaseOffset) and isinstance(stmt.data, ailment.Expr.StackBaseOffset) and stmt.data.offset == 0 and stmt.addr.offset < 0 ): return first_block, idx, stmt.addr # Not found return None def _find_retaddr_restore_stmt(self) -> Optional[List[Tuple[Any, int, ailment.Expr.StackBaseOffset]]]: """ Find the AIL statement that restores the return address from a stack slot. :return: A list of tuples, where each tuple is like (block_addr, statement_idx, load_src), or None if not found. """ endpoints = self._func.endpoints callouts_and_jumpouts = {n.addr for n in self._func.callout_sites + self._func.jumpout_sites} retaddr_restore_stmts = [] default_cc = DEFAULT_CC[self.project.arch.name](self.project.arch) retaddr = default_cc.return_addr assert isinstance(retaddr, SimRegArg) retaddr_reg = self.project.arch.registers[retaddr.reg_name][0] for endpoint in endpoints: for endpoint_block in self._get_blocks(endpoint.addr): for idx, stmt in enumerate(endpoint_block.statements): if ( isinstance(stmt, ailment.Stmt.Assignment) and isinstance(stmt.dst, ailment.Expr.Register) and stmt.dst.reg_offset == retaddr_reg and isinstance(stmt.src, ailment.Expr.Load) and isinstance(stmt.src.addr, ailment.Expr.StackBaseOffset) ): retaddr_restore_stmts.append((endpoint_block, idx, stmt.src.addr)) break else: if endpoint.addr not in callouts_and_jumpouts: _l.debug("Could not find retaddr restoring statement in function %#x.", endpoint.addr) return None else: _l.debug( "No retaddr restoring statement is found at callout/jumpout site %#x. " "Might be expected.", endpoint.addr, ) return retaddr_restore_stmts