From b15334ccace429b258d6922fdf5780d92c63f8cb Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Fri, 17 Oct 2025 19:55:54 -0500 Subject: [PATCH] the all_reduce thing in parallel.py was having errors with different buffer lengths; this should fix this --- src/ptychi/parallel.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/ptychi/parallel.py b/src/ptychi/parallel.py index 53689f8..19388a7 100644 --- a/src/ptychi/parallel.py +++ b/src/ptychi/parallel.py @@ -104,11 +104,20 @@ def sync_buffer( buffer = buffer.unsqueeze(0) unsqueezed = True - slicer = slice(None) if indices is None else indices + # For all_reduce, we need consistent shapes across all ranks + # so we must operate on the full buffer, not sliced views if source_rank is not None: + slicer = slice(None) if indices is None else indices dist.broadcast(buffer[slicer], src=source_rank) else: - dist.all_reduce(buffer[slicer], op=op) + if indices is not None: + # Create a temporary buffer for the indexed elements + temp_buffer = torch.zeros_like(buffer) + temp_buffer[indices] = buffer[indices] + dist.all_reduce(temp_buffer, op=op) + buffer[indices] = temp_buffer[indices] + else: + dist.all_reduce(buffer, op=op) if unsqueezed: buffer = buffer.squeeze(0)