import functools
import itertools
import logging
import math
import numbers
from functools import reduce
logger = logging.getLogger("claripy.vsa.strided_interval")
from ..backend_object import BackendObject
[docs]def reversed_processor(f):
def processor(self, *args, **kwargs):
if self._reversed:
# Reverse it for real. We have to accept the precision penalty.
reversed_thing = self._reverse()
return f(reversed_thing, *args, **kwargs)
return f(self, *args, **kwargs)
return processor
[docs]def normalize_types(f):
@functools.wraps(f)
def normalizer(self, o):
"""
Convert any object to an object that we can process.
"""
# Special handler for union
if f.__name__ == "union" and isinstance(o, DiscreteStridedIntervalSet):
return o.union(self)
if isinstance(o, ValueSet) or isinstance(o, DiscreteStridedIntervalSet):
# if it's singlevalued, we can convert it to a StridedInterval
if o.cardinality == 1:
o = o.stridedinterval()
else:
# It should be put to o.__radd__(self) when o is a ValueSet
return NotImplemented
if isinstance(o, Base) or isinstance(self, Base):
return NotImplemented
if isinstance(self, BVV):
self = self.value
if isinstance(o, BVV):
o = o.value
if isinstance(o, numbers.Number):
min_bits = self.bits if hasattr(self, "bits") else 64
repr_bits = StridedInterval.min_bits(o)
n_bits = max(repr_bits, min_bits)
si = StridedInterval(bits=n_bits, stride=0, lower_bound=o, upper_bound=o)
if o < 0:
si.upper_bound &= (1 << n_bits) - 1
si.lower_bound &= (1 << n_bits) - 1
mask = (2**n_bits - 1) - (2**repr_bits - 1)
si.lower_bound |= mask
si.upper_bound |= mask
o = si
if isinstance(self, numbers.Number):
min_bits = o.bits if hasattr(o, "bits") else 64
repr_bits = StridedInterval.min_bits(self)
n_bits = max(repr_bits, min_bits)
si = StridedInterval(bits=n_bits, stride=0, lower_bound=self, upper_bound=self)
if self < 0:
si.upper_bound &= (1 << n_bits) - 1
si.lower_bound &= (1 << n_bits) - 1
mask = (2**n_bits - 1) - (2**repr_bits - 1)
si.lower_bound |= mask
si.upper_bound |= mask
self = si
if f.__name__ != "concat":
# Make sure they have the same length
common_bits = max(o.bits, self.bits)
if o.bits < common_bits:
o = o.agnostic_extend(common_bits)
if self.bits < common_bits:
self = self.agnostic_extend(common_bits)
#
# Handling cases where one or both operands are reversed
#
# Assumption: Other than a few operations ("Concat" is one of them), reverse only comes from endianness
# conversion.
#
# if reverse(a) is lossless, reverse(a) op b -> a._reverse() op b
# a op reverse(b) -> reverse(a._reverse() op b)
# if reverse(b) is lossless, reverse(a) op b -> reverse(a op b._reverse())
# a op reverse(b) -> a op b._reverse()
# else:
# Force reverse and bear the loss in precision
def _lossless_reverse(a):
return a.uninitialized or a.is_top or a.is_integer
reverse_back = False
if f.__name__ in {"concat"}:
# TODO: Some optimizations can be applied to concat
if self._reversed:
self = self._reverse()
if o._reversed:
o = o._reverse()
else:
if not self._reversed and not o._reversed:
pass
elif self._reversed and o._reversed:
reverse_back = True
self = self.copy()
self._reversed = False
o = o.copy()
o._reversed = False
else:
# one of the operands is reversed
if _lossless_reverse(self):
self = self._reverse()
if o._reversed:
reverse_back = True
elif _lossless_reverse(o):
o = o._reverse()
if self._reversed:
reverse_back = True
else:
# Force reverse
if self._reversed:
self = self._reverse()
if o._reversed:
o = o._reverse()
ret = f(self, o)
if isinstance(ret, StridedInterval):
if isinstance(self, StridedInterval) and self.uninitialized:
ret.uninitialized = True
if isinstance(o, StridedInterval) and o.uninitialized:
ret.uninitialized = True
if reverse_back and isinstance(ret, StridedInterval):
ret = ret.reverse()
return ret
return normalizer
si_id_ctr = itertools.count()
# Whether DiscreteStridedIntervalSet should be used or not. Sometimes we manually set it to False to allow easy
# implementation of test cases.
allow_dsis = False
[docs]class WarrenMethods:
"""
Methods as suggested in book.
Hackers Delight.
"""
[docs] @staticmethod
def min_or(a, b, c, d, w):
"""
Lower bound of result of ORing 2-intervals.
:param a: Lower bound of first interval
:param b: Upper bound of first interval
:param c: Lower bound of second interval
:param d: Upper bound of second interval
:param w: bit width
:return: Lower bound of ORing 2-intervals
"""
m = 1 << (w - 1)
while m != 0:
if ((~a) & c & m) != 0:
temp = (a | m) & -m
if temp <= b:
a = temp
break
elif (a & (~c) & m) != 0:
temp = (c | m) & -m
if temp <= d:
c = temp
break
m >>= 1
return a | c
[docs] @staticmethod
def max_or(a, b, c, d, w):
"""
Upper bound of result of ORing 2-intervals.
:param a: Lower bound of first interval
:param b: Upper bound of first interval
:param c: Lower bound of second interval
:param d: Upper bound of second interval
:param w: bit width
:return: Upper bound of ORing 2-intervals
"""
m = 1 << (w - 1)
while m != 0:
if (b & d & m) != 0:
temp = (b - m) | (m - 1)
if temp >= a:
b = temp
break
temp = (d - m) | (m - 1)
if temp >= c:
d = temp
break
m >>= 1
return b | d
[docs] @staticmethod
def min_and(a, b, c, d, w):
"""
Lower bound of result of ANDing 2-intervals.
:param a: Lower bound of first interval
:param b: Upper bound of first interval
:param c: Lower bound of second interval
:param d: Upper bound of second interval
:param w: bit width
:return: Lower bound of ANDing 2-intervals
"""
m = 1 << (w - 1)
while m != 0:
if (~a & ~c & m) != 0:
temp = (a | m) & -m
if temp <= b:
a = temp
break
temp = (c | m) & -m
if temp <= d:
c = temp
break
m >>= 1
return a & c
[docs] @staticmethod
def max_and(a, b, c, d, w):
"""
Upper bound of result of ANDing 2-intervals.
:param a: Lower bound of first interval
:param b: Upper bound of first interval
:param c: Lower bound of second interval
:param d: Upper bound of second interval
:param w: bit width
:return: Upper bound of ANDing 2-intervals
"""
m = 1 << (w - 1)
while m != 0:
if ((~d) & b & m) != 0:
temp = (b & ~m) | (m - 1)
if temp >= a:
b = temp
break
elif (d & (~b) & m) != 0:
temp = (d & ~m) | (m - 1)
if temp >= c:
d = temp
break
m >>= 1
return b & d
[docs] @staticmethod
def min_xor(a, b, c, d, w):
"""
Lower bound of result of XORing 2-intervals.
:param a: Lower bound of first interval
:param b: Upper bound of first interval
:param c: Lower bound of second interval
:param d: Upper bound of second interval
:param w: bit width
:return: Lower bound of XORing 2-intervals
"""
m = 1 << (w - 1)
while m != 0:
if ((~a) & c & m) != 0:
temp = (a | m) & -m
if temp <= b:
a = temp
elif (a & (~c) & m) != 0:
temp = (c | m) & -m
if temp <= d:
c = temp
m >>= 1
return a ^ c
[docs] @staticmethod
def max_xor(a, b, c, d, w):
"""
Upper bound of result of XORing 2-intervals.
:param a: Lower bound of first interval
:param b: Upper bound of first interval
:param c: Lower bound of second interval
:param d: Upper bound of second interval
:param w: bit width
:return: Upper bound of XORing 2-intervals
"""
m = 1 << (w - 1)
while m != 0:
if (b & d & m) != 0:
temp = (b - m) | (m - 1)
if temp >= a:
b = temp
else:
temp = (d - m) | (m - 1)
if temp >= c:
d = temp
m >>= 1
return b ^ d
[docs]class StridedInterval(BackendObject):
"""
A Strided Interval is represented in the following form::
<bits> stride[lower_bound, upper_bound]
For more details, please refer to relevant papers like TIE and WYSINWYE.
This implementation is signedness-agostic, please refer to [1] *Signedness-Agnostic Program Analysis: Precise Integer
Bounds for Low-Level Code* by Jorge A. Navas, etc. for more details.
Note that this implementation only takes hint from [1]. Such a work has been improved to be more precise
(and still sound) when dealing with strided intervals.
DO NOT expect to see a 1-to-1 reproduction of [1].
Thanks all corresponding authors for their outstanding works.
"""
[docs] def __init__(
self, name=None, bits=0, stride=None, lower_bound=None, upper_bound=None, uninitialized=False, bottom=False
):
self._name = name
if self._name is None:
self._name = "SI_%d" % next(si_id_ctr)
self._bits = bits
self._stride = stride if stride is not None else 1
self._lower_bound = lower_bound if lower_bound is not None else 0
self._upper_bound = upper_bound if upper_bound is not None else (2**bits - 1)
if lower_bound is not None and not isinstance(lower_bound, numbers.Number):
raise ClaripyVSAError("'lower_bound' must be an int. %s is not supported." % type(lower_bound))
if upper_bound is not None and not isinstance(upper_bound, numbers.Number):
raise ClaripyVSAError("'upper_bound' must be an int. %s is not supported." % type(upper_bound))
self._reversed = False
self._is_bottom = bottom
self.uninitialized = uninitialized
if self._upper_bound is not None and bits == 0:
self._bits = self._min_bits()
if self._upper_bound is None:
self._upper_bound = StridedInterval.max_int(self.bits)
if self._lower_bound is None:
self._lower_bound = StridedInterval.min_int(self.bits)
# For lower bound and upper bound, we always store the unsigned version
self._lower_bound &= 2**bits - 1
self._upper_bound &= 2**bits - 1
self.normalize()
[docs] def copy(self):
si = StridedInterval(
name=self._name,
bits=self.bits,
stride=self.stride,
lower_bound=self.lower_bound,
upper_bound=self.upper_bound,
uninitialized=self.uninitialized,
bottom=self._is_bottom,
)
si._reversed = self._reversed
return si
[docs] def nameless_copy(self):
si = StridedInterval(
name=None,
bits=self.bits,
stride=self.stride,
lower_bound=self.lower_bound,
upper_bound=self.upper_bound,
uninitialized=self.uninitialized,
bottom=self._is_bottom,
)
si._reversed = self._reversed
return si
[docs] def normalize(self):
if self.bits == 8 and self.reversed:
self._reversed = False
if self.is_empty:
return self
if self.lower_bound == self.upper_bound:
self._stride = 0
if self.lower_bound < 0:
self.lower_bound &= 2**self.bits - 1
self._normalize_top()
if self._stride < 0:
raise Exception("Why does this happen?")
return self
[docs] def eval(self, n, signed=False):
"""
Evaluate this StridedInterval to obtain a list of concrete integers.
:param n: Upper bound for the number of concrete integers
:param signed: Treat this StridedInterval as signed or unsigned
:return: A list of at most `n` concrete integers
"""
if self.is_empty:
# no value is available
return []
if self._reversed:
return self._reverse().eval(n, signed=signed)
results = []
if self.stride == 0 and n > 0:
results.append(self.lower_bound)
else:
if signed:
# View it as a signed integer
bounds = self._signed_bounds()
else:
# View it as an unsigned integer
bounds = self._unsigned_bounds()
for lb, ub in bounds:
while len(results) < n and lb <= ub:
results.append(lb)
lb += self.stride # It will not overflow
return results
[docs] def solution(self, b):
"""
Checks whether an integer is solution of the current strided Interval
:param b: integer to check
:return: True if b belongs to the current Strided Interval, False otherwhise
"""
if isinstance(b, numbers.Number):
b = StridedInterval(lower_bound=b, upper_bound=b, stride=0, bits=self.bits)
else:
raise ClaripyOperationError(
'Oops, Strided intervals cannot be passed as "' "parameter to function solution. To implement"
)
if self.intersection(b).is_empty:
return False
return True
#
# Private methods
#
def __hash__(self):
return hash(
(
f"{self.bits:x} {self.lower_bound:x} {self.upper_bound:x} {self.stride:x}",
self._reversed,
self.uninitialized,
)
)
def _normalize_top(self):
if self.lower_bound == self._modular_add(self.upper_bound, 1, self.bits) and self.stride == 1:
# This is a TOP!
# Normalize it
self.lower_bound = 0
self.upper_bound = self.max_int(self.bits)
def _ssplit(self):
"""
Split `self` at the south pole, which is the same as in unsigned arithmetic.
When returning two StridedIntervals (which means a splitting occurred), it is guaranteed that the first
StridedInterval is on the right side of the south pole.
:return: a list of split StridedIntervals, that contains either one or two StridedIntervals
"""
south_pole_right = self.max_int(self.bits) # 111...1
# south_pole_left = 0
# Is `self` straddling the south pole?
if self.upper_bound < self.lower_bound:
# It straddles the south pole!
a_upper_bound = south_pole_right - ((south_pole_right - self.lower_bound) % self.stride)
a = StridedInterval(
bits=self.bits,
stride=self.stride,
lower_bound=self.lower_bound,
upper_bound=a_upper_bound,
uninitialized=self.uninitialized,
)
b_lower_bound = self._modular_add(a_upper_bound, self.stride, self.bits)
b = StridedInterval(
bits=self.bits,
stride=self.stride,
lower_bound=b_lower_bound,
upper_bound=self.upper_bound,
uninitialized=self.uninitialized,
)
return [a, b]
else:
return [self.copy()]
def _nsplit(self):
"""
Split `self` at the north pole, which is the same as in signed arithmetic.
:return: A list of split StridedIntervals
"""
north_pole_left = self.max_int(self.bits - 1) # 01111...1
north_pole_right = 2 ** (self.bits - 1) # 1000...0
# Is `self` straddling the north pole?
straddling = False
if self.upper_bound >= north_pole_right:
if self.lower_bound > self.upper_bound:
# Yes it does!
straddling = True
elif self.lower_bound <= north_pole_left:
straddling = True
else:
if self.lower_bound > self.upper_bound and self.lower_bound <= north_pole_left:
straddling = True
if straddling:
a_upper_bound = north_pole_left - ((north_pole_left - self.lower_bound) % self.stride)
a = StridedInterval(
bits=self.bits,
stride=self.stride,
lower_bound=self.lower_bound,
upper_bound=a_upper_bound,
uninitialized=self.uninitialized,
)
b_lower_bound = a_upper_bound + self.stride
b = StridedInterval(
bits=self.bits,
stride=self.stride,
lower_bound=b_lower_bound,
upper_bound=self.upper_bound,
uninitialized=self.uninitialized,
)
return [a, b]
else:
return [self.copy()]
def _psplit(self):
"""
Split `self` at both north and south poles.
:return: A list of split StridedIntervals
"""
nsplit_list = self._nsplit()
psplit_list = []
for si in nsplit_list:
psplit_list.extend(si._ssplit())
return psplit_list
def _signed_bounds(self):
"""
Get lower bound and upper bound for `self` in signed arithmetic.
:return: a list of (lower_bound, upper_bound) tuples
"""
nsplit = self._nsplit()
if len(nsplit) == 1:
lb = nsplit[0].lower_bound
ub = nsplit[0].upper_bound
lb = self._unsigned_to_signed(lb, self.bits)
ub = self._unsigned_to_signed(ub, self.bits)
return [(lb, ub)]
elif len(nsplit) == 2:
# nsplit[0] is on the left hemisphere, and nsplit[1] is on the right hemisphere
# The left one
lb_1 = nsplit[0].lower_bound
ub_1 = nsplit[0].upper_bound
# The right one
lb_2 = nsplit[1].lower_bound
ub_2 = nsplit[1].upper_bound
# Then convert them to negative numbers
lb_2 = self._unsigned_to_signed(lb_2, self.bits)
ub_2 = self._unsigned_to_signed(ub_2, self.bits)
return [(lb_1, ub_1), (lb_2, ub_2)]
else:
raise Exception("WTF")
def _unsigned_bounds(self):
"""
Get lower bound and upper bound for `self` in unsigned arithmetic.
:return: a list of (lower_bound, upper_bound) tuples.
"""
ssplit = self._ssplit()
if len(ssplit) == 1:
lb = ssplit[0].lower_bound
ub = ssplit[0].upper_bound
return [(lb, ub)]
elif len(ssplit) == 2:
# ssplit[0] is on the left hemisphere, and ssplit[1] is on the right hemisphere
lb_1 = ssplit[0].lower_bound
ub_1 = ssplit[0].upper_bound
lb_2 = ssplit[1].lower_bound
ub_2 = ssplit[1].upper_bound
return [(lb_1, ub_1), (lb_2, ub_2)]
else:
raise Exception("WTF")
def _rshift_logical(self, shift_amount):
"""
Logical shift right with a concrete shift amount
:param int shift_amount: Number of bits to shift right.
:return: The new StridedInterval after right shifting
:rtype: StridedInterval
"""
if self.is_empty:
return self
# If straddling the south pole, we'll have to split it into two, perform logical right shift on them
# individually, then union the result back together for better precision. Note that it's an improvement from
# the original WrappedIntervals paper.
ssplit = self._ssplit()
if len(ssplit) == 1:
l = self.lower_bound >> shift_amount
u = self.upper_bound >> shift_amount
stride = max(self.stride >> shift_amount, 1)
return StridedInterval(
bits=self.bits, lower_bound=l, upper_bound=u, stride=stride, uninitialized=self.uninitialized
)
else:
a = ssplit[0]._rshift_logical(shift_amount)
b = ssplit[1]._rshift_logical(shift_amount)
return a.union(b)
def _rshift_arithmetic(self, shift_amount):
"""
Arithmetic shift right with a concrete shift amount
:param int shift_amount: Number of bits to shift right.
:return: The new StridedInterval after right shifting
:rtype: StridedInterval
"""
if self.is_empty:
return self
# If straddling the north pole, we'll have to split it into two, perform arithmetic right shift on them
# individually, then union the result back together for better precision. Note that it's an improvement from
# the original WrappedIntervals paper.
nsplit = self._nsplit()
if len(nsplit) == 1:
# preserve the highest bit :-)
highest_bit_set = self.lower_bound > StridedInterval.signed_max_int(nsplit[0].bits)
l = self.lower_bound >> shift_amount
u = self.upper_bound >> shift_amount
stride = max(self.stride >> shift_amount, 1)
mask = (2**shift_amount - 1) << (self.bits - shift_amount)
if highest_bit_set:
l = l | mask
u = u | mask
if l == u:
stride = 0
return StridedInterval(
bits=self.bits, lower_bound=l, upper_bound=u, stride=stride, uninitialized=self.uninitialized
)
else:
a = nsplit[0]._rshift_arithmetic(shift_amount)
b = nsplit[1]._rshift_arithmetic(shift_amount)
return a.union(b)
#
# Comparison operations
#
[docs] def identical(self, o):
"""
Used to make exact comparisons between two StridedIntervals. Usually it is only used in test cases.
:param o: The other StridedInterval to compare with.
:return: True if they are exactly same, False otherwise.
"""
return (
self.bits == o.bits
and self.stride == o.stride
and self.lower_bound == o.lower_bound
and self.upper_bound == o.upper_bound
)
[docs] @normalize_types
def SLT(self, o):
"""
Signed less than
:param o: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
signed_bounds_1 = self._signed_bounds()
signed_bounds_2 = o._signed_bounds()
ret = []
for lb_1, ub_1 in signed_bounds_1:
for lb_2, ub_2 in signed_bounds_2:
if ub_1 < lb_2:
ret.append(TrueResult())
elif lb_1 >= ub_2:
ret.append(FalseResult())
else:
ret.append(MaybeResult())
if all(r.identical(TrueResult()) for r in ret):
return TrueResult()
elif all(r.identical(FalseResult()) for r in ret):
return FalseResult()
else:
return MaybeResult()
[docs] @normalize_types
def SLE(self, o):
"""
Signed less than or equal to.
:param o: The other operand.
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
signed_bounds_1 = self._signed_bounds()
signed_bounds_2 = o._signed_bounds()
ret = []
for lb_1, ub_1 in signed_bounds_1:
for lb_2, ub_2 in signed_bounds_2:
if ub_1 <= lb_2:
ret.append(TrueResult())
elif lb_1 > ub_2:
ret.append(FalseResult())
else:
ret.append(MaybeResult())
if all(r.identical(TrueResult()) for r in ret):
return TrueResult()
elif all(r.identical(FalseResult()) for r in ret):
return FalseResult()
else:
return MaybeResult()
[docs] @normalize_types
def SGT(self, o):
"""
Signed greater than.
:param o: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
signed_bounds_1 = self._signed_bounds()
signed_bounds_2 = o._signed_bounds()
ret = []
for lb_1, ub_1 in signed_bounds_1:
for lb_2, ub_2 in signed_bounds_2:
if lb_1 > ub_2:
ret.append(TrueResult())
elif ub_1 <= lb_2:
ret.append(FalseResult())
else:
ret.append(MaybeResult())
if all(r.identical(TrueResult()) for r in ret):
return TrueResult()
elif all(r.identical(FalseResult()) for r in ret):
return FalseResult()
else:
return MaybeResult()
[docs] @normalize_types
def SGE(self, o):
"""
Signed greater than or equal to.
:param o: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
signed_bounds_1 = self._signed_bounds()
signed_bounds_2 = o._signed_bounds()
ret = []
for lb_1, ub_1 in signed_bounds_1:
for lb_2, ub_2 in signed_bounds_2:
if lb_1 >= ub_2:
ret.append(TrueResult())
elif ub_1 < lb_2:
ret.append(FalseResult())
else:
ret.append(MaybeResult())
if all(r.identical(TrueResult()) for r in ret):
return TrueResult()
elif all(r.identical(FalseResult()) for r in ret):
return FalseResult()
else:
return MaybeResult()
[docs] @normalize_types
def ULT(self, o):
"""
Unsigned less than.
:param o: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
unsigned_bounds_1 = self._unsigned_bounds()
unsigned_bounds_2 = o._unsigned_bounds()
ret = []
for lb_1, ub_1 in unsigned_bounds_1:
for lb_2, ub_2 in unsigned_bounds_2:
if ub_1 < lb_2:
ret.append(TrueResult())
elif lb_1 >= ub_2:
ret.append(FalseResult())
else:
ret.append(MaybeResult())
if all(r.identical(TrueResult()) for r in ret):
return TrueResult()
elif all(r.identical(FalseResult()) for r in ret):
return FalseResult()
else:
return MaybeResult()
[docs] @normalize_types
def ULE(self, o):
"""
Unsigned less than or equal to.
:param o: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
unsigned_bounds_1 = self._unsigned_bounds()
unsigned_bounds_2 = o._unsigned_bounds()
ret = []
for lb_1, ub_1 in unsigned_bounds_1:
for lb_2, ub_2 in unsigned_bounds_2:
if ub_1 <= lb_2:
ret.append(TrueResult())
elif lb_1 > ub_2:
ret.append(FalseResult())
else:
ret.append(MaybeResult())
if all(r.identical(TrueResult()) for r in ret):
return TrueResult()
elif all(r.identical(FalseResult()) for r in ret):
return FalseResult()
else:
return MaybeResult()
[docs] @normalize_types
def UGT(self, o):
"""
Signed greater than.
:param o: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
unsigned_bounds_1 = self._unsigned_bounds()
unsigned_bounds_2 = o._unsigned_bounds()
ret = []
for lb_1, ub_1 in unsigned_bounds_1:
for lb_2, ub_2 in unsigned_bounds_2:
if lb_1 > ub_2:
ret.append(TrueResult())
elif ub_1 <= lb_2:
ret.append(FalseResult())
else:
ret.append(MaybeResult())
if all(r.identical(TrueResult()) for r in ret):
return TrueResult()
elif all(r.identical(FalseResult()) for r in ret):
return FalseResult()
else:
return MaybeResult()
[docs] @normalize_types
def UGE(self, o):
"""
Unsigned greater than or equal to.
:param o: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
unsigned_bounds_1 = self._unsigned_bounds()
unsigned_bounds_2 = o._unsigned_bounds()
ret = []
for lb_1, ub_1 in unsigned_bounds_1:
for lb_2, ub_2 in unsigned_bounds_2:
if lb_1 >= ub_2:
ret.append(TrueResult())
elif ub_1 < lb_2:
ret.append(FalseResult())
else:
ret.append(MaybeResult())
if all(r.identical(TrueResult()) for r in ret):
return TrueResult()
elif all(r.identical(FalseResult()) for r in ret):
return FalseResult()
else:
return MaybeResult()
[docs] @normalize_types
def eq(self, o):
"""
Equal
:param o: The ohter operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
if self.is_integer and o.is_integer:
# Two integers
if self.lower_bound == o.lower_bound:
# They are equal
return TrueResult()
else:
# They are not equal
return FalseResult()
else:
if self.name == o.name:
return TrueResult() # They are the same guy
si_intersection = self.intersection(o)
if si_intersection.is_empty:
return FalseResult()
else:
return MaybeResult()
#
# Overriding default operators in Python
#
def __len__(self):
"""
Get the length in bits of this variable.
:return:
"""
return self._bits
def __eq__(self, o):
return self.eq(o)
def __ne__(self, o):
return ~(self.eq(o))
def __gt__(self, other):
"""
Unsigned greater than
:param other: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
return self.UGT(other)
def __ge__(self, other):
"""
Unsigned greater than or equal to
:param other: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
return self.UGE(other)
def __lt__(self, other):
"""
Unsigned less than
:param other: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
return self.ULT(other)
def __le__(self, other):
"""
Unsigned less than or equal to
:param other: The other operand
:return: TrueResult(), FalseResult(), or MaybeResult()
"""
return self.ULE(other)
def __add__(self, o):
return self.add(o)
def __radd__(self, o):
return self.add(o)
def __sub__(self, o):
return self.sub(o)
def __rsub__(self, o):
return StridedInterval(bits=self.bits, stride=0, lower_bound=o, upper_bound=o).sub(self)
@normalize_types
def __mul__(self, o):
return self.mul(o)
@normalize_types
def __mod__(self, o):
# TODO: Make a better approximatiom
# FIXME: this is the implementation of the unsigned modulo
# implement also the signed one.
if o.is_integer and o.lower_bound == 0:
return StridedInterval.empty(o.bits)
if self.is_integer and o.is_integer:
r = self.lower_bound % o.lower_bound
si = StridedInterval(bits=self.bits, stride=0, lower_bound=r, upper_bound=r)
return si
all_resulting_intervals = []
for s in self._ssplit():
for t in o._ssplit():
card = s.udiv(t).cardinality
if card == 1:
tmp = s.sub(s.udiv(t)).mul(t)
else:
tmp = StridedInterval(bits=self.bits, stride=1, lower_bound=0, upper_bound=o.upper_bound - 1)
all_resulting_intervals.append(tmp)
return StridedInterval.least_upper_bound(*all_resulting_intervals).normalize()
@normalize_types
def __floordiv__(self, o):
"""
Unsigned division
:param o: The divisor
:return: The quotient (self / o)
"""
return self.udiv(o)
def __truediv__(self, other):
return self // other # decline to involve floating point numbers at ALL
def __neg__(self):
return self.bitwise_not()
def __invert__(self):
return self.bitwise_not()
@normalize_types
def __or__(self, other):
return self.bitwise_or(other)
@normalize_types
def __and__(self, other):
return self.bitwise_and(other)
def __rand__(self, other):
return self.__and__(other)
@normalize_types
def __xor__(self, other):
return self.bitwise_xor(other)
def __rxor__(self, other):
return self.__xor__(other)
def __lshift__(self, other):
return self.lshift(other)
def __rshift__(self, shift_amount):
"""
Arithmetic shift right.
:param StridedInterval shift_amount: Number of bits to shift right.
:return: The shifted StridedInterval object
:rtype: StridedInterval
"""
return self.rshift_arithmetic(shift_amount)
def __repr__(self):
if self.is_empty:
s = "<%d>[EmptySI]" % (self._bits)
else:
lower_bound = self._lower_bound if type(self._lower_bound) == str else "%#x" % self._lower_bound
upper_bound = self._upper_bound if type(self._upper_bound) == str else "%#x" % self._upper_bound
s = "<%d>0x%x[%s, %s]%s" % (
self._bits,
self._stride,
lower_bound,
upper_bound,
"R" if self._reversed else "",
)
if self.uninitialized:
s += "(uninit)"
return s
#
# Other operations
#
[docs] def LShR(self, shift_amount):
"""
Logical shift right.
:param StridedInterval shift_amount: The amount of shifting
:return: The shifted StridedInterval object
:rtype: StridedInterval
"""
return self.rshift_logical(shift_amount)
#
# Properties
#
@property
def name(self):
return self._name
@property
def reversed(self):
return self._reversed
@property
def size(self):
logger.warning("StridedInterval.size will be deprecated soon. Please use StridedInterval.cardinality instead.")
return self.cardinality
@property
def cardinality(self):
if self.is_bottom:
return 0
elif self.is_integer:
return 1
else:
return (self._modular_sub(self._upper_bound, self._lower_bound, self.bits) + self._stride) // self._stride
@property
def complement(self):
"""
Return the complement of the interval
Refer section 3.1 augmented for managing strides
:return:
"""
# case 1
if self.is_empty:
return StridedInterval.top(self.bits)
# case 2
if self.is_top:
return StridedInterval.empty(self.bits)
# case 3
y_plus_1 = StridedInterval._modular_add(self.upper_bound, 1, self.bits)
x_minus_1 = StridedInterval._modular_sub(self.lower_bound, 1, self.bits)
# the new stride has to be the GCD between the old stride and the distance
# between the new lower bound and the new upper bound. This assure that in
# the new interval the boundaries are valid solution when the SI is
# evaluated.
dist = StridedInterval._wrapped_cardinality(y_plus_1, x_minus_1, self.bits) - 1
# the new SI is an integer
if dist < 0:
new_stride = 0
elif self._stride == 0:
new_stride = 1
else:
new_stride = math.gcd(self._stride, dist)
return StridedInterval(
lower_bound=y_plus_1,
upper_bound=x_minus_1,
bits=self.bits,
stride=new_stride,
uninitialized=self.uninitialized,
)
@property
def lower_bound(self):
return self._lower_bound
@lower_bound.setter
def lower_bound(self, value):
self._lower_bound = value
@property
def upper_bound(self):
return self._upper_bound
@upper_bound.setter
def upper_bound(self, value):
self._upper_bound = value
@property
def bits(self):
return self._bits
@property
def stride(self):
return self._stride
@stride.setter
def stride(self, value):
self._stride = value
@property
@reversed_processor
def max(self):
"""
Treat this StridedInterval as a set of unsigned numbers, and return the greatest one
:return: the greatest number in this StridedInterval when evaluated as unsigned, or None if empty
"""
if not self.is_empty:
splitted = self._ssplit()
return splitted[0].upper_bound
else:
# It is empty!
return None
@property
@reversed_processor
def min(self):
"""
Treat this StridedInterval as a set of unsigned numbers, and return the smallest one
:return: the smallest number in this StridedInterval when evaluated as unsigned, or None if empty
"""
if not self.is_empty:
splitted = self._ssplit()
return splitted[-1].lower_bound
else:
# It is empty
return None
@property
def unique(self):
return self.lower_bound is not None and self.lower_bound == self.upper_bound
def _min_bits(self):
v = self._upper_bound
assert v >= 0
return StridedInterval.min_bits(v)
@property
def is_empty(self):
"""
The same as is_bottom
:return: True/False
"""
return self.is_bottom
@property
def is_top(self):
"""
If this is a TOP value.
:return: True if this is a TOP
"""
return self.stride == 1 and self.lower_bound == self._modular_add(self.upper_bound, 1, self.bits)
@property
def is_bottom(self):
"""
Whether this StridedInterval is a BOTTOM, in other words, describes an empty set of integers.
:return: True/False
"""
return self._is_bottom
@property
def is_integer(self):
"""
If this is an integer, i.e. self.lower_bound == self.upper_bound.
:return: True if this is an integer, False otherwise
"""
return self.lower_bound == self.upper_bound
@property
def is_interval(self):
return not self.is_integer
@property
def n_values(self):
return (StridedInterval._wrapped_cardinality(self.lower_bound, self.upper_bound, self.bits) // self.stride) + 1
#
# Modular arithmetic
#
@staticmethod
def _modular_add(a, b, bits):
return (a + b) % (2**bits)
@staticmethod
def _modular_sub(a, b, bits):
return (a - b) % (2**bits)
@staticmethod
def _modular_mul(a, b, bits):
return (a * b) % (2**bits)
#
# Helper methods
#
[docs] @staticmethod
def lcm(a, b):
"""
Get the least common multiple.
:param a: The first operand (integer)
:param b: The second operand (integer)
:return: Their LCM
"""
return a * b // math.gcd(a, b)
[docs] @staticmethod
def gcd(a, b):
"""
Get the greatest common divisor.
:param a: The first operand (integer)
:param b: The second operand (integer)
:return: Their GCD
"""
return math.gcd(a, b)
[docs] @staticmethod
def highbit(k):
return 1 << (k - 1)
[docs] @staticmethod
def min_bits(val, max_bits=None):
if val == 0:
return 1
elif val < 0:
if max_bits is None:
return int(math.log(-val, 2) + 1) + 1
else:
assert isinstance(max_bits, int)
return int(math.log((((1 << max_bits) - 1) & ~(-val)) + 1, 2) + 1)
else:
# FIXME: Support other bits
# Here we assume the maximum val is 64 bits
# Special case to deal with the floating-point imprecision
if val > 0xFFFFFFFFFFFE0000 and val <= 0x10000000000000000:
return 64
return int(math.log(val, 2) + 1)
[docs] @staticmethod
def max_int(k):
return StridedInterval.highbit(k + 1) - 1
[docs] @staticmethod
def min_int(k):
return -StridedInterval.highbit(k)
[docs] @staticmethod
def signed_max_int(k):
return 2 ** (k - 1) - 1
[docs] @staticmethod
def signed_min_int(k):
return -(2 ** (k - 1))
@staticmethod
def _to_negative(a, bits):
return -((1 << bits) - a)
[docs] @staticmethod
def upper(bits, i, stride):
"""
:return:
"""
if stride >= 1:
offset = i % stride
max = StridedInterval.max_int(bits) # pylint:disable=redefined-builtin
max_offset = max % stride
if max_offset >= offset:
o = max - (max_offset - offset)
else:
o = max - ((max_offset + stride) - offset)
return o
else:
return StridedInterval.max_int(bits)
[docs] @staticmethod
def lower(bits, i, stride):
"""
:return:
"""
if stride >= 1:
offset = i % stride
min = StridedInterval.min_int(bits) # pylint:disable=redefined-builtin
min_offset = min % stride
if offset >= min_offset:
o = min + (offset - min_offset)
else:
o = min + ((offset + stride) - min_offset)
return o
else:
return StridedInterval.min_int(bits)
@staticmethod
def _gap(src_interval, tar_interval):
"""
Refer section 3.1; gap function.
:param src_interval: first argument or interval 1
:param tar_interval: second argument or interval 2
:return: Interval representing gap between two intervals
"""
assert src_interval.bits == tar_interval.bits, "Number of bits should be same for operands"
# use the same variable names as in paper
s = src_interval
t = tar_interval
(_, b) = (s.lower_bound, s.upper_bound)
(c, _) = (t.lower_bound, t.upper_bound)
w = s.bits
# case 1
if (not t._surrounds_member(b)) and (not s._surrounds_member(c)):
# FIXME: maybe we can do better here and to not fix the stride to 1
# FIXME: found the first common integer for more precision
return StridedInterval(lower_bound=c, upper_bound=b, bits=w, stride=1).complement
# otherwise
return StridedInterval.empty(w)
[docs] @staticmethod
def top(bits, name=None, uninitialized=False):
"""
Get a TOP StridedInterval.
:return:
"""
return StridedInterval(
name=name,
bits=bits,
stride=1,
lower_bound=0,
upper_bound=StridedInterval.max_int(bits),
uninitialized=uninitialized,
)
[docs] @staticmethod
def empty(bits):
return StridedInterval(bits=bits, bottom=True)
@staticmethod
def _wrapped_cardinality(x, y, bits):
"""
Return the cardinality for a set of number (| x, y |) on the wrapped-interval domain.
:param x: The first operand (an integer)
:param y: The second operand (an integer)
:return: The cardinality
"""
if x == ((y + 1) % (2**bits)):
return 2**bits
else:
return ((y - x) + 1) & (2**bits - 1)
@staticmethod
def _is_msb_zero(v, bits):
"""
Checks if the most significant bit is zero (i.e. is the integer positive under signed arithmetic).
:param v: The integer to check with
:param bits: Bits of the integer
:return: True or False
"""
return (v & (2**bits - 1)) & (2 ** (bits - 1)) == 0
@staticmethod
def _is_msb_one(v, bits):
"""
Checks if the most significant bit is one (i.e. is the integer negative under signed arithmetic).
:param v: The integer to check with
:param bits: Bits of the integer
:return: True or False
"""
return not StridedInterval._is_msb_zero(v, bits)
@staticmethod
def _get_msb(v, bits):
"""
Get the MSB (most significant bit).
:param v: The integer
:param bits: Bits of the integer
:return: the MSB
"""
if StridedInterval._is_msb_zero(v, bits):
return 0
return 1
@staticmethod
def _unsigned_to_signed(v, bits):
"""
Convert an unsigned integer to a signed integer.
:param v: The unsigned integer
:param bits: How many bits this integer should be
:return: The converted signed integer
"""
if StridedInterval._is_msb_zero(v, bits):
return v
else:
return -(2**bits - v)
@staticmethod
def _wrapped_overflow_add(a, b):
"""
Determines if an overflow happens during the addition of `a` and `b`.
:param a: The first operand (StridedInterval)
:param b: The other operand (StridedInterval)
:return: True if overflows, False otherwise
"""
if a.is_integer and a.lower_bound == 0:
# Special case: if `a` or `b` is a zero
card_self = 0
else:
card_self = StridedInterval._wrapped_cardinality(a.lower_bound, a.upper_bound, a.bits)
if b.is_integer and b.lower_bound == 0:
# Special case: if `a` or `b` is a zero
card_b = 0
else:
card_b = StridedInterval._wrapped_cardinality(b.lower_bound, b.upper_bound, b.bits)
return (card_self + card_b) > (StridedInterval.max_int(a.bits) + 1)
@staticmethod
def _wrapped_overflow_sub(a, b):
"""
Determines if an overflow happens during the subtraction of `a` and `b`.
:param a: The first operand (StridedInterval)
:param b: The other operand (StridedInterval)
:return: True if overflows, False otherwise
"""
return StridedInterval._wrapped_overflow_add(a, b)
@staticmethod
def _wrapped_unsigned_mul(a, b):
"""
Perform wrapped unsigned multiplication on two StridedIntervals.
:param a: The first operand (StridedInterval)
:param b: The second operand (StridedInterval)
:return: The multiplication result
"""
if a.bits != b.bits:
logger.warning("Signed mul: two parameters have different bit length")
bits = max(a.bits, b.bits)
lb = a.lower_bound * b.lower_bound
ub = a.upper_bound * b.upper_bound
uninit_flag = a.uninitialized | b.uninitialized
if (ub - lb) < (2**bits):
if b.is_integer:
# Multiplication with an integer, and it does not overflow!
stride = abs(a.stride * b.lower_bound)
elif a.is_integer:
stride = abs(a.lower_bound * b.stride)
else:
stride = math.gcd(a.stride, b.stride)
return StridedInterval(bits=bits, stride=stride, lower_bound=lb, upper_bound=ub, uninitialized=uninit_flag)
else:
# Overflow occurred
return StridedInterval.top(bits, uninitialized=False)
@staticmethod
def _wrapped_signed_mul(a, b):
"""
Perform wrapped signed multiplication on two StridedIntervals.
:param a: The first operand (StridedInterval)
:param b: The second operand (StridedInterval)
:return: The product
"""
# NOTE: interval here should never straddle poles
# FIXME: add assert to be sure of it!
if a.bits != b.bits:
logger.warning("Signed mul: two parameters have different bit length")
bits = max(a.bits, b.bits)
# shorter SI
a_lb_positive = StridedInterval._is_msb_zero(a.lower_bound, bits)
a_ub_positive = StridedInterval._is_msb_zero(a.upper_bound, bits)
b_lb_positive = StridedInterval._is_msb_zero(b.lower_bound, bits)
b_ub_positive = StridedInterval._is_msb_zero(b.upper_bound, bits)
uninit_flag = a.uninitialized | b.uninitialized
if b.is_integer:
if b_lb_positive:
stride = abs(a.stride * b.lower_bound)
else:
# if the number is negative we have to get its value first
stride = abs(a.stride * StridedInterval._unsigned_to_signed(b.lower_bound, bits))
elif a.is_integer:
if a_lb_positive:
stride = abs(b.stride * a.lower_bound)
else:
# if the number is negative we have to get its value first:
stride = abs(b.stride * StridedInterval._unsigned_to_signed(a.lower_bound, bits))
else:
stride = math.gcd(a.stride, b.stride)
if a_lb_positive and a_ub_positive and b_lb_positive and b_ub_positive:
# [2, 5] * [10, 20] = [20, 100]
lb = a.lower_bound * b.lower_bound
ub = a.upper_bound * b.upper_bound
if ub - lb < (2**bits):
return StridedInterval(
bits=bits, stride=stride, lower_bound=lb, upper_bound=ub, uninitialized=uninit_flag
)
else:
return StridedInterval.top(bits, uninitialized=uninit_flag)
elif not a_lb_positive and not a_ub_positive and not b_lb_positive and not b_ub_positive:
# [-5, -2] * [-20, -10] = [20, 100]
lb = StridedInterval._unsigned_to_signed(a.upper_bound, bits) * StridedInterval._unsigned_to_signed(
b.upper_bound, bits
)
ub = StridedInterval._unsigned_to_signed(a.lower_bound, bits) * StridedInterval._unsigned_to_signed(
b.lower_bound, bits
)
if ub - lb < (2**bits):
return StridedInterval(
bits=bits, stride=stride, lower_bound=lb, upper_bound=ub, uninitialized=uninit_flag
)
else:
return StridedInterval.top(bits, uninitialized=uninit_flag)
elif not a_lb_positive and not a_ub_positive and b_lb_positive and b_ub_positive:
# [-10, -2] * [2, 5] = [-50, -4]
lb = StridedInterval._unsigned_to_signed(a.lower_bound, bits) * b.upper_bound
ub = StridedInterval._unsigned_to_signed(a.upper_bound, bits) * b.lower_bound
# since the intervals do not straddle the poles, ub is greater than lb
if ub - lb < (2**bits):
lb &= 2**bits - 1
ub &= 2**bits - 1
return StridedInterval(
bits=bits, stride=stride, lower_bound=lb, upper_bound=ub, uninitialized=uninit_flag
)
else:
return StridedInterval.top(bits, uninitialized=uninit_flag)
elif a_lb_positive and a_ub_positive and not b_lb_positive and not b_ub_positive:
# [2, 10] * [-5, -2] = [-50, -4]
lb = a.upper_bound * StridedInterval._unsigned_to_signed(b.lower_bound, bits)
ub = a.lower_bound * StridedInterval._unsigned_to_signed(b.upper_bound, bits)
# since the intervals do not straddle the poles, ub is greater than lb
if ub - lb < (2**bits):
lb &= 2**bits - 1
ub &= 2**bits - 1
return StridedInterval(
bits=bits, stride=stride, lower_bound=lb, upper_bound=ub, uninitialized=uninit_flag
)
else:
return StridedInterval.top(bits, uninitialized=uninit_flag)
else:
raise Exception(f"We shouldn't see this case: {a} * {b}")
@staticmethod
def _wrapped_unsigned_div(a, b):
"""
Perform wrapped unsigned division on two StridedIntervals.
:param a: The dividend (StridedInterval)
:param b: The divisor (StridedInterval)
:return: The quotient
"""
bits = max(a.bits, b.bits)
divisor_lb, divisor_ub = b.lower_bound, b.upper_bound
uninit_flag = a.uninitialized | b.uninitialized
# Make sure divisor_lb and divisor_ub is not 0
if divisor_lb == 0:
# Can we increment it?
if divisor_ub == 0:
# We can't :-(
return StridedInterval.empty(bits)
else:
divisor_lb += 1
# If divisor_ub is 0, decrement it to get last but one element
if divisor_ub == 0:
divisor_ub = (divisor_ub - 1) & (2**bits - 1)
lb = a.lower_bound // divisor_ub
ub = a.upper_bound // divisor_lb
# TODO: Can we make a more precise estimate of the stride?
stride = 1
return StridedInterval(bits=bits, stride=stride, lower_bound=lb, upper_bound=ub, uninitialized=uninit_flag)
@staticmethod
def _wrapped_signed_div(a, b):
"""
Perform wrapped unsigned division on two StridedIntervals.
:param a: The dividend (StridedInterval)
:param b: The divisor (StridedInterval)
:return: The quotient
"""
bits = max(a.bits, b.bits)
# Make sure the divisor is not 0
divisor_lb = b.lower_bound
divisor_ub = b.upper_bound
uninit_flag = a.uninitialized | b.uninitialized
if divisor_lb == 0:
# Try to increment it
if divisor_ub == 0:
return StridedInterval.empty(bits)
else:
divisor_lb = 1
# If divisor_ub is 0, decrement it to get last but one element
if divisor_ub == 0:
divisor_ub = (divisor_ub - 1) & (2**bits - 1)
dividend_positive = StridedInterval._is_msb_zero(a.lower_bound, bits)
divisor_positive = StridedInterval._is_msb_zero(b.lower_bound, bits)
# TODO: Can we make a more precise estimate of the stride?
stride = 1
if dividend_positive and divisor_positive:
# They are all positive numbers!
lb = a.lower_bound // divisor_ub
ub = a.upper_bound // divisor_lb
elif dividend_positive and not divisor_positive:
# + / -
lb = a.upper_bound // StridedInterval._unsigned_to_signed(divisor_ub, bits)
ub = a.lower_bound // StridedInterval._unsigned_to_signed(divisor_lb, bits)
elif not dividend_positive and divisor_positive:
# - / +
lb = StridedInterval._unsigned_to_signed(a.lower_bound, bits) // divisor_lb
ub = StridedInterval._unsigned_to_signed(a.upper_bound, bits) // divisor_ub
else:
# - / -
lb = StridedInterval._unsigned_to_signed(a.upper_bound, bits) // StridedInterval._unsigned_to_signed(
b.lower_bound, bits
)
ub = StridedInterval._unsigned_to_signed(a.lower_bound, bits) // StridedInterval._unsigned_to_signed(
b.upper_bound, bits
)
return StridedInterval(bits=bits, stride=stride, lower_bound=lb, upper_bound=ub, uninitialized=uninit_flag)
#
# Membership testing and poset ordering
#
@staticmethod
def _lex_lte(x, y, bits):
"""
Lexicographical LTE comparison
:param x: The first operand (integer)
:param y: The second operand (integer)
:param bits: bit-width of the operands
:return: True or False
"""
return (x & (2**bits - 1)) <= (y & (2**bits - 1))
@staticmethod
def _lex_lt(x, y, bits):
"""
Lexicographical LT comparison
:param x: The first operand (integer)
:param y: The second operand (integer)
:param bits: bit-width of the operands
:return: True or False
"""
return (x & (2**bits - 1)) < (y & (2**bits - 1))
def _surrounds_member(self, v):
s = self
return self._lex_lte(v - s.lower_bound, s.upper_bound - s.lower_bound, s.bits)
def _is_surrounded(self, b):
"""
Perform a wrapped LTE comparison only considering the SI bounds
:param a: The first operand
:param b: The second operand
:return: True if a <= b, False otherwise
"""
a = self
if a.is_empty:
return True
if a.is_top and b.is_top:
return True
elif a.is_top:
return False
elif b.is_top:
return True
if b._surrounds_member(a.lower_bound) and b._surrounds_member(a.upper_bound):
if (
(b.lower_bound == a.lower_bound and b.upper_bound == a.upper_bound)
or not a._surrounds_member(b.lower_bound)
or not a._surrounds_member(b.upper_bound)
):
return True
return False
#
# Arithmetic operations
#
[docs] @reversed_processor
def neg(self):
"""
Unary operation: neg
:return: 0 - self
"""
si = StridedInterval(bits=self.bits, stride=0, lower_bound=0, upper_bound=0).sub(self)
si.uninitialized = self.uninitialized
return si
[docs] @normalize_types
def add(self, b):
"""
Binary operation: add
:param b: The other operand
:return: self + b
"""
new_bits = max(self.bits, b.bits)
# TODO: Some improvements can be made here regarding the following case
# TODO: SI<16>0xff[0x0, 0xff] + 3
# TODO: In current implementation, it overflows, but it doesn't have to
# optimization
# case: SI<16>0xff[0x0, 0xff] + 3
"""
if self.is_top and b.is_integer:
si = self.copy()
si.lower_bound = b.lower_bound
return si
elif b.is_top and self.is_integer:
si = b.copy()
si.lower_bound = self.lower_bound
return si
"""
# FIXME
overflow = self._wrapped_overflow_add(self, b)
if overflow:
return StridedInterval.top(self.bits)
lb = self._modular_add(self.lower_bound, b.lower_bound, new_bits)
ub = self._modular_add(self.upper_bound, b.upper_bound, new_bits)
# Is it initialized?
uninitialized = self.uninitialized or b.uninitialized
# Take the GCD of two operands' strides
stride = math.gcd(self.stride, b.stride)
return StridedInterval(
bits=new_bits, stride=stride, lower_bound=lb, upper_bound=ub, uninitialized=uninitialized
).normalize()
[docs] @normalize_types
def sub(self, b):
"""
Binary operation: sub
:param b: The other operand
:return: self - b
"""
new_bits = max(self.bits, b.bits)
overflow = self._wrapped_overflow_sub(self, b)
if overflow:
return StridedInterval.top(self.bits)
lb = self._modular_sub(self.lower_bound, b.upper_bound, new_bits)
ub = self._modular_sub(self.upper_bound, b.lower_bound, new_bits)
# Is it initialized?
uninitialized = self.uninitialized or b.uninitialized
# Take the GCD of two operands' strides
stride = math.gcd(self.stride, b.stride)
return StridedInterval(
bits=new_bits, stride=stride, lower_bound=lb, upper_bound=ub, uninitialized=uninitialized
).normalize()
[docs] @normalize_types
def mul(self, o):
"""
Binary operation: multiplication
:param o: The other operand
:return: self * o
"""
if self.is_integer and o.is_integer:
# Two integers!
a, b = self.lower_bound, o.lower_bound
ret = StridedInterval(bits=self.bits, stride=0, lower_bound=a * b, upper_bound=a * b)
if a * b > (2**self.bits - 1):
logger.warning("Overflow in multiplication detected.")
return ret.normalize()
else:
# All other cases
# Cut from both north pole and south pole
si1_psplit = self._psplit()
si2_psplit = o._psplit()
all_resulting_intervals = []
for si1 in si1_psplit:
for si2 in si2_psplit:
tmp_unsigned_mul = self._wrapped_unsigned_mul(si1, si2)
tmp_signed_mul = self._wrapped_signed_mul(si1, si2)
for tmp_meet in tmp_unsigned_mul._multi_valued_intersection(tmp_signed_mul):
all_resulting_intervals.append(tmp_meet)
return StridedInterval.least_upper_bound(*all_resulting_intervals).normalize()
[docs] @normalize_types
def sdiv(self, o):
"""
Binary operation: signed division
:param o: The divisor
:return: (self / o) in signed arithmetic
"""
# TODO: copy the code from wrapped interval
splitted_dividends = self._psplit()
splitted_divisors = o._psplit()
resulting_intervals = set()
for dividend in splitted_dividends:
for divisor in splitted_divisors:
tmp = self._wrapped_signed_div(dividend, divisor)
resulting_intervals.add(tmp)
return StridedInterval.least_upper_bound(*resulting_intervals).normalize()
[docs] @normalize_types
def udiv(self, o):
"""
Binary operation: unsigned division
:param o: The divisor
:return: (self / o) in unsigned arithmetic
"""
# FIXME: copy the code fromm wrapped interval
splitted_dividends = self._ssplit()
splitted_divisors = o._ssplit()
resulting_intervals = set()
for dividend in splitted_dividends:
for divisor in splitted_divisors:
tmp = self._wrapped_unsigned_div(dividend, divisor)
resulting_intervals.add(tmp)
return StridedInterval.least_upper_bound(*resulting_intervals).normalize()
# FIXME: preserve uninitialized flag?
[docs] @reversed_processor
def bitwise_not(self):
"""
Unary operation: bitwise not
:return: ~self
"""
splitted_si = self._ssplit()
if len(splitted_si) == 0:
return StridedInterval.empty(self.bits)
result_interval = []
for si in splitted_si:
lb = ~si.upper_bound
ub = ~si.lower_bound
stride = self.stride
tmp = StridedInterval(bits=self.bits, stride=stride, lower_bound=lb, upper_bound=ub)
result_interval.append(tmp)
si = StridedInterval.least_upper_bound(*result_interval).normalize()
# preserve the uninitialized flag
si.uninitialized = self.uninitialized
return si
[docs] @normalize_types
def bitwise_or(self, t):
"""
Binary operation: logical or
:param b: The other operand
:return: self | b
"""
"""
This implementation combines the approaches used by 'WYSINWYX: what you see is not what you execute'
paper and 'Signedness-Agnostic Program Analysis: Precise Integer Bounds for Low-Level Code'. The
first paper provides an sound way to approximate the stride, whereas the second provides a way
to calculate the or operation using wrapping intervals.
Note that, even though according Warren's work 'Hacker's delight', one should follow different
approaches to calculate the minimun/maximum values of an or operations according on the type
of the operands (signed/unsigned). On the other other hand, by splitting the wrapping-intervals
at the south pole, we can safely and soundly only use the Warren's functions for unsigned
integers.
"""
s = self
result_interval = []
for u in s._ssplit():
for v in t._ssplit():
w = u.bits
# u |w v
if u.is_integer:
s_t = StridedInterval._ntz(v.stride)
elif v.is_integer:
s_t = StridedInterval._ntz(u.stride)
else:
s_t = min(StridedInterval._ntz(u.stride), StridedInterval._ntz(v.stride))
if u.is_integer and u.lower_bound == 0:
new_stride = v.stride
elif v.is_integer and v.lower_bound == 0:
new_stride = u.stride
else:
new_stride = 2**s_t
mask = (1 << s_t) - 1
r = (u.lower_bound & mask) | (v.lower_bound & mask)
m = (2**w) - 1
low_bound = WarrenMethods.min_or(
u.lower_bound & (~mask & m),
u.upper_bound & (~mask & m),
v.lower_bound & (~mask & m),
v.upper_bound & (~mask & m),
w,
)
upper_bound = WarrenMethods.max_or(
u.lower_bound & (~mask & m),
u.upper_bound & (~mask & m),
v.lower_bound & (~mask & m),
v.upper_bound & (~mask & m),
w,
)
if low_bound == upper_bound:
new_stride = 0
new_interval = StridedInterval(
lower_bound=((low_bound & (~mask & m)) | r),
upper_bound=((upper_bound & (~mask & m)) | r),
bits=w,
stride=new_stride,
)
result_interval.append(new_interval)
return StridedInterval.least_upper_bound(*result_interval).normalize()
[docs] @normalize_types
def bitwise_and(self, t):
"""
Binary operation: logical and
:param b: The other operand
:return:
"""
"""
The following code implements the and operations as presented in the paper
'Signedness-Agnostic Program Analysis: Precise Integer Bounds for Low-Level Code'
"""
s = self
def number_of_ones(n):
ctr = 0
while n > 0:
ctr += 1
n &= n - 1
return ctr
# Optimization: if one of the two intervals is an integer and contains only one one we can be precise
for a, b in [[s, t], [t, s]]:
if a.is_integer and number_of_ones(a.lower_bound) == 1:
if a.lower_bound == (1 << (t.bits - 1)):
# It's testing the sign bit
stride = 1 << (a.bits - 1)
if b.is_integer:
if b.lower_bound == stride:
return StridedInterval(bits=b.bits, stride=0, lower_bound=stride, upper_bound=stride)
else:
return StridedInterval(bits=b.bits, stride=0, lower_bound=0, upper_bound=0)
else:
is_sol = (
a.lower_bound - b.lower_bound
) % b.stride == 0 and b.lower_bound <= a.lower_bound <= b.upper_bound
if is_sol:
return StridedInterval(bits=b.bits, stride=stride, lower_bound=0, upper_bound=stride)
else:
return StridedInterval(bits=b.bits, stride=0, lower_bound=0, upper_bound=0)
else:
# FIXME: implement case only one 1 not in first position
pass
# paper's and
new_interval = s.bitwise_not().bitwise_or(t.bitwise_not()).bitwise_not()
return new_interval.normalize()
[docs] @normalize_types
def bitwise_xor(self, t):
"""
Operation xor
:param t: The other operand.
"""
# Using same variables as in paper
s = self
new_interval = (
(s.bitwise_not().bitwise_or(t)).bitwise_not().bitwise_or(s.bitwise_or(t.bitwise_not()).bitwise_not())
)
return new_interval.normalize()
def _pre_shift(self, shift_amount):
def get_range(expr):
"""
Get the range of bits for shifting
:param expr:
:return: A tuple of maximum and minimum bits to shift
"""
def round(max, x): # pylint:disable=redefined-builtin
if x < 0 or x > max:
return max
else:
return x
if isinstance(expr, numbers.Number):
return (expr, expr)
assert type(expr) is StridedInterval
if expr.is_integer:
return (round(self.bits, expr.lower_bound), round(self.bits, expr.lower_bound))
else:
if expr.lower_bound < 0:
if expr.upper_bound >= 0:
return (0, self.bits)
else:
return (self.bits, self.bits)
else:
return (round(self.bits, self.lower_bound), round(self.bits, self.upper_bound))
lower, upper = get_range(shift_amount)
# TODO: Is trancating necessary?
return lower, upper
[docs] @reversed_processor
def rshift_logical(self, shift_amount):
"""
Logical shift right.
:param StridedInterval shift_amount: The amount of shifting
:return: The shifted StridedInterval
:rtype: StridedInterval
"""
lower, upper = self._pre_shift(shift_amount)
# Shift the lower_bound and upper_bound by all possible amounts, and union all possible results
ret = None
for amount in range(lower, upper + 1):
si_ = self._rshift_logical(amount)
ret = si_ if ret is None else ret.union(si_)
if ret is None:
return StridedInterval.top(self.bits)
ret.normalize()
ret.uninitialized = self.uninitialized
return ret
def _unrev_rshift_logical(self, shift_amount):
"""
Logical shift right.
:param StridedInterval shift_amount: The amount of shifting
:return: The shifted StridedInterval
:rtype: StridedInterval
"""
lower, upper = self._pre_shift(shift_amount)
# Shift the lower_bound and upper_bound by all possible amounts, and union all possible results
ret = None
for amount in range(lower, upper + 1):
si_ = self._rshift_logical(amount)
ret = si_ if ret is None else ret.union(si_)
if ret is None:
return StridedInterval.top(self.bits)
ret.normalize()
ret.uninitialized = self.uninitialized
return ret
[docs] @reversed_processor
def rshift_arithmetic(self, shift_amount):
"""
Arithmetic shift right.
:param StridedInterval shift_amount: The amount of shifting
:return: The shifted StridedInterval
:rtype: StridedInterval
"""
lower, upper = self._pre_shift(shift_amount)
# Shift the lower_bound and upper_bound by all possible amounts, and union all possible results
ret = None
for amount in range(lower, upper + 1):
si_ = self._rshift_arithmetic(amount)
ret = si_ if ret is None else ret.union(si_)
if ret is None:
return StridedInterval.top(self.bits)
ret.normalize()
ret.uninitialized = self.uninitialized
return ret
[docs] @reversed_processor
def lshift(self, shift_amount):
lower, upper = self._pre_shift(shift_amount)
# Shift the lower_bound and upper_bound by all possible amounts, and
# get min/max values from all the resulting values
new_lower_bound = None
new_upper_bound = None
for shift_amount in range(lower, upper + 1):
l = self.lower_bound << shift_amount
if new_lower_bound is None or l < new_lower_bound:
new_lower_bound = l
u = self.upper_bound << shift_amount
if new_upper_bound is None or u > new_upper_bound:
new_upper_bound = u
# NOTE: If this is an arithmetic operation, we should take care
# of sign-changes.
ret = StridedInterval(
bits=self.bits,
stride=max(self.stride << lower, 1),
lower_bound=new_lower_bound,
upper_bound=new_upper_bound,
uninitialized=self.uninitialized,
)
ret.normalize()
return ret
[docs] @reversed_processor
def cast_low(self, tok):
assert tok <= self.bits
mask = (1 << tok) - 1
if self.stride >= (1 << tok):
logger.warning("Tried to cast_low an interval to an interval shorter than its stride.")
if tok == self.bits:
return self.copy()
else:
# the interval can be represented in tok bits
if (self.lower_bound & mask) == self.lower_bound and (self.upper_bound & mask) == self.upper_bound:
return StridedInterval(
bits=tok,
stride=self.stride,
lower_bound=self.lower_bound,
upper_bound=self.upper_bound,
uninitialized=self.uninitialized,
)
# the range between lower bound and upper bound can be represented
# in the new SI
elif 0 <= (self.upper_bound - self.lower_bound) <= mask:
l = self.lower_bound & mask
u = self.upper_bound & mask
return StridedInterval(
bits=tok, stride=self.stride, lower_bound=l, upper_bound=u, uninitialized=self.uninitialized
)
elif (self.upper_bound & mask == self.lower_bound & mask) and (
(self.upper_bound - self.lower_bound) & mask == 0
):
# This operation doesn't affect the stride. Stride should be 0 then.
bound = self.lower_bound & mask
return StridedInterval(
bits=tok, stride=0, lower_bound=bound, upper_bound=bound, uninitialized=self.uninitialized
)
else:
ntz = StridedInterval._ntz(self.stride)
if tok > ntz:
new_lower = self.lower_bound & ((2**ntz) - 1)
stride = 2**ntz
ret = self.top(tok, uninitialized=self.uninitialized)
ret.stride = stride
ret.lower_bound = new_lower
k = (ret.upper_bound - ret.lower_bound) // ret.stride
ret.upper_bound = ret.stride * k + ret.lower_bound
else:
ret = StridedInterval(
bits=tok,
stride=0,
lower_bound=(self.lower_bound & ((2**tok) - 1)),
upper_bound=(self.upper_bound & ((2**tok) - 1)),
)
return ret
def _unrev_cast_low(self, tok):
assert tok <= self.bits
mask = (1 << tok) - 1
if self.stride >= (1 << tok):
logger.warning("Tried to cast_low an interval to a an interval shorter than its stride.")
if tok == self.bits:
return self.copy()
else:
# the interval can be represented in tok bits
if (self.lower_bound & mask) == self.lower_bound and (self.upper_bound & mask) == self.upper_bound:
return StridedInterval(
bits=tok,
stride=self.stride,
lower_bound=self.lower_bound,
upper_bound=self.upper_bound,
uninitialized=self.uninitialized,
)
# the range between lower bound and upper bound can be represented
# in the new SI
elif self.upper_bound - self.lower_bound <= mask:
l = self.lower_bound & mask
u = self.upper_bound & mask
# Keep the signs!
if self.lower_bound < 0:
# how this should happen ?
logger.warning("Lower bound values is less than 0")
l = StridedInterval._to_negative(l, tok)
if self.upper_bound < 0:
# how this should happen ?
logger.warning("Upper bound value is less than 0")
u = StridedInterval._to_negative(u, tok)
return StridedInterval(
bits=tok, stride=self.stride, lower_bound=l, upper_bound=u, uninitialized=self.uninitialized
)
elif (self.upper_bound & mask == self.lower_bound & mask) and (
(self.upper_bound - self.lower_bound) & mask == 0
):
# This operation doesn't affect the stride. Stride should be 0 then.
bound = self.lower_bound & mask
return StridedInterval(
bits=tok, stride=0, lower_bound=bound, upper_bound=bound, uninitialized=self.uninitialized
)
else:
# TODO: How can we do better here? For example, keep the stride information?
return self.top(tok, uninitialized=self.uninitialized)
[docs] @normalize_types
def concat(self, b):
# Zero-extend
a = self.nameless_copy()
a._bits += b.bits
new_si = a.lshift(b.bits)
new_b = b.copy()
# Zero-extend b
new_b._bits = new_si.bits
if new_si.is_integer:
# We can be more precise!
new_si._bits = new_b.bits
new_si._stride = new_b.stride
new_si._lower_bound = new_si.lower_bound + b.lower_bound
new_si._upper_bound = new_si.upper_bound + b.upper_bound
return new_si
else:
return new_si.bitwise_or(new_b)
def _unrev_extract(self, high_bit, low_bit):
assert low_bit >= 0
bits = high_bit - low_bit + 1
if low_bit != 0:
ret = self._unrev_rshift_logical(low_bit)
else:
ret = self.copy()
if bits != self.bits:
ret = ret._unrev_cast_low(bits)
ret.uninitialized = self.uninitialized
return ret.normalize()
[docs] @reversed_processor
def agnostic_extend(self, new_length):
"""
Unary operation: SignExtend
:param new_length: New length after sign-extension
:return: A new StridedInterval
"""
"""
In a sign-agnostic implementation of strided-intervals a number can be signed or unsigned both.
Given a SI, we must pay attention how we extend its lower bound and upper bound.
Assuming that the lower bound is in the left emishpere (positive number).
Let's assume first that the SI is signed and its upper bound is in the right emisphere. Extending it with leading
1s (i.e., its MSB) is correct given that its values would be preserved.
On the other hand if the number is unsigned we should not replicate its MSB, since this would increase the value
of the upper bound in the new interval. In this case the correct approach would be to add 0 in front of the number,
i.e., moving it to the left emisphere. But this approach wouldn't be correct in the first scenario (signed SI).
The solution in this case is extend the upper bound with 1s. This gives us an overapproximation of the original
SI.
Extending this intuition, the implementation follows the below rules:
(UB: upper bound, LB: lower bound, RE: right emisphere, LE: left emisphere)
1* UB:LE and LB:LE: add leading 0s (sound).
2* UB:RE and LB:RE and the LB is closer to the north pole: add leading 0s to LB and leading 1s to the UB (sound)
3* UB:RE and LB:RE and UB is closer to the north pole: add leading 1s to LB and UB both (sound).
4* UB:LE and LB:RE: add leading 0s to UB and leading 0s to LB (sound).
5* UB:RE and LB:LE: add leading 0s to LB and leading 1s to UB (sound).
6* UB:RE and LB:RE and LB = UB: add leading 0s to LB and 1s to UB and add stride from LB to UB ****
"""
si = self.copy()
si._bits = new_length
leading_1_lb = False
leading_1_ub = False
ub_msb = self._get_msb(self.upper_bound, self.bits)
lb_msb = self._get_msb(self.lower_bound, self.bits)
# the only one which chages the stride
case_6 = False
# LB:RE cases
if lb_msb == 1:
# 2
if ub_msb == 1 and self.upper_bound > self.lower_bound:
leading_1_ub = True
# 3
if ub_msb == 1 and self.lower_bound > self.upper_bound:
leading_1_ub = True
leading_1_lb = True
# 6
if ub_msb == 1 and self.lower_bound == self.upper_bound:
leading_1_ub = True
case_6 = True
# 5
elif ub_msb == 1:
leading_1_ub = True
if leading_1_lb:
mask = (2**new_length - 1) - (2**self.bits - 1)
si._lower_bound |= mask
if leading_1_ub:
mask = (2**new_length - 1) - (2**self.bits - 1)
si._upper_bound |= mask
if case_6:
si.stride = si.upper_bound - si.lower_bound
return si
[docs] @reversed_processor
def zero_extend(self, new_length):
"""
Unary operation: ZeroExtend
:param new_length: New length after zero-extension
:return: A new StridedInterval
"""
si = self.copy()
si._bits = new_length
return si
[docs] @reversed_processor
def sign_extend(self, new_length):
"""
Unary operation: SignExtend
:param new_length: New length after sign-extension
:return: A new StridedInterval
"""
msb = self.extract(self.bits - 1, self.bits - 1).eval(2)
if msb == [0]:
# All positive numbers
return self.zero_extend(new_length)
if msb == [1]:
# All negative numbers
si = self.copy()
si._bits = new_length
mask = (2**new_length - 1) - (2**self.bits - 1)
si._lower_bound |= mask
si._upper_bound |= mask
else:
# Both positive numbers and negative numbers
numbers = self._nsplit()
# Since there are both positive and negative numbers, there must be two bounds after nsplit
# assert len(numbers) == 2
all_resulting_intervals = []
assert len(numbers) > 0
for n in numbers:
a, b = n.lower_bound, n.upper_bound
mask_a = 0
mask_b = 0
mask_n = ((1 << (new_length - n.bits)) - 1) << n.bits
if StridedInterval._get_msb(a, n.bits) == 1:
mask_a = mask_n
if StridedInterval._get_msb(b, n.bits) == 1:
mask_b = mask_n
si_ = StridedInterval(bits=new_length, stride=n.stride, lower_bound=a | mask_a, upper_bound=b | mask_b)
all_resulting_intervals.append(si_)
si = StridedInterval.least_upper_bound(*all_resulting_intervals).normalize()
si.uninitialized = self.uninitialized
return si
[docs] @normalize_types
def union(self, b):
"""
The union operation. It might return a DiscreteStridedIntervalSet to allow for better precision in analysis.
:param b: Operand
:return: A new DiscreteStridedIntervalSet, or a new StridedInterval.
"""
if not allow_dsis:
return StridedInterval.least_upper_bound(self, b)
else:
if (
self.cardinality > discrete_strided_interval_set.DEFAULT_MAX_CARDINALITY_WITHOUT_COLLAPSING
or b.cardinality > discrete_strided_interval_set.DEFAULT_MAX_CARDINALITY_WITHOUT_COLLAPSING
):
return StridedInterval.least_upper_bound(self, b)
else:
dsis = DiscreteStridedIntervalSet(bits=self._bits, si_set={self})
return dsis.union(b)
@staticmethod
def _bigger(interval1, interval2):
"""
Return interval with bigger cardinality
Refer Section 3.1
:param interval1: first interval
:param interval2: second interval
:return: Interval or interval2 whichever has greater cardinality
"""
if interval2.cardinality > interval1.cardinality:
return interval2.copy()
return interval1.copy()
@staticmethod
def _ntz(x):
"""
Get the number of consecutive zeros
:param x:
:return:
"""
if x == 0:
return 0
y = (~x) & (x - 1) # There is actually a bug in BAP until 0.8
def bits(y):
n = 0
while y != 0:
n += 1
y >>= 1
return n
return bits(y)
[docs] @staticmethod
def least_upper_bound(*intervals_to_join):
"""
Pseudo least upper bound.
Join the given set of intervals into a big interval. The resulting strided interval is the one which in
all the possible joins of the presented SI, presented the least number of values.
The number of joins to compute is linear with the number of intervals to join.
Draft of proof:
Considering three generic SI (a,b, and c) ordered from their lower bounds, such that
a.lower_bund <= b.lower_bound <= c.lower_bound, where <= is the lexicographic less or equal.
The only joins which have sense to compute are:
* a U b U c
* b U c U a
* c U a U b
All the other combinations fall in either one of these cases. For example: b U a U c does not make make sense
to be calculated. In fact, if one draws this union, the result is exactly either (b U c U a) or (a U b U c) or
(c U a U b).
:param intervals_to_join: Intervals to join
:return: Interval that contains all intervals
"""
assert len(intervals_to_join) > 0, "No intervals to join"
# Check if all intervals are of same width
all_same = all(x.bits == intervals_to_join[0].bits for x in intervals_to_join)
assert all_same, "All intervals to join should be same"
# Optimization: If we have only one interval, then return that interval as result
if len(intervals_to_join) == 1:
return intervals_to_join[0].copy()
# Optimization: If we have only two intervals, the pseudo-join is fine and more precise
if len(intervals_to_join) == 2:
return StridedInterval.pseudo_join(intervals_to_join[0], intervals_to_join[1])
# sort the intervals in increasing left bound
sorted_intervals = sorted(intervals_to_join, key=lambda x: x.lower_bound)
# Fig 3 of the paper
ret = None
# we try all possible joins (linear with the number of SI to join)
# and we return the one with the least number of values.
for i in range(len(sorted_intervals)):
# let's join all of them
si = reduce(
lambda x, y: StridedInterval.pseudo_join(x, y, False), sorted_intervals[i:] + sorted_intervals[0:i]
)
if ret is None or ret.n_values > si.n_values:
ret = si
if any([x for x in intervals_to_join if x.uninitialized]):
ret.uninitialized = True
return ret
@normalize_types
def _union(self, b):
# FIXME: to remove
# this function is here only for retro compatibility with the other parts of angr
return StridedInterval.pseudo_join(self, b)
[docs] @staticmethod
def pseudo_join(s, b, smart_join=True):
"""
It two intervals in a way that the resulting SI is the one that has the least
SI cardinality (i.e., which represents the least number of elements) possible if the smart_join flag is enabled,
otherwise it just joins the SI according the order they are passed to the function.
The pseudo-join operation is not associative in wrapping intervals (please refer to section 3.1 paper
'Signedness-Agnostic Program Analysis: Precise Integer Bounds for Low-Level Code'), Therefore the join of three
WI may give us different results according on the order we join them. All of the results will be sound, though.
Please use the function least_upper_bound as a stub.
:param s: The first SI
:param b: The other SI.
:param smart_join: Enable the smart join behavior. If this flag is set, this function joins the two SI in a way
that the resulting Si has least number of elements (more precise). If it is unset, this
function will join the two SI according on the order they are passed to the function.
:return: A new StridedInterval
"""
assert s.bits == b.bits
w = s.bits
if s._reversed != b._reversed:
logger.warning("Incoherent reversed flag between operands %s and %s", s, b)
uninit_flag = s.uninitialized | b.uninitialized
#
# Trivial cases
#
if s.is_empty:
return b
if b.is_empty:
return s
if s.is_integer and b.is_integer:
u = max(s.upper_bound, b.upper_bound) if smart_join else b.upper_bound
l = min(s.lower_bound, b.lower_bound) if smart_join else s.lower_bound
stride = abs(u - l)
return StridedInterval(bits=w, stride=stride, lower_bound=l, upper_bound=u, uninitialized=uninit_flag)
#
# Other cases
#
if s._is_surrounded(b):
# Containment: s <= b
new_stride = StridedInterval.gcd(s.stride, b.stride) if s.is_interval else b.stride
new_stride = StridedInterval.gcd(new_stride, s._modular_sub(s.lower_bound, b.lower_bound, w))
return StridedInterval(
bits=w,
stride=new_stride,
lower_bound=b.lower_bound,
upper_bound=b.upper_bound,
uninitialized=uninit_flag,
)
elif b._is_surrounded(s):
# Containment: b <= s
# TODO: This case is missing in the original implementation. Is that a bug?
new_stride = StridedInterval.gcd(s.stride, b.stride) if b.is_interval else s.stride
new_stride = StridedInterval.gcd(new_stride, s._modular_sub(b.lower_bound, s.lower_bound, w))
return StridedInterval(
bits=w,
stride=new_stride,
lower_bound=s.lower_bound,
upper_bound=s.upper_bound,
uninitialized=uninit_flag,
)
elif (
s._surrounds_member(b.lower_bound)
and s._surrounds_member(b.upper_bound)
and b._surrounds_member(s.lower_bound)
and b._surrounds_member(s.upper_bound)
):
# The union of them covers the entire sphere
return StridedInterval.top(w, uninitialized=uninit_flag)
elif s._surrounds_member(b.lower_bound):
# Overlapping. Nor s or b are integer here.
# We return the join with less values
new_stride = StridedInterval.gcd(s.stride, b.stride)
new_stride = StridedInterval.gcd(new_stride, s._modular_sub(b.lower_bound, s.lower_bound, w))
return StridedInterval(
bits=w,
stride=new_stride,
lower_bound=s.lower_bound,
upper_bound=b.upper_bound,
uninitialized=uninit_flag,
)
elif b._surrounds_member(s.lower_bound):
# Overlapping. Nor s or b are integer here.
# We return the join with less values
new_stride = StridedInterval.gcd(s.stride, b.stride)
new_stride = StridedInterval.gcd(new_stride, s._modular_sub(s.lower_bound, b.lower_bound, w))
return StridedInterval(
bits=w,
stride=new_stride,
lower_bound=b.lower_bound,
upper_bound=s.upper_bound,
uninitialized=uninit_flag,
)
else:
# no overlapping.
# we join the two intervals according on the order they are given
if not smart_join:
if s.is_integer:
new_stride = StridedInterval.gcd(b.stride, s._modular_sub(b.lower_bound, s.lower_bound, w))
elif b.is_integer:
new_stride = StridedInterval.gcd(s.stride, s._modular_sub(b.lower_bound, s.lower_bound, w))
else:
new_stride = StridedInterval.gcd(s.stride, b.stride)
new_stride = StridedInterval.gcd(
new_stride, StridedInterval._wrapped_cardinality(s.lower_bound, b.lower_bound, w) - 1
)
return StridedInterval(
bits=w,
stride=new_stride,
lower_bound=s.lower_bound,
upper_bound=b.upper_bound,
uninitialized=uninit_flag,
)
# Else: smart join.
# we return the join which produce an interval with the least number of values
if s.is_integer:
new_stride = b.stride
elif b.is_integer:
new_stride = s.stride
else:
new_stride = StridedInterval.gcd(s.stride, b.stride)
# from b to s
new_stride1 = StridedInterval.gcd(
new_stride, StridedInterval._wrapped_cardinality(b.lower_bound, s.lower_bound, w) - 1
)
# from s to b
new_stride2 = StridedInterval.gcd(
new_stride, StridedInterval._wrapped_cardinality(s.lower_bound, b.lower_bound, w) - 1
)
si1 = StridedInterval(
bits=w,
stride=new_stride1,
lower_bound=b.lower_bound,
upper_bound=s.upper_bound,
uninitialized=uninit_flag,
)
si2 = StridedInterval(
bits=w,
stride=new_stride2,
lower_bound=s.lower_bound,
upper_bound=b.upper_bound,
uninitialized=uninit_flag,
)
if si1.n_values <= si2.n_values:
return si1
else:
return si2
@staticmethod
def _minimal_common_integer(si_0, si_1):
"""
Calculates the minimal integer that appears in both StridedIntervals.
As a wrapper method of _minimal_common_integer_splitted(), this method takes arbitrary StridedIntervals.
For more information, please refer to the comment of _minimal_common_integer_splitted().
:param si_0: the first StridedInterval
:type si_0: StridedInterval
:param si_1: the second StridedInterval
:type si_1: StridedInterval
:return: the minimal common integer, or None if there is no common integer
"""
si_0_splitted = si_0._ssplit()
si_1_splitted = si_1._ssplit()
len_0, len_1 = len(si_0_splitted), len(si_1_splitted)
if len_0 == 1 and len_1 == 2:
# Swap them so we don't have to handle dual
si_0_splitted, si_1_splitted = si_1_splitted, si_0_splitted
len_0, len_1 = len_1, len_0
if len_0 == 1 and len_1 == 1:
# No splitting was necessary
return StridedInterval._minimal_common_integer_splitted(si_0, si_1)
if len_0 == 2 and len_1 == 1:
int_0 = StridedInterval._minimal_common_integer_splitted(si_0_splitted[0], si_1_splitted[0])
int_1 = StridedInterval._minimal_common_integer_splitted(si_0_splitted[1], si_1_splitted[0])
else:
# len_0 == 2 and len_1 == 2
int_0 = StridedInterval._minimal_common_integer_splitted(si_0_splitted[0], si_1_splitted[0])
int_1 = StridedInterval._minimal_common_integer_splitted(si_0_splitted[1], si_1_splitted[1])
if int_0 is None:
return int_1
elif int_1 is None:
return int_0
else:
return int_0
[docs] @staticmethod
def extended_euclid(a, b):
"""
It calculates the GCD of a and b, and two values x and y such that:
a*x + b*y = GCD(a,b).
This code has been taken from the project sympy.
:param a: first integer
:param b: second integer
:return: x,y and the GCD of a and b
"""
if b == 0:
return (1, 0, a)
x0, y0, d = StridedInterval.extended_euclid(b, a % b)
x, y = y0, x0 - (a // b) * y0
return x, y, d
[docs] @staticmethod
def sign(a):
return -1 if a < 0 else 1
[docs] @staticmethod
def igcd(a, b):
"""
:param a: First integer
:param b: Second integer
:return: the integer GCD between a and b
"""
a = int(round(a))
b = int(round(b))
if b < 0:
b = -b
while b:
a, b = b, a % b
if a == 1 or b == 1:
return 1
return a
[docs] @staticmethod
def diop_natural_solution_linear(c, a, b):
"""
It finds the fist natural solution of the diophantine equation
a*x + b*y = c. Some lines of this code are taken from the project
sympy.
:param c: constant
:param a: quotient of x
:param b: quotient of y
:return: the first natural solution of the diophatine equation
"""
def get_intersection(a, b, a_dir, b_dir):
# Do the intersection between two
# ranges.
if (a_dir, b_dir) == (">=", ">="):
lb = a if a > b else b
ub = float("inf")
elif (a_dir, b_dir) == ("<=", ">="):
if a > b:
lb = b
ub = a
else:
lb = None
ub = None
elif (a_dir, b_dir) == (">=", "<="):
if b > a:
lb = a
ub = b
else:
lb = None
ub = None
elif (a_dir, b_dir) == ("<=", "<="):
ub = a if a < b else b
lb = float("-inf")
return lb, ub
d = StridedInterval.igcd(a, StridedInterval.igcd(b, c))
a = a // d
b = b // d
c = c // d
if c == 0:
return (0, 0)
else:
x0, y0, d = StridedInterval.extended_euclid(int(abs(a)), int(abs(b)))
x0 = x0 * StridedInterval.sign(a)
y0 = y0 * StridedInterval.sign(b)
if c % d == 0:
"""
Integer solutions are: (c*x0 + b*t, c*y0 - a*t)
we have to get the first positive solution, which means
that we have to solve the following disequations for t:
c*x0 + b*t >= 0 and c*y0 - a*t >= 0.
"""
assert b != 0
assert a != 0
t0 = (-c * x0) / float(b)
t1 = (c * y0) / float(a)
# direction of the disequation depends on b and a sign
if b < 0:
t0_dir = "<="
else:
t0_dir = ">="
if a < 0:
t1_dir = ">="
else:
t1_dir = "<="
# calculate the intersection between the found
# solution intervals to get the common solutions
# for t.
lb, ub = get_intersection(t0, t1, t0_dir, t1_dir)
# Given that we are looking for the first value
# which solve the diophantine equation, we have to
# select the value of t closer to 0.
if lb <= 0 and ub >= 0:
t = ub if abs(ub) < abs(lb) else lb
elif lb == float("inf") or lb == float("-inf"):
t = ub
elif ub == float("inf") or ub == float("-inf"):
t = lb
else:
t = ub if abs(ub) < abs(lb) else lb
# round the value of t
if t == ub:
t = int(math.floor(t))
else:
t = int(math.ceil(t))
return (c * x0 + b * t, c * y0 - a * t)
else:
return (None, None)
@staticmethod
def _minimal_common_integer_splitted(si_0, si_1):
"""
Calculates the minimal integer that appears in both StridedIntervals.
It's equivalent to finding an integral solution for equation `ax + b = cy + d` that makes `ax + b` minimal
si_0.stride, si_1.stride being a and c, and si_0.lower_bound, si_1.lower_bound being b and d, respectively.
Upper bounds are used to check whether the minimal common integer exceeds the bound or not. None is returned
if no minimal common integers can be found within the range.
Some assumptions:
# - None of the StridedIntervals straddles the south pole. Consequently, we have x <= max_int(si.bits) and y <=
# max_int(si.bits)
# - a, b, c, d are all positive integers
# - x >= 0, y >= 0
:param StridedInterval si_0: the first StridedInterval
:param StridedInterval si_1: the second StrideInterval
:return: the minimal common integer, or None if there is no common integer
"""
a, c = si_0.stride, si_1.stride
b, d = si_0.lower_bound, si_1.lower_bound
# if any of them is an integer
if si_0.is_integer:
if si_1.is_integer:
return None if si_0.lower_bound != si_1.lower_bound else si_0.lower_bound
elif (
si_0.lower_bound >= si_1.lower_bound
and si_0.lower_bound <= si_1.upper_bound
and (si_0.lower_bound - si_1.lower_bound) % si_1.stride == 0
):
return si_0.lower_bound
else:
return None
elif si_1.is_integer:
return StridedInterval._minimal_common_integer_splitted(si_1, si_0)
# shortcut
if si_0.upper_bound < si_1.lower_bound or si_1.upper_bound < si_0.lower_bound:
# They don't overlap at all
return None
if (d - b) % StridedInterval.gcd(a, c) != 0:
# They don't overlap
return None
"""
Given two strided intervals a = sa[lba, uba] and b = sb[lbb, ubb], the first integer shared
by them is found by finding the minimum values of ka and kb which solve the equation:
ka * sa + lba = kb * sb + lbb
In particular one can solve the above diophantine equation and find the parameterized solutions
of ka and kb, with respect to a parameter t.
The minimum natural value of the parameter t which gives two positive natural values of ka and kb
is used to resolve ka and kb, and finally to solve the above equation and get the minimum shared integer.
"""
x, y = StridedInterval.diop_natural_solution_linear(-(b - d), a, -c)
if a is None or b is None:
return None
first_integer = x * a + b
assert first_integer == y * c + d
if (
first_integer >= si_0.lower_bound
and first_integer <= si_0.upper_bound
and first_integer >= si_1.lower_bound
and first_integer <= si_1.upper_bound
):
return first_integer
else:
return None
[docs] @normalize_types
def intersection(self, b):
intersection = self._multi_valued_intersection(b)
v = intersection[0]
if len(intersection) == 2:
v = StridedInterval.pseudo_join(v, intersection[1])
return v
@normalize_types
def _multi_valued_intersection(self, b):
if self.is_empty or b.is_empty:
return (StridedInterval.empty(self.bits),)
assert self.bits == b.bits
if self.is_integer and b.is_integer:
if self.lower_bound == b.lower_bound:
# They are the same number!
ret = (
StridedInterval(
bits=self.bits, stride=0, lower_bound=self.lower_bound, upper_bound=self.lower_bound
),
)
else:
ret = (StridedInterval.empty(self.bits),)
elif self.is_integer:
integer = self.lower_bound
if (b.lower_bound - integer) % b.stride == 0 and b._surrounds_member(integer):
ret = (StridedInterval(bits=self.bits, stride=0, lower_bound=integer, upper_bound=integer),)
else:
ret = (StridedInterval.empty(self.bits),)
elif b.is_integer:
integer = b.lower_bound
if (integer - self.lower_bound) % self.stride == 0 and self._surrounds_member(integer):
ret = (StridedInterval(bits=self.bits, stride=0, lower_bound=integer, upper_bound=integer),)
else:
ret = (StridedInterval.empty(self.bits),)
else:
# None of the operands is an integer
# Note that this is not a faithful implementation of the WI paper, rather it is based on WrappedMeet() in
# wrapped-intervals:lib/RangeAnalysis/WrappedRange.cpp . Please see wrapped-intervals on GitHub at
# https://github.com/sav-tools/wrapped-intervals
new_stride = self.lcm(self.stride, b.stride)
if self._is_surrounded(b):
# Containment case
# `b` may fully contain `self`
lb = StridedInterval._minimal_common_integer(self, b)
if lb is None:
ret = (StridedInterval.empty(self.bits),)
else:
ub = self._modular_add(
self._modular_sub(self.upper_bound, lb, self.bits) // new_stride * new_stride, lb, self.bits
)
ret = (StridedInterval(bits=self.bits, stride=new_stride, lower_bound=lb, upper_bound=ub),)
elif b._is_surrounded(self):
# Containment case 2
# `self` contains `b`
lb = StridedInterval._minimal_common_integer(self, b)
if lb is None:
ret = (StridedInterval.empty(self.bits),)
else:
ub = self._modular_add(
self._modular_sub(b.upper_bound, lb, self.bits) // new_stride * new_stride, lb, self.bits
)
ret = (StridedInterval(bits=self.bits, stride=new_stride, lower_bound=lb, upper_bound=ub),)
elif (
self._surrounds_member(b.lower_bound)
and self._surrounds_member(b.upper_bound)
and b._surrounds_member(self.lower_bound)
and b._surrounds_member(self.upper_bound)
):
# One cover the other
# bounds of the two common intervals
# among the SIs
lb_s0 = self.lower_bound
ub_s0 = b.upper_bound
lb_s1 = b.lower_bound
ub_s1 = self.upper_bound
# Let's build the SIs
s0 = StridedInterval(bits=self.bits, lower_bound=lb_s0, upper_bound=ub_s0, stride=self.stride)
s1 = StridedInterval(bits=self.bits, lower_bound=lb_s1, upper_bound=ub_s1, stride=b.stride)
# and find the first common integer
lb_s0_new = StridedInterval._minimal_common_integer(s0, b)
lb_s1_new = StridedInterval._minimal_common_integer(s1, self)
if lb_s0_new is None:
s0 = StridedInterval.empty(self.bits)
else:
ub_s0_new = self._modular_add(
self._modular_sub(ub_s0, lb_s0_new, self.bits) // new_stride * new_stride, lb_s0_new, self.bits
)
s0 = StridedInterval(
bits=self.bits, lower_bound=lb_s0_new, upper_bound=ub_s0_new, stride=new_stride
)
if lb_s1_new is None:
s1 = StridedInterval.empty(self.bits)
else:
ub_s1_new = self._modular_add(
self._modular_sub(ub_s1, lb_s1_new, self.bits) // new_stride * new_stride, lb_s1_new, self.bits
)
s1 = StridedInterval(
bits=self.bits, lower_bound=lb_s1_new, upper_bound=ub_s1_new, stride=new_stride
)
ret = (s0, s1)
# here we have four cases since the overlapping depends also on the stride
elif self._surrounds_member(b.lower_bound):
# Overlapping case 1a
lb = StridedInterval._minimal_common_integer(b, self)
if lb is None:
ret = (StridedInterval.empty(self.bits),)
else:
ub = self._modular_add(
self._modular_sub(self.upper_bound, lb, self.bits) // new_stride * new_stride, lb, self.bits
)
ret = (StridedInterval(bits=self.bits, stride=new_stride, lower_bound=lb, upper_bound=ub),)
elif self._surrounds_member(b.upper_bound):
# Overlapping case 1b
lb = StridedInterval._minimal_common_integer(b, self)
if lb is None:
ret = (StridedInterval.empty(self.bits),)
else:
ub = self._modular_add(
self._modular_sub(b.upper_bound, lb, self.bits) // new_stride * new_stride, lb, self.bits
)
ret = (StridedInterval(bits=self.bits, stride=new_stride, lower_bound=lb, upper_bound=ub),)
elif b._surrounds_member(self.lower_bound):
# Overlapping case 2a
lb = StridedInterval._minimal_common_integer(self, b)
if lb is None:
ret = (StridedInterval.empty(self.bits),)
else:
ub = self._modular_add(
self._modular_sub(b.upper_bound, lb, self.bits) // new_stride * new_stride, lb, self.bits
)
ret = (StridedInterval(bits=self.bits, stride=new_stride, lower_bound=lb, upper_bound=ub),)
elif b._surrounds_member(self.upper_bound):
# Overlapping case 2b
lb = StridedInterval._minimal_common_integer(self, b)
if lb is None:
ret = (StridedInterval.empty(self.bits),)
else:
ub = self._modular_add(
self._modular_sub(self.upper_bound, lb, self.bits) // new_stride * new_stride, lb, self.bits
)
ret = (StridedInterval(bits=self.bits, stride=new_stride, lower_bound=lb, upper_bound=ub),)
else:
# Disjoint case
ret = (StridedInterval.empty(self.bits),)
ret = tuple(r.normalize() for r in ret)
return ret
[docs] @normalize_types
def widen(self, b):
ret = None
if self.is_empty and not b.is_empty:
ret = StridedInterval.top(bits=self.bits)
elif self.is_empty:
ret = b
elif b.is_empty:
ret = self
else:
new_stride = StridedInterval.gcd(self.stride, b.stride)
l = (
StridedInterval.lower(self.bits, self.lower_bound, new_stride)
if b.lower_bound < self.lower_bound
else self.lower_bound
)
u = (
StridedInterval.upper(self.bits, self.upper_bound, new_stride)
if b.upper_bound > self.upper_bound
else self.upper_bound
)
if new_stride == 0:
if self.is_integer and b.is_integer:
ret = StridedInterval(bits=self.bits, stride=1, lower_bound=l, upper_bound=u)
else:
raise ClaripyOperationError("SI: operands are not reduced.")
else:
ret = StridedInterval(bits=self.bits, stride=new_stride, lower_bound=l, upper_bound=u)
ret.normalize()
return ret
[docs] def reverse(self):
"""
This is a delayed reversing function. All it really does is to invert the _reversed property of this
StridedInterval object.
:return: None
"""
if self.bits == 8:
# We cannot reverse a one-byte value
return self
si = self.copy()
si._reversed = not si._reversed
return si
def _reverse(self):
"""
This method reverses the StridedInterval object for real. Do expect loss of precision for most cases!
:return: A new reversed StridedInterval instance
"""
o = self.copy()
# Clear ok reversed flag
o._reversed = not o._reversed
if o.bits == 8:
# No need for reversing
return o.copy()
if o.is_top:
# A TOP is still a TOP after reversing
si = o.copy()
return si
else:
if not o.is_integer:
# We really don't want to do that... but well, sometimes it just happens...
logger.warning("Reversing a real strided-interval %s is bad", self)
# Reversing an integer is easy
rounded_bits = ((o.bits + 7) // 8) * 8
list_bytes = []
si = None
for i in range(0, rounded_bits, 8):
b = o._unrev_extract(min(i + 7, o.bits - 1), i)
list_bytes.append(b)
for b in list_bytes:
si = b if si is None else si.concat(b)
si.uninitialized = self.uninitialized
si._reversed = o._reversed
return si
"""
This reverse operation is unsound and incomplete, but allows the reverse operation to be......
"""
def _involuted_reverse(self):
"""
This method reverses the StridedInterval object for real. Do expect loss of precision for most cases!
:return: A new reversed StridedInterval instance
"""
def inv_is_top(si):
return si.stride == 1 and self._lower_bound == StridedInterval._modular_add(self._upper_bound, 1, self.bits)
o = self.copy()
# Clear the reversed flag
o._reversed = not o._reversed
if o.bits == 8:
# No need for reversing
return o.copy()
if inv_is_top(o):
# A TOP is still a TOP after reversing
si = o.copy()
return si
else:
lb = o._lower_bound
ub = o._upper_bound
rounded_bits = ((o.bits + 7) // 8) * 8
lb_r = []
ub_r = []
for i in range(0, rounded_bits, 8):
if i != 0:
lb = lb >> 8
ub = ub >> 8
lb_r.append(lb & 0xFF)
ub_r.append(ub & 0xFF)
si_lb = None
si_ub = None
for b in lb_r:
if si_lb is None:
si_lb = b
else:
si_lb <<= 8
si_lb |= b
for b in ub_r:
if si_ub is None:
si_ub = b
else:
si_ub <<= 8
si_ub |= b
si = StridedInterval(
bits=o.bits, lower_bound=si_lb, upper_bound=si_ub, stride=o._stride, uninitialized=o.uninitialized
)
si._reversed = o._reversed
if not o.is_integer:
# We really don't want to do that... but well, sometimes it just happens...
logger.warning("Reversing a real strided-interval %s is bad", self)
return si
[docs]def CreateStridedInterval(
name=None,
bits=0,
stride=None,
lower_bound=None,
upper_bound=None,
uninitialized=False,
to_conv=None,
discrete_set=False,
discrete_set_max_cardinality=None,
):
"""
:param name:
:param bits:
:param stride:
:param lower_bound:
:param upper_bound:
:param to_conv:
:param bool discrete_set:
:param int discrete_set_max_cardinality:
:return:
"""
if to_conv is not None:
if isinstance(to_conv, Base):
to_conv = to_conv._model_vsa
if isinstance(to_conv, StridedInterval):
# No conversion will be done
return to_conv
if not isinstance(to_conv, (numbers.Number, BVV)):
raise ClaripyOperationError("Unsupported to_conv type %s" % type(to_conv))
if stride is not None or lower_bound is not None or upper_bound is not None:
raise ClaripyOperationError("You cannot specify both to_conv and other parameters at the same time.")
if isinstance(to_conv, BVV):
bits = to_conv.bits
to_conv_value = to_conv.value
else:
bits = bits
to_conv_value = to_conv
stride = 0
lower_bound = to_conv_value
upper_bound = to_conv_value
bi = StridedInterval(
name=name,
bits=bits,
stride=stride,
lower_bound=lower_bound,
upper_bound=upper_bound,
uninitialized=uninitialized,
)
if not discrete_set:
return bi
else:
dsis = DiscreteStridedIntervalSet(
name=name, bits=bits, si_set={bi}, max_cardinality=discrete_set_max_cardinality
)
return dsis
from .errors import ClaripyVSAError
from ..errors import ClaripyOperationError
from .bool_result import TrueResult, FalseResult, MaybeResult
from . import discrete_strided_interval_set
from .discrete_strided_interval_set import DiscreteStridedIntervalSet
from .valueset import ValueSet
from ..ast.base import Base
from ..bv import BVV