From 2bcfb9941d3e91e5d4a6798a7f535aa1b40a0883 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 03:21:08 +0000 Subject: [PATCH 1/7] fix: restore Arrow nullable flags lost in Polars round-trips (ENG-375) Polars converts all Arrow fields to nullable=True when producing its Arrow output, corrupting schema intent for non-optional columns. Add `arrow_utils.restore_schema_nullability()` which reinstates the original nullable flag for each field by name using the reference schemas captured before any Polars operation. This is schema-intent-based (unlike `infer_schema_nullable` which is data-based and would incorrectly mark Optional columns as non-nullable when they happen to have no nulls). Apply the fix at all Polars round-trip sites: - FunctionNode: all 5 join sites and the as_table sort - Join.static_process: the per-iteration inner join loop (removes the previous infer_schema_nullable workaround) Tests added (red-green-refactor): - Unit tests for restore_schema_nullability including the Optional[T] correctness case that infer_schema_nullable fails - Integration tests for FunctionNode.get_all_records, FunctionNode._iter_all_from_database, and Join.op_forward Co-Authored-By: Claude Sonnet 4.6 --- src/orcapod/core/nodes/function_node.py | 17 ++ src/orcapod/core/operators/join.py | 20 +- src/orcapod/utils/arrow_utils.py | 71 ++++++ .../test_polars_nullability/__init__.py | 0 .../test_function_node_nullability.py | 237 ++++++++++++++++++ .../test_restore_schema_nullability.py | 186 ++++++++++++++ 6 files changed, 524 insertions(+), 7 deletions(-) create mode 100644 tests/test_data/test_polars_nullability/__init__.py create mode 100644 tests/test_data/test_polars_nullability/test_function_node_nullability.py create mode 100644 tests/test_data/test_polars_nullability/test_restore_schema_nullability.py diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index e5799862..5b61218d 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -686,6 +686,8 @@ def get_cached_results( if taginfo is None or results is None: return {} + taginfo_schema = taginfo.schema + results_schema = results.schema joined = ( pl.DataFrame(taginfo) .join( @@ -695,6 +697,7 @@ def get_cached_results( ) .to_arrow() ) + joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) if joined.num_rows == 0: return {} @@ -920,11 +923,14 @@ def get_all_records( if results is None or taginfo is None: return None + taginfo_schema = taginfo.schema + results_schema = results.schema joined = ( pl.DataFrame(taginfo) .join(pl.DataFrame(results), on=constants.PACKET_RECORD_ID, how="inner") .to_arrow() ) + joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) column_config = ColumnConfig.handle_config(columns, all_info=all_info) @@ -1000,6 +1006,8 @@ def _load_all_cached_records( if taginfo is None or results is None: return None + taginfo_schema = taginfo.schema + results_schema = results.schema joined = ( pl.DataFrame(taginfo) .join( @@ -1009,6 +1017,7 @@ def _load_all_cached_records( ) .to_arrow() ) + joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) if joined.num_rows == 0: return None @@ -1124,6 +1133,8 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: ) if taginfo is not None and results is not None: + taginfo_schema = taginfo.schema + results_schema = results.schema joined = ( pl.DataFrame(taginfo) .join( @@ -1133,6 +1144,7 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: ) .to_arrow() ) + joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) if joined.num_rows > 0: tag_keys = self._input_stream.keys()[0] # Collect pipeline entry_ids for Phase 2 skip check @@ -1356,11 +1368,13 @@ def as_table( ) if column_config.sort_by_tags: + output_table_schema = output_table.schema output_table = ( pl.DataFrame(output_table) .sort(by=self.keys()[0], descending=False) .to_arrow() ) + output_table = arrow_utils.restore_schema_nullability(output_table, output_table_schema) return output_table # ------------------------------------------------------------------ @@ -1444,6 +1458,8 @@ async def async_execute( ) if taginfo is not None and results is not None: + taginfo_schema = taginfo.schema + results_schema = results.schema joined = ( pl.DataFrame(taginfo) .join( @@ -1453,6 +1469,7 @@ async def async_execute( ) .to_arrow() ) + joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) if joined.num_rows > 0: tag_keys = self._input_stream.keys()[0] entry_ids_col = joined.column( diff --git a/src/orcapod/core/operators/join.py b/src/orcapod/core/operators/join.py index f4ddf479..17768014 100644 --- a/src/orcapod/core/operators/join.py +++ b/src/orcapod/core/operators/join.py @@ -171,17 +171,30 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol: counter += 1 new_name = f"{col}_{counter}" rename_map[col] = new_name + + # Build a reference schema for next_table with rename_map applied to + # field names, preserving nullable flags — must be done BEFORE the + # Polars rename call below, which loses all nullability information. + next_ref_schema = pa.schema([ + pa.field(rename_map.get(f.name, f.name), f.type, nullable=f.nullable, metadata=f.metadata) + for f in next_table.schema + ]) + if rename_map: next_table = pl.DataFrame(next_table).rename(rename_map).to_arrow() common_tag_keys = tag_keys.intersection(next_tag_keys) common_tag_keys.add(COMMON_JOIN_KEY) + # Capture the left-side schema before the Polars join, which sets all + # fields to nullable=True regardless of the original schema. + table_ref_schema = table.schema table = ( pl.DataFrame(table) .join(pl.DataFrame(next_table), on=list(common_tag_keys), how="inner") .to_arrow() ) + table = arrow_utils.restore_schema_nullability(table, table_ref_schema, next_ref_schema) tag_keys.update(next_tag_keys) @@ -196,13 +209,6 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol: reordered_columns += [col for col in table.column_names if col not in tag_keys] result_table = table.select(reordered_columns) - # Derive nullable per column from actual null counts so that: - # - Columns with no nulls (e.g. tag/packet fields after inner join) get - # nullable=False, avoiding spurious T | None from Polars' all-True default. - # - Columns that genuinely contain nulls (e.g. Optional fields, or source - # info columns after cross-stream joins) keep nullable=True, preventing - # cast failures. - result_table = result_table.cast(arrow_utils.infer_schema_nullable(result_table)) return ArrowTableStream( result_table, tag_columns=tuple(tag_keys), diff --git a/src/orcapod/utils/arrow_utils.py b/src/orcapod/utils/arrow_utils.py index 9307cf0d..05ec455a 100644 --- a/src/orcapod/utils/arrow_utils.py +++ b/src/orcapod/utils/arrow_utils.py @@ -831,6 +831,77 @@ def infer_schema_nullable(table: "pa.Table") -> "pa.Schema": ) +def restore_schema_nullability( + table: "pa.Table", + *reference_schemas: "pa.Schema", +) -> "pa.Table": + """Restore nullable flags lost during an Arrow → Polars → Arrow round-trip. + + Polars converts all Arrow fields to ``nullable=True`` when producing its + Arrow output. This function repairs the resulting table by looking up each + field name in the supplied reference schemas and reinstating the original + ``nullable`` flag (and type) from those references. + + Fields that are **not** found in any reference schema are left exactly as + Polars produced them (``nullable=True``), which is safe for internally + generated sentinel columns (e.g. ``_exists``, ``__pipeline_entry_id``). + + When the same field name appears in multiple reference schemas the + **last-supplied** schema wins, so callers can pass ``(left_schema, + right_schema)`` and get right-side nullability for join-key columns that + appear in both. + + Args: + table: Arrow table produced by a Polars round-trip (all fields + ``nullable=True``). + *reference_schemas: One or more Arrow schemas carrying the original + nullable flags. Pass the schemas of every table that participated + in the Polars join/sort so that all columns are covered. + + Returns: + A new Arrow table cast to the restored schema. Data is unchanged; + only the schema metadata (nullability, field type, field metadata) + is corrected. + + Raises: + pyarrow.ArrowInvalid: If a restored field type is incompatible with + the actual column data (should not happen for well-formed + round-trips, but surfaced to the caller rather than silently + discarded). + + Example:: + + taginfo_schema = taginfo.schema + results_schema = results.schema + joined = ( + pl.DataFrame(taginfo) + .join(pl.DataFrame(results), on="record_id", how="inner") + .to_arrow() + ) + joined = arrow_utils.restore_schema_nullability( + joined, taginfo_schema, results_schema + ) + """ + # Build a name → Field lookup; later schemas override earlier ones. + field_lookup: dict[str, "pa.Field"] = {} + for schema in reference_schemas: + for field in schema: + field_lookup[field.name] = field + + restored_fields = [] + for field in table.schema: + ref = field_lookup.get(field.name) + if ref is not None: + restored_fields.append( + pa.field(field.name, ref.type, nullable=ref.nullable, metadata=ref.metadata) + ) + else: + restored_fields.append(field) + + restored_schema = pa.schema(restored_fields, metadata=table.schema.metadata) + return table.cast(restored_schema) + + def drop_columns_with_prefix( table: "pa.Table", prefix: str | tuple[str, ...], diff --git a/tests/test_data/test_polars_nullability/__init__.py b/tests/test_data/test_polars_nullability/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_data/test_polars_nullability/test_function_node_nullability.py b/tests/test_data/test_polars_nullability/test_function_node_nullability.py new file mode 100644 index 00000000..2f50d93d --- /dev/null +++ b/tests/test_data/test_polars_nullability/test_function_node_nullability.py @@ -0,0 +1,237 @@ +""" +Integration tests: FunctionNode and Join preserve non-nullable column constraints +after the Arrow → Polars → Arrow round-trip that occurs during joins. + +RED phase: tests should fail before the fix is applied. +""" + +import pyarrow as pa +import pytest + +import orcapod as op +from orcapod.core.nodes.function_node import FunctionNode +from orcapod.databases import InMemoryArrowDatabase + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_function_nodes(pipeline: op.Pipeline) -> list[FunctionNode]: + """Return all FunctionNode instances from compiled pipeline nodes.""" + return [n for n in pipeline.compiled_nodes.values() if isinstance(n, FunctionNode)] + + +# --------------------------------------------------------------------------- +# FunctionNode.get_all_records nullability +# --------------------------------------------------------------------------- + + +class TestFunctionNodeGetAllRecordsNullability: + """FunctionNode.get_all_records must preserve the non-nullable schema of + output columns whose Python type annotation is non-optional (e.g. ``int``).""" + + def test_non_optional_return_type_yields_non_nullable_output_column(self): + """Output column from int return type must be non-nullable after get_all_records.""" + database = InMemoryArrowDatabase() + source = op.sources.DictSource( + [{"id": 1, "x": 10}, {"id": 2, "x": 20}], + tag_columns=["id"], + ) + + @op.function_pod(output_keys=["result"]) + def double(x: int) -> int: + return x * 2 + + pipeline = op.Pipeline("test_fn_nullable", database) + with pipeline: + double.pod(source) + + pipeline.run() + + fn_nodes = _get_function_nodes(pipeline) + assert len(fn_nodes) == 1, "Expected exactly one FunctionNode" + fn_node = fn_nodes[0] + + table = fn_node.get_all_records() + assert table is not None, "get_all_records() returned None after pipeline.run()" + + # "result" column has Python type int → nullable=False in Arrow schema + result_field = table.schema.field("result") + assert result_field.nullable is False, ( + f"Expected 'result' column (int return type) to be non-nullable, " + f"but got nullable={result_field.nullable}. " + "Arrow→Polars→Arrow round-trip in get_all_records() dropped nullability." + ) + + def test_input_tag_column_non_nullable_after_get_all_records(self): + """Input tag columns that are non-nullable must remain so after Polars join.""" + database = InMemoryArrowDatabase() + + # DictSource with integer id — infer_schema_nullable sets nullable=False (no nulls) + source = op.sources.DictSource( + [{"id": 1, "x": 5}, {"id": 2, "x": 15}], + tag_columns=["id"], + ) + + @op.function_pod(output_keys=["result"]) + def triple(x: int) -> int: + return x * 3 + + pipeline = op.Pipeline("test_fn_tag_nullable", database) + with pipeline: + triple.pod(source) + + pipeline.run() + + fn_nodes = _get_function_nodes(pipeline) + fn_node = fn_nodes[0] + + table = fn_node.get_all_records() + assert table is not None + + # "id" was non-nullable in the source; after the Polars join it must stay so + id_field = table.schema.field("id") + assert id_field.nullable is False, ( + f"Expected 'id' tag column to be non-nullable after get_all_records(), " + f"but got nullable={id_field.nullable}." + ) + + +# --------------------------------------------------------------------------- +# FunctionNode.iter_packets nullability +# --------------------------------------------------------------------------- + + +class TestFunctionNodeIterPacketsNullability: + """FunctionNode.iter_packets must yield packets whose underlying Arrow schema + preserves non-nullable column constraints.""" + + def test_iter_packets_from_database_preserves_non_nullable_output(self): + """Packets loaded from DB via iter_packets carry non-nullable output schema.""" + database = InMemoryArrowDatabase() + source = op.sources.DictSource( + [{"id": 1, "x": 7}], + tag_columns=["id"], + ) + + @op.function_pod(output_keys=["result"]) + def add_one(x: int) -> int: + return x + 1 + + pipeline = op.Pipeline("test_iter_packets_nullable", database) + with pipeline: + add_one.pod(source) + + pipeline.run() + + fn_nodes = _get_function_nodes(pipeline) + fn_node = fn_nodes[0] + + # Force a DB-backed iteration by going through _iter_all_from_database + # (simulates the CACHE_ONLY path used after save/load) + packets_seen = list(fn_node._iter_all_from_database()) + assert len(packets_seen) == 1, "Expected one packet from the database" + + _tag, packet = packets_seen[0] + packet_schema = packet.arrow_schema() + + result_field = packet_schema.field("result") + assert result_field.nullable is False, ( + f"Packet 'result' field should be non-nullable (int return type), " + f"but got nullable={result_field.nullable}. " + "Arrow→Polars→Arrow round-trip in iter_packets dropped nullability." + ) + + +# --------------------------------------------------------------------------- +# Join operator nullability +# --------------------------------------------------------------------------- + + +class TestJoinOperatorNullability: + """Join.op_forward must preserve non-nullable tag column flags through the + Polars inner join it uses internally.""" + + def test_join_preserves_non_nullable_shared_tag_column(self): + """Shared tag column remains non-nullable after stream join.""" + # DictSource applies infer_schema_nullable → integer 'id' has nullable=False + source1 = op.sources.DictSource( + [{"id": 1, "x": 10}, {"id": 2, "x": 20}], + tag_columns=["id"], + ) + source2 = op.sources.DictSource( + [{"id": 1, "y": 100}, {"id": 2, "y": 200}], + tag_columns=["id"], + ) + + joined_stream = source1.join(source2) + table = joined_stream.as_table() + + id_field = table.schema.field("id") + assert id_field.nullable is False, ( + f"Expected 'id' tag column to be non-nullable after Join, " + f"but got nullable={id_field.nullable}. " + "Arrow→Polars→Arrow round-trip in Join.op_forward dropped nullability." + ) + + def test_join_preserves_non_nullable_packet_columns(self): + """Packet columns that are non-nullable remain so after stream join.""" + source1 = op.sources.DictSource( + [{"id": 1, "x": 10}, {"id": 2, "x": 20}], + tag_columns=["id"], + ) + source2 = op.sources.DictSource( + [{"id": 1, "y": 100}, {"id": 2, "y": 200}], + tag_columns=["id"], + ) + + joined_stream = source1.join(source2) + table = joined_stream.as_table() + + # "x" and "y" are packet columns with integer values → non-nullable + x_field = table.schema.field("x") + y_field = table.schema.field("y") + + assert x_field.nullable is False, ( + f"Expected 'x' packet column to be non-nullable after Join, " + f"but got nullable={x_field.nullable}." + ) + assert y_field.nullable is False, ( + f"Expected 'y' packet column to be non-nullable after Join, " + f"but got nullable={y_field.nullable}." + ) + + def test_join_preserves_nullable_optional_column_with_no_nulls(self): + """Optional[int] packet column (nullable=True) must remain nullable=True after + Join, even when the data contains no actual null values. + + infer_schema_nullable incorrectly marks it nullable=False because it sees no + nulls in the data. restore_schema_nullability preserves schema intent. + """ + # Explicitly declare "x" as nullable=True (Optional[int]) via pa.Schema + schema1 = pa.schema([ + pa.field("id", pa.int64(), nullable=False), + pa.field("x", pa.int64(), nullable=True), # Optional — no actual nulls + ]) + source1 = op.sources.DictSource( + [{"id": 1, "x": 10}, {"id": 2, "x": 20}], + tag_columns=["id"], + data_schema=schema1, + ) + source2 = op.sources.DictSource( + [{"id": 1, "y": 100}, {"id": 2, "y": 200}], + tag_columns=["id"], + ) + + joined_stream = source1.join(source2) + table = joined_stream.as_table() + + x_field = table.schema.field("x") + assert x_field.nullable is True, ( + f"Expected 'x' (Optional[int], nullable=True in source schema) to remain " + f"nullable=True after Join, but got nullable={x_field.nullable}. " + "infer_schema_nullable incorrectly set it to nullable=False (no actual nulls " + "in data), ignoring schema intent." + ) diff --git a/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py new file mode 100644 index 00000000..ecefeddc --- /dev/null +++ b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py @@ -0,0 +1,186 @@ +""" +Unit tests for the restore_schema_nullability helper in arrow_utils. + +RED phase: all tests in this file must fail before the helper exists. +""" + +import pyarrow as pa +import polars as pl +import pytest + +from orcapod.utils import arrow_utils + + +class TestPolarsRoundTripLosesNullability: + """Demonstrate the root-cause: Polars always produces nullable=True.""" + + def test_polars_roundtrip_makes_non_nullable_fields_nullable(self): + """Polars DataFrame round-trip converts nullable=False to nullable=True.""" + schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("value", pa.large_string(), nullable=False), + ] + ) + table = pa.table({"id": [1, 2, 3], "value": ["a", "b", "c"]}, schema=schema) + + # Precondition: original table has non-nullable fields + assert table.schema.field("id").nullable is False + assert table.schema.field("value").nullable is False + + # After Polars round-trip all fields become nullable (the known bug) + roundtrip = pl.DataFrame(table).to_arrow() + assert roundtrip.schema.field("id").nullable is True + assert roundtrip.schema.field("value").nullable is True + + def test_polars_join_makes_all_result_fields_nullable(self): + """Polars inner join result has nullable=True for all fields.""" + left_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("name", pa.large_string(), nullable=False), + ] + ) + right_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("score", pa.float64(), nullable=False), + ] + ) + left = pa.table({"id": [1, 2], "name": ["a", "b"]}, schema=left_schema) + right = pa.table({"id": [1, 2], "score": [9.5, 8.0]}, schema=right_schema) + + joined = ( + pl.DataFrame(left) + .join(pl.DataFrame(right), on="id", how="inner") + .to_arrow() + ) + + # Bug: all columns nullable after join + assert joined.schema.field("id").nullable is True + assert joined.schema.field("name").nullable is True + assert joined.schema.field("score").nullable is True + + +class TestRestoreSchemaHullability: + """Unit tests for arrow_utils.restore_schema_nullability.""" + + def test_restores_non_nullable_flags_after_polars_roundtrip(self): + """restore_schema_nullability fixes nullable=True caused by Polars.""" + original_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("value", pa.large_string(), nullable=False), + ] + ) + table = pa.table( + {"id": [1, 2, 3], "value": ["a", "b", "c"]}, schema=original_schema + ) + + # Simulate Polars round-trip (loses nullability) + roundtrip = pl.DataFrame(table).to_arrow() + assert roundtrip.schema.field("id").nullable is True # confirms the bug + + # Fix: restore from original schema + restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema) + + assert restored.schema.field("id").nullable is False + assert restored.schema.field("value").nullable is False + + def test_preserves_data_values_after_restore(self): + """restore_schema_nullability does not alter data, only schema metadata.""" + original_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("x", pa.float64(), nullable=False), + ] + ) + table = pa.table({"id": [1, 2], "x": [1.5, 2.5]}, schema=original_schema) + roundtrip = pl.DataFrame(table).to_arrow() + + restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema) + + assert restored.to_pydict() == table.to_pydict() + + def test_preserves_nullable_fields_that_are_nullable_in_reference(self): + """Fields that are nullable in the reference schema remain nullable.""" + original_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("optional_val", pa.large_string(), nullable=True), + ] + ) + table = pa.table( + {"id": [1, 2], "optional_val": ["x", None]}, schema=original_schema + ) + roundtrip = pl.DataFrame(table).to_arrow() + + restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema) + + assert restored.schema.field("id").nullable is False + assert restored.schema.field("optional_val").nullable is True + + def test_leaves_extra_columns_as_nullable(self): + """Columns absent from the reference schema keep Polars-default nullable=True.""" + original_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + ] + ) + full_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("_exists", pa.bool_(), nullable=True), + ] + ) + table = pa.table({"id": [1, 2], "_exists": [True, False]}, schema=full_schema) + roundtrip = pl.DataFrame(table).to_arrow() + + restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema) + + assert restored.schema.field("id").nullable is False + # "_exists" not in reference → left as Polars output (nullable=True) + assert restored.schema.field("_exists").nullable is True + + def test_with_multiple_reference_schemas_later_wins(self): + """When the same field name appears in multiple reference schemas, the last wins.""" + schema_a = pa.schema([pa.field("id", pa.int64(), nullable=True)]) + schema_b = pa.schema([pa.field("id", pa.int64(), nullable=False)]) + + table = pa.table({"id": [1, 2]}) + roundtrip = pl.DataFrame(table).to_arrow() + + # schema_b comes last → nullable=False wins + restored = arrow_utils.restore_schema_nullability(roundtrip, schema_a, schema_b) + assert restored.schema.field("id").nullable is False + + def test_restores_non_nullable_in_polars_join_result(self): + """restore_schema_nullability works on the output of a Polars join.""" + left_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("name", pa.large_string(), nullable=False), + ] + ) + right_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("score", pa.float64(), nullable=False), + ] + ) + left = pa.table({"id": [1, 2], "name": ["a", "b"]}, schema=left_schema) + right = pa.table({"id": [1, 2], "score": [9.5, 8.0]}, schema=right_schema) + + joined = ( + pl.DataFrame(left) + .join(pl.DataFrame(right), on="id", how="inner") + .to_arrow() + ) + + restored = arrow_utils.restore_schema_nullability( + joined, left_schema, right_schema + ) + + assert restored.schema.field("id").nullable is False + assert restored.schema.field("name").nullable is False + assert restored.schema.field("score").nullable is False From e4248c040962c255e88e8914779be984613d8f78 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 03:24:55 +0000 Subject: [PATCH 2/7] test: add explicit tag-column nullability tests for Join (ENG-375) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add TestJoinTagColumnNullability with three cases: 1. Shared tag columns with mixed nullable flags — both "id" (nullable=False) and "group" (nullable=True, Optional, no actual nulls) are preserved through the Polars inner join on those shared keys. 2. Non-shared tag columns from each side of a cartesian-product join — "id" (non-nullable) from source1 and "category" (nullable) from source2 both retain their flags in the result. 3. Three-way join (two Polars iterations) — verifies that restore_schema_ nullability is applied correctly at each iteration, not just the last. Co-Authored-By: Claude Sonnet 4.6 --- .../test_function_node_nullability.py | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/tests/test_data/test_polars_nullability/test_function_node_nullability.py b/tests/test_data/test_polars_nullability/test_function_node_nullability.py index 2f50d93d..9f9298f2 100644 --- a/tests/test_data/test_polars_nullability/test_function_node_nullability.py +++ b/tests/test_data/test_polars_nullability/test_function_node_nullability.py @@ -4,6 +4,7 @@ RED phase: tests should fail before the fix is applied. """ +# ruff: noqa: E501 import pyarrow as pa import pytest @@ -235,3 +236,154 @@ def test_join_preserves_nullable_optional_column_with_no_nulls(self): "infer_schema_nullable incorrectly set it to nullable=False (no actual nulls " "in data), ignoring schema intent." ) + + +# --------------------------------------------------------------------------- +# Join tag-column nullability with mixed nullable/non-nullable tag keys +# --------------------------------------------------------------------------- + + +class TestJoinTagColumnNullability: + """Join must preserve the exact nullable flag of every tag column — + both non-nullable mandatory keys and nullable optional keys — through + the Polars inner join used internally.""" + + def test_shared_tag_columns_mixed_nullability_preserved(self): + """When two sources share multiple tag columns with mixed nullable flags, + each flag is preserved correctly after the join. + + Schema intent: + - "id" int64 nullable=False (mandatory join key) + - "group" utf8 nullable=True (Optional grouping key, no actual nulls) + """ + tag_schema = pa.schema([ + pa.field("id", pa.int64(), nullable=False), + pa.field("group", pa.utf8(), nullable=True), + ]) + schema1 = pa.schema([ + *tag_schema, + pa.field("x", pa.int64(), nullable=False), + ]) + schema2 = pa.schema([ + *tag_schema, + pa.field("y", pa.int64(), nullable=False), + ]) + + source1 = op.sources.DictSource( + [{"id": 1, "group": "a", "x": 10}, {"id": 2, "group": "b", "x": 20}], + tag_columns=["id", "group"], + data_schema=schema1, + ) + source2 = op.sources.DictSource( + [{"id": 1, "group": "a", "y": 100}, {"id": 2, "group": "b", "y": 200}], + tag_columns=["id", "group"], + data_schema=schema2, + ) + + table = source1.join(source2).as_table() + + id_field = table.schema.field("id") + group_field = table.schema.field("group") + + assert id_field.nullable is False, ( + f"'id' (non-nullable tag) must remain nullable=False after Join, " + f"got nullable={id_field.nullable}." + ) + assert group_field.nullable is True, ( + f"'group' (Optional tag, nullable=True) must remain nullable=True after Join " + f"even though data contains no actual nulls, got nullable={group_field.nullable}." + ) + + def test_non_shared_tag_columns_mixed_nullability_preserved(self): + """Tag columns that are unique to each side of a join (non-shared) also + preserve their nullable flags in the combined result. + + source1 has tag "id" (non-nullable int). + source2 has tag "category" (nullable string, Optional, no actual nulls). + Neither tag is shared, so the join is a full cartesian product. + Both tag columns appear in the result and must keep their original nullable flags. + """ + schema1 = pa.schema([ + pa.field("id", pa.int64(), nullable=False), + pa.field("x", pa.int64(), nullable=False), + ]) + schema2 = pa.schema([ + pa.field("category", pa.utf8(), nullable=True), # Optional tag + pa.field("y", pa.int64(), nullable=False), + ]) + + source1 = op.sources.DictSource( + [{"id": 1, "x": 10}], + tag_columns=["id"], + data_schema=schema1, + ) + source2 = op.sources.DictSource( + [{"category": "alpha", "y": 100}], + tag_columns=["category"], + data_schema=schema2, + ) + + table = source1.join(source2).as_table() + + id_field = table.schema.field("id") + category_field = table.schema.field("category") + + assert id_field.nullable is False, ( + f"'id' (non-nullable tag from source1) must remain nullable=False after " + f"cartesian join, got nullable={id_field.nullable}." + ) + assert category_field.nullable is True, ( + f"'category' (Optional tag, nullable=True from source2) must remain " + f"nullable=True after cartesian join even with no actual nulls, " + f"got nullable={category_field.nullable}." + ) + + def test_three_way_join_tag_nullability_preserved(self): + """A three-way join (two Polars join iterations) correctly restores nullable + flags on all tag columns across both iterations. + + shared tag "id" int64 nullable=False + shared tag "group" utf8 nullable=True (Optional, no actual nulls) + """ + tag_schema = pa.schema([ + pa.field("id", pa.int64(), nullable=False), + pa.field("group", pa.utf8(), nullable=True), + ]) + schema1 = pa.schema([*tag_schema, pa.field("a", pa.int64(), nullable=False)]) + schema2 = pa.schema([*tag_schema, pa.field("b", pa.int64(), nullable=True)]) # b is Optional + schema3 = pa.schema([*tag_schema, pa.field("c", pa.int64(), nullable=False)]) + + source1 = op.sources.DictSource( + [{"id": 1, "group": "x", "a": 1}], + tag_columns=["id", "group"], + data_schema=schema1, + ) + source2 = op.sources.DictSource( + [{"id": 1, "group": "x", "b": 2}], + tag_columns=["id", "group"], + data_schema=schema2, + ) + source3 = op.sources.DictSource( + [{"id": 1, "group": "x", "c": 3}], + tag_columns=["id", "group"], + data_schema=schema3, + ) + + table = source1.join(source2).join(source3).as_table() + + id_field = table.schema.field("id") + group_field = table.schema.field("group") + b_field = table.schema.field("b") + + assert id_field.nullable is False, ( + f"'id' (non-nullable tag) must remain nullable=False after 3-way join, " + f"got nullable={id_field.nullable}." + ) + assert group_field.nullable is True, ( + f"'group' (Optional tag) must remain nullable=True after 3-way join, " + f"got nullable={group_field.nullable}." + ) + assert b_field.nullable is True, ( + f"'b' (Optional packet column) must remain nullable=True after 3-way join, " + f"got nullable={b_field.nullable}." + ) From 44dc263dac91a9dccab33e48a1ed278d10b02433 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 03:54:38 +0000 Subject: [PATCH 3/7] fix: declare system-tag and source-info columns as nullable=False at creation add_system_tag_columns and add_source_info were appending columns with table.append_column(string_name, array), which Arrow defaults to nullable=True. Those columns are always computed and never null, so nullable=False is the correct schema intent. This caused a regression in TestJoinOutputSchemaSystemTags after restore_schema_nullability replaced infer_schema_nullable in Join: restore_schema_nullability faithfully preserves nullable=True for these columns (from the reference schema captured before the Polars join), while the operator's output_schema prediction expected nullable=False. Fix: pass an explicit pa.field(..., nullable=False) to append_column in both functions, so the schema is correct from the moment of creation. Co-Authored-By: Claude Sonnet 4.6 --- src/orcapod/utils/arrow_utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/orcapod/utils/arrow_utils.py b/src/orcapod/utils/arrow_utils.py index 05ec455a..33d1e8dc 100644 --- a/src/orcapod/utils/arrow_utils.py +++ b/src/orcapod/utils/arrow_utils.py @@ -967,8 +967,14 @@ def add_system_tag_columns( source_id_array = pa.array(source_ids, type=pa.large_string()) record_id_array = pa.array(record_ids, type=pa.large_string()) - table = table.append_column(source_id_col_name, source_id_array) - table = table.append_column(record_id_col_name, record_id_array) + # System tag columns are always computed, never null — declare nullable=False + # explicitly so the schema intent is not lost in Polars round-trips. + table = table.append_column( + pa.field(source_id_col_name, pa.large_string(), nullable=False), source_id_array + ) + table = table.append_column( + pa.field(record_id_col_name, pa.large_string(), nullable=False), record_id_array + ) return table @@ -1164,7 +1170,11 @@ def add_source_info( [f"{source_val}::{col}" for source_val in source_column], type=pa.large_string(), ) - table = table.append_column(f"{constants.SOURCE_PREFIX}{col}", source_column) + # Source info columns are always computed strings, never null. + table = table.append_column( + pa.field(f"{constants.SOURCE_PREFIX}{col}", pa.large_string(), nullable=False), + source_column, + ) return table From 3218c4cf60a5ed5dbdef02098bb0bd338b603355 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 04:28:20 +0000 Subject: [PATCH 4/7] fix: address PR review comments (ENG-375) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove unused `import pytest` from both test modules - Rename TestRestoreSchemaHullability → TestRestoreSchemaNullability (typo fix) - Fix source_column accumulation bug in add_source_info: each _source_ column was being built from the array produced for the previous column, causing tokens like "src::col1::col2" instead of "src::col2". Introduce base_source captured once before the loop so every column is independently derived from the original per-row source tokens. - Apply restore_schema_nullability after the Polars sort in FunctionPodStream.as_table (function_pod.py) to match the same fix already applied in FunctionNode.as_table, covering the stream materialization path. Co-Authored-By: Claude Sonnet 4.6 --- src/orcapod/core/function_pod.py | 2 ++ src/orcapod/utils/arrow_utils.py | 18 ++++++++++-------- .../test_function_node_nullability.py | 1 - .../test_restore_schema_nullability.py | 3 +-- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index f9964106..95bceed3 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -649,11 +649,13 @@ def as_table( if column_config.sort_by_tags: # TODO: reimplement using polars natively + output_table_schema = output_table.schema output_table = ( pl.DataFrame(output_table) .sort(by=self.keys()[0], descending=False) .to_arrow() ) + output_table = arrow_utils.restore_schema_nullability(output_table, output_table_schema) # output_table = output_table.sort_by( # [(column, "ascending") for column in self.keys()[0]] # ) diff --git a/src/orcapod/utils/arrow_utils.py b/src/orcapod/utils/arrow_utils.py index 33d1e8dc..8ae928cb 100644 --- a/src/orcapod/utils/arrow_utils.py +++ b/src/orcapod/utils/arrow_utils.py @@ -1151,29 +1151,31 @@ def add_source_info( exclude_columns: Collection[str] = (), ) -> "pa.Table": """Add source information to an Arrow table.""" - # Create a new column with the source information + # Create a base list of per-row source tokens; one entry per row. if source_info is None or isinstance(source_info, str): - source_column = [source_info] * table.num_rows + base_source = [source_info] * table.num_rows elif isinstance(source_info, Collection): if len(source_info) != table.num_rows: raise ValueError( "Length of source_info collection must match number of rows in the table." ) - source_column = source_info - - # identify columns for which source columns should be created + base_source = list(source_info) + # For each data column, build an independent _source_ column from the + # base tokens. We must NOT re-use the array produced for a previous column + # as input for the next one — doing so would accumulate column names + # (e.g. "src::col1::col2" instead of "src::col2"). for col in table.column_names: if col.startswith(tuple(exclude_prefixes)) or col in exclude_columns: continue - source_column = pa.array( - [f"{source_val}::{col}" for source_val in source_column], + col_source = pa.array( + [f"{source_val}::{col}" for source_val in base_source], type=pa.large_string(), ) # Source info columns are always computed strings, never null. table = table.append_column( pa.field(f"{constants.SOURCE_PREFIX}{col}", pa.large_string(), nullable=False), - source_column, + col_source, ) return table diff --git a/tests/test_data/test_polars_nullability/test_function_node_nullability.py b/tests/test_data/test_polars_nullability/test_function_node_nullability.py index 9f9298f2..9114b8aa 100644 --- a/tests/test_data/test_polars_nullability/test_function_node_nullability.py +++ b/tests/test_data/test_polars_nullability/test_function_node_nullability.py @@ -7,7 +7,6 @@ # ruff: noqa: E501 import pyarrow as pa -import pytest import orcapod as op from orcapod.core.nodes.function_node import FunctionNode diff --git a/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py index ecefeddc..903979fa 100644 --- a/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py +++ b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py @@ -6,7 +6,6 @@ import pyarrow as pa import polars as pl -import pytest from orcapod.utils import arrow_utils @@ -62,7 +61,7 @@ def test_polars_join_makes_all_result_fields_nullable(self): assert joined.schema.field("score").nullable is True -class TestRestoreSchemaHullability: +class TestRestoreSchemaNullability: """Unit tests for arrow_utils.restore_schema_nullability.""" def test_restores_non_nullable_flags_after_polars_roundtrip(self): From b3784657849c4462b4852f46b4993d99dc5537b0 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 04:32:10 +0000 Subject: [PATCH 5/7] test: add regression tests for add_source_info accumulation bug (ENG-375) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TestAddSourceInfo in test_arrow_utils.py covers four cases: 1. Single column — baseline: _source_x = "src::x". 2. Multi-column — the regression case: each _source_ must equal "base::", not "base::prev_col::". With the bug, _source_y would be "base::x::y" because source_column was reused across iterations. 3. Per-row source tokens — same accumulation bug with a list input: _source_b row 0 must be "src0::b", not "src0::a::b". 4. Column count — one _source_ per data column, no more. Co-Authored-By: Claude Sonnet 4.6 --- tests/test_utils/test_arrow_utils.py | 77 ++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test_utils/test_arrow_utils.py b/tests/test_utils/test_arrow_utils.py index 1ddbba1c..f0d5cdcd 100644 --- a/tests/test_utils/test_arrow_utils.py +++ b/tests/test_utils/test_arrow_utils.py @@ -4,6 +4,7 @@ import pytest from orcapod.utils.arrow_utils import ( + add_source_info, infer_schema_nullable, make_schema_non_nullable, prepare_prefixed_columns, @@ -175,3 +176,79 @@ def test_none_schema_metadata_stays_none(self): table = pa.table({"x": pa.array([1, 2], type=pa.int64())}) result = infer_schema_nullable(table) assert not result.metadata + + +# --------------------------------------------------------------------------- +# add_source_info +# --------------------------------------------------------------------------- + + +class TestAddSourceInfo: + """add_source_info must produce one independent _source_ column per + data column, each derived from the original source token — not from the + output of a previous column's computation. + + The regression being guarded against: source_column was overwritten inside + the per-column loop, so _source_col2 would contain "src::col1::col2" + instead of "src::col2". + """ + + def test_single_column_produces_correct_source_token(self): + """Baseline: a single data column gets _source_ = '::'.""" + table = pa.table({"x": pa.array([10, 20], type=pa.int64())}) + result = add_source_info(table, "mysrc") + + source_vals = result.column("_source_x").to_pylist() + assert source_vals == ["mysrc::x", "mysrc::x"] + + def test_multi_column_each_source_derives_from_base_token(self): + """Each _source_ must equal '::', not '::prev_col::'. + + With the bug present, _source_y would be "mysrc::x::y" because the loop + re-used the pa.array produced for _source_x as the starting point for _source_y. + """ + table = pa.table({ + "x": pa.array([1, 2], type=pa.int64()), + "y": pa.array([3, 4], type=pa.int64()), + "z": pa.array([5, 6], type=pa.int64()), + }) + result = add_source_info(table, "base") + + assert result.column("_source_x").to_pylist() == ["base::x", "base::x"] + assert result.column("_source_y").to_pylist() == ["base::y", "base::y"], ( + "_source_y must be 'base::y', not 'base::x::y' — " + "each column must derive from the original source token, not from the " + "previous column's output array." + ) + assert result.column("_source_z").to_pylist() == ["base::z", "base::z"] + + def test_per_row_source_tokens_not_accumulated(self): + """Per-row source tokens are also built independently per column. + + With the bug, row 0 of _source_b would be "src0::a::b" instead of "src0::b". + """ + table = pa.table({ + "a": pa.array([10, 20], type=pa.int64()), + "b": pa.array([30, 40], type=pa.int64()), + }) + result = add_source_info(table, ["src0", "src1"]) + + assert result.column("_source_a").to_pylist() == ["src0::a", "src1::a"] + assert result.column("_source_b").to_pylist() == ["src0::b", "src1::b"], ( + "_source_b row 0 must be 'src0::b', not 'src0::a::b'." + ) + + def test_correct_number_of_source_columns_added(self): + """One _source_ column is added for each non-excluded data column.""" + table = pa.table({ + "p": pa.array([1], type=pa.int64()), + "q": pa.array([2], type=pa.int64()), + "r": pa.array([3], type=pa.int64()), + }) + result = add_source_info(table, "s") + + # Original 3 columns + 3 source columns = 6 + assert result.num_columns == 6 + assert "_source_p" in result.column_names + assert "_source_q" in result.column_names + assert "_source_r" in result.column_names From 0e73020a5808a702583814e9f462f3a1647f2457 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 04:33:35 +0000 Subject: [PATCH 6/7] test: clean up add_source_info tests to assert expected form only (ENG-375) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove assertion messages and docstring prose that described the buggy output pattern. Tests now simply assert the exact contract — each _source_ value is '::' — which is sufficient to catch the regression without mentioning what the wrong value would be. Co-Authored-By: Claude Sonnet 4.6 --- tests/test_utils/test_arrow_utils.py | 45 ++++++++-------------------- 1 file changed, 12 insertions(+), 33 deletions(-) diff --git a/tests/test_utils/test_arrow_utils.py b/tests/test_utils/test_arrow_utils.py index f0d5cdcd..2831b90a 100644 --- a/tests/test_utils/test_arrow_utils.py +++ b/tests/test_utils/test_arrow_utils.py @@ -184,29 +184,18 @@ def test_none_schema_metadata_stays_none(self): class TestAddSourceInfo: - """add_source_info must produce one independent _source_ column per - data column, each derived from the original source token — not from the - output of a previous column's computation. - - The regression being guarded against: source_column was overwritten inside - the per-column loop, so _source_col2 would contain "src::col1::col2" - instead of "src::col2". - """ + """add_source_info must produce one _source_ column per data column + whose values are exactly '::' for every row.""" def test_single_column_produces_correct_source_token(self): - """Baseline: a single data column gets _source_ = '::'.""" + """_source_ value is '::' for every row.""" table = pa.table({"x": pa.array([10, 20], type=pa.int64())}) result = add_source_info(table, "mysrc") - source_vals = result.column("_source_x").to_pylist() - assert source_vals == ["mysrc::x", "mysrc::x"] - - def test_multi_column_each_source_derives_from_base_token(self): - """Each _source_ must equal '::', not '::prev_col::'. + assert result.column("_source_x").to_pylist() == ["mysrc::x", "mysrc::x"] - With the bug present, _source_y would be "mysrc::x::y" because the loop - re-used the pa.array produced for _source_x as the starting point for _source_y. - """ + def test_multi_column_each_source_column_uses_its_own_name(self): + """Each _source_ value is '::' — columns are independent.""" table = pa.table({ "x": pa.array([1, 2], type=pa.int64()), "y": pa.array([3, 4], type=pa.int64()), @@ -215,18 +204,11 @@ def test_multi_column_each_source_derives_from_base_token(self): result = add_source_info(table, "base") assert result.column("_source_x").to_pylist() == ["base::x", "base::x"] - assert result.column("_source_y").to_pylist() == ["base::y", "base::y"], ( - "_source_y must be 'base::y', not 'base::x::y' — " - "each column must derive from the original source token, not from the " - "previous column's output array." - ) + assert result.column("_source_y").to_pylist() == ["base::y", "base::y"] assert result.column("_source_z").to_pylist() == ["base::z", "base::z"] - def test_per_row_source_tokens_not_accumulated(self): - """Per-row source tokens are also built independently per column. - - With the bug, row 0 of _source_b would be "src0::a::b" instead of "src0::b". - """ + def test_per_row_source_tokens_combined_with_column_name(self): + """With a per-row source list, each row's token is '::'.""" table = pa.table({ "a": pa.array([10, 20], type=pa.int64()), "b": pa.array([30, 40], type=pa.int64()), @@ -234,12 +216,10 @@ def test_per_row_source_tokens_not_accumulated(self): result = add_source_info(table, ["src0", "src1"]) assert result.column("_source_a").to_pylist() == ["src0::a", "src1::a"] - assert result.column("_source_b").to_pylist() == ["src0::b", "src1::b"], ( - "_source_b row 0 must be 'src0::b', not 'src0::a::b'." - ) + assert result.column("_source_b").to_pylist() == ["src0::b", "src1::b"] def test_correct_number_of_source_columns_added(self): - """One _source_ column is added for each non-excluded data column.""" + """Exactly one _source_ column is added per non-excluded data column.""" table = pa.table({ "p": pa.array([1], type=pa.int64()), "q": pa.array([2], type=pa.int64()), @@ -247,8 +227,7 @@ def test_correct_number_of_source_columns_added(self): }) result = add_source_info(table, "s") - # Original 3 columns + 3 source columns = 6 - assert result.num_columns == 6 + assert result.num_columns == 6 # 3 data + 3 source assert "_source_p" in result.column_names assert "_source_q" in result.column_names assert "_source_r" in result.column_names From d4bcd7295fddeba35f8b090be80c405210cf48b9 Mon Sep 17 00:00:00 2001 From: "agent-kurodo[bot]" <268466204+agent-kurodo[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 05:36:29 +0000 Subject: [PATCH 7/7] fix: address second round of PR review comments (ENG-375) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - join.py: replace Polars rename round-trip with Arrow-native rename_columns(), eliminating an unnecessary Arrow→Polars→Arrow conversion before the join. - arrow_utils.py: add explicit else/TypeError in add_source_info so callers passing an unsized iterable (e.g. a generator) get a clear error instead of an UnboundLocalError from base_source being unbound. - test_function_node_nullability.py: remove stale "RED phase" phrasing from module docstring; replace with description of the behavioral contract being validated. - test_restore_schema_nullability.py: same stale docstring fix. Co-Authored-By: Claude Sonnet 4.6 --- src/orcapod/core/operators/join.py | 7 +++++-- src/orcapod/utils/arrow_utils.py | 5 +++++ .../test_function_node_nullability.py | 7 +++---- .../test_restore_schema_nullability.py | 6 +++--- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/orcapod/core/operators/join.py b/src/orcapod/core/operators/join.py index 17768014..7937034b 100644 --- a/src/orcapod/core/operators/join.py +++ b/src/orcapod/core/operators/join.py @@ -174,14 +174,17 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol: # Build a reference schema for next_table with rename_map applied to # field names, preserving nullable flags — must be done BEFORE the - # Polars rename call below, which loses all nullability information. + # rename so we capture the original schema. next_ref_schema = pa.schema([ pa.field(rename_map.get(f.name, f.name), f.type, nullable=f.nullable, metadata=f.metadata) for f in next_table.schema ]) if rename_map: - next_table = pl.DataFrame(next_table).rename(rename_map).to_arrow() + # Use Arrow-native rename to avoid an unnecessary Polars round-trip. + next_table = next_table.rename_columns( + [rename_map.get(name, name) for name in next_table.column_names] + ) common_tag_keys = tag_keys.intersection(next_tag_keys) common_tag_keys.add(COMMON_JOIN_KEY) diff --git a/src/orcapod/utils/arrow_utils.py b/src/orcapod/utils/arrow_utils.py index 8ae928cb..13ba410c 100644 --- a/src/orcapod/utils/arrow_utils.py +++ b/src/orcapod/utils/arrow_utils.py @@ -1160,6 +1160,11 @@ def add_source_info( "Length of source_info collection must match number of rows in the table." ) base_source = list(source_info) + else: + raise TypeError( + f"source_info must be a str, a sized Collection[str], or None; " + f"got {type(source_info).__name__}" + ) # For each data column, build an independent _source_ column from the # base tokens. We must NOT re-use the array produced for a previous column diff --git a/tests/test_data/test_polars_nullability/test_function_node_nullability.py b/tests/test_data/test_polars_nullability/test_function_node_nullability.py index 9114b8aa..ec845c6a 100644 --- a/tests/test_data/test_polars_nullability/test_function_node_nullability.py +++ b/tests/test_data/test_polars_nullability/test_function_node_nullability.py @@ -1,8 +1,7 @@ """ -Integration tests: FunctionNode and Join preserve non-nullable column constraints -after the Arrow → Polars → Arrow round-trip that occurs during joins. - -RED phase: tests should fail before the fix is applied. +Integration tests validating that FunctionNode and Join preserve nullable column +constraints across the Arrow → Polars → Arrow round-trip that occurs during +joins and cached-record retrieval. """ # ruff: noqa: E501 diff --git a/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py index 903979fa..4c24d0da 100644 --- a/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py +++ b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py @@ -1,7 +1,7 @@ """ -Unit tests for the restore_schema_nullability helper in arrow_utils. - -RED phase: all tests in this file must fail before the helper exists. +These tests document that Polars round-trips and joins widen all Arrow field +nullability to nullable=True, and verify that restore_schema_nullability +correctly reapplies the original nullable flags from reference schemas. """ import pyarrow as pa