import functools
import numbers
import itertools
from .strided_interval import StridedInterval
DEFAULT_MAX_CARDINALITY_WITHOUT_COLLAPSING = 256 # We don't collapse until there are more than this many SIs
[docs]def apply_on_each_si(f):
@functools.wraps(f)
def operator(self, o=None):
if o is None:
# This is an unary operator.
new_si_set = set()
for a in self._si_set:
new_si_set.add(getattr(a, f.__name__)())
ret = DiscreteStridedIntervalSet(bits=self.bits, si_set=new_si_set)
return ret.normalize()
if isinstance(o, DiscreteStridedIntervalSet):
# We gotta apply the operation on each single object
new_si_set = set()
for a in self._si_set:
for b in o._si_set:
new_si_set.add(getattr(a, f.__name__)(b))
ret = DiscreteStridedIntervalSet(bits=self.bits, si_set=new_si_set)
return ret.normalize()
elif isinstance(o, (StridedInterval, numbers.Number, BVV)):
new_si_set = set()
for si in self._si_set:
new_si_set.add(getattr(si, f.__name__)(o))
ret = DiscreteStridedIntervalSet(bits=self.bits, si_set=new_si_set)
return ret.normalize()
else:
raise ClaripyVSAOperationError("Unsupported operand type %s" % (type(o)))
return operator
[docs]def convert_operand_to_si(f):
@functools.wraps(f)
def converter(self, o):
if isinstance(o, BVV):
o = o.value
if isinstance(o, numbers.Number):
o = StridedInterval(bits=self.bits, stride=0, lower_bound=o, upper_bound=o)
return f(self, o)
return converter
[docs]def collapse_operand(f):
@functools.wraps(f)
def collapser(self, o):
if isinstance(o, DiscreteStridedIntervalSet):
return f(self, o.collapse())
else:
return f(self, o)
return collapser
dsis_id_ctr = itertools.count()
[docs]class DiscreteStridedIntervalSet(StridedInterval):
"""
A DiscreteStridedIntervalSet represents one or more discrete StridedInterval instances.
"""
[docs] def __init__(self, name=None, bits=0, si_set=None, max_cardinality=None):
if name is None:
name = "DSIS_%d" % next(dsis_id_ctr)
# Initialize the set for strided intervals
if si_set is not None and len(si_set):
self._si_set = si_set
else:
self._si_set = set()
self._max_cardinality = (
DEFAULT_MAX_CARDINALITY_WITHOUT_COLLAPSING if max_cardinality is None else max_cardinality
)
StridedInterval.__init__(self, name=name, bits=bits)
# Update lower_bound and upper_bound
for si in self._si_set:
self._update_bounds(si)
self._update_bits(si)
#
# Properties
#
def __repr__(self):
representatives = ", ".join([i.__repr__() for i in list(self._si_set)[:5]])
if self.number_of_values > 5:
representatives += ", ..."
return "%s<%d>(%d){%s}" % (self._name, self._bits, self.number_of_values, representatives)
@property
def cardinality(self):
"""
This is an over-approximation of the cardinality of this DSIS.
:return:
"""
cardinality = 0
for si in self._si_set:
cardinality += si.cardinality
return cardinality
@property
def number_of_values(self):
return len(self._si_set)
@property
def stride(self):
return self.collapse().stride
#
# Special methods
#
[docs] def should_collapse(self):
return self.cardinality > self._max_cardinality
[docs] def collapse(self):
"""
Collapse into a StridedInterval instance.
:return: A new StridedInterval instance.
"""
if self.cardinality:
r = None
for si in self._si_set:
r = r._union(si) if r is not None else si
return r
else:
# This is an empty StridedInterval...
return StridedInterval.empty(self._bits)
[docs] def normalize(self):
"""
Return the collapsed object if ``should_collapse()`` is True, otherwise return self.
:return: A DiscreteStridedIntervalSet object.
"""
if self.should_collapse():
return self.collapse()
elif self.number_of_values == 1:
return list(self._si_set)[0]
else:
for si in self._si_set:
self._update_bits(si)
return self
[docs] def copy(self):
copied = DiscreteStridedIntervalSet(
bits=self._bits, si_set=self._si_set.copy(), max_cardinality=self._max_cardinality
)
return copied
def __hash__(self):
return id(self) # ...not sure how to do this. these objects are mutable.
#
# Operations
#
# Logical operations
@convert_operand_to_si
@collapse_operand
def __eq__(self, o):
"""
Operation ==
:param o: The other operand.
:return: An instance of BoolResult.
"""
return self.collapse() == o
@convert_operand_to_si
@collapse_operand
def __ne__(self, o):
"""
Operation !=
:param o: The other operand.
:return: An instance of BoolResult.
"""
return self.collapse() != o
@convert_operand_to_si
@collapse_operand
def __gt__(self, o):
"""
Operation >
:param o: The other operand.
:return: An instance of BoolResult.
"""
return self.collapse() > o
@convert_operand_to_si
@collapse_operand
def __le__(self, o):
"""
Operation <=
:param o: The other operand.
:return: An instance of BoolResult.
"""
return self.collapse() <= o
@convert_operand_to_si
@collapse_operand
def __lt__(self, o):
"""
Operation <
:param o: The other operand.
:return: An instance of BoolResult.
"""
return self.collapse() < o
# Bitwise operations
@convert_operand_to_si
@apply_on_each_si
def __and__(self, o):
"""
Operation &
:param o: The other operand.
:return: An instance of DiscreteStridedIntervalSet.
"""
def __rand__(self, o):
return self.__and__(o)
@convert_operand_to_si
@apply_on_each_si
def __or__(self, o):
"""
Operation |
:param o: The other operand.
:return: An instance of DiscreteStridedIntervalSet.
"""
def __ror__(self, o):
return self.__or__(o)
@convert_operand_to_si
@apply_on_each_si
def __xor__(self, o):
"""
Operation ^
:param o: The other operand.
:return: An instance of DiscreteStridedIntervalSet.
"""
def __rxor__(self, o):
return self.__xor__(o)
def __neg__(self):
"""
Operation ~
:return: The negated value.
"""
new_si_set = set()
for si in self._si_set:
new_si_set.add(~si)
r = DiscreteStridedIntervalSet(bits=self._bits, si_set=new_si_set)
return r.normalize()
def __invert__(self):
"""
Operation ~
:return: The negated value.
"""
return self.__neg__()
@apply_on_each_si
def __lshift__(self, o):
"""
Operation <<
:param o: The other operand.
:return: An instance of DiscreteStridedIntervalSet.
"""
@apply_on_each_si
def __rshift__(self, o):
"""
Operation >>
:param o: The other operand.
:return: An instance of DiscreteStridedIntervalSet.
"""
[docs] @apply_on_each_si
def concat(self, b):
"""
Operation concat
:param b: The other operand to concatenate with.
:return: The concatenated value.
"""
# Arithmetic operations
@convert_operand_to_si
@apply_on_each_si
def __add__(self, o):
"""
Operation +
:param o: The other operand.
:return:
"""
def __radd__(self, o):
return self.__add__(o)
@convert_operand_to_si
@apply_on_each_si
def __sub__(self, o):
"""
Operation -
:param o: The other operand.
:return:
"""
def __rsub__(self, o):
return self.__sub__(o)
@convert_operand_to_si
@apply_on_each_si
def __floordiv__(self, o):
"""
Operation /
:param o: The other operand.
:return:
"""
def __truediv__(self, o):
return self.__floordiv__(o) # floats not welcome
def __rfloordiv__(self, o):
return self.__floordiv__(o)
def __rtruediv__(self, o):
return self.__rfloordiv__(o)
@convert_operand_to_si
@apply_on_each_si
def __mod__(self, o):
"""
Operation %
:param o: The other operand.
:return:
"""
def __rmod__(self, o):
return self.__mod__(o)
# Evaluation
[docs] def eval(self, n, signed=False):
"""
:param n:
:param signed:
:return:
"""
# FIXME: "signed" is silently ignored now
ret = set()
for si in self._si_set:
ret |= set(si.eval(n))
if len(ret) >= n:
break
return list(ret)[:n]
# Set operations
[docs] def union(self, b):
if isinstance(b, DiscreteStridedIntervalSet):
return self._union_with_dsis(b)
elif isinstance(b, StridedInterval):
if not b.is_empty:
return self._union_with_si(b)
else:
return self.copy()
elif isinstance(b, ValueSet):
return b.union(self)
else:
raise ClaripyVSAOperationError("Unsupported operand type %s for operation union." % type(b))
[docs] def intersection(self, b):
if isinstance(b, DiscreteStridedIntervalSet):
return self._intersection_with_dsis(b)
elif isinstance(b, StridedInterval):
return self._intersection_with_si(b)
else:
raise ClaripyVSAOperationError("Unsupported operand type %s for operation intersection." % type(b))
# Other operations
[docs] @apply_on_each_si
def reverse(self):
"""
Operation Reverse
:return: None
"""
[docs] @apply_on_each_si
def sign_extend(self, new_length):
"""
Operation SignExt
:param new_length: The length to extend to.
:return: SignExtended value.
"""
[docs] @apply_on_each_si
def zero_extend(self, new_length):
"""
Operation ZeroExt
:param new_length: The length to extend to.
:return: ZeroExtended value.
"""
[docs] @collapse_operand
def widen(self, b):
"""
Widening operator.
:param b: The other operand.
:return: The widened result.
"""
return self.collapse().widen(b)
#
# Private methods
#
def _union_with_si(self, si):
"""
Union with another StridedInterval.
:param si:
:return:
"""
dsis = self.copy()
for si_ in dsis._si_set:
if BoolResult.is_true(si_ == si):
return dsis
dsis._si_set.add(si)
dsis._update_bounds(si)
return dsis.normalize()
def _union_with_dsis(self, dsis):
"""
Union with another DiscreteStridedIntervalSet.
:param dsis:
:return:
"""
copied = self.copy()
for a in dsis._si_set:
copied = copied.union(a)
if isinstance(copied, DiscreteStridedIntervalSet):
copied._update_bounds(dsis)
return copied.normalize()
def _intersection_with_si(self, si):
"""
Intersection with another :class:`StridedInterval`.
:param si: The other operand
:return:
"""
new_si_set = set()
for si_ in self._si_set:
r = si_.intersection(si)
new_si_set.add(r)
if len(new_si_set):
ret = DiscreteStridedIntervalSet(bits=self.bits, si_set=new_si_set)
if ret.should_collapse():
return ret.collapse()
else:
return ret
else:
# There is no intersection between two operands
return StridedInterval.empty(self.bits)
def _intersection_with_dsis(self, dsis):
"""
Intersection with another :class:`DiscreteStridedIntervalSet`.
:param dsis: The other operand.
:return:
"""
new_si_set = set()
for si in dsis._si_set:
r = self._intersection_with_si(si)
if isinstance(r, StridedInterval):
if not r.is_empty:
new_si_set.add(r)
else: # r is a DiscreteStridedIntervalSet
new_si_set |= r._si_set
if len(new_si_set):
ret = DiscreteStridedIntervalSet(bits=self.bits, si_set=new_si_set)
return ret.normalize()
else:
return StridedInterval.empty(self.bits)
def _update_bounds(self, val):
if not isinstance(val, StridedInterval):
raise ClaripyVSAOperationError("Unsupported operand type %s." % type(val))
if val._lower_bound is not None:
if self._lower_bound is None:
self._lower_bound = val.lower_bound
elif val.lower_bound < self._lower_bound:
self._lower_bound = val.lower_bound
if val._upper_bound is not None:
if self._upper_bound is None:
self._upper_bound = val.upper_bound
elif val.upper_bound > self._upper_bound:
self._upper_bound = val.upper_bound
def _update_bits(self, val):
if not isinstance(val, StridedInterval):
raise ClaripyVSAOperationError("Unsupported operand type %s." % type(val))
self._bits = val.bits
from .errors import ClaripyVSAOperationError
from .bool_result import BoolResult
from ..bv import BVV
from .valueset import ValueSet