diff --git a/qwix/_src/flax_util.py b/qwix/_src/flax_util.py index 422356f..ca0d34e 100644 --- a/qwix/_src/flax_util.py +++ b/qwix/_src/flax_util.py @@ -394,7 +394,7 @@ def update_boxed( axes = update_sharding( axes, shape=shape, split=split, merge=merge, transpose=transpose ) - boxed.set_metadata(sharding_names=axes) + boxed.set_metadata('sharding_names', axes) elif isinstance(boxed, jax.Array): # not boxed. if value is not None: boxed = value