Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/python_test.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些还是先别打开了,cpu没实现ci都过不了

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像也不是没实现···是精度不够么_(:з」∠)_

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是我测试算子的时候都打开了所以触发了ci的执行

Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def run_tests(args):
"topkrouter.py",
"topksoftmax.py",
"zeros.py",
# "paged_attention.py",
# "paged_caching.py",
# "paged_attention_prefill.py"
]:
result = subprocess.run(
f"python {test} {args} --debug", text=True, encoding="utf-8", shell=True
Expand Down
2 changes: 2 additions & 0 deletions src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ infiniStatus_t Descriptor::calculate(
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_2048)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
Expand Down
3 changes: 3 additions & 0 deletions src/infiniop/ops/layer_norm/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down
5 changes: 5 additions & 0 deletions src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len,
_info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p,
_info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_2048>(
y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len,
_info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p,
_info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(
y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len,
Expand Down
8 changes: 4 additions & 4 deletions src/infiniop/ops/logsoftmax/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
Expand Down Expand Up @@ -73,7 +73,7 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// GET(INFINI_DEVICE_ILUVATAR, nvidia);
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
Expand Down Expand Up @@ -111,7 +111,7 @@ __C infiniStatus_t infiniopLogSoftmax(
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
Expand Down Expand Up @@ -144,7 +144,7 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia);
Expand Down
2 changes: 2 additions & 0 deletions src/infiniop/ops/lp_norm/nvidia/lp_norm_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
Expand Down
49 changes: 49 additions & 0 deletions src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,66 @@ struct OnlineSoftmaxState {
}
};
__device__ __forceinline__ float warpReduceSum(float x) {
#if defined(ENABLE_ILUVATAR_API)
// Iluvatar may use warp size 64; __shfl_sync(0xffffffff) only covers 32 threads.
// Use shared-memory tree reduce for portability across warp sizes.
constexpr int kMaxWarps = 16;
__shared__ float _reduce_buf[kMaxWarps * 32];
const int lane = threadIdx.x & 31;
const int warp_id = threadIdx.x / 32;
_reduce_buf[threadIdx.x] = x;
__syncthreads();
for (int offset = 16; offset > 0; offset >>= 1) {
if (lane < offset) {
_reduce_buf[warp_id * 32 + lane] += _reduce_buf[warp_id * 32 + lane + offset];
}
__syncthreads();
}
return _reduce_buf[warp_id * 32];
#else
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_down_sync(0xffffffff, x, offset);
}
return x;
#endif
}

__device__ __forceinline__ float warpBroadcast(float x, int src_lane) {
#if defined(ENABLE_ILUVATAR_API)
__shared__ float _bcast_buf[16];
const int warp_id = threadIdx.x / 32;
if ((threadIdx.x & 31) == src_lane) {
_bcast_buf[warp_id] = x;
}
__syncthreads();
return _bcast_buf[warp_id];
#else
return __shfl_sync(0xffffffff, x, src_lane);
#endif
}

__device__ __forceinline__ float warpReduceMax(float x) {
#if defined(ENABLE_ILUVATAR_API)
__shared__ float _reduce_buf[16 * 32];
const int lane = threadIdx.x & 31;
const int warp_id = threadIdx.x / 32;
_reduce_buf[threadIdx.x] = x;
__syncthreads();
for (int offset = 16; offset > 0; offset >>= 1) {
if (lane < offset) {
float other = _reduce_buf[warp_id * 32 + lane + offset];
float cur = _reduce_buf[warp_id * 32 + lane];
_reduce_buf[warp_id * 32 + lane] = fmaxf(cur, other);
}
__syncthreads();
}
return _reduce_buf[warp_id * 32];
#else
for (int offset = 16; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_down_sync(0xffffffff, x, offset));
}
return x;
#endif
}

__device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) {
Expand Down
32 changes: 18 additions & 14 deletions src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__

#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -194,8 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);

#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
Expand Down Expand Up @@ -233,7 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);

#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
Expand Down Expand Up @@ -411,8 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);

#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
Expand Down Expand Up @@ -450,7 +450,11 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
#ifdef ENABLE_ILUVATAR_API
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
#else
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#endif

#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
Expand Down Expand Up @@ -785,8 +789,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);

#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
Expand Down Expand Up @@ -826,7 +830,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);

#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
Expand Down Expand Up @@ -1270,7 +1274,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);

#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
Expand Down Expand Up @@ -1961,8 +1965,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);

#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
Expand Down Expand Up @@ -2002,7 +2006,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);

#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
Expand Down Expand Up @@ -2131,7 +2135,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);

const int64_t q_token = q_start + static_cast<int64_t>(q_token_local);
half *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ constexpr size_t ceilDiv(size_t a, size_t b) {
}

inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) {
// Iluvatar: use warp (stable). Users can override via INFINIOP_FLASH_PREFILL_KERNEL.
#ifdef ENABLE_ILUVATAR_API
(void)info;
return "warp";
#endif
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
Expand Down
23 changes: 20 additions & 3 deletions src/infiniop/ops/rearrange/nvidia/rearrange_kernel.cuh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一会儿请教一下这个改动

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#define ARRAY_TYPE_SIZE size_t

// 与 DEFINE_KERNELS_BY_CONSTRAINT 耦合,需要同时修改
#define MAX_BLOCK_ARRAY_SIZE 5
#define MAX_GRID_ARRAY_SIZE 5
#define MAX_BLOCK_ARRAY_SIZE 6
#define MAX_GRID_ARRAY_SIZE 6

template <int ArrSize, typename ArrayType>
struct ArrayStruct {
Expand Down Expand Up @@ -185,32 +185,43 @@ struct Constraint {
DEFINE_REARRANGE_KERNEL(double4, constraint_num, block_array_size, grid_array_size)

// 与 MAX_BLOCK_ARRAY_SIZE 和 MAX_GRID_ARRAY_SIZE 耦合,需要同时修改
// 为1-5和1-5的所有组合生成内核
// 为1-6和1-6的所有组合生成内核
DEFINE_KERNELS_BY_CONSTRAINT(1, 1)
DEFINE_KERNELS_BY_CONSTRAINT(1, 2)
DEFINE_KERNELS_BY_CONSTRAINT(1, 3)
DEFINE_KERNELS_BY_CONSTRAINT(1, 4)
DEFINE_KERNELS_BY_CONSTRAINT(1, 5)
DEFINE_KERNELS_BY_CONSTRAINT(1, 6)
DEFINE_KERNELS_BY_CONSTRAINT(2, 1)
DEFINE_KERNELS_BY_CONSTRAINT(2, 2)
DEFINE_KERNELS_BY_CONSTRAINT(2, 3)
DEFINE_KERNELS_BY_CONSTRAINT(2, 4)
DEFINE_KERNELS_BY_CONSTRAINT(2, 5)
DEFINE_KERNELS_BY_CONSTRAINT(2, 6)
DEFINE_KERNELS_BY_CONSTRAINT(3, 1)
DEFINE_KERNELS_BY_CONSTRAINT(3, 2)
DEFINE_KERNELS_BY_CONSTRAINT(3, 3)
DEFINE_KERNELS_BY_CONSTRAINT(3, 4)
DEFINE_KERNELS_BY_CONSTRAINT(3, 5)
DEFINE_KERNELS_BY_CONSTRAINT(3, 6)
DEFINE_KERNELS_BY_CONSTRAINT(4, 1)
DEFINE_KERNELS_BY_CONSTRAINT(4, 2)
DEFINE_KERNELS_BY_CONSTRAINT(4, 3)
DEFINE_KERNELS_BY_CONSTRAINT(4, 4)
DEFINE_KERNELS_BY_CONSTRAINT(4, 5)
DEFINE_KERNELS_BY_CONSTRAINT(4, 6)
DEFINE_KERNELS_BY_CONSTRAINT(5, 1)
DEFINE_KERNELS_BY_CONSTRAINT(5, 2)
DEFINE_KERNELS_BY_CONSTRAINT(5, 3)
DEFINE_KERNELS_BY_CONSTRAINT(5, 4)
DEFINE_KERNELS_BY_CONSTRAINT(5, 5)
DEFINE_KERNELS_BY_CONSTRAINT(5, 6)
DEFINE_KERNELS_BY_CONSTRAINT(6, 1)
DEFINE_KERNELS_BY_CONSTRAINT(6, 2)
DEFINE_KERNELS_BY_CONSTRAINT(6, 3)
DEFINE_KERNELS_BY_CONSTRAINT(6, 4)
DEFINE_KERNELS_BY_CONSTRAINT(6, 5)
DEFINE_KERNELS_BY_CONSTRAINT(6, 6)

// 准备参数结构体
struct RearrangeParams {
Expand Down Expand Up @@ -294,6 +305,9 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
case 5: \
GET_REARRANGE_KERNEL_BY_CONSTRAINT(block_array_size, 5); \
break; \
case 6: \
GET_REARRANGE_KERNEL_BY_CONSTRAINT(block_array_size, 6); \
break; \
}

#define GET_REARRANGE_KERNEL_BY_BLOCK_NUM \
Expand All @@ -313,6 +327,9 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
case 5: \
GET_REARRANGE_KERNEL_BY_GRID_NUM(5); \
break; \
case 6: \
GET_REARRANGE_KERNEL_BY_GRID_NUM(6); \
break; \
}

GET_REARRANGE_KERNEL_BY_BLOCK_NUM
Expand Down
2 changes: 2 additions & 0 deletions src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}

#ifdef ENABLE_QY_API
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t Descriptor::launchKernel(const I8GemmInfo &info, Tdata *y, const Tdata *bias, const int8_t *x_packed, const float *x_scale, const int8_t *w_packed, const float *w_scale, void *stream_, void *workspace) const {
cudaStream_t stream = (cudaStream_t)stream_;
Expand Down Expand Up @@ -112,6 +113,7 @@ infiniStatus_t Descriptor::launchKernel(const I8GemmInfo &info, Tdata *y, const

return INFINI_STATUS_SUCCESS;
}
#endif

infiniStatus_t Descriptor::calculate(
void *workspace,
Expand Down
Loading