Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import jax, jax.numpy as jnp
import serket as sk

x_train, y_train = ..., ...
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
k1, k2 = jax.random.split(jax.random.key(0))

net = sk.tree_mask(sk.Sequential(
jnp.ravel,
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Install from github::
import serket as sk

x_train, y_train = ..., ...
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
k1, k2 = jax.random.split(jax.random.key(0))

net = sk.tree_mask(sk.Sequential(
jnp.ravel,
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/[guides][core]checkpointing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@
"import optax\n",
"\n",
"net = sk.Sequential(\n",
" sk.nn.Linear(1, 128, key=jr.PRNGKey(0)),\n",
" sk.nn.Linear(1, 128, key=jr.key(0)),\n",
" jax.nn.relu,\n",
" sk.nn.Linear(128, 1, key=jr.PRNGKey(1)),\n",
" sk.nn.Linear(128, 1, key=jr.key(1)),\n",
")\n",
"\n",
"# exclude non-parameters\n",
Expand Down Expand Up @@ -131,8 +131,8 @@
" return net, optim_state, loss\n",
"\n",
"\n",
"x = jax.random.uniform(jax.random.PRNGKey(0), (100, 1))\n",
"y = jnp.sin(x) + jax.random.normal(jax.random.PRNGKey(0), (100, 1)) * 0.1\n",
"x = jax.random.uniform(jax.random.key(0), (100, 1))\n",
"y = jnp.sin(x) + jax.random.normal(jax.random.key(0), (100, 1)) * 0.1\n",
"\n",
"# should save step [0, 2, 4, 6, 8], and keep the last 3 checkpoints\n",
"# namely step [4, 6, 8]\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/[guides][core]distributed_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
],
"source": [
"x = jnp.linspace(-jnp.pi, jnp.pi, 8 * 10).reshape(-1, 1)\n",
"y = jnp.sin(x) + jr.normal(jr.PRNGKey(0), shape=x.shape) * 0.1\n",
"y = jnp.sin(x) + jr.normal(jr.key(0), shape=x.shape) * 0.1\n",
"\n",
"plt.plot(x, y, \"o\", label=\"data\")\n",
"plt.plot(x, jnp.sin(x), label=\"sin(x)\")\n",
Expand Down Expand Up @@ -165,7 +165,7 @@
" return x\n",
"\n",
"\n",
"net = sk.tree_mask(Net(key=jr.PRNGKey(0)))\n",
"net = sk.tree_mask(Net(key=jr.key(0)))\n",
"print(sk.tree_summary(net))"
]
},
Expand Down Expand Up @@ -555,7 +555,7 @@
"config.optim.learning_rate = 1e-3\n",
"config.epochs = 1_000\n",
"config.net = None\n",
"config.key = jr.PRNGKey(0)\n",
"config.key = jr.key(0)\n",
"config.optim.optim_state = None\n",
"\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/[guides][core]evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
" return input\n",
"\n",
"\n",
"net = sk.tree_mask(Net(key=jr.PRNGKey(0)))\n",
"net = sk.tree_mask(Net(key=jr.key(0)))\n",
"x = jnp.linspace(-1, 1, 100)[..., None]\n",
"y = jnp.sin(x * 3.14)\n",
"\n",
Expand All @@ -149,7 +149,7 @@
" return net\n",
"\n",
"\n",
"key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"\n",
"print(\"Before training\", \"-\" * 80)\n",
"print(repr(net))\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/[guides][core]mixed_precision.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"\n",
"half = jnp.float16 # On TPU this should be jnp.bfloat16.\n",
"full = jnp.float32\n",
"k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)\n",
"k1, k2 = jax.random.split(jax.random.key(0), 2)\n",
"mp_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=half)\n",
"\n",
"net = sk.Sequential(\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/[guides][core]subset_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
" return self.linear2(jax.nn.relu(self.linear1(x)))\n",
"\n",
"\n",
"net = Net(key=jax.random.PRNGKey(0))\n",
"net = Net(key=jax.random.key(0))\n",
"net = sk.tree_mask(net)\n",
"x = jnp.linspace(-1, 1, 100)[..., None]\n",
"y = jnp.sin(x * 3.14)\n",
Expand Down Expand Up @@ -247,7 +247,7 @@
}
],
"source": [
"net = Net(key=jax.random.PRNGKey(0))\n",
"net = Net(key=jax.random.key(0))\n",
"net = sk.tree_mask(net)\n",
"\n",
"print(\"linear1.weight before training\")\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/[guides][inter]keras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,9 @@
],
"source": [
"sk_model = sk.Sequential(\n",
" Linear(1, 20, key=jr.PRNGKey(0)),\n",
" Linear(1, 20, key=jr.key(0)),\n",
" jax.nn.tanh,\n",
" Linear(20, 20, key=jr.PRNGKey(1)),\n",
" Linear(20, 20, key=jr.key(1)),\n",
" jax.nn.tanh,\n",
")\n",
"\n",
Expand All @@ -372,7 +372,7 @@
")\n",
"\n",
"x = jnp.linspace(-1, 1, 100)[:, None]\n",
"y = x**2 + jr.normal(jr.PRNGKey(0), (100, 1)) * 0.01\n",
"y = x**2 + jr.normal(jr.key(0), (100, 1)) * 0.01\n",
"model.fit(x, y, epochs=100)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/[guides][inter]tensorflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
" return x\n",
"\n",
"\n",
"net = Net(jax.random.PRNGKey(0))"
"net = Net(jax.random.key(0))"
]
},
{
Expand Down Expand Up @@ -319,7 +319,7 @@
" return net\n",
"\n",
"\n",
"net = Net(jax.random.PRNGKey(0))\n",
"net = Net(jax.random.key(0))\n",
"\n",
"x_ = x.numpy()\n",
"y_ = y.numpy()\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/[guides][other]augmentations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
" sk.image.FFTMotionBlur2D(kernel_size=15, angle=30),\n",
")\n",
"\n",
"tut_aug = net(tut, key=jr.PRNGKey(0))\n",
"tut_aug = net(tut, key=jr.key(0))\n",
"tut_aug = (tut_aug * 255).astype(jnp.uint8)\n",
"plt.imshow(to_channel_last(tut_aug))\n",
"plt.axis(\"off\")"
Expand Down Expand Up @@ -156,7 +156,7 @@
],
"source": [
"# lets create 10 different augmented images from the same image\n",
"keys = jr.split(jr.PRNGKey(1), 10)\n",
"keys = jr.split(jr.key(1), 10)\n",
"tut_augs = jax.vmap(lambda key: net(tut, key=key))(keys)\n",
"tut_augs = (tut_augs * 255).astype(jnp.uint8)\n",
"\n",
Expand Down Expand Up @@ -201,7 +201,7 @@
"batch = jnp.stack([tut, nef], axis=0)\n",
"\n",
"# create a key for each image in the batch\n",
"keys = jr.split(jr.PRNGKey(1), len(batch))\n",
"keys = jr.split(jr.key(1), len(batch))\n",
"\n",
"\n",
"@jax.jit\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/[guides][other]custom_convolutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
" conv_op = staticmethod(my_conv)\n",
"\n",
"\n",
"k1, k2 = jr.split(jr.PRNGKey(0), 2)\n",
"k1, k2 = jr.split(jr.key(0), 2)\n",
"\n",
"basic_conv = sk.nn.Conv2D(\n",
" in_features=1,\n",
Expand Down Expand Up @@ -168,7 +168,7 @@
" conv_op = staticmethod(my_depthwise_conv)\n",
"\n",
"\n",
"k1, k2 = jr.split(jr.PRNGKey(0), 2)\n",
"k1, k2 = jr.split(jr.key(0), 2)\n",
"\n",
"basic_conv = sk.nn.DepthwiseConv2D(\n",
" in_features=1,\n",
Expand Down Expand Up @@ -277,7 +277,7 @@
" conv_op = staticmethod(my_custom_conv)\n",
"\n",
"\n",
"k1, k2 = jr.split(jr.PRNGKey(0), 2)\n",
"k1, k2 = jr.split(jr.key(0), 2)\n",
"\n",
"\n",
"custom_conv = CustomConv2D(\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/[guides][other]deep_ensembles.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
],
"source": [
"x = jnp.linspace(-1, 1, 100).reshape(-1, 1) # 100x1\n",
"y = jnp.sin(2 * jnp.pi * x) + 0.2 * jr.normal(jr.PRNGKey(0), (100, 1)) # 100x1\n",
"y = jnp.sin(2 * jnp.pi * x) + 0.2 * jr.normal(jr.key(0), (100, 1)) # 100x1\n",
"y_clean = jnp.sin(2 * jnp.pi * x) # 100x1\n",
"plt.plot(x, y, \"o\")\n",
"plt.plot(x, y_clean, \"-\")\n",
Expand Down Expand Up @@ -152,7 +152,7 @@
" return sk.tree_unmask(jax.vmap(build_single_mlp)(keys))\n",
"\n",
"\n",
"keys = jr.split(jr.PRNGKey(0), NUM_ENSEMBLES)\n",
"keys = jr.split(jr.key(0), NUM_ENSEMBLES)\n",
"mlps: Batched[sk.nn.MLP] = build_ensemble(keys)\n",
"\n",
"# note that each array in the ensemble has a batch dimension of 10\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/[guides][other]hyperparam.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@
"metadata": {},
"outputs": [],
"source": [
"key = jax.random.PRNGKey(42)\n",
"key = jax.random.key(42)\n",
"x = jnp.linspace(-1, 1, 100)[..., None]\n",
"y = jnp.sin(x * 3.14) + jax.random.normal(key, (100, 1)) * 0.01 + 0.5\n",
"EPOCHS = 1000\n",
"\n",
"\n",
"def objective(trial):\n",
" k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)\n",
" k1, k2 = jax.random.split(jax.random.key(0), 2)\n",
"\n",
" # hidden features selection from 1 to 50\n",
" features = trial.suggest_int(\"hidden_features\", 1, 50)\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/[guides][other]loss_landscape.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@
],
"source": [
"# train with and without skip connections\n",
"key = jr.PRNGKey(config.seed)\n",
"key = jr.key(config.seed)\n",
"res_net = ResConvNet(1, 32, key=key)\n",
"sans_net = SansConvNet(1, 32, key=key)\n",
"\n",
Expand Down Expand Up @@ -485,7 +485,7 @@
}
],
"source": [
"key = jr.PRNGKey(1)\n",
"key = jr.key(1)\n",
"x_sample = x_test[:10]\n",
"y_sample = y_test[:10]\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/[guides][other]optimlib.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@
}
],
"source": [
"net = MLP(key=jr.PRNGKey(0))\n",
"net = MLP(key=jr.key(0))\n",
"optim = Optim(net)\n",
"\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/[guides][train]bilstm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"# 101 samples of 1D data\n",
"x = jnp.linspace(0, 1, 101).reshape(-1, 1)\n",
"y = jnp.sin(2 * jnp.pi * x)\n",
"y += 0.1 * jr.normal(jr.PRNGKey(0), y.shape)\n",
"y += 0.1 * jr.normal(jr.key(0), y.shape)\n",
"# we will use 2 time steps to predict the next time step\n",
"x_train = jnp.stack([x[:-1], x[1:]], axis=1)\n",
"# 100 minibatches x 1 sample x 2 time steps x 1D data\n",
Expand Down Expand Up @@ -112,7 +112,7 @@
" return output[-1]\n",
"\n",
"\n",
"key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"net = BiLstm(1, 64, 1, key=key)\n",
"# 1) mask the non-jaxtype parameters\n",
"net = sk.tree_mask(net)\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/[guides][train]convlstm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@
"config.frame_size = 32\n",
"\n",
"x_shifts = jr.randint(\n",
" jr.PRNGKey(0),\n",
" jr.key(0),\n",
" (config.samples_count,),\n",
" -(config.frame_size // 2),\n",
" config.frame_size // 2,\n",
")\n",
"y_shifts = jr.randint(\n",
" jr.PRNGKey(1),\n",
" jr.key(1),\n",
" (config.samples_count,),\n",
" -(config.frame_size // 2),\n",
" config.frame_size // 2,\n",
Expand Down Expand Up @@ -272,7 +272,7 @@
"source": [
"config.features = 32\n",
"config.epochs = 1000\n",
"config.key = jr.PRNGKey(0)\n",
"config.key = jr.key(0)\n",
"config.optim = ConfigDict()\n",
"config.optim.kind = \"adam\"\n",
"config.optim.init_value = 1e-2\n",
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/[guides][train]fourier_features_network.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@
"source": [
"# define net\n",
"net = sk.Sequential(\n",
" sk.nn.FNN([2] + [M] * 4 + [3], act=\"relu\", key=jax.random.PRNGKey(0)),\n",
" sk.nn.FNN([2] + [M] * 4 + [3], act=\"relu\", key=jax.random.key(0)),\n",
" jax.nn.sigmoid,\n",
")\n",
"# pass non-jaxtype through jax transformation boundaries\n",
Expand Down Expand Up @@ -285,7 +285,7 @@
"source": [
"# define net\n",
"net = sk.Sequential(\n",
" sk.nn.FNN([2 * 2] + [M] * 4 + [3], act=\"relu\", key=jax.random.PRNGKey(0)),\n",
" sk.nn.FNN([2 * 2] + [M] * 4 + [3], act=\"relu\", key=jax.random.key(0)),\n",
" jax.nn.sigmoid,\n",
")\n",
"# pass non-jaxtype through jax transformation boundaries\n",
Expand Down Expand Up @@ -334,7 +334,7 @@
}
],
"source": [
"b1, b2, b3 = jax.vmap(jr.normal, in_axes=(0, None))(jr.split(jr.PRNGKey(0), 3), [M, D])\n",
"b1, b2, b3 = jax.vmap(jr.normal, in_axes=(0, None))(jr.split(jr.key(0), 3), [M, D])\n",
"b1 = b1 * 1\n",
"b2 = b2 * 10\n",
"b3 = b3 * 100\n",
Expand All @@ -345,7 +345,7 @@
"for index, bi in enumerate([b1, b2, b3]):\n",
" # define net\n",
" net = sk.Sequential(\n",
" sk.nn.FNN([2 * M] + [M] * 4 + [3], act=\"relu\", key=jax.random.PRNGKey(0)),\n",
" sk.nn.FNN([2 * M] + [M] * 4 + [3], act=\"relu\", key=jax.random.key(0)),\n",
" jax.nn.sigmoid,\n",
" )\n",
" # get only parameters\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/[guides][train]mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
"\n",
"def train(config: ConfigDict, x_train, y_train):\n",
" # 1) create net and mask out all non-inexact parameters\n",
" net = sk.tree_mask(ConvNet(key=jax.random.PRNGKey(config.seed)))\n",
" net = sk.tree_mask(ConvNet(key=jax.random.key(config.seed)))\n",
"\n",
" # visualize the network\n",
" print(sk.tree_summary(net, depth=1))\n",
Expand Down Expand Up @@ -286,7 +286,7 @@
"print(f\"test accuracy: {test_accuracy}\")\n",
"\n",
"# create 2x5 grid of images\n",
"key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"fig, axes = plt.subplots(2, 5, figsize=(10, 4))\n",
"idxs = jax.random.randint(key, shape=(10,), minval=0, maxval=x_train[0].shape[0])\n",
"\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/[guides][train]pinn_burgers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
"\n",
"\"\"\"boundary conditions\"\"\"\n",
"\n",
"keys = jax.random.split(jax.random.PRNGKey(0), 5)\n",
"keys = jax.random.split(jax.random.key(0), 5)\n",
"\n",
"# u[0,x] = -sin(pi*x)\n",
"t_0 = jnp.ones([N_0, 1], dtype=\"float32\") * 0.0\n",
Expand Down Expand Up @@ -378,7 +378,7 @@
}
],
"source": [
"pinn = PINN(key=jr.PRNGKey(0))\n",
"pinn = PINN(key=jr.key(0))\n",
"\n",
"# mask the network parameters to use it across jax transformations\n",
"pinn = sk.tree_mask(pinn)\n",
Expand Down Expand Up @@ -498,7 +498,7 @@
}
],
"source": [
"pinn = PINN(key=jr.PRNGKey(0))\n",
"pinn = PINN(key=jr.key(0))\n",
"\n",
"# mask the network parameters to use it across jax transformations\n",
"pinn = sk.tree_mask(pinn)\n",
Expand Down
Loading
Loading