Skip to content

Fused Add + RMSNorm pattern#55

Open
AndreSlavescu wants to merge 2 commits intoDao-AILab:mainfrom
AndreSlavescu:fused-add-rmsnorm
Open

Fused Add + RMSNorm pattern#55
AndreSlavescu wants to merge 2 commits intoDao-AILab:mainfrom
AndreSlavescu:fused-add-rmsnorm

Conversation

@AndreSlavescu
Copy link
Copy Markdown

@AndreSlavescu AndreSlavescu commented Nov 9, 2025

Fused Add + RMSNorm pattern found in Qwen3 Decoder Layer.

https://github.com/huggingface/transformers/blob/0dc2df5ddafe3cb5824ad24e85beba13e0aa6726/src/transformers/models/qwen3/modeling_qwen3.py#L271

Hardware Setup:
H100 80GB PCI-e
CUDA Driver: 570.172.08
CUDA Runtime Version: 12.9

Benchmarks compared to Quack RMSNorm with residual path, Pytorch Eager baseline, and Pytorch Compile baseline:

Forward Bench:

python3 benchmarks/benchmark_fused_add_rmsnorm.py
=== Fused Add + RMSNorm Forward Benchmark ===
Tensor dimensions: [32768, 32768]
Input / residual dtype: torch.bfloat16
Input tensor shapes:
  residual      : torch.Size([32768, 32768]), dtype: torch.bfloat16
  hidden_states : torch.Size([32768, 32768]), dtype: torch.bfloat16
  weight        : torch.Size([32768]), dtype: torch.float32

Executing Fused Add + RMSNorm kernel...
Fused kernel execution time: 3.5321 ms
Fused kernel mem throughput: 1824.02 GB/s

Executing RMSNorm kernel with residual path...
RMSNorm kernel execution time: 4.7570 ms
RMSNorm kernel mem throughput: 1354.33 GB/s

Executing PyTorch eager reference...
PyTorch eager reference execution time: 28.2111 ms
PyTorch eager reference mem throughput: 228.37 GB/s

Executing PyTorch compiled reference...
PyTorch compiled reference execution time: 3.3787 ms
PyTorch compiled reference mem throughput: 1906.85 GB/s

Comparisons:
Fused Add RMSNorm Forward Kernel vs RMSNorm kernel with Residual Path:   1.35x speedup
Fused Add RMSNorm Forward Kernel vs PyTorch compiled baseline:   0.96x speedup
Fused Add RMSNorm Forward Kernel vs PyTorch eager baseline:   7.99x speedup

Backward Bench:

python3 benchmarks/benchmark_fused_add_rmsnorm.py --backward
=== Fused Add + RMSNorm Backward Benchmark ===
Tensor dimensions: [32768, 32768]
Input / residual dtype: torch.bfloat16

Executing fused backward kernel...
Fused backward kernel execution time: 4.2096 ms
Fused backward kernel mem throughput: 1530.52 GB/s

Executing PyTorch eager backward reference...
PyTorch eager backward execution time: 76.7777 ms
PyTorch eager backward mem throughput: 83.92 GB/s

Executing PyTorch compiled backward reference...
PyTorch compiled backward execution time: 10.8680 ms
PyTorch compiled backward mem throughput: 592.83 GB/s

Comparisons:
Fused Add RMSNorm Backward Kernel vs PyTorch eager backward:  18.24x speedup
Fused Add RMSNorm Backward Kernel vs PyTorch compiled backward:   2.58x speedup

@tridao
Copy link
Copy Markdown
Member

tridao commented Nov 10, 2025

It's probably easier to add this as an option in benchmark_rmsnorm.py (I think there might be an option for residual there already).

@AndreSlavescu
Copy link
Copy Markdown
Author

It's probably easier to add this as an option in benchmark_rmsnorm.py (I think there might be an option for residual there already).

I ran a benchmark against the residual path in the existing RMSNorm implementation (forward benchmark listed above), and the fused one seems to perform better. Would it still be valuable to have a separate fused kernel? My idea was to have some parity with what LigerKernel has:

https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_add_rms_norm.py

@tridao
Copy link
Copy Markdown
Member

tridao commented Nov 10, 2025

How;'s the new one dfferent from the existing kernel?
I think this code path is already implemented?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants