Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
|[**Quick Example**](#QuickExample)

![Tests](https://github.com/ASEM000/serket/actions/workflows/tests.yml/badge.svg)
![pyver](https://img.shields.io/badge/python-3.9%203.9%203.10%203.12-blue)
![pyver](https://img.shields.io/badge/python-3.10%203.11%203.12%203.13-blue)
![codestyle](https://img.shields.io/badge/codestyle-black-black)
[![codecov](https://codecov.io/gh/ASEM000/serket/branch/main/graph/badge.svg?token=C6NXOK9EVS)](https://codecov.io/gh/ASEM000/serket)
[![Documentation Status](https://readthedocs.org/projects/serket/badge/?version=latest)](https://serket.readthedocs.io/?badge=latest)
Expand Down
2 changes: 1 addition & 1 deletion serket/_src/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def resolve_act(act):
return act
if isinstance(act, str):
try:
return jax.tree_map(lambda x: x, act_map[act])
return jax.tree_util.tree_map(lambda x: x, act_map[act])
except KeyError:
raise ValueError(f"Unknown {act=}, available activations: {list(act_map)}")
raise TypeError(f"Unknown activation type {type(act)}.")
26 changes: 14 additions & 12 deletions serket/_src/nn/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,36 @@
from __future__ import annotations

from collections.abc import Callable as ABCCallable
from typing import get_args
from typing import Callable, get_args

import jax
import jax.nn.initializers as ji
import jax.numpy as jnp
import jax.tree_util as jtu

from serket._src.utils.typing import InitFuncType, InitLiteral, InitType

inits: list[InitFuncType] = [
ji.he_normal(),
ji.he_uniform(),
ji.glorot_normal(),
ji.glorot_uniform(),
ji.lecun_normal(),
ji.lecun_uniform(),
inits: list[InitType] = [
ji.he_normal(in_axis=1, out_axis=0),
ji.he_uniform(in_axis=1, out_axis=0),
ji.glorot_normal(in_axis=1, out_axis=0),
ji.glorot_uniform(in_axis=1, out_axis=0),
ji.lecun_normal(in_axis=1, out_axis=0),
ji.lecun_uniform(in_axis=1, out_axis=0),
ji.normal(),
ji.uniform(),
ji.ones,
ji.zeros,
ji.xavier_normal(),
ji.xavier_uniform(),
ji.xavier_normal(in_axis=1, out_axis=0),
ji.xavier_uniform(in_axis=1, out_axis=0),
ji.orthogonal(),
]

init_map: dict[str, InitType] = dict(zip(get_args(InitLiteral), inits))

init_map: dict[str, Callable[..., InitType]] = dict(zip(get_args(InitLiteral), inits))

def resolve_init(init):

def resolve_init(init) -> jtu.Partial[InitFuncType]:
if isinstance(init, str):
try:
return jtu.Partial(jax.tree_map(lambda x: x, init_map[init]))
Expand Down
23 changes: 17 additions & 6 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import functools as ft
import math
from typing import Sequence

import jax
Expand Down Expand Up @@ -209,9 +210,15 @@ def __init__(

k1, k2 = jr.split(key)

weight_shape = (*out_features, *in_features)
weight_shape = (math.prod(out_features), math.prod(in_features))
self.weight = resolve_init(weight_init)(k1, weight_shape, dtype)
self.bias = resolve_init(bias_init)(k2, out_features, dtype)
self.bias = resolve_init(bias_init)(k2, (math.prod(out_features),), dtype)

if self.weight is not None:
self.weight = self.weight.reshape(out_features + in_features)

if self.bias is not None:
self.bias = self.bias.reshape(out_features)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(self, input: jax.Array) -> jax.Array:
Expand Down Expand Up @@ -405,10 +412,14 @@ def __init__(
self.in_bias = resolve_init(bias_init)(k2, (hidden_features,), dtype)

k3, k4 = jr.split(mid_key)
mid_weight_shape = (num_hidden_layers, hidden_features, hidden_features)
self.mid_weight = resolve_init(weight_init)(k3, mid_weight_shape, dtype)
mid_bias_shape = (num_hidden_layers, hidden_features)
self.mid_bias = resolve_init(bias_init)(k4, mid_bias_shape, dtype)

init_func = jax.vmap(resolve_init(weight_init), in_axes=(0, None, None))
k3s = jr.split(k3, num_hidden_layers)
self.mid_weight = init_func(k3s, (hidden_features, hidden_features), dtype)

init_func = jax.vmap(resolve_init(bias_init), in_axes=(0, None, None))
k4s = jr.split(k4, num_hidden_layers)
self.mid_bias = init_func(k4s, (hidden_features,), dtype)

k5, k6 = jr.split(out_key)
out_weight_shape = (out_features, hidden_features)
Expand Down
2 changes: 1 addition & 1 deletion serket/_src/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __init__(

self.in_hidden_to_hidden = Linear(
in_features=in_features + hidden_features,
out_features=hidden_features,
out_features=hidden_features * 4,
weight_init=lambda *_: jnp.concatenate([i2h.weight, h2h.weight], axis=-1),
bias_init=lambda *_: i2h.bias,
dtype=dtype,
Expand Down
2 changes: 1 addition & 1 deletion serket/_src/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

Shape = Tuple[int, ...]
DType = Union[np.dtype, str, Any]
InitFuncType = Callable[[jax.Array, Shape, DType], jax.Array]
InitFuncType = Callable[[jax.Array, Shape, DType], jax.Array | None]
InitType = Union[InitLiteral, InitFuncType]
MethodKind = Literal["nearest", "linear", "cubic", "lanczos3", "lanczos5"]
Weight = Annotated[jax.Array, "OI..."]
Expand Down
Loading