Source code for angr.analyses.typehoon.simple_solver

# pylint:disable=missing-class-docstring
import itertools
from collections import defaultdict
from typing import Union, Type, Callable

import networkx

from .typevars import (
    Existence,
    Equivalence,
    Subtype,
    TypeVariable,
    DerivedTypeVariable,
    HasField,
    Add,
    ConvertTo,
    IsArray,
)
from .typeconsts import (
    BottomType,
    TopType,
    TypeConstant,
    Int,
    Int8,
    Int16,
    Int32,
    Int64,
    Pointer,
    Pointer32,
    Pointer64,
    Struct,
    int_type,
    TypeVariableReference,
)

# lattice for 64-bit binaries
BASE_LATTICE_64 = networkx.DiGraph()
BASE_LATTICE_64.add_edge(TopType, Int)
BASE_LATTICE_64.add_edge(Int, Int64)
BASE_LATTICE_64.add_edge(Int, Int32)
BASE_LATTICE_64.add_edge(Int, Int16)
BASE_LATTICE_64.add_edge(Int, Int8)
BASE_LATTICE_64.add_edge(Int32, BottomType)
BASE_LATTICE_64.add_edge(Int16, BottomType)
BASE_LATTICE_64.add_edge(Int8, BottomType)
BASE_LATTICE_64.add_edge(Int64, Pointer64)
BASE_LATTICE_64.add_edge(Pointer64, BottomType)

# lattice for 32-bit binaries
BASE_LATTICE_32 = networkx.DiGraph()
BASE_LATTICE_32.add_edge(TopType, Int)
BASE_LATTICE_32.add_edge(Int, Int64)
BASE_LATTICE_32.add_edge(Int, Int32)
BASE_LATTICE_32.add_edge(Int, Int16)
BASE_LATTICE_32.add_edge(Int, Int8)
BASE_LATTICE_32.add_edge(Int32, Pointer32)
BASE_LATTICE_32.add_edge(Int64, BottomType)
BASE_LATTICE_32.add_edge(Pointer32, BottomType)
BASE_LATTICE_32.add_edge(Int16, BottomType)
BASE_LATTICE_32.add_edge(Int8, BottomType)

BASE_LATTICES = {
    32: BASE_LATTICE_32,
    64: BASE_LATTICE_64,
}


[docs]class RecursiveType:
[docs] def __init__(self, typevar, offset): self.typevar = typevar self.offset = offset
[docs]class SimpleSolver: """ SimpleSolver is, literally, a simple, unification-based type constraint solver. """
[docs] def __init__(self, bits: int, constraints): if bits not in (32, 64): raise ValueError("Pointer size %d is not supported. Expect 32 or 64." % bits) self.bits = bits self._constraints = constraints self._base_lattice = BASE_LATTICES[bits] # # Solving state # self._equivalence = {} self._lower_bounds = defaultdict(BottomType) self._upper_bounds = defaultdict(TopType) self._recursive_types = defaultdict(set) self.solve() self.solution = self.determine()
[docs] def solve(self): # import pprint # pprint.pprint(self._constraints) eq_constraints = self._eq_constraints_from_add() self._constraints |= eq_constraints constraints = self._handle_equivalence() subtypevars, supertypevars = self._calculate_closure(constraints) self._find_recursive_types(subtypevars) self._compute_lower_upper_bounds(subtypevars, supertypevars) self._lower_struct_fields() self._convert_arrays(constraints)
# import pprint # print("Lower bounds") # pprint.pprint(self._lower_bounds) # print("Upper bounds") # pprint.pprint(self._upper_bounds)
[docs] def determine(self): solution = {} for v in self._lower_bounds: if isinstance(v, TypeVariable) and not isinstance(v, DerivedTypeVariable): lb = self._lower_bounds[v] if isinstance(lb, BottomType): # use its upper bound instead solution[v] = self._upper_bounds[v] else: solution[v] = lb for v in self._upper_bounds: if v not in solution: ub = self._upper_bounds[v] if not isinstance(ub, TopType): solution[v] = ub for v, e in self._equivalence.items(): if v not in solution: solution[v] = solution.get(e, None) # import pprint # print("Lower bounds") # pprint.pprint(self._lower_bounds) # print("Upper bounds") # pprint.pprint(self._upper_bounds) # print("Solution") # pprint.pprint(solution) return solution
def _handle_equivalence(self): graph = networkx.Graph() replacements = {} constraints = set() # collect equivalence relations for constraint in self._constraints: if isinstance(constraint, Equivalence): # | type_a == type_b # we apply unification and removes one of them ta, tb = constraint.type_a, constraint.type_b if isinstance(ta, TypeConstant) and isinstance(tb, TypeVariable): # replace tb with ta replacements[tb] = ta elif isinstance(ta, TypeVariable) and isinstance(tb, TypeConstant): # replace ta with tb replacements[ta] = tb else: # they are both type variables. we will determine a representative later graph.add_edge(ta, tb) for components in networkx.connected_components(graph): components_lst = list(sorted(components, key=lambda x: str(x))) # pylint:disable=unnecessary-lambda representative = components_lst[0] for tv in components_lst[1:]: replacements[tv] = representative # replace for constraint in self._constraints: if isinstance(constraint, Existence): replaced, new_constraint = constraint.replace(replacements) if replaced: constraints.add(new_constraint) else: constraints.add(constraint) elif isinstance(constraint, Subtype): # subtype <: supertype # replace type variables replaced, new_constraint = constraint.replace(replacements) if replaced: constraints.add(new_constraint) else: constraints.add(constraint) # import pprint # print("Replacements") # pprint.pprint(replacements) # print("Constraints (after replacement)") # pprint.pprint(constraints) self._equivalence = replacements return constraints def _eq_constraints_from_add(self): """ Handle Add constraints. """ new_constraints = set() for constraint in self._constraints: if isinstance(constraint, Add): if ( isinstance(constraint.type_0, TypeVariable) and not isinstance(constraint.type_0, DerivedTypeVariable) and isinstance(constraint.type_r, TypeVariable) and not isinstance(constraint.type_r, DerivedTypeVariable) ): new_constraints.add(Equivalence(constraint.type_0, constraint.type_r)) if ( isinstance(constraint.type_1, TypeVariable) and not isinstance(constraint.type_1, DerivedTypeVariable) and isinstance(constraint.type_r, TypeVariable) and not isinstance(constraint.type_r, DerivedTypeVariable) ): new_constraints.add(Equivalence(constraint.type_1, constraint.type_r)) return new_constraints def _pointer_class(self) -> Union[Type[Pointer32], Type[Pointer64]]: if self.bits == 32: return Pointer32 elif self.bits == 64: return Pointer64 raise NotImplementedError("Unsupported bits %d" % self.bits) def _calculate_closure(self, constraints): ptr_class = self._pointer_class() # a mapping from type variables to all the variables which are {super,sub}types of them subtypevars = defaultdict(set) # {k: {v}}: v <: k supertypevars = defaultdict(set) # {k: {v}}: k <: v constraints = set(constraints) # make a copy while constraints: constraint = constraints.pop() if isinstance(constraint, Existence): # has a derived type if isinstance(constraint.type_, DerivedTypeVariable): # handle label if isinstance(constraint.type_.label, HasField): # the original variable is a pointer v = constraint.type_.type_var.type_var if isinstance(v, TypeVariable): subtypevars[v].add( ptr_class( Struct( fields={ constraint.type_.label.offset: int_type(constraint.type_.label.bits), } ) ) ) elif isinstance(constraint, Subtype): # subtype <: supertype subtype, supertype = constraint.sub_type, constraint.super_type if isinstance(supertype, TypeVariable): if subtype not in subtypevars[supertype]: if supertype is not subtype: subtypevars[supertype].add(subtype) for s in supertypevars[subtype]: # re-add impacted constraints constraints.add(Subtype(subtype, s)) if subtype in subtypevars: for v in subtypevars[subtype]: if v not in subtypevars[supertype]: if supertype is not v: subtypevars[supertype].add(v) for sup in supertypevars[v]: constraints.add(Subtype(subtype, sup)) if isinstance(subtype, TypeVariable): if supertype not in supertypevars[subtype]: if subtype is not supertype: supertypevars[subtype].add(supertype) for s in subtypevars[supertype]: # re-add impacted constraints constraints.add(Subtype(s, supertype)) if supertype in supertypevars: for v in supertypevars[supertype]: if v not in supertypevars[subtype]: if v is not subtype: supertypevars[subtype].add(v) for sup in supertypevars[v]: constraints.add(Subtype(subtype, sup)) elif isinstance(constraint, Equivalence): raise Exception("Shouldn't exist anymore.") else: raise NotImplementedError("Unsupported instance type %s." % type(constraint)) # import pprint # print("Subtype vars") # pprint.pprint(subtypevars) # print("Supertype vars") # pprint.pprint(supertypevars) return subtypevars, supertypevars def _find_recursive_types(self, subtypevars): ptr_class = self._pointer_class() for var in list(subtypevars.keys()): sts = subtypevars[var].copy() if isinstance(var, DerivedTypeVariable) and isinstance(var.label, HasField): for subtype_var in sts: if var.type_var.type_var == subtype_var: subtypevars[subtype_var].add( ptr_class(Struct({var.label.offset: TypeVariableReference(subtype_var)})) ) self._recursive_types[subtype_var].add(var.label.offset) def _get_lower_bound(self, v): if isinstance(v, TypeConstant): return v return self._lower_bounds[v] def _get_upper_bound(self, v): if isinstance(v, TypeConstant): return v if v in self._upper_bounds: return self._upper_bounds[v] # try to compute it if isinstance(v, DerivedTypeVariable): if isinstance(v.label, ConvertTo): # after integer conversion, ub = int_type(v.label.to_bits) if ub is not None: self._upper_bounds[v] = ub elif isinstance(v.label, HasField): ub = int_type(v.label.bits) if ub is not None: self._upper_bounds[v] = ub # if all that failed, let the defaultdict generate a Top return self._upper_bounds[v] def _compute_lower_upper_bounds(self, subtypevars, supertypevars): # compute the least upper bound for each type variable for typevar, upper_bounds in supertypevars.items(): if typevar is None: continue if isinstance(typevar, TypeConstant): continue self._upper_bounds[typevar] = self._meet(typevar, *upper_bounds, translate=self._get_upper_bound) # compute the greatest lower bound for each type variable seen = set() # loop avoidance queue = list(subtypevars) while queue: typevar = queue.pop(0) lower_bounds = subtypevars[typevar] if typevar not in seen: # we detect if it depends on any other typevar upon the first encounter seen.add(typevar) abort = False for subtypevar in lower_bounds: if isinstance(subtypevar, TypeVariable) and subtypevar not in self._lower_bounds: # oops - we should analyze the subtypevar first queue.append(typevar) # to avoid loops, make sure typevar does not rely on abort = True break if abort: continue else: # avoid loop and continue no matter what pass self._lower_bounds[typevar] = self._join(typevar, *lower_bounds, translate=self._get_lower_bound) # because of T-InheritR, fields are propagated *both ways* in a subtype relation for subtypevar in lower_bounds: if not isinstance(subtypevar, TypeVariable): continue subtype_infimum = self._lower_bounds[subtypevar] if isinstance(subtype_infimum, Pointer) and isinstance(subtype_infimum.basetype, Struct): subtype_infimum = self._join(subtypevar, typevar, translate=self._get_lower_bound) self._lower_bounds[subtypevar] = subtype_infimum def _lower_struct_fields(self): # tv_680: ptr32(struct{0: int32}) # tv_680.load.<32>@0: ptr32(struct{5: int8}) # becomes # tv_680: ptr32(struct{0: ptr32(struct{5: int8})}) for outer, outer_lb in self._lower_bounds.items(): if ( isinstance(outer, DerivedTypeVariable) and isinstance(outer.label, HasField) and not isinstance(outer_lb, BottomType) ): # unpack v base = outer.type_var.type_var if base in self._lower_bounds: base_lb = self._lower_bounds[base] # make sure it's a pointer at the offset that v.label specifies if isinstance(base_lb, Pointer): if isinstance(base_lb.basetype, Struct): the_field = base_lb.basetype.fields[outer.label.offset] # replace this field new_field = self._meet(the_field, outer_lb, translate=self._get_upper_bound) if new_field != the_field: new_fields = base_lb.basetype.fields.copy() new_fields.update( { outer.label.offset: new_field, } ) base_lb = base_lb.__class__(Struct(new_fields)) self._lower_bounds[base] = base_lb # another attempt: if a pointer to a struct has only one field, remove the struct if len(base_lb.basetype.fields) == 1 and 0 in base_lb.basetype.fields: base_lb = base_lb.__class__(base_lb.basetype.fields[0]) self._lower_bounds[base] = base_lb def _convert_arrays(self, constraints): for constraint in constraints: if not isinstance(constraint, Existence): continue inner = constraint.type_ if isinstance(inner, DerivedTypeVariable) and isinstance(inner.label, IsArray): if inner.type_var in self._lower_bounds: curr_type = self._lower_bounds[inner.type_var] if isinstance(curr_type, Pointer) and isinstance(curr_type.basetype, Struct): # replace all fields with the first field if 0 in curr_type.basetype.fields: first_field = curr_type.basetype.fields[0] for offset in curr_type.basetype.fields.keys(): curr_type.basetype.fields[offset] = first_field def _abstract(self, t): # pylint:disable=no-self-use return t.__class__ def _concretize(self, n_cls, t1, t2, join_or_meet, translate): ptr_class = self._pointer_class() if n_cls is ptr_class: if isinstance(t1, ptr_class) and isinstance(t2, ptr_class): # we need to merge them return ptr_class(join_or_meet(t1.basetype, t2.basetype, translate=translate)) if isinstance(t1, ptr_class): return t1 elif isinstance(t2, ptr_class): return t2 else: # huh? return ptr_class(BottomType()) return n_cls() def _join(self, *args, translate: Callable): """ Get the least upper bound (V, maximum) of the arguments. """ if len(args) == 0: return BottomType() if len(args) == 1: return translate(args[0]) if len(args) > 2: split = len(args) // 2 first = self._join(*args[:split], translate=translate) second = self._join(*args[split:], translate=translate) return self._join(first, second, translate=translate) t1 = translate(args[0]) t2 = translate(args[1]) # Trivial cases if t1 == t2: return t1 if isinstance(t1, TopType): return t1 elif isinstance(t2, TopType): return t2 if isinstance(t1, BottomType): return t2 elif isinstance(t2, BottomType): return t1 if isinstance(t1, TypeVariableReference) and not isinstance(t2, TypeVariableReference): return t1 elif isinstance(t2, TypeVariableReference) and not isinstance(t1, TypeVariableReference): return t2 # consult the graph t1_cls = self._abstract(t1) t2_cls = self._abstract(t2) if t1_cls in self._base_lattice and t2_cls in self._base_lattice: queue = [t1_cls] while queue: n = queue[0] queue = queue[1:] if networkx.has_path(self._base_lattice, n, t2_cls): return self._concretize(n, t1, t2, self._join, translate) # go up queue.extend(self._base_lattice.predecessors(n)) # handling Struct if t1_cls is Struct and t2_cls is Struct: fields = {} for offset in sorted(set(itertools.chain(t1.fields.keys(), t2.fields.keys()))): if offset in t1.fields and offset in t2.fields: v = self._join(t1.fields[offset], t2.fields[offset], translate=translate) elif offset in t1.fields: v = t1.fields[offset] elif offset in t2.fields: v = t2.fields[offset] else: raise Exception("Impossible") fields[offset] = v return Struct(fields=fields) # single element and single-element struct if issubclass(t2_cls, Int) and t1_cls is Struct: # swap them t1, t1_cls, t2, t2_cls = t2, t2_cls, t1, t1_cls if issubclass(t1_cls, Int) and t2_cls is Struct and len(t2.fields) == 1 and 0 in t2.fields: # e.g., char & struct {0: char} return Struct(fields={0: self._join(t1, t2.fields[0], translate=translate)}) ptr_class = self._pointer_class() # Struct and Pointers if t1_cls is ptr_class and t2_cls is Struct: # swap them t1, t1_cls, t2, t2_cls = t2, t2_cls, t1, t1_cls if t1_cls is Struct and len(t1.fields) == 1 and 0 in t1.fields: if t1.fields[0].size == 8 and t2_cls is Pointer64: # they are equivalent # e.g., struct{0: int64} ptr64(int8) # return t2 since t2 is more specific return t2 elif t1.fields[0].size == 4 and t2_cls is Pointer32: return t2 # import ipdb; ipdb.set_trace() return TopType() def _meet(self, *args, translate: Callable): """ Get the greatest lower bound (^, minimum) of the arguments. """ if len(args) == 0: return TopType() if len(args) == 1: return translate(args[0]) if len(args) > 2: split = len(args) // 2 first = self._meet(*args[:split], translate=translate) second = self._meet(*args[split:], translate=translate) return self._meet(first, second, translate=translate) t1 = translate(args[0]) t2 = translate(args[1]) # Trivial cases if t1 == t2: return t1 elif isinstance(t1, BottomType): return t1 elif isinstance(t2, BottomType): return t2 if isinstance(t1, TopType): return t2 elif isinstance(t2, TopType): return t1 if isinstance(t1, TypeVariableReference) and not isinstance(t2, TypeVariableReference): return t1 elif isinstance(t2, TypeVariableReference) and not isinstance(t1, TypeVariableReference): return t2 # consult the graph t1_cls = self._abstract(t1) t2_cls = self._abstract(t2) if t1_cls in self._base_lattice and t2_cls in self._base_lattice: queue = [t1_cls] while queue: n = queue[0] queue = queue[1:] if networkx.has_path(self._base_lattice, t2_cls, n): return self._concretize(n, t1, t2, self._meet, translate) # go down queue.extend(self._base_lattice.successors(n)) # handling Struct if t1_cls is Struct and t2_cls is Struct: fields = {} for offset in sorted(set(itertools.chain(t1.fields.keys(), t2.fields.keys()))): if offset in t1.fields and offset in t2.fields: v = self._meet(t1.fields[offset], t2.fields[offset], translate=translate) elif offset in t1.fields: v = t1.fields[offset] elif offset in t2.fields: v = t2.fields[offset] else: raise Exception("Impossible") fields[offset] = v return Struct(fields=fields) # single element and single-element struct if issubclass(t2_cls, Int) and t1_cls is Struct: # swap them t1, t1_cls, t2, t2_cls = t2, t2_cls, t1, t1_cls if issubclass(t1_cls, Int) and t2_cls is Struct and len(t2.fields) == 1 and 0 in t2.fields: # e.g., char & struct {0: char} return Struct(fields={0: self._meet(t1, t2.fields[0], translate=translate)}) ptr_class = self._pointer_class() # Struct and Pointers if t1_cls is ptr_class and t2_cls is Struct: # swap them t1, t1_cls, t2, t2_cls = t2, t2_cls, t1, t1_cls if t1_cls is Struct and len(t1.fields) == 1 and 0 in t1.fields: if t1.fields[0].size == 8 and t2_cls is Pointer64: # they are equivalent # e.g., struct{0: int64} ptr64(int8) # return t2 since t2 is more specific return t2 elif t1.fields[0].size == 4 and t2_cls is Pointer32: return t2 # import ipdb; ipdb.set_trace() return BottomType()