import logging
import numbers
import functools
import operator
from functools import reduce
l = logging.getLogger("claripy.backends.backend_vsa")
from . import Backend, BackendError
from ..vsa import RegionAnnotation
[docs]def arg_filter(f):
@functools.wraps(f)
def filter(*args): # pylint:disable=redefined-builtin
if isinstance(args[0], numbers.Number): # pylint:disable=unidiomatic-typecheck
raise BackendError("Unsupported argument type %s" % type(args[0]))
return f(*args)
return filter
[docs]def normalize_arg_order(f):
@functools.wraps(f)
def normalizer(*args):
if len(args) != 2:
raise BackendError("Unsupported arguments number %d" % len(args))
if type(args[0]) not in {
StridedInterval,
DiscreteStridedIntervalSet,
ValueSet,
}: # pylint:disable=unidiomatic-typecheck
if type(args[1]) not in {
StridedInterval,
DiscreteStridedIntervalSet,
ValueSet,
}: # pylint:disable=unidiomatic-typecheck
raise BackendError("Unsupported arguments")
args = [args[1], args[0]]
return f(*args)
return normalizer
[docs]def convert_args(f):
@functools.wraps(f)
def converter(self, ast):
raw_args = []
for i in range(len(ast.args)):
# It's not reversed
raw_args.append(ast.args[i])
for i in range(len(raw_args)):
raw_args[i] = self.convert(raw_args[i])
normalized = ast.swap_args(raw_args)
ret = f(self, normalized)
return ret
return converter
[docs]class BackendVSA(Backend):
[docs] def __init__(self):
Backend.__init__(self)
# self._make_raw_ops(set(expression_operations) - set(expression_set_operations), op_module=BackendVSA)
self._make_expr_ops(set(expression_set_operations), op_class=self)
self._make_raw_ops(set(backend_operations_vsa_compliant), op_module=BackendVSA)
self._op_raw["Reverse"] = BackendVSA.Reverse
self._op_raw["If"] = self.If
self._op_expr["BVV"] = self.BVV
self._op_expr["BoolV"] = self.BoolV
self._op_expr["BVS"] = self.BVS
# reduceable
self._op_raw["__add__"] = self._op_add
self._op_raw["__sub__"] = self._op_sub
self._op_raw["__mul__"] = self._op_mul
self._op_raw["__or__"] = self._op_or
self._op_raw["__xor__"] = self._op_xor
self._op_raw["__and__"] = self._op_and
self._op_raw["__mod__"] = self._op_mod
@staticmethod
def _op_add(*args):
return reduce(operator.__add__, args)
@staticmethod
def _op_sub(*args):
return reduce(operator.__sub__, args)
@staticmethod
def _op_mul(*args):
return reduce(operator.__mul__, args)
@staticmethod
def _op_or(*args):
return reduce(operator.__or__, args)
@staticmethod
def _op_xor(*args):
return reduce(operator.__xor__, args)
@staticmethod
def _op_and(*args):
return reduce(operator.__and__, args)
@staticmethod
def _op_mod(*args):
return reduce(operator.__mod__, args)
[docs] def convert(self, expr):
return Backend.convert(self, expr.ite_excavated if isinstance(expr, Base) else expr)
def _convert(self, a):
if isinstance(a, numbers.Number):
return a
elif isinstance(a, bool):
return TrueResult() if a else FalseResult()
if isinstance(a, (StridedInterval, DiscreteStridedIntervalSet, ValueSet)):
return a
if isinstance(a, BoolResult):
return a
# Not supported
raise BackendError()
def _eval(self, expr, n, extra_constraints=(), solver=None, model_callback=None):
if isinstance(expr, StridedInterval):
return expr.eval(n)
elif isinstance(expr, ValueSet):
return expr.eval(n)
elif isinstance(expr, BoolResult):
return expr.value
else:
raise BackendError("Unsupported type %s" % type(expr))
def _min(self, expr, extra_constraints=(), signed=False, solver=None, model_callback=None):
# TODO: signed min
if isinstance(expr, StridedInterval):
if expr.is_top:
# TODO: Return
return 0
return expr.min
elif isinstance(expr, ValueSet):
return expr.min
else:
raise BackendError("Unsupported expr type %s" % type(expr))
def _max(self, expr, extra_constraints=(), signed=False, solver=None, model_callback=None):
# TODO: signed max
if isinstance(expr, StridedInterval):
if expr.is_top:
# TODO:
return StridedInterval.max_int(expr.bits)
return expr.max
elif isinstance(expr, ValueSet):
return expr.max
else:
raise BackendError("Unsupported expr type %s" % type(expr))
def _solution(self, obj, v, extra_constraints=(), solver=None, model_callback=None):
if isinstance(obj, BoolResult):
return len(set(v.value) & set(obj.value)) > 0
if isinstance(obj, StridedInterval):
return not obj.intersection(v).is_empty
if isinstance(obj, ValueSet):
for _, si in obj.items():
if not si.intersection(v).is_empty:
return True
return False
raise NotImplementedError(type(obj).__name__)
def _has_true(self, o, extra_constraints=(), solver=None, model_callback=None):
return BoolResult.has_true(o)
def _has_false(self, o, extra_constraints=(), solver=None, model_callback=None):
return BoolResult.has_false(o)
def _is_true(self, o, extra_constraints=(), solver=None, model_callback=None):
return BoolResult.is_true(o)
def _is_false(self, o, extra_constraints=(), solver=None, model_callback=None):
return BoolResult.is_false(o)
#
# Backend Operations
#
[docs] def simplify(self, e):
raise BackendError("nope")
def _identical(self, a, b):
if type(a) != type(b):
return False
return a.identical(b)
def _unique(self, obj): # pylint:disable=unused-argument,no-self-use
if isinstance(obj, StridedInterval):
return obj.unique
elif isinstance(obj, ValueSet):
return obj.unique
else:
raise BackendError("Not supported type of operand %s" % type(obj))
def _cardinality(self, a): # pylint:disable=unused-argument,no-self-use
return a.cardinality
[docs] def name(self, a):
if isinstance(a, StridedInterval):
return a.name
else:
return None
[docs] def apply_annotation(self, bo, annotation):
"""
Apply an annotation on the backend object.
:param BackendObject bo: The backend object.
:param Annotation annotation: The annotation to be applied
:return: A new BackendObject
:rtype: BackendObject
"""
# Currently we only support RegionAnnotation
if not isinstance(annotation, RegionAnnotation):
return bo
if not isinstance(bo, ValueSet):
# Convert it to a ValueSet first
# Note that the original value is not kept at all. If you want to convert a StridedInterval to a ValueSet,
# you gotta do the conversion by calling AST.annotate() from outside.
bo = ValueSet.empty(bo.bits)
return bo.apply_annotation(annotation)
[docs] def BVV(self, ast): # pylint:disable=unused-argument,no-self-use
if ast.args[0] is None:
return StridedInterval.empty(ast.args[1])
else:
return CreateStridedInterval(bits=ast.args[1], stride=0, lower_bound=ast.args[0], upper_bound=ast.args[0])
[docs] @staticmethod
def BoolV(ast): # pylint:disable=unused-argument
return TrueResult() if ast.args[0] else FalseResult()
[docs] @staticmethod
def And(a, *args):
return reduce(operator.__and__, args, a)
[docs] @staticmethod
def Not(a):
return ~a
[docs] @staticmethod
@normalize_arg_order
def ULT(a, b):
return a.ULT(b)
[docs] @staticmethod
@normalize_arg_order
def ULE(a, b):
return a.ULE(b)
[docs] @staticmethod
@normalize_arg_order
def UGT(a, b):
return a.UGT(b)
[docs] @staticmethod
@normalize_arg_order
def UGE(a, b):
return a.UGE(b)
[docs] @staticmethod
@normalize_arg_order
def SLT(a, b):
return a.SLT(b)
[docs] @staticmethod
@normalize_arg_order
def SLE(a, b):
return a.SLE(b)
[docs] @staticmethod
@normalize_arg_order
def SGT(a, b):
return a.SGT(b)
[docs] @staticmethod
@normalize_arg_order
def SGE(a, b):
return a.SGE(b)
[docs] @staticmethod
def BVS(ast): # pylint:disable=unused-argument
size = ast.size()
name, mn, mx, stride, uninitialized, discrete_set, max_card = ast.args
return CreateStridedInterval(
name=name,
bits=size,
lower_bound=mn,
upper_bound=mx,
stride=stride,
uninitialized=uninitialized,
discrete_set=discrete_set,
discrete_set_max_cardinality=max_card,
)
[docs] def If(self, cond, t, f):
if not self.has_true(cond):
return f
elif not self.has_false(cond):
return t
else:
return t.union(f)
# TODO: Implement other operations!
[docs] @staticmethod
def Or(*args):
first = args[0]
others = args[1:]
for o in others:
first = first.union(o)
return first
@staticmethod
def __rshift__(expr, shift_amount): # pylint:disable=unexpected-special-method-signature
return expr.__rshift__(shift_amount)
[docs] @staticmethod
def LShR(expr, shift_amount):
return expr.LShR(shift_amount)
[docs] @staticmethod
def Concat(*args):
ret = None
for expr in args:
if type(expr) not in {
StridedInterval,
DiscreteStridedIntervalSet,
ValueSet,
}: # pylint:disable=unidiomatic-typecheck
raise BackendError("Unsupported expr type %s" % type(expr))
ret = ret.concat(expr) if ret is not None else expr
return ret
[docs] @staticmethod
def SignExt(*args):
new_bits = args[0]
expr = args[1]
if type(expr) not in {StridedInterval, DiscreteStridedIntervalSet}: # pylint:disable=unidiomatic-typecheck
raise BackendError("Unsupported expr type %s" % type(expr))
return expr.sign_extend(new_bits + expr.bits)
[docs] @staticmethod
def ZeroExt(*args):
new_bits = args[0]
expr = args[1]
if type(expr) not in {StridedInterval, DiscreteStridedIntervalSet}: # pylint:disable=unidiomatic-typecheck
raise BackendError("Unsupported expr type %s" % type(expr))
return expr.zero_extend(new_bits + expr.bits)
[docs] @staticmethod
def Reverse(arg):
if type(arg) not in {
StridedInterval,
DiscreteStridedIntervalSet,
ValueSet,
}: # pylint:disable=unidiomatic-typecheck
raise BackendError("Unsupported expr type %s" % type(arg))
return arg.reverse()
[docs] @convert_args
def union(self, ast): # pylint:disable=unused-argument,no-self-use
if len(ast.args) != 2:
raise BackendError("Incorrect number of arguments (%d) passed to BackendVSA.union()." % len(ast.args))
ret = ast.args[0].union(ast.args[1])
if ret is NotImplemented:
ret = ast.args[1].union(ast.args[0])
return ret
[docs] @convert_args
def intersection(self, ast): # pylint:disable=unused-argument,no-self-use
if len(ast.args) != 2:
raise BackendError(
"Incorrect number of arguments (%d) passed to BackendVSA.intersection()." % len(ast.args)
)
ret = None
for arg in ast.args:
if ret is None:
ret = arg
else:
ret = ret.intersection(arg)
return ret
[docs] @convert_args
def widen(self, ast): # pylint:disable=unused-argument,no-self-use
if len(ast.args) != 2:
raise BackendError("Incorrect number of arguments (%d) passed to BackendVSA.widen()." % len(ast.args))
ret = ast.args[0].widen(ast.args[1])
if ret is NotImplemented:
ret = ast.args[1].widen(ast.args[0])
return ret
[docs] @staticmethod
def CreateTopStridedInterval(bits, name=None, uninitialized=False): # pylint:disable=unused-argument,no-self-use
return StridedInterval.top(bits, name, uninitialized=uninitialized)
[docs] def constraint_to_si(self, expr):
return Balancer(self, expr).compat_ret
from ..ast.base import Base
from ..operations import backend_operations_vsa_compliant, expression_set_operations
from ..vsa import (
StridedInterval,
CreateStridedInterval,
DiscreteStridedIntervalSet,
ValueSet,
AbstractLocation,
BoolResult,
TrueResult,
FalseResult,
)
from ..balancer import Balancer
BackendVSA.CreateStridedInterval = staticmethod(CreateStridedInterval)