Skip to content

Conversation

@isVoid
Copy link
Contributor

@isVoid isVoid commented Dec 26, 2025

Summary

Adds FP8 support in numba-cuda by introducing generated bindings for CUDA FP8 types and conversion intrinsics (including saturation + rounding-mode variants), wiring them into CUDA typing/target registries, and adding a comprehensive test suite. Also adds a public runtime check for FP8 hardware acceleration.

Motivation / Context

  • Enable first-class use of CUDA FP8 types (__nv_fp8_*) and conversion intrinsics from Numba CUDA kernels.
  • Provide a capability API (cuda.is_fp8_accelerated()) so users can branch on HW-accelerated FP8 vs software-emulated behavior.
  • Lock down behavior with thorough tests (constructors, conversions, NaNs, saturation, rounding modes).

What’s in this PR

FP8 bindings + API surface

  • Autogenerated bindings: numba_cuda/numba/cuda/_internal/cuda_fp8.py
  • Public registry shim: numba_cuda/numba/cuda/fp8.py (exports typing_registry + target_registry)
  • Toolkit headers vendored:
    • numba_cuda/numba/cuda/include/12/cuda_fp8.h
    • numba_cuda/numba/cuda/include/12/cuda_fp8.hpp
    • numba_cuda/numba/cuda/include/13/cuda_fp8.h
    • numba_cuda/numba/cuda/include/13/cuda_fp8.hpp
  • Binding generator config: configs/cuda_fp8.yml

Compiler integration

  • Registers FP8 typing/target lowering via numba_cuda/numba/cuda/target.py (fp8.typing_registry, fp8.target_registry).

Public capability API

  • New: numba_cuda/numba/cuda/api.py::is_fp8_accelerated()
  • Backed by:
    • numba_cuda/numba/cuda/cudadrv/driver.py::Device.accelerates_fp8 (CC >= 8.9)
    • numba_cuda/numba/cuda/simulator/api.py::is_fp8_accelerated() → always False
    • Simulator detect() output reports FP8 HW acceleration status.

Tests

  • New: numba_cuda/numba/cuda/tests/cudapy/test_fp8_bindings.py
  • Covers:
    • Constructors from float/int types
    • Conversions to/from float types
    • NaN handling
    • Saturation (saturation_t)
    • Rounding-mode variants using cuda.bindings.runtime.cudaRoundMode (notably for E8M0 paths)

User-facing API changes

  • New: numba.cuda.is_fp8_accelerated() returns True when device supports FP8 HW acceleration; otherwise False.
  • New internal bindings: numba.cuda._internal.cuda_fp8 providing FP8 types and conversion intrinsics (available for advanced use).

Testing

Added test coverage via numba_cuda/numba/cuda/tests/cudapy/test_fp8_bindings.py.

python -m pytest -q numba_cuda/numba/cuda/tests/cudapy/test_fp8_bindings.py

Notes

  • numba_cuda/numba/cuda/_internal/cuda_fp8.py is autogenerated (see header for generator/version).
  • FP8 operations still work on older GPUs via software emulation, but cuda.is_fp8_accelerated() is the recommended guard for performance-sensitive paths.

Follow ups

  • Provide "raw-8-bit" constructor from uint8 type. I intend to combine with raw bf16 construction in a single PR.
  • Better error handling (unsupported round mode failure message)
  • High level API design and documentation

@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 26, 2025

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 26, 2025

Greptile Summary

This PR successfully adds comprehensive FP8 type support to numba-cuda by introducing auto-generated bindings for CUDA FP8 types (E5M2, E4M3, E8M0), proper compiler integration via typing and target registries, and a public hardware-acceleration detection API. The implementation closely follows established patterns from existing fp16/bf16 support and includes thorough test coverage with proper edge case handling (NaN, saturation modes, type conversions).

Key Changes:

  • Auto-generated FP8 bindings module with three FP8 types and conversion intrinsics
  • Public capability API cuda.is_fp8_accelerated() for users to branch on HW acceleration (CC >= 8.9)
  • Device property accelerates_fp8 and simulator stub with appropriate defaults
  • Compiler integration via registry installation in typing and target contexts
  • Comprehensive test suite covering constructors, conversions, rounding modes, and special values
  • Proper CUDA header inclusion and binding configuration

Integration Quality:

  • Follows established numba-cuda patterns for custom CUDA types (consistent with bf16/fp16)
  • Clean separation between auto-generated internal bindings and public wrapper API
  • Simulator mode correctly reports FP8 as non-accelerated
  • Device detection output updated appropriately

Confidence Score: 5/5

  • This PR is safe to merge with high confidence - well-architected FP8 support following established patterns, comprehensive testing, and no identified issues.
  • Score reflects: (1) Auto-generated bindings validated against established fp16/bf16 patterns; (2) Proper registry integration in compiler pipeline with clean separation of concerns; (3) Hardware acceleration detection correctly implemented with appropriate compute capability threshold (8.9); (4) Simulator mode properly handles FP8 as non-accelerated; (5) Comprehensive test coverage with constructors, conversions, NaN handling, and saturation modes; (6) No breaking changes to existing APIs; (7) Configuration and headers properly vendored; (8) Device detection output appropriately updated.
  • No files require special attention. All changes are properly integrated and follow established patterns.

