diff --git a/.gitignore b/.gitignore index 3645ff0..9938bae 100755 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,6 @@ -.DS_Store -.vscode -__pycache__ - -# build results -build/ -dist/ -*.egg-info -*.so - -# checkpoints +venv/ +__pycache__/ +.vscode/ checkpoints/ - -# outputs -output*/ -*.mp4 \ No newline at end of file +output/ +*.egg-info/ diff --git a/README_AMD_WINDOWS.md b/README_AMD_WINDOWS.md new file mode 100644 index 0000000..52f013c --- /dev/null +++ b/README_AMD_WINDOWS.md @@ -0,0 +1,172 @@ +# TurboDiffusion - AMD ROCm on Windows Setup Guide + +This guide explains how to build and run TurboDiffusion on Windows with AMD GPUs using ROCm. + +> **Note:** These steps should also work on Linux with minor modifications (use bash commands instead of PowerShell, `source venv/bin/activate` instead of `.\venv\Scripts\Activate.ps1`, and skip the Visual Studio environment setup). However, Linux support has not been tested yet and may have issues. + +## Supported Hardware + +TurboDiffusion on Windows has been tested with RDNA3/RDNA3.5 GPUs (gfx1100, gfx1101, gfx1102, gfx1151). + +## Prerequisites + +- Windows 10/11 +- Python 3.11, 3.12, or 3.13 +- Visual Studio 2022 with C++ build tools +- AMD Adrenaline driver (latest recommended) + +## Installation + +### 1. Install ROCm and PyTorch from TheRock + +Follow the instructions at [ROCm/TheRock RELEASES.md](https://github.com/ROCm/TheRock/blob/main/RELEASES.md) to install ROCm and PyTorch wheels for your GPU architecture. + +#### Create a Virtual Environment + +```powershell +python -m venv venv +.\venv\Scripts\Activate.ps1 +``` + +#### Install PyTorch (includes ROCm SDK as dependency) + +For **gfx1151** (AMD Strix Halo iGPU): +```powershell +pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ --pre torch torchaudio torchvision +``` + +For **gfx110X** (RX 7900 XTX, RX 7800 XT, RX 7700S, Radeon 780M): +```powershell +pip install --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/ --pre torch torchaudio torchvision +``` + +For **gfx120X** (RX 9060, RX 9070): +```powershell +pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ --pre torch torchaudio torchvision +``` + +#### Initialize ROCm SDK + +```powershell +rocm-sdk init +``` + +#### Install Triton with AMD Windows Support + +```powershell +pip install triton-windows +``` + +### 2. Set Environment Variables + +Open a PowerShell terminal and run: + +```powershell +# Activate Visual Studio environment +cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } } + +# Activate the virtual environment +.\venv\Scripts\Activate.ps1 + +# Set ROCm paths using rocm-sdk +$ROCM_ROOT = (rocm-sdk path --root).Trim() +$ROCM_BIN = (rocm-sdk path --bin).Trim() +$env:ROCM_HOME = $ROCM_ROOT +$env:PATH = "$ROCM_ROOT\lib\llvm\bin;$ROCM_BIN;$env:PATH" + +# Set compiler and build settings +$env:CC = "clang-cl" +$env:CXX = "clang-cl" +$env:DISTUTILS_USE_SDK = "1" + +# Enable experimental features +$env:FLASH_ATTENTION_TRITON_AMD_ENABLE = "TRUE" +$env:TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL = "1" + +# Set PYTHONPATH for TurboDiffusion +$env:PYTHONPATH = "turbodiffusion" +``` + +### 3. Build and Install TurboDiffusion + +```powershell +cd +pip install --no-build-isolation -e . +``` + +### 4. Install SpargeAttn (Optional, for sparse attention) + +If you want to use sparse attention with TurboDiffusion, clone the AMD Windows fork: + +```powershell +git clone --branch jam/amd_windows https://github.com/jammm/SpargeAttn.git +cd SpargeAttn +pip install --no-build-isolation -v . +``` + +## Running Inference + +### Text-to-Video with Wan2.1 + +```powershell +# Make sure environment variables are set (see step 2) + +python turbodiffusion/inference/wan2.1_t2v_infer.py ` + --model Wan2.1-1.3B ` + --dit_path checkpoints/TurboWan2.1-T2V-1.3B-480P-quant.pth ` + --resolution 480p ` + --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage." ` + --num_samples 1 ` + --num_steps 4 ` + --quant_linear ` + --attention_type sagesla ` + --sla_topk 0.1 +``` + +### Available Attention Types + +- `sdpa` - PyTorch Scaled Dot Product Attention +- `sagesla` - SageAttention with Sparse Linear Attention (requires SpargeAttn) + +## Environment Variable Summary + +| Variable | Value | Description | +|----------|-------|-------------| +| `CC` | `clang-cl` | C compiler | +| `CXX` | `clang-cl` | C++ compiler | +| `DISTUTILS_USE_SDK` | `1` | Use SDK for distutils | +| `ROCM_HOME` | `` | ROCm SDK root path | +| `PATH` | Include LLVM and ROCm bin | Required for hipcc, clang, lld-link | +| `FLASH_ATTENTION_TRITON_AMD_ENABLE` | `TRUE` | Enable Triton Flash Attention on AMD | +| `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL` | `1` | Enable experimental aotriton kernels | +| `PYTHONPATH` | `turbodiffusion` | Include turbodiffusion module | + +## Known Issues + +1. **Triton compiler warnings** - You may see `clang-cl: warning: unknown argument ignored` warnings during first run. These are harmless. + +2. **First run is slow** - Triton and MIOpen kernels are compiled on first use and cached. Subsequent runs will be faster. + +3. **No FP8 support on RDNA3** - RDNA3 GPUs don't support FP8, so FP16/BF16 kernels are used. + +## Troubleshooting + +### "LoadLibrary failed" or "cannot find amdhip64.dll" + +Make sure you ran `rocm-sdk init` after installing the ROCm SDK packages. + +### "LINK : fatal error LNK1104: cannot open file 'python312.lib'" + +Ensure Visual Studio environment is activated before building: +```powershell +cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } } +``` + +### "PermissionError" when compiling Triton kernels + +This is a known Windows issue with temp file handling. Make sure you're using the latest `triton-windows` package (`pip install --upgrade triton-windows`). + +### "flash_attn is not installed" warning + +This warning is expected. Flash Attention is not available on AMD GPUs, but Triton-based attention is used instead when `FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE` is set. + diff --git a/build_ext_log.txt b/build_ext_log.txt new file mode 100644 index 0000000..7fd75b3 Binary files /dev/null and b/build_ext_log.txt differ diff --git a/pyproject.toml b/pyproject.toml index 4036059..38b34f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ "torch>=2.7.0", "torchvision", - "triton>=3.3.0", + "triton-windows>=3.3.0", "flash-attn", "einops", "numpy", diff --git a/rocwmma_lib b/rocwmma_lib new file mode 160000 index 0000000..c360d54 --- /dev/null +++ b/rocwmma_lib @@ -0,0 +1 @@ +Subproject commit c360d5484a5f2c8dacb166154dbe462c4777db5a diff --git a/setup.py b/setup.py index 2ca06f9..5d1bcbe 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ """ Copyright (c) 2025 by TurboDiffusion team. -Licensed under the Apache License, Version 2.0 (the "License"); +Licensed under the Apache License, Version 2.0 (the "License") Citation (please cite if you use this code): @@ -16,60 +16,117 @@ from pathlib import Path from setuptools import setup, find_packages from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os +import sys + +import torch + +is_rocm = torch.version.hip is not None + +# On Windows, deduplicate INCLUDE/LIB/LIBPATH to avoid "command line too long" errors +if sys.platform == 'win32': + for var in ['INCLUDE', 'LIB', 'LIBPATH']: + val = os.environ.get(var, '') + if val: + unique = [] + seen = set() + for p in val.split(';'): + if p.lower() not in seen and p: + seen.add(p.lower()) + unique.append(p) + os.environ[var] = ';'.join(unique) ops_dir = Path(__file__).parent / "turbodiffusion" / "ops" cutlass_dir = ops_dir / "cutlass" +rocwmma_dir = Path(__file__).parent / "rocwmma_lib" / "projects" / "rocwmma" / "library" / "include" -nvcc_flags = [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=--verbose,--warn-on-local-memory-usage", - "-lineinfo", - "-DCUTLASS_DEBUG_TRACE_LEVEL=0", - "-DNDEBUG", - "-Xcompiler", - "-fPIC" -] +if is_rocm: + # HIP/ROCm build with rocWMMA + hip_flags = [ + "-O3", + "-std=c++17", + "-D__HIP_PLATFORM_AMD__", + "-DNDEBUG", + # Undefine PyTorch's half conversion restrictions - rocWMMA needs these + "-U__HIP_NO_HALF_OPERATORS__", + "-U__HIP_NO_HALF_CONVERSIONS__", + ] + + # Windows-specific: add C/C++ runtime libraries for clang-cl + extra_libraries = [] + extra_link_args = [] + if sys.platform == 'win32': + extra_libraries = ["msvcrt", "vcruntime", "ucrt"] + # Force linking with MSVC C++ runtime + extra_link_args = ["/DEFAULTLIB:msvcprt"] + + ext_modules = [ + CUDAExtension( + name="turbo_diffusion_ops", + sources=[ + "turbodiffusion/ops/bindings.cpp", + "turbodiffusion/ops/quant/quant.hip", + "turbodiffusion/ops/norm/rmsnorm.hip", + "turbodiffusion/ops/norm/layernorm.hip", + "turbodiffusion/ops/gemm/gemm_rocwmma.hip", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17", "-D__HIP_PLATFORM_AMD__"], + "nvcc": hip_flags, + }, + include_dirs=[ + str(rocwmma_dir), + str(ops_dir), + ], + libraries=extra_libraries, + extra_link_args=extra_link_args, + ) + ] +else: + # CUDA build with CUTLASS + nvcc_flags = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "-lineinfo", + "-DNDEBUG", + ] -cc_flag = [ - "-gencode", "arch=compute_120a,code=sm_120a", - "-gencode", "arch=compute_100,code=sm_100", - "-gencode", "arch=compute_90,code=sm_90", - "-gencode", "arch=compute_89,code=sm_89", - "-gencode", "arch=compute_80,code=sm_80" -] + cc_flag = [ + "-gencode", "arch=compute_120a,code=sm_120a", + "-gencode", "arch=compute_100,code=sm_100", + "-gencode", "arch=compute_90,code=sm_90", + "-gencode", "arch=compute_89,code=sm_89", + "-gencode", "arch=compute_80,code=sm_80" + ] -ext_modules = [ - CUDAExtension( - name="turbo_diffusion_ops", - sources=[ - "turbodiffusion/ops/bindings.cpp", - "turbodiffusion/ops/quant/quant.cu", - "turbodiffusion/ops/norm/rmsnorm.cu", - "turbodiffusion/ops/norm/layernorm.cu", - "turbodiffusion/ops/gemm/gemm.cu" - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], - "nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ["--threads", "4"], - }, - include_dirs=[ - cutlass_dir / "include", - cutlass_dir / "tools" / "util" / "include", - ops_dir - ], - libraries=["cuda"], - ) -] + ext_modules = [ + CUDAExtension( + name="turbo_diffusion_ops", + sources=[ + "turbodiffusion/ops/bindings.cpp", + "turbodiffusion/ops/quant/quant.cu", + "turbodiffusion/ops/norm/rmsnorm.cu", + "turbodiffusion/ops/norm/layernorm.cu", + "turbodiffusion/ops/gemm/gemm.cu" + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + "nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag, + }, + include_dirs=[ + str(cutlass_dir / "include"), + str(cutlass_dir / "tools" / "util" / "include"), + str(ops_dir), + ], + libraries=["cuda"], + ) + ] setup( packages=find_packages( diff --git a/turbodiffusion/SLA/core.py b/turbodiffusion/SLA/core.py index 430bfe0..3c56b88 100755 --- a/turbodiffusion/SLA/core.py +++ b/turbodiffusion/SLA/core.py @@ -17,6 +17,9 @@ import torch.nn as nn import torch.nn.functional as F +# Check for ROCm +IS_ROCM = torch.version.hip is not None + SAGESLA_ENABLED = True try: import spas_sage_attn._qattn as qattn @@ -182,10 +185,19 @@ def forward(self, q, k, v, return_sparsity=False): k = k.transpose(1, 2).contiguous() v = v.transpose(1, 2).contiguous() - arch = get_cuda_arch(q.device.index) + arch = get_cuda_arch(q.device.index) if not IS_ROCM else "rocm" + headdim = q.size(-1) if arch == "sm90": sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=64, BLKK=128) + elif IS_ROCM: + # ROCm: use smaller tiles for head_dim=128 to reduce register pressure + # head_dim=64: CTA_Q=64, CTA_K=64 + # head_dim=128: CTA_Q=32, CTA_K=16 (best performance at 10% sparsity) + blkq = 32 if headdim == 128 else 64 + blkk = 16 if headdim == 128 else 64 + sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=blkq, BLKK=blkk) else: + # Use 128x64 blocks for sm80-like archs sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=128, BLKK=64) q = q.to(self.dtype) @@ -195,26 +207,46 @@ def forward(self, q, k, v, return_sparsity=False): ########## SPARGE BEGIN ########## km = k.mean(dim=-2, keepdim=True) - headdim = q.size(-1) if arch == "sm90": q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 64, 128) + elif IS_ROCM: + # ROCm: use smaller tiles for head_dim=128 to reduce register pressure + # head_dim=64: CTA_Q=64, CTA_K=64 + # head_dim=128: CTA_Q=32, CTA_K=16 (best performance at 10% sparsity) + blkq = 32 if headdim == 128 else 64 + blkk = 16 if headdim == 128 else 64 + q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, blkq, blkk) else: + # Use 128x64 block sizes for sm80-like archs q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 128, 64) lut, valid_block_num = block_map_lut_triton(sparse_map) scale = 1.0 / (headdim ** 0.5) assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale." - o_s = torch.empty_like(q) - - if arch in ("sm80", "sm86", "sm87"): + if IS_ROCM: + # ROCm: kernel natively supports both float16 and bfloat16 + o_s = torch.empty_like(q) + pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device) + # Pass V in its native dtype (fp16 or bf16) - kernel handles both + qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( + q_int8, k_int8, v, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, 1, False, 1, scale, 0 + ) + elif arch in ("sm80", "sm86", "sm87"): + # NVIDIA sm80-sm87: requires FP16 V, kernel outputs float16 + o_s = torch.empty(q.shape, dtype=torch.float16, device=q.device) pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device) v_fp16 = v.to(torch.float16) + qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( q_int8, k_int8, v_fp16, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, 1, False, 1, scale, 0 ) + # Convert back to original dtype (may be bfloat16) + o_s = o_s.to(self.dtype) else: + # NVIDIA sm89+: use FP8 V kernels + o_s = torch.empty_like(q) b, h_kv, kv_len, head_dim = v.shape padded_len = (kv_len + 127) // 128 * 128 v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device) diff --git a/turbodiffusion/imaginaire/utils/distributed.py b/turbodiffusion/imaginaire/utils/distributed.py index e42ca42..b5d24e6 100644 --- a/turbodiffusion/imaginaire/utils/distributed.py +++ b/turbodiffusion/imaginaire/utils/distributed.py @@ -28,11 +28,11 @@ import pynvml import torch import torch.distributed as dist -from torch.distributed import get_process_group_ranks from imaginaire.utils.device import Device if dist.is_available(): + from torch.distributed import get_process_group_ranks from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes diff --git a/turbodiffusion/imaginaire/utils/misc.py b/turbodiffusion/imaginaire/utils/misc.py index 06d4236..634b403 100644 --- a/turbodiffusion/imaginaire/utils/misc.py +++ b/turbodiffusion/imaginaire/utils/misc.py @@ -30,8 +30,10 @@ import numpy as np import termcolor import torch -from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor.api import DTensor + +if torch.distributed.is_available(): + from torch.distributed._functional_collectives import AsyncCollectiveTensor + from torch.distributed._tensor.api import DTensor from imaginaire.utils import distributed, log diff --git a/turbodiffusion/ops/bindings.cpp b/turbodiffusion/ops/bindings.cpp index a87adf5..221f1c1 100644 --- a/turbodiffusion/ops/bindings.cpp +++ b/turbodiffusion/ops/bindings.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Python bindings for TurboDiffusion GPU operations. + * Supports both CUDA (NVIDIA) and HIP (AMD ROCm) backends. + */ + #include namespace py = pybind11; @@ -12,4 +21,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { register_rms_norm(m); register_layer_norm(m); register_gemm(m); -} \ No newline at end of file +} diff --git a/turbodiffusion/ops/common/common_hip.hpp b/turbodiffusion/ops/common/common_hip.hpp new file mode 100644 index 0000000..df11630 --- /dev/null +++ b/turbodiffusion/ops/common/common_hip.hpp @@ -0,0 +1,108 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#include +#include +#include + +// Define CUTLASS macros for HIP compatibility +#ifndef CUTLASS_HOST_DEVICE +#define CUTLASS_HOST_DEVICE __host__ __device__ +#endif +#ifndef CUTLASS_DEVICE +#define CUTLASS_DEVICE __device__ +#endif +#ifndef CUTLASS_HOST +#define CUTLASS_HOST __host__ +#endif + +// Define __grid_constant__ if not available (CUDA 11.5+ feature) +#ifndef __grid_constant__ +#define __grid_constant__ +#endif + +// Define CUTLASS pragma macros for HIP +#ifndef CUTLASS_PRAGMA_UNROLL +#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") +#endif +#ifndef CUTLASS_PRAGMA_NO_UNROLL +#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("nounroll") +#endif + +inline CUTLASS_HOST_DEVICE int64_t cdiv(int64_t const& a, int64_t const &b) { + return (a + b - 1) / b; +} + +// Note: Don't define max/min as they conflict with HIP builtins +// Use std::max/std::min or the built-in max/min instead + +#define MIN(a, b) ((a) > (b) ? (b) : (a)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ +[&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return (__VA_ARGS__)(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return (__VA_ARGS__)(); \ + } \ +}() + +#define CUDA_CHECK(call) \ +{ \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + fprintf(stderr, "CUDA Error at %s:%d: %s\n", __FILE__, __LINE__, hipGetErrorString(err)); \ + exit(err); \ + } \ +} + +#define CONFIG_SWITCH(N, ...) \ +[&] { \ + if (N <= 1024) { \ + constexpr int NUM_THR_PER_CTA = 128; \ + constexpr int MAX_HIDDEN_SIZE = 1024; \ + return (__VA_ARGS__)(); \ + } else if (N <= 2048) { \ + constexpr int NUM_THR_PER_CTA = 128; \ + constexpr int MAX_HIDDEN_SIZE = 2048; \ + return (__VA_ARGS__)(); \ + } else if (N <= 4096) { \ + constexpr int NUM_THR_PER_CTA = 128; \ + constexpr int MAX_HIDDEN_SIZE = 4096; \ + return (__VA_ARGS__)(); \ + } else if (N <= 8192) { \ + constexpr int NUM_THR_PER_CTA = 256; \ + constexpr int MAX_HIDDEN_SIZE = 8192; \ + return (__VA_ARGS__)(); \ + } else { \ + constexpr int NUM_THR_PER_CTA = 256; \ + constexpr int MAX_HIDDEN_SIZE = 16384; \ + return (__VA_ARGS__)(); \ + } \ +}() + + +template + void create_tensor( + torch::Device const &device, + std::optional &output, + std::optional &scale, + int m, int n + ) { + int num_block_m = cdiv(m, BlockSize); + int num_block_n = cdiv(n, BlockSize); + if (!output.has_value()) { + output.emplace(torch::empty( + {m, n}, + torch::TensorOptions().device(device).dtype(torch::kInt8) + )); + scale.emplace(torch::empty( + {num_block_m, num_block_n}, + torch::TensorOptions().device(device).dtype(torch::kFloat32) + )); + } + } + diff --git a/turbodiffusion/ops/common/launch_hip.hpp b/turbodiffusion/ops/common/launch_hip.hpp new file mode 100644 index 0000000..a2cc17d --- /dev/null +++ b/turbodiffusion/ops/common/launch_hip.hpp @@ -0,0 +1,45 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#include +#include + + +template +__global__ void device_kernel( + __grid_constant__ typename Kernel::Params const params +) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template +__global__ __launch_bounds__(Kernel::MaxThreadsPerBlock, Kernel::MinBlocksPerMultiprocessor) +void device_kernel_with_launch_bounds( + __grid_constant__ typename Kernel::Params const params +) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template +void launch_kernel( + typename Kernel::Params const ¶ms, + dim3 grid_shape, + dim3 cta_shape, + size_t ShmSize, + hipStream_t stream = nullptr +) { + auto func = device_kernel; + if (ShmSize >= 48 * 1024) { + CUDA_CHECK(hipFuncSetAttribute( + func, + hipFuncAttributeMaxDynamicSharedMemorySize, + ShmSize + )); + } + hipLaunchKernelGGL(( func), dim3(grid_shape), dim3(cta_shape), ShmSize, stream, params); + CUDA_CHECK(hipGetLastError()); +} diff --git a/turbodiffusion/ops/common/load.hpp b/turbodiffusion/ops/common/load.hpp index f72bcad..ae8ca7d 100644 --- a/turbodiffusion/ops/common/load.hpp +++ b/turbodiffusion/ops/common/load.hpp @@ -1,5 +1,48 @@ #pragma once +// Include common_hip.hpp for CUTLASS macro definitions when building for HIP +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#include "common/common_hip.hpp" +#include +#include + +// Helper functions for type conversion on HIP +namespace turbo_hip { + +template +__device__ __forceinline__ T from_float(float val) { + return static_cast(val); +} + +template <> +__device__ __forceinline__ __half from_float<__half>(float val) { + return __float2half(val); +} + +template <> +__device__ __forceinline__ hip_bfloat16 from_float(float val) { + return hip_bfloat16(val); +} + +template +__device__ __forceinline__ float to_float(T val) { + return static_cast(val); +} + +template <> +__device__ __forceinline__ float to_float<__half>(__half val) { + return __half2float(val); +} + +template <> +__device__ __forceinline__ float to_float(hip_bfloat16 val) { + return static_cast(val); +} + +} // namespace turbo_hip + +#endif + template < class InputDtype_, int TileM_, @@ -30,8 +73,13 @@ class Loader { void const *thr_input_ptr = (void*)((InputDtype*)cta_input_ptr + thr_m_offset * n + thr_n_offset); InputDtype tmp_reg[NumElementPerThread]; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < NumElementPerThread; ++i) + for (int i = 0; i < NumElementPerThread; ++i) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + tmp_reg[i] = turbo_hip::from_float(0.f); +#else tmp_reg[i] = InputDtype(0.f); +#endif + } bool pred = IsEvenM ? true : thr_m_offset + blk_m * TileM < m; int limit = IsEvenN ? NumElementPerThread : MIN(NumElementPerThread, n - (blk_n * TileN + thr_n_offset)); if (n_alignment % 128 == 0) @@ -42,8 +90,13 @@ class Loader { _load(thr_input_ptr, (void*)tmp_reg, limit, pred); CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < NumElementPerThread; ++i) + for (int i = 0; i < NumElementPerThread; ++i) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + *((float*)thr_output_reg + i) = turbo_hip::to_float(tmp_reg[i]); +#else *((float*)thr_output_reg + i) = static_cast(reinterpret_cast(tmp_reg[i])); +#endif + } } private: diff --git a/turbodiffusion/ops/common/numeric_conversion_hip.hpp b/turbodiffusion/ops/common/numeric_conversion_hip.hpp new file mode 100644 index 0000000..c8fccf9 --- /dev/null +++ b/turbodiffusion/ops/common/numeric_conversion_hip.hpp @@ -0,0 +1,51 @@ +// Compatibility header for CUTLASS numeric conversion on HIP/ROCm +// This provides a minimal subset of CUTLASS functionality needed for TurboDiffusion + +#pragma once + +#include +#include + +namespace cutlass { + +// FloatRoundStyle enum (subset of CUTLASS) +enum class FloatRoundStyle { + round_to_nearest = 0, + round_toward_zero = 1, + round_toward_infinity = 2, + round_toward_neg_infinity = 3, +}; + +// NumericConverter template - provides float to int8 conversion with rounding +template +struct NumericConverter { + __device__ __host__ __forceinline__ + To operator()(From const& val) const { + return static_cast(val); + } +}; + +// Specialization for float to int8_t with round_to_nearest +template <> +struct NumericConverter { + __device__ __host__ __forceinline__ + int8_t operator()(float val) const { + // Round to nearest and clamp to int8 range [-128, 127] + val = fmaxf(-128.0f, fminf(127.0f, rintf(val))); + return static_cast(val); + } +}; + +// Specialization for float to int8_t with round_toward_zero +template <> +struct NumericConverter { + __device__ __host__ __forceinline__ + int8_t operator()(float val) const { + // Truncate and clamp to int8 range [-128, 127] + val = fmaxf(-128.0f, fminf(127.0f, truncf(val))); + return static_cast(val); + } +}; + +} // namespace cutlass + diff --git a/turbodiffusion/ops/common/platform.hpp b/turbodiffusion/ops/common/platform.hpp new file mode 100644 index 0000000..284af0a --- /dev/null +++ b/turbodiffusion/ops/common/platform.hpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Platform abstraction layer for CUDA/HIP compatibility. + * This header provides unified macros and types for both backends. + */ + +#pragma once + +// Detect platform +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + #define TURBO_PLATFORM_HIP 1 + #define TURBO_PLATFORM_CUDA 0 +#else + #define TURBO_PLATFORM_HIP 0 + #define TURBO_PLATFORM_CUDA 1 +#endif + +// Include appropriate runtime headers +#if TURBO_PLATFORM_HIP + #include + #include +#else + #include + #include + #include +#endif + +// Stream type abstraction +#if TURBO_PLATFORM_HIP + using turboStream_t = hipStream_t; + using turboError_t = hipError_t; + #define turboSuccess hipSuccess + #define turboGetErrorString hipGetErrorString + #define turboGetLastError hipGetLastError + #define turboFuncSetAttribute hipFuncSetAttribute + #define turboFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize +#else + using turboStream_t = cudaStream_t; + using turboError_t = cudaError_t; + #define turboSuccess cudaSuccess + #define turboGetErrorString cudaGetErrorString + #define turboGetLastError cudaGetLastError + #define turboFuncSetAttribute cudaFuncSetAttribute + #define turboFuncAttributeMaxDynamicSharedMemorySize cudaFuncAttributeMaxDynamicSharedMemorySize +#endif + +// Device function qualifiers +#if TURBO_PLATFORM_HIP + #define TURBO_HOST __host__ + #define TURBO_DEVICE __device__ + #define TURBO_HOST_DEVICE __host__ __device__ + #define TURBO_KERNEL __global__ + #define TURBO_INLINE __forceinline__ +#else + #define TURBO_HOST __host__ + #define TURBO_DEVICE __device__ + #define TURBO_HOST_DEVICE __host__ __device__ + #define TURBO_KERNEL __global__ + #define TURBO_INLINE __forceinline__ +#endif + +// Pragma unroll +#if TURBO_PLATFORM_HIP + #define TURBO_PRAGMA_UNROLL _Pragma("unroll") + #define TURBO_PRAGMA_NO_UNROLL _Pragma("nounroll") +#else + #define TURBO_PRAGMA_UNROLL _Pragma("unroll") + #define TURBO_PRAGMA_NO_UNROLL _Pragma("nounroll") +#endif + +// Error checking macro +#define TURBO_CHECK(call) \ + do { \ + turboError_t err = (call); \ + if (err != turboSuccess) { \ + fprintf(stderr, "GPU Error at %s:%d: %s\n", __FILE__, __LINE__, \ + turboGetErrorString(err)); \ + exit(err); \ + } \ + } while (0) + +// Half precision types +#if TURBO_PLATFORM_HIP + using half_t = __half; + using bfloat16_t = hip_bfloat16; + + TURBO_DEVICE TURBO_INLINE float __int2float_rn_hip(int x) { + return static_cast(x); + } + #define __int2float_rn __int2float_rn_hip + + TURBO_DEVICE TURBO_INLINE float __int_as_float_hip(int x) { + return __int_as_float(x); + } +#else + #include + using half_t = cutlass::half_t; + using bfloat16_t = cutlass::bfloat16_t; +#endif + +// Warp/Wave primitives +#if TURBO_PLATFORM_HIP + // RDNA3 uses wave32 + #define TURBO_WARP_SIZE 32 + #define TURBO_FULL_MASK 0xFFFFFFFFu + + TURBO_DEVICE TURBO_INLINE float warpReduceSum(float val) { + TURBO_PRAGMA_UNROLL + for (int offset = TURBO_WARP_SIZE / 2; offset > 0; offset >>= 1) { + val += __shfl_xor(val, offset, TURBO_WARP_SIZE); + } + return val; + } + + TURBO_DEVICE TURBO_INLINE float warpReduceMax(float val) { + TURBO_PRAGMA_UNROLL + for (int offset = TURBO_WARP_SIZE / 2; offset > 0; offset >>= 1) { + val = fmaxf(val, __shfl_xor(val, offset, TURBO_WARP_SIZE)); + } + return val; + } +#else + #define TURBO_WARP_SIZE 32 + #define TURBO_FULL_MASK 0xFFFFFFFFu + + TURBO_DEVICE TURBO_INLINE float warpReduceSum(float val) { + TURBO_PRAGMA_UNROLL + for (int offset = TURBO_WARP_SIZE / 2; offset > 0; offset >>= 1) { + val += __shfl_xor_sync(TURBO_FULL_MASK, val, offset); + } + return val; + } + + TURBO_DEVICE TURBO_INLINE float warpReduceMax(float val) { + TURBO_PRAGMA_UNROLL + for (int offset = TURBO_WARP_SIZE / 2; offset > 0; offset >>= 1) { + val = fmaxf(val, __shfl_xor_sync(TURBO_FULL_MASK, val, offset)); + } + return val; + } +#endif + +// Synchronization +#if TURBO_PLATFORM_HIP + #define __syncwarp() __syncthreads() +#endif + +// Kernel launch helper +template +TURBO_KERNEL void device_kernel_impl( + typename Kernel::Params const params +) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template +void launch_kernel_unified( + typename Kernel::Params const& params, + dim3 grid_shape, + dim3 cta_shape, + size_t ShmSize, + turboStream_t stream = nullptr +) { + auto func = device_kernel_impl; + if (ShmSize >= 48 * 1024) { + TURBO_CHECK(turboFuncSetAttribute( + func, + turboFuncAttributeMaxDynamicSharedMemorySize, + ShmSize + )); + } +#if TURBO_PLATFORM_HIP + hipLaunchKernelGGL(func, dim3(grid_shape), dim3(cta_shape), ShmSize, stream, params); +#else + func<<>>(params); +#endif + TURBO_CHECK(turboGetLastError()); +} + +// Numeric conversion helpers +namespace turbo { + +template +TURBO_DEVICE TURBO_INLINE To convert(From val) { + return static_cast(val); +} + +#if TURBO_PLATFORM_HIP +template <> +TURBO_DEVICE TURBO_INLINE int8_t convert(float val) { + // Round to nearest with clamping + val = fmaxf(-128.0f, fminf(127.0f, rintf(val))); + return static_cast(val); +} +#else +template <> +TURBO_DEVICE TURBO_INLINE int8_t convert(float val) { + return cutlass::NumericConverter()(val); +} +#endif + +} // namespace turbo + diff --git a/turbodiffusion/ops/common/store_hip.hpp b/turbodiffusion/ops/common/store_hip.hpp new file mode 100644 index 0000000..3c1b4e1 --- /dev/null +++ b/turbodiffusion/ops/common/store_hip.hpp @@ -0,0 +1,78 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once +#include "common/common_hip.hpp" + + +template < + class OutputDtype_, + int TileM_, + int TileN_, + int NumThrPerCta_, + bool IsEvenM, + bool IsEvenN, + bool Round = true, + bool SaveScale = true +> +class Saver { +public: + using OutputDtype = OutputDtype_; + + static constexpr int TileM = TileM_; + static constexpr int TileN = TileN_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int NumElementPerThread = TileM * TileN / NumThrPerCta; + static constexpr int NumThrPerRow = TileN / NumElementPerThread; + + static_assert(TileM * TileN % NumThrPerCta == 0); + static_assert(NumThrPerCta % TileM == 0); + + CUTLASS_DEVICE void + store(void *Optr, void *OSptr, void *reg, float scale_inv, int64_t m, int64_t n, int blk_m, int blk_n, int tid) { + int n_alignment = (n & 31) * sizeof(OutputDtype); + int thr_m_offset = tid / NumThrPerRow; + int thr_n_offset = (tid % NumThrPerRow) * NumElementPerThread; + void *cta_output_ptr = (void*)((OutputDtype*)Optr + blk_m * TileM * (Round ? cdiv(n, TileN) * TileN : n) + blk_n * TileN); + void *thr_output_ptr = (void*)((OutputDtype*)cta_output_ptr + thr_m_offset * (Round ? cdiv(n, TileN) * TileN : n) + thr_n_offset); + bool pred = IsEvenM ? true : thr_m_offset + blk_m * TileM < m; + int limit = IsEvenN ? NumElementPerThread : MIN(NumElementPerThread, n - (blk_n * TileN + thr_n_offset)); + if (n_alignment % 128 == 0) + _store(thr_output_ptr, reg, limit, pred); + else if (n_alignment % 64 == 0) + _store(thr_output_ptr, reg, limit, pred); + else + _store(thr_output_ptr, reg, limit, pred); + + if constexpr (SaveScale) { + if (tid == 0) { + *((float*)OSptr + blk_m * cdiv(n, TileN)+ blk_n) = scale_inv; + } + } + } + +private: + template + CUTLASS_DEVICE void + _store(void *thr_output_ptr, void *reg, int limit, bool pred) { + static constexpr int NumElementPerStore = sizeof(StoreDataType) / sizeof(OutputDtype); + if (pred) { + if constexpr (IsEven) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; i += NumElementPerStore) { + *(StoreDataType*)((OutputDtype*)thr_output_ptr + i) = *(StoreDataType*)((OutputDtype*)reg + i); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < limit; i += NumElementPerStore) { + if (limit - i > NumElementPerStore) + *(StoreDataType*)((OutputDtype*)thr_output_ptr + i) = *(StoreDataType*)((OutputDtype*)reg + i); + else { + for (int j = 0; j < limit - i; ++j) { + *((OutputDtype*)thr_output_ptr + i + j) = *((OutputDtype*)reg + i + j); + } + } + } + } + } + } + +}; diff --git a/turbodiffusion/ops/gemm/gemm_rocwmma.hip b/turbodiffusion/ops/gemm/gemm_rocwmma.hip new file mode 100644 index 0000000..53ca6d7 --- /dev/null +++ b/turbodiffusion/ops/gemm/gemm_rocwmma.hip @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + * + * rocWMMA-based GEMM for AMD ROCm GPUs. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/common_hip.hpp" +#include "gemm/launch_rocwmma.hpp" + +void int8_gemm( + at::Tensor const& A, at::Tensor const& A_S, + at::Tensor const& B, at::Tensor const& B_S, + torch::Tensor& C +) { + static constexpr int swizzle_dir = 1; + static constexpr int swizzle_size_log = 5; + + int k = B.size(1); + int m = A.size(0); + int n = B.size(0); + + switch (C.scalar_type()) { + case torch::kHalf: { + int8_gemm_rocwmma<__half>( + (int8_t*)A.data_ptr(), A_S.data_ptr(), + (int8_t*)B.data_ptr(), B_S.data_ptr(), + (__half*)C.data_ptr(), + m, n, k, swizzle_dir, swizzle_size_log, + at::hip::getCurrentHIPStream().stream() + ); + break; + } + + case torch::kBFloat16: { + int8_gemm_rocwmma( + (int8_t*)A.data_ptr(), A_S.data_ptr(), + (int8_t*)B.data_ptr(), B_S.data_ptr(), + (hip_bfloat16*)C.data_ptr(), + m, n, k, swizzle_dir, swizzle_size_log, + at::hip::getCurrentHIPStream().stream() + ); + break; + } + + default: { + std::cerr << "Observing: " << C.scalar_type() << " for the output datatype which is invalid"; + throw std::runtime_error("Unsupported output data type for int8 gemm."); + } + } +} + +void register_gemm(pybind11::module_ &m) { + m.def("gemm_cuda", &int8_gemm); +} + diff --git a/turbodiffusion/ops/gemm/kernel_hip.hpp b/turbodiffusion/ops/gemm/kernel_hip.hpp new file mode 100644 index 0000000..9e2ec80 --- /dev/null +++ b/turbodiffusion/ops/gemm/kernel_hip.hpp @@ -0,0 +1,523 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + */ + +#pragma once + +#include +#include "cute/tensor_hip.hpp" + +#include "common/common_hip.hpp" +#include "gemm/utils_hip.hpp" + +using namespace cute; + +template < + class OutputDtype_, + bool IsEvenM, + bool IsEvenN +> +struct GemmKernel { + using ElementA = int8_t; + using ElementB = int8_t; + using OutputDtype = OutputDtype_; + using AccumulatorDtype = int32_t; + static constexpr int BlockSize = 128; + static constexpr int TileM = 128; + static constexpr int TileN = 128; + static constexpr int TileK = 128; + static constexpr int Stage = 3; + static constexpr int EpiStage = 2; + + static_assert( + BlockSize % TileM == 0 + && BlockSize % TileN == 0 + && BlockSize % TileK == 0 + ); + + static constexpr int NumTilePerBlock = BlockSize / TileK; + + using SmemLayoutAtom = decltype( + composition( + Swizzle<3, 4, 3>{}, + make_layout( + make_shape(Int<8>{}, Int{}), + make_stride(Int{}, Int<1>{}) + ) + ) + ); + + using SmemLayoutA = decltype( + tile_to_shape( + SmemLayoutAtom{}, + make_shape(Int{}, Int{}, Int{}) + ) + ); + + using SmemLayoutB = decltype( + tile_to_shape( + SmemLayoutAtom{}, + make_shape(Int{}, Int{}, Int{}) + ) + ); + + using MmaOP = cute::SM80_16x8x32_S32S8S8S32_TN; + using TiledMma = decltype( + make_tiled_mma( + MMA_Atom>{}, + make_layout(make_shape( + _4{}, _2{}, _1{} + )), + make_tile(Int<64>{}, Int<32>{}, Int<32>{}) + ) + ); + + using G2SCopyAtomA = Copy_Atom>, ElementA>; + using G2SCopyAtomB = Copy_Atom>, ElementB>; + using G2STiledCopyA = decltype( + make_tiled_copy( + G2SCopyAtomA{}, + make_layout( + make_shape(Int<64>{}, Int<4>{}), + make_stride(Int<4>{}, Int<1>{}) + ), + make_layout(make_shape(Int<1>{}, Int<16>{})) + ) + ); + using G2STiledCopyB = decltype( + make_tiled_copy( + G2SCopyAtomB{}, + make_layout( + make_shape(Int<64>{}, Int<4>{}), + make_stride(Int<4>{}, Int<1>{}) + ), + make_layout(make_shape(Int<1>{}, Int<16>{})) + ) + ); + + using S2RCopyAtomA = Copy_Atom, ElementA>; + using S2RCopyAtomB = Copy_Atom, ElementB>; + using S2RTiledCopyA = decltype(make_tiled_copy_A(S2RCopyAtomA{}, TiledMma{})); + using S2RTiledCopyB = decltype(make_tiled_copy_B(S2RCopyAtomB{}, TiledMma{})); + + // epilogue + using SmemLayoutAtomD = decltype( + composition( + Swizzle<2, 3, 3>{}, + make_layout( + make_shape(Int<32>{}, Int<32>{}), + LayoutRight{} + ) + ) + ); + + using SmemLayoutD = decltype( + tile_to_shape( + SmemLayoutAtomD{}, + make_shape(Int<64>{}, Int<32>{}, Int{}) + ) + ); + + using R2SCopyAtomD = Copy_Atom>, OutputDtype>; + using R2STiledCopyD = decltype(make_tiled_copy_C(R2SCopyAtomD{}, TiledMma{})); + + using S2GCopyAtomD = Copy_Atom, OutputDtype>; + using S2GCopyD = decltype(make_tiled_copy( + S2GCopyAtomD{}, + make_layout(Shape<_64, _4>{}), + make_layout(Shape<_1, _8>{}) + )); + + using TileShape = decltype(make_shape(Int{}, Int{}, Int{})); + + struct SharedStorageAB: cute::aligned_struct<128> { + array_aligned, 128> smem_A; + array_aligned, 128> smem_B; + array_aligned smem_AS; + array_aligned smem_BS; + array_aligned smem_AF; + }; + + struct SharedStorageD: cute::aligned_struct<128> { + array_aligned> smem_D; + }; + + union SharedStorage { + SharedStorageAB storage_AB; + SharedStorageD storage_D; + }; + + + struct Params { + void const* Aptr; + void const* ASptr; + void const* Bptr; + void const* BSptr; + void* Dptr; + int64_t const m; + int64_t const n; + int64_t const k; + int const swizzle_dir; + int const swizzle_size; + }; + + using Arguments = Params; + + static constexpr int ThreadNum = size(TiledMma{}); + static constexpr int ShmSize = sizeof(SharedStorage); + static constexpr bool FastInt2Float = false; + + static bool can_implement(int64_t m, int64_t n, int64_t k) { + if (k % BlockSize != 0) return false; + if ((n * sizeof(OutputDtype)) % 16 != 0) + return false; + return true; + } + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(cdiv(m, TileM) * cdiv(n, TileN)); + } + + CUTLASS_HOST_DEVICE + static auto get_block_coord( + int64_t m_blocks, + int64_t n_blocks, + int const swizzle_dir, + int64_t const swizzle_size_log + ) { + int64_t blk_m; + int64_t blk_n; + + if (swizzle_dir == 1) + std::swap(m_blocks, n_blocks); + + if (swizzle_size_log == 0) { + blk_m = blockIdx.x % m_blocks; + blk_n = blockIdx.x / m_blocks; + } else { + int64_t group_size = n_blocks << swizzle_size_log; + int64_t num_groups = m_blocks >> swizzle_size_log; + int64_t group_idx = blockIdx.x / group_size; + int64_t local_idx = blockIdx.x % group_size; + if (group_idx == num_groups) { + blk_m = (num_groups << swizzle_size_log) + local_idx % (m_blocks - (num_groups << swizzle_size_log)); + blk_n = local_idx / (m_blocks - (num_groups << swizzle_size_log)); + } else { + blk_m = (local_idx & ((1LL << swizzle_size_log) - 1)) + (group_idx << swizzle_size_log); + blk_n = local_idx >> swizzle_size_log; + } + } + + if (swizzle_dir == 1) + std::swap(blk_m, blk_n); + + return make_coord(blk_m, blk_n); + } + + CUTLASS_DEVICE + void operator()( + Params const& params, char* smem_data + ) { + + SharedStorage& shared_storage = *reinterpret_cast(smem_data); + + auto t_idx = threadIdx.x; + + int64_t const m = params.m; + int64_t const n = params.n; + int64_t const k = params.k; + int const swizzle_dir = params.swizzle_dir; + int const swizzle_size = params.swizzle_size; + + Tensor A = make_tensor( + make_gmem_ptr(params.Aptr), + make_shape(m, k), + make_stride(k, _1{}) + ); + Tensor B = make_tensor( + make_gmem_ptr(params.Bptr), + make_shape(m, k), + make_stride(k, _1{}) + ); + Tensor AS = make_tensor( + make_gmem_ptr(params.ASptr), + make_shape(cdiv(m, BlockSize), cdiv(k, BlockSize)), + make_stride(cdiv(k, BlockSize), _1{}) + ); + Tensor BS = make_tensor( + make_gmem_ptr(params.BSptr), + make_shape(cdiv(n, BlockSize), cdiv(k, BlockSize)), + make_stride(cdiv(k, BlockSize), _1{}) + ); + Tensor D = make_tensor( + make_gmem_ptr(params.Dptr), + make_shape(m, n), + LayoutRight{} + ); + + auto [m_coord, n_coord] = get_block_coord( + cdiv(m, size<0>(TileShape{})), + cdiv(n, size<1>(TileShape{})), + swizzle_dir, swizzle_size + ); + + int32_t blk_m_coord = m_coord / (BlockSize / TileM); + int32_t blk_n_coord = n_coord / (BlockSize / TileN); + + // local tile + auto gA = local_tile(A, TileShape{}, make_coord(m_coord, n_coord, _), Step<_1, X, _1>{}); + auto gB = local_tile(B, TileShape{}, make_coord(m_coord, n_coord, _), Step{}); + auto gD = local_tile(D, TileShape{}, make_coord(m_coord, n_coord, _), Step<_1, _1, X>{}); + + // shared memory + Tensor sA = make_tensor( + make_smem_ptr(shared_storage.storage_AB.smem_A.data()), + SmemLayoutA{} + ); + + Tensor sB = make_tensor( + make_smem_ptr(shared_storage.storage_AB.smem_B.data()), + SmemLayoutB{} + ); + + // register + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(t_idx); + auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); + auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); + auto tDrC = thr_mma.partition_fragment_C(gD); // mma accumulator + auto tDrD = make_tensor_like(tDrC); // float accumulator + + if constexpr (FastInt2Float) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tDrC); ++i) + tDrC(i) = 0x4B400000; + } else { + clear(tDrC); + } + + clear(tDrD); + + + // global to shared copy + G2STiledCopyA g2s_tiled_copy_a; + auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(t_idx); + auto tAgA = g2s_thr_copy_a.partition_S(gA); + auto tAsA = g2s_thr_copy_a.partition_D(sA); + auto cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); + auto tAcA = g2s_thr_copy_a.partition_S(cA); + int const m_limit = m - TileM * m_coord; + int const n_limit = n - TileN * n_coord; + + G2STiledCopyB g2s_tiled_copy_b; + auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(t_idx); + auto tBgB = g2s_thr_copy_b.partition_S(gB); + auto tBsB = g2s_thr_copy_b.partition_D(sB); + auto cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); + auto tBcB = g2s_thr_copy_a.partition_S(cB); + + + // shared to register copy + S2RTiledCopyA s2r_tiled_copy_a; + auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(t_idx); + auto tCsA = s2r_thr_copy_a.partition_S(sA); + auto tCrA_view = s2r_thr_copy_a.retile_D(tCrA); + + S2RTiledCopyB s2r_tiled_copy_b; + auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(t_idx); + auto tCsB = s2r_thr_copy_b.partition_S(sB); + auto tCrB_view = s2r_thr_copy_b.retile_D(tCrB); + + // pipeline status + int64_t g2s_a_tile = 0; + int64_t g2s_b_tile = 0; + int g2s_a_smem = 0; + int g2s_b_smem = 0; + + int g2s_tile_in_block = 0; + int g2s_block = 0; // b block idx + + int s2r_a_smem = 0; + int s2r_b_smem = 0; + int s2r_tile_in_block = 0; + + int mma_block_a = 0; + int mma_block_b = 0; + + int ntile = k / TileK; + // load scale and fallback + // we assume all ptrs are 128bit aligned + // auto smem_fallback_A = raw_pointer_cast(make_smem_ptr(shared_storage.storage_AB.smem_AF.data())); + // auto smem_scale_A = raw_pointer_cast(make_smem_ptr(shared_storage.storage_AB.smem_AS.data())); + // auto smem_scale_B = raw_pointer_cast(make_smem_ptr(shared_storage.storage_AB.smem_BS.data())); + __syncthreads(); + + + int32_t fallbackA_load = 0; + int32_t fallbackA_mma = 0; + + // copy first Stage - 1 tile + CUTLASS_PRAGMA_UNROLL + for (int i = 0, _i = min(Stage - 1, ntile); i < _i; ++i) { + if (g2s_b_tile < ntile) { + g2s_tile_in_block = (g2s_tile_in_block + 1) % NumTilePerBlock; + copy_AB(g2s_tiled_copy_a, tAgA, tAsA, tAcA, g2s_a_tile, g2s_a_smem, m_limit); + copy_AB(g2s_tiled_copy_b, tBgB, tBsB, tBcB, g2s_b_tile, g2s_b_smem, n_limit); + ++g2s_b_tile; + ++g2s_b_smem; + ++g2s_block; + g2s_a_tile = g2s_block * NumTilePerBlock; + ++g2s_a_smem; + } + cp_async_fence(); + } + + constexpr int nk = size<2>(tCrA); + float scale_a = AS(blk_m_coord, 0); + float scale_b = BS(blk_n_coord, 0); + + CUTLASS_PRAGMA_NO_UNROLL + for (int64_t mma_b_tile = 0; mma_b_tile < ntile; ++mma_b_tile) { + s2r_tile_in_block = (s2r_tile_in_block + 1) % NumTilePerBlock; + cp_async_wait(); + __syncthreads(); + + // do mma first + CUTLASS_PRAGMA_UNROLL + for (int ik = 0; ik < nk; ++ik) { + cute::copy(s2r_tiled_copy_a, tCsA(_, _, ik, s2r_a_smem), + tCrA_view(_, _, ik)); + cute::copy(s2r_tiled_copy_b, tCsB(_, _, ik, s2r_b_smem), + tCrB_view(_, _, ik)); + cute::gemm(tiled_mma, tDrC, tCrA(_, _, ik), tCrB(_, _, ik), tDrC); + } + + // a s2r increase anyway + s2r_a_smem = (s2r_a_smem + 1) % Stage; + + // get next s2r b tile int64_t + // end of a block + + // dequant first + dequant( + tDrC.data(), tDrD.data(), scale_a * scale_b + ); + + s2r_b_smem = (s2r_b_smem + 1) % Stage; + // b advance + ++mma_block_b; + if (mma_block_b < size<1>(BS)) scale_b = BS(blk_n_coord, mma_block_b); + mma_block_a = mma_block_b; + if (mma_block_a < size<1>(AS)) scale_a = AS(blk_m_coord, mma_block_a); + + // load next stage + if (g2s_b_tile < ntile) { + g2s_tile_in_block = (g2s_tile_in_block + 1) % NumTilePerBlock; + copy_AB(g2s_tiled_copy_a, tAgA, tAsA, tAcA, g2s_a_tile, g2s_a_smem, m_limit); + copy_AB(g2s_tiled_copy_b, tBgB, tBsB, tBcB, g2s_b_tile, g2s_b_smem, n_limit); + ++g2s_b_tile; + g2s_b_smem = (g2s_b_smem + 1) % Stage; + ++g2s_block; + g2s_a_tile = g2s_block * NumTilePerBlock; + g2s_a_smem = (g2s_a_smem + 1) % Stage; + } + cp_async_fence(); + } + + + // epilogue + + Tensor sD = make_tensor( + make_smem_ptr(shared_storage.storage_D.smem_D.data()), + SmemLayoutD{} + ); + + R2STiledCopyD r2s_tiled_copy_d; + auto r2s_thr_copy_d = r2s_tiled_copy_d.get_slice(t_idx); + auto tDrD_r2s = r2s_thr_copy_d.retile_S(tDrD); + auto tDsD_r2s = r2s_thr_copy_d.partition_D(sD); + + S2GCopyD s2g_tiled_copy_d; + auto s2g_thr_copy_d = s2g_tiled_copy_d.get_slice(t_idx); + auto tDsD_s2g = s2g_thr_copy_d.partition_S(sD); + auto tDgD_s2g = s2g_thr_copy_d.partition_D(gD); + Tensor cD = make_identity_tensor(make_shape(Int{}, Int{})); + auto tDcD_s2g = s2g_thr_copy_d.partition_D(cD); + + auto tDgD_s2gx = group_modes<1, 3>(tDgD_s2g); // (CPY_, CPY_MN) + auto tDrD_r2sx = group_modes<1, 3>(tDrD_r2s); // (CPY_, CPY_MN) + auto tDcD_s2gx = group_modes<1, 3>(tDcD_s2g); + + int32_t step = size<3>(tDsD_r2s); // pipe + CUTLASS_PRAGMA_UNROLL + for (int32_t i = 0; i < size<1>(tDrD_r2sx); i += step) { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) { + if constexpr (std::is_same::value) { + cute::copy(r2s_tiled_copy_d, tDrD_r2sx(_, i + j), tDsD_r2s(_, 0, 0, j)); + } else { + auto t = make_tensor_like(tDrD_r2sx(_, i + j)); + cute::copy(tDrD_r2sx(_, i + j), t); + cute::copy(r2s_tiled_copy_d, t, tDsD_r2s(_, 0, 0, j)); + } + } + + __syncthreads(); + + // shm -> global + if constexpr (IsEvenM && IsEvenN) { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) + cute::copy(s2g_tiled_copy_d, tDsD_s2g(_, 0, 0, j), tDgD_s2gx(_, i + j)); + } else if constexpr (IsEvenN) { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) { + if (get<0>(tDcD_s2gx(0, i + j)) < m_limit) + cute::copy(s2g_tiled_copy_d, tDsD_s2g(_, 0, 0, j), tDgD_s2gx(_, i + j)); + } + } else if constexpr (IsEvenM) { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) + if (get<1>(tDcD_s2gx(size<0>(tDsD_s2g) - 1, i + j)) < n_limit) { + cute::copy(s2g_tiled_copy_d, tDsD_s2g(_, 0, 0, j), tDgD_s2gx(_, i + j)); + } else { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<0>(tDsD_s2g); ++k) + if (get<1>(tDcD_s2gx(k, i + j)) < n_limit) + tDgD_s2gx(k, i + j) = tDsD_s2g(k, 0, 0, j); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) + if (get<0>(tDcD_s2gx(0, i + j)) < m_limit) { + if (get<1>(tDcD_s2gx(size<0>(tDsD_s2g) - 1, i + j)) < n_limit) { + cute::copy(s2g_tiled_copy_d, tDsD_s2g(_, 0, 0, j), tDgD_s2gx(_, i + j)); + } else { + for (int32_t k = 0; k < size<0>(tDsD_s2g); ++k) + if (get<1>(tDcD_s2gx(k, i + j)) < n_limit) + tDgD_s2gx(k, i + j) = tDsD_s2g(k, 0, 0, j); + } + } + } + __syncthreads(); + } + } + +}; + \ No newline at end of file diff --git a/turbodiffusion/ops/gemm/kernel_rocwmma.hpp b/turbodiffusion/ops/gemm/kernel_rocwmma.hpp new file mode 100644 index 0000000..ee06c67 --- /dev/null +++ b/turbodiffusion/ops/gemm/kernel_rocwmma.hpp @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + * + * rocWMMA-based GEMM kernel for AMD RDNA3 GPUs. + * This kernel performs int8 GEMM with per-block quantization scaling. + * + * Based on rocWMMA from https://github.com/ROCm/rocm-libraries/tree/develop/projects/rocwmma + */ + +#pragma once + +#include + +// Undefine the no-half-conversion macros that PyTorch sets +// rocWMMA needs these conversions to work properly +#ifdef __HIP_NO_HALF_OPERATORS__ +#undef __HIP_NO_HALF_OPERATORS__ +#endif +#ifdef __HIP_NO_HALF_CONVERSIONS__ +#undef __HIP_NO_HALF_CONVERSIONS__ +#endif + +#include +#include +#include + +#include "common/platform.hpp" + +using namespace rocwmma; + +// Helper for float to output type conversion +template +TURBO_DEVICE TURBO_INLINE T float_to_output(float val); + +template <> +TURBO_DEVICE TURBO_INLINE __half float_to_output<__half>(float val) { + return __float2half(val); +} + +template <> +TURBO_DEVICE TURBO_INLINE hip_bfloat16 float_to_output(float val) { + return hip_bfloat16(val); +} + +template <> +TURBO_DEVICE TURBO_INLINE float float_to_output(float val) { + return val; +} + +template <> +TURBO_DEVICE TURBO_INLINE int32_t float_to_output(float val) { + return static_cast(val); +} + +// RDNA3 (gfx11) specific parameters +// Wave size: 32, Block sizes: 16x16x16 +namespace rdna3 { + constexpr uint32_t ROCWMMA_M = 16u; + constexpr uint32_t ROCWMMA_N = 16u; + constexpr uint32_t ROCWMMA_K = 16u; + constexpr uint32_t WAVE_SIZE = 32u; + constexpr uint32_t QUANT_BLOCK = 128u; // Quantization block size +} + +template < + class OutputDtype_, + bool IsEvenM, + bool IsEvenN +> +struct GemmKernelRocWMMA { + using ElementA = int8_t; + using ElementB = int8_t; + using OutputDtype = OutputDtype_; + using AccumulatorDtype = int32_t; + using ComputeDtype = int32_t; // MMA accumulator type + + // Tile sizes + static constexpr int TileM = rdna3::ROCWMMA_M; + static constexpr int TileN = rdna3::ROCWMMA_N; + static constexpr int TileK = rdna3::ROCWMMA_K; + static constexpr int BlockSize = rdna3::QUANT_BLOCK; + static constexpr int WaveSize = rdna3::WAVE_SIZE; + + // Warp tile: how many MMA tiles each wave computes + static constexpr int WarpTileM = 2; // 2 tiles in M direction = 32 + static constexpr int WarpTileN = 2; // 2 tiles in N direction = 32 + + // Thread block configuration + static constexpr int TBlockX = 128; // 4 waves + static constexpr int TBlockY = 1; + static constexpr int NumWarps = TBlockX / WaveSize; // 4 waves + + // Macro tile computed by entire thread block + static constexpr int MacroTileM = NumWarps * WarpTileM * TileM; // 4 * 2 * 16 = 128 + static constexpr int MacroTileN = TBlockY * WarpTileN * TileN; // 1 * 2 * 16 = 32 + + // Fragment types - using row_major for A and col_major for B (NT layout) + using FragA = fragment; + using FragB = fragment; + using FragAcc = fragment; + + struct Params { + void const* Aptr; + void const* ASptr; + void const* Bptr; + void const* BSptr; + void* Dptr; + int64_t const m; + int64_t const n; + int64_t const k; + int const swizzle_dir; + int const swizzle_size; + }; + + using Arguments = Params; + + static constexpr int ThreadNum = TBlockX * TBlockY; + // Shared memory for storing fragments: each wave needs TileM*TileN*sizeof(float) per tile + // With WarpTileM=2, WarpTileN=2, and NumWarps=4 waves: + // Size = NumWarps * WarpTileM * WarpTileN * TileM * TileN * sizeof(float) + // = 4 * 2 * 2 * 16 * 16 * 4 = 16384 bytes + static constexpr int ShmSize = NumWarps * WarpTileM * WarpTileN * TileM * TileN * sizeof(float); + static constexpr int MaxThreadsPerBlock = ThreadNum; + static constexpr int MinBlocksPerMultiprocessor = 1; + + static bool can_implement(int64_t m, int64_t n, int64_t k) { + if (k % BlockSize != 0) return false; + if ((n * sizeof(OutputDtype)) % 16 != 0) return false; + return true; + } + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + TURBO_HOST_DEVICE + static int64_t cdiv(int64_t a, int64_t b) { + return (a + b - 1) / b; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + int64_t grid_m = cdiv(m, MacroTileM); + int64_t grid_n = cdiv(n, MacroTileN); + return dim3(grid_m * grid_n); + } + + TURBO_DEVICE + void operator()(Params const& params, char* smem_data) { + int64_t const m = params.m; + int64_t const n = params.n; + int64_t const k = params.k; + + // Wave and lane indices + int waveId = threadIdx.x / WaveSize; + int laneId = threadIdx.x % WaveSize; + + // Grid dimensions + int64_t grid_m = cdiv(m, MacroTileM); + int64_t grid_n = cdiv(n, MacroTileN); + + // Block coordinates (linear to 2D) + int64_t block_m = blockIdx.x % grid_m; + int64_t block_n = blockIdx.x / grid_m; + + // Base coordinates for this wave's output tiles + int64_t wave_m_base = block_m * MacroTileM + waveId * WarpTileM * TileM; + int64_t wave_n_base = block_n * MacroTileN; + + // Pointers + ElementA const* A = reinterpret_cast(params.Aptr); + ElementB const* B = reinterpret_cast(params.Bptr); + float const* AS = reinterpret_cast(params.ASptr); + float const* BS = reinterpret_cast(params.BSptr); + + // Number of quantization blocks in K dimension + int64_t num_quant_blocks_k = k / BlockSize; + + // Float accumulators for dequantized results + float floatAcc[WarpTileM][WarpTileN][FragAcc::num_elements]; + + // Initialize accumulators + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + TURBO_PRAGMA_UNROLL + for (int i = 0; i < FragAcc::num_elements; ++i) { + floatAcc[wm][wn][i] = 0.0f; + } + } + } + + // Process each quantization block + for (int64_t qb = 0; qb < num_quant_blocks_k; ++qb) { + int64_t k_start = qb * BlockSize; + int64_t k_end = k_start + BlockSize; + + // Integer accumulators for this quant block + FragAcc fragAcc[WarpTileM][WarpTileN]; + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + fill_fragment(fragAcc[wm][wn], static_cast(0)); + } + } + + // K-loop within quantization block + for (int64_t kk = k_start; kk < k_end; kk += TileK) { + // Load and compute for each tile in warp tile + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + int64_t tile_m = wave_m_base + wm * TileM; + + FragA fragA; + if (tile_m < m) { + // A is row-major: A[m, k] + load_matrix_sync(fragA, A + tile_m * k + kk, k); + } else { + fill_fragment(fragA, static_cast(0)); + } + + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + int64_t tile_n = wave_n_base + wn * TileN; + + FragB fragB; + if (tile_n < n) { + // B is stored as [N, K] in col-major (K changes fastest when reading B[n, :]) + load_matrix_sync(fragB, B + tile_n * k + kk, k); + } else { + fill_fragment(fragB, static_cast(0)); + } + + // Matrix multiply-accumulate + mma_sync(fragAcc[wm][wn], fragA, fragB, fragAcc[wm][wn]); + } + } + } + + // Dequantize this block's contribution + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + int64_t tile_m = wave_m_base + wm * TileM; + int64_t qblock_m = tile_m / BlockSize; + + // Get scale for A + float scale_a = 1.0f; + if (qblock_m < cdiv(m, BlockSize) && qb < num_quant_blocks_k) { + scale_a = AS[qblock_m * num_quant_blocks_k + qb]; + } + + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + int64_t tile_n = wave_n_base + wn * TileN; + int64_t qblock_n = tile_n / BlockSize; + + // Get scale for B + float scale_b = 1.0f; + if (qblock_n < cdiv(n, BlockSize) && qb < num_quant_blocks_k) { + scale_b = BS[qblock_n * num_quant_blocks_k + qb]; + } + + float scale = scale_a * scale_b; + + // Accumulate dequantized values + TURBO_PRAGMA_UNROLL + for (int i = 0; i < FragAcc::num_elements; ++i) { + floatAcc[wm][wn][i] += static_cast(fragAcc[wm][wn].x[i]) * scale; + } + } + } + } + + // Store final results using store_matrix_sync to shared memory temp buffer + // This ensures correct fragment layout interpretation + OutputDtype* D = reinterpret_cast(params.Dptr); + + // Each wave gets its own section of shared memory + float* smem_temp = reinterpret_cast(smem_data); + float* wave_smem = smem_temp + waveId * WarpTileM * WarpTileN * TileM * TileN; + + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + int64_t tile_m = wave_m_base + wm * TileM; + + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + int64_t tile_n = wave_n_base + wn * TileN; + + // Create a float fragment from the accumulated values + fragment fragFloat; + TURBO_PRAGMA_UNROLL + for (int i = 0; i < FragAcc::num_elements; ++i) { + fragFloat.x[i] = floatAcc[wm][wn][i]; + } + + // Store to wave's temp buffer using rocWMMA (row-major layout) + float* tile_buf = wave_smem + (wm * WarpTileN + wn) * TileM * TileN; + store_matrix_sync(tile_buf, fragFloat, TileN, mem_row_major); + + __syncthreads(); + + // Now read from tile_buf with linear indexing + for (int e = laneId; e < TileM * TileN; e += WaveSize) { + int local_row = e / TileN; + int local_col = e % TileN; + + int64_t global_row = tile_m + local_row; + int64_t global_col = tile_n + local_col; + + if (global_row < m && global_col < n) { + D[global_row * n + global_col] = float_to_output(tile_buf[e]); + } + } + + __syncthreads(); + } + } + } +}; diff --git a/turbodiffusion/ops/gemm/launch_hip.hpp b/turbodiffusion/ops/gemm/launch_hip.hpp new file mode 100644 index 0000000..ab4f332 --- /dev/null +++ b/turbodiffusion/ops/gemm/launch_hip.hpp @@ -0,0 +1,66 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + */ + +#pragma once + +#include "common/common_hip.hpp" +#include "common/launch_hip.hpp" +#include "gemm/kernel_hip.hpp" + + +template +bool int8_gemm_( + int8_t const *Aptr, float const *ASptr, + int8_t const *Bptr, float const *BSptr, + OutputDtype* Dptr, int64_t m, int64_t n, int64_t k, + int swizzle_dir = 1, int swizzle_size_log = 0, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(m % 128 == 0, IsEvenM, [&] { + BOOL_SWITCH(n % 128 == 0, IsEvenN, [&] { + using Kernel = GemmKernel; + if (!Kernel::can_implement(m, n, k)) + return false; + using Args = typename Kernel::Arguments; + Args args { + (void*)Aptr, (void*)ASptr, + (void*)Bptr, (void*)BSptr, (void*)Dptr, + m, n, k, swizzle_dir, + swizzle_size_log + }; + + auto params = Kernel::to_underlying_arguments(args); + + static constexpr size_t ShmSize = Kernel::ShmSize; + dim3 grid_shape = Kernel::get_grid_size(m, n); + dim3 block_shape = dim3(Kernel::ThreadNum); + auto func = device_kernel; + if (ShmSize >= 48 * 1024) { + hipFuncSetAttribute( + func, + hipFuncAttributeMaxDynamicSharedMemorySize, + ShmSize + ); + } + hipLaunchKernelGGL(( func), dim3(grid_shape), dim3(block_shape), ShmSize, stream, + params + ); + return true; + }); + }); + return true; +} diff --git a/turbodiffusion/ops/gemm/launch_rocwmma.hpp b/turbodiffusion/ops/gemm/launch_rocwmma.hpp new file mode 100644 index 0000000..9b975a7 --- /dev/null +++ b/turbodiffusion/ops/gemm/launch_rocwmma.hpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * rocWMMA GEMM kernel launch wrapper for AMD GPUs. + */ + +#pragma once + +#include +#include "common/platform.hpp" +#include "gemm/kernel_rocwmma.hpp" + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ +[&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return (__VA_ARGS__)(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return (__VA_ARGS__)(); \ + } \ +}() + +template +__global__ void rocwmma_gemm_kernel( + typename Kernel::Params const params +) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template +bool int8_gemm_rocwmma( + int8_t const* Aptr, float const* ASptr, + int8_t const* Bptr, float const* BSptr, + OutputDtype* Dptr, int64_t m, int64_t n, int64_t k, + int swizzle_dir = 1, int swizzle_size_log = 0, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(m % 128 == 0, IsEvenM, [&] { + BOOL_SWITCH(n % 128 == 0, IsEvenN, [&] { + using Kernel = GemmKernelRocWMMA; + + if (!Kernel::can_implement(m, n, k)) { + return false; + } + + using Args = typename Kernel::Arguments; + Args args{ + (void*)Aptr, (void*)ASptr, + (void*)Bptr, (void*)BSptr, (void*)Dptr, + m, n, k, swizzle_dir, + swizzle_size_log + }; + + auto params = Kernel::to_underlying_arguments(args); + + static constexpr size_t ShmSize = Kernel::ShmSize; + dim3 grid_shape = Kernel::get_grid_size(m, n); + dim3 block_shape = dim3(Kernel::ThreadNum); + + auto func = rocwmma_gemm_kernel; + if (ShmSize >= 48 * 1024) { + hipFuncSetAttribute( + func, + hipFuncAttributeMaxDynamicSharedMemorySize, + ShmSize + ); + } + + hipLaunchKernelGGL(func, grid_shape, block_shape, ShmSize, stream, params); + return true; + }); + }); + return true; +} + diff --git a/turbodiffusion/ops/gemm/utils_hip.hpp b/turbodiffusion/ops/gemm/utils_hip.hpp new file mode 100644 index 0000000..76356b4 --- /dev/null +++ b/turbodiffusion/ops/gemm/utils_hip.hpp @@ -0,0 +1,130 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + */ + +#pragma once + +#include +#include "cute/tensor_hip.hpp" + +template < + bool IsEven, + class TiledCopy, + class SrcTensor, + class DstTensor, + class PrdTensor +> +CUTLASS_DEVICE void +copy_AB( + TiledCopy const& _copy, + SrcTensor const &S, + DstTensor &D, + PrdTensor const &ID, + const int64_t &i_read, + const int64_t &i_write, + const int64_t &limit +) { + using namespace cute; + if constexpr (IsEven) + cute::copy(_copy, S(_, _, _, i_read), D(_, _, _, i_write)); + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(ID); ++i) + if (get<0>(ID(0, i, 0)) < limit) + cute::copy(_copy, S(_, i, _, i_read), D(_, i, _, i_write)); + } +} + +template +CUTLASS_DEVICE void copy_async( + void const* gmem_src, + void* smem_dst +) { + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_dst));; + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_src), + "n"(N)); +} + +template +CUTLASS_DEVICE void copy_aligned(const void* src, void* dst, size_t N, int64_t thread_idx) { + static constexpr int NumElementPerLoad = sizeof(LoadType) / sizeof(T); + for (int64_t i = thread_idx * NumElementPerLoad; i < N; i += NumElementPerLoad * NumThreads) { + if (i + NumElementPerLoad <= N) { + copy_async( + (void*)((T*)src + i), + (void*)((T*)dst + i) + ); + } else { + for (int64_t j = 0; j < N - i; ++j) + copy_async( + (void*)((T*)src + i + j), + (void*)((T*)dst + i + j) + ); + } + } +} + +template +CUTLASS_DEVICE void g2s_vector_copy(const void* src, void* dst, size_t N, int64_t thread_idx) { + + uintptr_t src_addr = reinterpret_cast(src); + + if (src_addr % 16 == 0) { + copy_aligned(src, dst, N, thread_idx); + } else if (src_addr % 8 == 0) { + copy_aligned(src, dst, N, thread_idx); + } else if (src_addr % 4 == 0) { + copy_aligned(src, dst, N, thread_idx); + } else { + assert(0); + } + if constexpr (Commit) { + asm volatile("cp.async.commit_group;\n" ::); + } + if constexpr (Wait) { + asm volatile("cp.async.wait_all;\n" ::); + } +} + +template +CUTLASS_DEVICE +static void dequant( + T* mma_accum_ptr, + float* float_accum_ptr, + float scale +) { + static int const ic = 0x4B400000; + if constexpr (FastInt2Float && std::is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < N; ++i) { + *(float_accum_ptr + i) += (__int_as_float(*(mma_accum_ptr + i)) - __int_as_float(ic)) * scale; + *(mma_accum_ptr + i) = ic; + } + } else if constexpr (std::is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < N; ++i) { + *(float_accum_ptr + i) += __int2float_rn(*(mma_accum_ptr + i)) * scale; + *(mma_accum_ptr + i) = 0; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < N; ++i) { + *(float_accum_ptr + i) += (*(mma_accum_ptr + i)) * scale; + *(mma_accum_ptr + i) = 0; + } + } +} \ No newline at end of file diff --git a/turbodiffusion/ops/norm/layernorm.hip b/turbodiffusion/ops/norm/layernorm.hip new file mode 100644 index 0000000..83c78d9 --- /dev/null +++ b/turbodiffusion/ops/norm/layernorm.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * HIP/ROCm LayerNorm kernel. + */ + +#include +#include +#include +#include +#include +#include + +#include "common/common_hip.hpp" +#include "norm/layernorm_hip.hpp" + +auto layer_norm( + at::Tensor const Input, + float eps, + std::optional W, + std::optional const B, + std::optional Output +) { + using ElementIn = float; + using ElementOut = float; + using ElementWeight = float; + + int64_t const m = Input.size(0); + int64_t const n = Input.size(1); + torch::Device const input_device = Input.device(); + + if (!Output.has_value()) { + Output.emplace( + torch::empty( + {m, n}, + torch::TensorOptions().device(input_device).dtype(torch::kFloat32) + ) + ); + } + + void *Iptr = Input.data_ptr(); + void *Wptr = W.has_value() ? W.value().data_ptr() : nullptr; + void *Bptr = B.has_value() ? B.value().data_ptr() : nullptr; + void *Optr = Output.value().data_ptr(); + + BOOL_SWITCH(B.has_value(), BIAS, [&]{ + BOOL_SWITCH(W.has_value(), AFFINE, [&]{ + CONFIG_SWITCH(n, [&]{ + layernorm< + ElementIn, ElementOut, ElementWeight, + AFFINE, BIAS, + MAX_HIDDEN_SIZE, NUM_THR_PER_CTA + >( + Iptr, Wptr, Bptr, + Optr, eps, m, n, + at::hip::getCurrentHIPStream().stream() + ); + }); + }); + }); + + return Output; +} + +void register_layer_norm(pybind11::module_ &m) { + m.def("layer_norm_cuda", &layer_norm); +} + diff --git a/turbodiffusion/ops/norm/layernorm_hip.hpp b/turbodiffusion/ops/norm/layernorm_hip.hpp new file mode 100644 index 0000000..5c5e0f5 --- /dev/null +++ b/turbodiffusion/ops/norm/layernorm_hip.hpp @@ -0,0 +1,221 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#pragma once + +#include "common/common_hip.hpp" +#include "common/load.hpp" +#include "common/store_hip.hpp" +#include "common/launch_hip.hpp" + +// Helper for output type conversion +namespace turbo_layernorm { +template +__device__ __forceinline__ T from_float(float val) { + return static_cast(val); +} +template <> +__device__ __forceinline__ __half from_float<__half>(float val) { + return __float2half(val); +} +template <> +__device__ __forceinline__ hip_bfloat16 from_float(float val) { + return hip_bfloat16(val); +} +} // namespace turbo_layernorm + + +template < + class InputDtype_, + class OutputDtype_, + class WeightDtype_, + bool Affine_, + bool Bias_, + int MaxHiddenSize_, + int NumThrPerCta_, + bool IsEven +> +class LayerNorm { +public: + using InputDtype = InputDtype_; + using OutputDtype = OutputDtype_; + using WeightDtype = WeightDtype_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int MaxHiddenSize = MaxHiddenSize_; + static constexpr bool Affine = Affine_; + static constexpr bool Bias = Bias_; + + static constexpr size_t ShmSize = 32; + static constexpr int NumElementPerThread = MaxHiddenSize / NumThrPerCta; + + static_assert(MaxHiddenSize % NumThrPerCta == 0); + + struct Params { + void const *Iptr; + void const *Wptr; + void const *Bptr; + void *Optr; + float eps; + int64_t m; + int64_t n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(m); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char *shared_data) { + int const blk_m = blockIdx.x; + int const blk_n = 1; + int tidx = threadIdx.x; + float x[NumElementPerThread]; + + // load + Loader loader; + loader.load(params.Iptr, x, params.m, params.n, blk_m, 0, tidx); + + // mean reduction + float u = _reduce_sum(x, shared_data) / params.n; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] -= u; + + __syncthreads(); + // var reduction + float v = sqrtf(_reduce_square(x, shared_data) / params.n + params.eps); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] /= v; + + if constexpr (Affine) { + // load weight + Loader weight_loader; + float w[NumElementPerThread]; + weight_loader.load(params.Wptr, w, 1, params.n, 0, 0, tidx); + if constexpr (Bias) { + float b[NumElementPerThread]; + weight_loader.load(params.Bptr, b, 1, params.n, 0, 0, tidx); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] = x[i] * w[i] + b[i]; + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] = x[i] * w[i]; + } + } + + // save y + { + Saver saver; + if constexpr (std::is_same_v) { + saver.store(params.Optr, nullptr, x, 0, params.m, params.n, blk_m, 0, tidx); + } else { + OutputDtype tmp[NumElementPerThread]; + for (int i = 0; i < NumElementPerThread; ++i) + tmp[i] = turbo_layernorm::from_float(x[i]); + saver.store(params.Optr, nullptr, tmp, 0, params.m, params.n, blk_m, 0, tidx); + } + } + + } + +private: + CUTLASS_DEVICE + float _reduce_square(float *reg, char *shared_data) { + // thread + float sum_square = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + sum_square += reg[i] * reg[i]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum_square += __shfl_down(sum_square, i, 32); + } + if (threadIdx.x == 0) { + *(float*)shared_data = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd((float*)shared_data, sum_square); + } + + __syncthreads(); + sum_square = *(float*)shared_data; + return sum_square; + } + + CUTLASS_DEVICE + float _reduce_sum(float *reg, char *shared_data) { + // thread + float sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + sum += reg[i]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum += __shfl_down(sum, i, 32); + } + if (threadIdx.x == 0) { + *(float*)shared_data = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd((float*)shared_data, sum); + } + + __syncthreads(); + sum = *(float*)shared_data; + return sum; + } +}; + +template < + class InputDtype, + class OutputDtype, + class WeightDtype, + bool Affine, + bool Bias, + int MaxHiddenSize, + int NumThrPerCta +> +bool layernorm( + void const *Iptr, void const *Wptr, void const *Bptr, + void *Optr, float eps, int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(n % MaxHiddenSize == 0, IsEven, [&] { + using Kernel = LayerNorm< + InputDtype, OutputDtype, WeightDtype, + Affine, Bias, + MaxHiddenSize, NumThrPerCta, + IsEven>; + using Arguments = typename Kernel::Arguments; + Arguments args = { + Iptr, Wptr, Bptr, Optr, + eps, m, n + }; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + launch_kernel(params, grid_shape, cta_shape, ShmSize, stream); + }); + return true; +} \ No newline at end of file diff --git a/turbodiffusion/ops/norm/norm_rocm.hpp b/turbodiffusion/ops/norm/norm_rocm.hpp new file mode 100644 index 0000000..1dd636c --- /dev/null +++ b/turbodiffusion/ops/norm/norm_rocm.hpp @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Normalization kernels for AMD GPUs using HIP. + */ + +#pragma once + +#include +#include +#include "common/platform.hpp" + +TURBO_HOST_DEVICE inline int64_t cdiv_norm(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +#define MIN_NORM(a, b) ((a) > (b) ? (b) : (a)) +#define MAX_NORM(a, b) ((a) > (b) ? (a) : (b)) + +#define BOOL_SWITCH_NORM(COND, CONST_NAME, ...) \ +[&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return (__VA_ARGS__)(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return (__VA_ARGS__)(); \ + } \ +}() + +// RMSNorm Kernel +template < + class InputDtype_, + class OutputDtype_, + class WeightDtype_, + int MaxHiddenSize_, + int NumThrPerCta_, + bool IsEven +> +class RMSNormHIP { +public: + using InputDtype = InputDtype_; + using OutputDtype = OutputDtype_; + using WeightDtype = WeightDtype_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int MaxHiddenSize = MaxHiddenSize_; + static constexpr size_t ShmSize = 32; + static constexpr int NumElementPerThread = MaxHiddenSize / NumThrPerCta; + + static_assert(MaxHiddenSize % NumThrPerCta == 0); + + struct Params { + void const* Iptr; + void const* Wptr; + void* Optr; + float eps; + int64_t m; + int64_t n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(m); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + TURBO_DEVICE + void operator()(Params const& params, char* shared_data) { + int blk_m = blockIdx.x; + int tidx = threadIdx.x; + float x[NumElementPerThread]; + float w[NumElementPerThread]; + + // Load input + load_input(params.Iptr, x, params.m, params.n, blk_m, tidx); + + // RMS reduction + float rms = sqrtf(reduce_square(x, shared_data) / params.n + params.eps); + + // Load weight + load_weight(params.Wptr, w, params.n, tidx); + + // Normalize + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + x[i] = w[i] * x[i] / rms; + } + + // Store output + store_output(params.Optr, x, params.m, params.n, blk_m, tidx); + } + +private: + TURBO_DEVICE + void load_input(void const* input_ptr, float* reg, int64_t m, int64_t n, int blk_m, int tidx) { + InputDtype const* input = reinterpret_cast(input_ptr); + int64_t offset = blk_m * n + tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + reg[i] = static_cast(input[offset + i]); + } else { + reg[i] = 0.0f; + } + } + } + + TURBO_DEVICE + void load_weight(void const* weight_ptr, float* reg, int64_t n, int tidx) { + if (weight_ptr == nullptr) { + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + reg[i] = 1.0f; + } + return; + } + + WeightDtype const* weight = reinterpret_cast(weight_ptr); + int64_t offset = tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + reg[i] = static_cast(weight[offset + i]); + } else { + reg[i] = 1.0f; + } + } + } + + TURBO_DEVICE + void store_output(void* output_ptr, float* reg, int64_t m, int64_t n, int blk_m, int tidx) { + OutputDtype* output = reinterpret_cast(output_ptr); + int64_t offset = blk_m * n + tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + output[offset + i] = static_cast(reg[i]); + } + } + } + + TURBO_DEVICE + float reduce_square(float* reg, char* shared_data) { + float sum_square = 0; + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + sum_square += reg[i] * reg[i]; + } + + TURBO_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum_square += __shfl_down(sum_square, i, 32); + } + + if (threadIdx.x == 0) { + *(float*)shared_data = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd((float*)shared_data, sum_square); + } + + __syncthreads(); + sum_square = *(float*)shared_data; + return sum_square; + } +}; + +// LayerNorm Kernel +template < + class InputDtype_, + class OutputDtype_, + class WeightDtype_, + bool Affine, + bool Bias, + int MaxHiddenSize_, + int NumThrPerCta_, + bool IsEven +> +class LayerNormHIP { +public: + using InputDtype = InputDtype_; + using OutputDtype = OutputDtype_; + using WeightDtype = WeightDtype_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int MaxHiddenSize = MaxHiddenSize_; + static constexpr size_t ShmSize = 64; + static constexpr int NumElementPerThread = MaxHiddenSize / NumThrPerCta; + + static_assert(MaxHiddenSize % NumThrPerCta == 0); + + struct Params { + void const* Iptr; + void const* Wptr; + void const* Bptr; + void* Optr; + float eps; + int64_t m; + int64_t n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(m); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + TURBO_DEVICE + void operator()(Params const& params, char* shared_data) { + int blk_m = blockIdx.x; + int tidx = threadIdx.x; + float x[NumElementPerThread]; + float w[NumElementPerThread]; + float b[NumElementPerThread]; + + // Load input + load_input(params.Iptr, x, params.m, params.n, blk_m, tidx); + + // Compute mean and variance + float mean, var; + reduce_mean_var(x, shared_data, params.n, mean, var); + float rstd = rsqrtf(var + params.eps); + + // Load weight and bias + if constexpr (Affine) { + load_weight(params.Wptr, w, params.n, tidx); + } + if constexpr (Bias) { + load_weight(params.Bptr, b, params.n, tidx); + } + + // Normalize + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + float val = (x[i] - mean) * rstd; + if constexpr (Affine) { + val *= w[i]; + } + if constexpr (Bias) { + val += b[i]; + } + x[i] = val; + } + + // Store output + store_output(params.Optr, x, params.m, params.n, blk_m, tidx); + } + +private: + TURBO_DEVICE + void load_input(void const* input_ptr, float* reg, int64_t m, int64_t n, int blk_m, int tidx) { + InputDtype const* input = reinterpret_cast(input_ptr); + int64_t offset = blk_m * n + tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + reg[i] = static_cast(input[offset + i]); + } else { + reg[i] = 0.0f; + } + } + } + + TURBO_DEVICE + void load_weight(void const* weight_ptr, float* reg, int64_t n, int tidx) { + if (weight_ptr == nullptr) { + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + reg[i] = 1.0f; + } + return; + } + + WeightDtype const* weight = reinterpret_cast(weight_ptr); + int64_t offset = tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + reg[i] = static_cast(weight[offset + i]); + } else { + reg[i] = 0.0f; + } + } + } + + TURBO_DEVICE + void store_output(void* output_ptr, float* reg, int64_t m, int64_t n, int blk_m, int tidx) { + OutputDtype* output = reinterpret_cast(output_ptr); + int64_t offset = blk_m * n + tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + output[offset + i] = static_cast(reg[i]); + } + } + } + + TURBO_DEVICE + void reduce_mean_var(float* reg, char* shared_data, int64_t n, float& mean, float& var) { + float sum = 0; + float sum_sq = 0; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + sum += reg[i]; + sum_sq += reg[i] * reg[i]; + } + + // Warp reduction + TURBO_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum += __shfl_down(sum, i, 32); + sum_sq += __shfl_down(sum_sq, i, 32); + } + + float* smem = (float*)shared_data; + if (threadIdx.x == 0) { + smem[0] = 0; + smem[1] = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd(&smem[0], sum); + atomicAdd(&smem[1], sum_sq); + } + + __syncthreads(); + + sum = smem[0]; + sum_sq = smem[1]; + + mean = sum / n; + var = sum_sq / n - mean * mean; + } +}; + +// Kernel launchers +template +__global__ void norm_kernel_hip(typename Kernel::Params const params) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template < + class InputDtype, + class OutputDtype, + class WeightDtype, + int MaxHiddenSize, + int NumThrPerCta +> +bool rmsnorm_hip( + void const* Iptr, void const* Wptr, + void* Optr, float eps, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH_NORM(n % MaxHiddenSize == 0, IsEven, [&] { + using Kernel = RMSNormHIP; + using Arguments = typename Kernel::Arguments; + + Arguments args = {Iptr, Wptr, Optr, eps, m, n}; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + + hipLaunchKernelGGL(norm_kernel_hip, grid_shape, cta_shape, ShmSize, stream, params); + }); + return true; +} + +template < + class InputDtype, + class OutputDtype, + class WeightDtype, + bool Affine, + bool Bias, + int MaxHiddenSize, + int NumThrPerCta +> +bool layernorm_hip( + void const* Iptr, void const* Wptr, void const* Bptr, + void* Optr, float eps, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH_NORM(n % MaxHiddenSize == 0, IsEven, [&] { + using Kernel = LayerNormHIP; + using Arguments = typename Kernel::Arguments; + + Arguments args = {Iptr, Wptr, Bptr, Optr, eps, m, n}; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + + hipLaunchKernelGGL(norm_kernel_hip, grid_shape, cta_shape, ShmSize, stream, params); + }); + return true; +} + diff --git a/turbodiffusion/ops/norm/rmsnorm.hip b/turbodiffusion/ops/norm/rmsnorm.hip new file mode 100644 index 0000000..fc8af49 --- /dev/null +++ b/turbodiffusion/ops/norm/rmsnorm.hip @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * HIP/ROCm RMSNorm kernel. + */ + +#include +#include +#include +#include +#include +#include + +#include "common/common_hip.hpp" +#include "norm/rmsnorm_hip.hpp" + +auto rms_norm( + at::Tensor const& Input, + float eps, + const std::optional& Weight, + std::optional& Output +) { + using ElementIn = float; + using ElementOut = float; + using ElementWeight = float; + + int64_t const m = Input.size(0); + int64_t const n = Input.size(1); + torch::Device const input_device = Input.device(); + + if (!Output.has_value()) { + Output.emplace( + torch::empty( + {m, n}, + torch::TensorOptions().device(input_device).dtype(torch::kFloat32) + ) + ); + } + + void *Iptr = Input.data_ptr(); + void *Wptr = Weight.has_value() ? Weight.value().data_ptr() : nullptr; + void *Optr = Output.value().data_ptr(); + + CONFIG_SWITCH(n, [&]{ + rmsnorm< + ElementIn, ElementOut, ElementWeight, + MAX_HIDDEN_SIZE, NUM_THR_PER_CTA + >( + Iptr, Wptr, + Optr, + eps, m, n, + at::hip::getCurrentHIPStream().stream() + ); + }); + + return Output; +} + +void register_rms_norm(pybind11::module_ &m) { + m.def("rms_norm_cuda", &rms_norm); +} + diff --git a/turbodiffusion/ops/norm/rmsnorm_hip.hpp b/turbodiffusion/ops/norm/rmsnorm_hip.hpp new file mode 100644 index 0000000..b4dbda3 --- /dev/null +++ b/turbodiffusion/ops/norm/rmsnorm_hip.hpp @@ -0,0 +1,166 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#pragma once + +#include "common/common_hip.hpp" +#include "common/load.hpp" +#include "common/store_hip.hpp" +#include "common/launch_hip.hpp" + +// Helper for output type conversion +namespace turbo_norm { +template +__device__ __forceinline__ T from_float(float val) { + return static_cast(val); +} +template <> +__device__ __forceinline__ __half from_float<__half>(float val) { + return __float2half(val); +} +template <> +__device__ __forceinline__ hip_bfloat16 from_float(float val) { + return hip_bfloat16(val); +} +} // namespace turbo_norm + + +template < + class InputDtype_, + class OutputDtype_, + class WeightDtype_, + int MaxHiddenSize_, + int NumThrPerCta_, + bool IsEven +> +class RMSNorm { +public: + using InputDtype = InputDtype_; + using OutputDtype = OutputDtype_; + using WeightDtype = WeightDtype_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int MaxHiddenSize = MaxHiddenSize_; + + static constexpr size_t ShmSize = 32; + static constexpr int NumElementPerThread = MaxHiddenSize / NumThrPerCta; + + static_assert(MaxHiddenSize % NumThrPerCta == 0); + + struct Params { + void const *Iptr; + void const *Wptr; + void *Optr; + float eps; + int64_t m; + int64_t n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(m); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char *shared_data) { + int const blk_m = blockIdx.x; + int const blk_n = 1; + int tidx = threadIdx.x; + float x[NumElementPerThread]; + + // load + Loader loader; + loader.load(params.Iptr, x, params.m, params.n, blk_m, 0, tidx); + + // rms reduction + float rms = sqrtf(_reduce_square(x, shared_data) / params.n + params.eps); + + // load weight + Loader weight_loader; + float w[NumElementPerThread]; + loader.load(params.Wptr, w, 1, params.n, 0, 0, tidx); + + // norm + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] = w[i] * x[i] / rms ; + + // save y + OutputDtype *output_reg = (OutputDtype*)x; + if constexpr (!std::is_same_v) { + output_reg = (OutputDtype*)w; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + output_reg[i] = turbo_norm::from_float(x[i]); + } + Saver saver; + saver.store(params.Optr, nullptr, output_reg, 0, params.m, params.n, blk_m, 0, tidx); + + } + +private: + CUTLASS_DEVICE + float _reduce_square(float *reg, char *shared_data) { + // thread + float sum_square = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + sum_square += reg[i] * reg[i]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum_square += __shfl_down(sum_square, i, 32); + } + if (threadIdx.x == 0) { + *(float*)shared_data = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd((float*)shared_data, sum_square); + } + + __syncthreads(); + sum_square = *(float*)shared_data; + return sum_square; + } +}; + + +template < + class InputDtype, + class OutputDtype, + class WeightDtype, + int MaxHiddenSize, + int NumThrPerCta +> +bool rmsnorm( + void const *Iptr, void const *Wptr, + void *Optr, float eps, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(n % MaxHiddenSize == 0, IsEven, [&] { + using Kernel = RMSNorm< + InputDtype, OutputDtype, WeightDtype, + MaxHiddenSize, NumThrPerCta, + IsEven>; + using Arguments = typename Kernel::Arguments; + Arguments args = { + Iptr, Wptr, Optr, eps, m, n + }; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + launch_kernel(params, grid_shape, cta_shape, ShmSize, stream); + }); + return true; +} \ No newline at end of file diff --git a/turbodiffusion/ops/quant/quant.hip b/turbodiffusion/ops/quant/quant.hip new file mode 100644 index 0000000..878d9c3 --- /dev/null +++ b/turbodiffusion/ops/quant/quant.hip @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + * + * HIP/ROCm quantization kernel. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/common_hip.hpp" +#include "quant/quant_hip.hpp" + +auto quant( + torch::Tensor const& Input, + std::optional& Output, + std::optional& Output_S +) { + using ElementOut = int8_t; + static constexpr int BlockSize = 128; + static constexpr int NumThrPerCta = 256; + + int64_t m = Input.size(0); + int64_t n = Input.size(1); + torch::Device const input_device = Input.device(); + + create_tensor(input_device, Output, Output_S, m, n); + + ElementOut *Optr = (ElementOut*)Output.value().data_ptr(); + float *OSptr = Output_S.value().data_ptr(); + + switch (Input.scalar_type()) { + case torch::kHalf: { + __half *Iptr = (__half*)Input.data_ptr(); + quantization<__half, BlockSize, NumThrPerCta>( + Iptr, Optr, OSptr, m, n, at::hip::getCurrentHIPStream().stream() + ); + break; + } + + case torch::kBFloat16: { + hip_bfloat16 *Iptr = (hip_bfloat16*)Input.data_ptr(); + quantization( + Iptr, Optr, OSptr, m, n, at::hip::getCurrentHIPStream().stream() + ); + break; + } + + default: { + std::cerr << "Observing: " << Input.scalar_type() << " for the input datatype which is invalid"; + throw std::runtime_error("Unsupported input data type for quantize_to_fp4."); + } + } + + return std::make_tuple(Output, Output_S); +} + +void register_quant(pybind11::module_ &m) { + m.def("quant_cuda", &quant); +} + diff --git a/turbodiffusion/ops/quant/quant_hip.hpp b/turbodiffusion/ops/quant/quant_hip.hpp new file mode 100644 index 0000000..60d9bdd --- /dev/null +++ b/turbodiffusion/ops/quant/quant_hip.hpp @@ -0,0 +1,194 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + */ + +#pragma once + +#include +#include +#include "common/numeric_conversion_hip.hpp" + +#include "common/load.hpp" +#include "common/store_hip.hpp" +#include "common/launch_hip.hpp" + +template < + class InputDtype_, + int NumThrPerCta_, + bool IsEvenM, + bool IsEvenN +> +class Quantization { +public: + using InputDtype = InputDtype_; + using OutputDtype = int8_t; + using FPConverter = cutlass::NumericConverter; + + static constexpr int BlockSize = 128; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int NumElementPerThread = BlockSize * BlockSize / NumThrPerCta; + static constexpr int NumThrPerRow = BlockSize / NumElementPerThread; + + static_assert(BlockSize * BlockSize % NumThrPerCta == 0); + static_assert(NumThrPerCta % BlockSize == 0); + + static constexpr size_t ShmSize = 32; + + static constexpr float int8_max = 128.f; + + struct Params { + void const *Iptr; + void *Optr; + void *OSptr; + int64_t const m; + int64_t const n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3( + cdiv(n, BlockSize), + cdiv(m, BlockSize) + ); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3( + NumThrPerCta, 1, 1 + ); + } + + CUTLASS_DEVICE + void quantization( + float *float_reg, + void *Optr, void *OSptr, + int64_t const m, int64_t const n, + int blk_m, int blk_n, int tidx, + char *shared_data + ) { + + OutputDtype output_reg[NumElementPerThread]; + + + Saver saver; + + float amax = _reduce_amax(float_reg, (float*)shared_data); + + + _quantization(float_reg, output_reg, int8_max / amax); + + float scale_inv = amax / int8_max; + + saver.store(Optr, OSptr, output_reg, scale_inv, m, n, blk_m, blk_n, tidx); + + __syncthreads(); + } + + + CUTLASS_DEVICE + void operator()(Params const& params, char *shared_data) { + int blk_m = blockIdx.y; + int blk_n = blockIdx.x; + int tidx = threadIdx.x; + + float float_reg[NumElementPerThread]; + + // load float32 data + Loader loader; + loader.load(params.Iptr, float_reg, params.m, params.n, blk_m, blk_n, tidx); + quantization( + float_reg, params.Optr, params.OSptr, params.m, params.n, blk_m, blk_n, tidx, shared_data + ); + } + +private: + + CUTLASS_DEVICE float + _reduce_amax(float *reg, float *smem_ptr) { + float amax = 1e-8f; + // thread reduction + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + amax = fmaxf(amax, fabsf(reg[i])); + + // warp reduction - use __shfl_xor for HIP (no sync needed on AMD) + CUTLASS_PRAGMA_UNROLL + for (int i = 16; i >= 1; i /= 2) { + amax = fmaxf( + __shfl_xor(amax, i, 32), + amax + ); + } + + // cta reduction + if (threadIdx.x == 0) { + *smem_ptr = 0; + } + __syncthreads(); + + atomicMax((uint32_t*)smem_ptr, reinterpret_cast(amax)); + + __syncthreads(); + + amax = *smem_ptr; + + return amax; + } + + CUTLASS_DEVICE void + _quantization(float *float_reg, OutputDtype *out_reg, float scale) { + FPConverter converter; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + out_reg[i] = converter(float_reg[i] * scale); + } + + } +}; + +template < + class InputDtype, + int BlockSize, + int NumThrPerCta +> +bool quantization( + void const *Iptr, void *Optr, void *OSptr, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(m % BlockSize == 0, IsEvenM, [&] { + BOOL_SWITCH(n % BlockSize == 0, IsEvenN, [&] { + using Kernel = Quantization< + InputDtype, NumThrPerCta, IsEvenM, IsEvenN>; + using Arguments = typename Kernel::Arguments; + Arguments args = { + Iptr, Optr, OSptr, + m, n + }; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + launch_kernel(params, grid_shape, cta_shape, ShmSize, stream); + }); + }); + + return true; +} diff --git a/turbodiffusion/ops/quant/quant_rocm.hpp b/turbodiffusion/ops/quant/quant_rocm.hpp new file mode 100644 index 0000000..4647b7a --- /dev/null +++ b/turbodiffusion/ops/quant/quant_rocm.hpp @@ -0,0 +1,261 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Quantization kernel for AMD GPUs using HIP. + */ + +#pragma once + +#include +#include +#include +#include "common/platform.hpp" + +// Helper for input type to float conversion (handles half types properly) +template +TURBO_DEVICE TURBO_INLINE float input_to_float(T val); + +template <> +TURBO_DEVICE TURBO_INLINE float input_to_float<__half>(__half val) { + return __half2float(val); +} + +template <> +TURBO_DEVICE TURBO_INLINE float input_to_float(hip_bfloat16 val) { + return static_cast(val); +} + +template <> +TURBO_DEVICE TURBO_INLINE float input_to_float(float val) { + return val; +} + +TURBO_HOST_DEVICE int64_t cdiv(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +#define MIN(a, b) ((a) > (b) ? (b) : (a)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ +[&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return (__VA_ARGS__)(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return (__VA_ARGS__)(); \ + } \ +}() + +template < + class InputDtype_, + int NumThrPerCta_, + bool IsEvenM, + bool IsEvenN +> +class QuantizationHIP { +public: + using InputDtype = InputDtype_; + using OutputDtype = int8_t; + + static constexpr int BlockSize = 128; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int NumElementPerThread = BlockSize * BlockSize / NumThrPerCta; + static constexpr int NumThrPerRow = BlockSize / NumElementPerThread; + + static_assert(BlockSize * BlockSize % NumThrPerCta == 0); + static_assert(NumThrPerCta % BlockSize == 0); + + static constexpr size_t ShmSize = 32; + static constexpr float int8_max = 128.f; + + struct Params { + void const* Iptr; + void* Optr; + void* OSptr; + int64_t const m; + int64_t const n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3( + cdiv(n, BlockSize), + cdiv(m, BlockSize) + ); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + TURBO_DEVICE + void operator()(Params const& params, char* shared_data) { + int blk_m = blockIdx.y; + int blk_n = blockIdx.x; + int tidx = threadIdx.x; + + float float_reg[NumElementPerThread]; + + // Load input data + load_input(params.Iptr, float_reg, params.m, params.n, blk_m, blk_n, tidx); + + // Quantize + quantize(float_reg, params.Optr, params.OSptr, params.m, params.n, blk_m, blk_n, tidx, shared_data); + } + +private: + TURBO_DEVICE + void load_input(void const* input_ptr, float* thr_output_reg, + int64_t m, int64_t n, int blk_m, int blk_n, int tid) { + int thr_m_offset = tid / NumThrPerRow; + int thr_n_offset = (tid % NumThrPerRow) * NumElementPerThread; + + int64_t global_m = blk_m * BlockSize + thr_m_offset; + int64_t global_n = blk_n * BlockSize + thr_n_offset; + + InputDtype const* input = reinterpret_cast(input_ptr); + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEvenM && IsEvenN) { + thr_output_reg[i] = input_to_float(input[global_m * n + global_n + i]); + } else { + if (global_m < m && (global_n + i) < n) { + thr_output_reg[i] = input_to_float(input[global_m * n + global_n + i]); + } else { + thr_output_reg[i] = 0.0f; + } + } + } + } + + TURBO_DEVICE + void quantize(float* float_reg, void* Optr, void* OSptr, + int64_t m, int64_t n, int blk_m, int blk_n, int tidx, char* shared_data) { + OutputDtype output_reg[NumElementPerThread]; + + float amax = reduce_amax(float_reg, (float*)shared_data); + + float scale = int8_max / amax; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + float val = float_reg[i] * scale; + val = fmaxf(-128.0f, fminf(127.0f, rintf(val))); + output_reg[i] = static_cast(val); + } + + float scale_inv = amax / int8_max; + + // Store output + store_output(Optr, OSptr, output_reg, scale_inv, m, n, blk_m, blk_n, tidx); + + __syncthreads(); + } + + TURBO_DEVICE + float reduce_amax(float* reg, float* smem_ptr) { + float amax = 1e-8f; + + // Thread reduction + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + amax = fmaxf(amax, fabsf(reg[i])); + } + + __syncwarp(); + + // Warp reduction + TURBO_PRAGMA_UNROLL + for (int i = 16; i >= 1; i /= 2) { + amax = fmaxf(__shfl_xor(amax, i, 32), amax); + } + + // CTA reduction + if (threadIdx.x == 0) { + *smem_ptr = 0; + } + __syncthreads(); + + atomicMax((unsigned int*)smem_ptr, __float_as_uint(amax)); + + __syncthreads(); + + amax = __uint_as_float(*(unsigned int*)smem_ptr); + return amax; + } + + TURBO_DEVICE + void store_output(void* Optr, void* OSptr, OutputDtype* reg, float scale_inv, + int64_t m, int64_t n, int blk_m, int blk_n, int tid) { + int thr_m_offset = tid / NumThrPerRow; + int thr_n_offset = (tid % NumThrPerRow) * NumElementPerThread; + + int64_t global_m = blk_m * BlockSize + thr_m_offset; + int64_t padded_n = cdiv(n, BlockSize) * BlockSize; + int64_t global_n = blk_n * BlockSize + thr_n_offset; + + OutputDtype* output = reinterpret_cast(Optr); + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEvenM && IsEvenN) { + output[global_m * padded_n + global_n + i] = reg[i]; + } else { + if (global_m < m && (global_n + i) < n) { + output[global_m * padded_n + global_n + i] = reg[i]; + } + } + } + + if (tid == 0) { + float* scale_ptr = reinterpret_cast(OSptr); + scale_ptr[blk_m * cdiv(n, BlockSize) + blk_n] = scale_inv; + } + } +}; + +template +__global__ void quant_kernel_hip(typename Kernel::Params const params) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template < + class InputDtype, + int BlockSize, + int NumThrPerCta +> +bool quantization_hip( + void const* Iptr, void* Optr, void* OSptr, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(m % BlockSize == 0, IsEvenM, [&] { + BOOL_SWITCH(n % BlockSize == 0, IsEvenN, [&] { + using Kernel = QuantizationHIP; + using Arguments = typename Kernel::Arguments; + + Arguments args = {Iptr, Optr, OSptr, m, n}; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + + hipLaunchKernelGGL(quant_kernel_hip, grid_shape, cta_shape, ShmSize, stream, params); + }); + }); + + return true; +} + diff --git a/turbodiffusion/rcm/networks/wan2pt1.py b/turbodiffusion/rcm/networks/wan2pt1.py index 1f55c09..a0b28a3 100644 --- a/turbodiffusion/rcm/networks/wan2pt1.py +++ b/turbodiffusion/rcm/networks/wan2pt1.py @@ -29,14 +29,17 @@ flash_apply_rotary_emb = None print("flash_attn is not installed.") -from torch.distributed import ProcessGroup, get_process_group_ranks -from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +from torch.distributed import ProcessGroup + +if torch.distributed.is_available(): + from torch.distributed import ProcessGroup, get_process_group_ranks + from torch.distributed._composable.fsdp import fully_shard + from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast from imaginaire.utils import log from rcm.utils.a2a_cp import MinimalA2AAttnOp from rcm.utils.selective_activation_checkpoint import CheckpointMode, SACConfig -from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast T5_CONTEXT_TOKEN_NUMBER = 512 FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 @@ -543,7 +546,7 @@ def __init__( Epsilon value for normalization layers """ - super().__init__() + super().__init__() assert model_type in ["t2v", "i2v", "flf2v"] self.model_type = model_type diff --git a/turbodiffusion/rcm/networks/wan2pt1_jvp.py b/turbodiffusion/rcm/networks/wan2pt1_jvp.py index cd06cf8..cb521c4 100644 --- a/turbodiffusion/rcm/networks/wan2pt1_jvp.py +++ b/turbodiffusion/rcm/networks/wan2pt1_jvp.py @@ -21,10 +21,12 @@ import torch.nn as nn from einops import rearrange, repeat from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb -from torch.distributed import ProcessGroup, get_process_group_ranks -from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +if torch.distributed.is_available(): + from torch.distributed import ProcessGroup, get_process_group_ranks + from torch.distributed._composable.fsdp import fully_shard + from imaginaire.utils import log from rcm.utils.a2a_cp import MinimalA2AAttnOp from rcm.utils.selective_activation_checkpoint import CheckpointMode, SACConfig diff --git a/turbodiffusion/rcm/networks/wan2pt2.py b/turbodiffusion/rcm/networks/wan2pt2.py index fc41564..55b6620 100644 --- a/turbodiffusion/rcm/networks/wan2pt2.py +++ b/turbodiffusion/rcm/networks/wan2pt2.py @@ -29,14 +29,17 @@ flash_apply_rotary_emb = None print("flash_attn is not installed.") -from torch.distributed import ProcessGroup, get_process_group_ranks -from torch.distributed._composable.fsdp import fully_shard +from torch.distributed import ProcessGroup from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +if torch.distributed.is_available(): + from torch.distributed import get_process_group_ranks + from torch.distributed._composable.fsdp import fully_shard + from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast + from imaginaire.utils import log from rcm.utils.a2a_cp import MinimalA2AAttnOp from rcm.utils.selective_activation_checkpoint import CheckpointMode, SACConfig -from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast T5_CONTEXT_TOKEN_NUMBER = 512 FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 diff --git a/turbodiffusion/rcm/utils/context_parallel.py b/turbodiffusion/rcm/utils/context_parallel.py index af27ffa..06f1b0b 100644 --- a/turbodiffusion/rcm/utils/context_parallel.py +++ b/turbodiffusion/rcm/utils/context_parallel.py @@ -16,8 +16,11 @@ import torch from torch import Tensor -from torch.distributed import ProcessGroup, all_gather, broadcast_object_list, get_process_group_ranks, get_world_size -from torch.distributed.utils import _verify_param_shape_across_processes +if torch.distributed.is_available(): + from torch.distributed import ProcessGroup, all_gather, broadcast_object_list, get_process_group_ranks, get_world_size + from torch.distributed.utils import _verify_param_shape_across_processes +else: + from torch.distributed import ProcessGroup from imaginaire.utils import distributed