Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 108 additions & 27 deletions src/orcapod/core/nodes/function_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
if TYPE_CHECKING:
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
else:
pa = LazyModule("pyarrow")
pc = LazyModule("pyarrow.compute")
pl = LazyModule("polars")


Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(
# Optional DB params for persistent mode:
pipeline_database: ArrowDatabaseProtocol | None = None,
result_database: ArrowDatabaseProtocol | None = None,
table_scope: Literal["pipeline_hash", "content_hash"] = "pipeline_hash",
):
if tracker_manager is None:
tracker_manager = DEFAULT_TRACKER_MANAGER
Expand Down Expand Up @@ -142,6 +145,13 @@ def __init__(
self._stored_pipeline_path: tuple[str, ...] = ()
self._stored_result_record_path: tuple[str, ...] = ()
self._descriptor: dict = {}
if table_scope not in ("pipeline_hash", "content_hash"):
raise ValueError(
f"Unknown table_scope {table_scope!r}. "
"Expected one of: 'pipeline_hash', 'content_hash'."
)
self._table_scope = table_scope
self._node_identity_path_cache: tuple[str, ...] | None = None

if pipeline_database is not None:
self.attach_databases(
Expand Down Expand Up @@ -187,6 +197,7 @@ def attach_databases(
self._pipeline_database = pipeline_database

# Clear all caches
self._node_identity_path_cache = None
self.clear_cache()
self._content_hash_cache.clear()
self._pipeline_hash_cache.clear()
Expand Down Expand Up @@ -216,6 +227,27 @@ def _require_pipeline_database(self) -> None:
"or supply one via Pipeline.load(..., pipeline_database=<db>)."
)

def _filter_by_content_hash(self, table: pa.Table) -> pa.Table:
"""Filter *table* to rows whose ``NODE_CONTENT_HASH_COL`` matches this node.

Only applied when ``table_scope="pipeline_hash"`` because in that mode
multiple runs share the same DB table and must be disambiguated at read
time. In ``"content_hash"`` mode every run has its own table so no
filtering is needed.
"""
if self._table_scope != "pipeline_hash":
return table
col_name = constants.NODE_CONTENT_HASH_COL
if col_name not in table.column_names:
raise ValueError(
f"Cannot isolate records for table_scope='pipeline_hash': "
f"required column {col_name!r} is missing from the stored table. "
"This may indicate records written by an older version of the code."
)
own_hash = self.content_hash().to_string()
mask = pc.equal(table.column(col_name), own_hash)
return table.filter(mask)

# ------------------------------------------------------------------
# from_descriptor — reconstruct from a serialized pipeline descriptor
# ------------------------------------------------------------------
Expand Down Expand Up @@ -253,6 +285,20 @@ def from_descriptor(
pipeline_db = databases.get("pipeline")
result_db = databases.get("result") # pre-scoped; None if not provided

if "table_scope" not in descriptor:
raise ValueError(
f"FunctionNode descriptor is missing required 'table_scope' field: "
f"{descriptor.get('label', '<unlabeled>')}"
)
raw_table_scope = descriptor["table_scope"]
if raw_table_scope not in ("pipeline_hash", "content_hash"):
raise ValueError(
f"FunctionNode descriptor has invalid 'table_scope' value "
f"{raw_table_scope!r} for {descriptor.get('label', '<unlabeled>')}; "
f"expected one of ('pipeline_hash', 'content_hash')"
)
table_scope = cast(Literal["pipeline_hash", "content_hash"], raw_table_scope)

if function_pod is not None and input_stream is not None:
# Full / READ_ONLY / CACHE_ONLY mode: construct normally via __init__.
node = cls(
Expand All @@ -261,6 +307,7 @@ def from_descriptor(
pipeline_database=pipeline_db,
result_database=result_db,
label=descriptor.get("label"),
table_scope=table_scope,
)
node._descriptor = descriptor

Expand Down Expand Up @@ -329,6 +376,8 @@ def from_descriptor(
node._stored_result_record_path = tuple(
descriptor.get("result_record_path", ())
)
node._table_scope = table_scope
node._node_identity_path_cache = None

# Determine load status based on DB availability
node._load_status = LoadStatus.UNAVAILABLE
Expand Down Expand Up @@ -447,20 +496,27 @@ def keys(
def node_identity_path(self) -> tuple[str, ...]:
"""Return the node identity path for observer contextualization.

The identity path is ``pod.uri + (schema_hash, instance_hash)`` and
is computable independently of whether a pipeline database is attached.
When ``table_scope="pipeline_hash"`` (default) the path is
``pod.uri + (schema:{pipeline_hash},)`` — all runs that share the same
pipeline structure are routed to one shared table, with per-run
disambiguation via the ``_node_content_hash`` row-level column.

When ``table_scope="content_hash"`` the legacy path is returned:
``pod.uri + (schema:{pipeline_hash}, instance:{content_hash})``.

In live mode (pod present) the path is computed from the pod.
In read-only/UNAVAILABLE mode (no pod) the path stored from the
deserialized descriptor is returned (empty tuple when absent).
"""
if self._packet_function is None:
return self._stored_pipeline_path
if self._node_identity_path_cache is not None:
return self._node_identity_path_cache
pf = self._function_pod
return pf.uri + (
f"schema:{self.pipeline_hash().to_string()}",
f"instance:{self.content_hash().to_string()}",
)
path = pf.uri + (f"schema:{self.pipeline_hash().to_string()}",)
if self._table_scope != "pipeline_hash":
path += (f"instance:{self.content_hash().to_string()}",)
self._node_identity_path_cache = path
return path

@property
def node_uri(self) -> tuple[str, ...]:
Expand Down Expand Up @@ -490,6 +546,7 @@ def clear_cache(self) -> None:
self._cached_output_packets.clear()
self._cached_output_table = None
self._cached_content_hash_column = None
self._node_identity_path_cache = None
self._update_modified_time()

# ------------------------------------------------------------------
Expand Down Expand Up @@ -672,7 +729,6 @@ def get_cached_results(
self._require_pipeline_database()

PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id"
entry_id_set = set(entry_ids)

taginfo = self._pipeline_database.get_all_records(
self.node_identity_path,
Expand All @@ -686,37 +742,34 @@ def get_cached_results(
if taginfo is None or results is None:
return {}

joined = (
taginfo = self._filter_by_content_hash(taginfo)

filtered = (
pl.DataFrame(taginfo)
.join(
pl.DataFrame(results),
on=constants.PACKET_RECORD_ID,
how="inner",
)
.filter(pl.col(PIPELINE_ENTRY_ID_COL).is_in(entry_ids))
.to_arrow()
)

if joined.num_rows == 0:
return {}

# Filter to requested entry IDs
all_entry_ids = joined.column(PIPELINE_ENTRY_ID_COL).to_pylist()
mask = [eid in entry_id_set for eid in all_entry_ids]
filtered = joined.filter(pa.array(mask))

if filtered.num_rows == 0:
return {}

tag_keys = self._input_stream.keys()[0]
drop_cols = [
c
for c in filtered.column_names
if c.startswith(constants.META_PREFIX) or c == PIPELINE_ENTRY_ID_COL
if c.startswith(constants.META_PREFIX)
or c == PIPELINE_ENTRY_ID_COL
or c == constants.NODE_CONTENT_HASH_COL
]
data_table = filtered.drop([c for c in drop_cols if c in filtered.column_names])

stream = ArrowTableStream(data_table, tag_columns=tag_keys)
filtered_entry_ids = [eid for eid, m in zip(all_entry_ids, mask) if m]
filtered_entry_ids = filtered.column(PIPELINE_ENTRY_ID_COL).to_pylist()

result_dict: dict[str, tuple[TagProtocol, PacketProtocol]] = {}
for entry_id, (tag, packet) in zip(filtered_entry_ids, stream.iter_packets()):
Expand Down Expand Up @@ -801,19 +854,29 @@ def compute_pipeline_entry_id(
) -> str:
"""Compute a unique pipeline entry ID from tag + system tags + input packet hash.

This ID uniquely identifies a (tag, system_tags, input_packet) combination
and is used as the record ID in the pipeline database.
``NODE_CONTENT_HASH_COL`` is always included so that two runs processing
identical inputs each get a distinct entry ID, regardless of table scope.
This prevents the second run's pipeline record from being silently skipped
by the duplicate entry_id check.

Args:
tag: The tag (including system tags).
input_packet: The input packet.

Returns:
A hash string uniquely identifying this combination.
A hash string uniquely identifying this (tag, input_packet, node run)
combination.
"""
tag_with_hash = tag.as_table(columns={"system_tags": True}).append_column(
constants.INPUT_PACKET_HASH_COL,
pa.array([input_packet.content_hash().to_string()], type=pa.large_string()),
tag_with_hash = (
tag.as_table(columns={"system_tags": True})
.append_column(
constants.INPUT_PACKET_HASH_COL,
pa.array([input_packet.content_hash().to_string()], type=pa.large_string()),
)
.append_column(
constants.NODE_CONTENT_HASH_COL,
pa.array([self.content_hash().to_string()], type=pa.large_string()),
)
)
return self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string()

Expand Down Expand Up @@ -866,6 +929,9 @@ def add_pipeline_record(
constants.PACKET_RECORD_ID: pa.array(
[packet_record_id], type=pa.large_string()
),
constants.NODE_CONTENT_HASH_COL: pa.array(
[self.content_hash().to_string()], type=pa.large_string()
),
f"{constants.META_PREFIX}input_packet{constants.CONTEXT_KEY}": pa.array(
[input_packet.data_context_key], type=pa.large_string()
),
Comment on lines 929 to 937
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NODE_CONTENT_HASH_COL is written per row, but the pipeline DB record ID (entry_id) is still computed from just (tag + system tags + input packet hash). In table_scope="pipeline_hash" mode the table is shared across runs, so overlapping inputs across runs will collide on the same entry_id; the second run will skip writing its row, and _filter_by_content_hash() will then hide the first run’s row—losing provenance for the second run. Consider making the pipeline record ID run-scoped when self._table_scope == "pipeline_hash" (e.g., include self.content_hash() in the entry ID or use a composite key) so identical packets can be recorded for multiple runs in the shared table.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in FunctionNode.compute_pipeline_entry_id. When table_scope='pipeline_hash', the node's content_hash() is appended as an extra column to the table being hashed, making the resulting entry ID run-scoped. This means two runs processing the same (tag, input_packet) produce distinct pipeline records in the shared table instead of the second run's record being silently skipped by the duplicate check.

Expand Down Expand Up @@ -920,6 +986,8 @@ def get_all_records(
if results is None or taginfo is None:
return None

taginfo = self._filter_by_content_hash(taginfo)

joined = (
pl.DataFrame(taginfo)
.join(pl.DataFrame(results), on=constants.PACKET_RECORD_ID, how="inner")
Expand All @@ -929,6 +997,9 @@ def get_all_records(
column_config = ColumnConfig.handle_config(columns, all_info=all_info)

drop_columns = []
# Always drop the node content hash column — it is an internal
# row-level discriminator, not a user-facing column.
drop_columns.append(constants.NODE_CONTENT_HASH_COL)
if not column_config.meta and not column_config.all_info:
drop_columns.extend(
c for c in joined.column_names if c.startswith(constants.META_PREFIX)
Expand Down Expand Up @@ -1000,6 +1071,8 @@ def _load_all_cached_records(
if taginfo is None or results is None:
return None

taginfo = self._filter_by_content_hash(taginfo)

joined = (
pl.DataFrame(taginfo)
.join(
Expand All @@ -1015,20 +1088,24 @@ def _load_all_cached_records(

# Tag keys are the user-facing tag columns from the pipeline DB table.
# Exclude: meta columns (__*), source columns (_source_*),
# system-tag columns (e.g. __tag_*), and the entry-ID column.
# system-tag columns (e.g. __tag_*), the entry-ID column, and the
# node content hash column.
tag_keys = tuple(
c
for c in taginfo.column_names
if not c.startswith(constants.META_PREFIX)
and not c.startswith(constants.SOURCE_PREFIX)
and not c.startswith(constants.SYSTEM_TAG_PREFIX)
and c != PIPELINE_ENTRY_ID_COL
and c != constants.NODE_CONTENT_HASH_COL
)

drop_cols = [
c
for c in joined.column_names
if c.startswith(constants.META_PREFIX) or c == PIPELINE_ENTRY_ID_COL
if c.startswith(constants.META_PREFIX)
or c == PIPELINE_ENTRY_ID_COL
or c == constants.NODE_CONTENT_HASH_COL
]
data_table = joined.drop([c for c in drop_cols if c in joined.column_names])
return tag_keys, data_table
Expand Down Expand Up @@ -1124,6 +1201,7 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]:
)

if taginfo is not None and results is not None:
taginfo = self._filter_by_content_hash(taginfo)
joined = (
pl.DataFrame(taginfo)
.join(
Expand All @@ -1148,6 +1226,7 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]:
for c in joined.column_names
if c.startswith(constants.META_PREFIX)
or c == PIPELINE_ENTRY_ID_COL
or c == constants.NODE_CONTENT_HASH_COL
]
data_table = joined.drop(
[c for c in drop_cols if c in joined.column_names]
Expand Down Expand Up @@ -1444,6 +1523,7 @@ async def async_execute(
)

if taginfo is not None and results is not None:
taginfo = self._filter_by_content_hash(taginfo)
joined = (
pl.DataFrame(taginfo)
.join(
Expand All @@ -1463,6 +1543,7 @@ async def async_execute(
for c in joined.column_names
if c.startswith(constants.META_PREFIX)
or c == PIPELINE_ENTRY_ID_COL
or c == constants.NODE_CONTENT_HASH_COL
]
data_table = joined.drop(
[c for c in drop_cols if c in joined.column_names]
Expand Down
Loading
Loading