From 679183fcf3dd867e5d4ccdf72300b72036a883c5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 6 Nov 2024 12:43:32 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 693827690 --- vmoe/initialization/rules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vmoe/initialization/rules.py b/vmoe/initialization/rules.py index 3e1825c..cdb6373 100644 --- a/vmoe/initialization/rules.py +++ b/vmoe/initialization/rules.py @@ -263,7 +263,7 @@ class ReshapeTransformation(Transformation): shape: Tuple[int, ...] = flax.struct.field(pytree_node=False) def __call__(self) -> Array: - return jnp.reshape(self.array, newshape=self.shape) + return jnp.reshape(self.array, shape=self.shape) class SqueezeTransformation(Transformation):