Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,3 @@ repos:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
hooks:
- id: clang-format
types_or: [c, c++, cuda]
110 changes: 57 additions & 53 deletions kernels-v1/attention-int8/attention_int8_cuda/attention_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,59 +113,63 @@ __device__ __forceinline__ void pad_Q_rows(int8_t *Q_i8, int q_size, int tid) {
// ============================================================================
// Main kernel
// ============================================================================
template <int HEAD_DIM, bool CAUSAL>
__global__ void __launch_bounds__(THREADS, OccHint<HEAD_DIM>::MIN_BLOCKS)
int8_attention_kernel(const float16 *__restrict__ Q,
const float16 *__restrict__ K,
const float16 *__restrict__ V,
float16 *__restrict__ O,
const float *__restrict__ timestep_scales,
int64_t timestep, int B, int H, int kv_H, int N) {
constexpr int BQ = BlockConfig<HEAD_DIM>::BQ;
constexpr int BK = BlockConfig<HEAD_DIM>::BK;

static_assert(BQ % 16 == 0, "BQ must be multiple of 16");
static_assert(BK % 16 == 0, "BK must be multiple of 16");
static_assert(HEAD_DIM % 16 == 0, "HEAD_DIM must be multiple of 16");

const int b = blockIdx.x;
const int h = blockIdx.y;
const int q_tile = blockIdx.z;
const int tid = threadIdx.x;
const int wid = tid >> 5;

const int q_start = q_tile * BQ;
if (q_start >= N)
return;
const int q_size = min(BQ, N - q_start);

const int kv_h = h % kv_H;

const size_t q_off = ((size_t)b * H + h) * N * HEAD_DIM;
const size_t kv_off = ((size_t)b * kv_H + kv_h) * N * HEAD_DIM;

const float16 *Q_head = Q + q_off;
const float16 *K_head = K + kv_off;
const float16 *V_head = V + kv_off;
float16 *O_head = O + q_off;

// Shared memory layout
extern __shared__ char smem[];

int8_t *Q_i8 = reinterpret_cast<int8_t *>(smem);
int8_t *K_i8_T = Q_i8 + BQ * HEAD_DIM;
float16 *V_tile = reinterpret_cast<float16 *>(K_i8_T + HEAD_DIM * BK);
int32_t *QK_i32 = reinterpret_cast<int32_t *>(V_tile + BK * HEAD_DIM);
float *warp_scr = reinterpret_cast<float *>(QK_i32 + BQ * BK);
float *row_max = warp_scr + WARPS;
float *row_sum = row_max + BQ;
float *out_acc = row_sum + BQ;

// Q scale [F1][F4]
const float ts = timestep_scales ? timestep_scales[timestep] : 1.f;

float lqmax = 0.f;
for (int i = tid; i < q_size * HEAD_DIM; i += THREADS) {
template<int HEAD_DIM, bool CAUSAL>
__global__ void
__launch_bounds__(THREADS, OccHint<HEAD_DIM>::MIN_BLOCKS)
int8_attention_kernel(
const float16* __restrict__ Q,
const float16* __restrict__ K,
const float16* __restrict__ V,
float16* __restrict__ O,
const float* __restrict__ timestep_scales,
int64_t timestep,
int B, int H, int kv_H, int N)
{
constexpr int BQ = BlockConfig<HEAD_DIM>::BQ;
constexpr int BK = BlockConfig<HEAD_DIM>::BK;

static_assert(BQ % 16 == 0, "BQ must be multiple of 16");
static_assert(BK % 16 == 0, "BK must be multiple of 16");
static_assert(HEAD_DIM % 16 == 0, "HEAD_DIM must be multiple of 16");

const int b = blockIdx.x;
const int h = blockIdx.y;
const int q_tile = blockIdx.z;
const int tid = threadIdx.x;
const int wid = tid >> 5;

const int q_start = q_tile * BQ;
if (q_start >= N) return;
const int q_size = min(BQ, N - q_start);

const int kv_h = h % kv_H;

const size_t q_off = ((size_t)b * H + h) * N * HEAD_DIM;
const size_t kv_off = ((size_t)b * kv_H + kv_h) * N * HEAD_DIM;

const float16* Q_head = Q + q_off;
const float16* K_head = K + kv_off;
const float16* V_head = V + kv_off;
float16* O_head = O + q_off;

// Shared memory layout
extern __shared__ char smem[];

int8_t* Q_i8 = reinterpret_cast<int8_t*>(smem);
int8_t* K_i8_T = Q_i8 + BQ * HEAD_DIM;
float16* V_tile = reinterpret_cast<float16*>(K_i8_T + HEAD_DIM * BK);
int32_t* QK_i32 = reinterpret_cast<int32_t*>(V_tile + BK * HEAD_DIM);
float* warp_scr = reinterpret_cast<float*>(QK_i32 + BQ * BK);
float* row_max = warp_scr + WARPS;
float* row_sum = row_max + BQ;
float* out_acc = row_sum + BQ;

// Q scale [F1][F4]
const float ts = timestep_scales ? timestep_scales[timestep] : 1.f;

float lqmax = 0.f;
for (int i = tid; i < q_size * HEAD_DIM; i += THREADS)
{
int qi = i / HEAD_DIM;
int di = i % HEAD_DIM;
lqmax = fmaxf(lqmax,
Expand Down
16 changes: 9 additions & 7 deletions kernels-v1/attention-int8/torch-ext/torch_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ void validate_shapes(const torch::Tensor &Q, const torch::Tensor &K,
void validate_head_dim(int64_t D) {
TORCH_CHECK(D % 16 == 0, "HEAD_DIM must be multiple of 16, got ", D);

TORCH_CHECK(is_head_dim_supported(D), "Unsupported HEAD_DIM=", D,
". Supported: 32, 64, 80, 96, 128, 160, 256");
TORCH_CHECK(is_head_dim_supported(D),
"Unsupported HEAD_DIM=", D,
". Supported: 32, 64, 80, 96, 128, 160, 256");
}

void validate_kv_constraint(int64_t H, int64_t kv_H) {
Expand Down Expand Up @@ -109,9 +110,9 @@ void validate_timestep_scales(const c10::optional<torch::Tensor> &ts,
TORCH_CHECK(t.dim() == 1, "timestep_scales must be 1D, got shape ",
t.sizes());

TORCH_CHECK(t.size(0) == batch_size,
"timestep_scales batch size mismatch: expected ", batch_size,
", got ", t.size(0));
TORCH_CHECK(t.size(0) == batch_size,
"timestep_scales batch size mismatch: expected ", batch_size,
", got ", t.size(0));

TORCH_CHECK(timestep >= 0, "timestep must be >= 0, got ", timestep);
}
Expand Down Expand Up @@ -226,6 +227,7 @@ TORCH_LIBRARY(int8_attn, m) {
"Tensor? timestep_scales=None,"
"int timestep=0,"
"bool causal=False"
") -> Tensor");
m.impl("int8_attention_forward", torch::kCUDA, &int8_attention_forward);
") -> Tensor"
);
m.impl("int8_attention_forward", torch::kCUDA, &int8_attention_forward);
}