-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Bad variable names
Some variable names in klax are currently badly chosen:
- The loss Protocol requires
data. That really should be calledbatch. To retain backwards compatibility, we should consider relaxing the Loss Protocol and only require positional arguments for the loss functions. Then the user can call the arguments whatever fits best. - All over the place,
batch_axisis used, which should be calledbatch_axes. Besides being grammatically correct in most cases where one actually specifies abatch_axes, this would make it consistent withjax.vmap(..., in_axes)and many more... - ...
Typing
There are some basic typing oversights:
-
klax.finalizeis missing an output type. Without having put much thought into this, I fear that a generic type will not be 100% correct:While this should work perfectly fine for many use cases, we should consider possible edge cases where it might be incorrect.def finalize[T: PyTree](tree: T) -> T: ...
-
klax.split_datadoes not have a generic type, but usesPyTree, meaning it usesAny. - The already mentioned Loss protocol is overly restrictive. It does not allow for the loss function arguments to be changed, and it does not allow applying custom types to the arguments.
-
klax.fitdemands aneqx.Moduleas input type, but in practice it could be any pytree, e.g., a tuple of modules. Therefore, the type annotation should be changes to a generic pytree.
Other
- The dependency on Paramax is unnecessary (?)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels