diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index ae3888cf04..6eb588c849 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -189,12 +189,12 @@ def _warn_gspmd_deprecation_once(): global _gspmd_deprecation_warned if not _gspmd_deprecation_warned: warnings.warn( - "GSPMD sharding propagation is planned to be removed in June 2026." - " It is no longer maintained or tested. Use it at your own risk." - " Please use Shardy partitioner instead." + "GSPMD sharding propagation rules in TE-JAX are planned to be removed in June 2026." + " They are no longer maintained or tested. Use them at your own risk." + " Please use Shardy propagation instead." " In case you cannot upgrade to a JAX version that supports Shardy, please reach out!", DeprecationWarning, - stacklevel=3, + stacklevel=2, ) _gspmd_deprecation_warned = True @@ -234,12 +234,24 @@ def name_of_wrapper_p(): outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) + if _JAX_GSPMD_SUPPORTED: - if "infer_sharding_from_operands" in cls.__dict__: - _warn_gspmd_deprecation_once() - gspmd_kwargs = {"infer_sharding_from_operands": cls.infer_sharding_from_operands} + fn = cls.__dict__.get("infer_sharding_from_operands") + if fn is not None: + actual_fn = ( + cls.infer_sharding_from_operands + ) # Use descriptor protocol to unwrap staticmethod + + def _gspmd_wrapper(*args, **kwargs): + _warn_gspmd_deprecation_once() + return actual_fn(*args, **kwargs) + + gspmd_kwargs = {"infer_sharding_from_operands": _gspmd_wrapper} + else: + gspmd_kwargs = {"infer_sharding_from_operands": cls.infer_sharding_from_operands} else: gspmd_kwargs = {} + outer_p_lower.def_partition( partition=cls.partition, sharding_rule=cls.shardy_sharding_rule,