from typing import Set, TYPE_CHECKING
import logging
l = logging.getLogger("claripy.frontends.composite_frontend")
import weakref
import itertools
symbolic_count = itertools.count()
from .constrained_frontend import ConstrainedFrontend
from claripy.ast.strings import String
if TYPE_CHECKING:
from claripy import SolverCompositeChild
[docs]class CompositeFrontend(ConstrainedFrontend):
[docs] def __init__(self, template_frontend, template_frontend_string, track=False, **kwargs):
super().__init__(**kwargs)
self._solvers = {}
self._unchecked_solvers = weakref.WeakSet()
self._owned_solvers = weakref.WeakSet()
self._template_frontend = template_frontend
self._template_frontend_string = template_frontend_string
self._unsat = False
self._track = track
def _blank_copy(self, c):
super()._blank_copy(c)
c._unchecked_solvers = weakref.WeakSet()
c._owned_solvers = weakref.WeakSet()
c._solvers = {}
c._template_frontend = self._template_frontend
if hasattr(self, "_template_frontend_string"):
c._template_frontend_string = self._template_frontend_string
c._unsat = False
c._track = self._track
def _copy(self, c):
super()._copy(c)
c._unsat = self._unsat
c._track = self._track
c._solvers = dict(self._solvers)
c._unchecked_solvers = weakref.WeakSet(self._unchecked_solvers)
self._owned_solvers = weakref.WeakSet() # for the COW
return c
#
# Serialization stuff
#
def __getstate__(self):
return self._solvers, self._template_frontend, self._unsat, self._track, super().__getstate__()
def __setstate__(self, s):
self._solvers, self._template_frontend, self._unsat, self._track, base_state = s
self._owned_solvers = weakref.WeakSet(self._solver_list)
self._unchecked_solvers = weakref.WeakSet()
super().__setstate__(base_state)
[docs] def downsize(self):
for e in self._solver_list:
e.downsize()
#
# Frontend management
#
@property
def _solver_list(self):
seen_solvers = set()
solver_list = []
for s in self._solvers.values():
if id(s) in seen_solvers:
continue
seen_solvers.add(id(s))
solver_list.append(s)
return solver_list
@property
def variables(self):
return set(self._solvers.keys())
# this is really hacky, but we want to avoid having our variables messed with
@variables.setter
def variables(self, v):
pass
#
# Solver list management
#
def _solvers_for_variables(self, names):
seen_solvers = set()
existing_solvers = []
for n in names:
if n not in self._solvers:
continue
s = self._solvers[n]
if id(s) in seen_solvers:
continue
seen_solvers.add(id(s))
existing_solvers.append(s)
return existing_solvers
@staticmethod
def _names_for(names=None, lst=None, lst2=None, e=None, v=None) -> Set[str]:
if names is None:
names = set()
if e is not None and isinstance(e, Base):
names.update(e.variables)
if v is not None and isinstance(v, Base):
names.update(v.variables)
if lst is not None:
for ee in lst:
if isinstance(ee, Base):
names.update(ee.variables)
if lst2 is not None:
for ee in lst2:
if isinstance(ee, Base):
names.update(ee.variables)
return names
def _merged_solver_for(self, *args, **kwargs):
return self._solver_for_names(self._names_for(*args, **kwargs))
def _solver_for_names(self, names: Set[str]) -> "SolverCompositeChild":
"""
Get a merged child solver for variables specified in `names`.
:param names: A set of variable names.
:return: A composite child solver.
"""
l.debug("composite_solver._merged_solver_for() running with %d names", len(names))
# compute a transitive closure for all variable names
all_names = set(names)
new_names = names
solvers = set()
while True:
tmp_solvers = self._solvers_for_variables(new_names)
solvers |= set(tmp_solvers)
all_names |= new_names
tmp_names = set()
for solver in solvers:
tmp_names |= solver.variables
new_names = tmp_names.difference(all_names)
if not new_names:
break
solvers = list(solvers)
if len(solvers) == 0:
if any(var for var in names if var.startswith(String.STRING_TYPE_IDENTIFIER)):
l.debug("... creating new solver for strings")
return self._template_frontend_string.blank_copy()
else:
l.debug("... creating new solver")
return self._template_frontend.blank_copy()
elif len(solvers) == 1:
l.debug("... got one solver")
return solvers[0]
else:
l.debug(".... combining %d solvers", len(solvers))
return solvers[0].combine(solvers[1:])
def _shared_solvers(self, others):
"""
Returns a sequence of the solvers that self and others share.
"""
solvers_by_id = {id(s): s for s in self._solver_list}
common_solvers = set(solvers_by_id.keys())
other_sets = [{id(s) for s in cs._solver_list} for cs in others]
for o in other_sets:
common_solvers &= o
return [solvers_by_id[s] for s in common_solvers]
def _variable_sets(self):
return {s.variables for s in self._solver_list}
def _shared_varsets(self, others):
common_varsets = self._variable_sets()
for o in others:
common_varsets &= o.all_varsets()
return common_varsets
def _split_child(self, s):
ss = s.split()
if len(ss) == 1:
return [s]
l.debug("... split solver %r into %d parts", s, len(ss))
l.debug("... variable counts: %s", [len(cs.variables) for cs in ss])
for ns in ss:
self._owned_solvers.add(ns)
self._store_child(ns)
return ss
def _reabsorb_solver(self, s):
try:
if len(s.variables) == 0 or self._solvers[min(iter(s.variables))] is s:
return
except KeyError:
# this happens when a variable is introduced due to constraint expansion
return
if isinstance(s, ModelCacheMixin):
new_solvers = s.split()
old_solvers = self._solvers_for_variables(s.variables)
if len(new_solvers) == len(old_solvers):
done = set()
for ss in new_solvers:
if ss in done:
continue
done.add(ss)
v = min(iter(ss.variables))
self._solvers[v].update(ss)
else:
for ns in new_solvers:
self._owned_solvers.add(ns)
self._store_child(ns)
def _store_child(self, ns, extra_names=frozenset(), invalidate_cache=True):
for v in ns.variables | extra_names:
# os = self._solvers[v]
self._solvers[v] = ns
if invalidate_cache:
self._unchecked_solvers.add(ns)
# if isinstance(s, ModelCacheMixin):
# if len(os._models) < len(ns._models):
# print("GOT %d NEW MODELS (before: %d)" % (
# len(ns._models) - len(os._models), len(os._models)
# ))
# elif len(os._models) > len(ns._models):
# print("WARNING: LOST %d NEW MODELS (before: %d)" % (
# len(os._models) - len(ns._models), len(os._models)
# ))
# else:
# print("Remained at %d models." % len(os._models))
#
# Constraints
#
def _claim(self, s):
if s not in self._owned_solvers:
sc = s.branch()
self._owned_solvers.add(sc)
return sc
else:
return s
def _add_dependent_constraints(self, names, constraints, invalidate_cache=True, **kwargs):
if not invalidate_cache and len(self._solvers_for_variables(names)) > 1:
l.debug("Ignoring cross-solver helper constraints.")
return []
l.debug("Adding %d constraints to %d names", len(constraints), len(names))
s = self._claim(self._merged_solver_for(names=names))
added = s.add(constraints, invalidate_cache=invalidate_cache, **kwargs)
self._store_child(s, invalidate_cache=invalidate_cache)
return added
[docs] def add(self, constraints, **kwargs): # pylint:disable=arguments-differ
split = self._split_constraints(constraints)
child_added = []
# l.debug("%s, solvers before: %d", self, len(self._solvers))
unsure = []
for names, set_constraints in split:
if names == {"CONCRETE"}:
try:
if any(backends.concrete.convert(c) is False for c in set_constraints):
self._unsat = True
except BackendError:
unsure.extend(set_constraints)
else:
child_added += self._add_dependent_constraints(names, set_constraints, **kwargs)
# l.debug("... solvers after add: %d", len(self._solver_list))
if len(unsure) > 0:
for s in self._solver_list:
s = self._claim(s)
s.add(unsure)
self._store_child(s)
return super().add(child_added)
#
# Solving
#
def _ensure_sat(self, extra_constraints):
if self._unsat or (len(extra_constraints) == 0 and not self.satisfiable()):
raise UnsatError("CompositeSolver is already unsat")
[docs] def check_satisfiability(self, extra_constraints=(), exact=None):
if self._unsat:
return "UNSAT"
l.debug("%r checking satisfiability...", self)
if len(extra_constraints) != 0:
extra_solver = self._merged_solver_for(lst=extra_constraints)
extra_solver_satness = extra_solver.check_satisfiability(extra_constraints=extra_constraints, exact=exact)
if extra_solver_satness in {"UNSAT", "UNKNOWN"}:
return extra_solver_satness
self._reabsorb_solver(extra_solver)
for s in self._unchecked_solvers:
if extra_constraints and s.variables & extra_solver.variables:
# skip solvers covered by extra constraints (they were checked above)
continue
if len(s.variables) == 0 or self._solvers[min(iter(s.variables))] is not s:
# this happens when a parent solver didn't check all unchecked solvers, and we have stale
# child solvers in the unchecked list
continue
satness = s.check_satisfiability(exact=exact)
if satness in {"UNSAT", "UNKNOWN"}:
return satness
self._unchecked_solvers.clear()
return "SAT"
[docs] def satisfiable(self, extra_constraints=(), exact=None):
return self.check_satisfiability(extra_constraints=extra_constraints, exact=exact) == "SAT"
[docs] def eval(self, e, n, extra_constraints=(), exact=None):
self._ensure_sat(extra_constraints=extra_constraints)
ms = self._merged_solver_for(e=e, lst=extra_constraints)
r = ms.eval(e, n, extra_constraints=extra_constraints, exact=exact)
self._reabsorb_solver(ms)
return r
[docs] def batch_eval(self, exprs, n, extra_constraints=(), exact=None):
self._ensure_sat(extra_constraints=extra_constraints)
ms = self._merged_solver_for(lst2=exprs, lst=extra_constraints)
r = ms.batch_eval(exprs, n, extra_constraints=extra_constraints, exact=exact)
self._reabsorb_solver(ms)
return r
[docs] def max(self, e, extra_constraints=(), signed=False, exact=None):
self._ensure_sat(extra_constraints=extra_constraints)
ms = self._merged_solver_for(e=e, lst=extra_constraints)
r = ms.max(e, extra_constraints=extra_constraints, signed=signed, exact=exact)
self._reabsorb_solver(ms)
return r
[docs] def min(self, e, extra_constraints=(), signed=False, exact=None):
self._ensure_sat(extra_constraints=extra_constraints)
ms = self._merged_solver_for(e=e, lst=extra_constraints)
r = ms.min(e, extra_constraints=extra_constraints, signed=signed, exact=exact)
self._reabsorb_solver(ms)
return r
[docs] def solution(self, e, v, extra_constraints=(), exact=None):
self._ensure_sat(extra_constraints=extra_constraints)
ms = self._merged_solver_for(e=e, v=v, lst=extra_constraints)
r = ms.solution(e, v, extra_constraints=extra_constraints, exact=exact)
self._reabsorb_solver(ms)
return r
[docs] def is_true(self, e, extra_constraints=(), exact=None):
# self._ensure_sat(extra_constraints=extra_constraints)
ms = self._merged_solver_for(e=e, lst=extra_constraints)
r = ms.is_true(e, extra_constraints=extra_constraints, exact=exact)
# self._reabsorb_solver(ms)
return r
[docs] def is_false(self, e, extra_constraints=(), exact=None):
# self._ensure_sat(extra_constraints=extra_constraints)
ms = self._merged_solver_for(e=e, lst=extra_constraints)
r = ms.is_false(e, extra_constraints=extra_constraints, exact=exact)
# self._reabsorb_solver(ms)
return r
[docs] def unsat_core(self, extra_constraints=()):
if self.satisfiable(extra_constraints=extra_constraints):
return ()
cores = []
for solver in self._solver_list:
cores.extend(list(solver.unsat_core(extra_constraints=extra_constraints)))
return cores
[docs] def simplify(self):
if self._unsat:
return self.constraints
new_constraints = []
l.debug("Simplifying %r with %d solvers", self, len(self._solver_list))
for s in self._solver_list:
if isinstance(s, SimplifySkipperMixin) and s._simplified:
new_constraints += s.constraints
continue
l.debug("... simplifying child solver %r", s)
s.simplify()
results = self._split_child(s)
for ns in results:
if isinstance(ns, SimplifySkipperMixin):
ns._simplified = True
new_constraints += s.constraints
l.debug("... after-split, %r has %d solvers", self, len(self._solver_list))
self.constraints = new_constraints
return new_constraints
#
# Merging and splitting
#
[docs] def finalize(self):
for s in self._solver_list:
s.finalize()
@property
def timeout(self):
return self._template_frontend.timeout
@timeout.setter
def timeout(self, t):
self._template_frontend.timeout = t
for s in self._solver_list:
s.timeout = t
@property
def max_memory(self):
return self._template_frontend.max_memory
@max_memory.setter
def max_memory(self, val):
# this is technically wrong. we cannot enforce a memory limit for a pool shared among multiple solvers
self._template_frontend.max_memory = val
for s in self._solver_list:
s.max_memory = val
@staticmethod
def _merge_with_ancestor(common_ancestor, merge_conditions):
merged = common_ancestor.branch()
merged.add([Or(*merge_conditions)])
# import ipdb; ipdb.set_trace()
return True, merged
[docs] def merge(self, others, merge_conditions, common_ancestor=None):
if common_ancestor is not None:
return self._merge_with_ancestor(common_ancestor, merge_conditions)
l.debug("Merging %s with %d other solvers.", self, len(others))
merged = self.blank_copy()
common_solvers = self._shared_solvers(others)
common_ids = {id(s) for s in common_solvers}
l.debug("... %s common solvers", len(common_solvers))
for s in common_solvers:
self._owned_solvers.discard(s)
for o in others:
o._owned_solvers.discard(s)
for v in s.variables:
merged._solvers[v] = s
noncommon_solvers = [[s for s in cs._solver_list if id(s) not in common_ids] for cs in [self] + others]
l.debug("... merging noncommon solvers")
combined_noncommons = []
for ns in noncommon_solvers:
l.debug("... %d", len(ns))
if len(ns) == 0:
pass
elif len(ns) == 1:
combined_noncommons.append(ns[0])
else:
combined_noncommons.append(ns[0].combine(ns[1:]))
if len(combined_noncommons):
_, merged_noncommon = combined_noncommons[0].merge(combined_noncommons[1:], merge_conditions)
merged._owned_solvers.add(merged_noncommon)
merged._store_child(merged_noncommon)
merged.constraints = list(itertools.chain.from_iterable(a.constraints for a in merged._solver_list))
return True, merged
[docs] def combine(self, others):
combined = self.blank_copy()
combined.add(self.constraints)
for o in others:
combined.add(o.constraints)
return combined
[docs] def split(self):
return [s.branch() for s in self._solver_list]
from ..ast import Base
from ..ast.bool import Or
from .. import backends
from ..errors import BackendError, UnsatError
from ..frontend_mixins.model_cache_mixin import ModelCacheMixin
from ..frontend_mixins.simplify_skipper_mixin import SimplifySkipperMixin