Source code for angr.analyses.typehoon.simple_solver

# pylint:disable=missing-class-docstring
from typing import Union, Type, Set, Dict, Optional, Tuple, List, DefaultDict
import enum
from collections import defaultdict
import logging

import networkx

from angr.utils.constants import MAX_POINTSTO_BITS
from .typevars import (
    Existence,
    Subtype,
    Equivalence,
    Add,
    TypeVariable,
    DerivedTypeVariable,
    HasField,
    IsArray,
    TypeConstraint,
    Load,
    Store,
    BaseLabel,
    FuncIn,
    FuncOut,
    ConvertTo,
)
from .typeconsts import (
    BottomType,
    TopType,
    TypeConstant,
    Int,
    Int8,
    Int16,
    Int32,
    Int64,
    Pointer,
    Pointer32,
    Pointer64,
    Struct,
    Array,
    Function,
    int_type,
)
from .variance import Variance
from .dfa import DFAConstraintSolver, EmptyEpsilonNFAError

_l = logging.getLogger(__name__)


PRIMITIVE_TYPES = {
    TopType(),
    Int(),
    Int8(),
    Int16(),
    Int32(),
    Int64(),
    Pointer32(),
    Pointer64(),
    BottomType(),
    Struct(),
    Array(),
}

Top_ = TopType()
Int_ = Int()
Int64_ = Int64()
Int32_ = Int32()
Int16_ = Int16()
Int8_ = Int8()
Bottom_ = BottomType()
Pointer64_ = Pointer64()
Pointer32_ = Pointer32()
Struct_ = Struct()
Array_ = Array()

# lattice for 64-bit binaries
BASE_LATTICE_64 = networkx.DiGraph()
BASE_LATTICE_64.add_edge(Top_, Int_)
BASE_LATTICE_64.add_edge(Int_, Int64_)
BASE_LATTICE_64.add_edge(Int_, Int32_)
BASE_LATTICE_64.add_edge(Int_, Int16_)
BASE_LATTICE_64.add_edge(Int_, Int8_)
BASE_LATTICE_64.add_edge(Int32_, Bottom_)
BASE_LATTICE_64.add_edge(Int16_, Bottom_)
BASE_LATTICE_64.add_edge(Int8_, Bottom_)
BASE_LATTICE_64.add_edge(Int64_, Pointer64_)
BASE_LATTICE_64.add_edge(Pointer64_, Bottom_)

# lattice for 32-bit binaries
BASE_LATTICE_32 = networkx.DiGraph()
BASE_LATTICE_32.add_edge(Top_, Int_)
BASE_LATTICE_32.add_edge(Int_, Int64_)
BASE_LATTICE_32.add_edge(Int_, Int32_)
BASE_LATTICE_32.add_edge(Int_, Int16_)
BASE_LATTICE_32.add_edge(Int_, Int8_)
BASE_LATTICE_32.add_edge(Int32_, Pointer32_)
BASE_LATTICE_32.add_edge(Int64_, Bottom_)
BASE_LATTICE_32.add_edge(Pointer32_, Bottom_)
BASE_LATTICE_32.add_edge(Int16_, Bottom_)
BASE_LATTICE_32.add_edge(Int8_, Bottom_)

BASE_LATTICES = {
    32: BASE_LATTICE_32,
    64: BASE_LATTICE_64,
}


#
# Sketch
#


