Compile-friendly batched model execution. Fuse N identical models into a single Wide model for massive speedups.
Instead of running N models sequentially:
outputs = [model(x) for model in models] # N kernel launchesFuse them into one:
import wide_compiler
wide = wide_compiler.compile(models, sample_input)
output = wide(packed_input) # 1 kernel launchSpeedups: 2-173x depending on model type, N, and compilation mode (A100, compiled).
- 24 Primitives - Added 11 new primitives: RMSNorm (21x), AdaLayerNormZero (15x), MLPEmbedder (15x), CrossAttention (18x), ConvTranspose1d/2d (5-6x), BatchNorm3d (24x), RNN (6x), PReLU (12x), Dropout (174x), AdaptiveAvgPool2d (15x)
- 5 Flux-Style Blocks - WideMLP, WideAttention, WideJointAttention, WideDoubleStreamBlock, WideSingleStreamBlock for transformer architectures
- Registry-Based TracedWideModel - Fully dynamic primitive lookup, no more hardcoded WIDE_BUILDERS dict
- RMSNorm Support - Native PyTorch RMSNorm with 21x speedup at N=32
- Comprehensive Benchmarks - All 29 components (24 primitives + 5 blocks) tested and validated
- I/O Shape Documentation - Clear input/output formats for every primitive and block
git clone https://github.com/AbstractEyes/pytorch-parallel-compiler
cd pytorch-parallel-compiler
pip install -e .import torch
import wide_compiler
# Define your model
class MLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(64, 128)
self.fc2 = torch.nn.Linear(128, 64)
def forward(self, x):
return self.fc2(torch.nn.functional.relu(self.fc1(x)))
# Create N models
models = [MLP().cuda() for _ in range(100)]
sample = torch.randn(1, 64).cuda()
# Compile to Wide model
wide = wide_compiler.compile(models, sample)
# Pack inputs and run
inputs = [torch.randn(32, 64).cuda() for _ in range(100)]
packed = wide_compiler.pack(inputs)
output = wide(packed)
# Unpack outputs
outputs = wide_compiler.unpack(output, n=100)import torch
from wide_compiler.core.blocks import WideMLP, WideAttention
# Create N separate MLP blocks
mlps = [...] # Your N MLP modules
wide_mlp = WideMLP.from_modules(mlps, strategy='fused')
# Input: N-first format [N, B, T, D]
x = torch.randn(8, 4, 128, 256).cuda()
out = wide_mlp(x) # [8, 4, 128, 256]
# Create N attention blocks
attns = [...] # Your N attention modules
wide_attn = WideAttention.from_modules(attns, strategy='fused')
# With RoPE embeddings
rope = torch.randn(128, 64).cuda() # Positional embeddings
out = wide_attn(x, rope=rope) # [8, 4, 128, 256]import torch
from wide_compiler.core.primitives import WideRMSNorm, WideLinear
# Create N separate RMSNorm layers
norms = [torch.nn.RMSNorm(256).cuda() for _ in range(8)]
wide_norm = WideRMSNorm.from_modules(norms, strategy='batched')
# Input: N-first format [N, B, T, D]
x = torch.randn(8, 4, 128, 256).cuda()
out = wide_norm(x) # [8, 4, 128, 256]Note: Wide primitives use N-first format [N, B, ...]. For automatic packing/unpacking with channel-packed format, use TracedWideModel (see above).
import wide_compiler
# From list of models
wide = wide_compiler.compile(models, sample_input)
# From single model (creates N copies with different weights)
wide = wide_compiler.compile(MyModel(), sample_input, n=100)
# With torch.compile enabled
wide = wide_compiler.compile(models, sample_input, compile_model=True)
# With validation
wide = wide_compiler.compile(models, sample_input, validate=True)
# With config
config = wide_compiler.WideConfig.fast()
wide = wide_compiler.compile(models, sample_input, config=config)wide = (wide_compiler.WideBuilder(models)
.with_sample(sample_input)
.validate()
.compile(mode='reduce-overhead')
.build())# Pack N inputs: List[Tensor] → Tensor [B, N*C, ...]
packed = wide_compiler.pack(inputs)
# Unpack output: Tensor [B, N*C, ...] → List[Tensor]
outputs = wide_compiler.unpack(output, n=100)from wide_compiler import WideConfig
# Presets
config = WideConfig.default() # Basic, no compile
config = WideConfig.fast() # Compiled, no validation
config = WideConfig.debug() # Verbose, strict
config = WideConfig.safe() # With validation
# Custom
config = WideConfig(
compile=True,
compile_mode='reduce-overhead', # 'default', 'max-autotune'
validate=True,
validate_rtol=1e-3,
debug=True,
)# Benchmark specific primitive
wide_compiler benchmark rmsnorm -p quick
wide_compiler benchmark dropout -p quick
wide_compiler benchmark multiheadcrossattention -p quick
# All 24 primitives:
# linear, conv1d, conv2d, conv3d, convtranspose1d, convtranspose2d,
# batchnorm1d, batchnorm2d, batchnorm3d, layernorm, groupnorm,
# instancenorm2d, rmsnorm, ada_layer_norm_zero_single,
# embedding, mlp_embedder, attention, multiheadcrossattention,
# gru, lstm, rnn, prelu, dropout, adaptiveavgpool2d
# Benchmark blocks (5 total)
wide_compiler benchmark mlp_block -p quick
wide_compiler benchmark attention_block -p quick
wide_compiler benchmark joint_attention -p quick
wide_compiler benchmark double_stream_block -p quick
wide_compiler benchmark single_stream_block -p quick
# With presets
wide_compiler benchmark rmsnorm -p quick # Quick (fewer configs)
wide_compiler benchmark rmsnorm -p full # Full sweep (default)
wide_compiler benchmark rmsnorm -p ci # CI preset (minimal)
# With torch.compile
wide_compiler benchmark rmsnorm -p quick -c
# Other options
wide_compiler benchmark rmsnorm -t 20 # Show top 20 results
wide_compiler benchmark rmsnorm -s # Auto-save with timestamp
wide_compiler benchmark rmsnorm -o results.json # Save to specific file# Run all 29 components (24 primitives + 5 blocks)
python test_cases.py
# Primitives only
python test_cases.py --primitives
# Blocks only
python test_cases.py --blocks
# With different presets
python test_cases.py --preset full# Run correctness tests
wide_compiler test
# Show FX trace for built-in models
wide_compiler trace -m mlp
wide_compiler trace -m resblock
# Show library info
wide_compiler info| Layer | Wide Version | I/O Format | Strategies | Best Speedup |
|---|---|---|---|---|
nn.Embedding |
WideEmbedding |
[N,B,T]→[N,B,T,D] |
indexed, gather, sequential | 27.1x @ N=32 |
MLPEmbedder |
WideMLPEmbedder |
[N,B,D]→[N,B,Dout] |
fused, sequential | 14.8x @ N=32 |
nn.Linear |
WideLinear |
[N,B,...,Din]→[N,B,...,Dout] |
einsum, sequential | 8.8x @ N=32 |
| Layer | Wide Version | I/O Format | Strategies | Best Speedup |
|---|---|---|---|---|
nn.Conv1d |
WideConv1d |
[N,B,C,L]→[N,B,Cout,Lout] |
grouped, sequential | 12.1x @ N=32 |
nn.Conv2d |
WideConv2d |
[N,B,C,H,W]→[N,B,Cout,Hout,Wout] |
grouped, channels_last, sequential | 6.2x @ N=32 |
nn.ConvTranspose2d |
WideConvTranspose2d |
[N,B,C,H,W]→[N,B,Cout,Hout,Wout] |
grouped, channels_last, sequential | 5.7x @ N=32 |
nn.ConvTranspose1d |
WideConvTranspose1d |
[N,B,C,L]→[N,B,Cout,Lout] |
grouped, sequential | 5.3x @ N=32 |
nn.Conv3d |
WideConv3d |
[N,B,C,D,H,W]→[N,B,Cout,Dout,Hout,Wout] |
grouped, sequential | 4.4x @ N=16 |
| Layer | Wide Version | I/O Format | Strategies | Best Speedup |
|---|---|---|---|---|
nn.BatchNorm1d |
WideBatchNorm1d |
[N,B,C]→[N,B,C] |
wide | 36.7x @ N=32 |
nn.BatchNorm2d |
WideBatchNorm2d |
[N,B,C,H,W]→[N,B,C,H,W] |
wide | 35.8x @ N=32 |
nn.BatchNorm3d |
WideBatchNorm3d |
[N,B,C,D,H,W]→[N,B,C,D,H,W] |
wide | 23.5x @ N=32 |
nn.InstanceNorm2d |
WideInstanceNorm2d |
[N,B,C,H,W]→[N,B,C,H,W] |
fused, sequential | 21.3x @ N=32 |
nn.RMSNorm |
WideRMSNorm |
[N,B,...,D]→[N,B,...,D] |
batched, sequential | 20.8x @ N=32 |
AdaLayerNormZeroSingle |
WideAdaLayerNormZeroSingle |
[N,B,D],[N,B,Demb]→[N,B,D],gate |
fused, sequential | 15.2x @ N=16 |
nn.GroupNorm |
WideGroupNorm |
[N,B,C,...]→[N,B,C,...] |
fused, sequential | 12.9x @ N=32 |
nn.LayerNorm |
WideLayerNorm |
[N,B,...,D]→[N,B,...,D] |
wide | 9.8x @ N=32 |
| Layer | Wide Version | I/O Format | Strategies | Best Speedup |
|---|---|---|---|---|
MultiheadCrossAttention |
WideMultiheadCrossAttention |
[N,B,Tq,D],[N,B,Tkv,D]→[N,B,Tq,D] |
fused, sequential | 17.8x @ N=32 |
nn.MultiheadAttention |
WideAttention |
[N,B,T,D]→[N,B,T,D] |
fused, sequential | 9.6x @ N=32 |
| Layer | Wide Version | I/O Format | Strategies | Best Speedup |
|---|---|---|---|---|
nn.RNN |
WideRNN |
[N,B,T,Din]→[N,B,T,H],[N,B,H] |
fused, sequential | 5.6x @ N=32 |
nn.LSTM |
WideLSTM |
[N,B,T,Din]→[N,B,T,H],[N,B,H],[N,B,H] |
fused, sequential | 3.3x @ N=32 |
nn.GRU |
WideGRU |
[N,B,T,Din]→[N,B,T,H],[N,B,H] |
fused, sequential | 2.9x @ N=32 |
| Layer | Wide Version | I/O Format | Strategies | Best Speedup |
|---|---|---|---|---|
nn.Dropout |
WideDropout |
[N,B,...]→[N,B,...] |
independent, shared, sequential | 173.7x @ N=32 |
nn.AdaptiveAvgPool2d |
WideAdaptiveAvgPool2d |
[N,B,C,Hin,Win]→[N,B,C,Hout,Wout] |
batched, sequential | 15.2x @ N=32 |
nn.PReLU |
WidePReLU |
[N,B,C,...]→[N,B,C,...] |
wide, sequential | 12.0x @ N=32 |
F.relu, F.gelu, etc. |
FunctionalOp |
agnostic | — | — |
+, -, *, /, @ |
BinaryOp |
agnostic | — | — |
All primitives operate on N-first format [N, B, ...] internally for optimal performance.
Benchmarks: A100 GPU, torch.compile (default mode), quick preset. See test_cases.py for full results.
Higher-level composite blocks for transformer architectures.
| Block | I/O Format | Components | Best Speedup |
|---|---|---|---|
WideAttention |
[N,B,T,D]→[N,B,T,D] |
QKV proj + SDPA + out proj | 10.7x @ N=16 |
WideDoubleStreamBlock |
[N,B,Ttxt,D],[N,B,Timg,D]→[N,B,Ttxt,D],[N,B,Timg,D] |
JointAttn + 2x MLP + norms | 6.8x @ N=8 |
WideJointAttention |
[N,B,Ttxt,D],[N,B,Timg,D]→[N,B,Ttxt,D],[N,B,Timg,D] |
Dual-stream attention | 5.5x @ N=32 |
WideMLP |
[N,B,T,D]→[N,B,T,D] |
2x Linear + activation | 3.9x @ N=16 |
WideSingleStreamBlock |
[N,B,T,D],[N,B,Demb]→[N,B,T,D] |
AdaLN + Attn + MLP | 3.5x @ N=8 |
from wide_compiler.core.blocks import WideDoubleStreamBlock
# Create N double-stream blocks
blocks = [...] # Your N DoubleStreamBlock modules
wide_block = WideDoubleStreamBlock.from_modules(blocks, strategy='fused')
# Input: Two streams (text and image)
txt = torch.randn(8, 4, 64, 256).cuda() # [N, B, Ttxt, D]
img = torch.randn(8, 4, 256, 256).cuda() # [N, B, Timg, D]
rope = torch.randn(320, 128).cuda() # [Ttxt+Timg, head_dim]
# Forward through block
txt_out, img_out = wide_block(txt, img, rope=rope)All benchmarks: A100 GPU, torch.compile (default mode), quick preset (N=[4,8,16,32]).
| Primitive | Best Speedup | Strategy |
|---|---|---|
| Dropout | 173.7x | shared |
| BatchNorm1d | 36.7x | wide |
| BatchNorm2d | 35.8x | wide |
| Embedding | 27.1x | indexed/gather |
| BatchNorm3d | 23.5x | wide |
| InstanceNorm2d | 21.3x | fused |
| RMSNorm | 20.8x | batched |
| MultiheadCrossAttention | 17.8x | fused |
| AdaptiveAvgPool2d | 15.2x | batched |
| AdaLayerNormZeroSingle | 15.2x | fused |
| Primitive | Best Speedup | Strategy |
|---|---|---|
| MLPEmbedder | 14.8x | fused |
| Linear | 8.8x | einsum |
| Primitive | Best Speedup | Strategy |
|---|---|---|
| Conv1d | 12.1x | grouped |
| PReLU | 12.0x | wide |
| Conv2d | 6.2x | grouped |
| ConvTranspose2d | 5.7x | grouped |
| ConvTranspose1d | 5.3x | grouped |
| Conv3d | 4.4x | grouped |
| Primitive | Best Speedup | Strategy |
|---|---|---|
| Attention | 9.6x | fused |
| Primitive | Best Speedup | Strategy | Notes |
|---|---|---|---|
| GroupNorm | 12.9x | fused | |
| LayerNorm | 9.8x | wide | |
| RNN | 5.6x | fused | Speedup improved with compilation |
| LSTM | 3.3x | fused | Speedup improved with compilation |
| GRU | 2.9x | fused | Speedup improved with compilation |
- Dropout achieves extreme speedups (173.7x) with shared random state
- BatchNorm layers see massive gains with compilation (23-37x)
- Embedding scales exceptionally well (27.1x)
- RMSNorm provides 20.8x speedup, outperforms LayerNorm (9.8x) by 2.1x
- CrossAttention scales to 17.8x with compilation
- RNN layers benefit from compilation (GRU: 2.9x, LSTM: 3.3x, RNN: 5.6x)
- Conv1d reaches 12.1x with compilation
- Compilation is critical - most primitives see 2-5x additional speedup
| Block | N=4 | N=8 | N=16 | N=32 | Components |
|---|---|---|---|---|---|
| AttentionBlock | 4.0x | 7.6x | 10.7x | 8.4x | QKV proj + SDPA + norm |
| DoubleStreamBlock | 4.0x | 6.8x | — | — | JointAttn + 2xMLP + norms |
| JointAttention | 3.0x | 3.9x | 5.4x | 5.5x | Dual-stream QKV + concat attn |
| MLPBlock | 3.1x | 3.4x | 3.9x | 2.5x | 2x Linear + activation |
| SingleStreamBlock | 3.1x | 3.5x | — | — | AdaLN + Attn + MLP |
- FX Tracing -
torch.fx.symbolic_tracecaptures the computation graph - Registry Lookup - Each layer dynamically resolved via global registry (24 primitives registered)
- Wide Primitives - Each layer replaced with N-first fused equivalent:
Linear→ Batched einsum[N, B, I] @ [N, O, I](transposed weight)Conv2d→ Grouped convolution on[N*B, C, H, W]withgroups=NAttention→ Reshape N→batch, single Flash Attention callRMSNorm→ Batched normalization[N, B, D]- All primitives operate on N-first
[N, B, ...]
- Strategy Selection - Each primitive auto-selects optimal strategy based on N
- Compile-Friendly - 0 graph breaks, all native PyTorch ops
- Optimal Data Flow - Only 2 reshapes per forward pass:
Input [B, N*C, ...] → Unpack → [N, B, C, ...] ↓ All stages operate on N-first (zero intermediate conversions) Output [N, B, C, ...] → Pack → [B, N*C, ...]
Why N-first format is optimal:
# All primitives operate on [N, B, ...] format
# Data flows through without any intermediate packing
# Only reshape at boundaries (input/output)
# Result: Maximum kernel fusion, minimum overheadwide_compiler/
├── __init__.py
├── __main__.py
├── api.py # compile(), WideBuilder, pack(), unpack()
├── cli.py # CLI commands
├── test_cases.py # Reusable test suite (NEW)
└── core/
├── config.py # WideConfig
├── registry.py # Dynamic primitive registration (24 primitives)
├── traced_wide.py # FX tracing + TracedWideModel
├── ensemble_util.py # pack_inputs(), unpack_outputs()
├── benchmark/ # Benchmark system
│ ├── benchmark_api.py # High-level API
│ ├── benchmark_runner.py # Execution engine
│ ├── benchmark_schema.py # BenchmarkJob, results
│ └── benchmark_registry.py # Auto-discovery
├── blocks/ # Flux-style composite blocks (NEW)
│ ├── wide_mlp.py
│ ├── wide_attention.py
│ ├── wide_joint_attention.py
│ ├── wide_double_stream_block.py
│ └── wide_single_stream_block.py
└── primitives/ # 24 wide primitives
├── wide_linear.py
├── wide_conv1d.py, wide_conv2d.py, wide_conv3d.py
├── wide_convtranspose1d.py, wide_convtranspose2d.py (NEW)
├── wide_batchnorm_1d.py, wide_batchnorm_2d.py, wide_batchnorm_3d.py (NEW)
├── wide_layernorm.py
├── wide_groupnorm.py
├── wide_instancenorm.py
├── wide_rmsnorm.py (NEW)
├── wide_ada_layer_norm_zero_single.py (NEW)
├── wide_embedding.py
├── wide_mlp_embedder.py (NEW)
├── wide_attention.py
├── wide_cross_attention.py (NEW)
├── wide_gru.py, wide_lstm.py, wide_rnn.py (NEW)
├── wide_prelu.py (NEW)
├── wide_dropout.py (NEW)
└── wide_adaptive_avgpool2d.py (NEW)
- Identical architecture required - All N models must have same structure
- Static shapes - FX tracing requires fixed tensor shapes
- No dynamic control flow -
if/forbased on tensor values won't trace
- Single layer only -
num_layers=1currently required - Unidirectional only -
bidirectional=Falserequired - batch_first only -
batch_first=Truerequired - Compilation recommended - RNN layers see 2-6x speedup with torch.compile
- Ensemble models - Run N ensemble members in parallel
- Hyperparameter search - Evaluate N configurations simultaneously
- Population-based training - Evolve N agents together
- Monte Carlo dropout - N stochastic forward passes
- Transformer ensembles - N attention heads across models
- Flux-style diffusion - Parallel text/image streams with WideDoubleStreamBlock
Apache License 2.0
AbstractPhil - HuggingFace | GitHub