diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py
index f9964106..95bceed3 100644
--- a/src/orcapod/core/function_pod.py
+++ b/src/orcapod/core/function_pod.py
@@ -649,11 +649,13 @@ def as_table(
if column_config.sort_by_tags:
# TODO: reimplement using polars natively
+ output_table_schema = output_table.schema
output_table = (
pl.DataFrame(output_table)
.sort(by=self.keys()[0], descending=False)
.to_arrow()
)
+ output_table = arrow_utils.restore_schema_nullability(output_table, output_table_schema)
# output_table = output_table.sort_by(
# [(column, "ascending") for column in self.keys()[0]]
# )
diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py
index f4bf4b05..a360ad0e 100644
--- a/src/orcapod/core/nodes/function_node.py
+++ b/src/orcapod/core/nodes/function_node.py
@@ -743,7 +743,8 @@ def get_cached_results(
return {}
taginfo = self._filter_by_content_hash(taginfo)
-
+ taginfo_schema = taginfo.schema
+ results_schema = results.schema
filtered = (
pl.DataFrame(taginfo)
.join(
@@ -754,6 +755,7 @@ def get_cached_results(
.filter(pl.col(PIPELINE_ENTRY_ID_COL).is_in(entry_ids))
.to_arrow()
)
+ filtered = arrow_utils.restore_schema_nullability(filtered, taginfo_schema, results_schema)
if filtered.num_rows == 0:
return {}
@@ -987,12 +989,14 @@ def get_all_records(
return None
taginfo = self._filter_by_content_hash(taginfo)
-
+ taginfo_schema = taginfo.schema
+ results_schema = results.schema
joined = (
pl.DataFrame(taginfo)
.join(pl.DataFrame(results), on=constants.PACKET_RECORD_ID, how="inner")
.to_arrow()
)
+ joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema)
column_config = ColumnConfig.handle_config(columns, all_info=all_info)
@@ -1072,7 +1076,8 @@ def _load_all_cached_records(
return None
taginfo = self._filter_by_content_hash(taginfo)
-
+ taginfo_schema = taginfo.schema
+ results_schema = results.schema
joined = (
pl.DataFrame(taginfo)
.join(
@@ -1082,6 +1087,7 @@ def _load_all_cached_records(
)
.to_arrow()
)
+ joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema)
if joined.num_rows == 0:
return None
@@ -1202,6 +1208,8 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]:
if taginfo is not None and results is not None:
taginfo = self._filter_by_content_hash(taginfo)
+ taginfo_schema = taginfo.schema
+ results_schema = results.schema
joined = (
pl.DataFrame(taginfo)
.join(
@@ -1211,6 +1219,7 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]:
)
.to_arrow()
)
+ joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema)
if joined.num_rows > 0:
tag_keys = self._input_stream.keys()[0]
# Collect pipeline entry_ids for Phase 2 skip check
@@ -1435,11 +1444,13 @@ def as_table(
)
if column_config.sort_by_tags:
+ output_table_schema = output_table.schema
output_table = (
pl.DataFrame(output_table)
.sort(by=self.keys()[0], descending=False)
.to_arrow()
)
+ output_table = arrow_utils.restore_schema_nullability(output_table, output_table_schema)
return output_table
# ------------------------------------------------------------------
@@ -1524,6 +1535,8 @@ async def async_execute(
if taginfo is not None and results is not None:
taginfo = self._filter_by_content_hash(taginfo)
+ taginfo_schema = taginfo.schema
+ results_schema = results.schema
joined = (
pl.DataFrame(taginfo)
.join(
@@ -1533,6 +1546,7 @@ async def async_execute(
)
.to_arrow()
)
+ joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema)
if joined.num_rows > 0:
tag_keys = self._input_stream.keys()[0]
entry_ids_col = joined.column(
diff --git a/src/orcapod/core/operators/join.py b/src/orcapod/core/operators/join.py
index f4ddf479..7937034b 100644
--- a/src/orcapod/core/operators/join.py
+++ b/src/orcapod/core/operators/join.py
@@ -171,17 +171,33 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol:
counter += 1
new_name = f"{col}_{counter}"
rename_map[col] = new_name
+
+ # Build a reference schema for next_table with rename_map applied to
+ # field names, preserving nullable flags — must be done BEFORE the
+ # rename so we capture the original schema.
+ next_ref_schema = pa.schema([
+ pa.field(rename_map.get(f.name, f.name), f.type, nullable=f.nullable, metadata=f.metadata)
+ for f in next_table.schema
+ ])
+
if rename_map:
- next_table = pl.DataFrame(next_table).rename(rename_map).to_arrow()
+ # Use Arrow-native rename to avoid an unnecessary Polars round-trip.
+ next_table = next_table.rename_columns(
+ [rename_map.get(name, name) for name in next_table.column_names]
+ )
common_tag_keys = tag_keys.intersection(next_tag_keys)
common_tag_keys.add(COMMON_JOIN_KEY)
+ # Capture the left-side schema before the Polars join, which sets all
+ # fields to nullable=True regardless of the original schema.
+ table_ref_schema = table.schema
table = (
pl.DataFrame(table)
.join(pl.DataFrame(next_table), on=list(common_tag_keys), how="inner")
.to_arrow()
)
+ table = arrow_utils.restore_schema_nullability(table, table_ref_schema, next_ref_schema)
tag_keys.update(next_tag_keys)
@@ -196,13 +212,6 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol:
reordered_columns += [col for col in table.column_names if col not in tag_keys]
result_table = table.select(reordered_columns)
- # Derive nullable per column from actual null counts so that:
- # - Columns with no nulls (e.g. tag/packet fields after inner join) get
- # nullable=False, avoiding spurious T | None from Polars' all-True default.
- # - Columns that genuinely contain nulls (e.g. Optional fields, or source
- # info columns after cross-stream joins) keep nullable=True, preventing
- # cast failures.
- result_table = result_table.cast(arrow_utils.infer_schema_nullable(result_table))
return ArrowTableStream(
result_table,
tag_columns=tuple(tag_keys),
diff --git a/src/orcapod/utils/arrow_utils.py b/src/orcapod/utils/arrow_utils.py
index 9307cf0d..13ba410c 100644
--- a/src/orcapod/utils/arrow_utils.py
+++ b/src/orcapod/utils/arrow_utils.py
@@ -831,6 +831,77 @@ def infer_schema_nullable(table: "pa.Table") -> "pa.Schema":
)
+def restore_schema_nullability(
+ table: "pa.Table",
+ *reference_schemas: "pa.Schema",
+) -> "pa.Table":
+ """Restore nullable flags lost during an Arrow → Polars → Arrow round-trip.
+
+ Polars converts all Arrow fields to ``nullable=True`` when producing its
+ Arrow output. This function repairs the resulting table by looking up each
+ field name in the supplied reference schemas and reinstating the original
+ ``nullable`` flag (and type) from those references.
+
+ Fields that are **not** found in any reference schema are left exactly as
+ Polars produced them (``nullable=True``), which is safe for internally
+ generated sentinel columns (e.g. ``_exists``, ``__pipeline_entry_id``).
+
+ When the same field name appears in multiple reference schemas the
+ **last-supplied** schema wins, so callers can pass ``(left_schema,
+ right_schema)`` and get right-side nullability for join-key columns that
+ appear in both.
+
+ Args:
+ table: Arrow table produced by a Polars round-trip (all fields
+ ``nullable=True``).
+ *reference_schemas: One or more Arrow schemas carrying the original
+ nullable flags. Pass the schemas of every table that participated
+ in the Polars join/sort so that all columns are covered.
+
+ Returns:
+ A new Arrow table cast to the restored schema. Data is unchanged;
+ only the schema metadata (nullability, field type, field metadata)
+ is corrected.
+
+ Raises:
+ pyarrow.ArrowInvalid: If a restored field type is incompatible with
+ the actual column data (should not happen for well-formed
+ round-trips, but surfaced to the caller rather than silently
+ discarded).
+
+ Example::
+
+ taginfo_schema = taginfo.schema
+ results_schema = results.schema
+ joined = (
+ pl.DataFrame(taginfo)
+ .join(pl.DataFrame(results), on="record_id", how="inner")
+ .to_arrow()
+ )
+ joined = arrow_utils.restore_schema_nullability(
+ joined, taginfo_schema, results_schema
+ )
+ """
+ # Build a name → Field lookup; later schemas override earlier ones.
+ field_lookup: dict[str, "pa.Field"] = {}
+ for schema in reference_schemas:
+ for field in schema:
+ field_lookup[field.name] = field
+
+ restored_fields = []
+ for field in table.schema:
+ ref = field_lookup.get(field.name)
+ if ref is not None:
+ restored_fields.append(
+ pa.field(field.name, ref.type, nullable=ref.nullable, metadata=ref.metadata)
+ )
+ else:
+ restored_fields.append(field)
+
+ restored_schema = pa.schema(restored_fields, metadata=table.schema.metadata)
+ return table.cast(restored_schema)
+
+
def drop_columns_with_prefix(
table: "pa.Table",
prefix: str | tuple[str, ...],
@@ -896,8 +967,14 @@ def add_system_tag_columns(
source_id_array = pa.array(source_ids, type=pa.large_string())
record_id_array = pa.array(record_ids, type=pa.large_string())
- table = table.append_column(source_id_col_name, source_id_array)
- table = table.append_column(record_id_col_name, record_id_array)
+ # System tag columns are always computed, never null — declare nullable=False
+ # explicitly so the schema intent is not lost in Polars round-trips.
+ table = table.append_column(
+ pa.field(source_id_col_name, pa.large_string(), nullable=False), source_id_array
+ )
+ table = table.append_column(
+ pa.field(record_id_col_name, pa.large_string(), nullable=False), record_id_array
+ )
return table
@@ -1074,26 +1151,37 @@ def add_source_info(
exclude_columns: Collection[str] = (),
) -> "pa.Table":
"""Add source information to an Arrow table."""
- # Create a new column with the source information
+ # Create a base list of per-row source tokens; one entry per row.
if source_info is None or isinstance(source_info, str):
- source_column = [source_info] * table.num_rows
+ base_source = [source_info] * table.num_rows
elif isinstance(source_info, Collection):
if len(source_info) != table.num_rows:
raise ValueError(
"Length of source_info collection must match number of rows in the table."
)
- source_column = source_info
-
- # identify columns for which source columns should be created
+ base_source = list(source_info)
+ else:
+ raise TypeError(
+ f"source_info must be a str, a sized Collection[str], or None; "
+ f"got {type(source_info).__name__}"
+ )
+ # For each data column, build an independent _source_
column from the
+ # base tokens. We must NOT re-use the array produced for a previous column
+ # as input for the next one — doing so would accumulate column names
+ # (e.g. "src::col1::col2" instead of "src::col2").
for col in table.column_names:
if col.startswith(tuple(exclude_prefixes)) or col in exclude_columns:
continue
- source_column = pa.array(
- [f"{source_val}::{col}" for source_val in source_column],
+ col_source = pa.array(
+ [f"{source_val}::{col}" for source_val in base_source],
type=pa.large_string(),
)
- table = table.append_column(f"{constants.SOURCE_PREFIX}{col}", source_column)
+ # Source info columns are always computed strings, never null.
+ table = table.append_column(
+ pa.field(f"{constants.SOURCE_PREFIX}{col}", pa.large_string(), nullable=False),
+ col_source,
+ )
return table
diff --git a/tests/test_data/test_polars_nullability/__init__.py b/tests/test_data/test_polars_nullability/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_data/test_polars_nullability/test_function_node_nullability.py b/tests/test_data/test_polars_nullability/test_function_node_nullability.py
new file mode 100644
index 00000000..ec845c6a
--- /dev/null
+++ b/tests/test_data/test_polars_nullability/test_function_node_nullability.py
@@ -0,0 +1,387 @@
+"""
+Integration tests validating that FunctionNode and Join preserve nullable column
+constraints across the Arrow → Polars → Arrow round-trip that occurs during
+joins and cached-record retrieval.
+"""
+# ruff: noqa: E501
+
+import pyarrow as pa
+
+import orcapod as op
+from orcapod.core.nodes.function_node import FunctionNode
+from orcapod.databases import InMemoryArrowDatabase
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _get_function_nodes(pipeline: op.Pipeline) -> list[FunctionNode]:
+ """Return all FunctionNode instances from compiled pipeline nodes."""
+ return [n for n in pipeline.compiled_nodes.values() if isinstance(n, FunctionNode)]
+
+
+# ---------------------------------------------------------------------------
+# FunctionNode.get_all_records nullability
+# ---------------------------------------------------------------------------
+
+
+class TestFunctionNodeGetAllRecordsNullability:
+ """FunctionNode.get_all_records must preserve the non-nullable schema of
+ output columns whose Python type annotation is non-optional (e.g. ``int``)."""
+
+ def test_non_optional_return_type_yields_non_nullable_output_column(self):
+ """Output column from int return type must be non-nullable after get_all_records."""
+ database = InMemoryArrowDatabase()
+ source = op.sources.DictSource(
+ [{"id": 1, "x": 10}, {"id": 2, "x": 20}],
+ tag_columns=["id"],
+ )
+
+ @op.function_pod(output_keys=["result"])
+ def double(x: int) -> int:
+ return x * 2
+
+ pipeline = op.Pipeline("test_fn_nullable", database)
+ with pipeline:
+ double.pod(source)
+
+ pipeline.run()
+
+ fn_nodes = _get_function_nodes(pipeline)
+ assert len(fn_nodes) == 1, "Expected exactly one FunctionNode"
+ fn_node = fn_nodes[0]
+
+ table = fn_node.get_all_records()
+ assert table is not None, "get_all_records() returned None after pipeline.run()"
+
+ # "result" column has Python type int → nullable=False in Arrow schema
+ result_field = table.schema.field("result")
+ assert result_field.nullable is False, (
+ f"Expected 'result' column (int return type) to be non-nullable, "
+ f"but got nullable={result_field.nullable}. "
+ "Arrow→Polars→Arrow round-trip in get_all_records() dropped nullability."
+ )
+
+ def test_input_tag_column_non_nullable_after_get_all_records(self):
+ """Input tag columns that are non-nullable must remain so after Polars join."""
+ database = InMemoryArrowDatabase()
+
+ # DictSource with integer id — infer_schema_nullable sets nullable=False (no nulls)
+ source = op.sources.DictSource(
+ [{"id": 1, "x": 5}, {"id": 2, "x": 15}],
+ tag_columns=["id"],
+ )
+
+ @op.function_pod(output_keys=["result"])
+ def triple(x: int) -> int:
+ return x * 3
+
+ pipeline = op.Pipeline("test_fn_tag_nullable", database)
+ with pipeline:
+ triple.pod(source)
+
+ pipeline.run()
+
+ fn_nodes = _get_function_nodes(pipeline)
+ fn_node = fn_nodes[0]
+
+ table = fn_node.get_all_records()
+ assert table is not None
+
+ # "id" was non-nullable in the source; after the Polars join it must stay so
+ id_field = table.schema.field("id")
+ assert id_field.nullable is False, (
+ f"Expected 'id' tag column to be non-nullable after get_all_records(), "
+ f"but got nullable={id_field.nullable}."
+ )
+
+
+# ---------------------------------------------------------------------------
+# FunctionNode.iter_packets nullability
+# ---------------------------------------------------------------------------
+
+
+class TestFunctionNodeIterPacketsNullability:
+ """FunctionNode.iter_packets must yield packets whose underlying Arrow schema
+ preserves non-nullable column constraints."""
+
+ def test_iter_packets_from_database_preserves_non_nullable_output(self):
+ """Packets loaded from DB via iter_packets carry non-nullable output schema."""
+ database = InMemoryArrowDatabase()
+ source = op.sources.DictSource(
+ [{"id": 1, "x": 7}],
+ tag_columns=["id"],
+ )
+
+ @op.function_pod(output_keys=["result"])
+ def add_one(x: int) -> int:
+ return x + 1
+
+ pipeline = op.Pipeline("test_iter_packets_nullable", database)
+ with pipeline:
+ add_one.pod(source)
+
+ pipeline.run()
+
+ fn_nodes = _get_function_nodes(pipeline)
+ fn_node = fn_nodes[0]
+
+ # Force a DB-backed iteration by going through _iter_all_from_database
+ # (simulates the CACHE_ONLY path used after save/load)
+ packets_seen = list(fn_node._iter_all_from_database())
+ assert len(packets_seen) == 1, "Expected one packet from the database"
+
+ _tag, packet = packets_seen[0]
+ packet_schema = packet.arrow_schema()
+
+ result_field = packet_schema.field("result")
+ assert result_field.nullable is False, (
+ f"Packet 'result' field should be non-nullable (int return type), "
+ f"but got nullable={result_field.nullable}. "
+ "Arrow→Polars→Arrow round-trip in iter_packets dropped nullability."
+ )
+
+
+# ---------------------------------------------------------------------------
+# Join operator nullability
+# ---------------------------------------------------------------------------
+
+
+class TestJoinOperatorNullability:
+ """Join.op_forward must preserve non-nullable tag column flags through the
+ Polars inner join it uses internally."""
+
+ def test_join_preserves_non_nullable_shared_tag_column(self):
+ """Shared tag column remains non-nullable after stream join."""
+ # DictSource applies infer_schema_nullable → integer 'id' has nullable=False
+ source1 = op.sources.DictSource(
+ [{"id": 1, "x": 10}, {"id": 2, "x": 20}],
+ tag_columns=["id"],
+ )
+ source2 = op.sources.DictSource(
+ [{"id": 1, "y": 100}, {"id": 2, "y": 200}],
+ tag_columns=["id"],
+ )
+
+ joined_stream = source1.join(source2)
+ table = joined_stream.as_table()
+
+ id_field = table.schema.field("id")
+ assert id_field.nullable is False, (
+ f"Expected 'id' tag column to be non-nullable after Join, "
+ f"but got nullable={id_field.nullable}. "
+ "Arrow→Polars→Arrow round-trip in Join.op_forward dropped nullability."
+ )
+
+ def test_join_preserves_non_nullable_packet_columns(self):
+ """Packet columns that are non-nullable remain so after stream join."""
+ source1 = op.sources.DictSource(
+ [{"id": 1, "x": 10}, {"id": 2, "x": 20}],
+ tag_columns=["id"],
+ )
+ source2 = op.sources.DictSource(
+ [{"id": 1, "y": 100}, {"id": 2, "y": 200}],
+ tag_columns=["id"],
+ )
+
+ joined_stream = source1.join(source2)
+ table = joined_stream.as_table()
+
+ # "x" and "y" are packet columns with integer values → non-nullable
+ x_field = table.schema.field("x")
+ y_field = table.schema.field("y")
+
+ assert x_field.nullable is False, (
+ f"Expected 'x' packet column to be non-nullable after Join, "
+ f"but got nullable={x_field.nullable}."
+ )
+ assert y_field.nullable is False, (
+ f"Expected 'y' packet column to be non-nullable after Join, "
+ f"but got nullable={y_field.nullable}."
+ )
+
+ def test_join_preserves_nullable_optional_column_with_no_nulls(self):
+ """Optional[int] packet column (nullable=True) must remain nullable=True after
+ Join, even when the data contains no actual null values.
+
+ infer_schema_nullable incorrectly marks it nullable=False because it sees no
+ nulls in the data. restore_schema_nullability preserves schema intent.
+ """
+ # Explicitly declare "x" as nullable=True (Optional[int]) via pa.Schema
+ schema1 = pa.schema([
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("x", pa.int64(), nullable=True), # Optional — no actual nulls
+ ])
+ source1 = op.sources.DictSource(
+ [{"id": 1, "x": 10}, {"id": 2, "x": 20}],
+ tag_columns=["id"],
+ data_schema=schema1,
+ )
+ source2 = op.sources.DictSource(
+ [{"id": 1, "y": 100}, {"id": 2, "y": 200}],
+ tag_columns=["id"],
+ )
+
+ joined_stream = source1.join(source2)
+ table = joined_stream.as_table()
+
+ x_field = table.schema.field("x")
+ assert x_field.nullable is True, (
+ f"Expected 'x' (Optional[int], nullable=True in source schema) to remain "
+ f"nullable=True after Join, but got nullable={x_field.nullable}. "
+ "infer_schema_nullable incorrectly set it to nullable=False (no actual nulls "
+ "in data), ignoring schema intent."
+ )
+
+
+# ---------------------------------------------------------------------------
+# Join tag-column nullability with mixed nullable/non-nullable tag keys
+# ---------------------------------------------------------------------------
+
+
+class TestJoinTagColumnNullability:
+ """Join must preserve the exact nullable flag of every tag column —
+ both non-nullable mandatory keys and nullable optional keys — through
+ the Polars inner join used internally."""
+
+ def test_shared_tag_columns_mixed_nullability_preserved(self):
+ """When two sources share multiple tag columns with mixed nullable flags,
+ each flag is preserved correctly after the join.
+
+ Schema intent:
+ - "id" int64 nullable=False (mandatory join key)
+ - "group" utf8 nullable=True (Optional grouping key, no actual nulls)
+ """
+ tag_schema = pa.schema([
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("group", pa.utf8(), nullable=True),
+ ])
+ schema1 = pa.schema([
+ *tag_schema,
+ pa.field("x", pa.int64(), nullable=False),
+ ])
+ schema2 = pa.schema([
+ *tag_schema,
+ pa.field("y", pa.int64(), nullable=False),
+ ])
+
+ source1 = op.sources.DictSource(
+ [{"id": 1, "group": "a", "x": 10}, {"id": 2, "group": "b", "x": 20}],
+ tag_columns=["id", "group"],
+ data_schema=schema1,
+ )
+ source2 = op.sources.DictSource(
+ [{"id": 1, "group": "a", "y": 100}, {"id": 2, "group": "b", "y": 200}],
+ tag_columns=["id", "group"],
+ data_schema=schema2,
+ )
+
+ table = source1.join(source2).as_table()
+
+ id_field = table.schema.field("id")
+ group_field = table.schema.field("group")
+
+ assert id_field.nullable is False, (
+ f"'id' (non-nullable tag) must remain nullable=False after Join, "
+ f"got nullable={id_field.nullable}."
+ )
+ assert group_field.nullable is True, (
+ f"'group' (Optional tag, nullable=True) must remain nullable=True after Join "
+ f"even though data contains no actual nulls, got nullable={group_field.nullable}."
+ )
+
+ def test_non_shared_tag_columns_mixed_nullability_preserved(self):
+ """Tag columns that are unique to each side of a join (non-shared) also
+ preserve their nullable flags in the combined result.
+
+ source1 has tag "id" (non-nullable int).
+ source2 has tag "category" (nullable string, Optional, no actual nulls).
+ Neither tag is shared, so the join is a full cartesian product.
+ Both tag columns appear in the result and must keep their original nullable flags.
+ """
+ schema1 = pa.schema([
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("x", pa.int64(), nullable=False),
+ ])
+ schema2 = pa.schema([
+ pa.field("category", pa.utf8(), nullable=True), # Optional tag
+ pa.field("y", pa.int64(), nullable=False),
+ ])
+
+ source1 = op.sources.DictSource(
+ [{"id": 1, "x": 10}],
+ tag_columns=["id"],
+ data_schema=schema1,
+ )
+ source2 = op.sources.DictSource(
+ [{"category": "alpha", "y": 100}],
+ tag_columns=["category"],
+ data_schema=schema2,
+ )
+
+ table = source1.join(source2).as_table()
+
+ id_field = table.schema.field("id")
+ category_field = table.schema.field("category")
+
+ assert id_field.nullable is False, (
+ f"'id' (non-nullable tag from source1) must remain nullable=False after "
+ f"cartesian join, got nullable={id_field.nullable}."
+ )
+ assert category_field.nullable is True, (
+ f"'category' (Optional tag, nullable=True from source2) must remain "
+ f"nullable=True after cartesian join even with no actual nulls, "
+ f"got nullable={category_field.nullable}."
+ )
+
+ def test_three_way_join_tag_nullability_preserved(self):
+ """A three-way join (two Polars join iterations) correctly restores nullable
+ flags on all tag columns across both iterations.
+
+ shared tag "id" int64 nullable=False
+ shared tag "group" utf8 nullable=True (Optional, no actual nulls)
+ """
+ tag_schema = pa.schema([
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("group", pa.utf8(), nullable=True),
+ ])
+ schema1 = pa.schema([*tag_schema, pa.field("a", pa.int64(), nullable=False)])
+ schema2 = pa.schema([*tag_schema, pa.field("b", pa.int64(), nullable=True)]) # b is Optional
+ schema3 = pa.schema([*tag_schema, pa.field("c", pa.int64(), nullable=False)])
+
+ source1 = op.sources.DictSource(
+ [{"id": 1, "group": "x", "a": 1}],
+ tag_columns=["id", "group"],
+ data_schema=schema1,
+ )
+ source2 = op.sources.DictSource(
+ [{"id": 1, "group": "x", "b": 2}],
+ tag_columns=["id", "group"],
+ data_schema=schema2,
+ )
+ source3 = op.sources.DictSource(
+ [{"id": 1, "group": "x", "c": 3}],
+ tag_columns=["id", "group"],
+ data_schema=schema3,
+ )
+
+ table = source1.join(source2).join(source3).as_table()
+
+ id_field = table.schema.field("id")
+ group_field = table.schema.field("group")
+ b_field = table.schema.field("b")
+
+ assert id_field.nullable is False, (
+ f"'id' (non-nullable tag) must remain nullable=False after 3-way join, "
+ f"got nullable={id_field.nullable}."
+ )
+ assert group_field.nullable is True, (
+ f"'group' (Optional tag) must remain nullable=True after 3-way join, "
+ f"got nullable={group_field.nullable}."
+ )
+ assert b_field.nullable is True, (
+ f"'b' (Optional packet column) must remain nullable=True after 3-way join, "
+ f"got nullable={b_field.nullable}."
+ )
diff --git a/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py
new file mode 100644
index 00000000..4c24d0da
--- /dev/null
+++ b/tests/test_data/test_polars_nullability/test_restore_schema_nullability.py
@@ -0,0 +1,185 @@
+"""
+These tests document that Polars round-trips and joins widen all Arrow field
+nullability to nullable=True, and verify that restore_schema_nullability
+correctly reapplies the original nullable flags from reference schemas.
+"""
+
+import pyarrow as pa
+import polars as pl
+
+from orcapod.utils import arrow_utils
+
+
+class TestPolarsRoundTripLosesNullability:
+ """Demonstrate the root-cause: Polars always produces nullable=True."""
+
+ def test_polars_roundtrip_makes_non_nullable_fields_nullable(self):
+ """Polars DataFrame round-trip converts nullable=False to nullable=True."""
+ schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("value", pa.large_string(), nullable=False),
+ ]
+ )
+ table = pa.table({"id": [1, 2, 3], "value": ["a", "b", "c"]}, schema=schema)
+
+ # Precondition: original table has non-nullable fields
+ assert table.schema.field("id").nullable is False
+ assert table.schema.field("value").nullable is False
+
+ # After Polars round-trip all fields become nullable (the known bug)
+ roundtrip = pl.DataFrame(table).to_arrow()
+ assert roundtrip.schema.field("id").nullable is True
+ assert roundtrip.schema.field("value").nullable is True
+
+ def test_polars_join_makes_all_result_fields_nullable(self):
+ """Polars inner join result has nullable=True for all fields."""
+ left_schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("name", pa.large_string(), nullable=False),
+ ]
+ )
+ right_schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("score", pa.float64(), nullable=False),
+ ]
+ )
+ left = pa.table({"id": [1, 2], "name": ["a", "b"]}, schema=left_schema)
+ right = pa.table({"id": [1, 2], "score": [9.5, 8.0]}, schema=right_schema)
+
+ joined = (
+ pl.DataFrame(left)
+ .join(pl.DataFrame(right), on="id", how="inner")
+ .to_arrow()
+ )
+
+ # Bug: all columns nullable after join
+ assert joined.schema.field("id").nullable is True
+ assert joined.schema.field("name").nullable is True
+ assert joined.schema.field("score").nullable is True
+
+
+class TestRestoreSchemaNullability:
+ """Unit tests for arrow_utils.restore_schema_nullability."""
+
+ def test_restores_non_nullable_flags_after_polars_roundtrip(self):
+ """restore_schema_nullability fixes nullable=True caused by Polars."""
+ original_schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("value", pa.large_string(), nullable=False),
+ ]
+ )
+ table = pa.table(
+ {"id": [1, 2, 3], "value": ["a", "b", "c"]}, schema=original_schema
+ )
+
+ # Simulate Polars round-trip (loses nullability)
+ roundtrip = pl.DataFrame(table).to_arrow()
+ assert roundtrip.schema.field("id").nullable is True # confirms the bug
+
+ # Fix: restore from original schema
+ restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema)
+
+ assert restored.schema.field("id").nullable is False
+ assert restored.schema.field("value").nullable is False
+
+ def test_preserves_data_values_after_restore(self):
+ """restore_schema_nullability does not alter data, only schema metadata."""
+ original_schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("x", pa.float64(), nullable=False),
+ ]
+ )
+ table = pa.table({"id": [1, 2], "x": [1.5, 2.5]}, schema=original_schema)
+ roundtrip = pl.DataFrame(table).to_arrow()
+
+ restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema)
+
+ assert restored.to_pydict() == table.to_pydict()
+
+ def test_preserves_nullable_fields_that_are_nullable_in_reference(self):
+ """Fields that are nullable in the reference schema remain nullable."""
+ original_schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("optional_val", pa.large_string(), nullable=True),
+ ]
+ )
+ table = pa.table(
+ {"id": [1, 2], "optional_val": ["x", None]}, schema=original_schema
+ )
+ roundtrip = pl.DataFrame(table).to_arrow()
+
+ restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema)
+
+ assert restored.schema.field("id").nullable is False
+ assert restored.schema.field("optional_val").nullable is True
+
+ def test_leaves_extra_columns_as_nullable(self):
+ """Columns absent from the reference schema keep Polars-default nullable=True."""
+ original_schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ ]
+ )
+ full_schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("_exists", pa.bool_(), nullable=True),
+ ]
+ )
+ table = pa.table({"id": [1, 2], "_exists": [True, False]}, schema=full_schema)
+ roundtrip = pl.DataFrame(table).to_arrow()
+
+ restored = arrow_utils.restore_schema_nullability(roundtrip, original_schema)
+
+ assert restored.schema.field("id").nullable is False
+ # "_exists" not in reference → left as Polars output (nullable=True)
+ assert restored.schema.field("_exists").nullable is True
+
+ def test_with_multiple_reference_schemas_later_wins(self):
+ """When the same field name appears in multiple reference schemas, the last wins."""
+ schema_a = pa.schema([pa.field("id", pa.int64(), nullable=True)])
+ schema_b = pa.schema([pa.field("id", pa.int64(), nullable=False)])
+
+ table = pa.table({"id": [1, 2]})
+ roundtrip = pl.DataFrame(table).to_arrow()
+
+ # schema_b comes last → nullable=False wins
+ restored = arrow_utils.restore_schema_nullability(roundtrip, schema_a, schema_b)
+ assert restored.schema.field("id").nullable is False
+
+ def test_restores_non_nullable_in_polars_join_result(self):
+ """restore_schema_nullability works on the output of a Polars join."""
+ left_schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("name", pa.large_string(), nullable=False),
+ ]
+ )
+ right_schema = pa.schema(
+ [
+ pa.field("id", pa.int64(), nullable=False),
+ pa.field("score", pa.float64(), nullable=False),
+ ]
+ )
+ left = pa.table({"id": [1, 2], "name": ["a", "b"]}, schema=left_schema)
+ right = pa.table({"id": [1, 2], "score": [9.5, 8.0]}, schema=right_schema)
+
+ joined = (
+ pl.DataFrame(left)
+ .join(pl.DataFrame(right), on="id", how="inner")
+ .to_arrow()
+ )
+
+ restored = arrow_utils.restore_schema_nullability(
+ joined, left_schema, right_schema
+ )
+
+ assert restored.schema.field("id").nullable is False
+ assert restored.schema.field("name").nullable is False
+ assert restored.schema.field("score").nullable is False
diff --git a/tests/test_utils/test_arrow_utils.py b/tests/test_utils/test_arrow_utils.py
index 1ddbba1c..2831b90a 100644
--- a/tests/test_utils/test_arrow_utils.py
+++ b/tests/test_utils/test_arrow_utils.py
@@ -4,6 +4,7 @@
import pytest
from orcapod.utils.arrow_utils import (
+ add_source_info,
infer_schema_nullable,
make_schema_non_nullable,
prepare_prefixed_columns,
@@ -175,3 +176,58 @@ def test_none_schema_metadata_stays_none(self):
table = pa.table({"x": pa.array([1, 2], type=pa.int64())})
result = infer_schema_nullable(table)
assert not result.metadata
+
+
+# ---------------------------------------------------------------------------
+# add_source_info
+# ---------------------------------------------------------------------------
+
+
+class TestAddSourceInfo:
+ """add_source_info must produce one _source_ column per data column
+ whose values are exactly '::' for every row."""
+
+ def test_single_column_produces_correct_source_token(self):
+ """_source_ value is '::' for every row."""
+ table = pa.table({"x": pa.array([10, 20], type=pa.int64())})
+ result = add_source_info(table, "mysrc")
+
+ assert result.column("_source_x").to_pylist() == ["mysrc::x", "mysrc::x"]
+
+ def test_multi_column_each_source_column_uses_its_own_name(self):
+ """Each _source_ value is '::' — columns are independent."""
+ table = pa.table({
+ "x": pa.array([1, 2], type=pa.int64()),
+ "y": pa.array([3, 4], type=pa.int64()),
+ "z": pa.array([5, 6], type=pa.int64()),
+ })
+ result = add_source_info(table, "base")
+
+ assert result.column("_source_x").to_pylist() == ["base::x", "base::x"]
+ assert result.column("_source_y").to_pylist() == ["base::y", "base::y"]
+ assert result.column("_source_z").to_pylist() == ["base::z", "base::z"]
+
+ def test_per_row_source_tokens_combined_with_column_name(self):
+ """With a per-row source list, each row's token is '::'."""
+ table = pa.table({
+ "a": pa.array([10, 20], type=pa.int64()),
+ "b": pa.array([30, 40], type=pa.int64()),
+ })
+ result = add_source_info(table, ["src0", "src1"])
+
+ assert result.column("_source_a").to_pylist() == ["src0::a", "src1::a"]
+ assert result.column("_source_b").to_pylist() == ["src0::b", "src1::b"]
+
+ def test_correct_number_of_source_columns_added(self):
+ """Exactly one _source_ column is added per non-excluded data column."""
+ table = pa.table({
+ "p": pa.array([1], type=pa.int64()),
+ "q": pa.array([2], type=pa.int64()),
+ "r": pa.array([3], type=pa.int64()),
+ })
+ result = add_source_info(table, "s")
+
+ assert result.num_columns == 6 # 3 data + 3 source
+ assert "_source_p" in result.column_names
+ assert "_source_q" in result.column_names
+ assert "_source_r" in result.column_names