[4/n] Add vLLM integration for modelopt sparse attention#1127
[4/n] Add vLLM integration for modelopt sparse attention#1127
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (7)
✅ Files skipped from review due to trivial changes (2)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds sparse-attention support end-to-end: new vLLM worker variants that load sparse/quant configs and patch Attention impls, a ModelOpt vLLM sparse backend, Triton flash-attention paged KV-cache support, an example server script, and GPU tests validating paged-KV behavior. Changes
Sequence DiagramssequenceDiagram
participant Env as Env / JSON
participant Worker as vLLM Worker
participant Model as Model Loader
participant Plugin as ModelOpt Plugin
participant Attention as Attention Modules
Env->>Worker: provide SPARSE_ATTN_CFG / SPARSE_CALIB_CONFIG_PATH / QUANT env
Worker->>Model: load_model()
Model-->>Worker: model with Attention modules
Worker->>Plugin: set_sparse_config(sparse_cfg)
Worker->>Attention: traverse & replace impls with ModelOptSparseAttentionImpl
Attention-->>Worker: patched layers count
sequenceDiagram
participant Input as Forward Input (query, kv, metadata)
participant Impl as ModelOptSparseAttentionImpl
participant Cache as Paged KV Cache (k_cache/v_cache, block_table)
participant Kernel as triton_fa.attention
participant Output as Output Tensor
Input->>Impl: forward(query, key, value, kv_cache, attn_metadata)
Impl->>Cache: unpack paged k_cache, v_cache, block_table, page_size
Impl->>Impl: construct sparse_kw from per-layer config
Impl->>Kernel: call attention(..., k_cache, v_cache, block_table, page_size, sparse_kw)
Kernel-->>Impl: attention result
Impl->>Output: write and return result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1127 +/- ##
==========================================
- Coverage 70.21% 70.19% -0.03%
==========================================
Files 230 230
Lines 26073 26073
==========================================
- Hits 18308 18302 -6
- Misses 7765 7771 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
26c6b3b to
e4c4680
Compare
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (1)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py (1)
290-336: Assert decode correctness, not just finiteness.This case still passes if paged decode reads the wrong blocks or masks the wrong keys, because it only checks shape/NaNs. Please compare
outagainst a contiguous decode reference here as well.As per coding guidelines, "Write tests using pytest for all new features and examples; organize tests into
tests/unit(fast CPU-based),tests/gpu(fast GPU-based),tests/gpu_megatron(Megatron-Core),tests/gpu_trtllm(TensorRT-LLM), andtests/examples(integration tests)" and "All test coverage checks in PRs must pass for new features and examples."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py` around lines 290 - 336, The test currently only checks shape/NaNs for paged decode; add a correctness assertion by computing a contiguous-reference decode (call the same attention function with the original k_flat/v_flat and without k_cache/v_cache/block_table/page_size, i.e., the non-paged code path) and compare outputs from test_paged_decode to that reference using torch.testing.assert_allclose (or torch.allclose with a small rtol/atol) to ensure values match; keep the existing shape and NaN checks and reuse q_flat, k_flat, v_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k, and scale so the only difference is paged vs contiguous.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 137-140: The worker assumes a different kv_cache axis order than
the vllm plugin; instead of slicing kv_cache[:, 0] and kv_cache[:, 1], normalize
to the same layout used elsewhere by splitting with kv_cache.unbind(0) so
k_cache and v_cache come from that unbind; update page_size to derive from the
resulting k_cache shape (e.g., k_cache.shape[1]) and ensure any downstream uses
of k_cache/v_cache match this normalized layout (references: kv_cache, k_cache,
v_cache and the existing kv_cache.unbind(0) usage in the vllm plugin).
- Around line 297-300: The import inside compile_or_warm_up_model currently uses
a relative import ("from .fakequant_worker import _fakequant_run_prolog_worker,
quant_config") which fails when the module is loaded as a top-level module;
change it to a top-level/absolute import or dynamic import so the code works
whether loaded as a package or directly (e.g., use "from fakequant_worker import
_fakequant_run_prolog_worker, quant_config" or use
importlib.import_module("fakequant_worker") and grab the attributes). Update the
import used by compile_or_warm_up_model and any callers expecting
_fakequant_run_prolog_worker and quant_config accordingly so the class
SparseQuantWorker can be imported as a top-level module without ImportError.
- Around line 273-284: The replacement sets sliding_window=None which disables
local/sliding-window attention; update the instantiation of
ModelOptSparseAttentionImpl (the assignment to module.impl) to pass through the
original value (old_impl.sliding_window) instead of None, or add a guard that
rejects/raises when old_impl.sliding_window is non-None and unsupported; ensure
you reference old_impl.sliding_window and ModelOptSparseAttentionImpl in the
change so sliding-window behavior is preserved or explicitly handled.
- Around line 176-183: The decode call to triton_attention in
sparse_attn_worker.py is using is_causal=True which causes incorrect masking for
paged KV; change the triton_attention invocation (the call that sets
q=query[offset: offset+nd], k=query[:0], v=query[:0],
b_start_loc=dm.query_start_loc, b_seq_len=..., max_input_len=1) to pass
is_causal=False instead, matching the decode path in
modelopt/torch/kernels/hf_triton_attention.py so later cached KV tiles are not
truncated.
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 945-958: The code currently allows paged-mode (is_paged True) but
does not disable/guard autograd, causing backward to recompute using contiguous
K/V and dummy b_start_loc_k and produce wrong gradients; update the forward
entrypoint that sets is_paged (look for the block using is_paged, b_start_loc_k
and b_start_loc) to explicitly disallow autograd in paged mode by either raising
a clear exception when torch.is_grad_enabled() (or when requires_grad on inputs)
and is_paged is True, or by wrapping the paged-mode path in torch.no_grad() and
documenting that backward is unsupported; ensure the guard references is_paged
and b_start_loc_k so callers cannot silently run backward with dummy
b_start_loc_k.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py`:
- Around line 81-99: The current loop builds a single sparse_kw from the first
enabled entry in _sparse_config["sparse_cfg"] and breaks, which ignores
layer-specific patterns; instead, for each module use name-based matching to
select the correct layer_cfg (reuse the matching logic from
examples/vllm_serve/sparse_attn_worker.py::_match_sparse_config or call that
helper) and then build a per-module sparse_kw from that matched layer_cfg
(respecting fields like sparsity_n, sparsity_m, num_sink_tokens,
dense_window_size, skip_softmax_threshold); do not break out of the loop—apply
the matched config only to the current module or stash sparse_kw on the module
instance before swapping implementations so multiple patterns in the calibration
file are handled correctly.
---
Nitpick comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py`:
- Around line 290-336: The test currently only checks shape/NaNs for paged
decode; add a correctness assertion by computing a contiguous-reference decode
(call the same attention function with the original k_flat/v_flat and without
k_cache/v_cache/block_table/page_size, i.e., the non-paged code path) and
compare outputs from test_paged_decode to that reference using
torch.testing.assert_allclose (or torch.allclose with a small rtol/atol) to
ensure values match; keep the existing shape and NaN checks and reuse q_flat,
k_flat, v_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k, and
scale so the only difference is paged vs contiguous.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b27921dd-dcd2-4ce3-acc9-28816a998e1f
📒 Files selected for processing (7)
examples/vllm_serve/sparse_attn_worker.pyexamples/vllm_serve/vllm_serve_sparse_attn.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/plugins/__init__.pymodelopt/torch/sparsity/attention_sparsity/plugins/vllm.pypyproject.tomltests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py
| is_paged = k_cache is not None | ||
|
|
||
| # Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V. | ||
| # Decode: K/V is a separate KV cache tensor, caller must pass explicit metadata. | ||
| if b_seq_len_k is None: | ||
| b_seq_len_k = b_seq_len | ||
| b_start_loc_k = b_start_loc | ||
| max_input_len_k = max_input_len | ||
|
|
||
| # Paged mode: b_start_loc_k may be None (KV is in paged cache, not contiguous). | ||
| # Provide a dummy tensor so Triton can compile the tl.load (it won't be used). | ||
| if b_start_loc_k is None: | ||
| b_start_loc_k = torch.zeros_like(b_start_loc) | ||
|
|
There was a problem hiding this comment.
Reject paged-mode autograd until backward is implemented.
is_paged only changes the forward kernel. Backward still recomputes from contiguous K/V and b_start_loc_k, which are dummy/unused in paged mode, so a backward pass can return incorrect gradients instead of failing fast.
Possible guard
is_paged = k_cache is not None
+ if is_paged and (q.requires_grad or k.requires_grad or v.requires_grad):
+ raise NotImplementedError(
+ "Paged KV cache path is forward-only; backward is not implemented."
+ )
# Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/triton_fa.py` around lines 945 - 958, The code
currently allows paged-mode (is_paged True) but does not disable/guard autograd,
causing backward to recompute using contiguous K/V and dummy b_start_loc_k and
produce wrong gradients; update the forward entrypoint that sets is_paged (look
for the block using is_paged, b_start_loc_k and b_start_loc) to explicitly
disallow autograd in paged mode by either raising a clear exception when
torch.is_grad_enabled() (or when requires_grad on inputs) and is_paged is True,
or by wrapping the paged-mode path in torch.no_grad() and documenting that
backward is unsupported; ensure the guard references is_paged and b_start_loc_k
so callers cannot silently run backward with dummy b_start_loc_k.
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)
145-152:⚠️ Potential issue | 🟠 MajorPreserve existing
sliding_windowwhen replacing attention implLine 151 hardcodes
sliding_window=None, which can silently change local/sliding-window attention behavior. Pass throughold_impl.sliding_window(or explicitly reject unsupported non-None values).Proposed fix
module.impl = ModelOptSparseAttentionImpl( num_heads=old_impl.num_heads, head_size=old_impl.head_size, scale=old_impl.scale, num_kv_heads=old_impl.num_kv_heads, alibi_slopes=old_impl.alibi_slopes, - sliding_window=None, + sliding_window=old_impl.sliding_window, kv_cache_dtype=old_impl.kv_cache_dtype, logits_soft_cap=old_impl.logits_soft_cap, attn_type=old_impl.attn_type, kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/sparse_attn_worker.py` around lines 145 - 152, The construction of ModelOptSparseAttentionImpl currently hardcodes sliding_window=None which can change attention behavior; update the initializer in the replacement code to pass through old_impl.sliding_window (i.e., use sliding_window=old_impl.sliding_window) or, if non-None values are unsupported, explicitly check old_impl.sliding_window and raise an informative error before creating ModelOptSparseAttentionImpl; reference ModelOptSparseAttentionImpl, old_impl, and old_impl.sliding_window to locate and fix the code.
🧹 Nitpick comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)
111-119: Remove or wire_match_sparse_configto avoid dead-path drift
_match_sparse_configis currently unused, which makes behavior harder to reason about and can drift from real matching logic. Either use it in patching/selection flow or remove it until needed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/sparse_attn_worker.py` around lines 111 - 119, The helper function _match_sparse_config is unused and creates dead-path drift; either remove this function or wire it into the sparse patch/selection flow by replacing the current pattern-matching logic with a call to _match_sparse_config(module_name, sparse_cfg) (or call it from wherever sparse layer configs are looked up) so that matching behavior is centralized; update any callers that currently duplicate pattern checks to use _match_sparse_config and remove dead duplicates if you choose to keep it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 81-83: The code currently returns whatever getattr(mtsa, cfg_name,
None) yields (cfg) and may hand back non-mapping objects; update the getter to
validate that cfg is a dict before returning (use isinstance(cfg, dict)) and
otherwise raise a clear error (e.g., ValueError) indicating that the symbol
named by cfg_name in modelopt.torch.sparsity.attention_sparsity must be a dict;
reference the getattr(mtsa, cfg_name, None) call and the cfg variable to locate
the change.
- Around line 92-106: The _load_sparse_config function currently trusts
arbitrary JSON from SPARSE_CALIB_CONFIG_PATH; update it to validate the loaded
object and each layer_cfg: assert the top-level JSON is a dict, allowed
top-level keys are strings and either "calibration" or pattern names, and each
layer_cfg is a dict before applying defaults; enforce allowed keys (e.g.,
"method", "backend", "enable", numeric sparsity params) and bounds for numeric
fields (e.g., sparsity percentages 0–100, integer layer indices >=0, and
reasonable max limits) and reject or clamp out-of-range values, raising a clear
exception on invalid schema; keep the existing defaults
(method="triton_sparse_softmax", backend="triton", enable=True) for valid
entries and ensure sparse_cfg["default"] = {"enable": False} remains set.
---
Duplicate comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 145-152: The construction of ModelOptSparseAttentionImpl currently
hardcodes sliding_window=None which can change attention behavior; update the
initializer in the replacement code to pass through old_impl.sliding_window
(i.e., use sliding_window=old_impl.sliding_window) or, if non-None values are
unsupported, explicitly check old_impl.sliding_window and raise an informative
error before creating ModelOptSparseAttentionImpl; reference
ModelOptSparseAttentionImpl, old_impl, and old_impl.sliding_window to locate and
fix the code.
---
Nitpick comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 111-119: The helper function _match_sparse_config is unused and
creates dead-path drift; either remove this function or wire it into the sparse
patch/selection flow by replacing the current pattern-matching logic with a call
to _match_sparse_config(module_name, sparse_cfg) (or call it from wherever
sparse layer configs are looked up) so that matching behavior is centralized;
update any callers that currently duplicate pattern checks to use
_match_sparse_config and remove dead duplicates if you choose to keep it.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: cf1d8866-f2b3-4adf-9b4a-eb396967c80e
📒 Files selected for processing (1)
examples/vllm_serve/sparse_attn_worker.py
| cfg = getattr(mtsa, cfg_name, None) | ||
| if cfg is not None: | ||
| return cfg |
There was a problem hiding this comment.
Validate preset object type before returning it
If SPARSE_ATTN_CFG matches a non-dict symbol in modelopt.torch.sparsity.attention_sparsity, cfg is returned as-is and later consumed as a mapping, which can crash at runtime. Add an explicit isinstance(cfg, dict) guard and fail fast with a clear error.
Proposed fix
cfg = getattr(mtsa, cfg_name, None)
if cfg is not None:
- return cfg
+ if not isinstance(cfg, dict):
+ raise ValueError(
+ f"Invalid sparse config preset '{cfg_name}': expected dict, got {type(cfg).__name__}."
+ )
+ return cfg🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/vllm_serve/sparse_attn_worker.py` around lines 81 - 83, The code
currently returns whatever getattr(mtsa, cfg_name, None) yields (cfg) and may
hand back non-mapping objects; update the getter to validate that cfg is a dict
before returning (use isinstance(cfg, dict)) and otherwise raise a clear error
(e.g., ValueError) indicating that the symbol named by cfg_name in
modelopt.torch.sparsity.attention_sparsity must be a dict; reference the
getattr(mtsa, cfg_name, None) call and the cfg variable to locate the change.
| def _load_sparse_config(path: str) -> dict: | ||
| """Load offline calibration config JSON.""" | ||
| with open(path) as f: | ||
| calib_cfg = json.load(f) | ||
|
|
||
| sparse_cfg = {} | ||
| for pattern, layer_cfg in calib_cfg.items(): | ||
| if pattern == "calibration": | ||
| sparse_cfg[pattern] = layer_cfg | ||
| continue | ||
| layer_cfg.setdefault("method", "triton_sparse_softmax") | ||
| layer_cfg.setdefault("backend", "triton") | ||
| layer_cfg.setdefault("enable", True) | ||
| sparse_cfg[pattern] = layer_cfg | ||
| sparse_cfg["default"] = {"enable": False} |
There was a problem hiding this comment.
Harden calibration JSON parsing with schema and bounds checks
SPARSE_CALIB_CONFIG_PATH is env-driven input, but Line 94–106 accepts arbitrary JSON structure and values without validation. This can propagate malformed sparsity params into kernel calls and create avoidable failure/DoS risk. Validate top-level/object types, allowed keys, and integer ranges before applying defaults.
Proposed fix
def _load_sparse_config(path: str) -> dict:
"""Load offline calibration config JSON."""
- with open(path) as f:
+ with open(path, encoding="utf-8") as f:
calib_cfg = json.load(f)
+ if not isinstance(calib_cfg, dict):
+ raise ValueError("Calibration config must be a JSON object mapping pattern -> layer config.")
sparse_cfg = {}
for pattern, layer_cfg in calib_cfg.items():
+ if not isinstance(pattern, str):
+ raise ValueError("Calibration config keys must be strings.")
if pattern == "calibration":
sparse_cfg[pattern] = layer_cfg
continue
+ if not isinstance(layer_cfg, dict):
+ raise ValueError(f"Layer config for pattern '{pattern}' must be an object.")
+ for int_key in ("sparsity_n", "sparsity_m", "num_sink_tokens", "dense_window_size"):
+ if int_key in layer_cfg and (
+ not isinstance(layer_cfg[int_key], int) or layer_cfg[int_key] < 0
+ ):
+ raise ValueError(f"Invalid '{int_key}' for pattern '{pattern}': {layer_cfg[int_key]!r}")
layer_cfg.setdefault("method", "triton_sparse_softmax")
layer_cfg.setdefault("backend", "triton")
layer_cfg.setdefault("enable", True)
sparse_cfg[pattern] = layer_cfgAs per coding guidelines, "Validate inputs and enforce limits to reduce resource-exhaustion/DoS risk (e.g., file sizes, expected schema/shape for sparse config/calibration JSON)."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/vllm_serve/sparse_attn_worker.py` around lines 92 - 106, The
_load_sparse_config function currently trusts arbitrary JSON from
SPARSE_CALIB_CONFIG_PATH; update it to validate the loaded object and each
layer_cfg: assert the top-level JSON is a dict, allowed top-level keys are
strings and either "calibration" or pattern names, and each layer_cfg is a dict
before applying defaults; enforce allowed keys (e.g., "method", "backend",
"enable", numeric sparsity params) and bounds for numeric fields (e.g., sparsity
percentages 0–100, integer layer indices >=0, and reasonable max limits) and
reject or clamp out-of-range values, raising a clear exception on invalid
schema; keep the existing defaults (method="triton_sparse_softmax",
backend="triton", enable=True) for valid entries and ensure
sparse_cfg["default"] = {"enable": False} remains set.
4644bf5 to
54079b8
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Summary: Adds paged KV cache support to the ModelOpt Triton flash attention kernel, a vLLM sparse attention plugin, worker classes for vLLM integration, and GPU tests for the paged KV path.
Issues Found:
-
[Correctness] Backward pass silently broken for paged KV —
triton_fa.py,forward()savesk, vtoctxfor backward, but in paged mode these are dummy/empty tensors (e.g.,k_dummy = torch.empty(0, ...)from the vLLM plugin). If.backward()is ever called on paged-mode output, gradients for dK/dV will be computed against the dummy tensors, producing silently incorrect results. The backward should either raiseNotImplementedError("Backward not supported for paged KV cache")whenis_paged=True, or the limitation should be clearly documented in theattention()docstring. The current code just addsNonereturn placeholders for the 4 new args without any guard. -
[Correctness] Unused
import_pluginimport inplugins/__init__.py— The diff addsfrom modelopt.torch.utils import import_pluginbut it is never used in the file. The existing__init__.pydoesn't use this import, andvllm.pyimports directly from vLLM. This is dead code that should be removed, or if the intent was to useimport_pluginfor conditional vLLM import (as other plugin modules do), that wiring is missing. -
[Correctness]
_build_sparse_configfallback logic is confusing —sparse_attn_worker.py:78-86:getattr(mtsa, cfg_name, None)is tried first, but thenSPARSE_SOFTMAX_DEFAULTfalls through to the hardcoded_DEFAULT_SPARSE_CFGdict. Ifmtsaactually definesSPARSE_SOFTMAX_DEFAULTin the future, thegetattrpath would return it and the hardcoded default would never be used, leading to silent behavior divergence. Consider making the precedence explicit or removing the duplication. -
[Readability] Duplicated paged V-tile loading —
triton_fa.py_attn_fwdkernel:_load_paged_v_tileis called with identical arguments in two separate branches (skip-softmax path around line 476 and standard path around line 518). While Triton JIT constraints may require this, the duplicated 20-line call blocks are a readability concern. A comment explaining why the duplication is necessary would help. -
[Tests] No backward/gradient test for paged mode —
test_triton_fa_paged.pyonly tests forward correctness. Given that paged mode changes the autograd Function'sforwardsignature and backward is not updated to support paged KV, there should be at minimum a test asserting that backward raises an error (if a guard is added per issue #1), or this should be explicitly documented as inference-only. -
[Tests] No integration test for vLLM plugin —
ModelOptSparseAttentionImpl.forward()andModelOptSparseAttentionBackendhave no test coverage. These are the most integration-critical new classes. Even a mock-based unit test validating the metadata translation logic (cu_seqlens_q→b_start_loc,seq_lens→b_seq_len_k) would catch regressions. -
[Tests] Inconsistent
b_start_loc_khandling across tests —test_paged_matches_contiguouspasses explicitb_start_loc_k=locs_k, whiletest_paged_no_nanomits it (relying on the dummy-zeros fallback). Both should use the same calling convention to avoid masking bugs in the fallback path. -
[Readability]
if threshold:truthiness check —sparse_attn_worker.py:146:if threshold:evaluatesFalsefor bothNoneand0.0. While0.0correctly means "disabled", this is subtle.if threshold is not Nonewould be clearer about intent (let the kernel handle the0.0case).
Suggestions:
- Consider adding
requires_grad=Falsetok_dummyandv_dummyin the vLLM plugin to make the inference-only intent explicit and catch accidental backward calls early. - The
_DEFAULT_SPARSE_CFGhardcoded in the worker could referencemtsaconstants instead, reducing drift risk. test_paged_decodeusesq_flatwith shape[batch, num_heads, head_dim](3D) rather than the expected[total_q_tokens, num_heads, head_dim]. This works becausebatch * 1 = batchtokens, but the shape semantics are confusing — usingq_flat.reshape(batch, num_heads, head_dim)explicitly or adding a comment would clarify.
Overall Assessment: The core kernel extension (paged KV tile loaders + IS_PAGED branching) is well-structured and the tests verify forward correctness across multiple configurations. The main blocking concern is the silent backward-pass breakage for paged mode — this needs at minimum a guard or documentation since the function is part of the public attention() API with autograd support. The unused import is a minor cleanup. The vLLM plugin lacks test coverage but is in examples/ territory so is less critical.
What does this PR do?
Type of change: ?
New feature. Add vLLM integration for ModelOpt sparse attention with paged KV cache support.
Extends the Triton flash attention kernel (
triton_fa.py) with paged KV cache support. KV cache can be read directly from vLLM's non-contiguous paged cache viablock_table lookup, avoiding expensive gather-to-contiguous copies. Both N:M sparse softmax and skip-softmax work with paged KV.SparseVLLMAttentionwraps vLLM's Attention layer. It lets vLLM write KV to its paged cache, then calls the ModelOpt Triton kernel withk_cache,v_cache,block_tablefor both prefill and decode.SparseAttnWorkerpatches vLLM attention modules at model load time.SparseQuantWorkercombines quantization+sparse attention. Worker selection is automatic based on env vars (SPARSE_ATTN_CFG, QUANT_CFG).vllm_serve_sparse_attn.pylaunches a vLLM OpenAI-compatible server with sparse attention enabled.Usage
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Tests