Skip to content

Commit

Permalink
use a single clamper for floats
Browse files Browse the repository at this point in the history
  • Loading branch information
tybug committed Dec 21, 2024
1 parent 904bdd9 commit de1b32a
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 77 deletions.
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

This patch cleans up some internal code around clamping floats.
30 changes: 8 additions & 22 deletions hypothesis-python/src/hypothesis/internal/conjecture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,8 +1576,7 @@ def draw_float(
(
sampler,
forced_sign_bit,
neg_clamper,
pos_clamper,
clamper,
nasty_floats,
) = self._draw_float_init_logic(
min_value=min_value,
Expand Down Expand Up @@ -1608,12 +1607,8 @@ def draw_float(
)
if allow_nan and math.isnan(result):
clamped = result
elif math.copysign(1.0, result) == -1:
assert neg_clamper is not None
clamped = -neg_clamper(-result)
else:
assert pos_clamper is not None
clamped = pos_clamper(result)
clamped = clamper(result)
if clamped != result and not (math.isnan(result) and allow_nan):
self._draw_float(forced=clamped, fake_forced=fake_forced)
result = clamped
Expand Down Expand Up @@ -1966,23 +1961,14 @@ def permitted(f: float) -> bool:
weights = [0.2 * len(nasty_floats)] + [0.8] * len(nasty_floats)
sampler = Sampler(weights, observe=False) if nasty_floats else None

pos_clamper = neg_clamper = None
if sign_aware_lte(0.0, max_value):
pos_min = max(min_value, smallest_nonzero_magnitude)
allow_zero = sign_aware_lte(min_value, 0.0)
pos_clamper = make_float_clamper(pos_min, max_value, allow_zero=allow_zero)
if sign_aware_lte(min_value, -0.0):
neg_max = min(max_value, -smallest_nonzero_magnitude)
allow_zero = sign_aware_lte(-0.0, max_value)
neg_clamper = make_float_clamper(
-neg_max, -min_value, allow_zero=allow_zero
)

forced_sign_bit: Optional[Literal[0, 1]] = None
if (pos_clamper is None) != (neg_clamper is None):
forced_sign_bit = 1 if neg_clamper else 0
if sign_aware_lte(min_value, -0.0) != sign_aware_lte(0.0, max_value):
forced_sign_bit = 1 if sign_aware_lte(min_value, -0.0) else 0

return (sampler, forced_sign_bit, neg_clamper, pos_clamper, nasty_floats)
clamper = make_float_clamper(
min_value, max_value, smallest_nonzero_magnitude, allow_nan
)
return (sampler, forced_sign_bit, clamper, nasty_floats)


# The set of available `PrimitiveProvider`s, by name. Other libraries, such as
Expand Down
59 changes: 36 additions & 23 deletions hypothesis-python/src/hypothesis/internal/floats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from sys import float_info
from typing import TYPE_CHECKING, Callable, Literal, Optional, SupportsFloat, Union

from hypothesis.internal.conjecture.junkdrawer import clamp

if TYPE_CHECKING:
from typing import TypeAlias
else:
Expand Down Expand Up @@ -136,38 +138,49 @@ def next_up_normal(value: float, width: int, *, allow_subnormal: bool) -> float:
}
assert width_smallest_normals[64] == float_info.min

mantissa_mask = (1 << 52) - 1


def make_float_clamper(
min_float: float = 0.0,
max_float: float = math.inf,
*,
allow_zero: bool = False, # Allows +0.0 (even if minfloat > 0)
min_value: float,
max_value: float,
smallest_nonzero_magnitude: float,
allow_nan: bool,
) -> Optional[Callable[[float], float]]:
"""
Return a function that clamps positive floats into the given bounds.
Returns None when no values are allowed (min > max and zero is not allowed).
"""
if max_float < min_float:
if allow_zero:
min_float = max_float = 0.0
else:
return None

range_size = min(max_float - min_float, float_info.max)
mantissa_mask = (1 << 52) - 1

def float_clamper(float_val: float) -> float:
if min_float <= float_val <= max_float:
return float_val
if float_val == 0.0 and allow_zero:
return float_val
assert sign_aware_lte(min_value, max_value)
range_size = min(max_value - min_value, float_info.max)

def permitted(f: float) -> bool:
if math.isnan(f):
return allow_nan
if 0 < abs(f) < smallest_nonzero_magnitude:
return False
return sign_aware_lte(min_value, f) and sign_aware_lte(f, max_value)

def float_clamper(f: float) -> float:
if permitted(f):
return f
# Outside bounds; pick a new value, sampled from the allowed range,
# using the mantissa bits.
mant = float_to_int(float_val) & mantissa_mask
float_val = min_float + range_size * (mant / mantissa_mask)
mant = float_to_int(abs(f)) & mantissa_mask
f = min_value + range_size * (mant / mantissa_mask)

# if we resampled into the space disallowed by smallest_nonzero_magnitude,
# default to smallest_nonzero_magnitude.
if 0 < abs(f) < smallest_nonzero_magnitude:
f = smallest_nonzero_magnitude
# we must have either -smallest_nonzero_magnitude <= min_value or
# smallest_nonzero_magnitude >= max_value, or no values would be
# possible. If smallest_nonzero_magnitude is not valid (because it's
# larger than max_value), then -smallest_nonzero_magnitude must be valid.
if smallest_nonzero_magnitude > max_value:
f *= -1

# Re-enforce the bounds (just in case of floating point arithmetic error)
return max(min_float, min(max_float, float_val))
return clamp(min_value, f, max_value)

return float_clamper

Expand Down
61 changes: 29 additions & 32 deletions hypothesis-python/tests/cover/test_float_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

import pytest

from hypothesis import example, given, strategies as st
from hypothesis import assume, example, given, strategies as st
from hypothesis.internal.floats import (
SMALLEST_SUBNORMAL,
count_between_floats,
make_float_clamper,
next_down,
next_up,
sign_aware_lte,
)


Expand All @@ -44,40 +46,35 @@ def test_next_float_equal(func, val):
assert func(val) == val


# invalid order -> clamper is None:
@example(2.0, 1.0, 3.0)
# exponent comparisons:
@example(1, float_info.max, 0)
@example(1, float_info.max, 1)
@example(1, float_info.max, 10)
@example(1, float_info.max, float_info.max)
@example(1, float_info.max, math.inf)
@example(1, float_info.max, 0, True)
@example(1, float_info.max, 1, True)
@example(1, float_info.max, 10, True)
@example(1, float_info.max, float_info.max, True)
@example(1, float_info.max, math.inf, True)
# mantissa comparisons:
@example(100.0001, 100.0003, 100.0001)
@example(100.0001, 100.0003, 100.0002)
@example(100.0001, 100.0003, 100.0003)
@given(st.floats(min_value=0), st.floats(min_value=0), st.floats(min_value=0))
def test_float_clamper(min_value, max_value, input_value):
clamper = make_float_clamper(min_value, max_value, allow_zero=False)
if max_value < min_value:
assert clamper is None
return
@example(100.0001, 100.0003, 100.0001, True)
@example(100.0001, 100.0003, 100.0002, True)
@example(100.0001, 100.0003, 100.0003, True)
@example(100.0001, 100.0003, math.nan, False)
@example(0, 10, math.nan, False)
@example(0, 10, math.nan, True)
@given(st.floats(), st.floats(), st.floats(), st.booleans())
def test_float_clamper(min_value, max_value, input_value, allow_nan):
assume(sign_aware_lte(min_value, max_value))

clamper = make_float_clamper(min_value, max_value, SMALLEST_SUBNORMAL, allow_nan)
clamped = clamper(input_value)
if min_value <= input_value <= max_value:
assert input_value == clamped
if math.isnan(clamped):
# we should only clamp to nan if nans are allowed.
assert allow_nan
else:
assert min_value <= clamped <= max_value

# otherwise, we should have clamped to something in the permitted range.
assert sign_aware_lte(min_value, clamped) and sign_aware_lte(clamped, max_value)

@example(0.01, math.inf, 0.0)
@given(st.floats(min_value=0), st.floats(min_value=0), st.floats(min_value=0))
def test_float_clamper_with_allowed_zeros(min_value, max_value, input_value):
clamper = make_float_clamper(min_value, max_value, allow_zero=True)
assert clamper is not None
clamped = clamper(input_value)
if input_value == 0.0 or max_value < min_value:
assert clamped == 0.0
elif min_value <= input_value <= max_value:
# if input_value was permitted in the first place, then the clamped value should
# be the same as the input value.
if sign_aware_lte(min_value, input_value) and sign_aware_lte(
input_value, max_value
):
assert input_value == clamped
else:
assert min_value <= clamped <= max_value

0 comments on commit de1b32a

Please sign in to comment.