Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions pyreason/scripts/annotation_functions/annotation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numba
import numpy as np

import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval

@numba.njit
def _get_weighted_sum(annotations, weights, mode='lower'):
"""
Expand Down Expand Up @@ -49,7 +47,7 @@ def average(annotations, weights):

lower, upper = _check_bound(avg_lower, avg_upper)

return interval.closed(lower, upper)
return (lower, upper)

@numba.njit
def average_lower(annotations, weights):
Expand All @@ -67,7 +65,7 @@ def average_lower(annotations, weights):

lower, upper = _check_bound(avg_lower, max_upper)

return interval.closed(lower, upper)
return (lower, upper)

@numba.njit
def maximum(annotations, weights):
Expand All @@ -82,7 +80,7 @@ def maximum(annotations, weights):

lower, upper = _check_bound(max_lower, max_upper)

return interval.closed(lower, upper)
return (lower, upper)


@numba.njit
Expand All @@ -98,4 +96,4 @@ def minimum(annotations, weights):

lower, upper = _check_bound(min_lower, min_upper)

return interval.closed(lower, upper)
return (lower, upper)
8 changes: 6 additions & 2 deletions pyreason/scripts/interval/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@


class Interval(structref.StructRefProxy):
def __new__(cls, lower, upper, s=False):
return structref.StructRefProxy.__new__(cls, lower, upper, s, lower, upper)
def __new__(cls, lower, upper, s=False, prev_l=None, prev_u=None):
if prev_l is None:
prev_l = lower
if prev_u is None:
prev_u = upper
return structref.StructRefProxy.__new__(cls, lower, upper, s, prev_l, prev_u)

@property
@njit
Expand Down
22 changes: 12 additions & 10 deletions tests/unit/disable_jit/test_annotation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import pyreason.scripts.annotation_functions.annotation_functions as af

af.interval = SimpleNamespace(closed=lambda l, u, static=False: SimpleNamespace(lower=l, upper=u))

def _interval(lower, upper):
return SimpleNamespace(lower=lower, upper=upper)

Expand Down Expand Up @@ -42,23 +40,27 @@ def test_check_bound(lower, upper, expected):
def test_average():
annotations, weights = _example_annotations()
result = af.average(annotations, weights)
assert result.lower == pytest.approx(1.4 / 3)
assert result.upper == pytest.approx(0.6)
assert isinstance(result, tuple), f"expected tuple, got {type(result)}"
assert result[0] == pytest.approx(1.4 / 3)
assert result[1] == pytest.approx(0.6)

def test_average_lower():
annotations, weights = _example_annotations()
result = af.average_lower(annotations, weights)
assert result.lower == pytest.approx(1.4 / 3)
assert result.upper == pytest.approx(0.6)
assert isinstance(result, tuple), f"expected tuple, got {type(result)}"
assert result[0] == pytest.approx(1.4 / 3)
assert result[1] == pytest.approx(0.6)

def test_maximum():
annotations, weights = _example_annotations()
result = af.maximum(annotations, weights)
assert result.lower == pytest.approx(1.0)
assert result.upper == pytest.approx(1.0)
assert isinstance(result, tuple), f"expected tuple, got {type(result)}"
assert result[0] == pytest.approx(1.0)
assert result[1] == pytest.approx(1.0)

def test_minimum():
annotations, weights = _example_annotations()
result = af.minimum(annotations, weights)
assert result.lower == pytest.approx(0.4)
assert result.upper == pytest.approx(0.6)
assert isinstance(result, tuple), f"expected tuple, got {type(result)}"
assert result[0] == pytest.approx(0.4)
assert result[1] == pytest.approx(0.6)