The first NKI Conv3d kernel for AWS Trainium.
Video generation models (Wan2.1/2.2, CogVideoX, HunyuanVideo) use 3D VAEs built on Conv3d / CausalConv3d. AWS Trainium has no native Conv3d support — the NKI ecosystem only has Conv1d. This gap blocks all video generation models from running on Trainium.
This repo fills that gap.
| NKI Kernel | Exists Before This Repo |
|---|---|
| Conv1d | ✅ (nki-library) |
| Conv2d | ❌ |
| Conv3d | ❌ → ✅ this repo |
Conv3d is decomposed into temporal slices of Conv2d, each computed via im2col + GEMM:
output[:, :, d, :, :] = Σ_{kd} Conv2d(input[:, :, d*s+kd, :, :], weight[:, :, kd, :, :])
This decomposition is exact (not an approximation). The host-side wrapper builds im2col matrices and pads all dimensions to multiples of 128, then a tiled NKI matmul kernel runs the GEMM via nisa.nc_matmul.
| File | Description |
|---|---|
conv3d.py |
NKI kernel — the main deliverable |
conv3d_ref.py |
NumPy reference (im2col + matmul) for testing |
test_conv3d.py |
138+ test cases across 3 layers |
| Layer | Cases | Source |
|---|---|---|
| PyTorch standard | 12 | Adapted from torch/testing/_internal/common_nn.py |
| Wan2.1/2.2 VAE configs | 12 | Actual CausalConv3d shapes from wan/modules/vae.py |
| CogVideoX-5b VAE configs | 15 | From THUDM/CogVideoX-5b vae/config.json |
| HunyuanVideo VAE configs | 18 | From tencent/HunyuanVideo vae/config.json |
| BFloat16 precision | 12 | bf16-quantized inputs vs PyTorch bf16 |
| Dilation | 8 | Uniform, spatial-only, temporal-only, asymmetric |
| Grouped / depthwise | 15 | groups=2/4, depthwise, with stride/padding/bias |
| Edge cases | 7+ | Single channel, D=1, mixed strides, causal padding |
All tests compare against torch.nn.functional.conv3d as ground truth.
pip install numpy pytest torch
pytest test_conv3d.py -k "Ref" -vneuronxcc only runs on x86_64 Linux. Use Docker:
docker build --platform linux/amd64 -t nki-conv3d .
docker run --platform linux/amd64 nki-conv3dThis uses nki.simulate_kernel — no Trainium hardware required.
pip install neuronx-cc==2.* numpy torch pytest \
--extra-index-url=https://pip.repos.neuron.amazonaws.com
pytest test_conv3d.py -k "NKI" -vfrom conv3d import conv3d
# Calls NKI tiled_matmul_kernel internally (CPU simulation, no hardware needed)
result = conv3d(input_np, weight_np, bias_np,
stride=(1, 1, 1), padding=(1, 1, 1))Wan's CausalConv3d applies asymmetric temporal padding (2*pad, 0) before calling standard Conv3d. This kernel handles the Conv3d part; causal padding is done at the Python wrapper level:
import numpy as np
from conv3d_ref import conv3d_ref
# Simulate CausalConv3d(3,3,3) with padding=(1,1,1)
input_causal = np.pad(input, ((0,0), (0,0), (2,0), (1,1), (1,1)), mode="constant")
output = conv3d_ref(input_causal, weight, stride=(1,1,1), padding=(0,0,0))- NumPy reference implementation with im2col + matmul
- NKI tiled matmul kernel with bulk
nl.arangeload/store - Comprehensive test suite (347 cases, all ref + NKI tests pass)
- Wan2.1/2.2 VAE CausalConv3d compatibility tests (all channel sizes up to 384)
- Vectorized im2col (
conv3d_fused) — true on-device im2col blocked by NKI API (no gather DMA) - Performance benchmarks on trn1/trn2
- bfloat16 precision tests (12 cases)
- Dilation support (
dilation > 1) with 8 test cases - Grouped / depthwise convolution (
groups > 1) with 15 test cases - CogVideoX-5b VAE configs (15 cases) and HunyuanVideo VAE configs (18 cases)
- Backward pass NumPy reference (grad_input, grad_weight, grad_bias, 38 tests)
- Backward pass NKI kernel (
conv3d_backward, 38 tests via Docker) - PR to aws-neuron/nki-library
- aws-neuron/nki-library — Official NKI kernels (Conv1d, Flash Attention, RoPE, RMSNorm)
- aws-neuron/nki-samples — NKI tutorials and examples
- Wan-Video/Wan2.1 — Video generation model whose 3D VAE needs this kernel
- neuronx-distributed-inference #57 — LTX-2 video model on Trainium (DiT only, no Conv3D)
Apache 2.0