-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Hybrid Gate Selector: 3x speedup over Transformer baseline (T4 GPU benchmark) #889
Copy link
Copy link
Open
Description
Motivation
Transformers have quadratic memory cost with sequence length.
This proposal routes tokens dynamically: low-importance tokens
go through a linear SSM (Mamba), critical tokens through full Attention.
The gate is learned, not static — tokens are redirected, not discarded.
Architecture
Input tokens
↓
[Gate Selector] ← small learned network (Linear → GELU → Linear → Sigmoid)
↓
score < 0.5 → Mamba SSM (linear cost, ~80% of tokens)
score ≥ 0.5 → Attention (quadratic cost, ~20% of tokens)
↓
[Weighted fusion by gate score]
↓
Output
Benchmark — GPU T4, PyTorch 2.5.1+cu121
| seq_len | Hybrid | Transformer baseline | Speedup |
|---|---|---|---|
| 128 | 3.20ms | 10.43ms | 3.26x |
| 512 | 7.63ms | 25.28ms | 3.31x |
| 2048 | 21.68ms | 49.87ms | 2.30x |
Note: Mamba block is currently a linear placeholder.
Speedup expected to increase further withmamba-ssmreal SSM block.
Code
Gate selector core (full code available on request):
class GateSelector(nn.Module):
def __init__(self, d_model, threshold=0.5):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(d_model, d_model//4, bias=False),
nn.GELU(),
nn.Linear(d_model//4, 1, bias=False))
def forward(self, x):
scores = torch.sigmoid(self.gate(x))
mask = scores.squeeze(-1) >= self.threshold
return scores, maskNext steps
- Replace placeholder with real
mamba-ssmblock - Train gate on downstream task to learn optimal token routing
- Benchmark on long-context tasks (≥4096 tokens)
Discussion
Happy to collaborate or provide full implementation.
This could integrate as an optional hybrid layer in existing Mamba architectures.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels