Skip to content

[4/n] Add vLLM integration for modelopt sparse attention#1127

Open
kaix-nv wants to merge 4 commits intomainfrom
kaix/sparse_attn_vllm_integration
Open

[4/n] Add vLLM integration for modelopt sparse attention#1127
kaix-nv wants to merge 4 commits intomainfrom
kaix/sparse_attn_vllm_integration

Conversation

@kaix-nv
Copy link
Copy Markdown
Contributor

@kaix-nv kaix-nv commented Mar 27, 2026

What does this PR do?

Type of change: ?

New feature. Add vLLM integration for ModelOpt sparse attention with paged KV cache support.

Extends the Triton flash attention kernel (triton_fa.py) with paged KV cache support. KV cache can be read directly from vLLM's non-contiguous paged cache via block_table lookup, avoiding expensive gather-to-contiguous copies. Both N:M sparse softmax and skip-softmax work with paged KV.

SparseVLLMAttention wraps vLLM's Attention layer. It lets vLLM write KV to its paged cache, then calls the ModelOpt Triton kernel with k_cache, v_cache, block_table for both prefill and decode.

SparseAttnWorker patches vLLM attention modules at model load time. SparseQuantWorker combines quantization+sparse attention. Worker selection is automatic based on env vars (SPARSE_ATTN_CFG, QUANT_CFG).

vllm_serve_sparse_attn.py launches a vLLM OpenAI-compatible server with sparse attention enabled.

Usage

# Launch vLLM with 2:4 sparse attention
SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python
examples/vllm_serve/vllm_serve_sparse_attn.py \
    meta-llama/Llama-3.1-8B --max-model-len 8192

# Kernel-level paged KV API
from modelopt.torch.kernels import attention

out = attention(
    q, k, v, b_start_loc, b_seq_len, max_input_len,
    is_causal=False,
    softmax_scale=scale,
    b_start_loc_k=b_start_loc_k,
    b_seq_len_k=b_seq_len_k,
    max_input_len_k=max_kv_len,
    sparsity_n=2, sparsity_m=4,
    k_cache=k_cache,        # [num_blocks, page_size, num_kv_heads, head_dim]
    v_cache=v_cache,        # [num_blocks, page_size, num_kv_heads, head_dim]
    block_table=block_table, # [batch, max_blocks_per_seq]
    page_size=16,
)

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Sparse attention support for vLLM with environment-driven configuration, optional calibration and quantization-aware worker options
    • Paged KV-cache support in the attention kernel for larger contexts and reduced memory usage
    • vLLM backend plugin and an example server launcher for sparse/quantized deployments
  • Tests

    • New GPU tests validating paged KV-cache attention correctness, stability, and sparsity scenarios

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 27, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 27, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 1b021776-1b88-455e-8f04-6e7e2b7df854

📥 Commits

Reviewing files that changed from the base of the PR and between 4644bf5 and 54079b8.

📒 Files selected for processing (7)
  • examples/vllm_serve/sparse_attn_worker.py
  • examples/vllm_serve/vllm_serve_sparse_attn.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
  • pyproject.toml
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py
✅ Files skipped from review due to trivial changes (2)
  • modelopt/torch/sparsity/attention_sparsity/plugins/init.py
  • pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/vllm_serve/vllm_serve_sparse_attn.py

📝 Walkthrough

Walkthrough

Adds sparse-attention support end-to-end: new vLLM worker variants that load sparse/quant configs and patch Attention impls, a ModelOpt vLLM sparse backend, Triton flash-attention paged KV-cache support, an example server script, and GPU tests validating paged-KV behavior.

Changes

