diff --git a/.beads/interactions.jsonl b/.beads/interactions.jsonl index 45b4e44..ab4fca6 100644 --- a/.beads/interactions.jsonl +++ b/.beads/interactions.jsonl @@ -8,3 +8,9 @@ {"id":"int-e2c83df6","kind":"field_change","created_at":"2026-04-03T04:07:29.73357Z","actor":"beardedeagle","issue_id":"ollm-cly","extra":{"field":"status","new_value":"closed","old_value":"in_progress","reason":"Completed"}} {"id":"int-7c080521","kind":"field_change","created_at":"2026-04-03T04:58:45.426115Z","actor":"beardedeagle","issue_id":"ollm-7zk","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} {"id":"int-bf6e89f2","kind":"field_change","created_at":"2026-04-03T06:33:55.374489Z","actor":"beardedeagle","issue_id":"ollm-nnt","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} +{"id":"int-a6481fd7","kind":"field_change","created_at":"2026-04-03T09:48:50.10993Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"closed","old_value":"in_progress","reason":"Completed"}} +{"id":"int-32d6d982","kind":"field_change","created_at":"2026-04-03T11:31:04.109777Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"open","old_value":"closed"}} +{"id":"int-9e450675","kind":"field_change","created_at":"2026-04-03T11:33:42.488488Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"in_progress","old_value":"open"}} +{"id":"int-5cd99253","kind":"field_change","created_at":"2026-04-03T11:52:07.85795Z","actor":"beardedeagle","issue_id":"ollm-6b7","extra":{"field":"status","new_value":"closed","old_value":"in_progress","reason":"Completed"}} +{"id":"int-25387d16","kind":"field_change","created_at":"2026-04-03T13:58:30.61251Z","actor":"beardedeagle","issue_id":"ollm-qm9","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} +{"id":"int-f18a669e","kind":"field_change","created_at":"2026-04-03T13:58:30.652694Z","actor":"beardedeagle","issue_id":"ollm-dnl","extra":{"field":"status","new_value":"closed","old_value":"open","reason":"Completed"}} diff --git a/README.md b/README.md index 2290b43..e2898b2 100644 --- a/README.md +++ b/README.md @@ -160,12 +160,24 @@ full-history KV in memory. When the bounded the recent-context token budget and oldest tokens are evicted once the window is exceeded. -On optimized-native decoder-only text runtimes, long prompts are ingested -through bounded prefill chunks before the final decode step. That keeps prompt -execution from growing one full prompt-wide activation step at a time on very -long inputs while preserving the external prompt/chat contract. Prompt-scaling -benchmarks remain the right place to evaluate the TTFT and memory tradeoff on -target hardware. +On the causal runtime lanes that support chunked prefill, long prompts are +ingested through bounded prefill chunks before the final decode step. The +current strategy lanes are: + +- `optimized-native-text` +- `optimized-native-multimodal` +- `transformers-generic-text` +- `transformers-generic-multimodal` +- `transformers-generic-seq2seq-source` + +That keeps prompt execution from growing one full prompt-wide activation step +at a time on very long inputs while preserving the external prompt/chat +contract. Prompt-scaling benchmarks remain the right place to evaluate the +TTFT and memory tradeoff on target hardware. Prompt tokenization now streams +from rendered prompt pieces inside the strategy path, prefix attention masks +are synthesized lazily per chunk, and seq2seq source prompts use the dedicated +`transformers-generic-seq2seq-source` lane instead of pretending they share the +causal-cache contract. Configuration layering uses an explicit precedence contract: diff --git a/docs/benchmarking.md b/docs/benchmarking.md index 353f262..288604d 100644 --- a/docs/benchmarking.md +++ b/docs/benchmarking.md @@ -179,7 +179,8 @@ Interpretation notes: - output throughput is generated output tokens divided by total generation latency - peak RSS includes a source label; long-lived warm/scaling/session probes use stage-local sampled peaks instead of process-lifetime peaks - allocator-gap metrics are reported as reserved-minus-allocated style slack when the backend exposes the required counters; unsupported backends serialize them as `null` -- optimized-native decoder-only prompt-scaling runs exercise bounded chunked prefill on long text prompts, so the prompt-length sweep is the intended place to inspect the memory versus TTFT tradeoff for this feature +- text prompt-scaling runs exercise bounded chunked prefill on the supported text strategy lanes, so the prompt-length sweep is the intended place to inspect the memory versus TTFT tradeoff for this feature +- request metrics also include a `chunked_prefill` section that states the selected strategy ID, whether the active runtime was eligible, whether the strategy actually ran, and the implemented execution boundary for streamed prompt tokenization plus lazy prefix-mask synthesis On loader-streamed families such as optimized Gemma3 on CPU, a long per-turn session-growth response can become dominated by repeated safetensor layer reads diff --git a/docs/guides/optimization.md b/docs/guides/optimization.md index 40011b4..1bcf79f 100644 --- a/docs/guides/optimization.md +++ b/docs/guides/optimization.md @@ -15,10 +15,19 @@ Native families: - `gpt-oss` - `voxtral` -Optimized-native decoder-only text prompts use bounded chunked prefill for -long prompt ingestion before the final decode step. This is a memory-control -path, not a blanket latency optimization, so prompt-scaling benchmarks are the -truthful way to evaluate whether the chunking tradeoff helps on a given host. +Supported causal runtime lanes use bounded chunked prefill for long prompt +ingestion before the final decode step. The current lanes are +`optimized-native-text`, `optimized-native-multimodal`, +`transformers-generic-text`, and `transformers-generic-multimodal`. +Encoder-decoder source prompts use the dedicated +`transformers-generic-seq2seq-source` lane. +This is a memory-control path, not a blanket latency optimization, so +prompt-scaling benchmarks are the truthful way to evaluate whether the chunking +tradeoff helps on a given host. + +Prompt tokenization now streams from rendered prompt pieces inside these +strategy handlers, and causal lanes synthesize prefix attention masks lazily +per chunk instead of materializing the full mask before ingestion starts. ### Transformers-generic Used for compatible local or materialized models that can run through the generic Transformers-backed path. diff --git a/src/ollm/runtime/benchmark/chunked_prefill_serialization.py b/src/ollm/runtime/benchmark/chunked_prefill_serialization.py new file mode 100644 index 0000000..85b8a43 --- /dev/null +++ b/src/ollm/runtime/benchmark/chunked_prefill_serialization.py @@ -0,0 +1,73 @@ +"""Chunked-prefill benchmark JSON parsing helpers.""" + +from collections.abc import Mapping +from typing import cast + +from ollm.runtime.chunked_prefill import ( + ChunkedPrefillAttentionMaskMode, + ChunkedPrefillExecutionBoundary, + ChunkedPrefillGapDecision, + ChunkedPrefillGapId, + ChunkedPrefillRecommendation, + ChunkedPrefillScopeSurface, + ChunkedPrefillStrategyId, +) + + +def parse_chunked_prefill( + value: object, + *, + require_bool, + require_object_mapping, + require_sequence, + require_string, +) -> ChunkedPrefillScopeSurface: + if not isinstance(value, Mapping): + raise ValueError("chunked_prefill must be an object") + payload = cast(Mapping[str, object], value) + gap_items = require_sequence(payload, "gap_inventory") + return ChunkedPrefillScopeSurface( + strategy_id=_optional_strategy_id(payload, require_string=require_string), + runtime_eligible=require_bool(payload, "runtime_eligible"), + applied=require_bool(payload, "applied"), + activation_reason=require_string(payload, "activation_reason"), + execution_boundary=ChunkedPrefillExecutionBoundary( + require_string(payload, "execution_boundary") + ), + attention_mask_mode=ChunkedPrefillAttentionMaskMode( + require_string(payload, "attention_mask_mode") + ), + gap_inventory=tuple( + parse_chunked_prefill_gap( + require_object_mapping(item, f"gap_inventory[{index}]"), + require_string=require_string, + ) + for index, item in enumerate(gap_items) + ), + ) + + +def parse_chunked_prefill_gap( + payload: Mapping[str, object], + *, + require_string, +) -> ChunkedPrefillGapDecision: + return ChunkedPrefillGapDecision( + gap_id=ChunkedPrefillGapId(require_string(payload, "gap_id")), + current_behavior=require_string(payload, "current_behavior"), + recommendation=ChunkedPrefillRecommendation( + require_string(payload, "recommendation") + ), + rationale=require_string(payload, "rationale"), + ) + + +def _optional_strategy_id( + payload: Mapping[str, object], + *, + require_string, +) -> ChunkedPrefillStrategyId | None: + value = payload.get("strategy_id") + if value is None: + return None + return ChunkedPrefillStrategyId(require_string(payload, "strategy_id")) diff --git a/src/ollm/runtime/benchmark/details.py b/src/ollm/runtime/benchmark/details.py index a3f7990..ec9fcb9 100644 --- a/src/ollm/runtime/benchmark/details.py +++ b/src/ollm/runtime/benchmark/details.py @@ -36,7 +36,6 @@ def build_cold_probe_details( def summarize_request_metrics(samples: list[RequestProbeMetrics]) -> dict[str, object]: """Summarize request-level runtime probe metrics.""" - from ollm.runtime.benchmark.offload_summary import summarize_request_offload return { @@ -82,6 +81,7 @@ def summarize_request_metrics(samples: list[RequestProbeMetrics]) -> dict[str, o ] ), }, + "chunked_prefill": samples[-1].chunked_prefill.to_dict(), "memory": summarize_stage_resources([sample.resources for sample in samples]), "cache": { "cache_mode": single_optional_string( diff --git a/src/ollm/runtime/benchmark/probe_execution.py b/src/ollm/runtime/benchmark/probe_execution.py index 2267fad..2d72bd2 100644 --- a/src/ollm/runtime/benchmark/probe_execution.py +++ b/src/ollm/runtime/benchmark/probe_execution.py @@ -217,6 +217,7 @@ def execute_request_probe( kv_cache_adaptation=kv_cache_adaptation, cache_dir_size_mb=cache_dir_size, cache_state=cache_state, + chunked_prefill=trace.chunked_prefill, allocator_gap_mb=allocator_gap_mb, allocator_gap_ratio=allocator_gap_ratio, native_runtime_profile=native_runtime_profile, diff --git a/src/ollm/runtime/benchmark/probe_serialization.py b/src/ollm/runtime/benchmark/probe_serialization.py index 3765df1..4e501d4 100644 --- a/src/ollm/runtime/benchmark/probe_serialization.py +++ b/src/ollm/runtime/benchmark/probe_serialization.py @@ -6,6 +6,7 @@ from ollm.kv_cache.matrix import KVCacheAdaptationSurface from ollm.kv_cache.state import KVCacheStateSnapshot +from ollm.runtime.benchmark.chunked_prefill_serialization import parse_chunked_prefill from ollm.runtime.benchmark.probe_types import ( EventTimingSummary, NativeRuntimeProfile, @@ -244,6 +245,13 @@ def _parse_request_probe_metrics(payload: Mapping[str, object]) -> RequestProbeM ), cache_dir_size_mb=_optional_float(payload, "cache_dir_size_mb"), cache_state=_parse_cache_state(payload.get("cache_state")), + chunked_prefill=parse_chunked_prefill( + payload.get("chunked_prefill"), + require_bool=_require_bool, + require_object_mapping=_require_object_mapping, + require_sequence=_require_sequence, + require_string=_require_string, + ), allocator_gap_mb=_optional_float(payload, "allocator_gap_mb"), allocator_gap_ratio=_optional_float(payload, "allocator_gap_ratio"), native_runtime_profile=_parse_native_runtime_profile( @@ -447,6 +455,13 @@ def _optional_int(payload: Mapping[str, object], key: str) -> int | None: raise ValueError(f"probe field '{key}' must be an integer or null") +def _require_bool(payload: Mapping[str, object], key: str) -> bool: + value = payload.get(key) + if isinstance(value, bool): + return value + raise ValueError(f"probe field '{key}' must be a boolean") + + def _require_string(payload: Mapping[str, object], key: str) -> str: value = payload.get(key) if isinstance(value, str): diff --git a/src/ollm/runtime/benchmark/probe_types.py b/src/ollm/runtime/benchmark/probe_types.py index 6562693..90fd5a1 100644 --- a/src/ollm/runtime/benchmark/probe_types.py +++ b/src/ollm/runtime/benchmark/probe_types.py @@ -5,6 +5,7 @@ from ollm.kv_cache.matrix import KVCacheAdaptationSurface from ollm.kv_cache.state import KVCacheStateSnapshot from ollm.runtime.benchmark.resources import StageResourceSnapshot +from ollm.runtime.chunked_prefill import ChunkedPrefillScopeSurface @dataclass(frozen=True, slots=True) @@ -50,6 +51,7 @@ class RequestProbeMetrics: kv_cache_adaptation: KVCacheAdaptationSurface | None cache_dir_size_mb: float | None cache_state: KVCacheStateSnapshot | None + chunked_prefill: ChunkedPrefillScopeSurface allocator_gap_mb: float | None allocator_gap_ratio: float | None native_runtime_profile: NativeRuntimeProfile | None @@ -71,6 +73,7 @@ def to_dict(self) -> dict[str, object]: payload["cache_state"] = ( None if self.cache_state is None else self.cache_state.to_dict() ) + payload["chunked_prefill"] = self.chunked_prefill.to_dict() payload["kv_cache_adaptation"] = ( None if self.kv_cache_adaptation is None diff --git a/src/ollm/runtime/chunked_prefill.py b/src/ollm/runtime/chunked_prefill.py new file mode 100644 index 0000000..3f6a8c3 --- /dev/null +++ b/src/ollm/runtime/chunked_prefill.py @@ -0,0 +1,467 @@ +"""Chunked prompt-ingestion strategies for runtime generation.""" + +from collections.abc import Callable +from dataclasses import asdict, dataclass, replace +from enum import StrEnum +from typing import Self + +from ollm.app.types import Message +from ollm.runtime.capability_discovery import GenericModelKind +from ollm.runtime.chunked_prefill_support import ( + build_forward_input_filter, + ones_attention_mask, + prepare_static_inputs, + prompt_token_id_pieces, + render_prompt_text, + resolve_stream_tokenizer, + run_causal_prefill_chunk, + token_tensor, +) +from ollm.runtime.chunked_prefill_support import ( + prompt_token_count as count_prompt_tokens, +) +from ollm.runtime.errors import PromptExecutionError +from ollm.runtime.loaded_runtime import LoadedRuntime + + +class ChunkedPrefillStrategyId(StrEnum): + OPTIMIZED_NATIVE_TEXT = "optimized-native-text" + OPTIMIZED_NATIVE_MULTIMODAL = "optimized-native-multimodal" + TRANSFORMERS_GENERIC_TEXT = "transformers-generic-text" + TRANSFORMERS_GENERIC_MULTIMODAL = "transformers-generic-multimodal" + TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE = "transformers-generic-seq2seq-source" + + +class ChunkedPrefillGapId(StrEnum): + PROMPT_TOKENIZATION_BEFORE_PREFILL = "prompt-tokenization-before-prefill" + FULL_ATTENTION_MASK_BEFORE_PREFILL = "full-attention-mask-before-prefill" + SEQ2SEQ_SOURCE_PREFILL = "seq2seq-source-prefill" + + +class ChunkedPrefillRecommendation(StrEnum): + IMPLEMENT = "implement" + DEFER = "defer" + REJECT = "reject" + + +class ChunkedPrefillExecutionBoundary(StrEnum): + STREAMED_PROMPT_PREPARATION = "streamed-prompt-preparation" + + +class ChunkedPrefillAttentionMaskMode(StrEnum): + LAZY_PREFIX_SYNTHESIS = "lazy-prefix-synthesis" + + +@dataclass(frozen=True, slots=True) +class ChunkedPrefillGapDecision: + gap_id: ChunkedPrefillGapId + current_behavior: str + recommendation: ChunkedPrefillRecommendation + rationale: str + + def to_dict(self) -> dict[str, str]: + payload = asdict(self) + return {key: str(value) for key, value in payload.items()} + + +@dataclass(frozen=True, slots=True) +class ChunkedPrefillScopeSurface: + strategy_id: ChunkedPrefillStrategyId | None + runtime_eligible: bool + applied: bool + activation_reason: str + execution_boundary: ChunkedPrefillExecutionBoundary + attention_mask_mode: ChunkedPrefillAttentionMaskMode + gap_inventory: tuple[ChunkedPrefillGapDecision, ...] + + def with_activation(self, *, applied: bool, activation_reason: str) -> Self: + return replace( + self, + applied=applied, + activation_reason=activation_reason, + ) + + def to_dict(self) -> dict[str, object]: + return { + "strategy_id": None if self.strategy_id is None else self.strategy_id.value, + "runtime_eligible": self.runtime_eligible, + "applied": self.applied, + "activation_reason": self.activation_reason, + "execution_boundary": self.execution_boundary.value, + "attention_mask_mode": self.attention_mask_mode.value, + "gap_inventory": [decision.to_dict() for decision in self.gap_inventory], + } + + +@dataclass(frozen=True, slots=True) +class PreparedChunkedPrefill: + inputs: dict[str, object] + generate_kwargs: dict[str, object] + scope: ChunkedPrefillScopeSurface + prompt_token_count: int + + +@dataclass(frozen=True, slots=True) +class ChunkedPrefillStrategy: + strategy_id: ChunkedPrefillStrategyId + matches: Callable[[LoadedRuntime, GenericModelKind | None], bool] + prepare: Callable[ + [LoadedRuntime, list[Message], dict[str, object], int], + PreparedChunkedPrefill, + ] + + +_CHUNKED_PREFILL_GAP_INVENTORY = ( + ChunkedPrefillGapDecision( + gap_id=ChunkedPrefillGapId.PROMPT_TOKENIZATION_BEFORE_PREFILL, + current_behavior=( + "Supported strategies render the prompt template once, then tokenize " + "prompt pieces incrementally during strategy execution." + ), + recommendation=ChunkedPrefillRecommendation.IMPLEMENT, + rationale=( + "Prompt tokenization no longer has to complete as one full prompt-wide " + "step before the chunked strategy begins." + ), + ), + ChunkedPrefillGapDecision( + gap_id=ChunkedPrefillGapId.FULL_ATTENTION_MASK_BEFORE_PREFILL, + current_behavior=( + "Causal chunked strategies synthesize prefix attention masks per " + "chunk and materialize the full mask only at the final generate " + "handoff when the runtime still requires it." + ), + recommendation=ChunkedPrefillRecommendation.IMPLEMENT, + rationale=( + "Full prompt attention masks are no longer built before chunked " + "prompt ingestion begins." + ), + ), + ChunkedPrefillGapDecision( + gap_id=ChunkedPrefillGapId.SEQ2SEQ_SOURCE_PREFILL, + current_behavior=( + "Seq2seq source prompts now use a dedicated streamed source-ingestion " + "strategy instead of pretending they share the causal-cache prefill " + "contract." + ), + recommendation=ChunkedPrefillRecommendation.IMPLEMENT, + rationale=( + "Seq2seq now has its own explicit strategy lane rather than being left " + "unsupported." + ), + ), +) + + +def prepare_chunked_prefill( + *, + runtime: LoadedRuntime, + messages: list[Message], + generate_kwargs: dict[str, object], + chunk_tokens: int, + eager_input_builder: Callable[[LoadedRuntime, list[Message]], dict[str, object]], +) -> PreparedChunkedPrefill: + if chunk_tokens < 1: + raise ValueError("chunk_tokens must be at least 1") + runtime_kind = ( + runtime.plan.generic_model_kind or runtime.resolved_model.generic_model_kind + ) + strategy = _resolve_strategy(runtime, runtime_kind) + if strategy is None: + inputs = eager_input_builder(runtime, messages) + return PreparedChunkedPrefill( + inputs=inputs, + generate_kwargs=generate_kwargs, + scope=_scope( + strategy_id=None, + runtime_eligible=False, + activation_reason=( + "No chunked prompt-ingestion strategy matched the active runtime." + ), + ), + prompt_token_count=count_prompt_tokens(inputs), + ) + if ( + strategy.strategy_id + is not ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE + and not callable(getattr(runtime.model, "forward", None)) + ): + inputs = eager_input_builder(runtime, messages) + return PreparedChunkedPrefill( + inputs=inputs, + generate_kwargs=generate_kwargs, + scope=_scope( + strategy_id=strategy.strategy_id, + runtime_eligible=False, + activation_reason=( + "Chunked prompt-ingestion strategy requires a callable forward method." + ), + ), + prompt_token_count=count_prompt_tokens(inputs), + ) + return strategy.prepare(runtime, messages, generate_kwargs, chunk_tokens) + + +def chunked_prefill_gap_inventory() -> tuple[ChunkedPrefillGapDecision, ...]: + return _CHUNKED_PREFILL_GAP_INVENTORY + + +def _prepare_streamed_causal_strategy( + runtime: LoadedRuntime, + messages: list[Message], + generate_kwargs: dict[str, object], + chunk_tokens: int, + *, + strategy_id: ChunkedPrefillStrategyId, +) -> PreparedChunkedPrefill: + rendered_prompt = render_prompt_text(runtime, messages) + static_inputs = prepare_static_inputs(runtime, messages) + total_prompt_token_count = 0 + deferred_tokens: list[int] = [] + prefilled_token_count = 0 + prefill_cache = generate_kwargs.get("past_key_values") + forward_method = getattr(runtime.model, "forward", None) + forward_input_filter = ( + None + if not callable(forward_method) + else build_forward_input_filter(forward_method) + ) + + for token_piece in prompt_token_id_pieces( + resolve_stream_tokenizer(runtime), + rendered_prompt, + ): + total_prompt_token_count += len(token_piece) + deferred_tokens.extend(token_piece) + while len(deferred_tokens) > chunk_tokens + 1: + if not callable(forward_method): + raise PromptExecutionError( + f"Chunked prompt-ingestion strategy {strategy_id.value!r} requires a callable forward method." + ) + assert forward_input_filter is not None + prefill_cache = run_causal_prefill_chunk( + runtime=runtime, + forward_method=forward_method, + forward_input_filter=forward_input_filter, + static_inputs=static_inputs, + chunk_ids=deferred_tokens[:chunk_tokens], + prefill_cache=prefill_cache, + prefix_token_count=prefilled_token_count, + strategy_label=strategy_id.value, + ) + del deferred_tokens[:chunk_tokens] + prefilled_token_count += chunk_tokens + + if total_prompt_token_count - 1 > chunk_tokens: + while len(deferred_tokens) > 1: + if not callable(forward_method): + raise PromptExecutionError( + f"Chunked prompt-ingestion strategy {strategy_id.value!r} requires a callable forward method." + ) + chunk_size = min(chunk_tokens, len(deferred_tokens) - 1) + assert forward_input_filter is not None + prefill_cache = run_causal_prefill_chunk( + runtime=runtime, + forward_method=forward_method, + forward_input_filter=forward_input_filter, + static_inputs=static_inputs, + chunk_ids=deferred_tokens[:chunk_size], + prefill_cache=prefill_cache, + prefix_token_count=prefilled_token_count, + strategy_label=strategy_id.value, + ) + del deferred_tokens[:chunk_size] + prefilled_token_count += chunk_size + + if total_prompt_token_count == 0: + raise PromptExecutionError( + "Chunked prompt ingestion produced no prompt tokens." + ) + + final_inputs = dict(static_inputs) + final_generate_kwargs = dict(generate_kwargs) + if prefilled_token_count > 0: + final_inputs["input_ids"] = token_tensor(deferred_tokens, device=runtime.device) + final_inputs["attention_mask"] = ones_attention_mask( + token_count=total_prompt_token_count, + device=runtime.device, + ) + final_generate_kwargs["past_key_values"] = prefill_cache + scope = _scope( + strategy_id=strategy_id, + runtime_eligible=True, + activation_reason="Bounded chunked prefill ran before final decode.", + ).with_activation( + applied=True, + activation_reason="Bounded chunked prefill ran before final decode.", + ) + else: + final_inputs["input_ids"] = token_tensor(deferred_tokens, device=runtime.device) + final_inputs["attention_mask"] = ones_attention_mask( + token_count=total_prompt_token_count, + device=runtime.device, + ) + scope = _scope( + strategy_id=strategy_id, + runtime_eligible=True, + activation_reason=( + "Prompt length does not exceed the chunked-prefill threshold." + ), + ) + return PreparedChunkedPrefill( + inputs=final_inputs, + generate_kwargs=final_generate_kwargs, + scope=scope, + prompt_token_count=total_prompt_token_count, + ) + + +def _prepare_seq2seq_source_strategy( + runtime: LoadedRuntime, + messages: list[Message], + generate_kwargs: dict[str, object], + chunk_tokens: int, +) -> PreparedChunkedPrefill: + rendered_prompt = render_prompt_text(runtime, messages) + prompt_tokens = [ + token_id + for token_piece in prompt_token_id_pieces( + resolve_stream_tokenizer(runtime), + rendered_prompt, + ) + for token_id in token_piece + ] + if not prompt_tokens: + raise PromptExecutionError( + "Seq2seq source ingestion produced no prompt tokens." + ) + strategy_id = ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE + applied = len(prompt_tokens) > chunk_tokens + activation_reason = ( + "Streamed seq2seq source tokens were built incrementally before encoder generation." + if applied + else "Prompt length does not exceed the streamed seq2seq source threshold." + ) + return PreparedChunkedPrefill( + inputs={ + "input_ids": token_tensor(prompt_tokens, device=runtime.device), + "attention_mask": ones_attention_mask( + token_count=len(prompt_tokens), + device=runtime.device, + ), + }, + generate_kwargs=generate_kwargs, + scope=_scope( + strategy_id=strategy_id, + runtime_eligible=True, + activation_reason=activation_reason, + ).with_activation(applied=applied, activation_reason=activation_reason), + prompt_token_count=len(prompt_tokens), + ) + + +_CHUNKED_PREFILL_STRATEGIES = ( + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "optimized-native" + and runtime.processor is None + and runtime_kind is GenericModelKind.CAUSAL_LM + ), + prepare=lambda runtime, messages, generate_kwargs, chunk_tokens: ( + _prepare_streamed_causal_strategy( + runtime, + messages, + generate_kwargs, + chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT, + ) + ), + ), + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "optimized-native" + and runtime.processor is not None + and runtime_kind is not GenericModelKind.SEQ2SEQ_LM + ), + prepare=lambda runtime, messages, generate_kwargs, chunk_tokens: ( + _prepare_streamed_causal_strategy( + runtime, + messages, + generate_kwargs, + chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL, + ) + ), + ), + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "transformers-generic" + and runtime.processor is None + and runtime_kind is GenericModelKind.CAUSAL_LM + ), + prepare=lambda runtime, messages, generate_kwargs, chunk_tokens: ( + _prepare_streamed_causal_strategy( + runtime, + messages, + generate_kwargs, + chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT, + ) + ), + ), + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "transformers-generic" + and runtime.processor is not None + and runtime_kind is not GenericModelKind.SEQ2SEQ_LM + ), + prepare=lambda runtime, messages, generate_kwargs, chunk_tokens: ( + _prepare_streamed_causal_strategy( + runtime, + messages, + generate_kwargs, + chunk_tokens, + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL, + ) + ), + ), + ChunkedPrefillStrategy( + strategy_id=ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE, + matches=lambda runtime, runtime_kind: ( + runtime.plan.backend_id == "transformers-generic" + and runtime_kind is GenericModelKind.SEQ2SEQ_LM + ), + prepare=_prepare_seq2seq_source_strategy, + ), +) + + +def _resolve_strategy( + runtime: LoadedRuntime, + runtime_kind: GenericModelKind | None, +) -> ChunkedPrefillStrategy | None: + for strategy in _CHUNKED_PREFILL_STRATEGIES: + if strategy.matches(runtime, runtime_kind): + return strategy + return None + + +def _scope( + *, + strategy_id: ChunkedPrefillStrategyId | None, + runtime_eligible: bool, + activation_reason: str, +) -> ChunkedPrefillScopeSurface: + return ChunkedPrefillScopeSurface( + strategy_id=strategy_id, + runtime_eligible=runtime_eligible, + applied=False, + activation_reason=activation_reason, + execution_boundary=ChunkedPrefillExecutionBoundary.STREAMED_PROMPT_PREPARATION, + attention_mask_mode=ChunkedPrefillAttentionMaskMode.LAZY_PREFIX_SYNTHESIS, + gap_inventory=_CHUNKED_PREFILL_GAP_INVENTORY, + ) diff --git a/src/ollm/runtime/chunked_prefill_support.py b/src/ollm/runtime/chunked_prefill_support.py new file mode 100644 index 0000000..8773e4b --- /dev/null +++ b/src/ollm/runtime/chunked_prefill_support.py @@ -0,0 +1,301 @@ +"""Support helpers for chunked prompt-ingestion strategies.""" + +import re +from collections.abc import Callable, Iterable, Mapping +from inspect import Parameter, signature + +import torch + +from ollm.app.types import ContentKind, Message +from ollm.runtime.errors import PromptExecutionError +from ollm.runtime.generation_support import render_plain_prompt +from ollm.runtime.loaded_runtime import LoadedRuntime +from ollm.runtime.output_control import suppress_module_prints + + +def render_prompt_text(runtime: LoadedRuntime, messages: list[Message]) -> str: + transformers_messages = [ + message.as_transformers_message( + structured_content=runtime.processor is not None + ) + for message in messages + ] + if runtime.processor is not None and hasattr( + runtime.processor, "apply_chat_template" + ): + try: + rendered = runtime.processor.apply_chat_template( + transformers_messages, + add_generation_prompt=True, + tokenize=False, + return_dict=False, + return_tensors=None, + ) + if isinstance(rendered, str): + return rendered + except (TypeError, ValueError, AttributeError): + pass + if hasattr(runtime.tokenizer, "apply_chat_template"): + try: + rendered = runtime.tokenizer.apply_chat_template( + transformers_messages, + tokenize=False, + add_generation_prompt=True, + return_dict=False, + ) + if isinstance(rendered, str): + return rendered + except (TypeError, ValueError, AttributeError): + pass + return render_plain_prompt(messages) + + +def resolve_stream_tokenizer(runtime: LoadedRuntime): + if ( + runtime.processor is not None + and getattr(runtime.processor, "tokenizer", None) is not None + ): + return runtime.processor.tokenizer + return runtime.tokenizer + + +def prepare_static_inputs( + runtime: LoadedRuntime, + messages: list[Message], +) -> dict[str, object]: + if runtime.processor is None: + return {} + custom_builder = getattr( + runtime.processor, "prepare_chunked_prefill_static_inputs", None + ) + if callable(custom_builder): + prepared = custom_builder(messages, runtime.device) + return dict(prepared) + image_values = [ + part.value + for message in messages + for part in message.content + if part.kind is ContentKind.IMAGE + ] + audio_values = [ + part.value + for message in messages + for part in message.content + if part.kind is ContentKind.AUDIO + ] + if not image_values and not audio_values: + return {} + prepared = call_processor_for_static_inputs( + processor=runtime.processor, + image_values=image_values, + audio_values=audio_values, + device=runtime.device, + ) + prepared.pop("input_ids", None) + prepared.pop("attention_mask", None) + prepared.pop("token_type_ids", None) + return prepared + + +def call_processor_for_static_inputs( + *, + processor, + image_values: list[str], + audio_values: list[str], + device: torch.device, +) -> dict[str, object]: + try: + processor_signature = signature(processor.__call__) + except (TypeError, ValueError): + processor_signature = None + accepts_kwargs = True + else: + accepts_kwargs = any( + parameter.kind is Parameter.VAR_KEYWORD + for parameter in processor_signature.parameters.values() + ) + kwargs: dict[str, object] = {} + if accepts_kwargs or ( + processor_signature is not None + and "return_tensors" in processor_signature.parameters + ): + kwargs["return_tensors"] = "pt" + if image_values and ( + accepts_kwargs + or ( + processor_signature is not None + and "images" in processor_signature.parameters + ) + ): + kwargs["images"] = image_values + if audio_values: + if accepts_kwargs or ( + processor_signature is not None + and "audios" in processor_signature.parameters + ): + kwargs["audios"] = audio_values + elif ( + processor_signature is not None + and "audio" in processor_signature.parameters + ): + kwargs["audio"] = audio_values + if accepts_kwargs or ( + processor_signature is not None and "text" in processor_signature.parameters + ): + kwargs["text"] = [""] + prepared = processor(**kwargs) + return move_input_mapping(prepared, device) + + +def move_input_mapping(value: object, device: torch.device) -> dict[str, object]: + to_method = getattr(value, "to", None) + if callable(to_method): + moved = to_method(device) + if isinstance(moved, Mapping): + return dict(moved) + if isinstance(value, Mapping): + result: dict[str, object] = {} + for key, item in value.items(): + if isinstance(item, torch.Tensor): + result[str(key)] = item.to(device) + else: + result[str(key)] = item + return result + raise PromptExecutionError("Chunked prompt ingestion expected mapping-like inputs.") + + +def prompt_token_id_pieces(tokenizer, rendered_prompt: str) -> Iterable[list[int]]: + custom_streamer = getattr(tokenizer, "stream_tokenize_prompt", None) + if callable(custom_streamer): + for piece in custom_streamer(rendered_prompt): + ids = list(piece) + if ids: + yield ids + return + for piece_text in prompt_piece_texts(tokenizer, rendered_prompt): + piece_ids = tokenize_prompt_piece(tokenizer, piece_text) + if piece_ids: + yield piece_ids + + +def prompt_piece_texts(tokenizer, rendered_prompt: str) -> tuple[str, ...]: + backend_tokenizer = getattr(tokenizer, "backend_tokenizer", None) + pre_tokenizer = ( + None + if backend_tokenizer is None + else getattr(backend_tokenizer, "pre_tokenizer", None) + ) + if pre_tokenizer is not None and hasattr(pre_tokenizer, "pre_tokenize_str"): + pieces = pre_tokenizer.pre_tokenize_str(rendered_prompt) + if pieces: + return tuple(text for text, _offset in pieces if text) + regex_pieces = tuple( + match.group(0) + for match in re.finditer(r"\S+\s*|\s+", rendered_prompt) + if match.group(0) + ) + if regex_pieces: + return regex_pieces + return (rendered_prompt,) + + +def tokenize_prompt_piece(tokenizer, piece_text: str) -> list[int]: + try: + encoded = tokenizer( + piece_text, + add_special_tokens=False, + return_attention_mask=False, + ) + except TypeError: + encode_method = getattr(tokenizer, "encode", None) + if callable(encode_method): + encoded = encode_method(piece_text, add_special_tokens=False) + else: + encoded = tokenizer(piece_text) + if isinstance(encoded, dict): + input_ids = encoded.get("input_ids") + if isinstance(input_ids, list): + if input_ids and isinstance(input_ids[0], list): + return [int(token_id) for token_id in input_ids[0]] + return [int(token_id) for token_id in input_ids] + if isinstance(input_ids, torch.Tensor): + return [int(token_id) for token_id in input_ids.reshape(-1).tolist()] + if isinstance(encoded, list): + return [int(token_id) for token_id in encoded] + raise PromptExecutionError( + "Chunked prompt tokenization produced unsupported input ids." + ) + + +def run_causal_prefill_chunk( + *, + runtime: LoadedRuntime, + forward_method, + forward_input_filter: Callable[[dict[str, object]], dict[str, object]], + static_inputs: dict[str, object], + chunk_ids: list[int], + prefill_cache: object, + prefix_token_count: int, + strategy_label: str, +) -> object: + forward_inputs: dict[str, object] = dict(static_inputs) + forward_inputs["input_ids"] = token_tensor(chunk_ids, device=runtime.device) + forward_inputs["attention_mask"] = ones_attention_mask( + token_count=prefix_token_count + len(chunk_ids), + device=runtime.device, + ) + forward_inputs["use_cache"] = True + forward_inputs["cache_position"] = torch.arange( + prefix_token_count, + prefix_token_count + len(chunk_ids), + device=runtime.device, + dtype=torch.long, + ) + if prefill_cache is not None: + forward_inputs["past_key_values"] = prefill_cache + filtered_inputs = forward_input_filter(forward_inputs) + with torch.inference_mode(): + with suppress_module_prints(runtime.backend.print_suppression_modules): + outputs = forward_method(**filtered_inputs) + next_cache = getattr(outputs, "past_key_values", None) + if next_cache is None: + raise PromptExecutionError( + "Chunked prompt-ingestion strategy " + f"{strategy_label!r} requires a runtime that returns past_key_values." + ) + return next_cache + + +def build_forward_input_filter( + forward_method, +) -> Callable[[dict[str, object]], dict[str, object]]: + try: + method_signature = signature(forward_method) + except (TypeError, ValueError): + return lambda inputs: inputs + if any( + parameter.kind is Parameter.VAR_KEYWORD + for parameter in method_signature.parameters.values() + ): + return lambda inputs: inputs + supported_keys = frozenset(method_signature.parameters) + return lambda inputs: { + key: value for key, value in inputs.items() if key in supported_keys + } + + +def token_tensor(token_ids: list[int], *, device: torch.device) -> torch.Tensor: + return torch.tensor([token_ids], device=device, dtype=torch.long) + + +def ones_attention_mask(token_count: int, *, device: torch.device) -> torch.Tensor: + return torch.ones((1, token_count), device=device, dtype=torch.long) + + +def prompt_token_count(inputs: dict[str, object]) -> int: + input_ids = inputs.get("input_ids") + if not isinstance(input_ids, torch.Tensor): + raise PromptExecutionError( + "Chunked prompt preparation expected tensor-backed input_ids." + ) + return int(input_ids.shape[-1]) diff --git a/src/ollm/runtime/execution_trace.py b/src/ollm/runtime/execution_trace.py index e6cbf76..1f8ecb9 100644 --- a/src/ollm/runtime/execution_trace.py +++ b/src/ollm/runtime/execution_trace.py @@ -8,10 +8,9 @@ from ollm.app.types import PromptRequest from ollm.kv_cache.state import KVCacheStateSnapshot from ollm.runtime.capability_discovery import GenericModelKind -from ollm.runtime.errors import PromptExecutionError +from ollm.runtime.chunked_prefill import ChunkedPrefillScopeSurface from ollm.runtime.generation import ( build_runtime_generate_kwargs, - build_runtime_inputs, decode_runtime_response, prepare_runtime_generate_inputs, validate_runtime_request, @@ -35,6 +34,7 @@ class RuntimeExecutionTrace: output_token_count: int response_text: str cache_state: KVCacheStateSnapshot | None + chunked_prefill: ChunkedPrefillScopeSurface def execute_request_with_trace( @@ -49,19 +49,19 @@ def execute_request_with_trace( if request.generation_config.seed is not None: torch.manual_seed(request.generation_config.seed) - inputs = build_runtime_inputs(runtime, request.messages) - prompt_token_count = _count_prompt_tokens(inputs) generate_kwargs, generation_config = build_runtime_generate_kwargs( runtime, request, streamer ) - prepared_inputs = normalize_generate_inputs(inputs) generation_started_at = time.perf_counter() - prepared_inputs, prepared_generate_kwargs = prepare_runtime_generate_inputs( + prepared_result = prepare_runtime_generate_inputs( runtime, request, - prepared_inputs, generate_kwargs, ) + prompt_token_count = prepared_result.prompt_token_count + prepared_inputs = normalize_generate_inputs(prepared_result.inputs) + prepared_generate_kwargs = prepared_result.generate_kwargs + chunked_prefill = prepared_result.scope outputs, effective_generate_kwargs = _generate_outputs( runtime=runtime, prepared_inputs=prepared_inputs, @@ -90,6 +90,7 @@ def execute_request_with_trace( output_token_count=output_token_count, response_text=response_text, cache_state=cache_state, + chunked_prefill=chunked_prefill, ) @@ -142,13 +143,6 @@ def _run_model_generate( ) -def _count_prompt_tokens(inputs: dict[str, object]) -> int: - input_ids = inputs.get("input_ids") - if not isinstance(input_ids, torch.Tensor): - raise PromptExecutionError("Benchmark probe expected tensor-backed input_ids") - return int(input_ids.shape[0] if input_ids.ndim == 1 else input_ids.shape[-1]) - - def _decode_prefix_token_count(inputs: dict[str, object]) -> int: input_ids = require_tensor(inputs["input_ids"]) return int(input_ids.shape[-1]) diff --git a/src/ollm/runtime/generation.py b/src/ollm/runtime/generation.py index 9f487cb..ec277cb 100644 --- a/src/ollm/runtime/generation.py +++ b/src/ollm/runtime/generation.py @@ -13,6 +13,11 @@ from ollm.kv_cache.state import KVCacheStateSnapshot from ollm.runtime.capability_discovery import GenericModelKind from ollm.runtime.catalog import ModelModality +from ollm.runtime.chunked_prefill import ( + ChunkedPrefillScopeSurface, + PreparedChunkedPrefill, + prepare_chunked_prefill, +) from ollm.runtime.errors import PromptExecutionError from ollm.runtime.generation_config_support import ( clear_sampling_fields, @@ -182,75 +187,15 @@ def build_runtime_generate_kwargs( def prepare_runtime_generate_inputs( runtime: LoadedRuntime, request: PromptRequest, - inputs: dict[str, object], generate_kwargs: dict[str, object], -) -> tuple[dict[str, object], dict[str, object]]: - input_ids_value = inputs.get("input_ids") - if not isinstance(input_ids_value, torch.Tensor): - return inputs, generate_kwargs - if input_ids_value.ndim != 2 or input_ids_value.shape[0] != 1: - return inputs, generate_kwargs - if runtime.processor is not None: - return inputs, generate_kwargs - if runtime.plan.backend_id != "optimized-native": - return inputs, generate_kwargs - runtime_kind = ( - runtime.plan.generic_model_kind or runtime.resolved_model.generic_model_kind +) -> PreparedChunkedPrefill: + return prepare_chunked_prefill( + runtime=runtime, + messages=request.messages, + generate_kwargs=generate_kwargs, + chunk_tokens=DEFAULT_PREFILL_CHUNK_TOKENS, + eager_input_builder=build_runtime_inputs, ) - if runtime_kind is not GenericModelKind.CAUSAL_LM: - return inputs, generate_kwargs - prefill_token_count = input_ids_value.shape[1] - 1 - if prefill_token_count <= DEFAULT_PREFILL_CHUNK_TOKENS: - return inputs, generate_kwargs - return _run_chunked_prefill(runtime, inputs, generate_kwargs) - - -def _run_chunked_prefill( - runtime: LoadedRuntime, - inputs: dict[str, object], - generate_kwargs: dict[str, object], -) -> tuple[dict[str, object], dict[str, object]]: - forward_method = getattr(runtime.model, "forward", None) - if not callable(forward_method): - return inputs, generate_kwargs - input_ids = require_tensor(inputs["input_ids"]) - attention_mask_value = inputs.get("attention_mask") - attention_mask = ( - None if attention_mask_value is None else require_tensor(attention_mask_value) - ) - prefill_cache = generate_kwargs.get("past_key_values") - prefill_end = input_ids.shape[1] - 1 - with torch.inference_mode(): - with suppress_module_prints(runtime.backend.print_suppression_modules): - for chunk_start in range(0, prefill_end, DEFAULT_PREFILL_CHUNK_TOKENS): - chunk_end = min(prefill_end, chunk_start + DEFAULT_PREFILL_CHUNK_TOKENS) - forward_inputs: dict[str, object] = { - "input_ids": input_ids[:, chunk_start:chunk_end], - "use_cache": True, - "cache_position": torch.arange( - chunk_start, - chunk_end, - device=input_ids.device, - dtype=torch.long, - ), - } - if attention_mask is not None: - forward_inputs["attention_mask"] = attention_mask[:, :chunk_end] - if prefill_cache is not None: - forward_inputs["past_key_values"] = prefill_cache - outputs = forward_method(**forward_inputs) - prefill_cache = getattr(outputs, "past_key_values", None) - if prefill_cache is None: - raise PromptExecutionError( - "Chunked prefill requires a causal runtime that returns past_key_values." - ) - updated_inputs = dict(inputs) - updated_inputs["input_ids"] = input_ids[:, -1:] - if attention_mask is not None: - updated_inputs["attention_mask"] = attention_mask - updated_generate_kwargs = dict(generate_kwargs) - updated_generate_kwargs["past_key_values"] = prefill_cache - return updated_inputs, updated_generate_kwargs def decode_runtime_response( @@ -294,7 +239,6 @@ def execute( if request.generation_config.seed is not None: torch.manual_seed(request.generation_config.seed) - inputs = build_runtime_inputs(runtime, request.messages) streamer = None if request.generation_config.stream: streamer = BufferedTextStreamer( @@ -307,13 +251,14 @@ def execute( generate_kwargs, generation_config = build_runtime_generate_kwargs( runtime, request, streamer ) - filtered_inputs = normalize_generate_inputs(inputs) - filtered_inputs, generate_kwargs = prepare_runtime_generate_inputs( + prepared_inputs = prepare_runtime_generate_inputs( runtime, request, - filtered_inputs, generate_kwargs, ) + filtered_inputs = normalize_generate_inputs(prepared_inputs.inputs) + generate_kwargs = prepared_inputs.generate_kwargs + chunked_prefill = prepared_inputs.scope with torch.inference_mode(): with suppress_module_prints(runtime.backend.print_suppression_modules): @@ -332,7 +277,7 @@ def execute( if streamer is not None and not response_text.strip(): response_text = streamer.text assistant_message = Message.assistant_text(response_text) - metadata = self._plan_metadata(runtime, cache_state) + metadata = self._plan_metadata(runtime, cache_state, chunked_prefill) return PromptResponse( text=response_text, assistant_message=assistant_message, metadata=metadata ) @@ -341,7 +286,8 @@ def _finalize_response( self, runtime: LoadedRuntime, response: PromptResponse ) -> PromptResponse: metadata = dict(response.metadata) - metadata.update(self._plan_metadata(runtime, None)) + for key, value in self._plan_metadata(runtime, None, None).items(): + metadata.setdefault(key, value) return PromptResponse( text=response.text, assistant_message=response.assistant_message, @@ -352,6 +298,7 @@ def _plan_metadata( self, runtime: LoadedRuntime, cache_state: KVCacheStateSnapshot | None, + chunked_prefill: ChunkedPrefillScopeSurface | None, ) -> dict[str, str]: metadata = { "backend_id": runtime.plan.backend_id or "unknown", @@ -377,6 +324,30 @@ def _plan_metadata( ), "kv_cache_lifecycle": runtime.config.resolved_kv_cache_lifecycle(), "kv_cache_adaptation_mode": runtime.config.resolved_kv_cache_adaptation_mode(), + "chunked_prefill_strategy_id": ( + "" + if chunked_prefill is None or chunked_prefill.strategy_id is None + else chunked_prefill.strategy_id.value + ), + "chunked_prefill_runtime_eligible": str( + False if chunked_prefill is None else chunked_prefill.runtime_eligible + ).lower(), + "chunked_prefill_applied": str( + False if chunked_prefill is None else chunked_prefill.applied + ).lower(), + "chunked_prefill_activation_reason": ( + "" if chunked_prefill is None else chunked_prefill.activation_reason + ), + "chunked_prefill_execution_boundary": ( + "" + if chunked_prefill is None or chunked_prefill.strategy_id is None + else chunked_prefill.execution_boundary.value + ), + "chunked_prefill_attention_mask_mode": ( + "" + if chunked_prefill is None or chunked_prefill.strategy_id is None + else chunked_prefill.attention_mask_mode.value + ), } resolved_window_tokens = runtime.config.resolved_kv_cache_window_tokens() if resolved_window_tokens is not None: diff --git a/tests/benchmark_support.py b/tests/benchmark_support.py index adf9019..6cf843a 100644 --- a/tests/benchmark_support.py +++ b/tests/benchmark_support.py @@ -6,6 +6,13 @@ RequestProbeMetrics, ) from ollm.runtime.benchmark.resources import StageResourceSnapshot +from ollm.runtime.chunked_prefill import ( + ChunkedPrefillAttentionMaskMode, + ChunkedPrefillExecutionBoundary, + ChunkedPrefillScopeSurface, + ChunkedPrefillStrategyId, + chunked_prefill_gap_inventory, +) def build_stage_resources() -> StageResourceSnapshot: @@ -95,6 +102,17 @@ def build_request_probe_metrics() -> RequestProbeMetrics: evicted_tokens=0, cold_store_format=None, ), + chunked_prefill=ChunkedPrefillScopeSurface( + strategy_id=ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT, + runtime_eligible=True, + applied=True, + activation_reason="Bounded chunked prefill ran before final decode.", + execution_boundary=( + ChunkedPrefillExecutionBoundary.STREAMED_PROMPT_PREPARATION + ), + attention_mask_mode=ChunkedPrefillAttentionMaskMode.LAZY_PREFIX_SYNTHESIS, + gap_inventory=chunked_prefill_gap_inventory(), + ), allocator_gap_mb=20.0, allocator_gap_ratio=0.066667, native_runtime_profile=build_native_runtime_profile(), diff --git a/tests/test_benchmark_probe_execution.py b/tests/test_benchmark_probe_execution.py index 41a988c..70a776f 100644 --- a/tests/test_benchmark_probe_execution.py +++ b/tests/test_benchmark_probe_execution.py @@ -12,6 +12,7 @@ from ollm.runtime.capabilities import CapabilityProfile, SupportLevel from ollm.runtime.capability_discovery import GenericModelKind from ollm.runtime.catalog import ModelModality +from ollm.runtime.chunked_prefill import ChunkedPrefillStrategyId from ollm.runtime.config import GenerationConfig, RuntimeConfig from ollm.runtime.execution_trace import execute_request_with_trace from ollm.runtime.loaded_runtime import LoadedRuntime @@ -39,21 +40,34 @@ def to(self, device, dtype=None): return self +class BenchmarkProcessorTokenizer: + def stream_tokenize_prompt(self, rendered_prompt): + del rendered_prompt + return ([1, 2, 3],) + + class BenchmarkProcessor: def __init__(self): self.inputs = BenchmarkProcessorInputs() + self.tokenizer = BenchmarkProcessorTokenizer() def apply_chat_template( self, messages, add_generation_prompt, tokenize, - return_dict, - return_tensors, + return_dict=False, + return_tensors=None, ): - del messages, add_generation_prompt, tokenize, return_dict, return_tensors + del messages, add_generation_prompt, return_dict, return_tensors + if not tokenize: + return "rendered-long-prompt" return self.inputs + def prepare_chunked_prefill_static_inputs(self, messages, device): + del messages, device + return {} + def batch_decode(self, outputs, skip_special_tokens=False): del outputs, skip_special_tokens return ["decoded-benchmark"] @@ -152,6 +166,11 @@ def test_execute_request_probe_strips_processor_token_type_ids() -> None: assert execution.response_text == "decoded-benchmark" assert "token_type_ids" not in runtime.model.generate_kwargs + assert execution.metrics.chunked_prefill.strategy_id is ( + ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL + ) + assert execution.metrics.chunked_prefill.runtime_eligible is False + assert execution.metrics.chunked_prefill.applied is False def test_execute_request_probe_uses_chunked_prefill_for_long_causal_prompts( @@ -199,6 +218,9 @@ def test_execute_request_probe_uses_chunked_prefill_for_long_causal_prompts( execution = execute_request_probe(runtime=runtime, request=request) assert execution.response_text == "long-decoded" + assert execution.metrics.chunked_prefill.strategy_id is ( + ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT + ) assert len(model.forward_calls) == 2 generate_input_ids = model.generate_kwargs["input_ids"] assert isinstance(generate_input_ids, torch.Tensor) @@ -222,6 +244,11 @@ def test_execute_request_with_trace_reports_processor_counts() -> None: assert trace.decode_prefix_token_count == 3 assert trace.output_token_count == 1 assert trace.cache_state is None + assert trace.chunked_prefill.strategy_id is ( + ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL + ) + assert trace.chunked_prefill.runtime_eligible is False + assert trace.chunked_prefill.applied is False def test_execute_request_with_trace_tracks_chunked_prefill_prefix_length( @@ -272,6 +299,11 @@ def test_execute_request_with_trace_tracks_chunked_prefill_prefix_length( assert trace.prompt_token_count == 5 assert trace.decode_prefix_token_count == 1 assert trace.output_token_count == 4 + assert trace.chunked_prefill.strategy_id is ( + ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT + ) + assert trace.chunked_prefill.runtime_eligible is True + assert trace.chunked_prefill.applied is True assert len(model.forward_calls) == 2 @@ -325,9 +357,9 @@ def wrapped_perf_counter() -> float: "prepare_runtime_generate_inputs" ] - def wrapped_prepare(runtime, request, inputs, generate_kwargs): + def wrapped_prepare(runtime, request, generate_kwargs): order.append("prepare") - return original_prepare(runtime, request, inputs, generate_kwargs) + return original_prepare(runtime, request, generate_kwargs) monkeypatch.setattr( "ollm.runtime.execution_trace.time.perf_counter", wrapped_perf_counter diff --git a/tests/test_benchmark_reporting.py b/tests/test_benchmark_reporting.py index ced1e1f..d6dc458 100644 --- a/tests/test_benchmark_reporting.py +++ b/tests/test_benchmark_reporting.py @@ -33,6 +33,7 @@ parse_output_scaling_probe_result, parse_prompt_scaling_probe_result, parse_reopen_session_growth_probe_result, + parse_runtime_probe_result, parse_session_growth_probe_result, parse_warm_runtime_probe_result, render_output_scaling_probe_json, @@ -101,16 +102,27 @@ def test_render_runtime_probe_json_round_trips() -> None: request=build_request_probe_metrics(), ) - payload = json.loads(render_runtime_probe_json(probe)) + rendered = render_runtime_probe_json(probe) + payload = json.loads(rendered) + parsed = parse_runtime_probe_result(rendered) assert payload["load_ms"] == 10.0 + assert parsed.request.chunked_prefill.applied is True request = cast(dict[str, object], payload["request"]) resources = cast(dict[str, object], request["resources"]) native_runtime_profile = cast(dict[str, object], request["native_runtime_profile"]) cache_state = cast(dict[str, object], request["cache_state"]) + chunked_prefill = cast(dict[str, object], request["chunked_prefill"]) events = cast(dict[str, object], native_runtime_profile["events"]) assert request["output_tokens"] == 4 assert request["kv_cache_strategy"] == "chunked" + assert chunked_prefill["strategy_id"] == "optimized-native-text" + assert chunked_prefill["runtime_eligible"] is True + assert chunked_prefill["applied"] is True + assert chunked_prefill["execution_boundary"] == "streamed-prompt-preparation" + assert chunked_prefill["attention_mask_mode"] == "lazy-prefix-synthesis" + gap_inventory = cast(list[object], chunked_prefill["gap_inventory"]) + assert len(gap_inventory) == 3 adaptation = cast(dict[str, object], request["kv_cache_adaptation"]) assert adaptation["adaptation_mode"] == "observe-only" assert adaptation["recommendation_available"] is True @@ -237,11 +249,15 @@ def test_summarize_request_metrics_includes_native_runtime_profile() -> None: summary = summarize_request_metrics([build_request_probe_metrics()]) native_runtime_profile = cast(dict[str, object], summary["native_runtime_profile"]) + chunked_prefill = cast(dict[str, object], summary["chunked_prefill"]) assert native_runtime_profile["storage_paths"] == [ "disk-kv-cache", "safetensor-io", ] + assert chunked_prefill["strategy_id"] == "optimized-native-text" + assert chunked_prefill["runtime_eligible"] is True + assert chunked_prefill["applied"] is True assert cast(dict[str, object], summary["cache"])["kv_cache_strategy"] == "chunked" adaptation = cast( dict[str, object], diff --git a/tests/test_chunked_prefill_scope.py b/tests/test_chunked_prefill_scope.py new file mode 100644 index 0000000..24e2479 --- /dev/null +++ b/tests/test_chunked_prefill_scope.py @@ -0,0 +1,332 @@ +from dataclasses import replace + +import pytest +import torch +from transformers import T5Config +from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration + +from ollm.app.types import ContentPart, Message, MessageRole +from ollm.runtime.capabilities import CapabilityProfile, SupportLevel +from ollm.runtime.capability_discovery import GenericModelKind +from ollm.runtime.chunked_prefill import ( + ChunkedPrefillGapId, + ChunkedPrefillRecommendation, + ChunkedPrefillStrategyId, + prepare_chunked_prefill, +) +from ollm.runtime.chunked_prefill_support import ( + build_forward_input_filter, + call_processor_for_static_inputs, + render_prompt_text, + tokenize_prompt_piece, +) +from ollm.runtime.generation import ( + build_runtime_generate_kwargs, + prepare_runtime_generate_inputs, +) +from tests.test_runtime_executor import ( + build_request, + build_runtime_with_model, +) +from tests.test_runtime_executor_prefill import ( + ChunkedPrefillModel, + LongMappingTokenizer, +) + + +def test_prepare_runtime_generate_inputs_surfaces_chunked_prefill_scope( + monkeypatch, +) -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=ChunkedPrefillModel(), + ) + runtime.plan = replace( + runtime.plan, + backend_id="optimized-native", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + generate_kwargs, _generation_config = build_runtime_generate_kwargs( + runtime, + request, + streamer=None, + ) + prepared_result = prepare_runtime_generate_inputs(runtime, request, generate_kwargs) + chunked_prefill = prepared_result.scope + + assert chunked_prefill.runtime_eligible is True + assert chunked_prefill.applied is True + assert chunked_prefill.strategy_id is ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT + assert ( + chunked_prefill.activation_reason + == "Bounded chunked prefill ran before final decode." + ) + gap_inventory = { + decision.gap_id: decision.recommendation + for decision in chunked_prefill.gap_inventory + } + assert gap_inventory[ChunkedPrefillGapId.PROMPT_TOKENIZATION_BEFORE_PREFILL] is ( + ChunkedPrefillRecommendation.IMPLEMENT + ) + assert ( + gap_inventory[ChunkedPrefillGapId.FULL_ATTENTION_MASK_BEFORE_PREFILL] + is ChunkedPrefillRecommendation.IMPLEMENT + ) + assert gap_inventory[ChunkedPrefillGapId.SEQ2SEQ_SOURCE_PREFILL] is ( + ChunkedPrefillRecommendation.IMPLEMENT + ) + + +def test_prepare_runtime_generate_inputs_defers_seq2seq_source_prefill( + monkeypatch, +) -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=ChunkedPrefillModel(), + ) + runtime.plan = replace( + runtime.plan, + backend_id="transformers-generic", + generic_model_kind=GenericModelKind.SEQ2SEQ_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + generate_kwargs, _generation_config = build_runtime_generate_kwargs( + runtime, + request, + streamer=None, + ) + prepared_result = prepare_runtime_generate_inputs(runtime, request, generate_kwargs) + chunked_prefill = prepared_result.scope + + assert chunked_prefill.runtime_eligible is True + assert chunked_prefill.applied is True + assert ( + chunked_prefill.strategy_id + is ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE + ) + assert ( + chunked_prefill.activation_reason + == "Streamed seq2seq source tokens were built incrementally before encoder generation." + ) + + +def test_prepare_runtime_generate_inputs_leaves_boundary_blank_without_strategy( + monkeypatch, +) -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=ChunkedPrefillModel(), + ) + runtime.plan = replace( + runtime.plan, + backend_id="custom-backend", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + generate_kwargs, _generation_config = build_runtime_generate_kwargs( + runtime, + request, + streamer=None, + ) + prepared_result = prepare_runtime_generate_inputs(runtime, request, generate_kwargs) + + assert prepared_result.scope.strategy_id is None + assert prepared_result.scope.runtime_eligible is False + + +def test_t5_encoder_does_not_expose_cacheable_source_prefill() -> None: + model = T5ForConditionalGeneration( + T5Config( + vocab_size=64, + d_model=32, + d_kv=8, + d_ff=64, + num_layers=2, + num_decoder_layers=2, + num_heads=4, + pad_token_id=0, + eos_token_id=1, + decoder_start_token_id=0, + ) + ) + model.eval() + input_ids = torch.tensor([[5, 6, 7]]) + attention_mask = torch.ones_like(input_ids) + + with torch.inference_mode(), pytest.raises(ValueError, match="used as a decoder"): + model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=True, + return_dict=True, + ) + + +def test_build_forward_input_filter_falls_back_for_uninspectable_callable() -> None: + class UninspectableForward: + @property + def __signature__(self): + raise ValueError("no signature") + + def __call__(self, **kwargs): + return kwargs + + forward_filter = build_forward_input_filter(UninspectableForward()) + inputs: dict[str, object] = { + "input_ids": torch.tensor([[1, 2]]), + "attention_mask": torch.tensor([[1, 1]]), + "unexpected": "kept", + } + + assert forward_filter(inputs) is inputs + + +def test_build_forward_input_filter_inspects_signature_once() -> None: + class CountingForward: + signature_reads = 0 + + @property + def __signature__(self): + type(self).signature_reads += 1 + return None + + def __call__(self, *, input_ids, attention_mask): + return input_ids, attention_mask + + forward = CountingForward() + forward_filter = build_forward_input_filter(forward) + first = forward_filter( + { + "input_ids": torch.tensor([[1, 2]]), + "attention_mask": torch.tensor([[1, 1]]), + "cache_position": torch.tensor([0, 1]), + } + ) + second = forward_filter( + { + "input_ids": torch.tensor([[3, 4]]), + "attention_mask": torch.tensor([[1, 1]]), + "cache_position": torch.tensor([2, 3]), + } + ) + + assert CountingForward.signature_reads == 1 + assert set(first) == {"input_ids", "attention_mask"} + assert set(second) == {"input_ids", "attention_mask"} + + +def test_prepare_chunked_prefill_rejects_non_positive_chunk_budget() -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=ChunkedPrefillModel(), + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + + with pytest.raises(ValueError, match="chunk_tokens must be at least 1"): + prepare_chunked_prefill( + runtime=runtime, + messages=request.messages, + generate_kwargs={}, + chunk_tokens=0, + eager_input_builder=lambda runtime, messages: { + "input_ids": torch.tensor([[1]]) + }, + ) + + +def test_tokenize_prompt_piece_disables_special_tokens_in_fallback() -> None: + class EncodeOnlyTokenizer: + def encode(self, piece_text: str, *, add_special_tokens: bool = True): + assert piece_text == "piece" + assert add_special_tokens is False + return [7, 8] + + def __call__(self, piece_text: str): + raise TypeError(piece_text) + + assert tokenize_prompt_piece(EncodeOnlyTokenizer(), "piece") == [7, 8] + + +def test_render_prompt_text_falls_back_when_processor_signature_differs() -> None: + class FragileProcessor: + def apply_chat_template(self, messages, tokenize): + del messages, tokenize + raise TypeError("different signature") + + class FallbackTokenizer: + def apply_chat_template( + self, + messages, + tokenize, + add_generation_prompt, + return_tensors=None, + return_dict=False, + ): + del messages, add_generation_prompt, return_dict, return_tensors + if not tokenize: + return "tokenizer-rendered" + return { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[1, 1, 1]]), + } + + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=FallbackTokenizer(), + model=ChunkedPrefillModel(), + ) + runtime.backend = replace( + runtime.backend, + processor=FragileProcessor(), + tokenizer=FallbackTokenizer(), + ) + + rendered = render_prompt_text( + runtime, + [Message(role=MessageRole.USER, content=[ContentPart.text("hello")])], + ) + + assert rendered == "tokenizer-rendered" + + +def test_call_processor_for_static_inputs_omits_return_tensors_when_unsupported() -> ( + None +): + class ReturnTensorsRejectingProcessor: + def __call__(self, *, images): + assert images == ["image.png"] + return {"pixel_values": torch.tensor([[[1.0]]])} + + prepared = call_processor_for_static_inputs( + processor=ReturnTensorsRejectingProcessor(), + image_values=["image.png"], + audio_values=[], + device=torch.device("cpu"), + ) + + pixel_values = prepared["pixel_values"] + assert isinstance(pixel_values, torch.Tensor) + assert torch.equal(pixel_values, torch.tensor([[[1.0]]])) diff --git a/tests/test_runtime_executor.py b/tests/test_runtime_executor.py index fe2ae62..ee5b6f9 100644 --- a/tests/test_runtime_executor.py +++ b/tests/test_runtime_executor.py @@ -3,7 +3,12 @@ import pytest import torch -from ollm.app.types import ContentPart, Message, MessageRole, PromptRequest +from ollm.app.types import ( + ContentPart, + Message, + MessageRole, + PromptRequest, +) from ollm.runtime.backends.base import BackendRuntime from ollm.runtime.capabilities import CapabilityProfile, SupportLevel from ollm.runtime.catalog import ModelModality diff --git a/tests/test_runtime_executor_backend_metadata.py b/tests/test_runtime_executor_backend_metadata.py new file mode 100644 index 0000000..92a35ff --- /dev/null +++ b/tests/test_runtime_executor_backend_metadata.py @@ -0,0 +1,41 @@ +import torch + +from ollm.app.types import ContentPart, Message, MessageRole, PromptResponse +from ollm.runtime.backends.base import BackendRuntime +from ollm.runtime.capabilities import CapabilityProfile, SupportLevel +from ollm.runtime.generation import RuntimeExecutor +from tests.test_runtime_executor import build_request, build_runtime + + +def test_runtime_executor_preserves_backend_chunked_prefill_metadata() -> None: + runtime = build_runtime(CapabilityProfile(support_level=SupportLevel.GENERIC)) + runtime.backend = BackendRuntime( + backend_id="test-backend", + model=None, + tokenizer=None, + processor=None, + device=torch.device("cpu"), + stats=None, + print_suppression_modules=(), + create_cache=lambda cache_dir, cache_strategy=None, cache_lifecycle=None, cache_window_tokens=None: ( + None + ), + apply_offload=lambda runtime_config: None, + execute_prompt=lambda request, sink: PromptResponse( + text="backend-response", + assistant_message=Message.assistant_text("backend-response"), + metadata={ + "chunked_prefill_strategy_id": "backend-strategy", + "chunked_prefill_activation_reason": "backend-owned", + }, + ), + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("hello")]), + ) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.metadata["chunked_prefill_strategy_id"] == "backend-strategy" + assert response.metadata["chunked_prefill_activation_reason"] == "backend-owned" diff --git a/tests/test_runtime_executor_prefill.py b/tests/test_runtime_executor_prefill.py index 7e30859..2cf21e5 100644 --- a/tests/test_runtime_executor_prefill.py +++ b/tests/test_runtime_executor_prefill.py @@ -5,12 +5,17 @@ import torch from ollm.app.types import ContentPart, Message, MessageRole +from ollm.runtime.backends.base import BackendRuntime from ollm.runtime.capabilities import CapabilityProfile, SupportLevel from ollm.runtime.capability_discovery import GenericModelKind +from ollm.runtime.catalog import ModelModality +from ollm.runtime.chunked_prefill import ChunkedPrefillStrategyId from ollm.runtime.generation import RuntimeExecutor +from ollm.runtime.loaded_runtime import LoadedRuntime from tests.test_runtime_executor import ( FakeModel, build_request, + build_runtime, build_runtime_with_model, ) @@ -22,14 +27,11 @@ def apply_chat_template( tokenize, add_generation_prompt, return_tensors, - return_dict, + return_dict=False, ): - del ( - messages, - tokenize, - add_generation_prompt, - return_tensors, - ) + del messages, add_generation_prompt, return_tensors + if not tokenize: + return "rendered-long-prompt" if not return_dict: raise TypeError("return_dict=True required") return { @@ -37,6 +39,10 @@ def apply_chat_template( "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), } + def stream_tokenize_prompt(self, rendered_prompt): + del rendered_prompt + return ([1, 2], [3, 4], [5]) + def decode(self, tensor, skip_special_tokens=False): del tensor, skip_special_tokens return "long-decoded" @@ -72,6 +78,104 @@ def forward( return types.SimpleNamespace(past_key_values=self.prefill_cache) +class LongProcessor: + def __init__(self, static_key: str): + self.static_key = static_key + self.tokenizer = LongMappingTokenizer() + self.static_to_calls: list[tuple[torch.device, torch.dtype | None]] = [] + + def apply_chat_template( + self, + messages, + add_generation_prompt, + tokenize, + return_dict=False, + return_tensors=None, + ): + del messages, add_generation_prompt, return_dict, return_tensors + if not tokenize: + return "rendered-long-prompt" + return { + "input_ids": torch.tensor([[1, 2, 3, 4, 5]]), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), + self.static_key: torch.tensor([[[0.25, 0.5]]], dtype=torch.float32), + } + + def prepare_chunked_prefill_static_inputs(self, messages, device): + del messages + self.static_to_calls.append((device, torch.bfloat16)) + return { + self.static_key: torch.tensor( + [[[0.25, 0.5]]], + device=device, + dtype=torch.float32, + ) + } + + def batch_decode(self, outputs, skip_special_tokens=False): + del outputs, skip_special_tokens + return ["long-decoded"] + + +class ChunkedPrefillMultimodalModel(FakeModel): + def __init__(self) -> None: + super().__init__() + self.forward_calls: list[dict[str, object]] = [] + self.prefill_cache = object() + + def forward( + self, + input_ids, + attention_mask=None, + pixel_values=None, + input_features=None, + past_key_values=None, + use_cache=None, + cache_position=None, + ): + self.forward_calls.append( + { + "input_ids": input_ids.clone(), + "attention_mask": None + if attention_mask is None + else attention_mask.clone(), + "pixel_values": None if pixel_values is None else pixel_values.clone(), + "input_features": None + if input_features is None + else input_features.clone(), + "past_key_values": past_key_values, + "use_cache": use_cache, + "cache_position": None + if cache_position is None + else cache_position.clone(), + } + ) + return types.SimpleNamespace(past_key_values=self.prefill_cache) + + +def build_runtime_with_processor_model( + *, + capabilities: CapabilityProfile, + processor: LongProcessor, + model: FakeModel, +) -> LoadedRuntime: + runtime = build_runtime(capabilities) + runtime.backend = BackendRuntime( + backend_id=runtime.backend.backend_id, + model=model, + tokenizer=runtime.tokenizer, + processor=processor, + device=torch.device("cpu"), + stats=None, + print_suppression_modules=(), + create_cache=lambda cache_dir, cache_strategy=None, cache_lifecycle=None, cache_window_tokens=None: ( + None + ), + apply_offload=lambda runtime_config: None, + ) + return runtime + + def test_runtime_executor_prefills_long_causal_prompts_in_chunks(monkeypatch) -> None: model = ChunkedPrefillModel() runtime = build_runtime_with_model( @@ -93,6 +197,10 @@ def test_runtime_executor_prefills_long_causal_prompts_in_chunks(monkeypatch) -> response = RuntimeExecutor().execute(runtime, request) assert response.text == "long-decoded" + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT.value + ) assert len(model.forward_calls) == 2 first_call = model.forward_calls[0] second_call = model.forward_calls[1] @@ -123,7 +231,7 @@ def test_runtime_executor_prefills_long_causal_prompts_in_chunks(monkeypatch) -> assert model.generate_kwargs["past_key_values"] is model.prefill_cache -def test_runtime_executor_skips_chunked_prefill_for_seq2seq_runtime( +def test_runtime_executor_prefills_long_generic_causal_prompts_in_chunks( monkeypatch, ) -> None: model = ChunkedPrefillModel() @@ -132,9 +240,143 @@ def test_runtime_executor_skips_chunked_prefill_for_seq2seq_runtime( tokenizer=LongMappingTokenizer(), model=model, ) + runtime.plan = replace( + runtime.plan, + backend_id="transformers-generic", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.text == "long-decoded" + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_TEXT.value + ) + assert len(model.forward_calls) == 2 + generate_input_ids = model.generate_kwargs["input_ids"] + assert isinstance(generate_input_ids, torch.Tensor) + assert torch.equal(generate_input_ids, torch.tensor([[5]])) + + +def test_runtime_executor_prefills_long_native_multimodal_prompts_in_chunks( + monkeypatch, +) -> None: + processor = LongProcessor("pixel_values") + model = ChunkedPrefillMultimodalModel() + runtime = build_runtime_with_processor_model( + capabilities=CapabilityProfile( + support_level=SupportLevel.OPTIMIZED, + modalities=(ModelModality.TEXT, ModelModality.IMAGE), + requires_processor=True, + ), + processor=processor, + model=model, + ) runtime.plan = replace( runtime.plan, backend_id="optimized-native", + generic_model_kind=GenericModelKind.IMAGE_TEXT_TO_TEXT, + ) + request = build_request( + runtime.config, + Message( + role=MessageRole.USER, + content=[ + ContentPart.text("long prompt"), + ContentPart.image("image.png"), + ], + ), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.text == "long-decoded" + assert processor.static_to_calls == [(torch.device("cpu"), torch.bfloat16)] + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_MULTIMODAL.value + ) + assert len(model.forward_calls) == 2 + first_call = model.forward_calls[0] + second_call = model.forward_calls[1] + first_pixel_values = cast(torch.Tensor, first_call["pixel_values"]) + second_pixel_values = cast(torch.Tensor, second_call["pixel_values"]) + assert torch.equal(first_pixel_values, torch.tensor([[[0.25, 0.5]]])) + assert torch.equal(second_pixel_values, torch.tensor([[[0.25, 0.5]]])) + generate_input_ids = model.generate_kwargs["input_ids"] + generate_pixel_values = model.generate_kwargs["pixel_values"] + assert isinstance(generate_input_ids, torch.Tensor) + assert isinstance(generate_pixel_values, torch.Tensor) + assert torch.equal(generate_input_ids, torch.tensor([[5]])) + assert torch.equal(generate_pixel_values, torch.tensor([[[0.25, 0.5]]])) + + +def test_runtime_executor_prefills_long_generic_multimodal_prompts_in_chunks( + monkeypatch, +) -> None: + processor = LongProcessor("pixel_values") + model = ChunkedPrefillMultimodalModel() + runtime = build_runtime_with_processor_model( + capabilities=CapabilityProfile( + support_level=SupportLevel.GENERIC, + modalities=(ModelModality.TEXT, ModelModality.IMAGE), + requires_processor=True, + ), + processor=processor, + model=model, + ) + runtime.plan = replace( + runtime.plan, + backend_id="transformers-generic", + generic_model_kind=GenericModelKind.IMAGE_TEXT_TO_TEXT, + ) + request = build_request( + runtime.config, + Message( + role=MessageRole.USER, + content=[ + ContentPart.text("long prompt"), + ContentPart.image("image.png"), + ], + ), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.text == "long-decoded" + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_MULTIMODAL.value + ) + assert len(model.forward_calls) == 2 + generate_input_ids = model.generate_kwargs["input_ids"] + generate_pixel_values = model.generate_kwargs["pixel_values"] + assert isinstance(generate_input_ids, torch.Tensor) + assert isinstance(generate_pixel_values, torch.Tensor) + assert torch.equal(generate_input_ids, torch.tensor([[5]])) + assert torch.equal(generate_pixel_values, torch.tensor([[[0.25, 0.5]]])) + + +def test_runtime_executor_streams_seq2seq_source_strategy( + monkeypatch, +) -> None: + model = ChunkedPrefillModel() + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=model, + ) + runtime.plan = replace( + runtime.plan, + backend_id="transformers-generic", generic_model_kind=GenericModelKind.SEQ2SEQ_LM, ) request = build_request( @@ -143,9 +385,80 @@ def test_runtime_executor_skips_chunked_prefill_for_seq2seq_runtime( ) monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) - RuntimeExecutor().execute(runtime, request) + response = RuntimeExecutor().execute(runtime, request) assert model.forward_calls == [] + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.TRANSFORMERS_GENERIC_SEQ2SEQ_SOURCE.value + ) + assert response.metadata["chunked_prefill_applied"] == "true" generate_input_ids = model.generate_kwargs["input_ids"] assert isinstance(generate_input_ids, torch.Tensor) assert torch.equal(generate_input_ids, torch.tensor([[1, 2, 3, 4, 5]])) + + +def test_runtime_executor_leaves_chunked_prefill_metadata_blank_without_strategy( + monkeypatch, +) -> None: + model = ChunkedPrefillModel() + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=model, + ) + runtime.plan = replace( + runtime.plan, + backend_id="custom-backend", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.metadata["chunked_prefill_strategy_id"] == "" + assert response.metadata["chunked_prefill_runtime_eligible"] == "false" + assert response.metadata["chunked_prefill_execution_boundary"] == "" + assert response.metadata["chunked_prefill_attention_mask_mode"] == "" + + +def test_runtime_executor_falls_back_when_forward_is_unavailable_for_chunking( + monkeypatch, +) -> None: + runtime = build_runtime_with_model( + CapabilityProfile(support_level=SupportLevel.GENERIC), + tokenizer=LongMappingTokenizer(), + model=FakeModel(), + ) + runtime.plan = replace( + runtime.plan, + backend_id="optimized-native", + generic_model_kind=GenericModelKind.CAUSAL_LM, + ) + runtime.model.forward = None + request = build_request( + runtime.config, + Message(role=MessageRole.USER, content=[ContentPart.text("long prompt")]), + ) + monkeypatch.setattr("ollm.runtime.generation.DEFAULT_PREFILL_CHUNK_TOKENS", 2) + + response = RuntimeExecutor().execute(runtime, request) + + assert response.text == "long-decoded" + assert ( + response.metadata["chunked_prefill_strategy_id"] + == ChunkedPrefillStrategyId.OPTIMIZED_NATIVE_TEXT.value + ) + assert response.metadata["chunked_prefill_runtime_eligible"] == "false" + assert response.metadata["chunked_prefill_applied"] == "false" + assert ( + response.metadata["chunked_prefill_activation_reason"] + == "Chunked prompt-ingestion strategy requires a callable forward method." + ) + generate_input_ids = runtime.model.generate_kwargs["input_ids"] + assert isinstance(generate_input_ids, torch.Tensor) + assert torch.equal(generate_input_ids, torch.tensor([[1, 2, 3, 4, 5]]))