Skip to content

LianyuHuang/nki-conv3d

Repository files navigation

nki-conv3d

The first NKI Conv3d kernel for AWS Trainium.

Why

Video generation models (Wan2.1/2.2, CogVideoX, HunyuanVideo) use 3D VAEs built on Conv3d / CausalConv3d. AWS Trainium has no native Conv3d support — the NKI ecosystem only has Conv1d. This gap blocks all video generation models from running on Trainium.

This repo fills that gap.

NKI Kernel Exists Before This Repo
Conv1d ✅ (nki-library)
Conv2d
Conv3d ❌ → ✅ this repo

Algorithm

Conv3d is decomposed into temporal slices of Conv2d, each computed via im2col + GEMM:

output[:, :, d, :, :] = Σ_{kd} Conv2d(input[:, :, d*s+kd, :, :], weight[:, :, kd, :, :])

This decomposition is exact (not an approximation). The host-side wrapper builds im2col matrices and pads all dimensions to multiples of 128, then a tiled NKI matmul kernel runs the GEMM via nisa.nc_matmul.

Files

File Description
conv3d.py NKI kernel — the main deliverable
conv3d_ref.py NumPy reference (im2col + matmul) for testing
test_conv3d.py 138+ test cases across 3 layers

Test Coverage

Layer Cases Source
PyTorch standard 12 Adapted from torch/testing/_internal/common_nn.py
Wan2.1/2.2 VAE configs 12 Actual CausalConv3d shapes from wan/modules/vae.py
CogVideoX-5b VAE configs 15 From THUDM/CogVideoX-5b vae/config.json
HunyuanVideo VAE configs 18 From tencent/HunyuanVideo vae/config.json
BFloat16 precision 12 bf16-quantized inputs vs PyTorch bf16
Dilation 8 Uniform, spatial-only, temporal-only, asymmetric
Grouped / depthwise 15 groups=2/4, depthwise, with stride/padding/bias
Edge cases 7+ Single channel, D=1, mixed strides, causal padding

All tests compare against torch.nn.functional.conv3d as ground truth.

Quick Start

Run reference tests (any machine, no NKI needed)

pip install numpy pytest torch
pytest test_conv3d.py -k "Ref" -v

Run NKI kernel tests via Docker (macOS / any machine)

neuronxcc only runs on x86_64 Linux. Use Docker:

docker build --platform linux/amd64 -t nki-conv3d .
docker run --platform linux/amd64 nki-conv3d

This uses nki.simulate_kernelno Trainium hardware required.

Run NKI kernel tests directly (x86_64 Linux only)

pip install neuronx-cc==2.* numpy torch pytest \
    --extra-index-url=https://pip.repos.neuron.amazonaws.com
pytest test_conv3d.py -k "NKI" -v

Use in your model

from conv3d import conv3d

# Calls NKI tiled_matmul_kernel internally (CPU simulation, no hardware needed)
result = conv3d(input_np, weight_np, bias_np,
                stride=(1, 1, 1), padding=(1, 1, 1))

CausalConv3d (Wan2.1/2.2 VAE)

Wan's CausalConv3d applies asymmetric temporal padding (2*pad, 0) before calling standard Conv3d. This kernel handles the Conv3d part; causal padding is done at the Python wrapper level:

import numpy as np
from conv3d_ref import conv3d_ref

# Simulate CausalConv3d(3,3,3) with padding=(1,1,1)
input_causal = np.pad(input, ((0,0), (0,0), (2,0), (1,1), (1,1)), mode="constant")
output = conv3d_ref(input_causal, weight, stride=(1,1,1), padding=(0,0,0))

Roadmap

  • NumPy reference implementation with im2col + matmul
  • NKI tiled matmul kernel with bulk nl.arange load/store
  • Comprehensive test suite (347 cases, all ref + NKI tests pass)
  • Wan2.1/2.2 VAE CausalConv3d compatibility tests (all channel sizes up to 384)
  • Vectorized im2col (conv3d_fused) — true on-device im2col blocked by NKI API (no gather DMA)
  • Performance benchmarks on trn1/trn2
  • bfloat16 precision tests (12 cases)
  • Dilation support (dilation > 1) with 8 test cases
  • Grouped / depthwise convolution (groups > 1) with 15 test cases
  • CogVideoX-5b VAE configs (15 cases) and HunyuanVideo VAE configs (18 cases)
  • Backward pass NumPy reference (grad_input, grad_weight, grad_bias, 38 tests)
  • Backward pass NKI kernel (conv3d_backward, 38 tests via Docker)
  • PR to aws-neuron/nki-library

Related

License

Apache 2.0

About

First NKI Conv3d kernel for AWS Trainium — enables video generation models on Neuron

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors