You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importtorchfromwide_compiler.core.blocksimportWideMLP, WideAttention# Create N separate MLP blocksmlps= [...] # Your N MLP moduleswide_mlp=WideMLP.from_modules(mlps, strategy='fused')
# Input: N-first format [N, B, T, D]x=torch.randn(8, 4, 128, 256).cuda()
out=wide_mlp(x) # [8, 4, 128, 256]# Create N attention blocksattns= [...] # Your N attention moduleswide_attn=WideAttention.from_modules(attns, strategy='fused')
# With RoPE embeddingsrope=torch.randn(128, 64).cuda() # Positional embeddingsout=wide_attn(x, rope=rope) # [8, 4, 128, 256]
Using Wide Primitives Directly (N-first format)
importtorchfromwide_compiler.core.primitivesimportWideRMSNorm, WideLinear# Create N separate RMSNorm layersnorms= [torch.nn.RMSNorm(256).cuda() for_inrange(8)]
wide_norm=WideRMSNorm.from_modules(norms, strategy='batched')
# Input: N-first format [N, B, T, D]x=torch.randn(8, 4, 128, 256).cuda()
out=wide_norm(x) # [8, 4, 128, 256]
Note: Wide primitives use N-first format[N, B, ...]. For automatic packing/unpacking with channel-packed format, use TracedWideModel (see above).
API
Main Entry Point
importwide_compiler# From list of modelswide=wide_compiler.compile(models, sample_input)
# From single model (creates N copies with different weights)wide=wide_compiler.compile(MyModel(), sample_input, n=100)
# With torch.compile enabledwide=wide_compiler.compile(models, sample_input, compile_model=True)
# With validationwide=wide_compiler.compile(models, sample_input, validate=True)
# With configconfig=wide_compiler.WideConfig.fast()
wide=wide_compiler.compile(models, sample_input, config=config)
fromwide_compilerimportWideConfig# Presetsconfig=WideConfig.default() # Basic, no compileconfig=WideConfig.fast() # Compiled, no validationconfig=WideConfig.debug() # Verbose, strictconfig=WideConfig.safe() # With validation# Customconfig=WideConfig(
compile=True,
compile_mode='reduce-overhead', # 'default', 'max-autotune'validate=True,
validate_rtol=1e-3,
debug=True,
)
CLI
Benchmark Primitives (v0.7.0 - 24 primitives)
# Benchmark specific primitive
wide_compiler benchmark rmsnorm -p quick
wide_compiler benchmark dropout -p quick
wide_compiler benchmark multiheadcrossattention -p quick
# All 24 primitives:# linear, conv1d, conv2d, conv3d, convtranspose1d, convtranspose2d,# batchnorm1d, batchnorm2d, batchnorm3d, layernorm, groupnorm,# instancenorm2d, rmsnorm, ada_layer_norm_zero_single,# embedding, mlp_embedder, attention, multiheadcrossattention,# gru, lstm, rnn, prelu, dropout, adaptiveavgpool2d# Benchmark blocks (5 total)
wide_compiler benchmark mlp_block -p quick
wide_compiler benchmark attention_block -p quick
wide_compiler benchmark joint_attention -p quick
wide_compiler benchmark double_stream_block -p quick
wide_compiler benchmark single_stream_block -p quick
# With presets
wide_compiler benchmark rmsnorm -p quick # Quick (fewer configs)
wide_compiler benchmark rmsnorm -p full # Full sweep (default)
wide_compiler benchmark rmsnorm -p ci # CI preset (minimal)# With torch.compile
wide_compiler benchmark rmsnorm -p quick -c
# Other options
wide_compiler benchmark rmsnorm -t 20 # Show top 20 results
wide_compiler benchmark rmsnorm -s # Auto-save with timestamp
wide_compiler benchmark rmsnorm -o results.json # Save to specific file
Run Test Suite
# Run all 29 components (24 primitives + 5 blocks)
python test_cases.py
# Primitives only
python test_cases.py --primitives
# Blocks only
python test_cases.py --blocks
# With different presets
python test_cases.py --preset full
Other Commands
# Run correctness tests
wide_compiler test# Show FX trace for built-in models
wide_compiler trace -m mlp
wide_compiler trace -m resblock
# Show library info
wide_compiler info
Supported Layers (v0.7.0 - 24 primitives)
Linear & Embedding
Layer
Wide Version
I/O Format
Strategies
Best Speedup
nn.Embedding
WideEmbedding
[N,B,T]→[N,B,T,D]
indexed, gather, sequential
27.1x @ N=32
MLPEmbedder
WideMLPEmbedder
[N,B,D]→[N,B,Dout]
fused, sequential
14.8x @ N=32
nn.Linear
WideLinear
[N,B,...,Din]→[N,B,...,Dout]
einsum, sequential
8.8x @ N=32
Convolution Layers
Layer
Wide Version
I/O Format
Strategies
Best Speedup
nn.Conv1d
WideConv1d
[N,B,C,L]→[N,B,Cout,Lout]
grouped, sequential
12.1x @ N=32
nn.Conv2d
WideConv2d
[N,B,C,H,W]→[N,B,Cout,Hout,Wout]
grouped, channels_last, sequential
6.2x @ N=32
nn.ConvTranspose2d
WideConvTranspose2d
[N,B,C,H,W]→[N,B,Cout,Hout,Wout]
grouped, channels_last, sequential
5.7x @ N=32
nn.ConvTranspose1d
WideConvTranspose1d
[N,B,C,L]→[N,B,Cout,Lout]
grouped, sequential
5.3x @ N=32
nn.Conv3d
WideConv3d
[N,B,C,D,H,W]→[N,B,Cout,Dout,Hout,Wout]
grouped, sequential
4.4x @ N=16
Normalization Layers
Layer
Wide Version
I/O Format
Strategies
Best Speedup
nn.BatchNorm1d
WideBatchNorm1d
[N,B,C]→[N,B,C]
wide
36.7x @ N=32
nn.BatchNorm2d
WideBatchNorm2d
[N,B,C,H,W]→[N,B,C,H,W]
wide
35.8x @ N=32
nn.BatchNorm3d
WideBatchNorm3d
[N,B,C,D,H,W]→[N,B,C,D,H,W]
wide
23.5x @ N=32
nn.InstanceNorm2d
WideInstanceNorm2d
[N,B,C,H,W]→[N,B,C,H,W]
fused, sequential
21.3x @ N=32
nn.RMSNorm
WideRMSNorm
[N,B,...,D]→[N,B,...,D]
batched, sequential
20.8x @ N=32
AdaLayerNormZeroSingle
WideAdaLayerNormZeroSingle
[N,B,D],[N,B,Demb]→[N,B,D],gate
fused, sequential
15.2x @ N=16
nn.GroupNorm
WideGroupNorm
[N,B,C,...]→[N,B,C,...]
fused, sequential
12.9x @ N=32
nn.LayerNorm
WideLayerNorm
[N,B,...,D]→[N,B,...,D]
wide
9.8x @ N=32
Attention Layers
Layer
Wide Version
I/O Format
Strategies
Best Speedup
MultiheadCrossAttention
WideMultiheadCrossAttention
[N,B,Tq,D],[N,B,Tkv,D]→[N,B,Tq,D]
fused, sequential
17.8x @ N=32
nn.MultiheadAttention
WideAttention
[N,B,T,D]→[N,B,T,D]
fused, sequential
9.6x @ N=32
RNN Layers
Layer
Wide Version
I/O Format
Strategies
Best Speedup
nn.RNN
WideRNN
[N,B,T,Din]→[N,B,T,H],[N,B,H]
fused, sequential
5.6x @ N=32
nn.LSTM
WideLSTM
[N,B,T,Din]→[N,B,T,H],[N,B,H],[N,B,H]
fused, sequential
3.3x @ N=32
nn.GRU
WideGRU
[N,B,T,Din]→[N,B,T,H],[N,B,H]
fused, sequential
2.9x @ N=32
Other Layers
Layer
Wide Version
I/O Format
Strategies
Best Speedup
nn.Dropout
WideDropout
[N,B,...]→[N,B,...]
independent, shared, sequential
173.7x @ N=32
nn.AdaptiveAvgPool2d
WideAdaptiveAvgPool2d
[N,B,C,Hin,Win]→[N,B,C,Hout,Wout]
batched, sequential
15.2x @ N=32
nn.PReLU
WidePReLU
[N,B,C,...]→[N,B,C,...]
wide, sequential
12.0x @ N=32
F.relu, F.gelu, etc.
FunctionalOp
agnostic
—
—
+, -, *, /, @
BinaryOp
agnostic
—
—
All primitives operate on N-first format [N, B, ...] internally for optimal performance.
Benchmarks: A100 GPU, torch.compile (default mode), quick preset. See test_cases.py for full results.
Flux-Style Blocks (v0.7.0 - 5 blocks)
Higher-level composite blocks for transformer architectures.
# All primitives operate on [N, B, ...] format# Data flows through without any intermediate packing# Only reshape at boundaries (input/output)# Result: Maximum kernel fusion, minimum overhead