Important Files Changed

Filename Overview
numba_cuda/numba/cuda/_internal/cuda_fp8.py Auto-generated FP8 bindings with proper type definitions (fp8_e5m2, fp8_e4m3, fp8_e8m0), data models, and intrinsic function lowering using FunctionCallConv. Follows established patterns from fp16/bf16 modules.
numba_cuda/numba/cuda/api.py Added is_fp8_accelerated() function that correctly delegates to device's accelerates_fp8 property (CC >= 8.9). Consistent with existing is_bfloat16_supported() pattern. Includes proper docstring.
numba_cuda/numba/cuda/cudadrv/driver.py Added accelerates_fp8 property to Device class with correct CC threshold (8.9). Properly documented with note about software emulation on older devices.
numba_cuda/numba/cuda/target.py Correctly imports fp8 module and registers both typing and target registries in CUDATypingContext and CUDATargetContext. Follows established pattern from fp16/bf16.
numba_cuda/numba/cuda/tests/cudapy/test_fp8_bindings.py Comprehensive test coverage for FP8 types including constructors, conversions, NaN handling, and saturation modes. Tests are well-structured but lack some edge case tests (e.g., overflow behavior, corner case rounding modes). Proper use of cuda.jit and np.testing patterns.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. numba_cuda/numba/cuda/include/13/cuda_fp8.h, line 65 (link)

    syntax: typo in macro condition comment - should be 'CUDACC' not '_CUDACC'

  2. numba_cuda/numba/cuda/tests/cudapy/test_fp8_bindings.py, line 368 (link)

    logic: using np.array([float("nan")] * 9) creates array of NaN values but assertion may not behave as expected since NaN != NaN

  3. numba_cuda/numba/cuda/tests/cudapy/test_fp8_bindings.py, line 294 (link)

    logic: result array declared as uint64 but should be int64 or float64 since zero conversions include both int64(zero) and float32(zero)

  4. numba_cuda/numba/cuda/_internal/cuda_fp8.py, line 832-853 (link)

    logic: Another duplicate function definition for int8 conversion with different C++ type mapping (char vs signed char)

13 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@isVoid
Copy link
Contributor Author

isVoid commented Dec 28, 2025

I addressed point 3 and 4.

  1. Will not address since it's vended code.
  2. numpy.testing.assert_array_equal does not raise when nan appears at the same index for two arrays even if elementwise comparison of nan==nan is false.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. numba_cuda/numba/cuda/tests/cudapy/test_fp8_bindings.py, line 242-250 (link)

    style: Commented-out test cases for E5M2 and E4M3 negative values should be removed or implemented if they're valid test scenarios. Are these commented test cases intentionally disabled due to known limitations or should they be implemented?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  2. numba_cuda/numba/cuda/tests/cudapy/test_fp8_bindings.py, line 368 (link)

    logic: Using np.array([float('nan')] * 9) for NaN comparison may not work reliably since NaN != NaN - consider using np.testing.assert_array_equal with np.isnan() checks instead

  3. numba_cuda/numba/cuda/_internal/cuda_fp8.py, line 784-825 (link)

    logic: Duplicate function definition detected - both _from___nv_fp8_e5m2_to_uint64_lower functions register the same @lower_cast(_type_fp8_e5m2, uint64) decorator but with different mangled names (_ZNK13__nv_fp8_e5m2cvmEv vs _ZNK13__nv_fp8_e5m2cvyEv). The second registration will override the first, potentially causing runtime issues. Is this intentional behavior where unsigned long and unsigned long long both map to uint64, or should these have different target types?

13 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 28, 2025

Greptile found no issues!

From now on, if a review finishes and we haven't found any issues, we will not post anything, but you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@isVoid
Copy link
Contributor Author

isVoid commented Dec 28, 2025

/ok to test

@isVoid
Copy link
Contributor Author

isVoid commented Jan 8, 2026

/ok to test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds comprehensive FP8 type support to numba-cuda, introducing bindings for CUDA's native FP8 types (E5M2, E4M3, E8M0) along with conversion intrinsics and hardware acceleration detection.

Key Changes

  • New public API: cuda.is_fp8_accelerated() returns whether the current device has native FP8 hardware support (CC >= 8.9)
  • Autogenerated bindings: ~3900 lines of generated code providing type definitions, constructors, conversions, and intrinsics for all FP8 variants
  • Compiler integration: FP8 types registered in both typing and target contexts, enabling seamless use in JIT kernels
  • Comprehensive tests: 550+ lines covering constructors, type conversions, NaN handling, saturation modes, and rounding behaviors
  • Simulator support: FP8 API works in simulation mode (always returns False for hardware acceleration)
  • Header vendoring: CUDA Toolkit 12 and 13 FP8 headers included for binding generation

