Skip to content
2 changes: 2 additions & 0 deletions src/orcapod/core/function_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,11 +649,13 @@ def as_table(

if column_config.sort_by_tags:
# TODO: reimplement using polars natively
output_table_schema = output_table.schema
output_table = (
pl.DataFrame(output_table)
.sort(by=self.keys()[0], descending=False)
.to_arrow()
)
output_table = arrow_utils.restore_schema_nullability(output_table, output_table_schema)
# output_table = output_table.sort_by(
# [(column, "ascending") for column in self.keys()[0]]
# )
Expand Down
20 changes: 17 additions & 3 deletions src/orcapod/core/nodes/function_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,8 @@ def get_cached_results(
return {}

taginfo = self._filter_by_content_hash(taginfo)

taginfo_schema = taginfo.schema
results_schema = results.schema
filtered = (
pl.DataFrame(taginfo)
.join(
Expand All @@ -754,6 +755,7 @@ def get_cached_results(
.filter(pl.col(PIPELINE_ENTRY_ID_COL).is_in(entry_ids))
.to_arrow()
)
filtered = arrow_utils.restore_schema_nullability(filtered, taginfo_schema, results_schema)

if filtered.num_rows == 0:
return {}
Expand Down Expand Up @@ -987,12 +989,14 @@ def get_all_records(
return None

taginfo = self._filter_by_content_hash(taginfo)

taginfo_schema = taginfo.schema
results_schema = results.schema
joined = (
pl.DataFrame(taginfo)
.join(pl.DataFrame(results), on=constants.PACKET_RECORD_ID, how="inner")
.to_arrow()
)
joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema)

column_config = ColumnConfig.handle_config(columns, all_info=all_info)

Expand Down Expand Up @@ -1072,7 +1076,8 @@ def _load_all_cached_records(
return None

taginfo = self._filter_by_content_hash(taginfo)

taginfo_schema = taginfo.schema
results_schema = results.schema
joined = (
pl.DataFrame(taginfo)
.join(
Expand All @@ -1082,6 +1087,7 @@ def _load_all_cached_records(
)
.to_arrow()
)
joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema)

if joined.num_rows == 0:
return None
Expand Down Expand Up @@ -1202,6 +1208,8 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]:

if taginfo is not None and results is not None:
taginfo = self._filter_by_content_hash(taginfo)
taginfo_schema = taginfo.schema
results_schema = results.schema
joined = (
pl.DataFrame(taginfo)
.join(
Expand All @@ -1211,6 +1219,7 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]:
)
.to_arrow()
)
joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema)
if joined.num_rows > 0:
tag_keys = self._input_stream.keys()[0]
# Collect pipeline entry_ids for Phase 2 skip check
Expand Down Expand Up @@ -1435,11 +1444,13 @@ def as_table(
)

if column_config.sort_by_tags:
output_table_schema = output_table.schema
output_table = (
pl.DataFrame(output_table)
.sort(by=self.keys()[0], descending=False)
.to_arrow()
)
output_table = arrow_utils.restore_schema_nullability(output_table, output_table_schema)
Comment on lines 1446 to +1453
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR restores nullability after Polars .sort(...).to_arrow() in FunctionNode.as_table, but there is another Polars sort round-trip in FunctionPodStream.as_table (src/orcapod/core/function_pod.py, around the column_config.sort_by_tags block) that still drops nullable flags. Consider applying restore_schema_nullability there too for consistent behavior across stream/table materialization paths.

Copilot uses AI. Check for mistakes.
return output_table

# ------------------------------------------------------------------
Expand Down Expand Up @@ -1524,6 +1535,8 @@ async def async_execute(

if taginfo is not None and results is not None:
taginfo = self._filter_by_content_hash(taginfo)
taginfo_schema = taginfo.schema
results_schema = results.schema
joined = (
pl.DataFrame(taginfo)
.join(
Expand All @@ -1533,6 +1546,7 @@ async def async_execute(
)
.to_arrow()
)
joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema)
if joined.num_rows > 0:
tag_keys = self._input_stream.keys()[0]
entry_ids_col = joined.column(
Expand Down
25 changes: 17 additions & 8 deletions src/orcapod/core/operators/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,33 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol:
counter += 1
new_name = f"{col}_{counter}"
rename_map[col] = new_name

# Build a reference schema for next_table with rename_map applied to
# field names, preserving nullable flags — must be done BEFORE the
# rename so we capture the original schema.
next_ref_schema = pa.schema([
pa.field(rename_map.get(f.name, f.name), f.type, nullable=f.nullable, metadata=f.metadata)
for f in next_table.schema
])

if rename_map:
next_table = pl.DataFrame(next_table).rename(rename_map).to_arrow()
# Use Arrow-native rename to avoid an unnecessary Polars round-trip.
next_table = next_table.rename_columns(
[rename_map.get(name, name) for name in next_table.column_names]
)

common_tag_keys = tag_keys.intersection(next_tag_keys)
common_tag_keys.add(COMMON_JOIN_KEY)

# Capture the left-side schema before the Polars join, which sets all
# fields to nullable=True regardless of the original schema.
table_ref_schema = table.schema
table = (
pl.DataFrame(table)
.join(pl.DataFrame(next_table), on=list(common_tag_keys), how="inner")
.to_arrow()
)
table = arrow_utils.restore_schema_nullability(table, table_ref_schema, next_ref_schema)

tag_keys.update(next_tag_keys)

Expand All @@ -196,13 +212,6 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol:
reordered_columns += [col for col in table.column_names if col not in tag_keys]

result_table = table.select(reordered_columns)
# Derive nullable per column from actual null counts so that:
# - Columns with no nulls (e.g. tag/packet fields after inner join) get
# nullable=False, avoiding spurious T | None from Polars' all-True default.
# - Columns that genuinely contain nulls (e.g. Optional fields, or source
# info columns after cross-stream joins) keep nullable=True, preventing
# cast failures.
result_table = result_table.cast(arrow_utils.infer_schema_nullable(result_table))
return ArrowTableStream(
result_table,
tag_columns=tuple(tag_keys),
Expand Down
108 changes: 98 additions & 10 deletions src/orcapod/utils/arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,77 @@ def infer_schema_nullable(table: "pa.Table") -> "pa.Schema":
)


def restore_schema_nullability(
table: "pa.Table",
*reference_schemas: "pa.Schema",
) -> "pa.Table":
"""Restore nullable flags lost during an Arrow → Polars → Arrow round-trip.

Polars converts all Arrow fields to ``nullable=True`` when producing its
Arrow output. This function repairs the resulting table by looking up each
field name in the supplied reference schemas and reinstating the original
``nullable`` flag (and type) from those references.

Fields that are **not** found in any reference schema are left exactly as
Polars produced them (``nullable=True``), which is safe for internally
generated sentinel columns (e.g. ``_exists``, ``__pipeline_entry_id``).

When the same field name appears in multiple reference schemas the
**last-supplied** schema wins, so callers can pass ``(left_schema,
right_schema)`` and get right-side nullability for join-key columns that
appear in both.

Args:
table: Arrow table produced by a Polars round-trip (all fields
``nullable=True``).
*reference_schemas: One or more Arrow schemas carrying the original
nullable flags. Pass the schemas of every table that participated
in the Polars join/sort so that all columns are covered.

Returns:
A new Arrow table cast to the restored schema. Data is unchanged;
only the schema metadata (nullability, field type, field metadata)
is corrected.

Raises:
pyarrow.ArrowInvalid: If a restored field type is incompatible with
the actual column data (should not happen for well-formed
round-trips, but surfaced to the caller rather than silently
discarded).

Example::

taginfo_schema = taginfo.schema
results_schema = results.schema
joined = (
pl.DataFrame(taginfo)
.join(pl.DataFrame(results), on="record_id", how="inner")
.to_arrow()
)
joined = arrow_utils.restore_schema_nullability(
joined, taginfo_schema, results_schema
)
"""
# Build a name → Field lookup; later schemas override earlier ones.
field_lookup: dict[str, "pa.Field"] = {}
for schema in reference_schemas:
for field in schema:
field_lookup[field.name] = field

restored_fields = []
for field in table.schema:
ref = field_lookup.get(field.name)
if ref is not None:
restored_fields.append(
pa.field(field.name, ref.type, nullable=ref.nullable, metadata=ref.metadata)
)
else:
restored_fields.append(field)

restored_schema = pa.schema(restored_fields, metadata=table.schema.metadata)
return table.cast(restored_schema)


def drop_columns_with_prefix(
table: "pa.Table",
prefix: str | tuple[str, ...],
Expand Down Expand Up @@ -896,8 +967,14 @@ def add_system_tag_columns(
source_id_array = pa.array(source_ids, type=pa.large_string())
record_id_array = pa.array(record_ids, type=pa.large_string())

table = table.append_column(source_id_col_name, source_id_array)
table = table.append_column(record_id_col_name, record_id_array)
# System tag columns are always computed, never null — declare nullable=False
# explicitly so the schema intent is not lost in Polars round-trips.
table = table.append_column(
pa.field(source_id_col_name, pa.large_string(), nullable=False), source_id_array
)
table = table.append_column(
pa.field(record_id_col_name, pa.large_string(), nullable=False), record_id_array
)
return table


Expand Down Expand Up @@ -1074,26 +1151,37 @@ def add_source_info(
exclude_columns: Collection[str] = (),
) -> "pa.Table":
"""Add source information to an Arrow table."""
# Create a new column with the source information
# Create a base list of per-row source tokens; one entry per row.
if source_info is None or isinstance(source_info, str):
source_column = [source_info] * table.num_rows
base_source = [source_info] * table.num_rows
elif isinstance(source_info, Collection):
if len(source_info) != table.num_rows:
Comment on lines 1155 to 1158
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_source is only assigned for source_info is None/str or when source_info is a Collection; any other runtime type will fall through and later raise an UnboundLocalError when building col_source. Add an explicit else: raise TypeError(...) (or normalize other iterables) so callers get a clear error and the function can’t crash with an unbound local.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in d4bcd72 — added else: raise TypeError(...) after the elif branch so any unsized iterable (e.g. a generator) gets a clear error instead of an UnboundLocalError.

raise ValueError(
"Length of source_info collection must match number of rows in the table."
)
source_column = source_info

# identify columns for which source columns should be created
base_source = list(source_info)
else:
raise TypeError(
f"source_info must be a str, a sized Collection[str], or None; "
f"got {type(source_info).__name__}"
)

# For each data column, build an independent _source_<col> column from the
# base tokens. We must NOT re-use the array produced for a previous column
# as input for the next one — doing so would accumulate column names
# (e.g. "src::col1::col2" instead of "src::col2").
for col in table.column_names:
if col.startswith(tuple(exclude_prefixes)) or col in exclude_columns:
continue
source_column = pa.array(
[f"{source_val}::{col}" for source_val in source_column],
col_source = pa.array(
[f"{source_val}::{col}" for source_val in base_source],
type=pa.large_string(),
)
table = table.append_column(f"{constants.SOURCE_PREFIX}{col}", source_column)
# Source info columns are always computed strings, never null.
table = table.append_column(
pa.field(f"{constants.SOURCE_PREFIX}{col}", pa.large_string(), nullable=False),
col_source,
)

return table

Expand Down
Empty file.
Loading
Loading