Skip to content

Hybrid Gate Selector: 3x speedup over Transformer baseline (T4 GPU benchmark) #889

@LuisAlbertoMK

Description

@LuisAlbertoMK

Motivation

Transformers have quadratic memory cost with sequence length.
This proposal routes tokens dynamically: low-importance tokens
go through a linear SSM (Mamba), critical tokens through full Attention.
The gate is learned, not static — tokens are redirected, not discarded.

Architecture

Input tokens
     ↓
[Gate Selector]  ← small learned network (Linear → GELU → Linear → Sigmoid)
     ↓
score < 0.5 → Mamba SSM   (linear cost, ~80% of tokens)
score ≥ 0.5 → Attention   (quadratic cost, ~20% of tokens)
     ↓
[Weighted fusion by gate score]
     ↓
Output

Benchmark — GPU T4, PyTorch 2.5.1+cu121

seq_len Hybrid Transformer baseline Speedup
128 3.20ms 10.43ms 3.26x
512 7.63ms 25.28ms 3.31x
2048 21.68ms 49.87ms 2.30x

Note: Mamba block is currently a linear placeholder.
Speedup expected to increase further with mamba-ssm real SSM block.

Code

Gate selector core (full code available on request):

class GateSelector(nn.Module):
    def __init__(self, d_model, threshold=0.5):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(d_model, d_model//4, bias=False),
            nn.GELU(),
            nn.Linear(d_model//4, 1, bias=False))
    def forward(self, x):
        scores = torch.sigmoid(self.gate(x))
        mask = scores.squeeze(-1) >= self.threshold
        return scores, mask

Next steps

  • Replace placeholder with real mamba-ssm block
  • Train gate on downstream task to learn optimal token routing
  • Benchmark on long-context tasks (≥4096 tokens)

Discussion

Happy to collaborate or provide full implementation.
This could integrate as an optional hybrid layer in existing Mamba architectures.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions