diff --git a/qwix/_src/flax_util.py b/qwix/_src/flax_util.py index 422356f..4ea5d80 100644 --- a/qwix/_src/flax_util.py +++ b/qwix/_src/flax_util.py @@ -389,12 +389,18 @@ def update_boxed( if value is not None: boxed = boxed.replace(value) shape = boxed.value.shape - axes = boxed.get_metadata().get('sharding_names', None) + metadata = boxed.get_metadata() + # Check for out_sharding first (Flax >= 0.12.4), then sharding_names + if 'out_sharding' in metadata: + sharding_key = 'out_sharding' + else: + sharding_key = 'sharding_names' + axes = metadata.get(sharding_key, None) if isinstance(axes, (list, tuple)): axes = update_sharding( axes, shape=shape, split=split, merge=merge, transpose=transpose ) - boxed.set_metadata(sharding_names=axes) + boxed.set_metadata(sharding_key, axes) elif isinstance(boxed, jax.Array): # not boxed. if value is not None: boxed = value