diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index f9964106..95bceed3 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -649,11 +649,13 @@ def as_table( if column_config.sort_by_tags: # TODO: reimplement using polars natively + output_table_schema = output_table.schema output_table = ( pl.DataFrame(output_table) .sort(by=self.keys()[0], descending=False) .to_arrow() ) + output_table = arrow_utils.restore_schema_nullability(output_table, output_table_schema) # output_table = output_table.sort_by( # [(column, "ascending") for column in self.keys()[0]] # ) diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index f4bf4b05..a360ad0e 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -743,7 +743,8 @@ def get_cached_results( return {} taginfo = self._filter_by_content_hash(taginfo) - + taginfo_schema = taginfo.schema + results_schema = results.schema filtered = ( pl.DataFrame(taginfo) .join( @@ -754,6 +755,7 @@ def get_cached_results( .filter(pl.col(PIPELINE_ENTRY_ID_COL).is_in(entry_ids)) .to_arrow() ) + filtered = arrow_utils.restore_schema_nullability(filtered, taginfo_schema, results_schema) if filtered.num_rows == 0: return {} @@ -987,12 +989,14 @@ def get_all_records( return None taginfo = self._filter_by_content_hash(taginfo) - + taginfo_schema = taginfo.schema + results_schema = results.schema joined = ( pl.DataFrame(taginfo) .join(pl.DataFrame(results), on=constants.PACKET_RECORD_ID, how="inner") .to_arrow() ) + joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) column_config = ColumnConfig.handle_config(columns, all_info=all_info) @@ -1072,7 +1076,8 @@ def _load_all_cached_records( return None taginfo = self._filter_by_content_hash(taginfo) - + taginfo_schema = taginfo.schema + results_schema = results.schema joined = ( pl.DataFrame(taginfo) .join( @@ -1082,6 +1087,7 @@ def _load_all_cached_records( ) .to_arrow() ) + joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) if joined.num_rows == 0: return None @@ -1202,6 +1208,8 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: if taginfo is not None and results is not None: taginfo = self._filter_by_content_hash(taginfo) + taginfo_schema = taginfo.schema + results_schema = results.schema joined = ( pl.DataFrame(taginfo) .join( @@ -1211,6 +1219,7 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: ) .to_arrow() ) + joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) if joined.num_rows > 0: tag_keys = self._input_stream.keys()[0] # Collect pipeline entry_ids for Phase 2 skip check @@ -1435,11 +1444,13 @@ def as_table( ) if column_config.sort_by_tags: + output_table_schema = output_table.schema output_table = ( pl.DataFrame(output_table) .sort(by=self.keys()[0], descending=False) .to_arrow() ) + output_table = arrow_utils.restore_schema_nullability(output_table, output_table_schema) return output_table # ------------------------------------------------------------------ @@ -1524,6 +1535,8 @@ async def async_execute( if taginfo is not None and results is not None: taginfo = self._filter_by_content_hash(taginfo) + taginfo_schema = taginfo.schema + results_schema = results.schema joined = ( pl.DataFrame(taginfo) .join( @@ -1533,6 +1546,7 @@ async def async_execute( ) .to_arrow() ) + joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) if joined.num_rows > 0: tag_keys = self._input_stream.keys()[0] entry_ids_col = joined.column( diff --git a/src/orcapod/core/operators/join.py b/src/orcapod/core/operators/join.py index f4ddf479..7937034b 100644 --- a/src/orcapod/core/operators/join.py +++ b/src/orcapod/core/operators/join.py @@ -171,17 +171,33 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol: counter += 1 new_name = f"{col}_{counter}" rename_map[col] = new_name + + # Build a reference schema for next_table with rename_map applied to + # field names, preserving nullable flags — must be done BEFORE the + # rename so we capture the original schema. + next_ref_schema = pa.schema([ + pa.field(rename_map.get(f.name, f.name), f.type, nullable=f.nullable, metadata=f.metadata) + for f in next_table.schema + ]) + if rename_map: - next_table = pl.DataFrame(next_table).rename(rename_map).to_arrow() + # Use Arrow-native rename to avoid an unnecessary Polars round-trip. + next_table = next_table.rename_columns( + [rename_map.get(name, name) for name in next_table.column_names] + ) common_tag_keys = tag_keys.intersection(next_tag_keys) common_tag_keys.add(COMMON_JOIN_KEY) + # Capture the left-side schema before the Polars join, which sets all + # fields to nullable=True regardless of the original schema. + table_ref_schema = table.schema table = ( pl.DataFrame(table) .join(pl.DataFrame(next_table), on=list(common_tag_keys), how="inner") .to_arrow() ) + table = arrow_utils.restore_schema_nullability(table, table_ref_schema, next_ref_schema) tag_keys.update(next_tag_keys) @@ -196,13 +212,6 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol: reordered_columns += [col for col in table.column_names if col not in tag_keys] result_table = table.select(reordered_columns) - # Derive nullable per column from actual null counts so that: - # - Columns with no nulls (e.g. tag/packet fields after inner join) get - # nullable=False, avoiding spurious T | None from Polars' all-True default. - # - Columns that genuinely contain nulls (e.g. Optional fields, or source - # info columns after cross-stream joins) keep nullable=True, preventing - # cast failures. - result_table = result_table.cast(arrow_utils.infer_schema_nullable(result_table)) return ArrowTableStream( result_table, tag_columns=tuple(tag_keys), diff --git a/src/orcapod/utils/arrow_utils.py b/src/orcapod/utils/arrow_utils.py index 9307cf0d..13ba410c 100644 --- a/src/orcapod/utils/arrow_utils.py +++ b/src/orcapod/utils/arrow_utils.py @@ -831,6 +831,77 @@ def infer_schema_nullable(table: "pa.Table") -> "pa.Schema": ) +def restore_schema_nullability( + table: "pa.Table", + *reference_schemas: "pa.Schema", +) -> "pa.Table": + """Restore nullable flags lost during an Arrow → Polars → Arrow round-trip. + + Polars converts all Arrow fields to ``nullable=True`` when producing its + Arrow output. This function repairs the resulting table by looking up each + field name in the supplied reference schemas and reinstating the original + ``nullable`` flag (and type) from those references. + + Fields that are **not** found in any reference schema are left exactly as + Polars produced them (``nullable=True``), which is safe for internally + generated sentinel columns (e.g. ``_exists``, ``__pipeline_entry_id``). + + When the same field name appears in multiple reference schemas the + **last-supplied** schema wins, so callers can pass ``(left_schema, + right_schema)`` and get right-side nullability for join-key columns that + appear in both. + + Args: + table: Arrow table produced by a Polars round-trip (all fields + ``nullable=True``). + *reference_schemas: One or more Arrow schemas carrying the original + nullable flags. Pass the schemas of every table that participated + in the Polars join/sort so that all columns are covered. + + Returns: + A new Arrow table cast to the restored schema. Data is unchanged; + only the schema metadata (nullability, field type, field metadata) + is corrected. + + Raises: + pyarrow.ArrowInvalid: If a restored field type is incompatible with + the actual column data (should not happen for well-formed + round-trips, but surfaced to the caller rather than silently + discarded). + + Example:: + + taginfo_schema = taginfo.schema + results_schema = results.schema + joined = ( + pl.DataFrame(taginfo) + .join(pl.DataFrame(results), on="record_id", how="inner") + .to_arrow() + ) + joined = arrow_utils.restore_schema_nullability( + joined, taginfo_schema, results_schema + ) + """ + # Build a name → Field lookup; later schemas override earlier ones. + field_lookup: dict[str, "pa.Field"] = {} + for schema in reference_schemas: + for field in schema: + field_lookup[field.name] = field + + restored_fields = [] + for field in table.schema: + ref = field_lookup.get(field.name) + if ref is not None: + restored_fields.append( + pa.field(field.name, ref.type, nullable=ref.nullable, metadata=ref.metadata) + ) + else: + restored_fields.append(field) + + restored_schema = pa.schema(restored_fields, metadata=table.schema.metadata) + return table.cast(restored_schema) + + def drop_columns_with_prefix( table: "pa.Table", prefix: str | tuple[str, ...], @@ -896,8 +967,14 @@ def add_system_tag_columns( source_id_array = pa.array(source_ids, type=pa.large_string()) record_id_array = pa.array(record_ids, type=pa.large_string()) - table = table.append_column(source_id_col_name, source_id_array) - table = table.append_column(record_id_col_name, record_id_array) + # System tag columns are always computed, never null — declare nullable=False + # explicitly so the schema intent is not lost in Polars round-trips. + table = table.append_column( + pa.field(source_id_col_name, pa.large_string(), nullable=False), source_id_array + ) + table = table.append_column( + pa.field(record_id_col_name, pa.large_string(), nullable=False), record_id_array + ) return table @@ -1074,26 +1151,37 @@ def add_source_info( exclude_columns: Collection[str] = (), ) -> "pa.Table": """Add source information to an Arrow table.""" - # Create a new column with the source information + # Create a base list of per-row source tokens; one entry per row. if source_info is None or isinstance(source_info, str): - source_column = [source_info] * table.num_rows + base_source = [source_info] * table.num_rows elif isinstance(source_info, Collection): if len(source_info) != table.num_rows: raise ValueError( "Length of source_info collection must match number of rows in the table." ) - source_column = source_info - - # identify columns for which source columns should be created + base_source = list(source_info) + else: + raise TypeError( + f"source_info must be a str, a sized Collection[str], or None; " + f"got {type(source_info).__name__}" + ) + # For each data column, build an independent _source_ column from the + # base tokens. We must NOT re-use the array produced for a previous column + # as input for the next one — doing so would accumulate column names + # (e.g. "src::col1::col2" instead of "src::col2"). for col in table.column_names: if col.startswith(tuple(exclude_prefixes)) or col in exclude_columns: continue - source_column = pa.array( - [f"{source_val}::{col}" for source_val in source_column], + col_source = pa.array( + [f"{source_val}::{col}" for source_val in base_source], type=pa.large_string(), ) - table = table.append_column(f"{constants.SOURCE_PREFIX}{col}", source_column) + # Source info columns are always computed strings, never null. + table = table.append_column( + pa.field(f"{constants.SOURCE_PREFIX}{col}", pa.large_string(), nullable=False), + col_source, + ) return table diff --git a/tests/test_data/test_polars_nullability/__init__.py b/tests/test_data/test_polars_nullability/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_data/test_polars_nullability/test_function_node_nullability.py b/tests/test_data/test_polars_nullability/test_function_node_nullability.py new file mode 100644 index 00000000..ec845c6a --- /dev/null +++ b/tests/test_data/test_polars_nullability/test_function_node_nullability.py @@ -0,0 +1,387 @@ +""" +Integration tests validating that FunctionNode and Join preserve nullable column +constraints across the Arrow → Polars → Arrow round-trip that occurs during +joins and cached-record retrieval. +""" +# ruff: noqa: E501 + +import pyarrow as pa + +import orcapod as op +from orcapod.core.nodes.function_node import FunctionNode +from orcapod.databases import InMemoryArrowDatabase + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_function_nodes(pipeline: op.Pipeline) -> list[FunctionNode]: + """Return all FunctionNode instances from compiled pipeline nodes.""" + return [n for n in pipeline.compiled_nodes.values() if isinstance(n, FunctionNode)] + + +# --------------------------------------------------------------------------- +# FunctionNode.get_all_records nullability +# --------------------------------------------------------------------------- + + +class TestFunctionNodeGetAllRecordsNullability: + """FunctionNode.get_all_records must preserve the non-nullable schema of + output columns whose Python type annotation is non-optional (e.g. ``int``).""" + + def test_non_optional_return_type_yields_non_nullable_output_column(self): + """Output column from int return type must be non-nullable after get_all_records.""" + database = InMemoryArrowDatabase() + source = op.sources.DictSource( + [{"id": 1, "x": 10}, {"id": 2, "x": 20}], + tag_columns=["id"], + ) + + @op.function_pod(output_keys=["result"]) + def double(x: int) -> int: + return x * 2 + + pipeline = op.Pipeline("test_fn_nullable", database) + with pipeline: + double.pod(source) + + pipeline.run() + + fn_nodes = _get_function_nodes(pipeline) + assert len(fn_nodes) == 1, "Expected exactly one FunctionNode" + fn_node = fn_nodes[0] + + table = fn_node.get_all_records() + assert table is not None, "get_all_records() returned None after pipeline.run()" + + # "result" column has Python type int → nullable=False in Arrow schema + result_field = table.schema.field("result") + assert result_field.nullable is False, ( + f"Expected 'result' column (int return type) to be non-nullable, " + f"but got nullable={result_field.nullable}. " + "Arrow→Polars→Arrow round-trip in get_all_records() dropped nullability." + ) + + def test_input_tag_column_non_nullable_after_get_all_records(self): + """Input tag columns that are non-nullable must remain so after Polars join.""" + database = InMemoryArrowDatabase() + + # DictSource with integer id — infer_schema_nullable sets nullable=False (no nulls) + source = op.sources.DictSource( + [{"id": 1, "x": 5}, {"id": 2, "x": 15}], + tag_columns=["id"], + ) + + @op.function_pod(output_keys=["result"]) + def triple(x: int) -> int: + return x * 3 + + pipeline = op.Pipeline("test_fn_tag_nullable", database) + with pipeline: + triple.pod(source) + + pipeline.run() + + fn_nodes = _get_function_nodes(pipeline) + fn_node = fn_nodes[0] + + table = fn_node.get_all_records() + assert table is not None + + # "id" was non-nullable in the source; after the Polars join it must stay so + id_field = table.schema.field("id") + assert id_field.nullable is False, ( + f"Expected 'id' tag column to be non-nullable after get_all_records(), " + f"but got nullable={id_field.nullable}." + ) + + +# --------------------------------------------------------------------------- +# FunctionNode.iter_packets nullability +# --------------------------------------------------------------------------- + + +class TestFunctionNodeIterPacketsNullability: + """FunctionNode.iter_packets must yield packets whose underlying Arrow schema + preserves non-nullable column constraints.""" + + def test_iter_packets_from_database_preserves_non_nullable_output(self): + """Packets loaded from DB via iter_packets carry non-nullable output schema.""" + database = InMemoryArrowDatabase() + source = op.sources.DictSource( + [{"id": 1, "x": 7}], + tag_columns=["id"], + ) + + @op.function_pod(output_keys=["result"]) + def add_one(x: int) -> int: + return x + 1 + + pipeline = op.Pipeline("test_iter_packets_nullable", database) + with pipeline: + add_one.pod(source) + + pipeline.run() + + fn_nodes = _get_function_nodes(pipeline) + fn_node = fn_nodes[0] + + # Force a DB-backed iteration by going through _iter_all_from_database + # (simulates the CACHE_ONLY path used after save/load) + packets_seen = list(fn_node._iter_all_from_database()) + assert len(packets_seen) == 1, "Expected one packet from the database" + + _tag, packet = packets_seen[0] + packet_schema = packet.arrow_schema() + + result_field = packet_schema.field("result") + assert result_field.nullable is False, ( + f"Packet 'result' field should be non-nullable (int return type), " + f"but got nullable={result_field.nullable}. " + "Arrow→Polars→Arrow round-trip in iter_packets dropped nullability." + ) + + +# --------------------------------------------------------------------------- +# Join operator nullability +# --------------------------------------------------------------------------- + + +class TestJoinOperatorNullability: + """Join.op_forward must preserve non-nullable tag column flags through the + Polars inner join it uses internally.""" + + def test_join_preserves_non_nullable_shared_tag_column(self): + """Shared tag column remains non-nullable after stream join.""" + # DictSource applies infer_schema_nullable → integer 'id' has nullable=False + source1 = op.sources.DictSource( + [{"id": 1, "x": 10}, {"id": 2, "x": 20}], + tag_columns=["id"], + ) + source2 = op.sources.DictSource( + [{"id": 1, "y": 100}, {"id": 2, "y": 200}], + tag_columns=["id"], + ) + + joined_stream = source1.join(source2) + table = joined_stream.as_table() + + id_field = table.schema.field("id") + assert id_field.nullable is False, ( + f"Expected 'id' tag column to be non-nullable after Join, " + f"but got nullable={id_field.nullable}. " + "Arrow→Polars→Arrow round-trip in Join.op_forward dropped nullability." + ) + + def test_join_preserves_non_nullable_packet_columns(self): + """Packet columns that are non-nullable remain so after stream join.""" + source1 = op.sources.DictSource( + [{"id": 1, "x": 10}, {"id": 2, "x": 20}], + tag_columns=["id"], + ) + source2 = op.sources.DictSource( + [{"id": 1, "y": 100}, {"id": 2, "y": 200}], + tag_columns=["id"], + ) + + joined_stream = source1.join(source2) + table = joined_stream.as_table() + + # "x" and "y" are packet columns with integer values → non-nullable + x_field = table.schema.field("x") + y_field = table.schema.field("y") + + assert x_field.nullable is False, ( + f"Expected 'x' packet column to be non-nullable after Join, " + f"but got nullable={x_field.nullable}." + ) + assert y_field.nullable is False, ( + f"Expected 'y' packet column to be non-nullable after Join, " + f"but got nullable={y_field.nullable}." + ) + + def test_join_preserves_nullable_optional_column_with_no_nulls(self): + """Optional[int] packet column (nullable=True) must remain nullable=True after + Join, even when the data contains no actual null values. + + infer_schema_nullable incorrectly marks it nullable=False because it sees no + nulls in the data. restore_schema_nullability preserves schema intent. + """ + # Explicitly declare "x" as nullable=True (Optional[int]) via pa.Schema + schema1 = pa.schema([ + pa.field("id", pa.int64(), nullable=False), + pa.field("x", pa.int64(), nullable=True), # Optional — no actual nulls + ]) + source1 = op.sources.DictSource( + [{"id": 1, "x": 10}, {"id": 2, "x": 20}], + tag_columns=["id"], + data_schema=schema1, + ) + source2 = op.sources.DictSource( + [{"id": 1, "y": 100}, {"id": 2, "y": 200}], + tag_columns=["id"], + ) + + joined_stream = source1.join(source2) + table = joined_stream.as_table() + + x_field = table.schema.field("x") + assert x_field.nullable is True, ( + f"Expected 'x' (Optional[int], nullable=True in source schema) to remain " + f"nullable=True after Join, but got nullable={x_field.nullable}. " + "infer_schema_nullable incorrectly set it to nullable=False (no actual nulls " + "in data), ignoring schema intent." + ) + + +# --------------------------------------------------------------------------- +# Join tag-column nullability with mixed nullable/non-nullable tag keys +# --------------------------------------------------------------------------- + + +class TestJoinTagColumnNullability: + """Join must preserve the exact nullable flag of every tag column — + both non-nullable mandatory keys and nullable optional keys — through + the Polars inner join used internally.""" + + def test_shared_tag_columns_mixed_nullability_preserved(self): + """When two sources share multiple tag columns with mixed nullable flags, + each flag is preserved correctly after the join. + + Schema intent: + - "id" int64 nullable=False (mandatory join key) + - "group" utf8 nullable=True (Optional grouping key, no actual nulls) + """ + tag_schema = pa.schema([ + pa.field("id", pa.int64(), nullable=False), + pa.field("group", pa.utf8(), nullable=True), + ]) + schema1 = pa.schema([ + *tag_schema, + pa.field("x", pa.int64(), nullable=False), + ]) + schema2 = pa.schema([ + *tag_schema, + pa.field("y", pa.int64(), nullable=False), + ]) + + source1 = op.sources.DictSource( + [{"id": 1, "group": "a", "x": 10}, {"id": 2, "group": "b", "x": 20}], + tag_columns=["id", "group"], + data_schema=schema1, + ) + source2 = op.sources.DictSource( + [{"id": 1, "group": "a", "y": 100}, {"id": 2, "group": "b", "y": 200}], + tag_columns=["id", "group"], + data_schema=schema2, + ) + + table = source1.join(source2).as_table() + + id_field = table.schema.field("id") + group_field = table.schema.field("group") + + assert id_field.nullable is False, ( + f"'id' (non-nullable tag) must remain nullable=False after Join, " + f"got nullable={id_field.nullable}." + ) + assert group_field.nullable is True, ( + f"'group' (Optional tag, nullable=True) must remain nullable=True after Join " + f"even though data contains no actual nulls, got nullable={group_field.nullable}." + ) + + def test_non_shared_tag_columns_mixed_nullability_preserved(self): + """Tag columns that are unique to each side of a join (non-shared) also + preserve their nullable flags in the combined result. + + source1 has tag "id" (non-nullable int). + source2 has tag "category" (nullable string, Optional, no actual nulls). + Neither tag is shared, so the join is a full cartesian product. + Both tag columns appear in the result and must keep their original nullable flags. + """ + schema1 = pa.schema([ + pa.field("id", pa.int64(), nullable=False), + pa.field("x", pa.int64(), nullable=False), + ]) + schema2 = pa.schema([ + pa.field("category", pa.utf8(), nullable=True), # Optional tag + pa.field("y", pa.int64(), nullable=False), + ]) + + source1 = op.sources.DictSource( + [{"id": 1, "x": 10}], + tag_columns=["id"], + data_schema=schema1, + ) + source2 = op.sources.DictSource( + [{"category": "alpha", "y": 100}], + tag_columns=["category"], + data_schema=schema2, + ) + + table = source1.join(source2).as_table() + + id_field = table.schema.field("id") + category_field = table.schema.field("category") + + assert id_field.nullable is False, ( + f"'id' (non-nullable tag from source1) must remain nullable=False after " + f"cartesian join, got nullable={id_field.nullable}." + ) + assert category_field.nullable is True, ( + f"'category' (Optional tag, nullable=True from source2) must remain " + f"nullable=True after cartesian join even with no actual nulls, " + f"got nullable={category_field.nullable}." + ) + + def test_three_way_join_tag_nullability_preserved(self): + """A three-way join (two Polars join iterations) correctly restores nullable + flags on all tag columns across both iterations. + + shared tag "id" int64 nullable=False + shared tag "group" utf8 nullable=True (Optional, no actual nulls) + """ + tag_schema = pa.schema([ + pa.field("id", pa.int64(), nullable=False), + pa.field("group", pa.utf8(), nullable=True), + ]) + schema1 = pa.schema([*tag_schema, pa.field("a", pa.int64(), nullable=False)]) + schema2 = pa.schema([*tag_schema, pa.field("b", pa.int64(), nullable=True)]) # b is Optional + schema3 = pa.schema([*tag_schema, pa.field("c", pa.int64(), nullable=False)]) + + source1 = op.sources.DictSource( + [{"id": 1, "group": "x", "a": 1}], + tag_columns=["id", "group"], + data_schema=schema1, + ) + source2 = op.sources.DictSource( + [{"id": 1, "group": "x", "b": 2}], + tag_columns=["id", "group"], + data_schema=schema2, + ) + source3 = op.sources.DictSource( + [{"id": 1, "group": "x", "c": 3}], + tag_columns=["id", "group"], + data_schema=schema3, + ) + + table = source1.join(source2).join(source3).as_table() + + id_field = table.schema.field("id") + group_field = table.schema.field("group") + b_field = table.schema.field("b") + + assert id_field.nullable is False, ( + f"'id' (non-nullable tag) must remain nullable=False after 3-way join, " + f"got nullable={id_field.nullable}." + ) + assert group_field.nullable is True, ( + f"'group' (Optional tag) must remain nullable=True after 3-way join, " + f"got nullable={group_field.nullable}." + ) + assert b_field.nullable is True, ( + f"'b' (Optional packet column) must remain nullable=True after 3-way join, " + f"got nullable={b_field.nullable}." + ) diff --git a/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py new file mode 100644 index 00000000..4c24d0da --- /dev/null +++ b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py @@ -0,0 +1,185 @@ +""" +These tests document that Polars round-trips and joins widen all Arrow field +nullability to nullable=True, and verify that restore_schema_nullability +correctly reapplies the original nullable flags from reference schemas. +""" + +import pyarrow as pa +import polars as pl + +from orcapod.utils import arrow_utils + + +class TestPolarsRoundTripLosesNullability: + """Demonstrate the root-cause: Polars always produces nullable=True.""" + + def test_polars_roundtrip_makes_non_nullable_fields_nullable(self): + """Polars DataFrame round-trip converts nullable=False to nullable=True.""" + schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("value", pa.large_string(), nullable=False), + ] + ) + table = pa.table({"id": [1, 2, 3], "value": ["a", "b", "c"]}, schema=schema) + + # Precondition: original table has non-nullable fields + assert table.schema.field("id").nullable is False + assert table.schema.field("value").nullable is False + + # After Polars round-trip all fields become nullable (the known bug) + roundtrip = pl.DataFrame(table).to_arrow() + assert roundtrip.schema.field("id").nullable is True + assert roundtrip.schema.field("value").nullable is True + + def test_polars_join_makes_all_result_fields_nullable(self): + """Polars inner join result has nullable=True for all fields.""" + left_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("name", pa.large_string(), nullable=False), + ] + ) + right_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("score", pa.float64(), nullable=False), + ] + ) + left = pa.table({"id": [1, 2], "name": ["a", "b"]}, schema=left_schema) + right = pa.table({"id": [1, 2], "score": [9.5, 8.0]}, schema=right_schema) + + joined = ( + pl.DataFrame(left) + .join(pl.DataFrame(right), on="id", how="inner") + .to_arrow() + ) + + # Bug: all columns nullable after join + assert joined.schema.field("id").nullable is True + assert joined.schema.field("name").nullable is True + assert joined.schema.field("score").nullable is True + + +class TestRestoreSchemaNullability: + """Unit tests for arrow_utils.restore_schema_nullability.""" + + def test_restores_non_nullable_flags_after_polars_roundtrip(self): + """restore_schema_nullability fixes nullable=True caused by Polars.""" + original_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("value", pa.large_string(), nullable=False), + ] + ) + table = pa.table( + {"id": [1, 2, 3], "value": ["a", "b", "c"]}, schema=original_schema + ) + + # Simulate Polars round-trip (loses nullability) + roundtrip = pl.DataFrame(table).to_arrow() + assert roundtrip.schema.field("id").nullable is True # confirms the bug + + # Fix: restore from original schema + restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema) + + assert restored.schema.field("id").nullable is False + assert restored.schema.field("value").nullable is False + + def test_preserves_data_values_after_restore(self): + """restore_schema_nullability does not alter data, only schema metadata.""" + original_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("x", pa.float64(), nullable=False), + ] + ) + table = pa.table({"id": [1, 2], "x": [1.5, 2.5]}, schema=original_schema) + roundtrip = pl.DataFrame(table).to_arrow() + + restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema) + + assert restored.to_pydict() == table.to_pydict() + + def test_preserves_nullable_fields_that_are_nullable_in_reference(self): + """Fields that are nullable in the reference schema remain nullable.""" + original_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("optional_val", pa.large_string(), nullable=True), + ] + ) + table = pa.table( + {"id": [1, 2], "optional_val": ["x", None]}, schema=original_schema + ) + roundtrip = pl.DataFrame(table).to_arrow() + + restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema) + + assert restored.schema.field("id").nullable is False + assert restored.schema.field("optional_val").nullable is True + + def test_leaves_extra_columns_as_nullable(self): + """Columns absent from the reference schema keep Polars-default nullable=True.""" + original_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + ] + ) + full_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("_exists", pa.bool_(), nullable=True), + ] + ) + table = pa.table({"id": [1, 2], "_exists": [True, False]}, schema=full_schema) + roundtrip = pl.DataFrame(table).to_arrow() + + restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema) + + assert restored.schema.field("id").nullable is False + # "_exists" not in reference → left as Polars output (nullable=True) + assert restored.schema.field("_exists").nullable is True + + def test_with_multiple_reference_schemas_later_wins(self): + """When the same field name appears in multiple reference schemas, the last wins.""" + schema_a = pa.schema([pa.field("id", pa.int64(), nullable=True)]) + schema_b = pa.schema([pa.field("id", pa.int64(), nullable=False)]) + + table = pa.table({"id": [1, 2]}) + roundtrip = pl.DataFrame(table).to_arrow() + + # schema_b comes last → nullable=False wins + restored = arrow_utils.restore_schema_nullability(roundtrip, schema_a, schema_b) + assert restored.schema.field("id").nullable is False + + def test_restores_non_nullable_in_polars_join_result(self): + """restore_schema_nullability works on the output of a Polars join.""" + left_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("name", pa.large_string(), nullable=False), + ] + ) + right_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("score", pa.float64(), nullable=False), + ] + ) + left = pa.table({"id": [1, 2], "name": ["a", "b"]}, schema=left_schema) + right = pa.table({"id": [1, 2], "score": [9.5, 8.0]}, schema=right_schema) + + joined = ( + pl.DataFrame(left) + .join(pl.DataFrame(right), on="id", how="inner") + .to_arrow() + ) + + restored = arrow_utils.restore_schema_nullability( + joined, left_schema, right_schema + ) + + assert restored.schema.field("id").nullable is False + assert restored.schema.field("name").nullable is False + assert restored.schema.field("score").nullable is False diff --git a/tests/test_utils/test_arrow_utils.py b/tests/test_utils/test_arrow_utils.py index 1ddbba1c..2831b90a 100644 --- a/tests/test_utils/test_arrow_utils.py +++ b/tests/test_utils/test_arrow_utils.py @@ -4,6 +4,7 @@ import pytest from orcapod.utils.arrow_utils import ( + add_source_info, infer_schema_nullable, make_schema_non_nullable, prepare_prefixed_columns, @@ -175,3 +176,58 @@ def test_none_schema_metadata_stays_none(self): table = pa.table({"x": pa.array([1, 2], type=pa.int64())}) result = infer_schema_nullable(table) assert not result.metadata + + +# --------------------------------------------------------------------------- +# add_source_info +# --------------------------------------------------------------------------- + + +class TestAddSourceInfo: + """add_source_info must produce one _source_ column per data column + whose values are exactly '::' for every row.""" + + def test_single_column_produces_correct_source_token(self): + """_source_ value is '::' for every row.""" + table = pa.table({"x": pa.array([10, 20], type=pa.int64())}) + result = add_source_info(table, "mysrc") + + assert result.column("_source_x").to_pylist() == ["mysrc::x", "mysrc::x"] + + def test_multi_column_each_source_column_uses_its_own_name(self): + """Each _source_ value is '::' — columns are independent.""" + table = pa.table({ + "x": pa.array([1, 2], type=pa.int64()), + "y": pa.array([3, 4], type=pa.int64()), + "z": pa.array([5, 6], type=pa.int64()), + }) + result = add_source_info(table, "base") + + assert result.column("_source_x").to_pylist() == ["base::x", "base::x"] + assert result.column("_source_y").to_pylist() == ["base::y", "base::y"] + assert result.column("_source_z").to_pylist() == ["base::z", "base::z"] + + def test_per_row_source_tokens_combined_with_column_name(self): + """With a per-row source list, each row's token is '::'.""" + table = pa.table({ + "a": pa.array([10, 20], type=pa.int64()), + "b": pa.array([30, 40], type=pa.int64()), + }) + result = add_source_info(table, ["src0", "src1"]) + + assert result.column("_source_a").to_pylist() == ["src0::a", "src1::a"] + assert result.column("_source_b").to_pylist() == ["src0::b", "src1::b"] + + def test_correct_number_of_source_columns_added(self): + """Exactly one _source_ column is added per non-excluded data column.""" + table = pa.table({ + "p": pa.array([1], type=pa.int64()), + "q": pa.array([2], type=pa.int64()), + "r": pa.array([3], type=pa.int64()), + }) + result = add_source_info(table, "s") + + assert result.num_columns == 6 # 3 data + 3 source + assert "_source_p" in result.column_names + assert "_source_q" in result.column_names + assert "_source_r" in result.column_names