From ea0fa9565a1a4ac669d8a3813156a7885ac631fc Mon Sep 17 00:00:00 2001 From: Nicolas PASCAL <344493+haricot@users.noreply.github.com> Date: Wed, 8 Jan 2025 03:17:11 +0100 Subject: [PATCH 1/2] cuda fallback bf16 for compute_cap < 8.0 (#57) --- candle-kernels/src/binary.cu | 5 +++++ candle-kernels/src/cast.cu | 8 ++++++++ candle-kernels/src/compatibility.cuh | 30 ++++++++++++++-------------- candle-kernels/src/cuda_utils.cuh | 4 ++++ candle-kernels/src/fill.cu | 4 ++++ candle-kernels/src/fused_rope.cu | 5 +++-- candle-kernels/src/indexing.cu | 4 ++++ candle-kernels/src/kvconcat.cu | 4 ++++ candle-kernels/src/reduce.cu | 6 ++++++ candle-kernels/src/ternary.cu | 4 ++++ candle-kernels/src/unary.cu | 13 ++++++++++++ 11 files changed, 70 insertions(+), 17 deletions(-) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index 67dced483f..5c6f5ca5d2 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -35,6 +35,11 @@ BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ge_f8_e4m3, F8E4M3_TO_FLOAT(x) >= F8E4M3_T #endif #if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ < 800 +#include "cuda_bf16.h" +BINARY_OP(__nv_bfloat16, bmul_bf16, x * y) +BINARY_OP(__nv_bfloat16, badd_bf16, x + y) +#endif BINARY_OP(__half, badd_f16, x + y) BINARY_OP(__half, bdiv_f16, x / y) BINARY_OP(__half, bmul_f16, x * y) diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index bd2a9723a7..010f40979a 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -177,6 +177,14 @@ CAST_OP_FP8_INTO(__nv_bfloat16, __nv_fp8_e4m3, cast_bf16_f8_e4m3) #endif #if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ < 800 +#include "cuda_bf16.h" +CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) +CAST_OP(__nv_bfloat16, float, cast_bf16_f32) +CAST_OP(float, __nv_bfloat16, cast_f16_bf16) +CAST_OP(float, __nv_bfloat16, cast_f32_bf16) +#endif + CAST_OP(__half, __half, cast_f16_f16) CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8) diff --git a/candle-kernels/src/compatibility.cuh b/candle-kernels/src/compatibility.cuh index 73d8bc1bc4..18618841fe 100644 --- a/candle-kernels/src/compatibility.cuh +++ b/candle-kernels/src/compatibility.cuh @@ -39,21 +39,21 @@ __device__ double atomicAdd(double* address, double val) { // The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher. // Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 __device__ __half atomicAdd(__half *address, __half val) { - // unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - // unsigned int old = *address_as_ui; - // unsigned int assumed; - // bool unaligned = (size_t) address & 2; - // do { - // assumed = old; - // unsigned int hsum; - // hsum = unaligned ? (old >> 16) : (old & 0xffff); - // hsum = __half_as_ushort(__ushort_as_half(hsum) + val); - // old = atomicCAS(address_as_ui, assumed, - // unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum - // ); - - // } while (assumed != old); - // return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); + unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + bool unaligned = (size_t) address & 2; + do { + assumed = old; + unsigned int hsum; + hsum = unaligned ? (old >> 16) : (old & 0xffff); + hsum = __half_as_ushort(__ushort_as_half(hsum) + val); + old = atomicCAS(address_as_ui, assumed, + unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum + ); + + } while (assumed != old); + return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); } #endif diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 5325b71d67..f0ca68ea39 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -192,6 +192,10 @@ __device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, __device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); } __device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); } #if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ < 800 +#include "cuda_bf16.h" +__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); } +#endif __device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); } __device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); } __device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); } diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index eeea8d4cd4..07ce877463 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -41,6 +41,10 @@ COPY2D_OP(int32_t, copy2d_i32) COPY2D_OP(int64_t, copy2d_i64) #if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ < 800 +#include +COPY2D_OP(__nv_bfloat16, copy2d_bf16) +#endif extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__half, copy2d_f16) #endif diff --git a/candle-kernels/src/fused_rope.cu b/candle-kernels/src/fused_rope.cu index 9f7873cca7..65e0087da2 100644 --- a/candle-kernels/src/fused_rope.cu +++ b/candle-kernels/src/fused_rope.cu @@ -189,7 +189,7 @@ extern "C" __global__ void rotary_embedding_kernel_neox_f64( apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); } -#if __CUDA_ARCH__ >= 800 +#if __CUDA_ARCH__ >= 530 #include extern "C" __global__ void rotary_embedding_kernel_bf16( const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] @@ -228,4 +228,5 @@ extern "C" __global__ void rotary_embedding_kernel_neox_bf16( apply_rotary_embedding<__nv_bfloat16, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); } -#endif \ No newline at end of file +#endif + diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 9005fedaa5..ff90e78a1a 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -268,6 +268,10 @@ SA_OP_F8(__nv_fp8_e4m3, uint8_t, sa_u8_f8_e4m3) #endif #if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ < 800 +#include "cuda_bf16.h" +IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) +#endif IS_OP(__half, int16_t, is_i16_f16) IS_OP(__half, int32_t, is_i32_f16) IS_OP(__half, int64_t, is_i64_f16) diff --git a/candle-kernels/src/kvconcat.cu b/candle-kernels/src/kvconcat.cu index 9b594e63d0..2fc01e2787 100644 --- a/candle-kernels/src/kvconcat.cu +++ b/candle-kernels/src/kvconcat.cu @@ -45,6 +45,10 @@ KVCONCAT_OP(double, kvconcat_f64) KVCONCAT_OP(float, kvconcat_f32) #if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ < 800 +#include "cuda_bf16.h" +KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16) +#endif KVCONCAT_OP(__half, kvconcat_f16) #endif diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 782457eb23..d96bc2e99d 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -590,6 +590,12 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm #endif #if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ < 800 +#include "cuda_bf16.h" +ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) +SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) +RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) +#endif SOFTMAX_OP(__half, float, softmax_f16) RMSNORM_OP(__half, rmsnorm_f16) LAYERNORM_OP(__half, layernorm_f16) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index 81350dadd8..6051f00d52 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -50,6 +50,10 @@ WHERE_OP(__nv_fp8_e4m3, uint8_t, where_u8_fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ < 800 +#include "cuda_bf16.h" +WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) +#endif WHERE_OP(__half, int16_t, where_i16_f16) WHERE_OP(__half, int32_t, where_i32_f16) WHERE_OP(__half, int64_t, where_i64_f16) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index e95b1c1b22..0792b43795 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -155,6 +155,19 @@ UNARY_OP(__nv_fp8_e4m3, usigmoid_fp8_e4m3, __nv_fp8_e4m3(sigmoid_fwd(F8E4M3_TO_F #endif #if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ < 800 +#include "cuda_bf16.h" +template +__device__ __forceinline__ T silu_fwd_fallback(T x) { + const T one = T(1.0f); + const T neg_x = -x; + const T exp_neg_x = expg(neg_x); + return x / (one + exp_neg_x); +} + +UNARY_OP(__nv_bfloat16, ucopy_bf16, x) +UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd_fallback(x)) +#endif UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, urecip_f16, recipg(x)) From 2042704a8c1f20a9bac0485a28ba1f3fa6f86549 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 7 Jan 2025 21:49:30 -0500 Subject: [PATCH 2/2] Better bf16 support --- candle-kernels/src/affine.cu | 4 ---- candle-kernels/src/binary.cu | 9 --------- candle-kernels/src/cast.cu | 15 +-------------- candle-kernels/src/conv.cu | 4 ---- candle-kernels/src/cuda_utils.cuh | 15 +++------------ candle-kernels/src/fill.cu | 13 +++---------- candle-kernels/src/fused_rms_norm.cu | 4 +--- candle-kernels/src/fused_rope.cu | 2 -- candle-kernels/src/indexing.cu | 10 ---------- candle-kernels/src/kvconcat.cu | 12 ++---------- candle-kernels/src/reduce.cu | 10 ---------- candle-kernels/src/sort.cu | 4 ---- candle-kernels/src/ternary.cu | 8 -------- candle-kernels/src/unary.cu | 17 ----------------- 14 files changed, 10 insertions(+), 117 deletions(-) diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 7f9e061c38..650e5dd54d 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -28,7 +28,6 @@ extern "C" __global__ void FN_NAME( \ } \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -37,11 +36,8 @@ AFFINE_OP(__nv_bfloat16, affine_bf16, x * mul + add) #define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add))) -#endif -#if __CUDA_ARCH__ >= 530 AFFINE_OP(__half, affine_f16, x * mul + add) -#endif AFFINE_OP(float, affine_f32, x * mul + add) AFFINE_OP(double, affine_f64, x * mul + add) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index 5c6f5ca5d2..b9dd936870 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -1,7 +1,6 @@ #include "binary_op_macros.cuh" #include -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -32,14 +31,7 @@ BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, lt_f8_e4m3, F8E4M3_TO_FLOAT(x) < F8E4M3_TO BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, le_f8_e4m3, F8E4M3_TO_FLOAT(x) <= F8E4M3_TO_FLOAT(y)) BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, gt_f8_e4m3, F8E4M3_TO_FLOAT(x) > F8E4M3_TO_FLOAT(y)) BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ge_f8_e4m3, F8E4M3_TO_FLOAT(x) >= F8E4M3_TO_FLOAT(y)) -#endif -#if __CUDA_ARCH__ >= 530 -#if __CUDA_ARCH__ < 800 -#include "cuda_bf16.h" -BINARY_OP(__nv_bfloat16, bmul_bf16, x * y) -BINARY_OP(__nv_bfloat16, badd_bf16, x + y) -#endif BINARY_OP(__half, badd_f16, x + y) BINARY_OP(__half, bdiv_f16, x / y) BINARY_OP(__half, bmul_f16, x * y) @@ -52,7 +44,6 @@ BINARY_OP_OUT(__half, uint8_t, lt_f16, x < y) BINARY_OP_OUT(__half, uint8_t, le_f16, x <= y) BINARY_OP_OUT(__half, uint8_t, gt_f16, x > y) BINARY_OP_OUT(__half, uint8_t, ge_f16, x >= y) -#endif BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_f64, x + y); diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 010f40979a..d2756c53c0 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -24,7 +24,6 @@ __device__ void cast_( } } -#if __CUDA_ARCH__ >= 800 #define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) template @@ -71,7 +70,6 @@ __device__ void cast_fp8_into_( } } } -#endif template __device__ void cast_through( @@ -143,9 +141,9 @@ extern "C" __global__ void FN_NAME( \ cast_through(numel, num_dims, info, inp, out); \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" + CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) CAST_OP(__nv_fp8_e4m3, __nv_fp8_e4m3, cast_f8_e4m3_f8_e4m3) @@ -174,16 +172,6 @@ CAST_OP_FP8_INTO(int32_t, __nv_fp8_e4m3, cast_i32_f8_e4m3) CAST_OP_FP8(__nv_fp8_e4m3, int32_t, cast_f8_e4m3_i32) CAST_OP_FP8(__nv_fp8_e4m3, __nv_bfloat16, cast_f8_e4m3_bf16) CAST_OP_FP8_INTO(__nv_bfloat16, __nv_fp8_e4m3, cast_bf16_f8_e4m3) -#endif - -#if __CUDA_ARCH__ >= 530 -#if __CUDA_ARCH__ < 800 -#include "cuda_bf16.h" -CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) -CAST_OP(__nv_bfloat16, float, cast_bf16_f32) -CAST_OP(float, __nv_bfloat16, cast_f16_bf16) -CAST_OP(float, __nv_bfloat16, cast_f32_bf16) -#endif CAST_OP(__half, __half, cast_f16_f16) @@ -197,7 +185,6 @@ CAST_OP(float, __half, cast_f32_f16) CAST_OP(double, __half, cast_f64_f16) CAST_OP(int32_t, __half, cast_i32_f16 ) CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32) -#endif CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 436e8a9fd4..f8caafc147 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -691,7 +691,6 @@ extern "C" __global__ void FN_NAME( \ upsample_nearest2d(w_out, h_out, w_scale, h_scale, info, src, dst); \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_bf16.h" CONV1D_OP(__nv_bfloat16, float, conv1d_bf16) @@ -716,9 +715,7 @@ COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16) // IM2COL_OP(__nv_fp8_e4m3, im2col_f8_e5m) // IM2COL1D_OP(__nv_fp8_e4m3, im2col1d_f8_e5m) // COL2IM1D_OP(__nv_fp8_e4m3, col2im1d_f8_e5m) -#endif -#if __CUDA_ARCH__ >= 530 CONV1D_OP(__half, float, conv1d_f16) CONV2D_OP(__half, float, conv2d_f16) CONVT1D_OP(__half, float, conv_transpose1d_f16) @@ -729,7 +726,6 @@ UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) IM2COL_OP(__half, im2col_f16) IM2COL1D_OP(__half, im2col1d_f16) COL2IM1D_OP(__half, col2im1d_f16) -#endif CONV1D_OP(float, float, conv1d_f32) CONV1D_OP(double, double, conv1d_f64) diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index f0ca68ea39..e37d930c04 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -2,6 +2,9 @@ #include #include +#include "cuda_fp8.h" +#include "cuda_bf16.h" + // TODO: This is often used to check that the data is contiguous so that // kernels can be easily mapped. However this only returns true for row // major, if all the inputs are column major, we could apply the fast path @@ -191,11 +194,6 @@ __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, __device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); } __device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); } __device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); } -#if __CUDA_ARCH__ >= 530 -#if __CUDA_ARCH__ < 800 -#include "cuda_bf16.h" -__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); } -#endif __device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); } __device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); } __device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); } @@ -214,11 +212,6 @@ __device__ __forceinline__ __half logg(__half a) { return hlog(a); } __device__ __forceinline__ __half expg(__half a) { return hexp(a); } __device__ __forceinline__ __half absg(__half a) { return __habs(a); } __device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); } -#endif - -#if __CUDA_ARCH__ >= 800 -#include "cuda_fp8.h" -#include "cuda_bf16.h" __device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); } __device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); } @@ -260,5 +253,3 @@ __device__ __forceinline__ __nv_fp8_e4m3 expg(__nv_fp8_e4m3 a) { return __nv_fp8 __device__ __forceinline__ __nv_fp8_e4m3 absg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(fabsf(F8E4M3_TO_FLOAT(a))); } __device__ __forceinline__ __nv_fp8_e4m3 copysigng(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(copysignf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } - -#endif diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index 07ce877463..ab14180b05 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -40,22 +40,15 @@ COPY2D_OP(int16_t, copy2d_i16) COPY2D_OP(int32_t, copy2d_i32) COPY2D_OP(int64_t, copy2d_i64) -#if __CUDA_ARCH__ >= 530 -#if __CUDA_ARCH__ < 800 #include -COPY2D_OP(__nv_bfloat16, copy2d_bf16) -#endif +#include + extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__half, copy2d_f16) -#endif - -#if __CUDA_ARCH__ >= 800 -#include -#include extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__nv_bfloat16, copy2d_bf16) extern "C" __global__ void fill_f8_e4m3(__nv_fp8_e4m3 *buf, __nv_fp8_e4m3 value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__nv_fp8_e4m3, copy2d_f8_e4m3) -#endif + diff --git a/candle-kernels/src/fused_rms_norm.cu b/candle-kernels/src/fused_rms_norm.cu index f012e002ad..2100a69043 100644 --- a/candle-kernels/src/fused_rms_norm.cu +++ b/candle-kernels/src/fused_rms_norm.cu @@ -76,7 +76,5 @@ extern "C" __global__ void FN_NAME(\ RMS_NORM_OP(rms_norm_f32, float) RMS_NORM_OP(rms_norm_f16, __half) -#if __CUDA_ARCH__ >= 800 #include -RMS_NORM_OP(rms_norm_bf16, __nv_bfloat16) -#endif \ No newline at end of file +RMS_NORM_OP(rms_norm_bf16, __nv_bfloat16) \ No newline at end of file diff --git a/candle-kernels/src/fused_rope.cu b/candle-kernels/src/fused_rope.cu index 65e0087da2..872d447444 100644 --- a/candle-kernels/src/fused_rope.cu +++ b/candle-kernels/src/fused_rope.cu @@ -189,7 +189,6 @@ extern "C" __global__ void rotary_embedding_kernel_neox_f64( apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); } -#if __CUDA_ARCH__ >= 530 #include extern "C" __global__ void rotary_embedding_kernel_bf16( const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] @@ -228,5 +227,4 @@ extern "C" __global__ void rotary_embedding_kernel_neox_bf16( apply_rotary_embedding<__nv_bfloat16, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); } -#endif diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index ff90e78a1a..963c1e0c02 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -99,7 +99,6 @@ __device__ void index_add( } } -#if __CUDA_ARCH__ >= 800 #define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) template @@ -148,7 +147,6 @@ __device__ void index_add_f8( } } } -#endif #define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ @@ -220,7 +218,6 @@ extern "C" __global__ void FN_NAME( \ ) { scatter_add_f8(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -265,13 +262,7 @@ SA_OP_F8(__nv_fp8_e4m3, int32_t, sa_i32_f8_e4m3) SA_OP_F8(__nv_fp8_e4m3, int64_t, sa_i64_f8_e4m3) SA_OP_F8(__nv_fp8_e4m3, uint32_t, sa_u32_f8_e4m3) SA_OP_F8(__nv_fp8_e4m3, uint8_t, sa_u8_f8_e4m3) -#endif -#if __CUDA_ARCH__ >= 530 -#if __CUDA_ARCH__ < 800 -#include "cuda_bf16.h" -IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) -#endif IS_OP(__half, int16_t, is_i16_f16) IS_OP(__half, int32_t, is_i32_f16) IS_OP(__half, int64_t, is_i64_f16) @@ -292,7 +283,6 @@ SA_OP(__half, int32_t, sa_i32_f16) SA_OP(__half, int64_t, sa_i64_f16) SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) -#endif IS_OP(float, int16_t, is_i16_f32) IS_OP(double, int16_t, is_i16_f64) diff --git a/candle-kernels/src/kvconcat.cu b/candle-kernels/src/kvconcat.cu index 2fc01e2787..bbf19e1a52 100644 --- a/candle-kernels/src/kvconcat.cu +++ b/candle-kernels/src/kvconcat.cu @@ -44,18 +44,10 @@ KVCONCAT_OP(uint8_t, kvconcat_u8) KVCONCAT_OP(double, kvconcat_f64) KVCONCAT_OP(float, kvconcat_f32) -#if __CUDA_ARCH__ >= 530 -#if __CUDA_ARCH__ < 800 +#include "cuda_fp8.h" #include "cuda_bf16.h" KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16) -#endif -KVCONCAT_OP(__half, kvconcat_f16) -#endif -#if __CUDA_ARCH__ >= 800 -#include "cuda_fp8.h" -#include "cuda_bf16.h" +KVCONCAT_OP(__half, kvconcat_f16) -KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16) KVCONCAT_OP(__nv_fp8_e4m3, kvconcat_f8_e4m3) -#endif \ No newline at end of file diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index d96bc2e99d..6013f45859 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -571,7 +571,6 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, rope_thd(src, cos, sin, dst, b, t, h, d); \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_bf16.h" SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) @@ -587,22 +586,13 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm // LAYERNORM_OP(__nv_fp8_e4m3, layernorm_fp8_e4m3) // ROPE_OP(__nv_fp8_e4m3, rope_fp8_e4m3, rope_i_fp8_e4m3, rope_thd_fp8_e4m3) // FAST_OP(__nv_fp8_e4m3, fast_min_fp8_e4m3, fast_max_fp8_e4m3, fast_argmin_fp8_e4m3, fast_argmax_fp8_e4m3, fast_sum_fp8_e4m3) -#endif -#if __CUDA_ARCH__ >= 530 -#if __CUDA_ARCH__ < 800 -#include "cuda_bf16.h" -ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) -SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) -RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) -#endif SOFTMAX_OP(__half, float, softmax_f16) RMSNORM_OP(__half, rmsnorm_f16) LAYERNORM_OP(__half, layernorm_f16) ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) -#endif SUM_OP(float, sum_f32) SUM_OP(double, sum_f64) diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index 71ac5e381e..149e0045c7 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -73,17 +73,13 @@ extern "C" __global__ void asort_desc_##RUST_NAME( \ k_argsort(x, dst, ncols, ncols_pad); \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_bf16.h" ASORT_OP(__nv_bfloat16, bf16) // NOTE: No sort ops for f8 // ASORT_OP(__nv_fp8_e4m3, fp8_e4m3) -#endif -#if __CUDA_ARCH__ >= 530 ASORT_OP(__half, f16) -#endif ASORT_OP(float, f32) ASORT_OP(double, f64) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index 6051f00d52..2f22a1e62d 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -32,7 +32,6 @@ extern "C" __global__ void FN_NAME( \ } \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -47,19 +46,12 @@ WHERE_OP(__nv_fp8_e4m3, int32_t, where_i32_fp8_e4m3) WHERE_OP(__nv_fp8_e4m3, int64_t, where_i64_fp8_e4m3) WHERE_OP(__nv_fp8_e4m3, uint32_t, where_u32_fp8_e4m3) WHERE_OP(__nv_fp8_e4m3, uint8_t, where_u8_fp8_e4m3) -#endif -#if __CUDA_ARCH__ >= 530 -#if __CUDA_ARCH__ < 800 -#include "cuda_bf16.h" -WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) -#endif WHERE_OP(__half, int16_t, where_i16_f16) WHERE_OP(__half, int32_t, where_i32_f16) WHERE_OP(__half, int64_t, where_i64_f16) WHERE_OP(__half, uint32_t, where_u32_f16) WHERE_OP(__half, uint8_t, where_u8_f16) -#endif WHERE_OP(float, int16_t, where_i16_f32) WHERE_OP(double, int16_t, where_i16_f64) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 0792b43795..808987de10 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -97,7 +97,6 @@ __device__ T sign_(T t) { } -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -152,22 +151,7 @@ UNARY_OP(__nv_fp8_e4m3, usilu_fp8_e4m3, __nv_fp8_e4m3(silu_fwd(F8E4M3_TO_FLOAT(x UNARY_OP1(__nv_fp8_e4m3, upowf_fp8_e4m3, powg(x, param)) UNARY_OP(__nv_fp8_e4m3, usign_fp8_e4m3, __nv_fp8_e4m3(sign_(F8E4M3_TO_FLOAT(x)))) UNARY_OP(__nv_fp8_e4m3, usigmoid_fp8_e4m3, __nv_fp8_e4m3(sigmoid_fwd(F8E4M3_TO_FLOAT(x)))) -#endif -#if __CUDA_ARCH__ >= 530 -#if __CUDA_ARCH__ < 800 -#include "cuda_bf16.h" -template -__device__ __forceinline__ T silu_fwd_fallback(T x) { - const T one = T(1.0f); - const T neg_x = -x; - const T exp_neg_x = expg(neg_x); - return x / (one + exp_neg_x); -} - -UNARY_OP(__nv_bfloat16, ucopy_bf16, x) -UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd_fallback(x)) -#endif UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, urecip_f16, recipg(x)) @@ -192,7 +176,6 @@ UNARY_OP(__half, usilu_f16, silu_fwd(x)) UNARY_OP1(__half, upowf_f16, powg(x, param)) UNARY_OP(__half, usign_f16, sign_(x)) UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x)) -#endif UNARY_OP(uint8_t, ucopy_u8, x) UNARY_OP(uint32_t, ucopy_u32, x)