Source code for claripy.balancer

from typing import Set
import logging
import operator

l = logging.getLogger("claripy.balancer")


[docs]class Balancer: """ The Balancer is an equation redistributor. The idea is to take an AST and rebalance it to, for example, isolate unknown terms on one side of an inequality. """
[docs] def __init__(self, helper, c, validation_frontend=None): self._helper = helper self._validation_frontend = validation_frontend self._truisms = [] self._processed_truisms = set() self._identified_assumptions = set() self._lower_bounds = {} self._upper_bounds = {} self._queue_truism(c.ite_excavated) self.sat = True try: self._doit() except ClaripyBalancerUnsatError: self.bounds = {} self.sat = False except BackendError: l.debug("Backend error in balancer.", exc_info=True)
@property def compat_ret(self): return (self.sat, self.replacements) def _replacements_iter(self): all_keys = set(self._lower_bounds.keys()) | set(self._upper_bounds.keys()) for k in all_keys: max_int = (1 << len(k.ast)) - 1 min_int = 0 mn = self._lower_bounds.get(k, min_int) mx = self._upper_bounds.get(k, max_int) bound_si = BVS("bound", len(k.ast), min=mn, max=mx) l.debug("Yielding bound %s for %s.", bound_si, k.ast) if k.ast.op == "Reverse": yield (k.ast.args[0], k.ast.intersection(bound_si).reversed) else: yield (k.ast, k.ast.intersection(bound_si)) def _add_lower_bound(self, o, b): l.debug("Adding lower bound %s for %s.", b, o) if o.cache_key in self._lower_bounds: old_b = self._lower_bounds[o.cache_key] l.debug("... old bound: %s", old_b) b = max(b, old_b) l.debug("... new bound: %s", b) if self._validation_frontend is not None: emin = self._validation_frontend.min(o) bmin = self._helper.min(b) assert emin >= bmin self._lower_bounds[o.cache_key] = b def _add_upper_bound(self, o, b): l.debug("Adding upper bound %s for %s.", b, o) if o.cache_key in self._upper_bounds: old_b = self._upper_bounds[o.cache_key] l.debug("... old bound: %s", old_b) b = min(b, old_b) l.debug("... new bound: %s", b) if self._validation_frontend is not None: emax = self._validation_frontend.max(o) bmax = self._helper.max(b) assert emax <= bmax self._upper_bounds[o.cache_key] = b @property def replacements(self): return list(self._replacements_iter()) # # AST helper functions # def _same_bound_bv(self, a): si = backends.vsa.convert(a) mx = self._max(a) mn = self._min(a) return BVS("bounds", len(a), min=mn, max=mx, stride=si._stride) @staticmethod def _cardinality(a): return a.cardinality if isinstance(a, Base) else 0 @staticmethod def _min(a, signed=False): converted = backends.vsa.convert(a) if isinstance(converted, vsa.ValueSet): if len(converted.regions) == 1: converted = list(converted.regions.values())[0] else: # unfortunately, this is a real abstract pointer # the minimum value will be 0 or MIN_INT if signed: return -(1 << (len(converted) - 1)) else: return 0 if not signed: bounds = converted._unsigned_bounds() else: bounds = converted._signed_bounds() return min(mn for mn, mx in bounds) @staticmethod def _max(a, signed=False): converted = backends.vsa.convert(a) if isinstance(converted, vsa.ValueSet): if len(converted.regions) == 1: converted = list(converted.regions.values())[0] else: # unfortunately, this is a real abstract pointer # the minimum value will be 0 or MIN_INT if signed: return (1 << (len(converted) - 1)) - 1 else: return (1 << len(converted)) - 1 if not signed: bounds = converted._unsigned_bounds() else: bounds = converted._signed_bounds() return max(mx for mn, mx in bounds) def _range(self, a, signed=False): return (self._min(a, signed=signed), self._max(a, signed=signed)) @staticmethod def _invert_comparison(a): return _all_operations.Not(a) # # Truism alignment # def _align_truism(self, truism): outer_aligned = self._align_ast(truism) inner_aligned = outer_aligned.make_like( outer_aligned.op, (self._align_ast(outer_aligned.args[0]),) + outer_aligned.args[1:] ) if not backends.vsa.identical(inner_aligned, truism): l.critical( "ERROR: the balancer is messing up an AST. This must be looked into. Please submit the binary and script to the angr project, if possible. Outer op is %s and inner op is %s.", truism.op, truism.args[0].op, ) return truism return inner_aligned def _align_ast(self, a): """ Aligns the AST so that the argument with the highest cardinality is on the left. :return: a new AST. """ try: if isinstance(a, BV): return self._align_bv(a) elif isinstance(a, Bool) and len(a.args) == 2 and a.args[1].cardinality > a.args[0].cardinality: return self._reverse_comparison(a) else: return a except ClaripyBalancerError: return a @staticmethod def _reverse_comparison(a): try: new_op = opposites[a.op] except KeyError: raise ClaripyBalancerError("unable to reverse comparison %s (missing from 'opposites')" % a.op) try: if new_op.startswith("__"): op = getattr(operator, new_op) else: op = getattr(_all_operations, new_op) except AttributeError: raise ClaripyBalancerError("unable to reverse comparison %s (AttributeError)" % a.op) try: return op(*a.args[::-1]) except ClaripyOperationError: # TODO: copy trace raise ClaripyBalancerError("unable to reverse comparison %s (ClaripyOperationError)" % a.op) def _align_bv(self, a): if a.op in commutative_operations: return a.make_like(a.op, tuple(sorted(a.args, key=lambda v: -self._cardinality(v)))) else: try: op = getattr(self, "_align_" + a.op) except AttributeError: return a return op(a) def _align___sub__(self, a): cardinalities = [self._cardinality(v) for v in a.args] if max(cardinalities) == cardinalities[0]: return a adjusted = tuple(operator.__neg__(v) for v in a.args[1:]) + a.args[:1] return a.make_like("__add__", tuple(sorted(adjusted, key=lambda v: -self._cardinality(v)))) # # Find bounds # def _doit(self): """ This function processes the list of truisms and finds bounds for ASTs. """ while len(self._truisms): truism = self._truisms.pop() if truism in self._processed_truisms: continue unpacked_truisms = self._unpack_truisms(truism) if is_false(truism): raise ClaripyBalancerUnsatError() self._processed_truisms.add(truism) if len(unpacked_truisms): self._queue_truisms(unpacked_truisms, check_true=True) continue if not self._handleable_truism(truism): continue truism = self._adjust_truism(truism) assumptions = self._get_assumptions(truism) if truism not in self._identified_assumptions and len(assumptions): l.debug("Queued assumptions %s for truism %s.", assumptions, truism) self._truisms.extend(assumptions) self._identified_assumptions.update(assumptions) l.debug("Processing truism %s", truism) balanced_truism = self._balance(truism) l.debug("... handling") self._handle(balanced_truism) def _queue_truism(self, t, check_true=False): if not check_true: self._truisms.append(t) elif check_true and not is_true(t): self._truisms.append(t) def _queue_truisms(self, ts, check_true=False): if check_true: self._truisms.extend(t for t in ts if not is_true(t)) else: self._truisms.extend(ts) @staticmethod def _handleable_truism(t): """ Checks whether we can handle this truism. The truism should already be aligned. """ if len(t.args) < 2: l.debug("can't do anything with an unop bool") elif t.args[0].cardinality > 1 and t.args[1].cardinality > 1: l.debug("can't do anything because we have multiple multivalued guys") return False elif t.op == "If": l.debug("can't handle If") return False else: return True @staticmethod def _adjust_truism(t): """ Swap the operands of the truism if the unknown variable is on the right side and the concrete value is on the left side. """ if t.args[0].cardinality == 1 and t.args[1].cardinality > 1: swapped = Balancer._reverse_comparison(t) return swapped return t # # Assumptions management # @staticmethod def _get_assumptions(t): """ Given a constraint, _get_assumptions() returns a set of constraints that are implicitly assumed to be true. For example, `x <= 10` would return `x >= 0`. """ if t.op in ("__le__", "__lt__", "ULE", "ULT"): return [t.args[0] >= 0] elif t.op in ("__ge__", "__gt__", "UGE", "UGT"): return [t.args[0] <= 2 ** len(t.args[0]) - 1] elif t.op in ("SLE", "SLT"): return [_all_operations.SGE(t.args[0], -(1 << (len(t.args[0]) - 1)))] elif t.op in ("SGE", "SGT"): return [_all_operations.SLE(t.args[0], (1 << (len(t.args[0]) - 1)) - 1)] else: return [] # # Truism extractor # def _unpack_truisms(self, c) -> Set: """ Given a constraint, _unpack_truisms() returns a set of constraints that must be True for this constraint to be True. """ try: op = getattr(self, "_unpack_truisms_" + c.op) except AttributeError: return set() return op(c) def _unpack_truisms_And(self, c): return set.union(*[self._unpack_truisms(a) for a in c.args]) def _unpack_truisms_Not(self, c): if c.args[0].op == "And": return self._unpack_truisms(_all_operations.Or(*[_all_operations.Not(a) for a in c.args[0].args])) elif c.args[0].op == "Or": return self._unpack_truisms(_all_operations.And(*[_all_operations.Not(a) for a in c.args[0].args])) else: return set() def _unpack_truisms_Or(self, c): vals = [is_false(v) for v in c.args] if all(vals): raise ClaripyBalancerUnsatError() elif vals.count(False) == 1: return self._unpack_truisms(c.args[vals.index(False)]) else: return set() # # Dealing with constraints # comparison_info = {} # Tuples look like (is_lt, is_eq, is_unsigned) comparison_info["SLT"] = (True, False, False) comparison_info["SLE"] = (True, True, False) comparison_info["SGT"] = (False, False, False) comparison_info["SGE"] = (False, True, False) comparison_info["ULT"] = (True, False, True) comparison_info["ULE"] = (True, True, True) comparison_info["UGT"] = (False, False, True) comparison_info["UGE"] = (False, True, True) comparison_info["__lt__"] = comparison_info["ULT"] comparison_info["__le__"] = comparison_info["ULE"] comparison_info["__gt__"] = comparison_info["UGT"] comparison_info["__ge__"] = comparison_info["UGE"] # # Simplification routines # def _balance(self, truism): l.debug("Balancing %s", truism) # can't balance single-arg bools (Not) for now if len(truism.args) == 1: return truism if not isinstance(truism.args[0], Base): return truism try: inner_aligned = self._align_truism(truism) if inner_aligned.args[1].cardinality > 1: l.debug("can't do anything because we have multiple multivalued guys") return truism try: balancer = getattr(self, "_balance_%s" % inner_aligned.args[0].op) except AttributeError: l.debug("Balance handler %s is not found in balancer. Consider implementing.", truism.args[0].op) return truism balanced = balancer(inner_aligned) if balanced is inner_aligned: # print("... balanced:", balanced) return balanced else: return self._balance(balanced) except ClaripyBalancerError: l.warning("Balance handler for operation %s raised exception.", truism.args[0].op) return truism @staticmethod def _balance_Reverse(truism): if truism.op in ["__eq__", "__ne__"]: return truism.make_like(truism.op, (truism.args[0].args[0], truism.args[1].reversed)) else: return truism @staticmethod def _balance___add__(truism): if len(truism.args) != 2: return truism new_lhs = truism.args[0].args[0] old_rhs = truism.args[1] other_adds = truism.args[0].args[1:] new_rhs = truism.args[0].make_like("__sub__", (old_rhs,) + other_adds) return truism.make_like(truism.op, (new_lhs, new_rhs)) @staticmethod def _balance___sub__(truism): if len(truism.args) != 2: return truism new_lhs = truism.args[0].args[0] old_rhs = truism.args[1] other_adds = truism.args[0].args[1:] new_rhs = truism.args[0].make_like("__add__", (old_rhs,) + other_adds) return truism.make_like(truism.op, (new_lhs, new_rhs)) @staticmethod def _balance_ZeroExt(truism): num_zeroes, inner = truism.args[0].args other_side = truism.args[1][len(truism.args[1]) - 1 : len(truism.args[1]) - num_zeroes] if is_true(other_side == 0): # We can safely eliminate this layer of ZeroExt new_args = (inner, truism.args[1][len(truism.args[1]) - num_zeroes - 1 : 0]) return truism.make_like(truism.op, new_args) return truism @staticmethod def _balance_SignExt(truism): num_zeroes = truism.args[0].args[0] left_side = truism.args[0][len(truism.args[1]) - 1 : len(truism.args[1]) - num_zeroes] other_side = truism.args[1][len(truism.args[1]) - 1 : len(truism.args[1]) - num_zeroes] # TODO: what if this is a set value, but *not* the same as other_side if backends.vsa.identical(left_side, other_side): # We can safely eliminate this layer of ZeroExt new_args = (truism.args[0].args[1], truism.args[1][len(truism.args[1]) - num_zeroes - 1 : 0]) return truism.make_like(truism.op, new_args) return truism @staticmethod def _balance_Extract(truism): high, low, inner = truism.args[0].args inner_size = len(inner) if high < inner_size - 1: left_msb = inner[inner_size - 1 : high + 1] left_msb_zero = is_true(left_msb == 0) else: left_msb = None left_msb_zero = None if low > 0: left_lsb = inner[high - 1 : 0] left_lsb_zero = is_true(left_lsb == 0) else: left_lsb = None left_lsb_zero = None if left_msb_zero and left_lsb_zero: new_left = inner new_right = _all_operations.Concat(BVV(0, len(left_msb)), truism.args[1], BVV(0, len(left_lsb))) return truism.make_like(truism.op, (new_left, new_right)) elif left_msb_zero: new_left = inner new_right = _all_operations.Concat(BVV(0, len(left_msb)), truism.args[1]) return truism.make_like(truism.op, (new_left, new_right)) elif left_lsb_zero: new_left = inner new_right = _all_operations.Concat(truism.args[1], BVV(0, len(left_lsb))) return truism.make_like(truism.op, (new_left, new_right)) if low == 0 and truism.args[1].op == "BVV" and truism.op not in {"SGE", "SLE", "SGT", "SLT"}: # single-valued rhs value with an unsigned operator # Eliminate Extract on lhs and zero-extend the value on rhs new_left = inner new_right = _all_operations.ZeroExt(inner.size() - truism.args[1].size(), truism.args[1]) return truism.make_like(truism.op, (new_left, new_right)) return truism @staticmethod def _balance___and__(truism): if len(truism.args[0].args) != 2: return truism op0, op1 = truism.args[0].args if op1.op == "BVV": # if all low bits of right are 1 and all high bits of right are 0, then this is equivalent to Extract() v = op1.args[0] low_ones = 0 while v != 0: if v & 1 == 0: # not all high bits are 0. abort return truism low_ones += 1 v >>= 1 if low_ones == 0: # this should probably never happen new_left = truism.args[0].make_like("BVV", (0, truism.args[0].size())) return truism.make_like(truism.op, (new_left, truism.args[1])) if op0.op == "ZeroExt" and op0.args[0] + low_ones == op0.size(): # ZeroExt(56, a) & 0xff == a if a.size() == 8 # we can safely remove __and__ new_left = op0 return truism.make_like(truism.op, (new_left, truism.args[1])) return truism @staticmethod def _balance_Concat(truism): size = len(truism.args[0]) left_msb = truism.args[0].args[0] right_msb = truism.args[1][size - 1 : size - len(left_msb)] if is_true(left_msb == 0) and is_true(right_msb == 0): # we can cut these guys off! remaining_left = _all_operations.Concat(*truism.args[0].args[1:]) remaining_right = truism.args[1][size - len(left_msb) - 1 : 0] return truism.make_like(truism.op, (remaining_left, remaining_right)) else: # TODO: handle non-zero single-valued cases return truism def _balance___lshift__(self, truism): lhs = truism.args[0] rhs = truism.args[1] shift_amount_expr = lhs.args[1] expr = lhs.args[0] shift_amount_values = self._helper.eval(shift_amount_expr, 2) if len(shift_amount_values) != 1: return truism shift_amount = shift_amount_values[0] rhs_lower = _all_operations.Extract(shift_amount - 1, 0, rhs) rhs_lower_values = self._helper.eval(rhs_lower, 2) if len(rhs_lower_values) == 1 and rhs_lower_values[0] == 0: # we can remove the __lshift__ return truism.make_like(truism.op, (expr, rhs >> shift_amount)) return truism def _balance_If(self, truism): condition, true_expr, false_expr = truism.args[0].args try: if truism.op.startswith("__"): true_condition = getattr(operator, truism.op)(true_expr, truism.args[1]) false_condition = getattr(operator, truism.op)(false_expr, truism.args[1]) else: true_condition = getattr(_all_operations, truism.op)(true_expr, truism.args[1]) false_condition = getattr(_all_operations, truism.op)(false_expr, truism.args[1]) except ClaripyOperationError: # the condition was probably a Not (TODO) return truism can_true = backends.vsa.has_true(true_condition) can_false = backends.vsa.has_true(false_condition) must_true = backends.vsa.is_true(true_condition) must_false = backends.vsa.is_true(false_condition) if can_true and can_false: # always satisfiable return truism elif not (can_true or can_false): # neither are satisfiable. This truism is fucked raise ClaripyBalancerUnsatError() elif must_true or (can_true and not can_false): # it will always be true self._queue_truism(condition) return truism.make_like(truism.op, (true_expr, truism.args[1])) elif must_false or (can_false and not can_true): # it will always be false self._queue_truism(self._invert_comparison(condition)) return truism.make_like(truism.op, (false_expr, truism.args[1])) # # Constraint handlers # def _handle(self, truism): l.debug("Handling %s", truism) if is_false(truism): raise ClaripyBalancerUnsatError() elif self._cardinality(truism.args[0]) == 1: # we are down to single-cardinality arguments, so our work is not # necessary return try: handler = getattr(self, "_handle_%s" % truism.op) except AttributeError: l.debug("No handler for operation %s", truism.op) return handler(truism) def _handle_comparison(self, truism): """ Handles all comparisons. """ # print("COMP:", truism) is_lt, is_equal, is_unsigned = self.comparison_info[truism.op] size = len(truism.args[0]) int_max = 2**size - 1 if is_unsigned else 2 ** (size - 1) - 1 int_min = -(2 ** (size - 1)) left_min = self._min(truism.args[0], signed=not is_unsigned) left_max = self._max(truism.args[0], signed=not is_unsigned) right_min = self._min(truism.args[1], signed=not is_unsigned) right_max = self._max(truism.args[1], signed=not is_unsigned) bound_max = right_max if is_equal else (right_max - 1 if is_lt else right_max + 1) bound_min = right_min if is_equal else (right_min - 1 if is_lt else right_min + 1) if is_lt and bound_max < int_min: # if the bound max is negative and we're unsigned less than, we're fucked raise ClaripyBalancerUnsatError() elif not is_lt and bound_min > int_max: # if the bound min is too big, we're fucked raise ClaripyBalancerUnsatError() current_min = int_min current_max = int_max if is_lt: current_max = min(int_max, left_max, bound_max) self._add_upper_bound(truism.args[0], current_max) else: current_min = max(int_min, left_min, bound_min) self._add_lower_bound(truism.args[0], current_min) def _handle___eq__(self, truism): lhs, rhs = truism.args if rhs.cardinality != 1: common = self._same_bound_bv(lhs.intersection(rhs)) mn, mx = self._range(common) self._add_upper_bound(lhs, mx) self._add_upper_bound(rhs, mx) self._add_lower_bound(lhs, mn) self._add_lower_bound(rhs, mn) else: mn, mx = self._range(rhs) self._add_upper_bound(lhs, mx) self._add_lower_bound(lhs, mn) def _handle___ne__(self, truism): lhs, rhs = truism.args if rhs.cardinality == 1: val = self._helper.eval(rhs, 1)[0] max_int = vsa.StridedInterval.max_int(len(rhs)) if val == 0: self._add_lower_bound(lhs, val + 1) elif val == max_int or val == -1: self._add_upper_bound(lhs, max_int - 1) def _handle_If(self, truism): if is_false(truism.args[2]): self._queue_truism(truism.args[0]) elif is_false(truism.args[1]): self._queue_truism(self._invert_comparison(truism.args[0])) _handle___lt__ = _handle_comparison _handle___le__ = _handle_comparison _handle___gt__ = _handle_comparison _handle___ge__ = _handle_comparison _handle_ULT = _handle_comparison _handle_ULE = _handle_comparison _handle_UGT = _handle_comparison _handle_UGE = _handle_comparison _handle_SLT = _handle_comparison _handle_SLE = _handle_comparison _handle_SGT = _handle_comparison _handle_SGE = _handle_comparison
[docs]def is_true(a): return backends.vsa.is_true(a)
[docs]def is_false(a): return backends.vsa.is_false(a)
from .errors import ClaripyBalancerError, ClaripyBalancerUnsatError, ClaripyOperationError, BackendError from .ast.base import Base from .ast.bool import Bool from .ast.bv import BVV, BVS, BV from . import _all_operations from .backend_manager import backends from . import vsa from .operations import opposites, commutative_operations