diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 7f9e061c38..650e5dd54d 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -28,7 +28,6 @@ extern "C" __global__ void FN_NAME( \ } \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -37,11 +36,8 @@ AFFINE_OP(__nv_bfloat16, affine_bf16, x * mul + add) #define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add))) -#endif -#if __CUDA_ARCH__ >= 530 AFFINE_OP(__half, affine_f16, x * mul + add) -#endif AFFINE_OP(float, affine_f32, x * mul + add) AFFINE_OP(double, affine_f64, x * mul + add) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index 67dced483f..b9dd936870 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -1,7 +1,6 @@ #include "binary_op_macros.cuh" #include -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -32,9 +31,7 @@ BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, lt_f8_e4m3, F8E4M3_TO_FLOAT(x) < F8E4M3_TO BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, le_f8_e4m3, F8E4M3_TO_FLOAT(x) <= F8E4M3_TO_FLOAT(y)) BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, gt_f8_e4m3, F8E4M3_TO_FLOAT(x) > F8E4M3_TO_FLOAT(y)) BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ge_f8_e4m3, F8E4M3_TO_FLOAT(x) >= F8E4M3_TO_FLOAT(y)) -#endif -#if __CUDA_ARCH__ >= 530 BINARY_OP(__half, badd_f16, x + y) BINARY_OP(__half, bdiv_f16, x / y) BINARY_OP(__half, bmul_f16, x * y) @@ -47,7 +44,6 @@ BINARY_OP_OUT(__half, uint8_t, lt_f16, x < y) BINARY_OP_OUT(__half, uint8_t, le_f16, x <= y) BINARY_OP_OUT(__half, uint8_t, gt_f16, x > y) BINARY_OP_OUT(__half, uint8_t, ge_f16, x >= y) -#endif BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_f64, x + y); diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index bd2a9723a7..d2756c53c0 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -24,7 +24,6 @@ __device__ void cast_( } } -#if __CUDA_ARCH__ >= 800 #define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) template @@ -71,7 +70,6 @@ __device__ void cast_fp8_into_( } } } -#endif template __device__ void cast_through( @@ -143,9 +141,9 @@ extern "C" __global__ void FN_NAME( \ cast_through(numel, num_dims, info, inp, out); \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" + CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) CAST_OP(__nv_fp8_e4m3, __nv_fp8_e4m3, cast_f8_e4m3_f8_e4m3) @@ -174,9 +172,7 @@ CAST_OP_FP8_INTO(int32_t, __nv_fp8_e4m3, cast_i32_f8_e4m3) CAST_OP_FP8(__nv_fp8_e4m3, int32_t, cast_f8_e4m3_i32) CAST_OP_FP8(__nv_fp8_e4m3, __nv_bfloat16, cast_f8_e4m3_bf16) CAST_OP_FP8_INTO(__nv_bfloat16, __nv_fp8_e4m3, cast_bf16_f8_e4m3) -#endif -#if __CUDA_ARCH__ >= 530 CAST_OP(__half, __half, cast_f16_f16) CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8) @@ -189,7 +185,6 @@ CAST_OP(float, __half, cast_f32_f16) CAST_OP(double, __half, cast_f64_f16) CAST_OP(int32_t, __half, cast_i32_f16 ) CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32) -#endif CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) diff --git a/candle-kernels/src/compatibility.cuh b/candle-kernels/src/compatibility.cuh index 73d8bc1bc4..18618841fe 100644 --- a/candle-kernels/src/compatibility.cuh +++ b/candle-kernels/src/compatibility.cuh @@ -39,21 +39,21 @@ __device__ double atomicAdd(double* address, double val) { // The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher. // Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 __device__ __half atomicAdd(__half *address, __half val) { - // unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - // unsigned int old = *address_as_ui; - // unsigned int assumed; - // bool unaligned = (size_t) address & 2; - // do { - // assumed = old; - // unsigned int hsum; - // hsum = unaligned ? (old >> 16) : (old & 0xffff); - // hsum = __half_as_ushort(__ushort_as_half(hsum) + val); - // old = atomicCAS(address_as_ui, assumed, - // unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum - // ); - - // } while (assumed != old); - // return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); + unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + bool unaligned = (size_t) address & 2; + do { + assumed = old; + unsigned int hsum; + hsum = unaligned ? (old >> 16) : (old & 0xffff); + hsum = __half_as_ushort(__ushort_as_half(hsum) + val); + old = atomicCAS(address_as_ui, assumed, + unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum + ); + + } while (assumed != old); + return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); } #endif diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 436e8a9fd4..f8caafc147 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -691,7 +691,6 @@ extern "C" __global__ void FN_NAME( \ upsample_nearest2d(w_out, h_out, w_scale, h_scale, info, src, dst); \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_bf16.h" CONV1D_OP(__nv_bfloat16, float, conv1d_bf16) @@ -716,9 +715,7 @@ COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16) // IM2COL_OP(__nv_fp8_e4m3, im2col_f8_e5m) // IM2COL1D_OP(__nv_fp8_e4m3, im2col1d_f8_e5m) // COL2IM1D_OP(__nv_fp8_e4m3, col2im1d_f8_e5m) -#endif -#if __CUDA_ARCH__ >= 530 CONV1D_OP(__half, float, conv1d_f16) CONV2D_OP(__half, float, conv2d_f16) CONVT1D_OP(__half, float, conv_transpose1d_f16) @@ -729,7 +726,6 @@ UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) IM2COL_OP(__half, im2col_f16) IM2COL1D_OP(__half, im2col1d_f16) COL2IM1D_OP(__half, col2im1d_f16) -#endif CONV1D_OP(float, float, conv1d_f32) CONV1D_OP(double, double, conv1d_f64) diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 5325b71d67..e37d930c04 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -2,6 +2,9 @@ #include #include +#include "cuda_fp8.h" +#include "cuda_bf16.h" + // TODO: This is often used to check that the data is contiguous so that // kernels can be easily mapped. However this only returns true for row // major, if all the inputs are column major, we could apply the fast path @@ -191,7 +194,6 @@ __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, __device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); } __device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); } __device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); } -#if __CUDA_ARCH__ >= 530 __device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); } __device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); } __device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); } @@ -210,11 +212,6 @@ __device__ __forceinline__ __half logg(__half a) { return hlog(a); } __device__ __forceinline__ __half expg(__half a) { return hexp(a); } __device__ __forceinline__ __half absg(__half a) { return __habs(a); } __device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); } -#endif - -#if __CUDA_ARCH__ >= 800 -#include "cuda_fp8.h" -#include "cuda_bf16.h" __device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); } __device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); } @@ -256,5 +253,3 @@ __device__ __forceinline__ __nv_fp8_e4m3 expg(__nv_fp8_e4m3 a) { return __nv_fp8 __device__ __forceinline__ __nv_fp8_e4m3 absg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(fabsf(F8E4M3_TO_FLOAT(a))); } __device__ __forceinline__ __nv_fp8_e4m3 copysigng(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(copysignf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } - -#endif diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index eeea8d4cd4..ab14180b05 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -40,18 +40,15 @@ COPY2D_OP(int16_t, copy2d_i16) COPY2D_OP(int32_t, copy2d_i32) COPY2D_OP(int64_t, copy2d_i64) -#if __CUDA_ARCH__ >= 530 -extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } -COPY2D_OP(__half, copy2d_f16) -#endif - -#if __CUDA_ARCH__ >= 800 #include #include +extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__half, copy2d_f16) + extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__nv_bfloat16, copy2d_bf16) extern "C" __global__ void fill_f8_e4m3(__nv_fp8_e4m3 *buf, __nv_fp8_e4m3 value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__nv_fp8_e4m3, copy2d_f8_e4m3) -#endif + diff --git a/candle-kernels/src/fused_rms_norm.cu b/candle-kernels/src/fused_rms_norm.cu index f012e002ad..2100a69043 100644 --- a/candle-kernels/src/fused_rms_norm.cu +++ b/candle-kernels/src/fused_rms_norm.cu @@ -76,7 +76,5 @@ extern "C" __global__ void FN_NAME(\ RMS_NORM_OP(rms_norm_f32, float) RMS_NORM_OP(rms_norm_f16, __half) -#if __CUDA_ARCH__ >= 800 #include -RMS_NORM_OP(rms_norm_bf16, __nv_bfloat16) -#endif \ No newline at end of file +RMS_NORM_OP(rms_norm_bf16, __nv_bfloat16) \ No newline at end of file diff --git a/candle-kernels/src/fused_rope.cu b/candle-kernels/src/fused_rope.cu index 9f7873cca7..872d447444 100644 --- a/candle-kernels/src/fused_rope.cu +++ b/candle-kernels/src/fused_rope.cu @@ -189,7 +189,6 @@ extern "C" __global__ void rotary_embedding_kernel_neox_f64( apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); } -#if __CUDA_ARCH__ >= 800 #include extern "C" __global__ void rotary_embedding_kernel_bf16( const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] @@ -228,4 +227,4 @@ extern "C" __global__ void rotary_embedding_kernel_neox_bf16( apply_rotary_embedding<__nv_bfloat16, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); } -#endif \ No newline at end of file + diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 9005fedaa5..963c1e0c02 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -99,7 +99,6 @@ __device__ void index_add( } } -#if __CUDA_ARCH__ >= 800 #define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) template @@ -148,7 +147,6 @@ __device__ void index_add_f8( } } } -#endif #define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ @@ -220,7 +218,6 @@ extern "C" __global__ void FN_NAME( \ ) { scatter_add_f8(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -265,9 +262,7 @@ SA_OP_F8(__nv_fp8_e4m3, int32_t, sa_i32_f8_e4m3) SA_OP_F8(__nv_fp8_e4m3, int64_t, sa_i64_f8_e4m3) SA_OP_F8(__nv_fp8_e4m3, uint32_t, sa_u32_f8_e4m3) SA_OP_F8(__nv_fp8_e4m3, uint8_t, sa_u8_f8_e4m3) -#endif -#if __CUDA_ARCH__ >= 530 IS_OP(__half, int16_t, is_i16_f16) IS_OP(__half, int32_t, is_i32_f16) IS_OP(__half, int64_t, is_i64_f16) @@ -288,7 +283,6 @@ SA_OP(__half, int32_t, sa_i32_f16) SA_OP(__half, int64_t, sa_i64_f16) SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) -#endif IS_OP(float, int16_t, is_i16_f32) IS_OP(double, int16_t, is_i16_f64) diff --git a/candle-kernels/src/kvconcat.cu b/candle-kernels/src/kvconcat.cu index 9b594e63d0..bbf19e1a52 100644 --- a/candle-kernels/src/kvconcat.cu +++ b/candle-kernels/src/kvconcat.cu @@ -44,14 +44,10 @@ KVCONCAT_OP(uint8_t, kvconcat_u8) KVCONCAT_OP(double, kvconcat_f64) KVCONCAT_OP(float, kvconcat_f32) -#if __CUDA_ARCH__ >= 530 -KVCONCAT_OP(__half, kvconcat_f16) -#endif - -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" - KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16) + +KVCONCAT_OP(__half, kvconcat_f16) + KVCONCAT_OP(__nv_fp8_e4m3, kvconcat_f8_e4m3) -#endif \ No newline at end of file diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 782457eb23..6013f45859 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -571,7 +571,6 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, rope_thd(src, cos, sin, dst, b, t, h, d); \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_bf16.h" SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) @@ -587,16 +586,13 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm // LAYERNORM_OP(__nv_fp8_e4m3, layernorm_fp8_e4m3) // ROPE_OP(__nv_fp8_e4m3, rope_fp8_e4m3, rope_i_fp8_e4m3, rope_thd_fp8_e4m3) // FAST_OP(__nv_fp8_e4m3, fast_min_fp8_e4m3, fast_max_fp8_e4m3, fast_argmin_fp8_e4m3, fast_argmax_fp8_e4m3, fast_sum_fp8_e4m3) -#endif -#if __CUDA_ARCH__ >= 530 SOFTMAX_OP(__half, float, softmax_f16) RMSNORM_OP(__half, rmsnorm_f16) LAYERNORM_OP(__half, layernorm_f16) ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) -#endif SUM_OP(float, sum_f32) SUM_OP(double, sum_f64) diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index 71ac5e381e..149e0045c7 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -73,17 +73,13 @@ extern "C" __global__ void asort_desc_##RUST_NAME( \ k_argsort(x, dst, ncols, ncols_pad); \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_bf16.h" ASORT_OP(__nv_bfloat16, bf16) // NOTE: No sort ops for f8 // ASORT_OP(__nv_fp8_e4m3, fp8_e4m3) -#endif -#if __CUDA_ARCH__ >= 530 ASORT_OP(__half, f16) -#endif ASORT_OP(float, f32) ASORT_OP(double, f64) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index 81350dadd8..2f22a1e62d 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -32,7 +32,6 @@ extern "C" __global__ void FN_NAME( \ } \ } \ -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -47,15 +46,12 @@ WHERE_OP(__nv_fp8_e4m3, int32_t, where_i32_fp8_e4m3) WHERE_OP(__nv_fp8_e4m3, int64_t, where_i64_fp8_e4m3) WHERE_OP(__nv_fp8_e4m3, uint32_t, where_u32_fp8_e4m3) WHERE_OP(__nv_fp8_e4m3, uint8_t, where_u8_fp8_e4m3) -#endif -#if __CUDA_ARCH__ >= 530 WHERE_OP(__half, int16_t, where_i16_f16) WHERE_OP(__half, int32_t, where_i32_f16) WHERE_OP(__half, int64_t, where_i64_f16) WHERE_OP(__half, uint32_t, where_u32_f16) WHERE_OP(__half, uint8_t, where_u8_f16) -#endif WHERE_OP(float, int16_t, where_i16_f32) WHERE_OP(double, int16_t, where_i16_f64) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index e95b1c1b22..808987de10 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -97,7 +97,6 @@ __device__ T sign_(T t) { } -#if __CUDA_ARCH__ >= 800 #include "cuda_fp8.h" #include "cuda_bf16.h" @@ -152,9 +151,7 @@ UNARY_OP(__nv_fp8_e4m3, usilu_fp8_e4m3, __nv_fp8_e4m3(silu_fwd(F8E4M3_TO_FLOAT(x UNARY_OP1(__nv_fp8_e4m3, upowf_fp8_e4m3, powg(x, param)) UNARY_OP(__nv_fp8_e4m3, usign_fp8_e4m3, __nv_fp8_e4m3(sign_(F8E4M3_TO_FLOAT(x)))) UNARY_OP(__nv_fp8_e4m3, usigmoid_fp8_e4m3, __nv_fp8_e4m3(sigmoid_fwd(F8E4M3_TO_FLOAT(x)))) -#endif -#if __CUDA_ARCH__ >= 530 UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, urecip_f16, recipg(x)) @@ -179,7 +176,6 @@ UNARY_OP(__half, usilu_f16, silu_fwd(x)) UNARY_OP1(__half, upowf_f16, powg(x, param)) UNARY_OP(__half, usign_f16, sign_(x)) UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x)) -#endif UNARY_OP(uint8_t, ucopy_u8, x) UNARY_OP(uint32_t, ucopy_u32, x)