diff --git a/src/ptychi/parallel.py b/src/ptychi/parallel.py index 53689f8..19388a7 100644 --- a/src/ptychi/parallel.py +++ b/src/ptychi/parallel.py @@ -104,11 +104,20 @@ def sync_buffer( buffer = buffer.unsqueeze(0) unsqueezed = True - slicer = slice(None) if indices is None else indices + # For all_reduce, we need consistent shapes across all ranks + # so we must operate on the full buffer, not sliced views if source_rank is not None: + slicer = slice(None) if indices is None else indices dist.broadcast(buffer[slicer], src=source_rank) else: - dist.all_reduce(buffer[slicer], op=op) + if indices is not None: + # Create a temporary buffer for the indexed elements + temp_buffer = torch.zeros_like(buffer) + temp_buffer[indices] = buffer[indices] + dist.all_reduce(temp_buffer, op=op) + buffer[indices] = temp_buffer[indices] + else: + dist.all_reduce(buffer, op=op) if unsqueezed: buffer = buffer.squeeze(0)