Skip to content

Conversation

@brandon-b-miller
Copy link
Contributor

This PR allows a user to register extra arg handling for any type they wish to pass to a kernel. Arg handlers must inherit from ArgHandlerBase and implement the required handling. Then, a user may pass it to register_arg_handler, potentially upon the import of a library that expects to pass a custom object to a numba kernel.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 3, 2025

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@gmarkall
Copy link
Contributor

gmarkall commented Oct 3, 2025

Adding to the pre-populated list of extensions will increase the launch time for all kernels whether they use the extension or not - I'm not sure how great the impact will be, but I am concerned it could increase kernel launch time by a noticeable percentage. If we go with this route (as opposed to globally registered extensions being tried right before the NotImplementedError in _prepare_args) can we make some measurements to determine the impact on launch time, so we can understand the tradeoff we're making in ease of use vs. speed?

@gmarkall gmarkall added the 2 - In Progress Currently a work in progress label Oct 7, 2025
@ZzEeKkAa
Copy link
Contributor

ZzEeKkAa commented Oct 21, 2025

Regarding performance - could we add support through the type's method instead of global registry? It should not impact existing kernel launches and will have O(1) overhead for those types that actually need custom support.

UPD: I looked at the implementation - it is map lookup, so should be also O(1) overhead per argument. We should greatly improve performance by switching existing logic to register_arg_handler

@brandon-b-miller
Copy link
Contributor Author

/ok to test

@brandon-b-miller
Copy link
Contributor Author

Regarding performance - could we add support through the type's method instead of global registry? It should not impact existing kernel launches and will have O(1) overhead for those types that actually need custom support.

UPD: I looked at the implementation - it is map lookup, so should be also O(1) overhead per argument. We should greatly improve performance by switching existing logic to register_arg_handler

Agree. @gmarkall does the lookup approach alleviate your concern about perf?

@brandon-b-miller
Copy link
Contributor Author

/ok to test

@brandon-b-miller brandon-b-miller added 3 - Ready for Review Ready for review by team and removed 2 - In Progress Currently a work in progress labels Oct 27, 2025
@brandon-b-miller
Copy link
Contributor Author

/ok to test

Copy link
Contributor

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had another look and made some comments towards simplifying the API - I think it was always overcomplicated, so it would be good not to propagate that needless complication into a new API.

I think the testing is a bit light, and have some suggestions.

I also note that there's no documentation, but I don't think we need to block this PR for it because there's absolutely no documentation for argument handling extensions at all already. Once we're in a place where we have an API we're happy with, then I think it will be good to document it.

One other thought, which I'm not suggesting is a good idea, but could be a basis for further thought on the API: for a short while I considered that perhaps typeof implementations should register arg handlers, so that it's not necessary to manually register them at all. However, typeof is in the critical path for kernel launch, so it should not be fiddling around registering things. I wonder if there's another way we can avoid explicit registration at all, more closely integrated in the way typing works (similar to how it's not necessary to tell the jit decorator what to link if the extension adds its required files to the link during lowering).

@gmarkall gmarkall added 4 - Waiting on author Waiting for author to respond to review and removed 3 - Ready for Review Ready for review by team labels Nov 18, 2025
@brandon-b-miller
Copy link
Contributor Author

/ok to test

def __init__(self, arr):
self.arr = arr

def numpy_array_wrapper_int32_arg_handler(ty, val, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here ty will be the Numba type from typeof, which will be types.int32[::1] in this case, but would also match any other type that comes out of numpy_array_wrapper_int32_typeof_impl. Would it make a more exemplar test to just return ty, val.arr, here?

(If there's a flaw in my logic / understanding, please do let me know)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, why make these specific to int32 types at all? Apart from the name, and the hardcoded types.int32[::1], it looks like these could be fully generic for a wrapper holding any dtype / shape NumPy array?

Copy link
Contributor Author

@brandon-b-miller brandon-b-miller Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is test code, I prefer to spell things out in some cases. Here we're testing the effect of one or more handlers on one or more classes, it makes sense to me to write them separately and somewhat verbosely for readability as to what we're testing, even if things are repeated.

Copy link
Contributor

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for changing this - I think the API design looks much nicer now.

A couple of other notes:

  • I don't think we should allow registering multiple handlers for the same Python type. This seems like it would most likely be a user error, or cause user surprise.
  • There are some other notes on the diff on the tests.

@brandon-b-miller
Copy link
Contributor Author

/ok to test

@brandon-b-miller
Copy link
Contributor Author

/ok to test

ty, val = extension.prepare_args(
ty, val, stream=stream, retr=retr
)
elif handler := _arg_handlers.get(type(val)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the if / elif logic here is correct. It seems to implement "If there are any extensions registered for this kernel, ignore all globally-registered extensions". But I could imagine a scenario where there is a globally-registered extension for one argument, and a "locally"-registered one for another argument - I think this situation is precluded from working in the current implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this section of the code might have been what pushed me towards having a class that the user supplies with a particular interface in the original implementation. The existing API requires we pass a class with a _prepare_args method. If we expect the same thing from the new API, it lets us assemble a single container of handlers we end up passing along the existing codepath, and we don't have to have deal with two kinds / can error if we find a duplicate along the way.

@brandon-b-miller
Copy link
Contributor Author

@gmarkall should we instead be changing the existing extensions interface to be more like the one we're moving towards in this PR now? I note that the existing extensions API does not require a typeof_impl, meaning it only works for the case where a signature is provided as well.

Either way we go, I think these two APIs should expect to be passed the same things, and there may be some flexibility due to the lack of existing docs. We can go with the function vs the class if we desire, but the more I think about it the more I think the typeof_impl requirement should be optional. It doesn't buy the person registering the handler much over just calling typeof_impl.register on import in their own library, but it does make it hard to normalize the extensions api with the register_arg_handler API. What do you think?

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 21, 2025

Greptile Summary

This PR introduces a new register_arg_handler function that allows users to register custom argument handlers for kernel invocations. The implementation adds a global _arg_handlers dictionary and modifies _prepare_args to check for registered handlers before processing arguments.

Key changes:

  • Added global _arg_handlers registry for type-to-handler mappings
  • Modified _prepare_args method to support both registered handlers and passed extensions with proper conflict detection
  • Implemented mutation tracking to prevent both registered handlers and extensions from processing the same argument
  • Added collision detection in register_arg_handler to prevent multiple handlers for the same type
  • Comprehensive test suite covering basic registration, multiple handlers, collision detection, and extension interaction

Previous review concerns addressed:

  • The dispatcher now properly supports both registered handlers and passed extensions working together (previous elif issue resolved)
  • Collision checking happens before any registration occurs, preventing partial registration state
  • Test extension properly validates argument types before processing

Confidence Score: 5/5

  • This PR is safe to merge with well-designed architecture and thorough testing
  • The implementation correctly addresses all concerns from previous reviews. The collision detection logic properly checks all types before registering any, preventing partial registration. The mutation tracking mechanism cleanly prevents conflicts between registered handlers and passed extensions. The test suite comprehensively covers edge cases including collision detection and handler/extension interaction.
  • No files require special attention

Important Files Changed

Filename Overview
numba_cuda/numba/cuda/dispatcher.py Adds global arg handler registration with proper collision detection and extension compatibility through mutation tracking
numba_cuda/numba/cuda/tests/cudapy/test_extending.py Comprehensive test coverage for arg handler registration including collision detection and extension interaction

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 2131 to 2137
for ty in handled_types:
if _arg_handlers.get(ty, None):
raise ValueError(
f"A handler for args of type {ty} is already registered."
)
typeof_impl.register(ty)(impl)
_arg_handlers[ty] = handler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: If handled_types contains multiple types and a collision occurs on a later type, earlier types will already be registered in both _arg_handlers and typeof_impl, leaving partial registration state

Suggested change
for ty in handled_types:
if _arg_handlers.get(ty, None):
raise ValueError(
f"A handler for args of type {ty} is already registered."
)
typeof_impl.register(ty)(impl)
_arg_handlers[ty] = handler
# Check all types for collisions first
for ty in handled_types:
if _arg_handlers.get(ty, None):
raise ValueError(
f"A handler for args of type {ty} is already registered."
)
# If all checks pass, register all types
for ty in handled_types:
typeof_impl.register(ty)(impl)
_arg_handlers[ty] = handler

@brandon-b-miller
Copy link
Contributor Author

@gmarkall I've made some simplifications and would like to circle back to this if possible. I have:

  • For now removed the no-signature cases. The current registration API doesn't support this either, since it adds a significant number of test cases and complexity I would like to peel this off into a follow up.
  • I'd like to make a decision about if we want to register using a function or a class with _prepare_args. I agree with your previous comment that there's not been any state or inheritance used historically with these, but I do think it's important to either always pass classes or always pass functions. In the API we're adding, we now pass functions. Should the old API change? Should it be deprecated? Should we just be consistent and go back to passing classes in this PR?

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 539 to 545
if self.extensions:
for extension in reversed(self.extensions):
ty, val = extension.prepare_args(
ty, val, stream=stream, retr=retr
)
elif handler := _arg_handlers.get(type(val)):
ty, val = handler(ty, val, stream=stream, retr=retr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The elif prevents registered arg handlers from being used when self.extensions exists. This breaks the intended behavior in test_register_arg_handler_and_pass where both registered handlers and passed extensions should work together.

Change elif to a separate if block, or add fallback logic within the extensions loop:

Suggested change
if self.extensions:
for extension in reversed(self.extensions):
ty, val = extension.prepare_args(
ty, val, stream=stream, retr=retr
)
elif handler := _arg_handlers.get(type(val)):
ty, val = handler(ty, val, stream=stream, retr=retr)
if self.extensions:
for extension in reversed(self.extensions):
ty, val = extension.prepare_args(
ty, val, stream=stream, retr=retr
)
if handler := _arg_handlers.get(type(val)):
ty, val = handler(ty, val, stream=stream, retr=retr)

Comment on lines 2120 to 2125
for ty in handled_types:
if _arg_handlers.get(ty, None):
raise ValueError(
f"A handler for args of type {ty} is already registered."
)
_arg_handlers[ty] = handler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: If handled_types contains multiple types and collision check fails on a later type, earlier types remain registered in _arg_handlers, leaving partial registration state.

Check all types for collisions before registering any:

Suggested change
for ty in handled_types:
if _arg_handlers.get(ty, None):
raise ValueError(
f"A handler for args of type {ty} is already registered."
)
_arg_handlers[ty] = handler
for ty in handled_types:
if _arg_handlers.get(ty, None):
raise ValueError(
f"A handler for args of type {ty} is already registered."
)
for ty in handled_types:
_arg_handlers[ty] = handler

Comment on lines 497 to 499
class PassedArgHandler_float32:
def prepare_args(self, ty, val, **kwargs):
return types.float32[::1], val.arr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: This extension doesn't check the type of val before processing. With the current dispatcher logic using elif at dispatcher.py:544, registered handlers are never called when extensions exist, so this extension would be invoked for ALL argument types (including NumpyArrayWrapper_int32), causing it to incorrectly return float32[::1] type for int32 arrays.

Add type checking:

Suggested change
class PassedArgHandler_float32:
def prepare_args(self, ty, val, **kwargs):
return types.float32[::1], val.arr
class PassedArgHandler_float32:
def prepare_args(self, ty, val, **kwargs):
if isinstance(val, TestArgHandlerRegistration.NumpyArrayWrapper_float32):
return types.float32[::1], val.arr
return ty, val

@brandon-b-miller
Copy link
Contributor Author

/ok to test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

4 - Waiting on author Waiting for author to respond to review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants