Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions kernels-v1/attention-int8/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
cmake_minimum_required(VERSION 3.26)

# Set Intel SYCL compiler before project() call
find_program(ICX_COMPILER icx)
find_program(ICPX_COMPILER icpx)

if(ICX_COMPILER OR ICPX_COMPILER)
set(CMAKE_C_COMPILER ${ICX_COMPILER})

if(WIN32)
set(CMAKE_CXX_COMPILER ${ICX_COMPILER})
else()
set(CMAKE_CXX_COMPILER ${ICPX_COMPILER})
endif()
endif()

project(attention-int8 LANGUAGES CXX)

install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)

include(FetchContent)
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")

set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")

include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/kernel.cmake)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/get_gpu_lang.cmake)

if(DEFINED Python3_EXECUTABLE)
# Allow passing through the interpreter (e.g. from setup.py).
find_package(Python3 COMPONENTS Development Development.SABIModule Interpreter)
if (NOT Python3_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif()
else()
find_package(Python3 REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
endif()

get_gpu_lang(DETECTED_GPU_LANG)
set(GPU_LANG "${DETECTED_GPU_LANG}" CACHE STRING "GPU language")
gpu_lang_to_backend(BACKEND "${GPU_LANG}")
message(STATUS "Using backend: ${BACKEND}, GPU language: ${GPU_LANG}")

set(KERNEL_REVISION "dba582b_dirty" CACHE STRING "Kernel revision, must be unique")
set(OPS_NAME "_attention_int8_${BACKEND}_dba582b_dirty")

append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")

find_package(Torch REQUIRED)

run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version")



option(BUILD_ALL_SUPPORTED_ARCHS "Build all supported architectures" off)

if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
set(CUDA_DEFAULT_KERNEL_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0+PTX")
elseif(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0+PTX")
else()
set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX")
endif()

# Basic checks for each GPU language.
if(GPU_LANG STREQUAL "CUDA")
if(NOT CUDA_FOUND)
message(FATAL_ERROR "GPU language is set to CUDA, but cannot find CUDA toolkit")
endif()



# This clears out -gencode arguments from `CMAKE_CUDA_FLAGS`, which we need
# to set our own set of capabilities.
clear_gencode_flags()

# Get the capabilities without +PTX suffixes, so that we can use them as
# the target archs in the loose intersection with a kernel's capabilities.
cuda_remove_ptx_suffixes(CUDA_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}")
message(STATUS "CUDA supported base architectures: ${CUDA_ARCHS}")

if(BUILD_ALL_SUPPORTED_ARCHS)
set(CUDA_KERNEL_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}")
else()
try_run_python(CUDA_KERNEL_ARCHS SUCCESS "import torch; cc=torch.cuda.get_device_capability(); print(f\"{cc[0]}.{cc[1]}\")" "Failed to get CUDA capability")
if(NOT SUCCESS)
message(WARNING "Failed to detect CUDA capability, using default capabilities.")
set(CUDA_KERNEL_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}")
endif()
endif()

message(STATUS "CUDA supported kernel architectures: ${CUDA_KERNEL_ARCHS}")

if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()

# TODO: deprecate one of these settings.
add_compile_definitions(USE_CUDA=1)
add_compile_definitions(CUDA_KERNEL)
elseif(GPU_LANG STREQUAL "HIP")
if(NOT HIP_FOUND)
message(FATAL_ERROR "GPU language is set to HIP, but cannot find ROCm toolkit")
endif()

# Importing torch recognizes and sets up some HIP/ROCm configuration but does
# not let cmake recognize .hip files. In order to get cmake to understand the
# .hip extension automatically, HIP must be enabled explicitly.
enable_language(HIP)

override_gpu_arches(GPU_ARCHES HIP ${HIP_SUPPORTED_ARCHS})
set(ROCM_ARCHS ${GPU_ARCHES})
message(STATUS "ROCM supported target architectures: ${ROCM_ARCHS}")

# TODO: deprecate one of these settings.
add_compile_definitions(USE_ROCM=1)
add_compile_definitions(ROCM_KERNEL)
elseif(GPU_LANG STREQUAL "CPU")
add_compile_definitions(CPU_KERNEL)
set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version")
elseif(GPU_LANG STREQUAL "METAL")
set(CMAKE_OSX_DEPLOYMENT_TARGET "26.0" CACHE STRING "Minimum macOS deployment version")
enable_language(C OBJC OBJCXX)

add_compile_definitions(METAL_KERNEL)

# Initialize lists for Metal shader sources and their include directories
set(ALL_METAL_SOURCES)
set(METAL_INCLUDE_DIRS)
elseif(GPU_LANG STREQUAL "SYCL")
if(NOT ICX_COMPILER AND NOT ICPX_COMPILER)
message(FATAL_ERROR "Intel SYCL C++ compiler (icpx) and/or C compiler (icx) not found. Please install Intel oneAPI toolkit.")
endif()

execute_process(
COMMAND ${ICPX_COMPILER} --version
OUTPUT_VARIABLE ICPX_VERSION_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
string(REGEX MATCH "[0-9]+\\.[0-9]+" DPCPP_VERSION "${ICPX_VERSION_OUTPUT}")
set(DPCPP_VERSION "${DPCPP_VERSION}" CACHE STRING "DPCPP major.minor version")

# On Windows, use icx (MSVC-compatible) for C++ to work with Ninja generator
# On Linux, use icpx (GNU-compatible) for C++
if(WIN32)
message(STATUS "Using Intel SYCL C++ compiler: ${ICX_COMPILER} and C compiler: ${ICX_COMPILER} Version: ${DPCPP_VERSION} (Windows MSVC-compatible mode)")
else()
message(STATUS "Using Intel SYCL C++ compiler: ${ICPX_COMPILER} and C compiler: ${ICX_COMPILER} Version: ${DPCPP_VERSION}")
endif()


set(sycl_link_flags "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required';")
set(sycl_flags "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")
set(GPU_FLAGS "${sycl_flags}")
set(GPU_ARCHES "")


add_compile_definitions(XPU_KERNEL)
add_compile_definitions(USE_XPU)
else()
message(FATAL_ERROR "Unsupported GPU language: ${GPU_LANG}")
endif()

# Initialize SRC list for kernel and binding sources
set(SRC "")

include(${CMAKE_CURRENT_LIST_DIR}/cmake/build-variants.cmake)

# Generate build variant name.
if(GPU_LANG STREQUAL "CUDA")
generate_build_name(BUILD_VARIANT_NAME "${TORCH_VERSION}" "cuda" "${CUDA_VERSION}")
elseif(GPU_LANG STREQUAL "HIP")
run_python(ROCM_VERSION "import torch.version; print(torch.version.hip.split('.')[0] + '.' + torch.version.hip.split('.')[1])" "Failed to get ROCm version")
generate_build_name(BUILD_VARIANT_NAME "${TORCH_VERSION}" "rocm" "${ROCM_VERSION}")
elseif(GPU_LANG STREQUAL "SYCL")
generate_build_name(BUILD_VARIANT_NAME "${TORCH_VERSION}" "xpu" "${DPCPP_VERSION}")
elseif(GPU_LANG STREQUAL "METAL")
generate_build_name(BUILD_VARIANT_NAME "${TORCH_VERSION}" "metal" "")
elseif(GPU_LANG STREQUAL "CPU")
generate_build_name(BUILD_VARIANT_NAME "${TORCH_VERSION}" "cpu" "")
else()
message(FATAL_ERROR "Cannot generate build name for unknown GPU_LANG: ${GPU_LANG}")
endif()

configure_file(
${CMAKE_CURRENT_LIST_DIR}/cmake/_ops.py.in
${CMAKE_CURRENT_SOURCE_DIR}/torch-ext/attention_int8/_ops.py
@ONLY
)

if(GPU_LANG STREQUAL "CUDA")
get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
endif()

set(TORCH_attention-int8_SRC
torch-ext/torch_binding.cpp torch-ext/torch_binding.h
)


list(APPEND SRC "${TORCH_attention-int8_SRC}")
cuda_kernel_component(SRC
SOURCES "attention_int8_cuda/attention_int8.cu"
)
# Include Metal shader compilation utilities if needed
if(GPU_LANG STREQUAL "METAL")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/compile-metal.cmake)
endif()

# Define the extension target with unified parameters
define_gpu_extension_target(
${OPS_NAME}
${OPS_NAME}
DESTINATION ${OPS_NAME}
LANGUAGE ${GPU_LANG}
SOURCES ${SRC}
COMPILE_FLAGS ${GPU_FLAGS}
ARCHITECTURES ${GPU_ARCHES}
USE_SABI 3
WITH_SOABI)

if(NOT (MSVC OR GPU_LANG STREQUAL "SYCL"))
target_link_options(${OPS_NAME} PRIVATE -static-libstdc++)
endif()

if(GPU_LANG STREQUAL "SYCL")
target_link_options(${OPS_NAME} PRIVATE ${sycl_link_flags})
target_link_libraries(${OPS_NAME} PRIVATE dnnl)
endif()

# Compile Metal shaders if any were found
if(GPU_LANG STREQUAL "METAL")
if(ALL_METAL_SOURCES)
compile_metal_shaders(${OPS_NAME} "${ALL_METAL_SOURCES}" "${METAL_INCLUDE_DIRS}")
endif()
endif()


# Add kernels_install target for huggingface/kernels library layout
add_kernels_install_target(${OPS_NAME} "attention_int8" "${BUILD_VARIANT_NAME}"
DATA_EXTENSIONS ""
GPU_ARCHS "${ALL_GPU_ARCHS}")

# Add local_install target for local development with get_local_kernel()
add_local_install_target(${OPS_NAME} "attention_int8" "${BUILD_VARIANT_NAME}"
DATA_EXTENSIONS ""
GPU_ARCHS "${ALL_GPU_ARCHS}")
60 changes: 28 additions & 32 deletions kernels-v1/attention-int8/attention_int8_cuda/attention_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -172,30 +172,25 @@ int8_attention_kernel(
{
int qi = i / HEAD_DIM;
int di = i % HEAD_DIM;
lqmax = fmaxf(lqmax,
fabsf(__half2float(Q_head[(q_start + qi) * HEAD_DIM + di])));
}
lqmax = fmaxf(lqmax, fabsf(__half2float(Q_head[(q_start + qi) * HEAD_DIM + di])));
}

float abs_max_Q = block_reduce_max(lqmax, warp_scr);
const float inv_Q = 127.f / fmaxf(abs_max_Q * ts, 1e-6f);
const float scl_Q = 1.f / inv_Q;
float abs_max_Q = block_reduce_max(lqmax, warp_scr);
const float inv_Q = 127.f / fmaxf(abs_max_Q, 1e-6f);
const float scl_Q = 1.f / inv_Q;

// Quantize Q tile
for (int i = tid; i < q_size * HEAD_DIM; i += THREADS) {
int qi = i / HEAD_DIM, di = i % HEAD_DIM;
Q_i8[qi * HEAD_DIM + di] = quantize_f32(
__half2float(Q_head[(q_start + qi) * HEAD_DIM + di]), inv_Q);
}
pad_Q_rows<HEAD_DIM, BQ>(Q_i8, q_size, tid);
// Quantize Q tile
for (int i = tid; i < q_size * HEAD_DIM; i += THREADS) {
int qi = i / HEAD_DIM, di = i % HEAD_DIM;
Q_i8[qi * HEAD_DIM + di] =
quantize_f32(__half2float(Q_head[(q_start + qi) * HEAD_DIM + di]), inv_Q);
}
pad_Q_rows<HEAD_DIM, BQ>(Q_i8, q_size, tid);

// Initialise per-row accumulators [F3]
for (int qi = tid; qi < BQ; qi += THREADS) {
row_max[qi] = -1e30f;
row_sum[qi] = 0.f;
}
for (int i = tid; i < BQ * HEAD_DIM; i += THREADS)
out_acc[i] = 0.f;
__syncthreads();
// Initialise per-row accumulators [F3]
for (int qi = tid; qi < BQ; qi += THREADS) { row_max[qi] = -1e30f; row_sum[qi] = 0.f; }
for (int i = tid; i < BQ * HEAD_DIM; i += THREADS) out_acc[i] = 0.f;
__syncthreads();

#if __CUDA_ARCH__ >= 750
// WMMA fragment types — INT8 WMMA requires sm_75+ (Turing and above) [G1]
Expand All @@ -208,20 +203,21 @@ int8_attention_kernel(

const float inv_sqrt_d = rsqrtf((float)HEAD_DIM);

// Stream K tiles
for (int k_start = 0; k_start < N; k_start += BK) {
const int k_size = min(BK, N - k_start);
float lkmax_global = 0.f;

float lkmax = 0.f;
for (int i = tid; i < k_size * HEAD_DIM; i += THREADS)
lkmax = fmaxf(lkmax, fabsf(__half2float(K_head[k_start * HEAD_DIM + i])));
float abs_max_K = block_reduce_max(lkmax, warp_scr);
const float inv_K = 127.f / fmaxf(abs_max_K * ts, 1e-6f);
for (int i = tid; i < N * HEAD_DIM; i += THREADS) lkmax_global = fmaxf(lkmax_global, fabsf(__half2float(K_head[i])));

float abs_max_K_global = block_reduce_max(lkmax_global, warp_scr);
const float inv_K = 127.f / fmaxf(abs_max_K_global * ts, 1e-6f);
const float scl_K = 1.f / inv_K;

// [F5][G1] Fused quantize + transpose K
load_and_quantize_K_transposed<HEAD_DIM, BK>(K_head + k_start * HEAD_DIM,
K_i8_T, k_size, tid, inv_K);
// Stream K tiles
for (int k_start = 0; k_start < N; k_start += BK) {
const int k_size = min(BK, N - k_start);

// [F5][G1] Fused quantize + transpose K
load_and_quantize_K_transposed<HEAD_DIM, BK>(
K_head + k_start * HEAD_DIM, K_i8_T, k_size, tid, inv_K);

// Load V tile
for (int i = tid; i < k_size * HEAD_DIM; i += THREADS) {
Expand Down
9 changes: 9 additions & 0 deletions kernels-v1/attention-int8/cmake/_ops.py.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
from . import @OPS_NAME@
ops = torch.ops.@OPS_NAME@

def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
return f"@OPS_NAME@::{op_name}"
55 changes: 55 additions & 0 deletions kernels-v1/attention-int8/cmake/add_gpu_arch_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import argparse
import json
import sys


def main():
parser = argparse.ArgumentParser(
description="Write a metadata JSON file with GPU architecture information, "
"reading from a source file and writing to a destination."
)
parser.add_argument(
"input",
help="Path to the source metadata JSON file to read from.",
)
parser.add_argument(
"destination",
help="Path to write the output metadata JSON file to.",
)
parser.add_argument(
"--backend",
required=True,
choices=["cuda", "rocm"],
help="GPU backend type.",
)
parser.add_argument(
"--archs",
required=True,
help="Semicolon-separated list of GPU architectures/capabilities.",
)
args = parser.parse_args()

archs = sorted(set(a for a in args.archs.split(";") if a))

try:
with open(args.input) as f:
data = json.load(f)
except FileNotFoundError:
print(f"Error: input metadata file not found: {args.input}", file=sys.stderr)
sys.exit(1)
except json.JSONDecodeError as e:
print(f"Error: failed to parse input metadata JSON: {e}", file=sys.stderr)
sys.exit(1)

data["backend"] = {
"type": args.backend,
"archs": archs,
}

with open(args.destination, "w") as f:
json.dump(data, f, indent=2)
f.write("\n")


if __name__ == "__main__":
main()
Loading