[docs]class SketchNodeBase: """ The base class for nodes in a sketch. """ __slots__ = ()
[docs]class SketchNode(SketchNodeBase): """ Represents a node in a sketch graph. """ __slots__ = ("typevar", "upper_bound", "lower_bound")
[docs] def __init__(self, typevar: Union[TypeVariable, DerivedTypeVariable]): self.typevar: Union[TypeVariable, DerivedTypeVariable] = typevar self.upper_bound = TopType() self.lower_bound = BottomType()
def __repr__(self): return f"{self.lower_bound} <: {self.typevar} <: {self.upper_bound}" def __eq__(self, other): return isinstance(other, SketchNode) and self.typevar == other.typevar def __hash__(self): return hash((SketchNode, self.typevar))
[docs]class RecursiveRefNode(SketchNodeBase): """ Represents a cycle in a sketch graph. This is equivalent to sketches.LabelNode in the reference implementation of retypd. """
[docs] def __init__(self, target: DerivedTypeVariable): self.target: DerivedTypeVariable = target
def __hash__(self): return hash((RecursiveRefNode, self.target)) def __eq__(self, other): return type(other) is RecursiveRefNode and other.target == self.target
[docs]class Sketch: """ Describes the sketch of a type variable. """ __slots__ = ( "graph", "root", "node_mapping", "solver", )
[docs] def __init__(self, solver: "SimpleSolver", root: TypeVariable): self.root: SketchNode = SketchNode(root) self.graph = networkx.DiGraph() self.node_mapping: Dict[Union[TypeVariable, DerivedTypeVariable], SketchNodeBase] = {} self.solver = solver # add the root node self.graph.add_node(self.root) self.node_mapping[root] = self.root
[docs] def lookup(self, typevar: Union[TypeVariable, DerivedTypeVariable]) -> Optional[SketchNodeBase]: if typevar in self.node_mapping: return self.node_mapping[typevar] node: Optional[SketchNodeBase] = None if isinstance(typevar, DerivedTypeVariable): node = self.node_mapping[SimpleSolver._to_typevar_or_typeconst(typevar.type_var)] for label in typevar.labels: succs = [] for _, dst, data in self.graph.out_edges(node, data=True): if "label" in data and data["label"] == label: succs.append(dst) assert len(succs) <= 1 if not succs: return None node = succs[0] if isinstance(node, RecursiveRefNode): node = self.lookup(node.target) return node
[docs] def add_edge(self, src: SketchNodeBase, dst: SketchNodeBase, label): self.graph.add_edge(src, dst, label=label)
[docs] def add_constraint(self, constraint: TypeConstraint) -> None: # sub <: super if not isinstance(constraint, Subtype): return subtype = self.flatten_typevar(constraint.sub_type) supertype = self.flatten_typevar(constraint.super_type) if SimpleSolver._typevar_inside_set(subtype, PRIMITIVE_TYPES) and not SimpleSolver._typevar_inside_set( supertype, PRIMITIVE_TYPES ): super_node: Optional[SketchNode] = self.lookup(supertype) if super_node is not None: super_node.lower_bound = self.solver.join(super_node.lower_bound, subtype) elif SimpleSolver._typevar_inside_set(supertype, PRIMITIVE_TYPES) and not SimpleSolver._typevar_inside_set( subtype, PRIMITIVE_TYPES ): sub_node: Optional[SketchNode] = self.lookup(subtype) # assert sub_node is not None if sub_node is not None: sub_node.upper_bound = self.solver.meet(sub_node.upper_bound, supertype)
[docs] @staticmethod def flatten_typevar( derived_typevar: Union[TypeVariable, TypeConstant, DerivedTypeVariable] ) -> Union[DerivedTypeVariable, TypeVariable, TypeConstant]: # pylint:disable=too-many-boolean-expressions if ( isinstance(derived_typevar, DerivedTypeVariable) and isinstance(derived_typevar.type_var, Pointer) and SimpleSolver._typevar_inside_set(derived_typevar.type_var.basetype, PRIMITIVE_TYPES) and len(derived_typevar.labels) == 2 and isinstance(derived_typevar.labels[0], Load) and isinstance(derived_typevar.labels[1], HasField) and derived_typevar.labels[1].offset == 0 and derived_typevar.labels[1].bits == MAX_POINTSTO_BITS ): return derived_typevar.type_var.basetype return derived_typevar
# # Constraint graph #
[docs]class ConstraintGraphTag(enum.Enum): LEFT = 0 RIGHT = 1 UNKNOWN = 2
[docs]class FORGOTTEN(enum.Enum): PRE_FORGOTTEN = 0 POST_FORGOTTEN = 1
[docs]class ConstraintGraphNode: __slots__ = ("typevar", "variance", "tag", "forgotten")
[docs] def __init__( self, typevar: Union[TypeVariable, DerivedTypeVariable], variance: Variance, tag: ConstraintGraphTag, forgotten: FORGOTTEN, ): self.typevar = typevar self.variance = variance self.tag = tag self.forgotten = forgotten
def __repr__(self): variance_str = "CO" if self.variance == Variance.COVARIANT else "CONTRA" if self.tag == ConstraintGraphTag.LEFT: tag_str = "L" elif self.tag == ConstraintGraphTag.RIGHT: tag_str = "R" else: tag_str = "U" forgotten_str = "PRE" if FORGOTTEN.PRE_FORGOTTEN else "POST" s = f"{self.typevar}#{variance_str}.{tag_str}.{forgotten_str}" if ":" in s: return '"' + s + '"' return s def __eq__(self, other): if not isinstance(other, ConstraintGraphNode): return False return ( self.typevar == other.typevar and self.variance == other.variance and self.tag == other.tag and self.forgotten == other.forgotten ) def __hash__(self): return hash((ConstraintGraphNode, self.typevar, self.variance, self.tag, self.forgotten))
[docs] def forget_last_label(self) -> Optional[Tuple["ConstraintGraphNode", BaseLabel]]: if isinstance(self.typevar, DerivedTypeVariable) and self.typevar.labels: last_label = self.typevar.labels[-1] if len(self.typevar.labels) == 1: prefix = self.typevar.type_var else: prefix = DerivedTypeVariable(self.typevar.type_var, None, labels=self.typevar.labels[:-1]) if self.variance == last_label.variance: variance = Variance.COVARIANT else: variance = Variance.CONTRAVARIANT return ( ConstraintGraphNode(prefix, variance, self.tag, FORGOTTEN.PRE_FORGOTTEN), self.typevar.labels[-1], ) return None
[docs] def recall(self, label: BaseLabel) -> "ConstraintGraphNode": if isinstance(self.typevar, DerivedTypeVariable): labels = self.typevar.labels + (label,) typevar = self.typevar.type_var elif isinstance(self.typevar, TypeVariable): labels = (label,) typevar = self.typevar elif isinstance(self.typevar, TypeConstant): labels = (label,) typevar = self.typevar else: raise TypeError(f"Unsupported type {type(self.typevar)}") if self.variance == label.variance: variance = Variance.COVARIANT else: variance = Variance.CONTRAVARIANT if not labels: var = typevar else: var = DerivedTypeVariable(typevar, None, labels=labels) return ConstraintGraphNode(var, variance, self.tag, FORGOTTEN.PRE_FORGOTTEN)
[docs] def inverse(self) -> "ConstraintGraphNode": if self.tag == ConstraintGraphTag.LEFT: tag = ConstraintGraphTag.RIGHT elif self.tag == ConstraintGraphTag.RIGHT: tag = ConstraintGraphTag.LEFT else: tag = ConstraintGraphTag.UNKNOWN if self.variance == Variance.COVARIANT: variance = Variance.CONTRAVARIANT else: variance = Variance.COVARIANT return ConstraintGraphNode(self.typevar, variance, tag, self.forgotten)
[docs] def inverse_wo_tag(self) -> "ConstraintGraphNode": """ Invert the variance only. """ if self.variance == Variance.COVARIANT: variance = Variance.CONTRAVARIANT else: variance = Variance.COVARIANT return ConstraintGraphNode(self.typevar, variance, self.tag, self.forgotten)
# # The solver #
[docs]class SimpleSolver: """ SimpleSolver is, by its name, a simple solver. Most of this solver is based on the (complex) simplification logic that the retypd paper describes and the retypd re-implementation (https://github.com/GrammaTech/retypd) implements. Additionally, we add some improvements to allow type propagation of known struct names, among a few other improvements. """
[docs] def __init__(self, bits: int, constraints, typevars): if bits not in (32, 64): raise ValueError("Pointer size %d is not supported. Expect 32 or 64." % bits) self.bits = bits self._constraints: Dict[TypeVariable, Set[TypeConstraint]] = constraints self._typevars: Set[TypeVariable] = typevars self._base_lattice = BASE_LATTICES[bits] self._base_lattice_inverted = networkx.DiGraph() for src, dst in self._base_lattice.edges: self._base_lattice_inverted.add_edge(dst, src) # # Solving state # self._equivalence = defaultdict(dict) for typevar in list(self._constraints): if self._constraints[typevar]: self._constraints[typevar] |= self._eq_constraints_from_add(typevar) self._constraints[typevar] = self._handle_equivalence(typevar) equ_classes, sketches, _ = self.solve() self.solution = {} self._solution_cache = {} self.determine(equ_classes, sketches, self.solution) for typevar in list(self._constraints): self._convert_arrays(self._constraints[typevar])
[docs] def solve(self): """ Steps: For each type variable, - Infer the shape in its sketch - Build the constraint graph - Collect all constraints - Apply constraints to derive the lower and upper bounds """ typevars = set(self._constraints) | self._typevars constraints = set() for tv in typevars: if tv in self._constraints: constraints |= self._constraints[tv] # collect typevars used in the constraint set constrained_typevars = set() for constraint in constraints: if isinstance(constraint, Subtype): for t in (constraint.sub_type, constraint.super_type): if isinstance(t, DerivedTypeVariable): if t.type_var in typevars: constrained_typevars.add(t.type_var) elif isinstance(t, TypeVariable): if t in typevars: constrained_typevars.add(t) equivalence_classes, sketches = self.infer_shapes(typevars, constraints) # TODO: Handle global variables type_schemes = constraints constraintset2tvs = defaultdict(set) for idx, tv in enumerate(constrained_typevars): _l.debug("Collecting constraints for type variable %r (%d/%d)", tv, idx + 1, len(constrained_typevars)) # build a sub constraint set for the type variable constraint_subset = frozenset(self._generate_constraint_subset(constraints, {tv})) constraintset2tvs[constraint_subset].add(tv) for idx, (constraint_subset, tvs) in enumerate(constraintset2tvs.items()): _l.debug( "Solving %d constraints for type variables %r (%d/%d)", len(constraint_subset), tvs, idx + 1, len(constraintset2tvs), ) base_constraint_graph = self._generate_constraint_graph(constraint_subset, tvs | PRIMITIVE_TYPES) for idx_0, tv in enumerate(tvs): _l.debug("Solving for type variable %r (%d/%d)", tv, idx_0 + 1, len(tvs)) primitive_constraints = self._generate_primitive_constraints({tv}, base_constraint_graph) for primitive_constraint in primitive_constraints: sketches[tv].add_constraint(primitive_constraint) return equivalence_classes, sketches, type_schemes
[docs] def infer_shapes( self, typevars: Set[TypeVariable], constraints: Set[TypeConstraint] ) -> Tuple[Dict, Dict[TypeVariable, Sketch]]: """ Computing sketches from constraint sets. Implements Algorithm E.1 in the retypd paper. """ equivalence_classes, quotient_graph = self.compute_quotient_graph(constraints) sketches: Dict[TypeVariable, Sketch] = {} for tv in typevars: sketches[tv] = Sketch(self, tv) for tv, sketch in sketches.items(): sketch_node = sketch.lookup(tv) graph_node = equivalence_classes.get(tv, None) # assert graph_node is not None if graph_node is None: continue visited = {graph_node: sketch_node} self._get_all_paths(quotient_graph, sketch, graph_node, visited) return equivalence_classes, sketches
[docs] def compute_quotient_graph(self, constraints: Set[TypeConstraint]): """ Compute the quotient graph (the constraint graph modulo ~ in Algorithm E.1 in the retypd paper) with respect to a given set of type constraints. """ g = networkx.DiGraph() # collect all derived type variables typevars = self._typevars_from_constraints(constraints) g.add_nodes_from(typevars) # add paths for each derived type variable into the graph for tv in typevars: last_node = tv prefix = tv while isinstance(prefix, DerivedTypeVariable) and prefix.labels: prefix = prefix.longest_prefix() if prefix is None: continue g.add_edge(prefix, last_node, label=last_node.labels[-1]) last_node = prefix # compute the constraint graph modulo ~ equivalence_classes = {node: node for node in g} load = Load() store = Store() for node in g.nodes: lbl_to_node = {} for succ in g.successors(node): lbl_to_node[succ.labels[-1]] = succ if load in lbl_to_node and store in lbl_to_node: self._unify(equivalence_classes, lbl_to_node[load], lbl_to_node[store], g) for constraint in constraints: if isinstance(constraint, Subtype): if self._typevar_inside_set(constraint.super_type, PRIMITIVE_TYPES) or self._typevar_inside_set( constraint.sub_type, PRIMITIVE_TYPES ): continue self._unify(equivalence_classes, constraint.super_type, constraint.sub_type, g) out_graph = networkx.MultiDiGraph() # there can be multiple edges between two nodes, each edge is associated # with a different label for src, dst, data in g.edges(data=True): src_cls = equivalence_classes[src] dst_cls = equivalence_classes[dst] label = None if not data else data["label"] if label is not None and out_graph.has_edge(src_cls, dst_cls): # do not add the same edge twice existing_labels = { data_["label"] for _, dst_cls_, data_ in out_graph.out_edges(src_cls, data=True) if dst_cls_ == dst_cls and data } if label in existing_labels: continue out_graph.add_edge(src_cls, dst_cls, label=label) return equivalence_classes, out_graph
def _generate_primitive_constraints( self, non_primitive_endpoints: Set[Union[TypeVariable, DerivedTypeVariable]], constraint_graph, ) -> Set[TypeConstraint]: # FIXME: Extract interesting variables constraints_0 = self._solve_constraints_between(constraint_graph, non_primitive_endpoints, PRIMITIVE_TYPES) constraints_1 = self._solve_constraints_between(constraint_graph, PRIMITIVE_TYPES, non_primitive_endpoints) return constraints_0 | constraints_1 @staticmethod def _typevars_from_constraints(constraints: Set[TypeConstraint]) -> Set[Union[TypeVariable, DerivedTypeVariable]]: """ Collect derived type variables from a set of constraints. """ typevars: Set[Union[TypeVariable, DerivedTypeVariable]] = set() for constraint in constraints: if isinstance(constraint, Subtype): typevars.add(constraint.sub_type) typevars.add(constraint.super_type) # TODO: Other types of constraints? return typevars @staticmethod def _get_all_paths( graph: networkx.DiGraph, sketch: Sketch, node: DerivedTypeVariable, visited: Dict[Union[TypeVariable, DerivedTypeVariable], SketchNode], ): if node not in graph: return curr_node = visited[node] for _, succ, data in graph.out_edges(node, data=True): label = data["label"] if succ not in visited: if isinstance(curr_node.typevar, DerivedTypeVariable): base_typevar = curr_node.typevar.type_var labels = curr_node.typevar.labels elif isinstance(curr_node.typevar, TypeVariable): base_typevar = curr_node.typevar labels = () else: raise TypeError("Unexpected") labels += (label,) succ_derived_typevar = DerivedTypeVariable( base_typevar, None, labels=labels, ) succ_node = SketchNode(succ_derived_typevar) sketch.add_edge(curr_node, succ_node, label) visited[succ] = succ_node SimpleSolver._get_all_paths(graph, sketch, succ, visited) del visited[succ] else: # a cycle exists ref_node = RecursiveRefNode(visited[succ].typevar) sketch.add_edge(curr_node, ref_node, label) @staticmethod def _unify( equivalence_classes: Dict, cls0: DerivedTypeVariable, cls1: DerivedTypeVariable, graph: networkx.DiGraph ) -> None: # first convert cls0 and cls1 to their equivalence classes cls0 = equivalence_classes[cls0] cls1 = equivalence_classes[cls1] # unify if needed if cls0 != cls1: # MakeEquiv existing_elements = {key for key, item in equivalence_classes.items() if item in {cls0, cls1}} rep_cls = cls0 for elem in existing_elements: equivalence_classes[elem] = rep_cls # the logic below refers to the retypd reference implementation. it is different from Algorithm E.1 # note that graph is used read-only in this method, so we do not need to make copy of edges for _, dst0, data0 in graph.out_edges(cls0, data=True): if "label" in data0 and data0["label"] is not None: for _, dst1, data1 in graph.out_edges(cls1, data=True): if ( data0["label"] == data1["label"] or isinstance(data0["label"], Load) and isinstance(data1["label"], Store) ): SimpleSolver._unify( equivalence_classes, equivalence_classes[dst0], equivalence_classes[dst1], graph ) def _eq_constraints_from_add(self, typevar: TypeVariable): """ Handle Add constraints. """ new_constraints = set() for constraint in self._constraints[typevar]: if isinstance(constraint, Add): if ( isinstance(constraint.type_0, TypeVariable) and not isinstance(constraint.type_0, DerivedTypeVariable) and isinstance(constraint.type_r, TypeVariable) and not isinstance(constraint.type_r, DerivedTypeVariable) ): new_constraints.add(Equivalence(constraint.type_0, constraint.type_r)) if ( isinstance(constraint.type_1, TypeVariable) and not isinstance(constraint.type_1, DerivedTypeVariable) and isinstance(constraint.type_r, TypeVariable) and not isinstance(constraint.type_r, DerivedTypeVariable) ): new_constraints.add(Equivalence(constraint.type_1, constraint.type_r)) return new_constraints def _handle_equivalence(self, typevar: TypeVariable): graph = networkx.Graph() replacements = {} constraints = set() # collect equivalence relations for constraint in self._constraints[typevar]: if isinstance(constraint, Equivalence): # | type_a == type_b # we apply unification and removes one of them ta, tb = constraint.type_a, constraint.type_b if isinstance(ta, TypeConstant) and isinstance(tb, TypeVariable): # replace tb with ta replacements[tb] = ta elif isinstance(ta, TypeVariable) and isinstance(tb, TypeConstant): # replace ta with tb replacements[ta] = tb else: # they are both type variables. we will determine a representative later graph.add_edge(ta, tb) for components in networkx.connected_components(graph): components_lst = list(sorted(components, key=lambda x: str(x))) # pylint:disable=unnecessary-lambda representative = components_lst[0] for tv in components_lst[1:]: replacements[tv] = representative # replace for constraint in self._constraints[typevar]: if isinstance(constraint, Existence): replaced, new_constraint = constraint.replace(replacements) if replaced: constraints.add(new_constraint) else: constraints.add(constraint) elif isinstance(constraint, Subtype): # subtype <: supertype # replace type variables replaced, new_constraint = constraint.replace(replacements) if replaced: constraints.add(new_constraint) else: constraints.add(constraint) # import pprint # print("Replacements") # pprint.pprint(replacements) # print("Constraints (after replacement)") # pprint.pprint(constraints) self._equivalence = replacements return constraints def _convert_arrays(self, constraints): for constraint in constraints: if not isinstance(constraint, Existence): continue inner = constraint.type_ if isinstance(inner, DerivedTypeVariable) and isinstance(inner.one_label(), IsArray): if inner.type_var in self.solution: curr_type = self.solution[inner.type_var] if isinstance(curr_type, Pointer) and isinstance(curr_type.basetype, Struct): # replace all fields with the first field if 0 in curr_type.basetype.fields: first_field = curr_type.basetype.fields[0] for offset in curr_type.basetype.fields.keys(): curr_type.basetype.fields[offset] = first_field # # Constraint graph # @staticmethod def _generate_constraint_subset( constraints: Set[TypeConstraint], typevars: Set[TypeVariable] ) -> Set[TypeConstraint]: subset = set() related_typevars = set(typevars) while True: new = set() for constraint in constraints: if constraint in subset: continue if isinstance(constraint, Subtype): if isinstance(constraint.sub_type, DerivedTypeVariable): subt = constraint.sub_type.type_var elif isinstance(constraint.sub_type, TypeVariable): subt = constraint.sub_type else: subt = None if isinstance(constraint.super_type, DerivedTypeVariable): supert = constraint.super_type.type_var elif isinstance(constraint.super_type, TypeVariable): supert = constraint.super_type else: supert = None if subt in related_typevars or supert in related_typevars: new.add(constraint) if subt is not None: related_typevars.add(subt) if supert is not None: related_typevars.add(supert) if not new: break subset |= new return subset def _generate_constraint_graph( self, constraints: Set[TypeConstraint], interesting_variables: Set[DerivedTypeVariable] ) -> networkx.DiGraph: """ A constraint graph is the same as the finite state transducer that is presented in Appendix D in the retypd paper. """ graph = networkx.DiGraph() for constraint in constraints: if isinstance(constraint, Subtype): self._constraint_graph_add_edges( graph, constraint.sub_type, constraint.super_type, interesting_variables ) self._constraint_graph_saturate(graph) self._constraint_graph_remove_self_loops(graph) self._constraint_graph_recall_forget_split(graph) return graph @staticmethod def _constraint_graph_add_recall_edges(graph: networkx.DiGraph, node: ConstraintGraphNode) -> None: while True: r = node.forget_last_label() if r is None: break prefix, last_label = r graph.add_edge(prefix, node, label=(last_label, "recall")) node = prefix @staticmethod def _constraint_graph_add_forget_edges(graph: networkx.DiGraph, node: ConstraintGraphNode) -> None: while True: r = node.forget_last_label() if r is None: break prefix, last_label = r graph.add_edge(node, prefix, label=(last_label, "forget")) node = prefix def _constraint_graph_add_edges( self, graph: networkx.DiGraph, subtype: Union[TypeVariable, DerivedTypeVariable], supertype: Union[TypeVariable, DerivedTypeVariable], interesting_variables: Set[DerivedTypeVariable], ): # left and right tags if self._typevar_inside_set(self._to_typevar_or_typeconst(subtype), interesting_variables): left_tag = ConstraintGraphTag.LEFT else: left_tag = ConstraintGraphTag.UNKNOWN if self._typevar_inside_set(self._to_typevar_or_typeconst(supertype), interesting_variables): right_tag = ConstraintGraphTag.RIGHT else: right_tag = ConstraintGraphTag.UNKNOWN # nodes forward_src = ConstraintGraphNode(subtype, Variance.COVARIANT, left_tag, FORGOTTEN.PRE_FORGOTTEN) forward_dst = ConstraintGraphNode(supertype, Variance.COVARIANT, right_tag, FORGOTTEN.PRE_FORGOTTEN) graph.add_edge(forward_src, forward_dst) # add recall edges and forget edges self._constraint_graph_add_recall_edges(graph, forward_src) self._constraint_graph_add_forget_edges(graph, forward_dst) # backward edges backward_src = forward_dst.inverse() backward_dst = forward_src.inverse() graph.add_edge(backward_src, backward_dst) self._constraint_graph_add_recall_edges(graph, backward_src) self._constraint_graph_add_forget_edges(graph, backward_dst) @staticmethod def _constraint_graph_saturate(graph: networkx.DiGraph) -> None: """ The saturation algorithm D.2 as described in Appendix of the retypd paper. """ R: DefaultDict[ConstraintGraphNode, Set[Tuple[BaseLabel, ConstraintGraphNode]]] = defaultdict(set) # initialize the reaching-push sets R(x) for x, y, data in graph.edges(data=True): if "label" in data and data.get("label")[1] == "forget": d = data["label"][0], x R[y].add(d) # repeat ... until fixed point changed = True while changed: changed = False for x, y, data in graph.edges(data=True): if "label" not in data: if R[y].issuperset(R[x]): continue changed = True R[y] |= R[x] for x, y, data in graph.edges(data=True): lbl = data.get("label") if lbl and lbl[1] == "recall": for label, z in R[x]: if not graph.has_edge(z, y): changed = True graph.add_edge(z, y) v_contravariant = [] for node in graph.nodes: node: ConstraintGraphNode if node.variance == Variance.CONTRAVARIANT: v_contravariant.append(node) # lazily apply saturation rules corresponding to S-Pointer for x in v_contravariant: for z_label, z in R[x]: label = None if isinstance(z_label, Store): label = Load() elif isinstance(z_label, Load): label = Store() if label is not None: x_inverse = x.inverse_wo_tag() d = label, z if d not in R[x_inverse]: changed = True R[x_inverse].add(d) @staticmethod def _constraint_graph_remove_self_loops(graph: networkx.DiGraph): for node in list(graph.nodes): if graph.has_edge(node, node): graph.remove_edge(node, node) @staticmethod def _constraint_graph_recall_forget_split(graph: networkx.DiGraph): """ Ensure that recall edges are not reachable after traversing a forget node. """ for src, dst, data in list(graph.edges(data=True)): src: ConstraintGraphNode dst: ConstraintGraphNode if "label" in data and data["label"][1] == "recall": continue forget_src = ConstraintGraphNode(src.typevar, src.variance, src.tag, FORGOTTEN.POST_FORGOTTEN) forget_dst = ConstraintGraphNode(dst.typevar, dst.variance, dst.tag, FORGOTTEN.POST_FORGOTTEN) if "label" in data and data["label"][1] == "forget": graph.remove_edge(src, dst) graph.add_edge(src, forget_dst, **data) graph.add_edge(forget_src, forget_dst, **data) @staticmethod def _to_typevar_or_typeconst( obj: Union[TypeVariable, DerivedTypeVariable, TypeConstant] ) -> Union[TypeVariable, TypeConstant]: if isinstance(obj, DerivedTypeVariable): return SimpleSolver._to_typevar_or_typeconst(obj.type_var) elif isinstance(obj, TypeVariable): return obj elif isinstance(obj, TypeConstant): return obj raise TypeError(f"Unsupported type {type(obj)}") # # Graph solver # @staticmethod def _typevar_inside_set(typevar, typevar_set: Set[Union[TypeConstant, TypeVariable, DerivedTypeVariable]]) -> bool: if typevar in typevar_set: return True if isinstance(typevar, Struct) and Struct_ in typevar_set: if not typevar.fields: return True return all( SimpleSolver._typevar_inside_set(field_typevar, typevar_set) for field_typevar in typevar.fields.values() ) if isinstance(typevar, Array) and Array_ in typevar_set: return SimpleSolver._typevar_inside_set(typevar.element, typevar_set) if isinstance(typevar, Pointer) and (Pointer32_ in typevar_set or Pointer64_ in typevar_set): return SimpleSolver._typevar_inside_set(typevar.basetype, typevar_set) return False def _solve_constraints_between( self, graph: networkx.DiGraph, starts: Set[Union[TypeConstant, TypeVariable, DerivedTypeVariable]], ends: Set[Union[TypeConstant, TypeVariable, DerivedTypeVariable]], ) -> Set[TypeConstraint]: start_nodes = set() end_nodes = set() for node in graph.nodes: node: ConstraintGraphNode if ( self._typevar_inside_set(self._to_typevar_or_typeconst(node.typevar), starts) and node.tag == ConstraintGraphTag.LEFT ): start_nodes.add(node) if ( self._typevar_inside_set(self._to_typevar_or_typeconst(node.typevar), ends) and node.tag == ConstraintGraphTag.RIGHT ): end_nodes.add(node) if not start_nodes or not end_nodes: return set() dfa_solver = DFAConstraintSolver() try: return dfa_solver.generate_constraints_between(graph, start_nodes, end_nodes) except EmptyEpsilonNFAError: return set() # # Type lattice #
[docs] def join(self, t1: Union[TypeConstant, TypeVariable], t2: Union[TypeConstant, TypeVariable]) -> TypeConstant: abstract_t1 = self.abstract(t1) abstract_t2 = self.abstract(t2) if abstract_t1 in self._base_lattice and abstract_t2 in self._base_lattice: ancestor = networkx.lowest_common_ancestor(self._base_lattice, abstract_t1, abstract_t2) if ancestor == abstract_t1: return t1 elif ancestor == abstract_t2: return t2 else: return ancestor if t1 == Bottom_: return t2 if t2 == Bottom_: return t1 return Bottom_
[docs] def meet(self, t1: Union[TypeConstant, TypeVariable], t2: Union[TypeConstant, TypeVariable]) -> TypeConstant: abstract_t1 = self.abstract(t1) abstract_t2 = self.abstract(t2) if abstract_t1 in self._base_lattice_inverted and abstract_t2 in self._base_lattice_inverted: ancestor = networkx.lowest_common_ancestor(self._base_lattice_inverted, abstract_t1, abstract_t2) if ancestor == abstract_t1: return t1 elif ancestor == abstract_t2: return t2 else: return ancestor if t1 == Top_: return t2 if t2 == Top_: return t1 return Top_
[docs] @staticmethod def abstract(t: Union[TypeConstant, TypeVariable]) -> Union[TypeConstant, TypeVariable]: if isinstance(t, Pointer32): return Pointer32() elif isinstance(t, Pointer64): return Pointer64() return t
[docs] def determine( self, equivalent_classes: Dict[TypeVariable, TypeVariable], sketches, solution: Dict, nodes: Optional[Set[SketchNode]] = None, ) -> None: """ Determine C-like types from sketches. :param equivalent_classes: A dictionary mapping each type variable from its representative in the equivalence class over ~. :param sketches: A dictionary storing sketches for each type variable. :param solution: The dictionary storing C-like types for each type variable. Output. :param nodes: Optional. Nodes that should be considered in the sketch. :return: None """ for typevar, sketch in sketches.items(): self._determine(equivalent_classes, typevar, sketch, solution, nodes=nodes) for v, e in self._equivalence.items(): if v not in solution and e in solution: solution[v] = solution[e]
def _determine( self, equivalent_classes, the_typevar, sketch, solution: Dict, nodes: Optional[Set[SketchNode]] = None ): """ Return the solution from sketches """ if not nodes: # TODO: resolve references node = sketch.lookup(the_typevar) assert node is not None nodes = {node} # consult the cache cached_results = set() for node in nodes: if node.typevar in self._solution_cache: cached_results.add(self._solution_cache[node.typevar]) if len(cached_results) == 1: return next(iter(cached_results)) elif len(cached_results) > 1: # we get nodes for multiple type variables? raise RuntimeError("Getting nodes for multiple type variables. Unexpected.") # collect all successors and the paths (labels) of this type variable path_and_successors = [] last_labels = [] for node in nodes: path_and_successors += self._collect_sketch_paths(node, sketch) for labels, _ in path_and_successors: if labels: last_labels.append(labels[-1]) # now, what is this variable? result = None if last_labels and all(isinstance(label, (FuncIn, FuncOut)) for label in last_labels): # create a dummy result and dump it to the cache func_type = Function([], []) result = self._pointer_class()(basetype=func_type) for node in nodes: self._solution_cache[node.typevar] = result # this is a function variable func_inputs = defaultdict(set) func_outputs = defaultdict(set) for labels, succ in path_and_successors: last_label = labels[-1] if labels else None if isinstance(last_label, FuncIn): func_inputs[last_label.loc].add(succ) elif isinstance(last_label, FuncOut): func_outputs[last_label.loc].add(succ) else: raise RuntimeError("Unreachable") input_args = [] output_values = [] for vals, out in [(func_inputs, input_args), (func_outputs, output_values)]: for idx in range(0, max(vals) + 1): if idx in vals: sol = self._determine(equivalent_classes, the_typevar, sketch, solution, nodes=vals[idx]) out.append(sol) else: out.append(None) # back patch func_type.params = input_args func_type.outputs = output_values for node in nodes: solution[node.typevar] = result elif path_and_successors: # maybe this is a pointer to a struct? if len(nodes) == 1: the_node = next(iter(nodes)) if ( isinstance(the_node.upper_bound, self._pointer_class()) and isinstance(the_node.upper_bound.basetype, Struct) and the_node.upper_bound.basetype.name ): # handle pointers to known struct types result = ( the_node.lower_bound if not isinstance(the_node.lower_bound, BottomType) else the_node.upper_bound ) for node in nodes: solution[node.typevar] = result self._solution_cache[node.typevar] = result return result # create a dummy result and shove it into the cache struct_type = Struct(fields={}) result = self._pointer_class()(struct_type) for node in nodes: self._solution_cache[node.typevar] = result # this might be a struct fields = {} candidate_bases = defaultdict(set) for labels, succ in path_and_successors: last_label = labels[-1] if labels else None if isinstance(last_label, HasField): candidate_bases[last_label.offset].add(last_label.bits // 8) node_to_base = {} for labels, succ in path_and_successors: last_label = labels[-1] if labels else None if isinstance(last_label, HasField): for start_offset, sizes in candidate_bases.items(): for size in sizes: if last_label.offset > start_offset: if last_label.offset < start_offset + size: # ??? node_to_base[succ] = start_offset node_by_offset = defaultdict(set) for labels, succ in path_and_successors: last_label = labels[-1] if labels else None if isinstance(last_label, HasField): if succ in node_to_base: node_by_offset[node_to_base[succ]].add(succ) else: node_by_offset[last_label.offset].add(succ) for offset, child_nodes in node_by_offset.items(): sol = self._determine(equivalent_classes, the_typevar, sketch, solution, nodes=child_nodes) if isinstance(sol, TopType): sol = int_type(min(candidate_bases[offset]) * 8) fields[offset] = sol if not fields: result = Top_ for node in nodes: self._solution_cache[node.typevar] = result else: # back-patch struct_type.fields = fields for node in nodes: solution[node.typevar] = result if not path_and_successors or result in {Top_, None}: # this is probably a primitive variable lower_bound = Bottom_ upper_bound = Top_ for node in nodes: lower_bound = self.join(lower_bound, node.lower_bound) upper_bound = self.meet(upper_bound, node.upper_bound) # TODO: Support variables that are accessed via differently sized pointers result = lower_bound if not isinstance(lower_bound, BottomType) else upper_bound for node in nodes: solution[node.typevar] = result self._solution_cache[node.typevar] = result # import pprint # print("Solution") # pprint.pprint(result) return result @staticmethod def _collect_sketch_paths(node: SketchNodeBase, sketch: Sketch) -> List[Tuple[List[BaseLabel], SketchNodeBase]]: """ Collect all paths that go from `typevar` to its leaves. """ paths = [] visited: Set[SketchNodeBase] = set() queue: List[Tuple[List[BaseLabel], SketchNodeBase]] = [([], node)] while queue: curr_labels, curr_node = queue.pop(0) if curr_node in visited: continue visited.add(curr_node) out_edges = sketch.graph.out_edges(curr_node, data=True) for _, succ, data in out_edges: if isinstance(succ, RecursiveRefNode): ref = succ succ: Optional[SketchNode] = sketch.lookup(succ.target) if succ is None: # failed to resolve... _l.warning( "Failed to resolve reference node to a real sketch node for type variable %s", ref.target ) continue label = data["label"] if isinstance(label, ConvertTo): # drop conv labels for now continue if isinstance(label, IsArray): continue new_labels = curr_labels + [label] succ: SketchNode if isinstance(succ.typevar, DerivedTypeVariable) and isinstance(succ.typevar.labels[-1], (Load, Store)): queue.append((new_labels, succ)) else: paths.append((new_labels, succ)) return paths def _pointer_class(self) -> Union[Type[Pointer32], Type[Pointer64]]: if self.bits == 32: return Pointer32 elif self.bits == 64: return Pointer64 raise NotImplementedError("Unsupported bits %d" % self.bits)