Cohort / File(s) Summary
Examples: workers & server
examples/vllm_serve/sparse_attn_worker.py, examples/vllm_serve/vllm_serve_sparse_attn.py
New SparseAttnWorker and SparseQuantWorker that read SPARSE_ATTN_CFG / SPARSE_CALIB_CONFIG_PATH and patch vLLM Attention impls; new server launcher that selects worker based on env/quant presence and propagates env vars to Ray workers. Review attention patch logic and env handling.
Triton attention kernel (paged KV)
modelopt/torch/kernels/triton_fa.py
Extended triton flash-attention kernel to support paged KV-cache: paged K/V tile loaders, block_table/page_size args, branching for paged vs contiguous loads; public attention(...) API extended with k_cache, v_cache, block_table, page_size. Focus review on kernel branches and new API/backward return placeholders.
vLLM sparse backend plugin
modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
New ModelOptSparseAttentionImpl and ModelOptSparseAttentionBackend to invoke the paged Triton kernel with per-layer sparse kwargs; unpacks paged kv-cache and maps FlashAttentionMetadata into kernel inputs. Inspect forward path, metadata translation, and sparse_kw usage.
Plugin init helper import
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
Added import of import_plugin from utilities.
Tests: GPU paged KV
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py
New CUDA/Triton-gated tests exercising paged KV path vs contiguous reference, NaN/Inf checks, variable lengths, multiple page sizes, sparsity interaction, and decode-mode. Review helpers that build k_cache/v_cache and block_table.
Config: linting
pyproject.toml
Added vllm to isort known-third-party list.

Sequence Diagrams

sequenceDiagram
    participant Env as Env / JSON
    participant Worker as vLLM Worker
    participant Model as Model Loader
    participant Plugin as ModelOpt Plugin
    participant Attention as Attention Modules

    Env->>Worker: provide SPARSE_ATTN_CFG / SPARSE_CALIB_CONFIG_PATH / QUANT env
    Worker->>Model: load_model()
    Model-->>Worker: model with Attention modules
    Worker->>Plugin: set_sparse_config(sparse_cfg)
    Worker->>Attention: traverse & replace impls with ModelOptSparseAttentionImpl
    Attention-->>Worker: patched layers count
Loading
sequenceDiagram
    participant Input as Forward Input (query, kv, metadata)
    participant Impl as ModelOptSparseAttentionImpl
    participant Cache as Paged KV Cache (k_cache/v_cache, block_table)
    participant Kernel as triton_fa.attention
    participant Output as Output Tensor

    Input->>Impl: forward(query, key, value, kv_cache, attn_metadata)
    Impl->>Cache: unpack paged k_cache, v_cache, block_table, page_size
    Impl->>Impl: construct sparse_kw from per-layer config
    Impl->>Kernel: call attention(..., k_cache, v_cache, block_table, page_size, sparse_kw)
    Kernel-->>Impl: attention result
    Impl->>Output: write and return result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the main objective: adding vLLM integration for ModelOpt sparse attention. It directly reflects the core feature introduced across the multiple files in this changeset.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed No security anti-patterns found: no torch.load with weights_only=False, no numpy.load with allow_pickle=True, no eval/exec, no trust_remote_code=True, no nosec comments. JSON deserialization from user paths is safe.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kaix/sparse_attn_vllm_integration

Comment @coderabbitai help to get the list of available commands and usage tips.

@kaix-nv kaix-nv changed the title Add vLLM integration for modelopt sparse attention [4/n] Add vLLM integration for modelopt sparse attention Mar 27, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 27, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1127/

Built to branch gh-pages at 2026-03-31 00:50 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 27, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.19%. Comparing base (74a8694) to head (54079b8).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1127      +/-   ##
==========================================
- Coverage   70.21%   70.19%   -0.03%     
==========================================
  Files         230      230              
  Lines       26073    26073              
==========================================
- Hits        18308    18302       -6     
- Misses       7765     7771       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attn_vllm_integration branch 6 times, most recently from 26c6b3b to e4c4680 Compare March 28, 2026 23:05
@kaix-nv kaix-nv marked this pull request as ready for review March 30, 2026 21:05
@kaix-nv kaix-nv requested review from a team as code owners March 30, 2026 21:05
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (1)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py (1)

290-336: Assert decode correctness, not just finiteness.

This case still passes if paged decode reads the wrong blocks or masks the wrong keys, because it only checks shape/NaNs. Please compare out against a contiguous decode reference here as well.

