Skip to content

z-lab/flash-colreduce

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash-ColReduce

PyPI License Python 3.10+

Flash-ColReduce provides highly optimized Triton kernels for computing column-wise reductions of the attention matrix such as sum, mean, or max without materializing the full $O(N^2)$ attention weights.

This primitive is essential for KV-cache pruning, token importance estimation, and attention analysis in Large Language Models (LLMs) and Vision-Language Models (VLMs). It powers the visual token pruning in SparseVILA.

Highlights

  • 🚀 Efficient: Fused kernels compute column reductions in $O(N)$ memory.
  • 🧩 Flexible: Supports causal and non-causal attention with irregular shapes ($M \neq N$).
  • ✅ Exact: Uses online softmax for numerical precision and correct causal masking.

Prerequisites

  • Python: 3.10+
  • PyTorch: 2.1+ (with CUDA support)
  • Triton: 3.0.0+
  • GPU: NVIDIA GPU with Compute Capability 8.0+ (Ampere or newer recommended)

Installation

Install from PyPI:

pip install flash-colreduce

Or build from source:

git clone https://github.com/z-lab/flash-colreduce.git
cd flash-colreduce
pip install -e .

Usage

1. Non-Causal Attention

Compute a column-wise reduction of the attention matrix over the query dimension.

import torch
from flash_colreduce import flash_colreduce

q = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)
k = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)

flash_colreduce(q, k, reduction="sum")  # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="mean")  # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="max")  # Shape: (8, 16, 512)

2. Causal Attention

Handle autoregressive attention where $M \neq N$. The kernel applies a right-aligned causal mask matching KV-cached decoding behavior.

import torch
from flash_colreduce import flash_colreduce

q = torch.randn(1, 32, 128, 128, device="cuda", dtype=torch.float16)
k = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.float16)

flash_colreduce(q, k, reduction="sum", is_causal=True)  # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="mean", is_causal=True)  # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="max", is_causal=True)  # Shape: (1, 32, 4096)

Performance

Flash-ColReduce achieves significant speedups and memory savings over naïve implementations. By fusing softmax and reduction into a single kernel, it avoids writing the $B \times H \times M \times N$ attention matrix to GPU memory.

Benchmark Results on NVIDIA RTX Pro 6000 Blackwell Benchmarked on NVIDIA RTX Pro 6000 Blackwell with FP16 precision

Development

Running Tests

pip install -e ".[test]"
pytest -v

Running Benchmarks

pip install -e ".[bench]"
python benchmarks/run.py

Citation

If you use Flash-ColReduce in your research, please cite the SparseVILA paper:

@inproceedings{khaki2025sparsevila,
  title = {{SparseVILA: Decoupling Visual Sparsity for Efficient VLM Inference}},
  author = {Khaki, Samir and Guo, Junxian and Tang, Jiaming and Yang, Shang and Chen, Yukang and Plataniotis, Konstantinos N and Lu, Yao and Han, Song and Liu, Zhijian},
  booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  year = {2025}
}

License

MIT License

Acknowledgments

  • FlashAttention: The tiling and online softmax approach is heavily inspired by FlashAttention.
  • SparseVILA: The original project that motivated this primitive.

About

Fast, memory-efficient attention column reduction (e.g., sum, mean, max)

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages