Flash-ColReduce provides highly optimized Triton kernels for computing column-wise reductions of the attention matrix such as sum, mean, or max without materializing the full
This primitive is essential for KV-cache pruning, token importance estimation, and attention analysis in Large Language Models (LLMs) and Vision-Language Models (VLMs). It powers the visual token pruning in SparseVILA.
-
🚀 Efficient: Fused kernels compute column reductions in
$O(N)$ memory. -
🧩 Flexible: Supports causal and non-causal attention with irregular shapes (
$M \neq N$ ). - ✅ Exact: Uses online softmax for numerical precision and correct causal masking.
- Python: 3.10+
- PyTorch: 2.1+ (with CUDA support)
- Triton: 3.0.0+
- GPU: NVIDIA GPU with Compute Capability 8.0+ (Ampere or newer recommended)
Install from PyPI:
pip install flash-colreduceOr build from source:
git clone https://github.com/z-lab/flash-colreduce.git
cd flash-colreduce
pip install -e .Compute a column-wise reduction of the attention matrix over the query dimension.
import torch
from flash_colreduce import flash_colreduce
q = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)
k = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)
flash_colreduce(q, k, reduction="sum") # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="mean") # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="max") # Shape: (8, 16, 512)Handle autoregressive attention where
import torch
from flash_colreduce import flash_colreduce
q = torch.randn(1, 32, 128, 128, device="cuda", dtype=torch.float16)
k = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.float16)
flash_colreduce(q, k, reduction="sum", is_causal=True) # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="mean", is_causal=True) # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="max", is_causal=True) # Shape: (1, 32, 4096)Flash-ColReduce achieves significant speedups and memory savings over naïve implementations. By fusing softmax and reduction into a single kernel, it avoids writing the
Benchmarked on NVIDIA RTX Pro 6000 Blackwell with FP16 precision
pip install -e ".[test]"
pytest -vpip install -e ".[bench]"
python benchmarks/run.pyIf you use Flash-ColReduce in your research, please cite the SparseVILA paper:
@inproceedings{khaki2025sparsevila,
title = {{SparseVILA: Decoupling Visual Sparsity for Efficient VLM Inference}},
author = {Khaki, Samir and Guo, Junxian and Tang, Jiaming and Yang, Shang and Chen, Yukang and Plataniotis, Konstantinos N and Lu, Yao and Han, Song and Liu, Zhijian},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2025}
}- FlashAttention: The tiling and online softmax approach is heavily inspired by FlashAttention.
- SparseVILA: The original project that motivated this primitive.