diff --git a/distml/operator/base_operator.py b/distml/operator/base_operator.py index 892134b..9af342d 100644 --- a/distml/operator/base_operator.py +++ b/distml/operator/base_operator.py @@ -1,6 +1,7 @@ """Abstract class for framework-specific training operators.""" from abc import ABCMeta from abc import abstractmethod +from typing import Optional class TrainingOperator(metaclass=ABCMeta): @@ -90,7 +91,7 @@ def load_custom_states(self, states, *args, **kwargs): pass @abstractmethod - def save_states(self, checkpoint): + def save_states(self, checkpoint: str): """Save the states to a file path. This function shall be instantiated in framework-specific operator @@ -104,7 +105,10 @@ def get_states(self): raise NotImplementedError() @abstractmethod - def load_states(self, checkpoint): + def load_states(self, + states=None, + checkpoint: Optional[str] = None, + keys: Optional[bool] = None): """Load the states from a file path. This functions shall be instantiated in framework-specific operators diff --git a/distml/operator/jax_operator.py b/distml/operator/jax_operator.py index 05a700d..a738640 100644 --- a/distml/operator/jax_operator.py +++ b/distml/operator/jax_operator.py @@ -1,5 +1,12 @@ +import os +import pickle +import warnings + +from typing import Any, Mapping, Optional, List, Dict + import numpy as np import cupy as cp + import jax from jax import value_and_grad import jax.numpy as jnp @@ -13,7 +20,7 @@ class JAXTrainingOperator(TrainingOperator): - def __init__(self, operator_config): + def __init__(self, *, operator_config: Optional[Mapping[str, Any]]): super(JAXTrainingOperator, self).__init__(operator_config) # Should be set by users in the `register` function. # model methods @@ -26,11 +33,14 @@ def __init__(self, operator_config): self.get_params = None self.criterion = None + self.lr_scheduler = None # Data loaders for training and validation, registered by users. self._train_loader = None self._validation_loader = None + self._custom_states = None + self.setup(operator_config) if hasattr(operator_config, "jit_mode"): @@ -63,7 +73,7 @@ def setup(self, *args, **kwargs): raise NotImplementedError("Please override this function to register " "your model, optimizer, and criterion.") - def register(self, *, model, optimizer, criterion, jit_mode=False): + def register(self, *, model, optimizer, criterion, jit_mode: bool = False): """Register a few critical information about the model to operator. Args: @@ -93,7 +103,7 @@ def register(self, *, model, optimizer, criterion, jit_mode=False): "'opt_init', 'opt_update' and 'get_params'." "Got: {} {}".format(type(optimizer), len(optimizer))) - if not hasattr(criterion, "__call__"): + if not callable(criterion): raise RuntimeError( "The `criterion` must be callable function that " "feed logits and target, return the loss value. " @@ -113,12 +123,12 @@ def _register_model(self, model): "`opt_states` return from optimizer `opt_init`. " "Got: {}".format(type(model[0]))) - if not hasattr(model[1], "__call__"): + if not callable(model[1]): raise RuntimeError("The second elemente of `model` must be the " "`init_fun` return from model. " "Got: {}".format(type(model[1]))) - if not hasattr(model[2], "__call__"): + if not callable(model[2]): raise RuntimeError("The third elemente of `model` must be the " "`predict_fun` return from model. " "Got: {}".format(type(model[2]))) @@ -129,18 +139,18 @@ def _register_model(self, model): def _register_optimizer(self, optimizer): """register optimizer components.""" - if not hasattr(optimizer[0], "__call__"): + if not callable(optimizer[0]): raise RuntimeError("The fist elemente of `optimizer` must be the " "`opt_init` return from optimizer. " "Got: {}".format(type(optimizer[1]))) - if not hasattr(optimizer[1], "__call__"): + if not callable(optimizer[1]): raise RuntimeError( "The second elemente of `optimizer` must be the " "`opt_update` return from optimizer. " "Got: {}".format(type(optimizer[1]))) - if not hasattr(optimizer[2], "__call__"): + if not callable(optimizer[2]): raise RuntimeError("The third elemente of `optimizer` must be the " "`get_params` return from optimizer. " "Got: {}".format(type(optimizer[2]))) @@ -264,15 +274,15 @@ def validate_batch(self, batch): targets_class = jnp.argmax(targets, axis=1) acc = jnp.mean(prediction_class == targets_class) - samples_num = targets.shape[0] + num_sample = targets.shape[0] return { "val_loss": loss.item(), "val_accuracy": acc.item(), - "samples_num": samples_num + "num_sample": num_sample } - def get_parameters(self, cpu): + def get_parameters(self, cpu: bool) -> List: """get the flatten parameters.""" params = self.get_params(self.opt_state) flatten_params, tree = tree_flatten(params) @@ -281,9 +291,11 @@ def get_parameters(self, cpu): if cpu: flatten_params = list(map(np.asarray, flatten_params)) + else: + flatten_params = list(map(jnp.asarray, flatten_params)) return flatten_params - def get_named_parameters(self, cpu): + def get_named_parameters(self, cpu: bool) -> Dict: """Get the named parameters. In jax, we need to construct a dict to contain the parameters. @@ -296,6 +308,7 @@ def get_named_parameters(self, cpu): } else: dict_params = {f"{idx}": p for idx, p in enumerate(params)} + return dict_params # TODO(HUI): used in load states or load parameters @@ -309,6 +322,9 @@ def set_parameters(self, new_params): """ assert isinstance(new_params, dict) + # make sure all params in GPU. Should be controlled of use_gpu. + new_params = {k: jax.device_put(v) for k, v in new_params.items()} + keys, new_params = unzip2( sorted(new_params.items(), key=lambda d: int(d[0]))) self.preset_keys = keys @@ -334,7 +350,7 @@ def update(param, state): zip(subtrees, new_subtrees)): if new_subtree != subtree: msg = ( - "input structur did not match the save params struture. " + "input structure did not match the save params structure. " "input {} and output {}.") raise TypeError(msg.format(subtree, new_subtree)) @@ -346,29 +362,153 @@ def reset_optimizer_for_params(self, params): "Got {}".format(type(params))) keys, params = unzip2(sorted(params.items(), key=lambda d: int(d[0]))) + + self.preset_keys = keys # The keys to index the params. self.tree = tree_structure(params) self.opt_state = self.opt_init(params) + def ones(self, shape, cpu: bool = True): + if cpu: + return np.ones(shape) + else: + return jnp.ones(shape) + + def zeros(self, shape, cpu: bool = True): + if cpu: + return np.zeros(shape) + else: + return jnp.zeros(shape) + + def ones_like(self, x, cpu: bool = True): + if cpu: + return np.ones_like(x) + else: + return jnp.ones_like(x) + + def zeros_like(self, x, cpu: bool = True): + if cpu: + return np.zeros_like(x) + else: + return jnp.zeros_like(x) + + def numel(self, v): + return np.size(v) + + def asarray(self, v): + return jnp.asarray(v) + def clean_redundancy(self): - del self._train_loader - del self._validation_loader + if self._train_loader: + del self._train_loader + self._train_loader = None + if self._validation_loader: + del self._validation_loader + self._validation_loader = None - # TODO(HUI): use pickle to serialize parameters or states and save it. - def save_parameters(self, checkpoint): - raise NotImplementedError( - "save_parameters is not support in jax operator.") + def register_custom_states(self, custom_states): + self._custom_states = custom_states - def load_parameters(self, checkpoint): - raise NotImplementedError( - "load_parameters is not support in jax operator.") + def get_custom_states(self): + return self._custom_states - def save_states(self, checkpoint): - raise NotImplementedError( - "save_states is not support in jax operator.") + def get_states(self) -> Dict: + """Return the states of this training operator.""" - def get_states(self): - raise NotImplementedError("get_states is not support in jax operator.") + states_flat, tree, subtrees = self.opt_state + + states_unflat = map(tree_unflatten, subtrees, states_flat) + + states_unflat_dict = { + str(idx): value + for idx, value in enumerate(states_unflat) + } - def load_states(self, checkpoint): - raise NotImplementedError( - "load_states is not support in jax operator.") + states = { + "opt_state": states_unflat_dict, + } + + if self._custom_states: + states.update({"custom": self.get_custom_states()}) + + if self.lr_scheduler and hasattr(self.lr_scheduler, + "get_state_dict()"): + states.update({"lr_scheduler": self.lr_scheduler.get_state_dict()}) + + return states + + def save_states(self, checkpoint: str): + states = self.get_states() + with open(checkpoint, "wb") as f: + pickle.dump(states, f) + + def load_states(self, + states=None, + checkpoint: Optional[str] = None, + keys: Optional[bool] = None): + if checkpoint: + assert ".pkl" in checkpoint, \ + "checkpoint should be a .pkl file. Got {}".format(checkpoint) + if not os.path.exists(checkpoint): + raise RuntimeError("Checkpoint file doesn't exists.") + with open(checkpoint, "rb") as f: + states = pickle.load(f) + + if states: + new_opt_states = states.get("opt_state", None) + custom_states = states.get("custom_states", None) + lr_scheduler_states = states.get("lr_scheduler", None) + + if not new_opt_states: + raise RuntimeError("subtrees of new params is empty.") + + assert isinstance(new_opt_states, dict) + + if not keys: + keys = tuple([ + str(idx) + for idx in range(len(self.get_parameters(cpu=False))) + ]) + else: + # construct_opt_states_dict = OrderedDict() + construct_opt_states_dict = dict() + for key in keys: + construct_opt_states_dict[key] = new_opt_states[key] + new_opt_states = construct_opt_states_dict + + new_keys, new_opt_states = unzip2( + sorted(new_opt_states.items(), key=lambda d: int(d[0]))) + + keys = tuple(keys) + new_keys = tuple(new_keys) + assert keys == new_keys, \ + "checkpoint key doesn't match the model params." + + states_flat, tree, subtrees = self.opt_state + states_flat_2, subtrees_2 = unzip2( + map(tree_flatten, new_opt_states)) + + if not subtrees_2: + raise RuntimeError("subtrees of new params is empty.") + for idx, (subtree, subtree_2) in enumerate( + zip(subtrees, subtrees_2)): + if subtree_2 != subtree: + msg = ("input structure did not match the save params " + "structure. input {} and output {}.") + raise TypeError(msg.format(subtree, subtree_2)) + + self.opt_state = OptimizerState(states_flat_2, tree, subtrees_2) + + if custom_states: + self._custom_states.update(custom_states) + + if lr_scheduler_states: + if hasattr(self.lr_scheduler, "set_states_dict"): + self.lr_scheduler.set_states_dict(lr_scheduler_states) + else: + warnings.warn( + "lr scheduler must have `set_states_dict` method" + " to support loading lr scheduler states.") + else: + raise RuntimeError("This checkpoint is empty." + "Got checkpoint {}, states {}".format( + checkpoint, states)) diff --git a/distml/operator/torch_operator.py b/distml/operator/torch_operator.py index 3ddd667..d0a98f8 100644 --- a/distml/operator/torch_operator.py +++ b/distml/operator/torch_operator.py @@ -168,9 +168,55 @@ def validate_batch(self, batch): loss = criterion(output, target) # Todo(Hao): report accuracy instead loss here. - batch_metric = {"val_loss": loss.item()} + batch_metric = {"val_loss": loss.item(), "num_sample": target.size(0)} return batch_metric + def get_named_parameters(self, cpu): + named_params = self._model.named_parameters() + is_cuda = next(self._model.parameters()).is_cuda + output_params = {} + + if cpu: + if is_cuda: + for key, p in named_params: + output_params[key] = p.cpu() + else: + for key, p in named_params: + output_params[key] = p + else: + if not is_cuda: + for key, p in named_params: + # TODO(HUI): should put in specific device. + named_params[key] = p.cuda() + else: + for key, p in named_params: + output_params[key] = p + + return output_params + + def get_parameters(self, cpu): + params = self._model.parameters() + is_cuda = next(self._model.parameters()).is_cuda + output_params = [] + + if cpu: + if is_cuda: + for p in params: + output_params.append(p.cpu()) + else: + for p in params: + output_params.append(p) + else: + if not is_cuda: + for idx, p in enumerate(params): + # TODO(HUI): should put in specific device. + output_params(p.cuda()) + else: + for p in params: + output_params.append(p) + + return output_params + def get_states(self): """Return the states of this training operator.""" states = { @@ -196,12 +242,47 @@ def load_states(self, states=None, checkpoint=None): self._lr_scheduler.load_state_dict(states["lr_scheduler"]) self.load_custom_states(states["custom"]) + def _load_from_checkpoint(self, checkpoint): + return torch.load(checkpoint) + def save_states(self, checkpoint): """Save the states to a file path.""" states = self.get_states() # TODO(Hao): test this. torch.save(states, checkpoint) + def clean_redundancy(self): + del self._train_loader + del self._validation_loader + + def set_parameters(self, params): + if isinstance(params, dict): + self._model.load_state_dict(params) + else: + raise RuntimeError("params is not dict." + "Got {}".format(type(params))) + + def reset_optimizer_for_params(self, params): + if isinstance(params, dict): + params_list = [] + + for k, v in params.items(): + params_list.append(v) + params = params_list + + _optimizer = self._optimizer + + _optimizer.param_groups = [] + + param_groups = list(params) + if len(param_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + if not isinstance(param_groups[0], dict): + param_groups = [{'params': param_groups}] + + for param_group in param_groups: + _optimizer.add_param_group(param_group) + @staticmethod def _get_gradients(model): """Return the gradient updates of the model as a Python dict. @@ -241,3 +322,26 @@ def _set_gradients(model, grads): # to(p.grad.device) # else: # p.grad = torch.from_numpy(gradients[name]) + + def ones(self, shape, cpu: bool = True): + tensor = torch.ones(shape) + return tensor if cpu else tensor.cuda() + + def zeros(self, shape, cpu: bool = True): + tensor = torch.zeros(shape) + return tensor if cpu else tensor.cuda() + + def ones_like(self, x, cpu: bool = True): + tensor = torch.ones_like(x) + return tensor if cpu else tensor.cuda() + + def zeros_like(self, x, cpu: bool = True): + tensor = torch.zeros_like(x) + return tensor if cpu else tensor.cuda() + + @staticmethod + def numel(tensor): + return tensor.numel() + + def asarray(self, v): + return torch.as_tensor(v) diff --git a/distml/strategy/allreduce_strategy.py b/distml/strategy/allreduce_strategy.py index 5eb92b8..1cc3d82 100644 --- a/distml/strategy/allreduce_strategy.py +++ b/distml/strategy/allreduce_strategy.py @@ -1,12 +1,14 @@ import logging +from typing import List, Callable, Mapping, Any, Optional, Dict import ray import ray.util.collective as col -from distml.strategy.base_strategy import BaseStrategy -from distml.util import ThroughputCollection +from ray.util.sgd.utils import AverageMeterCollection import numpy as np +from distml.strategy.base_strategy import BaseStrategy, BaseDataParallelGroup + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -29,30 +31,32 @@ class AllReduceStrategy(BaseStrategy): def __init__(self, *, training_operator_cls, - operator_config=None, - initialization_hook=None, - world_size=2, - num_cpus_per_worker=1, - num_gpus_per_worker=1, + operator_config: Optional[Mapping[str, Any]] = None, + initialization_hook: Optional[Callable] = None, + world_size: int = 2, + backend: str = "nccl", + group_name: str = "default", + num_cpus_per_worker: int = 1, + num_gpus_per_worker: int = 1, **kwargs): super(AllReduceStrategy, self). \ __init__(training_operator_cls=training_operator_cls, operator_config=operator_config, initialization_hook=initialization_hook, world_size=world_size, + backend=backend, + group_name=group_name, num_cpus_per_worker=num_cpus_per_worker, num_gpus_per_worker=num_gpus_per_worker, **kwargs) self._global_batch_size = None + if operator_config and operator_config.get("batch_size"): self._global_batch_size = operator_config.get("batch_size") - if self._global_batch_size: - self._collector = ThroughputCollection( - batch_size=self._global_batch_size) - else: - self._collector = ThroughputCollection() - def train(self, num_steps=None): + self._init_strategy() + + def train(self, num_steps: Optional[int] = None) -> Dict: """Run the training on parallel workers. Args: @@ -60,8 +64,10 @@ def train(self, num_steps=None): function will simply train for one epoch. Returns: - None + metric (dict): metric of the training set. """ + # TODO(HUI): metric use hook to control. + # TODO (Hao): add fault tolerance using `max_retries`. steps = num_steps if num_steps \ else self.data_parallel_group.get_data_loader_len() @@ -69,30 +75,46 @@ def train(self, num_steps=None): # TODO(Hao): this call should be hidden inside Replica. self.data_parallel_group.make_iterator() for idx in range(steps): - with self._collector.record("train"): - metrics = self.data_parallel_group.train_batch() + metric = self.data_parallel_group.train_batch() print("Step: {}/{}".format(idx, steps)) - return metrics + return metric - def validate(self, num_steps=None): + def validate(self, num_steps: Optional[int] = None) -> Dict: """Evaluates the model on the validation data. Args: num_steps (int): number of batches to evaluate. If None, the function will simply validate across the entire validation dataset. + + Returns: + metric (dict): metric of the validate set. """ steps = num_steps if num_steps \ else self.data_parallel_group.get_data_loader_len(training=False) + + metrics = [ + AverageMeterCollection() + for _ in range(len(self.data_parallel_group.replicas)) + ] + self.data_parallel_group.make_iterator(training=False) for idx in range(steps): - with self._collector.record("validate"): - batch_metrics = self.data_parallel_group.validate_batch() - self._collector.update( - "validate", val_acc=batch_metrics[0]["val_loss"]) - self._collector.save("validate") + batch_metrics = self.data_parallel_group.validate_batch() + + for metric_idx, metric in enumerate(batch_metrics): + num_sample = metric.pop("num_sample") + metrics[metric_idx].update(metric, n=num_sample) + # TODO: validate result should be the same in all workers - return batch_metrics + return metrics[0].summary() + + def _init_strategy(self): + """Do initialization for the distributed strategy.""" + # All sync with replica 0 + init_weights = self.data_parallel_group.get_named_parameters(cpu=True) + # all replicas get synced + self.data_parallel_group.set_parameters(init_weights) def _start_workers(self): """Create distributed workers on the Ray cluster for distributed training. @@ -111,30 +133,30 @@ def _start_workers(self): # (2) params for setting up collective group and strategy prep-ups. dist_params = dict( strategy="allreduce", - backend="nccl", - group_name="default", + backend=self.backend, + group_name=self.group_name, ) group_init_args = dict( - replica_params=replica_params, + actor_params=replica_params, dist_params=dist_params, initialization_hook=self.initialization_hook, - num_cpus_per_worker=self.num_cpus_per_worker, - num_gpus_per_worker=self.num_gpus_per_worker) + num_cpus_per_actor=self.num_cpus_per_worker, + num_gpus_per_actor=self.num_gpus_per_worker) self.data_parallel_group = DataParallelGroup(**group_init_args) # Once the group is created, we start it. self.data_parallel_group.start_replicas(self.world_size) - def shutdown(self, force=False): + def shutdown(self, force: bool = False): self.data_parallel_group.shutdown(force=force) - def save_parameters(self, checkpoint): - self.data_parallel_group.save_parameters(checkpoint) + def get_states(self): + return self.data_parallel_group.get_states() - def load_parameters(self, checkpoint): - self.data_parallel_group.load_parameters(checkpoint) + def save_states(self, checkpoint: str): + self.data_parallel_group.save_states(checkpoint) - def _init_strategy(self): - pass + def load_states(self, states=None, checkpoint: Optional[str] = None): + self.data_parallel_group.load_states(states, checkpoint) class Replica: @@ -144,7 +166,8 @@ class Replica: and Ray collective group setup. """ - def __init__(self, training_operator_cls, operator_config): + def __init__(self, training_operator_cls, + operator_config: Optional[Mapping[str, Any]]): self.training_operator_cls = training_operator_cls self.operator_config = operator_config # Training operator @@ -165,17 +188,17 @@ def setup_operator(self): operator_config=self.operator_config) def setup_collective_group(self, - rank, - world_size, - backend, - group_name="default"): + rank: int, + world_size: str, + backend: str, + group_name: str = "default"): self._rank = rank self._group_name = group_name self._world_size = world_size col.init_collective_group( world_size, rank, backend=backend, group_name=group_name) - def make_iterator(self, training=True): + def make_iterator(self, training: bool = True): """Convert loader to be an iterator at the start of an epoch.""" # TODO(Hao): need to check whether reaching the boundary of iterator # instead of making a new one every time. @@ -184,7 +207,7 @@ def make_iterator(self, training=True): else: self.validation_iterator = iter(self.validation_loader) - def get_data_loader_len(self, training=True): + def get_data_loader_len(self, training: bool = True) -> int: """Return the number of batches in the data loader.""" loader = self.train_loader if training \ else self.validation_loader @@ -195,11 +218,11 @@ def get_data_loader_len(self, training=True): "Data loader has no attribute `__len__`. " "Please set `num_steps` in `train()` or `validate()`.") - def train_batch(self): + def train_batch(self) -> Dict: metrics = {} try: batch = next(self.train_iterator) - except StopIteration and NameError: + except StopIteration or NameError: self.make_iterator() batch = next(self.train_iterator) loss_val, updates = self.derive_updates(batch) @@ -208,17 +231,15 @@ def train_batch(self): metrics["train_loss"] = loss_val for _, g in updates.items(): cg = self.training_operator.to_cupy(g) - col.allreduce(cg) - # TODO(Hao): this is conflicting with Runhui's code though. + col.allreduce(cg, self.group_name) cg = cg / float(self.world_size) self.apply_updates(updates) return metrics - def derive_updates(self, batch): + def derive_updates(self, batch) -> Dict: return self.training_operator.derive_updates(batch) def apply_updates(self, updates): - # TODO(Hao): conflicting with Runhui's code on averaging grads self.training_operator.apply_updates(updates) def updates_transform(self, updates): @@ -227,7 +248,7 @@ def updates_transform(self, updates): def validate_batch(self): try: batch = next(self.validation_iterator) - except StopIteration and NameError: + except StopIteration or NameError: self.make_iterator(training=False) batch = next(self.validation_iterator) batch_metric = self.training_operator.validate_batch(batch) @@ -240,13 +261,25 @@ def shutdown(self): del self.training_operator return 1 - def save_parameters(self, checkpoint): - self.training_operator.save_parameters(checkpoint) + def get_states(self): + return self.training_operator.get_states() - def load_parameters(self, checkpoint): - self.training_operator.load_parameters(checkpoint) + def save_states(self, checkpoint: str): + self.training_operator.save_states(checkpoint) - def apply(self, fn): + def load_states(self, states=None, checkpoint: Optional[str] = None): + self.training_operator.load_states(states, checkpoint) + + def get_parameters(self, cpu: bool) -> List: + return self.training_operator.get_parameters(cpu) + + def get_named_parameters(self, cpu: bool) -> Dict: + return self.training_operator.get_named_parameters(cpu) + + def set_parameters(self, params): + self.training_operator.set_parameters(params) + + def apply(self, fn: Callable): """Apply a function in the replica process.""" return fn() @@ -271,45 +304,41 @@ def group_name(self): return self._group_name -class DataParallelGroup: - """Spawn a group a replicas for data-parallel training.""" +class DataParallelGroup(BaseDataParallelGroup): + """Spawn a replica group for data-parallel training.""" - def __init__(self, replica_params, dist_params, initialization_hook, - num_cpus_per_worker, num_gpus_per_worker): - self._replica_params = replica_params - self._dist_params = dist_params - - # try to unroll the dist_params - self._backend = self._dist_params["backend"] - self._group_name = self._dist_params["group_name"] - - self._initialization_hook = initialization_hook - self._num_cpus_per_worker = num_cpus_per_worker - self._num_gpus_per_worker = num_gpus_per_worker - self._replicas = None - - @property - def replicas(self): - return self._replicas + def __init__(self, actor_params: Mapping[str, Any], + dist_params: Mapping[str, Any], num_cpus_per_actor: int, + num_gpus_per_actor: int, + initialization_hook: Optional[Callable]): + super(DataParallelGroup, self).__init__( + actor_params=actor_params, + dist_params=dist_params, + num_cpus_per_actor=num_cpus_per_actor, + num_gpus_per_actor=num_gpus_per_actor, + initialization_hook=initialization_hook) @property - def world_size(self): - return len(self._replicas) + def _replica_params(self): + return self._actor_params @property - def backend(self): - return self._backend + def replicas(self): + return self._actors @property def group_name(self): return self._group_name - def start_replicas(self, num_replicas): + def start_replicas(self, num_replicas: int): + self._start_actors(num_replicas) + + def _start_actors(self, num_replicas: int): assert num_replicas > 1 RemoteReplica = ray.remote( - num_cpus=self._num_cpus_per_worker, - num_gpus=self._num_gpus_per_worker)(Replica) - self._replicas = [ + num_cpus=self._num_cpus_per_actor, + num_gpus=self._num_gpus_per_actor)(Replica) + self._actors = [ RemoteReplica.remote(**self._replica_params) for _ in range(num_replicas) ] @@ -327,16 +356,16 @@ def start_replicas(self, num_replicas): operator_setups = self._setup_operator() ray.get(operator_setups) - def _make_iterator(self, training): + def _make_iterator(self, training: bool): return [ replica.make_iterator.remote(training=training) for replica in self.replicas ] - def make_iterator(self, training=True): + def make_iterator(self, training: bool = True): ray.get(self._make_iterator(training=training)) - def get_data_loader_len(self, training=True): + def get_data_loader_len(self, training: bool = True): """Return the number of batches in the data loader.""" lens = ray.get([ replica.get_data_loader_len.remote(training=training) @@ -361,7 +390,7 @@ def validate_batch(self): stats = ray.get(rets) return stats - def shutdown(self, force=False): + def shutdown(self, force: bool = False): rets = [replica.shutdown.remote() for replica in self.replicas] stats = ray.get(rets) return stats @@ -369,14 +398,18 @@ def shutdown(self, force=False): def reset(self): pass - def save_parameters(self, checkpoint): - rets = [self.replicas[0].save_parameters.remote(checkpoint)] + def get_states(self): + rets = [self.replicas[0].get_states.remote()] + return ray.get(rets)[0] + + def save_states(self, checkpoint: str): + rets = [self.replicas[0].save_states.remote(checkpoint)] ray.get(rets) - def load_parameters(self, checkpoint): + def load_states(self, states=None, checkpoint: Optional[str] = None): rets = [ - replica.load_parameters.remote(checkpoint) - for _, replica in enumerate(self.replicas) + replica.load_states.remote(states, checkpoint) + for replica in self.replicas ] ray.get(rets) @@ -387,15 +420,15 @@ def set_parameters(self, params): ] ray.get(rets) - def get_parameters(self, cpu=False): + def get_parameters(self, cpu: bool = False): ret = self.replicas[0].get_parameters.remote(cpu) return ray.get(ret)[0] - def get_named_parameters(self, cpu=False): + def get_named_parameters(self, cpu: bool = False): ret = self.replicas[0].get_named_parameters.remote(cpu) return ray.get([ret])[0] - def apply_all_replicas(self, fn): + def apply_all_replicas(self, fn: Callable): """Apply fn in all replica processes and wait until completion.""" return ray.get(self._apply_all_replicas(fn)) @@ -404,13 +437,13 @@ def _apply_all_replicas(self, fn): return [replica.apply.remote(fn) for replica in self.replicas] def _setup_collective_group(self, - world_size, - backend, - group_name="default"): + group_size: int, + backend: int, + group_name: str = "default"): refs = [ replica.setup_collective_group.remote( rank=i, - world_size=world_size, + world_size=group_size, backend=backend, group_name=group_name) for i, replica in enumerate(self.replicas) diff --git a/distml/strategy/base_strategy.py b/distml/strategy/base_strategy.py index 69e3b0a..3f31237 100644 --- a/distml/strategy/base_strategy.py +++ b/distml/strategy/base_strategy.py @@ -1,6 +1,7 @@ from abc import ABCMeta from abc import abstractmethod import logging +from typing import Callable, Any, Mapping, Optional, Sequence import ray @@ -11,11 +12,13 @@ class BaseStrategy(metaclass=ABCMeta): def __init__(self, *, training_operator_cls, - operator_config=None, - initialization_hook=None, - world_size=2, - num_cpus_per_worker=1, - num_gpus_per_worker=1, + operator_config: Optional[Mapping[str, Any]] = None, + initialization_hook: Optional[Callable] = None, + world_size: int = 2, + backend: str = "nccl", + group_name: str = "default", + num_cpus_per_worker: int = 1, + num_gpus_per_worker: int = 1, **kwargs): self.training_operator_cls = training_operator_cls self.initialization_hook = initialization_hook @@ -24,6 +27,8 @@ def __init__(self, "ray.util.distml does not support single-process training " "at this moment.") self.world_size = world_size + self.backend = backend + self.group_name = group_name self.num_cpus_per_worker = num_cpus_per_worker self.num_gpus_per_worker = num_gpus_per_worker self._operator_config = {} if not operator_config \ @@ -47,7 +52,7 @@ def validate(self): raise NotImplementedError() @abstractmethod - def save_parameters(self, checkpoint): + def save_states(self, checkpoint: str): """Saves the Trainer state to the provided checkpoint path. Args: @@ -56,7 +61,18 @@ def save_parameters(self, checkpoint): raise NotImplementedError() @abstractmethod - def load_parameters(self, checkpoint): + def load_states(self, + states=None, + checkpoint: Optional[str] = None, + keys: Optional[Sequence[str]] = None): + """Saves the Trainer state to the provided checkpoint path. + + Args: + states: States to load. + checkpoint (str): Path to target checkpoint file. + keys (str): Keys of the params to load. + If None, using all states. + """ raise NotImplementedError() @abstractmethod @@ -70,6 +86,117 @@ def _init_strategy(self): raise NotImplementedError() @abstractmethod - def shutdown(self, force=False): + def shutdown(self, force: bool = False): """Kill all workers.""" raise NotImplementedError() + + +class BaseDataParallelGroup: + """Spawn a actor group for data-parallel training.""" + + def __init__(self, actor_params: Mapping[str, Any], + dist_params: Mapping[str, Any], num_cpus_per_actor: int, + num_gpus_per_actor: int, + initialization_hook: Optional[Callable], **kwargs): + self._actor_params = actor_params + self._dist_params = dist_params + self._backend = self._dist_params["backend"] + self._group_name = self._dist_params["group_name"] + self._num_cpus_per_actor = num_cpus_per_actor + self._num_gpus_per_actor = num_gpus_per_actor + self._initialization_hook = initialization_hook + + # try to unroll the dist_params + self._backend = self._dist_params["backend"] + self._group_name = self._dist_params["group_name"] + + @property + def world_size(self): + return len(self._actors) + + @property + def backend(self): + return self._backend + + @property + def group_name(self): + return self._group_name + + @abstractmethod + def _setup_collective_group(self, *args, **kwargs): + """All actors setup operators.""" + raise NotImplementedError() + + @abstractmethod + def setup_operator(self): + """All actors setup operators.""" + raise NotImplementedError() + + @abstractmethod + def _start_actors(self, num_actors): + """Start all actors.""" + raise NotImplementedError() + + @abstractmethod + def make_iterator(self, training: bool = True): + """Make iterator.""" + raise NotImplementedError() + + @abstractmethod + def get_data_loader_len(self, training: bool = True): + """Return the number of batches in the data loader.""" + raise NotImplementedError() + + @abstractmethod + def validate_batch(self): + """Validate one batch and return batch metrics.""" + raise NotImplementedError() + + @abstractmethod + def shutdown(self, force: bool = False): + """Shutdown all actors.""" + raise NotImplementedError() + + @abstractmethod + def reset(self): + """Reset group.""" + raise NotImplementedError() + + @abstractmethod + def save_states(self, checkpoint: str): + """Saves the Trainer state to the provided checkpoint path. + + Args: + checkpoint (str): Path to target checkpoint file. + """ + raise NotImplementedError() + + @abstractmethod + def load_states(self, + states=None, + checkpoint: Optional[str] = None, + keys: Optional[Sequence[str]] = None): + """Saves the Trainer state to the provided checkpoint path. + + Args: + states: States to load. + checkpoint (str): Path to target checkpoint file. + keys (str): Keys of the params to load. + If None, using all states. + """ + raise NotImplementedError() + + @abstractmethod + def set_parameters(self, params): + """Input params and replace the model parameters.""" + raise NotImplementedError() + + @abstractmethod + def get_parameters(self, cpu: bool = False): + """Return parameters from the first actor.""" + raise NotImplementedError() + + @abstractmethod + def get_named_parameters(self, cpu: bool = False): + """Return named parameters from the first actor.""" + raise NotImplementedError() diff --git a/distml/strategy/ps_strategy.py b/distml/strategy/ps_strategy.py new file mode 100644 index 0000000..6709069 --- /dev/null +++ b/distml/strategy/ps_strategy.py @@ -0,0 +1,808 @@ +import logging +from typing import List, Callable, Mapping, Any, Optional, Sequence, Dict + +import ray +import ray.util.collective as col +from ray.util.sgd.utils import AverageMeterCollection + +import numpy as np + +import distml.util as util +from distml.strategy.base_strategy import BaseStrategy, BaseDataParallelGroup + +logger = logging.getLogger(__name__) + + +class ParameterServerStrategy(BaseStrategy): + """Strategy that trains a model via parameter server. + + Args: + training_operator_cls (TrainingOperator): + Custom training operator class. + operator_config (dict): operator config specified by users. + initialization_hook (function): A function to call on all training + workers when they are first initialized. This could be useful to + set environment variables for all the worker processes. + num_worker (int): The number of workers. + num_ps (int): The number of parameter servers. + num_cpus_per_worker (int): number of CPUs allocated per worker. + num_gpus_per_worker (int): number of GPUs allocated per worker. + num_cpus_per_server (int): number of CPUs allocated per server. + num_gpus_per_server (int): number of GPUs allocated per server. + """ + + def __init__(self, + *, + training_operator_cls, + operator_config: Optional[Mapping[str, Any]] = None, + initialization_hook: Optional[Callable] = None, + world_size: int = 2, + num_worker: int = 1, + num_ps: int = 1, + backend: str = "nccl", + group_name: str = "default", + num_cpus_per_worker: int = 1, + num_gpus_per_worker: int = 1, + num_cpus_per_server: int = 1, + num_gpus_per_server: int = 1, + **kwargs): + + assert world_size == num_ps + num_worker, \ + "'world_size' should be equal to 'num_ps' plus 'num_worker'" + + self.assignments = None + self.num_ps = num_ps + self.num_worker = num_worker + self.num_cpus_per_server = num_cpus_per_server + self.num_gpus_per_server = num_gpus_per_server + + super(ParameterServerStrategy, self). \ + __init__(training_operator_cls=training_operator_cls, + operator_config=operator_config, + initialization_hook=initialization_hook, + world_size=world_size, + backend=backend, + group_name=group_name, + num_cpus_per_worker=num_cpus_per_worker, + num_gpus_per_worker=num_gpus_per_worker, + **kwargs) + + # PS strategy needs some other prep-up. + self._init_strategy() + + if operator_config and operator_config.get("batch_size"): + self._global_batch_size = operator_config.get("batch_size") + + def _init_strategy(self): + """Do initialization for the distributed strategy.""" + # All sync with worker 0 + init_weights_id = self.worker_group.get_named_parameters(cpu=True) + + self._round_robin_sharding() + + # set assignments to every worker + self.worker_group.set_assignments(self.assignments) + + # all workers get synced + for i, worker in enumerate(self.worker_group.actors): + if i != 0: + ray.get([worker.set_parameters.remote(init_weights_id)]) + + # now spawn parameter server actors + shard_ids = self.worker_group.split_parameters(self.assignments) + + # TODO(HUI): use scatter to send parameters + for server_idx, server in enumerate(self.server_group.actors): + this_shard_ref = self.worker_group.actors[0].index_shard.remote( + shard_ids, server_idx) + ray.get([server.set_params.remote(this_shard_ref)]) + + def _start_workers(self): + """Start worker group and server group.""" + # so here we get two set of params that will be passed around: + # (1) Those for setting up training logic in training_operator, + # including: batch size, user defined operator_config. + operator_config = self._operator_config.copy() + params = dict( + training_operator_cls=self.training_operator_cls, + operator_config=operator_config) + # (2) params for setting up collective group + # and the strategy-related things; + + # For now, we do not have many of them though. + dist_params_worker = dict( + strategy="ps", + is_server=False, + backend=self.backend, + group_name=self.group_name, + num_ps=self.num_ps, + num_worker=self.num_worker, + ) + + dist_params_server = dict( + strategy="ps", + is_server=True, + backend=self.backend, + group_name=self.group_name, + num_ps=self.num_ps, + num_worker=self.num_worker, + ) + + # (3) other arguments that used to init the DataParallelGrup + worker_group_init_args = dict( + actor_params=params, + dist_params=dist_params_worker, + num_cpus_per_actor=self.num_cpus_per_worker, + num_gpus_per_actor=self.num_gpus_per_worker, + initialization_hook=self.initialization_hook, + ) + + server_group_init_args = dict( + actor_params=params, + dist_params=dist_params_server, + num_cpus_per_actor=self.num_cpus_per_server, + num_gpus_per_actor=self.num_gpus_per_server, + initialization_hook=self.initialization_hook, + ) + + # Should we make two groups for worker and server? + self.worker_group = DataParallelGroup(**worker_group_init_args) + self.server_group = DataParallelGroup(**server_group_init_args) + + # Once the group is created, we start it. + self.worker_group._start_actors(self.num_worker) + # server at the last num_ps processes. + self.server_group._start_actors(self.num_ps) + + # worker_rets = self.worker_group.test_connection() + # server_rets = self.server_group.test_connection() + # ray.get(worker_rets + server_rets) + ray.get(self.worker_group.setup_operator()) + ray.get(self.server_group.setup_operator()) + + self.server_group.clean_redundancy() + + def shutdown(self, force: bool = False): + self.worker_group.shutdown(force=force) + self.server_group.shutdown(force=force) + + def get_states(self): + # worker0 pull latest params and return states. + for server_idx, server in enumerate(self.server_group.actors): + # every server sends its shard to the worker0 + server.send_params.remote(0) + # the worker0 receives shards from ps. + ret = self.worker_group.actors[0].recv_params.remote() + ray.get([ret]) + + return self.worker_group.get_states() + + def save_states(self, checkpoint: str): + # worker0 pull latest params. + for server_idx, server in enumerate(self.server_group.actors): + server.send_params.remote(0) + ret = self.worker_group.actors[0].recv_params.remote() + ray.get([ret]) + # Then, worker0 save parameters. + self.worker_group.save_states(checkpoint) + + def load_states(self, states=None, checkpoint: Optional[str] = None): + self.server_group.load_states(states=states, checkpoint=checkpoint) + + def _round_robin_sharding(self): + """Generate the assignment of variable to servers.""" + parameter_distribution = ray.get( + self.worker_group.actors[0].params_distribution.remote()) + assignments = [0 for _ in parameter_distribution] + loads = [0 for _ in range(self.num_ps)] + for i, var_size in enumerate(parameter_distribution): + min_ps_index = loads.index(min(loads)) + loads[min_ps_index] += var_size + assignments[i] = min_ps_index + print("Load of each ps {}".format(loads)) + self.assignments = assignments + + def train(self, num_steps: Optional[int] = None) -> Dict: + """Run the training on parallel workers. + + Args: + num_steps (int): number of steps to train. If none, the + function will simply train for one epoch. + + Returns: + metrics (dict): metrics of training result. + """ + # TODO (Hao): add fault tolerance using `max_retries`. + steps = num_steps if num_steps \ + else self.worker_group.get_data_loader_len() + + # TODO(HUI): Record server rank instead of using num_ps. + # TODO(Hao): this call should be hidden inside Replica. + # train one epoch + self.worker_group.make_iterator() + for idx in range(steps): + metrics = self.train_batch() + print("Step: {}/{}".format(idx, steps)) + return metrics + + def validate(self, num_steps: Optional[int] = None) -> Dict: + """Evaluates the model on the validation data. + + Args: + num_steps (int): number of batches to evaluate. If None, the + function will simply validate across the entire validation + dataset. + """ + steps = num_steps if num_steps \ + else self.worker_group.get_data_loader_len(training=False) + self.worker_group.make_iterator(training=False) + + # Worker group pull latest params. + rets = [] + for worker_idx, worker in enumerate(self.worker_group.actors): + for server_idx, server in enumerate(self.server_group.actors): + # every server sends its shard to the worker + server.send_params.remote(worker_idx) + # the worker receives shards from ps, compute loss, gradients + # and sends these gradients to every server + ret = worker.recv_params.remote() + rets.append(ret) + ray.get(rets) + + metrics = [ + AverageMeterCollection() + for _ in range(len(self.worker_group.actors)) + ] + + # TODO(HUI): Construct a better tool to save validate results. + for idx in range(steps): + batch_metrics = self.worker_group.validate_batch() + for metric_idx, metric in enumerate(batch_metrics): + num_sample = metric.pop("num_sample") + metrics[metric_idx].update(metric, n=num_sample) + # Validate results should be the same in all workers + return metrics[0].summary() + + def train_batch(self) -> Dict: + loss_vals = [] + rets = [] + metrics = {} + + for worker_idx, worker in enumerate(self.worker_group.actors): + for server_idx, server in enumerate(self.server_group.actors): + # every server sends its shard to the worker + server.send_params.remote(worker_idx) + # the worker receives shards from ps, compute loss, gradients + # and sends these gradients to every server + loss_val = worker.compute.remote() + loss_vals.append(loss_val) + + for worker_idx, worker in enumerate(self.worker_group.actors): + for server in self.server_group.actors: + rets.append(server.update.remote(worker_idx)) + + loss_vals = ray.get(loss_vals) + ray.get(rets) + train_loss_list = [d["train_loss"] for d in loss_vals] + metrics["train_loss"] = np.mean(train_loss_list) + return metrics + + +class PS(object): + def __init__(self, training_operator_cls, + operator_config: Optional[Mapping[str, Any]]): + self.training_operator_cls = training_operator_cls + self.operator_config = operator_config + + self.grad_counts = None + self.params = dict() + + def setup_operator(self): + """Instantiate the training operator.""" + self.training_operator = self.training_operator_cls( + operator_config=self.operator_config) + + def setup_collective_group(self, + rank: int, + num_ps: int, + num_worker: int, + backend: str = "nccl", + group_name: str = "default"): + # rank has already plus num_worker. + self.rank = rank + self.num_ps = num_ps + self.num_worker = num_worker + self.group_name = group_name + self.group_size = num_ps + num_worker + self._init_grad_counts() + # the last num_ps processes are servers. + col.init_collective_group( + num_ps + num_worker, rank, backend=backend, group_name=group_name) + + def apply(self, fn: Callable): + """Apply a function in the replica process.""" + return fn() + + def test_connection(self): + for i in range(self.num_worker): + recv = util.zeros((1, ), cpu=False) + col.recv(recv, i, group_name=self.group_name) + assert recv == 1 + for i in range(self.num_worker): + send = util.ones((1, ), cpu=False) + col.send(send, i, group_name=self.group_name) + + def _init_grad_counts(self): + self.grad_counts = [0] * self.num_worker + + def _init_grad_buffer(self): + self.grad_buffer = { + k: self.training_operator.zeros_like(v, cpu=False) + for k, v in self.params.items() + } + + def get_params(self) -> dict: + return self.params + + def set_params(self, params): + # params should in GPU when calling this function. + for k, v in params.items(): + self.params[k] = self.training_operator.asarray(v) + + # param is a dict, if needed list, should convert in operator. + self.training_operator.reset_optimizer_for_params(self.params) + self._init_grad_buffer() + + def load_states(self, states=None, checkpoint: Optional[str] = None): + self.training_operator.load_states( + states=states, + checkpoint=checkpoint, + keys=tuple(self.params.keys())) + + # # Update the params in actor aspect. + latest_params = self.training_operator.get_named_parameters(cpu=False) + + assert self.params.keys() == latest_params.keys() + + for key in latest_params.keys(): + self.params[key] = latest_params[key] + + def apply_updates(self, grad_buffer): + self.training_operator.apply_updates(grad_buffer) + self.params = self.training_operator.get_named_parameters(cpu=False) + + def _inc_gradients(self, gradients): + for name, p in self.get_params().items(): + if gradients[name] is not None: + self.grad_buffer[name] += gradients[name] + + def send_params(self, dst_rank: int): + """ Send this param shard to the destination worker """ + for name, v in self.params.items(): + cv = self.training_operator.to_cupy(v) + col.send(cv, dst_rank, group_name=self.group_name) + + def update(self, src_rank: int): + """Receive gradients and update""" + keys = list(self.params.keys()) + grads = dict() + recv_list = [] + + for key in keys: + to_recv = self.params[key] + recv_list.append( + self.training_operator.zeros(to_recv.shape, cpu=False)) + + for i in range(len(keys)): + v = self.training_operator.to_cupy(recv_list[i]) + col.recv(v, src_rank, self.group_name) + + for i in range(len(keys)): + grads[keys[i]] = recv_list[i] + + self._inc_gradients(grads) + if not self.grad_counts[src_rank]: + self.grad_counts[src_rank] = 1 + else: + raise RuntimeError(f"This worker {src_rank} send gradients again.") + if sum(self.grad_counts) == self.num_worker: + self.apply_updates(self.grad_buffer) + + self._init_grad_buffer() + self._init_grad_counts() + return True + + def clean_redundancy(self): + self.training_operator.clean_redundancy() + + def shutdown(self): + # destroy the collective group resources on this process + col.destroy_collective_group(self.group_name) + if self.training_operator: + del self.training_operator + return 1 + + +class Worker(object): + def __init__(self, training_operator_cls, + operator_config: Optional[Mapping[str, Any]]): + self.training_operator_cls = training_operator_cls + self.operator_config = operator_config + + # collective-related information + self.group_size = None + self.rank = None + self.group_name = None + self.assignments = None + + def setup_operator(self): + # figure out the signature of training_operator_cls later. + self.training_operator = self.training_operator_cls( + operator_config=self.operator_config) + + def setup_collective_group(self, + rank: int, + num_ps: int, + num_worker: int, + backend: str = "nccl", + group_name: str = "default"): + self.rank = rank + self.num_ps = num_ps + self.num_worker = num_worker + self.group_name = group_name + self.group_size = num_ps + num_worker + self.name_list = [[] for i in range(num_ps)] + + # the last num_ps processes are servers. + col.init_collective_group( + num_ps + num_worker, rank, backend=backend, group_name=group_name) + + def apply(self, fn: Callable): + """Apply a function in the replica process.""" + return fn() + + def test_connection(self): + for i in range(self.num_ps): + send = util.ones((1, ), cpu=False) + col.send(send, self.num_worker + i, group_name=self.group_name) + for i in range(self.num_ps): + recv = util.zeros((1, ), cpu=False) + col.recv(recv, self.num_worker + i, group_name=self.group_name) + assert recv == 1 + return + + def params_distribution(self) -> List: + distribution = [] + weights = self.get_named_parameters(cpu=True) + for k, v in weights.items(): + distribution.append(self.training_operator.numel(v)) + return distribution + + def make_iterator(self, training: bool = True): + """Convert loader to be an iterator at the start of an epoch.""" + # TODO(Hao): need to check whether reaching the boundary of iterator + # instead of making a new one every time. + if training: + self.training_iterator = iter( + self.training_operator._get_train_loader()) + else: + self.validation_iterator = iter( + self.training_operator._get_validation_loader()) + + def get_data_loader_len(self, training: bool = True) -> int: + """Return the number of batches in the data loader.""" + loader = self.training_operator._get_train_loader() if training \ + else self.training_operator._get_validation_loader() + if hasattr(loader, "__len__"): + return len(loader) + else: + raise RuntimeError( + "Data loader has no attribute `__len__`. " + "Please set `num_steps` in `train()` or `validate()`.") + + def derive_updates(self, batch: Sequence[Any]) -> Dict: + # TODO (Hao): handling data loader next. + return self.training_operator.derive_updates(batch) + + def compute_gradients(self): + """ + Update worker parameters that received from server. + Compute gradients and return named gradients. + """ + + try: + batch = next(self.training_iterator) + except StopIteration or NameError: + self.make_iterator() + batch = next(self.training_iterator) + + # different from original core ps. + # Here derive_updates return loss_val and graident in order. + loss_val, grads = self.training_operator.derive_updates(batch) + assert isinstance(grads, dict) + + return loss_val, grads + + def split_gradients(self, grad, assignments) -> List: + """Splitting gradients according to assignments.""" + # assuming messages are gradients or parameters + # this grad is ready to be called by apply_gradients in ParameterServer + num_shards = np.unique(np.array(assignments)).size + shards = [dict() for i in range(num_shards)] + for i, (k, v) in enumerate(grad.items()): + shards[assignments[i]][k] = v + return shards + + def split_parameters(self, assignments) -> List: + """Splitting parameters according to assignments.""" + params = self.get_named_parameters(cpu=False) + num_shards = np.unique(np.array(assignments)).size + shards = [dict() for i in range(num_shards)] + for i, (k, v) in enumerate(params.items()): + shards[assignments[i]][k] = v + return shards + + def index_shard(self, shards, index: int): + return shards[index] + + def recv_params(self): + weights = self.get_named_parameters(cpu=False) + params = dict() + + # 1. Create the receive lists to group collective calls + recv_list = [] + for i in range(self.num_ps): + recv_list.append([]) + param_shard_keys = self.name_list[i] + for key in param_shard_keys: + to_recv = weights[key] + recv_list[-1].append( + self.training_operator.ones(to_recv.shape, cpu=False)) + + # 2. Receive params from servers + for i in range(self.num_ps): + for j in range(len(self.name_list[i])): + v = self.training_operator.to_cupy(recv_list[i][j]) + col.recv(v, self.num_worker + i, group_name=self.group_name) + + # 3. Set params in workers. + for i in range(self.num_ps): + param_shard_keys = self.name_list[i] + for j in range(len(param_shard_keys)): + params[param_shard_keys[j]] = recv_list[i][j] + + self.set_parameters(params) + + def set_parameters(self, params): + self.training_operator.set_parameters(params) + + def get_parameters(self, cpu: bool) -> List: + return self.training_operator.get_parameters(cpu) + + def get_named_parameters(self, cpu: bool) -> Dict: + return self.training_operator.get_named_parameters(cpu) + + def get_gradients(self): + # training_operator call gradients or we save gradient in replica + # when derive_updates. + return self.training_operator.get_gradients() + + def get_states(self): + return self.training_operator.get_states() + + def save_states(self, checkpoint: str): + self.training_operator.save_states(checkpoint) + + def set_assignments(self, assignments): + self.assignments = assignments + keys = list(self.get_named_parameters(cpu=False).keys()) + for i, a in enumerate(self.assignments): + self.name_list[a].append(keys[i]) + + def compute(self): + """Returns the loss, and send gradients to servers""" + metrics = {} + + self.recv_params() + + loss_val, grad = self.compute_gradients() + metrics["train_loss"] = loss_val + + # Shard gradients and send to servers. + split_grad = self.split_gradients(grad, self.assignments) + for i in range(self.num_ps): + this_shard = self.index_shard(split_grad, i) + for _, v in this_shard.items(): + cv = self.training_operator.to_cupy(v) + col.send(cv, self.num_worker + i, group_name=self.group_name) + return metrics + + def validate_batch(self): + try: + batch = next(self.validation_iterator) + except StopIteration and TypeError: + self.make_iterator(training=False) + batch = next(self.validation_iterator) + batch_metric = self.training_operator.validate_batch(batch) + return batch_metric + + def shutdown(self): + # destroy the collective group resources on this process + col.destroy_collective_group(self.group_name) + if self.training_operator: + del self.training_operator + return 1 + + +class DataParallelGroup(BaseDataParallelGroup): + """Spawn a actor group for data-parallel training.""" + + def __init__(self, actor_params: Mapping[str, Any], + dist_params: Mapping[str, Any], num_cpus_per_actor: int, + num_gpus_per_actor: int, + initialization_hook: Optional[Callable]): + super(DataParallelGroup, self).__init__( + actor_params=actor_params, + dist_params=dist_params, + num_cpus_per_actor=num_cpus_per_actor, + num_gpus_per_actor=num_gpus_per_actor, + initialization_hook=initialization_hook) + self.is_server = self._dist_params["is_server"] + self.num_ps = self._dist_params["num_ps"] + self.num_worker = self._dist_params["num_worker"] + + self._distributed_actors = None + + def _setup_collective_group(self, + num_ps: int, + num_worker: int, + backend: int, + group_name: str = "default"): + if self._dist_params["strategy"] == "ps": + is_server = self.is_server + + rets = [ + actor.setup_collective_group.remote( + rank=i + is_server * num_worker, + num_worker=num_worker, + num_ps=num_ps, + backend=backend, + group_name=group_name) + for i, actor in enumerate(self._distributed_actors) + ] + else: # this can be extend for allreduce. + raise RuntimeError("Unrecognized strategy.") + return rets + + def setup_operator(self): + setups = [ + actor.setup_operator.remote() + for i, actor in enumerate(self._distributed_actors) + ] + return setups + + def _start_actors(self, num_actors: int): + if self.is_server: + RemoteActor = ray.remote( + num_cpus=self._num_cpus_per_actor, + num_gpus=self._num_gpus_per_actor)(PS) + else: + RemoteActor = ray.remote( + num_cpus=self._num_cpus_per_actor, + num_gpus=self._num_gpus_per_actor)(Worker) + + self._distributed_actors = [ + RemoteActor.remote(**self._actor_params) for _ in range(num_actors) + ] + + # apply init_hook + if self._initialization_hook: + self.apply_all_replicas(self._initialization_hook) + + # setup the rank and group in each replica + ray.get( + self._setup_collective_group(self.num_ps, self.num_worker, + self.backend, self.group_name)) + + def test_connection(self): + rets = [ + actor.test_connection.remote() + for _, actor in enumerate(self.actors) + ] + return rets + + def set_assignments(self, assignments): + rets = [ + actor.set_assignments.remote(assignments) + for _, actor in enumerate(self.actors) + ] + return rets + + def apply_all_replicas(self, fn: Callable): + """Apply fn in all replica processes and wait until completion.""" + return ray.get(self._apply_all_replicas(fn)) + + def _apply_all_replicas(self, fn): + """Apply a function fn in all replica processes.""" + return [actor.apply.remote(fn) for actor in self.actors] + + def _make_iterator(self, training: bool): + return [actor.make_iterator.remote(training) for actor in self.actors] + + def make_iterator(self, training: bool = True): + ray.get(self._make_iterator(training)) + + def get_data_loader_len(self, training: bool = True): + """Return the number of batches in the data loader.""" + lens = ray.get([ + actor.get_data_loader_len.remote(training=training) + for actor in self.actors + ]) + + if len(set(lens)) != 1: + # TODO(Hao): is this correct after we add distributed data loader? + raise RuntimeError( + "All actors should have the same dataloader len.") + return lens[0] + + def validate_batch(self): + rets = [ + actor.validate_batch.remote() + for _, actor in enumerate(self.actors) + ] + stats = ray.get(rets) + return stats + + def shutdown(self, force: bool = False): + rets = [actor.shutdown.remote() for _, actor in enumerate(self.actors)] + stats = ray.get(rets) + return stats + + def reset(self): + pass + + @property + def actors(self): + return self._distributed_actors + + def get_states(self): + ret = self.actors[0].get_states.remote() + return ray.get([ret])[0] + + def save_states(self, checkpoint: str): + rets = [self.actors[0].save_states.remote(checkpoint)] + ray.get(rets) + + def load_states(self, states=None, checkpoint: Optional[str] = None): + rets = [ + actor.load_states.remote(states=states, checkpoint=checkpoint) + for _, actor in enumerate(self.actors) + ] + ray.get(rets) + + def set_parameters(self, params): + rets = [ + actor.set_parameters.remote(params) + for _, actor in enumerate(self.actors) + ] + ray.get(rets) + + def get_parameters(self, cpu: bool = False): + ret = self.actors[0].get_parameters.remote(cpu) + return ray.get([ret])[0] + + def get_named_parameters(self, cpu: bool = False): + ret = self.actors[0].get_named_parameters.remote(cpu) + return ray.get([ret])[0] + + def split_parameters(self, assignments): + ret = self.actors[0].split_parameters.remote(assignments) + return ray.get([ret])[0] + + def clean_redundancy(self): + """Clean dataloader. Only for servers""" + rets = [ + actor.clean_redundancy.remote() + for _, actor in enumerate(self.actors) + ] + ray.get(rets) diff --git a/examples/jax/Bert-base-1node.png b/examples/jax/Bert-base-1node.png new file mode 100644 index 0000000..be7b717 Binary files /dev/null and b/examples/jax/Bert-base-1node.png differ diff --git a/examples/jax/benchmark.md b/examples/jax/benchmark.md new file mode 100644 index 0000000..6a73724 --- /dev/null +++ b/examples/jax/benchmark.md @@ -0,0 +1,67 @@ +## TestBed + +| Name | GPU | Bandwidth | +|------|---------|-----------| +| lm1 | 2080 Ti | 1 Gbps | +| lm6 | 2080 Ti | 1 Gbps | +| room | V100 | 100 Gbps | +| AWS | V100 | 100 Gbps | + + +| | CUDA | cuDNN | jaxlib | cupy | +|---------|------|-------|--------|-------| +| Version | 10.1 | 7.6.5 | 0.1.56 | 8.3.0 | + + +- ResNet18 batch size (per replica): 128 +- ResNet101 batch size (per replica): 128 +- Bert-base batch size (per relica): 8, sentence length 128. + +## Baseline + +| | ResNet18 (images/s) | ResNet101 (images/s) | Bert-base (words/s) | +|------|---------------------|----------------------|---------------------| +| lm1 | 74.31 | 19.70 | 1798.73 | +| lm6 | 72.07 | 21.06 | 1906.47 | +| room | 90.94 | 22.97 | 1710.08 | +| AWS | 90.20 | 23.94 | 2269.57 | + +## Results + +### Setting: 1x node, up to 8x GPUs, on lm1 and lm6 + +- Green: AllReduce +- Orange: PS +- Blue: Ideal + +#### Bert-base +![img.png](Bert-base-1node.png) + +#### ResNet-18 +![img.png](resnet18-1node.png) + +#### ResNet-101 +![img_1.png](resnet101-1node.png) + +### Setting: up to 3x nodes, up to 12x GPUs + +- Blue: AllReduce +- Red: ideal + +Note: we observe bandwidth bottleneck. + +#### 1 Gbps bandwidth, ResNet-18 +![img.png](resnet18-distributed-1gbps.png) + +#### 1 Gbps bandwidth, ResNet-101 +![img.png](resnet101-distributed-1gbps.png) + +#### 1 Gbps bandwidth, Bert-base +![img.png](bert-base-distributed-1gbps.png) + +#### 100 Gbps bandwidth +| | Throughput (12 nodes) | Scalability (x) | Throughput (16 nodes) | Scalability (x) | +|-----------|-----------------------|-----------------|-----------------------|-----------------| +| ResNet18 | 915.43 | 10.15x | 1191.73 | 13.21x | +| ResNet101 | 248.76 | 10.43x | 326.49 | 13.69x | +| Bert-base | 20479.66 | 9.08x | 26798.02 | 11.88x | diff --git a/examples/jax/bert-base-distributed-1gbps.png b/examples/jax/bert-base-distributed-1gbps.png new file mode 100644 index 0000000..638a473 Binary files /dev/null and b/examples/jax/bert-base-distributed-1gbps.png differ diff --git a/examples/jax/jax_util/datasets.py b/examples/jax/jax_util/datasets.py index 72f136b..254d56c 100644 --- a/examples/jax/jax_util/datasets.py +++ b/examples/jax/jax_util/datasets.py @@ -128,6 +128,7 @@ def load_CIFAR_batch(root, mode="train"): datadict = pickle.load(f, encoding="bytes") X = datadict[b"data"] Y = datadict[b"fine_labels"] + if mode == "train": X = X.reshape(50000, 3, 32, 32) else: diff --git a/examples/jax/mnist_jax_example.py b/examples/jax/mnist_jax_example.py index 5a4c061..324a99a 100644 --- a/examples/jax/mnist_jax_example.py +++ b/examples/jax/mnist_jax_example.py @@ -6,6 +6,7 @@ import ray from distml.operator.jax_operator import JAXTrainingOperator from distml.strategy.allreduce_strategy import AllReduceStrategy +from distml.strategy.ps_strategy import ParameterServerStrategy from ray.util.sgd.utils import override @@ -15,12 +16,14 @@ from jax_util.resnet import ResNet18, ResNet50, ResNet101 from jax_util.datasets import mnist, Dataloader +import numpy as np def initialization_hook(): # Need this for avoiding a connection restart issue on AWS. os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" os.environ["NCCL_LL_THRESHOLD"] = "0" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" + # set the below if needed # print("NCCL DEBUG SET") # os.environ["NCCL_DEBUG"] = "INFO" @@ -53,9 +56,14 @@ def setup(self, config): with FileLock(".ray.lock"): train_images, train_labels, test_images, test_labels = mnist() + if config.get("test_mode", False): + train_images = np.random.choice(train_images, 1000) + train_labels = np.random.choice(train_labels, 1000) + test_images = np.random.choice(test_images, 1000) + test_labels = np.random.choice(test_labels, 1000) + train_images = train_images.reshape(train_images.shape[0], 1, 28, 28).transpose(2, 3, 1, 0) - test_images = test_images.reshape(test_images.shape[0], 1, 28, 28).transpose(2, 3, 1, 0) @@ -76,6 +84,36 @@ def criterion(logits, targets): train_loader=train_loader, validation_loader=test_loader) +def make_ar_strategy(args): + strategy = AllReduceStrategy( + training_operator_cls=MnistTrainingOperator, + world_size=args.num_worker, + operator_config={ + "lr": 0.01, + "batch_size": 128, + "num_worker": args.num_worker, + "num_classes": 10, + "model_name": args.model_name + }, + initialization_hook=initialization_hook) + return strategy + + +def make_ps_strategy(args): + strategy = ParameterServerStrategy( + training_operator_cls=MnistTrainingOperator, + world_size=args.num_worker, + num_worker=args.num_worker - args.num_ps, + num_ps=args.num_ps, + operator_config={ + "lr": 0.01, + "batch_size": 128, + "num_classes": 10, + "model_name": args.model_name + }) + return strategy + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -84,11 +122,16 @@ def criterion(logits, targets): type=str, help="the address to use for connecting to the Ray cluster") parser.add_argument( - "--num-workers", + "--num-worker", "-n", type=int, default=2, help="Sets number of workers for training.") + parser.add_argument( + "--num-ps", + type=int, + default=1, + help="Sets number of servers for training. Only for ps_strategy.") parser.add_argument( "--num-epochs", type=int, @@ -104,6 +147,8 @@ def criterion(logits, targets): type=str, default="resnet18", help="model, Optional: resnet18, resnet50, resnet101.") + parser.add_argument( + "--strategy", type=str, default="ar", help="model, Optional: ar, ps.") args, _ = parser.parse_known_args() @@ -111,21 +156,17 @@ def criterion(logits, targets): ray.init(args.address) else: ray.init( - num_gpus=args.num_workers, - num_cpus=args.num_workers * 2, + num_gpus=args.num_worker, + num_cpus=args.num_worker * 2, log_to_driver=True) - strategy = AllReduceStrategy( - training_operator_cls=MnistTrainingOperator, - world_size=args.num_workers, - operator_config={ - "lr": 0.01, - "batch_size": 128, - "num_workers": args.num_workers, - "num_classes": 10, - "model_name": args.model_name - }, - initialization_hook=initialization_hook) + if args.strategy == "ar": + strategy = make_ar_strategy(args) + elif args.strategy == "ps": + strategy = make_ps_strategy(args) + else: + raise RuntimeError("Unrecognized trainer type. Except 'ar' or 'ps'" + "Got {}".format(args.strategy)) for i in range(args.num_epochs): strategy.train() diff --git a/examples/jax/resnet101-1node.png b/examples/jax/resnet101-1node.png new file mode 100644 index 0000000..9c263d7 Binary files /dev/null and b/examples/jax/resnet101-1node.png differ diff --git a/examples/jax/resnet101-distributed-1gbps.png b/examples/jax/resnet101-distributed-1gbps.png new file mode 100644 index 0000000..ffc76e6 Binary files /dev/null and b/examples/jax/resnet101-distributed-1gbps.png differ diff --git a/examples/jax/resnet18-1node.png b/examples/jax/resnet18-1node.png new file mode 100644 index 0000000..78d8e9c Binary files /dev/null and b/examples/jax/resnet18-1node.png differ diff --git a/examples/jax/resnet18-distributed-1gbps.png b/examples/jax/resnet18-distributed-1gbps.png new file mode 100644 index 0000000..4e22025 Binary files /dev/null and b/examples/jax/resnet18-distributed-1gbps.png differ