-
Notifications
You must be signed in to change notification settings - Fork 101
Refine JIT wrappers for new JAX for comaptiblity with jax>=0.8.2
#809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Reviewer's GuideRefactors BrainPy to be compatible with newer JAX (>=0.8.2) by updating the JIT wrappers, cleaning up imports and style across many modules, and adding a small smoke test entry point, without changing core numerical logic. Sequence diagram for updated jit wrapper behaviorsequenceDiagram
participant User as UserCode
participant BPJit as brainpy_math_object_transform_jit
participant BSTransform as brainstate_transform
participant compiled_func as compiled_func
User->>BPJit: jit(func, static_argnums, static_argnames, donate_argnums, inline, keep_unused, **kwargs)
activate BPJit
BPJit->>BPJit: warp_to_no_state_input_output(func)
BPJit->>BSTransform: jit(wrapped_func, static_argnums, static_argnames, donate_argnums, inline, keep_unused, **kwargs)
activate BSTransform
BSTransform-->>BPJit: compiled_func
deactivate BSTransform
BPJit-->>User: compiled_func
deactivate BPJit
User->>compiled_func: call_compiled(*args, **kwargs)
activate compiled_func
compiled_func-->>User: results
deactivate compiled_func
Sequence diagram for cls_jit method wrappingsequenceDiagram
participant User as UserClassDefinition
participant BPJit as brainpy_math_object_transform_jit
participant BSTransform as brainstate_transform
User->>BPJit: cls_jit(func, static_argnums, static_argnames, inline, keep_unused, **kwargs)
activate BPJit
BPJit->>BPJit: wrap func to bind self as first argument
BPJit->>BPJit: call jit(wrapped_method, static_argnums, static_argnames, donate_argnums, inline, keep_unused, **kwargs)
BPJit->>BSTransform: jit(wrapped_method, ...)
activate BSTransform
BSTransform-->>BPJit: compiled_method
deactivate BSTransform
BPJit-->>User: compiled_method (to be attached as bound method)
deactivate BPJit
User->>User: attach compiled_method to class instances
Class diagram for jit and ProgressBar related utilitiesclassDiagram
class brainpy_math_object_transform_controls {
+_convert_progress_bar_to_pbar(progress_bar) brainstate_transform_ProgressBar
}
class brainpy_math_object_transform_jit {
+jit(func, static_argnums, static_argnames, donate_argnums, inline, keep_unused, kwargs) Callable
+cls_jit(func, static_argnums, static_argnames, inline, keep_unused, kwargs) Callable
}
class brainpy_math_object_transform__utils {
+warp_to_no_state_input_output(func) Callable
}
class brainstate_transform_ProgressBar {
+freq
}
class brainstate_transform {
+jit(func, static_argnums, static_argnames, donate_argnums, inline, keep_unused, kwargs) Callable
}
brainpy_math_object_transform_controls ..> brainstate_transform_ProgressBar : uses
brainpy_math_object_transform_jit ..> brainpy_math_object_transform__utils : uses warp_to_no_state_input_output
brainpy_math_object_transform_jit ..> brainstate_transform : wraps jit
brainpy_math_object_transform_controls ..> brainstate_transform : creates ProgressBar instances
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
@sourcery-ai title |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - I've found 1 issue, and left some high level feedback:
- In
brainpy/dyn/projections/utils.py,_get_return()still referencesbmbut thefrom brainpy import math as bmimport was removed, which will cause a NameError at runtime; restore or replace this import. - The new
brainpy/test_main.pytest function only imports and printsbrainpy.statewithout any assertions or checks; consider turning this into an actual smoke test by asserting basic properties of the import instead of printing. - Several added
if __name__ == '__main__': ...blocks (e.g., inmath/object_transform/__init__.py,brainpy/__init__.py,channels.py) simply reference symbols without doing anything and can be removed to avoid dead, confusing code paths.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `brainpy/dyn/projections/utils.py`, `_get_return()` still references `bm` but the `from brainpy import math as bm` import was removed, which will cause a NameError at runtime; restore or replace this import.
- The new `brainpy/test_main.py` test function only imports and prints `brainpy.state` without any assertions or checks; consider turning this into an actual smoke test by asserting basic properties of the import instead of printing.
- Several added `if __name__ == '__main__': ...` blocks (e.g., in `math/object_transform/__init__.py`, `brainpy/__init__.py`, `channels.py`) simply reference symbols without doing anything and can be removed to avoid dead, confusing code paths.
## Individual Comments
### Comment 1
<location> `brainpy/math/object_transform/__init__.py:43-44` </location>
<code_context>
-
-
if __name__ == '__main__':
connect
initialize, # weight initialization
</code_context>
<issue_to_address>
**suggestion:** The `__main__` block referencing `ProgressBar` appears to be dead code.
This `if __name__ == '__main__':` block has no side effects (the expressions are evaluated and discarded) and won’t run in normal library usage. If it isn’t needed, consider removing it; if it is, replace it with explicit calls or move the relevant logic to import time to make its purpose clear.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| if __name__ == '__main__': | ||
| ProgressBar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: The __main__ block referencing ProgressBar appears to be dead code.
This if __name__ == '__main__': block has no side effects (the expressions are evaluated and discarded) and won’t run in normal library usage. If it isn’t needed, consider removing it; if it is, replace it with explicit calls or move the relevant logic to import time to make its purpose clear.
jax>=0.8.2
Summary by Sourcery
Make minor refactors and formatting cleanups across the codebase while adjusting JIT wrappers for newer JAX versions and adding a basic import test.
Enhancements:
Tests: