Source code for angr.knowledge_plugins.key_definitions.rd_model

from typing import Dict, Tuple, Set, Union, Optional, TYPE_CHECKING, overload

from angr.knowledge_plugins.key_definitions.constants import ObservationPointType
from angr.code_location import CodeLocation

from .uses import Uses
from .live_definitions import LiveDefinitions

if TYPE_CHECKING:
    from angr.knowledge_plugins.key_definitions.definition import Definition


# TODO: Make ReachingDefinitionsModel serializable
[docs]class ReachingDefinitionsModel: """ Models the definitions, uses, and memory of a ReachingDefinitionState object """
[docs] def __init__(self, func_addr: Optional[int] = None): self.func_addr = func_addr # do not use. only for pretty-printing self.observed_results: Dict[ Tuple[str, Union[int, Tuple[int, int], Tuple[int, int, int]], ObservationPointType], LiveDefinitions ] = {} self.all_definitions: Set["Definition"] = set() self.all_uses = Uses()
def __repr__(self): return "<RDModel{} with {} observations>".format( f"[func {self.func_addr:#x}]" if self.func_addr is not None else "", len(self.observed_results), )
[docs] def copy(self) -> "ReachingDefinitionsModel": new = ReachingDefinitionsModel(self.func_addr) new.observed_results = self.observed_results.copy() new.all_definitions = self.all_definitions.copy() new.all_uses = self.all_uses.copy() return new
[docs] def merge(self, model: "ReachingDefinitionsModel"): for k, v in model.observed_results.items(): if k not in self.observed_results: self.observed_results[k] = v else: merged, merge_occured = self.observed_results[k].merge(v) if merge_occured: self.observed_results[k] = merged self.all_definitions.union(model.all_definitions) self.all_uses.merge(model.all_uses)
[docs] def get_observation_by_insn( self, ins_addr: Union[int, CodeLocation], kind: ObservationPointType ) -> Optional[LiveDefinitions]: if isinstance(ins_addr, int): return self.observed_results.get(("insn", ins_addr, kind), None) elif ins_addr.ins_addr is None: raise ValueError("CodeLocation must have an instruction address associated") return self.observed_results.get(("insn", ins_addr.ins_addr, kind))
[docs] def get_observation_by_node( self, node_addr: Union[int, CodeLocation], kind: ObservationPointType ) -> Optional[LiveDefinitions]: if isinstance(node_addr, int): return self.observed_results.get(("node", node_addr, kind), None) else: return self.observed_results.get(("node", node_addr.block_addr, kind))
@overload def get_observation_by_stmt(self, codeloc: CodeLocation, kind: ObservationPointType) -> Optional[LiveDefinitions]: ... @overload def get_observation_by_stmt( self, node_addr: int, stmt_idx: int, kind: ObservationPointType, *, block_idx: Optional[int] = None ): ...
[docs] def get_observation_by_stmt(self, arg1, arg2, arg3=None, *, block_idx=None): if isinstance(arg1, int): if block_idx is None: return self.observed_results.get(("stmt", (arg1, arg2), arg3), None) else: return self.observed_results.get(("stmt", (arg1, arg2, block_idx), arg3), None) else: if arg1.stmt_idx is None: raise ValueError("CodeLocation must have a statement index associated") if arg1.block_idx is None: return self.observed_results.get(("stmt", (arg1.block_addr, arg1.stmt_idx), arg2), None) else: return self.observed_results.get(("stmt", (arg1.block_addr, arg1.stmt_idx, block_idx), arg2), None)