Source code for angr.analyses.s_propagator

from __future__ import annotations

import contextlib
from collections.abc import Mapping
from collections import defaultdict

import networkx

from ailment.block import Block
from ailment.expression import (
    Const,
    VirtualVariable,
    VirtualVariableCategory,
    StackBaseOffset,
    Load,
    Convert,
    Expression,
)
from ailment.statement import Assignment, Store, Return, Jump, ConditionalJump

from angr.knowledge_plugins.functions import Function
from angr.code_location import CodeLocation, ExternalCodeLocation
from angr.analyses import Analysis, register_analysis
from angr.utils.ssa import (
    get_vvar_uselocs,
    get_vvar_deflocs,
    has_ite_expr,
    has_ite_stmt,
    is_phi_assignment,
    is_const_assignment,
    is_const_and_vvar_assignment,
    is_const_vvar_load_assignment,
    is_const_vvar_load_dirty_assignment,
    is_const_vvar_tmp_assignment,
    get_tmp_uselocs,
    get_tmp_deflocs,
    phi_assignment_get_src,
    has_store_stmt_in_between_stmts,
)


class SPropagatorModel:
    """
    The SPropagator model that stores replacements for virtual variables.
    """

    def __init__(self):
        self.replacements: Mapping[CodeLocation, Mapping[Expression, Expression]] = {}
        # store vvars that are definitely dead (but usually not removed by default because they are stack variables)
        self.dead_vvar_ids: set[int] = set()


