-
Notifications
You must be signed in to change notification settings - Fork 5
fix: restore Arrow nullable flags lost in Polars round-trips (ENG-375) #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2bcfb99
e4248c0
44dc263
3218c4c
b378465
0e73020
d4bcd72
464ac7c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -831,6 +831,77 @@ def infer_schema_nullable(table: "pa.Table") -> "pa.Schema": | |
| ) | ||
|
|
||
|
|
||
| def restore_schema_nullability( | ||
| table: "pa.Table", | ||
| *reference_schemas: "pa.Schema", | ||
| ) -> "pa.Table": | ||
| """Restore nullable flags lost during an Arrow → Polars → Arrow round-trip. | ||
|
|
||
| Polars converts all Arrow fields to ``nullable=True`` when producing its | ||
| Arrow output. This function repairs the resulting table by looking up each | ||
| field name in the supplied reference schemas and reinstating the original | ||
| ``nullable`` flag (and type) from those references. | ||
|
|
||
| Fields that are **not** found in any reference schema are left exactly as | ||
| Polars produced them (``nullable=True``), which is safe for internally | ||
| generated sentinel columns (e.g. ``_exists``, ``__pipeline_entry_id``). | ||
|
|
||
| When the same field name appears in multiple reference schemas the | ||
| **last-supplied** schema wins, so callers can pass ``(left_schema, | ||
| right_schema)`` and get right-side nullability for join-key columns that | ||
| appear in both. | ||
|
|
||
| Args: | ||
| table: Arrow table produced by a Polars round-trip (all fields | ||
| ``nullable=True``). | ||
| *reference_schemas: One or more Arrow schemas carrying the original | ||
| nullable flags. Pass the schemas of every table that participated | ||
| in the Polars join/sort so that all columns are covered. | ||
|
|
||
| Returns: | ||
| A new Arrow table cast to the restored schema. Data is unchanged; | ||
| only the schema metadata (nullability, field type, field metadata) | ||
| is corrected. | ||
|
|
||
| Raises: | ||
| pyarrow.ArrowInvalid: If a restored field type is incompatible with | ||
| the actual column data (should not happen for well-formed | ||
| round-trips, but surfaced to the caller rather than silently | ||
| discarded). | ||
|
|
||
| Example:: | ||
|
|
||
| taginfo_schema = taginfo.schema | ||
| results_schema = results.schema | ||
| joined = ( | ||
| pl.DataFrame(taginfo) | ||
| .join(pl.DataFrame(results), on="record_id", how="inner") | ||
| .to_arrow() | ||
| ) | ||
| joined = arrow_utils.restore_schema_nullability( | ||
| joined, taginfo_schema, results_schema | ||
| ) | ||
| """ | ||
| # Build a name → Field lookup; later schemas override earlier ones. | ||
| field_lookup: dict[str, "pa.Field"] = {} | ||
| for schema in reference_schemas: | ||
| for field in schema: | ||
| field_lookup[field.name] = field | ||
|
|
||
| restored_fields = [] | ||
| for field in table.schema: | ||
| ref = field_lookup.get(field.name) | ||
| if ref is not None: | ||
| restored_fields.append( | ||
| pa.field(field.name, ref.type, nullable=ref.nullable, metadata=ref.metadata) | ||
| ) | ||
| else: | ||
| restored_fields.append(field) | ||
|
|
||
| restored_schema = pa.schema(restored_fields, metadata=table.schema.metadata) | ||
| return table.cast(restored_schema) | ||
|
|
||
|
|
||
| def drop_columns_with_prefix( | ||
| table: "pa.Table", | ||
| prefix: str | tuple[str, ...], | ||
|
|
@@ -896,8 +967,14 @@ def add_system_tag_columns( | |
| source_id_array = pa.array(source_ids, type=pa.large_string()) | ||
| record_id_array = pa.array(record_ids, type=pa.large_string()) | ||
|
|
||
| table = table.append_column(source_id_col_name, source_id_array) | ||
| table = table.append_column(record_id_col_name, record_id_array) | ||
| # System tag columns are always computed, never null — declare nullable=False | ||
| # explicitly so the schema intent is not lost in Polars round-trips. | ||
| table = table.append_column( | ||
| pa.field(source_id_col_name, pa.large_string(), nullable=False), source_id_array | ||
| ) | ||
| table = table.append_column( | ||
| pa.field(record_id_col_name, pa.large_string(), nullable=False), record_id_array | ||
| ) | ||
| return table | ||
|
|
||
|
|
||
|
|
@@ -1074,26 +1151,37 @@ def add_source_info( | |
| exclude_columns: Collection[str] = (), | ||
| ) -> "pa.Table": | ||
| """Add source information to an Arrow table.""" | ||
| # Create a new column with the source information | ||
| # Create a base list of per-row source tokens; one entry per row. | ||
| if source_info is None or isinstance(source_info, str): | ||
| source_column = [source_info] * table.num_rows | ||
| base_source = [source_info] * table.num_rows | ||
| elif isinstance(source_info, Collection): | ||
| if len(source_info) != table.num_rows: | ||
|
Comment on lines
1155
to
1158
|
||
| raise ValueError( | ||
| "Length of source_info collection must match number of rows in the table." | ||
| ) | ||
| source_column = source_info | ||
|
|
||
| # identify columns for which source columns should be created | ||
| base_source = list(source_info) | ||
| else: | ||
| raise TypeError( | ||
| f"source_info must be a str, a sized Collection[str], or None; " | ||
| f"got {type(source_info).__name__}" | ||
| ) | ||
|
|
||
| # For each data column, build an independent _source_<col> column from the | ||
| # base tokens. We must NOT re-use the array produced for a previous column | ||
| # as input for the next one — doing so would accumulate column names | ||
| # (e.g. "src::col1::col2" instead of "src::col2"). | ||
| for col in table.column_names: | ||
| if col.startswith(tuple(exclude_prefixes)) or col in exclude_columns: | ||
| continue | ||
| source_column = pa.array( | ||
| [f"{source_val}::{col}" for source_val in source_column], | ||
| col_source = pa.array( | ||
| [f"{source_val}::{col}" for source_val in base_source], | ||
| type=pa.large_string(), | ||
| ) | ||
| table = table.append_column(f"{constants.SOURCE_PREFIX}{col}", source_column) | ||
| # Source info columns are always computed strings, never null. | ||
| table = table.append_column( | ||
| pa.field(f"{constants.SOURCE_PREFIX}{col}", pa.large_string(), nullable=False), | ||
| col_source, | ||
| ) | ||
|
|
||
| return table | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR restores nullability after Polars
.sort(...).to_arrow()inFunctionNode.as_table, but there is another Polars sort round-trip inFunctionPodStream.as_table(src/orcapod/core/function_pod.py, around thecolumn_config.sort_by_tagsblock) that still drops nullable flags. Consider applyingrestore_schema_nullabilitythere too for consistent behavior across stream/table materialization paths.