diff --git a/pyreason/scripts/annotation_functions/annotation_functions.py b/pyreason/scripts/annotation_functions/annotation_functions.py index 75eb9b6f..e928910d 100755 --- a/pyreason/scripts/annotation_functions/annotation_functions.py +++ b/pyreason/scripts/annotation_functions/annotation_functions.py @@ -3,8 +3,6 @@ import numba import numpy as np -import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval - @numba.njit def _get_weighted_sum(annotations, weights, mode='lower'): """ @@ -49,7 +47,7 @@ def average(annotations, weights): lower, upper = _check_bound(avg_lower, avg_upper) - return interval.closed(lower, upper) + return (lower, upper) @numba.njit def average_lower(annotations, weights): @@ -67,7 +65,7 @@ def average_lower(annotations, weights): lower, upper = _check_bound(avg_lower, max_upper) - return interval.closed(lower, upper) + return (lower, upper) @numba.njit def maximum(annotations, weights): @@ -82,7 +80,7 @@ def maximum(annotations, weights): lower, upper = _check_bound(max_lower, max_upper) - return interval.closed(lower, upper) + return (lower, upper) @numba.njit @@ -98,4 +96,4 @@ def minimum(annotations, weights): lower, upper = _check_bound(min_lower, min_upper) - return interval.closed(lower, upper) + return (lower, upper) diff --git a/pyreason/scripts/interval/interval.py b/pyreason/scripts/interval/interval.py index 18d25ede..66b7bc11 100755 --- a/pyreason/scripts/interval/interval.py +++ b/pyreason/scripts/interval/interval.py @@ -4,8 +4,12 @@ class Interval(structref.StructRefProxy): - def __new__(cls, lower, upper, s=False): - return structref.StructRefProxy.__new__(cls, lower, upper, s, lower, upper) + def __new__(cls, lower, upper, s=False, prev_l=None, prev_u=None): + if prev_l is None: + prev_l = lower + if prev_u is None: + prev_u = upper + return structref.StructRefProxy.__new__(cls, lower, upper, s, prev_l, prev_u) @property @njit diff --git a/tests/unit/disable_jit/test_annotation_functions.py b/tests/unit/disable_jit/test_annotation_functions.py index fcb94c0b..ac887827 100644 --- a/tests/unit/disable_jit/test_annotation_functions.py +++ b/tests/unit/disable_jit/test_annotation_functions.py @@ -4,8 +4,6 @@ import pyreason.scripts.annotation_functions.annotation_functions as af -af.interval = SimpleNamespace(closed=lambda l, u, static=False: SimpleNamespace(lower=l, upper=u)) - def _interval(lower, upper): return SimpleNamespace(lower=lower, upper=upper) @@ -42,23 +40,27 @@ def test_check_bound(lower, upper, expected): def test_average(): annotations, weights = _example_annotations() result = af.average(annotations, weights) - assert result.lower == pytest.approx(1.4 / 3) - assert result.upper == pytest.approx(0.6) + assert isinstance(result, tuple), f"expected tuple, got {type(result)}" + assert result[0] == pytest.approx(1.4 / 3) + assert result[1] == pytest.approx(0.6) def test_average_lower(): annotations, weights = _example_annotations() result = af.average_lower(annotations, weights) - assert result.lower == pytest.approx(1.4 / 3) - assert result.upper == pytest.approx(0.6) + assert isinstance(result, tuple), f"expected tuple, got {type(result)}" + assert result[0] == pytest.approx(1.4 / 3) + assert result[1] == pytest.approx(0.6) def test_maximum(): annotations, weights = _example_annotations() result = af.maximum(annotations, weights) - assert result.lower == pytest.approx(1.0) - assert result.upper == pytest.approx(1.0) + assert isinstance(result, tuple), f"expected tuple, got {type(result)}" + assert result[0] == pytest.approx(1.0) + assert result[1] == pytest.approx(1.0) def test_minimum(): annotations, weights = _example_annotations() result = af.minimum(annotations, weights) - assert result.lower == pytest.approx(0.4) - assert result.upper == pytest.approx(0.6) + assert isinstance(result, tuple), f"expected tuple, got {type(result)}" + assert result[0] == pytest.approx(0.4) + assert result[1] == pytest.approx(0.6)