import decimal
import functools
import math
import struct
from decimal import Decimal
from enum import Enum
from .errors import ClaripyOperationError
from .backend_object import BackendObject
[docs]def compare_sorts(f):
@functools.wraps(f)
def compare_guard(self, o):
if self.sort != o.sort:
raise TypeError(f"FPVs are differently-sorted ({self.sort} and {o.sort})")
return f(self, o)
return compare_guard
[docs]def normalize_types(f):
@functools.wraps(f)
def normalize_helper(self, o):
if isinstance(o, float):
o = FPV(o, self.sort)
if not isinstance(self, FPV) or not isinstance(o, FPV):
raise TypeError("must have two FPVs")
return f(self, o)
return normalize_helper
[docs]class RM(Enum):
# see https://en.wikipedia.org/wiki/IEEE_754#Rounding_rules
RM_NearestTiesEven = "RM_RNE"
RM_NearestTiesAwayFromZero = "RM_RNA"
RM_TowardsZero = "RM_RTZ"
RM_TowardsPositiveInf = "RM_RTP"
RM_TowardsNegativeInf = "RM_RTN"
[docs] @staticmethod
def default():
return RM.RM_NearestTiesEven
[docs] def pydecimal_equivalent_rounding_mode(self):
return {
RM.RM_TowardsPositiveInf: decimal.ROUND_CEILING,
RM.RM_TowardsNegativeInf: decimal.ROUND_FLOOR,
RM.RM_TowardsZero: decimal.ROUND_DOWN,
RM.RM_NearestTiesEven: decimal.ROUND_HALF_EVEN,
RM.RM_NearestTiesAwayFromZero: decimal.ROUND_UP,
}[self]
RM_NearestTiesEven = RM.RM_NearestTiesEven
RM_NearestTiesAwayFromZero = RM.RM_NearestTiesAwayFromZero
RM_TowardsZero = RM.RM_TowardsZero
RM_TowardsPositiveInf = RM.RM_TowardsPositiveInf
RM_TowardsNegativeInf = RM.RM_TowardsNegativeInf
[docs]class FSort:
[docs] def __init__(self, name, exp, mantissa):
self.name = name
self.exp = exp
self.mantissa = mantissa
def __eq__(self, other):
return self.exp == other.exp and self.mantissa == other.mantissa
def __repr__(self):
return self.name
def __hash__(self):
return hash((self.name, self.exp, self.mantissa))
@property
def length(self):
return self.exp + self.mantissa
[docs] @staticmethod
def from_size(n):
if n == 32:
return FSORT_FLOAT
elif n == 64:
return FSORT_DOUBLE
else:
raise ClaripyOperationError(f"{n} is not a valid FSort size")
[docs] @staticmethod
def from_params(exp, mantissa):
if exp == 8 and mantissa == 24:
return FSORT_FLOAT
elif exp == 11 and mantissa == 53:
return FSORT_DOUBLE
else:
raise ClaripyOperationError("unrecognized FSort params")
FSORT_FLOAT = FSort("FLOAT", 8, 24)
FSORT_DOUBLE = FSort("DOUBLE", 11, 53)
[docs]class FPV(BackendObject):
__slots__ = ["sort", "value"]
[docs] def __init__(self, value, sort):
if not isinstance(value, float) or sort not in {FSORT_FLOAT, FSORT_DOUBLE}:
raise ClaripyOperationError("FPV needs a sort (FSORT_FLOAT or FSORT_DOUBLE) and a float value")
self.value = value
self.sort = sort
def __hash__(self):
return hash((self.value, self.sort))
def __getstate__(self):
return self.value, self.sort
def __setstate__(self, st):
self.value, self.sort = st
def __abs__(self):
return FPV(abs(self.value), self.sort)
def __neg__(self):
return FPV(-self.value, self.sort)
[docs] def fpSqrt(self):
return FPV(math.sqrt(self.value), self.sort)
@normalize_types
@compare_sorts
def __add__(self, o):
return FPV(self.value + o.value, self.sort)
@normalize_types
@compare_sorts
def __sub__(self, o):
return FPV(self.value - o.value, self.sort)
@normalize_types
@compare_sorts
def __mul__(self, o):
return FPV(self.value * o.value, self.sort)
@normalize_types
@compare_sorts
def __mod__(self, o):
return FPV(self.value % o.value, self.sort)
@normalize_types
@compare_sorts
def __truediv__(self, o):
try:
return FPV(self.value / o.value, self.sort)
except ZeroDivisionError:
if str(self.value * o.value)[0] == "-":
return FPV(float("-inf"), self.sort)
else:
return FPV(float("inf"), self.sort)
def __floordiv__(self, other): # decline to involve integers in this floating point process
return self.__truediv__(other)
#
# Reverse arithmetic stuff
#
@normalize_types
@compare_sorts
def __radd__(self, o):
return FPV(o.value + self.value, self.sort)
@normalize_types
@compare_sorts
def __rsub__(self, o):
return FPV(o.value - self.value, self.sort)
@normalize_types
@compare_sorts
def __rmul__(self, o):
return FPV(o.value * self.value, self.sort)
@normalize_types
@compare_sorts
def __rmod__(self, o):
return FPV(o.value % self.value, self.sort)
@normalize_types
@compare_sorts
def __rtruediv__(self, o):
try:
return FPV(o.value / self.value, self.sort)
except ZeroDivisionError:
if str(o.value * self.value)[0] == "-":
return FPV(float("-inf"), self.sort)
else:
return FPV(float("inf"), self.sort)
def __rfloordiv__(self, other): # decline to involve integers in this floating point process
return self.__rtruediv__(other)
#
# Boolean stuff
#
@normalize_types
@compare_sorts
def __eq__(self, o):
return self.value == o.value
@normalize_types
@compare_sorts
def __ne__(self, o):
return self.value != o.value
@normalize_types
@compare_sorts
def __lt__(self, o):
return self.value < o.value
@normalize_types
@compare_sorts
def __gt__(self, o):
return self.value > o.value
@normalize_types
@compare_sorts
def __le__(self, o):
return self.value <= o.value
@normalize_types
@compare_sorts
def __ge__(self, o):
return self.value >= o.value
def __repr__(self):
return f"FPV({self.value:f}, {self.sort})"
[docs]def fpToFP(a1, a2, a3=None):
"""
Returns a FP AST and has three signatures:
fpToFP(ubvv, sort)
Returns a FP AST whose value is the same as the unsigned BVV `a1`
and whose sort is `a2`.
fpToFP(rm, fpv, sort)
Returns a FP AST whose value is the same as the floating point `a2`
and whose sort is `a3`.
fpToTP(rm, sbvv, sort)
Returns a FP AST whose value is the same as the signed BVV `a2` and
whose sort is `a3`.
"""
if isinstance(a1, BVV) and isinstance(a2, FSort):
sort = a2
if sort == FSORT_FLOAT:
pack, unpack = "I", "f"
elif sort == FSORT_DOUBLE:
pack, unpack = "Q", "d"
else:
raise ClaripyOperationError("unrecognized float sort")
try:
packed = struct.pack("<" + pack, a1.value)
(unpacked,) = struct.unpack("<" + unpack, packed)
except OverflowError as e:
# struct.pack sometimes overflows
raise ClaripyOperationError("OverflowError: " + str(e)) from e
return FPV(unpacked, sort)
elif isinstance(a1, RM) and isinstance(a2, FPV) and isinstance(a3, FSort):
return FPV(a2.value, a3)
elif isinstance(a1, RM) and isinstance(a2, BVV) and isinstance(a3, FSort):
return FPV(float(a2.signed), a3)
else:
raise ClaripyOperationError("unknown types passed to fpToFP")
[docs]def fpToFPUnsigned(_rm, thing, sort):
"""
Returns a FP AST whose value is the same as the unsigned BVV `thing` and
whose sort is `sort`.
"""
# thing is a BVV
return FPV(float(thing.value), sort)
[docs]def fpToIEEEBV(fpv):
"""
Interprets the bit-pattern of the IEEE754 floating point number `fpv` as a
bitvector.
:return: A BV AST whose bit-pattern is the same as `fpv`
"""
if fpv.sort == FSORT_FLOAT:
pack, unpack = "f", "I"
elif fpv.sort == FSORT_DOUBLE:
pack, unpack = "d", "Q"
else:
raise ClaripyOperationError("unrecognized float sort")
try:
packed = struct.pack("<" + pack, fpv.value)
(unpacked,) = struct.unpack("<" + unpack, packed)
except OverflowError as e:
# struct.pack sometimes overflows
raise ClaripyOperationError("OverflowError: " + str(e)) from e
return BVV(unpacked, fpv.sort.length)
[docs]def fpFP(sgn, exp, mantissa):
"""
Concatenates the bitvectors `sgn`, `exp` and `mantissa` and returns the
corresponding IEEE754 floating point number.
:return: A FP AST whose bit-pattern is the same as the concatenated
bitvector
"""
concatted = Concat(sgn, exp, mantissa)
sort = FSort.from_size(concatted.size())
if sort == FSORT_FLOAT:
pack, unpack = "I", "f"
elif sort == FSORT_DOUBLE:
pack, unpack = "Q", "d"
else:
raise ClaripyOperationError("unrecognized float sort")
try:
packed = struct.pack("<" + pack, concatted.value)
(unpacked,) = struct.unpack("<" + unpack, packed)
except OverflowError as e:
# struct.pack sometimes overflows
raise ClaripyOperationError("OverflowError: " + str(e)) from e
return FPV(unpacked, sort)
[docs]def fpToSBV(rm, fp, size):
try:
rounding_mode = rm.pydecimal_equivalent_rounding_mode()
val = int(Decimal(fp.value).to_integral_value(rounding_mode))
return BVV(val, size)
except (ValueError, OverflowError):
return BVV(0, size)
except Exception as ex:
print(f"Unhandled error during floating point rounding! {ex}")
raise
[docs]def fpToUBV(rm, fp, size):
# todo: actually make unsigned
try:
rounding_mode = rm.pydecimal_equivalent_rounding_mode()
val = int(Decimal(fp.value).to_integral_value(rounding_mode))
assert (
val & ((1 << size) - 1) == val
), f"Rounding produced values outside the BV range! rounding {fp.value} with rounding mode {rm} produced {val}"
if val < 0:
val = (1 << size) + val
return BVV(val, size)
except (ValueError, OverflowError):
return BVV(0, size)
[docs]def fpEQ(a, b):
"""
Checks if floating point `a` is equal to floating point `b`.
"""
return a == b
[docs]def fpNE(a, b):
"""
Checks if floating point `a` is not equal to floating point `b`.
"""
return a != b
[docs]def fpGT(a, b):
"""
Checks if floating point `a` is greater than floating point `b`.
"""
return a > b
[docs]def fpGEQ(a, b):
"""
Checks if floating point `a` is greater than or equal to floating point `b`.
"""
return a >= b
[docs]def fpLT(a, b):
"""
Checks if floating point `a` is less than floating point `b`.
"""
return a < b
[docs]def fpLEQ(a, b):
"""
Checks if floating point `a` is less than or equal to floating point `b`.
"""
return a <= b
[docs]def fpAbs(x):
"""
Returns the absolute value of the floating point `x`. So:
a = FPV(-3.2, FSORT_DOUBLE)
b = fpAbs(a)
b is FPV(3.2, FSORT_DOUBLE)
"""
return abs(x)
[docs]def fpNeg(x):
"""
Returns the additive inverse of the floating point `x`. So:
a = FPV(3.2, FSORT_DOUBLE)
b = fpAbs(a)
b is FPV(-3.2, FSORT_DOUBLE)
"""
return -x
[docs]def fpSub(_rm, a, b):
"""
Returns the subtraction of the floating point `a` by the floating point `b`.
"""
return a - b
[docs]def fpAdd(_rm, a, b):
"""
Returns the addition of two floating point numbers, `a` and `b`.
"""
return a + b
[docs]def fpMul(_rm, a, b):
"""
Returns the multiplication of two floating point numbers, `a` and `b`.
"""
return a * b
[docs]def fpDiv(_rm, a, b):
"""
Returns the division of the floating point `a` by the floating point `b`.
"""
return a / b
[docs]def fpIsNaN(x):
"""
Checks whether the argument is a floating point NaN.
"""
return math.isnan(x)
[docs]def fpIsInf(x):
"""
Checks whether the argument is a floating point infinity.
"""
return math.isinf(x)
from .bv import BVV, Concat