-
Notifications
You must be signed in to change notification settings - Fork 772
Incorrect sharded matmul results when resharding unreduced axes to mixed sharded/replicated axes #40034
Description
I'm seeing incorrect gradients when training a model that only manifest for particular device counts. The evidence seems to point to the SPMD paritioner not inserting an all-reduce when we reshard from unreduced axes to a sharding with mixed sharded/replicated axes.
This is the smallest reproducer I could make to trigger the bug (it requires on 4k GPUs to run, but you can still get the dumps with cross-compile): https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-repro_script-py
When running on 4k GPUs, we get the output DP replicas DIFFER: max_diff = 0.35546875
Buggy HLO dumps (with DP=2, doesn't insert all-reduce):
before_optimizations: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_buggy_dp2_before_optimizations
after_spmd_paritioner: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_buggy_dp2_after_spmd_partitioner
after_optimizations: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_buggy_dp2_after_optimizations
Correct (fallback) HLO dumps (with DP=4, hits a fallback that rematerializes the tensor):
before_optimizations: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_fallback_dp4_before_optimizations
after_spmd_paritioner: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_fallback_dp4_after_spmd_partitioner
after_optimizations: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_fallback_dp4_after_optimizations
JAX/Jaxlib versions:
jax: 0.9.2.dev20260323+059c83841 (I am building JAX from source, but it's unmodified)
jaxlib: 0.9.1
numpy: 2.4.1
This is AI's analysis of the issue (I am not familiar with the internals of XLA so I'm not 100% sure this is correct, but it seems to make sense). It seems to suggest that in the buggy case, an all-reduce across the "dp" axis is not inserted because the compiler has logic to skip the all-reduce if there are any unreduced axes at all. But in the case that produces correct results, XLA can't find an efficient way to shard the matmul so it hits a fallback that fully replicates the result and repartitions.
The einsum bsh,bsf->hf with out_sharding=PS(("tp",), unreduced={"dp","fsdp","sp"}) produces a partial sum. In XLA's representation:
devices=[1,4,1024]<=[1024,4]T(1,0) last_tile_dims={unreduced}
- dim 0 (H): unpartitioned (size 1 in device grid)
- dim 1 (F): partitioned across tp=4
- remaining 1024 devices (dp×fsdp×sp = 2×256×2): unreduced — each holds a partial sum
The reshard to W_PS = PS(("fsdp","sp"), ("tp",)) targets:
devices=[512,4,2]<=[2,2048]T(1,0) last_tile_dim_replicate
- dim 0 (H): partitioned across fsdp×sp=512
- dim 1 (F): partitioned across tp=4
- dp=2: last_tile_dim_replicate (replicated)
This transition requires two collectives:
1. Reduce-scatter across fsdp×sp: sum the partial sums and distribute H-shards. This happens (we see it as dynamic-slice + collective-permute in the HLO)
2. All-reduce across dp: sum so both DP replicas have the same value. This is dropped at dp=2
At dp=4, the partitioner can't find an efficient decomposition for this transition and falls back to full rematerialization (replicate everything then repartition), which implicitly performs the all-reduce.
At dp=2, it finds a "more efficient" path that handles the reduce-scatter but silently skips the DP all-reduce, because kHasUnreducedAxes causes AllReduceAlongShardingDimsInternal to return early.
In spmd_partitioner.cc AllReduceAlongShardingDimsInternal skips the all-reduce entirely when the operand has the kHasUnreducedAxes frontend attribute
xla/xla/service/spmd/spmd_partitioner.cc
Lines 5642 to 5647 in 8fccf39
| // Skip adding AR if the operand has unreduced axes in its sharding, | |
| // represented by the frontend attribute. | |
| if (operand->frontend_attributes().map().contains(sdy::kHasUnreducedAxes)) { | |
| return operand; | |
| } | |
Let me know if this makes any sense! If not I can try to make a better reproducer for the issue, but hopefully this is enough to debug the issue.