Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/stickler/doc_split/packet_evaluation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def calculate_ordering_score_per_group(
"""
Calculate Kendall's Tau for each document group.

Single-page groups are excluded (ordering undefined).
Single-page groups are assigned a perfect score of 1.0 (trivially in correct order).

Args:
data: DataFrame with group_id, page_number, page_number_predicted.
Expand All @@ -151,6 +151,7 @@ def calculate_ordering_score_per_group(

for group_id, group_data in data.groupby("group_id"):
if len(group_data) <= 1:
group_scores[group_id] = 1.0
continue

tau, _p_value = kendalltau(
Expand All @@ -163,9 +164,9 @@ def calculate_ordering_score_per_group(

def calculate_average_ordering_score(group_scores: Dict[Any, float]) -> float:
"""
Mean Kendall's Tau across all multi-page groups.
Mean Kendall's Tau across all groups (single-page groups score 1.0).

Returns 0 if no multi-page groups exist.
Returns 0 if no groups exist.
"""
if not group_scores:
return 0
Expand Down
9 changes: 5 additions & 4 deletions tests/doc_split/test_packet_evaluation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,17 @@ def test_reverse_ordering(self):
avg = calculate_average_ordering_score(scores)
assert avg == pytest.approx(-1.0, abs=1e-4)

def test_single_page_groups_excluded(self):
"""Groups with only 1 page should not appear in ordering scores."""
def test_single_page_groups_score_perfect(self):
"""Single-page groups should receive a perfect ordering score of 1.0."""
data = [
_page("invoice", "inv-01", 1, "invoice", "inv-01", 1),
_page("form", "form-01", 2, "form", "form-01", 2),
]
df = pd.DataFrame(data)
scores = calculate_ordering_score_per_group(df)
assert len(scores) == 0
assert calculate_average_ordering_score(scores) == 0
assert len(scores) == 2
assert all(v == 1.0 for v in scores.values())
assert calculate_average_ordering_score(scores) == 1.0

def test_missing_columns_raises(self):
df = pd.DataFrame([{"foo": 1}])
Expand Down
Loading