diff --git a/src/stickler/doc_split/packet_evaluation_metrics.py b/src/stickler/doc_split/packet_evaluation_metrics.py index b466397..0760de7 100644 --- a/src/stickler/doc_split/packet_evaluation_metrics.py +++ b/src/stickler/doc_split/packet_evaluation_metrics.py @@ -134,7 +134,7 @@ def calculate_ordering_score_per_group( """ Calculate Kendall's Tau for each document group. - Single-page groups are excluded (ordering undefined). + Single-page groups are assigned a perfect score of 1.0 (trivially in correct order). Args: data: DataFrame with group_id, page_number, page_number_predicted. @@ -151,6 +151,7 @@ def calculate_ordering_score_per_group( for group_id, group_data in data.groupby("group_id"): if len(group_data) <= 1: + group_scores[group_id] = 1.0 continue tau, _p_value = kendalltau( @@ -163,9 +164,9 @@ def calculate_ordering_score_per_group( def calculate_average_ordering_score(group_scores: Dict[Any, float]) -> float: """ - Mean Kendall's Tau across all multi-page groups. + Mean Kendall's Tau across all groups (single-page groups score 1.0). - Returns 0 if no multi-page groups exist. + Returns 0 if no groups exist. """ if not group_scores: return 0 diff --git a/tests/doc_split/test_packet_evaluation_metrics.py b/tests/doc_split/test_packet_evaluation_metrics.py index 97fd7ae..4c4684d 100644 --- a/tests/doc_split/test_packet_evaluation_metrics.py +++ b/tests/doc_split/test_packet_evaluation_metrics.py @@ -308,16 +308,17 @@ def test_reverse_ordering(self): avg = calculate_average_ordering_score(scores) assert avg == pytest.approx(-1.0, abs=1e-4) - def test_single_page_groups_excluded(self): - """Groups with only 1 page should not appear in ordering scores.""" + def test_single_page_groups_score_perfect(self): + """Single-page groups should receive a perfect ordering score of 1.0.""" data = [ _page("invoice", "inv-01", 1, "invoice", "inv-01", 1), _page("form", "form-01", 2, "form", "form-01", 2), ] df = pd.DataFrame(data) scores = calculate_ordering_score_per_group(df) - assert len(scores) == 0 - assert calculate_average_ordering_score(scores) == 0 + assert len(scores) == 2 + assert all(v == 1.0 for v in scores.values()) + assert calculate_average_ordering_score(scores) == 1.0 def test_missing_columns_raises(self): df = pd.DataFrame([{"foo": 1}])