Skip to content

Incorrect sharded matmul results when resharding unreduced axes to mixed sharded/replicated axes #40034

@justinjfu

Description

@justinjfu

I'm seeing incorrect gradients when training a model that only manifest for particular device counts. The evidence seems to point to the SPMD paritioner not inserting an all-reduce when we reshard from unreduced axes to a sharding with mixed sharded/replicated axes.

This is the smallest reproducer I could make to trigger the bug (it requires on 4k GPUs to run, but you can still get the dumps with cross-compile): https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-repro_script-py

When running on 4k GPUs, we get the output DP replicas DIFFER: max_diff = 0.35546875

Buggy HLO dumps (with DP=2, doesn't insert all-reduce):
before_optimizations: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_buggy_dp2_before_optimizations
after_spmd_paritioner: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_buggy_dp2_after_spmd_partitioner
after_optimizations: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_buggy_dp2_after_optimizations

Correct (fallback) HLO dumps (with DP=4, hits a fallback that rematerializes the tensor):
before_optimizations: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_fallback_dp4_before_optimizations
after_spmd_paritioner: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_fallback_dp4_after_spmd_partitioner
after_optimizations: https://gist.github.com/justinjfu/7b5134ff4332024c11594324103ae725#file-hlo_dump_fallback_dp4_after_optimizations

JAX/Jaxlib versions:

jax:    0.9.2.dev20260323+059c83841 (I am building JAX from source, but it's unmodified)
jaxlib: 0.9.1
numpy:  2.4.1

This is AI's analysis of the issue (I am not familiar with the internals of XLA so I'm not 100% sure this is correct, but it seems to make sense). It seems to suggest that in the buggy case, an all-reduce across the "dp" axis is not inserted because the compiler has logic to skip the all-reduce if there are any unreduced axes at all. But in the case that produces correct results, XLA can't find an efficient way to shard the matmul so it hits a fallback that fully replicates the result and repartitions.

The einsum bsh,bsf->hf with out_sharding=PS(("tp",), unreduced={"dp","fsdp","sp"}) produces a partial sum. In XLA's representation:
  devices=[1,4,1024]<=[1024,4]T(1,0) last_tile_dims={unreduced}
  - dim 0 (H): unpartitioned (size 1 in device grid)
  - dim 1 (F): partitioned across tp=4
  - remaining 1024 devices (dp×fsdp×sp = 2×256×2): unreduced — each holds a partial sum

  The reshard to W_PS = PS(("fsdp","sp"), ("tp",)) targets:
  devices=[512,4,2]<=[2,2048]T(1,0) last_tile_dim_replicate
  - dim 0 (H): partitioned across fsdp×sp=512
  - dim 1 (F): partitioned across tp=4
  - dp=2: last_tile_dim_replicate (replicated)

  This transition requires two collectives:
  1. Reduce-scatter across fsdp×sp: sum the partial sums and distribute H-shards. This happens (we see it as dynamic-slice + collective-permute in the HLO)
  2. All-reduce across dp: sum so both DP replicas have the same value. This is dropped at dp=2

At dp=4, the partitioner can't find an efficient decomposition for this transition and falls back to full rematerialization (replicate everything then repartition), which implicitly performs the all-reduce. 
At dp=2, it finds a "more efficient" path that handles the reduce-scatter but silently skips the DP all-reduce, because kHasUnreducedAxes causes AllReduceAlongShardingDimsInternal to return early. 
In spmd_partitioner.cc AllReduceAlongShardingDimsInternal skips the all-reduce entirely when the operand has the kHasUnreducedAxes frontend attribute

// Skip adding AR if the operand has unreduced axes in its sharding,
// represented by the frontend attribute.
if (operand->frontend_attributes().map().contains(sdy::kHasUnreducedAxes)) {
return operand;
}


Let me know if this makes any sense! If not I can try to make a better reproducer for the issue, but hopefully this is enough to debug the issue.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions