Skip to content

Conversation

@haricot
Copy link
Contributor

@haricot haricot commented Jan 7, 2025

tested and works with:

nvidia-smi   --query-gpu="compute_cap"  --format=csv
compute_cap
6.1
cargo run -F "cuda,cudnn" --example llama  --release -- --model-id meta-llama/Llama-3.2-1B-Instruct --temperature 0.1 --which v32-1b-instruct --seed 42 --dtype bf16 --prompt

related EricLBuehler#57

@LaurentMazare
Copy link
Collaborator

Shouldn't the new code be in some else blocks for the #if __CUDA_ARCH__ >= 800 blocks that contain the proper bf16 implementation?

@haricot
Copy link
Contributor Author

haricot commented Jan 7, 2025

You are certainly right because I do not encounter the condition CUDA_ARCH >= 800 in this control flow in my case but this must possibly cause a function cuda error already present, I will review this.

@LaurentMazare
Copy link
Collaborator

I think the current code is likely to result in lots of compile failures with cuda compute cap >= 8.0.
Would it work to just replace the checks for compute cap 800 with 530? If you can make the changes on an old cuda toolkit, I can also test them on hardware that is past 800.

@haricot
Copy link
Contributor Author

haricot commented Jan 16, 2025

I hope that the latest additions will allow to work on __CUDA_ARCH__ >=800.
Then I think it logical to propose more tests by considering if possible a full support of the BF16 type for __CUDA_ARCH__ >=530 .

@haricot haricot marked this pull request as draft January 18, 2025 09:31
@haricot haricot force-pushed the bf16_candle branch 8 times, most recently from ec25c81 to 09e3d0b Compare January 23, 2025 15:58
@haricot haricot changed the title add cuda fallback bf16 for compute_cap < 8.0 add cuda fallback bf16 for compute_cap >=530 <800 Jan 23, 2025
@haricot
Copy link
Contributor Author

haricot commented Oct 12, 2025

With this, we can notice for a similar token/s in f16 or bf16 (fot short sentence), the results bf16 are identical to an original model bf16 despite fallbacks, numerical fidelity preserved, model behavior unchanged.

Tests:
I used a macro_rules assert_tensor to be able to test the test functions on different types and verify that everything works. I thought it was not necessary to add it, for example it looked like this:

ex: assert_tensor! avg_pool2d (f32,bf16,f16)
    assert_tensor!(dev, (t:Tensor), (res):(Tensor)|{
            t.avg_pool2d(3)?.squeeze(0)?
        },[
        res/eq_round|(F32:to_vec3_round, each:4)(
            [[[0.085]], [[0.0078]]]
        ),
        res/eq_max|(F32-F16=>f32, cpu:4e-5, cuda:11e-5, metal:11e-5),
        res/eq_max|(F32-BF16=>f32, cpu:3e-5, cuda:5e-5, metal:5e-5),
    ]);

Issue:
Without this PR, it no longer compiles on my old laptop frozen in Cuda 12.9.

compile log
Caused by:
  process didn't exit successfully: `/tmp/candle/target/release/build/candle-kernels-ce3a0757978ddb55/build-script-build` (exit status:```
Caused by:
  process didn't exit successfully: `/tmp/candle/target/release/build/candle-kernels-ce3a0757978ddb55/build-script-build` (exit status: 101)
  --- stdout
  cargo:rerun-if-changed=build.rs
  cargo:rerun-if-changed=src/compatibility.cuh
  cargo:rerun-if-changed=src/cuda_utils.cuh
  cargo:rerun-if-changed=src/binary_op_macros.cuh
  cargo:info=["/usr", "/usr/local/cuda", "/opt/cuda", "/usr/lib/cuda", "C:/Program Files/NVIDIA GPU Computing Toolkit", "C:/CUDA"]
  cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP
  cargo:rustc-env=CUDA_COMPUTE_CAP=61
  cargo:info=Builder { cuda_root: Some("/opt/cuda"), kernel_paths: ["src/affine.cu", "src/binary.cu", "src/cast.cu", "src/conv.cu", "src/fill.cu", "src/indexing.cu", "src/quantized.cu", "src/reduce.cu", "src/sort.cu", "src/ternary.cu", "src/unary.cu"], watch: [], include_paths: ["src/binary_op_macros.cuh", "src/compatibility.cuh", "src/cuda_utils.cuh"], compute_cap: Some(61), out_dir: "/tmp/candle/target/release/build/candle-kernels-4037943e9d714c54/out", extra_args: [] }
  cargo:rustc-env=CUDA_INCLUDE_DIR=/opt/cuda/include
  cargo:rerun-if-changed=src/binary_op_macros.cuh
  cargo:rerun-if-changed=src/compatibility.cuh
  cargo:rerun-if-changed=src/cuda_utils.cuh
  cargo:rerun-if-env-changed=NVCC_CCBIN
  cargo:rerun-if-changed=src/cast.cu
  cargo:rerun-if-changed=src/indexing.cu
  cargo:rerun-if-changed=src/affine.cu
  cargo:rerun-if-changed=src/conv.cu
  cargo:rerun-if-changed=src/binary.cu
  cargo:rerun-if-changed=src/sort.cu
  cargo:rerun-if-changed=src/quantized.cu
  cargo:rerun-if-changed=src/fill.cu
  cargo:rerun-if-changed=src/ternary.cu
  cargo:rerun-if-changed=src/reduce.cu
  cargo:rerun-if-changed=src/unary.cu

  --- stderr
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  src/reduce.cu(612): error: no instance of overloaded function "atomicAdd" matches the argument list
              argument types are: (__half *, const __half)
    extern "C" __attribute__((global)) void sum_f16( const size_t numel, const size_t num_dims, const size_t num_sum_dims, const size_t *info, const __half *inp, __half *out) { const size_t *dims = info; const size_t *strides = info + num_dims; const size_t *sum_dims_l = info + 2 * num_dims; const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; if (is_contiguous(num_dims, dims, strides)) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { size_t dst_index = i; for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { size_t stride = sum_dims_s[nd]; size_t pre = dst_index / stride; size_t post = dst_index % stride; dst_index = (pre / sum_dims_l[nd]) * stride + post; } atomicAdd(out + dst_index, inp[i]); } } else { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { unsigned strided_i = get_strided_index(i, num_dims, dims, strides); size_t dst_index = i; for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { size_t stride = sum_dims_s[nd]; size_t pre = dst_index / stride; size_t post = dst_index % stride; dst_index = (pre / sum_dims_l[nd]) * stride + post; } atomicAdd(out + dst_index, inp[strided_i]); } } }
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             ^
  /opt/cuda/include/cuda_bf16.hpp(3802): note #3326-D: function "atomicAdd(__nv_bfloat162 *, __nv_bfloat162)" does not match because argument #1 does not match parameter
    static __attribute__((device)) __inline__ __nv_bfloat162 atomicAdd(__nv_bfloat162 *const address, const __nv_bfloat162 val)
                                                             ^
  /opt/cuda/include/cuda_fp16.hpp(3426): note #3326-D: function "atomicAdd(__half2 *, __half2)" does not match because argument #1 does not match parameter
    static __attribute__((device)) __inline__ __half2 atomicAdd(__half2 *const address, const __half2 val) {
                                                      ^
  /opt/cuda/include/sm_60_atomic_functions.hpp(292): note #3326-D: function "atomicAdd(double *, double)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) double atomicAdd(double *address, double val)
                                                     ^
  /opt/cuda/include/sm_20_atomic_functions.hpp(82): note #3326-D: function "atomicAdd(float *, float)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) float atomicAdd(float *address, float val)
                                                    ^
  /opt/cuda/include/device_atomic_functions.hpp(224): note #3326-D: function "atomicAdd(unsigned long long *, unsigned long long)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) unsigned long long int atomicAdd(unsigned long long int *address, unsigned long long int val)
                                                                     ^
  /opt/cuda/include/device_atomic_functions.hpp(110): note #3326-D: function "atomicAdd(unsigned int *, unsigned int)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) unsigned int atomicAdd(unsigned int *address, unsigned int val)
                                                           ^
  /opt/cuda/include/device_atomic_functions.hpp(105): note #3326-D: function "atomicAdd(int *, int)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) int atomicAdd(int *address, int val)
                                                  ^

  src/reduce.cu(612): error: no instance of overloaded function "atomicAdd" matches the argument list
              argument types are: (__half *, const __half)
    extern "C" __attribute__((global)) void sum_f16( const size_t numel, const size_t num_dims, const size_t num_sum_dims, const size_t *info, const __half *inp, __half *out) { const size_t *dims = info; const size_t *strides = info + num_dims; const size_t *sum_dims_l = info + 2 * num_dims; const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; if (is_contiguous(num_dims, dims, strides)) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { size_t dst_index = i; for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { size_t stride = sum_dims_s[nd]; size_t pre = dst_index / stride; size_t post = dst_index % stride; dst_index = (pre / sum_dims_l[nd]) * stride + post; } atomicAdd(out + dst_index, inp[i]); } } else { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { unsigned strided_i = get_strided_index(i, num_dims, dims, strides); size_t dst_index = i; for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { size_t stride = sum_dims_s[nd]; size_t pre = dst_index / stride; size_t post = dst_index % stride; dst_index = (pre / sum_dims_l[nd]) * stride + post; } atomicAdd(out + dst_index, inp[strided_i]); } } }
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           ^
  /opt/cuda/include/cuda_bf16.hpp(3802): note #3326-D: function "atomicAdd(__nv_bfloat162 *, __nv_bfloat162)" does not match because argument #1 does not match parameter
    static __attribute__((device)) __inline__ __nv_bfloat162 atomicAdd(__nv_bfloat162 *const address, const __nv_bfloat162 val)
                                                             ^
  /opt/cuda/include/cuda_fp16.hpp(3426): note #3326-D: function "atomicAdd(__half2 *, __half2)" does not match because argument #1 does not match parameter
    static __attribute__((device)) __inline__ __half2 atomicAdd(__half2 *const address, const __half2 val) {
                                                      ^
  /opt/cuda/include/sm_60_atomic_functions.hpp(292): note #3326-D: function "atomicAdd(double *, double)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) double atomicAdd(double *address, double val)
                                                     ^
  /opt/cuda/include/sm_20_atomic_functions.hpp(82): note #3326-D: function "atomicAdd(float *, float)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) float atomicAdd(float *address, float val)
                                                    ^
  /opt/cuda/include/device_atomic_functions.hpp(224): note #3326-D: function "atomicAdd(unsigned long long *, unsigned long long)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) unsigned long long int atomicAdd(unsigned long long int *address, unsigned long long int val)
                                                                     ^
  /opt/cuda/include/device_atomic_functions.hpp(110): note #3326-D: function "atomicAdd(unsigned int *, unsigned int)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) unsigned int atomicAdd(unsigned int *address, unsigned int val)
                                                           ^
  /opt/cuda/include/device_atomic_functions.hpp(105): note #3326-D: function "atomicAdd(int *, int)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) int atomicAdd(int *address, int val)
                                                  ^

  2 errors detected in the compilation of "src/reduce.cu".
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).

  thread 'main' panicked at /home/np/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/bindgen_cuda-0.1.5/src/lib.rs:391:13:
  nvcc error while compiling "src/reduce.cu":

  # CLI "nvcc" "--gpu-architecture=sm_61" "--ptx" "--default-stream" "per-thread" "--output-directory" "/tmp/candle/target/release/build/candle-kernels-4037943e9d714c54/out" "-Isrc" "-I/opt/cuda/include" "-allow-unsupported-compiler" "-ccbin" "/usr/bin/g++-14" "src/reduce.cu" 

  # stdout


  # stderr

  note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
``` 101)
  --- stdout
  cargo:rerun-if-changed=build.rs
  cargo:rerun-if-changed=src/compatibility.cuh
  cargo:rerun-if-changed=src/cuda_utils.cuh
  cargo:rerun-if-changed=src/binary_op_macros.cuh
  cargo:info=["/usr", "/usr/local/cuda", "/opt/cuda", "/usr/lib/cuda", "C:/Program Files/NVIDIA GPU Computing Toolkit", "C:/CUDA"]
  cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP
  cargo:rustc-env=CUDA_COMPUTE_CAP=61
  cargo:info=Builder { cuda_root: Some("/opt/cuda"), kernel_paths: ["src/affine.cu", "src/binary.cu", "src/cast.cu", "src/conv.cu", "src/fill.cu", "src/indexing.cu", "src/quantized.cu", "src/reduce.cu", "src/sort.cu", "src/ternary.cu", "src/unary.cu"], watch: [], include_paths: ["src/binary_op_macros.cuh", "src/compatibility.cuh", "src/cuda_utils.cuh"], compute_cap: Some(61), out_dir: "/tmp/candle/target/release/build/candle-kernels-4037943e9d714c54/out", extra_args: [] }
  cargo:rustc-env=CUDA_INCLUDE_DIR=/opt/cuda/include
  cargo:rerun-if-changed=src/binary_op_macros.cuh
  cargo:rerun-if-changed=src/compatibility.cuh
  cargo:rerun-if-changed=src/cuda_utils.cuh
  cargo:rerun-if-env-changed=NVCC_CCBIN
  cargo:rerun-if-changed=src/cast.cu
  cargo:rerun-if-changed=src/indexing.cu
  cargo:rerun-if-changed=src/affine.cu
  cargo:rerun-if-changed=src/conv.cu
  cargo:rerun-if-changed=src/binary.cu
  cargo:rerun-if-changed=src/sort.cu
  cargo:rerun-if-changed=src/quantized.cu
  cargo:rerun-if-changed=src/fill.cu
  cargo:rerun-if-changed=src/ternary.cu
  cargo:rerun-if-changed=src/reduce.cu
  cargo:rerun-if-changed=src/unary.cu

  --- stderr
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
  src/reduce.cu(612): error: no instance of overloaded function "atomicAdd" matches the argument list
              argument types are: (__half *, const __half)
    extern "C" __attribute__((global)) void sum_f16( const size_t numel, const size_t num_dims, const size_t num_sum_dims, const size_t *info, const __half *inp, __half *out) { const size_t *dims = info; const size_t *strides = info + num_dims; const size_t *sum_dims_l = info + 2 * num_dims; const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; if (is_contiguous(num_dims, dims, strides)) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { size_t dst_index = i; for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { size_t stride = sum_dims_s[nd]; size_t pre = dst_index / stride; size_t post = dst_index % stride; dst_index = (pre / sum_dims_l[nd]) * stride + post; } atomicAdd(out + dst_index, inp[i]); } } else { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { unsigned strided_i = get_strided_index(i, num_dims, dims, strides); size_t dst_index = i; for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { size_t stride = sum_dims_s[nd]; size_t pre = dst_index / stride; size_t post = dst_index % stride; dst_index = (pre / sum_dims_l[nd]) * stride + post; } atomicAdd(out + dst_index, inp[strided_i]); } } }
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             ^
  /opt/cuda/include/cuda_bf16.hpp(3802): note #3326-D: function "atomicAdd(__nv_bfloat162 *, __nv_bfloat162)" does not match because argument #1 does not match parameter
    static __attribute__((device)) __inline__ __nv_bfloat162 atomicAdd(__nv_bfloat162 *const address, const __nv_bfloat162 val)
                                                             ^
  /opt/cuda/include/cuda_fp16.hpp(3426): note #3326-D: function "atomicAdd(__half2 *, __half2)" does not match because argument #1 does not match parameter
    static __attribute__((device)) __inline__ __half2 atomicAdd(__half2 *const address, const __half2 val) {
                                                      ^
  /opt/cuda/include/sm_60_atomic_functions.hpp(292): note #3326-D: function "atomicAdd(double *, double)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) double atomicAdd(double *address, double val)
                                                     ^
  /opt/cuda/include/sm_20_atomic_functions.hpp(82): note #3326-D: function "atomicAdd(float *, float)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) float atomicAdd(float *address, float val)
                                                    ^
  /opt/cuda/include/device_atomic_functions.hpp(224): note #3326-D: function "atomicAdd(unsigned long long *, unsigned long long)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) unsigned long long int atomicAdd(unsigned long long int *address, unsigned long long int val)
                                                                     ^
  /opt/cuda/include/device_atomic_functions.hpp(110): note #3326-D: function "atomicAdd(unsigned int *, unsigned int)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) unsigned int atomicAdd(unsigned int *address, unsigned int val)
                                                           ^
  /opt/cuda/include/device_atomic_functions.hpp(105): note #3326-D: function "atomicAdd(int *, int)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) int atomicAdd(int *address, int val)
                                                  ^

  src/reduce.cu(612): error: no instance of overloaded function "atomicAdd" matches the argument list
              argument types are: (__half *, const __half)
    extern "C" __attribute__((global)) void sum_f16( const size_t numel, const size_t num_dims, const size_t num_sum_dims, const size_t *info, const __half *inp, __half *out) { const size_t *dims = info; const size_t *strides = info + num_dims; const size_t *sum_dims_l = info + 2 * num_dims; const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; if (is_contiguous(num_dims, dims, strides)) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { size_t dst_index = i; for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { size_t stride = sum_dims_s[nd]; size_t pre = dst_index / stride; size_t post = dst_index % stride; dst_index = (pre / sum_dims_l[nd]) * stride + post; } atomicAdd(out + dst_index, inp[i]); } } else { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { unsigned strided_i = get_strided_index(i, num_dims, dims, strides); size_t dst_index = i; for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { size_t stride = sum_dims_s[nd]; size_t pre = dst_index / stride; size_t post = dst_index % stride; dst_index = (pre / sum_dims_l[nd]) * stride + post; } atomicAdd(out + dst_index, inp[strided_i]); } } }
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           ^
  /opt/cuda/include/cuda_bf16.hpp(3802): note #3326-D: function "atomicAdd(__nv_bfloat162 *, __nv_bfloat162)" does not match because argument #1 does not match parameter
    static __attribute__((device)) __inline__ __nv_bfloat162 atomicAdd(__nv_bfloat162 *const address, const __nv_bfloat162 val)
                                                             ^
  /opt/cuda/include/cuda_fp16.hpp(3426): note #3326-D: function "atomicAdd(__half2 *, __half2)" does not match because argument #1 does not match parameter
    static __attribute__((device)) __inline__ __half2 atomicAdd(__half2 *const address, const __half2 val) {
                                                      ^
  /opt/cuda/include/sm_60_atomic_functions.hpp(292): note #3326-D: function "atomicAdd(double *, double)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) double atomicAdd(double *address, double val)
                                                     ^
  /opt/cuda/include/sm_20_atomic_functions.hpp(82): note #3326-D: function "atomicAdd(float *, float)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) float atomicAdd(float *address, float val)
                                                    ^
  /opt/cuda/include/device_atomic_functions.hpp(224): note #3326-D: function "atomicAdd(unsigned long long *, unsigned long long)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) unsigned long long int atomicAdd(unsigned long long int *address, unsigned long long int val)
                                                                     ^
  /opt/cuda/include/device_atomic_functions.hpp(110): note #3326-D: function "atomicAdd(unsigned int *, unsigned int)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) unsigned int atomicAdd(unsigned int *address, unsigned int val)
                                                           ^
  /opt/cuda/include/device_atomic_functions.hpp(105): note #3326-D: function "atomicAdd(int *, int)" does not match because argument #1 does not match parameter
    static __inline__ __attribute__((device)) int atomicAdd(int *address, int val)
                                                  ^

  2 errors detected in the compilation of "src/reduce.cu".
  nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).

  thread 'main' panicked at /home/np/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/bindgen_cuda-0.1.5/src/lib.rs:391:13:
  nvcc error while compiling "src/reduce.cu":

  # CLI "nvcc" "--gpu-architecture=sm_61" "--ptx" "--default-stream" "per-thread" "--output-directory" "/tmp/candle/target/release/build/candle-kernels-4037943e9d714c54/out" "-Isrc" "-I/opt/cuda/include" "-allow-unsupported-compiler" "-ccbin" "/usr/bin/g++-14" "src/reduce.cu" 

  # stdout


  # stderr

  note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

In resolving conflicts, i added fallback fp8 related #2989

@haricot haricot changed the title add cuda fallback bf16 for compute_cap >=530 <800 CUDA: Add full BF16 and FP8 fallback support for compute >=530 Oct 12, 2025
@haricot haricot marked this pull request as ready for review October 12, 2025 11:48
@haricot haricot changed the title CUDA: Add full BF16 and FP8 fallback support for compute >=530 CUDA: Add (full BF16) and FP8 fallback support for compute >=530 Oct 15, 2025
@haricot haricot marked this pull request as draft October 30, 2025 12:27
@haricot haricot changed the title CUDA: Add (full BF16) and FP8 fallback support for compute >=530 CUDA: Add (full BF16) and FP8 support for CC < 700 Jan 26, 2026
@haricot haricot force-pushed the bf16_candle branch 3 times, most recently from ba56fc9 to d5f31ea Compare January 26, 2026 19:28
@haricot haricot changed the title CUDA: Add (full BF16) and FP8 support for CC < 700 CUDA: Add (full BF16) and FP8 support for CC < 700 and 800 Feb 1, 2026
@haricot haricot marked this pull request as ready for review February 1, 2026 17:48
@haricot haricot changed the title CUDA: Add (full BF16) and FP8 support for CC < 700 and 800 CUDA: Add (full BF16) and FP8 support for CC < 700-800 Feb 1, 2026
@haricot
Copy link
Contributor Author

haricot commented Feb 1, 2026

I added ALLOW_LEGACY_BF16 and ALLOW_LEGACY_FP8 as well as moe_hfma2 (wwma fallback solution (Tests passed, but testing in real-world conditions is needed)) for CC < 700-800. related #3331

@haricot haricot changed the title CUDA: Add (full BF16) and FP8 support for CC < 700-800 Add CC register capabilities in Rust and in the CUDA builder and optional emulation bf16 fp8 Feb 12, 2026
@haricot haricot marked this pull request as draft February 12, 2026 21:30
@haricot haricot marked this pull request as draft February 12, 2026 21:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants