From 57000786a8d5b821a79f1f0ec151d15db078310a Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 22 Jan 2026 13:20:00 -0800 Subject: [PATCH] Rename sharding_names to out_sharding in NNX Variable metadata This CL renames the sharding_names attribute to out_sharding for better consistency with the sharding API. The new name more clearly indicates the purpose of this metadata field. ## Changes - Bump Flax version to 0.12.4 - Core changes in variablelib.py: - Add sharding_names to out_sharding metadata remapping for backward compatibility - Add deprecated sharding_names property that returns out_sharding with a warning - Update nnx/spmd.py, core/spmd.py, core/meta.py, linen/spmd.py to use out_sharding - Update all NNX tests to use the new attribute name - Update qwix flax_util.py to check for out_sharding first, with fallback to sharding_names - Update maxtext initializers.py to check for out_sharding first - Update documentation and examples to use out_sharding ## Backward Compatibility Existing code using sharding_names will continue to work via: - Metadata remapping during Variable creation - Deprecated Variable.sharding_names property PiperOrigin-RevId: 859745972 --- qwix/_src/flax_util.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/qwix/_src/flax_util.py b/qwix/_src/flax_util.py index 422356f..4ea5d80 100644 --- a/qwix/_src/flax_util.py +++ b/qwix/_src/flax_util.py @@ -389,12 +389,18 @@ def update_boxed( if value is not None: boxed = boxed.replace(value) shape = boxed.value.shape - axes = boxed.get_metadata().get('sharding_names', None) + metadata = boxed.get_metadata() + # Check for out_sharding first (Flax >= 0.12.4), then sharding_names + if 'out_sharding' in metadata: + sharding_key = 'out_sharding' + else: + sharding_key = 'sharding_names' + axes = metadata.get(sharding_key, None) if isinstance(axes, (list, tuple)): axes = update_sharding( axes, shape=shape, split=split, merge=merge, transpose=transpose ) - boxed.set_metadata(sharding_names=axes) + boxed.set_metadata(sharding_key, axes) elif isinstance(boxed, jax.Array): # not boxed. if value is not None: boxed = value