diff --git a/CHANGELOG.md b/CHANGELOG.md index d9931d0d0e..e032538629 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - `SHAPInsight` breaking with `numpy>=2.4` due to no longer accepted implicit array to scalar conversion +- Using `np.isclose` for assessing equality of interval bounds instead of hard equality + check ### Removed - `parallel_runs` argument from `simulate_scenarios`, since parallelization diff --git a/baybe/utils/interval.py b/baybe/utils/interval.py index 4b1f4c0dca..2945c18dcc 100644 --- a/baybe/utils/interval.py +++ b/baybe/utils/interval.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Union import numpy as np -from attrs import define, field +from attrs import cmp_using, define, field from baybe.serialization import SerialMixin, converter from baybe.settings import active_settings @@ -39,6 +39,7 @@ class Interval(SerialMixin): default=float("-inf"), converter=lambda x: float("-inf") if x is None else float(x), validator=non_nan_float, + eq=cmp_using(eq=lambda a, b: bool(np.isclose(a, b))), ) """The lower end of the interval.""" @@ -46,6 +47,7 @@ class Interval(SerialMixin): default=float("inf"), converter=lambda x: float("inf") if x is None else float(x), validator=non_nan_float, + eq=cmp_using(eq=lambda a, b: bool(np.isclose(a, b))), ) """The upper end of the interval.""" diff --git a/tests/validation/test_interval_validation.py b/tests/validation/test_interval_validation.py index d232273758..5636f50466 100644 --- a/tests/validation/test_interval_validation.py +++ b/tests/validation/test_interval_validation.py @@ -24,3 +24,18 @@ def test_invalid_range(request, bounds): return with pytest.raises(ValueError): Interval(*bounds[::-1]) + + +@pytest.mark.parametrize( + ("other", "expected"), + [ + param(Interval(0, 1), True, id="exact_match"), + param(Interval(0, 0.9999999999999999), True, id="upper_float_imprecision"), + param(Interval(1e-16, 1 - 1e-16), True, id="both_float_imprecision"), + param(Interval(0, 0.5), False, id="different_upper"), + param(Interval(0.5, 1), False, id="different_lower"), + ], +) +def test_close_interval_bounds(other, expected): + """Intervals that are close up to floating-point precision are detected.""" + assert (Interval(0, 1) == other) == expected