Skip to content

Minor issues that keep annoying me #70

@Drenderer

Description

@Drenderer

Bad variable names

Some variable names in klax are currently badly chosen:

  • The loss Protocol requires data. That really should be called batch. To retain backwards compatibility, we should consider relaxing the Loss Protocol and only require positional arguments for the loss functions. Then the user can call the arguments whatever fits best.
  • All over the place, batch_axis is used, which should be called batch_axes. Besides being grammatically correct in most cases where one actually specifies a batch_axes, this would make it consistent with jax.vmap(..., in_axes) and many more...
  • ...

Typing

There are some basic typing oversights:

  • klax.finalize is missing an output type. Without having put much thought into this, I fear that a generic type will not be 100% correct:
    def finalize[T: PyTree](tree: T) -> T: ...
    While this should work perfectly fine for many use cases, we should consider possible edge cases where it might be incorrect.
  • klax.split_data does not have a generic type, but uses PyTree, meaning it uses Any.
  • The already mentioned Loss protocol is overly restrictive. It does not allow for the loss function arguments to be changed, and it does not allow applying custom types to the arguments.
  • klax.fit demands an eqx.Module as input type, but in practice it could be any pytree, e.g., a tuple of modules. Therefore, the type annotation should be changes to a generic pytree.

Other

  • The dependency on Paramax is unnecessary (?)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions