diff --git a/returnn/datasets/basic.py b/returnn/datasets/basic.py index 1be9a49858..42b8cd505a 100644 --- a/returnn/datasets/basic.py +++ b/returnn/datasets/basic.py @@ -24,7 +24,7 @@ from returnn.log import log from returnn.engine.batch import Batch, BatchSetGenerator from returnn.datasets.util.vocabulary import Vocabulary -from returnn.util.basic import try_run, NumbersDict, OptionalNotImplementedError +from returnn.util.basic import get_fwd_compat_kwargs, try_run, NumbersDict, OptionalNotImplementedError from returnn.util import file_cache from returnn.tensor import TensorDict @@ -1050,7 +1050,7 @@ def iterate_seqs(self, recurrent_net=True, used_data_keys=None): :rtype: list[(int,NumbersDict,NumbersDict)] """ if self.custom_chunking_func: - sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None} + sentinel_kw = get_fwd_compat_kwargs() for seq_idx, t_start, t_end in self.custom_chunking_func( dataset=self, seq_idx_start=0, recurrent_net=recurrent_net, used_data_keys=used_data_keys, **sentinel_kw ): diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 5175d3123d..43a60fce1c 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -12,6 +12,7 @@ from returnn.datasets.util.vocabulary import Vocabulary from returnn.tensor import Tensor, TensorDict from returnn.tensor.dim import Dim +from returnn.util.basic import get_fwd_compat_kwargs from .basic import init_dataset from .cached2 import CachedDataset2 @@ -205,9 +206,7 @@ def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDi data_iter = self._iterate_dataset() if self._map_seq_stream is not None: - data_iter = self._map_seq_stream( - data_iter, rng=self._rng, **{f"fwd_compatible_random_kwarg_{self._rng.randint(0, 1000)}": None} - ) + data_iter = self._map_seq_stream(data_iter, rng=self._rng, **get_fwd_compat_kwargs()) assert isinstance( data_iter, Iterator ), f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}" @@ -226,9 +225,7 @@ def _iterate_dataset(self) -> Iterator[TensorDict]: for data_key in data_keys: tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key) if self._map_seq is not None: - tensor_dict = self._map_seq( - tensor_dict, rng=self._rng, **{f"fwd_compatible_random_kwarg_{self._rng.randint(0, 1000)}": None} - ) + tensor_dict = self._map_seq(tensor_dict, rng=self._rng, **get_fwd_compat_kwargs()) assert isinstance( tensor_dict, TensorDict ), f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index 3bb2cf8258..83633dfac6 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -3,20 +3,77 @@ """ from __future__ import annotations -from typing import Optional, Any, Dict +from abc import abstractmethod, ABC +import logging import os import socket -import logging +from typing import Callable, Optional, Any, Dict, Type, Union import torch from torch.nn.parallel import DistributedDataParallel -from returnn.config import Config -from returnn.util.basic import CollectionReadCheckCovered +from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError, get_fwd_compat_kwargs _logger = logging.getLogger("returnn.torch.distributed") +class ParamSynchronizer(ABC): + """ + Custom parameter synchronization primitive. + + Contains a callback that is called after every train step to synchronize model parameters + across processes/nodes. + """ + + @abstractmethod + def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **_kwargs): + """ + `__init__` called after the default global process group is created. + Can be used to initialize any additional custom process (sub)groups. + + Note the `__init__` is passed a randomly named kwarg on every invocation to ensure forwards compatibility. + + :param rank: global rank of the current process across all nodes + :param size: global world size across all nodes + :param local_rank: local rank of the current process on the current node + :param local_rank: local world size on the current node + :param _kwargs: any additional kwargs + """ + super().__init__() + + self.rank = rank + self.size = size + self.local_rank = local_rank + self.local_size = local_size + + def make_distributed_model(self, *, module: torch.nn.Module, **kwargs) -> DistributedDataParallel: + """ + Creates an associated `DistributedDataParallel` for the given module for gradient synchronization. + + This function can be left unimplemented if no gradient synchronization is done. + + Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility. + """ + raise OptionalNotImplementedError + + @abstractmethod + def step(self, *, module: torch.nn.Module, train_step_idx: int, **kwargs): + """ + Parameter synchronization callback called after every train step with updated model parameters. + + Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility. + + :param module: the NN being trained + :param train_step_idx: the current train step + :param kwargs: any additional kwargs + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs): + """forwards to :func:``step``""" + return self.step(*args, **kwargs) + + class DistributedContext: """ This class setups some helper functions for torch distributed training @@ -42,8 +99,13 @@ def __init__(self, options: Dict[str, Any]): % (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size) ) + self._custom_sync_class: Optional[Union[Callable, Type[ParamSynchronizer]]] = self._opts.get( + "synchronizer", None + ) + self._custom_sync: Optional[Callable] = None self._reduce_type = self._opts.get("reduce_type", "grad") self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None) + if self._reduce_type == "param": assert isinstance(self._param_sync_step, int) and self._param_sync_step > 0, ( f"reduce_type param: param_sync_step must be a positive int," @@ -52,6 +114,23 @@ def __init__(self, options: Dict[str, Any]): _logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}") elif self._reduce_type == "grad": _logger.info("reduce_type grad") + elif self._reduce_type == "custom": + if issubclass(self._custom_sync_class, ParamSynchronizer): + self._custom_sync = self._custom_sync_class( + rank=self._rank, + size=self._size, + local_rank=self._local_rank, + local_size=self._local_size, + **get_fwd_compat_kwargs(), + ) + elif isinstance(self._custom_sync_class, Callable): + self._custom_sync = self._custom_sync_class + else: + raise ValueError( + f"synchronizer must either be a callable or a class inheriting from {ParamSynchronizer.__name__}" + ) + + _logger.info(f"reduce_type custom: {type(self._custom_sync)}") else: raise ValueError(f"invalid reduce_type {self._reduce_type!r}") @@ -70,6 +149,8 @@ def _check_no_unknown_opts(self): self._opts.get("options") if self._reduce_type == "param": self._opts.get("sync_on_cpu") + if self._reduce_type == "custom": + self._opts.get("synchronizer") self._opts.assert_all_read() @@ -102,7 +183,22 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis """ if self._reduce_type == "param": return None - assert self._reduce_type == "grad" + assert self._reduce_type in ["custom", "grad"] + + if self._reduce_type == "custom": + assert isinstance(self._custom_sync, (ParamSynchronizer, Callable)) + + if isinstance(self._custom_sync, ParamSynchronizer): + try: + return self._custom_sync.make_distributed_model(module=module, **get_fwd_compat_kwargs()) + except OptionalNotImplementedError: + pass + else: + # callable short form does not have support for DistributedDataParallel + pass + + return None + cls = self._opts.get("class", DistributedDataParallel) if cls is not DistributedDataParallel: _logger.warning(f"Using custom class {cls} instead of DistributedDataParallel, might be unsupported.") @@ -115,7 +211,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int): """one train step""" - if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): + if self._reduce_type == "custom": + with torch.no_grad(): # TODO: do we want this for all syncers? + self._custom_sync( + module=module, + train_step_idx=epoch_step_idx, + **get_fwd_compat_kwargs(), + ) + elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): _sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False)) @@ -127,7 +230,7 @@ def get_ctx(config=None) -> Optional[DistributedContext]: """ :param Config|None config: :returns: the global context if Torch distributed is enabled, or None otherwise. - If we did not setup the context yet, it will automatically create it. + If we did not set up the context yet, it will automatically create it. """ global _is_set_up, _ctx if _is_set_up: @@ -155,7 +258,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False): if sync_on_cpu: for param in module.parameters(): - # Separately move each param to CPU (instead of the whole module), to safe CPU memory. + # Separately move each param to CPU (instead of the whole module), to save CPU memory. param_cpu = param.to(torch.device("cpu")) # On CPU, we are likely using Gloo, and Gloo does not support AVG dist.all_reduce(param_cpu.data, op=dist.ReduceOp.SUM) @@ -166,12 +269,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False): if dist.get_backend() == "gloo": # Gloo does not support AVG reduce_op = dist.ReduceOp.SUM + elif hasattr(dist.ReduceOp, "AVG"): + reduce_op = dist.ReduceOp.AVG else: - if hasattr(dist.ReduceOp, "AVG"): - reduce_op = dist.ReduceOp.AVG - else: - # Older PyTorch versions do not have ReduceOp.AVG. - reduce_op = dist.ReduceOp.SUM + # Older PyTorch versions do not have ReduceOp.AVG. + reduce_op = dist.ReduceOp.SUM for param in module.parameters(): dist.all_reduce(param.data, op=reduce_op) diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index 7b2c81d6fb..49896aecc8 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -17,7 +17,6 @@ from torch.utils.data import DataLoader from torch import autocast from torch.cuda import amp -from random import random import math import returnn @@ -680,7 +679,7 @@ def _run_step( if self._use_autocast else nullcontext() ), rf.set_default_device_ctx(self._device): - sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None} + sentinel_kw = util.get_fwd_compat_kwargs() if train_func: self._train_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw) else: @@ -846,7 +845,7 @@ def _load_model(self): if self._use_autocast else nullcontext() ), rf.set_default_device_ctx(self._device): - sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None} + sentinel_kw = util.get_fwd_compat_kwargs() for hook in load_model_post_hooks: hook(model=self._orig_model, **sentinel_kw) @@ -876,7 +875,7 @@ def _create_model(self, *, epoch: int, step: int): get_model_func = self.config.typed_value("get_model") assert get_model_func, "get_model not defined in config" - sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None} + sentinel_kw = util.get_fwd_compat_kwargs() model = get_model_func(epoch=epoch, step=step, **sentinel_kw) self._orig_model = model if isinstance(model, rf.Module): diff --git a/returnn/util/basic.py b/returnn/util/basic.py index 26fb244844..a696d72422 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -4586,3 +4586,15 @@ def override_env_var(var_name: str, value: str): os.environ[var_name] = cur_val else: os.environ.pop(var_name) + + +_fwd_compat_rng = np.random.default_rng() + + +def get_fwd_compat_kwargs() -> Dict[str, Any]: + """ + Returns a dictionary suitable for passing as kwargs for any RETURNN userland + function where forwards compatibility wrt. additional arguments must be + ensured. + """ + return {f"fwd_compatible_random_kwarg_{_fwd_compat_rng.integers(0, 100)}": None} diff --git a/tools/torch_export_to_onnx.py b/tools/torch_export_to_onnx.py index 154387bcb4..99139acc19 100644 --- a/tools/torch_export_to_onnx.py +++ b/tools/torch_export_to_onnx.py @@ -35,7 +35,6 @@ from typing import Callable, Optional, Dict, List import argparse import os -from random import random import _setup_returnn_env # noqa from returnn.config import Config @@ -204,7 +203,7 @@ def main(): get_model_func = config.typed_value("get_model") assert get_model_func, "get_model() isn't specified in the config passed as a parameter." - sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None} + sentinel_kw = util.get_fwd_compat_kwargs() model = get_model_func(epoch=epoch, step=step, **sentinel_kw) is_rf_module = isinstance(model, rf.Module)