diff --git a/clt/models/activations.py b/clt/models/activations.py index b48df5d..ca0de1a 100644 --- a/clt/models/activations.py +++ b/clt/models/activations.py @@ -1,9 +1,9 @@ import torch from typing import Optional, Tuple, Dict, List -import torch.distributed as dist import logging from clt.config import CLTConfig from torch.distributed import ProcessGroup +from clt.parallel import ops as dist_ops class BatchTopK(torch.autograd.Function): @@ -234,9 +234,7 @@ def _apply_batch_topk_helper( ) -> Dict[int, torch.Tensor]: """Helper to apply BatchTopK globally across concatenated layer pre-activations.""" - world_size = 1 - if process_group is not None and dist.is_initialized(): - world_size = dist.get_world_size(process_group) + world_size = dist_ops.get_world_size(process_group) if not preactivations_dict: logger_helpers.warning(f"Rank {rank}: _apply_batch_topk_helper received empty preactivations_dict.") @@ -310,9 +308,9 @@ def _apply_batch_topk_helper( concatenated_preactivations_original, k_val, concatenated_preactivations_normalized ) mask.copy_(local_mask) - dist.broadcast(mask, src=0, group=process_group) + dist_ops.broadcast(mask, src=0, group=process_group) else: - dist.broadcast(mask, src=0, group=process_group) + dist_ops.broadcast(mask, src=0, group=process_group) else: mask = BatchTopK._compute_mask( concatenated_preactivations_original, k_val, concatenated_preactivations_normalized @@ -340,9 +338,7 @@ def _apply_token_topk_helper( process_group: Optional[ProcessGroup], ) -> Dict[int, torch.Tensor]: """Helper to apply TokenTopK globally across concatenated layer pre-activations.""" - world_size = 1 - if process_group is not None and dist.is_initialized(): - world_size = dist.get_world_size(process_group) + world_size = dist_ops.get_world_size(process_group) if not preactivations_dict: logger_helpers.warning(f"Rank {rank}: _apply_token_topk_helper received empty preactivations_dict.") @@ -418,9 +414,9 @@ def _apply_token_topk_helper( concatenated_preactivations_normalized, ) mask.copy_(local_mask) - dist.broadcast(mask, src=0, group=process_group) + dist_ops.broadcast(mask, src=0, group=process_group) else: - dist.broadcast(mask, src=0, group=process_group) + dist_ops.broadcast(mask, src=0, group=process_group) else: mask = TokenTopK._compute_mask( concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized diff --git a/clt/models/clt.py b/clt/models/clt.py index 92306ca..3de28fd 100644 --- a/clt/models/clt.py +++ b/clt/models/clt.py @@ -1,7 +1,6 @@ import torch from typing import Dict, Optional, Union, Tuple, List import logging -import torch.distributed as dist from clt.config import CLTConfig from clt.models.base import BaseTranscoder @@ -12,6 +11,7 @@ from clt.models.theta import ThetaManager from clt.activations.registry import get_activation_fn +from clt.parallel import ops as dist_ops from torch.distributed import ProcessGroup @@ -34,13 +34,8 @@ def __init__( ): super().__init__(config) self.process_group = process_group - if process_group is None or not dist.is_initialized(): - self.world_size = 1 - self.rank = 0 - self.process_group = None - else: - self.world_size = dist.get_world_size(process_group) - self.rank = dist.get_rank(process_group) + self.world_size = dist_ops.get_world_size(process_group) + self.rank = dist_ops.get_rank(process_group) self.dtype = self._resolve_dtype(config.clt_dtype) if device is not None: diff --git a/clt/models/decoder.py b/clt/models/decoder.py index 1c0e8d3..68fd5c5 100644 --- a/clt/models/decoder.py +++ b/clt/models/decoder.py @@ -2,11 +2,11 @@ import torch.nn as nn from typing import Dict, Optional import logging -import torch.distributed as dist -from torch.distributed import ProcessGroup from clt.config import CLTConfig from clt.models.parallel import RowParallelLinear +from clt.parallel import ops as dist_ops +from torch.distributed import ProcessGroup logger = logging.getLogger(__name__) @@ -31,12 +31,12 @@ def __init__( self.device = device self.dtype = dtype - if process_group is None or not dist.is_initialized(): + if process_group is None or not dist_ops.is_dist_initialized_and_available(): self.world_size = 1 self.rank = 0 else: - self.world_size = dist.get_world_size(process_group) - self.rank = dist.get_rank(process_group) + self.world_size = dist_ops.get_world_size(process_group) + self.rank = dist_ops.get_rank(process_group) self.decoders = nn.ModuleDict( { @@ -175,8 +175,8 @@ def get_decoder_norms(self) -> torch.Tensor: f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}." ) - if self.process_group is not None and dist.is_initialized(): - dist.all_reduce(local_norms_sq_accum, op=dist.ReduceOp.SUM, group=self.process_group) + if self.process_group is not None and dist_ops.is_dist_initialized_and_available(): + dist_ops.all_reduce(local_norms_sq_accum, op=dist_ops.SUM, group=self.process_group) full_decoder_norms[src_layer] = torch.sqrt(local_norms_sq_accum).to(self.dtype) diff --git a/clt/models/encoder.py b/clt/models/encoder.py index 474f729..b9032e1 100644 --- a/clt/models/encoder.py +++ b/clt/models/encoder.py @@ -2,10 +2,10 @@ import torch.nn as nn from typing import Dict, List, Tuple, Optional import logging -import torch.distributed as dist from clt.config import CLTConfig from clt.models.parallel import ColumnParallelLinear +from clt.parallel import ops as dist_ops from torch.distributed import ProcessGroup logger = logging.getLogger(__name__) @@ -31,12 +31,8 @@ def __init__( self.device = device self.dtype = dtype - if process_group is None or not dist.is_initialized(): - self.world_size = 1 - self.rank = 0 - else: - self.world_size = dist.get_world_size(process_group) - self.rank = dist.get_rank(process_group) + self.world_size = dist_ops.get_world_size(process_group) + self.rank = dist_ops.get_rank(process_group) self.encoders = nn.ModuleList( [ diff --git a/clt/models/parallel.py b/clt/models/parallel.py index c657613..13b6760 100644 --- a/clt/models/parallel.py +++ b/clt/models/parallel.py @@ -1,12 +1,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist from torch.distributed import ProcessGroup import math from typing import Callable, Optional, cast, Tuple from . import mark_replicated +from clt.parallel import ops as dist_ops class _ParallelLinear(nn.Module): @@ -32,14 +32,12 @@ def __init__( super().__init__() self.process_group = process_group - # Handle non-distributed case - if process_group is None or not dist.is_initialized(): - self.world_size = 1 - self.rank = 0 + # Handle non-distributed case using new utility functions + self.world_size = dist_ops.get_world_size(process_group) + self.rank = dist_ops.get_rank(process_group) + # If world_size is 1, process_group should effectively be None for logic below + if self.world_size == 1: self.process_group = None - else: - self.world_size = dist.get_world_size(process_group) - self.rank = dist.get_rank(process_group) self.partition_dim = partition_dim self.input_is_parallel = input_is_parallel @@ -108,15 +106,16 @@ class _Gather(torch.autograd.Function): @staticmethod def forward(ctx, input_: torch.Tensor, process_group: ProcessGroup, dim: int, full_dim_size: Optional[int]): - if process_group is None or not dist.is_initialized() or dist.get_world_size(process_group) == 1: + # Use new utility functions + if not dist_ops.is_dist_initialized_and_available() or dist_ops.get_world_size(process_group) == 1: ctx.dim = dim ctx.local_dim = input_.size(dim) ctx.full_dim_size = full_dim_size or input_.size(dim) ctx.process_group = None # Mark non-distributed case return input_ - world_size = dist.get_world_size(process_group) - rank = dist.get_rank(process_group) + world_size = dist_ops.get_world_size(process_group) + rank = dist_ops.get_rank(process_group) ctx.dim = dim ctx.local_dim = input_.size(dim) @@ -131,8 +130,8 @@ def forward(ctx, input_: torch.Tensor, process_group: ProcessGroup, dim: int, fu # can track the dependency (no copy!). gathered[rank] = input_contig - # Perform the collective. - dist.all_gather(gathered, input_contig, group=process_group) + # Perform the collective using new utility function wrapper + dist_ops.all_gather(gathered, input_contig, group=process_group) output = torch.cat(gathered, dim=dim) @@ -150,10 +149,15 @@ def backward(ctx, *grad_outputs): grad_output = grad_outputs[0] # Non-distributed: gradient flows straight through. - if ctx.process_group is None or not dist.is_initialized() or dist.get_world_size(ctx.process_group) == 1: + # Use new utility functions + if ( + ctx.process_group is None + or not dist_ops.is_dist_initialized_and_available() + or dist_ops.get_world_size(ctx.process_group) == 1 + ): return grad_output, None, None, None - rank = dist.get_rank(ctx.process_group) + rank = dist_ops.get_rank(ctx.process_group) # Compute start/end indices for this rank's slice along the gather dim. local_dim_padded = ctx.local_dim # Already accounts for padding in weight shape. @@ -179,25 +183,28 @@ class _Reduce(torch.autograd.Function): @staticmethod def forward(ctx, input_: torch.Tensor, process_group: Optional[ProcessGroup]): - if process_group is None or not dist.is_initialized() or dist.get_world_size(process_group) == 1: + # Use new utility functions + if not dist_ops.is_dist_initialized_and_available() or dist_ops.get_world_size(process_group) == 1: ctx.process_group = None # Mark non-distributed case return input_ ctx.process_group = process_group input_contig = input_.contiguous() # Ensure contiguous before collective - # Perform the all-reduce with SUM operation. - # The operation is in-place on input_contig if it's the same object for all_reduce's output internally, - # or if all_reduce returns a new tensor, that's what we return. - # For clarity, let's assume all_reduce modifies input_contig or we assign its result. - dist.all_reduce(input_contig, op=dist.ReduceOp.SUM, group=process_group) + # Perform the all-reduce with SUM operation using new utility function wrapper. + dist_ops.all_reduce(input_contig, op=dist_ops.SUM, group=process_group) # The tensor input_contig now holds the sum. return input_contig @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: # Non-distributed: gradient flows straight through. - if ctx.process_group is None or not dist.is_initialized() or dist.get_world_size(ctx.process_group) == 1: + # Use new utility functions + if ( + ctx.process_group is None + or not dist_ops.is_dist_initialized_and_available() + or dist_ops.get_world_size(ctx.process_group) == 1 + ): # Match the number of forward inputs in return for consistency return grad_output.contiguous() if grad_output is not None else None, None @@ -220,10 +227,11 @@ def _reduce(input_, process_group): and broken optimisation. The caller can always divide afterwards if an average is truly desired, but for the core TP math we need the raw sum. """ - if process_group is None or not dist.is_initialized(): + # Use new utility functions + if not dist_ops.is_dist_initialized_and_available(): return input_ # No-op if not distributed - world_size = dist.get_world_size(process_group) + world_size = dist_ops.get_world_size(process_group) if world_size == 1: return input_ @@ -239,14 +247,15 @@ def _split(input_, process_group, dim=-1): Assumes uniform padding, so each rank gets ceil(full_dim / world_size). Handles truncation for ranks that would exceed the original full dimension. """ - if process_group is None or not dist.is_initialized(): + # Use new utility functions + if not dist_ops.is_dist_initialized_and_available(): return input_ # No-op if not distributed - world_size = dist.get_world_size(process_group) + world_size = dist_ops.get_world_size(process_group) if world_size == 1: return input_ - rank = dist.get_rank(process_group) + rank = dist_ops.get_rank(process_group) full_dim_size = input_.size(dim) # Calculate the size of each slice (using ceil for uniform distribution) @@ -402,10 +411,11 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # Add bias *after* reduction if self.bias and self.bias_param is not None: - # Cast bias_param for type checker; runtime None already guarded. + # The runtime check `self.bias_param is not None` is the primary guard. + # Casting `self.bias_param` to `torch.Tensor` helps the type checker. reduced_output = reduced_output + cast(torch.Tensor, self.bias_param) - return cast(torch.Tensor, reduced_output) # Cast to ensure Tensor type + return cast(torch.Tensor, reduced_output) # --------------------------- Public helper --------------------------- # diff --git a/clt/parallel/ops.py b/clt/parallel/ops.py new file mode 100644 index 0000000..a035b23 --- /dev/null +++ b/clt/parallel/ops.py @@ -0,0 +1,117 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp, Work +from typing import Optional, List + +# Re-export common ReduceOp for convenience. +# Users can import these from this module (e.g., from clt.parallel.ops import SUM) +SUM = ReduceOp.SUM +AVG = ReduceOp.AVG +PRODUCT = ReduceOp.PRODUCT +MIN = ReduceOp.MIN +MAX = ReduceOp.MAX +BAND = ReduceOp.BAND +BOR = ReduceOp.BOR +BXOR = ReduceOp.BXOR + + +def is_dist_initialized_and_available() -> bool: + """Checks if torch.distributed is available and initialized.""" + return dist.is_available() and dist.is_initialized() + + +def get_rank(group: Optional[ProcessGroup] = None) -> int: + """Returns the rank of the current process in the group. + Returns 0 if distributed is not initialized or not available. + """ + if not is_dist_initialized_and_available(): + return 0 + return dist.get_rank(group=group) + + +def get_world_size(group: Optional[ProcessGroup] = None) -> int: + """Returns the world size of the given process group. + Returns 1 if distributed is not initialized or not available. + """ + if not is_dist_initialized_and_available(): + return 1 + return dist.get_world_size(group=group) + + +def is_main_process(group: Optional[ProcessGroup] = None) -> bool: + """Checks if the current process is the main process (rank 0).""" + return get_rank(group=group) == 0 + + +def all_reduce( + tensor: torch.Tensor, + op: ReduceOp = SUM, # Default to SUM + group: Optional[ProcessGroup] = None, + async_op: bool = False, +) -> Optional[Work]: + """Reduces the tensor data across all machines. + + Args: + tensor: Input and output of the collective. The function operates in-place. + op: The reduction operation (e.g., ReduceOp.SUM, ReduceOp.PRODUCT). + group: The process group to work on. If None, the default process group will be used. + async_op: Whether this op should be an async op. + + Returns: + A Work object if async_op is True, otherwise None. + Returns None if distributed is not initialized or world_size is 1, as no actual communication occurs. + """ + if not is_dist_initialized_and_available() or get_world_size(group=group) == 1: + return None + return dist.all_reduce(tensor, op=op, group=group, async_op=async_op) + + +def broadcast( + tensor: torch.Tensor, + src: int, + group: Optional[ProcessGroup] = None, + async_op: bool = False, +) -> Optional[Work]: + """Broadcasts the tensor to the whole group. + + Args: + tensor: Data to be sent if src is the rank of current process, + or tensor to be used to save received data otherwise. + src: Source rank. + group: The process group to work on. If None, the default process group will be used. + async_op: Whether this op should be an async op. + + Returns: + A Work object if async_op is True, otherwise None. + Returns None if distributed is not initialized or world_size is 1, as no actual communication occurs. + """ + if not is_dist_initialized_and_available() or get_world_size(group=group) == 1: + return None + return dist.broadcast(tensor, src=src, group=group, async_op=async_op) + + +def all_gather( + tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group: Optional[ProcessGroup] = None, + async_op: bool = False, +) -> Optional[Work]: + """Gathers tensors from the whole group in a list. + + Args: + tensor_list: Output list. It should contain correctly-sized tensors to be used for output of the collective. + tensor: Tensor to be broadcast from current process. + group: The process group to work on. If None, the default process group will be used. + async_op: Whether this op should be an async op. + + Returns: + A Work object if async_op is True, otherwise None. + If distributed is not initialized, it places the input tensor into tensor_list[0] (assuming single process context). + """ + if not is_dist_initialized_and_available(): + rank = get_rank(group) + if rank < len(tensor_list): + tensor_list[rank] = tensor # pyright: ignore[reportGeneralTypeIssues] + return None + + return dist.all_gather(tensor_list, tensor, group=group, async_op=async_op) diff --git a/clt/training/distributed_utils.py b/clt/training/distributed_utils.py index a1908b6..f0b8179 100644 --- a/clt/training/distributed_utils.py +++ b/clt/training/distributed_utils.py @@ -1,5 +1,5 @@ -import torch.distributed as dist from typing import TYPE_CHECKING +from clt.parallel import ops as dist_ops if TYPE_CHECKING: from clt.models.clt import CrossLayerTranscoder # For type hinting model parameters @@ -30,5 +30,5 @@ def average_shared_parameter_grads(model: "CrossLayerTranscoder", world_size: in # Only average if explicitly marked as replicated. # The p.dim() == 1 heuristic was too broad and could incorrectly average sharded 1D parameters (e.g., encoder biases). if is_rep: - dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + dist_ops.all_reduce(p.grad, op=dist_ops.SUM) p.grad /= world_size diff --git a/tests/unit/test_parallel_ops.py b/tests/unit/test_parallel_ops.py new file mode 100644 index 0000000..3e21b6d --- /dev/null +++ b/tests/unit/test_parallel_ops.py @@ -0,0 +1,177 @@ +import torch +import pytest +from unittest.mock import patch + +from clt.parallel import ops as dist_ops + + +# Fixture to simulate dist not being initialized +@pytest.fixture +def mock_dist_not_initialized(): + with patch("torch.distributed.is_available", return_value=True), patch( + "torch.distributed.is_initialized", return_value=False + ): + yield + + +# Fixture to simulate dist being initialized, single process (world_size=1, rank=0) +@pytest.fixture +def mock_dist_initialized_single_process(): + with patch("torch.distributed.is_available", return_value=True), patch( + "torch.distributed.is_initialized", return_value=True + ), patch("torch.distributed.get_rank", return_value=0), patch("torch.distributed.get_world_size", return_value=1): + yield + + +# Fixture to simulate dist being initialized, multi-process (world_size=2, rank=1 as example) +@pytest.fixture +def mock_dist_initialized_multi_process(): + with patch("torch.distributed.is_available", return_value=True), patch( + "torch.distributed.is_initialized", return_value=True + ), patch("torch.distributed.get_rank", return_value=1), patch("torch.distributed.get_world_size", return_value=2): + yield + + +def test_is_dist_initialized_and_available_not_initialized(mock_dist_not_initialized): + assert not dist_ops.is_dist_initialized_and_available() + + +def test_is_dist_initialized_and_available_initialized(mock_dist_initialized_single_process): + assert dist_ops.is_dist_initialized_and_available() + + +def test_get_rank_not_initialized(mock_dist_not_initialized): + assert dist_ops.get_rank() == 0 + + +def test_get_rank_initialized_single_process(mock_dist_initialized_single_process): + assert dist_ops.get_rank() == 0 + + +def test_get_rank_initialized_multi_process(mock_dist_initialized_multi_process): + # Our mock torch.distributed.get_rank is set to return 1 + assert dist_ops.get_rank() == 1 + + +def test_get_world_size_not_initialized(mock_dist_not_initialized): + assert dist_ops.get_world_size() == 1 + + +def test_get_world_size_initialized_single_process(mock_dist_initialized_single_process): + assert dist_ops.get_world_size() == 1 + + +def test_get_world_size_initialized_multi_process(mock_dist_initialized_multi_process): + # Our mock torch.distributed.get_world_size is set to return 2 + assert dist_ops.get_world_size() == 2 + + +def test_is_main_process_not_initialized(mock_dist_not_initialized): + assert dist_ops.is_main_process() + + +def test_is_main_process_initialized_rank_0(mock_dist_initialized_single_process): + # This fixture sets rank to 0 + assert dist_ops.is_main_process() + + +def test_is_main_process_initialized_rank_1(mock_dist_initialized_multi_process): + # This fixture sets rank to 1 + assert not dist_ops.is_main_process() + + +# Tests for collective wrappers in non-initialized state + + +def test_all_reduce_not_initialized(mock_dist_not_initialized): + tensor = torch.tensor([1.0, 2.0]) + original_tensor = tensor.clone() + work_obj = dist_ops.all_reduce(tensor) + assert work_obj is None + assert torch.equal(tensor, original_tensor) # Should be a no-op + + +def test_all_reduce_initialized_single_process(mock_dist_initialized_single_process): + tensor = torch.tensor([1.0, 2.0]) + original_tensor = tensor.clone() + # We need to mock the actual dist.all_reduce since it might be called if initialized + with patch("torch.distributed.all_reduce") as mock_actual_all_reduce: + work_obj = dist_ops.all_reduce(tensor) + assert work_obj is None # Our wrapper returns None for world_size = 1 + mock_actual_all_reduce.assert_not_called() # Should not call actual dist op + assert torch.equal(tensor, original_tensor) + + +def test_broadcast_not_initialized(mock_dist_not_initialized): + tensor = torch.tensor([1.0, 2.0]) + original_tensor = tensor.clone() + work_obj = dist_ops.broadcast(tensor, src=0) + assert work_obj is None + assert torch.equal(tensor, original_tensor) # Should be a no-op + + +def test_broadcast_initialized_single_process(mock_dist_initialized_single_process): + tensor = torch.tensor([1.0, 2.0]) + original_tensor = tensor.clone() + with patch("torch.distributed.broadcast") as mock_actual_broadcast: + work_obj = dist_ops.broadcast(tensor, src=0) + assert work_obj is None + mock_actual_broadcast.assert_not_called() + assert torch.equal(tensor, original_tensor) + + +def test_all_gather_not_initialized(mock_dist_not_initialized): + tensor = torch.tensor([1.0, 2.0]) + tensor_list = [torch.empty_like(tensor) for _ in range(2)] # Example list + + work_obj = dist_ops.all_gather(tensor_list, tensor) + assert work_obj is None + # In non-initialized case, tensor_list[0] should contain the tensor + assert torch.equal(tensor_list[0], tensor) + # Other elements of tensor_list should remain unchanged if not rank 0 + # (assuming rank is 0 in non-initialized state, as per get_rank logic) + assert torch.equal(tensor_list[1], torch.empty_like(tensor)) # Or its original value + + +def test_all_gather_initialized_single_process(mock_dist_initialized_single_process): + tensor = torch.tensor([1.0, 2.0]) + # For world_size = 1, tensor_list should have at least one element + tensor_list = [torch.empty_like(tensor)] + + with patch("torch.distributed.all_gather") as mock_actual_all_gather: + dist_ops.all_gather(tensor_list, tensor) + # In single process, dist.all_gather may or may not be called by the underlying + # torch.distributed.all_gather depending on its implementation. + # Our wrapper for world_size=1 would try to call it. + # The critical part for our wrapper is that it *should* call the underlying if initialized. + # However, for world_size=1, dist.all_gather itself should effectively + # behave like a copy from input to tensor_list[0]. + + # If dist_ops.all_gather directly handles world_size=1 by not calling dist.all_gather: + # mock_actual_all_gather.assert_not_called() + # assert torch.equal(tensor_list[0], tensor) + + # If dist_ops.all_gather calls dist.all_gather which handles world_size=1: + mock_actual_all_gather.assert_called_once() + # We can't easily assert tensor_list[0] without knowing mock_actual_all_gather's behavior + # For now, just ensure our wrapper attempts the call. + # The behavior of actual dist.all_gather in ws=1 is that it populates tensor_list[rank] + + # Let's refine the logic in dist_ops.all_gather for ws=1 if it's not calling the backend. + # Current `dist_ops.all_gather` calls `dist.all_gather` if initialized. + # So, mock_actual_all_gather *should* be called. + # To test the outcome, we can make the mock_actual_all_gather simulate the copy. + def mock_all_gather_side_effect(out_list, in_tensor, group=None, async_op=False): + out_list[0] = in_tensor # Simulate behavior for rank 0, world_size 1 + return None # Simulate no Work object for sync op + + with patch("torch.distributed.all_gather", side_effect=mock_all_gather_side_effect) as mock_actual_all_gather_ws1: + work_obj_ws1 = dist_ops.all_gather(tensor_list, tensor) + assert work_obj_ws1 is None + mock_actual_all_gather_ws1.assert_called_once() + assert torch.equal(tensor_list[0], tensor) + + +# Example of how one might test a specific ReduceOp re-export +def test_sum_op_export(): + assert dist_ops.SUM == torch.distributed.ReduceOp.SUM