Source code for claripy.frontend_mixins.model_cache_mixin

from typing import Tuple
import weakref
import itertools

from .. import errors


[docs]class ModelCache: _defaults = {0, 0.0, True}
[docs] def __init__(self, model): self.model = model self.replacements = weakref.WeakKeyDictionary() self.constraint_only_replacements = weakref.WeakKeyDictionary()
def __hash__(self): if not hasattr(self, "_hash"): self._hash = hash(frozenset(self.model.items())) # pylint:disable=attribute-defined-outside-init return self._hash def __eq__(self, other): return self.model == other.model def __getstate__(self): return (self.model,) def __setstate__(self, s): self.model = s[0] self.replacements = weakref.WeakKeyDictionary() self.constraint_only_replacements = weakref.WeakKeyDictionary() # # Splitting support #
[docs] def filter(self, variables): return ModelCache({k: self.model[k] for k in self.model if k in variables})
[docs] @staticmethod def combine(*models): return ModelCache(dict(itertools.chain.from_iterable(m.model.items() for m in models)))
# # Model-driven evaluation # def _leaf_op(self, a): return ( all_operations.BVV(self.model.get(a.args[0], 0), a.length) if a.op == "BVS" else ( all_operations.BoolV(self.model.get(a.args[0], True)) if a.op == "BoolS" else ( all_operations.FPV(self.model.get(a.args[0], 0.0), a.args[1]) if a.op == "FPS" else all_operations.StringV(self.model.get(a.args[0], "")) if a.op == "StringS" else a ) ) ) def _leaf_op_existonly(self, a): return ( all_operations.BVV(self.model[a.args[0]], a.length) if a.op == "BVS" else ( all_operations.BoolV(self.model[a.args[0]]) if a.op == "BoolS" else ( all_operations.FPV(self.model[a.args[0]], a.args[1]) if a.op == "FPS" else all_operations.StringV(self.model[a.args[0]]) if a.op == "StringS" else a ) ) )
[docs] def eval_ast(self, ast, allow_unconstrained: bool = True): """ Eval the ast, replacing symbols by their last value in the model. :param ast: The AST to evaluate. :param allow_unconstrained: When set to True, we will treat non-existent variables as unconstrained variables and will use arbitrary concrete values for them during evaluation. Otherwise, raise KeyErrors for non-existent variables. """ if allow_unconstrained: new_ast = ast.replace_dict(self.replacements, leaf_operation=self._leaf_op) else: new_ast = ast.replace_dict(self.constraint_only_replacements, leaf_operation=self._leaf_op_existonly) return backends.concrete.eval(new_ast, 1)[0]
[docs] def eval_constraints(self, constraints): """Returns whether the constraints is satisfied trivially by using the last model.""" # eval_ast is concretizing symbols and evaluating them, this can raise # exceptions. try: return all(self.eval_ast(c) for c in constraints) except errors.ClaripyZeroDivisionError: return False
[docs] def eval_list(self, asts, allow_unconstrained: bool = True) -> Tuple: """ Evaluate a list of ASTs. :param asts: A list of ASTs to evaluate. :param allow_unconstrained: When set to True, we will treat non-existent variables as unconstrained variables and will use arbitrary concrete values for them during evaluation. Otherwise, raise KeyErrors for non-existent variables. :return: A tuple of evaluated results, one element per AST. """ return tuple(self.eval_ast(c, allow_unconstrained=allow_unconstrained) for c in asts)
[docs]class ModelCacheMixin:
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._models = set() self._exhausted = False self._eval_exhausted = weakref.WeakSet() self._max_exhausted = weakref.WeakSet() self._min_exhausted = weakref.WeakSet() self._max_signed_exhausted = weakref.WeakSet() self._min_signed_exhausted = weakref.WeakSet()
def _blank_copy(self, c): super()._blank_copy(c) c._models = set() c._exhausted = False c._eval_exhausted = weakref.WeakSet() c._max_exhausted = weakref.WeakSet() c._min_exhausted = weakref.WeakSet() c._max_signed_exhausted = weakref.WeakSet() c._min_signed_exhausted = weakref.WeakSet() def _copy(self, c): super()._copy(c) c._models = set(self._models) c._exhausted = self._exhausted c._eval_exhausted = weakref.WeakSet(self._eval_exhausted) c._max_exhausted = weakref.WeakSet(self._max_exhausted) c._min_exhausted = weakref.WeakSet(self._min_exhausted) c._max_signed_exhausted = weakref.WeakSet(self._max_signed_exhausted) c._min_signed_exhausted = weakref.WeakSet(self._min_signed_exhausted) def __setstate__(self, base_state): super().__setstate__(base_state) self._models = set() self._exhausted = False self._eval_exhausted = weakref.WeakSet() self._max_exhausted = weakref.WeakSet() self._min_exhausted = weakref.WeakSet() self._max_signed_exhausted = weakref.WeakSet() self._min_signed_exhausted = weakref.WeakSet() # # Model cleaning #
[docs] def simplify(self, *args, **kwargs): results = super().simplify(*args, **kwargs) if len(results) > 0 and any(c is false for c in results): self._models.clear() return results
def _trivial_model_optimization(self): c = self.constraints[0] if not ( c.depth == 2 and c.op == "__eq__" and len(c.variables) == 1 and c.args[0].symbolic and not c.args[1].symbolic and c.args[0].op == "BVS" ): return self._models.add(ModelCache({next(iter(c.args[0].variables)): backends.concrete.eval(c.args[1], 1)[0]})) self._eval_exhausted.add(c.args[0].cache_key) self._max_exhausted.add(c.args[0].cache_key) self._min_exhausted.add(c.args[0].cache_key) self._max_signed_exhausted.add(c.args[0].cache_key) self._min_signed_exhausted.add(c.args[0].cache_key)
[docs] def add(self, constraints, invalidate_cache=True, **kwargs): if len(constraints) == 0: return constraints old_vars = frozenset(self.variables) added = super().add(constraints, **kwargs) if len(added) == 0: return added if len(self.constraints) == 1 and len(self._models) == 0: self._trivial_model_optimization() new_vars = any(a.variables - old_vars for a in added) if new_vars or invalidate_cache: # shortcut for unsat if any(c is false for c in constraints): self._models.clear() still_valid = set(self._get_models(extra_constraints=added)) if len(still_valid) != len(self._models): self._exhausted = False self._eval_exhausted.clear() self._max_exhausted.clear() self._min_exhausted.clear() self._max_signed_exhausted.clear() self._min_signed_exhausted.clear() self._models = still_valid return added
[docs] def split(self): results = super().split() for r in results: r._models = {m.filter(r.variables) for m in self._models} return results
[docs] def combine(self, others): combined = super().combine(others) if any(len(o._models) == 0 for o in others) or len(self._models) == 0: # this would need a solve anyways, so screw it return combined vars_count = len(self.variables) + sum(len(s.variables) for s in others) all_vars = self.variables.union(*[s.variables for s in others]) if vars_count != len(all_vars): # this is the case where there are variables missing from the models. # We'll need more intelligence here to handle it return combined model_lists = [self._models] model_lists.extend(o._models for o in others) combined._models.update( ModelCache.combine(*product) for product in itertools.islice(itertools.product(*model_lists), len(self._models)) ) return combined
[docs] def update(self, other): """ Updates this cache mixin with results discovered by the other split off one. """ acceptable_models = [m for m in other._models if set(m.model.keys()) == self.variables] self._models.update(acceptable_models) self._eval_exhausted.update(other._eval_exhausted) self._max_exhausted.update(other._max_exhausted) self._min_exhausted.update(other._min_exhausted) self._max_signed_exhausted.update(other._max_signed_exhausted) self._min_signed_exhausted.update(other._min_signed_exhausted)
# # Cache retrieval # def _model_hook(self, m): # Z3 might give us solutions for variables that we did not ask for. so we create a new dict with solutions for # only the variables that are under the solver's control m_ = {k: v for k, v in m.items() if k in self.variables} if m_: model = ModelCache(m_) self._models.add(model) def _get_models(self, extra_constraints=()): for m in self._models: if m.eval_constraints(extra_constraints): yield m def _get_batch_solutions(self, asts, n=None, extra_constraints=(), allow_unconstrained=True): results = set() for m in self._get_models(extra_constraints): try: results.add(m.eval_list(asts, allow_unconstrained=allow_unconstrained)) except (ZeroDivisionError, KeyError): continue if len(results) == n: break return results def _get_solutions(self, e, n=None, extra_constraints=(), allow_unconstrained=True): return tuple( v[0] for v in self._get_batch_solutions( [e], n=n, extra_constraints=extra_constraints, allow_unconstrained=allow_unconstrained, ) ) # # Cached functions #
[docs] def satisfiable(self, extra_constraints=(), **kwargs): for _ in self._get_models(extra_constraints=extra_constraints): return True return super().satisfiable(extra_constraints=extra_constraints, **kwargs)
[docs] def batch_eval(self, asts, n, extra_constraints=(), **kwargs): results = self._get_batch_solutions(asts, n=n, extra_constraints=extra_constraints) if len(results) == n or (len(asts) == 1 and asts[0].cache_key in self._eval_exhausted): return results remaining = n - len(results) # TODO: faster to concat? if len(results) != 0: constraints = ( all_operations.And(*[all_operations.Or(*[a != v for a, v in zip(asts, r)]) for r in results]), ) + tuple(extra_constraints) else: constraints = extra_constraints try: results.update(super().batch_eval(asts, remaining, extra_constraints=constraints, **kwargs)) except UnsatError: if len(results) == 0: raise if len(extra_constraints) == 0 and len(results) < n: for e in asts: # only mark an AST as eval-exhausted if e.variables is a subset of variables that the current solver # knows about (from its constraints) if self.variables.issuperset(e.variables): self._eval_exhausted.add(e.cache_key) return results
[docs] def eval(self, e, n, **kwargs): return tuple(r[0] for r in ModelCacheMixin.batch_eval(self, [e], n=n, **kwargs))
[docs] def min(self, e, extra_constraints=(), signed=False, **kwargs): cached = [] if e.cache_key in self._eval_exhausted or e.cache_key in self._min_exhausted: # we set allow_unconstrained to False because we expect all returned values for e are returned by Z3, # instead of some arbitrarily assigned concrete values. cached = self._get_solutions(e, extra_constraints=extra_constraints, allow_unconstrained=False) if len(cached) > 0: signed_key = lambda v: v if v < 2 ** (len(e) - 1) else v - 2 ** len(e) return min(cached, key=signed_key if signed else lambda v: v) else: m = super().min(e, extra_constraints=extra_constraints, signed=signed, **kwargs) if len(extra_constraints) == 0: (self._min_signed_exhausted if signed else self._min_exhausted).add(e.cache_key) return m
[docs] def max(self, e, extra_constraints=(), signed=False, **kwargs): cached = [] if e.cache_key in self._eval_exhausted or e.cache_key in self._max_exhausted: cached = self._get_solutions(e, extra_constraints=extra_constraints, allow_unconstrained=False) if len(cached) > 0: signed_key = lambda v: v if v < 2 ** (len(e) - 1) else v - 2 ** len(e) return max(cached, key=signed_key if signed else lambda v: v) else: m = super().max(e, extra_constraints=extra_constraints, signed=signed, **kwargs) if len(extra_constraints) == 0: (self._max_signed_exhausted if signed else self._max_exhausted).add(e.cache_key) return m
[docs] def solution(self, e, v, extra_constraints=(), **kwargs): if isinstance(v, Base): cached = self._get_batch_solutions([e, v], extra_constraints=extra_constraints) if any(ec == vc for ec, vc in cached): return True else: cached = self._get_solutions(e, extra_constraints=extra_constraints) if v in cached: return True return super().solution(e, v, extra_constraints=extra_constraints, **kwargs)
from .. import backends, false from ..errors import UnsatError from ..ast import all_operations, Base