Implementation Quality

The implementation follows established patterns from existing FP16/BF16 support. The autogenerated bindings are excluded from linting (appropriate for generated code). Tests are thorough and cover edge cases including overflow behavior with saturation modes and rounding mode variants for E8M0 conversions.

The hardware detection logic correctly identifies Ada Lovelace and Hopper architectures (CC 8.9+) as having native FP8 acceleration, while older devices fall back to software emulation.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • Well-structured implementation following existing patterns for FP16/BF16, comprehensive test coverage, autogenerated bindings reduce manual error risk, and clear separation of concerns across files
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
numba_cuda/numba/cuda/api.py 5/5 Adds is_fp8_accelerated() function to check FP8 hardware support, and updates detect() to display FP8 acceleration status
numba_cuda/numba/cuda/cudadrv/driver.py 5/5 Adds accelerates_fp8 property to Device class that checks compute capability >= 8.9 for hardware FP8 acceleration
numba_cuda/numba/cuda/fp8.py 5/5 Simple registry shim that exports typing and target registries from autogenerated FP8 bindings
numba_cuda/numba/cuda/target.py 5/5 Integrates FP8 type registries into CUDA typing and target contexts for compiler support
configs/cuda_fp8.yml 5/5 Configuration file for FP8 binding generator specifying types, data models, and code generation parameters
numba_cuda/numba/cuda/_internal/cuda_fp8.py 5/5 Autogenerated bindings (~3900 lines) for CUDA FP8 types (E5M2/E4M3/E8M0) and conversion intrinsics
numba_cuda/numba/cuda/tests/cudapy/test_fp8_bindings.py 5/5 Comprehensive test suite covering FP8 constructors, conversions, NaN handling, saturation modes, and rounding
numba_cuda/numba/cuda/tests/cudapy/test_fp8_e4m3_int8_debug.py 5/5 Debug test for fp8_e4m3 to int8 conversions with comprehensive assertions and diagnostic output

@isVoid
Copy link
Contributor Author

isVoid commented Jan 20, 2026

/ok to test

Copy link
Contributor

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good to me. Obviously we need to get to the bottom of the A100 issues, but otherwise I think everything is good.

Edited: Originally I mentioned "L4 issues" as well, but then I saw that that fail was just a network error. Only the A100 shows issues, not L4 as well.

@gmarkall
Copy link
Contributor

I had a quick look at the FP8 headers, and I think that things are going wrong in __nv_cvt_fp8_to_halfraw but I haven't yet looked into why... I think it should be possible to debug by compiling for the host and using the -6 input value then inspecting what's going on though.

That said, I wonder if it's acceptable and more straightforward to only support FP8 on Hopper / Ada and beyond?

Implementation of the function lifted from cuda_fp16.hpp (from line 472) for reference below:

__CUDA_HOSTDEVICE_FP8_DECL__ __half_raw
__nv_cvt_fp8_to_halfraw(const __nv_fp8_storage_t x,
                        const __nv_fp8_interpretation_t fp8_interpretation) {
    __half_raw res;
    res.x = 0U;
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
    res.x =
        __nv_cvt_fp8x2_to_halfraw2((__nv_fp8x2_storage_t)x, fp8_interpretation)
            .x;
#else
    unsigned short int ur = (unsigned short int)x;
    ur = (unsigned short int)(ur << 8U);

    if (fp8_interpretation == __NV_E5M2) {
        if ((ur & 0x7FFFU) > 0x7C00U) {
            /* If NaN, return canonical NaN */
            ur = 0x7FFFU;
        }
    } else { // __NV_E4M3
        unsigned short int sign = ur & 0x8000U;
        unsigned short int exponent =
            (unsigned short int)(((ur & 0x7800U) >> 1U) + 0x2000U);
        unsigned short int mantissa = (ur & 0x0700U) >> 1U;
        unsigned char absx = 0x7FU & (unsigned char)x;

        if (absx == 0x7FU) // NaN
        {
            ur = 0x7FFFU; // fp16 canonical NaN, discard sign
        } else if (exponent == 0x2000U) {
            // zero or denormal
            if (mantissa != 0U) {
                // normalize
                mantissa = (unsigned short int)(mantissa << 1U);
                while ((mantissa & 0x0400U) == 0U) {
                    mantissa = (unsigned short int)(mantissa << 1U);
                    exponent = (unsigned short int)(exponent - 0x0400U);
                }
                // discard implicit leading bit
                mantissa &= 0x03FFU;
            } else { // Zero
                exponent = 0U;
            }

            ur = (sign | exponent) | mantissa;
        } else {
            ur = (sign | exponent) | mantissa;
        }
    }
    res.x = ur;
#endif
    return res;
}

@gmarkall gmarkall added the 4 - Waiting on author Waiting for author to respond to review label Jan 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

4 - Waiting on author Waiting for author to respond to review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants