Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ def _initialize_distributed(args):
)

_distributed_initialized = True
jax.clear_caches()
jax.config.update(
"jax_use_shardy_partitioner", False
) # CollectiveGEMM does not work with Shardy yet

assert jax.local_device_count() == 1, (
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
Expand Down
5 changes: 1 addition & 4 deletions examples/jax/collective_gemm/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_shard
def run_gemm_tests(args, mesh=None):
"""Execute GEMM tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)

# Initialize distributed with provided arguments
_initialize_distributed(args)
Expand Down Expand Up @@ -137,8 +135,7 @@ def run_gemm_tests(args, mesh=None):
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=collective_op,
# CollectiveGEMM output should have a correct sharding without applying sharding constraint
output_sharding=None,
output_sharding=output_sharding,
)
assert (
ref_output.sharding == output.sharding
Expand Down
2 changes: 0 additions & 2 deletions examples/jax/collective_gemm/test_layernorm_mlp_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def _value_and_grad_layernorm_mlp(
def run_layernorm_mlp_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)

# Initialize distributed with provided arguments
_initialize_distributed(args)
Expand Down
13 changes: 10 additions & 3 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,9 +1172,16 @@ def shardy_sharding_rule(
del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer

if not collective_op.is_none:
raise NotImplementedError(
"CollectiveGEMM with Shardy propagation is not supported yet! Please turn off"
" Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false"
warnings.warn(
"CollectiveGEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output.\n To resolve this, apply a sharding constraint on the output"
" using one of the following options:\n"
" - TE `dense` vjp: set `output_axes`.\n"
" - TE `layernorm_mlp` vjp: set `dot_2_input_axes`.\n"
" - TE `transformer_engine.jax.cpp_extensions.gemm`: apply"
" `jax.lax.with_sharding_constraint` on the output.\n"
" - TE via MaxText: no action needed.",
UserWarning,
)

prefix = "Gemm_"
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/jax/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def dense(
if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!")

if collective_op_set != tex.noop_collective_op_set and not output_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `output_axes` to apply the correct sharding constraint.",
UserWarning,
)

if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/jax/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from typing import List, Tuple, Sequence, Union, Callable
from functools import partial
import warnings

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -275,6 +276,13 @@ def _layernorm_mlp_fwd_rule(
assert not collective_op_set_1.forward.is_reduce_scatter
assert not collective_op_set_2.forward.is_all_gather

if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.",
UserWarning,
)

# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in)
Expand Down
Loading