[docs] class SPropagatorAnalysis(Analysis): """ Constant and expression propagation that only supports SSA AIL graphs. """
[docs] def __init__( # pylint: disable=too-many-positional-arguments self, subject: Block | Function, func_graph: networkx.DiGraph | None = None, only_consts: bool = True, stack_pointer_tracker=None, func_args: set[VirtualVariable] | None = None, func_addr: int | None = None, ): if isinstance(subject, Block): self.block = subject self.func = None self.mode = "block" elif isinstance(subject, Function): self.block = None self.func = subject self.mode = "function" else: raise TypeError(f"Unsupported subject type {type(subject)}") self.func_graph = func_graph self.func_addr = func_addr self.func_args = func_args self.only_consts = only_consts self._sp_tracker = stack_pointer_tracker bp_as_gpr = False the_func = None if self.func is not None: the_func = self.func else: if self.func_addr is not None: with contextlib.suppress(KeyError): the_func = self.kb.functions.get_by_addr(self.func_addr) if the_func is not None: bp_as_gpr = the_func.info.get("bp_as_gpr", False) self._bp_as_gpr = bp_as_gpr # output self.model = SPropagatorModel() self._analyze()
@property def replacements(self): return self.model.replacements @property def dead_vvar_ids(self): return self.model.dead_vvar_ids def _analyze(self): blocks: dict[tuple[int, int | None], Block] match self.mode: case "block": assert self.block is not None blocks = {(self.block.addr, self.block.idx): self.block} case "function": assert self.func_graph is not None blocks = {(block.addr, block.idx): block for block in self.func_graph} case _: raise NotImplementedError # find all vvar definitions vvar_deflocs = get_vvar_deflocs(blocks.values()) # find all vvar uses vvar_uselocs = get_vvar_uselocs(blocks.values()) # update vvar_deflocs using function arguments if self.func_args: for func_arg in self.func_args: vvar_deflocs[func_arg] = ExternalCodeLocation() # find all ret sites and indirect jump sites retsites: set[tuple[int, int | None, int]] = set() jumpsites: set[tuple[int, int | None, int]] = set() for bb in blocks.values(): if bb.statements: if isinstance(bb.statements[-1], Return): retsites.add((bb.addr, bb.idx, len(bb.statements) - 1)) elif isinstance(bb.statements[-1], Jump): jumpsites.add((bb.addr, bb.idx, len(bb.statements) - 1)) replacements = defaultdict(dict) # find constant and other propagatable assignments vvarid_to_vvar = {} const_vvars: dict[int, Const] = {} for vvar, defloc in vvar_deflocs.items(): if not vvar.was_reg and not vvar.was_parameter: continue vvarid_to_vvar[vvar.varid] = vvar if isinstance(defloc, ExternalCodeLocation): continue assert defloc.block_addr is not None assert defloc.stmt_idx is not None block = blocks[(defloc.block_addr, defloc.block_idx)] stmt = block.statements[defloc.stmt_idx] r, v = is_const_assignment(stmt) if r: # replace wherever it's used assert v is not None const_vvars[vvar.varid] = v for vvar_at_use, useloc in vvar_uselocs[vvar.varid]: replacements[useloc][vvar_at_use] = v continue v = phi_assignment_get_src(stmt) if v is not None: src_varids = {vvar.varid if vvar is not None else None for _, vvar in v.src_and_vvars} if None not in src_varids and all(varid in const_vvars for varid in src_varids): all_int_src_varids: set[int] = {varid for varid in src_varids if varid is not None} src_values = { ( (const_vvars[varid].value, const_vvars[varid].bits) if isinstance(const_vvars[varid], Const) else const_vvars[varid] ) for varid in all_int_src_varids } if len(src_values) == 1: # replace it! const_value = const_vvars[next(iter(all_int_src_varids))] const_vvars[vvar.varid] = const_value for vvar_at_use, useloc in vvar_uselocs[vvar.varid]: replacements[useloc][vvar_at_use] = const_value # function mode only if self.mode == "function": assert self.func_graph is not None for vvar, defloc in vvar_deflocs.items(): if vvar.varid not in vvar_uselocs: continue if vvar.varid in const_vvars: continue if isinstance(defloc, ExternalCodeLocation): continue assert defloc.block_addr is not None assert defloc.stmt_idx is not None block = blocks[(defloc.block_addr, defloc.block_idx)] stmt = block.statements[defloc.stmt_idx] if ( (vvar.was_reg or vvar.was_parameter) and len(vvar_uselocs[vvar.varid]) <= 2 and isinstance(stmt, Assignment) and isinstance(stmt.src, Load) ): # do we want to propagate this Load expression if it's used for less than twice? # it's often seen in the following pattern, where propagation will be beneficial: # v0 = Load(...) # if (!v0) { # v1 = v0 + 1; # } can_replace = True for _, vvar_useloc in vvar_uselocs[vvar.varid]: if has_store_stmt_in_between_stmts(self.func_graph, blocks, defloc, vvar_useloc): can_replace = False if can_replace: # we can propagate this load because there is no store between its def and use for vvar_used, vvar_useloc in vvar_uselocs[vvar.varid]: replacements[vvar_useloc][vvar_used] = stmt.src continue if ( (vvar.was_reg or vvar.was_stack) and len(vvar_uselocs[vvar.varid]) == 2 and not is_phi_assignment(stmt) ): # a special case: in a typical switch-case construct, a variable may be used once for comparison # for the default case and then used again for constructing the jump target. we can propagate this # variable for such cases. uselocs = {loc for _, loc in vvar_uselocs[vvar.varid]} if self.is_vvar_used_for_addr_loading_switch_case(uselocs, blocks): for vvar_used, vvar_useloc in vvar_uselocs[vvar.varid]: replacements[vvar_useloc][vvar_used] = stmt.src # mark the vvar as dead and should be removed self.model.dead_vvar_ids.add(vvar.varid) continue if vvar.was_reg or vvar.was_parameter: if len(vvar_uselocs[vvar.varid]) == 1: vvar_used, vvar_useloc = next(iter(vvar_uselocs[vvar.varid])) if is_const_vvar_load_assignment(stmt) and not has_store_stmt_in_between_stmts( self.func_graph, blocks, defloc, vvar_useloc ): # we can propagate this load because there is no store between its def and use replacements[vvar_useloc][vvar_used] = stmt.src continue if is_const_and_vvar_assignment(stmt): # if the useloc is a phi assignment statement, ensure that stmt.src is the same as the phi # variable assert vvar_useloc.block_addr is not None assert vvar_useloc.stmt_idx is not None useloc_stmt = blocks[(vvar_useloc.block_addr, vvar_useloc.block_idx)].statements[ vvar_useloc.stmt_idx ] if is_phi_assignment(useloc_stmt): if ( isinstance(stmt.src, VirtualVariable) and stmt.src.oident == useloc_stmt.dst.oident and stmt.src.category == useloc_stmt.dst.category ): replacements[vvar_useloc][vvar_used] = stmt.src else: replacements[vvar_useloc][vvar_used] = stmt.src continue else: non_exitsite_uselocs = [ loc for _, loc in vvar_uselocs[vvar.varid] if (loc.block_addr, loc.block_idx, loc.stmt_idx) not in (retsites | jumpsites) ] if is_const_and_vvar_assignment(stmt): if len(non_exitsite_uselocs) == 1: # this vvar is used once if we exclude its uses at ret sites or jump sites. we can # propagate it for vvar_used, vvar_useloc in vvar_uselocs[vvar.varid]: replacements[vvar_useloc][vvar_used] = stmt.src continue if len(set(non_exitsite_uselocs)) == 1 and not has_ite_expr(stmt.src): useloc = non_exitsite_uselocs[0] assert useloc.block_addr is not None assert useloc.stmt_idx is not None useloc_stmt = blocks[(useloc.block_addr, useloc.block_idx)].statements[useloc.stmt_idx] if stmt.src.depth <= 3 and not has_ite_stmt(useloc_stmt): # remove duplicate use locs (e.g., if the variable is used multiple times by the # same statement) - but ensure stmt is simple enough for vvar_used, vvar_useloc in vvar_uselocs[vvar.varid]: replacements[vvar_useloc][vvar_used] = stmt.src continue # special logic for global variables: if it's used once or multiple times, and the variable is never # updated before it's used, we will propagate the load if (vvar.was_reg or vvar.was_parameter) and isinstance(stmt, Assignment): stmt_src = stmt.src # unpack conversions while isinstance(stmt_src, Convert): stmt_src = stmt_src.operand if ( isinstance(stmt_src, Load) and isinstance(stmt_src.addr, Const) and isinstance(stmt_src.addr.value, int) ): gv_updated = False for _vvar_used, vvar_useloc in vvar_uselocs[vvar.varid]: gv_updated |= self.is_global_variable_updated( self.func_graph, blocks, vvar.varid, stmt_src.addr.value, stmt_src.size, defloc, vvar_useloc, ) if not gv_updated: for vvar_used, vvar_useloc in vvar_uselocs[vvar.varid]: replacements[vvar_useloc][vvar_used] = stmt.src continue for vvar_id, uselocs in vvar_uselocs.items(): vvar = next(iter(uselocs))[0] if vvar_id not in vvarid_to_vvar else vvarid_to_vvar[vvar_id] if self._sp_tracker is not None and vvar.category == VirtualVariableCategory.REGISTER: if vvar.oident == self.project.arch.sp_offset: sp_bits = ( (self.project.arch.registers["sp"][1] * self.project.arch.byte_width) if "sp" in self.project.arch.registers else None ) for vvar_at_use, useloc in vvar_uselocs[vvar.varid]: sb_offset = self._sp_tracker.offset_before(useloc.ins_addr, self.project.arch.sp_offset) if sb_offset is not None: v = StackBaseOffset(None, self.project.arch.bits, sb_offset) if sp_bits is not None and vvar.bits < sp_bits: # truncation needed v = Convert(None, sp_bits, vvar.bits, False, v) replacements[useloc][vvar_at_use] = v continue if not self._bp_as_gpr and vvar.oident == self.project.arch.bp_offset: bp_bits = ( (self.project.arch.registers["bp"][1] * self.project.arch.byte_width) if "bp" in self.project.arch.registers else None ) for vvar_at_use, useloc in vvar_uselocs[vvar.varid]: sb_offset = self._sp_tracker.offset_before(useloc.ins_addr, self.project.arch.bp_offset) if sb_offset is not None: v = StackBaseOffset(None, self.project.arch.bits, sb_offset) if bp_bits is not None and vvar.bits < bp_bits: # truncation needed v = Convert(None, bp_bits, vvar.bits, False, v) replacements[useloc][vvar_at_use] = v continue # find all tmp definitions tmp_deflocs = get_tmp_deflocs(blocks.values()) # find all tmp uses tmp_uselocs = get_tmp_uselocs(blocks.values()) for block_loc, tmp_and_uses in tmp_uselocs.items(): for tmp_atom, tmp_uses in tmp_and_uses.items(): # take a look at the definition and propagate the definition if supported assert block_loc.block_addr is not None block = blocks[(block_loc.block_addr, block_loc.block_idx)] tmp_def_stmtidx = tmp_deflocs[block_loc][tmp_atom] stmt = block.statements[tmp_def_stmtidx] if isinstance(stmt, Assignment): r, v = is_const_assignment(stmt) if r: # we can propagate it! for tmp_used, tmp_use_stmtidx in tmp_uses: replacements[ CodeLocation(block_loc.block_addr, tmp_use_stmtidx, block_idx=block_loc.block_idx) ][tmp_used] = stmt.src continue r = is_const_vvar_tmp_assignment(stmt) if r: # we can propagate it! if isinstance(stmt.src, VirtualVariable): v = const_vvars.get(stmt.src.varid, stmt.src) else: v = stmt.src for tmp_used, tmp_use_stmtidx in tmp_uses: replacements[ CodeLocation(block_loc.block_addr, tmp_use_stmtidx, block_idx=block_loc.block_idx) ][tmp_used] = v continue if len(tmp_uses) <= 2: tmp_used, tmp_use_stmtidx = next(iter(tmp_uses)) if is_const_vvar_load_dirty_assignment(stmt): same_inst = ( block.statements[tmp_def_stmtidx].ins_addr == block.statements[tmp_use_stmtidx].ins_addr ) has_store = any( isinstance(stmt_, Store) for stmt_ in block.statements[tmp_def_stmtidx + 1 : tmp_use_stmtidx] ) if same_inst or not has_store: # we can propagate this load because either we do not consider memory aliasing problem # within the same instruction (blocks must be originally lifted with # CROSS_INSN_OPT=False), or there is no store between its def and use. replacements[ CodeLocation(block_loc.block_addr, tmp_use_stmtidx, block_idx=block_loc.block_idx) ][tmp_used] = stmt.src continue self.model.replacements = replacements
[docs] @staticmethod def is_global_variable_updated( func_graph, block_dict, varid: int, gv_addr: int, gv_size: int, defloc: CodeLocation, useloc: CodeLocation ) -> bool: defblock = block_dict[(defloc.block_addr, defloc.block_idx)] useblock = block_dict[(useloc.block_addr, useloc.block_idx)] # traverse a graph slice from the def block to the use block and check if the global variable is updated seen = {defblock} queue = [defblock] while queue: block = queue.pop(0) start_stmt_idx = defloc.stmt_idx if block is defblock else 0 # inclusive end_stmt_idx = useloc.stmt_idx if block is useblock else len(block.statements) # exclusive assert start_stmt_idx is not None assert end_stmt_idx is not None for idx in range(start_stmt_idx, end_stmt_idx): stmt = block.statements[idx] if isinstance(stmt, Store) and isinstance(stmt.addr, Const): store_addr = stmt.addr.value store_size = stmt.size if gv_addr <= store_addr < gv_addr + gv_size or store_addr <= gv_addr < store_addr + store_size: return True if block is useblock: continue for succ in func_graph.successors(block): if succ not in seen: abort_path = False for stmt in succ.statements: if is_phi_assignment(stmt) and any( vvar.varid == varid for _, vvar in stmt.src.src_and_vvars if vvar is not None ): # the virtual variable is no longer live after this point abort_path = True break if abort_path: continue seen.add(succ) queue.append(succ) return False
[docs] @staticmethod def is_vvar_used_for_addr_loading_switch_case(uselocs: set[CodeLocation], blocks) -> bool: """ Check if a virtual variable is used for loading an address in a switch-case construct. :param uselocs: The use locations of the virtual variable. :param blocks: All blocks of the current function. :return: True if the virtual variable is used for loading an address in a switch-case construct, False otherwise. """ if len(uselocs) != 2: return False useloc_0, useloc_1 = list(uselocs) block_0 = blocks[(useloc_0.block_addr, useloc_0.block_idx)] stmt_0 = block_0.statements[useloc_0.stmt_idx] block_1 = blocks[(useloc_1.block_addr, useloc_1.block_idx)] stmt_1 = block_1.statements[useloc_1.stmt_idx] if isinstance(stmt_0, Jump): stmt_0, stmt_1 = stmt_1, stmt_0 block_0, block_1 = block_1, block_0 if not isinstance(stmt_0, ConditionalJump) or not isinstance(stmt_1, Jump): return False # check if stmt_0 jumps to block_1 if not isinstance(stmt_0.true_target, Const) or not isinstance(stmt_0.false_target, Const): return False stmt_0_targets = { (stmt_0.true_target.value, stmt_0.true_target_idx), (stmt_0.false_target.value, stmt_0.false_target_idx), } return (block_1.addr, block_1.idx) in stmt_0_targets
register_analysis(SPropagatorAnalysis, "SPropagator")