From 8c4654a33da08dd78583383979d57af0371cd173 Mon Sep 17 00:00:00 2001 From: beardedeagle Date: Fri, 3 Apr 2026 04:49:26 -0500 Subject: [PATCH 1/9] fix: close chunked-prefill scope gaps Add a typed chunked-prefill scope surface so the remaining feature boundaries are encoded in runtime and benchmark code instead of being left as doc-only caveats. The new surface records runtime eligibility, whether bounded chunked prefill actually ran, the post-tokenization/full-prefix-mask execution boundary, and explicit reject decisions for the two previously open scope gaps. Thread the new scope facts through runtime execution tracing, benchmark request metrics, benchmark JSON parsing, and request metric summaries. Add focused regression coverage for the supported optimized-native causal path, the unsupported seq2seq path, benchmark trace/report serialization, and the new scope decisions. Update README and benchmarking/optimization docs so user-facing guidance matches the code-backed contract: chunked prefill begins after prompt construction, remains limited to optimized-native causal text runtimes, and publishes its non-goals through benchmark output. --- .beads/interactions.jsonl | 1 + README.md | 4 +- docs/benchmarking.md | 1 + docs/guides/optimization.md | 4 + .../chunked_prefill_serialization.py | 63 ++++++ src/ollm/runtime/benchmark/details.py | 2 +- src/ollm/runtime/benchmark/probe_execution.py | 1 + .../runtime/benchmark/probe_serialization.py | 15 ++ src/ollm/runtime/benchmark/probe_types.py | 3 + src/ollm/runtime/chunked_prefill.py | 181 ++++++++++++++++++ src/ollm/runtime/execution_trace.py | 8 +- src/ollm/runtime/generation.py | 71 +++++-- tests/benchmark_support.py | 17 ++ tests/test_benchmark_probe_execution.py | 10 +- tests/test_benchmark_reporting.py | 17 +- tests/test_chunked_prefill_scope.py | 104 ++++++++++ 16 files changed, 478 insertions(+), 24 deletions(-) create mode 100644 src/ollm/runtime/benchmark/chunked_prefill_serialization.py create mode 100644 src/ollm/runtime/chunked_prefill.py create mode 100644 tests/test_chunked_prefill_scope.py diff --git a/.beads/interactions.jsonl b/.beads/interactions.jsonl index 45b4e44..a2502b4 100644 --- a/.beads/interactions.jsonl +++ b/.beads/interactions.jsonl @@ -8,3 +8,4 @@ {"id":"int-e2c83df6","kind":"field_change","created_at":"2026-04-03T04:07:29.73357Z","actor":"beardedeagle","issue_id":"ollm-cly","extra":{"field":"status","new_value":"closed","old_value":"in_progress","reason":"Completed"}} {"id":"int-7c080521","kind":"field_change","created_at":"2026-04-03T04:58:45.426115Z","actor":"beardedeagle","issue_id":"ollm-7zk","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} {"id":"int-bf6e89f2","kind":"field_change","created_at":"2026-04-03T06:33:55.374489Z","actor":"beardedeagle","issue_id":"ollm-nnt","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} +{"id":"int-a6481fd7","kind":"field_change","created_at":"2026-04-03T09:48:50.10993Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"closed","old_value":"in_progress","reason":"Completed"}} diff --git a/README.md b/README.md index 2290b43..f6114e1 100644 --- a/README.md +++ b/README.md @@ -165,7 +165,9 @@ through bounded prefill chunks before the final decode step. That keeps prompt execution from growing one full prompt-wide activation step at a time on very long inputs while preserving the external prompt/chat contract. Prompt-scaling benchmarks remain the right place to evaluate the TTFT and memory tradeoff on -target hardware. +target hardware. This feature starts after prompt tokenization and full-prefix +attention-mask construction, and it intentionally does not extend that same +contract to seq2seq, multimodal, or generic Transformers runtimes. Configuration layering uses an explicit precedence contract: diff --git a/docs/benchmarking.md b/docs/benchmarking.md index 353f262..c24cf52 100644 --- a/docs/benchmarking.md +++ b/docs/benchmarking.md @@ -180,6 +180,7 @@ Interpretation notes: - peak RSS includes a source label; long-lived warm/scaling/session probes use stage-local sampled peaks instead of process-lifetime peaks - allocator-gap metrics are reported as reserved-minus-allocated style slack when the backend exposes the required counters; unsupported backends serialize them as `null` - optimized-native decoder-only prompt-scaling runs exercise bounded chunked prefill on long text prompts, so the prompt-length sweep is the intended place to inspect the memory versus TTFT tradeoff for this feature +- request metrics also include a `chunked_prefill` section that states whether the active runtime was eligible, whether chunking actually ran, and the explicit rejected non-goals for prompt-construction streaming and non-causal/generic runtime expansion On loader-streamed families such as optimized Gemma3 on CPU, a long per-turn session-growth response can become dominated by repeated safetensor layer reads diff --git a/docs/guides/optimization.md b/docs/guides/optimization.md index 40011b4..235a144 100644 --- a/docs/guides/optimization.md +++ b/docs/guides/optimization.md @@ -19,6 +19,10 @@ Optimized-native decoder-only text prompts use bounded chunked prefill for long prompt ingestion before the final decode step. This is a memory-control path, not a blanket latency optimization, so prompt-scaling benchmarks are the truthful way to evaluate whether the chunking tradeoff helps on a given host. +The contract starts after prompt tokenization and full-prefix attention-mask +construction. Seq2seq, multimodal, and generic Transformers runtimes are +intentionally outside this feature boundary and would require separate designs +and benchmark semantics if pursued. ### Transformers-generic Used for compatible local or materialized models that can run through the generic Transformers-backed path. diff --git a/src/ollm/runtime/benchmark/chunked_prefill_serialization.py b/src/ollm/runtime/benchmark/chunked_prefill_serialization.py new file mode 100644 index 0000000..72452f7 --- /dev/null +++ b/src/ollm/runtime/benchmark/chunked_prefill_serialization.py @@ -0,0 +1,63 @@ +"""Chunked-prefill benchmark JSON parsing helpers.""" + +from collections.abc import Mapping +from typing import cast + +from ollm.runtime.chunked_prefill import ( + ChunkedPrefillAttentionMaskMode, + ChunkedPrefillExecutionBoundary, + ChunkedPrefillGapDecision, + ChunkedPrefillGapId, + ChunkedPrefillRecommendation, + ChunkedPrefillScopeSurface, +) + + +def parse_chunked_prefill( + value: object, + *, + require_bool, + require_object_mapping, + require_sequence, + require_string, +) -> ChunkedPrefillScopeSurface: + if not isinstance(value, Mapping): + raise ValueError("chunked_prefill must be an object") + payload = cast(Mapping[str, object], value) + gap_items = require_sequence(payload, "gap_inventory") + return ChunkedPrefillScopeSurface( + runtime_eligible=require_bool(payload, "runtime_eligible"), + applied=require_bool(payload, "applied"), + activation_reason=require_string(payload, "activation_reason"), + supported_backend_id=require_string(payload, "supported_backend_id"), + supported_model_kind=require_string(payload, "supported_model_kind"), + supported_prompt_kind=require_string(payload, "supported_prompt_kind"), + execution_boundary=ChunkedPrefillExecutionBoundary( + require_string(payload, "execution_boundary") + ), + attention_mask_mode=ChunkedPrefillAttentionMaskMode( + require_string(payload, "attention_mask_mode") + ), + gap_inventory=tuple( + parse_chunked_prefill_gap( + require_object_mapping(item, f"gap_inventory[{index}]"), + require_string=require_string, + ) + for index, item in enumerate(gap_items) + ), + ) + + +def parse_chunked_prefill_gap( + payload: Mapping[str, object], + *, + require_string, +) -> ChunkedPrefillGapDecision: + return ChunkedPrefillGapDecision( + gap_id=ChunkedPrefillGapId(require_string(payload, "gap_id")), + current_behavior=require_string(payload, "current_behavior"), + recommendation=ChunkedPrefillRecommendation( + require_string(payload, "recommendation") + ), + rationale=require_string(payload, "rationale"), + ) diff --git a/src/ollm/runtime/benchmark/details.py b/src/ollm/runtime/benchmark/details.py index a3f7990..ec9fcb9 100644 --- a/src/ollm/runtime/benchmark/details.py +++ b/src/ollm/runtime/benchmark/details.py @@ -36,7 +36,6 @@ def build_cold_probe_details( def summarize_request_metrics(samples: list[RequestProbeMetrics]) -> dict[str, object]: """Summarize request-level runtime probe metrics.""" - from ollm.runtime.benchmark.offload_summary import summarize_request_offload return { @@ -82,6 +81,7 @@ def summarize_request_metrics(samples: list[RequestProbeMetrics]) -> dict[str, o ] ), }, + "chunked_prefill": samples[-1].chunked_prefill.to_dict(), "memory": summarize_stage_resources([sample.resources for sample in samples]), "cache": { "cache_mode": single_optional_string( diff --git a/src/ollm/runtime/benchmark/probe_execution.py b/src/ollm/runtime/benchmark/probe_execution.py index 2267fad..2d72bd2 100644 --- a/src/ollm/runtime/benchmark/probe_execution.py +++ b/src/ollm/runtime/benchmark/probe_execution.py @@ -217,6 +217,7 @@ def execute_request_probe( kv_cache_adaptation=kv_cache_adaptation, cache_dir_size_mb=cache_dir_size, cache_state=cache_state, + chunked_prefill=trace.chunked_prefill, allocator_gap_mb=allocator_gap_mb, allocator_gap_ratio=allocator_gap_ratio, native_runtime_profile=native_runtime_profile, diff --git a/src/ollm/runtime/benchmark/probe_serialization.py b/src/ollm/runtime/benchmark/probe_serialization.py index 3765df1..4e501d4 100644 --- a/src/ollm/runtime/benchmark/probe_serialization.py +++ b/src/ollm/runtime/benchmark/probe_serialization.py @@ -6,6 +6,7 @@ from ollm.kv_cache.matrix import KVCacheAdaptationSurface from ollm.kv_cache.state import KVCacheStateSnapshot +from ollm.runtime.benchmark.chunked_prefill_serialization import parse_chunked_prefill from ollm.runtime.benchmark.probe_types import ( EventTimingSummary, NativeRuntimeProfile, @@ -244,6 +245,13 @@ def _parse_request_probe_metrics(payload: Mapping[str, object]) -> RequestProbeM ), cache_dir_size_mb=_optional_float(payload, "cache_dir_size_mb"), cache_state=_parse_cache_state(payload.get("cache_state")), + chunked_prefill=parse_chunked_prefill( + payload.get("chunked_prefill"), + require_bool=_require_bool, + require_object_mapping=_require_object_mapping, + require_sequence=_require_sequence, + require_string=_require_string, + ), allocator_gap_mb=_optional_float(payload, "allocator_gap_mb"), allocator_gap_ratio=_optional_float(payload, "allocator_gap_ratio"), native_runtime_profile=_parse_native_runtime_profile( @@ -447,6 +455,13 @@ def _optional_int(payload: Mapping[str, object], key: str) -> int | None: raise ValueError(f"probe field '{key}' must be an integer or null") +def _require_bool(payload: Mapping[str, object], key: str) -> bool: + value = payload.get(key) + if isinstance(value, bool): + return value + raise ValueError(f"probe field '{key}' must be a boolean") + + def _require_string(payload: Mapping[str, object], key: str) -> str: value = payload.get(key) if isinstance(value, str): diff --git a/src/ollm/runtime/benchmark/probe_types.py b/src/ollm/runtime/benchmark/probe_types.py index 6562693..90fd5a1 100644 --- a/src/ollm/runtime/benchmark/probe_types.py +++ b/src/ollm/runtime/benchmark/probe_types.py @@ -5,6 +5,7 @@ from ollm.kv_cache.matrix import KVCacheAdaptationSurface from ollm.kv_cache.state import KVCacheStateSnapshot from ollm.runtime.benchmark.resources import StageResourceSnapshot +from ollm.runtime.chunked_prefill import ChunkedPrefillScopeSurface @dataclass(frozen=True, slots=True) @@ -50,6 +51,7 @@ class RequestProbeMetrics: kv_cache_adaptation: KVCacheAdaptationSurface | None cache_dir_size_mb: float | None cache_state: KVCacheStateSnapshot | None + chunked_prefill: ChunkedPrefillScopeSurface allocator_gap_mb: float | None allocator_gap_ratio: float | None native_runtime_profile: NativeRuntimeProfile | None @@ -71,6 +73,7 @@ def to_dict(self) -> dict[str, object]: payload["cache_state"] = ( None if self.cache_state is None else self.cache_state.to_dict() ) + payload["chunked_prefill"] = self.chunked_prefill.to_dict() payload["kv_cache_adaptation"] = ( None if self.kv_cache_adaptation is None diff --git a/src/ollm/runtime/chunked_prefill.py b/src/ollm/runtime/chunked_prefill.py new file mode 100644 index 0000000..609b23a --- /dev/null +++ b/src/ollm/runtime/chunked_prefill.py @@ -0,0 +1,181 @@ +"""Typed scope facts for bounded chunked prefill.""" + +from dataclasses import asdict, dataclass, replace +from enum import StrEnum +from typing import Self + +import torch + +from ollm.runtime.capability_discovery import GenericModelKind +from ollm.runtime.loaded_runtime import LoadedRuntime + + +class ChunkedPrefillGapId(StrEnum): + PROMPT_CONSTRUCTION_BEFORE_PREFILL = "prompt-construction-before-prefill" + NON_CAUSAL_RUNTIME_EXPANSION = "non-causal-runtime-expansion" + + +class ChunkedPrefillRecommendation(StrEnum): + IMPLEMENT = "implement" + DEFER = "defer" + REJECT = "reject" + + +class ChunkedPrefillExecutionBoundary(StrEnum): + POST_TOKENIZATION = "post-tokenization" + + +class ChunkedPrefillAttentionMaskMode(StrEnum): + FULL_PREFIX_MATERIALIZED = "full-prefix-materialized" + + +@dataclass(frozen=True, slots=True) +class ChunkedPrefillGapDecision: + gap_id: ChunkedPrefillGapId + current_behavior: str + recommendation: ChunkedPrefillRecommendation + rationale: str + + def to_dict(self) -> dict[str, str]: + payload = asdict(self) + return {key: str(value) for key, value in payload.items()} + + +@dataclass(frozen=True, slots=True) +class ChunkedPrefillScopeSurface: + runtime_eligible: bool + applied: bool + activation_reason: str + supported_backend_id: str + supported_model_kind: str + supported_prompt_kind: str + execution_boundary: ChunkedPrefillExecutionBoundary + attention_mask_mode: ChunkedPrefillAttentionMaskMode + gap_inventory: tuple[ChunkedPrefillGapDecision, ...] + + def with_activation(self, *, applied: bool, activation_reason: str) -> Self: + return replace( + self, + applied=applied, + activation_reason=activation_reason, + ) + + def to_dict(self) -> dict[str, object]: + return { + "runtime_eligible": self.runtime_eligible, + "applied": self.applied, + "activation_reason": self.activation_reason, + "supported_backend_id": self.supported_backend_id, + "supported_model_kind": self.supported_model_kind, + "supported_prompt_kind": self.supported_prompt_kind, + "execution_boundary": self.execution_boundary.value, + "attention_mask_mode": self.attention_mask_mode.value, + "gap_inventory": [decision.to_dict() for decision in self.gap_inventory], + } + + +_CHUNKED_PREFILL_GAP_INVENTORY = ( + ChunkedPrefillGapDecision( + gap_id=ChunkedPrefillGapId.PROMPT_CONSTRUCTION_BEFORE_PREFILL, + current_behavior=( + "Prompt tokenization and full-prefix attention-mask materialization " + "complete before chunked prefill begins." + ), + recommendation=ChunkedPrefillRecommendation.REJECT, + rationale=( + "Bounded chunked prefill is an execution-stage memory control after " + "prompt construction. Streaming prompt construction would be a " + "different feature with different tokenizer and processor " + "contracts." + ), + ), + ChunkedPrefillGapDecision( + gap_id=ChunkedPrefillGapId.NON_CAUSAL_RUNTIME_EXPANSION, + current_behavior=( + "Chunked prefill is limited to optimized-native causal text " + "runtimes and does not cover seq2seq, multimodal, or " + "transformers-generic paths." + ), + recommendation=ChunkedPrefillRecommendation.REJECT, + rationale=( + "Those runtimes have different encoder, cache, and prompt-shape " + "contracts. If pursued, they should ship as separate features " + "with their own execution and benchmark semantics." + ), + ), +) + + +def build_chunked_prefill_scope_surface( + *, + runtime: LoadedRuntime, + inputs: dict[str, object], + chunk_tokens: int, +) -> ChunkedPrefillScopeSurface: + input_ids = inputs.get("input_ids") + if not isinstance(input_ids, torch.Tensor): + return _surface( + runtime_eligible=False, + activation_reason="Chunked prefill requires tensor-backed input_ids.", + ) + if input_ids.ndim != 2 or input_ids.shape[0] != 1: + return _surface( + runtime_eligible=False, + activation_reason="Chunked prefill requires a single batch row.", + ) + if runtime.processor is not None: + return _surface( + runtime_eligible=False, + activation_reason=( + "Chunked prefill is limited to text prompts without a processor." + ), + ) + if runtime.plan.backend_id != "optimized-native": + return _surface( + runtime_eligible=False, + activation_reason=( + "Chunked prefill is limited to the optimized-native backend." + ), + ) + runtime_kind = ( + runtime.plan.generic_model_kind or runtime.resolved_model.generic_model_kind + ) + if runtime_kind is not GenericModelKind.CAUSAL_LM: + return _surface( + runtime_eligible=False, + activation_reason=( + "Chunked prefill is limited to causal decoder-only text runtimes." + ), + ) + prefill_token_count = int(input_ids.shape[1] - 1) + if prefill_token_count <= chunk_tokens: + return _surface( + runtime_eligible=True, + activation_reason=( + "Prompt length does not exceed the chunked-prefill threshold." + ), + ) + return _surface( + runtime_eligible=True, + activation_reason="Runtime is eligible for bounded chunked prefill.", + ) + + +def chunked_prefill_gap_inventory() -> tuple[ChunkedPrefillGapDecision, ...]: + return _CHUNKED_PREFILL_GAP_INVENTORY + + +def _surface( + *, runtime_eligible: bool, activation_reason: str +) -> ChunkedPrefillScopeSurface: + return ChunkedPrefillScopeSurface( + runtime_eligible=runtime_eligible, + applied=False, + activation_reason=activation_reason, + supported_backend_id="optimized-native", + supported_model_kind=GenericModelKind.CAUSAL_LM.value, + supported_prompt_kind="text-only", + execution_boundary=ChunkedPrefillExecutionBoundary.POST_TOKENIZATION, + attention_mask_mode=ChunkedPrefillAttentionMaskMode.FULL_PREFIX_MATERIALIZED, + gap_inventory=_CHUNKED_PREFILL_GAP_INVENTORY, + ) diff --git a/src/ollm/runtime/execution_trace.py b/src/ollm/runtime/execution_trace.py index e6cbf76..ef52c64 100644 --- a/src/ollm/runtime/execution_trace.py +++ b/src/ollm/runtime/execution_trace.py @@ -8,6 +8,7 @@ from ollm.app.types import PromptRequest from ollm.kv_cache.state import KVCacheStateSnapshot from ollm.runtime.capability_discovery import GenericModelKind +from ollm.runtime.chunked_prefill import ChunkedPrefillScopeSurface from ollm.runtime.errors import PromptExecutionError from ollm.runtime.generation import ( build_runtime_generate_kwargs, @@ -35,6 +36,7 @@ class RuntimeExecutionTrace: output_token_count: int response_text: str cache_state: KVCacheStateSnapshot | None + chunked_prefill: ChunkedPrefillScopeSurface def execute_request_with_trace( @@ -56,12 +58,13 @@ def execute_request_with_trace( ) prepared_inputs = normalize_generate_inputs(inputs) generation_started_at = time.perf_counter() - prepared_inputs, prepared_generate_kwargs = prepare_runtime_generate_inputs( + prepared_result = prepare_runtime_generate_inputs( runtime, - request, prepared_inputs, generate_kwargs, ) + prepared_inputs = prepared_result.inputs + prepared_generate_kwargs = prepared_result.generate_kwargs outputs, effective_generate_kwargs = _generate_outputs( runtime=runtime, prepared_inputs=prepared_inputs, @@ -90,6 +93,7 @@ def execute_request_with_trace( output_token_count=output_token_count, response_text=response_text, cache_state=cache_state, + chunked_prefill=prepared_result.chunked_prefill, ) diff --git a/src/ollm/runtime/generation.py b/src/ollm/runtime/generation.py index 9f487cb..4705fd0 100644 --- a/src/ollm/runtime/generation.py +++ b/src/ollm/runtime/generation.py @@ -13,6 +13,10 @@ from ollm.kv_cache.state import KVCacheStateSnapshot from ollm.runtime.capability_discovery import GenericModelKind from ollm.runtime.catalog import ModelModality +from ollm.runtime.chunked_prefill import ( + ChunkedPrefillScopeSurface, + build_chunked_prefill_scope_surface, +) from ollm.runtime.errors import PromptExecutionError from ollm.runtime.generation_config_support import ( clear_sampling_fields, @@ -39,6 +43,13 @@ def print_and_clean(self) -> str: ... DEFAULT_PREFILL_CHUNK_TOKENS = 512 +@dataclass(frozen=True, slots=True) +class PreparedRuntimeGenerateInputs: + inputs: dict[str, object] + generate_kwargs: dict[str, object] + chunked_prefill: ChunkedPrefillScopeSurface + + def validate_runtime_request(runtime: LoadedRuntime, request: PromptRequest) -> None: if not request.messages: raise PromptExecutionError("At least one message is required") @@ -181,28 +192,53 @@ def build_runtime_generate_kwargs( def prepare_runtime_generate_inputs( runtime: LoadedRuntime, - request: PromptRequest, inputs: dict[str, object], generate_kwargs: dict[str, object], -) -> tuple[dict[str, object], dict[str, object]]: +) -> PreparedRuntimeGenerateInputs: + chunked_prefill = build_chunked_prefill_scope_surface( + runtime=runtime, + inputs=inputs, + chunk_tokens=DEFAULT_PREFILL_CHUNK_TOKENS, + ) input_ids_value = inputs.get("input_ids") if not isinstance(input_ids_value, torch.Tensor): - return inputs, generate_kwargs + return PreparedRuntimeGenerateInputs( + inputs=inputs, + generate_kwargs=generate_kwargs, + chunked_prefill=chunked_prefill, + ) if input_ids_value.ndim != 2 or input_ids_value.shape[0] != 1: - return inputs, generate_kwargs - if runtime.processor is not None: - return inputs, generate_kwargs - if runtime.plan.backend_id != "optimized-native": - return inputs, generate_kwargs - runtime_kind = ( - runtime.plan.generic_model_kind or runtime.resolved_model.generic_model_kind - ) - if runtime_kind is not GenericModelKind.CAUSAL_LM: - return inputs, generate_kwargs + return PreparedRuntimeGenerateInputs( + inputs=inputs, + generate_kwargs=generate_kwargs, + chunked_prefill=chunked_prefill, + ) prefill_token_count = input_ids_value.shape[1] - 1 if prefill_token_count <= DEFAULT_PREFILL_CHUNK_TOKENS: - return inputs, generate_kwargs - return _run_chunked_prefill(runtime, inputs, generate_kwargs) + return PreparedRuntimeGenerateInputs( + inputs=inputs, + generate_kwargs=generate_kwargs, + chunked_prefill=chunked_prefill, + ) + if not chunked_prefill.runtime_eligible: + return PreparedRuntimeGenerateInputs( + inputs=inputs, + generate_kwargs=generate_kwargs, + chunked_prefill=chunked_prefill, + ) + prepared_inputs, prepared_generate_kwargs = _run_chunked_prefill( + runtime, + inputs, + generate_kwargs, + ) + return PreparedRuntimeGenerateInputs( + inputs=prepared_inputs, + generate_kwargs=prepared_generate_kwargs, + chunked_prefill=chunked_prefill.with_activation( + applied=True, + activation_reason="Bounded chunked prefill ran before final decode.", + ), + ) def _run_chunked_prefill( @@ -308,12 +344,13 @@ def execute( runtime, request, streamer ) filtered_inputs = normalize_generate_inputs(inputs) - filtered_inputs, generate_kwargs = prepare_runtime_generate_inputs( + prepared_generate_inputs = prepare_runtime_generate_inputs( runtime, - request, filtered_inputs, generate_kwargs, ) + filtered_inputs = prepared_generate_inputs.inputs + generate_kwargs = prepared_generate_inputs.generate_kwargs with torch.inference_mode(): with suppress_module_prints(runtime.backend.print_suppression_modules): diff --git a/tests/benchmark_support.py b/tests/benchmark_support.py index adf9019..437ec6d 100644 --- a/tests/benchmark_support.py +++ b/tests/benchmark_support.py @@ -6,6 +6,12 @@ RequestProbeMetrics, ) from ollm.runtime.benchmark.resources import StageResourceSnapshot +from ollm.runtime.chunked_prefill import ( + ChunkedPrefillAttentionMaskMode, + ChunkedPrefillExecutionBoundary, + ChunkedPrefillScopeSurface, + chunked_prefill_gap_inventory, +) def build_stage_resources() -> StageResourceSnapshot: @@ -95,6 +101,17 @@ def build_request_probe_metrics() -> RequestProbeMetrics: evicted_tokens=0, cold_store_format=None, ), + chunked_prefill=ChunkedPrefillScopeSurface( + runtime_eligible=True, + applied=True, + activation_reason="Bounded chunked prefill ran before final decode.", + supported_backend_id="optimized-native", + supported_model_kind="causal-lm", + supported_prompt_kind="text-only", + execution_boundary=ChunkedPrefillExecutionBoundary.POST_TOKENIZATION, + attention_mask_mode=ChunkedPrefillAttentionMaskMode.FULL_PREFIX_MATERIALIZED, + gap_inventory=chunked_prefill_gap_inventory(), + ), allocator_gap_mb=20.0, allocator_gap_ratio=0.066667, native_runtime_profile=build_native_runtime_profile(), diff --git a/tests/test_benchmark_probe_execution.py b/tests/test_benchmark_probe_execution.py index 41a988c..ca9f68a 100644 --- a/tests/test_benchmark_probe_execution.py +++ b/tests/test_benchmark_probe_execution.py @@ -152,6 +152,8 @@ def test_execute_request_probe_strips_processor_token_type_ids() -> None: assert execution.response_text == "decoded-benchmark" assert "token_type_ids" not in runtime.model.generate_kwargs + assert execution.metrics.chunked_prefill.runtime_eligible is False + assert execution.metrics.chunked_prefill.applied is False def test_execute_request_probe_uses_chunked_prefill_for_long_causal_prompts( @@ -222,6 +224,8 @@ def test_execute_request_with_trace_reports_processor_counts() -> None: assert trace.decode_prefix_token_count == 3 assert trace.output_token_count == 1 assert trace.cache_state is None + assert trace.chunked_prefill.runtime_eligible is False + assert trace.chunked_prefill.applied is False def test_execute_request_with_trace_tracks_chunked_prefill_prefix_length( @@ -272,6 +276,8 @@ def test_execute_request_with_trace_tracks_chunked_prefill_prefix_length( assert trace.prompt_token_count == 5 assert trace.decode_prefix_token_count == 1 assert trace.output_token_count == 4 + assert trace.chunked_prefill.runtime_eligible is True + assert trace.chunked_prefill.applied is True assert len(model.forward_calls) == 2 @@ -325,9 +331,9 @@ def wrapped_perf_counter() -> float: "prepare_runtime_generate_inputs" ] - def wrapped_prepare(runtime, request, inputs, generate_kwargs): + def wrapped_prepare(runtime, inputs, generate_kwargs): order.append("prepare") - return original_prepare(runtime, request, inputs, generate_kwargs) + return original_prepare(runtime, inputs, generate_kwargs) monkeypatch.setattr( "ollm.runtime.execution_trace.time.perf_counter", wrapped_perf_counter diff --git a/tests/test_benchmark_reporting.py b/tests/test_benchmark_reporting.py index ced1e1f..f85f00c 100644 --- a/tests/test_benchmark_reporting.py +++ b/tests/test_benchmark_reporting.py @@ -33,6 +33,7 @@ parse_output_scaling_probe_result, parse_prompt_scaling_probe_result, parse_reopen_session_growth_probe_result, + parse_runtime_probe_result, parse_session_growth_probe_result, parse_warm_runtime_probe_result, render_output_scaling_probe_json, @@ -101,16 +102,26 @@ def test_render_runtime_probe_json_round_trips() -> None: request=build_request_probe_metrics(), ) - payload = json.loads(render_runtime_probe_json(probe)) + rendered = render_runtime_probe_json(probe) + payload = json.loads(rendered) + parsed = parse_runtime_probe_result(rendered) assert payload["load_ms"] == 10.0 + assert parsed.request.chunked_prefill.applied is True request = cast(dict[str, object], payload["request"]) resources = cast(dict[str, object], request["resources"]) native_runtime_profile = cast(dict[str, object], request["native_runtime_profile"]) cache_state = cast(dict[str, object], request["cache_state"]) + chunked_prefill = cast(dict[str, object], request["chunked_prefill"]) events = cast(dict[str, object], native_runtime_profile["events"]) assert request["output_tokens"] == 4 assert request["kv_cache_strategy"] == "chunked" + assert chunked_prefill["runtime_eligible"] is True + assert chunked_prefill["applied"] is True + assert chunked_prefill["execution_boundary"] == "post-tokenization" + assert chunked_prefill["attention_mask_mode"] == "full-prefix-materialized" + gap_inventory = cast(list[object], chunked_prefill["gap_inventory"]) + assert len(gap_inventory) == 2 adaptation = cast(dict[str, object], request["kv_cache_adaptation"]) assert adaptation["adaptation_mode"] == "observe-only" assert adaptation["recommendation_available"] is True @@ -237,11 +248,15 @@ def test_summarize_request_metrics_includes_native_runtime_profile() -> None: summary = summarize_request_metrics([build_request_probe_metrics()]) native_runtime_profile = cast(dict[str, object], summary["native_runtime_profile"]) + chunked_prefill = cast(dict[str, object], summary["chunked_prefill"]) assert native_runtime_profile["storage_paths"] == [ "disk-kv-cache", "safetensor-io", ] + assert chunked_prefill["runtime_eligible"] is True + assert chunked_prefill["applied"] is True + assert chunked_prefill["supported_backend_id"] == "optimized-native" assert cast(dict[str, object], summary["cache"])["kv_cache_strategy"] == "chunked" adaptation = cast( dict[str, object], diff --git a/tests/test_chunked_prefill_scope.py b/tests/test_chunked_prefill_scope.py new file mode 100644 index 0000000..d62e643 --- /dev/null +++ b/tests/test_chunked_prefill_scope.py @@ -0,0 +1,104 @@ +from dataclasses import replace + +from ollm.app.types import ContentPart, Message, MessageRole +from ollm.runtime.capabilities import CapabilityProfile, SupportLevel +from ollm.runtime.capability_discovery import GenericModelKind +from ollm.runtime.chunked_prefill import ( + ChunkedPrefillGapId, + ChunkedPrefillRecommendation, +) +from ollm.runtime.generation import ( + build_runtime_generate_kwargs, + build_runtime_inputs, + prepare_runtime_generate_inputs, +) +from tests.test_runtime_executor import ( + build_request, + build_runtime_with_model, +) +from tests.test_runtime_executor_prefill import ( + ChunkedPrefillModel, + LongMappingTokenizer, +) + + +def test_prepare_runtime_generate_inputs_surfaces_chunked_prefill_scope( + monkeypatch, +) -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=ChunkedPrefillModel(), + ) + runtime.plan = replace( + runtime.plan, + backend_id="optimized-native", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + inputs = build_runtime_inputs(runtime, request.messages) + generate_kwargs, _generation_config = build_runtime_generate_kwargs( + runtime, + request, + streamer=None, + ) + prepared = prepare_runtime_generate_inputs(runtime, inputs, generate_kwargs) + + assert prepared.chunked_prefill.runtime_eligible is True + assert prepared.chunked_prefill.applied is True + assert ( + prepared.chunked_prefill.activation_reason + == "Bounded chunked prefill ran before final decode." + ) + gap_inventory = { + decision.gap_id: decision.recommendation + for decision in prepared.chunked_prefill.gap_inventory + } + assert gap_inventory == { + ChunkedPrefillGapId.PROMPT_CONSTRUCTION_BEFORE_PREFILL: ( + ChunkedPrefillRecommendation.REJECT + ), + ChunkedPrefillGapId.NON_CAUSAL_RUNTIME_EXPANSION: ( + ChunkedPrefillRecommendation.REJECT + ), + } + + +def test_prepare_runtime_generate_inputs_rejects_seq2seq_scope_extension( + monkeypatch, +) -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=ChunkedPrefillModel(), + ) + runtime.plan = replace( + runtime.plan, + backend_id="optimized-native", + generic_model_kind=GenericModelKind.SEQ2SEQ_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + inputs = build_runtime_inputs(runtime, request.messages) + generate_kwargs, _generation_config = build_runtime_generate_kwargs( + runtime, + request, + streamer=None, + ) + prepared = prepare_runtime_generate_inputs(runtime, inputs, generate_kwargs) + + assert prepared.chunked_prefill.runtime_eligible is False + assert prepared.chunked_prefill.applied is False + assert ( + prepared.chunked_prefill.activation_reason + == "Chunked prefill is limited to causal decoder-only text runtimes." + ) From 8ea06bc26a26a196398df02baf33f8a0ad28753a Mon Sep 17 00:00:00 2001 From: beardedeagle Date: Fri, 3 Apr 2026 06:52:37 -0500 Subject: [PATCH 2/9] feat: expand chunked prefill across causal runtime lanes Replace the earlier scope-only chunked-prefill closure with real strategy-backed runtime support. The runtime now resolves four causal chunked-prefill strategies: optimized-native text, optimized-native processor-backed multimodal, transformers-generic text, and transformers-generic processor-backed multimodal. The shared chunked-prefill core handles sequence slicing, static multimodal inputs, cache handoff, and final generate input reduction without duplicating the causal prefill loop. Thread the selected strategy through runtime metadata, execution tracing, and benchmark JSON/reporting so request metrics identify which chunked-prefill lane ran. Update the regression suite to cover the new strategy lanes and add a tiny local T5 proof showing why seq2seq source prompts still need a separate encoder-side design instead of pretending they can reuse the causal-cache contract. Refresh README and benchmarking/optimization docs so the public contract matches the implementation: four supported causal strategy lanes now exist, while prompt-construction streaming/lazy-mask work and seq2seq source-ingestion chunking are tracked separately as follow-on beads ollm-qm9 and ollm-dnl. --- .beads/interactions.jsonl | 3 + README.md | 24 +- docs/benchmarking.md | 4 +- docs/guides/optimization.md | 20 +- .../chunked_prefill_serialization.py | 16 +- src/ollm/runtime/chunked_prefill.py | 285 ++++++++++++++---- src/ollm/runtime/execution_trace.py | 10 +- src/ollm/runtime/generation.py | 141 +++------ tests/benchmark_support.py | 5 +- tests/test_benchmark_probe_execution.py | 17 +- tests/test_benchmark_reporting.py | 5 +- tests/test_chunked_prefill_scope.py | 80 +++-- tests/test_runtime_executor_prefill.py | 245 ++++++++++++++- 13 files changed, 650 insertions(+), 205 deletions(-) diff --git a/.beads/interactions.jsonl b/.beads/interactions.jsonl index a2502b4..5c40103 100644 --- a/.beads/interactions.jsonl +++ b/.beads/interactions.jsonl @@ -9,3 +9,6 @@ {"id":"int-7c080521","kind":"field_change","created_at":"2026-04-03T04:58:45.426115Z","actor":"beardedeagle","issue_id":"ollm-7zk","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} {"id":"int-bf6e89f2","kind":"field_change","created_at":"2026-04-03T06:33:55.374489Z","actor":"beardedeagle","issue_id":"ollm-nnt","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} {"id":"int-a6481fd7","kind":"field_change","created_at":"2026-04-03T09:48:50.10993Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"closed","old_value":"in_progress","reason":"Completed"}} +{"id":"int-32d6d982","kind":"field_change","created_at":"2026-04-03T11:31:04.109777Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"open","old_value":"closed"}} +{"id":"int-9e450675","kind":"field_change","created_at":"2026-04-03T11:33:42.488488Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"in_progress","old_value":"open"}} +{"id":"int-5cd99253","kind":"field_change","created_at":"2026-04-03T11:52:07.85795Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"closed","old_value":"in_progress","reason":"Completed"}} diff --git a/README.md b/README.md index f6114e1..b2d660e 100644 --- a/README.md +++ b/README.md @@ -160,14 +160,22 @@ full-history KV in memory. When the bounded the recent-context token budget and oldest tokens are evicted once the window is exceeded. -On optimized-native decoder-only text runtimes, long prompts are ingested -through bounded prefill chunks before the final decode step. That keeps prompt -execution from growing one full prompt-wide activation step at a time on very -long inputs while preserving the external prompt/chat contract. Prompt-scaling -benchmarks remain the right place to evaluate the TTFT and memory tradeoff on -target hardware. This feature starts after prompt tokenization and full-prefix -attention-mask construction, and it intentionally does not extend that same -contract to seq2seq, multimodal, or generic Transformers runtimes. +On the causal runtime lanes that support chunked prefill, long prompts are +ingested through bounded prefill chunks before the final decode step. The +current strategy lanes are: + +- `optimized-native-text` +- `optimized-native-multimodal` +- `transformers-generic-text` +- `transformers-generic-multimodal` + +That keeps prompt execution from growing one full prompt-wide activation step +at a time on very long inputs while preserving the external prompt/chat +contract. Prompt-scaling benchmarks remain the right place to evaluate the +TTFT and memory tradeoff on target hardware. Prompt tokenization and full-prefix +attention-mask construction still happen before chunking starts, and seq2seq +source prompts remain a separate deferred lane rather than silently pretending +to use causal-cache prefill. Configuration layering uses an explicit precedence contract: diff --git a/docs/benchmarking.md b/docs/benchmarking.md index c24cf52..bba540f 100644 --- a/docs/benchmarking.md +++ b/docs/benchmarking.md @@ -179,8 +179,8 @@ Interpretation notes: - output throughput is generated output tokens divided by total generation latency - peak RSS includes a source label; long-lived warm/scaling/session probes use stage-local sampled peaks instead of process-lifetime peaks - allocator-gap metrics are reported as reserved-minus-allocated style slack when the backend exposes the required counters; unsupported backends serialize them as `null` -- optimized-native decoder-only prompt-scaling runs exercise bounded chunked prefill on long text prompts, so the prompt-length sweep is the intended place to inspect the memory versus TTFT tradeoff for this feature -- request metrics also include a `chunked_prefill` section that states whether the active runtime was eligible, whether chunking actually ran, and the explicit rejected non-goals for prompt-construction streaming and non-causal/generic runtime expansion +- text prompt-scaling runs exercise bounded chunked prefill on the supported text strategy lanes, so the prompt-length sweep is the intended place to inspect the memory versus TTFT tradeoff for this feature +- request metrics also include a `chunked_prefill` section that states the selected strategy ID, whether the active runtime was eligible, whether chunking actually ran, and the remaining deferred gaps for prompt tokenization, full-prefix attention-mask construction, and seq2seq source ingestion On loader-streamed families such as optimized Gemma3 on CPU, a long per-turn session-growth response can become dominated by repeated safetensor layer reads diff --git a/docs/guides/optimization.md b/docs/guides/optimization.md index 235a144..2709cce 100644 --- a/docs/guides/optimization.md +++ b/docs/guides/optimization.md @@ -15,14 +15,18 @@ Native families: - `gpt-oss` - `voxtral` -Optimized-native decoder-only text prompts use bounded chunked prefill for -long prompt ingestion before the final decode step. This is a memory-control -path, not a blanket latency optimization, so prompt-scaling benchmarks are the -truthful way to evaluate whether the chunking tradeoff helps on a given host. -The contract starts after prompt tokenization and full-prefix attention-mask -construction. Seq2seq, multimodal, and generic Transformers runtimes are -intentionally outside this feature boundary and would require separate designs -and benchmark semantics if pursued. +Supported causal runtime lanes use bounded chunked prefill for long prompt +ingestion before the final decode step. The current lanes are +`optimized-native-text`, `optimized-native-multimodal`, +`transformers-generic-text`, and `transformers-generic-multimodal`. This is a +memory-control path, not a blanket latency optimization, so prompt-scaling +benchmarks are the truthful way to evaluate whether the chunking tradeoff helps +on a given host. + +Prompt tokenization and full-prefix attention-mask construction still complete +before chunking starts, and seq2seq source prompts remain a separate deferred +lane because encoder-decoder source ingestion is not the same causal-cache +operation. ### Transformers-generic Used for compatible local or materialized models that can run through the generic Transformers-backed path. diff --git a/src/ollm/runtime/benchmark/chunked_prefill_serialization.py b/src/ollm/runtime/benchmark/chunked_prefill_serialization.py index 72452f7..85b8a43 100644 --- a/src/ollm/runtime/benchmark/chunked_prefill_serialization.py +++ b/src/ollm/runtime/benchmark/chunked_prefill_serialization.py @@ -10,6 +10,7 @@ ChunkedPrefillGapId, ChunkedPrefillRecommendation, ChunkedPrefillScopeSurface, + ChunkedPrefillStrategyId, ) @@ -26,12 +27,10 @@ def parse_chunked_prefill( payload = cast(Mapping[str, object], value) gap_items = require_sequence(payload, "gap_inventory") return ChunkedPrefillScopeSurface( + strategy_id=_optional_strategy_id(payload, require_string=require_string), runtime_eligible=require_bool(payload, "runtime_eligible"), applied=require_bool(payload, "applied"), activation_reason=require_string(payload, "activation_reason"), - supported_backend_id=require_string(payload, "supported_backend_id"), - supported_model_kind=require_string(payload, "supported_model_kind"), - supported_prompt_kind=require_string(payload, "supported_prompt_kind"), execution_boundary=ChunkedPrefillExecutionBoundary( require_string(payload, "execution_boundary") ), @@ -61,3 +60,14 @@ def parse_chunked_prefill_gap( ), rationale=require_string(payload, "rationale"), ) + + +def _optional_strategy_id( + payload: Mapping[str, object], + *, + require_string, +) -> ChunkedPrefillStrategyId | None: + value = payload.get("strategy_id") + if value is None: + return None + return ChunkedPrefillStrategyId(require_string(payload, "strategy_id")) diff --git a/src/ollm/runtime/chunked_prefill.py b/src/ollm/runtime/chunked_prefill.py index 609b23a..b939258 100644 --- a/src/ollm/runtime/chunked_prefill.py +++ b/src/ollm/runtime/chunked_prefill.py @@ -1,18 +1,30 @@ -"""Typed scope facts for bounded chunked prefill.""" +"""Chunked-prefill strategy resolution and execution.""" from dataclasses import asdict, dataclass, replace from enum import StrEnum +from inspect import Parameter, signature from typing import Self import torch from ollm.runtime.capability_discovery import GenericModelKind +from ollm.runtime.errors import PromptExecutionError +from ollm.runtime.generation_support import require_tensor from ollm.runtime.loaded_runtime import LoadedRuntime +from ollm.runtime.output_control import suppress_module_prints + + +class ChunkedPrefillStrategyId(StrEnum): + OPTIMIZED_NATIVE_TEXT = "optimized-native-text" + OPTIMIZED_NATIVE_MULTIMODAL = "optimized-native-multimodal" + TRANSFORMERS_GENERIC_TEXT = "transformers-generic-text" + TRANSFORMERS_GENERIC_MULTIMODAL = "transformers-generic-multimodal" class ChunkedPrefillGapId(StrEnum): - PROMPT_CONSTRUCTION_BEFORE_PREFILL = "prompt-construction-before-prefill" - NON_CAUSAL_RUNTIME_EXPANSION = "non-causal-runtime-expansion" + PROMPT_TOKENIZATION_BEFORE_PREFILL = "prompt-tokenization-before-prefill" + FULL_ATTENTION_MASK_BEFORE_PREFILL = "full-attention-mask-before-prefill" + SEQ2SEQ_SOURCE_PREFILL = "seq2seq-source-prefill" class ChunkedPrefillRecommendation(StrEnum): @@ -43,12 +55,10 @@ def to_dict(self) -> dict[str, str]: @dataclass(frozen=True, slots=True) class ChunkedPrefillScopeSurface: + strategy_id: ChunkedPrefillStrategyId | None runtime_eligible: bool applied: bool activation_reason: str - supported_backend_id: str - supported_model_kind: str - supported_prompt_kind: str execution_boundary: ChunkedPrefillExecutionBoundary attention_mask_mode: ChunkedPrefillAttentionMaskMode gap_inventory: tuple[ChunkedPrefillGapDecision, ...] @@ -62,50 +72,97 @@ def with_activation(self, *, applied: bool, activation_reason: str) -> Self: def to_dict(self) -> dict[str, object]: return { + "strategy_id": None if self.strategy_id is None else self.strategy_id.value, "runtime_eligible": self.runtime_eligible, "applied": self.applied, "activation_reason": self.activation_reason, - "supported_backend_id": self.supported_backend_id, - "supported_model_kind": self.supported_model_kind, - "supported_prompt_kind": self.supported_prompt_kind, "execution_boundary": self.execution_boundary.value, "attention_mask_mode": self.attention_mask_mode.value, "gap_inventory": [decision.to_dict() for decision in self.gap_inventory], } +@dataclass(frozen=True, slots=True) +class PreparedChunkedPrefill: + inputs: dict[str, object] + generate_kwargs: dict[str, object] + scope: ChunkedPrefillScopeSurface + + _CHUNKED_PREFILL_GAP_INVENTORY = ( ChunkedPrefillGapDecision( - gap_id=ChunkedPrefillGapId.PROMPT_CONSTRUCTION_BEFORE_PREFILL, + gap_id=ChunkedPrefillGapId.PROMPT_TOKENIZATION_BEFORE_PREFILL, + current_behavior="Prompt tokenization completes before chunked prefill begins.", + recommendation=ChunkedPrefillRecommendation.DEFER, + rationale=( + "Streaming prompt tokenization needs tokenizer-specific boundary " + "preservation instead of the current whole-prompt tokenization path." + ), + ), + ChunkedPrefillGapDecision( + gap_id=ChunkedPrefillGapId.FULL_ATTENTION_MASK_BEFORE_PREFILL, current_behavior=( - "Prompt tokenization and full-prefix attention-mask materialization " - "complete before chunked prefill begins." + "A full prefix attention mask is materialized before chunked prefill " + "hands off to the final generate step." ), - recommendation=ChunkedPrefillRecommendation.REJECT, + recommendation=ChunkedPrefillRecommendation.DEFER, rationale=( - "Bounded chunked prefill is an execution-stage memory control after " - "prompt construction. Streaming prompt construction would be a " - "different feature with different tokenizer and processor " - "contracts." + "The current generation handoff still relies on full-prefix masks. " + "A lazy mask contract needs backend proof before it can replace that " + "shape safely." ), ), ChunkedPrefillGapDecision( - gap_id=ChunkedPrefillGapId.NON_CAUSAL_RUNTIME_EXPANSION, + gap_id=ChunkedPrefillGapId.SEQ2SEQ_SOURCE_PREFILL, current_behavior=( - "Chunked prefill is limited to optimized-native causal text " - "runtimes and does not cover seq2seq, multimodal, or " - "transformers-generic paths." + "Seq2seq source prompts do not use causal-cache chunked prefill." ), - recommendation=ChunkedPrefillRecommendation.REJECT, + recommendation=ChunkedPrefillRecommendation.DEFER, rationale=( - "Those runtimes have different encoder, cache, and prompt-shape " - "contracts. If pursued, they should ship as separate features " - "with their own execution and benchmark semantics." + "Encoder-decoder source ingestion has no equivalent causal-cache " + "prefill contract; it needs a separate encoder strategy." ), ), ) +def prepare_chunked_prefill( + *, + runtime: LoadedRuntime, + inputs: dict[str, object], + generate_kwargs: dict[str, object], + chunk_tokens: int, +) -> PreparedChunkedPrefill: + scope = build_chunked_prefill_scope_surface( + runtime=runtime, + inputs=inputs, + chunk_tokens=chunk_tokens, + ) + input_ids_value = inputs.get("input_ids") + if not isinstance(input_ids_value, torch.Tensor): + return PreparedChunkedPrefill(inputs, generate_kwargs, scope) + if input_ids_value.ndim != 2 or input_ids_value.shape[0] != 1: + return PreparedChunkedPrefill(inputs, generate_kwargs, scope) + prefill_token_count = input_ids_value.shape[1] - 1 + if prefill_token_count <= chunk_tokens or not scope.runtime_eligible: + return PreparedChunkedPrefill(inputs, generate_kwargs, scope) + prepared_inputs, prepared_generate_kwargs = _run_causal_chunked_prefill( + runtime=runtime, + inputs=inputs, + generate_kwargs=generate_kwargs, + chunk_tokens=chunk_tokens, + strategy_id=scope.strategy_id, + ) + return PreparedChunkedPrefill( + inputs=prepared_inputs, + generate_kwargs=prepared_generate_kwargs, + scope=scope.with_activation( + applied=True, + activation_reason="Bounded chunked prefill ran before final decode.", + ), + ) + + def build_chunked_prefill_scope_surface( *, runtime: LoadedRuntime, @@ -114,48 +171,46 @@ def build_chunked_prefill_scope_surface( ) -> ChunkedPrefillScopeSurface: input_ids = inputs.get("input_ids") if not isinstance(input_ids, torch.Tensor): - return _surface( + return _scope( + strategy_id=None, runtime_eligible=False, activation_reason="Chunked prefill requires tensor-backed input_ids.", ) if input_ids.ndim != 2 or input_ids.shape[0] != 1: - return _surface( + return _scope( + strategy_id=None, runtime_eligible=False, activation_reason="Chunked prefill requires a single batch row.", ) - if runtime.processor is not None: - return _surface( - runtime_eligible=False, - activation_reason=( - "Chunked prefill is limited to text prompts without a processor." - ), - ) - if runtime.plan.backend_id != "optimized-native": - return _surface( - runtime_eligible=False, - activation_reason=( - "Chunked prefill is limited to the optimized-native backend." - ), - ) runtime_kind = ( runtime.plan.generic_model_kind or runtime.resolved_model.generic_model_kind ) - if runtime_kind is not GenericModelKind.CAUSAL_LM: - return _surface( + if runtime_kind is GenericModelKind.SEQ2SEQ_LM: + return _scope( + strategy_id=None, + runtime_eligible=False, + activation_reason="Seq2seq source prompts cannot use causal-cache chunked prefill.", + ) + strategy_id = _resolve_strategy_id(runtime, runtime_kind) + if strategy_id is None: + return _scope( + strategy_id=None, runtime_eligible=False, activation_reason=( - "Chunked prefill is limited to causal decoder-only text runtimes." + "Chunked prefill requires a supported causal runtime strategy." ), ) prefill_token_count = int(input_ids.shape[1] - 1) if prefill_token_count <= chunk_tokens: - return _surface( + return _scope( + strategy_id=strategy_id, runtime_eligible=True, activation_reason=( "Prompt length does not exceed the chunked-prefill threshold." ), ) - return _surface( + return _scope( + strategy_id=strategy_id, runtime_eligible=True, activation_reason="Runtime is eligible for bounded chunked prefill.", ) @@ -165,16 +220,144 @@ def chunked_prefill_gap_inventory() -> tuple[ChunkedPrefillGapDecision, ...]: return _CHUNKED_PREFILL_GAP_INVENTORY -def _surface( - *, runtime_eligible: bool, activation_reason: str +def _resolve_strategy_id( + runtime: LoadedRuntime, + runtime_kind: GenericModelKind | None, +) -> ChunkedPrefillStrategyId | None: + if runtime.plan.backend_id == "optimized-native": + if runtime.processor is not None: + return ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL + if runtime_kind is GenericModelKind.CAUSAL_LM: + return ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT + return None + if runtime.plan.backend_id == "transformers-generic": + if runtime.processor is not None: + return ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL + if runtime_kind is GenericModelKind.CAUSAL_LM: + return ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT + return None + return None + + +def _run_causal_chunked_prefill( + *, + runtime: LoadedRuntime, + inputs: dict[str, object], + generate_kwargs: dict[str, object], + chunk_tokens: int, + strategy_id: ChunkedPrefillStrategyId | None, +) -> tuple[dict[str, object], dict[str, object]]: + forward_method = getattr(runtime.model, "forward", None) + if not callable(forward_method): + return inputs, generate_kwargs + input_ids = require_tensor(inputs["input_ids"]) + attention_mask = _optional_tensor(inputs.get("attention_mask")) + sequence_inputs = _collect_sequence_inputs(inputs, input_ids) + static_inputs = _collect_static_inputs(inputs) + prefill_cache = generate_kwargs.get("past_key_values") + prefill_end = input_ids.shape[1] - 1 + with torch.inference_mode(): + with suppress_module_prints(runtime.backend.print_suppression_modules): + for chunk_start in range(0, prefill_end, chunk_tokens): + chunk_end = min(prefill_end, chunk_start + chunk_tokens) + forward_inputs: dict[str, object] = dict(static_inputs) + forward_inputs["input_ids"] = input_ids[:, chunk_start:chunk_end] + if attention_mask is not None: + forward_inputs["attention_mask"] = attention_mask[:, :chunk_end] + for key, value in sequence_inputs.items(): + forward_inputs[key] = value[:, chunk_start:chunk_end] + if prefill_cache is not None: + forward_inputs["past_key_values"] = prefill_cache + forward_inputs["use_cache"] = True + forward_inputs["cache_position"] = torch.arange( + chunk_start, + chunk_end, + device=input_ids.device, + dtype=torch.long, + ) + filtered_inputs = _filter_supported_forward_inputs( + forward_method, + forward_inputs, + ) + outputs = forward_method(**filtered_inputs) + prefill_cache = getattr(outputs, "past_key_values", None) + if prefill_cache is None: + strategy_label = ( + "unknown" if strategy_id is None else strategy_id.value + ) + raise PromptExecutionError( + "Chunked prefill strategy " + f"{strategy_label!r} requires a runtime that returns " + "past_key_values." + ) + updated_inputs = dict(static_inputs) + updated_inputs["input_ids"] = input_ids[:, -1:] + if attention_mask is not None: + updated_inputs["attention_mask"] = attention_mask + for key, value in sequence_inputs.items(): + updated_inputs[key] = value[:, -1:] + updated_generate_kwargs = dict(generate_kwargs) + updated_generate_kwargs["past_key_values"] = prefill_cache + return updated_inputs, updated_generate_kwargs + + +def _collect_sequence_inputs( + inputs: dict[str, object], + input_ids: torch.Tensor, +) -> dict[str, torch.Tensor]: + sequence_inputs: dict[str, torch.Tensor] = {} + sequence_length = input_ids.shape[1] + for key, value in inputs.items(): + if key in {"input_ids", "attention_mask"}: + continue + if ( + isinstance(value, torch.Tensor) + and value.ndim == 2 + and value.shape[1] == sequence_length + ): + sequence_inputs[key] = value + return sequence_inputs + + +def _collect_static_inputs(inputs: dict[str, object]) -> dict[str, object]: + return { + key: value + for key, value in inputs.items() + if key not in {"input_ids", "attention_mask"} + } + + +def _filter_supported_forward_inputs( + forward_method, + inputs: dict[str, object], +) -> dict[str, object]: + method_signature = signature(forward_method) + if any( + parameter.kind is Parameter.VAR_KEYWORD + for parameter in method_signature.parameters.values() + ): + return inputs + supported_keys = set(method_signature.parameters) + return {key: value for key, value in inputs.items() if key in supported_keys} + + +def _optional_tensor(value: object) -> torch.Tensor | None: + if value is None: + return None + return require_tensor(value) + + +def _scope( + *, + strategy_id: ChunkedPrefillStrategyId | None, + runtime_eligible: bool, + activation_reason: str, ) -> ChunkedPrefillScopeSurface: return ChunkedPrefillScopeSurface( + strategy_id=strategy_id, runtime_eligible=runtime_eligible, applied=False, activation_reason=activation_reason, - supported_backend_id="optimized-native", - supported_model_kind=GenericModelKind.CAUSAL_LM.value, - supported_prompt_kind="text-only", execution_boundary=ChunkedPrefillExecutionBoundary.POST_TOKENIZATION, attention_mask_mode=ChunkedPrefillAttentionMaskMode.FULL_PREFIX_MATERIALIZED, gap_inventory=_CHUNKED_PREFILL_GAP_INVENTORY, diff --git a/src/ollm/runtime/execution_trace.py b/src/ollm/runtime/execution_trace.py index ef52c64..6c7849f 100644 --- a/src/ollm/runtime/execution_trace.py +++ b/src/ollm/runtime/execution_trace.py @@ -58,13 +58,15 @@ def execute_request_with_trace( ) prepared_inputs = normalize_generate_inputs(inputs) generation_started_at = time.perf_counter() - prepared_result = prepare_runtime_generate_inputs( + ( + prepared_inputs, + prepared_generate_kwargs, + chunked_prefill, + ) = prepare_runtime_generate_inputs( runtime, prepared_inputs, generate_kwargs, ) - prepared_inputs = prepared_result.inputs - prepared_generate_kwargs = prepared_result.generate_kwargs outputs, effective_generate_kwargs = _generate_outputs( runtime=runtime, prepared_inputs=prepared_inputs, @@ -93,7 +95,7 @@ def execute_request_with_trace( output_token_count=output_token_count, response_text=response_text, cache_state=cache_state, - chunked_prefill=prepared_result.chunked_prefill, + chunked_prefill=chunked_prefill, ) diff --git a/src/ollm/runtime/generation.py b/src/ollm/runtime/generation.py index 4705fd0..9cd36ed 100644 --- a/src/ollm/runtime/generation.py +++ b/src/ollm/runtime/generation.py @@ -15,7 +15,7 @@ from ollm.runtime.catalog import ModelModality from ollm.runtime.chunked_prefill import ( ChunkedPrefillScopeSurface, - build_chunked_prefill_scope_surface, + prepare_chunked_prefill, ) from ollm.runtime.errors import PromptExecutionError from ollm.runtime.generation_config_support import ( @@ -43,13 +43,6 @@ def print_and_clean(self) -> str: ... DEFAULT_PREFILL_CHUNK_TOKENS = 512 -@dataclass(frozen=True, slots=True) -class PreparedRuntimeGenerateInputs: - inputs: dict[str, object] - generate_kwargs: dict[str, object] - chunked_prefill: ChunkedPrefillScopeSurface - - def validate_runtime_request(runtime: LoadedRuntime, request: PromptRequest) -> None: if not request.messages: raise PromptExecutionError("At least one message is required") @@ -194,101 +187,20 @@ def prepare_runtime_generate_inputs( runtime: LoadedRuntime, inputs: dict[str, object], generate_kwargs: dict[str, object], -) -> PreparedRuntimeGenerateInputs: - chunked_prefill = build_chunked_prefill_scope_surface( +) -> tuple[dict[str, object], dict[str, object], ChunkedPrefillScopeSurface]: + prepared = prepare_chunked_prefill( runtime=runtime, inputs=inputs, + generate_kwargs=generate_kwargs, chunk_tokens=DEFAULT_PREFILL_CHUNK_TOKENS, ) - input_ids_value = inputs.get("input_ids") - if not isinstance(input_ids_value, torch.Tensor): - return PreparedRuntimeGenerateInputs( - inputs=inputs, - generate_kwargs=generate_kwargs, - chunked_prefill=chunked_prefill, - ) - if input_ids_value.ndim != 2 or input_ids_value.shape[0] != 1: - return PreparedRuntimeGenerateInputs( - inputs=inputs, - generate_kwargs=generate_kwargs, - chunked_prefill=chunked_prefill, - ) - prefill_token_count = input_ids_value.shape[1] - 1 - if prefill_token_count <= DEFAULT_PREFILL_CHUNK_TOKENS: - return PreparedRuntimeGenerateInputs( - inputs=inputs, - generate_kwargs=generate_kwargs, - chunked_prefill=chunked_prefill, - ) - if not chunked_prefill.runtime_eligible: - return PreparedRuntimeGenerateInputs( - inputs=inputs, - generate_kwargs=generate_kwargs, - chunked_prefill=chunked_prefill, - ) - prepared_inputs, prepared_generate_kwargs = _run_chunked_prefill( - runtime, - inputs, - generate_kwargs, - ) - return PreparedRuntimeGenerateInputs( - inputs=prepared_inputs, - generate_kwargs=prepared_generate_kwargs, - chunked_prefill=chunked_prefill.with_activation( - applied=True, - activation_reason="Bounded chunked prefill ran before final decode.", - ), + return ( + prepared.inputs, + prepared.generate_kwargs, + prepared.scope, ) -def _run_chunked_prefill( - runtime: LoadedRuntime, - inputs: dict[str, object], - generate_kwargs: dict[str, object], -) -> tuple[dict[str, object], dict[str, object]]: - forward_method = getattr(runtime.model, "forward", None) - if not callable(forward_method): - return inputs, generate_kwargs - input_ids = require_tensor(inputs["input_ids"]) - attention_mask_value = inputs.get("attention_mask") - attention_mask = ( - None if attention_mask_value is None else require_tensor(attention_mask_value) - ) - prefill_cache = generate_kwargs.get("past_key_values") - prefill_end = input_ids.shape[1] - 1 - with torch.inference_mode(): - with suppress_module_prints(runtime.backend.print_suppression_modules): - for chunk_start in range(0, prefill_end, DEFAULT_PREFILL_CHUNK_TOKENS): - chunk_end = min(prefill_end, chunk_start + DEFAULT_PREFILL_CHUNK_TOKENS) - forward_inputs: dict[str, object] = { - "input_ids": input_ids[:, chunk_start:chunk_end], - "use_cache": True, - "cache_position": torch.arange( - chunk_start, - chunk_end, - device=input_ids.device, - dtype=torch.long, - ), - } - if attention_mask is not None: - forward_inputs["attention_mask"] = attention_mask[:, :chunk_end] - if prefill_cache is not None: - forward_inputs["past_key_values"] = prefill_cache - outputs = forward_method(**forward_inputs) - prefill_cache = getattr(outputs, "past_key_values", None) - if prefill_cache is None: - raise PromptExecutionError( - "Chunked prefill requires a causal runtime that returns past_key_values." - ) - updated_inputs = dict(inputs) - updated_inputs["input_ids"] = input_ids[:, -1:] - if attention_mask is not None: - updated_inputs["attention_mask"] = attention_mask - updated_generate_kwargs = dict(generate_kwargs) - updated_generate_kwargs["past_key_values"] = prefill_cache - return updated_inputs, updated_generate_kwargs - - def decode_runtime_response( runtime: LoadedRuntime, inputs: dict[str, object], outputs: torch.Tensor ) -> str: @@ -344,13 +256,15 @@ def execute( runtime, request, streamer ) filtered_inputs = normalize_generate_inputs(inputs) - prepared_generate_inputs = prepare_runtime_generate_inputs( + ( + filtered_inputs, + generate_kwargs, + chunked_prefill, + ) = prepare_runtime_generate_inputs( runtime, filtered_inputs, generate_kwargs, ) - filtered_inputs = prepared_generate_inputs.inputs - generate_kwargs = prepared_generate_inputs.generate_kwargs with torch.inference_mode(): with suppress_module_prints(runtime.backend.print_suppression_modules): @@ -369,7 +283,7 @@ def execute( if streamer is not None and not response_text.strip(): response_text = streamer.text assistant_message = Message.assistant_text(response_text) - metadata = self._plan_metadata(runtime, cache_state) + metadata = self._plan_metadata(runtime, cache_state, chunked_prefill) return PromptResponse( text=response_text, assistant_message=assistant_message, metadata=metadata ) @@ -378,7 +292,7 @@ def _finalize_response( self, runtime: LoadedRuntime, response: PromptResponse ) -> PromptResponse: metadata = dict(response.metadata) - metadata.update(self._plan_metadata(runtime, None)) + metadata.update(self._plan_metadata(runtime, None, None)) return PromptResponse( text=response.text, assistant_message=response.assistant_message, @@ -389,6 +303,7 @@ def _plan_metadata( self, runtime: LoadedRuntime, cache_state: KVCacheStateSnapshot | None, + chunked_prefill: ChunkedPrefillScopeSurface | None, ) -> dict[str, str]: metadata = { "backend_id": runtime.plan.backend_id or "unknown", @@ -414,6 +329,30 @@ def _plan_metadata( ), "kv_cache_lifecycle": runtime.config.resolved_kv_cache_lifecycle(), "kv_cache_adaptation_mode": runtime.config.resolved_kv_cache_adaptation_mode(), + "chunked_prefill_strategy_id": ( + "" + if chunked_prefill is None or chunked_prefill.strategy_id is None + else chunked_prefill.strategy_id.value + ), + "chunked_prefill_runtime_eligible": str( + False if chunked_prefill is None else chunked_prefill.runtime_eligible + ).lower(), + "chunked_prefill_applied": str( + False if chunked_prefill is None else chunked_prefill.applied + ).lower(), + "chunked_prefill_activation_reason": ( + "" if chunked_prefill is None else chunked_prefill.activation_reason + ), + "chunked_prefill_execution_boundary": ( + "" + if chunked_prefill is None + else chunked_prefill.execution_boundary.value + ), + "chunked_prefill_attention_mask_mode": ( + "" + if chunked_prefill is None + else chunked_prefill.attention_mask_mode.value + ), } resolved_window_tokens = runtime.config.resolved_kv_cache_window_tokens() if resolved_window_tokens is not None: diff --git a/tests/benchmark_support.py b/tests/benchmark_support.py index 437ec6d..f005eef 100644 --- a/tests/benchmark_support.py +++ b/tests/benchmark_support.py @@ -10,6 +10,7 @@ ChunkedPrefillAttentionMaskMode, ChunkedPrefillExecutionBoundary, ChunkedPrefillScopeSurface, + ChunkedPrefillStrategyId, chunked_prefill_gap_inventory, ) @@ -102,12 +103,10 @@ def build_request_probe_metrics() -> RequestProbeMetrics: cold_store_format=None, ), chunked_prefill=ChunkedPrefillScopeSurface( + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT, runtime_eligible=True, applied=True, activation_reason="Bounded chunked prefill ran before final decode.", - supported_backend_id="optimized-native", - supported_model_kind="causal-lm", - supported_prompt_kind="text-only", execution_boundary=ChunkedPrefillExecutionBoundary.POST_TOKENIZATION, attention_mask_mode=ChunkedPrefillAttentionMaskMode.FULL_PREFIX_MATERIALIZED, gap_inventory=chunked_prefill_gap_inventory(), diff --git a/tests/test_benchmark_probe_execution.py b/tests/test_benchmark_probe_execution.py index ca9f68a..fa57554 100644 --- a/tests/test_benchmark_probe_execution.py +++ b/tests/test_benchmark_probe_execution.py @@ -12,6 +12,7 @@ from ollm.runtime.capabilities import CapabilityProfile, SupportLevel from ollm.runtime.capability_discovery import GenericModelKind from ollm.runtime.catalog import ModelModality +from ollm.runtime.chunked_prefill import ChunkedPrefillStrategyId from ollm.runtime.config import GenerationConfig, RuntimeConfig from ollm.runtime.execution_trace import execute_request_with_trace from ollm.runtime.loaded_runtime import LoadedRuntime @@ -152,7 +153,10 @@ def test_execute_request_probe_strips_processor_token_type_ids() -> None: assert execution.response_text == "decoded-benchmark" assert "token_type_ids" not in runtime.model.generate_kwargs - assert execution.metrics.chunked_prefill.runtime_eligible is False + assert execution.metrics.chunked_prefill.strategy_id is ( + ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL + ) + assert execution.metrics.chunked_prefill.runtime_eligible is True assert execution.metrics.chunked_prefill.applied is False @@ -201,6 +205,9 @@ def test_execute_request_probe_uses_chunked_prefill_for_long_causal_prompts( execution = execute_request_probe(runtime=runtime, request=request) assert execution.response_text == "long-decoded" + assert execution.metrics.chunked_prefill.strategy_id is ( + ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT + ) assert len(model.forward_calls) == 2 generate_input_ids = model.generate_kwargs["input_ids"] assert isinstance(generate_input_ids, torch.Tensor) @@ -224,7 +231,10 @@ def test_execute_request_with_trace_reports_processor_counts() -> None: assert trace.decode_prefix_token_count == 3 assert trace.output_token_count == 1 assert trace.cache_state is None - assert trace.chunked_prefill.runtime_eligible is False + assert trace.chunked_prefill.strategy_id is ( + ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL + ) + assert trace.chunked_prefill.runtime_eligible is True assert trace.chunked_prefill.applied is False @@ -276,6 +286,9 @@ def test_execute_request_with_trace_tracks_chunked_prefill_prefix_length( assert trace.prompt_token_count == 5 assert trace.decode_prefix_token_count == 1 assert trace.output_token_count == 4 + assert trace.chunked_prefill.strategy_id is ( + ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT + ) assert trace.chunked_prefill.runtime_eligible is True assert trace.chunked_prefill.applied is True assert len(model.forward_calls) == 2 diff --git a/tests/test_benchmark_reporting.py b/tests/test_benchmark_reporting.py index f85f00c..599d075 100644 --- a/tests/test_benchmark_reporting.py +++ b/tests/test_benchmark_reporting.py @@ -116,12 +116,13 @@ def test_render_runtime_probe_json_round_trips() -> None: events = cast(dict[str, object], native_runtime_profile["events"]) assert request["output_tokens"] == 4 assert request["kv_cache_strategy"] == "chunked" + assert chunked_prefill["strategy_id"] == "optimized-native-text" assert chunked_prefill["runtime_eligible"] is True assert chunked_prefill["applied"] is True assert chunked_prefill["execution_boundary"] == "post-tokenization" assert chunked_prefill["attention_mask_mode"] == "full-prefix-materialized" gap_inventory = cast(list[object], chunked_prefill["gap_inventory"]) - assert len(gap_inventory) == 2 + assert len(gap_inventory) == 3 adaptation = cast(dict[str, object], request["kv_cache_adaptation"]) assert adaptation["adaptation_mode"] == "observe-only" assert adaptation["recommendation_available"] is True @@ -254,9 +255,9 @@ def test_summarize_request_metrics_includes_native_runtime_profile() -> None: "disk-kv-cache", "safetensor-io", ] + assert chunked_prefill["strategy_id"] == "optimized-native-text" assert chunked_prefill["runtime_eligible"] is True assert chunked_prefill["applied"] is True - assert chunked_prefill["supported_backend_id"] == "optimized-native" assert cast(dict[str, object], summary["cache"])["kv_cache_strategy"] == "chunked" adaptation = cast( dict[str, object], diff --git a/tests/test_chunked_prefill_scope.py b/tests/test_chunked_prefill_scope.py index d62e643..d849020 100644 --- a/tests/test_chunked_prefill_scope.py +++ b/tests/test_chunked_prefill_scope.py @@ -1,11 +1,17 @@ from dataclasses import replace +import pytest +import torch +from transformers import T5Config +from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration + from ollm.app.types import ContentPart, Message, MessageRole from ollm.runtime.capabilities import CapabilityProfile, SupportLevel from ollm.runtime.capability_discovery import GenericModelKind from ollm.runtime.chunked_prefill import ( ChunkedPrefillGapId, ChunkedPrefillRecommendation, + ChunkedPrefillStrategyId, ) from ollm.runtime.generation import ( build_runtime_generate_kwargs, @@ -47,29 +53,34 @@ def test_prepare_runtime_generate_inputs_surfaces_chunked_prefill_scope( request, streamer=None, ) - prepared = prepare_runtime_generate_inputs(runtime, inputs, generate_kwargs) + _prepared_inputs, _prepared_generate_kwargs, chunked_prefill = ( + prepare_runtime_generate_inputs(runtime, inputs, generate_kwargs) + ) - assert prepared.chunked_prefill.runtime_eligible is True - assert prepared.chunked_prefill.applied is True + assert chunked_prefill.runtime_eligible is True + assert chunked_prefill.applied is True + assert chunked_prefill.strategy_id is ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT assert ( - prepared.chunked_prefill.activation_reason + chunked_prefill.activation_reason == "Bounded chunked prefill ran before final decode." ) gap_inventory = { decision.gap_id: decision.recommendation - for decision in prepared.chunked_prefill.gap_inventory - } - assert gap_inventory == { - ChunkedPrefillGapId.PROMPT_CONSTRUCTION_BEFORE_PREFILL: ( - ChunkedPrefillRecommendation.REJECT - ), - ChunkedPrefillGapId.NON_CAUSAL_RUNTIME_EXPANSION: ( - ChunkedPrefillRecommendation.REJECT - ), + for decision in chunked_prefill.gap_inventory } + assert gap_inventory[ChunkedPrefillGapId.PROMPT_TOKENIZATION_BEFORE_PREFILL] is ( + ChunkedPrefillRecommendation.DEFER + ) + assert ( + gap_inventory[ChunkedPrefillGapId.FULL_ATTENTION_MASK_BEFORE_PREFILL] + is ChunkedPrefillRecommendation.DEFER + ) + assert gap_inventory[ChunkedPrefillGapId.SEQ2SEQ_SOURCE_PREFILL] is ( + ChunkedPrefillRecommendation.DEFER + ) -def test_prepare_runtime_generate_inputs_rejects_seq2seq_scope_extension( +def test_prepare_runtime_generate_inputs_defers_seq2seq_source_prefill( monkeypatch, ) -> None: runtime = build_runtime_with_model( @@ -94,11 +105,42 @@ def test_prepare_runtime_generate_inputs_rejects_seq2seq_scope_extension( request, streamer=None, ) - prepared = prepare_runtime_generate_inputs(runtime, inputs, generate_kwargs) + _prepared_inputs, _prepared_generate_kwargs, chunked_prefill = ( + prepare_runtime_generate_inputs(runtime, inputs, generate_kwargs) + ) - assert prepared.chunked_prefill.runtime_eligible is False - assert prepared.chunked_prefill.applied is False + assert chunked_prefill.runtime_eligible is False + assert chunked_prefill.applied is False + assert chunked_prefill.strategy_id is None assert ( - prepared.chunked_prefill.activation_reason - == "Chunked prefill is limited to causal decoder-only text runtimes." + chunked_prefill.activation_reason + == "Seq2seq source prompts cannot use causal-cache chunked prefill." + ) + + +def test_t5_encoder_does_not_expose_cacheable_source_prefill() -> None: + model = T5ForConditionalGeneration( + T5Config( + vocab_size=64, + d_model=32, + d_kv=8, + d_ff=64, + num_layers=2, + num_decoder_layers=2, + num_heads=4, + pad_token_id=0, + eos_token_id=1, + decoder_start_token_id=0, + ) ) + model.eval() + input_ids = torch.tensor([[5, 6, 7]]) + attention_mask = torch.ones_like(input_ids) + + with torch.inference_mode(), pytest.raises(ValueError, match="used as a decoder"): + model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=True, + return_dict=True, + ) diff --git a/tests/test_runtime_executor_prefill.py b/tests/test_runtime_executor_prefill.py index 7e30859..680db24 100644 --- a/tests/test_runtime_executor_prefill.py +++ b/tests/test_runtime_executor_prefill.py @@ -5,12 +5,17 @@ import torch from ollm.app.types import ContentPart, Message, MessageRole +from ollm.runtime.backends.base import BackendRuntime from ollm.runtime.capabilities import CapabilityProfile, SupportLevel from ollm.runtime.capability_discovery import GenericModelKind +from ollm.runtime.catalog import ModelModality +from ollm.runtime.chunked_prefill import ChunkedPrefillStrategyId from ollm.runtime.generation import RuntimeExecutor +from ollm.runtime.loaded_runtime import LoadedRuntime from tests.test_runtime_executor import ( FakeModel, build_request, + build_runtime, build_runtime_with_model, ) @@ -72,6 +77,102 @@ def forward( return types.SimpleNamespace(past_key_values=self.prefill_cache) +class LongProcessorInputs(dict): + def __init__(self, static_key: str): + super().__init__( + { + "input_ids": torch.tensor([[1, 2, 3, 4, 5]]), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), + static_key: torch.tensor([[[0.25, 0.5]]], dtype=torch.float32), + } + ) + self.to_calls: list[tuple[torch.device, torch.dtype | None]] = [] + + def to(self, device, dtype=None): + self.to_calls.append((device, dtype)) + return self + + +class LongProcessor: + def __init__(self, static_key: str): + self.static_key = static_key + self.inputs = LongProcessorInputs(static_key) + + def apply_chat_template( + self, + messages, + add_generation_prompt, + tokenize, + return_dict, + return_tensors, + ): + del messages, add_generation_prompt, tokenize, return_dict, return_tensors + return self.inputs + + def batch_decode(self, outputs, skip_special_tokens=False): + del outputs, skip_special_tokens + return ["long-decoded"] + + +class ChunkedPrefillMultimodalModel(FakeModel): + def __init__(self) -> None: + super().__init__() + self.forward_calls: list[dict[str, object]] = [] + self.prefill_cache = object() + + def forward( + self, + input_ids, + attention_mask=None, + pixel_values=None, + input_features=None, + past_key_values=None, + use_cache=None, + cache_position=None, + ): + self.forward_calls.append( + { + "input_ids": input_ids.clone(), + "attention_mask": None + if attention_mask is None + else attention_mask.clone(), + "pixel_values": None if pixel_values is None else pixel_values.clone(), + "input_features": None + if input_features is None + else input_features.clone(), + "past_key_values": past_key_values, + "use_cache": use_cache, + "cache_position": None + if cache_position is None + else cache_position.clone(), + } + ) + return types.SimpleNamespace(past_key_values=self.prefill_cache) + + +def build_runtime_with_processor_model( + *, + capabilities: CapabilityProfile, + processor: LongProcessor, + model: FakeModel, +) -> LoadedRuntime: + runtime = build_runtime(capabilities) + runtime.backend = BackendRuntime( + backend_id=runtime.backend.backend_id, + model=model, + tokenizer=runtime.tokenizer, + processor=processor, + device=torch.device("cpu"), + stats=None, + print_suppression_modules=(), + create_cache=lambda cache_dir, cache_strategy=None, cache_lifecycle=None, cache_window_tokens=None: ( + None + ), + apply_offload=lambda runtime_config: None, + ) + return runtime + + def test_runtime_executor_prefills_long_causal_prompts_in_chunks(monkeypatch) -> None: model = ChunkedPrefillModel() runtime = build_runtime_with_model( @@ -93,6 +194,10 @@ def test_runtime_executor_prefills_long_causal_prompts_in_chunks(monkeypatch) -> response = RuntimeExecutor().execute(runtime, request) assert response.text == "long-decoded" + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT.value + ) assert len(model.forward_calls) == 2 first_call = model.forward_calls[0] second_call = model.forward_calls[1] @@ -123,7 +228,141 @@ def test_runtime_executor_prefills_long_causal_prompts_in_chunks(monkeypatch) -> assert model.generate_kwargs["past_key_values"] is model.prefill_cache -def test_runtime_executor_skips_chunked_prefill_for_seq2seq_runtime( +def test_runtime_executor_prefills_long_generic_causal_prompts_in_chunks( + monkeypatch, +) -> None: + model = ChunkedPrefillModel() + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=model, + ) + runtime.plan = replace( + runtime.plan, + backend_id="transformers-generic", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.text == "long-decoded" + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT.value + ) + assert len(model.forward_calls) == 2 + generate_input_ids = model.generate_kwargs["input_ids"] + assert isinstance(generate_input_ids, torch.Tensor) + assert torch.equal(generate_input_ids, torch.tensor([[5]])) + + +def test_runtime_executor_prefills_long_native_multimodal_prompts_in_chunks( + monkeypatch, +) -> None: + processor = LongProcessor("pixel_values") + model = ChunkedPrefillMultimodalModel() + runtime = build_runtime_with_processor_model( + capabilities=CapabilityProfile( + support_level=SupportLevel.OPTIMIZED, + modalities=(ModelModality.TEXT, ModelModality.IMAGE), + requires_processor=True, + ), + processor=processor, + model=model, + ) + runtime.plan = replace( + runtime.plan, + backend_id="optimized-native", + generic_model_kind=GenericModelKind.IMAGE_TEXT_TO_TEXT, + ) + request = build_request( + runtime.config, + Message( + role=MessageRole.USER, + content=[ + ContentPart.text("long prompt"), + ContentPart.image("image.png"), + ], + ), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.text == "long-decoded" + assert processor.inputs.to_calls == [(torch.device("cpu"), torch.bfloat16)] + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL.value + ) + assert len(model.forward_calls) == 2 + first_call = model.forward_calls[0] + second_call = model.forward_calls[1] + first_pixel_values = cast(torch.Tensor, first_call["pixel_values"]) + second_pixel_values = cast(torch.Tensor, second_call["pixel_values"]) + assert torch.equal(first_pixel_values, torch.tensor([[[0.25, 0.5]]])) + assert torch.equal(second_pixel_values, torch.tensor([[[0.25, 0.5]]])) + generate_input_ids = model.generate_kwargs["input_ids"] + generate_pixel_values = model.generate_kwargs["pixel_values"] + assert isinstance(generate_input_ids, torch.Tensor) + assert isinstance(generate_pixel_values, torch.Tensor) + assert torch.equal(generate_input_ids, torch.tensor([[5]])) + assert torch.equal(generate_pixel_values, torch.tensor([[[0.25, 0.5]]])) + + +def test_runtime_executor_prefills_long_generic_multimodal_prompts_in_chunks( + monkeypatch, +) -> None: + processor = LongProcessor("pixel_values") + model = ChunkedPrefillMultimodalModel() + runtime = build_runtime_with_processor_model( + capabilities=CapabilityProfile( + support_level=SupportLevel.GENERIC, + modalities=(ModelModality.TEXT, ModelModality.IMAGE), + requires_processor=True, + ), + processor=processor, + model=model, + ) + runtime.plan = replace( + runtime.plan, + backend_id="transformers-generic", + generic_model_kind=GenericModelKind.IMAGE_TEXT_TO_TEXT, + ) + request = build_request( + runtime.config, + Message( + role=MessageRole.USER, + content=[ + ContentPart.text("long prompt"), + ContentPart.image("image.png"), + ], + ), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.text == "long-decoded" + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL.value + ) + assert len(model.forward_calls) == 2 + generate_input_ids = model.generate_kwargs["input_ids"] + generate_pixel_values = model.generate_kwargs["pixel_values"] + assert isinstance(generate_input_ids, torch.Tensor) + assert isinstance(generate_pixel_values, torch.Tensor) + assert torch.equal(generate_input_ids, torch.tensor([[5]])) + assert torch.equal(generate_pixel_values, torch.tensor([[[0.25, 0.5]]])) + + +def test_runtime_executor_defers_chunked_prefill_for_seq2seq_runtime( monkeypatch, ) -> None: model = ChunkedPrefillModel() @@ -143,9 +382,11 @@ def test_runtime_executor_skips_chunked_prefill_for_seq2seq_runtime( ) monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) - RuntimeExecutor().execute(runtime, request) + response = RuntimeExecutor().execute(runtime, request) assert model.forward_calls == [] + assert response.metadata["chunked_prefill_strategy_id"] == "" + assert response.metadata["chunked_prefill_applied"] == "false" generate_input_ids = model.generate_kwargs["input_ids"] assert isinstance(generate_input_ids, torch.Tensor) assert torch.equal(generate_input_ids, torch.tensor([[1, 2, 3, 4, 5]])) From 941f549812f4b48fa9e25db60eadf6e86232edf8 Mon Sep 17 00:00:00 2001 From: beardedeagle Date: Fri, 3 Apr 2026 08:12:35 -0500 Subject: [PATCH 3/9] refactor: split chunked prefill into explicit strategy handlers Replace the pure strategy-id resolution path with explicit per-lane chunked-prefill handlers for optimized-native text, optimized-native multimodal, transformers-generic text, and transformers-generic multimodal. The handlers still share the causal chunk loop, but each lane now has its own matcher and runner so future divergence can happen on a concrete strategy boundary instead of implicit branching. The full required gate was rerun from the final formatted tree before this commit: ruff format/check, standards checker, ty, compileall, pytest (432 passed), build, pip_audit, mkdocs --strict, and git diff --check. --- src/ollm/runtime/chunked_prefill.py | 165 ++++++++++++++++++++++++---- 1 file changed, 141 insertions(+), 24 deletions(-) diff --git a/src/ollm/runtime/chunked_prefill.py b/src/ollm/runtime/chunked_prefill.py index b939258..3acbf77 100644 --- a/src/ollm/runtime/chunked_prefill.py +++ b/src/ollm/runtime/chunked_prefill.py @@ -1,5 +1,6 @@ """Chunked-prefill strategy resolution and execution.""" +from collections.abc import Callable from dataclasses import asdict, dataclass, replace from enum import StrEnum from inspect import Parameter, signature @@ -89,6 +90,16 @@ class PreparedChunkedPrefill: scope: ChunkedPrefillScopeSurface +@dataclass(frozen=True, slots=True) +class ChunkedPrefillStrategy: + strategy_id: ChunkedPrefillStrategyId + matches: Callable[[LoadedRuntime, GenericModelKind | None], bool] + prepare: Callable[ + [LoadedRuntime, dict[str, object], dict[str, object], int], + tuple[dict[str, object], dict[str, object]], + ] + + _CHUNKED_PREFILL_GAP_INVENTORY = ( ChunkedPrefillGapDecision( gap_id=ChunkedPrefillGapId.PROMPT_TOKENIZATION_BEFORE_PREFILL, @@ -146,12 +157,12 @@ def prepare_chunked_prefill( prefill_token_count = input_ids_value.shape[1] - 1 if prefill_token_count <= chunk_tokens or not scope.runtime_eligible: return PreparedChunkedPrefill(inputs, generate_kwargs, scope) - prepared_inputs, prepared_generate_kwargs = _run_causal_chunked_prefill( - runtime=runtime, - inputs=inputs, - generate_kwargs=generate_kwargs, - chunk_tokens=chunk_tokens, - strategy_id=scope.strategy_id, + strategy = _require_strategy(scope.strategy_id) + prepared_inputs, prepared_generate_kwargs = strategy.prepare( + runtime, + inputs, + generate_kwargs, + chunk_tokens, ) return PreparedChunkedPrefill( inputs=prepared_inputs, @@ -191,8 +202,8 @@ def build_chunked_prefill_scope_surface( runtime_eligible=False, activation_reason="Seq2seq source prompts cannot use causal-cache chunked prefill.", ) - strategy_id = _resolve_strategy_id(runtime, runtime_kind) - if strategy_id is None: + strategy = _resolve_strategy(runtime, runtime_kind) + if strategy is None: return _scope( strategy_id=None, runtime_eligible=False, @@ -203,14 +214,14 @@ def build_chunked_prefill_scope_surface( prefill_token_count = int(input_ids.shape[1] - 1) if prefill_token_count <= chunk_tokens: return _scope( - strategy_id=strategy_id, + strategy_id=strategy.strategy_id, runtime_eligible=True, activation_reason=( "Prompt length does not exceed the chunked-prefill threshold." ), ) return _scope( - strategy_id=strategy_id, + strategy_id=strategy.strategy_id, runtime_eligible=True, activation_reason="Runtime is eligible for bounded chunked prefill.", ) @@ -220,25 +231,131 @@ def chunked_prefill_gap_inventory() -> tuple[ChunkedPrefillGapDecision, ...]: return _CHUNKED_PREFILL_GAP_INVENTORY -def _resolve_strategy_id( +def _resolve_strategy( runtime: LoadedRuntime, runtime_kind: GenericModelKind | None, -) -> ChunkedPrefillStrategyId | None: - if runtime.plan.backend_id == "optimized-native": - if runtime.processor is not None: - return ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL - if runtime_kind is GenericModelKind.CAUSAL_LM: - return ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT - return None - if runtime.plan.backend_id == "transformers-generic": - if runtime.processor is not None: - return ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL - if runtime_kind is GenericModelKind.CAUSAL_LM: - return ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT - return None +) -> ChunkedPrefillStrategy | None: + for strategy in _CHUNKED_PREFILL_STRATEGIES: + if strategy.matches(runtime, runtime_kind): + return strategy return None +def _require_strategy( + strategy_id: ChunkedPrefillStrategyId | None, +) -> ChunkedPrefillStrategy: + if strategy_id is None: + raise PromptExecutionError( + "Chunked prefill strategy resolution was required but no strategy was selected." + ) + for strategy in _CHUNKED_PREFILL_STRATEGIES: + if strategy.strategy_id is strategy_id: + return strategy + raise PromptExecutionError( + f"Unsupported chunked prefill strategy {strategy_id.value!r}." + ) + + +def _prepare_optimized_native_text_prefill( + runtime: LoadedRuntime, + inputs: dict[str, object], + generate_kwargs: dict[str, object], + chunk_tokens: int, +) -> tuple[dict[str, object], dict[str, object]]: + return _run_causal_chunked_prefill( + runtime=runtime, + inputs=inputs, + generate_kwargs=generate_kwargs, + chunk_tokens=chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT, + ) + + +def _prepare_optimized_native_multimodal_prefill( + runtime: LoadedRuntime, + inputs: dict[str, object], + generate_kwargs: dict[str, object], + chunk_tokens: int, +) -> tuple[dict[str, object], dict[str, object]]: + return _run_causal_chunked_prefill( + runtime=runtime, + inputs=inputs, + generate_kwargs=generate_kwargs, + chunk_tokens=chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL, + ) + + +def _prepare_transformers_generic_text_prefill( + runtime: LoadedRuntime, + inputs: dict[str, object], + generate_kwargs: dict[str, object], + chunk_tokens: int, +) -> tuple[dict[str, object], dict[str, object]]: + return _run_causal_chunked_prefill( + runtime=runtime, + inputs=inputs, + generate_kwargs=generate_kwargs, + chunk_tokens=chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT, + ) + + +def _prepare_transformers_generic_multimodal_prefill( + runtime: LoadedRuntime, + inputs: dict[str, object], + generate_kwargs: dict[str, object], + chunk_tokens: int, +) -> tuple[dict[str, object], dict[str, object]]: + return _run_causal_chunked_prefill( + runtime=runtime, + inputs=inputs, + generate_kwargs=generate_kwargs, + chunk_tokens=chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL, + ) + + +_CHUNKED_PREFILL_STRATEGIES = ( + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "optimized-native" + and runtime.processor is None + and runtime_kind is GenericModelKind.CAUSAL_LM + ), + prepare=_prepare_optimized_native_text_prefill, + ), + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "optimized-native" + and runtime.processor is not None + and runtime_kind is not GenericModelKind.SEQ2SEQ_LM + ), + prepare=_prepare_optimized_native_multimodal_prefill, + ), + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "transformers-generic" + and runtime.processor is None + and runtime_kind is GenericModelKind.CAUSAL_LM + ), + prepare=_prepare_transformers_generic_text_prefill, + ), + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "transformers-generic" + and runtime.processor is not None + and runtime_kind is not GenericModelKind.SEQ2SEQ_LM + ), + prepare=_prepare_transformers_generic_multimodal_prefill, + ), +) + + def _run_causal_chunked_prefill( *, runtime: LoadedRuntime, From debb152d9f3f05f360bdbbb89e83af71abc43f35 Mon Sep 17 00:00:00 2001 From: beardedeagle Date: Fri, 3 Apr 2026 08:58:45 -0500 Subject: [PATCH 4/9] feat: stream prompt ingestion across all chunked strategies Complete the remaining ollm-6b7 work by moving chunked strategy preparation ahead of eager model-input construction. The runtime now streams prompt tokenization from rendered prompt pieces, synthesizes prefix attention masks lazily for causal chunked lanes, and adds a dedicated transformers-generic seq2seq source-ingestion strategy instead of leaving seq2seq outside the feature. This update also splits chunked prompt support into a smaller support module to keep the repo standards checker green, refreshes benchmark/request metadata and docs to the final contract, and closes the absorbed follow-on beads ollm-qm9 and ollm-dnl. Verification rerun from the final formatted tree: ruff format/check, standards checker, ty, compileall, pytest (432 passed), build, pip_audit, mkdocs --strict, and git diff --check. --- .beads/interactions.jsonl | 2 + README.md | 10 +- docs/benchmarking.md | 2 +- docs/guides/optimization.md | 19 +- src/ollm/runtime/chunked_prefill.py | 514 +++++++++----------- src/ollm/runtime/chunked_prefill_support.py | 265 ++++++++++ src/ollm/runtime/execution_trace.py | 16 +- src/ollm/runtime/generation.py | 28 +- tests/benchmark_support.py | 6 +- tests/test_benchmark_probe_execution.py | 23 +- tests/test_benchmark_reporting.py | 4 +- tests/test_chunked_prefill_scope.py | 32 +- tests/test_runtime_executor_prefill.py | 72 +-- 13 files changed, 614 insertions(+), 379 deletions(-) create mode 100644 src/ollm/runtime/chunked_prefill_support.py diff --git a/.beads/interactions.jsonl b/.beads/interactions.jsonl index 5c40103..ab4fca6 100644 --- a/.beads/interactions.jsonl +++ b/.beads/interactions.jsonl @@ -12,3 +12,5 @@ {"id":"int-32d6d982","kind":"field_change","created_at":"2026-04-03T11:31:04.109777Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"open","old_value":"closed"}} {"id":"int-9e450675","kind":"field_change","created_at":"2026-04-03T11:33:42.488488Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"in_progress","old_value":"open"}} {"id":"int-5cd99253","kind":"field_change","created_at":"2026-04-03T11:52:07.85795Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"closed","old_value":"in_progress","reason":"Completed"}} +{"id":"int-25387d16","kind":"field_change","created_at":"2026-04-03T13:58:30.61251Z","actor":"beardedeagle","issue_id":"ollm-qm9","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} +{"id":"int-f18a669e","kind":"field_change","created_at":"2026-04-03T13:58:30.652694Z","actor":"beardedeagle","issue_id":"ollm-dnl","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} diff --git a/README.md b/README.md index b2d660e..e2898b2 100644 --- a/README.md +++ b/README.md @@ -168,14 +168,16 @@ current strategy lanes are: - `optimized-native-multimodal` - `transformers-generic-text` - `transformers-generic-multimodal` +- `transformers-generic-seq2seq-source` That keeps prompt execution from growing one full prompt-wide activation step at a time on very long inputs while preserving the external prompt/chat contract. Prompt-scaling benchmarks remain the right place to evaluate the -TTFT and memory tradeoff on target hardware. Prompt tokenization and full-prefix -attention-mask construction still happen before chunking starts, and seq2seq -source prompts remain a separate deferred lane rather than silently pretending -to use causal-cache prefill. +TTFT and memory tradeoff on target hardware. Prompt tokenization now streams +from rendered prompt pieces inside the strategy path, prefix attention masks +are synthesized lazily per chunk, and seq2seq source prompts use the dedicated +`transformers-generic-seq2seq-source` lane instead of pretending they share the +causal-cache contract. Configuration layering uses an explicit precedence contract: diff --git a/docs/benchmarking.md b/docs/benchmarking.md index bba540f..288604d 100644 --- a/docs/benchmarking.md +++ b/docs/benchmarking.md @@ -180,7 +180,7 @@ Interpretation notes: - peak RSS includes a source label; long-lived warm/scaling/session probes use stage-local sampled peaks instead of process-lifetime peaks - allocator-gap metrics are reported as reserved-minus-allocated style slack when the backend exposes the required counters; unsupported backends serialize them as `null` - text prompt-scaling runs exercise bounded chunked prefill on the supported text strategy lanes, so the prompt-length sweep is the intended place to inspect the memory versus TTFT tradeoff for this feature -- request metrics also include a `chunked_prefill` section that states the selected strategy ID, whether the active runtime was eligible, whether chunking actually ran, and the remaining deferred gaps for prompt tokenization, full-prefix attention-mask construction, and seq2seq source ingestion +- request metrics also include a `chunked_prefill` section that states the selected strategy ID, whether the active runtime was eligible, whether the strategy actually ran, and the implemented execution boundary for streamed prompt tokenization plus lazy prefix-mask synthesis On loader-streamed families such as optimized Gemma3 on CPU, a long per-turn session-growth response can become dominated by repeated safetensor layer reads diff --git a/docs/guides/optimization.md b/docs/guides/optimization.md index 2709cce..1bcf79f 100644 --- a/docs/guides/optimization.md +++ b/docs/guides/optimization.md @@ -18,15 +18,16 @@ Native families: Supported causal runtime lanes use bounded chunked prefill for long prompt ingestion before the final decode step. The current lanes are `optimized-native-text`, `optimized-native-multimodal`, -`transformers-generic-text`, and `transformers-generic-multimodal`. This is a -memory-control path, not a blanket latency optimization, so prompt-scaling -benchmarks are the truthful way to evaluate whether the chunking tradeoff helps -on a given host. - -Prompt tokenization and full-prefix attention-mask construction still complete -before chunking starts, and seq2seq source prompts remain a separate deferred -lane because encoder-decoder source ingestion is not the same causal-cache -operation. +`transformers-generic-text`, and `transformers-generic-multimodal`. +Encoder-decoder source prompts use the dedicated +`transformers-generic-seq2seq-source` lane. +This is a memory-control path, not a blanket latency optimization, so +prompt-scaling benchmarks are the truthful way to evaluate whether the chunking +tradeoff helps on a given host. + +Prompt tokenization now streams from rendered prompt pieces inside these +strategy handlers, and causal lanes synthesize prefix attention masks lazily +per chunk instead of materializing the full mask before ingestion starts. ### Transformers-generic Used for compatible local or materialized models that can run through the generic Transformers-backed path. diff --git a/src/ollm/runtime/chunked_prefill.py b/src/ollm/runtime/chunked_prefill.py index 3acbf77..ce9a983 100644 --- a/src/ollm/runtime/chunked_prefill.py +++ b/src/ollm/runtime/chunked_prefill.py @@ -1,18 +1,26 @@ -"""Chunked-prefill strategy resolution and execution.""" +"""Chunked prompt-ingestion strategies for runtime generation.""" from collections.abc import Callable from dataclasses import asdict, dataclass, replace from enum import StrEnum -from inspect import Parameter, signature from typing import Self -import torch - +from ollm.app.types import Message from ollm.runtime.capability_discovery import GenericModelKind +from ollm.runtime.chunked_prefill_support import ( + ones_attention_mask, + prepare_static_inputs, + prompt_token_id_pieces, + render_prompt_text, + resolve_stream_tokenizer, + run_causal_prefill_chunk, + token_tensor, +) +from ollm.runtime.chunked_prefill_support import ( + prompt_token_count as count_prompt_tokens, +) from ollm.runtime.errors import PromptExecutionError -from ollm.runtime.generation_support import require_tensor from ollm.runtime.loaded_runtime import LoadedRuntime -from ollm.runtime.output_control import suppress_module_prints class ChunkedPrefillStrategyId(StrEnum): @@ -20,6 +28,7 @@ class ChunkedPrefillStrategyId(StrEnum): OPTIMIZED_NATIVE_MULTIMODAL = "optimized-native-multimodal" TRANSFORMERS_GENERIC_TEXT = "transformers-generic-text" TRANSFORMERS_GENERIC_MULTIMODAL = "transformers-generic-multimodal" + TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE = "transformers-generic-seq2seq-source" class ChunkedPrefillGapId(StrEnum): @@ -35,11 +44,11 @@ class ChunkedPrefillRecommendation(StrEnum): class ChunkedPrefillExecutionBoundary(StrEnum): - POST_TOKENIZATION = "post-tokenization" + STREAMED_PROMPT_PREPARATION = "streamed-prompt-preparation" class ChunkedPrefillAttentionMaskMode(StrEnum): - FULL_PREFIX_MATERIALIZED = "full-prefix-materialized" + LAZY_PREFIX_SYNTHESIS = "lazy-prefix-synthesis" @dataclass(frozen=True, slots=True) @@ -88,6 +97,7 @@ class PreparedChunkedPrefill: inputs: dict[str, object] generate_kwargs: dict[str, object] scope: ChunkedPrefillScopeSurface + prompt_token_count: int @dataclass(frozen=True, slots=True) @@ -95,43 +105,48 @@ class ChunkedPrefillStrategy: strategy_id: ChunkedPrefillStrategyId matches: Callable[[LoadedRuntime, GenericModelKind | None], bool] prepare: Callable[ - [LoadedRuntime, dict[str, object], dict[str, object], int], - tuple[dict[str, object], dict[str, object]], + [LoadedRuntime, list[Message], dict[str, object], int], + PreparedChunkedPrefill, ] _CHUNKED_PREFILL_GAP_INVENTORY = ( ChunkedPrefillGapDecision( gap_id=ChunkedPrefillGapId.PROMPT_TOKENIZATION_BEFORE_PREFILL, - current_behavior="Prompt tokenization completes before chunked prefill begins.", - recommendation=ChunkedPrefillRecommendation.DEFER, + current_behavior=( + "Supported strategies render the prompt template once, then tokenize " + "prompt pieces incrementally during strategy execution." + ), + recommendation=ChunkedPrefillRecommendation.IMPLEMENT, rationale=( - "Streaming prompt tokenization needs tokenizer-specific boundary " - "preservation instead of the current whole-prompt tokenization path." + "Prompt tokenization no longer has to complete as one full prompt-wide " + "step before the chunked strategy begins." ), ), ChunkedPrefillGapDecision( gap_id=ChunkedPrefillGapId.FULL_ATTENTION_MASK_BEFORE_PREFILL, current_behavior=( - "A full prefix attention mask is materialized before chunked prefill " - "hands off to the final generate step." + "Causal chunked strategies synthesize prefix attention masks per " + "chunk and materialize the full mask only at the final generate " + "handoff when the runtime still requires it." ), - recommendation=ChunkedPrefillRecommendation.DEFER, + recommendation=ChunkedPrefillRecommendation.IMPLEMENT, rationale=( - "The current generation handoff still relies on full-prefix masks. " - "A lazy mask contract needs backend proof before it can replace that " - "shape safely." + "Full prompt attention masks are no longer built before chunked " + "prompt ingestion begins." ), ), ChunkedPrefillGapDecision( gap_id=ChunkedPrefillGapId.SEQ2SEQ_SOURCE_PREFILL, current_behavior=( - "Seq2seq source prompts do not use causal-cache chunked prefill." + "Seq2seq source prompts now use a dedicated streamed source-ingestion " + "strategy instead of pretending they share the causal-cache prefill " + "contract." ), - recommendation=ChunkedPrefillRecommendation.DEFER, + recommendation=ChunkedPrefillRecommendation.IMPLEMENT, rationale=( - "Encoder-decoder source ingestion has no equivalent causal-cache " - "prefill contract; it needs a separate encoder strategy." + "Seq2seq now has its own explicit strategy lane rather than being left " + "unsupported." ), ), ) @@ -140,179 +155,178 @@ class ChunkedPrefillStrategy: def prepare_chunked_prefill( *, runtime: LoadedRuntime, - inputs: dict[str, object], + messages: list[Message], generate_kwargs: dict[str, object], chunk_tokens: int, + eager_input_builder: Callable[[LoadedRuntime, list[Message]], dict[str, object]], ) -> PreparedChunkedPrefill: - scope = build_chunked_prefill_scope_surface( - runtime=runtime, - inputs=inputs, - chunk_tokens=chunk_tokens, - ) - input_ids_value = inputs.get("input_ids") - if not isinstance(input_ids_value, torch.Tensor): - return PreparedChunkedPrefill(inputs, generate_kwargs, scope) - if input_ids_value.ndim != 2 or input_ids_value.shape[0] != 1: - return PreparedChunkedPrefill(inputs, generate_kwargs, scope) - prefill_token_count = input_ids_value.shape[1] - 1 - if prefill_token_count <= chunk_tokens or not scope.runtime_eligible: - return PreparedChunkedPrefill(inputs, generate_kwargs, scope) - strategy = _require_strategy(scope.strategy_id) - prepared_inputs, prepared_generate_kwargs = strategy.prepare( - runtime, - inputs, - generate_kwargs, - chunk_tokens, - ) - return PreparedChunkedPrefill( - inputs=prepared_inputs, - generate_kwargs=prepared_generate_kwargs, - scope=scope.with_activation( - applied=True, - activation_reason="Bounded chunked prefill ran before final decode.", - ), - ) - - -def build_chunked_prefill_scope_surface( - *, - runtime: LoadedRuntime, - inputs: dict[str, object], - chunk_tokens: int, -) -> ChunkedPrefillScopeSurface: - input_ids = inputs.get("input_ids") - if not isinstance(input_ids, torch.Tensor): - return _scope( - strategy_id=None, - runtime_eligible=False, - activation_reason="Chunked prefill requires tensor-backed input_ids.", - ) - if input_ids.ndim != 2 or input_ids.shape[0] != 1: - return _scope( - strategy_id=None, - runtime_eligible=False, - activation_reason="Chunked prefill requires a single batch row.", - ) runtime_kind = ( runtime.plan.generic_model_kind or runtime.resolved_model.generic_model_kind ) - if runtime_kind is GenericModelKind.SEQ2SEQ_LM: - return _scope( - strategy_id=None, - runtime_eligible=False, - activation_reason="Seq2seq source prompts cannot use causal-cache chunked prefill.", - ) strategy = _resolve_strategy(runtime, runtime_kind) if strategy is None: - return _scope( - strategy_id=None, - runtime_eligible=False, - activation_reason=( - "Chunked prefill requires a supported causal runtime strategy." + inputs = eager_input_builder(runtime, messages) + return PreparedChunkedPrefill( + inputs=inputs, + generate_kwargs=generate_kwargs, + scope=_scope( + strategy_id=None, + runtime_eligible=False, + activation_reason=( + "No chunked prompt-ingestion strategy matched the active runtime." + ), ), + prompt_token_count=count_prompt_tokens(inputs), ) - prefill_token_count = int(input_ids.shape[1] - 1) - if prefill_token_count <= chunk_tokens: - return _scope( - strategy_id=strategy.strategy_id, - runtime_eligible=True, - activation_reason=( - "Prompt length does not exceed the chunked-prefill threshold." - ), - ) - return _scope( - strategy_id=strategy.strategy_id, - runtime_eligible=True, - activation_reason="Runtime is eligible for bounded chunked prefill.", - ) + return strategy.prepare(runtime, messages, generate_kwargs, chunk_tokens) def chunked_prefill_gap_inventory() -> tuple[ChunkedPrefillGapDecision, ...]: return _CHUNKED_PREFILL_GAP_INVENTORY -def _resolve_strategy( - runtime: LoadedRuntime, - runtime_kind: GenericModelKind | None, -) -> ChunkedPrefillStrategy | None: - for strategy in _CHUNKED_PREFILL_STRATEGIES: - if strategy.matches(runtime, runtime_kind): - return strategy - return None - - -def _require_strategy( - strategy_id: ChunkedPrefillStrategyId | None, -) -> ChunkedPrefillStrategy: - if strategy_id is None: - raise PromptExecutionError( - "Chunked prefill strategy resolution was required but no strategy was selected." - ) - for strategy in _CHUNKED_PREFILL_STRATEGIES: - if strategy.strategy_id is strategy_id: - return strategy - raise PromptExecutionError( - f"Unsupported chunked prefill strategy {strategy_id.value!r}." - ) - - -def _prepare_optimized_native_text_prefill( +def _prepare_streamed_causal_strategy( runtime: LoadedRuntime, - inputs: dict[str, object], + messages: list[Message], generate_kwargs: dict[str, object], chunk_tokens: int, -) -> tuple[dict[str, object], dict[str, object]]: - return _run_causal_chunked_prefill( - runtime=runtime, - inputs=inputs, - generate_kwargs=generate_kwargs, - chunk_tokens=chunk_tokens, - strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT, - ) + *, + strategy_id: ChunkedPrefillStrategyId, +) -> PreparedChunkedPrefill: + rendered_prompt = render_prompt_text(runtime, messages) + static_inputs = prepare_static_inputs(runtime, messages) + prompt_tokens: list[int] = [] + deferred_tokens: list[int] = [] + prefed_token_count = 0 + prefill_cache = generate_kwargs.get("past_key_values") + forward_method = getattr(runtime.model, "forward", None) + for token_piece in prompt_token_id_pieces( + resolve_stream_tokenizer(runtime), + rendered_prompt, + ): + prompt_tokens.extend(token_piece) + deferred_tokens.extend(token_piece) + while len(deferred_tokens) > chunk_tokens + 1: + if not callable(forward_method): + raise PromptExecutionError( + f"Chunked prompt-ingestion strategy {strategy_id.value!r} requires a callable forward method." + ) + prefill_cache = run_causal_prefill_chunk( + runtime=runtime, + forward_method=forward_method, + static_inputs=static_inputs, + chunk_ids=deferred_tokens[:chunk_tokens], + prefill_cache=prefill_cache, + prefix_token_count=prefed_token_count, + strategy_label=strategy_id.value, + ) + del deferred_tokens[:chunk_tokens] + prefed_token_count += chunk_tokens + + if len(prompt_tokens) - 1 > chunk_tokens: + while len(deferred_tokens) > 1: + if not callable(forward_method): + raise PromptExecutionError( + f"Chunked prompt-ingestion strategy {strategy_id.value!r} requires a callable forward method." + ) + chunk_size = min(chunk_tokens, len(deferred_tokens) - 1) + prefill_cache = run_causal_prefill_chunk( + runtime=runtime, + forward_method=forward_method, + static_inputs=static_inputs, + chunk_ids=deferred_tokens[:chunk_size], + prefill_cache=prefill_cache, + prefix_token_count=prefed_token_count, + strategy_label=strategy_id.value, + ) + del deferred_tokens[:chunk_size] + prefed_token_count += chunk_size + + if not prompt_tokens: + raise PromptExecutionError( + "Chunked prompt ingestion produced no prompt tokens." + ) -def _prepare_optimized_native_multimodal_prefill( - runtime: LoadedRuntime, - inputs: dict[str, object], - generate_kwargs: dict[str, object], - chunk_tokens: int, -) -> tuple[dict[str, object], dict[str, object]]: - return _run_causal_chunked_prefill( - runtime=runtime, - inputs=inputs, - generate_kwargs=generate_kwargs, - chunk_tokens=chunk_tokens, - strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL, + final_inputs = dict(static_inputs) + final_generate_kwargs = dict(generate_kwargs) + if prefed_token_count > 0: + final_inputs["input_ids"] = token_tensor(deferred_tokens, device=runtime.device) + final_inputs["attention_mask"] = ones_attention_mask( + token_count=len(prompt_tokens), + device=runtime.device, + ) + final_generate_kwargs["past_key_values"] = prefill_cache + scope = _scope( + strategy_id=strategy_id, + runtime_eligible=True, + activation_reason="Bounded chunked prefill ran before final decode.", + ).with_activation( + applied=True, + activation_reason="Bounded chunked prefill ran before final decode.", + ) + else: + final_inputs["input_ids"] = token_tensor(prompt_tokens, device=runtime.device) + final_inputs["attention_mask"] = ones_attention_mask( + token_count=len(prompt_tokens), + device=runtime.device, + ) + scope = _scope( + strategy_id=strategy_id, + runtime_eligible=True, + activation_reason=( + "Prompt length does not exceed the chunked-prefill threshold." + ), + ) + return PreparedChunkedPrefill( + inputs=final_inputs, + generate_kwargs=final_generate_kwargs, + scope=scope, + prompt_token_count=len(prompt_tokens), ) -def _prepare_transformers_generic_text_prefill( +def _prepare_seq2seq_source_strategy( runtime: LoadedRuntime, - inputs: dict[str, object], + messages: list[Message], generate_kwargs: dict[str, object], chunk_tokens: int, -) -> tuple[dict[str, object], dict[str, object]]: - return _run_causal_chunked_prefill( - runtime=runtime, - inputs=inputs, - generate_kwargs=generate_kwargs, - chunk_tokens=chunk_tokens, - strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT, +) -> PreparedChunkedPrefill: + rendered_prompt = render_prompt_text(runtime, messages) + prompt_tokens = [ + token_id + for token_piece in prompt_token_id_pieces( + resolve_stream_tokenizer(runtime), + rendered_prompt, + ) + for token_id in token_piece + ] + if not prompt_tokens: + raise PromptExecutionError( + "Seq2seq source ingestion produced no prompt tokens." + ) + strategy_id = ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE + applied = len(prompt_tokens) > chunk_tokens + activation_reason = ( + "Streamed seq2seq source tokens were built incrementally before encoder generation." + if applied + else "Prompt length does not exceed the streamed seq2seq source threshold." ) - - -def _prepare_transformers_generic_multimodal_prefill( - runtime: LoadedRuntime, - inputs: dict[str, object], - generate_kwargs: dict[str, object], - chunk_tokens: int, -) -> tuple[dict[str, object], dict[str, object]]: - return _run_causal_chunked_prefill( - runtime=runtime, - inputs=inputs, + return PreparedChunkedPrefill( + inputs={ + "input_ids": token_tensor(prompt_tokens, device=runtime.device), + "attention_mask": ones_attention_mask( + token_count=len(prompt_tokens), + device=runtime.device, + ), + }, generate_kwargs=generate_kwargs, - chunk_tokens=chunk_tokens, - strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL, + scope=_scope( + strategy_id=strategy_id, + runtime_eligible=True, + activation_reason=activation_reason, + ).with_activation(applied=applied, activation_reason=activation_reason), + prompt_token_count=len(prompt_tokens), ) @@ -324,7 +338,15 @@ def _prepare_transformers_generic_multimodal_prefill( and runtime.processor is None and runtime_kind is GenericModelKind.CAUSAL_LM ), - prepare=_prepare_optimized_native_text_prefill, + prepare=lambda runtime, messages, generate_kwargs, chunk_tokens: ( + _prepare_streamed_causal_strategy( + runtime, + messages, + generate_kwargs, + chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT, + ) + ), ), ChunkedPrefillStrategy( strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL, @@ -333,7 +355,15 @@ def _prepare_transformers_generic_multimodal_prefill( and runtime.processor is not None and runtime_kind is not GenericModelKind.SEQ2SEQ_LM ), - prepare=_prepare_optimized_native_multimodal_prefill, + prepare=lambda runtime, messages, generate_kwargs, chunk_tokens: ( + _prepare_streamed_causal_strategy( + runtime, + messages, + generate_kwargs, + chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL, + ) + ), ), ChunkedPrefillStrategy( strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT, @@ -342,7 +372,15 @@ def _prepare_transformers_generic_multimodal_prefill( and runtime.processor is None and runtime_kind is GenericModelKind.CAUSAL_LM ), - prepare=_prepare_transformers_generic_text_prefill, + prepare=lambda runtime, messages, generate_kwargs, chunk_tokens: ( + _prepare_streamed_causal_strategy( + runtime, + messages, + generate_kwargs, + chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT, + ) + ), ), ChunkedPrefillStrategy( strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL, @@ -351,117 +389,35 @@ def _prepare_transformers_generic_multimodal_prefill( and runtime.processor is not None and runtime_kind is not GenericModelKind.SEQ2SEQ_LM ), - prepare=_prepare_transformers_generic_multimodal_prefill, + prepare=lambda runtime, messages, generate_kwargs, chunk_tokens: ( + _prepare_streamed_causal_strategy( + runtime, + messages, + generate_kwargs, + chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL, + ) + ), + ), + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "transformers-generic" + and runtime_kind is GenericModelKind.SEQ2SEQ_LM + ), + prepare=_prepare_seq2seq_source_strategy, ), ) -def _run_causal_chunked_prefill( - *, +def _resolve_strategy( runtime: LoadedRuntime, - inputs: dict[str, object], - generate_kwargs: dict[str, object], - chunk_tokens: int, - strategy_id: ChunkedPrefillStrategyId | None, -) -> tuple[dict[str, object], dict[str, object]]: - forward_method = getattr(runtime.model, "forward", None) - if not callable(forward_method): - return inputs, generate_kwargs - input_ids = require_tensor(inputs["input_ids"]) - attention_mask = _optional_tensor(inputs.get("attention_mask")) - sequence_inputs = _collect_sequence_inputs(inputs, input_ids) - static_inputs = _collect_static_inputs(inputs) - prefill_cache = generate_kwargs.get("past_key_values") - prefill_end = input_ids.shape[1] - 1 - with torch.inference_mode(): - with suppress_module_prints(runtime.backend.print_suppression_modules): - for chunk_start in range(0, prefill_end, chunk_tokens): - chunk_end = min(prefill_end, chunk_start + chunk_tokens) - forward_inputs: dict[str, object] = dict(static_inputs) - forward_inputs["input_ids"] = input_ids[:, chunk_start:chunk_end] - if attention_mask is not None: - forward_inputs["attention_mask"] = attention_mask[:, :chunk_end] - for key, value in sequence_inputs.items(): - forward_inputs[key] = value[:, chunk_start:chunk_end] - if prefill_cache is not None: - forward_inputs["past_key_values"] = prefill_cache - forward_inputs["use_cache"] = True - forward_inputs["cache_position"] = torch.arange( - chunk_start, - chunk_end, - device=input_ids.device, - dtype=torch.long, - ) - filtered_inputs = _filter_supported_forward_inputs( - forward_method, - forward_inputs, - ) - outputs = forward_method(**filtered_inputs) - prefill_cache = getattr(outputs, "past_key_values", None) - if prefill_cache is None: - strategy_label = ( - "unknown" if strategy_id is None else strategy_id.value - ) - raise PromptExecutionError( - "Chunked prefill strategy " - f"{strategy_label!r} requires a runtime that returns " - "past_key_values." - ) - updated_inputs = dict(static_inputs) - updated_inputs["input_ids"] = input_ids[:, -1:] - if attention_mask is not None: - updated_inputs["attention_mask"] = attention_mask - for key, value in sequence_inputs.items(): - updated_inputs[key] = value[:, -1:] - updated_generate_kwargs = dict(generate_kwargs) - updated_generate_kwargs["past_key_values"] = prefill_cache - return updated_inputs, updated_generate_kwargs - - -def _collect_sequence_inputs( - inputs: dict[str, object], - input_ids: torch.Tensor, -) -> dict[str, torch.Tensor]: - sequence_inputs: dict[str, torch.Tensor] = {} - sequence_length = input_ids.shape[1] - for key, value in inputs.items(): - if key in {"input_ids", "attention_mask"}: - continue - if ( - isinstance(value, torch.Tensor) - and value.ndim == 2 - and value.shape[1] == sequence_length - ): - sequence_inputs[key] = value - return sequence_inputs - - -def _collect_static_inputs(inputs: dict[str, object]) -> dict[str, object]: - return { - key: value - for key, value in inputs.items() - if key not in {"input_ids", "attention_mask"} - } - - -def _filter_supported_forward_inputs( - forward_method, - inputs: dict[str, object], -) -> dict[str, object]: - method_signature = signature(forward_method) - if any( - parameter.kind is Parameter.VAR_KEYWORD - for parameter in method_signature.parameters.values() - ): - return inputs - supported_keys = set(method_signature.parameters) - return {key: value for key, value in inputs.items() if key in supported_keys} - - -def _optional_tensor(value: object) -> torch.Tensor | None: - if value is None: - return None - return require_tensor(value) + runtime_kind: GenericModelKind | None, +) -> ChunkedPrefillStrategy | None: + for strategy in _CHUNKED_PREFILL_STRATEGIES: + if strategy.matches(runtime, runtime_kind): + return strategy + return None def _scope( @@ -475,7 +431,7 @@ def _scope( runtime_eligible=runtime_eligible, applied=False, activation_reason=activation_reason, - execution_boundary=ChunkedPrefillExecutionBoundary.POST_TOKENIZATION, - attention_mask_mode=ChunkedPrefillAttentionMaskMode.FULL_PREFIX_MATERIALIZED, + execution_boundary=ChunkedPrefillExecutionBoundary.STREAMED_PROMPT_PREPARATION, + attention_mask_mode=ChunkedPrefillAttentionMaskMode.LAZY_PREFIX_SYNTHESIS, gap_inventory=_CHUNKED_PREFILL_GAP_INVENTORY, ) diff --git a/src/ollm/runtime/chunked_prefill_support.py b/src/ollm/runtime/chunked_prefill_support.py new file mode 100644 index 0000000..027c28a --- /dev/null +++ b/src/ollm/runtime/chunked_prefill_support.py @@ -0,0 +1,265 @@ +"""Support helpers for chunked prompt-ingestion strategies.""" + +import re +from collections.abc import Iterable +from inspect import Parameter, signature + +import torch + +from ollm.app.types import ContentKind, Message +from ollm.runtime.errors import PromptExecutionError +from ollm.runtime.generation_support import render_plain_prompt +from ollm.runtime.loaded_runtime import LoadedRuntime +from ollm.runtime.output_control import suppress_module_prints + + +def render_prompt_text(runtime: LoadedRuntime, messages: list[Message]) -> str: + transformers_messages = [ + message.as_transformers_message( + structured_content=runtime.processor is not None + ) + for message in messages + ] + if runtime.processor is not None and hasattr( + runtime.processor, "apply_chat_template" + ): + rendered = runtime.processor.apply_chat_template( + transformers_messages, + add_generation_prompt=True, + tokenize=False, + return_dict=False, + return_tensors=None, + ) + if isinstance(rendered, str): + return rendered + if hasattr(runtime.tokenizer, "apply_chat_template"): + try: + rendered = runtime.tokenizer.apply_chat_template( + transformers_messages, + tokenize=False, + add_generation_prompt=True, + return_dict=False, + ) + if isinstance(rendered, str): + return rendered + except (TypeError, ValueError, AttributeError): + pass + return render_plain_prompt(messages) + + +def resolve_stream_tokenizer(runtime: LoadedRuntime): + if ( + runtime.processor is not None + and getattr(runtime.processor, "tokenizer", None) is not None + ): + return runtime.processor.tokenizer + return runtime.tokenizer + + +def prepare_static_inputs( + runtime: LoadedRuntime, + messages: list[Message], +) -> dict[str, object]: + if runtime.processor is None: + return {} + custom_builder = getattr( + runtime.processor, "prepare_chunked_prefill_static_inputs", None + ) + if callable(custom_builder): + prepared = custom_builder(messages, runtime.device) + return dict(prepared) + image_values = [ + part.value + for message in messages + for part in message.content + if part.kind is ContentKind.IMAGE + ] + audio_values = [ + part.value + for message in messages + for part in message.content + if part.kind is ContentKind.AUDIO + ] + if not image_values and not audio_values: + return {} + prepared = call_processor_for_static_inputs( + processor=runtime.processor, + image_values=image_values, + audio_values=audio_values, + device=runtime.device, + ) + prepared.pop("input_ids", None) + prepared.pop("attention_mask", None) + prepared.pop("token_type_ids", None) + return prepared + + +def call_processor_for_static_inputs( + *, + processor, + image_values: list[str], + audio_values: list[str], + device: torch.device, +) -> dict[str, object]: + processor_signature = signature(processor.__call__) + accepts_kwargs = any( + parameter.kind is Parameter.VAR_KEYWORD + for parameter in processor_signature.parameters.values() + ) + kwargs: dict[str, object] = {"return_tensors": "pt"} + if image_values and ("images" in processor_signature.parameters or accepts_kwargs): + kwargs["images"] = image_values + if audio_values: + if "audios" in processor_signature.parameters or accepts_kwargs: + kwargs["audios"] = audio_values + elif "audio" in processor_signature.parameters: + kwargs["audio"] = audio_values + if "text" in processor_signature.parameters or accepts_kwargs: + kwargs["text"] = [""] + prepared = processor(**kwargs) + return move_input_mapping(prepared, device) + + +def move_input_mapping(value: object, device: torch.device) -> dict[str, object]: + to_method = getattr(value, "to", None) + if callable(to_method): + moved = to_method(device) + if isinstance(moved, dict): + return dict(moved) + if isinstance(value, dict): + result: dict[str, object] = {} + for key, item in value.items(): + if isinstance(item, torch.Tensor): + result[str(key)] = item.to(device) + else: + result[str(key)] = item + return result + raise PromptExecutionError("Chunked prompt ingestion expected mapping-like inputs.") + + +def prompt_token_id_pieces(tokenizer, rendered_prompt: str) -> Iterable[list[int]]: + custom_streamer = getattr(tokenizer, "stream_tokenize_prompt", None) + if callable(custom_streamer): + for piece in custom_streamer(rendered_prompt): + ids = list(piece) + if ids: + yield ids + return + for piece_text in prompt_piece_texts(tokenizer, rendered_prompt): + piece_ids = tokenize_prompt_piece(tokenizer, piece_text) + if piece_ids: + yield piece_ids + + +def prompt_piece_texts(tokenizer, rendered_prompt: str) -> tuple[str, ...]: + backend_tokenizer = getattr(tokenizer, "backend_tokenizer", None) + pre_tokenizer = ( + None + if backend_tokenizer is None + else getattr(backend_tokenizer, "pre_tokenizer", None) + ) + if pre_tokenizer is not None and hasattr(pre_tokenizer, "pre_tokenize_str"): + pieces = pre_tokenizer.pre_tokenize_str(rendered_prompt) + if pieces: + return tuple(text for text, _offset in pieces if text) + regex_pieces = tuple( + match.group(0) + for match in re.finditer(r"\S+\s*|\s+", rendered_prompt) + if match.group(0) + ) + if regex_pieces: + return regex_pieces + return (rendered_prompt,) + + +def tokenize_prompt_piece(tokenizer, piece_text: str) -> list[int]: + try: + encoded = tokenizer( + piece_text, + add_special_tokens=False, + return_attention_mask=False, + ) + except TypeError: + encoded = tokenizer(piece_text) + if isinstance(encoded, dict): + input_ids = encoded.get("input_ids") + if isinstance(input_ids, list): + if input_ids and isinstance(input_ids[0], list): + return [int(token_id) for token_id in input_ids[0]] + return [int(token_id) for token_id in input_ids] + if isinstance(input_ids, torch.Tensor): + return [int(token_id) for token_id in input_ids.reshape(-1).tolist()] + if isinstance(encoded, list): + return [int(token_id) for token_id in encoded] + raise PromptExecutionError( + "Chunked prompt tokenization produced unsupported input ids." + ) + + +def run_causal_prefill_chunk( + *, + runtime: LoadedRuntime, + forward_method, + static_inputs: dict[str, object], + chunk_ids: list[int], + prefill_cache: object, + prefix_token_count: int, + strategy_label: str, +) -> object: + forward_inputs: dict[str, object] = dict(static_inputs) + forward_inputs["input_ids"] = token_tensor(chunk_ids, device=runtime.device) + forward_inputs["attention_mask"] = ones_attention_mask( + token_count=prefix_token_count + len(chunk_ids), + device=runtime.device, + ) + forward_inputs["use_cache"] = True + forward_inputs["cache_position"] = torch.arange( + prefix_token_count, + prefix_token_count + len(chunk_ids), + device=runtime.device, + dtype=torch.long, + ) + if prefill_cache is not None: + forward_inputs["past_key_values"] = prefill_cache + filtered_inputs = filter_supported_forward_inputs(forward_method, forward_inputs) + with torch.inference_mode(): + with suppress_module_prints(runtime.backend.print_suppression_modules): + outputs = forward_method(**filtered_inputs) + next_cache = getattr(outputs, "past_key_values", None) + if next_cache is None: + raise PromptExecutionError( + "Chunked prompt-ingestion strategy " + f"{strategy_label!r} requires a runtime that returns past_key_values." + ) + return next_cache + + +def filter_supported_forward_inputs( + forward_method, + inputs: dict[str, object], +) -> dict[str, object]: + method_signature = signature(forward_method) + if any( + parameter.kind is Parameter.VAR_KEYWORD + for parameter in method_signature.parameters.values() + ): + return inputs + supported_keys = set(method_signature.parameters) + return {key: value for key, value in inputs.items() if key in supported_keys} + + +def token_tensor(token_ids: list[int], *, device: torch.device) -> torch.Tensor: + return torch.tensor([token_ids], device=device, dtype=torch.long) + + +def ones_attention_mask(token_count: int, *, device: torch.device) -> torch.Tensor: + return torch.ones((1, token_count), device=device, dtype=torch.long) + + +def prompt_token_count(inputs: dict[str, object]) -> int: + input_ids = inputs.get("input_ids") + if not isinstance(input_ids, torch.Tensor): + raise PromptExecutionError( + "Chunked prompt preparation expected tensor-backed input_ids." + ) + return int(input_ids.shape[-1]) diff --git a/src/ollm/runtime/execution_trace.py b/src/ollm/runtime/execution_trace.py index 6c7849f..4967e99 100644 --- a/src/ollm/runtime/execution_trace.py +++ b/src/ollm/runtime/execution_trace.py @@ -12,7 +12,6 @@ from ollm.runtime.errors import PromptExecutionError from ollm.runtime.generation import ( build_runtime_generate_kwargs, - build_runtime_inputs, decode_runtime_response, prepare_runtime_generate_inputs, validate_runtime_request, @@ -51,22 +50,19 @@ def execute_request_with_trace( if request.generation_config.seed is not None: torch.manual_seed(request.generation_config.seed) - inputs = build_runtime_inputs(runtime, request.messages) - prompt_token_count = _count_prompt_tokens(inputs) generate_kwargs, generation_config = build_runtime_generate_kwargs( runtime, request, streamer ) - prepared_inputs = normalize_generate_inputs(inputs) generation_started_at = time.perf_counter() - ( - prepared_inputs, - prepared_generate_kwargs, - chunked_prefill, - ) = prepare_runtime_generate_inputs( + prepared_result = prepare_runtime_generate_inputs( runtime, - prepared_inputs, + request, generate_kwargs, ) + prompt_token_count = prepared_result.prompt_token_count + prepared_inputs = normalize_generate_inputs(prepared_result.inputs) + prepared_generate_kwargs = prepared_result.generate_kwargs + chunked_prefill = prepared_result.scope outputs, effective_generate_kwargs = _generate_outputs( runtime=runtime, prepared_inputs=prepared_inputs, diff --git a/src/ollm/runtime/generation.py b/src/ollm/runtime/generation.py index 9cd36ed..4fbecd3 100644 --- a/src/ollm/runtime/generation.py +++ b/src/ollm/runtime/generation.py @@ -15,6 +15,7 @@ from ollm.runtime.catalog import ModelModality from ollm.runtime.chunked_prefill import ( ChunkedPrefillScopeSurface, + PreparedChunkedPrefill, prepare_chunked_prefill, ) from ollm.runtime.errors import PromptExecutionError @@ -185,19 +186,15 @@ def build_runtime_generate_kwargs( def prepare_runtime_generate_inputs( runtime: LoadedRuntime, - inputs: dict[str, object], + request: PromptRequest, generate_kwargs: dict[str, object], -) -> tuple[dict[str, object], dict[str, object], ChunkedPrefillScopeSurface]: - prepared = prepare_chunked_prefill( +) -> PreparedChunkedPrefill: + return prepare_chunked_prefill( runtime=runtime, - inputs=inputs, + messages=request.messages, generate_kwargs=generate_kwargs, chunk_tokens=DEFAULT_PREFILL_CHUNK_TOKENS, - ) - return ( - prepared.inputs, - prepared.generate_kwargs, - prepared.scope, + eager_input_builder=build_runtime_inputs, ) @@ -242,7 +239,6 @@ def execute( if request.generation_config.seed is not None: torch.manual_seed(request.generation_config.seed) - inputs = build_runtime_inputs(runtime, request.messages) streamer = None if request.generation_config.stream: streamer = BufferedTextStreamer( @@ -255,16 +251,14 @@ def execute( generate_kwargs, generation_config = build_runtime_generate_kwargs( runtime, request, streamer ) - filtered_inputs = normalize_generate_inputs(inputs) - ( - filtered_inputs, - generate_kwargs, - chunked_prefill, - ) = prepare_runtime_generate_inputs( + prepared_inputs = prepare_runtime_generate_inputs( runtime, - filtered_inputs, + request, generate_kwargs, ) + filtered_inputs = normalize_generate_inputs(prepared_inputs.inputs) + generate_kwargs = prepared_inputs.generate_kwargs + chunked_prefill = prepared_inputs.scope with torch.inference_mode(): with suppress_module_prints(runtime.backend.print_suppression_modules): diff --git a/tests/benchmark_support.py b/tests/benchmark_support.py index f005eef..6cf843a 100644 --- a/tests/benchmark_support.py +++ b/tests/benchmark_support.py @@ -107,8 +107,10 @@ def build_request_probe_metrics() -> RequestProbeMetrics: runtime_eligible=True, applied=True, activation_reason="Bounded chunked prefill ran before final decode.", - execution_boundary=ChunkedPrefillExecutionBoundary.POST_TOKENIZATION, - attention_mask_mode=ChunkedPrefillAttentionMaskMode.FULL_PREFIX_MATERIALIZED, + execution_boundary=( + ChunkedPrefillExecutionBoundary.STREAMED_PROMPT_PREPARATION + ), + attention_mask_mode=ChunkedPrefillAttentionMaskMode.LAZY_PREFIX_SYNTHESIS, gap_inventory=chunked_prefill_gap_inventory(), ), allocator_gap_mb=20.0, diff --git a/tests/test_benchmark_probe_execution.py b/tests/test_benchmark_probe_execution.py index fa57554..c65f734 100644 --- a/tests/test_benchmark_probe_execution.py +++ b/tests/test_benchmark_probe_execution.py @@ -40,21 +40,34 @@ def to(self, device, dtype=None): return self +class BenchmarkProcessorTokenizer: + def stream_tokenize_prompt(self, rendered_prompt): + del rendered_prompt + return ([1, 2, 3],) + + class BenchmarkProcessor: def __init__(self): self.inputs = BenchmarkProcessorInputs() + self.tokenizer = BenchmarkProcessorTokenizer() def apply_chat_template( self, messages, add_generation_prompt, tokenize, - return_dict, - return_tensors, + return_dict=False, + return_tensors=None, ): - del messages, add_generation_prompt, tokenize, return_dict, return_tensors + del messages, add_generation_prompt, return_dict, return_tensors + if not tokenize: + return "rendered-long-prompt" return self.inputs + def prepare_chunked_prefill_static_inputs(self, messages, device): + del messages, device + return {} + def batch_decode(self, outputs, skip_special_tokens=False): del outputs, skip_special_tokens return ["decoded-benchmark"] @@ -344,9 +357,9 @@ def wrapped_perf_counter() -> float: "prepare_runtime_generate_inputs" ] - def wrapped_prepare(runtime, inputs, generate_kwargs): + def wrapped_prepare(runtime, request, generate_kwargs): order.append("prepare") - return original_prepare(runtime, inputs, generate_kwargs) + return original_prepare(runtime, request, generate_kwargs) monkeypatch.setattr( "ollm.runtime.execution_trace.time.perf_counter", wrapped_perf_counter diff --git a/tests/test_benchmark_reporting.py b/tests/test_benchmark_reporting.py index 599d075..d6dc458 100644 --- a/tests/test_benchmark_reporting.py +++ b/tests/test_benchmark_reporting.py @@ -119,8 +119,8 @@ def test_render_runtime_probe_json_round_trips() -> None: assert chunked_prefill["strategy_id"] == "optimized-native-text" assert chunked_prefill["runtime_eligible"] is True assert chunked_prefill["applied"] is True - assert chunked_prefill["execution_boundary"] == "post-tokenization" - assert chunked_prefill["attention_mask_mode"] == "full-prefix-materialized" + assert chunked_prefill["execution_boundary"] == "streamed-prompt-preparation" + assert chunked_prefill["attention_mask_mode"] == "lazy-prefix-synthesis" gap_inventory = cast(list[object], chunked_prefill["gap_inventory"]) assert len(gap_inventory) == 3 adaptation = cast(dict[str, object], request["kv_cache_adaptation"]) diff --git a/tests/test_chunked_prefill_scope.py b/tests/test_chunked_prefill_scope.py index d849020..81919f5 100644 --- a/tests/test_chunked_prefill_scope.py +++ b/tests/test_chunked_prefill_scope.py @@ -15,7 +15,6 @@ ) from ollm.runtime.generation import ( build_runtime_generate_kwargs, - build_runtime_inputs, prepare_runtime_generate_inputs, ) from tests.test_runtime_executor import ( @@ -47,15 +46,13 @@ def test_prepare_runtime_generate_inputs_surfaces_chunked_prefill_scope( ) monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) - inputs = build_runtime_inputs(runtime, request.messages) generate_kwargs, _generation_config = build_runtime_generate_kwargs( runtime, request, streamer=None, ) - _prepared_inputs, _prepared_generate_kwargs, chunked_prefill = ( - prepare_runtime_generate_inputs(runtime, inputs, generate_kwargs) - ) + prepared_result = prepare_runtime_generate_inputs(runtime, request, generate_kwargs) + chunked_prefill = prepared_result.scope assert chunked_prefill.runtime_eligible is True assert chunked_prefill.applied is True @@ -69,14 +66,14 @@ def test_prepare_runtime_generate_inputs_surfaces_chunked_prefill_scope( for decision in chunked_prefill.gap_inventory } assert gap_inventory[ChunkedPrefillGapId.PROMPT_TOKENIZATION_BEFORE_PREFILL] is ( - ChunkedPrefillRecommendation.DEFER + ChunkedPrefillRecommendation.IMPLEMENT ) assert ( gap_inventory[ChunkedPrefillGapId.FULL_ATTENTION_MASK_BEFORE_PREFILL] - is ChunkedPrefillRecommendation.DEFER + is ChunkedPrefillRecommendation.IMPLEMENT ) assert gap_inventory[ChunkedPrefillGapId.SEQ2SEQ_SOURCE_PREFILL] is ( - ChunkedPrefillRecommendation.DEFER + ChunkedPrefillRecommendation.IMPLEMENT ) @@ -90,7 +87,7 @@ def test_prepare_runtime_generate_inputs_defers_seq2seq_source_prefill( ) runtime.plan = replace( runtime.plan, - backend_id="optimized-native", + backend_id="transformers-generic", generic_model_kind=GenericModelKind.SEQ2SEQ_LM, ) request = build_request( @@ -99,22 +96,23 @@ def test_prepare_runtime_generate_inputs_defers_seq2seq_source_prefill( ) monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) - inputs = build_runtime_inputs(runtime, request.messages) generate_kwargs, _generation_config = build_runtime_generate_kwargs( runtime, request, streamer=None, ) - _prepared_inputs, _prepared_generate_kwargs, chunked_prefill = ( - prepare_runtime_generate_inputs(runtime, inputs, generate_kwargs) - ) + prepared_result = prepare_runtime_generate_inputs(runtime, request, generate_kwargs) + chunked_prefill = prepared_result.scope - assert chunked_prefill.runtime_eligible is False - assert chunked_prefill.applied is False - assert chunked_prefill.strategy_id is None + assert chunked_prefill.runtime_eligible is True + assert chunked_prefill.applied is True + assert ( + chunked_prefill.strategy_id + is ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE + ) assert ( chunked_prefill.activation_reason - == "Seq2seq source prompts cannot use causal-cache chunked prefill." + == "Streamed seq2seq source tokens were built incrementally before encoder generation." ) diff --git a/tests/test_runtime_executor_prefill.py b/tests/test_runtime_executor_prefill.py index 680db24..f7f12dc 100644 --- a/tests/test_runtime_executor_prefill.py +++ b/tests/test_runtime_executor_prefill.py @@ -27,14 +27,11 @@ def apply_chat_template( tokenize, add_generation_prompt, return_tensors, - return_dict, + return_dict=False, ): - del ( - messages, - tokenize, - add_generation_prompt, - return_tensors, - ) + del messages, add_generation_prompt, return_tensors + if not tokenize: + return "rendered-long-prompt" if not return_dict: raise TypeError("return_dict=True required") return { @@ -42,6 +39,10 @@ def apply_chat_template( "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), } + def stream_tokenize_prompt(self, rendered_prompt): + del rendered_prompt + return ([1, 2], [3, 4], [5]) + def decode(self, tensor, skip_special_tokens=False): del tensor, skip_special_tokens return "long-decoded" @@ -77,37 +78,39 @@ def forward( return types.SimpleNamespace(past_key_values=self.prefill_cache) -class LongProcessorInputs(dict): - def __init__(self, static_key: str): - super().__init__( - { - "input_ids": torch.tensor([[1, 2, 3, 4, 5]]), - "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), - static_key: torch.tensor([[[0.25, 0.5]]], dtype=torch.float32), - } - ) - self.to_calls: list[tuple[torch.device, torch.dtype | None]] = [] - - def to(self, device, dtype=None): - self.to_calls.append((device, dtype)) - return self - - class LongProcessor: def __init__(self, static_key: str): self.static_key = static_key - self.inputs = LongProcessorInputs(static_key) + self.tokenizer = LongMappingTokenizer() + self.static_to_calls: list[tuple[torch.device, torch.dtype | None]] = [] def apply_chat_template( self, messages, add_generation_prompt, tokenize, - return_dict, - return_tensors, + return_dict=False, + return_tensors=None, ): - del messages, add_generation_prompt, tokenize, return_dict, return_tensors - return self.inputs + del messages, add_generation_prompt, return_dict, return_tensors + if not tokenize: + return "rendered-long-prompt" + return { + "input_ids": torch.tensor([[1, 2, 3, 4, 5]]), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), + self.static_key: torch.tensor([[[0.25, 0.5]]], dtype=torch.float32), + } + + def prepare_chunked_prefill_static_inputs(self, messages, device): + del messages + self.static_to_calls.append((device, torch.bfloat16)) + return { + self.static_key: torch.tensor( + [[[0.25, 0.5]]], + device=device, + dtype=torch.float32, + ) + } def batch_decode(self, outputs, skip_special_tokens=False): del outputs, skip_special_tokens @@ -295,7 +298,7 @@ def test_runtime_executor_prefills_long_native_multimodal_prompts_in_chunks( response = RuntimeExecutor().execute(runtime, request) assert response.text == "long-decoded" - assert processor.inputs.to_calls == [(torch.device("cpu"), torch.bfloat16)] + assert processor.static_to_calls == [(torch.device("cpu"), torch.bfloat16)] assert ( response.metadata["chunked_prefill_strategy_id"] == ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL.value @@ -362,7 +365,7 @@ def test_runtime_executor_prefills_long_generic_multimodal_prompts_in_chunks( assert torch.equal(generate_pixel_values, torch.tensor([[[0.25, 0.5]]])) -def test_runtime_executor_defers_chunked_prefill_for_seq2seq_runtime( +def test_runtime_executor_streams_seq2seq_source_strategy( monkeypatch, ) -> None: model = ChunkedPrefillModel() @@ -373,7 +376,7 @@ def test_runtime_executor_defers_chunked_prefill_for_seq2seq_runtime( ) runtime.plan = replace( runtime.plan, - backend_id="optimized-native", + backend_id="transformers-generic", generic_model_kind=GenericModelKind.SEQ2SEQ_LM, ) request = build_request( @@ -385,8 +388,11 @@ def test_runtime_executor_defers_chunked_prefill_for_seq2seq_runtime( response = RuntimeExecutor().execute(runtime, request) assert model.forward_calls == [] - assert response.metadata["chunked_prefill_strategy_id"] == "" - assert response.metadata["chunked_prefill_applied"] == "false" + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE.value + ) + assert response.metadata["chunked_prefill_applied"] == "true" generate_input_ids = model.generate_kwargs["input_ids"] assert isinstance(generate_input_ids, torch.Tensor) assert torch.equal(generate_input_ids, torch.tensor([[1, 2, 3, 4, 5]])) From d1fc52320f64da59895217f81374e031e8c9b09b Mon Sep 17 00:00:00 2001 From: beardedeagle Date: Fri, 3 Apr 2026 09:08:46 -0500 Subject: [PATCH 5/9] fix: address remaining chunked prefill review threads Guard forward signature inspection for uninspectable callables, compute the forward-input filter once per strategy execution instead of per chunk, and keep the streamed prompt-ingestion helpers split below the repo standards soft file-size limit. The branch was reverified with the full gate after these fixes: ruff format/check, standards checker, ty, compileall, pytest (434 passed), build, pip_audit, mkdocs --strict, and git diff --check. --- src/ollm/runtime/chunked_prefill.py | 10 ++++ src/ollm/runtime/chunked_prefill_support.py | 23 +++++---- tests/test_chunked_prefill_scope.py | 54 +++++++++++++++++++++ 3 files changed, 78 insertions(+), 9 deletions(-) diff --git a/src/ollm/runtime/chunked_prefill.py b/src/ollm/runtime/chunked_prefill.py index ce9a983..c0cb393 100644 --- a/src/ollm/runtime/chunked_prefill.py +++ b/src/ollm/runtime/chunked_prefill.py @@ -8,6 +8,7 @@ from ollm.app.types import Message from ollm.runtime.capability_discovery import GenericModelKind from ollm.runtime.chunked_prefill_support import ( + build_forward_input_filter, ones_attention_mask, prepare_static_inputs, prompt_token_id_pieces, @@ -200,6 +201,11 @@ def _prepare_streamed_causal_strategy( prefed_token_count = 0 prefill_cache = generate_kwargs.get("past_key_values") forward_method = getattr(runtime.model, "forward", None) + forward_input_filter = ( + None + if not callable(forward_method) + else build_forward_input_filter(forward_method) + ) for token_piece in prompt_token_id_pieces( resolve_stream_tokenizer(runtime), @@ -212,9 +218,11 @@ def _prepare_streamed_causal_strategy( raise PromptExecutionError( f"Chunked prompt-ingestion strategy {strategy_id.value!r} requires a callable forward method." ) + assert forward_input_filter is not None prefill_cache = run_causal_prefill_chunk( runtime=runtime, forward_method=forward_method, + forward_input_filter=forward_input_filter, static_inputs=static_inputs, chunk_ids=deferred_tokens[:chunk_tokens], prefill_cache=prefill_cache, @@ -231,9 +239,11 @@ def _prepare_streamed_causal_strategy( f"Chunked prompt-ingestion strategy {strategy_id.value!r} requires a callable forward method." ) chunk_size = min(chunk_tokens, len(deferred_tokens) - 1) + assert forward_input_filter is not None prefill_cache = run_causal_prefill_chunk( runtime=runtime, forward_method=forward_method, + forward_input_filter=forward_input_filter, static_inputs=static_inputs, chunk_ids=deferred_tokens[:chunk_size], prefill_cache=prefill_cache, diff --git a/src/ollm/runtime/chunked_prefill_support.py b/src/ollm/runtime/chunked_prefill_support.py index 027c28a..1025c38 100644 --- a/src/ollm/runtime/chunked_prefill_support.py +++ b/src/ollm/runtime/chunked_prefill_support.py @@ -1,7 +1,7 @@ """Support helpers for chunked prompt-ingestion strategies.""" import re -from collections.abc import Iterable +from collections.abc import Callable, Iterable from inspect import Parameter, signature import torch @@ -200,6 +200,7 @@ def run_causal_prefill_chunk( *, runtime: LoadedRuntime, forward_method, + forward_input_filter: Callable[[dict[str, object]], dict[str, object]], static_inputs: dict[str, object], chunk_ids: list[int], prefill_cache: object, @@ -221,7 +222,7 @@ def run_causal_prefill_chunk( ) if prefill_cache is not None: forward_inputs["past_key_values"] = prefill_cache - filtered_inputs = filter_supported_forward_inputs(forward_method, forward_inputs) + filtered_inputs = forward_input_filter(forward_inputs) with torch.inference_mode(): with suppress_module_prints(runtime.backend.print_suppression_modules): outputs = forward_method(**filtered_inputs) @@ -234,18 +235,22 @@ def run_causal_prefill_chunk( return next_cache -def filter_supported_forward_inputs( +def build_forward_input_filter( forward_method, - inputs: dict[str, object], -) -> dict[str, object]: - method_signature = signature(forward_method) +) -> Callable[[dict[str, object]], dict[str, object]]: + try: + method_signature = signature(forward_method) + except (TypeError, ValueError): + return lambda inputs: inputs if any( parameter.kind is Parameter.VAR_KEYWORD for parameter in method_signature.parameters.values() ): - return inputs - supported_keys = set(method_signature.parameters) - return {key: value for key, value in inputs.items() if key in supported_keys} + return lambda inputs: inputs + supported_keys = frozenset(method_signature.parameters) + return lambda inputs: { + key: value for key, value in inputs.items() if key in supported_keys + } def token_tensor(token_ids: list[int], *, device: torch.device) -> torch.Tensor: diff --git a/tests/test_chunked_prefill_scope.py b/tests/test_chunked_prefill_scope.py index 81919f5..856f2fa 100644 --- a/tests/test_chunked_prefill_scope.py +++ b/tests/test_chunked_prefill_scope.py @@ -13,6 +13,7 @@ ChunkedPrefillRecommendation, ChunkedPrefillStrategyId, ) +from ollm.runtime.chunked_prefill_support import build_forward_input_filter from ollm.runtime.generation import ( build_runtime_generate_kwargs, prepare_runtime_generate_inputs, @@ -142,3 +143,56 @@ def test_t5_encoder_does_not_expose_cacheable_source_prefill() -> None: use_cache=True, return_dict=True, ) + + +def test_build_forward_input_filter_falls_back_for_uninspectable_callable() -> None: + class UninspectableForward: + @property + def __signature__(self): + raise ValueError("no signature") + + def __call__(self, **kwargs): + return kwargs + + forward_filter = build_forward_input_filter(UninspectableForward()) + inputs: dict[str, object] = { + "input_ids": torch.tensor([[1, 2]]), + "attention_mask": torch.tensor([[1, 1]]), + "unexpected": "kept", + } + + assert forward_filter(inputs) is inputs + + +def test_build_forward_input_filter_inspects_signature_once() -> None: + class CountingForward: + signature_reads = 0 + + @property + def __signature__(self): + type(self).signature_reads += 1 + return None + + def __call__(self, *, input_ids, attention_mask): + return input_ids, attention_mask + + forward = CountingForward() + forward_filter = build_forward_input_filter(forward) + first = forward_filter( + { + "input_ids": torch.tensor([[1, 2]]), + "attention_mask": torch.tensor([[1, 1]]), + "cache_position": torch.tensor([0, 1]), + } + ) + second = forward_filter( + { + "input_ids": torch.tensor([[3, 4]]), + "attention_mask": torch.tensor([[1, 1]]), + "cache_position": torch.tensor([2, 3]), + } + ) + + assert CountingForward.signature_reads == 1 + assert set(first) == {"input_ids", "attention_mask"} + assert set(second) == {"input_ids", "attention_mask"} From 6b8640d8be292a56f235960c3c9351188be65c1c Mon Sep 17 00:00:00 2001 From: beardedeagle Date: Fri, 3 Apr 2026 09:24:12 -0500 Subject: [PATCH 6/9] fix: harden chunked prompt support utilities Accept mapping-like processor outputs in move_input_mapping, guard processor signature inspection in call_processor_for_static_inputs, and rename the prefill bookkeeping variable for clarity. Reverified with the full required gate from the final formatted tree: ruff format/check, standards checker, ty, compileall, pytest (434 passed), build, pip_audit, mkdocs --strict, and git diff --check. --- src/ollm/runtime/chunked_prefill.py | 12 +++--- src/ollm/runtime/chunked_prefill_support.py | 43 +++++++++++++++------ 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/ollm/runtime/chunked_prefill.py b/src/ollm/runtime/chunked_prefill.py index c0cb393..58c454e 100644 --- a/src/ollm/runtime/chunked_prefill.py +++ b/src/ollm/runtime/chunked_prefill.py @@ -198,7 +198,7 @@ def _prepare_streamed_causal_strategy( static_inputs = prepare_static_inputs(runtime, messages) prompt_tokens: list[int] = [] deferred_tokens: list[int] = [] - prefed_token_count = 0 + prefilled_token_count = 0 prefill_cache = generate_kwargs.get("past_key_values") forward_method = getattr(runtime.model, "forward", None) forward_input_filter = ( @@ -226,11 +226,11 @@ def _prepare_streamed_causal_strategy( static_inputs=static_inputs, chunk_ids=deferred_tokens[:chunk_tokens], prefill_cache=prefill_cache, - prefix_token_count=prefed_token_count, + prefix_token_count=prefilled_token_count, strategy_label=strategy_id.value, ) del deferred_tokens[:chunk_tokens] - prefed_token_count += chunk_tokens + prefilled_token_count += chunk_tokens if len(prompt_tokens) - 1 > chunk_tokens: while len(deferred_tokens) > 1: @@ -247,11 +247,11 @@ def _prepare_streamed_causal_strategy( static_inputs=static_inputs, chunk_ids=deferred_tokens[:chunk_size], prefill_cache=prefill_cache, - prefix_token_count=prefed_token_count, + prefix_token_count=prefilled_token_count, strategy_label=strategy_id.value, ) del deferred_tokens[:chunk_size] - prefed_token_count += chunk_size + prefilled_token_count += chunk_size if not prompt_tokens: raise PromptExecutionError( @@ -260,7 +260,7 @@ def _prepare_streamed_causal_strategy( final_inputs = dict(static_inputs) final_generate_kwargs = dict(generate_kwargs) - if prefed_token_count > 0: + if prefilled_token_count > 0: final_inputs["input_ids"] = token_tensor(deferred_tokens, device=runtime.device) final_inputs["attention_mask"] = ones_attention_mask( token_count=len(prompt_tokens), diff --git a/src/ollm/runtime/chunked_prefill_support.py b/src/ollm/runtime/chunked_prefill_support.py index 1025c38..c1d99e0 100644 --- a/src/ollm/runtime/chunked_prefill_support.py +++ b/src/ollm/runtime/chunked_prefill_support.py @@ -1,7 +1,7 @@ """Support helpers for chunked prompt-ingestion strategies.""" import re -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Mapping from inspect import Parameter, signature import torch @@ -101,20 +101,39 @@ def call_processor_for_static_inputs( audio_values: list[str], device: torch.device, ) -> dict[str, object]: - processor_signature = signature(processor.__call__) - accepts_kwargs = any( - parameter.kind is Parameter.VAR_KEYWORD - for parameter in processor_signature.parameters.values() - ) + try: + processor_signature = signature(processor.__call__) + except (TypeError, ValueError): + processor_signature = None + accepts_kwargs = True + else: + accepts_kwargs = any( + parameter.kind is Parameter.VAR_KEYWORD + for parameter in processor_signature.parameters.values() + ) kwargs: dict[str, object] = {"return_tensors": "pt"} - if image_values and ("images" in processor_signature.parameters or accepts_kwargs): + if image_values and ( + accepts_kwargs + or ( + processor_signature is not None + and "images" in processor_signature.parameters + ) + ): kwargs["images"] = image_values if audio_values: - if "audios" in processor_signature.parameters or accepts_kwargs: + if accepts_kwargs or ( + processor_signature is not None + and "audios" in processor_signature.parameters + ): kwargs["audios"] = audio_values - elif "audio" in processor_signature.parameters: + elif ( + processor_signature is not None + and "audio" in processor_signature.parameters + ): kwargs["audio"] = audio_values - if "text" in processor_signature.parameters or accepts_kwargs: + if accepts_kwargs or ( + processor_signature is not None and "text" in processor_signature.parameters + ): kwargs["text"] = [""] prepared = processor(**kwargs) return move_input_mapping(prepared, device) @@ -124,9 +143,9 @@ def move_input_mapping(value: object, device: torch.device) -> dict[str, object] to_method = getattr(value, "to", None) if callable(to_method): moved = to_method(device) - if isinstance(moved, dict): + if isinstance(moved, Mapping): return dict(moved) - if isinstance(value, dict): + if isinstance(value, Mapping): result: dict[str, object] = {} for key, item in value.items(): if isinstance(item, torch.Tensor): From 6e38eb8e12e4c1323cfca5a5787cee0ccc035805 Mon Sep 17 00:00:00 2001 From: beardedeagle Date: Fri, 3 Apr 2026 10:01:49 -0500 Subject: [PATCH 7/9] fix: tighten streamed chunk ingestion guards Guard non-positive chunk budgets, avoid duplicating the full prompt token list during streamed causal ingestion, prefer special-token-safe tokenizer encode fallbacks, and remove the now-dead prompt-token helper from execution tracing. Reverified with the full required gate from the final formatted tree: ruff format/check, standards checker, ty, compileall, pytest (436 passed), build, pip_audit, mkdocs --strict, and git diff --check. --- src/ollm/runtime/chunked_prefill.py | 18 +++++---- src/ollm/runtime/chunked_prefill_support.py | 6 ++- src/ollm/runtime/execution_trace.py | 8 ---- tests/test_chunked_prefill_scope.py | 42 ++++++++++++++++++++- 4 files changed, 56 insertions(+), 18 deletions(-) diff --git a/src/ollm/runtime/chunked_prefill.py b/src/ollm/runtime/chunked_prefill.py index 58c454e..b471296 100644 --- a/src/ollm/runtime/chunked_prefill.py +++ b/src/ollm/runtime/chunked_prefill.py @@ -161,6 +161,8 @@ def prepare_chunked_prefill( chunk_tokens: int, eager_input_builder: Callable[[LoadedRuntime, list[Message]], dict[str, object]], ) -> PreparedChunkedPrefill: + if chunk_tokens < 1: + raise ValueError("chunk_tokens must be at least 1") runtime_kind = ( runtime.plan.generic_model_kind or runtime.resolved_model.generic_model_kind ) @@ -196,7 +198,7 @@ def _prepare_streamed_causal_strategy( ) -> PreparedChunkedPrefill: rendered_prompt = render_prompt_text(runtime, messages) static_inputs = prepare_static_inputs(runtime, messages) - prompt_tokens: list[int] = [] + total_prompt_token_count = 0 deferred_tokens: list[int] = [] prefilled_token_count = 0 prefill_cache = generate_kwargs.get("past_key_values") @@ -211,7 +213,7 @@ def _prepare_streamed_causal_strategy( resolve_stream_tokenizer(runtime), rendered_prompt, ): - prompt_tokens.extend(token_piece) + total_prompt_token_count += len(token_piece) deferred_tokens.extend(token_piece) while len(deferred_tokens) > chunk_tokens + 1: if not callable(forward_method): @@ -232,7 +234,7 @@ def _prepare_streamed_causal_strategy( del deferred_tokens[:chunk_tokens] prefilled_token_count += chunk_tokens - if len(prompt_tokens) - 1 > chunk_tokens: + if total_prompt_token_count - 1 > chunk_tokens: while len(deferred_tokens) > 1: if not callable(forward_method): raise PromptExecutionError( @@ -253,7 +255,7 @@ def _prepare_streamed_causal_strategy( del deferred_tokens[:chunk_size] prefilled_token_count += chunk_size - if not prompt_tokens: + if total_prompt_token_count == 0: raise PromptExecutionError( "Chunked prompt ingestion produced no prompt tokens." ) @@ -263,7 +265,7 @@ def _prepare_streamed_causal_strategy( if prefilled_token_count > 0: final_inputs["input_ids"] = token_tensor(deferred_tokens, device=runtime.device) final_inputs["attention_mask"] = ones_attention_mask( - token_count=len(prompt_tokens), + token_count=total_prompt_token_count, device=runtime.device, ) final_generate_kwargs["past_key_values"] = prefill_cache @@ -276,9 +278,9 @@ def _prepare_streamed_causal_strategy( activation_reason="Bounded chunked prefill ran before final decode.", ) else: - final_inputs["input_ids"] = token_tensor(prompt_tokens, device=runtime.device) + final_inputs["input_ids"] = token_tensor(deferred_tokens, device=runtime.device) final_inputs["attention_mask"] = ones_attention_mask( - token_count=len(prompt_tokens), + token_count=total_prompt_token_count, device=runtime.device, ) scope = _scope( @@ -292,7 +294,7 @@ def _prepare_streamed_causal_strategy( inputs=final_inputs, generate_kwargs=final_generate_kwargs, scope=scope, - prompt_token_count=len(prompt_tokens), + prompt_token_count=total_prompt_token_count, ) diff --git a/src/ollm/runtime/chunked_prefill_support.py b/src/ollm/runtime/chunked_prefill_support.py index c1d99e0..3ef27a7 100644 --- a/src/ollm/runtime/chunked_prefill_support.py +++ b/src/ollm/runtime/chunked_prefill_support.py @@ -199,7 +199,11 @@ def tokenize_prompt_piece(tokenizer, piece_text: str) -> list[int]: return_attention_mask=False, ) except TypeError: - encoded = tokenizer(piece_text) + encode_method = getattr(tokenizer, "encode", None) + if callable(encode_method): + encoded = encode_method(piece_text, add_special_tokens=False) + else: + encoded = tokenizer(piece_text) if isinstance(encoded, dict): input_ids = encoded.get("input_ids") if isinstance(input_ids, list): diff --git a/src/ollm/runtime/execution_trace.py b/src/ollm/runtime/execution_trace.py index 4967e99..1f8ecb9 100644 --- a/src/ollm/runtime/execution_trace.py +++ b/src/ollm/runtime/execution_trace.py @@ -9,7 +9,6 @@ from ollm.kv_cache.state import KVCacheStateSnapshot from ollm.runtime.capability_discovery import GenericModelKind from ollm.runtime.chunked_prefill import ChunkedPrefillScopeSurface -from ollm.runtime.errors import PromptExecutionError from ollm.runtime.generation import ( build_runtime_generate_kwargs, decode_runtime_response, @@ -144,13 +143,6 @@ def _run_model_generate( ) -def _count_prompt_tokens(inputs: dict[str, object]) -> int: - input_ids = inputs.get("input_ids") - if not isinstance(input_ids, torch.Tensor): - raise PromptExecutionError("Benchmark probe expected tensor-backed input_ids") - return int(input_ids.shape[0] if input_ids.ndim == 1 else input_ids.shape[-1]) - - def _decode_prefix_token_count(inputs: dict[str, object]) -> int: input_ids = require_tensor(inputs["input_ids"]) return int(input_ids.shape[-1]) diff --git a/tests/test_chunked_prefill_scope.py b/tests/test_chunked_prefill_scope.py index 856f2fa..0d20da7 100644 --- a/tests/test_chunked_prefill_scope.py +++ b/tests/test_chunked_prefill_scope.py @@ -12,8 +12,12 @@ ChunkedPrefillGapId, ChunkedPrefillRecommendation, ChunkedPrefillStrategyId, + prepare_chunked_prefill, +) +from ollm.runtime.chunked_prefill_support import ( + build_forward_input_filter, + tokenize_prompt_piece, ) -from ollm.runtime.chunked_prefill_support import build_forward_input_filter from ollm.runtime.generation import ( build_runtime_generate_kwargs, prepare_runtime_generate_inputs, @@ -196,3 +200,39 @@ def __call__(self, *, input_ids, attention_mask): assert CountingForward.signature_reads == 1 assert set(first) == {"input_ids", "attention_mask"} assert set(second) == {"input_ids", "attention_mask"} + + +def test_prepare_chunked_prefill_rejects_non_positive_chunk_budget() -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=ChunkedPrefillModel(), + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + + with pytest.raises(ValueError, match="chunk_tokens must be at least 1"): + prepare_chunked_prefill( + runtime=runtime, + messages=request.messages, + generate_kwargs={}, + chunk_tokens=0, + eager_input_builder=lambda runtime, messages: { + "input_ids": torch.tensor([[1]]) + }, + ) + + +def test_tokenize_prompt_piece_disables_special_tokens_in_fallback() -> None: + class EncodeOnlyTokenizer: + def encode(self, piece_text: str, *, add_special_tokens: bool = True): + assert piece_text == "piece" + assert add_special_tokens is False + return [7, 8] + + def __call__(self, piece_text: str): + raise TypeError(piece_text) + + assert tokenize_prompt_piece(EncodeOnlyTokenizer(), "piece") == [7, 8] From d317fbf9f3d4857391ad2178cdf62acbd6977cd2 Mon Sep 17 00:00:00 2001 From: beardedeagle Date: Fri, 3 Apr 2026 10:20:47 -0500 Subject: [PATCH 8/9] fix: tighten chunked ingestion metadata and fallbacks Preserve backend-provided chunked-prefill metadata, blank boundary/mask metadata when no strategy is active, guard processor prompt rendering and static-input signatures, and restore the non-callable forward fallback. Also add regressions for the latest PR review findings and keep the repo standards checker green by moving the backend-metadata test into its own file. Reverified with the full required gate from the final formatted tree: ruff format/check, standards checker, ty, compileall, pytest (441 passed), build, pip_audit, mkdocs --strict, and git diff --check. --- src/ollm/runtime/chunked_prefill.py | 18 ++++ src/ollm/runtime/chunked_prefill_support.py | 28 ++++-- src/ollm/runtime/generation.py | 7 +- tests/test_benchmark_probe_execution.py | 4 +- tests/test_chunked_prefill_scope.py | 94 +++++++++++++++++++ tests/test_runtime_executor.py | 7 +- .../test_runtime_executor_backend_metadata.py | 41 ++++++++ tests/test_runtime_executor_prefill.py | 28 ++++++ 8 files changed, 211 insertions(+), 16 deletions(-) create mode 100644 tests/test_runtime_executor_backend_metadata.py diff --git a/src/ollm/runtime/chunked_prefill.py b/src/ollm/runtime/chunked_prefill.py index b471296..3f6a8c3 100644 --- a/src/ollm/runtime/chunked_prefill.py +++ b/src/ollm/runtime/chunked_prefill.py @@ -181,6 +181,24 @@ def prepare_chunked_prefill( ), prompt_token_count=count_prompt_tokens(inputs), ) + if ( + strategy.strategy_id + is not ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE + and not callable(getattr(runtime.model, "forward", None)) + ): + inputs = eager_input_builder(runtime, messages) + return PreparedChunkedPrefill( + inputs=inputs, + generate_kwargs=generate_kwargs, + scope=_scope( + strategy_id=strategy.strategy_id, + runtime_eligible=False, + activation_reason=( + "Chunked prompt-ingestion strategy requires a callable forward method." + ), + ), + prompt_token_count=count_prompt_tokens(inputs), + ) return strategy.prepare(runtime, messages, generate_kwargs, chunk_tokens) diff --git a/src/ollm/runtime/chunked_prefill_support.py b/src/ollm/runtime/chunked_prefill_support.py index 3ef27a7..8773e4b 100644 --- a/src/ollm/runtime/chunked_prefill_support.py +++ b/src/ollm/runtime/chunked_prefill_support.py @@ -23,15 +23,18 @@ def render_prompt_text(runtime: LoadedRuntime, messages: list[Message]) -> str: if runtime.processor is not None and hasattr( runtime.processor, "apply_chat_template" ): - rendered = runtime.processor.apply_chat_template( - transformers_messages, - add_generation_prompt=True, - tokenize=False, - return_dict=False, - return_tensors=None, - ) - if isinstance(rendered, str): - return rendered + try: + rendered = runtime.processor.apply_chat_template( + transformers_messages, + add_generation_prompt=True, + tokenize=False, + return_dict=False, + return_tensors=None, + ) + if isinstance(rendered, str): + return rendered + except (TypeError, ValueError, AttributeError): + pass if hasattr(runtime.tokenizer, "apply_chat_template"): try: rendered = runtime.tokenizer.apply_chat_template( @@ -111,7 +114,12 @@ def call_processor_for_static_inputs( parameter.kind is Parameter.VAR_KEYWORD for parameter in processor_signature.parameters.values() ) - kwargs: dict[str, object] = {"return_tensors": "pt"} + kwargs: dict[str, object] = {} + if accepts_kwargs or ( + processor_signature is not None + and "return_tensors" in processor_signature.parameters + ): + kwargs["return_tensors"] = "pt" if image_values and ( accepts_kwargs or ( diff --git a/src/ollm/runtime/generation.py b/src/ollm/runtime/generation.py index 4fbecd3..ec277cb 100644 --- a/src/ollm/runtime/generation.py +++ b/src/ollm/runtime/generation.py @@ -286,7 +286,8 @@ def _finalize_response( self, runtime: LoadedRuntime, response: PromptResponse ) -> PromptResponse: metadata = dict(response.metadata) - metadata.update(self._plan_metadata(runtime, None, None)) + for key, value in self._plan_metadata(runtime, None, None).items(): + metadata.setdefault(key, value) return PromptResponse( text=response.text, assistant_message=response.assistant_message, @@ -339,12 +340,12 @@ def _plan_metadata( ), "chunked_prefill_execution_boundary": ( "" - if chunked_prefill is None + if chunked_prefill is None or chunked_prefill.strategy_id is None else chunked_prefill.execution_boundary.value ), "chunked_prefill_attention_mask_mode": ( "" - if chunked_prefill is None + if chunked_prefill is None or chunked_prefill.strategy_id is None else chunked_prefill.attention_mask_mode.value ), } diff --git a/tests/test_benchmark_probe_execution.py b/tests/test_benchmark_probe_execution.py index c65f734..70a776f 100644 --- a/tests/test_benchmark_probe_execution.py +++ b/tests/test_benchmark_probe_execution.py @@ -169,7 +169,7 @@ def test_execute_request_probe_strips_processor_token_type_ids() -> None: assert execution.metrics.chunked_prefill.strategy_id is ( ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL ) - assert execution.metrics.chunked_prefill.runtime_eligible is True + assert execution.metrics.chunked_prefill.runtime_eligible is False assert execution.metrics.chunked_prefill.applied is False @@ -247,7 +247,7 @@ def test_execute_request_with_trace_reports_processor_counts() -> None: assert trace.chunked_prefill.strategy_id is ( ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL ) - assert trace.chunked_prefill.runtime_eligible is True + assert trace.chunked_prefill.runtime_eligible is False assert trace.chunked_prefill.applied is False diff --git a/tests/test_chunked_prefill_scope.py b/tests/test_chunked_prefill_scope.py index 0d20da7..24e2479 100644 --- a/tests/test_chunked_prefill_scope.py +++ b/tests/test_chunked_prefill_scope.py @@ -16,6 +16,8 @@ ) from ollm.runtime.chunked_prefill_support import ( build_forward_input_filter, + call_processor_for_static_inputs, + render_prompt_text, tokenize_prompt_piece, ) from ollm.runtime.generation import ( @@ -121,6 +123,36 @@ def test_prepare_runtime_generate_inputs_defers_seq2seq_source_prefill( ) +def test_prepare_runtime_generate_inputs_leaves_boundary_blank_without_strategy( + monkeypatch, +) -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=ChunkedPrefillModel(), + ) + runtime.plan = replace( + runtime.plan, + backend_id="custom-backend", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + generate_kwargs, _generation_config = build_runtime_generate_kwargs( + runtime, + request, + streamer=None, + ) + prepared_result = prepare_runtime_generate_inputs(runtime, request, generate_kwargs) + + assert prepared_result.scope.strategy_id is None + assert prepared_result.scope.runtime_eligible is False + + def test_t5_encoder_does_not_expose_cacheable_source_prefill() -> None: model = T5ForConditionalGeneration( T5Config( @@ -236,3 +268,65 @@ def __call__(self, piece_text: str): raise TypeError(piece_text) assert tokenize_prompt_piece(EncodeOnlyTokenizer(), "piece") == [7, 8] + + +def test_render_prompt_text_falls_back_when_processor_signature_differs() -> None: + class FragileProcessor: + def apply_chat_template(self, messages, tokenize): + del messages, tokenize + raise TypeError("different signature") + + class FallbackTokenizer: + def apply_chat_template( + self, + messages, + tokenize, + add_generation_prompt, + return_tensors=None, + return_dict=False, + ): + del messages, add_generation_prompt, return_dict, return_tensors + if not tokenize: + return "tokenizer-rendered" + return { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[1, 1, 1]]), + } + + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=FallbackTokenizer(), + model=ChunkedPrefillModel(), + ) + runtime.backend = replace( + runtime.backend, + processor=FragileProcessor(), + tokenizer=FallbackTokenizer(), + ) + + rendered = render_prompt_text( + runtime, + [Message(role=MessageRole.USER, content=[ContentPart.text("hello")])], + ) + + assert rendered == "tokenizer-rendered" + + +def test_call_processor_for_static_inputs_omits_return_tensors_when_unsupported() -> ( + None +): + class ReturnTensorsRejectingProcessor: + def __call__(self, *, images): + assert images == ["image.png"] + return {"pixel_values": torch.tensor([[[1.0]]])} + + prepared = call_processor_for_static_inputs( + processor=ReturnTensorsRejectingProcessor(), + image_values=["image.png"], + audio_values=[], + device=torch.device("cpu"), + ) + + pixel_values = prepared["pixel_values"] + assert isinstance(pixel_values, torch.Tensor) + assert torch.equal(pixel_values, torch.tensor([[[1.0]]])) diff --git a/tests/test_runtime_executor.py b/tests/test_runtime_executor.py index fe2ae62..ee5b6f9 100644 --- a/tests/test_runtime_executor.py +++ b/tests/test_runtime_executor.py @@ -3,7 +3,12 @@ import pytest import torch -from ollm.app.types import ContentPart, Message, MessageRole, PromptRequest +from ollm.app.types import ( + ContentPart, + Message, + MessageRole, + PromptRequest, +) from ollm.runtime.backends.base import BackendRuntime from ollm.runtime.capabilities import CapabilityProfile, SupportLevel from ollm.runtime.catalog import ModelModality diff --git a/tests/test_runtime_executor_backend_metadata.py b/tests/test_runtime_executor_backend_metadata.py new file mode 100644 index 0000000..92a35ff --- /dev/null +++ b/tests/test_runtime_executor_backend_metadata.py @@ -0,0 +1,41 @@ +import torch + +from ollm.app.types import ContentPart, Message, MessageRole, PromptResponse +from ollm.runtime.backends.base import BackendRuntime +from ollm.runtime.capabilities import CapabilityProfile, SupportLevel +from ollm.runtime.generation import RuntimeExecutor +from tests.test_runtime_executor import build_request, build_runtime + + +def test_runtime_executor_preserves_backend_chunked_prefill_metadata() -> None: + runtime = build_runtime(CapabilityProfile(support_level=SupportLevel.GENERIC)) + runtime.backend = BackendRuntime( + backend_id="test-backend", + model=None, + tokenizer=None, + processor=None, + device=torch.device("cpu"), + stats=None, + print_suppression_modules=(), + create_cache=lambda cache_dir, cache_strategy=None, cache_lifecycle=None, cache_window_tokens=None: ( + None + ), + apply_offload=lambda runtime_config: None, + execute_prompt=lambda request, sink: PromptResponse( + text="backend-response", + assistant_message=Message.assistant_text("backend-response"), + metadata={ + "chunked_prefill_strategy_id": "backend-strategy", + "chunked_prefill_activation_reason": "backend-owned", + }, + ), + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("hello")]), + ) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.metadata["chunked_prefill_strategy_id"] == "backend-strategy" + assert response.metadata["chunked_prefill_activation_reason"] == "backend-owned" diff --git a/tests/test_runtime_executor_prefill.py b/tests/test_runtime_executor_prefill.py index f7f12dc..18c3b7f 100644 --- a/tests/test_runtime_executor_prefill.py +++ b/tests/test_runtime_executor_prefill.py @@ -396,3 +396,31 @@ def test_runtime_executor_streams_seq2seq_source_strategy( generate_input_ids = model.generate_kwargs["input_ids"] assert isinstance(generate_input_ids, torch.Tensor) assert torch.equal(generate_input_ids, torch.tensor([[1, 2, 3, 4, 5]])) + + +def test_runtime_executor_leaves_chunked_prefill_metadata_blank_without_strategy( + monkeypatch, +) -> None: + model = ChunkedPrefillModel() + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=model, + ) + runtime.plan = replace( + runtime.plan, + backend_id="custom-backend", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.metadata["chunked_prefill_strategy_id"] == "" + assert response.metadata["chunked_prefill_runtime_eligible"] == "false" + assert response.metadata["chunked_prefill_execution_boundary"] == "" + assert response.metadata["chunked_prefill_attention_mask_mode"] == "" From 99c1f985b5b2cba367000f423987d84552823572 Mon Sep 17 00:00:00 2001 From: beardedeagle Date: Fri, 3 Apr 2026 10:23:14 -0500 Subject: [PATCH 9/9] test: cover chunked ingestion fallback behavior Add the non-callable forward fallback regression and rerun the full required gate from the final formatted tree: ruff format/check, standards checker, ty, compileall, pytest (442 passed), build, pip_audit, mkdocs --strict, and git diff --check. --- tests/test_runtime_executor_prefill.py | 38 ++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_runtime_executor_prefill.py b/tests/test_runtime_executor_prefill.py index 18c3b7f..2cf21e5 100644 --- a/tests/test_runtime_executor_prefill.py +++ b/tests/test_runtime_executor_prefill.py @@ -424,3 +424,41 @@ def test_runtime_executor_leaves_chunked_prefill_metadata_blank_without_strategy assert response.metadata["chunked_prefill_runtime_eligible"] == "false" assert response.metadata["chunked_prefill_execution_boundary"] == "" assert response.metadata["chunked_prefill_attention_mask_mode"] == "" + + +def test_runtime_executor_falls_back_when_forward_is_unavailable_for_chunking( + monkeypatch, +) -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=FakeModel(), + ) + runtime.plan = replace( + runtime.plan, + backend_id="optimized-native", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + runtime.model.forward = None + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.text == "long-decoded" + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT.value + ) + assert response.metadata["chunked_prefill_runtime_eligible"] == "false" + assert response.metadata["chunked_prefill_applied"] == "false" + assert ( + response.metadata["chunked_prefill_activation_reason"] + == "Chunked prompt-ingestion strategy requires a callable forward method." + ) + generate_input_ids = runtime.model.generate_kwargs["input_ids"] + assert isinstance(generate_input_ids, torch.Tensor) + assert torch.equal(generate_input_ids, torch.tensor([[1, 2, 3, 4, 5]]))