From 79b410d280e3365e3a958b144274bb899e6bf620 Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Thu, 29 Apr 2021 01:14:25 +0800 Subject: [PATCH 01/12] add jax operator --- distml/operator/__init__.py | 3 --- format.sh | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/distml/operator/__init__.py b/distml/operator/__init__.py index 5fe7080..e69de29 100644 --- a/distml/operator/__init__.py +++ b/distml/operator/__init__.py @@ -1,3 +0,0 @@ -from distml.operator.torch_operator import TorchTrainingOperator - -__all__ = ["TorchTrainingOperator"] diff --git a/format.sh b/format.sh index c03b7b9..bad8cb9 100755 --- a/format.sh +++ b/format.sh @@ -46,7 +46,7 @@ builtin cd "$ROOT" || exit 1 # Add the upstream remote if it doesn't exist if ! git remote -v | grep -q upstream; then - git remote add 'upstream' 'https://github.com/ray-project/distml.git' + git remote add 'upstream' 'https://yuan.cm/https://github.com/ray-project/distml.git' fi FLAKE8_VERSION=$(flake8 --version | awk '{print $1}') @@ -106,14 +106,14 @@ format_changed() { yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" if which flake8 >/dev/null; then git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ - flake8 --inline-quotes '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 + flake8 '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 fi fi if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then if which flake8 >/dev/null; then git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ - flake8 --inline-quotes '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 + flake8 '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 fi fi } @@ -121,7 +121,7 @@ format_changed() { # Format all files, and print the diff to stdout for travis. format_all() { yapf --diff "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" distml - flake8 --inline-quotes '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 distml + flake8 '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 distml } # This flag formats individual files. --files *must* be the first command line From dac98abb101848125769079265e77d0ab90a787d Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Thu, 29 Apr 2021 01:14:50 +0800 Subject: [PATCH 02/12] add jax operator --- distml/operator/jax_operator.py | 331 ++++++++++++++++++++++++++++++ examples/jax/.ray.lock | 0 examples/jax/default_train.csv | 2 + examples/jax/jax_util/__init__.py | 2 + examples/jax/jax_util/datasets.py | 281 +++++++++++++++++++++++++ examples/jax/jax_util/resnet.py | 221 ++++++++++++++++++++ examples/jax/mnist_jax_example.py | 120 +++++++++++ 7 files changed, 957 insertions(+) create mode 100644 distml/operator/jax_operator.py create mode 100755 examples/jax/.ray.lock create mode 100644 examples/jax/default_train.csv create mode 100644 examples/jax/jax_util/__init__.py create mode 100644 examples/jax/jax_util/datasets.py create mode 100644 examples/jax/jax_util/resnet.py create mode 100644 examples/jax/mnist_jax_example.py diff --git a/distml/operator/jax_operator.py b/distml/operator/jax_operator.py new file mode 100644 index 0000000..3ac8cd4 --- /dev/null +++ b/distml/operator/jax_operator.py @@ -0,0 +1,331 @@ +import numpy as np +import cupy as cp +import jax +from jax import grad, value_and_grad +import jax.numpy as jnp +from jax.lib import xla_client +from jax.dlpack import from_dlpack, to_dlpack +from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, tree_map, build_tree +from jax._src.util import unzip2 +from jax.experimental.optimizers import OptimizerState + +from .base_operator import TrainingOperator +from distml.util import ThroughputCollection, func_timer +from ray.util.sgd.utils import TimerCollection, AverageMeterCollection + +import time + + +class JAXTrainingOperator(TrainingOperator): + + def __init__(self, operator_config): + super(JAXTrainingOperator, self).__init__(operator_config) + # Should be set by users in the `register` function. + # model methods + self.opt_state = None + self.init_fun = None + self.predict_fun = None + # optimizer methods + self.opt_init = None + self.opt_update = None + self.get_params = None + + self.criterion = None + + # Data loaders for training and validation, registered by users. + self._train_loader = None + self._validation_loader = None + + self.setup(operator_config) + + if hasattr(operator_config, "jit_mode"): + assert not operator_config["jit_mode"], "Not support jit in jax operator." + + self.train_step_num = 0 + + def setup(self, *args, **kwargs): + """Function that needs to be override by users. + + Example: + # some code is the same for all users, maybe we can put it in register. + rng_key = random.PRNGKey(0) + input_shape = (28, 28, 1, batch_size) + lr=0.01 + init_fun, predict_fun = ResNet18(num_classes) + _, init_params = init_fun(rng_key, input_shape) + + opt_init, opt_update, get_params = optimizers.adam(lr) + opt_state = opt_init(init_params) + + self.register(model=(opt_state, get_params, predict_fun), optimizer=opt_update, criterion=lambda logits, targets:-jnp.sum(logits * targets)) + + """ + + pass + + def register(self, + *, + model, + optimizer, + criterion, + lr_schedulers=None, + jit_mode=False): + """Register a few critical information about the model to operator.""" + self.criterion = criterion + if lr_schedulers: + self.lr_schedulers = lr_schedulers + print("WARNING: jax not support learning rate scheduler." + "This will not work.") + + self._register_model(model) + self._register_optimizer(optimizer) + + def _register_model(self, model): + """register model components. + + This function shall be instantiated in framework-specific operator + implementations. + """ + self.opt_state = model[0] + self.init_fun = model[1] + self.predict_fun = model[2] + + def _register_optimizer(self, optimizer): + self.opt_init = optimizer[0] + self.opt_update = optimizer[1] + self.get_params = optimizer[2] + + def register_data(self, *, train_loader=None, validation_loader=None): + self._train_loader = train_loader + self._validation_loader = validation_loader + + def _get_train_loader(self): + return self._train_loader + + def _get_validation_loader(self): + return self._validation_loader + + def loss_func(self, params, batch): + """A function to calculate predictions and loss value. + + This function is going to be decorated by `grad` in Jax to calculate gradients. + + Args: + batch (tuple): a data batch containing a feature/target pair. + """ + inputs, targets = batch + logits = self.predict_fun(params, inputs) + return self.criterion(logits, targets) + + def derive_updates(self, batch): + """Compute the parameter updates on a given batch of data. + + The `derive_updates` function should be called in conjunction with + the next `apply_updates` function in order to finish one iteration + of training. + + Args: + batch (tuple): a data batch containing a feature/target pair. + """ + loss_val, gradient = self._calculate_gradient(self.opt_state, batch) + + gradient_dict, tree = tree_flatten(gradient) + assert tree == self.opt_state[1] + + if hasattr(self, "preset_keys"): + gradient_dict = {k:g for k, g in zip(self.preset_keys, gradient_dict)} + else: + gradient_dict = {f"{idx}":g for idx, g in enumerate(gradient_dict)} + return loss_val.item(), gradient_dict + + def apply_updates(self, updates): + """Set and apply the updates using the opt_update in Jax. + + Args: + updates (dict): a dictionary of parameter name and updates. + """ + keys, updates = unzip2(sorted(updates.items(), key=lambda d: int(d[0]))) + updates = tree_unflatten(self.opt_state[1], updates) + self.opt_state = self.opt_update(self.train_step_num, updates, self.opt_state) + self.train_step_num += 1 + + def to_cupy(self, tensor): + """Convert a torch GPU tensor to cupy tensor.""" + if isinstance(tensor, list): + return list(map(self.to_cupy, tensor)) + ctensor = cp.fromDlpack(self.get_jax_dlpack(tensor)) + return ctensor + + def to_operator_tensor(self, tensor): + """Convert a cupy tensor to jax tensor. + + There comes a bug. The layouts of tensor explained by cupy and jax are different. But dlpack doesn't convert the layout. + """ + if isinstance(tensor, list): + return list(map(self.to_operator_tensor, tensor)) + return from_dlpack(tensor.toDlpack()) + + # TODO(HUI): support return logits by adding use_aux in `value_and_grad` + def _calculate_gradient(self, opt_state, batch): + params = self.get_params(opt_state) + loss_val, gradient = value_and_grad(self.loss_func)(params, batch) + return loss_val, gradient + + def get_jax_dlpack(self, tensor): + """Get the dlpack of a jax tensor. + + Jax api might cause different pointer address after the conversion. + We use the xla api to avoid this bug in jax api. + """ + return xla_client._xla.buffer_to_dlpack_managed_tensor(tensor.device_buffer, + take_ownership=False) + + def validate_batch(self, batch): + """Perform validation over a data batch. + + Args: + batch (tuple): a data batch containing a feature/target pair. + """ + params = self.get_params(self.opt_state) + criterion = self.criterion + predict_fun = self.predict_fun + + # unpack features into list to support multiple inputs model + features, targets = batch + + outputs = predict_fun(params, features) + loss = criterion(outputs, targets) + prediction_class = jnp.argmax(outputs, axis=1) + targets_class = jnp.argmax(targets, axis=1) + + acc = jnp.mean(prediction_class == targets_class) + samples_num = targets.shape[0] + + return { + "val_loss": loss.item(), + "val_accuracy": acc.item(), + "samples_num": samples_num + } + + def get_parameters(self, cpu): + """get the flatten parameters.""" + params = self.get_params(self.opt_state) + flatten_params, tree = tree_flatten(params) + if not hasattr(self, "tree"): + self.tree = tree + + if cpu: + flatten_params = list(map(np.asarray, flatten_params)) + return flatten_params + + def get_named_parameters(self, cpu): + """Get the named parameters. + + In jax, we need to construct a dict to contain the parameters. + """ + params = self.get_parameters(cpu) + if hasattr(self, "preset_keys"): + dict_params = {name:p for name, p in zip(self.preset_keys, params)} + else: + dict_params = {f"{idx}":p for idx, p in enumerate(params)} + return dict_params + + # TODO(HUI): used in load states or load parameters + def set_parameters(self, new_params): + """Use new parameters to replace model parameters. + + In jax, we need to construct a dict to contain the parameters. + + Args: + new_params (dict): New parameters to updates the current model. + """ + assert isinstance(new_params, dict) + + keys, new_params = unzip2(sorted(new_params.items(), key=lambda d: int(d[0]))) + self.preset_keys = keys + + if not hasattr(self, "tree"): + self.tree = tree_structure(self.get_params(self.opt_state)) + + states_flat, tree, subtrees = self.opt_state + + states = map(tree_unflatten, subtrees, states_flat) + + def update(param, state): + new_state = param, *state[1:] + return new_state + + new_states = map(update, new_params, states) + + new_states_flat, new_subtrees = unzip2(map(tree_flatten, new_states)) + + if not new_subtrees: + raise RuntimeError("subtrees of new params is empty.") + for idx, (subtree, new_subtree) in enumerate(zip(subtrees, new_subtrees)): + if new_subtree != subtree: + msg = ("input structur did not match the save params struture. " + "input {} and output {}.") + raise TypeError(msg.format(subtree, new_subtree)) + + self.opt_state = OptimizerState(new_states_flat, tree, new_subtrees) + + def reset_optimizer_for_params(self, params): + keys, params = unzip2(sorted(params.items(), key=lambda d: int(d[0]))) + self.tree = tree_structure(params) + self.opt_state = self.opt_init(params) + + def ones(self, shape, cpu=True): + if cpu: + return np.ones(shape) + else: + return jnp.ones(shape) + + def zeros(self, shape, cpu=True): + if cpu: + return np.zeros(shape) + else: + return jnp.zeros(shape) + + def ones_like(self, x, cpu=True): + if cpu: + return np.ones_like(x) + else: + return jnp.ones_like(x) + + def zeros_like(self, x, cpu=True): + if cpu: + return np.zeros_like(x) + else: + return jnp.zeros_like(x) + + def numel(self, v): + return np.size(v) + + def asarray(self, v): + return jnp.asarray(v) + + def clean_redundancy(self): + del self._train_loader + del self._validation_loader + + # TODO(HUI): use pickle to serialize parameters or states and save it. + def save_parameters(self, checkpoint): + raise NotImplementedError( + "save_parameters is not support in jax operator.") + + def load_parameters(self, checkpoint): + raise NotImplementedError( + "load_parameters is not support in jax operator.") + + def save_states(self, states): + raise NotImplementedError( + "save_states is not support in jax operator.") + + def get_states(self, states): + raise NotImplementedError( + "get_states is not support in jax operator.") + + def load_states(self, checkpoint): + raise NotImplementedError( + "load_states is not support in jax operator.") + diff --git a/examples/jax/.ray.lock b/examples/jax/.ray.lock new file mode 100755 index 0000000..e69de29 diff --git a/examples/jax/default_train.csv b/examples/jax/default_train.csv new file mode 100644 index 0000000..0d23b47 --- /dev/null +++ b/examples/jax/default_train.csv @@ -0,0 +1,2 @@ +count_train,mean_train_s,last_train_s,total_train_s,pass_data_train,throughout_train_d +50,2.456741285324097,2.3998360633850098,164.64879870414734,6400,38.87061460739823 diff --git a/examples/jax/jax_util/__init__.py b/examples/jax/jax_util/__init__.py new file mode 100644 index 0000000..955ceb9 --- /dev/null +++ b/examples/jax/jax_util/__init__.py @@ -0,0 +1,2 @@ +from .datasets import mnist, Dataloader +from .resnet import ResNet18, ResNet50, ResNet101 \ No newline at end of file diff --git a/examples/jax/jax_util/datasets.py b/examples/jax/jax_util/datasets.py new file mode 100644 index 0000000..6bc9385 --- /dev/null +++ b/examples/jax/jax_util/datasets.py @@ -0,0 +1,281 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Datasets used in examples.""" + + +import array +import gzip +import os +from os import path +import struct +import urllib.request +from jax.api import F +import jax.numpy as jnp +from jax import jit + +import numpy as np +import numpy.random as npr +import pickle +from functools import partial + +_DATA = "/tmp/jax_example_data/" + + +def _download(url, filename, dataset_name="mnist"): + """Download a url to a file in the JAX data temp directory.""" + root = os.path.join(_DATA,dataset_name) + if not path.exists(root): + os.makedirs(root) + out_file = path.join(root, filename) + if not path.isfile(out_file): + urllib.request.urlretrieve(url, out_file) + print("downloaded {} to {}".format(url, root)) + + +def _partial_flatten(x): + """Flatten all but the first dimension of an ndarray.""" + return np.reshape(x, (x.shape[0], -1)) + + +def _one_hot(x, k, dtype=np.float32): + """Create a one-hot encoding of x of size k.""" + return np.asarray(x[:, None] == np.arange(k), dtype) + +# @partial(jit, static_argnums=1) +def _one_hot_jit(x, k, dtype=np.float32): + """Create a one-hot encoding of x of size k.""" + return jnp.asarray(x[:, None] == jnp.arange(0, k), dtype) + +def mnist_raw(): + """Download and parse the raw MNIST dataset.""" + # CVDF mirror of http://yann.lecun.com/exdb/mnist/ + base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" + + def parse_labels(filename): + with gzip.open(filename, "rb") as fh: + _ = struct.unpack(">II", fh.read(8)) + return np.array(array.array("B", fh.read()), dtype=np.uint8) + + def parse_images(filename): + with gzip.open(filename, "rb") as fh: + _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) + return np.array(array.array("B", fh.read()), + dtype=np.uint8).reshape(num_data, rows, cols) + + for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]: + _download(base_url + filename, filename) + + train_images = parse_images(path.join(_DATA, "mnist", "train-images-idx3-ubyte.gz")) + train_labels = parse_labels(path.join(_DATA, "mnist", "train-labels-idx1-ubyte.gz")) + test_images = parse_images(path.join(_DATA, "mnist", "t10k-images-idx3-ubyte.gz")) + test_labels = parse_labels(path.join(_DATA, "mnist", "t10k-labels-idx1-ubyte.gz")) + + return train_images, train_labels, test_images, test_labels + + +def mnist(permute_train=False): + """Download, parse and process MNIST data to unit scale and one-hot labels.""" + train_images, train_labels, test_images, test_labels = mnist_raw() + + train_images = _partial_flatten(train_images) / np.float32(255.) + test_images = _partial_flatten(test_images) / np.float32(255.) + train_labels = _one_hot(train_labels, 10) + test_labels = _one_hot(test_labels, 10) + + if permute_train: + perm = np.random.RandomState(0).permutation(train_images.shape[0]) + train_images = train_images[perm] + train_labels = train_labels[perm] + + return train_images, train_labels, test_images, test_labels + +def cifa100_raw(): + """Download and parse the raw MNIST dataset.""" + base_url = "http://www.cs.toronto.edu/~kriz/" + + def load_CIFAR_batch(root, mode="train"): + """ load single batch of cifar """ + if mode == "train": + filename = path.join(root, "train") + elif mode == "test": + filename = path.join(root, "test") + else: + raise RuntimeError("Error: unrecognized mode", + " Got {}".format(mode)) + + with open(filename, 'rb')as f: + datadict = pickle.load(f,encoding='bytes') + X = datadict[b'data'] + Y = datadict[b'fine_labels'] + if mode == "train": + X = X.reshape(50000, 3, 32, 32) + else: + X = X.reshape(10000, 3, 32, 32) + return np.array(X), np.array(Y) + + for filename in ["cifar-100-python.tar.gz"]: + _download(base_url + filename, filename, dataset_name="cifa100") + + root = path.join(_DATA, "cifa100") + + if not os.path.exists(path.join(root, "cifar-100-python.tar.gz")): + os.system("tar xvf {} -C {}".format(path.join(root, "cifar-100-python.tar.gz"), + root)) + + train_images, train_labels = load_CIFAR_batch(path.join(root, "cifar-100-python"), + mode="train") + test_images, test_labels = load_CIFAR_batch(path.join(root, "cifar-100-python"), + mode="test") + + # b"fine_label_names" b"coarse_label_names" + # meta_path = path.join(root, "cifar-100-python", "meta") + return train_images, train_labels, test_images, test_labels + +def cifa100(permute_train=False): + """Download, parse and process cida100 data to unit scale and one-hot labels.""" + train_images, train_labels, test_images, test_labels = cifa100_raw() + + train_images = _partial_flatten(train_images) / np.float32(255.) + test_images = _partial_flatten(test_images) / np.float32(255.) + train_labels = _one_hot(train_labels, 100) + test_labels = _one_hot(test_labels, 100) + + if permute_train: + perm = np.random.RandomState(0).permutation(train_images.shape[0]) + train_images = train_images[perm] + train_labels = train_labels[perm] + + return train_images, train_labels, test_images, test_labels + + +def cifa10_raw(): + """Download and parse the raw MNIST dataset.""" + base_url = "http://www.cs.toronto.edu/~kriz/" + + def load_CIFAR_batch(root, mode="train"): + """ load single batch of cifar """ + if mode == "train": + filenames = [] + for i in range(1,6): + filenames.append(path.join(root, f"data_batch_{i}")) + elif mode == "test": + filenames = [path.join(root, "test_batch")] + else: + raise RuntimeError("Error: unrecognized mode", + " Got {}".format(mode)) + print(filenames) + datas = [] + labels = [] + for filename in filenames: + with open(filename, 'rb')as f: + datadict = pickle.load(f,encoding='bytes') + X = datadict[b'data'] + Y = datadict[b'labels'] + X = X.reshape(10000, 3, 32, 32) + datas.append(X) + labels.append(Y) + return np.concatenate(datas, axis=0), np.concatenate(labels) + + for filename in ["cifar-10-python.tar.gz"]: + _download(base_url + filename, filename, dataset_name="cifa10") + + root = path.join(_DATA, "cifa10") + + if not os.path.exists(path.join(root, "cifar-10-batches-py")): + os.system("tar xvf {} -C {}".format(path.join(root, "cifar-10-python.tar.gz"), + root)) + + train_images, train_labels = load_CIFAR_batch(path.join(root, "cifar-10-batches-py"), + mode="train") + test_images, test_labels = load_CIFAR_batch(path.join(root, "cifar-10-batches-py"), + mode="test") + print(test_images.shape) + + # b"fine_label_names" b"coarse_label_names" + # meta_path = path.join(root, "cifar-100-python", "meta") + return train_images, train_labels, test_images, test_labels + + +def cifa10(permute_train=False): + """Download, parse and process cida100 data to unit scale and one-hot labels.""" + train_images, train_labels, test_images, test_labels = cifa10_raw() + + train_images = _partial_flatten(train_images) / np.float32(255.) + test_images = _partial_flatten(test_images) / np.float32(255.) + train_labels = _one_hot(train_labels, 10) + test_labels = _one_hot(test_labels, 10) + + if permute_train: + perm = np.random.RandomState(0).permutation(train_images.shape[0]) + train_images = train_images[perm] + train_labels = train_labels[perm] + + return train_images, train_labels, test_images, test_labels + + +class Dataloader: + def __init__(self, data, target, batch_size=128, shuffle=False): + ''' + data: shape(width, height, channel, num) + target: shape(num, num_classes) + ''' + self.data = data + self.target = target + self.batch_size = batch_size + num_data = self.target.shape[0] + num_complete_batches, leftover = divmod(num_data, batch_size) + self.num_batches = num_complete_batches + bool(leftover) + self.shuffle = shuffle + + def synth_batches(self): + num_imgs = self.target.shape[0] + rng = npr.RandomState(npr.randint(10)) + perm = rng.permutation(num_imgs) if self.shuffle else np.arange(num_imgs) + for i in range(self.num_batches): + batch_idx = perm[i * self.batch_size:(i + 1) * self.batch_size] + img_batch = self.data[:, :, :, batch_idx] + label_batch = self.target[batch_idx] + yield img_batch, label_batch + + def __iter__(self): + return self.synth_batches() + + def __len__(self): + return self.num_batches + + +if __name__ == "__main__": + train_images, train_labels, test_images, test_labels = cifa10() + + print(type(train_images), type(train_labels)) + print(train_images.shape, train_labels.shape) + print(type(test_images), type(test_labels)) + print(test_images.shape, test_labels.shape) + + train_images, train_labels, test_images, test_labels = cifa100() + + print(type(train_images), type(train_labels)) + print(train_images.shape, train_labels.shape) + print(type(test_images), type(test_labels)) + print(test_images.shape, test_labels.shape) + + # cifa10_filepath = path.join(_DATA, "cifa10", "cifar-10-batches-py/test_batch") + # with open(cifa10_filepath, 'rb')as f: + # datadict = pickle.load(f,encoding='bytes') + # print(datadict.keys()) + # print(datadict[b"data"]) + # print(type(datadict[b"data"])) + # print(len(datadict[b"labels"])) \ No newline at end of file diff --git a/examples/jax/jax_util/resnet.py b/examples/jax/jax_util/resnet.py new file mode 100644 index 0000000..7387b10 --- /dev/null +++ b/examples/jax/jax_util/resnet.py @@ -0,0 +1,221 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A mock-up showing a ResNet50 network with training on synthetic data. + +This file uses the stax neural network definition library and the optimizers +optimization library. +""" + +import numpy.random as npr + +import jax.numpy as jnp +from jax import jit, grad, random +from jax.experimental import optimizers +from jax.experimental import stax +from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum, + FanOut, Flatten, GeneralConv, Identity, + MaxPool, Relu, LogSoftmax) + + +# ResNet blocks compose other layers + +def ConvBlock(kernel_size, filters, strides=(2, 2)): + ks = kernel_size + filters1, filters2, filters3 = filters + Main = stax.serial( + Conv(filters1, (1, 1), strides), BatchNorm(), Relu, + Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, + Conv(filters3, (1, 1)), BatchNorm()) + Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) + return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) + + +def IdentityBlock(kernel_size, filters): + ks = kernel_size + filters1, filters2 = filters + def make_main(input_shape): + # the number of output channels depends on the number of input channels + return stax.serial( + Conv(filters1, (1, 1)), BatchNorm(), Relu, + Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, + Conv(input_shape[3], (1, 1)), BatchNorm()) + Main = stax.shape_dependent(make_main) + return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) + + +def BasicBlock(kernel_size, filters, strides=(1, 1)): + ks = kernel_size + filters1, filters2 = filters + Main = stax.serial( + Conv(filters1, (ks, ks), strides, padding='SAME'), BatchNorm(), Relu, + Conv(filters2, (ks, ks), strides, padding='SAME'), BatchNorm()) + + Shortcut = stax.serial(Conv(filters2, (1, 1), strides), BatchNorm()) + return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) + +def BasicBlock_withoutBN(kernel_size, filters, strides=(1, 1)): + ks = kernel_size + filters1, filters2 = filters + Main = stax.serial( + Conv(filters1, (ks, ks), strides, padding='SAME'), Relu, + Conv(filters2, (ks, ks), strides, padding='SAME')) + + Shortcut = stax.serial(Conv(filters2, (1, 1), strides)) + return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) + + +def IdentityBlock_withoutBN(kernel_size, filters): + ks = kernel_size + filters1, filters2 = filters + def make_main(input_shape): + # the number of output channels depends on the number of input channels + return stax.serial( + Conv(filters1, (1, 1)), Relu, + Conv(filters2, (ks, ks), padding='SAME'), Relu, + Conv(input_shape[3], (1, 1))) + Main = stax.shape_dependent(make_main) + return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) + +# ResNet architectures compose layers and ResNet blocks + +def ResNet101(num_classes): + return stax.serial( + GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), + BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), + ConvBlock(3, [64, 64, 256], strides=(1, 1)), + IdentityBlock(3, [64, 64]), + IdentityBlock(3, [64, 64]), + ConvBlock(3, [128, 128, 512]), + IdentityBlock(3, [128, 128]), + IdentityBlock(3, [128, 128]), + IdentityBlock(3, [128, 128]), + ConvBlock(3, [256, 256, 1024]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + ConvBlock(3, [512, 512, 2048]), + IdentityBlock(3, [512, 512]), + IdentityBlock(3, [512, 512]), + AvgPool((7, 7), padding="SAME"), Flatten, Dense(num_classes), LogSoftmax) + + +def ResNet50(num_classes): + return stax.serial( + GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), + BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), + ConvBlock(3, [64, 64, 256], strides=(1, 1)), + IdentityBlock(3, [64, 64]), + IdentityBlock(3, [64, 64]), + ConvBlock(3, [128, 128, 512]), + IdentityBlock(3, [128, 128]), + IdentityBlock(3, [128, 128]), + IdentityBlock(3, [128, 128]), + ConvBlock(3, [256, 256, 1024]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + ConvBlock(3, [512, 512, 2048]), + IdentityBlock(3, [512, 512]), + IdentityBlock(3, [512, 512]), + AvgPool((7, 7), padding="SAME"), Flatten, Dense(num_classes), LogSoftmax) + + +def ResNet18(num_classes): + return stax.serial( + GeneralConv(('HWCN', 'OIHW', 'NHWC'), 1, (7, 7), (2, 2), 'SAME'), + BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), + BasicBlock(3, [64, 64]), + IdentityBlock(3, [64, 64]), + BasicBlock(3, [128, 128]), + IdentityBlock(3, [128, 128]), + BasicBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), + BasicBlock(3, [512, 512]), + IdentityBlock(3, [512, 512]), + AvgPool((7, 7), padding="SAME"), Flatten, Dense(num_classes), LogSoftmax) + + +def MLP(num_classes): + return stax.serial( + Flatten, + Dense(32), BatchNorm(), Relu, + Dense(128), BatchNorm(), Relu, + Dense(num_classes), LogSoftmax) + + +if __name__ == "__main__": + rng_key = random.PRNGKey(0) + + batch_size = 8 + num_classes = 1001 + input_shape = (224, 224, 3, batch_size) + step_size = 0.1 + num_steps = 10 + + init_fun, predict_fun = ResNet50(num_classes) + _, init_params = init_fun(rng_key, input_shape) + + def loss(params, batch): + inputs, targets = batch + logits = predict_fun(params, inputs) + return -jnp.sum(logits * targets) + + def accuracy(params, batch): + inputs, targets = batch + target_class = jnp.argmax(targets, axis=-1) + predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1) + return jnp.mean(predicted_class == target_class) + + def synth_batches(): + rng = npr.RandomState(0) + while True: + images = rng.rand(*input_shape).astype('float32') + labels = rng.randint(num_classes, size=(batch_size, 1)) + onehot_labels = labels == jnp.arange(num_classes) + yield images, onehot_labels + + opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9) + batches = synth_batches() + + @jit + def update(i, opt_state, batch): + params = get_params(opt_state) + return opt_update(i, grad(loss)(params, batch), opt_state) + + opt_state = opt_init(init_params) + for i in range(num_steps): + opt_state = update(i, opt_state, next(batches)) + trained_params = get_params(opt_state) diff --git a/examples/jax/mnist_jax_example.py b/examples/jax/mnist_jax_example.py new file mode 100644 index 0000000..09a36ad --- /dev/null +++ b/examples/jax/mnist_jax_example.py @@ -0,0 +1,120 @@ +import os +import argparse + +from filelock import FileLock + +from tqdm import trange + +import ray +from distml.operator.jax_operator import JAXTrainingOperator +from distml.strategy.allreduce_strategy import AllReduceStrategy + +from ray.util.sgd.utils import BATCH_SIZE, override + +import numpy as np +import numpy.random as npr +import jax +from jax import jit, grad, random +from jax.tree_util import tree_flatten +from jax.experimental import optimizers +from jax.lib import xla_client +import jax.numpy as jnp +from jax_util.resnet import ResNet18, ResNet50, ResNet101 +from jax_util.datasets import mnist, Dataloader + + +def initialization_hook(): + # Need this for avoiding a connection restart issue on AWS. + os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" + os.environ["NCCL_LL_THRESHOLD"] = "0" + + # set the below if needed + # print("NCCL DEBUG SET") + # os.environ["NCCL_DEBUG"] = "INFO" + + +class MnistTrainingOperator(JAXTrainingOperator): + @override(JAXTrainingOperator) + def setup(self, config): + batch_size = config["batch_size"] + rng_key = random.PRNGKey(0) + input_shape = (28, 28, 1, batch_size) + lr = config["lr"] + model_name = config["model_name"] + num_classes = config["num_classes"] + + if model_name == "resnet18": + init_fun, predict_fun = ResNet18(num_classes) + elif model_name == "resnet50": + init_fun, predict_fun = ResNet50(num_classes) + elif model_name == "resnet101": + init_fun, predict_fun = ResNet101(num_classes) + else: + raise RuntimeError("Unrecognized model name") + + _, init_params = init_fun(rng_key, input_shape) + + opt_init, opt_update, get_params = optimizers.adam(lr) + opt_state = opt_init(init_params) + + with FileLock(".ray.lock"): + train_images, train_labels, test_images, test_labels = mnist() + + train_images = train_images.reshape(train_images.shape[0], 1, 28, 28).transpose(2, 3, 1, 0) + test_images = test_images.reshape(test_images.shape[0], 1, 28, 28).transpose(2, 3, 1, 0) + + train_loader = Dataloader(train_images, train_labels, batch_size=batch_size, shuffle=True) + test_loader = Dataloader(test_images, test_labels, batch_size=batch_size) + + self.register(model=[opt_state, init_fun, predict_fun], optimizer=[opt_init, opt_update, get_params], criterion=lambda logits, targets:-jnp.sum(logits * targets)) + + self.register_data(train_loader=train_loader, validation_loader=test_loader) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--address", + required=False, + type=str, + help="the address to use for connecting to the Ray cluster") + parser.add_argument( + "--num-workers", + "-n", + type=int, + default=2, + help="Sets number of workers for training.") + parser.add_argument( + "--num-epochs", type=int, default=20, help="Number of epochs to train.") + parser.add_argument( + "--fp16", + action="store_true", + default=False, + help="Enables FP16 training with apex. Requires `use-gpu`.") + parser.add_argument( + "--model-name", type=str, default="resnet18", help="model, Optional: resnet18, resnet50, resnet101.") + + args, _ = parser.parse_known_args() + + if args.address: + ray.init(args.address) + else: + ray.init(num_gpus=args.num_workers, num_cpus=args.num_workers * 2, log_to_driver=True) + + strategy = AllReduceStrategy( + training_operator_cls=MnistTrainingOperator, + world_size=args.num_workers, + operator_config={ + "lr": 0.01, + "batch_size": 128 , + "num_workers": args.num_workers, + "num_classes": 10, + "model_name": args.model_name + }) + + for i in range(args.num_epochs): + strategy.train() + print(strategy.validate()) + + strategy.shutdown() + print("success!") From da251fbd218ac415fb39b09b4ef10f054abff740 Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Thu, 29 Apr 2021 01:58:10 +0800 Subject: [PATCH 03/12] lint --- distml/operator/jax_operator.py | 112 ++++++----- examples/jax/jax_util/__init__.py | 4 +- examples/jax/jax_util/datasets.py | 95 +++++----- examples/jax/jax_util/resnet.py | 299 ++++++++++++++---------------- examples/jax/mnist_jax_example.py | 60 +++--- format.sh | 6 +- 6 files changed, 294 insertions(+), 282 deletions(-) diff --git a/distml/operator/jax_operator.py b/distml/operator/jax_operator.py index 3ac8cd4..b560617 100644 --- a/distml/operator/jax_operator.py +++ b/distml/operator/jax_operator.py @@ -1,23 +1,17 @@ import numpy as np import cupy as cp -import jax -from jax import grad, value_and_grad +from jax import value_and_grad import jax.numpy as jnp from jax.lib import xla_client -from jax.dlpack import from_dlpack, to_dlpack -from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, tree_map, build_tree +from jax.dlpack import from_dlpack +from jax.tree_util import tree_flatten, tree_unflatten, tree_structure from jax._src.util import unzip2 from jax.experimental.optimizers import OptimizerState from .base_operator import TrainingOperator -from distml.util import ThroughputCollection, func_timer -from ray.util.sgd.utils import TimerCollection, AverageMeterCollection - -import time class JAXTrainingOperator(TrainingOperator): - def __init__(self, operator_config): super(JAXTrainingOperator, self).__init__(operator_config) # Should be set by users in the `register` function. @@ -25,7 +19,7 @@ def __init__(self, operator_config): self.opt_state = None self.init_fun = None self.predict_fun = None - # optimizer methods + # optimizer methods self.opt_init = None self.opt_update = None self.get_params = None @@ -39,47 +33,51 @@ def __init__(self, operator_config): self.setup(operator_config) if hasattr(operator_config, "jit_mode"): - assert not operator_config["jit_mode"], "Not support jit in jax operator." + if operator_config["jit_mode"]: + raise NotImplementedError("Not support jit in jax operator.") self.train_step_num = 0 def setup(self, *args, **kwargs): """Function that needs to be override by users. - + Example: - # some code is the same for all users, maybe we can put it in register. + # some code is the same for all users, + # maybe we can put it in register. rng_key = random.PRNGKey(0) input_shape = (28, 28, 1, batch_size) lr=0.01 init_fun, predict_fun = ResNet18(num_classes) _, init_params = init_fun(rng_key, input_shape) - + opt_init, opt_update, get_params = optimizers.adam(lr) opt_state = opt_init(init_params) - - self.register(model=(opt_state, get_params, predict_fun), optimizer=opt_update, criterion=lambda logits, targets:-jnp.sum(logits * targets)) - + + self.register(model=(opt_state, get_params, predict_fun), + optimizer=opt_update, + criterion=lambda logits, \ + targets:-jnp.sum(logits * targets)) + """ - pass - def register(self, + def register(self, *, - model, - optimizer, - criterion, - lr_schedulers=None, + model, + optimizer, + criterion, + lr_schedulers=None, jit_mode=False): """Register a few critical information about the model to operator.""" self.criterion = criterion if lr_schedulers: self.lr_schedulers = lr_schedulers - print("WARNING: jax not support learning rate scheduler." + print("WARNING: jax not support learning rate scheduler." "This will not work.") - + self._register_model(model) self._register_optimizer(optimizer) - + def _register_model(self, model): """register model components. @@ -108,7 +106,8 @@ def _get_validation_loader(self): def loss_func(self, params, batch): """A function to calculate predictions and loss value. - This function is going to be decorated by `grad` in Jax to calculate gradients. + This function is going to be decorated by + `grad` in Jax to calculate gradients. Args: batch (tuple): a data batch containing a feature/target pair. @@ -133,20 +132,28 @@ def derive_updates(self, batch): assert tree == self.opt_state[1] if hasattr(self, "preset_keys"): - gradient_dict = {k:g for k, g in zip(self.preset_keys, gradient_dict)} + gradient_dict = { + k: g + for k, g in zip(self.preset_keys, gradient_dict) + } else: - gradient_dict = {f"{idx}":g for idx, g in enumerate(gradient_dict)} + gradient_dict = { + f"{idx}": g + for idx, g in enumerate(gradient_dict) + } return loss_val.item(), gradient_dict - + def apply_updates(self, updates): """Set and apply the updates using the opt_update in Jax. Args: updates (dict): a dictionary of parameter name and updates. """ - keys, updates = unzip2(sorted(updates.items(), key=lambda d: int(d[0]))) + keys, updates = unzip2( + sorted(updates.items(), key=lambda d: int(d[0]))) updates = tree_unflatten(self.opt_state[1], updates) - self.opt_state = self.opt_update(self.train_step_num, updates, self.opt_state) + self.opt_state = self.opt_update(self.train_step_num, updates, + self.opt_state) self.train_step_num += 1 def to_cupy(self, tensor): @@ -158,8 +165,9 @@ def to_cupy(self, tensor): def to_operator_tensor(self, tensor): """Convert a cupy tensor to jax tensor. - - There comes a bug. The layouts of tensor explained by cupy and jax are different. But dlpack doesn't convert the layout. + + There comes a bug. The layouts of tensor explained by cupy + and jax are different. But dlpack doesn't convert the layout. """ if isinstance(tensor, list): return list(map(self.to_operator_tensor, tensor)) @@ -177,8 +185,8 @@ def get_jax_dlpack(self, tensor): Jax api might cause different pointer address after the conversion. We use the xla api to avoid this bug in jax api. """ - return xla_client._xla.buffer_to_dlpack_managed_tensor(tensor.device_buffer, - take_ownership=False) + return xla_client._xla.buffer_to_dlpack_managed_tensor( + tensor.device_buffer, take_ownership=False) def validate_batch(self, batch): """Perform validation over a data batch. @@ -220,32 +228,36 @@ def get_parameters(self, cpu): def get_named_parameters(self, cpu): """Get the named parameters. - + In jax, we need to construct a dict to contain the parameters. """ params = self.get_parameters(cpu) if hasattr(self, "preset_keys"): - dict_params = {name:p for name, p in zip(self.preset_keys, params)} + dict_params = { + name: p + for name, p in zip(self.preset_keys, params) + } else: - dict_params = {f"{idx}":p for idx, p in enumerate(params)} + dict_params = {f"{idx}": p for idx, p in enumerate(params)} return dict_params - # TODO(HUI): used in load states or load parameters + # TODO(HUI): used in load states or load parameters def set_parameters(self, new_params): """Use new parameters to replace model parameters. - + In jax, we need to construct a dict to contain the parameters. - + Args: new_params (dict): New parameters to updates the current model. """ assert isinstance(new_params, dict) - keys, new_params = unzip2(sorted(new_params.items(), key=lambda d: int(d[0]))) + keys, new_params = unzip2( + sorted(new_params.items(), key=lambda d: int(d[0]))) self.preset_keys = keys if not hasattr(self, "tree"): - self.tree = tree_structure(self.get_params(self.opt_state)) + self.tree = tree_structure(self.get_params(self.opt_state)) states_flat, tree, subtrees = self.opt_state @@ -261,10 +273,12 @@ def update(param, state): if not new_subtrees: raise RuntimeError("subtrees of new params is empty.") - for idx, (subtree, new_subtree) in enumerate(zip(subtrees, new_subtrees)): + for idx, (subtree, new_subtree) in enumerate( + zip(subtrees, new_subtrees)): if new_subtree != subtree: - msg = ("input structur did not match the save params struture. " - "input {} and output {}.") + msg = ( + "input structur did not match the save params struture. " + "input {} and output {}.") raise TypeError(msg.format(subtree, new_subtree)) self.opt_state = OptimizerState(new_states_flat, tree, new_subtrees) @@ -322,10 +336,8 @@ def save_states(self, states): "save_states is not support in jax operator.") def get_states(self, states): - raise NotImplementedError( - "get_states is not support in jax operator.") + raise NotImplementedError("get_states is not support in jax operator.") def load_states(self, checkpoint): raise NotImplementedError( "load_states is not support in jax operator.") - diff --git a/examples/jax/jax_util/__init__.py b/examples/jax/jax_util/__init__.py index 955ceb9..bcf82e0 100644 --- a/examples/jax/jax_util/__init__.py +++ b/examples/jax/jax_util/__init__.py @@ -1,2 +1,2 @@ -from .datasets import mnist, Dataloader -from .resnet import ResNet18, ResNet50, ResNet101 \ No newline at end of file +from .datasets import mnist, Dataloader # noqa: F401 +from .resnet import ResNet18, ResNet50, ResNet101 # noqa: F401 diff --git a/examples/jax/jax_util/datasets.py b/examples/jax/jax_util/datasets.py index 6bc9385..a2b8cf0 100644 --- a/examples/jax/jax_util/datasets.py +++ b/examples/jax/jax_util/datasets.py @@ -11,31 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Datasets used in examples.""" - import array import gzip import os from os import path import struct import urllib.request -from jax.api import F import jax.numpy as jnp -from jax import jit import numpy as np import numpy.random as npr import pickle -from functools import partial _DATA = "/tmp/jax_example_data/" def _download(url, filename, dataset_name="mnist"): """Download a url to a file in the JAX data temp directory.""" - root = os.path.join(_DATA,dataset_name) + root = os.path.join(_DATA, dataset_name) if not path.exists(root): os.makedirs(root) out_file = path.join(root, filename) @@ -53,11 +48,13 @@ def _one_hot(x, k, dtype=np.float32): """Create a one-hot encoding of x of size k.""" return np.asarray(x[:, None] == np.arange(k), dtype) + # @partial(jit, static_argnums=1) def _one_hot_jit(x, k, dtype=np.float32): """Create a one-hot encoding of x of size k.""" return jnp.asarray(x[:, None] == jnp.arange(0, k), dtype) + def mnist_raw(): """Download and parse the raw MNIST dataset.""" # CVDF mirror of http://yann.lecun.com/exdb/mnist/ @@ -71,23 +68,33 @@ def parse_labels(filename): def parse_images(filename): with gzip.open(filename, "rb") as fh: _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) - return np.array(array.array("B", fh.read()), - dtype=np.uint8).reshape(num_data, rows, cols) - - for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", - "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]: + return np.array( + array.array("B", fh.read()), dtype=np.uint8).reshape( + num_data, rows, cols) + + for filename in [ + "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz" + ]: _download(base_url + filename, filename) - train_images = parse_images(path.join(_DATA, "mnist", "train-images-idx3-ubyte.gz")) - train_labels = parse_labels(path.join(_DATA, "mnist", "train-labels-idx1-ubyte.gz")) - test_images = parse_images(path.join(_DATA, "mnist", "t10k-images-idx3-ubyte.gz")) - test_labels = parse_labels(path.join(_DATA, "mnist", "t10k-labels-idx1-ubyte.gz")) + train_images = parse_images( + path.join(_DATA, "mnist", "train-images-idx3-ubyte.gz")) + train_labels = parse_labels( + path.join(_DATA, "mnist", "train-labels-idx1-ubyte.gz")) + test_images = parse_images( + path.join(_DATA, "mnist", "t10k-images-idx3-ubyte.gz")) + test_labels = parse_labels( + path.join(_DATA, "mnist", "t10k-labels-idx1-ubyte.gz")) return train_images, train_labels, test_images, test_labels def mnist(permute_train=False): - """Download, parse and process MNIST data to unit scale and one-hot labels.""" + """ + Download, parse and process MNIST data + to unit scale and one-hot labels. + """ train_images, train_labels, test_images, test_labels = mnist_raw() train_images = _partial_flatten(train_images) / np.float32(255.) @@ -102,6 +109,7 @@ def mnist(permute_train=False): return train_images, train_labels, test_images, test_labels + def cifa100_raw(): """Download and parse the raw MNIST dataset.""" base_url = "http://www.cs.toronto.edu/~kriz/" @@ -116,8 +124,8 @@ def load_CIFAR_batch(root, mode="train"): raise RuntimeError("Error: unrecognized mode", " Got {}".format(mode)) - with open(filename, 'rb')as f: - datadict = pickle.load(f,encoding='bytes') + with open(filename, 'rb') as f: + datadict = pickle.load(f, encoding='bytes') X = datadict[b'data'] Y = datadict[b'fine_labels'] if mode == "train": @@ -132,20 +140,23 @@ def load_CIFAR_batch(root, mode="train"): root = path.join(_DATA, "cifa100") if not os.path.exists(path.join(root, "cifar-100-python.tar.gz")): - os.system("tar xvf {} -C {}".format(path.join(root, "cifar-100-python.tar.gz"), - root)) + os.system("tar xvf {} -C {}".format( + path.join(root, "cifar-100-python.tar.gz"), root)) - train_images, train_labels = load_CIFAR_batch(path.join(root, "cifar-100-python"), - mode="train") - test_images, test_labels = load_CIFAR_batch(path.join(root, "cifar-100-python"), - mode="test") + train_images, train_labels = load_CIFAR_batch( + path.join(root, "cifar-100-python"), mode="train") + test_images, test_labels = load_CIFAR_batch( + path.join(root, "cifar-100-python"), mode="test") # b"fine_label_names" b"coarse_label_names" # meta_path = path.join(root, "cifar-100-python", "meta") return train_images, train_labels, test_images, test_labels + def cifa100(permute_train=False): - """Download, parse and process cida100 data to unit scale and one-hot labels.""" + """ + Download, parse and process cida100 data to unit scale and one-hot labels. + """ train_images, train_labels, test_images, test_labels = cifa100_raw() train_images = _partial_flatten(train_images) / np.float32(255.) @@ -169,7 +180,7 @@ def load_CIFAR_batch(root, mode="train"): """ load single batch of cifar """ if mode == "train": filenames = [] - for i in range(1,6): + for i in range(1, 6): filenames.append(path.join(root, f"data_batch_{i}")) elif mode == "test": filenames = [path.join(root, "test_batch")] @@ -180,8 +191,8 @@ def load_CIFAR_batch(root, mode="train"): datas = [] labels = [] for filename in filenames: - with open(filename, 'rb')as f: - datadict = pickle.load(f,encoding='bytes') + with open(filename, 'rb') as f: + datadict = pickle.load(f, encoding='bytes') X = datadict[b'data'] Y = datadict[b'labels'] X = X.reshape(10000, 3, 32, 32) @@ -195,13 +206,13 @@ def load_CIFAR_batch(root, mode="train"): root = path.join(_DATA, "cifa10") if not os.path.exists(path.join(root, "cifar-10-batches-py")): - os.system("tar xvf {} -C {}".format(path.join(root, "cifar-10-python.tar.gz"), - root)) + os.system("tar xvf {} -C {}".format( + path.join(root, "cifar-10-python.tar.gz"), root)) - train_images, train_labels = load_CIFAR_batch(path.join(root, "cifar-10-batches-py"), - mode="train") - test_images, test_labels = load_CIFAR_batch(path.join(root, "cifar-10-batches-py"), - mode="test") + train_images, train_labels = load_CIFAR_batch( + path.join(root, "cifar-10-batches-py"), mode="train") + test_images, test_labels = load_CIFAR_batch( + path.join(root, "cifar-10-batches-py"), mode="test") print(test_images.shape) # b"fine_label_names" b"coarse_label_names" @@ -210,7 +221,10 @@ def load_CIFAR_batch(root, mode="train"): def cifa10(permute_train=False): - """Download, parse and process cida100 data to unit scale and one-hot labels.""" + """ + Download, parse and process cida100 data + to unit scale and one-hot labels. + """ train_images, train_labels, test_images, test_labels = cifa10_raw() train_images = _partial_flatten(train_images) / np.float32(255.) @@ -243,7 +257,8 @@ def __init__(self, data, target, batch_size=128, shuffle=False): def synth_batches(self): num_imgs = self.target.shape[0] rng = npr.RandomState(npr.randint(10)) - perm = rng.permutation(num_imgs) if self.shuffle else np.arange(num_imgs) + perm = rng.permutation(num_imgs) if self.shuffle else np.arange( + num_imgs) for i in range(self.num_batches): batch_idx = perm[i * self.batch_size:(i + 1) * self.batch_size] img_batch = self.data[:, :, :, batch_idx] @@ -271,11 +286,3 @@ def __len__(self): print(train_images.shape, train_labels.shape) print(type(test_images), type(test_labels)) print(test_images.shape, test_labels.shape) - - # cifa10_filepath = path.join(_DATA, "cifa10", "cifar-10-batches-py/test_batch") - # with open(cifa10_filepath, 'rb')as f: - # datadict = pickle.load(f,encoding='bytes') - # print(datadict.keys()) - # print(datadict[b"data"]) - # print(type(datadict[b"data"])) - # print(len(datadict[b"labels"])) \ No newline at end of file diff --git a/examples/jax/jax_util/resnet.py b/examples/jax/jax_util/resnet.py index 7387b10..f3e89f5 100644 --- a/examples/jax/jax_util/resnet.py +++ b/examples/jax/jax_util/resnet.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """A mock-up showing a ResNet50 network with training on synthetic data. This file uses the stax neural network definition library and the optimizers @@ -28,194 +27,178 @@ FanOut, Flatten, GeneralConv, Identity, MaxPool, Relu, LogSoftmax) - # ResNet blocks compose other layers + def ConvBlock(kernel_size, filters, strides=(2, 2)): - ks = kernel_size - filters1, filters2, filters3 = filters - Main = stax.serial( - Conv(filters1, (1, 1), strides), BatchNorm(), Relu, - Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, - Conv(filters3, (1, 1)), BatchNorm()) - Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) - return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) + ks = kernel_size + filters1, filters2, filters3 = filters + Main = stax.serial( + Conv(filters1, (1, 1), strides), BatchNorm(), Relu, + Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, + Conv(filters3, (1, 1)), BatchNorm()) + Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) + return stax.serial( + FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) def IdentityBlock(kernel_size, filters): - ks = kernel_size - filters1, filters2 = filters - def make_main(input_shape): - # the number of output channels depends on the number of input channels + ks = kernel_size + filters1, filters2 = filters + + def make_main(input_shape): + # the number of output channels depends on the number of input channels + return stax.serial( + Conv(filters1, (1, 1)), BatchNorm(), Relu, + Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, + Conv(input_shape[3], (1, 1)), BatchNorm()) + + Main = stax.shape_dependent(make_main) return stax.serial( - Conv(filters1, (1, 1)), BatchNorm(), Relu, - Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, - Conv(input_shape[3], (1, 1)), BatchNorm()) - Main = stax.shape_dependent(make_main) - return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) + FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) def BasicBlock(kernel_size, filters, strides=(1, 1)): - ks = kernel_size - filters1, filters2 = filters - Main = stax.serial( - Conv(filters1, (ks, ks), strides, padding='SAME'), BatchNorm(), Relu, - Conv(filters2, (ks, ks), strides, padding='SAME'), BatchNorm()) + ks = kernel_size + filters1, filters2 = filters + Main = stax.serial( + Conv(filters1, (ks, ks), strides, padding='SAME'), BatchNorm(), Relu, + Conv(filters2, (ks, ks), strides, padding='SAME'), BatchNorm()) + + Shortcut = stax.serial(Conv(filters2, (1, 1), strides), BatchNorm()) + return stax.serial( + FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) - Shortcut = stax.serial(Conv(filters2, (1, 1), strides), BatchNorm()) - return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) def BasicBlock_withoutBN(kernel_size, filters, strides=(1, 1)): - ks = kernel_size - filters1, filters2 = filters - Main = stax.serial( - Conv(filters1, (ks, ks), strides, padding='SAME'), Relu, - Conv(filters2, (ks, ks), strides, padding='SAME')) + ks = kernel_size + filters1, filters2 = filters + Main = stax.serial( + Conv(filters1, (ks, ks), strides, padding='SAME'), Relu, + Conv(filters2, (ks, ks), strides, padding='SAME')) - Shortcut = stax.serial(Conv(filters2, (1, 1), strides)) - return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) + Shortcut = stax.serial(Conv(filters2, (1, 1), strides)) + return stax.serial( + FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) def IdentityBlock_withoutBN(kernel_size, filters): - ks = kernel_size - filters1, filters2 = filters - def make_main(input_shape): - # the number of output channels depends on the number of input channels + ks = kernel_size + filters1, filters2 = filters + + def make_main(input_shape): + # the number of output channels depends on the number of input channels + return stax.serial( + Conv(filters1, (1, 1)), Relu, + Conv(filters2, (ks, ks), padding='SAME'), Relu, + Conv(input_shape[3], (1, 1))) + + Main = stax.shape_dependent(make_main) return stax.serial( - Conv(filters1, (1, 1)), Relu, - Conv(filters2, (ks, ks), padding='SAME'), Relu, - Conv(input_shape[3], (1, 1))) - Main = stax.shape_dependent(make_main) - return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) + FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) + # ResNet architectures compose layers and ResNet blocks + def ResNet101(num_classes): - return stax.serial( - GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), - BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), - ConvBlock(3, [64, 64, 256], strides=(1, 1)), - IdentityBlock(3, [64, 64]), - IdentityBlock(3, [64, 64]), - ConvBlock(3, [128, 128, 512]), - IdentityBlock(3, [128, 128]), - IdentityBlock(3, [128, 128]), - IdentityBlock(3, [128, 128]), - ConvBlock(3, [256, 256, 1024]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - ConvBlock(3, [512, 512, 2048]), - IdentityBlock(3, [512, 512]), - IdentityBlock(3, [512, 512]), - AvgPool((7, 7), padding="SAME"), Flatten, Dense(num_classes), LogSoftmax) + return stax.serial( + GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), + BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), + ConvBlock(3, [64, 64, 256], strides=(1, + 1)), IdentityBlock(3, [64, 64]), + IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), + IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), + IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), + IdentityBlock(3, [512, 512]), AvgPool((7, 7), padding="SAME"), Flatten, + Dense(num_classes), LogSoftmax) def ResNet50(num_classes): - return stax.serial( - GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), - BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), - ConvBlock(3, [64, 64, 256], strides=(1, 1)), - IdentityBlock(3, [64, 64]), - IdentityBlock(3, [64, 64]), - ConvBlock(3, [128, 128, 512]), - IdentityBlock(3, [128, 128]), - IdentityBlock(3, [128, 128]), - IdentityBlock(3, [128, 128]), - ConvBlock(3, [256, 256, 1024]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - ConvBlock(3, [512, 512, 2048]), - IdentityBlock(3, [512, 512]), - IdentityBlock(3, [512, 512]), - AvgPool((7, 7), padding="SAME"), Flatten, Dense(num_classes), LogSoftmax) + return stax.serial( + GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), + BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), + ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock( + 3, [64, 64]), IdentityBlock(3, [64, 64]), + ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), + IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), + ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), + IdentityBlock(3, [512, 512]), AvgPool((7, 7), padding="SAME"), Flatten, + Dense(num_classes), LogSoftmax) def ResNet18(num_classes): - return stax.serial( - GeneralConv(('HWCN', 'OIHW', 'NHWC'), 1, (7, 7), (2, 2), 'SAME'), - BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), - BasicBlock(3, [64, 64]), - IdentityBlock(3, [64, 64]), - BasicBlock(3, [128, 128]), - IdentityBlock(3, [128, 128]), - BasicBlock(3, [256, 256]), - IdentityBlock(3, [256, 256]), - BasicBlock(3, [512, 512]), - IdentityBlock(3, [512, 512]), - AvgPool((7, 7), padding="SAME"), Flatten, Dense(num_classes), LogSoftmax) + return stax.serial( + GeneralConv(('HWCN', 'OIHW', 'NHWC'), 1, (7, 7), (2, 2), 'SAME'), + BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), + BasicBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), + BasicBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), + BasicBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), + BasicBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), + AvgPool((7, 7), padding="SAME"), Flatten, Dense(num_classes), + LogSoftmax) def MLP(num_classes): - return stax.serial( - Flatten, - Dense(32), BatchNorm(), Relu, - Dense(128), BatchNorm(), Relu, - Dense(num_classes), LogSoftmax) + return stax.serial(Flatten, Dense(32), BatchNorm(), Relu, Dense(128), + BatchNorm(), Relu, Dense(num_classes), LogSoftmax) if __name__ == "__main__": - rng_key = random.PRNGKey(0) - - batch_size = 8 - num_classes = 1001 - input_shape = (224, 224, 3, batch_size) - step_size = 0.1 - num_steps = 10 - - init_fun, predict_fun = ResNet50(num_classes) - _, init_params = init_fun(rng_key, input_shape) - - def loss(params, batch): - inputs, targets = batch - logits = predict_fun(params, inputs) - return -jnp.sum(logits * targets) - - def accuracy(params, batch): - inputs, targets = batch - target_class = jnp.argmax(targets, axis=-1) - predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1) - return jnp.mean(predicted_class == target_class) - - def synth_batches(): - rng = npr.RandomState(0) - while True: - images = rng.rand(*input_shape).astype('float32') - labels = rng.randint(num_classes, size=(batch_size, 1)) - onehot_labels = labels == jnp.arange(num_classes) - yield images, onehot_labels - - opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9) - batches = synth_batches() - - @jit - def update(i, opt_state, batch): - params = get_params(opt_state) - return opt_update(i, grad(loss)(params, batch), opt_state) - - opt_state = opt_init(init_params) - for i in range(num_steps): - opt_state = update(i, opt_state, next(batches)) - trained_params = get_params(opt_state) + rng_key = random.PRNGKey(0) + + batch_size = 8 + num_classes = 1001 + input_shape = (224, 224, 3, batch_size) + step_size = 0.1 + num_steps = 10 + + init_fun, predict_fun = ResNet50(num_classes) + _, init_params = init_fun(rng_key, input_shape) + + def loss(params, batch): + inputs, targets = batch + logits = predict_fun(params, inputs) + return -jnp.sum(logits * targets) + + def accuracy(params, batch): + inputs, targets = batch + target_class = jnp.argmax(targets, axis=-1) + predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1) + return jnp.mean(predicted_class == target_class) + + def synth_batches(): + rng = npr.RandomState(0) + while True: + images = rng.rand(*input_shape).astype('float32') + labels = rng.randint(num_classes, size=(batch_size, 1)) + onehot_labels = labels == jnp.arange(num_classes) + yield images, onehot_labels + + opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9) + batches = synth_batches() + + @jit + def update(i, opt_state, batch): + params = get_params(opt_state) + return opt_update(i, grad(loss)(params, batch), opt_state) + + opt_state = opt_init(init_params) + for i in range(num_steps): + opt_state = update(i, opt_state, next(batches)) + trained_params = get_params(opt_state) diff --git a/examples/jax/mnist_jax_example.py b/examples/jax/mnist_jax_example.py index 09a36ad..2335b5d 100644 --- a/examples/jax/mnist_jax_example.py +++ b/examples/jax/mnist_jax_example.py @@ -3,23 +3,16 @@ from filelock import FileLock -from tqdm import trange - import ray from distml.operator.jax_operator import JAXTrainingOperator from distml.strategy.allreduce_strategy import AllReduceStrategy -from ray.util.sgd.utils import BATCH_SIZE, override +from ray.util.sgd.utils import override -import numpy as np -import numpy.random as npr -import jax -from jax import jit, grad, random -from jax.tree_util import tree_flatten +from jax import random from jax.experimental import optimizers -from jax.lib import xla_client import jax.numpy as jnp -from jax_util.resnet import ResNet18, ResNet50, ResNet101 +from jax_util.resnet import ResNet18, ResNet50, ResNet101 from jax_util.datasets import mnist, Dataloader @@ -53,22 +46,30 @@ def setup(self, config): raise RuntimeError("Unrecognized model name") _, init_params = init_fun(rng_key, input_shape) - + opt_init, opt_update, get_params = optimizers.adam(lr) opt_state = opt_init(init_params) - + with FileLock(".ray.lock"): train_images, train_labels, test_images, test_labels = mnist() - - train_images = train_images.reshape(train_images.shape[0], 1, 28, 28).transpose(2, 3, 1, 0) - test_images = test_images.reshape(test_images.shape[0], 1, 28, 28).transpose(2, 3, 1, 0) - train_loader = Dataloader(train_images, train_labels, batch_size=batch_size, shuffle=True) - test_loader = Dataloader(test_images, test_labels, batch_size=batch_size) - - self.register(model=[opt_state, init_fun, predict_fun], optimizer=[opt_init, opt_update, get_params], criterion=lambda logits, targets:-jnp.sum(logits * targets)) - - self.register_data(train_loader=train_loader, validation_loader=test_loader) + train_images = train_images.reshape(train_images.shape[0], 1, 28, + 28).transpose(2, 3, 1, 0) + test_images = test_images.reshape(test_images.shape[0], 1, 28, + 28).transpose(2, 3, 1, 0) + + train_loader = Dataloader( + train_images, train_labels, batch_size=batch_size, shuffle=True) + test_loader = Dataloader( + test_images, test_labels, batch_size=batch_size) + + self.register( + model=[opt_state, init_fun, predict_fun], + optimizer=[opt_init, opt_update, get_params], + criterion=lambda logits, targets: -jnp.sum(logits * targets)) + + self.register_data( + train_loader=train_loader, validation_loader=test_loader) if __name__ == "__main__": @@ -85,28 +86,37 @@ def setup(self, config): default=2, help="Sets number of workers for training.") parser.add_argument( - "--num-epochs", type=int, default=20, help="Number of epochs to train.") + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.") parser.add_argument( "--fp16", action="store_true", default=False, help="Enables FP16 training with apex. Requires `use-gpu`.") parser.add_argument( - "--model-name", type=str, default="resnet18", help="model, Optional: resnet18, resnet50, resnet101.") + "--model-name", + type=str, + default="resnet18", + help="model, Optional: resnet18, resnet50, resnet101.") args, _ = parser.parse_known_args() if args.address: ray.init(args.address) else: - ray.init(num_gpus=args.num_workers, num_cpus=args.num_workers * 2, log_to_driver=True) + ray.init( + num_gpus=args.num_workers, + num_cpus=args.num_workers * 2, + log_to_driver=True) strategy = AllReduceStrategy( training_operator_cls=MnistTrainingOperator, world_size=args.num_workers, operator_config={ "lr": 0.01, - "batch_size": 128 , + "batch_size": 128, "num_workers": args.num_workers, "num_classes": 10, "model_name": args.model_name diff --git a/format.sh b/format.sh index bad8cb9..d5abce5 100755 --- a/format.sh +++ b/format.sh @@ -106,14 +106,14 @@ format_changed() { yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" if which flake8 >/dev/null; then git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ - flake8 '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 + flake8 '"' --ignore=N,I,C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 fi fi if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then if which flake8 >/dev/null; then git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ - flake8 '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 + flake8 '"' --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 fi fi } @@ -121,7 +121,7 @@ format_changed() { # Format all files, and print the diff to stdout for travis. format_all() { yapf --diff "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" distml - flake8 '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 distml + flake8 '"' --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 distml } # This flag formats individual files. --files *must* be the first command line From 225abfd2208cc7a2076fd4ed4f6ac8efb6ca59c6 Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Thu, 29 Apr 2021 02:10:14 +0800 Subject: [PATCH 04/12] setup string --- distml/operator/jax_operator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distml/operator/jax_operator.py b/distml/operator/jax_operator.py index b560617..a96e40f 100644 --- a/distml/operator/jax_operator.py +++ b/distml/operator/jax_operator.py @@ -53,11 +53,11 @@ def setup(self, *args, **kwargs): opt_init, opt_update, get_params = optimizers.adam(lr) opt_state = opt_init(init_params) - self.register(model=(opt_state, get_params, predict_fun), - optimizer=opt_update, - criterion=lambda logits, \ - targets:-jnp.sum(logits * targets)) + criterion = lambda logits, targets:-jnp.sum(logits * targets) + self.register(model=(opt_state, init_fun, predict_fun), + optimizer=(opt_init, opt_update, get_params), + criterion=criterion) """ pass From 560dd7594724f56daf167034aea628a87f6d3b99 Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Fri, 30 Apr 2021 00:21:26 +0800 Subject: [PATCH 05/12] reset format.sh --- format.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/format.sh b/format.sh index d5abce5..c03b7b9 100755 --- a/format.sh +++ b/format.sh @@ -46,7 +46,7 @@ builtin cd "$ROOT" || exit 1 # Add the upstream remote if it doesn't exist if ! git remote -v | grep -q upstream; then - git remote add 'upstream' 'https://yuan.cm/https://github.com/ray-project/distml.git' + git remote add 'upstream' 'https://github.com/ray-project/distml.git' fi FLAKE8_VERSION=$(flake8 --version | awk '{print $1}') @@ -106,14 +106,14 @@ format_changed() { yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" if which flake8 >/dev/null; then git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ - flake8 '"' --ignore=N,I,C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 + flake8 --inline-quotes '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 fi fi if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then if which flake8 >/dev/null; then git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ - flake8 '"' --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 + flake8 --inline-quotes '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 fi fi } @@ -121,7 +121,7 @@ format_changed() { # Format all files, and print the diff to stdout for travis. format_all() { yapf --diff "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" distml - flake8 '"' --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 distml + flake8 --inline-quotes '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 distml } # This flag formats individual files. --files *must* be the first command line From b8d2ec91b099adec168394183dc6bf221088e33b Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Fri, 30 Apr 2021 02:01:59 +0800 Subject: [PATCH 06/12] init ps strategy --- distml/strategy/ps_strategy.py | 702 +++++++++++++++++++++++++++++++++ examples/jax/default_train.csv | 2 - 2 files changed, 702 insertions(+), 2 deletions(-) create mode 100644 distml/strategy/ps_strategy.py delete mode 100644 examples/jax/default_train.csv diff --git a/distml/strategy/ps_strategy.py b/distml/strategy/ps_strategy.py new file mode 100644 index 0000000..382daeb --- /dev/null +++ b/distml/strategy/ps_strategy.py @@ -0,0 +1,702 @@ +import ray +import ray.util.collective as col + +import numpy as np + +import distml.strategy.util as util +from distml.strategy.base_trainer import BaseTrainer +from .util import ThroughputCollection + +import logging + +logger = logging.getLogger(__name__) + + +class ParameterServerStrategy(BaseTrainer): + """Strategy that trains a model via collective AllReduce. + + Args: + training_operator_cls (TrainingOperator): + Custom training operator class. + operator_config (dict): operator config specified by users. + initialization_hook (function): A function to call on all training + workers when they are first initialized. This could be useful to + set environment variables for all the worker processes. + num_workers (int): The number of workers. + num_ps (int): The number of parameter servers. + num_cpus_per_worker (int): number of CPUs allocated per worker. + num_gpus_per_worker (int): number of GPUs allocated per worker. + num_cpus_per_server (int): number of CPUs allocated per server. + num_gpus_per_server (int): number of GPUs allocated per server. + """ + + def __init__(self, + *, + training_operator_cls, + operator_config=None, + initialization_hook=None, + num_workers=1, + num_ps=1, + num_cpus_per_worker=1, + num_gpus_per_worker=1, + num_cpus_per_server=1, + num_gpus_per_server=1, + **kwargs): + self.assignments = None + + assert num_ps + self.num_ps = num_ps + self.num_workers = num_workers + self.num_cpus_per_server = num_cpus_per_server + self.num_gpus_per_server = num_gpus_per_server + + super(ParameterServerStrategy, self).\ + __init__(training_operator_cls=training_operator_cls, + operator_config=operator_config, + initialization_hook=initialization_hook, + num_cpus_per_worker=num_cpus_per_worker, + num_gpus_per_worker=num_gpus_per_worker, + **kwargs) + + # PS strategy needs some other prep-up. + self._init_strategy() + + if operator_config and operator_config.get("batch_size"): + self._global_batch_size = operator_config.get("batch_size") + if self._global_batch_size: + self._collector = ThroughputCollection( + batch_size=self._global_batch_size) + else: + self._collector = ThroughputCollection() + + def _init_strategy(self): + """Do initialization for the distributed strategy.""" + # All sync with worker 0 + init_weights_id = self.worker_group.get_named_parameters(cpu=True) + + self._round_robin_sharding() + + # set assignments to every worker + self.worker_group.set_assignments(self.assignments) + + # all workers get synced + for i, worker in enumerate(self.worker_group.actors): + if i != 0: + ray.get([worker.set_parameters.remote(init_weights_id)]) + + # now spawn parameter server actors + shard_ids = self.worker_group.split_parameters(self.assignments) + + # TODO(HUI): use scatter to send parameters + for server_idx, server in enumerate(self.server_group.actors): + this_shard_ref = self.worker_group.actors[0].index_shard.remote( + shard_ids, server_idx) + ray.get([server.set_params.remote(this_shard_ref)]) + + def _start_workers(self): + """Create worker(actor), maybe need worker group to manager these workers. + Or, send these workers to strategy to manager? + + set workers or worker group + set worker info, record rank, backend, use_num_gpus? + """ + # TODO (Hao): infer the per-replica batch size here... + + # so here we get two set of params that will be passed around: + # (1) Those for setting up training logic in training_operator, + # including: batchsize, use_tqdm, user defined operator_config. + operator_config = self._operator_config.copy() + params = dict( + training_operator_cls=self.training_operator_cls, + operator_config=operator_config) + # (2) params for setting up collective group + # and the strategy-related things; + + # For now, we do not have many of them though. + dist_params_worker = dict( + strategy="ps", + is_server=False, + group_name="default", + num_ps=self.num_ps, + num_workers=self.num_workers, + ) + + dist_params_server = dict( + strategy="ps", + is_server=True, + group_name="default", + num_ps=self.num_ps, + num_workers=self.num_workers, + ) + + # (3) other arguments that used to init the DataParallelGrup + workergroup_init_args = { + "params": params, + "dist_params": dist_params_worker, + "num_cpus_per_actor": self.num_cpus_per_worker, + "num_gpus_per_actor": self.num_gpus_per_worker, + } + + servergroup_init_args = { + "params": params, + "dist_params": dist_params_server, + "num_cpus_per_actor": self.num_cpus_per_server, + "num_gpus_per_actor": self.num_gpus_per_server, + } + + # Should we make two groups for worker and server? + self.worker_group = DataParallelGroup(**workergroup_init_args) + self.server_group = DataParallelGroup(**servergroup_init_args) + + # Once the group is created, we start it. + self.worker_group.start_actors(self.num_workers) + self.server_group.start_actors( + self.num_ps) # server at the last num_ps processes. + + worker_rets = self.worker_group.test_connection() + server_rets = self.server_group.test_connection() + ray.get(worker_rets + server_rets) + ray.get(self.worker_group.setup_operator()) + ray.get(self.server_group.setup_operator()) + + self.server_group.clean_redundancy() + + def shutdown(self, force=False): + self.worker_group.shutdown(force=force) + self.server_group.shutdown(force=force) + + def save_parameters(self, checkpoint): + # TODO(HUI): ps save parameters. + # First, worker rank 0 should pull the latest parameter from servers + # Then, worker rank 0 save parameters + self.worker_group.save_parameters(checkpoint) + + def load_parameters(self, checkpoint): + # TODO(HUI): ps load parameters. + # shard parameters and send to all servers. + self.server_group.load_parameters(checkpoint) + + def _round_robin_sharding(self): + """Generate the assignment of variable to servers.""" + parameter_distribution = ray.get( + self.worker_group.actors[0].params_distribution.remote()) + assignments = [0 for _ in parameter_distribution] + loads = [0 for _ in range(self.num_ps)] + for i, var_size in enumerate(parameter_distribution): + min_ps_index = loads.index(min(loads)) + loads[min_ps_index] += var_size + assignments[i] = min_ps_index + print("Load of each ps {}".format(loads)) + self.assignments = assignments + + def train(self, num_steps=None): + # TODO (Hao): add fault tolerance using `max_retries`. + steps = num_steps if num_steps \ + else self.worker_group.get_data_loader_len() + + # TODO(HUI): Record server rank instead of using num_ps. + # TODO(Hao): this call should be hidden inside Replica. + # train one epoch + self.worker_group.make_iterator() + for idx in range(steps): + with self._collector.record("train"): + metrics = self.train_batch() + logger.info("Step: {}/{}".format(idx, steps)) + return metrics + + def validate(self, num_steps=None): + steps = num_steps if num_steps \ + else self.worker_group.get_data_loader_len(training=False) + self.worker_group.make_iterator(training=False) + + # TODO(HUI): Construct a better tool to save validate results. + for idx in range(steps): + batch_metrics = self.worker_group.validate_batch() + # Validate results should be the same in all workers + return batch_metrics + + def train_batch(self): + loss_vals = [] + rets = [] + metrics = {} + + for worker_idx, worker in enumerate(self.worker_group.actors): + for server_idx, server in enumerate(self.server_group.actors): + # every server sends its shard to the worker + server.send_params.remote(worker_idx) + # the worker receives shards from ps, compute loss, gradients + # and sends these gradients to every server + loss_val = worker.compute.remote() + loss_vals.append(loss_val) + + for worker_idx, worker in enumerate(self.worker_group.actors): + for server in self.server_group.actors: + rets.append(server.update.remote(worker_idx)) + + loss_vals = ray.get(loss_vals) + ray.get(rets) + train_loss_list = [d["train_loss"] for d in loss_vals] + metrics["train_loss"] = np.mean(train_loss_list) + return metrics + + +class PS(object): + def __init__(self, training_operator_cls, operator_config): + self.training_operator_cls = training_operator_cls + self.operator_config = operator_config + + self.grad_counts = None + self.params = dict() + + def setup_operator(self): + # figure out the signature of training_operator_cls later. + self.training_operator = self.training_operator_cls( + self.operator_config) + + def setup_collective_group(self, + rank, + num_ps, + num_workers, + backend="nccl", + group_name="default"): + # rank should be true rank. means, rank has already plus num_worker. + self.rank = rank + self.num_ps = num_ps + self.num_workers = num_workers + self.group_name = group_name + self.group_size = num_ps + num_workers + self._init_grad_counts() + # the last num_ps processes are servers. + col.init_collective_group( + num_ps + num_workers, rank, backend=backend, group_name=group_name) + + def test_connection(self): + for i in range(self.num_workers): + recv = util.zeros((1, ), cpu=False) + col.recv(recv, i, self.group_name) + assert recv == 1 + for i in range(self.num_workers): + send = util.ones((1, ), cpu=False) + col.send(send, i, self.group_name) + return + + def _init_grad_counts(self): + self.grad_counts = [0] * self.num_workers + + def _init_grad_buffer(self): + self.grad_buffer = { + k: self.training_operator.zeros_like(v, cpu=False) + for k, v in self.params.items() + } + + def get_params(self): + return self.params + + def set_params(self, params): + # params should in GPU when calling this function. + for k, v in params.items(): + self.params[k] = self.training_operator.asarray(v) + + # param is a dict, if needed list, should convert in operator. + self.training_operator.reset_optimizer_for_params(self.params) + self._init_grad_buffer() + + def apply_updates(self, grad_buffer): + # TODO(HUI): gradient divide by num_workers + self.training_operator.apply_updates(grad_buffer) + self.params = self.training_operator.get_named_parameters(cpu=False) + + def _inc_gradients(self, gradients): + for name, p in self.get_params().items(): + if gradients[name] is not None: + self.grad_buffer[name] += gradients[name] + + def send_params(self, dst_rank): + """ Send this param shard to the destination worker """ + for name, v in self.params.items(): + cv = self.training_operator.to_cupy(v) + col.send(cv, dst_rank, self.group_name) + + def update(self, src_rank): + """Receive gradients and update""" + keys = list(self.params.keys()) + grads = dict() + recv_list = [] + + for key in keys: + to_recv = self.params[key] + recv_list.append( + self.training_operator.zeros(to_recv.shape, cpu=False)) + + for i in range(len(keys)): + v = self.training_operator.to_cupy(recv_list[i]) + col.recv(v, src_rank, self.group_name) + + for i in range(len(keys)): + grads[keys[i]] = recv_list[i] + + self._inc_gradients(grads) + if not self.grad_counts[src_rank]: + self.grad_counts[src_rank] = 1 + else: + raise RuntimeError(f"This worker {src_rank} send gradients again.") + if sum(self.grad_counts) == self.num_workers: + self.apply_updates(self.grad_buffer) + + self._init_grad_buffer() + self._init_grad_counts() + return True + + def clean_redundancy(self): + self.training_operator.clean_redundancy() + + def shutdown(self): + # destroy the collective group resources on this process + col.destroy_collective_group(self.group_name) + if self.training_operator: + del self.training_operator + return 1 + + +class Worker(object): + def __init__(self, training_operator_cls, operator_config): + self.training_operator_cls = training_operator_cls + self.operator_config = operator_config + + # collective-related information + self.group_size = None + self.rank = None + self.group_name = None + self.assignments = None + + def setup_operator(self): + # figure out the signature of training_operator_cls later. + self.training_operator = self.training_operator_cls( + self.operator_config) + + def setup_collective_group(self, + rank, + num_ps, + num_workers, + backend="nccl", + group_name="default"): + self.rank = rank + self.num_ps = num_ps + self.num_workers = num_workers + self.group_name = group_name + self.group_size = num_ps + num_workers + self.name_list = [[] for i in range(num_ps)] + + # the last num_ps processes are servers. + col.init_collective_group( + num_ps + num_workers, rank, backend=backend, group_name=group_name) + + def test_connection(self): + for i in range(self.num_ps): + send = util.ones((1, ), cpu=False) + col.send(send, self.num_workers + i, self.group_name) + for i in range(self.num_ps): + recv = util.zeros((1, ), cpu=False) + col.recv(recv, self.num_workers + i, self.group_name) + assert recv == 1 + return + + def params_distribution(self): + distribution = [] + weights = self.get_named_parameters(cpu=True) + for k, v in weights.items(): + distribution.append(self.training_operator.numel(v)) + return distribution + + def make_iterator(self, training=True): + """Convert loader to be an iterator at the start of an epoch.""" + # TODO(Hao): need to check whether reaching the boundary of iterator + # instead of making a new one every time. + if training: + self.training_iterator = iter( + self.training_operator._get_train_loader()) + else: + self.validation_iterator = iter( + self.training_operator._get_validation_loader()) + + def get_data_loader_len(self, training=True): + """Return the number of batches in the data loader.""" + loader = self.training_operator._get_train_loader() if training \ + else self.training_operator._get_validation_loader() + if hasattr(loader, "__len__"): + return len(loader) + else: + raise RuntimeError( + "Data loader has no attribute `__len__`. " + "Please set `num_steps` in `train()` or `validate()`.") + + def derive_updates(self, batch): + # TODO (Hao): handling data loader next. + # TODO (Hao): change it to derive_update and apply_update. + return self.training_operator.derive_updates(batch) + + def compute_gradients(self, params): + """ + Update worker parameters that received from server. + Compute gradients and return named gradients. + """ + self.set_parameters(params) + + try: + batch = next(self.training_iterator) + except StopIteration and NameError: + self.make_iterator() + batch = next(self.training_iterator) + + # different from original core ps. + # Here derive_updates return loss_val and graident in order. + loss_val, grads = self.training_operator.derive_updates(batch) + assert isinstance(grads, dict) + + return loss_val, grads + + def split_gradients(self, grad, assignments): + # assuming messages are gradients or parameters + # this grad is ready to be called by apply_gradients in ParameterServer + num_shards = np.unique(np.array(assignments)).size + shards = [dict() for i in range(num_shards)] + for i, (k, v) in enumerate(grad.items()): + shards[assignments[i]][k] = v + return shards + + def split_parameters(self, assignments): + params = self.get_named_parameters(cpu=False) + num_shards = np.unique(np.array(assignments)).size + shards = [dict() for i in range(num_shards)] + for i, (k, v) in enumerate(params.items()): + shards[assignments[i]][k] = v + return shards + + def index_shard(self, shards, index): + return shards[index] + + def set_parameters(self, params): + return self.training_operator.set_parameters(params) + + def get_parameters(self, cpu): + return self.training_operator.get_parameters(cpu) + + def get_named_parameters(self, cpu): + return self.training_operator.get_named_parameters(cpu) + + def get_gradients(self): + # training_operator call gradients or we save gradient in replica + # when derive_updates. + return self.training_operator.get_gradients() + + def set_assignments(self, assignments): + self.assignments = assignments + keys = list(self.get_named_parameters(cpu=False).keys()) + for i, a in enumerate(self.assignments): + self.name_list[a].append(keys[i]) + + def compute(self): + """Returns the loss, and send gradients to servers""" + metrics = {} + + weights = self.get_named_parameters(cpu=False) + params = dict() + + # 1. Create the receive lists to group collective calls + recv_list = [] + for i in range(self.num_ps): + recv_list.append([]) + param_shard_keys = self.name_list[i] + for key in param_shard_keys: + to_recv = weights[key] + recv_list[-1].append( + self.training_operator.ones(to_recv.shape, cpu=False)) + + # 2. Receive params from servers + for i in range(self.num_ps): + for j in range(len(self.name_list[i])): + v = self.training_operator.to_cupy(recv_list[i][j]) + col.recv(v, self.num_workers + i, self.group_name) + + # 3. Set params in workers and compute gradients. + for i in range(self.num_ps): + param_shard_keys = self.name_list[i] + for j in range(len(param_shard_keys)): + params[param_shard_keys[j]] = recv_list[i][j] + + loss_val, grad = self.compute_gradients(params) + metrics["train_loss"] = loss_val + + # 4. Shard gradients and send to servers. + split_grad = self.split_gradients(grad, self.assignments) + for i in range(self.num_ps): + this_shard = self.index_shard(split_grad, i) + for _, v in this_shard.items(): + cv = self.training_operator.to_cupy(v) + col.send(cv, self.num_workers + i, self.group_name) + return metrics + + def validate_batch(self): + try: + batch = next(self.validation_iterator) + except StopIteration and TypeError: + self.make_iterator(training=False) + batch = next(self.validation_iterator) + batch_metric = self.training_operator.validate_batch(batch) + return batch_metric + + def shutdown(self): + # destroy the collective group resources on this process + col.destroy_collective_group(self.group_name) + if self.training_operator: + del self.training_operator + return 1 + + +class DataParallelGroup: + """Spawn a group a replicas for data-parallel training.""" + + def __init__(self, params, dist_params, num_cpus_per_actor, + num_gpus_per_actor): + self._params = params + self._dist_params = dist_params + self._num_cpus_per_actor = num_cpus_per_actor + self._num_gpus_per_actor = num_gpus_per_actor + + self.is_server = self._dist_params["is_server"] + self.num_ps = self._dist_params["num_ps"] + self.num_workers = self._dist_params["num_workers"] + + self._distributed_actors = None + + def _setup_collective_group(self, num_replicas): + if self._dist_params["strategy"] == "ps": + num_ps = self._dist_params["num_ps"] + num_workers = self._dist_params["num_workers"] + is_server = self.is_server + rets = [ + actor.setup_collective_group.remote( + rank=i + is_server * num_workers, + num_workers=num_workers, + num_ps=num_ps, + backend="nccl") + for i, actor in enumerate(self._distributed_actors) + ] + else: # this can be extend for allreduce. + raise RuntimeError("Unrecognized strategy.") + return rets + + def setup_operator(self): + setups = [ + actor.setup_operator.remote() + for i, actor in enumerate(self._distributed_actors) + ] + return setups + + def start_actors(self, num_actors): + if self.is_server: + RemoteActor = ray.remote( + num_cpus=self._num_cpus_per_actor, + num_gpus=self._num_gpus_per_actor)(PS) + else: + RemoteActor = ray.remote( + num_cpus=self._num_cpus_per_actor, + num_gpus=self._num_gpus_per_actor)(Worker) + + self._distributed_actors = [ + RemoteActor.remote(**self._params) for _ in range(num_actors) + ] + + # setup the rank and group in each replica + ray.get(self._setup_collective_group(len(self._distributed_actors))) + + def test_connection(self): + rets = [ + actor.test_connection.remote() + for _, actor in enumerate(self.actors) + ] + return rets + + def set_assignments(self, assignments): + rets = [ + actor.set_assignments.remote(assignments) + for _, actor in enumerate(self.actors) + ] + return rets + + def _make_iterator(self, training): + return [actor.make_iterator.remote(training) for actor in self.actors] + + def make_iterator(self, training=True): + ray.get(self._make_iterator(training)) + + def get_data_loader_len(self, training=True): + """Return the number of batches in the data loader.""" + lens = ray.get([ + actor.get_data_loader_len.remote(training=training) + for actor in self.actors + ]) + + if len(set(lens)) != 1: + # TODO(Hao): is this correct after we add distributed data loader? + raise RuntimeError( + "All actors should have the same dataloader len.") + return lens[0] + + def validate_batch(self): + rets = [ + actor.validate_batch.remote() + for _, actor in enumerate(self.actors) + ] + stats = ray.get(rets) + return stats + + def shutdown(self, force=False): + rets = [actor.shutdown.remote() for _, actor in enumerate(self.actors)] + stats = ray.get(rets) + return stats + + def reset(self): + pass + + @property + def actors(self): + return self._distributed_actors + + def save_parameters(self, checkpoint): + rets = [self.actors[0].save_parameters.remote(checkpoint)] + ray.get(rets) + + def load_parameters(self, checkpoint): + rets = [ + actor.load_parameters.remote(checkpoint) + for _, actor in enumerate(self.actors) + ] + ray.get(rets) + + def set_parameters(self, params): + rets = [ + actor.set_parameters.remote(params) + for _, actor in enumerate(self.actors) + ] + ray.get(rets) + + def get_parameters(self, cpu=False): + ret = self.actors[0].get_parameters.remote(cpu) + return ray.get([ret])[0] + + def get_named_parameters(self, cpu=False): + ret = self.actors[0].get_named_parameters.remote(cpu) + return ray.get([ret])[0] + + def split_parameters(self, assignments): + ret = self.actors[0].split_parameters.remote(assignments) + return ray.get([ret])[0] + + def clean_redundancy(self): + """Clean dataloader. Only for servers""" + rets = [ + actor.clean_redundancy.remote() + for _, actor in enumerate(self.actors) + ] + ray.get(rets) diff --git a/examples/jax/default_train.csv b/examples/jax/default_train.csv deleted file mode 100644 index 0d23b47..0000000 --- a/examples/jax/default_train.csv +++ /dev/null @@ -1,2 +0,0 @@ -count_train,mean_train_s,last_train_s,total_train_s,pass_data_train,throughout_train_d -50,2.456741285324097,2.3998360633850098,164.64879870414734,6400,38.87061460739823 From a68dbf3ffcadcee6e21ac002de83b4e6dd5f57fe Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Fri, 30 Apr 2021 02:03:15 +0800 Subject: [PATCH 07/12] delete some trash files --- examples/jax/.ray.lock | 0 examples/jax/default_train.csv | 2 -- 2 files changed, 2 deletions(-) delete mode 100755 examples/jax/.ray.lock delete mode 100644 examples/jax/default_train.csv diff --git a/examples/jax/.ray.lock b/examples/jax/.ray.lock deleted file mode 100755 index e69de29..0000000 diff --git a/examples/jax/default_train.csv b/examples/jax/default_train.csv deleted file mode 100644 index 0d23b47..0000000 --- a/examples/jax/default_train.csv +++ /dev/null @@ -1,2 +0,0 @@ -count_train,mean_train_s,last_train_s,total_train_s,pass_data_train,throughout_train_d -50,2.456741285324097,2.3998360633850098,164.64879870414734,6400,38.87061460739823 From e3792ab33b03a5b3a62cf1322db89c09df4cf075 Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Fri, 30 Apr 2021 02:27:12 +0800 Subject: [PATCH 08/12] jax ps example --- distml/strategy/ps_strategy.py | 10 +++--- examples/jax/mnist_jax_example.py | 54 +++++++++++++++++++++++++------ 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/distml/strategy/ps_strategy.py b/distml/strategy/ps_strategy.py index 382daeb..99fef65 100644 --- a/distml/strategy/ps_strategy.py +++ b/distml/strategy/ps_strategy.py @@ -3,16 +3,16 @@ import numpy as np -import distml.strategy.util as util -from distml.strategy.base_trainer import BaseTrainer -from .util import ThroughputCollection +import distml.util as util +from distml.strategy.base_strategy import BaseStrategy +from distml.util import ThroughputCollection import logging logger = logging.getLogger(__name__) -class ParameterServerStrategy(BaseTrainer): +class ParameterServerStrategy(BaseStrategy): """Strategy that trains a model via collective AllReduce. Args: @@ -240,7 +240,7 @@ def train_batch(self): return metrics -class PS(object): +class PS(object): def __init__(self, training_operator_cls, operator_config): self.training_operator_cls = training_operator_cls self.operator_config = operator_config diff --git a/examples/jax/mnist_jax_example.py b/examples/jax/mnist_jax_example.py index 2335b5d..cc265ae 100644 --- a/examples/jax/mnist_jax_example.py +++ b/examples/jax/mnist_jax_example.py @@ -6,6 +6,7 @@ import ray from distml.operator.jax_operator import JAXTrainingOperator from distml.strategy.allreduce_strategy import AllReduceStrategy +from distml.strategy.ps_strategy import ParameterServerStrategy from ray.util.sgd.utils import override @@ -72,6 +73,35 @@ def setup(self, config): train_loader=train_loader, validation_loader=test_loader) +def make_ar_strategy(args): + strategy = AllReduceStrategy( + training_operator_cls=MnistTrainingOperator, + world_size=args.num_workers, + operator_config={ + "lr": 0.01, + "batch_size": 128, + "num_workers": args.num_workers, + "num_classes": 10, + "model_name": args.model_name + }) + return strategy + + +def make_ps_strategy(args): + strategy = ParameterServerStrategy( + training_operator_cls=MnistTrainingOperator, + world_size=args.num_workers, + num_workers=args.num_workers - args.num_ps, + num_ps=args.num_ps, + operator_config={ + "lr": 0.01, + "batch_size": 128, + "num_classes": 10, + "model_name": args.model_name + }) + return strategy + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -85,6 +115,11 @@ def setup(self, config): type=int, default=2, help="Sets number of workers for training.") + parser.add_argument( + "--num-ps", + type=int, + default=1, + help="Sets number of servers for training. Only for ps_strategy.") parser.add_argument( "--num-epochs", type=int, @@ -100,6 +135,8 @@ def setup(self, config): type=str, default="resnet18", help="model, Optional: resnet18, resnet50, resnet101.") + parser.add_argument( + "--strategy", type=str, default="ar", help="model, Optional: ar, ps.") args, _ = parser.parse_known_args() @@ -111,16 +148,13 @@ def setup(self, config): num_cpus=args.num_workers * 2, log_to_driver=True) - strategy = AllReduceStrategy( - training_operator_cls=MnistTrainingOperator, - world_size=args.num_workers, - operator_config={ - "lr": 0.01, - "batch_size": 128, - "num_workers": args.num_workers, - "num_classes": 10, - "model_name": args.model_name - }) + if args.strategy == "ar": + strategy = make_ar_strategy(args) + elif args.strategy == "ps": + strategy = make_ps_strategy(args) + else: + raise RuntimeError("Unrecognized trainer type. Except 'ar' or 'ps'" + "Got {}".format(args.strategy)) for i in range(args.num_epochs): strategy.train() From 21340f4310592c20eee0457330c68f6418d19afc Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Sat, 15 May 2021 01:50:55 +0800 Subject: [PATCH 09/12] base_data_parallel_group and add typing in function params --- distml/strategy/allreduce_strategy.py | 139 +++++++------- distml/strategy/base_strategy.py | 126 ++++++++++++- distml/strategy/ps_strategy.py | 258 ++++++++++++++------------ examples/jax/mnist_jax_example.py | 14 +- 4 files changed, 333 insertions(+), 204 deletions(-) diff --git a/distml/strategy/allreduce_strategy.py b/distml/strategy/allreduce_strategy.py index 5eb92b8..3016625 100644 --- a/distml/strategy/allreduce_strategy.py +++ b/distml/strategy/allreduce_strategy.py @@ -1,8 +1,9 @@ import logging +from typing import Callable, Mapping, Any, Optional import ray import ray.util.collective as col -from distml.strategy.base_strategy import BaseStrategy +from distml.strategy.base_strategy import BaseStrategy, BaseDataParallelGroup from distml.util import ThroughputCollection import numpy as np @@ -29,17 +30,21 @@ class AllReduceStrategy(BaseStrategy): def __init__(self, *, training_operator_cls, - operator_config=None, - initialization_hook=None, - world_size=2, - num_cpus_per_worker=1, - num_gpus_per_worker=1, + operator_config: Optional[Mapping[str, Any]] = None, + initialization_hook: Optional[Callable] = None, + world_size: int = 2, + backend: str = "nccl", + group_name: str = "default", + num_cpus_per_worker: int = 1, + num_gpus_per_worker: int = 1, **kwargs): super(AllReduceStrategy, self). \ __init__(training_operator_cls=training_operator_cls, operator_config=operator_config, initialization_hook=initialization_hook, world_size=world_size, + backend=backend, + group_name=group_name, num_cpus_per_worker=num_cpus_per_worker, num_gpus_per_worker=num_gpus_per_worker, **kwargs) @@ -52,7 +57,7 @@ def __init__(self, else: self._collector = ThroughputCollection() - def train(self, num_steps=None): + def train(self, num_steps: Optional[int] = None): """Run the training on parallel workers. Args: @@ -74,7 +79,7 @@ def train(self, num_steps=None): print("Step: {}/{}".format(idx, steps)) return metrics - def validate(self, num_steps=None): + def validate(self, num_steps: Optional[int] = None): """Evaluates the model on the validation data. Args: @@ -111,26 +116,26 @@ def _start_workers(self): # (2) params for setting up collective group and strategy prep-ups. dist_params = dict( strategy="allreduce", - backend="nccl", - group_name="default", + backend=self.backend, + group_name=self.group_name, ) group_init_args = dict( - replica_params=replica_params, + actor_params=replica_params, dist_params=dist_params, initialization_hook=self.initialization_hook, - num_cpus_per_worker=self.num_cpus_per_worker, - num_gpus_per_worker=self.num_gpus_per_worker) + num_cpus_per_actor=self.num_cpus_per_worker, + num_gpus_per_actor=self.num_gpus_per_worker) self.data_parallel_group = DataParallelGroup(**group_init_args) # Once the group is created, we start it. self.data_parallel_group.start_replicas(self.world_size) - def shutdown(self, force=False): + def shutdown(self, force: bool = False): self.data_parallel_group.shutdown(force=force) - def save_parameters(self, checkpoint): + def save_parameters(self, checkpoint: str): self.data_parallel_group.save_parameters(checkpoint) - def load_parameters(self, checkpoint): + def load_parameters(self, checkpoint: str): self.data_parallel_group.load_parameters(checkpoint) def _init_strategy(self): @@ -144,7 +149,7 @@ class Replica: and Ray collective group setup. """ - def __init__(self, training_operator_cls, operator_config): + def __init__(self, training_operator_cls, operator_config: Optional[Mapping[str, Any]]): self.training_operator_cls = training_operator_cls self.operator_config = operator_config # Training operator @@ -165,17 +170,17 @@ def setup_operator(self): operator_config=self.operator_config) def setup_collective_group(self, - rank, - world_size, - backend, - group_name="default"): + rank: int, + world_size: str, + backend: str, + group_name: str = "default"): self._rank = rank self._group_name = group_name self._world_size = world_size col.init_collective_group( world_size, rank, backend=backend, group_name=group_name) - def make_iterator(self, training=True): + def make_iterator(self, training: bool = True): """Convert loader to be an iterator at the start of an epoch.""" # TODO(Hao): need to check whether reaching the boundary of iterator # instead of making a new one every time. @@ -184,7 +189,7 @@ def make_iterator(self, training=True): else: self.validation_iterator = iter(self.validation_loader) - def get_data_loader_len(self, training=True): + def get_data_loader_len(self, training: bool = True): """Return the number of batches in the data loader.""" loader = self.train_loader if training \ else self.validation_loader @@ -195,7 +200,7 @@ def get_data_loader_len(self, training=True): "Data loader has no attribute `__len__`. " "Please set `num_steps` in `train()` or `validate()`.") - def train_batch(self): + def train_batch(self) -> dict: metrics = {} try: batch = next(self.train_iterator) @@ -209,16 +214,14 @@ def train_batch(self): for _, g in updates.items(): cg = self.training_operator.to_cupy(g) col.allreduce(cg) - # TODO(Hao): this is conflicting with Runhui's code though. cg = cg / float(self.world_size) self.apply_updates(updates) return metrics - def derive_updates(self, batch): + def derive_updates(self, batch) -> dict: return self.training_operator.derive_updates(batch) def apply_updates(self, updates): - # TODO(Hao): conflicting with Runhui's code on averaging grads self.training_operator.apply_updates(updates) def updates_transform(self, updates): @@ -240,13 +243,13 @@ def shutdown(self): del self.training_operator return 1 - def save_parameters(self, checkpoint): + def save_parameters(self, checkpoint: str): self.training_operator.save_parameters(checkpoint) - def load_parameters(self, checkpoint): + def load_parameters(self, checkpoint: str): self.training_operator.load_parameters(checkpoint) - def apply(self, fn): + def apply(self, fn: Callable): """Apply a function in the replica process.""" return fn() @@ -271,45 +274,43 @@ def group_name(self): return self._group_name -class DataParallelGroup: - """Spawn a group a replicas for data-parallel training.""" - - def __init__(self, replica_params, dist_params, initialization_hook, - num_cpus_per_worker, num_gpus_per_worker): - self._replica_params = replica_params - self._dist_params = dist_params - - # try to unroll the dist_params - self._backend = self._dist_params["backend"] - self._group_name = self._dist_params["group_name"] +class DataParallelGroup(BaseDataParallelGroup): + """Spawn a replica group for data-parallel training.""" - self._initialization_hook = initialization_hook - self._num_cpus_per_worker = num_cpus_per_worker - self._num_gpus_per_worker = num_gpus_per_worker + def __init__(self, + actor_params: Mapping[str, Any], + dist_params: Mapping[str, Any], + num_cpus_per_actor: int, + num_gpus_per_actor: int, + initialization_hook: Optional[Callable]): + super(DataParallelGroup, self).__init__(actor_params=actor_params, + dist_params=dist_params, + num_cpus_per_actor=num_cpus_per_actor, + num_gpus_per_actor=num_gpus_per_actor, + initialization_hook=initialization_hook) self._replicas = None @property - def replicas(self): - return self._replicas - - @property - def world_size(self): - return len(self._replicas) + def _replica_params(self): + return self._actor_params @property - def backend(self): - return self._backend + def replicas(self): + return self._actors @property def group_name(self): return self._group_name - def start_replicas(self, num_replicas): + def start_replicas(self, num_replicas: int): + self._start_actors(num_replicas) + + def _start_actors(self, num_replicas: int): assert num_replicas > 1 RemoteReplica = ray.remote( - num_cpus=self._num_cpus_per_worker, - num_gpus=self._num_gpus_per_worker)(Replica) - self._replicas = [ + num_cpus=self._num_cpus_per_actor, + num_gpus=self._num_gpus_per_actor)(Replica) + self._actors = [ RemoteReplica.remote(**self._replica_params) for _ in range(num_replicas) ] @@ -327,16 +328,16 @@ def start_replicas(self, num_replicas): operator_setups = self._setup_operator() ray.get(operator_setups) - def _make_iterator(self, training): + def _make_iterator(self, training: bool): return [ replica.make_iterator.remote(training=training) for replica in self.replicas ] - def make_iterator(self, training=True): + def make_iterator(self, training: bool = True): ray.get(self._make_iterator(training=training)) - def get_data_loader_len(self, training=True): + def get_data_loader_len(self, training: bool = True): """Return the number of batches in the data loader.""" lens = ray.get([ replica.get_data_loader_len.remote(training=training) @@ -361,7 +362,7 @@ def validate_batch(self): stats = ray.get(rets) return stats - def shutdown(self, force=False): + def shutdown(self, force: bool = False): rets = [replica.shutdown.remote() for replica in self.replicas] stats = ray.get(rets) return stats @@ -369,11 +370,11 @@ def shutdown(self, force=False): def reset(self): pass - def save_parameters(self, checkpoint): + def save_parameters(self, checkpoint: str): rets = [self.replicas[0].save_parameters.remote(checkpoint)] ray.get(rets) - def load_parameters(self, checkpoint): + def load_parameters(self, checkpoint: str): rets = [ replica.load_parameters.remote(checkpoint) for _, replica in enumerate(self.replicas) @@ -387,15 +388,15 @@ def set_parameters(self, params): ] ray.get(rets) - def get_parameters(self, cpu=False): + def get_parameters(self, cpu: bool = False): ret = self.replicas[0].get_parameters.remote(cpu) return ray.get(ret)[0] - def get_named_parameters(self, cpu=False): + def get_named_parameters(self, cpu: bool = False): ret = self.replicas[0].get_named_parameters.remote(cpu) return ray.get([ret])[0] - def apply_all_replicas(self, fn): + def apply_all_replicas(self, fn: Callable): """Apply fn in all replica processes and wait until completion.""" return ray.get(self._apply_all_replicas(fn)) @@ -404,13 +405,13 @@ def _apply_all_replicas(self, fn): return [replica.apply.remote(fn) for replica in self.replicas] def _setup_collective_group(self, - world_size, - backend, - group_name="default"): + group_size: int, + backend: int, + group_name: str = "default"): refs = [ replica.setup_collective_group.remote( rank=i, - world_size=world_size, + world_size=group_size, backend=backend, group_name=group_name) for i, replica in enumerate(self.replicas) diff --git a/distml/strategy/base_strategy.py b/distml/strategy/base_strategy.py index 69e3b0a..27a92d0 100644 --- a/distml/strategy/base_strategy.py +++ b/distml/strategy/base_strategy.py @@ -1,6 +1,7 @@ from abc import ABCMeta from abc import abstractmethod import logging +from typing import AbstractSet, Callable, Any, Mapping, Optional import ray @@ -11,11 +12,13 @@ class BaseStrategy(metaclass=ABCMeta): def __init__(self, *, training_operator_cls, - operator_config=None, - initialization_hook=None, - world_size=2, - num_cpus_per_worker=1, - num_gpus_per_worker=1, + operator_config: Optional[Mapping[str, Any]] = None, + initialization_hook: Optional[Callable] = None, + world_size: int = 2, + backend: str = "nccl", + group_name: str = "default", + num_cpus_per_worker: int = 1, + num_gpus_per_worker: int = 1, **kwargs): self.training_operator_cls = training_operator_cls self.initialization_hook = initialization_hook @@ -24,6 +27,8 @@ def __init__(self, "ray.util.distml does not support single-process training " "at this moment.") self.world_size = world_size + self.backend = backend + self.group_name = group_name self.num_cpus_per_worker = num_cpus_per_worker self.num_gpus_per_worker = num_gpus_per_worker self._operator_config = {} if not operator_config \ @@ -47,7 +52,7 @@ def validate(self): raise NotImplementedError() @abstractmethod - def save_parameters(self, checkpoint): + def save_parameters(self, checkpoint: str): """Saves the Trainer state to the provided checkpoint path. Args: @@ -56,7 +61,12 @@ def save_parameters(self, checkpoint): raise NotImplementedError() @abstractmethod - def load_parameters(self, checkpoint): + def load_parameters(self, checkpoint: str): + """Loads the Trainer state to the provided checkpoint path. + + Args: + checkpoint (str): Path to target checkpoint file. + """ raise NotImplementedError() @abstractmethod @@ -70,6 +80,106 @@ def _init_strategy(self): raise NotImplementedError() @abstractmethod - def shutdown(self, force=False): + def shutdown(self, force: bool = False): """Kill all workers.""" raise NotImplementedError() + + +class BaseDataParallelGroup: + """Spawn a actor group for data-parallel training.""" + + def __init__(self, + actor_params: Mapping[str, Any], + dist_params: Mapping[str, Any], + num_cpus_per_actor: int, + num_gpus_per_actor: int, + initialization_hook: Optional[Callable], + **kwargs): + self._actor_params = actor_params + self._dist_params = dist_params + self._backend = self._dist_params["backend"] + self._group_name = self._dist_params["group_name"] + self._num_cpus_per_actor = num_cpus_per_actor + self._num_gpus_per_actor = num_gpus_per_actor + self._initialization_hook = initialization_hook + + # try to unroll the dist_params + self._backend = self._dist_params["backend"] + self._group_name = self._dist_params["group_name"] + + @property + def world_size(self): + return len(self._actors) + + @property + def backend(self): + return self._backend + + @property + def group_name(self): + return self._group_name + + @abstractmethod + def _setup_collective_group(self, *args, **kwargs): + """All actors setup operators.""" + raise NotImplementedError() + + @abstractmethod + def setup_operator(self): + """All actors setup operators.""" + raise NotImplementedError() + + @abstractmethod + def _start_actors(self, num_actors): + """Start all actors.""" + raise NotImplementedError() + + @abstractmethod + def make_iterator(self, training: bool = True): + """Make iterator.""" + raise NotImplementedError() + + @abstractmethod + def get_data_loader_len(self, training: bool = True): + """Return the number of batches in the data loader.""" + raise NotImplementedError() + + @abstractmethod + def validate_batch(self): + """Validate one batch and return batch metrics.""" + raise NotImplementedError() + + @abstractmethod + def shutdown(self, force: bool = False): + """Shutdown all actors.""" + raise NotImplementedError() + + @abstractmethod + def reset(self): + """Reset group.""" + raise NotImplementedError() + + @abstractmethod + def save_parameters(self, checkpoint: str): + """Let the first actor save parameters.""" + raise NotImplementedError() + + @abstractmethod + def load_parameters(self, checkpoint: str): + """All actor load parameters from checkpoint.""" + raise NotImplementedError() + + @abstractmethod + def set_parameters(self, params): + """Input params and replace the model parameters.""" + raise NotImplementedError() + + @abstractmethod + def get_parameters(self, cpu: bool = False): + """Return parameters from the first actor.""" + raise NotImplementedError() + + @abstractmethod + def get_named_parameters(self, cpu: bool = False): + """Return named parameters from the first actor.""" + raise NotImplementedError() diff --git a/distml/strategy/ps_strategy.py b/distml/strategy/ps_strategy.py index 99fef65..0143365 100644 --- a/distml/strategy/ps_strategy.py +++ b/distml/strategy/ps_strategy.py @@ -1,10 +1,11 @@ +from typing import Tuple, List, Callable, Mapping, Union, Any, Optional, Sequence import ray import ray.util.collective as col import numpy as np import distml.util as util -from distml.strategy.base_strategy import BaseStrategy +from distml.strategy.base_strategy import BaseStrategy, BaseDataParallelGroup from distml.util import ThroughputCollection import logging @@ -13,7 +14,7 @@ class ParameterServerStrategy(BaseStrategy): - """Strategy that trains a model via collective AllReduce. + """Strategy that trains a model via parameter server. Args: training_operator_cls (TrainingOperator): @@ -22,7 +23,7 @@ class ParameterServerStrategy(BaseStrategy): initialization_hook (function): A function to call on all training workers when they are first initialized. This could be useful to set environment variables for all the worker processes. - num_workers (int): The number of workers. + num_worker (int): The number of workers. num_ps (int): The number of parameter servers. num_cpus_per_worker (int): number of CPUs allocated per worker. num_gpus_per_worker (int): number of GPUs allocated per worker. @@ -33,27 +34,35 @@ class ParameterServerStrategy(BaseStrategy): def __init__(self, *, training_operator_cls, - operator_config=None, - initialization_hook=None, - num_workers=1, - num_ps=1, - num_cpus_per_worker=1, - num_gpus_per_worker=1, - num_cpus_per_server=1, - num_gpus_per_server=1, + operator_config: Optional[Mapping[str, Any]] = None, + initialization_hook: Optional[Callable] = None, + world_size: int = 2, + num_worker: int = 1, + num_ps: int = 1, + backend="nccl", + group_name="default", + num_cpus_per_worker: int = 1, + num_gpus_per_worker: int = 1, + num_cpus_per_server: int = 1, + num_gpus_per_server: int = 1, **kwargs): - self.assignments = None - assert num_ps + assert world_size == num_ps + num_worker, \ + "'world_size' should be equal to 'num_ps' plus 'num_worker'" + + self.assignments = None self.num_ps = num_ps - self.num_workers = num_workers + self.num_worker = num_worker self.num_cpus_per_server = num_cpus_per_server self.num_gpus_per_server = num_gpus_per_server - super(ParameterServerStrategy, self).\ + super(ParameterServerStrategy, self). \ __init__(training_operator_cls=training_operator_cls, operator_config=operator_config, initialization_hook=initialization_hook, + world_size=world_size, + backend=backend, + group_name=group_name, num_cpus_per_worker=num_cpus_per_worker, num_gpus_per_worker=num_gpus_per_worker, **kwargs) @@ -94,17 +103,13 @@ def _init_strategy(self): ray.get([server.set_params.remote(this_shard_ref)]) def _start_workers(self): - """Create worker(actor), maybe need worker group to manager these workers. - Or, send these workers to strategy to manager? - - set workers or worker group - set worker info, record rank, backend, use_num_gpus? + """Start worker group and server group. """ # TODO (Hao): infer the per-replica batch size here... # so here we get two set of params that will be passed around: # (1) Those for setting up training logic in training_operator, - # including: batchsize, use_tqdm, user defined operator_config. + # including: batch size, user defined operator_config. operator_config = self._operator_config.copy() params = dict( training_operator_cls=self.training_operator_cls, @@ -116,62 +121,66 @@ def _start_workers(self): dist_params_worker = dict( strategy="ps", is_server=False, - group_name="default", + backend=self.backend, + group_name=self.group_name, num_ps=self.num_ps, - num_workers=self.num_workers, + num_worker=self.num_worker, ) dist_params_server = dict( strategy="ps", is_server=True, - group_name="default", + backend=self.backend, + group_name=self.group_name, num_ps=self.num_ps, - num_workers=self.num_workers, + num_worker=self.num_worker, ) # (3) other arguments that used to init the DataParallelGrup - workergroup_init_args = { - "params": params, - "dist_params": dist_params_worker, - "num_cpus_per_actor": self.num_cpus_per_worker, - "num_gpus_per_actor": self.num_gpus_per_worker, - } + worker_group_init_args = dict( + actor_params=params, + dist_params=dist_params_worker, + num_cpus_per_actor=self.num_cpus_per_worker, + num_gpus_per_actor=self.num_gpus_per_worker, + initialization_hook=self.initialization_hook, + ) - servergroup_init_args = { - "params": params, - "dist_params": dist_params_server, - "num_cpus_per_actor": self.num_cpus_per_server, - "num_gpus_per_actor": self.num_gpus_per_server, - } + server_group_init_args = dict( + actor_params=params, + dist_params=dist_params_server, + num_cpus_per_actor=self.num_cpus_per_server, + num_gpus_per_actor=self.num_gpus_per_server, + initialization_hook=self.initialization_hook, + ) # Should we make two groups for worker and server? - self.worker_group = DataParallelGroup(**workergroup_init_args) - self.server_group = DataParallelGroup(**servergroup_init_args) + self.worker_group = DataParallelGroup(**worker_group_init_args) + self.server_group = DataParallelGroup(**server_group_init_args) # Once the group is created, we start it. - self.worker_group.start_actors(self.num_workers) - self.server_group.start_actors( - self.num_ps) # server at the last num_ps processes. + self.worker_group._start_actors(self.num_worker) + # server at the last num_ps processes. + self.server_group._start_actors(self.num_ps) - worker_rets = self.worker_group.test_connection() - server_rets = self.server_group.test_connection() - ray.get(worker_rets + server_rets) + # worker_rets = self.worker_group.test_connection() + # server_rets = self.server_group.test_connection() + # ray.get(worker_rets + server_rets) ray.get(self.worker_group.setup_operator()) ray.get(self.server_group.setup_operator()) self.server_group.clean_redundancy() - def shutdown(self, force=False): + def shutdown(self, force: bool = False): self.worker_group.shutdown(force=force) self.server_group.shutdown(force=force) - def save_parameters(self, checkpoint): + def save_parameters(self, checkpoint: str): # TODO(HUI): ps save parameters. # First, worker rank 0 should pull the latest parameter from servers # Then, worker rank 0 save parameters self.worker_group.save_parameters(checkpoint) - def load_parameters(self, checkpoint): + def load_parameters(self, checkpoint: str): # TODO(HUI): ps load parameters. # shard parameters and send to all servers. self.server_group.load_parameters(checkpoint) @@ -189,7 +198,7 @@ def _round_robin_sharding(self): print("Load of each ps {}".format(loads)) self.assignments = assignments - def train(self, num_steps=None): + def train(self, num_steps: Optional[int] = None) -> dict: # TODO (Hao): add fault tolerance using `max_retries`. steps = num_steps if num_steps \ else self.worker_group.get_data_loader_len() @@ -201,10 +210,10 @@ def train(self, num_steps=None): for idx in range(steps): with self._collector.record("train"): metrics = self.train_batch() - logger.info("Step: {}/{}".format(idx, steps)) + print("Step: {}/{}".format(idx, steps)) return metrics - def validate(self, num_steps=None): + def validate(self, num_steps: Optional[int] = None): steps = num_steps if num_steps \ else self.worker_group.get_data_loader_len(training=False) self.worker_group.make_iterator(training=False) @@ -215,7 +224,7 @@ def validate(self, num_steps=None): # Validate results should be the same in all workers return batch_metrics - def train_batch(self): + def train_batch(self) -> dict: loss_vals = [] rets = [] metrics = {} @@ -241,7 +250,7 @@ def train_batch(self): class PS(object): - def __init__(self, training_operator_cls, operator_config): + def __init__(self, training_operator_cls, operator_config: Optional[Mapping[str, Any]]): self.training_operator_cls = training_operator_cls self.operator_config = operator_config @@ -254,34 +263,34 @@ def setup_operator(self): self.operator_config) def setup_collective_group(self, - rank, - num_ps, - num_workers, - backend="nccl", - group_name="default"): + rank: int, + num_ps: int, + num_worker: int, + backend: str = "nccl", + group_name: str = "default"): # rank should be true rank. means, rank has already plus num_worker. self.rank = rank self.num_ps = num_ps - self.num_workers = num_workers + self.num_worker = num_worker self.group_name = group_name - self.group_size = num_ps + num_workers + self.group_size = num_ps + num_worker self._init_grad_counts() # the last num_ps processes are servers. col.init_collective_group( - num_ps + num_workers, rank, backend=backend, group_name=group_name) + num_ps + num_worker, rank, backend=backend, group_name=group_name) def test_connection(self): - for i in range(self.num_workers): - recv = util.zeros((1, ), cpu=False) + for i in range(self.num_worker): + recv = util.zeros((1,), cpu=False) col.recv(recv, i, self.group_name) assert recv == 1 - for i in range(self.num_workers): - send = util.ones((1, ), cpu=False) + for i in range(self.num_worker): + send = util.ones((1,), cpu=False) col.send(send, i, self.group_name) return def _init_grad_counts(self): - self.grad_counts = [0] * self.num_workers + self.grad_counts = [0] * self.num_worker def _init_grad_buffer(self): self.grad_buffer = { @@ -289,7 +298,7 @@ def _init_grad_buffer(self): for k, v in self.params.items() } - def get_params(self): + def get_params(self) -> dict: return self.params def set_params(self, params): @@ -302,7 +311,7 @@ def set_params(self, params): self._init_grad_buffer() def apply_updates(self, grad_buffer): - # TODO(HUI): gradient divide by num_workers + # TODO(HUI): gradient divide by num_worker self.training_operator.apply_updates(grad_buffer) self.params = self.training_operator.get_named_parameters(cpu=False) @@ -311,13 +320,13 @@ def _inc_gradients(self, gradients): if gradients[name] is not None: self.grad_buffer[name] += gradients[name] - def send_params(self, dst_rank): + def send_params(self, dst_rank: int): """ Send this param shard to the destination worker """ for name, v in self.params.items(): cv = self.training_operator.to_cupy(v) col.send(cv, dst_rank, self.group_name) - def update(self, src_rank): + def update(self, src_rank: int): """Receive gradients and update""" keys = list(self.params.keys()) grads = dict() @@ -340,7 +349,7 @@ def update(self, src_rank): self.grad_counts[src_rank] = 1 else: raise RuntimeError(f"This worker {src_rank} send gradients again.") - if sum(self.grad_counts) == self.num_workers: + if sum(self.grad_counts) == self.num_worker: self.apply_updates(self.grad_buffer) self._init_grad_buffer() @@ -359,7 +368,7 @@ def shutdown(self): class Worker(object): - def __init__(self, training_operator_cls, operator_config): + def __init__(self, training_operator_cls, operator_config: Optional[Mapping[str, Any]]): self.training_operator_cls = training_operator_cls self.operator_config = operator_config @@ -375,29 +384,29 @@ def setup_operator(self): self.operator_config) def setup_collective_group(self, - rank, - num_ps, - num_workers, - backend="nccl", - group_name="default"): + rank: int, + num_ps: int, + num_worker: int, + backend: str = "nccl", + group_name: str = "default"): self.rank = rank self.num_ps = num_ps - self.num_workers = num_workers + self.num_worker = num_worker self.group_name = group_name - self.group_size = num_ps + num_workers + self.group_size = num_ps + num_worker self.name_list = [[] for i in range(num_ps)] # the last num_ps processes are servers. col.init_collective_group( - num_ps + num_workers, rank, backend=backend, group_name=group_name) + num_ps + num_worker, rank, backend=backend, group_name=group_name) def test_connection(self): for i in range(self.num_ps): - send = util.ones((1, ), cpu=False) - col.send(send, self.num_workers + i, self.group_name) + send = util.ones((1,), cpu=False) + col.send(send, self.num_worker + i, self.group_name) for i in range(self.num_ps): - recv = util.zeros((1, ), cpu=False) - col.recv(recv, self.num_workers + i, self.group_name) + recv = util.zeros((1,), cpu=False) + col.recv(recv, self.num_worker + i, self.group_name) assert recv == 1 return @@ -408,7 +417,7 @@ def params_distribution(self): distribution.append(self.training_operator.numel(v)) return distribution - def make_iterator(self, training=True): + def make_iterator(self, training: bool = True): """Convert loader to be an iterator at the start of an epoch.""" # TODO(Hao): need to check whether reaching the boundary of iterator # instead of making a new one every time. @@ -419,7 +428,7 @@ def make_iterator(self, training=True): self.validation_iterator = iter( self.training_operator._get_validation_loader()) - def get_data_loader_len(self, training=True): + def get_data_loader_len(self, training: bool = True): """Return the number of batches in the data loader.""" loader = self.training_operator._get_train_loader() if training \ else self.training_operator._get_validation_loader() @@ -430,9 +439,8 @@ def get_data_loader_len(self, training=True): "Data loader has no attribute `__len__`. " "Please set `num_steps` in `train()` or `validate()`.") - def derive_updates(self, batch): + def derive_updates(self, batch: Sequence[Any]): # TODO (Hao): handling data loader next. - # TODO (Hao): change it to derive_update and apply_update. return self.training_operator.derive_updates(batch) def compute_gradients(self, params): @@ -472,16 +480,16 @@ def split_parameters(self, assignments): shards[assignments[i]][k] = v return shards - def index_shard(self, shards, index): + def index_shard(self, shards, index: int): return shards[index] def set_parameters(self, params): return self.training_operator.set_parameters(params) - def get_parameters(self, cpu): + def get_parameters(self, cpu: bool): return self.training_operator.get_parameters(cpu) - def get_named_parameters(self, cpu): + def get_named_parameters(self, cpu: bool): return self.training_operator.get_named_parameters(cpu) def get_gradients(self): @@ -516,7 +524,7 @@ def compute(self): for i in range(self.num_ps): for j in range(len(self.name_list[i])): v = self.training_operator.to_cupy(recv_list[i][j]) - col.recv(v, self.num_workers + i, self.group_name) + col.recv(v, self.num_worker + i, self.group_name) # 3. Set params in workers and compute gradients. for i in range(self.num_ps): @@ -533,7 +541,7 @@ def compute(self): this_shard = self.index_shard(split_grad, i) for _, v in this_shard.items(): cv = self.training_operator.to_cupy(v) - col.send(cv, self.num_workers + i, self.group_name) + col.send(cv, self.num_worker + i, self.group_name) return metrics def validate_batch(self): @@ -553,33 +561,42 @@ def shutdown(self): return 1 -class DataParallelGroup: - """Spawn a group a replicas for data-parallel training.""" - - def __init__(self, params, dist_params, num_cpus_per_actor, - num_gpus_per_actor): - self._params = params - self._dist_params = dist_params - self._num_cpus_per_actor = num_cpus_per_actor - self._num_gpus_per_actor = num_gpus_per_actor +class DataParallelGroup(BaseDataParallelGroup): + """Spawn a actor group for data-parallel training.""" + def __init__(self, + actor_params: Mapping[str, Any], + dist_params: Mapping[str, Any], + num_cpus_per_actor: int, + num_gpus_per_actor: int, + initialization_hook: Optional[Callable]): + super(DataParallelGroup, self).__init__(actor_params=actor_params, + dist_params=dist_params, + num_cpus_per_actor=num_cpus_per_actor, + num_gpus_per_actor=num_gpus_per_actor, + initialization_hook=initialization_hook) self.is_server = self._dist_params["is_server"] self.num_ps = self._dist_params["num_ps"] - self.num_workers = self._dist_params["num_workers"] + self.num_worker = self._dist_params["num_worker"] self._distributed_actors = None - def _setup_collective_group(self, num_replicas): + def _setup_collective_group(self, + num_ps: int, + num_worker: int, + backend: int, + group_name: str = "default"): if self._dist_params["strategy"] == "ps": - num_ps = self._dist_params["num_ps"] - num_workers = self._dist_params["num_workers"] is_server = self.is_server + rets = [ actor.setup_collective_group.remote( - rank=i + is_server * num_workers, - num_workers=num_workers, + rank=i + is_server * num_worker, + num_worker=num_worker, num_ps=num_ps, - backend="nccl") + backend=backend, + group_name=group_name + ) for i, actor in enumerate(self._distributed_actors) ] else: # this can be extend for allreduce. @@ -593,7 +610,7 @@ def setup_operator(self): ] return setups - def start_actors(self, num_actors): + def _start_actors(self, num_actors: int): if self.is_server: RemoteActor = ray.remote( num_cpus=self._num_cpus_per_actor, @@ -604,11 +621,12 @@ def start_actors(self, num_actors): num_gpus=self._num_gpus_per_actor)(Worker) self._distributed_actors = [ - RemoteActor.remote(**self._params) for _ in range(num_actors) + RemoteActor.remote(**self._actor_params) for _ in range(num_actors) ] # setup the rank and group in each replica - ray.get(self._setup_collective_group(len(self._distributed_actors))) + ray.get(self._setup_collective_group( + self.num_ps, self.num_worker, self.backend, self.group_name)) def test_connection(self): rets = [ @@ -624,13 +642,13 @@ def set_assignments(self, assignments): ] return rets - def _make_iterator(self, training): + def _make_iterator(self, training: bool): return [actor.make_iterator.remote(training) for actor in self.actors] - def make_iterator(self, training=True): + def make_iterator(self, training: bool = True): ray.get(self._make_iterator(training)) - def get_data_loader_len(self, training=True): + def get_data_loader_len(self, training: bool = True): """Return the number of batches in the data loader.""" lens = ray.get([ actor.get_data_loader_len.remote(training=training) @@ -651,7 +669,7 @@ def validate_batch(self): stats = ray.get(rets) return stats - def shutdown(self, force=False): + def shutdown(self, force: bool = False): rets = [actor.shutdown.remote() for _, actor in enumerate(self.actors)] stats = ray.get(rets) return stats @@ -663,11 +681,11 @@ def reset(self): def actors(self): return self._distributed_actors - def save_parameters(self, checkpoint): + def save_parameters(self, checkpoint: str): rets = [self.actors[0].save_parameters.remote(checkpoint)] ray.get(rets) - def load_parameters(self, checkpoint): + def load_parameters(self, checkpoint: str): rets = [ actor.load_parameters.remote(checkpoint) for _, actor in enumerate(self.actors) @@ -681,11 +699,11 @@ def set_parameters(self, params): ] ray.get(rets) - def get_parameters(self, cpu=False): + def get_parameters(self, cpu: bool = False): ret = self.actors[0].get_parameters.remote(cpu) return ray.get([ret])[0] - def get_named_parameters(self, cpu=False): + def get_named_parameters(self, cpu: bool = False): ret = self.actors[0].get_named_parameters.remote(cpu) return ray.get([ret])[0] diff --git a/examples/jax/mnist_jax_example.py b/examples/jax/mnist_jax_example.py index 494c8fa..db7abe2 100644 --- a/examples/jax/mnist_jax_example.py +++ b/examples/jax/mnist_jax_example.py @@ -81,11 +81,11 @@ def criterion(logits, targets): def make_ar_strategy(args): strategy = AllReduceStrategy( training_operator_cls=MnistTrainingOperator, - world_size=args.num_workers, + world_size=args.num_worker, operator_config={ "lr": 0.01, "batch_size": 128, - "num_workers": args.num_workers, + "num_worker": args.num_worker, "num_classes": 10, "model_name": args.model_name }, @@ -96,8 +96,8 @@ def make_ar_strategy(args): def make_ps_strategy(args): strategy = ParameterServerStrategy( training_operator_cls=MnistTrainingOperator, - world_size=args.num_workers, - num_workers=args.num_workers - args.num_ps, + world_size=args.num_worker, + num_worker=args.num_worker - args.num_ps, num_ps=args.num_ps, operator_config={ "lr": 0.01, @@ -116,7 +116,7 @@ def make_ps_strategy(args): type=str, help="the address to use for connecting to the Ray cluster") parser.add_argument( - "--num-workers", + "--num-worker", "-n", type=int, default=2, @@ -150,8 +150,8 @@ def make_ps_strategy(args): ray.init(args.address) else: ray.init( - num_gpus=args.num_workers, - num_cpus=args.num_workers * 2, + num_gpus=args.num_worker, + num_cpus=args.num_worker * 2, log_to_driver=True) if args.strategy == "ar": From 21d9a3503ac9e263ba6e0a033403b683c8e42592 Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Tue, 18 May 2021 00:39:41 +0800 Subject: [PATCH 10/12] lint --- distml/operator/jax_operator.py | 28 +++++---- distml/strategy/allreduce_strategy.py | 32 +++++----- distml/strategy/base_strategy.py | 11 ++-- distml/strategy/ps_strategy.py | 88 ++++++++++++++++----------- examples/jax/mnist_jax_example.py | 1 - 5 files changed, 86 insertions(+), 74 deletions(-) diff --git a/distml/operator/jax_operator.py b/distml/operator/jax_operator.py index 640e6f0..2ae070b 100644 --- a/distml/operator/jax_operator.py +++ b/distml/operator/jax_operator.py @@ -1,3 +1,5 @@ +from typing import Any, Mapping, Optional + import numpy as np import cupy as cp @@ -14,7 +16,7 @@ class JAXTrainingOperator(TrainingOperator): - def __init__(self, operator_config): + def __init__(self, operator_config: Optional[Mapping[str, Any]]): super(JAXTrainingOperator, self).__init__(operator_config) # Should be set by users in the `register` function. # model methods @@ -64,7 +66,7 @@ def setup(self, *args, **kwargs): raise NotImplementedError("Please override this function to register " "your model, optimizer, and criterion.") - def register(self, *, model, optimizer, criterion, jit_mode=False): + def register(self, *, model, optimizer, criterion, jit_mode: bool = False): """Register a few critical information about the model to operator. Args: @@ -273,7 +275,7 @@ def validate_batch(self, batch): "samples_num": samples_num } - def get_parameters(self, cpu): + def get_parameters(self, cpu: bool): """get the flatten parameters.""" params = self.get_params(self.opt_state) flatten_params, tree = tree_flatten(params) @@ -284,7 +286,7 @@ def get_parameters(self, cpu): flatten_params = list(map(np.asarray, flatten_params)) return flatten_params - def get_named_parameters(self, cpu): + def get_named_parameters(self, cpu: bool): """Get the named parameters. In jax, we need to construct a dict to contain the parameters. @@ -335,7 +337,7 @@ def update(param, state): zip(subtrees, new_subtrees)): if new_subtree != subtree: msg = ( - "input structur did not match the save params struture. " + "input structure did not match the save params structure. " "input {} and output {}.") raise TypeError(msg.format(subtree, new_subtree)) @@ -350,25 +352,25 @@ def reset_optimizer_for_params(self, params): self.tree = tree_structure(params) self.opt_state = self.opt_init(params) - def ones(self, shape, cpu=True): + def ones(self, shape, cpu: bool = True): if cpu: return np.ones(shape) else: return jnp.ones(shape) - def zeros(self, shape, cpu=True): + def zeros(self, shape, cpu: bool = True): if cpu: return np.zeros(shape) else: return jnp.zeros(shape) - def ones_like(self, x, cpu=True): + def ones_like(self, x, cpu: bool = True): if cpu: return np.ones_like(x) else: return jnp.ones_like(x) - def zeros_like(self, x, cpu=True): + def zeros_like(self, x, cpu: bool = True): if cpu: return np.zeros_like(x) else: @@ -385,21 +387,21 @@ def clean_redundancy(self): del self._validation_loader # TODO(HUI): use pickle to serialize parameters or states and save it. - def save_parameters(self, checkpoint): + def save_parameters(self, checkpoint: str): raise NotImplementedError( "save_parameters is not support in jax operator.") - def load_parameters(self, checkpoint): + def load_parameters(self, checkpoint: str): raise NotImplementedError( "load_parameters is not support in jax operator.") - def save_states(self, checkpoint): + def save_states(self, checkpoint: str): raise NotImplementedError( "save_states is not support in jax operator.") def get_states(self): raise NotImplementedError("get_states is not support in jax operator.") - def load_states(self, checkpoint): + def load_states(self, checkpoint: str): raise NotImplementedError( "load_states is not support in jax operator.") diff --git a/distml/strategy/allreduce_strategy.py b/distml/strategy/allreduce_strategy.py index 3016625..dc557a8 100644 --- a/distml/strategy/allreduce_strategy.py +++ b/distml/strategy/allreduce_strategy.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Mapping, Any, Optional +from typing import Callable, Mapping, Any, Optional, Dict import ray import ray.util.collective as col @@ -57,7 +57,7 @@ def __init__(self, else: self._collector = ThroughputCollection() - def train(self, num_steps: Optional[int] = None): + def train(self, num_steps: Optional[int] = None) -> Dict: """Run the training on parallel workers. Args: @@ -79,7 +79,7 @@ def train(self, num_steps: Optional[int] = None): print("Step: {}/{}".format(idx, steps)) return metrics - def validate(self, num_steps: Optional[int] = None): + def validate(self, num_steps: Optional[int] = None) -> Dict: """Evaluates the model on the validation data. Args: @@ -149,7 +149,8 @@ class Replica: and Ray collective group setup. """ - def __init__(self, training_operator_cls, operator_config: Optional[Mapping[str, Any]]): + def __init__(self, training_operator_cls, + operator_config: Optional[Mapping[str, Any]]): self.training_operator_cls = training_operator_cls self.operator_config = operator_config # Training operator @@ -189,7 +190,7 @@ def make_iterator(self, training: bool = True): else: self.validation_iterator = iter(self.validation_loader) - def get_data_loader_len(self, training: bool = True): + def get_data_loader_len(self, training: bool = True) -> int: """Return the number of batches in the data loader.""" loader = self.train_loader if training \ else self.validation_loader @@ -200,7 +201,7 @@ def get_data_loader_len(self, training: bool = True): "Data loader has no attribute `__len__`. " "Please set `num_steps` in `train()` or `validate()`.") - def train_batch(self) -> dict: + def train_batch(self) -> Dict: metrics = {} try: batch = next(self.train_iterator) @@ -218,7 +219,7 @@ def train_batch(self) -> dict: self.apply_updates(updates) return metrics - def derive_updates(self, batch) -> dict: + def derive_updates(self, batch) -> Dict: return self.training_operator.derive_updates(batch) def apply_updates(self, updates): @@ -277,17 +278,16 @@ def group_name(self): class DataParallelGroup(BaseDataParallelGroup): """Spawn a replica group for data-parallel training.""" - def __init__(self, - actor_params: Mapping[str, Any], - dist_params: Mapping[str, Any], - num_cpus_per_actor: int, + def __init__(self, actor_params: Mapping[str, Any], + dist_params: Mapping[str, Any], num_cpus_per_actor: int, num_gpus_per_actor: int, initialization_hook: Optional[Callable]): - super(DataParallelGroup, self).__init__(actor_params=actor_params, - dist_params=dist_params, - num_cpus_per_actor=num_cpus_per_actor, - num_gpus_per_actor=num_gpus_per_actor, - initialization_hook=initialization_hook) + super(DataParallelGroup, self).__init__( + actor_params=actor_params, + dist_params=dist_params, + num_cpus_per_actor=num_cpus_per_actor, + num_gpus_per_actor=num_gpus_per_actor, + initialization_hook=initialization_hook) self._replicas = None @property diff --git a/distml/strategy/base_strategy.py b/distml/strategy/base_strategy.py index 27a92d0..4ad1e80 100644 --- a/distml/strategy/base_strategy.py +++ b/distml/strategy/base_strategy.py @@ -1,7 +1,7 @@ from abc import ABCMeta from abc import abstractmethod import logging -from typing import AbstractSet, Callable, Any, Mapping, Optional +from typing import Callable, Any, Mapping, Optional import ray @@ -88,13 +88,10 @@ def shutdown(self, force: bool = False): class BaseDataParallelGroup: """Spawn a actor group for data-parallel training.""" - def __init__(self, - actor_params: Mapping[str, Any], - dist_params: Mapping[str, Any], - num_cpus_per_actor: int, + def __init__(self, actor_params: Mapping[str, Any], + dist_params: Mapping[str, Any], num_cpus_per_actor: int, num_gpus_per_actor: int, - initialization_hook: Optional[Callable], - **kwargs): + initialization_hook: Optional[Callable], **kwargs): self._actor_params = actor_params self._dist_params = dist_params self._backend = self._dist_params["backend"] diff --git a/distml/strategy/ps_strategy.py b/distml/strategy/ps_strategy.py index 0143365..a3486b8 100644 --- a/distml/strategy/ps_strategy.py +++ b/distml/strategy/ps_strategy.py @@ -1,4 +1,4 @@ -from typing import Tuple, List, Callable, Mapping, Union, Any, Optional, Sequence +from typing import List, Callable, Mapping, Any, Optional, Sequence, Dict import ray import ray.util.collective as col @@ -39,8 +39,8 @@ def __init__(self, world_size: int = 2, num_worker: int = 1, num_ps: int = 1, - backend="nccl", - group_name="default", + backend: str = "nccl", + group_name: str = "default", num_cpus_per_worker: int = 1, num_gpus_per_worker: int = 1, num_cpus_per_server: int = 1, @@ -103,10 +103,7 @@ def _init_strategy(self): ray.get([server.set_params.remote(this_shard_ref)]) def _start_workers(self): - """Start worker group and server group. - """ - # TODO (Hao): infer the per-replica batch size here... - + """Start worker group and server group.""" # so here we get two set of params that will be passed around: # (1) Those for setting up training logic in training_operator, # including: batch size, user defined operator_config. @@ -198,7 +195,16 @@ def _round_robin_sharding(self): print("Load of each ps {}".format(loads)) self.assignments = assignments - def train(self, num_steps: Optional[int] = None) -> dict: + def train(self, num_steps: Optional[int] = None) -> Dict: + """Run the training on parallel workers. + + Args: + num_steps (int): number of steps to train. If none, the + function will simply train for one epoch. + + Returns: + None + """ # TODO (Hao): add fault tolerance using `max_retries`. steps = num_steps if num_steps \ else self.worker_group.get_data_loader_len() @@ -213,7 +219,14 @@ def train(self, num_steps: Optional[int] = None) -> dict: print("Step: {}/{}".format(idx, steps)) return metrics - def validate(self, num_steps: Optional[int] = None): + def validate(self, num_steps: Optional[int] = None) -> Dict: + """Evaluates the model on the validation data. + + Args: + num_steps (int): number of batches to evaluate. If None, the + function will simply validate across the entire validation + dataset. + """ steps = num_steps if num_steps \ else self.worker_group.get_data_loader_len(training=False) self.worker_group.make_iterator(training=False) @@ -224,7 +237,7 @@ def validate(self, num_steps: Optional[int] = None): # Validate results should be the same in all workers return batch_metrics - def train_batch(self) -> dict: + def train_batch(self) -> Dict: loss_vals = [] rets = [] metrics = {} @@ -250,7 +263,8 @@ def train_batch(self) -> dict: class PS(object): - def __init__(self, training_operator_cls, operator_config: Optional[Mapping[str, Any]]): + def __init__(self, training_operator_cls, + operator_config: Optional[Mapping[str, Any]]): self.training_operator_cls = training_operator_cls self.operator_config = operator_config @@ -281,11 +295,11 @@ def setup_collective_group(self, def test_connection(self): for i in range(self.num_worker): - recv = util.zeros((1,), cpu=False) + recv = util.zeros((1, ), cpu=False) col.recv(recv, i, self.group_name) assert recv == 1 for i in range(self.num_worker): - send = util.ones((1,), cpu=False) + send = util.ones((1, ), cpu=False) col.send(send, i, self.group_name) return @@ -368,7 +382,8 @@ def shutdown(self): class Worker(object): - def __init__(self, training_operator_cls, operator_config: Optional[Mapping[str, Any]]): + def __init__(self, training_operator_cls, + operator_config: Optional[Mapping[str, Any]]): self.training_operator_cls = training_operator_cls self.operator_config = operator_config @@ -402,15 +417,15 @@ def setup_collective_group(self, def test_connection(self): for i in range(self.num_ps): - send = util.ones((1,), cpu=False) + send = util.ones((1, ), cpu=False) col.send(send, self.num_worker + i, self.group_name) for i in range(self.num_ps): - recv = util.zeros((1,), cpu=False) + recv = util.zeros((1, ), cpu=False) col.recv(recv, self.num_worker + i, self.group_name) assert recv == 1 return - def params_distribution(self): + def params_distribution(self) -> List: distribution = [] weights = self.get_named_parameters(cpu=True) for k, v in weights.items(): @@ -428,7 +443,7 @@ def make_iterator(self, training: bool = True): self.validation_iterator = iter( self.training_operator._get_validation_loader()) - def get_data_loader_len(self, training: bool = True): + def get_data_loader_len(self, training: bool = True) -> int: """Return the number of batches in the data loader.""" loader = self.training_operator._get_train_loader() if training \ else self.training_operator._get_validation_loader() @@ -439,7 +454,7 @@ def get_data_loader_len(self, training: bool = True): "Data loader has no attribute `__len__`. " "Please set `num_steps` in `train()` or `validate()`.") - def derive_updates(self, batch: Sequence[Any]): + def derive_updates(self, batch: Sequence[Any]) -> Dict: # TODO (Hao): handling data loader next. return self.training_operator.derive_updates(batch) @@ -463,7 +478,7 @@ def compute_gradients(self, params): return loss_val, grads - def split_gradients(self, grad, assignments): + def split_gradients(self, grad, assignments) -> List: # assuming messages are gradients or parameters # this grad is ready to be called by apply_gradients in ParameterServer num_shards = np.unique(np.array(assignments)).size @@ -472,7 +487,7 @@ def split_gradients(self, grad, assignments): shards[assignments[i]][k] = v return shards - def split_parameters(self, assignments): + def split_parameters(self, assignments) -> List: params = self.get_named_parameters(cpu=False) num_shards = np.unique(np.array(assignments)).size shards = [dict() for i in range(num_shards)] @@ -484,12 +499,12 @@ def index_shard(self, shards, index: int): return shards[index] def set_parameters(self, params): - return self.training_operator.set_parameters(params) + self.training_operator.set_parameters(params) - def get_parameters(self, cpu: bool): + def get_parameters(self, cpu: bool) -> List: return self.training_operator.get_parameters(cpu) - def get_named_parameters(self, cpu: bool): + def get_named_parameters(self, cpu: bool) -> Dict: return self.training_operator.get_named_parameters(cpu) def get_gradients(self): @@ -564,17 +579,16 @@ def shutdown(self): class DataParallelGroup(BaseDataParallelGroup): """Spawn a actor group for data-parallel training.""" - def __init__(self, - actor_params: Mapping[str, Any], - dist_params: Mapping[str, Any], - num_cpus_per_actor: int, + def __init__(self, actor_params: Mapping[str, Any], + dist_params: Mapping[str, Any], num_cpus_per_actor: int, num_gpus_per_actor: int, initialization_hook: Optional[Callable]): - super(DataParallelGroup, self).__init__(actor_params=actor_params, - dist_params=dist_params, - num_cpus_per_actor=num_cpus_per_actor, - num_gpus_per_actor=num_gpus_per_actor, - initialization_hook=initialization_hook) + super(DataParallelGroup, self).__init__( + actor_params=actor_params, + dist_params=dist_params, + num_cpus_per_actor=num_cpus_per_actor, + num_gpus_per_actor=num_gpus_per_actor, + initialization_hook=initialization_hook) self.is_server = self._dist_params["is_server"] self.num_ps = self._dist_params["num_ps"] self.num_worker = self._dist_params["num_worker"] @@ -595,8 +609,7 @@ def _setup_collective_group(self, num_worker=num_worker, num_ps=num_ps, backend=backend, - group_name=group_name - ) + group_name=group_name) for i, actor in enumerate(self._distributed_actors) ] else: # this can be extend for allreduce. @@ -625,8 +638,9 @@ def _start_actors(self, num_actors: int): ] # setup the rank and group in each replica - ray.get(self._setup_collective_group( - self.num_ps, self.num_worker, self.backend, self.group_name)) + ray.get( + self._setup_collective_group(self.num_ps, self.num_worker, + self.backend, self.group_name)) def test_connection(self): rets = [ diff --git a/examples/jax/mnist_jax_example.py b/examples/jax/mnist_jax_example.py index db7abe2..1a189b3 100644 --- a/examples/jax/mnist_jax_example.py +++ b/examples/jax/mnist_jax_example.py @@ -8,7 +8,6 @@ from distml.strategy.allreduce_strategy import AllReduceStrategy from distml.strategy.ps_strategy import ParameterServerStrategy - from ray.util.sgd.utils import override from jax import random From 43498f8af7a901fc98ac3891095110efb1f8fabd Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Tue, 8 Jun 2021 01:00:23 +0800 Subject: [PATCH 11/12] strategy load/save states --- distml/operator/base_operator.py | 8 +- distml/operator/jax_operator.py | 153 ++++++++++++++++--- distml/operator/torch_operator.py | 106 ++++++++++++- distml/strategy/allreduce_strategy.py | 106 ++++++++----- distml/strategy/base_strategy.py | 36 ++++- distml/strategy/ps_strategy.py | 206 +++++++++++++++++--------- examples/jax/mnist_jax_example.py | 7 + 7 files changed, 485 insertions(+), 137 deletions(-) diff --git a/distml/operator/base_operator.py b/distml/operator/base_operator.py index 892134b..9af342d 100644 --- a/distml/operator/base_operator.py +++ b/distml/operator/base_operator.py @@ -1,6 +1,7 @@ """Abstract class for framework-specific training operators.""" from abc import ABCMeta from abc import abstractmethod +from typing import Optional class TrainingOperator(metaclass=ABCMeta): @@ -90,7 +91,7 @@ def load_custom_states(self, states, *args, **kwargs): pass @abstractmethod - def save_states(self, checkpoint): + def save_states(self, checkpoint: str): """Save the states to a file path. This function shall be instantiated in framework-specific operator @@ -104,7 +105,10 @@ def get_states(self): raise NotImplementedError() @abstractmethod - def load_states(self, checkpoint): + def load_states(self, + states=None, + checkpoint: Optional[str] = None, + keys: Optional[bool] = None): """Load the states from a file path. This functions shall be instantiated in framework-specific operators diff --git a/distml/operator/jax_operator.py b/distml/operator/jax_operator.py index 2ae070b..6feff06 100644 --- a/distml/operator/jax_operator.py +++ b/distml/operator/jax_operator.py @@ -1,4 +1,8 @@ -from typing import Any, Mapping, Optional +import os +import pickle +import warnings + +from typing import Any, Mapping, Optional, List, Dict import numpy as np import cupy as cp @@ -16,7 +20,7 @@ class JAXTrainingOperator(TrainingOperator): - def __init__(self, operator_config: Optional[Mapping[str, Any]]): + def __init__(self, *, operator_config: Optional[Mapping[str, Any]]): super(JAXTrainingOperator, self).__init__(operator_config) # Should be set by users in the `register` function. # model methods @@ -29,11 +33,14 @@ def __init__(self, operator_config: Optional[Mapping[str, Any]]): self.get_params = None self.criterion = None + self.lr_scheduler = None # Data loaders for training and validation, registered by users. self._train_loader = None self._validation_loader = None + self._custom_states = None + self.setup(operator_config) if hasattr(operator_config, "jit_mode"): @@ -267,15 +274,15 @@ def validate_batch(self, batch): targets_class = jnp.argmax(targets, axis=1) acc = jnp.mean(prediction_class == targets_class) - samples_num = targets.shape[0] + num_sample = targets.shape[0] return { "val_loss": loss.item(), "val_accuracy": acc.item(), - "samples_num": samples_num + "num_sample": num_sample } - def get_parameters(self, cpu: bool): + def get_parameters(self, cpu: bool) -> List: """get the flatten parameters.""" params = self.get_params(self.opt_state) flatten_params, tree = tree_flatten(params) @@ -284,9 +291,11 @@ def get_parameters(self, cpu: bool): if cpu: flatten_params = list(map(np.asarray, flatten_params)) + else: + flatten_params = list(map(jnp.asarray, flatten_params)) return flatten_params - def get_named_parameters(self, cpu: bool): + def get_named_parameters(self, cpu: bool) -> Dict: """Get the named parameters. In jax, we need to construct a dict to contain the parameters. @@ -299,6 +308,7 @@ def get_named_parameters(self, cpu: bool): } else: dict_params = {f"{idx}": p for idx, p in enumerate(params)} + return dict_params # TODO(HUI): used in load states or load parameters @@ -312,6 +322,9 @@ def set_parameters(self, new_params): """ assert isinstance(new_params, dict) + # make sure all params in GPU. Should be controlled of use_gpu. + new_params = {k: jax.device_put(v) for k, v in new_params.items()} + keys, new_params = unzip2( sorted(new_params.items(), key=lambda d: int(d[0]))) self.preset_keys = keys @@ -349,6 +362,8 @@ def reset_optimizer_for_params(self, params): "Got {}".format(type(params))) keys, params = unzip2(sorted(params.items(), key=lambda d: int(d[0]))) + + self.preset_keys = keys # The keys to index the params. self.tree = tree_structure(params) self.opt_state = self.opt_init(params) @@ -383,25 +398,117 @@ def asarray(self, v): return jnp.asarray(v) def clean_redundancy(self): - del self._train_loader - del self._validation_loader + if self._train_loader: + del self._train_loader + self._train_loader = None + if self._validation_loader: + del self._validation_loader + self._validation_loader = None - # TODO(HUI): use pickle to serialize parameters or states and save it. - def save_parameters(self, checkpoint: str): - raise NotImplementedError( - "save_parameters is not support in jax operator.") + def register_custom_states(self, custom_states): + self._custom_states = custom_states - def load_parameters(self, checkpoint: str): - raise NotImplementedError( - "load_parameters is not support in jax operator.") + def get_custom_states(self): + return self._custom_states - def save_states(self, checkpoint: str): - raise NotImplementedError( - "save_states is not support in jax operator.") + def get_states(self) -> Dict: + """Return the states of this training operator.""" + + states_flat, tree, subtrees = self.opt_state - def get_states(self): - raise NotImplementedError("get_states is not support in jax operator.") + states_unflat = map(tree_unflatten, subtrees, states_flat) - def load_states(self, checkpoint: str): - raise NotImplementedError( - "load_states is not support in jax operator.") + states_unflat_dict = { + str(idx): value + for idx, value in enumerate(states_unflat) + } + + states = { + "opt_state": states_unflat_dict, + } + + if self._custom_states: + states.update({"custom": self.get_custom_states()}) + + if self.lr_scheduler and hasattr(self.lr_scheduler, + "get_state_dict()"): + states.update({"lr_scheduler": self.lr_scheduler.get_state_dict()}) + + return states + + def save_states(self, checkpoint: str): + states = self.get_states() + with open(checkpoint, "wb") as f: + pickle.dump(states, f) + + def load_states(self, + states=None, + checkpoint: Optional[str] = None, + keys: Optional[bool] = None): + if checkpoint: + assert ".pkl" in checkpoint, \ + "checkpoint should be a .pkl file. Got {}".format(checkpoint) + if not os.path.exists(checkpoint): + raise RuntimeError("Checkpoint file doesn't exists.") + with open(checkpoint, "rb") as f: + states = pickle.load(f) + + if states: + new_opt_states = states.get("opt_state", None) + custom_states = states.get("custom_states", None) + lr_scheduler_states = states.get("lr_scheduler", None) + + if not new_opt_states: + raise RuntimeError("subtrees of new params is empty.") + + assert isinstance(new_opt_states, dict) + + if not keys: + keys = tuple([ + str(idx) + for idx in range(len(self.get_parameters(cpu=False))) + ]) + else: + # construct_opt_states_dict = OrderedDict() + construct_opt_states_dict = dict() + for key in keys: + construct_opt_states_dict[key] = new_opt_states[key] + new_opt_states = construct_opt_states_dict + + new_keys, new_opt_states = unzip2( + sorted(new_opt_states.items(), key=lambda d: int(d[0]))) + + keys = tuple(keys) + new_keys = tuple(new_keys) + assert keys == new_keys, \ + "checkpoint key doesn't match the model params." + + states_flat, tree, subtrees = self.opt_state + states_flat_2, subtrees_2 = unzip2( + map(tree_flatten, new_opt_states)) + + if not subtrees_2: + raise RuntimeError("subtrees of new params is empty.") + for idx, (subtree, subtree_2) in enumerate( + zip(subtrees, subtrees_2)): + if subtree_2 != subtree: + msg = ("input structure did not match the save params " + "structure. input {} and output {}.") + raise TypeError(msg.format(subtree, subtree_2)) + + self.opt_state = OptimizerState(states_flat_2, tree, subtrees_2) + + if custom_states: + self._custom_states.update(custom_states) + + if lr_scheduler_states: + if hasattr(self.lr_scheduler, "set_states_dict"): + self.lr_scheduler.set_states_dict(lr_scheduler_states) + else: + warnings.warn( + "lr scheduler must have `set_states_dict` method" + " to support loading lr scheduler states.") + else: + raise RuntimeError("This checkpoint is empty." + "Got checkpoint {}, states {}".format( + checkpoint, states)) diff --git a/distml/operator/torch_operator.py b/distml/operator/torch_operator.py index 3ddd667..d0a98f8 100644 --- a/distml/operator/torch_operator.py +++ b/distml/operator/torch_operator.py @@ -168,9 +168,55 @@ def validate_batch(self, batch): loss = criterion(output, target) # Todo(Hao): report accuracy instead loss here. - batch_metric = {"val_loss": loss.item()} + batch_metric = {"val_loss": loss.item(), "num_sample": target.size(0)} return batch_metric + def get_named_parameters(self, cpu): + named_params = self._model.named_parameters() + is_cuda = next(self._model.parameters()).is_cuda + output_params = {} + + if cpu: + if is_cuda: + for key, p in named_params: + output_params[key] = p.cpu() + else: + for key, p in named_params: + output_params[key] = p + else: + if not is_cuda: + for key, p in named_params: + # TODO(HUI): should put in specific device. + named_params[key] = p.cuda() + else: + for key, p in named_params: + output_params[key] = p + + return output_params + + def get_parameters(self, cpu): + params = self._model.parameters() + is_cuda = next(self._model.parameters()).is_cuda + output_params = [] + + if cpu: + if is_cuda: + for p in params: + output_params.append(p.cpu()) + else: + for p in params: + output_params.append(p) + else: + if not is_cuda: + for idx, p in enumerate(params): + # TODO(HUI): should put in specific device. + output_params(p.cuda()) + else: + for p in params: + output_params.append(p) + + return output_params + def get_states(self): """Return the states of this training operator.""" states = { @@ -196,12 +242,47 @@ def load_states(self, states=None, checkpoint=None): self._lr_scheduler.load_state_dict(states["lr_scheduler"]) self.load_custom_states(states["custom"]) + def _load_from_checkpoint(self, checkpoint): + return torch.load(checkpoint) + def save_states(self, checkpoint): """Save the states to a file path.""" states = self.get_states() # TODO(Hao): test this. torch.save(states, checkpoint) + def clean_redundancy(self): + del self._train_loader + del self._validation_loader + + def set_parameters(self, params): + if isinstance(params, dict): + self._model.load_state_dict(params) + else: + raise RuntimeError("params is not dict." + "Got {}".format(type(params))) + + def reset_optimizer_for_params(self, params): + if isinstance(params, dict): + params_list = [] + + for k, v in params.items(): + params_list.append(v) + params = params_list + + _optimizer = self._optimizer + + _optimizer.param_groups = [] + + param_groups = list(params) + if len(param_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + if not isinstance(param_groups[0], dict): + param_groups = [{'params': param_groups}] + + for param_group in param_groups: + _optimizer.add_param_group(param_group) + @staticmethod def _get_gradients(model): """Return the gradient updates of the model as a Python dict. @@ -241,3 +322,26 @@ def _set_gradients(model, grads): # to(p.grad.device) # else: # p.grad = torch.from_numpy(gradients[name]) + + def ones(self, shape, cpu: bool = True): + tensor = torch.ones(shape) + return tensor if cpu else tensor.cuda() + + def zeros(self, shape, cpu: bool = True): + tensor = torch.zeros(shape) + return tensor if cpu else tensor.cuda() + + def ones_like(self, x, cpu: bool = True): + tensor = torch.ones_like(x) + return tensor if cpu else tensor.cuda() + + def zeros_like(self, x, cpu: bool = True): + tensor = torch.zeros_like(x) + return tensor if cpu else tensor.cuda() + + @staticmethod + def numel(tensor): + return tensor.numel() + + def asarray(self, v): + return torch.as_tensor(v) diff --git a/distml/strategy/allreduce_strategy.py b/distml/strategy/allreduce_strategy.py index dc557a8..1cc3d82 100644 --- a/distml/strategy/allreduce_strategy.py +++ b/distml/strategy/allreduce_strategy.py @@ -1,13 +1,14 @@ import logging -from typing import Callable, Mapping, Any, Optional, Dict +from typing import List, Callable, Mapping, Any, Optional, Dict import ray import ray.util.collective as col -from distml.strategy.base_strategy import BaseStrategy, BaseDataParallelGroup -from distml.util import ThroughputCollection +from ray.util.sgd.utils import AverageMeterCollection import numpy as np +from distml.strategy.base_strategy import BaseStrategy, BaseDataParallelGroup + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -49,13 +50,11 @@ def __init__(self, num_gpus_per_worker=num_gpus_per_worker, **kwargs) self._global_batch_size = None + if operator_config and operator_config.get("batch_size"): self._global_batch_size = operator_config.get("batch_size") - if self._global_batch_size: - self._collector = ThroughputCollection( - batch_size=self._global_batch_size) - else: - self._collector = ThroughputCollection() + + self._init_strategy() def train(self, num_steps: Optional[int] = None) -> Dict: """Run the training on parallel workers. @@ -65,8 +64,10 @@ def train(self, num_steps: Optional[int] = None) -> Dict: function will simply train for one epoch. Returns: - None + metric (dict): metric of the training set. """ + # TODO(HUI): metric use hook to control. + # TODO (Hao): add fault tolerance using `max_retries`. steps = num_steps if num_steps \ else self.data_parallel_group.get_data_loader_len() @@ -74,10 +75,9 @@ def train(self, num_steps: Optional[int] = None) -> Dict: # TODO(Hao): this call should be hidden inside Replica. self.data_parallel_group.make_iterator() for idx in range(steps): - with self._collector.record("train"): - metrics = self.data_parallel_group.train_batch() + metric = self.data_parallel_group.train_batch() print("Step: {}/{}".format(idx, steps)) - return metrics + return metric def validate(self, num_steps: Optional[int] = None) -> Dict: """Evaluates the model on the validation data. @@ -86,18 +86,35 @@ def validate(self, num_steps: Optional[int] = None) -> Dict: num_steps (int): number of batches to evaluate. If None, the function will simply validate across the entire validation dataset. + + Returns: + metric (dict): metric of the validate set. """ steps = num_steps if num_steps \ else self.data_parallel_group.get_data_loader_len(training=False) + + metrics = [ + AverageMeterCollection() + for _ in range(len(self.data_parallel_group.replicas)) + ] + self.data_parallel_group.make_iterator(training=False) for idx in range(steps): - with self._collector.record("validate"): - batch_metrics = self.data_parallel_group.validate_batch() - self._collector.update( - "validate", val_acc=batch_metrics[0]["val_loss"]) - self._collector.save("validate") + batch_metrics = self.data_parallel_group.validate_batch() + + for metric_idx, metric in enumerate(batch_metrics): + num_sample = metric.pop("num_sample") + metrics[metric_idx].update(metric, n=num_sample) + # TODO: validate result should be the same in all workers - return batch_metrics + return metrics[0].summary() + + def _init_strategy(self): + """Do initialization for the distributed strategy.""" + # All sync with replica 0 + init_weights = self.data_parallel_group.get_named_parameters(cpu=True) + # all replicas get synced + self.data_parallel_group.set_parameters(init_weights) def _start_workers(self): """Create distributed workers on the Ray cluster for distributed training. @@ -132,14 +149,14 @@ def _start_workers(self): def shutdown(self, force: bool = False): self.data_parallel_group.shutdown(force=force) - def save_parameters(self, checkpoint: str): - self.data_parallel_group.save_parameters(checkpoint) + def get_states(self): + return self.data_parallel_group.get_states() - def load_parameters(self, checkpoint: str): - self.data_parallel_group.load_parameters(checkpoint) + def save_states(self, checkpoint: str): + self.data_parallel_group.save_states(checkpoint) - def _init_strategy(self): - pass + def load_states(self, states=None, checkpoint: Optional[str] = None): + self.data_parallel_group.load_states(states, checkpoint) class Replica: @@ -205,7 +222,7 @@ def train_batch(self) -> Dict: metrics = {} try: batch = next(self.train_iterator) - except StopIteration and NameError: + except StopIteration or NameError: self.make_iterator() batch = next(self.train_iterator) loss_val, updates = self.derive_updates(batch) @@ -214,7 +231,7 @@ def train_batch(self) -> Dict: metrics["train_loss"] = loss_val for _, g in updates.items(): cg = self.training_operator.to_cupy(g) - col.allreduce(cg) + col.allreduce(cg, self.group_name) cg = cg / float(self.world_size) self.apply_updates(updates) return metrics @@ -231,7 +248,7 @@ def updates_transform(self, updates): def validate_batch(self): try: batch = next(self.validation_iterator) - except StopIteration and NameError: + except StopIteration or NameError: self.make_iterator(training=False) batch = next(self.validation_iterator) batch_metric = self.training_operator.validate_batch(batch) @@ -244,11 +261,23 @@ def shutdown(self): del self.training_operator return 1 - def save_parameters(self, checkpoint: str): - self.training_operator.save_parameters(checkpoint) + def get_states(self): + return self.training_operator.get_states() + + def save_states(self, checkpoint: str): + self.training_operator.save_states(checkpoint) + + def load_states(self, states=None, checkpoint: Optional[str] = None): + self.training_operator.load_states(states, checkpoint) - def load_parameters(self, checkpoint: str): - self.training_operator.load_parameters(checkpoint) + def get_parameters(self, cpu: bool) -> List: + return self.training_operator.get_parameters(cpu) + + def get_named_parameters(self, cpu: bool) -> Dict: + return self.training_operator.get_named_parameters(cpu) + + def set_parameters(self, params): + self.training_operator.set_parameters(params) def apply(self, fn: Callable): """Apply a function in the replica process.""" @@ -288,7 +317,6 @@ def __init__(self, actor_params: Mapping[str, Any], num_cpus_per_actor=num_cpus_per_actor, num_gpus_per_actor=num_gpus_per_actor, initialization_hook=initialization_hook) - self._replicas = None @property def _replica_params(self): @@ -370,14 +398,18 @@ def shutdown(self, force: bool = False): def reset(self): pass - def save_parameters(self, checkpoint: str): - rets = [self.replicas[0].save_parameters.remote(checkpoint)] + def get_states(self): + rets = [self.replicas[0].get_states.remote()] + return ray.get(rets)[0] + + def save_states(self, checkpoint: str): + rets = [self.replicas[0].save_states.remote(checkpoint)] ray.get(rets) - def load_parameters(self, checkpoint: str): + def load_states(self, states=None, checkpoint: Optional[str] = None): rets = [ - replica.load_parameters.remote(checkpoint) - for _, replica in enumerate(self.replicas) + replica.load_states.remote(states, checkpoint) + for replica in self.replicas ] ray.get(rets) diff --git a/distml/strategy/base_strategy.py b/distml/strategy/base_strategy.py index 4ad1e80..3f31237 100644 --- a/distml/strategy/base_strategy.py +++ b/distml/strategy/base_strategy.py @@ -1,7 +1,7 @@ from abc import ABCMeta from abc import abstractmethod import logging -from typing import Callable, Any, Mapping, Optional +from typing import Callable, Any, Mapping, Optional, Sequence import ray @@ -52,7 +52,7 @@ def validate(self): raise NotImplementedError() @abstractmethod - def save_parameters(self, checkpoint: str): + def save_states(self, checkpoint: str): """Saves the Trainer state to the provided checkpoint path. Args: @@ -61,11 +61,17 @@ def save_parameters(self, checkpoint: str): raise NotImplementedError() @abstractmethod - def load_parameters(self, checkpoint: str): - """Loads the Trainer state to the provided checkpoint path. + def load_states(self, + states=None, + checkpoint: Optional[str] = None, + keys: Optional[Sequence[str]] = None): + """Saves the Trainer state to the provided checkpoint path. Args: + states: States to load. checkpoint (str): Path to target checkpoint file. + keys (str): Keys of the params to load. + If None, using all states. """ raise NotImplementedError() @@ -157,13 +163,27 @@ def reset(self): raise NotImplementedError() @abstractmethod - def save_parameters(self, checkpoint: str): - """Let the first actor save parameters.""" + def save_states(self, checkpoint: str): + """Saves the Trainer state to the provided checkpoint path. + + Args: + checkpoint (str): Path to target checkpoint file. + """ raise NotImplementedError() @abstractmethod - def load_parameters(self, checkpoint: str): - """All actor load parameters from checkpoint.""" + def load_states(self, + states=None, + checkpoint: Optional[str] = None, + keys: Optional[Sequence[str]] = None): + """Saves the Trainer state to the provided checkpoint path. + + Args: + states: States to load. + checkpoint (str): Path to target checkpoint file. + keys (str): Keys of the params to load. + If None, using all states. + """ raise NotImplementedError() @abstractmethod diff --git a/distml/strategy/ps_strategy.py b/distml/strategy/ps_strategy.py index a3486b8..6709069 100644 --- a/distml/strategy/ps_strategy.py +++ b/distml/strategy/ps_strategy.py @@ -1,14 +1,14 @@ +import logging from typing import List, Callable, Mapping, Any, Optional, Sequence, Dict + import ray import ray.util.collective as col +from ray.util.sgd.utils import AverageMeterCollection import numpy as np import distml.util as util from distml.strategy.base_strategy import BaseStrategy, BaseDataParallelGroup -from distml.util import ThroughputCollection - -import logging logger = logging.getLogger(__name__) @@ -72,11 +72,6 @@ def __init__(self, if operator_config and operator_config.get("batch_size"): self._global_batch_size = operator_config.get("batch_size") - if self._global_batch_size: - self._collector = ThroughputCollection( - batch_size=self._global_batch_size) - else: - self._collector = ThroughputCollection() def _init_strategy(self): """Do initialization for the distributed strategy.""" @@ -171,16 +166,28 @@ def shutdown(self, force: bool = False): self.worker_group.shutdown(force=force) self.server_group.shutdown(force=force) - def save_parameters(self, checkpoint: str): - # TODO(HUI): ps save parameters. - # First, worker rank 0 should pull the latest parameter from servers - # Then, worker rank 0 save parameters - self.worker_group.save_parameters(checkpoint) + def get_states(self): + # worker0 pull latest params and return states. + for server_idx, server in enumerate(self.server_group.actors): + # every server sends its shard to the worker0 + server.send_params.remote(0) + # the worker0 receives shards from ps. + ret = self.worker_group.actors[0].recv_params.remote() + ray.get([ret]) - def load_parameters(self, checkpoint: str): - # TODO(HUI): ps load parameters. - # shard parameters and send to all servers. - self.server_group.load_parameters(checkpoint) + return self.worker_group.get_states() + + def save_states(self, checkpoint: str): + # worker0 pull latest params. + for server_idx, server in enumerate(self.server_group.actors): + server.send_params.remote(0) + ret = self.worker_group.actors[0].recv_params.remote() + ray.get([ret]) + # Then, worker0 save parameters. + self.worker_group.save_states(checkpoint) + + def load_states(self, states=None, checkpoint: Optional[str] = None): + self.server_group.load_states(states=states, checkpoint=checkpoint) def _round_robin_sharding(self): """Generate the assignment of variable to servers.""" @@ -203,7 +210,7 @@ def train(self, num_steps: Optional[int] = None) -> Dict: function will simply train for one epoch. Returns: - None + metrics (dict): metrics of training result. """ # TODO (Hao): add fault tolerance using `max_retries`. steps = num_steps if num_steps \ @@ -214,8 +221,7 @@ def train(self, num_steps: Optional[int] = None) -> Dict: # train one epoch self.worker_group.make_iterator() for idx in range(steps): - with self._collector.record("train"): - metrics = self.train_batch() + metrics = self.train_batch() print("Step: {}/{}".format(idx, steps)) return metrics @@ -231,11 +237,31 @@ def validate(self, num_steps: Optional[int] = None) -> Dict: else self.worker_group.get_data_loader_len(training=False) self.worker_group.make_iterator(training=False) + # Worker group pull latest params. + rets = [] + for worker_idx, worker in enumerate(self.worker_group.actors): + for server_idx, server in enumerate(self.server_group.actors): + # every server sends its shard to the worker + server.send_params.remote(worker_idx) + # the worker receives shards from ps, compute loss, gradients + # and sends these gradients to every server + ret = worker.recv_params.remote() + rets.append(ret) + ray.get(rets) + + metrics = [ + AverageMeterCollection() + for _ in range(len(self.worker_group.actors)) + ] + # TODO(HUI): Construct a better tool to save validate results. for idx in range(steps): batch_metrics = self.worker_group.validate_batch() + for metric_idx, metric in enumerate(batch_metrics): + num_sample = metric.pop("num_sample") + metrics[metric_idx].update(metric, n=num_sample) # Validate results should be the same in all workers - return batch_metrics + return metrics[0].summary() def train_batch(self) -> Dict: loss_vals = [] @@ -272,9 +298,9 @@ def __init__(self, training_operator_cls, self.params = dict() def setup_operator(self): - # figure out the signature of training_operator_cls later. + """Instantiate the training operator.""" self.training_operator = self.training_operator_cls( - self.operator_config) + operator_config=self.operator_config) def setup_collective_group(self, rank: int, @@ -282,7 +308,7 @@ def setup_collective_group(self, num_worker: int, backend: str = "nccl", group_name: str = "default"): - # rank should be true rank. means, rank has already plus num_worker. + # rank has already plus num_worker. self.rank = rank self.num_ps = num_ps self.num_worker = num_worker @@ -293,15 +319,18 @@ def setup_collective_group(self, col.init_collective_group( num_ps + num_worker, rank, backend=backend, group_name=group_name) + def apply(self, fn: Callable): + """Apply a function in the replica process.""" + return fn() + def test_connection(self): for i in range(self.num_worker): recv = util.zeros((1, ), cpu=False) - col.recv(recv, i, self.group_name) + col.recv(recv, i, group_name=self.group_name) assert recv == 1 for i in range(self.num_worker): send = util.ones((1, ), cpu=False) - col.send(send, i, self.group_name) - return + col.send(send, i, group_name=self.group_name) def _init_grad_counts(self): self.grad_counts = [0] * self.num_worker @@ -324,8 +353,21 @@ def set_params(self, params): self.training_operator.reset_optimizer_for_params(self.params) self._init_grad_buffer() + def load_states(self, states=None, checkpoint: Optional[str] = None): + self.training_operator.load_states( + states=states, + checkpoint=checkpoint, + keys=tuple(self.params.keys())) + + # # Update the params in actor aspect. + latest_params = self.training_operator.get_named_parameters(cpu=False) + + assert self.params.keys() == latest_params.keys() + + for key in latest_params.keys(): + self.params[key] = latest_params[key] + def apply_updates(self, grad_buffer): - # TODO(HUI): gradient divide by num_worker self.training_operator.apply_updates(grad_buffer) self.params = self.training_operator.get_named_parameters(cpu=False) @@ -338,7 +380,7 @@ def send_params(self, dst_rank: int): """ Send this param shard to the destination worker """ for name, v in self.params.items(): cv = self.training_operator.to_cupy(v) - col.send(cv, dst_rank, self.group_name) + col.send(cv, dst_rank, group_name=self.group_name) def update(self, src_rank: int): """Receive gradients and update""" @@ -396,7 +438,7 @@ def __init__(self, training_operator_cls, def setup_operator(self): # figure out the signature of training_operator_cls later. self.training_operator = self.training_operator_cls( - self.operator_config) + operator_config=self.operator_config) def setup_collective_group(self, rank: int, @@ -415,13 +457,17 @@ def setup_collective_group(self, col.init_collective_group( num_ps + num_worker, rank, backend=backend, group_name=group_name) + def apply(self, fn: Callable): + """Apply a function in the replica process.""" + return fn() + def test_connection(self): for i in range(self.num_ps): send = util.ones((1, ), cpu=False) - col.send(send, self.num_worker + i, self.group_name) + col.send(send, self.num_worker + i, group_name=self.group_name) for i in range(self.num_ps): recv = util.zeros((1, ), cpu=False) - col.recv(recv, self.num_worker + i, self.group_name) + col.recv(recv, self.num_worker + i, group_name=self.group_name) assert recv == 1 return @@ -458,16 +504,15 @@ def derive_updates(self, batch: Sequence[Any]) -> Dict: # TODO (Hao): handling data loader next. return self.training_operator.derive_updates(batch) - def compute_gradients(self, params): + def compute_gradients(self): """ Update worker parameters that received from server. Compute gradients and return named gradients. """ - self.set_parameters(params) try: batch = next(self.training_iterator) - except StopIteration and NameError: + except StopIteration or NameError: self.make_iterator() batch = next(self.training_iterator) @@ -479,6 +524,7 @@ def compute_gradients(self, params): return loss_val, grads def split_gradients(self, grad, assignments) -> List: + """Splitting gradients according to assignments.""" # assuming messages are gradients or parameters # this grad is ready to be called by apply_gradients in ParameterServer num_shards = np.unique(np.array(assignments)).size @@ -488,6 +534,7 @@ def split_gradients(self, grad, assignments) -> List: return shards def split_parameters(self, assignments) -> List: + """Splitting parameters according to assignments.""" params = self.get_named_parameters(cpu=False) num_shards = np.unique(np.array(assignments)).size shards = [dict() for i in range(num_shards)] @@ -498,6 +545,34 @@ def split_parameters(self, assignments) -> List: def index_shard(self, shards, index: int): return shards[index] + def recv_params(self): + weights = self.get_named_parameters(cpu=False) + params = dict() + + # 1. Create the receive lists to group collective calls + recv_list = [] + for i in range(self.num_ps): + recv_list.append([]) + param_shard_keys = self.name_list[i] + for key in param_shard_keys: + to_recv = weights[key] + recv_list[-1].append( + self.training_operator.ones(to_recv.shape, cpu=False)) + + # 2. Receive params from servers + for i in range(self.num_ps): + for j in range(len(self.name_list[i])): + v = self.training_operator.to_cupy(recv_list[i][j]) + col.recv(v, self.num_worker + i, group_name=self.group_name) + + # 3. Set params in workers. + for i in range(self.num_ps): + param_shard_keys = self.name_list[i] + for j in range(len(param_shard_keys)): + params[param_shard_keys[j]] = recv_list[i][j] + + self.set_parameters(params) + def set_parameters(self, params): self.training_operator.set_parameters(params) @@ -512,6 +587,12 @@ def get_gradients(self): # when derive_updates. return self.training_operator.get_gradients() + def get_states(self): + return self.training_operator.get_states() + + def save_states(self, checkpoint: str): + self.training_operator.save_states(checkpoint) + def set_assignments(self, assignments): self.assignments = assignments keys = list(self.get_named_parameters(cpu=False).keys()) @@ -522,41 +603,18 @@ def compute(self): """Returns the loss, and send gradients to servers""" metrics = {} - weights = self.get_named_parameters(cpu=False) - params = dict() + self.recv_params() - # 1. Create the receive lists to group collective calls - recv_list = [] - for i in range(self.num_ps): - recv_list.append([]) - param_shard_keys = self.name_list[i] - for key in param_shard_keys: - to_recv = weights[key] - recv_list[-1].append( - self.training_operator.ones(to_recv.shape, cpu=False)) - - # 2. Receive params from servers - for i in range(self.num_ps): - for j in range(len(self.name_list[i])): - v = self.training_operator.to_cupy(recv_list[i][j]) - col.recv(v, self.num_worker + i, self.group_name) - - # 3. Set params in workers and compute gradients. - for i in range(self.num_ps): - param_shard_keys = self.name_list[i] - for j in range(len(param_shard_keys)): - params[param_shard_keys[j]] = recv_list[i][j] - - loss_val, grad = self.compute_gradients(params) + loss_val, grad = self.compute_gradients() metrics["train_loss"] = loss_val - # 4. Shard gradients and send to servers. + # Shard gradients and send to servers. split_grad = self.split_gradients(grad, self.assignments) for i in range(self.num_ps): this_shard = self.index_shard(split_grad, i) for _, v in this_shard.items(): cv = self.training_operator.to_cupy(v) - col.send(cv, self.num_worker + i, self.group_name) + col.send(cv, self.num_worker + i, group_name=self.group_name) return metrics def validate_batch(self): @@ -637,6 +695,10 @@ def _start_actors(self, num_actors: int): RemoteActor.remote(**self._actor_params) for _ in range(num_actors) ] + # apply init_hook + if self._initialization_hook: + self.apply_all_replicas(self._initialization_hook) + # setup the rank and group in each replica ray.get( self._setup_collective_group(self.num_ps, self.num_worker, @@ -656,6 +718,14 @@ def set_assignments(self, assignments): ] return rets + def apply_all_replicas(self, fn: Callable): + """Apply fn in all replica processes and wait until completion.""" + return ray.get(self._apply_all_replicas(fn)) + + def _apply_all_replicas(self, fn): + """Apply a function fn in all replica processes.""" + return [actor.apply.remote(fn) for actor in self.actors] + def _make_iterator(self, training: bool): return [actor.make_iterator.remote(training) for actor in self.actors] @@ -695,13 +765,17 @@ def reset(self): def actors(self): return self._distributed_actors - def save_parameters(self, checkpoint: str): - rets = [self.actors[0].save_parameters.remote(checkpoint)] + def get_states(self): + ret = self.actors[0].get_states.remote() + return ray.get([ret])[0] + + def save_states(self, checkpoint: str): + rets = [self.actors[0].save_states.remote(checkpoint)] ray.get(rets) - def load_parameters(self, checkpoint: str): + def load_states(self, states=None, checkpoint: Optional[str] = None): rets = [ - actor.load_parameters.remote(checkpoint) + actor.load_states.remote(states=states, checkpoint=checkpoint) for _, actor in enumerate(self.actors) ] ray.get(rets) diff --git a/examples/jax/mnist_jax_example.py b/examples/jax/mnist_jax_example.py index 1a189b3..324a99a 100644 --- a/examples/jax/mnist_jax_example.py +++ b/examples/jax/mnist_jax_example.py @@ -16,6 +16,7 @@ from jax_util.resnet import ResNet18, ResNet50, ResNet101 from jax_util.datasets import mnist, Dataloader +import numpy as np def initialization_hook(): # Need this for avoiding a connection restart issue on AWS. @@ -55,6 +56,12 @@ def setup(self, config): with FileLock(".ray.lock"): train_images, train_labels, test_images, test_labels = mnist() + if config.get("test_mode", False): + train_images = np.random.choice(train_images, 1000) + train_labels = np.random.choice(train_labels, 1000) + test_images = np.random.choice(test_images, 1000) + test_labels = np.random.choice(test_labels, 1000) + train_images = train_images.reshape(train_images.shape[0], 1, 28, 28).transpose(2, 3, 1, 0) test_images = test_images.reshape(test_images.shape[0], 1, 28, From 9a4acb1bc718eeb36be7a49cfede2b657daee7be Mon Sep 17 00:00:00 2001 From: Ezra-H Date: Sat, 19 Jun 2021 22:44:46 +0800 Subject: [PATCH 12/12] callable check --- distml/operator/jax_operator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/distml/operator/jax_operator.py b/distml/operator/jax_operator.py index 6feff06..a738640 100644 --- a/distml/operator/jax_operator.py +++ b/distml/operator/jax_operator.py @@ -103,7 +103,7 @@ def register(self, *, model, optimizer, criterion, jit_mode: bool = False): "'opt_init', 'opt_update' and 'get_params'." "Got: {} {}".format(type(optimizer), len(optimizer))) - if not hasattr(criterion, "__call__"): + if not callable(criterion): raise RuntimeError( "The `criterion` must be callable function that " "feed logits and target, return the loss value. " @@ -123,12 +123,12 @@ def _register_model(self, model): "`opt_states` return from optimizer `opt_init`. " "Got: {}".format(type(model[0]))) - if not hasattr(model[1], "__call__"): + if not callable(model[1]): raise RuntimeError("The second elemente of `model` must be the " "`init_fun` return from model. " "Got: {}".format(type(model[1]))) - if not hasattr(model[2], "__call__"): + if not callable(model[2]): raise RuntimeError("The third elemente of `model` must be the " "`predict_fun` return from model. " "Got: {}".format(type(model[2]))) @@ -139,18 +139,18 @@ def _register_model(self, model): def _register_optimizer(self, optimizer): """register optimizer components.""" - if not hasattr(optimizer[0], "__call__"): + if not callable(optimizer[0]): raise RuntimeError("The fist elemente of `optimizer` must be the " "`opt_init` return from optimizer. " "Got: {}".format(type(optimizer[1]))) - if not hasattr(optimizer[1], "__call__"): + if not callable(optimizer[1]): raise RuntimeError( "The second elemente of `optimizer` must be the " "`opt_update` return from optimizer. " "Got: {}".format(type(optimizer[1]))) - if not hasattr(optimizer[2], "__call__"): + if not callable(optimizer[2]): raise RuntimeError("The third elemente of `optimizer` must be the " "`get_params` return from optimizer. " "Got: {}".format(type(optimizer[2])))