diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f6479d4..ec94bee 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v2 diff --git a/README.md b/README.md index dec8886..30663cb 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ |[**Quick Example**](#QuickExample) ![Tests](https://github.com/ASEM000/serket/actions/workflows/tests.yml/badge.svg) -![pyver](https://img.shields.io/badge/python-3.9%203.9%203.10%203.12-blue) +![pyver](https://img.shields.io/badge/python-3.10%203.11%203.12%203.13-blue) ![codestyle](https://img.shields.io/badge/codestyle-black-black) [![codecov](https://codecov.io/gh/ASEM000/serket/branch/main/graph/badge.svg?token=C6NXOK9EVS)](https://codecov.io/gh/ASEM000/serket) [![Documentation Status](https://readthedocs.org/projects/serket/badge/?version=latest)](https://serket.readthedocs.io/?badge=latest) diff --git a/serket/_src/nn/activation.py b/serket/_src/nn/activation.py index a0274fc..d6f4c5e 100644 --- a/serket/_src/nn/activation.py +++ b/serket/_src/nn/activation.py @@ -341,7 +341,7 @@ def resolve_act(act): return act if isinstance(act, str): try: - return jax.tree_map(lambda x: x, act_map[act]) + return jax.tree_util.tree_map(lambda x: x, act_map[act]) except KeyError: raise ValueError(f"Unknown {act=}, available activations: {list(act_map)}") raise TypeError(f"Unknown activation type {type(act)}.") diff --git a/serket/_src/nn/initialization.py b/serket/_src/nn/initialization.py index 975a603..053c06f 100644 --- a/serket/_src/nn/initialization.py +++ b/serket/_src/nn/initialization.py @@ -14,34 +14,36 @@ from __future__ import annotations from collections.abc import Callable as ABCCallable -from typing import get_args +from typing import Callable, get_args import jax import jax.nn.initializers as ji +import jax.numpy as jnp import jax.tree_util as jtu from serket._src.utils.typing import InitFuncType, InitLiteral, InitType -inits: list[InitFuncType] = [ - ji.he_normal(), - ji.he_uniform(), - ji.glorot_normal(), - ji.glorot_uniform(), - ji.lecun_normal(), - ji.lecun_uniform(), +inits: list[InitType] = [ + ji.he_normal(in_axis=1, out_axis=0), + ji.he_uniform(in_axis=1, out_axis=0), + ji.glorot_normal(in_axis=1, out_axis=0), + ji.glorot_uniform(in_axis=1, out_axis=0), + ji.lecun_normal(in_axis=1, out_axis=0), + ji.lecun_uniform(in_axis=1, out_axis=0), ji.normal(), ji.uniform(), ji.ones, ji.zeros, - ji.xavier_normal(), - ji.xavier_uniform(), + ji.xavier_normal(in_axis=1, out_axis=0), + ji.xavier_uniform(in_axis=1, out_axis=0), ji.orthogonal(), ] -init_map: dict[str, InitType] = dict(zip(get_args(InitLiteral), inits)) +init_map: dict[str, Callable[..., InitType]] = dict(zip(get_args(InitLiteral), inits)) -def resolve_init(init): + +def resolve_init(init) -> jtu.Partial[InitFuncType]: if isinstance(init, str): try: return jtu.Partial(jax.tree_map(lambda x: x, init_map[init])) diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index c35329a..6facba2 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -15,6 +15,7 @@ from __future__ import annotations import functools as ft +import math from typing import Sequence import jax @@ -209,9 +210,15 @@ def __init__( k1, k2 = jr.split(key) - weight_shape = (*out_features, *in_features) + weight_shape = (math.prod(out_features), math.prod(in_features)) self.weight = resolve_init(weight_init)(k1, weight_shape, dtype) - self.bias = resolve_init(bias_init)(k2, out_features, dtype) + self.bias = resolve_init(bias_init)(k2, (math.prod(out_features),), dtype) + + if self.weight is not None: + self.weight = self.weight.reshape(out_features + in_features) + + if self.bias is not None: + self.bias = self.bias.reshape(out_features) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__(self, input: jax.Array) -> jax.Array: @@ -405,10 +412,14 @@ def __init__( self.in_bias = resolve_init(bias_init)(k2, (hidden_features,), dtype) k3, k4 = jr.split(mid_key) - mid_weight_shape = (num_hidden_layers, hidden_features, hidden_features) - self.mid_weight = resolve_init(weight_init)(k3, mid_weight_shape, dtype) - mid_bias_shape = (num_hidden_layers, hidden_features) - self.mid_bias = resolve_init(bias_init)(k4, mid_bias_shape, dtype) + + init_func = jax.vmap(resolve_init(weight_init), in_axes=(0, None, None)) + k3s = jr.split(k3, num_hidden_layers) + self.mid_weight = init_func(k3s, (hidden_features, hidden_features), dtype) + + init_func = jax.vmap(resolve_init(bias_init), in_axes=(0, None, None)) + k4s = jr.split(k4, num_hidden_layers) + self.mid_bias = init_func(k4s, (hidden_features,), dtype) k5, k6 = jr.split(out_key) out_weight_shape = (out_features, hidden_features) diff --git a/serket/_src/nn/recurrent.py b/serket/_src/nn/recurrent.py index 0ec104e..8bdd138 100644 --- a/serket/_src/nn/recurrent.py +++ b/serket/_src/nn/recurrent.py @@ -383,7 +383,7 @@ def __init__( self.in_hidden_to_hidden = Linear( in_features=in_features + hidden_features, - out_features=hidden_features, + out_features=hidden_features * 4, weight_init=lambda *_: jnp.concatenate([i2h.weight, h2h.weight], axis=-1), bias_init=lambda *_: i2h.bias, dtype=dtype, diff --git a/serket/_src/utils/typing.py b/serket/_src/utils/typing.py index 995866c..9a07aeb 100644 --- a/serket/_src/utils/typing.py +++ b/serket/_src/utils/typing.py @@ -61,7 +61,7 @@ Shape = Tuple[int, ...] DType = Union[np.dtype, str, Any] -InitFuncType = Callable[[jax.Array, Shape, DType], jax.Array] +InitFuncType = Callable[[jax.Array, Shape, DType], jax.Array | None] InitType = Union[InitLiteral, InitFuncType] MethodKind = Literal["nearest", "linear", "cubic", "lanczos3", "lanczos5"] Weight = Annotated[jax.Array, "OI..."]