As per coding guidelines, "Write tests using pytest for all new features and examples; organize tests into tests/unit (fast CPU-based), tests/gpu (fast GPU-based), tests/gpu_megatron (Megatron-Core), tests/gpu_trtllm (TensorRT-LLM), and tests/examples (integration tests)" and "All test coverage checks in PRs must pass for new features and examples."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py` around
lines 290 - 336, The test currently only checks shape/NaNs for paged decode; add
a correctness assertion by computing a contiguous-reference decode (call the
same attention function with the original k_flat/v_flat and without
k_cache/v_cache/block_table/page_size, i.e., the non-paged code path) and
compare outputs from test_paged_decode to that reference using
torch.testing.assert_allclose (or torch.allclose with a small rtol/atol) to
ensure values match; keep the existing shape and NaN checks and reuse q_flat,
k_flat, v_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k, and
scale so the only difference is paged vs contiguous.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 137-140: The worker assumes a different kv_cache axis order than
the vllm plugin; instead of slicing kv_cache[:, 0] and kv_cache[:, 1], normalize
to the same layout used elsewhere by splitting with kv_cache.unbind(0) so
k_cache and v_cache come from that unbind; update page_size to derive from the
resulting k_cache shape (e.g., k_cache.shape[1]) and ensure any downstream uses
of k_cache/v_cache match this normalized layout (references: kv_cache, k_cache,
v_cache and the existing kv_cache.unbind(0) usage in the vllm plugin).
- Around line 297-300: The import inside compile_or_warm_up_model currently uses
a relative import ("from .fakequant_worker import _fakequant_run_prolog_worker,
quant_config") which fails when the module is loaded as a top-level module;
change it to a top-level/absolute import or dynamic import so the code works
whether loaded as a package or directly (e.g., use "from fakequant_worker import
_fakequant_run_prolog_worker, quant_config" or use
importlib.import_module("fakequant_worker") and grab the attributes). Update the
import used by compile_or_warm_up_model and any callers expecting
_fakequant_run_prolog_worker and quant_config accordingly so the class
SparseQuantWorker can be imported as a top-level module without ImportError.
- Around line 273-284: The replacement sets sliding_window=None which disables
local/sliding-window attention; update the instantiation of
ModelOptSparseAttentionImpl (the assignment to module.impl) to pass through the
original value (old_impl.sliding_window) instead of None, or add a guard that
rejects/raises when old_impl.sliding_window is non-None and unsupported; ensure
you reference old_impl.sliding_window and ModelOptSparseAttentionImpl in the
change so sliding-window behavior is preserved or explicitly handled.
- Around line 176-183: The decode call to triton_attention in
sparse_attn_worker.py is using is_causal=True which causes incorrect masking for
paged KV; change the triton_attention invocation (the call that sets
q=query[offset: offset+nd], k=query[:0], v=query[:0],
b_start_loc=dm.query_start_loc, b_seq_len=..., max_input_len=1) to pass
is_causal=False instead, matching the decode path in
modelopt/torch/kernels/hf_triton_attention.py so later cached KV tiles are not
truncated.

In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 945-958: The code currently allows paged-mode (is_paged True) but
does not disable/guard autograd, causing backward to recompute using contiguous
K/V and dummy b_start_loc_k and produce wrong gradients; update the forward
entrypoint that sets is_paged (look for the block using is_paged, b_start_loc_k
and b_start_loc) to explicitly disallow autograd in paged mode by either raising
a clear exception when torch.is_grad_enabled() (or when requires_grad on inputs)
and is_paged is True, or by wrapping the paged-mode path in torch.no_grad() and
documenting that backward is unsupported; ensure the guard references is_paged
and b_start_loc_k so callers cannot silently run backward with dummy
b_start_loc_k.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py`:
- Around line 81-99: The current loop builds a single sparse_kw from the first
enabled entry in _sparse_config["sparse_cfg"] and breaks, which ignores
layer-specific patterns; instead, for each module use name-based matching to
select the correct layer_cfg (reuse the matching logic from
examples/vllm_serve/sparse_attn_worker.py::_match_sparse_config or call that
helper) and then build a per-module sparse_kw from that matched layer_cfg
(respecting fields like sparsity_n, sparsity_m, num_sink_tokens,
dense_window_size, skip_softmax_threshold); do not break out of the loop—apply
the matched config only to the current module or stash sparse_kw on the module
instance before swapping implementations so multiple patterns in the calibration
file are handled correctly.

---

Nitpick comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py`:
- Around line 290-336: The test currently only checks shape/NaNs for paged
decode; add a correctness assertion by computing a contiguous-reference decode
(call the same attention function with the original k_flat/v_flat and without
k_cache/v_cache/block_table/page_size, i.e., the non-paged code path) and
compare outputs from test_paged_decode to that reference using
torch.testing.assert_allclose (or torch.allclose with a small rtol/atol) to
ensure values match; keep the existing shape and NaN checks and reuse q_flat,
k_flat, v_flat, b_start_loc_q, b_seq_len_q, b_start_loc_k, b_seq_len_k, and
scale so the only difference is paged vs contiguous.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b27921dd-dcd2-4ce3-acc9-28816a998e1f

📥 Commits

Reviewing files that changed from the base of the PR and between 24ceba6 and e4c4680.

📒 Files selected for processing (7)
  • examples/vllm_serve/sparse_attn_worker.py
  • examples/vllm_serve/vllm_serve_sparse_attn.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
  • pyproject.toml
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py

Comment on lines +945 to +958
is_paged = k_cache is not None

# Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V.
# Decode: K/V is a separate KV cache tensor, caller must pass explicit metadata.
if b_seq_len_k is None:
b_seq_len_k = b_seq_len
b_start_loc_k = b_start_loc
max_input_len_k = max_input_len

# Paged mode: b_start_loc_k may be None (KV is in paged cache, not contiguous).
# Provide a dummy tensor so Triton can compile the tl.load (it won't be used).
if b_start_loc_k is None:
b_start_loc_k = torch.zeros_like(b_start_loc)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject paged-mode autograd until backward is implemented.

is_paged only changes the forward kernel. Backward still recomputes from contiguous K/V and b_start_loc_k, which are dummy/unused in paged mode, so a backward pass can return incorrect gradients instead of failing fast.

Possible guard
         is_paged = k_cache is not None
+        if is_paged and (q.requires_grad or k.requires_grad or v.requires_grad):
+            raise NotImplementedError(
+                "Paged KV cache path is forward-only; backward is not implemented."
+            )
 
         # Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 945 - 958, The code
currently allows paged-mode (is_paged True) but does not disable/guard autograd,
causing backward to recompute using contiguous K/V and dummy b_start_loc_k and
produce wrong gradients; update the forward entrypoint that sets is_paged (look
for the block using is_paged, b_start_loc_k and b_start_loc) to explicitly
disallow autograd in paged mode by either raising a clear exception when
torch.is_grad_enabled() (or when requires_grad on inputs) and is_paged is True,
or by wrapping the paged-mode path in torch.no_grad() and documenting that
backward is unsupported; ensure the guard references is_paged and b_start_loc_k
so callers cannot silently run backward with dummy b_start_loc_k.

kaix-nv added 4 commits March 30, 2026 17:44
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)

145-152: ⚠️ Potential issue | 🟠 Major

Preserve existing sliding_window when replacing attention impl

Line 151 hardcodes sliding_window=None, which can silently change local/sliding-window attention behavior. Pass through old_impl.sliding_window (or explicitly reject unsupported non-None values).

Proposed fix
             module.impl = ModelOptSparseAttentionImpl(
                 num_heads=old_impl.num_heads,
                 head_size=old_impl.head_size,
                 scale=old_impl.scale,
                 num_kv_heads=old_impl.num_kv_heads,
                 alibi_slopes=old_impl.alibi_slopes,
-                sliding_window=None,
+                sliding_window=old_impl.sliding_window,
                 kv_cache_dtype=old_impl.kv_cache_dtype,
                 logits_soft_cap=old_impl.logits_soft_cap,
                 attn_type=old_impl.attn_type,
                 kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 145 - 152, The
construction of ModelOptSparseAttentionImpl currently hardcodes
sliding_window=None which can change attention behavior; update the initializer
in the replacement code to pass through old_impl.sliding_window (i.e., use
sliding_window=old_impl.sliding_window) or, if non-None values are unsupported,
explicitly check old_impl.sliding_window and raise an informative error before
creating ModelOptSparseAttentionImpl; reference ModelOptSparseAttentionImpl,
old_impl, and old_impl.sliding_window to locate and fix the code.
🧹 Nitpick comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)

111-119: Remove or wire _match_sparse_config to avoid dead-path drift

_match_sparse_config is currently unused, which makes behavior harder to reason about and can drift from real matching logic. Either use it in patching/selection flow or remove it until needed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 111 - 119, The helper
function _match_sparse_config is unused and creates dead-path drift; either
remove this function or wire it into the sparse patch/selection flow by
replacing the current pattern-matching logic with a call to
_match_sparse_config(module_name, sparse_cfg) (or call it from wherever sparse
layer configs are looked up) so that matching behavior is centralized; update
any callers that currently duplicate pattern checks to use _match_sparse_config
and remove dead duplicates if you choose to keep it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 81-83: The code currently returns whatever getattr(mtsa, cfg_name,
None) yields (cfg) and may hand back non-mapping objects; update the getter to
validate that cfg is a dict before returning (use isinstance(cfg, dict)) and
otherwise raise a clear error (e.g., ValueError) indicating that the symbol
named by cfg_name in modelopt.torch.sparsity.attention_sparsity must be a dict;
reference the getattr(mtsa, cfg_name, None) call and the cfg variable to locate
the change.
- Around line 92-106: The _load_sparse_config function currently trusts
arbitrary JSON from SPARSE_CALIB_CONFIG_PATH; update it to validate the loaded
object and each layer_cfg: assert the top-level JSON is a dict, allowed
top-level keys are strings and either "calibration" or pattern names, and each
layer_cfg is a dict before applying defaults; enforce allowed keys (e.g.,
"method", "backend", "enable", numeric sparsity params) and bounds for numeric
fields (e.g., sparsity percentages 0–100, integer layer indices >=0, and
reasonable max limits) and reject or clamp out-of-range values, raising a clear
exception on invalid schema; keep the existing defaults
(method="triton_sparse_softmax", backend="triton", enable=True) for valid
entries and ensure sparse_cfg["default"] = {"enable": False} remains set.

---

Duplicate comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 145-152: The construction of ModelOptSparseAttentionImpl currently
hardcodes sliding_window=None which can change attention behavior; update the
initializer in the replacement code to pass through old_impl.sliding_window
(i.e., use sliding_window=old_impl.sliding_window) or, if non-None values are
unsupported, explicitly check old_impl.sliding_window and raise an informative
error before creating ModelOptSparseAttentionImpl; reference
ModelOptSparseAttentionImpl, old_impl, and old_impl.sliding_window to locate and
fix the code.

---

Nitpick comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 111-119: The helper function _match_sparse_config is unused and
creates dead-path drift; either remove this function or wire it into the sparse
patch/selection flow by replacing the current pattern-matching logic with a call
to _match_sparse_config(module_name, sparse_cfg) (or call it from wherever
sparse layer configs are looked up) so that matching behavior is centralized;
update any callers that currently duplicate pattern checks to use
_match_sparse_config and remove dead duplicates if you choose to keep it.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: cf1d8866-f2b3-4adf-9b4a-eb396967c80e

📥 Commits

Reviewing files that changed from the base of the PR and between e4c4680 and 4644bf5.

📒 Files selected for processing (1)
  • examples/vllm_serve/sparse_attn_worker.py

Comment on lines +81 to +83
cfg = getattr(mtsa, cfg_name, None)
if cfg is not None:
return cfg
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate preset object type before returning it

If SPARSE_ATTN_CFG matches a non-dict symbol in modelopt.torch.sparsity.attention_sparsity, cfg is returned as-is and later consumed as a mapping, which can crash at runtime. Add an explicit isinstance(cfg, dict) guard and fail fast with a clear error.

Proposed fix
     cfg = getattr(mtsa, cfg_name, None)
     if cfg is not None:
-        return cfg
+        if not isinstance(cfg, dict):
+            raise ValueError(
+                f"Invalid sparse config preset '{cfg_name}': expected dict, got {type(cfg).__name__}."
+            )
+        return cfg
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 81 - 83, The code
currently returns whatever getattr(mtsa, cfg_name, None) yields (cfg) and may
hand back non-mapping objects; update the getter to validate that cfg is a dict
before returning (use isinstance(cfg, dict)) and otherwise raise a clear error
(e.g., ValueError) indicating that the symbol named by cfg_name in
modelopt.torch.sparsity.attention_sparsity must be a dict; reference the
getattr(mtsa, cfg_name, None) call and the cfg variable to locate the change.

Comment on lines +92 to +106
def _load_sparse_config(path: str) -> dict:
"""Load offline calibration config JSON."""
with open(path) as f:
calib_cfg = json.load(f)

sparse_cfg = {}
for pattern, layer_cfg in calib_cfg.items():
if pattern == "calibration":
sparse_cfg[pattern] = layer_cfg
continue
layer_cfg.setdefault("method", "triton_sparse_softmax")
layer_cfg.setdefault("backend", "triton")
layer_cfg.setdefault("enable", True)
sparse_cfg[pattern] = layer_cfg
sparse_cfg["default"] = {"enable": False}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Harden calibration JSON parsing with schema and bounds checks

SPARSE_CALIB_CONFIG_PATH is env-driven input, but Line 94–106 accepts arbitrary JSON structure and values without validation. This can propagate malformed sparsity params into kernel calls and create avoidable failure/DoS risk. Validate top-level/object types, allowed keys, and integer ranges before applying defaults.

Proposed fix
 def _load_sparse_config(path: str) -> dict:
     """Load offline calibration config JSON."""
-    with open(path) as f:
+    with open(path, encoding="utf-8") as f:
         calib_cfg = json.load(f)
+    if not isinstance(calib_cfg, dict):
+        raise ValueError("Calibration config must be a JSON object mapping pattern -> layer config.")
 
     sparse_cfg = {}
     for pattern, layer_cfg in calib_cfg.items():
+        if not isinstance(pattern, str):
+            raise ValueError("Calibration config keys must be strings.")
         if pattern == "calibration":
             sparse_cfg[pattern] = layer_cfg
             continue
+        if not isinstance(layer_cfg, dict):
+            raise ValueError(f"Layer config for pattern '{pattern}' must be an object.")
+        for int_key in ("sparsity_n", "sparsity_m", "num_sink_tokens", "dense_window_size"):
+            if int_key in layer_cfg and (
+                not isinstance(layer_cfg[int_key], int) or layer_cfg[int_key] < 0
+            ):
+                raise ValueError(f"Invalid '{int_key}' for pattern '{pattern}': {layer_cfg[int_key]!r}")
         layer_cfg.setdefault("method", "triton_sparse_softmax")
         layer_cfg.setdefault("backend", "triton")
         layer_cfg.setdefault("enable", True)
         sparse_cfg[pattern] = layer_cfg

As per coding guidelines, "Validate inputs and enforce limits to reduce resource-exhaustion/DoS risk (e.g., file sizes, expected schema/shape for sparse config/calibration JSON)."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 92 - 106, The
_load_sparse_config function currently trusts arbitrary JSON from
SPARSE_CALIB_CONFIG_PATH; update it to validate the loaded object and each
layer_cfg: assert the top-level JSON is a dict, allowed top-level keys are
strings and either "calibration" or pattern names, and each layer_cfg is a dict
before applying defaults; enforce allowed keys (e.g., "method", "backend",
"enable", numeric sparsity params) and bounds for numeric fields (e.g., sparsity
percentages 0–100, integer layer indices >=0, and reasonable max limits) and
reject or clamp out-of-range values, raising a clear exception on invalid
schema; keep the existing defaults (method="triton_sparse_softmax",
backend="triton", enable=True) for valid entries and ensure
sparse_cfg["default"] = {"enable": False} remains set.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attn_vllm_integration branch from 4644bf5 to 54079b8 Compare March 31, 2026 00:46
@kaix-nv kaix-nv requested review from Edwardf0t1 and jingyu-ml March 31, 2026 01:49
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary: Adds paged KV cache support to the ModelOpt Triton flash attention kernel, a vLLM sparse attention plugin, worker classes for vLLM integration, and GPU tests for the paged KV path.

Issues Found:

  1. [Correctness] Backward pass silently broken for paged KVtriton_fa.py, forward() saves k, v to ctx for backward, but in paged mode these are dummy/empty tensors (e.g., k_dummy = torch.empty(0, ...) from the vLLM plugin). If .backward() is ever called on paged-mode output, gradients for dK/dV will be computed against the dummy tensors, producing silently incorrect results. The backward should either raise NotImplementedError("Backward not supported for paged KV cache") when is_paged=True, or the limitation should be clearly documented in the attention() docstring. The current code just adds None return placeholders for the 4 new args without any guard.

  2. [Correctness] Unused import_plugin import in plugins/__init__.py — The diff adds from modelopt.torch.utils import import_plugin but it is never used in the file. The existing __init__.py doesn't use this import, and vllm.py imports directly from vLLM. This is dead code that should be removed, or if the intent was to use import_plugin for conditional vLLM import (as other plugin modules do), that wiring is missing.

  3. [Correctness] _build_sparse_config fallback logic is confusingsparse_attn_worker.py:78-86: getattr(mtsa, cfg_name, None) is tried first, but then SPARSE_SOFTMAX_DEFAULT falls through to the hardcoded _DEFAULT_SPARSE_CFG dict. If mtsa actually defines SPARSE_SOFTMAX_DEFAULT in the future, the getattr path would return it and the hardcoded default would never be used, leading to silent behavior divergence. Consider making the precedence explicit or removing the duplication.

  4. [Readability] Duplicated paged V-tile loadingtriton_fa.py _attn_fwd kernel: _load_paged_v_tile is called with identical arguments in two separate branches (skip-softmax path around line 476 and standard path around line 518). While Triton JIT constraints may require this, the duplicated 20-line call blocks are a readability concern. A comment explaining why the duplication is necessary would help.

  5. [Tests] No backward/gradient test for paged modetest_triton_fa_paged.py only tests forward correctness. Given that paged mode changes the autograd Function's forward signature and backward is not updated to support paged KV, there should be at minimum a test asserting that backward raises an error (if a guard is added per issue #1), or this should be explicitly documented as inference-only.

  6. [Tests] No integration test for vLLM pluginModelOptSparseAttentionImpl.forward() and ModelOptSparseAttentionBackend have no test coverage. These are the most integration-critical new classes. Even a mock-based unit test validating the metadata translation logic (cu_seqlens_qb_start_loc, seq_lensb_seq_len_k) would catch regressions.

  7. [Tests] Inconsistent b_start_loc_k handling across teststest_paged_matches_contiguous passes explicit b_start_loc_k=locs_k, while test_paged_no_nan omits it (relying on the dummy-zeros fallback). Both should use the same calling convention to avoid masking bugs in the fallback path.

  8. [Readability] if threshold: truthiness checksparse_attn_worker.py:146: if threshold: evaluates False for both None and 0.0. While 0.0 correctly means "disabled", this is subtle. if threshold is not None would be clearer about intent (let the kernel handle the 0.0 case).

Suggestions:

  • Consider adding requires_grad=False to k_dummy and v_dummy in the vLLM plugin to make the inference-only intent explicit and catch accidental backward calls early.
  • The _DEFAULT_SPARSE_CFG hardcoded in the worker could reference mtsa constants instead, reducing drift risk.
  • test_paged_decode uses q_flat with shape [batch, num_heads, head_dim] (3D) rather than the expected [total_q_tokens, num_heads, head_dim]. This works because batch * 1 = batch tokens, but the shape semantics are confusing — using q_flat.reshape(batch, num_heads, head_dim) explicitly or adding a comment would clarify.

Overall Assessment: The core kernel extension (paged KV tile loaders + IS_PAGED branching) is well-structured and the tests verify forward correctness across multiple configurations. The main blocking concern is the silent backward-pass breakage for paged mode — this needs at minimum a guard or documentation since the function is part of the public attention() API with autograd support. The unused import is a minor cleanup. The vLLM plugin lacks test coverage but is in examples/ territory so is less critical.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants