From a946f7e6e1bc093049d6edde646fe1d0da8b996c Mon Sep 17 00:00:00 2001 From: Xuezhi-Liang Date: Mon, 4 Apr 2022 14:26:15 +0400 Subject: [PATCH 01/10] stage1 --- adaptdl/adaptdl/torch/torch/__init__.py | 142 +++++ adaptdl/adaptdl/torch/torch/_metrics.py | 199 +++++++ adaptdl/adaptdl/torch/torch/_metrics_test.py | 158 ++++++ adaptdl/adaptdl/torch/torch/accumulator.py | 312 +++++++++++ .../adaptdl/torch/torch/accumulator_test.py | 60 +++ adaptdl/adaptdl/torch/torch/context.py | 98 ++++ adaptdl/adaptdl/torch/torch/data.py | 492 ++++++++++++++++++ adaptdl/adaptdl/torch/torch/data_test.py | 168 ++++++ adaptdl/adaptdl/torch/torch/epoch.py | 178 +++++++ adaptdl/adaptdl/torch/torch/epoch_test.py | 42 ++ .../torch/torch/gradient_noise_scale.py | 330 ++++++++++++ .../torch/torch/gradient_noise_scale_test.py | 60 +++ adaptdl/adaptdl/torch/torch/iterator.py | 121 +++++ adaptdl/adaptdl/torch/torch/parallel.py | 232 +++++++++ adaptdl/adaptdl/torch/torch/parallel_test.py | 67 +++ adaptdl/adaptdl/torch/torch/scaling_rules.py | 200 +++++++ .../adaptdl/torch/torch/scaling_rules_test.py | 253 +++++++++ ...testcase_for_adaptdldataloader_refactor.py | 156 ++++++ 18 files changed, 3268 insertions(+) create mode 100644 adaptdl/adaptdl/torch/torch/__init__.py create mode 100644 adaptdl/adaptdl/torch/torch/_metrics.py create mode 100644 adaptdl/adaptdl/torch/torch/_metrics_test.py create mode 100644 adaptdl/adaptdl/torch/torch/accumulator.py create mode 100644 adaptdl/adaptdl/torch/torch/accumulator_test.py create mode 100644 adaptdl/adaptdl/torch/torch/context.py create mode 100644 adaptdl/adaptdl/torch/torch/data.py create mode 100644 adaptdl/adaptdl/torch/torch/data_test.py create mode 100644 adaptdl/adaptdl/torch/torch/epoch.py create mode 100644 adaptdl/adaptdl/torch/torch/epoch_test.py create mode 100644 adaptdl/adaptdl/torch/torch/gradient_noise_scale.py create mode 100644 adaptdl/adaptdl/torch/torch/gradient_noise_scale_test.py create mode 100644 adaptdl/adaptdl/torch/torch/iterator.py create mode 100644 adaptdl/adaptdl/torch/torch/parallel.py create mode 100644 adaptdl/adaptdl/torch/torch/parallel_test.py create mode 100644 adaptdl/adaptdl/torch/torch/scaling_rules.py create mode 100644 adaptdl/adaptdl/torch/torch/scaling_rules_test.py create mode 100644 tutorial/testcase_for_adaptdldataloader_refactor.py diff --git a/adaptdl/adaptdl/torch/torch/__init__.py b/adaptdl/adaptdl/torch/torch/__init__.py new file mode 100644 index 000000000..c9832e600 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/__init__.py @@ -0,0 +1,142 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +import os +if "darwin" in sys.platform.lower(): + # To avoid multiple runs of the model code + # https://pythonspeed.com/articles/python-multiprocessing/ + import multiprocessing + multiprocessing.set_start_method('fork') + +import logging +import portpicker +import requests +import torch.distributed +import pkg_resources + +import adaptdl.collective +import adaptdl.env +import semver +from .epoch import current_epoch, finished_epochs, remaining_epochs_until +from .data import current_dataloader, AdaptiveDataLoader, ElasticSampler +from .parallel import AdaptiveDataParallel +from .accumulator import Accumulator + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.INFO) + + +def version_check(version): + if semver.VersionInfo.isvalid(version) and \ + version != "0.0.0": + return True + else: + return False + + +def init_process_group(backend, + init_method=None, + world_size=None, + rank=None): + """ + Initializes the default distributed process group and the AdaptDL + collectives module. + + Args: + backend (str or Backend): The backend to use. Use "nccl" for multi-GPU + training else "gloo". + init_method (str, optional): URL specifying how to initialize the + process group. + world_size (int, optional): Number of processes participating in + the job + rank (int, optional): Rank of the current process (it should be a + number between 0 and ``world_size``-1). + + If init_method, world_size and rank is NOT provided, typically in the + Kubernetes environment, AdaptDL will try to infer them through environment + variables ADAPTDL_MASTER_ADDR, ADAPTDL_NUM_REPLICAS and + ADAPTDL_REPLICA_RANK respectively. + """ + if adaptdl.env.from_ray(): + from adaptdl_ray.adaptdl.utils import unique_nodes_pg + assert init_method is not None + assert world_size is not None + assert rank is not None + os.environ["ADAPTDL_NUM_NODES"] = str(unique_nodes_pg()) + os.environ["ADAPTDL_REPLICA_RANK"] = str(rank) + os.environ["ADAPTDL_NUM_REPLICAS"] = str(world_size) + + url = adaptdl.env.supervisor_url() + master_port = adaptdl.env.master_port() + if rank is None: + rank = adaptdl.env.replica_rank() + + if world_size is None: + world_size = adaptdl.env.num_replicas() + + if init_method is not None: + _, master_addr, master_port = init_method.split(":") + master_addr = master_addr[2:] + master_port = int(master_port) + elif url: + key = adaptdl.env.job_id() + group = adaptdl.env.num_restarts() + while True: + response = requests.get(url=f"{url}/discover/{key}/{group}") + if response.status_code != 408: # Timeout. + break + response.raise_for_status() + master_addr = response.json()[0] + sched_version = adaptdl.env.adaptdl_sched_version() + trainer_version = pkg_resources.get_distribution("adaptdl").version + if version_check(sched_version) and version_check(trainer_version): + trainer_ver_maj = semver.VersionInfo.parse(trainer_version).major + sched_ver_maj = semver.VersionInfo.parse(sched_version).major + if trainer_ver_maj != sched_ver_maj: + raise Exception('adaptdl version {} is incompatible with' + 'scheduler version {}'.format(trainer_version, + sched_version)) + else: + master_addr = adaptdl.env.master_addr() + + # Initialize collective module. + adaptdl.collective.initialize(master_addr, + master_port, + rank, + world_size) + + # Initialize torch.distributed. + torch_port = adaptdl.collective.broadcast(portpicker.pick_unused_port()) + init_method = "tcp://{}:{}?rank={}&world_size={}".format( + master_addr, torch_port, rank, world_size) + LOG.info("Initializing torch.distributed using %s", init_method) + torch.distributed.init_process_group(backend, init_method) + + LOG.info("torch.distributed initialized") + + +__all__ = [ + "init_process_group", + "current_epoch", + "finished_epochs", + "remaining_epochs_until", + "current_dataloader", + "AdaptiveDataLoader", + "ElasticSampler", + "AdaptiveDataParallel", + "Accumulator", +] diff --git a/adaptdl/adaptdl/torch/torch/_metrics.py b/adaptdl/adaptdl/torch/torch/_metrics.py new file mode 100644 index 000000000..038c22b35 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/_metrics.py @@ -0,0 +1,199 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import pickle +import time + +import numpy as np + +import adaptdl.checkpoint +import adaptdl.collective +import adaptdl.env +from adaptdl.goodput import GoodputFunction, fit_perf_params +from adaptdl.sched_hints import SCHED_HINTS, PERF_PARAMS, post_sched_hints + + +def profile_step_start(atomic_bsz): + state = _metrics_state() + state.atomic_bsz = atomic_bsz + state.step_start = time.time() + state.sync_time = 0.0 + + +def profile_sync_time(sync_time): + _metrics_state().sync_time += sync_time + + +_PREV_REPORT = None + + +def profile_step_commit(accumulation_step=False): + global _PREV_REPORT + state = _metrics_state() + step_time = time.time() - state.step_start + num_nodes = adaptdl.env.num_nodes() + num_replicas = adaptdl.env.num_replicas() + key = (num_nodes, num_replicas, state.atomic_bsz) + if accumulation_step: + state.profile[key]["accum_step_time"] += step_time + state.profile[key]["accum_count"] += 1 + else: + state.profile[key]["optim_step_time"] += step_time + state.profile[key]["optim_sync_time"] += state.sync_time + state.profile[key]["optim_count"] += 1 + del state.atomic_bsz + del state.step_start + del state.sync_time + if not accumulation_step: + if _PREV_REPORT is None: + _PREV_REPORT = time.time() + if adaptdl.env.replica_rank() == 0 and time.time() - _PREV_REPORT > 30: + _fit_perf_params() + _report_sched_hints() + _PREV_REPORT = time.time() + + +_GRAD_PARAM_DICT = {} + + +def update_grad_params(edp_key, grad_norm_sqr, grad_variance): + global _GRAD_PARAM_DICT + _GRAD_PARAM_DICT[edp_key] = np.asarray([grad_norm_sqr, grad_variance]) + grad_params = sum(_GRAD_PARAM_DICT.values()) + _metrics_state().grad_params = (grad_params[0], grad_params[1]) + + +def update_progress(progress): + _metrics_state().progress = progress + + +def get_progress(): + return _metrics_state().progress + + +def set_batch_size(init_batch_size, max_batch_size, local_bsz_bounds, + gradient_accumulation): + state = _metrics_state() + state.init_batch_size = init_batch_size + state.max_batch_size = max_batch_size + state.local_bsz_bounds = local_bsz_bounds + state.gradient_accumulation = gradient_accumulation + + +def get_goodput_fn(): + state = _metrics_state() + if state.grad_params is None or state.perf_params is None: + return None + return GoodputFunction(state.perf_params, state.grad_params, + state.init_batch_size) + + +def _fit_perf_params(): + state = _metrics_state() + profile = {k: v for k, v in state.profile.items() if v.get("optim_count")} + # Convert profile into numpy arrays. + num_nodes, num_replicas, atomic_bsz = ( + np.array(k) for k in zip(*profile.keys())) + accum_step_time = np.array([v.get("accum_step_time", 0.0) + for v in profile.values()]) + accum_count = np.array([v.get("accum_count", 0) for v in profile.values()]) + optim_step_time = np.array([v.get("optim_step_time", 0.0) + for v in profile.values()]) + optim_sync_time = np.array([v.get("optim_sync_time", 0.0) + for v in profile.values()]) + optim_count = np.array([v.get("optim_count", 0) for v in profile.values()]) + assert np.all(optim_count > 0) + # Non-sync time during optimization steps should be approximately equal to + # accumulation step time, combine those data points. + assert np.all(optim_step_time >= optim_sync_time) + accum_step_time += optim_step_time - optim_sync_time + accum_count += optim_count + accum_step_time /= accum_count + optim_step_time /= optim_count + state.perf_params = fit_perf_params(num_nodes, num_replicas, atomic_bsz, + accum_step_time, optim_step_time) + + +def _get_sched_hints(): + state = _metrics_state() + if len(state.profile) == 0: + return None + _fit_perf_params() + return _metrics_state() + + +def _report_sched_hints(): + assert adaptdl.env.replica_rank() == 0 + state = _metrics_state() + # Scheduling hints + sched_hints = SCHED_HINTS.copy() + sched_hints["perfParams"] = {k: v for (k, v) in + zip(PERF_PARAMS.keys(), + state.perf_params)} + sched_hints["maxBatchSize"] = state.max_batch_size + sched_hints["localBszBounds"] = state.local_bsz_bounds + sched_hints["initBatchSize"] = state.init_batch_size + if state.grad_params: + sched_hints["gradParams"] = {} + sched_hints["gradParams"]["norm"] = state.grad_params[0] + sched_hints["gradParams"]["var"] = state.grad_params[1] + sched_hints["maxProfiledReplicas"] = max(key[1] for key in state.profile) + sched_hints["gradientAccumulation"] = state.gradient_accumulation + post_sched_hints(sched_hints, adaptdl.env.job_id()) + + +class _MetricsState(adaptdl.checkpoint.State): + def __init__(self): + super().__init__("adaptdl-metrics") + self.profile = collections.defaultdict(collections.Counter) + self.perf_params = None + self.grad_params = None + self.init_batch_size = None + self.max_batch_size = None + self.local_bsz_bounds = None + self.gradient_accumulation = False + self.progress = 0.0 # Progress in scale-invariant iterations. + + def save(self, fileobj): + pickle.dump(self.profile, fileobj) + pickle.dump(self.perf_params, fileobj) + pickle.dump(self.grad_params, fileobj) + pickle.dump(self.init_batch_size, fileobj) + pickle.dump(self.max_batch_size, fileobj) + pickle.dump(self.local_bsz_bounds, fileobj) + pickle.dump(self.gradient_accumulation, fileobj) + pickle.dump(self.progress, fileobj) + + def load(self, fileobj): + self.profile = pickle.load(fileobj) + self.perf_params = pickle.load(fileobj) + self.grad_params = pickle.load(fileobj) + self.init_batch_size = pickle.load(fileobj) + self.max_batch_size = pickle.load(fileobj) + self.local_bsz_bounds = pickle.load(fileobj) + self.gradient_accumulation = pickle.load(fileobj) + self.progress = pickle.load(fileobj) + + +def _metrics_state(): + global _METRICS_STATE + if _METRICS_STATE is None: + _METRICS_STATE = _MetricsState() + adaptdl.checkpoint.load_state(_METRICS_STATE) + return _METRICS_STATE + + +_METRICS_STATE = None diff --git a/adaptdl/adaptdl/torch/torch/_metrics_test.py b/adaptdl/adaptdl/torch/torch/_metrics_test.py new file mode 100644 index 000000000..ee69d9a2f --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/_metrics_test.py @@ -0,0 +1,158 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from adaptdl.conftest import elastic_multiprocessing + + +@pytest.mark.parametrize("num_replicas", [1, 2, 3, 4]) +@elastic_multiprocessing +def test_profile(num_replicas): + import adaptdl.checkpoint + from adaptdl.env import num_restarts + from adaptdl.torch._metrics import ( + profile_step_start, profile_sync_time, + profile_step_commit, _metrics_state) + if num_restarts() == 0: + profile = _metrics_state().profile + assert len(profile) == 0 + # Profile local_bsz=1 but don't commit. + profile_step_start(1) + profile_sync_time(1.0) + # Profile local_bsz=2 and commit. + profile_step_start(2) + profile_sync_time(1.0) + profile_sync_time(2.0) + profile_step_commit() + # Ensure profile is updated correctly. + profile = _metrics_state().profile + key = (1, 1, 2) + assert len(profile) == 1 + assert profile[key]["accum_count"] == 0 + assert profile[key]["optim_count"] == 1 + assert profile[key]["optim_sync_time"] == 3.0 + assert profile[key]["optim_step_time"] > 0.0 + # Checkpoint and restart. + adaptdl.checkpoint.save_all_states() + return num_replicas + elif num_restarts() == 1: + profile = _metrics_state().profile + # Ensure checkpoint is loaded correctly. + key = (1, 1, 2) + assert len(profile) == 1 + assert profile[key]["accum_count"] == 0 + assert profile[key]["optim_count"] == 1 + assert profile[key]["optim_sync_time"] == 3.0 + assert profile[key]["optim_step_time"] > 0.0 + # Profile local_bsz=3 and commit twice. + profile_step_start(3) + profile_sync_time(2.0) + profile_sync_time(3.0) + profile_step_commit() + key = (1, num_replicas, 3) + old_step_time = profile[key]["optim_step_time"] + profile_step_start(3) + profile_sync_time(3.0) + profile_sync_time(4.0) + profile_step_commit() + # Ensure profile is updated correctly. + assert len(profile) == 2 + assert profile[key]["accum_count"] == 0 + assert profile[key]["optim_count"] == 2 + assert profile[key]["optim_sync_time"] == 12.0 + assert profile[key]["optim_step_time"] > old_step_time > 0.0 + + +@pytest.mark.parametrize("num_replicas", [1, 2, 3, 4]) +@elastic_multiprocessing +def test_profile_accumulation(num_replicas): + import adaptdl.checkpoint + from adaptdl.env import num_restarts + from adaptdl.torch._metrics import ( + profile_step_start, profile_sync_time, + profile_step_commit, _metrics_state, _fit_perf_params) + if num_restarts() == 0: + profile = _metrics_state().profile + assert len(profile) == 0 + # Profile local_bsz=1 but don't commit. + profile_step_start(1) + profile_sync_time(1.0) + # Profile local_bsz=2 and commit. + profile_step_start(2) + profile_step_commit(accumulation_step=True) + profile_step_start(2) + profile_step_commit(accumulation_step=True) + profile_step_start(2) + profile_sync_time(4.0) + profile_step_commit(accumulation_step=False) + profile_step_start(5) + profile_step_commit(accumulation_step=True) + profile_step_start(5) + profile_step_commit(accumulation_step=True) + profile_step_start(5) + profile_sync_time(6.0) + profile_step_commit(accumulation_step=False) + # Ensure profile is updated correctly. + profile = _metrics_state().profile + key = (1, 1, 2) + assert len(profile) == 2 + assert profile[key]["accum_count"] == 2 + assert profile[key]["optim_count"] == 1 + assert profile[key]["optim_sync_time"] == 4.0 + assert profile[key]["accum_step_time"] > 0.0 + assert profile[key]["optim_step_time"] > 0.0 + profile_step_start(3) + profile_step_commit(accumulation_step=True) + profile_step_start(3) + profile_step_commit(accumulation_step=True) + # Check that fitting parameters works even + # without a final accumulation_step=False commit + for val in profile.values(): + # Ensure step time is at least sync time. + val["optim_step_time"] += val["optim_sync_time"] + _fit_perf_params() + # Checkpoint and restart. + adaptdl.checkpoint.save_all_states() + return num_replicas + elif num_restarts() == 1: + profile = _metrics_state().profile + # Ensure checkpoint is loaded correctly. + key = (1, 1, 2) + assert len(profile) == 3 + assert profile[key]["accum_count"] == 2 + assert profile[key]["optim_count"] == 1 + assert profile[key]["optim_sync_time"] == 4.0 + assert profile[key]["optim_step_time"] > 0.0 + # Profile local_bsz=3 and commit twice. + profile_step_start(3) + profile_sync_time(2.0) + profile_sync_time(3.0) + profile_step_commit() + key = (1, num_replicas, 3) + old_step_time = profile[key]["optim_step_time"] + profile_step_start(3) + profile_sync_time(3.0) + profile_sync_time(4.0) + profile_step_commit() + # Ensure profile is updated correctly. + if num_replicas == 1: + assert len(profile) == 3 + else: + assert len(profile) == 4 + assert profile[key]["accum_count"] == 0 if num_replicas > 1 else 2 + assert profile[key]["optim_count"] == 2 + assert profile[key]["optim_sync_time"] == 12.0 + assert profile[key]["optim_step_time"] > old_step_time > 0.0 diff --git a/adaptdl/adaptdl/torch/torch/accumulator.py b/adaptdl/adaptdl/torch/torch/accumulator.py new file mode 100644 index 000000000..6851368ad --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/accumulator.py @@ -0,0 +1,312 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import collections.abc +import contextlib +import copy +import pickle + +import adaptdl.checkpoint +import adaptdl.collective +from adaptdl.torch.epoch import current_epoch +from adaptdl.torch.data import current_dataloader + + +class Accumulator(collections.abc.MutableMapping): + """ + This class helps aggregate simple statistics across all replicas in the + current job, and across any number of checkpoint-restarts. Can be used to + compute metrics like loss and accuracy, synchronized across each replica. + + Accumulators imitate python dictionaries, but with a few key differences + described below. Primarily, its usage and behavior depend on whether it is + set to *accumulation mode* or to *synchronized mode*. + + 1. **Accumulation mode:** the accumulator is being updated on all + replicas. Operations like ``accum["key"] += val`` or + ``accum.update(key=val)`` will aggregate the updates locally on each + replica, which are lazily synchronized in the background (either upon a + checkpoint or a switch to synchronized mode). Each replica may make + different updates, which are summed together when synchronized. While + accumulation mode is enabled, all read operations on the accumulator + will behave as if they were performed on an empty ``dict``, ie. + ``len(accum)`` will always return ``0``. By default, all accumulators + are set to accumulation mode. + 2. **Synchronized mode:** the accumulator contains the same data on every + replica, and the application must ensure that all write operations are + exactly the same across all replicas. While in synchronized mode, the + accumulator may be used as if it were a native python ``dict``, and all + read/write operations are supported. :meth:`Accumulator.synchronized` + may be used to enter synchronized mode. Upon entering synchronized + mode, the accumulator will automatically sum all updates from all + replicas to ensure the same data is available to each replica. + + Using accumulators, many training/validation metrics can be computed + easily and correctly in an elastic distributed setting. For example, a + simple validation step which calculates a loss and accuracy can be + implemented as follows: + + .. code-block:: python + + accum = Accumulator() # New accumulator starts in accumulation mode. + + for epoch in remaining_epochs_until(60): + + for batch in validloader: + ... + accum["loss_sum"] += + accum["correct"] += + accum["total"] += + + with accum.synchronized(): # Enter synchronized mode. + accum["loss_avg"] = accum["loss_sum"] / accum["total"] + accum["accuracy"] = accum["correct"] / accum["total"] + print("Loss: {}, Accuracy: {}".format( + accum["loss_avg"], accum["accuracy"])) + accum.clear() + # Back to accumulation mode. + + Arguments: + args: Positional arguments same as ``dict``. + kwargs: Keyword arguments same as ``dict``. + + .. automethod:: __iadd__ + .. automethod:: __isub__ + .. automethod:: __getitem__ + """ + def __init__(self, *args, **kwargs): + self._sync_count = collections.Counter() + self._synchronized = None + self._state = _AccumulatorState(*args, **kwargs) + adaptdl.checkpoint.load_state(self._state) + + @contextlib.contextmanager + def synchronized(self): + """ + A context manager which can be used to define the code to execute in + *synchronized* mode. Within the context manager, any code can interact + with this accumulator as if it were a regular Python ``dict``. The + application must ensure that whatever operations performed within this + context block are the same across all replicas. + + .. warning:: + Entering this context manager is a distributed synchronization + point! Please ensure that all replicas enter this context manager + at the same point in their code. + """ + if self._synchronized is not None: + # Already synchronized, don't need to do anything. + yield self + return + epoch = current_epoch() + # Remove saved results from all finished epochs. Since finished + # epochs are never replayed, they should never be needed again. + for key in list(self._state.results_history.keys()): + if key is not None and key < epoch: + self._state.results_history.pop(key) + # Get the number of synchronizations so far in the current epoch. + count = self._sync_count[epoch] + self._sync_count[epoch] += 1 + results_list = self._state.results_history[epoch] + assert count <= len(results_list) + if count < len(results_list): + # Results for this synchronization are saved in the history. + self._synchronized = results_list[count] + self._state.updates.clear() + else: + self._state.sync() # Sync results and updates across replicas. + if current_dataloader() is None: + # Only save into results history if outside of a dataloader + # iteration, since code inside iterations are not replayed. + results_list.append(copy.deepcopy(self._state.results)) + self._synchronized = self._state.results + try: + yield self + finally: + self._synchronized = None + + def update(self, *args, **kwargs): + """ + Apply a collection of key-update pairs. Unlike ``dict.update``, this + method *additively* applies the updates to the accumulated values. + + Arguments: + args: Positional arguments same as ``dict.update``. Can be a + mapping object or an iterable of key-update pairs. + kwargs: Keyword arguments same as ``dict.update``. Each keyword is + the string key corresponding to the provided update. + """ + for key, val in dict(*args, **kwargs).items(): + self[key] += val + + def subtract(self, *args, **kwargs): + """ + Apply a collection of key-update pairs. Unlike + :meth:`Accumulator.update`, this method *subtracts* the updates from + the accumulated values. + + Arguments: + args: Positional arguments same as :meth:`Accumulator.update`. + kwargs: Keyword arguments same as :meth:`Accumulator.update`. + """ + for key, val in dict(*args, **kwargs).items(): + self[key] -= val + + def __iadd__(self, other): + """ + Supports the += operation, e.g. ``accum += {key1: val1, key2: val2}``. + Behaves the same way as ``accum.update({key1: val1, key2: val2})``. + + Arguments: + other: Mapping object or an iterable of key-update pairs. + """ + self.update(other) + return self + + def __isub__(self, other): + """ + Supports the -= operation, e.g. ``accum -= {key1: val1, key2: val2}``. + Behaves the same way as ``accum.subtract({key1: val1, key2: val2})``. + + Arguments: + other: Mapping object or an iterable of key-update pairs. + """ + self.subtract(other) + return self + + def __getitem__(self, key): + """ + Supports indexing, e.g. ``val = accum[key]`` and ``accum[key] += 1``. + The former (read access) should only be used when the accumulator is in + synchronized mode. + + Arguments: + other: Key used to access a value in the accumulator. + """ + if self._synchronized is not None: + return self._synchronized.__getitem__(key) + # In accumulation mode, return a dummy object which captures all + # updates performed on it, to be later applied by __setitem__. + return _Value(self, key) + + def __setitem__(self, key, value): + if self._synchronized is not None: + return self._synchronized.__setitem__(key, value) + # Whenever an in-place addition or subtraction is done, like a[k] += v, + # python will essentially perform 3 steps: (1) tmp = a.__getitem__(k), + # (2) tmp += v, (3) a.__setitem__(k, tmp). In order to obtain the + # update v, we let a.__getitem__(k) return an opaque object which + # implements the __add__ operator to capture the update v in step (2). + # Then, a.__setitem__(k, tmp) can pull v out of tmp and accumulate it. + if not isinstance(value, _Value): + raise TypeError("invalid value type: {}".format(type(value))) + if value.accum is not self: + raise ValueError("incompatible {}".format(self.__class__.__name__)) + if key != value.key: + raise ValueError("incompatible key: {}".format(value.key)) + self._state.updates.setdefault(key, 0) + self._state.updates[key] += value.update + + # Rest of the abstract methods needed by collections.MutableMapping + + def __contains__(self, key): + if self._synchronized is not None: + return self._synchronized.__contains__(key) + return {}.__contains__(key) + + def __delitem__(self, key): + if self._synchronized is not None: + return self._synchronized.__delitem__(key) + return {}.__delitem__(key) + + def __iter__(self): + if self._synchronized is not None: + return self._synchronized.__iter__() + return {}.__iter__() + + def __len__(self): + if self._synchronized is not None: + return self._synchronized.__len__() + return {}.__len__() + + def __repr__(self): + if self._synchronized is not None: + return self._synchronized.__repr__() + return {}.__repr__() + + +class _Value(object): + __slots__ = ["accum", "key", "update"] + + def __init__(self, accum, key): + # Initialize the opaque object used for supporting "accum[k] += v" and + # "accum[k] -= v" operations. + self.accum = accum + self.key = key + self.update = 0 + + def __add__(self, update): + if isinstance(update, _Value): + raise TypeError("invalid update type: {}".format(type(update))) + self.update += update + return self + + def __sub__(self, update): + if isinstance(update, _Value): + raise TypeError("invalid update type: {}".format(type(update))) + self.update -= update + return self + + +class _AccumulatorState(adaptdl.checkpoint.State): + + # Assume accumulators are initialized in the same order in every replica. + # Keep a map of epoch -> number of accumulators initialized so far in that + # epoch, and use that count to construct a unique name for the state. + init_count = collections.Counter() + + def __init__(self, *args, **kwargs): + if current_dataloader() is not None: + raise RuntimeError("accumulator may not be initialized during " + "dataloader iteration") + epoch = current_epoch() + count = _AccumulatorState.init_count[epoch] + super().__init__("adaptdl-accumulator-epoch{}-{}".format(epoch, count)) + _AccumulatorState.init_count[epoch] += 1 + + self.results_history = collections.defaultdict(list) + self.results = dict(*args, **kwargs) + self.updates = {} + + def save(self, fileobj): + pickle.dump((self.results_history, self.results), fileobj) + + def load(self, fileobj): + self.results_history, self.results = pickle.load(fileobj) + + def sync(self): + # Aggregate pending updates across all replicas and apply them. + updates = adaptdl.collective.allreduce(self.updates, _dict_iadd) + _dict_iadd(self.results, updates) + self.updates.clear() + + +def _dict_iadd(a, b): + for k, v in b.items(): + if k in a: + a[k] += v + else: + a[k] = v + return a diff --git a/adaptdl/adaptdl/torch/torch/accumulator_test.py b/adaptdl/adaptdl/torch/torch/accumulator_test.py new file mode 100644 index 000000000..6dc49bb58 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/accumulator_test.py @@ -0,0 +1,60 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from adaptdl.conftest import elastic_multiprocessing +from adaptdl.torch.accumulator import Accumulator + + +@elastic_multiprocessing +def test_accumulator_restarts(): + import adaptdl.checkpoint + import adaptdl.collective + from adaptdl.env import num_restarts, replica_rank + adaptdl.collective.initialize("0.0.0.0") + accum = Accumulator() + + if num_restarts() == 0: + accum["a"] += 15 # Idempotent. + assert "a" not in accum + with accum.synchronized(): + assert "a" in accum + assert accum["a"] == 15 + assert "a" not in accum + if num_restarts() == 0: + accum["a"] -= 5 # Idempotent. + adaptdl.checkpoint.save_all_states() + return 4 # Restart with 4 replicas. + + if num_restarts() == 1: # Idempotent. + accum.update({"a": replica_rank(), "b": replica_rank()}) + assert len(accum) == 0 + with accum.synchronized(): + assert len(accum) == 2 + assert accum["a"] == 16 + assert accum["b"] == 6 + assert len(accum) == 0 + if num_restarts() == 1: + adaptdl.checkpoint.save_all_states() + return 2 # Restart with 2 replicas. + + if num_restarts() == 2: # Idempotent. + accum -= {"b": 5, "c": 5} + with accum.synchronized(): + assert accum["a"] == 16 + assert accum["b"] == -4 + assert accum["c"] == -10 + accum.clear() + with accum.synchronized(): + assert not accum diff --git a/adaptdl/adaptdl/torch/torch/context.py b/adaptdl/adaptdl/torch/torch/context.py new file mode 100644 index 000000000..a130a592d --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/context.py @@ -0,0 +1,98 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import math +import adaptdl.checkpoint +import adaptdl.collective +import adaptdl.env +from adaptdl.torch._metrics import get_goodput_fn +import adaptdl.torch.data +from adaptdl.torch.scaling_rules import ScalingRuleBase + +class AdaptiveDLContext(object): + """ + This class provides context tool to get AdaptDL-suggest parameters, + such as batch_size, accum_steps and lr_scale. + """ + + def __init__(self, batch_size): + self._elastic = adaptdl.torch.data.AdaptiveDataLoaderHelper(batch_size) + # Autoscale batch size fields. + self._speedup_threshold = 1.05 + self.adapt_batch_size = None + self.adapt_accum_steps = None + self.adapt_lr_scale = None + + def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, + gradient_accumulation=False): + self._elastic.autoscale_batch_size(max_batch_size, local_bsz_bounds, + gradient_accumulation) + + def get_batch_size(self): + _, self.adapt_batch_size, _ = self._sync_local_bsz() + return self.adapt_batch_size + + def get_accum_steps(self): + _, _, self.adapt_accum_steps = self._sync_local_bsz() + return self.adapt_accum_steps + + def get_lr_scale(self): + self.adapt_lr_scale = ScalingRuleBase._get_adapt_lr_scale() + return float(self.adapt_lr_scale) + + def _sync_local_bsz(self): + goodput_fn = get_goodput_fn() + if self._elastic.max_batch_size is None or goodput_fn is None: + # No autoscale batch size, just divide batch size evenly. + self._elastic._state.current_local_bsz = math.ceil( + self._elastic.batch_size / adaptdl.env.num_replicas()) + self._elastic._state.accumulation_steps = 0 + elif not self._elastic._state.current_local_bsz: + # if init, use the batch size suggested + _, atomic_bsz, accum_steps = goodput_fn.optimize( + adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), + max_batch_size=self._elastic._max_batch_size, + atomic_bsz_range=self._elastic._local_bsz_bounds, + accumulation=self._elastic._gradient_accumulation) + self._elastic._state.current_local_bsz = atomic_bsz + self._elastic._state.accumulation_steps = accum_steps + else: + # if not first time, we check against the relative speedup + suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize( + adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), + max_batch_size=self._elastic._max_batch_size, + atomic_bsz_range=self._elastic._local_bsz_bounds, + accumulation=self._elastic._gradient_accumulation) + # get current goodput + current_goodput = goodput_fn( + adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), + self._elastic.current_local_bsz, self._elastic.accumulation_steps) + # use only if speedup is significant + speedup = suggest_goodput / max(current_goodput, 1e-8) + if speedup > self._speedup_threshold: + self._elastic._state.current_local_bsz = atomic_bsz + self._elastic._state.accumulation_steps = accum_steps + self._elastic._state.current_local_bsz, self._elastic._state.accumulation_steps = \ + adaptdl.collective.broadcast((self._elastic._state.current_local_bsz, + self._elastic._state.accumulation_steps)) + return self._elastic.current_local_bsz, self._elastic._state.current_local_bsz, self._elastic._state.accumulation_steps + + @property + def training(self): + return self._elastic.training + + def to_tensorboard(self, writer, global_step, tag_prefix=""): + self._elastic.to_tensorboard(writer, global_step, tag_prefix) + # to_tensorboard.__doc__ = adaptdl.torch.data.AdaptiveDataLoaderHelper.to_tensorboard.__doc__ \ No newline at end of file diff --git a/adaptdl/adaptdl/torch/torch/data.py b/adaptdl/adaptdl/torch/torch/data.py new file mode 100644 index 000000000..8da7ce7a0 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/data.py @@ -0,0 +1,492 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from contextlib import contextmanager +import collections +import functools +import logging +import math +import numpy as np +import pickle +import random +import torch +from torch.utils.data import DataLoader, Sampler + +import adaptdl.checkpoint +import adaptdl.collective +import adaptdl.env +from adaptdl.torch.epoch import current_epoch +from adaptdl.torch._metrics import ( + profile_step_start, profile_step_commit, + set_batch_size, get_goodput_fn, get_progress) +from adaptdl._signal import get_exit_flag +from adaptdl.torch.context import AdaptiveDLContext + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.INFO) + + +class ElasticSampler(Sampler): + """ + A PyTorch Sampler which partitions data samples across multiple replicas, + and supports deterministic continuing across checkpoint-restarts. Shuffling + is deterministic for each epoch, and :meth:`ElasticSampler.set_epoch` + should be invoked to obtain different orderings in different epochs. + + Arguments: + dataset (torch.util.data.Dataset): The dataset to sample from. + shuffle (bool): Whether the data samples should be shuffled. + + .. automethod:: __iter__ + .. automethod:: __len__ + """ + def __init__(self, dataset, shuffle=True): + self.dataset = dataset + self.shuffle = shuffle + self.num_replicas = adaptdl.env.num_replicas() + self.rank = adaptdl.env.replica_rank() + self.epoch = 0 + self.index = 0 + + def __iter__(self): + """ + Iterate through the samples in the dataset, in the order defined for a + set epoch, starting at a set index. Produces only the indices for the + local replica. + + Returns: Iterator over data sample indices. + """ + if self.shuffle: + # Deterministically shuffle based on epoch. + g = torch.Generator() + g.manual_seed(hash((self.epoch, self.index // len(self.dataset)))) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + base_index = self.index % len(self.dataset) + + # Subsample. + local_indices = indices[base_index + self.rank::self.num_replicas] + + # Add extra samples to make it evenly divisible. + if len(local_indices) < len(self): + local_indices.append(indices[self.rank]) + assert len(local_indices) == len(self) + return iter(local_indices) + + def __len__(self): + """ + The total number of samples to be iterated through, starting at the set + index, for the local replica. + + Returns (int): Number of samples. + """ + base_index = self.index % len(self.dataset) + return math.ceil((len(self.dataset) - base_index) / self.num_replicas) + + def set_epoch(self, epoch, index=0): + """ + Set the epoch to derive samples from. Optional argument ``index`` can + be specified to start sampling from a particular index, e.g. after a + checkpoint-restart. + + Arguments: + epoch (int): The epoch to sample from. + index (int): The index to start sampling from. + """ + self.epoch = epoch + self.index = index + + +def current_dataloader(): + """ + Reference to the data loader currently being iterated. + + Returns (AdaptiveDataLoaderHelper): Current data loader. + """ + return AdaptiveDataLoaderHelper._current + + +class AdaptiveDataLoaderHelper(object): + """ + This class provides fine-grained control over adaptive training loops. It + can be used for building more user-friendly custom data loaders, such as + :class:`AdaptiveDataLoader`. + + Arguments: + batch_size (int): The target total batch size across all replicas. The + actual total batch size may be different due to rounding (each + replica must have the same local batch size), or being scaled up + using adaptive batch sizes. + """ + + # Epoch -> the number of dataloader loops completed so far in that epoch, + # across all AdaptiveDataLoader objects. + _position = collections.Counter() + _training = None # The AdaptiveDataLoader which loads training data. + _current = None # The AdaptiveDataLoader which is currently iterating. + + def __init__(self, batch_size=1): + # Autoscale batch size fields. + self._max_batch_size = None + self._local_bsz_bounds = None + # Create and load state. + self._state = _AdaptiveDataLoaderState() + adaptdl.checkpoint.load_state(self._state) + self.batch_size = batch_size + self.future_exit = None + self._gradient_accumulation = False + self._speedup_threshold = 1.05 + self._accum_count = 0 + + @property + def current_index(self): + """ + The total number of data samples processed so far in the current loop. + Includes the data processed by all replicas. ``None`` if this data + loader is not currently being iterated. + """ + if AdaptiveDataLoaderHelper._current is not self: + return None + return self._state.current_index + + @current_index.setter + def current_index(self, index): + if AdaptiveDataLoaderHelper._current is not self: + return + self._state.current_index = index + + @property + def end_index(self): + """ + (Optional) Can be used to track the end index of dataset across + restarts. + """ + return self._state.end_index + + @end_index.setter + def end_index(self, index): + """ + (Optional) Supports mutations of end_index + """ + self._state.end_index = index + + @property + def max_batch_size(self): + """ + The maximum total batch size allowed for adaptive batch size. ``None`` + if adaptive batch size is disabled. + """ + return self._max_batch_size + + @property + def local_bsz_bounds(self): + """ + The local batch size bounds on each replica. A pair of integers, + (min_local_bsz, max_local_bsz). + """ + return self._local_bsz_bounds + + @property + def current_local_bsz(self): + """ + The current logical local batch size used by the dataloader. + The batch size returned by the dataloader may be smaller if + gradient accumulation is used + """ + return self._state.current_local_bsz + + @property + def accumulation_steps(self): + """ + The number of batches returned by the dataloader before a + step is taken. + """ + return self._state.accumulation_steps + + def is_accum_step(self): + """ + Whether the current step's gradient will be accumulated. + """ + return self._accum_count < self._state.accumulation_steps + + def is_optim_step(self): + """ + Whether the optimizer step will be invoked in this step. + """ + return not self.is_accum_step() + + def train(self): + """ + Set this data loader to be the one used for training. Only one data + loader may be used for training. + """ + if AdaptiveDataLoaderHelper._training is None: + AdaptiveDataLoaderHelper._training = self + set_batch_size(self.batch_size, self.max_batch_size, + self.local_bsz_bounds, self._gradient_accumulation) + + def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, + gradient_accumulation=False): + """ + Enables adaptive batch size. Should be invoked once after the data + loader object is created. + + Arguments: + max_batch_size (int): Maximum total batch size allowed. + local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz), + the min and max local batch sizes allowed on each replica. + + Raises: + ValueError: If any of the provided batch size bounds are invalid. + """ + if not isinstance(max_batch_size, int) or \ + max_batch_size < self.batch_size: + raise ValueError("invalid max_batch_size") + if local_bsz_bounds is not None and ( + local_bsz_bounds[0] is not None and + local_bsz_bounds[0] > self.batch_size or + local_bsz_bounds[1] is not None and + local_bsz_bounds[1] < self.batch_size): + raise ValueError("invalid local_bsz_bounds") + self._max_batch_size = max_batch_size + self._local_bsz_bounds = local_bsz_bounds + self._gradient_accumulation = gradient_accumulation + self.train() + + @property + def training(self): + return self is AdaptiveDataLoaderHelper._training + + @contextmanager + def profile(self, commit): + """ + Every iteration of every epoch should be profiled under this context. + Note that, custom DataLoader writers should make sure that it gets + called equal number of times on each replica. + + Arguments: + commit (bool): Whether to commit the profiled results. + """ + # Synchronize the exit signal so all replicas exit after + # the same iteration. Do this asynchronously to prevent + # unnecessary blocking on the network. + if self.future_exit is not None and self.future_exit.result(): + adaptdl.checkpoint.save_all_states() + exit(143) # Standard exit code response to SIGTERM. + self.future_exit = adaptdl.collective.allreduce_async( + get_exit_flag(), lambda a, b: a or b) + profile_step_start(self.current_local_bsz) + yield + if commit: + profile_step_commit(self.is_accum_step()) + self._accum_count = (0 if self.is_optim_step() + else self._accum_count + 1) + + @contextmanager + def context(self): + """ + All iterators should be iterated under this context. It ensures + proper cleanup of elastic context at the end of each epoch. + """ + epoch = current_epoch() + try: + if AdaptiveDataLoaderHelper._current is not None: + raise RuntimeError("overlapping dataloader \ + iterations detected") + AdaptiveDataLoaderHelper._current = self + yield + finally: + self._state.current_index = 0 + self._state.end_index = 0 + self._state.last_position[epoch] = self._position[epoch] + self._position[epoch] += 1 + AdaptiveDataLoaderHelper._current = None + + @property + def current_batch_size(self): + return (self.current_local_bsz * (self.accumulation_steps + 1) * + adaptdl.env.num_replicas()) + + def skipdone(self): + """ + Should be called just after entering the `_elastic` context to make + sure that the dataloader loop is not replayed if has already finished + before a restart. + """ + + epoch = current_epoch() + position = self._position[epoch] + if position <= self._state.last_position.get(epoch, -1): + # Already completed the dataloader loop at the current + # position, skip this loop and keep replaying the application + # code. + LOG.info("skipping %s loop at position %s in epoch %s", + self.__class__.__name__, position, epoch) + self._position[epoch] += 1 + return True + else: + return False + + def to_tensorboard(self, writer, global_step, tag_prefix=""): + """ + Output some useful metrics to TensorBoard. + + Arguments: + writer (torch.utils.tensorboard.SummaryWriter): ``SummaryWriter`` + object to output metrics to. + global_step (int): Global step value to record. + tag_prefix (str): Prefix added to each metric's tag. + """ + if tag_prefix and not tag_prefix.endswith("/"): + tag_prefix += "/" + writer.add_scalar(tag_prefix + "Total_Batch_Size", + self.current_batch_size, global_step) + writer.add_scalar(tag_prefix + "Local_Batch_Size", + self.current_local_bsz, global_step) + writer.add_scalar(tag_prefix + "Accumulation_Steps", + self.accumulation_steps, global_step) + + +def _worker_init_wrapper(worker_init_fn, num_workers): + # Set globally-unique python and numpy seeds for each worker. + + @functools.wraps(worker_init_fn) + def wrapper(worker_id): + nonlocal num_workers + num_workers = num_workers or 1 + # https://pytorch.org/docs/master/data.html#randomness-in-multi-process-data-loading. + seed = torch.initial_seed() + adaptdl.env.replica_rank() * num_workers + torch.manual_seed(seed) + np.random.seed(seed % 2 ** 32) + random.seed(seed) + if worker_init_fn is not None: + return worker_init_fn(worker_id) + return wrapper + + +class AdaptiveDataLoader(DataLoader, AdaptiveDLContext): + """ + This class is a PyTorch DataLoader that also supports adaptive batch sizes + and checkpoint-restart elasticity. Applications can typically use objects + of this class as direct replacements for PyTorch DataLoaders. However, some + notable differences are: + + 1. The ``batch_size`` argument defines the target total batch size across + all replicas, rather than the local batch size on each replica. + 2. Custom ``sampler`` and ``batch_sampler`` are not supported. + 3. Iterating through the dataloader is only allowed from within an epoch + loop (see :mod:`adaptdl.torch.epoch`), and only one dataloader loop is + allowed at any given time. + + Arguments: + dataset (torch.util.data.Dataset): Dataset from which to load the data. + batch_size (int): The target total batch size across all replicas. The + actual total batch size may be different due to rounding (each + replica must have the same local batch size), or being scaled up + using adaptive batch sizes. + shuffle (bool): Whether the data is reshuffled at every epoch. + **kwargs: Keyword arguments passed to ``torch.util.data.Dataloader``. + + Raises: + ValueError: If ``sampler`` or ``batch_sampler`` are not ``None``. + + .. automethod:: __iter__ + """ + def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): + if kwargs.get("batch_sampler") is not None \ + or kwargs.get("sampler") is not None: + raise ValueError("AdaptiveDataLoader does not support " + "custom 'sampler' or 'batch_sampler'") + # Custom sampler is incompatible with shuffle=True, so we always set + # shuffle=False in __init__ and let our own sampler do the shuffling. + kwargs["sampler"] = ElasticSampler(dataset, shuffle=shuffle) + kwargs["worker_init_fn"] = _worker_init_wrapper( + kwargs.get("worker_init_fn"), kwargs.get("num_workers")) + super().__init__(dataset, batch_size, shuffle=False, **kwargs) + AdaptiveDLContext.__init__(self, batch_size) + + def __iter__(self): + """ + Iterate over batches of data. When adaptive batch size is disabled, + stops after the entire dataset has been processed once in total by all + replicas. This means if there are K replicas, then this method will + iterate over ~1/K of the dataset. When adaptive batch size is enabled, + stops after making enough statistical progress roughly equivalent to + one pass over the dataset with non-adaptive batch size. In this case, + the dataset may be processed more than once. + + A checkpoint-restart may be triggered in-between each batch. In this + case, the current iteration state will be saved and restored after the + restart, and continue where it left off. + """ + epoch = current_epoch() + num_replicas = adaptdl.env.num_replicas() + with self._elastic.context(): + if self._elastic.skipdone(): + return + done = False + while not done: + self.sampler.set_epoch( + epoch, index=self._elastic.current_index) + self.batch_sampler.batch_size = self.get_batch_size() + for idx, batch in enumerate(super().__iter__()): + with self._elastic.profile(self.training and idx >= 1): + yield batch + # Increment by the number of data samples processed + self._elastic.current_index += \ + num_replicas * self.batch_sampler.batch_size + if self._elastic.max_batch_size is not None and \ + get_progress() >= len(self.dataset) * \ + (epoch + 1) / self.batch_size: + done = True + break + if self._elastic.max_batch_size is None: + done = True + self._elastic.current_index -= \ + self._elastic.current_index % -len(self.dataset) + + +class _AdaptiveDataLoaderState(adaptdl.checkpoint.State): + + # Assume dataloaders are initialized in the same order in every replica. + # Keep a map of epoch -> number of dataloaders initialized so far in that + # epoch, and use that count to construct a unique name for the state. + init_count = collections.Counter() + + def __init__(self): + if current_dataloader() is not None: + raise RuntimeError("dataloader may not be initialized during " + "dataloader iteration") + epoch = current_epoch() + count = _AdaptiveDataLoaderState.init_count[epoch] + super().__init__("adaptdl-dataloader-epoch{}-{}".format(epoch, count)) + _AdaptiveDataLoaderState.init_count[epoch] += 1 + + self.current_index = 0 # Index within the current dataloader loop. + self.end_index = 0 # End index of the current DataLoader loop. + self.last_position = {} # Epoch -> position of last completed loop. + self.current_local_bsz = 0 + self.accumulation_steps = 0 + + def save(self, fileobj): + pickle.dump((self.current_index, self.end_index, + self.last_position), fileobj) + + def load(self, fileobj): + self.current_index, self.end_index, self.last_position = \ + pickle.load(fileobj) diff --git a/adaptdl/adaptdl/torch/torch/data_test.py b/adaptdl/adaptdl/torch/torch/data_test.py new file mode 100644 index 000000000..e4742263f --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/data_test.py @@ -0,0 +1,168 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import math + +import pytest +import torch +import torchtext +from torch.utils.data import TensorDataset +from torchtext.data.utils import get_tokenizer + +from adaptdl.conftest import elastic_multiprocessing +from adaptdl.torch.data import (ElasticSampler, AdaptiveDataLoader, + current_dataloader) +from adaptdl.torch.iterator import AdaptiveBPTTIterator + + +@pytest.mark.parametrize("num_replicas", [1, 3, 5]) +@pytest.mark.parametrize("dataset_size", [9, 15, 25]) +def test_sampler_epoch(num_replicas, dataset_size, + epoch=0, index=0, shuffle=True): + dataset = TensorDataset(torch.rand(dataset_size)) + sampler = ElasticSampler(dataset, shuffle=shuffle) + sampler.num_replicas = num_replicas + sampler.set_epoch(epoch, index) + epoch_samples = [] + sample_counts = collections.Counter() + for rank in range(num_replicas): + sampler.rank = rank + epoch_samples.append(list(sampler)) + # Check indices are split evenly between replicas. + assert len(sampler) == \ + math.ceil((dataset_size - index % dataset_size) / num_replicas) + # Check the actual samples obey the length. + assert len(sampler) == len(epoch_samples[rank]) + # Check ordering is the same within the same epoch. + assert list(sampler) == epoch_samples[rank] + sample_counts.update(epoch_samples[rank]) + # Check all indices are present. + assert len(sample_counts) >= dataset_size - index % dataset_size + assert all(0 <= key < dataset_size for key in sample_counts) + # Check each index is counted roughly the same number of times. + assert max(sample_counts.values()) - min(sample_counts.values()) <= 1 + return epoch_samples + + +@pytest.mark.parametrize("num_replicas", [1, 3, 5]) +@pytest.mark.parametrize("dataset_size", [9, 15, 25]) +def test_sampler_shuffle(num_replicas, dataset_size): + epoch0_samples = test_sampler_epoch(num_replicas, dataset_size, epoch=0) + epoch1_samples = test_sampler_epoch(num_replicas, dataset_size, epoch=1) + assert epoch0_samples != epoch1_samples # Shuffle is on. + epoch0_samples = test_sampler_epoch(num_replicas, dataset_size, + epoch=0, shuffle=False) + epoch1_samples = test_sampler_epoch(num_replicas, dataset_size, + epoch=1, shuffle=False) + assert epoch0_samples == epoch1_samples # Shuffle is off. + + +@pytest.mark.parametrize("num_replicas", [1, 3, 5]) +@pytest.mark.parametrize("dataset_size", [9, 15, 25]) +def test_sampler_index(num_replicas, dataset_size): + index = dataset_size // 2 # Set index to halfway through the dataset. + epoch_samples = test_sampler_epoch(num_replicas, dataset_size, + index=index, shuffle=False) + samples = sum(epoch_samples, []) + # Check contains second half of dataset. + for idx in range(index, dataset_size): + assert idx in samples + + index = 2 * dataset_size # Test sampler wrap-around. + epoch_samples = test_sampler_epoch(num_replicas, dataset_size, + index=index, shuffle=False) + assert set(sum(epoch_samples, [])) == set(range(dataset_size)) + + +@elastic_multiprocessing +def test_dataloader_restarts(): + import adaptdl.checkpoint + import adaptdl.collective + from adaptdl.env import num_restarts, num_replicas + adaptdl.collective.initialize("0.0.0.0") + dataset_size = 100 + init_batch_size = 10 + dataset = TensorDataset(torch.rand(dataset_size)) + dataloader = AdaptiveDataLoader(dataset, batch_size=init_batch_size) + # Load data samples in the following order: + # 2 batches (20 samples) using 1 replica (local_bsz = 10, batch_size = 10) + # 5 batches (60 samples) using 4 replica (local_bsz = 3, batch_size = 12) + # 2 batches (20 samples) using 2 replica (local_bsz = 5, batch_size = 10) + assert current_dataloader() is None + for idx, batch in enumerate(dataloader): + if num_restarts() == 0 and idx == 2: + adaptdl.checkpoint.save_all_states() + return 4 # Restart with 4 replicas. + if num_restarts() == 1 and idx == 5: + adaptdl.checkpoint.save_all_states() + return 2 # Restart with 2 replicas. + assert current_dataloader() is dataloader._elastic + local_bsz = batch[0].size(0) + assert dataloader.current_local_bsz == local_bsz + assert local_bsz == math.ceil(init_batch_size / num_replicas()) + assert dataloader.current_batch_size == num_replicas() * local_bsz + # After the last 2 batches. + assert idx == 1 + + +@elastic_multiprocessing +def test_dataloader_break(): + import adaptdl.checkpoint + import adaptdl.collective + from adaptdl.env import num_restarts + if num_restarts() == 0: + return 2 + adaptdl.collective.initialize("0.0.0.0") + dataset = TensorDataset(torch.rand(100)) + dataloader = AdaptiveDataLoader(dataset, batch_size=10) + # Break in the middle of the first for-loop, and start another for-loop. + assert current_dataloader() is None + for idx, batch in enumerate(dataloader): + assert current_dataloader() is dataloader._elastic + if idx == 5: + break + assert current_dataloader() is None + for idx, batch in enumerate(dataloader): + assert current_dataloader() is dataloader._elastic + assert idx == 9 # Run 10 batches total. + + +@elastic_multiprocessing +def test_bptt_iterator(): + import adaptdl.checkpoint + import adaptdl.collective + from adaptdl.env import num_restarts + adaptdl.collective.initialize("0.0.0.0") + # Load the iterator with 500 words + # 1 batch (5x10) using 1 replica. Restart after one iteration. + # 9 batches (5x5) using 2 replicas to consume remaining batches. + TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"), + init_token='', + eos_token='') + fields = [('text', TEXT)] + examples = [torchtext.data.Example.fromlist([['The'] * 500], fields)] + dataset = torchtext.data.Dataset(examples, fields) + TEXT.build_vocab(dataset) + bptt_iter = AdaptiveBPTTIterator(dataset, batch_size=10, bptt_len=5) + for idx, batch in enumerate(bptt_iter): + if num_restarts() == 0 and idx == 1: + assert batch.text.shape == (5, 10) + adaptdl.checkpoint.save_all_states() + return 2 + if adaptdl.env.num_replicas() == 2: + assert batch.text.shape == (5, 5) or batch.text.shape == (4, 5) + if adaptdl.env.num_replicas() == 2: + assert idx == 8 diff --git a/adaptdl/adaptdl/torch/torch/epoch.py b/adaptdl/adaptdl/torch/torch/epoch.py new file mode 100644 index 000000000..08b58ede1 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/epoch.py @@ -0,0 +1,178 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides tools for the top-level loop over epochs during training. +AdaptDL expects the training program to be implemented as loop over several +epochs, each containing a series of loops over datasets (e.g. one loop over the +training set followed by one loop over the validation set). The program can be +interrupted between every iteration of any dataset loop, trigger a checkpoint +to be taken, and restarted using a different set of replicas. + +**Due to checkpoint-restarts, parts of the training program may be executed +multiple times (e.g. once after each restart)!** To avoid incorrect execution, +ensure that your code is idempotent_ in the following locations: + +1. Immediately before any epoch loop (using :func:`remaining_epochs_until`). +2. Immediately before any dataset loop (using + :class:`adaptdl.torch.data.AdaptiveDataLoader`). + +Your code may be non-idempotent in other locations. + +.. code-block:: python + + ### IDEMPOTENT CODE ONLY ### + + for epoch in remaining_epochs_until(30): + + ### IDEMPOTENT CODE ONLY ### + + for batch in train_loader: + # ... any code ... + + ### IDEMPOTENT CODE ONLY ### + + for batch in valid_loader: + # ... any code ... + + # ... any code ... + + # ... any code ... + + ### END PROGRAM ### + +For example, a common non-idempotent operation is learning-rate annealing: + +.. code-block:: python + + for epoch in remaining_epochs_until(30): + + lr_scheduler.step() # (A) WRONG! + + for batch in train_loader: + # ... + + lr_scheduler.step() # (B) WRONG! + + for batch in valid_loader: + # ... + + lr_scheduler.step() # (C) OK! + +Location (A) will be executed again after any checkpoint-restart during either +the training or validation loop, resulting in the learning rate being annealed +several times in one epoch! Similarly with location (B), if checkpoint-restart +happens during the validation loop. + +Location (C) results in the correct behavior, because (1) an epoch will not be +repeated once it has finished, and (2) no checkpoint-restarts can occur between +the learning rate annealing and the end of the epoch. + +.. _idempotent: https://stackoverflow.com/a/1077421 +""" + +import logging +import pickle + +import adaptdl.checkpoint + + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.INFO) + + +def remaining_epochs_until(epoch): + """ + Iterate over epochs in a way that is consistent with checkpoint-restarts. + For example: + + .. code-block:: python + + for epoch in remaining_epochs_until(30): + print(current_epoch()) # Should print 0 through 29 + + for epoch in remaining_epochs_until(60): + print(current_epoch()) # Should print 30 through 59 + + If a checkpoint-restart happens during an epoch, all previous epochs will + be skipped after the program restarts. + + Arguments: + epoch (int): The epoch number to end at (exclusively). + + Raises: + RuntimeError: If invoked before a previous epoch loop has ended. + """ + if current_epoch() is not None: + raise RuntimeError("overlapping epoch loops detected") + if finished_epochs() < epoch: + LOG.info("starting at epoch %s", finished_epochs()) + else: + LOG.info("skipping all epochs up to %s", epoch) + while finished_epochs() < epoch: + _epoch_state().current_epoch = finished_epochs() + try: + yield current_epoch() + finally: + # Try to catch any exits from epoch loop, including breaks and + # Exceptions. See https://www.peterbe.com/plog/generatorexit. + _epoch_state().finished_epochs += 1 + _epoch_state().current_epoch = None + + +def current_epoch(): + """ + Get the current epoch while iterating with :func:`remaining_epochs_until`. + + Returns: + int or None: The current epoch number if called from within a + :func:`remaining_epochs_until` iteration, ``None`` otherwise. + """ + return _epoch_state().current_epoch + + +def finished_epochs(): + """ + Get the number of epochs finished using :func:`remaining_epochs_until`. + + Returns: + int: The number of finished epochs. Equal to :func:`current_epoch` + if called from within a :func:`remaining_epochs_until` iteration. + """ + return _epoch_state().finished_epochs + + +class _EpochState(adaptdl.checkpoint.State): + def __init__(self): + super().__init__(".adaptdl-epoch") + self.finished_epochs = 0 + self.current_epoch = None + + def save(self, fileobj): + pickle.dump(self.finished_epochs, fileobj) + + def load(self, fileobj): + self.finished_epochs = pickle.load(fileobj) + + +def _epoch_state(): + global _EPOCH_STATE + if _EPOCH_STATE is None: + _EPOCH_STATE = _EpochState() + adaptdl.checkpoint.load_state(_EPOCH_STATE) + return _EPOCH_STATE + + +_EPOCH_STATE = None diff --git a/adaptdl/adaptdl/torch/torch/epoch_test.py b/adaptdl/adaptdl/torch/torch/epoch_test.py new file mode 100644 index 000000000..464bb240c --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/epoch_test.py @@ -0,0 +1,42 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from adaptdl.conftest import elastic_multiprocessing + + +@elastic_multiprocessing +def test_epoch(): + import adaptdl.checkpoint + from adaptdl.env import num_restarts + from adaptdl.torch.epoch import (remaining_epochs_until, + current_epoch, finished_epochs) + total_epochs = 10 + restart_epoch = 5 + assert current_epoch() is None + if num_restarts() == 0: + assert finished_epochs() == 0 + expected_epochs = list(range(restart_epoch + 1)) + elif num_restarts() == 1: + assert finished_epochs() == restart_epoch + expected_epochs = list(range(restart_epoch, total_epochs)) + else: + assert False + for idx, epoch in enumerate(remaining_epochs_until(10)): + assert epoch == expected_epochs[idx] + assert current_epoch() == epoch + assert finished_epochs() == epoch + if num_restarts() == 0 and epoch == restart_epoch: + adaptdl.checkpoint.save_all_states() + return 5 # Restart with 5 replicas. diff --git a/adaptdl/adaptdl/torch/torch/gradient_noise_scale.py b/adaptdl/adaptdl/torch/torch/gradient_noise_scale.py new file mode 100644 index 000000000..2687644b9 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/gradient_noise_scale.py @@ -0,0 +1,330 @@ +import functools +import logging +import math +import numpy as np +import torch.distributed +import torch.optim + +from torch.autograd import Variable + +import adaptdl.utils + +__all__ = ["GradientNoiseScale"] + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.INFO) + + +def _average_groups(grads1, grads2): + ret = [] + for group1, group2 in zip(grads1, grads2): + ret.append([]) + for g1, g2 in zip(group1, group2): + if g1 is None: + ret[-1].append(g2) + elif g2 is None: + ret[-1].append(g1) + else: + ret[-1].append((g1 + g2) / 2) + return ret + + +def _normsqr_groups(grads, pinvs): + ret = [] + for group, pinv_group in zip(grads, pinvs): + normsqr = [(g / pinv).pow(2).sum(dtype=torch.float64) + for g, pinv in zip(group, pinv_group) if g is not None] + ret.append(sum(normsqr).item() if normsqr else 0.0) + return np.array(ret) + + +class GradientNoiseScale(object): + """This class tracks gradient related stats and takes care of gradient + accumulation.""" + def __init__(self, adp, optimizer, + mp_scaler=None, + num_replicas=None, + accum_scale=None): + self._adp = adp + self._optimizer = optimizer + self._orig_optimizer_zero_grad = optimizer.zero_grad + self._should_zero_grad = True + self._mp_scaler = mp_scaler + self._local_sqr = None + self._num_replicas = (num_replicas if num_replicas is not None + else torch.distributed.get_world_size()) + self._accum_scale = accum_scale or self._num_replicas + self._prev_grads = None + + self.reset_accumulation() + + self._optimizer.state.setdefault("gns", { + "progress": 0.0, + "prev_scale": 0.0, + # Averages of n and v + "sqr_avg": np.ones(len(optimizer.param_groups)), + "var_avg": np.zeros(len(optimizer.param_groups)), + # Whether estimates are biased (using differenced estimator). + "biased": False, + }) + + for idx, param_group in enumerate(self._optimizer.param_groups): + for param in param_group["params"]: + param.register_hook( + functools.partial(self._backward_hook, idx, param)) + self._callback_queued = False + self._smoothing = 0.999 + + @property + def _state(self): + return self._optimizer.state["gns"] + + def reset_accumulation(self): + """reset accumulation calculations and gradients.""" + self._orig_optimizer_zero_grad() + self._local_sqr = None + self._accum_count = 0 + + @property + def should_zero_grad(self): + return self._should_zero_grad + + @property + def accum_scale(self): + return self._accum_scale + + @property + def accum_count(self): + return self._accum_count + + def set_accum_scale(self, accum_scale): + if not np.isclose(self._accum_scale, accum_scale): + self.reset_accumulation() + self._accum_scale = accum_scale + + @property + def raw_sqr_avg(self): + view = self._state["sqr_avg"].view() + view.flags.writeable = False + return view + + def sqr_avg(self): + """ + Current estimate of the squared l2-norm of the true gradient (sigma + squared). + + Returns (float): Estimate of squared l2-norm. + """ + return float(np.sum(np.maximum(self._state["sqr_avg"], 0.0))) + + @property + def raw_var_avg(self): + view = self._state["var_avg"].view() + view.flags.writeable = False + return view + + def var_avg(self): + """ + Current estimate of the trace of the covariance of the true gradient + (mu squared). + + Returns (float): Estimate of trace of the covariance. + """ + return float(np.sum(np.maximum(self._state["var_avg"], 1e-6))) + + def get_progress(self): + return self._state["progress"] + + def set_progress(self, progress): + self._state["progress"] = progress + + def gain(self, scale): + """ + Current estimate of the GradientNoiseScale gain ratio. + + Arguments: + scale (float): The total scale to estimate the gain ratio for. + + Returns (float): Estimate of gain ratio. + """ + var = self.var_avg() + norm = self.sqr_avg() + return (var + norm) / (var / scale + norm) + + def _update_avg(self, param_name, value, factor): + biased = self._state.get(param_name + "_biased", 0.0) + unbias = self._state.get(param_name + "_unbias", 0.0) + biased = factor * biased + (1.0 - factor) * value + unbias = factor * unbias + (1.0 - factor) + self._state[param_name + "_biased"] = biased + self._state[param_name + "_unbias"] = unbias + self._state[param_name] = biased / unbias + + def _reset_avg(self, param_name): + self._state.pop(param_name + "_biased", None) + self._state.pop(param_name + "_unbias", None) + + @adaptdl.utils.print_exc + def _backward_hook(self, idx, param, grad): + # This method should be invoked once for each parameter during the + # backward pass, before gradients are synchronized between replicas. + if self._local_sqr is None: + self._local_sqr = torch.zeros(len(self._optimizer.param_groups), + device=grad.device, + dtype=torch.float64) + + # Get the preconditioning matrix for the optimizer + preconditioner = self._calculate_preconditioner(idx, param) + + # Update the local gradient square sum + self._local_sqr[idx] += \ + (grad.detach() / preconditioner).pow(2).sum(dtype=torch.float64) + if not self._callback_queued: + Variable._execution_engine.queue_callback(self._queue_callback) + self._callback_queued = True + + @adaptdl.utils.print_exc + def _queue_callback(self): + # This method should be invoked after the entire backward pass. We want + # to make sure self._final_callback is invoked once, only after all + # gradients have been synchronized between each replica. However, the + # synchronization code in DistributedDataParallel is also done in a + # callback, which might not yet be executed. Therefore, we enqueue + # self._final_callback from this method, which should ensure it is + # invoked after the gradient synchronization callback. + self._callback_queued = False + self._accum_count += 1 + if self._adp.require_backward_grad_sync: + # Asynchronously sum the local squared-gradient statistics. The + # actual gradient averaging should also be happening at the same + # time, until self._final_callback is invoked. + if self._num_replicas > 1: + self._async_op = torch.distributed.all_reduce(self._local_sqr, + async_op=True) + Variable._execution_engine.queue_callback(self._final_callback) + self._should_zero_grad = True + else: + # Keep on accumulating gradients, should not zero grad. + self._should_zero_grad = False + + @adaptdl.utils.print_exc + def _final_callback(self): + # This method should be invoked once the gradients have been + # synchronized between all replicas and accumulation steps. + if self._num_replicas > 1: + self._async_op.wait() + grads = [] + if self._mp_scaler is not None: + mixed_precision_scale = self._mp_scaler.get_scale() + else: + mixed_precision_scale = 1.0 + for group in self._optimizer.param_groups: + grads.append([]) + for param in group["params"]: + if param.grad is None: + grads[-1].append(None) + continue + grad = param.grad.detach().float() + grads[-1].append( + grad / mixed_precision_scale / self._accum_count) + preconditioner = self._get_preconditioner() + + # Note: mixed precision can result in nan/inf gradients, + # which propogate into our norm and variance estimates. + # Mixed precision autoscaling skips the skip where + # there are nan/inf, so we also skip the update here + grads_normsqr = _normsqr_groups(grads, preconditioner) + if not np.all(np.isfinite(grads_normsqr)): + LOG.warning("GradientNoiseScale detected invalid gradient! " + "Skipping step.") + return + count = self._num_replicas * self._accum_count + scale = self._accum_scale * self._accum_count + if count > 1: + # Average local squared-norm samples. + local_sqr = self._local_sqr.cpu().numpy() / count + # Gradient is squared in local_sqr, so need to square the + # mixed precision scale as well + local_sqr = (local_sqr / mixed_precision_scale ** 2) + total_sqr = grads_normsqr + if self._state["biased"]: + self._reset_avg("sqr_avg") + self._reset_avg("var_avg") + self._state["biased"] = False + self._prev_grads = None + else: + # Single gradient datapoint, use difference estimation. + if self._prev_grads is not None: + local_sqr = (_normsqr_groups(self._prev_grads, preconditioner) + + grads_normsqr) / 2 + avg_grads = _average_groups(grads, self._prev_grads) + total_sqr = _normsqr_groups(avg_grads, preconditioner) + count = 2 + scale = 2 * self._accum_scale + self._state["biased"] = True + self._prev_grads = [[g.clone() if g is not None else None + for g in group] for group in grads] + if count > 1: + grad_sqr = (count * total_sqr - local_sqr) / (count - 1) + grad_var = (local_sqr - total_sqr) * scale / (count - 1) + theta = self._smoothing ** scale + self._update_avg('sqr_avg', grad_sqr, theta) + self._update_avg('var_avg', grad_var, theta) + + def _get_preconditioner(self): + out = [] + for idx, group in enumerate(self._optimizer.param_groups): + pinvs = [] + for param in group["params"]: + pinv = self._calculate_preconditioner(idx, param) + pinvs.append(pinv) + out.append(pinvs) + return out + + def _calculate_preconditioner(self, idx, param): + return torch.ones_like(param, memory_format=torch.preserve_format) + + +class AdamGradientNoiseScale(GradientNoiseScale): + def __init__(self, adp, optimizer, + mp_scaler=None, + num_replicas=None, + accum_scale=None): + self._adam_param_group = {'beta': [], 'eps': []} + super().__init__(adp, optimizer, mp_scaler, num_replicas, accum_scale) + for idx, param_group in enumerate(self._optimizer.param_groups): + self._adam_param_group['beta'].append(param_group['betas'][1]) + self._adam_param_group['eps'].append(param_group['eps']) + + def _calculate_preconditioner(self, idx, param): + state = self._optimizer.state[param] + if state.get('step', 0) < 5: + return torch.ones_like(param, memory_format=torch.preserve_format) + + exp_avg_sq = state["exp_avg_sq"].clone() # not sure if clone is needed + beta2 = self._adam_param_group['beta'][idx] + eps = self._adam_param_group['eps'][idx] + correction = 1 - beta2 ** state['step'] + pinv = (exp_avg_sq.sqrt() / math.sqrt(correction)).add_(eps) + return pinv + + def _reset_adam_state(self, step=0): + for group in self._optimizer.param_groups: + beta1, beta2 = group["betas"] + for param in group["params"]: + state = self._optimizer.state[param] + if state.get("step", 0) > 0: + state["exp_avg"].mul_( + (1 - beta1 ** step) / (1 - beta1 ** state["step"])) + state["exp_avg_sq"].mul_( + (1 - beta2 ** step) / (1 - beta2 ** state["step"])) + state["step"] = step + + def _final_callback(self): + scale = self._accum_scale * self._accum_count + if not np.isclose(scale, self._state["prev_scale"]): + self._reset_adam_state() + # reset Adam states when scale is changed + self._state["prev_scale"] = scale + return super()._final_callback() diff --git a/adaptdl/adaptdl/torch/torch/gradient_noise_scale_test.py b/adaptdl/adaptdl/torch/torch/gradient_noise_scale_test.py new file mode 100644 index 000000000..5f5112a36 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/gradient_noise_scale_test.py @@ -0,0 +1,60 @@ +import numpy as np +import pytest +import torch +import random + +from unittest.mock import Mock + +from adaptdl.torch.gradient_noise_scale import GradientNoiseScale + + +def test_object(): + params = [torch.tensor([[1., -1.], [2., 3.]], requires_grad=True), + torch.tensor([[2., 3.]], requires_grad=True)] + sgd = torch.optim.SGD(params, lr=0.1) + adp = Mock(require_backward_grad_sync=True) + obj = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) + assert(obj._accum_scale == 1.0) + obj._num_replicas = 8 + obj.set_accum_scale(3.0) + assert(obj.accum_scale == 3.0) + obj._num_replicas = 4 + obj.set_accum_scale(3.0) + assert(obj.accum_scale == 3.0) + assert(np.isclose(obj.gain(2.0), 1.0)) + obj._state['var_avg'] = 3.0 + obj._state['norm_avg'] = 1.0 + assert(np.isclose(obj.gain(3.0), 2.0)) + + +ATOL = 0.01 + + +def test_nan(): + def nan_objective(tensor): + if random.random() > 0.5: + target = float("Nan") + else: + target = 4.0 + return (tensor - target)**2 + + params_t = torch.Tensor([1.0]) + params = torch.autograd.Variable(params_t, requires_grad=True) + sgd = torch.optim.SGD([params], lr=0.1) + adp = Mock(require_backward_grad_sync=True) + gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) + adp.gns = gns + for i in range(100): + gns.reset_accumulation() + loss = nan_objective(params) + loss.backward() + if np.all(np.isfinite(loss.detach().numpy())): + sgd.step() + if params.allclose(torch.tensor([4.0]), atol=ATOL): + break + else: + pytest.fail(f"Did not converge: {params}") + if not (np.all(np.isfinite(gns.sqr_avg())) and + np.all(np.isfinite(gns.var_avg()))): + pytest.fail(f"non-finite adascale parameters:" + f"{gns.sqr_avg()}, {gns.var_avg()}") diff --git a/adaptdl/adaptdl/torch/torch/iterator.py b/adaptdl/adaptdl/torch/torch/iterator.py new file mode 100644 index 000000000..337292ea6 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/iterator.py @@ -0,0 +1,121 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import logging + +from torchtext.data import BPTTIterator +from torchtext.data.dataset import Dataset +from torchtext.data.batch import Batch + +import adaptdl.checkpoint +import adaptdl.collective +import adaptdl.env +from adaptdl.torch.data import AdaptiveDataLoaderMixin + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.INFO) + + +class AdaptiveBPTTIterator(BPTTIterator, AdaptiveDataLoaderMixin): + def __init__(self, dataset, batch_size, bptt_len, **kwargs): + max_batch_size = kwargs.pop("max_batch_size", None) + local_bsz_bounds = kwargs.pop("local_bsz_bounds", None) + + BPTTIterator.__init__(self, dataset=dataset, batch_size=batch_size, + bptt_len=bptt_len, **kwargs) + AdaptiveDataLoaderMixin.__init__(self, batch_size) + + self.num_replicas = adaptdl.env.num_replicas() + self.rank = adaptdl.env.replica_rank() + + if max_batch_size and local_bsz_bounds: + self._elastic.autoscale_batch_size(max_batch_size, + local_bsz_bounds) + + # The start index changes when there is a rescaling. We recompute a new + # start index based on how much we covered before the restart + def _recompute_start(self, prev_curr, prev_end, curr_end): + if prev_end == 0: + return prev_curr + return math.ceil(prev_curr * curr_end / prev_end) + + def __iter__(self): + with self._elastic.context(): + if self._elastic.skipdone(): + return + + self.batch_size = self._elastic._sync_local_bsz() + + text = self.dataset[0].text + TEXT = self.dataset.fields['text'] + TEXT.eos_token = None + text = text + ([TEXT.pad_token] * + int(math.ceil(len(text) / self.batch_size) * + self.batch_size - len(text))) + data = TEXT.numericalize( + [text], device=self.device) + data = data.view(self.batch_size, -1).t().contiguous() + dataset = Dataset(examples=self.dataset.examples, fields=[ + ('text', TEXT), ('target', TEXT)]) + end = data.size(0) # current length of dataset + + # Change in current batch size changes the dimensions of dataset + # which changes the starting position in the reshaped dataset. The + # local batch size is also a function of number of replicas, so + # when we rescale we need to recalculate where to start the + # iterations from for the new batch size. + self._elastic.current_index = \ + self._recompute_start(self._elastic.current_index, + self._elastic.end_index, end) + self._elastic.end_index = end + + # Every replica reads data strided by bptt_len + start = self._elastic.current_index + (self.bptt_len * self.rank) + step = self.bptt_len * self.num_replicas + + # The starting index of the highest rank + highest_start = self._elastic.current_index + \ + (self.bptt_len * (self.num_replicas - 1)) + + # Number of steps we will take on the highest rank. We limit + # iterations on all replicas by this number to prevent asymmetric + # reduce ops which would result in a deadlock. + min_steps_in_epoch = max(math.ceil((end - highest_start) / step), 0) # noqa: E501 + self.iterations = 0 + while True: + for i in range(start, end, step): + self.iterations += 1 + # Make sure that _elastic.profile is called equal number of + # times on all replicas + if self.iterations > min_steps_in_epoch: + break + with self._elastic.profile(self.training and i > 0): + seq_len = min(self.bptt_len, data.size(0) - i - 1) + assert seq_len > 0 + batch_text = data[i:i + seq_len] + batch_target = data[i + 1:i + 1 + seq_len] + if TEXT.batch_first: + batch_text = batch_text.t().contiguous() + batch_target = batch_target.t().contiguous() + yield Batch.fromvars( + dataset, self.batch_size, + text=batch_text, + target=batch_target) + self._elastic.current_index += step + + if not self.repeat: + break diff --git a/adaptdl/adaptdl/torch/torch/parallel.py b/adaptdl/adaptdl/torch/torch/parallel.py new file mode 100644 index 000000000..218ae2981 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/parallel.py @@ -0,0 +1,232 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import numpy as np +import time +import warnings +from typing import Optional + +import torch +import torch.cuda +import torch.distributed +from torch.autograd import Variable +from torch.nn.parallel import DistributedDataParallel + +import adaptdl.checkpoint +import adaptdl.env +import adaptdl.utils +from adaptdl.torch.data import current_dataloader +from adaptdl.torch.scaling_rules import AdaScale, AdamScale, ScalingRuleBase +from adaptdl.torch.gradient_noise_scale import GradientNoiseScale,\ + AdamGradientNoiseScale +from adaptdl.torch._metrics import profile_sync_time, update_grad_params,\ + update_progress + + +class AdaptiveDataParallel(DistributedDataParallel): + """ + This class extends PyTorch DistributedDataParallel with support for + adaptive batch sizes and checkpoint-restart elasticity. It automatically + saves the given model, optimizer, and (optionally) LR scheduler whenever a + checkpoint is triggered, and restores their states after restart. The + optimizer is automatically patched with the chosen scaling rule. + + Arguments: + model (torch.nn.Module): Model to be distributed. + optimizer (torch.optim.Optimizer): Optimizer used to update the given + model's parameters, will be patched using subclass of + :class:`adaptdl.torch.scaling_rules.ScalingRuleBase`. + scaling_rule (ScalingRuleBase): Scaling rule used to + patch the given optimizer, default to AdaScale. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used + to anneal the learning rate for the given optimizer. + name (string): Unique name for each instance of this class, needed only + if multiple instances exist. + """ + def __init__(self, model, optimizer, lr_scheduler=None, mp_scaler=None, + scaling_rule: Optional[ScalingRuleBase] = None, + name="adaptdl-dataparallel", **kwargs): + super().__init__(model, **kwargs) + self._key = id(self) + # Register backward hooks on model parameters. Depends on these hooks + # being invoked before gradients are averaged. This is technically an + # internal behavior of DistributedDataParallel, but seems to be abused + # pretty widely so there should be little chance of it changing. + # https://discuss.pytorch.org/t/59291 + for param in model.parameters(): + param.register_hook(functools.partial(self._backward_hook, param)) + + # Setup for the scaling_rule, must be after registering backward hooks + # because some of them need to register their own backward hooks. + if not scaling_rule and (isinstance(optimizer, torch.optim.Adam) or + isinstance(optimizer, torch.optim.AdamW)): + self.scaling_rule = AdamScale() + else: + self.scaling_rule = scaling_rule or AdaScale() + + if isinstance(scaling_rule, AdamScale): + self.gns = AdamGradientNoiseScale(self, optimizer, + mp_scaler=mp_scaler) + else: + self.gns = GradientNoiseScale(self, optimizer, mp_scaler=mp_scaler) + self.scaling_rule.initialize(self, optimizer, patch_optimizer=True) + + self._state = _AdaptiveDataParallelState( + model, optimizer, lr_scheduler, mp_scaler, name) + adaptdl.checkpoint.load_state(self._state) + + self._sync_start = None + + def forward(self, *args, **kwargs): + # Do not do gradient synchronization during gradient accumulation. + dataloader = current_dataloader() + if dataloader is not None and dataloader.training: + self.require_backward_grad_sync = dataloader.is_optim_step() + accum_scale = (dataloader.current_local_bsz * + adaptdl.env.num_replicas() / dataloader.batch_size) + self.gns.set_accum_scale(accum_scale) + return super().forward(*args, **kwargs) + + @adaptdl.utils.print_exc + def _backward_hook(self, param, grad): + # This method should be invoked once for each parameter during the + # backward pass, before gradients are synchronized between replicas. + if grad.device.type.startswith("cuda"): + self._sync_start = torch.cuda.Event(enable_timing=True) + self._sync_start.record() + else: + self._sync_start = time.time() + self._final_callback_queued = False + Variable._execution_engine.queue_callback(self._queue_callback) + + @adaptdl.utils.print_exc + def _queue_callback(self): + # This method should be invoked after the entire backward pass. We want + # to make sure self._final_callback is invoked once, only after all + # gradients have been synchronized between each replica. However, the + # synchronization code in DistributedDataParallel is also done in a + # callback, which might not yet be executed. Therefore, we enqueue + # self._final_callback from this method, which should ensure it is + # invoked after the gradient synchronization callback. + if self._final_callback_queued: + return + self._final_callback_queued = True + Variable._execution_engine.queue_callback(self._final_callback) + + @adaptdl.utils.print_exc + def _final_callback(self): + # This method should be invoked once for each backward pass, after + # gradients have been synchronized between each replica. + self._final_callback_queued = False + # self._sync_start should mark the last time any local gradient + # from this module was produced. We assume the duration until now was + # primarily spent waiting for gradient synchronization. + # TODO: Depends on the internal behavior of DistributedDataParallel, + # which might break with future versions of PyTorch. Any better + # and well-supported way to measure the synchronization time? + if isinstance(self._sync_start, torch.cuda.Event): + sync_end = torch.cuda.Event(enable_timing=True) + sync_end.record() + sync_end.synchronize() + profile_sync_time(self._sync_start.elapsed_time(sync_end) / 1e3) + else: + profile_sync_time(time.time() - self._sync_start) + + dataloader = current_dataloader() + if dataloader is None: + # Don't allow backpropagation outside of a dataloader loop, because + # the batch size would be unknown. + raise RuntimeError("backpropagation outside AdaptiveDataLoader") + dataloader.train() + + scale = dataloader.current_batch_size / dataloader.batch_size + self._state.gain = self.gns.gain(scale) + self._state.lr_factor = \ + np.average(self.scaling_rule.scale_lr(scale)) + update_progress(self.gns.get_progress()) + if dataloader.max_batch_size and \ + dataloader.max_batch_size > dataloader.batch_size: + update_grad_params(self._key, self.gns.sqr_avg(), + self.gns.var_avg()) + self._sync_start = None + + def zero_grad(self, *args, **kwargs): + warnings.warn("zero_grad has no effect with AdaptiveDataParallel") + + @property + def gain(self): # TODO: should be tracked in the metrics module instead. + """ + Current estimate of the AdaScale gain (r_t) value. + """ + return self._state.gain + + def to_tensorboard(self, writer, global_step, tag_prefix=""): + """ + Output some useful metrics to TensorBoard. + + Arguments: + writer (torch.utils.tensorboard.SummaryWriter): ``SummaryWriter`` + object to output metrics to. + global_step (int): Global step value to record. + tag_prefix (str): Prefix added to each metric's tag. + """ + if tag_prefix and not tag_prefix.endswith("/"): + tag_prefix += "/" + writer.add_scalar(tag_prefix + "Gradient_Norm_Sqr", + self.gns.sqr_avg(), global_step) + writer.add_scalar(tag_prefix + "Gradient_Variance", + self.gns.var_avg(), global_step) + writer.add_scalar(tag_prefix + "Gain", + self._state.gain, global_step) + writer.add_scalar(tag_prefix + "Learning_Rate_Factor", + self._state.lr_factor, global_step) + + +class _AdaptiveDataParallelState(adaptdl.checkpoint.State): + def __init__(self, model, optimizer, lr_scheduler, mp_scaler, + name="adaptdl-dataparallel"): + super().__init__(name) + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.mp_scaler = mp_scaler + # TODO: Gain/goodput should be tracked in the metrics module instead. + self.gain = 1.0 + # lr_factor summary + self.lr_factor = 1.0 + + def save(self, fileobj): + state_dicts = [self.model.state_dict(), self.optimizer.state_dict()] + + if self.lr_scheduler is not None: + state_dicts.append(self.lr_scheduler.state_dict()) + else: + state_dicts.append(None) + + if self.mp_scaler is not None: + state_dicts.append(self.mp_scaler.state_dict()) + else: + state_dicts.append(None) + torch.save((state_dicts, self.gain, self.lr_factor), fileobj) + + def load(self, fileobj): + state_dicts, self.gain, self.lr_factor = torch.load(fileobj) + self.model.load_state_dict(state_dicts[0]) + self.optimizer.load_state_dict(state_dicts[1]) + if state_dicts[2] is not None: + self.lr_scheduler.load_state_dict(state_dicts[2]) + if state_dicts[3] is not None: + self.mp_scaler.load_state_dict(state_dicts[3]) diff --git a/adaptdl/adaptdl/torch/torch/parallel_test.py b/adaptdl/adaptdl/torch/torch/parallel_test.py new file mode 100644 index 000000000..3e67eb551 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/parallel_test.py @@ -0,0 +1,67 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + +from torch.utils.data import Dataset +import adaptdl.torch as adl + + +class LRIterableDataset(Dataset): + def __init__(self, size, true_values, noise): + input_values = np.random.uniform(-5.0, 5.0, size) + bias_input_values = np.stack([np.ones(size), input_values]) + target_values = ( + np.dot(true_values, bias_input_values) + + np.random.normal(0.0, noise, size=(size,))) + self._values = list(zip(input_values, target_values)) + self._len = size + + def __getitem__(self, index): + return self._values[index] + + def __len__(self): + return self._len + + +def test_single_replica_parallel(): + adl.init_process_group("gloo") + true_values = np.asarray([3.0, 4.0]) + dataset = LRIterableDataset(1000, true_values, 1.0) + dataloader = adl.AdaptiveDataLoader( + dataset, batch_size=32, shuffle=False, num_workers=1) + model = torch.nn.Linear(1, 1, bias=True) + params = [model.bias, model.weight] + sgd = torch.optim.SGD( + [{"params": [param]} for param in params], + lr=0.01) + schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, [50]) + model = adl.AdaptiveDataParallel(model, sgd, schedule) + loss = torch.nn.MSELoss() + for epoch in adl.remaining_epochs_until(100): + for inputs, targets in dataloader: + inputs = inputs.float() + targets = targets.float() + sgd.zero_grad() + output = model(torch.reshape(inputs, (-1, 1))) + targets = torch.reshape(targets, (-1, 1)) + loss_value = loss(output, targets) + loss_value.backward() + sgd.step() + schedule.step() + params = np.asarray([param.item() for param in params]) + assert(np.all(np.isclose(params, true_values, atol=0.1))), \ + (params, true_values) diff --git a/adaptdl/adaptdl/torch/torch/scaling_rules.py b/adaptdl/adaptdl/torch/torch/scaling_rules.py new file mode 100644 index 000000000..ac0f2a1c0 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/scaling_rules.py @@ -0,0 +1,200 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math +import numpy as np +import warnings + +from types import MethodType + +# from adaptdl.torch.data import current_dataloader + + +__all__ = ["ScalingRuleBase", "AdaScale", "LinearScale", "SqrtScale", + "LEGWScale"] + + +class ScalingRuleBase(object): + """ + Base class for scaling rules that has the ability to track gradient noise + scale calculations. Its subclasses can be used in combination with + ``adaptdl.torch.parallel.AdaptiveDataParallel`` and ``torch.optim.SGD``. + + .. code-block:: python + + optim = torch.optim.SGD(model, lr=0.001) + adascale = AdaScale() + model = AdaptiveDataParallel(model, optim, adascale) + + for epoch in ...: + for batch in ...: + optim.zero_grad() + loss = ... + loss.backward() + adascale.step() + """ + + _adaptlr = None + + def __init__(self): + # instance of AdaptiveDataParallel, needs to be set before any of the + # methods can be used + self.adp = None + self._optimizer = None + self._orig_optimizer_step = None + + def scale_lr(self, scale): + raise NotImplementedError + + def zero_grad(self, *args, **kwargs): + if self.adp.gns.should_zero_grad: + self.adp.gns.reset_accumulation(*args, **kwargs) + else: + warnings.warn("skipping zero_grad for accumulated gradient") + + def step(self, *args, **kwargs): + """ + Run one optimizer step. Essentially just invokes + ``optimizer.step(*args, **kwargs)`` with a scaled learning rate. + + Arguments: + args: Positional arguments passed to ``optimizer.step``. + kwargs: Keyword arguments passed to ``optimizer.step``. + """ + if not self.adp: + raise ValueError("AdaptiveDataParallel instance is not set!") + if not self.adp.require_backward_grad_sync: + return + scale = self.adp.gns.accum_scale * self.adp.gns.accum_count + initial_lr = [pg["lr"] for pg in self._optimizer.param_groups] + scaled_lr = np.multiply(self.scale_lr(scale), initial_lr) + ScalingRuleBase._adaptlr = scaled_lr + for lr, pg in zip(scaled_lr, self._optimizer.param_groups): + pg["lr"] = lr + self._orig_optimizer_step(*args, **kwargs) + for lr, pg in zip(initial_lr, self._optimizer.param_groups): + pg["lr"] = lr + self.adp.gns.set_progress(self.adp.gns.get_progress() + + self.adp.gns.gain(scale)) + + def _patch_optimizer(self): + """ + Monkey-patch the optimizer's step function with + :meth:`ScalingRuleBase.step`. + """ + @functools.wraps(self._optimizer.step) + def step_wrapper(optim, *args, **kwargs): + return self.step(*args, **kwargs) + + @functools.wraps(self._optimizer.zero_grad) + def zero_wrapper(optim, *args, **kwargs): + return self.zero_grad(*args, **kwargs) + self._optimizer.step = MethodType(step_wrapper, self._optimizer) + self._optimizer.zero_grad = MethodType(zero_wrapper, self._optimizer) + + def initialize(self, adp, optimizer, patch_optimizer=False): + self.adp = adp + self._optimizer = optimizer + self._orig_optimizer_step = optimizer.step + if patch_optimizer: + self._patch_optimizer() + + @staticmethod + def _get_adapt_lr_scale(): + return ScalingRuleBase._adaptlr + + +class AdaScale(ScalingRuleBase): + """ + Implements the AdaScale_ algorithm for scaling the learning rate for + distributed and large batch size training. + + .. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf + """ # noqa: E501 + + def scale_lr(self, scale): + """Calculate factors to be applied to lr for each parameter group.""" + var = self.adp.gns.raw_var_avg + sqr = self.adp.gns.raw_sqr_avg + var = np.maximum(var, 1e-6) + sqr = np.maximum(sqr, 0.0) + return (var + sqr) / (var / scale + sqr) + + +class AdamScale(AdaScale): + """ + Implements the variant of AdaScale_ that supports Adam, AdamW and RMSProp + """ + + def scale_lr(self, scale, power=0.5): + return np.power(super().scale_lr(scale=scale), power) + + +class LinearScale(ScalingRuleBase): + + def scale_lr(self, scale): + return scale + + +class SqrtScale(ScalingRuleBase): + + def scale_lr(self, scale): + return math.sqrt(scale) + + +class LEGWScale(ScalingRuleBase): + """ + Implements the LEGWScale algorithm for scaling the learning rate. + + Essentially, with LEGWScale, lr_factor is calculated based on + training progress as follows: + - when current_step < base_warmup_epoch * scale * steps_per_epoch: + `lr_factor = sqrt(scale) * progress_ratio` where + `progress_ratio = current_step / + (scale * base_warmup_epochs * steps_per_epoch)` + - when current_step >= base_warmup_epoch * scale * steps_per_epoch: + `lr_factor = sqrt(scale)` + + In order to adapt LEGWScale to AdaptDL, `progress_ratio` is + calculated differently as: + `progress / (scale * base_warmup_epochs * steps_per_epoch)` where + `progress` is the effective steps trained based on AdaptDL's + estimation. + + Argmuents: + base_warmup_epochs: Base warmup epochs + data_size: total number of samples in the dataset + + .. _LEGWScale: https://arxiv.org/pdf/1901.08256.pdf + """ + + def __init__(self, base_warmup_epochs, data_size): + super().__init__() + self._base_warmup_epochs = base_warmup_epochs + self._data_size = data_size + + def scale_lr(self, scale): + dataloader = current_dataloader() + # total training steps for warm up + total_steps = self._base_warmup_epochs * scale * \ + self._data_size / dataloader.batch_size + max_lr_multiplier = math.sqrt(scale) + # effective training steps taken + progress = self.adp.gns.get_progress() + if progress < total_steps: + lr_factor = max_lr_multiplier * (progress / total_steps) + else: + lr_factor = max_lr_multiplier + return lr_factor diff --git a/adaptdl/adaptdl/torch/torch/scaling_rules_test.py b/adaptdl/adaptdl/torch/torch/scaling_rules_test.py new file mode 100644 index 000000000..072adf666 --- /dev/null +++ b/adaptdl/adaptdl/torch/torch/scaling_rules_test.py @@ -0,0 +1,253 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import pytest +import torch + +from unittest.mock import Mock, patch + +from adaptdl.torch.gradient_noise_scale import GradientNoiseScale +from adaptdl.torch.scaling_rules import AdaScale, LinearScale,\ + LEGWScale, SqrtScale + + +def test_scaling_rules_1(): + """test AdaScale lr factors""" + adp = Mock(require_backward_grad_sync=True) + opm = Mock(param_groups=[1, 0, 2, -1]) + gns = Mock(raw_var_avg=np.asarray([1, 0, 0, 2]), + raw_sqr_avg=np.asarray([-1, 0, -1, 1])) + adp.gns = gns + adascale = AdaScale() + adascale.initialize(adp, opm) + input_scales = [0.5, 1, 2, 4, 10] + expected_ans = [[0.5, 0.5, 0.5, 0.6], [1., 1., 1., 1.], [2., 2., 2., 1.5], + [4., 4., 4., 2.], [10., 10., 10., 2.5]] + for scale, ans in zip(input_scales, expected_ans): + np.testing.assert_equal(adascale.scale_lr(scale), ans) + + +def test_scaling_rules_2(): + """test LinearScale lr factors""" + adp = Mock(require_backward_grad_sync=True) + opm = Mock(param_groups=[1, 0, 2, -1]) + gns = Mock(optimizer=opm) + adp.gns = gns + linearscale = LinearScale() + linearscale.initialize(adp, opm) + input_scales = [0.5, 1, 2, 4, 10] + expected_ans = [0.5, 1., 2., 4., 10.] + for scale, ans in zip(input_scales, expected_ans): + np.testing.assert_equal(linearscale.scale_lr(scale), ans) + + +def test_scaling_rules_3(): + """test SqrtScale lr factors""" + adp = Mock(require_backward_grad_sync=True) + opm = Mock(param_groups=[1, 0, 2, -1]) + gns = Mock(optimizer=opm) + adp.gns = gns + sqrtscale = SqrtScale() + sqrtscale.initialize(adp, opm) + input_scales = [1, 4, 9, 16, 25] + expected_ans = [1., 2., 3., 4., 5.] + for scale, ans in zip(input_scales, expected_ans): + np.testing.assert_equal(sqrtscale.scale_lr(scale), ans) + + +def test_scaling_rules_4(): + """test LEGWScale lr factors""" + with patch("adaptdl.torch.scaling_rules.current_dataloader", + return_value=Mock(batch_size=100)): + adp = Mock(require_backward_grad_sync=True) + opm = Mock(param_groups=[1, 0, 2, -1]) + gns = Mock(optimizer=opm, get_progress=Mock(return_value=5)) + adp.gns = gns + legwscale = LEGWScale(10, 1000) + legwscale.initialize(adp, opm) + input_scales = [1, 4, 9, 16, 25] + expected_ans = [1/20, 1/40, 1/60, 1/80, 1/100] + for scale, ans in zip(input_scales, expected_ans): + np.testing.assert_equal(legwscale.scale_lr(scale), ans) + with patch("adaptdl.torch.scaling_rules.current_dataloader", + return_value=Mock(batch_size=50)): + gns = Mock(optimizer=opm, get_progress=Mock(return_value=400)) + adp.gns = gns + input_scales = [1, 4, 9, 16, 25] + expected_ans = [1., 1., 2/3, 0.5, 0.4] + for scale, ans in zip(input_scales, expected_ans): + np.testing.assert_equal(legwscale.scale_lr(scale), ans) + gns = Mock(optimizer=opm, get_progress=Mock(return_value=400)) + adp.gns = gns + input_scales = [1, 4, 9, 16, 25] + expected_ans = [1., 2., 4/3, 1., 0.8] + for scale, ans in zip(input_scales, expected_ans): + np.testing.assert_equal(legwscale.scale_lr(scale), ans) + + +LR = 0.001 +STEP_SCHEDULE = [1000] +ATOL = 0.01 + + +def test_optimization_1(): + # See torch.test.test_optim + # Also see Rosenbrock/banana function + def rosenbrock(tensor): + x, y = tensor + return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2 + + params_t = torch.Tensor([1.0, 1.5]) + + params = torch.autograd.Variable(params_t, requires_grad=True) + sgd = torch.optim.SGD([params], lr=LR) + schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) + adp = Mock(require_backward_grad_sync=True) + gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) + adp.gns = gns + adascale = AdaScale() + adascale.initialize(adp, sgd, patch_optimizer=True) + for i in range(100000): + sgd.zero_grad() + loss = rosenbrock(params) + loss.backward() + sgd.step() + schedule.step() + if params.allclose(torch.tensor([1.0, 1.0]), atol=ATOL): + break + else: + pytest.fail(f"Did not converge: {params}") + + +def test_optimization_2(): + + def rosenbrock_noisy(tensor): + x, y = tensor + return (np.random.normal(1.0, 0.2) * (1 - x) ** 2 + + np.random.normal(1.0, 0.2) * 100 * (y - x ** 2) ** 2) + + params_t = torch.Tensor([1.0, 1.5]) + + params = torch.autograd.Variable(params_t, requires_grad=True) + sgd = torch.optim.SGD([params], lr=LR) + schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) + adp = Mock(require_backward_grad_sync=True) + gns = GradientNoiseScale(adp, sgd, accum_scale=2.0, num_replicas=1) + adp.gns = gns + adascale = AdaScale() + adascale.initialize(adp, sgd, patch_optimizer=True) + for i in range(100000): + sgd.zero_grad() + loss = sum([rosenbrock_noisy(params) for i in range(2)]) / 2.0 + loss.backward() + sgd.step() + schedule.step() + if params.allclose(torch.tensor([1.0, 1.0]), atol=ATOL): + break + else: + pytest.fail(f"Did not converge: {params}") + + +def test_optimization_3(): + def rosenbrock(x, y): + return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2 + + params_t = [ + {"params": [torch.autograd.Variable(torch.Tensor([1.0]), + requires_grad=True)]}, + {"params": [torch.autograd.Variable(torch.Tensor([1.5]), + requires_grad=True)]}] + + sgd = torch.optim.SGD(params_t, lr=LR) + schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) + adp = Mock(require_backward_grad_sync=True) + gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) + adp.gns = gns + adascale = AdaScale() + adascale.initialize(adp, sgd, patch_optimizer=True) + for i in range(100000): + sgd.zero_grad() + loss = rosenbrock(params_t[0]['params'][0], params_t[1]['params'][0]) + loss.backward() + sgd.step() + schedule.step() + if params_t[0]['params'][0].allclose(torch.tensor([1.0]), atol=ATOL) \ + and params_t[1]['params'][0].allclose(torch.tensor([1.0]), + atol=ATOL): + break + else: + pytest.fail(f"Did not converge: {params_t}") + + +def test_gradient_accumulation_optimization_1(): + + def rosenbrock(tensor): + x, y = tensor + return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2 + + params_t = torch.Tensor([1.0, 1.5]) + + params = torch.autograd.Variable(params_t, requires_grad=True) + sgd = torch.optim.SGD([params], lr=LR) + schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) + adp = Mock(require_backward_grad_sync=False) + gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) + adp.gns = gns + adascale = AdaScale() + adascale.initialize(adp, sgd, patch_optimizer=True) + for i in range(100000): + adp.require_backward_grad_sync = i % 2 == 1 + sgd.zero_grad() + loss = rosenbrock(params) + loss.backward() + sgd.step() + if adp.require_backward_grad_sync: + schedule.step() + if params.allclose(torch.tensor([1.0, 1.0]), atol=10 * ATOL): + break + else: + pytest.fail(f"Did not converge: {params}") + + +def test_gradient_accumulation_optimization_2(): + + def rosenbrock_noisy(tensor): + x, y = tensor + return (np.random.normal(1.0, 0.2) * (1 - x) ** 2 + + np.random.normal(1.0, 0.2) * 100 * (y - x ** 2) ** 2) + + params_t = torch.Tensor([1.0, 1.5]) + + params = torch.autograd.Variable(params_t, requires_grad=True) + sgd = torch.optim.SGD([params], lr=LR) + schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) + adp = Mock(require_backward_grad_sync=False) + gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) + adp.gns = gns + adascale = AdaScale() + adascale.initialize(adp, sgd, patch_optimizer=True) + for i in range(1000000): + adp.require_backward_grad_sync = i % 2 == 1 + sgd.zero_grad() + loss = rosenbrock_noisy(params) + loss.backward() + sgd.step() + if adp.require_backward_grad_sync: + schedule.step() + if params.allclose(torch.tensor([1.0, 1.0]), atol=ATOL): + break + else: + pytest.fail(f"Did not converge: {params}") diff --git a/tutorial/testcase_for_adaptdldataloader_refactor.py b/tutorial/testcase_for_adaptdldataloader_refactor.py new file mode 100644 index 000000000..038386404 --- /dev/null +++ b/tutorial/testcase_for_adaptdldataloader_refactor.py @@ -0,0 +1,156 @@ +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + +import adaptdl # Changed in step 1 +import adaptdl.torch # Changed in step 1 +from adaptdl.torch.data import AdaptiveDLContext # For test AdaptiveDLContext only, users do not need to call this + +from adaptdl.torch.scaling_rules import ScalingRuleBase + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout2d(0.25) + self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch, adacontext): # For test AdaptiveDLContext only, users do not need to call this + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' + '\treal_batch_size:{}\treal_lr:{}' + '\t ada_batch_size:{}\tada_accum:{}\tada_lr_scale:{}' + .format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item(), + len(data), optimizer.param_groups[0]['lr'], + adacontext.get_batch_size(), adacontext.get_accum_steps(), adacontext.get_lr_scale(),# For test AdaptiveDLContext only, users do not need to call this + )) + if args.dry_run: + break + + +def tst(model, device, test_loader): + model.eval() + stats = adaptdl.torch.Accumulator() # Changed in step 5 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + stats["test_loss"] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss # Changed in step 5 + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + stats["correct"] += pred.eq(target.view_as(pred)).sum().item() # Changed in step 5 + + with stats.synchronized(): # Changed in step 5 + test_loss = stats["test_loss"] / len(test_loader.dataset) # Changed in step 5 + correct = stats["correct"] # Changed in step 5 + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) # Changed in step 5 + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + kwargs = {'batch_size': args.batch_size} + if use_cuda: + kwargs.update({'num_workers': 1, + 'pin_memory': True, + 'shuffle': True}, + ) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + dataset1 = datasets.MNIST('../data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, + transform=transform) + adacontext = AdaptiveDLContext(args.batch_size) # For test AdaptiveDLContext only, users do not need to call this + train_loader = adaptdl.torch.AdaptiveDataLoader(dataset1, drop_last=True, **kwargs) # Changed in step 2 + test_loader = adaptdl.torch.AdaptiveDataLoader(dataset2, **kwargs) # Changed in step 2 + + train_loader.autoscale_batch_size(1028, local_bsz_bounds=(32, 128)) # Changed in step 3, optional + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() + else "gloo") # Changed in step 1 + model = adaptdl.torch.AdaptiveDataParallel(model, optimizer, scheduler) # Changed in step 1 + # adacontext = AdaptiveDLContext(args.batch_size) # For test AdaptiveDLContext only, users do not need to call this + + for epoch in adaptdl.torch.remaining_epochs_until(args.epochs): # Changed in step 4 + train(args, model, device, train_loader, optimizer, epoch, adacontext) # For test AdaptiveDLContext only, users do not need to call this + tst(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == '__main__': + main() From c5513f0edea6d7304306560ba807500efa574bcb Mon Sep 17 00:00:00 2001 From: Xuezhi-Liang Date: Mon, 4 Apr 2022 14:29:55 +0400 Subject: [PATCH 02/10] stage1 --- adaptdl/adaptdl/torch/{torch => }/context.py | 0 adaptdl/adaptdl/torch/data.py | 91 +--- adaptdl/adaptdl/torch/scaling_rules.py | 10 +- adaptdl/adaptdl/torch/torch/__init__.py | 142 ----- adaptdl/adaptdl/torch/torch/_metrics.py | 199 ------- adaptdl/adaptdl/torch/torch/_metrics_test.py | 158 ------ adaptdl/adaptdl/torch/torch/accumulator.py | 312 ----------- .../adaptdl/torch/torch/accumulator_test.py | 60 --- adaptdl/adaptdl/torch/torch/data.py | 492 ------------------ adaptdl/adaptdl/torch/torch/data_test.py | 168 ------ adaptdl/adaptdl/torch/torch/epoch.py | 178 ------- adaptdl/adaptdl/torch/torch/epoch_test.py | 42 -- .../torch/torch/gradient_noise_scale.py | 330 ------------ .../torch/torch/gradient_noise_scale_test.py | 60 --- adaptdl/adaptdl/torch/torch/iterator.py | 121 ----- adaptdl/adaptdl/torch/torch/parallel.py | 232 --------- adaptdl/adaptdl/torch/torch/parallel_test.py | 67 --- adaptdl/adaptdl/torch/torch/scaling_rules.py | 200 ------- .../adaptdl/torch/torch/scaling_rules_test.py | 253 --------- 19 files changed, 13 insertions(+), 3102 deletions(-) rename adaptdl/adaptdl/torch/{torch => }/context.py (100%) delete mode 100644 adaptdl/adaptdl/torch/torch/__init__.py delete mode 100644 adaptdl/adaptdl/torch/torch/_metrics.py delete mode 100644 adaptdl/adaptdl/torch/torch/_metrics_test.py delete mode 100644 adaptdl/adaptdl/torch/torch/accumulator.py delete mode 100644 adaptdl/adaptdl/torch/torch/accumulator_test.py delete mode 100644 adaptdl/adaptdl/torch/torch/data.py delete mode 100644 adaptdl/adaptdl/torch/torch/data_test.py delete mode 100644 adaptdl/adaptdl/torch/torch/epoch.py delete mode 100644 adaptdl/adaptdl/torch/torch/epoch_test.py delete mode 100644 adaptdl/adaptdl/torch/torch/gradient_noise_scale.py delete mode 100644 adaptdl/adaptdl/torch/torch/gradient_noise_scale_test.py delete mode 100644 adaptdl/adaptdl/torch/torch/iterator.py delete mode 100644 adaptdl/adaptdl/torch/torch/parallel.py delete mode 100644 adaptdl/adaptdl/torch/torch/parallel_test.py delete mode 100644 adaptdl/adaptdl/torch/torch/scaling_rules.py delete mode 100644 adaptdl/adaptdl/torch/torch/scaling_rules_test.py diff --git a/adaptdl/adaptdl/torch/torch/context.py b/adaptdl/adaptdl/torch/context.py similarity index 100% rename from adaptdl/adaptdl/torch/torch/context.py rename to adaptdl/adaptdl/torch/context.py diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 90a8767ac..8da7ce7a0 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -32,6 +32,7 @@ profile_step_start, profile_step_commit, set_batch_size, get_goodput_fn, get_progress) from adaptdl._signal import get_exit_flag +from adaptdl.torch.context import AdaptiveDLContext logging.basicConfig(level=logging.INFO) LOG = logging.getLogger(__name__) @@ -267,43 +268,6 @@ def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, self._gradient_accumulation = gradient_accumulation self.train() - def _sync_local_bsz(self): - goodput_fn = get_goodput_fn() - if self.max_batch_size is None or goodput_fn is None: - # No autoscale batch size, just divide batch size evenly. - self._state.current_local_bsz = math.ceil( - self.batch_size / adaptdl.env.num_replicas()) - self._state.accumulation_steps = 0 - elif not self._state.current_local_bsz: - # if init, use the batch size suggested - _, atomic_bsz, accum_steps = goodput_fn.optimize( - adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, - accumulation=self._gradient_accumulation) - self._state.current_local_bsz = atomic_bsz - self._state.accumulation_steps = accum_steps - else: - # if not first time, we check against the relative speedup - suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize( - adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, - accumulation=self._gradient_accumulation) - # get current goodput - current_goodput = goodput_fn( - adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - self.current_local_bsz, self.accumulation_steps) - # use only if speedup is significant - speedup = suggest_goodput / max(current_goodput, 1e-8) - if speedup > self._speedup_threshold: - self._state.current_local_bsz = atomic_bsz - self._state.accumulation_steps = accum_steps - self._state.current_local_bsz, self._state.accumulation_steps = \ - adaptdl.collective.broadcast((self._state.current_local_bsz, - self._state.accumulation_steps)) - return self.current_local_bsz - @property def training(self): return self is AdaptiveDataLoaderHelper._training @@ -398,53 +362,6 @@ def to_tensorboard(self, writer, global_step, tag_prefix=""): self.accumulation_steps, global_step) -class AdaptiveDataLoaderMixin(object): - """ - This class provides elastic functionality to any custom DataLoader which - inherits it. It defines a member _elastic of type - :class:`AdaptiveDataLoaderHelper` which has useful methods and members to - implement restart-safe, elastic DataLoaders. It also exposes public methods - which can be used inside training loops directly from - :class:`AdaptiveDataLoader`. - """ - - def __init__(self, batch_size): - self._elastic = AdaptiveDataLoaderHelper(batch_size) - - def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, - gradient_accumulation=False): - self._elastic.autoscale_batch_size(max_batch_size, local_bsz_bounds, - gradient_accumulation) - - @property - def current_local_bsz(self): - if AdaptiveDataLoaderHelper._current is not self._elastic: - return None - return self._elastic.current_local_bsz - - @property - def accumulation_steps(self): - """ - The number of batches returned by the dataloader before a - step is taken. - """ - return self._elastic.accumulation_steps - - @property - def training(self): - return self._elastic.training - - @property - def current_batch_size(self): - if AdaptiveDataLoaderHelper._current is not self._elastic: - return None - return self._elastic.current_batch_size - - def to_tensorboard(self, writer, global_step, tag_prefix=""): - self._elastic.to_tensorboard(writer, global_step, tag_prefix) - to_tensorboard.__doc__ = AdaptiveDataLoaderHelper.to_tensorboard.__doc__ - - def _worker_init_wrapper(worker_init_fn, num_workers): # Set globally-unique python and numpy seeds for each worker. @@ -462,7 +379,7 @@ def wrapper(worker_id): return wrapper -class AdaptiveDataLoader(DataLoader, AdaptiveDataLoaderMixin): +class AdaptiveDataLoader(DataLoader, AdaptiveDLContext): """ This class is a PyTorch DataLoader that also supports adaptive batch sizes and checkpoint-restart elasticity. Applications can typically use objects @@ -501,7 +418,7 @@ def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): kwargs["worker_init_fn"] = _worker_init_wrapper( kwargs.get("worker_init_fn"), kwargs.get("num_workers")) super().__init__(dataset, batch_size, shuffle=False, **kwargs) - AdaptiveDataLoaderMixin.__init__(self, batch_size) + AdaptiveDLContext.__init__(self, batch_size) def __iter__(self): """ @@ -526,7 +443,7 @@ def __iter__(self): while not done: self.sampler.set_epoch( epoch, index=self._elastic.current_index) - self.batch_sampler.batch_size = self._elastic._sync_local_bsz() + self.batch_sampler.batch_size = self.get_batch_size() for idx, batch in enumerate(super().__iter__()): with self._elastic.profile(self.training and idx >= 1): yield batch diff --git a/adaptdl/adaptdl/torch/scaling_rules.py b/adaptdl/adaptdl/torch/scaling_rules.py index a1300232f..ac0f2a1c0 100644 --- a/adaptdl/adaptdl/torch/scaling_rules.py +++ b/adaptdl/adaptdl/torch/scaling_rules.py @@ -19,7 +19,7 @@ from types import MethodType -from adaptdl.torch.data import current_dataloader +# from adaptdl.torch.data import current_dataloader __all__ = ["ScalingRuleBase", "AdaScale", "LinearScale", "SqrtScale", @@ -45,6 +45,9 @@ class ScalingRuleBase(object): loss.backward() adascale.step() """ + + _adaptlr = None + def __init__(self): # instance of AdaptiveDataParallel, needs to be set before any of the # methods can be used @@ -77,6 +80,7 @@ def step(self, *args, **kwargs): scale = self.adp.gns.accum_scale * self.adp.gns.accum_count initial_lr = [pg["lr"] for pg in self._optimizer.param_groups] scaled_lr = np.multiply(self.scale_lr(scale), initial_lr) + ScalingRuleBase._adaptlr = scaled_lr for lr, pg in zip(scaled_lr, self._optimizer.param_groups): pg["lr"] = lr self._orig_optimizer_step(*args, **kwargs) @@ -107,6 +111,10 @@ def initialize(self, adp, optimizer, patch_optimizer=False): if patch_optimizer: self._patch_optimizer() + @staticmethod + def _get_adapt_lr_scale(): + return ScalingRuleBase._adaptlr + class AdaScale(ScalingRuleBase): """ diff --git a/adaptdl/adaptdl/torch/torch/__init__.py b/adaptdl/adaptdl/torch/torch/__init__.py deleted file mode 100644 index c9832e600..000000000 --- a/adaptdl/adaptdl/torch/torch/__init__.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import sys -import os -if "darwin" in sys.platform.lower(): - # To avoid multiple runs of the model code - # https://pythonspeed.com/articles/python-multiprocessing/ - import multiprocessing - multiprocessing.set_start_method('fork') - -import logging -import portpicker -import requests -import torch.distributed -import pkg_resources - -import adaptdl.collective -import adaptdl.env -import semver -from .epoch import current_epoch, finished_epochs, remaining_epochs_until -from .data import current_dataloader, AdaptiveDataLoader, ElasticSampler -from .parallel import AdaptiveDataParallel -from .accumulator import Accumulator - -logging.basicConfig(level=logging.INFO) -LOG = logging.getLogger(__name__) -LOG.setLevel(logging.INFO) - - -def version_check(version): - if semver.VersionInfo.isvalid(version) and \ - version != "0.0.0": - return True - else: - return False - - -def init_process_group(backend, - init_method=None, - world_size=None, - rank=None): - """ - Initializes the default distributed process group and the AdaptDL - collectives module. - - Args: - backend (str or Backend): The backend to use. Use "nccl" for multi-GPU - training else "gloo". - init_method (str, optional): URL specifying how to initialize the - process group. - world_size (int, optional): Number of processes participating in - the job - rank (int, optional): Rank of the current process (it should be a - number between 0 and ``world_size``-1). - - If init_method, world_size and rank is NOT provided, typically in the - Kubernetes environment, AdaptDL will try to infer them through environment - variables ADAPTDL_MASTER_ADDR, ADAPTDL_NUM_REPLICAS and - ADAPTDL_REPLICA_RANK respectively. - """ - if adaptdl.env.from_ray(): - from adaptdl_ray.adaptdl.utils import unique_nodes_pg - assert init_method is not None - assert world_size is not None - assert rank is not None - os.environ["ADAPTDL_NUM_NODES"] = str(unique_nodes_pg()) - os.environ["ADAPTDL_REPLICA_RANK"] = str(rank) - os.environ["ADAPTDL_NUM_REPLICAS"] = str(world_size) - - url = adaptdl.env.supervisor_url() - master_port = adaptdl.env.master_port() - if rank is None: - rank = adaptdl.env.replica_rank() - - if world_size is None: - world_size = adaptdl.env.num_replicas() - - if init_method is not None: - _, master_addr, master_port = init_method.split(":") - master_addr = master_addr[2:] - master_port = int(master_port) - elif url: - key = adaptdl.env.job_id() - group = adaptdl.env.num_restarts() - while True: - response = requests.get(url=f"{url}/discover/{key}/{group}") - if response.status_code != 408: # Timeout. - break - response.raise_for_status() - master_addr = response.json()[0] - sched_version = adaptdl.env.adaptdl_sched_version() - trainer_version = pkg_resources.get_distribution("adaptdl").version - if version_check(sched_version) and version_check(trainer_version): - trainer_ver_maj = semver.VersionInfo.parse(trainer_version).major - sched_ver_maj = semver.VersionInfo.parse(sched_version).major - if trainer_ver_maj != sched_ver_maj: - raise Exception('adaptdl version {} is incompatible with' - 'scheduler version {}'.format(trainer_version, - sched_version)) - else: - master_addr = adaptdl.env.master_addr() - - # Initialize collective module. - adaptdl.collective.initialize(master_addr, - master_port, - rank, - world_size) - - # Initialize torch.distributed. - torch_port = adaptdl.collective.broadcast(portpicker.pick_unused_port()) - init_method = "tcp://{}:{}?rank={}&world_size={}".format( - master_addr, torch_port, rank, world_size) - LOG.info("Initializing torch.distributed using %s", init_method) - torch.distributed.init_process_group(backend, init_method) - - LOG.info("torch.distributed initialized") - - -__all__ = [ - "init_process_group", - "current_epoch", - "finished_epochs", - "remaining_epochs_until", - "current_dataloader", - "AdaptiveDataLoader", - "ElasticSampler", - "AdaptiveDataParallel", - "Accumulator", -] diff --git a/adaptdl/adaptdl/torch/torch/_metrics.py b/adaptdl/adaptdl/torch/torch/_metrics.py deleted file mode 100644 index 038c22b35..000000000 --- a/adaptdl/adaptdl/torch/torch/_metrics.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -import pickle -import time - -import numpy as np - -import adaptdl.checkpoint -import adaptdl.collective -import adaptdl.env -from adaptdl.goodput import GoodputFunction, fit_perf_params -from adaptdl.sched_hints import SCHED_HINTS, PERF_PARAMS, post_sched_hints - - -def profile_step_start(atomic_bsz): - state = _metrics_state() - state.atomic_bsz = atomic_bsz - state.step_start = time.time() - state.sync_time = 0.0 - - -def profile_sync_time(sync_time): - _metrics_state().sync_time += sync_time - - -_PREV_REPORT = None - - -def profile_step_commit(accumulation_step=False): - global _PREV_REPORT - state = _metrics_state() - step_time = time.time() - state.step_start - num_nodes = adaptdl.env.num_nodes() - num_replicas = adaptdl.env.num_replicas() - key = (num_nodes, num_replicas, state.atomic_bsz) - if accumulation_step: - state.profile[key]["accum_step_time"] += step_time - state.profile[key]["accum_count"] += 1 - else: - state.profile[key]["optim_step_time"] += step_time - state.profile[key]["optim_sync_time"] += state.sync_time - state.profile[key]["optim_count"] += 1 - del state.atomic_bsz - del state.step_start - del state.sync_time - if not accumulation_step: - if _PREV_REPORT is None: - _PREV_REPORT = time.time() - if adaptdl.env.replica_rank() == 0 and time.time() - _PREV_REPORT > 30: - _fit_perf_params() - _report_sched_hints() - _PREV_REPORT = time.time() - - -_GRAD_PARAM_DICT = {} - - -def update_grad_params(edp_key, grad_norm_sqr, grad_variance): - global _GRAD_PARAM_DICT - _GRAD_PARAM_DICT[edp_key] = np.asarray([grad_norm_sqr, grad_variance]) - grad_params = sum(_GRAD_PARAM_DICT.values()) - _metrics_state().grad_params = (grad_params[0], grad_params[1]) - - -def update_progress(progress): - _metrics_state().progress = progress - - -def get_progress(): - return _metrics_state().progress - - -def set_batch_size(init_batch_size, max_batch_size, local_bsz_bounds, - gradient_accumulation): - state = _metrics_state() - state.init_batch_size = init_batch_size - state.max_batch_size = max_batch_size - state.local_bsz_bounds = local_bsz_bounds - state.gradient_accumulation = gradient_accumulation - - -def get_goodput_fn(): - state = _metrics_state() - if state.grad_params is None or state.perf_params is None: - return None - return GoodputFunction(state.perf_params, state.grad_params, - state.init_batch_size) - - -def _fit_perf_params(): - state = _metrics_state() - profile = {k: v for k, v in state.profile.items() if v.get("optim_count")} - # Convert profile into numpy arrays. - num_nodes, num_replicas, atomic_bsz = ( - np.array(k) for k in zip(*profile.keys())) - accum_step_time = np.array([v.get("accum_step_time", 0.0) - for v in profile.values()]) - accum_count = np.array([v.get("accum_count", 0) for v in profile.values()]) - optim_step_time = np.array([v.get("optim_step_time", 0.0) - for v in profile.values()]) - optim_sync_time = np.array([v.get("optim_sync_time", 0.0) - for v in profile.values()]) - optim_count = np.array([v.get("optim_count", 0) for v in profile.values()]) - assert np.all(optim_count > 0) - # Non-sync time during optimization steps should be approximately equal to - # accumulation step time, combine those data points. - assert np.all(optim_step_time >= optim_sync_time) - accum_step_time += optim_step_time - optim_sync_time - accum_count += optim_count - accum_step_time /= accum_count - optim_step_time /= optim_count - state.perf_params = fit_perf_params(num_nodes, num_replicas, atomic_bsz, - accum_step_time, optim_step_time) - - -def _get_sched_hints(): - state = _metrics_state() - if len(state.profile) == 0: - return None - _fit_perf_params() - return _metrics_state() - - -def _report_sched_hints(): - assert adaptdl.env.replica_rank() == 0 - state = _metrics_state() - # Scheduling hints - sched_hints = SCHED_HINTS.copy() - sched_hints["perfParams"] = {k: v for (k, v) in - zip(PERF_PARAMS.keys(), - state.perf_params)} - sched_hints["maxBatchSize"] = state.max_batch_size - sched_hints["localBszBounds"] = state.local_bsz_bounds - sched_hints["initBatchSize"] = state.init_batch_size - if state.grad_params: - sched_hints["gradParams"] = {} - sched_hints["gradParams"]["norm"] = state.grad_params[0] - sched_hints["gradParams"]["var"] = state.grad_params[1] - sched_hints["maxProfiledReplicas"] = max(key[1] for key in state.profile) - sched_hints["gradientAccumulation"] = state.gradient_accumulation - post_sched_hints(sched_hints, adaptdl.env.job_id()) - - -class _MetricsState(adaptdl.checkpoint.State): - def __init__(self): - super().__init__("adaptdl-metrics") - self.profile = collections.defaultdict(collections.Counter) - self.perf_params = None - self.grad_params = None - self.init_batch_size = None - self.max_batch_size = None - self.local_bsz_bounds = None - self.gradient_accumulation = False - self.progress = 0.0 # Progress in scale-invariant iterations. - - def save(self, fileobj): - pickle.dump(self.profile, fileobj) - pickle.dump(self.perf_params, fileobj) - pickle.dump(self.grad_params, fileobj) - pickle.dump(self.init_batch_size, fileobj) - pickle.dump(self.max_batch_size, fileobj) - pickle.dump(self.local_bsz_bounds, fileobj) - pickle.dump(self.gradient_accumulation, fileobj) - pickle.dump(self.progress, fileobj) - - def load(self, fileobj): - self.profile = pickle.load(fileobj) - self.perf_params = pickle.load(fileobj) - self.grad_params = pickle.load(fileobj) - self.init_batch_size = pickle.load(fileobj) - self.max_batch_size = pickle.load(fileobj) - self.local_bsz_bounds = pickle.load(fileobj) - self.gradient_accumulation = pickle.load(fileobj) - self.progress = pickle.load(fileobj) - - -def _metrics_state(): - global _METRICS_STATE - if _METRICS_STATE is None: - _METRICS_STATE = _MetricsState() - adaptdl.checkpoint.load_state(_METRICS_STATE) - return _METRICS_STATE - - -_METRICS_STATE = None diff --git a/adaptdl/adaptdl/torch/torch/_metrics_test.py b/adaptdl/adaptdl/torch/torch/_metrics_test.py deleted file mode 100644 index ee69d9a2f..000000000 --- a/adaptdl/adaptdl/torch/torch/_metrics_test.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest - -from adaptdl.conftest import elastic_multiprocessing - - -@pytest.mark.parametrize("num_replicas", [1, 2, 3, 4]) -@elastic_multiprocessing -def test_profile(num_replicas): - import adaptdl.checkpoint - from adaptdl.env import num_restarts - from adaptdl.torch._metrics import ( - profile_step_start, profile_sync_time, - profile_step_commit, _metrics_state) - if num_restarts() == 0: - profile = _metrics_state().profile - assert len(profile) == 0 - # Profile local_bsz=1 but don't commit. - profile_step_start(1) - profile_sync_time(1.0) - # Profile local_bsz=2 and commit. - profile_step_start(2) - profile_sync_time(1.0) - profile_sync_time(2.0) - profile_step_commit() - # Ensure profile is updated correctly. - profile = _metrics_state().profile - key = (1, 1, 2) - assert len(profile) == 1 - assert profile[key]["accum_count"] == 0 - assert profile[key]["optim_count"] == 1 - assert profile[key]["optim_sync_time"] == 3.0 - assert profile[key]["optim_step_time"] > 0.0 - # Checkpoint and restart. - adaptdl.checkpoint.save_all_states() - return num_replicas - elif num_restarts() == 1: - profile = _metrics_state().profile - # Ensure checkpoint is loaded correctly. - key = (1, 1, 2) - assert len(profile) == 1 - assert profile[key]["accum_count"] == 0 - assert profile[key]["optim_count"] == 1 - assert profile[key]["optim_sync_time"] == 3.0 - assert profile[key]["optim_step_time"] > 0.0 - # Profile local_bsz=3 and commit twice. - profile_step_start(3) - profile_sync_time(2.0) - profile_sync_time(3.0) - profile_step_commit() - key = (1, num_replicas, 3) - old_step_time = profile[key]["optim_step_time"] - profile_step_start(3) - profile_sync_time(3.0) - profile_sync_time(4.0) - profile_step_commit() - # Ensure profile is updated correctly. - assert len(profile) == 2 - assert profile[key]["accum_count"] == 0 - assert profile[key]["optim_count"] == 2 - assert profile[key]["optim_sync_time"] == 12.0 - assert profile[key]["optim_step_time"] > old_step_time > 0.0 - - -@pytest.mark.parametrize("num_replicas", [1, 2, 3, 4]) -@elastic_multiprocessing -def test_profile_accumulation(num_replicas): - import adaptdl.checkpoint - from adaptdl.env import num_restarts - from adaptdl.torch._metrics import ( - profile_step_start, profile_sync_time, - profile_step_commit, _metrics_state, _fit_perf_params) - if num_restarts() == 0: - profile = _metrics_state().profile - assert len(profile) == 0 - # Profile local_bsz=1 but don't commit. - profile_step_start(1) - profile_sync_time(1.0) - # Profile local_bsz=2 and commit. - profile_step_start(2) - profile_step_commit(accumulation_step=True) - profile_step_start(2) - profile_step_commit(accumulation_step=True) - profile_step_start(2) - profile_sync_time(4.0) - profile_step_commit(accumulation_step=False) - profile_step_start(5) - profile_step_commit(accumulation_step=True) - profile_step_start(5) - profile_step_commit(accumulation_step=True) - profile_step_start(5) - profile_sync_time(6.0) - profile_step_commit(accumulation_step=False) - # Ensure profile is updated correctly. - profile = _metrics_state().profile - key = (1, 1, 2) - assert len(profile) == 2 - assert profile[key]["accum_count"] == 2 - assert profile[key]["optim_count"] == 1 - assert profile[key]["optim_sync_time"] == 4.0 - assert profile[key]["accum_step_time"] > 0.0 - assert profile[key]["optim_step_time"] > 0.0 - profile_step_start(3) - profile_step_commit(accumulation_step=True) - profile_step_start(3) - profile_step_commit(accumulation_step=True) - # Check that fitting parameters works even - # without a final accumulation_step=False commit - for val in profile.values(): - # Ensure step time is at least sync time. - val["optim_step_time"] += val["optim_sync_time"] - _fit_perf_params() - # Checkpoint and restart. - adaptdl.checkpoint.save_all_states() - return num_replicas - elif num_restarts() == 1: - profile = _metrics_state().profile - # Ensure checkpoint is loaded correctly. - key = (1, 1, 2) - assert len(profile) == 3 - assert profile[key]["accum_count"] == 2 - assert profile[key]["optim_count"] == 1 - assert profile[key]["optim_sync_time"] == 4.0 - assert profile[key]["optim_step_time"] > 0.0 - # Profile local_bsz=3 and commit twice. - profile_step_start(3) - profile_sync_time(2.0) - profile_sync_time(3.0) - profile_step_commit() - key = (1, num_replicas, 3) - old_step_time = profile[key]["optim_step_time"] - profile_step_start(3) - profile_sync_time(3.0) - profile_sync_time(4.0) - profile_step_commit() - # Ensure profile is updated correctly. - if num_replicas == 1: - assert len(profile) == 3 - else: - assert len(profile) == 4 - assert profile[key]["accum_count"] == 0 if num_replicas > 1 else 2 - assert profile[key]["optim_count"] == 2 - assert profile[key]["optim_sync_time"] == 12.0 - assert profile[key]["optim_step_time"] > old_step_time > 0.0 diff --git a/adaptdl/adaptdl/torch/torch/accumulator.py b/adaptdl/adaptdl/torch/torch/accumulator.py deleted file mode 100644 index 6851368ad..000000000 --- a/adaptdl/adaptdl/torch/torch/accumulator.py +++ /dev/null @@ -1,312 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import collections.abc -import contextlib -import copy -import pickle - -import adaptdl.checkpoint -import adaptdl.collective -from adaptdl.torch.epoch import current_epoch -from adaptdl.torch.data import current_dataloader - - -class Accumulator(collections.abc.MutableMapping): - """ - This class helps aggregate simple statistics across all replicas in the - current job, and across any number of checkpoint-restarts. Can be used to - compute metrics like loss and accuracy, synchronized across each replica. - - Accumulators imitate python dictionaries, but with a few key differences - described below. Primarily, its usage and behavior depend on whether it is - set to *accumulation mode* or to *synchronized mode*. - - 1. **Accumulation mode:** the accumulator is being updated on all - replicas. Operations like ``accum["key"] += val`` or - ``accum.update(key=val)`` will aggregate the updates locally on each - replica, which are lazily synchronized in the background (either upon a - checkpoint or a switch to synchronized mode). Each replica may make - different updates, which are summed together when synchronized. While - accumulation mode is enabled, all read operations on the accumulator - will behave as if they were performed on an empty ``dict``, ie. - ``len(accum)`` will always return ``0``. By default, all accumulators - are set to accumulation mode. - 2. **Synchronized mode:** the accumulator contains the same data on every - replica, and the application must ensure that all write operations are - exactly the same across all replicas. While in synchronized mode, the - accumulator may be used as if it were a native python ``dict``, and all - read/write operations are supported. :meth:`Accumulator.synchronized` - may be used to enter synchronized mode. Upon entering synchronized - mode, the accumulator will automatically sum all updates from all - replicas to ensure the same data is available to each replica. - - Using accumulators, many training/validation metrics can be computed - easily and correctly in an elastic distributed setting. For example, a - simple validation step which calculates a loss and accuracy can be - implemented as follows: - - .. code-block:: python - - accum = Accumulator() # New accumulator starts in accumulation mode. - - for epoch in remaining_epochs_until(60): - - for batch in validloader: - ... - accum["loss_sum"] += - accum["correct"] += - accum["total"] += - - with accum.synchronized(): # Enter synchronized mode. - accum["loss_avg"] = accum["loss_sum"] / accum["total"] - accum["accuracy"] = accum["correct"] / accum["total"] - print("Loss: {}, Accuracy: {}".format( - accum["loss_avg"], accum["accuracy"])) - accum.clear() - # Back to accumulation mode. - - Arguments: - args: Positional arguments same as ``dict``. - kwargs: Keyword arguments same as ``dict``. - - .. automethod:: __iadd__ - .. automethod:: __isub__ - .. automethod:: __getitem__ - """ - def __init__(self, *args, **kwargs): - self._sync_count = collections.Counter() - self._synchronized = None - self._state = _AccumulatorState(*args, **kwargs) - adaptdl.checkpoint.load_state(self._state) - - @contextlib.contextmanager - def synchronized(self): - """ - A context manager which can be used to define the code to execute in - *synchronized* mode. Within the context manager, any code can interact - with this accumulator as if it were a regular Python ``dict``. The - application must ensure that whatever operations performed within this - context block are the same across all replicas. - - .. warning:: - Entering this context manager is a distributed synchronization - point! Please ensure that all replicas enter this context manager - at the same point in their code. - """ - if self._synchronized is not None: - # Already synchronized, don't need to do anything. - yield self - return - epoch = current_epoch() - # Remove saved results from all finished epochs. Since finished - # epochs are never replayed, they should never be needed again. - for key in list(self._state.results_history.keys()): - if key is not None and key < epoch: - self._state.results_history.pop(key) - # Get the number of synchronizations so far in the current epoch. - count = self._sync_count[epoch] - self._sync_count[epoch] += 1 - results_list = self._state.results_history[epoch] - assert count <= len(results_list) - if count < len(results_list): - # Results for this synchronization are saved in the history. - self._synchronized = results_list[count] - self._state.updates.clear() - else: - self._state.sync() # Sync results and updates across replicas. - if current_dataloader() is None: - # Only save into results history if outside of a dataloader - # iteration, since code inside iterations are not replayed. - results_list.append(copy.deepcopy(self._state.results)) - self._synchronized = self._state.results - try: - yield self - finally: - self._synchronized = None - - def update(self, *args, **kwargs): - """ - Apply a collection of key-update pairs. Unlike ``dict.update``, this - method *additively* applies the updates to the accumulated values. - - Arguments: - args: Positional arguments same as ``dict.update``. Can be a - mapping object or an iterable of key-update pairs. - kwargs: Keyword arguments same as ``dict.update``. Each keyword is - the string key corresponding to the provided update. - """ - for key, val in dict(*args, **kwargs).items(): - self[key] += val - - def subtract(self, *args, **kwargs): - """ - Apply a collection of key-update pairs. Unlike - :meth:`Accumulator.update`, this method *subtracts* the updates from - the accumulated values. - - Arguments: - args: Positional arguments same as :meth:`Accumulator.update`. - kwargs: Keyword arguments same as :meth:`Accumulator.update`. - """ - for key, val in dict(*args, **kwargs).items(): - self[key] -= val - - def __iadd__(self, other): - """ - Supports the += operation, e.g. ``accum += {key1: val1, key2: val2}``. - Behaves the same way as ``accum.update({key1: val1, key2: val2})``. - - Arguments: - other: Mapping object or an iterable of key-update pairs. - """ - self.update(other) - return self - - def __isub__(self, other): - """ - Supports the -= operation, e.g. ``accum -= {key1: val1, key2: val2}``. - Behaves the same way as ``accum.subtract({key1: val1, key2: val2})``. - - Arguments: - other: Mapping object or an iterable of key-update pairs. - """ - self.subtract(other) - return self - - def __getitem__(self, key): - """ - Supports indexing, e.g. ``val = accum[key]`` and ``accum[key] += 1``. - The former (read access) should only be used when the accumulator is in - synchronized mode. - - Arguments: - other: Key used to access a value in the accumulator. - """ - if self._synchronized is not None: - return self._synchronized.__getitem__(key) - # In accumulation mode, return a dummy object which captures all - # updates performed on it, to be later applied by __setitem__. - return _Value(self, key) - - def __setitem__(self, key, value): - if self._synchronized is not None: - return self._synchronized.__setitem__(key, value) - # Whenever an in-place addition or subtraction is done, like a[k] += v, - # python will essentially perform 3 steps: (1) tmp = a.__getitem__(k), - # (2) tmp += v, (3) a.__setitem__(k, tmp). In order to obtain the - # update v, we let a.__getitem__(k) return an opaque object which - # implements the __add__ operator to capture the update v in step (2). - # Then, a.__setitem__(k, tmp) can pull v out of tmp and accumulate it. - if not isinstance(value, _Value): - raise TypeError("invalid value type: {}".format(type(value))) - if value.accum is not self: - raise ValueError("incompatible {}".format(self.__class__.__name__)) - if key != value.key: - raise ValueError("incompatible key: {}".format(value.key)) - self._state.updates.setdefault(key, 0) - self._state.updates[key] += value.update - - # Rest of the abstract methods needed by collections.MutableMapping - - def __contains__(self, key): - if self._synchronized is not None: - return self._synchronized.__contains__(key) - return {}.__contains__(key) - - def __delitem__(self, key): - if self._synchronized is not None: - return self._synchronized.__delitem__(key) - return {}.__delitem__(key) - - def __iter__(self): - if self._synchronized is not None: - return self._synchronized.__iter__() - return {}.__iter__() - - def __len__(self): - if self._synchronized is not None: - return self._synchronized.__len__() - return {}.__len__() - - def __repr__(self): - if self._synchronized is not None: - return self._synchronized.__repr__() - return {}.__repr__() - - -class _Value(object): - __slots__ = ["accum", "key", "update"] - - def __init__(self, accum, key): - # Initialize the opaque object used for supporting "accum[k] += v" and - # "accum[k] -= v" operations. - self.accum = accum - self.key = key - self.update = 0 - - def __add__(self, update): - if isinstance(update, _Value): - raise TypeError("invalid update type: {}".format(type(update))) - self.update += update - return self - - def __sub__(self, update): - if isinstance(update, _Value): - raise TypeError("invalid update type: {}".format(type(update))) - self.update -= update - return self - - -class _AccumulatorState(adaptdl.checkpoint.State): - - # Assume accumulators are initialized in the same order in every replica. - # Keep a map of epoch -> number of accumulators initialized so far in that - # epoch, and use that count to construct a unique name for the state. - init_count = collections.Counter() - - def __init__(self, *args, **kwargs): - if current_dataloader() is not None: - raise RuntimeError("accumulator may not be initialized during " - "dataloader iteration") - epoch = current_epoch() - count = _AccumulatorState.init_count[epoch] - super().__init__("adaptdl-accumulator-epoch{}-{}".format(epoch, count)) - _AccumulatorState.init_count[epoch] += 1 - - self.results_history = collections.defaultdict(list) - self.results = dict(*args, **kwargs) - self.updates = {} - - def save(self, fileobj): - pickle.dump((self.results_history, self.results), fileobj) - - def load(self, fileobj): - self.results_history, self.results = pickle.load(fileobj) - - def sync(self): - # Aggregate pending updates across all replicas and apply them. - updates = adaptdl.collective.allreduce(self.updates, _dict_iadd) - _dict_iadd(self.results, updates) - self.updates.clear() - - -def _dict_iadd(a, b): - for k, v in b.items(): - if k in a: - a[k] += v - else: - a[k] = v - return a diff --git a/adaptdl/adaptdl/torch/torch/accumulator_test.py b/adaptdl/adaptdl/torch/torch/accumulator_test.py deleted file mode 100644 index 6dc49bb58..000000000 --- a/adaptdl/adaptdl/torch/torch/accumulator_test.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from adaptdl.conftest import elastic_multiprocessing -from adaptdl.torch.accumulator import Accumulator - - -@elastic_multiprocessing -def test_accumulator_restarts(): - import adaptdl.checkpoint - import adaptdl.collective - from adaptdl.env import num_restarts, replica_rank - adaptdl.collective.initialize("0.0.0.0") - accum = Accumulator() - - if num_restarts() == 0: - accum["a"] += 15 # Idempotent. - assert "a" not in accum - with accum.synchronized(): - assert "a" in accum - assert accum["a"] == 15 - assert "a" not in accum - if num_restarts() == 0: - accum["a"] -= 5 # Idempotent. - adaptdl.checkpoint.save_all_states() - return 4 # Restart with 4 replicas. - - if num_restarts() == 1: # Idempotent. - accum.update({"a": replica_rank(), "b": replica_rank()}) - assert len(accum) == 0 - with accum.synchronized(): - assert len(accum) == 2 - assert accum["a"] == 16 - assert accum["b"] == 6 - assert len(accum) == 0 - if num_restarts() == 1: - adaptdl.checkpoint.save_all_states() - return 2 # Restart with 2 replicas. - - if num_restarts() == 2: # Idempotent. - accum -= {"b": 5, "c": 5} - with accum.synchronized(): - assert accum["a"] == 16 - assert accum["b"] == -4 - assert accum["c"] == -10 - accum.clear() - with accum.synchronized(): - assert not accum diff --git a/adaptdl/adaptdl/torch/torch/data.py b/adaptdl/adaptdl/torch/torch/data.py deleted file mode 100644 index 8da7ce7a0..000000000 --- a/adaptdl/adaptdl/torch/torch/data.py +++ /dev/null @@ -1,492 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from contextlib import contextmanager -import collections -import functools -import logging -import math -import numpy as np -import pickle -import random -import torch -from torch.utils.data import DataLoader, Sampler - -import adaptdl.checkpoint -import adaptdl.collective -import adaptdl.env -from adaptdl.torch.epoch import current_epoch -from adaptdl.torch._metrics import ( - profile_step_start, profile_step_commit, - set_batch_size, get_goodput_fn, get_progress) -from adaptdl._signal import get_exit_flag -from adaptdl.torch.context import AdaptiveDLContext - -logging.basicConfig(level=logging.INFO) -LOG = logging.getLogger(__name__) -LOG.setLevel(logging.INFO) - - -class ElasticSampler(Sampler): - """ - A PyTorch Sampler which partitions data samples across multiple replicas, - and supports deterministic continuing across checkpoint-restarts. Shuffling - is deterministic for each epoch, and :meth:`ElasticSampler.set_epoch` - should be invoked to obtain different orderings in different epochs. - - Arguments: - dataset (torch.util.data.Dataset): The dataset to sample from. - shuffle (bool): Whether the data samples should be shuffled. - - .. automethod:: __iter__ - .. automethod:: __len__ - """ - def __init__(self, dataset, shuffle=True): - self.dataset = dataset - self.shuffle = shuffle - self.num_replicas = adaptdl.env.num_replicas() - self.rank = adaptdl.env.replica_rank() - self.epoch = 0 - self.index = 0 - - def __iter__(self): - """ - Iterate through the samples in the dataset, in the order defined for a - set epoch, starting at a set index. Produces only the indices for the - local replica. - - Returns: Iterator over data sample indices. - """ - if self.shuffle: - # Deterministically shuffle based on epoch. - g = torch.Generator() - g.manual_seed(hash((self.epoch, self.index // len(self.dataset)))) - indices = torch.randperm(len(self.dataset), generator=g).tolist() - else: - indices = list(range(len(self.dataset))) - - base_index = self.index % len(self.dataset) - - # Subsample. - local_indices = indices[base_index + self.rank::self.num_replicas] - - # Add extra samples to make it evenly divisible. - if len(local_indices) < len(self): - local_indices.append(indices[self.rank]) - assert len(local_indices) == len(self) - return iter(local_indices) - - def __len__(self): - """ - The total number of samples to be iterated through, starting at the set - index, for the local replica. - - Returns (int): Number of samples. - """ - base_index = self.index % len(self.dataset) - return math.ceil((len(self.dataset) - base_index) / self.num_replicas) - - def set_epoch(self, epoch, index=0): - """ - Set the epoch to derive samples from. Optional argument ``index`` can - be specified to start sampling from a particular index, e.g. after a - checkpoint-restart. - - Arguments: - epoch (int): The epoch to sample from. - index (int): The index to start sampling from. - """ - self.epoch = epoch - self.index = index - - -def current_dataloader(): - """ - Reference to the data loader currently being iterated. - - Returns (AdaptiveDataLoaderHelper): Current data loader. - """ - return AdaptiveDataLoaderHelper._current - - -class AdaptiveDataLoaderHelper(object): - """ - This class provides fine-grained control over adaptive training loops. It - can be used for building more user-friendly custom data loaders, such as - :class:`AdaptiveDataLoader`. - - Arguments: - batch_size (int): The target total batch size across all replicas. The - actual total batch size may be different due to rounding (each - replica must have the same local batch size), or being scaled up - using adaptive batch sizes. - """ - - # Epoch -> the number of dataloader loops completed so far in that epoch, - # across all AdaptiveDataLoader objects. - _position = collections.Counter() - _training = None # The AdaptiveDataLoader which loads training data. - _current = None # The AdaptiveDataLoader which is currently iterating. - - def __init__(self, batch_size=1): - # Autoscale batch size fields. - self._max_batch_size = None - self._local_bsz_bounds = None - # Create and load state. - self._state = _AdaptiveDataLoaderState() - adaptdl.checkpoint.load_state(self._state) - self.batch_size = batch_size - self.future_exit = None - self._gradient_accumulation = False - self._speedup_threshold = 1.05 - self._accum_count = 0 - - @property - def current_index(self): - """ - The total number of data samples processed so far in the current loop. - Includes the data processed by all replicas. ``None`` if this data - loader is not currently being iterated. - """ - if AdaptiveDataLoaderHelper._current is not self: - return None - return self._state.current_index - - @current_index.setter - def current_index(self, index): - if AdaptiveDataLoaderHelper._current is not self: - return - self._state.current_index = index - - @property - def end_index(self): - """ - (Optional) Can be used to track the end index of dataset across - restarts. - """ - return self._state.end_index - - @end_index.setter - def end_index(self, index): - """ - (Optional) Supports mutations of end_index - """ - self._state.end_index = index - - @property - def max_batch_size(self): - """ - The maximum total batch size allowed for adaptive batch size. ``None`` - if adaptive batch size is disabled. - """ - return self._max_batch_size - - @property - def local_bsz_bounds(self): - """ - The local batch size bounds on each replica. A pair of integers, - (min_local_bsz, max_local_bsz). - """ - return self._local_bsz_bounds - - @property - def current_local_bsz(self): - """ - The current logical local batch size used by the dataloader. - The batch size returned by the dataloader may be smaller if - gradient accumulation is used - """ - return self._state.current_local_bsz - - @property - def accumulation_steps(self): - """ - The number of batches returned by the dataloader before a - step is taken. - """ - return self._state.accumulation_steps - - def is_accum_step(self): - """ - Whether the current step's gradient will be accumulated. - """ - return self._accum_count < self._state.accumulation_steps - - def is_optim_step(self): - """ - Whether the optimizer step will be invoked in this step. - """ - return not self.is_accum_step() - - def train(self): - """ - Set this data loader to be the one used for training. Only one data - loader may be used for training. - """ - if AdaptiveDataLoaderHelper._training is None: - AdaptiveDataLoaderHelper._training = self - set_batch_size(self.batch_size, self.max_batch_size, - self.local_bsz_bounds, self._gradient_accumulation) - - def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, - gradient_accumulation=False): - """ - Enables adaptive batch size. Should be invoked once after the data - loader object is created. - - Arguments: - max_batch_size (int): Maximum total batch size allowed. - local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz), - the min and max local batch sizes allowed on each replica. - - Raises: - ValueError: If any of the provided batch size bounds are invalid. - """ - if not isinstance(max_batch_size, int) or \ - max_batch_size < self.batch_size: - raise ValueError("invalid max_batch_size") - if local_bsz_bounds is not None and ( - local_bsz_bounds[0] is not None and - local_bsz_bounds[0] > self.batch_size or - local_bsz_bounds[1] is not None and - local_bsz_bounds[1] < self.batch_size): - raise ValueError("invalid local_bsz_bounds") - self._max_batch_size = max_batch_size - self._local_bsz_bounds = local_bsz_bounds - self._gradient_accumulation = gradient_accumulation - self.train() - - @property - def training(self): - return self is AdaptiveDataLoaderHelper._training - - @contextmanager - def profile(self, commit): - """ - Every iteration of every epoch should be profiled under this context. - Note that, custom DataLoader writers should make sure that it gets - called equal number of times on each replica. - - Arguments: - commit (bool): Whether to commit the profiled results. - """ - # Synchronize the exit signal so all replicas exit after - # the same iteration. Do this asynchronously to prevent - # unnecessary blocking on the network. - if self.future_exit is not None and self.future_exit.result(): - adaptdl.checkpoint.save_all_states() - exit(143) # Standard exit code response to SIGTERM. - self.future_exit = adaptdl.collective.allreduce_async( - get_exit_flag(), lambda a, b: a or b) - profile_step_start(self.current_local_bsz) - yield - if commit: - profile_step_commit(self.is_accum_step()) - self._accum_count = (0 if self.is_optim_step() - else self._accum_count + 1) - - @contextmanager - def context(self): - """ - All iterators should be iterated under this context. It ensures - proper cleanup of elastic context at the end of each epoch. - """ - epoch = current_epoch() - try: - if AdaptiveDataLoaderHelper._current is not None: - raise RuntimeError("overlapping dataloader \ - iterations detected") - AdaptiveDataLoaderHelper._current = self - yield - finally: - self._state.current_index = 0 - self._state.end_index = 0 - self._state.last_position[epoch] = self._position[epoch] - self._position[epoch] += 1 - AdaptiveDataLoaderHelper._current = None - - @property - def current_batch_size(self): - return (self.current_local_bsz * (self.accumulation_steps + 1) * - adaptdl.env.num_replicas()) - - def skipdone(self): - """ - Should be called just after entering the `_elastic` context to make - sure that the dataloader loop is not replayed if has already finished - before a restart. - """ - - epoch = current_epoch() - position = self._position[epoch] - if position <= self._state.last_position.get(epoch, -1): - # Already completed the dataloader loop at the current - # position, skip this loop and keep replaying the application - # code. - LOG.info("skipping %s loop at position %s in epoch %s", - self.__class__.__name__, position, epoch) - self._position[epoch] += 1 - return True - else: - return False - - def to_tensorboard(self, writer, global_step, tag_prefix=""): - """ - Output some useful metrics to TensorBoard. - - Arguments: - writer (torch.utils.tensorboard.SummaryWriter): ``SummaryWriter`` - object to output metrics to. - global_step (int): Global step value to record. - tag_prefix (str): Prefix added to each metric's tag. - """ - if tag_prefix and not tag_prefix.endswith("/"): - tag_prefix += "/" - writer.add_scalar(tag_prefix + "Total_Batch_Size", - self.current_batch_size, global_step) - writer.add_scalar(tag_prefix + "Local_Batch_Size", - self.current_local_bsz, global_step) - writer.add_scalar(tag_prefix + "Accumulation_Steps", - self.accumulation_steps, global_step) - - -def _worker_init_wrapper(worker_init_fn, num_workers): - # Set globally-unique python and numpy seeds for each worker. - - @functools.wraps(worker_init_fn) - def wrapper(worker_id): - nonlocal num_workers - num_workers = num_workers or 1 - # https://pytorch.org/docs/master/data.html#randomness-in-multi-process-data-loading. - seed = torch.initial_seed() + adaptdl.env.replica_rank() * num_workers - torch.manual_seed(seed) - np.random.seed(seed % 2 ** 32) - random.seed(seed) - if worker_init_fn is not None: - return worker_init_fn(worker_id) - return wrapper - - -class AdaptiveDataLoader(DataLoader, AdaptiveDLContext): - """ - This class is a PyTorch DataLoader that also supports adaptive batch sizes - and checkpoint-restart elasticity. Applications can typically use objects - of this class as direct replacements for PyTorch DataLoaders. However, some - notable differences are: - - 1. The ``batch_size`` argument defines the target total batch size across - all replicas, rather than the local batch size on each replica. - 2. Custom ``sampler`` and ``batch_sampler`` are not supported. - 3. Iterating through the dataloader is only allowed from within an epoch - loop (see :mod:`adaptdl.torch.epoch`), and only one dataloader loop is - allowed at any given time. - - Arguments: - dataset (torch.util.data.Dataset): Dataset from which to load the data. - batch_size (int): The target total batch size across all replicas. The - actual total batch size may be different due to rounding (each - replica must have the same local batch size), or being scaled up - using adaptive batch sizes. - shuffle (bool): Whether the data is reshuffled at every epoch. - **kwargs: Keyword arguments passed to ``torch.util.data.Dataloader``. - - Raises: - ValueError: If ``sampler`` or ``batch_sampler`` are not ``None``. - - .. automethod:: __iter__ - """ - def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): - if kwargs.get("batch_sampler") is not None \ - or kwargs.get("sampler") is not None: - raise ValueError("AdaptiveDataLoader does not support " - "custom 'sampler' or 'batch_sampler'") - # Custom sampler is incompatible with shuffle=True, so we always set - # shuffle=False in __init__ and let our own sampler do the shuffling. - kwargs["sampler"] = ElasticSampler(dataset, shuffle=shuffle) - kwargs["worker_init_fn"] = _worker_init_wrapper( - kwargs.get("worker_init_fn"), kwargs.get("num_workers")) - super().__init__(dataset, batch_size, shuffle=False, **kwargs) - AdaptiveDLContext.__init__(self, batch_size) - - def __iter__(self): - """ - Iterate over batches of data. When adaptive batch size is disabled, - stops after the entire dataset has been processed once in total by all - replicas. This means if there are K replicas, then this method will - iterate over ~1/K of the dataset. When adaptive batch size is enabled, - stops after making enough statistical progress roughly equivalent to - one pass over the dataset with non-adaptive batch size. In this case, - the dataset may be processed more than once. - - A checkpoint-restart may be triggered in-between each batch. In this - case, the current iteration state will be saved and restored after the - restart, and continue where it left off. - """ - epoch = current_epoch() - num_replicas = adaptdl.env.num_replicas() - with self._elastic.context(): - if self._elastic.skipdone(): - return - done = False - while not done: - self.sampler.set_epoch( - epoch, index=self._elastic.current_index) - self.batch_sampler.batch_size = self.get_batch_size() - for idx, batch in enumerate(super().__iter__()): - with self._elastic.profile(self.training and idx >= 1): - yield batch - # Increment by the number of data samples processed - self._elastic.current_index += \ - num_replicas * self.batch_sampler.batch_size - if self._elastic.max_batch_size is not None and \ - get_progress() >= len(self.dataset) * \ - (epoch + 1) / self.batch_size: - done = True - break - if self._elastic.max_batch_size is None: - done = True - self._elastic.current_index -= \ - self._elastic.current_index % -len(self.dataset) - - -class _AdaptiveDataLoaderState(adaptdl.checkpoint.State): - - # Assume dataloaders are initialized in the same order in every replica. - # Keep a map of epoch -> number of dataloaders initialized so far in that - # epoch, and use that count to construct a unique name for the state. - init_count = collections.Counter() - - def __init__(self): - if current_dataloader() is not None: - raise RuntimeError("dataloader may not be initialized during " - "dataloader iteration") - epoch = current_epoch() - count = _AdaptiveDataLoaderState.init_count[epoch] - super().__init__("adaptdl-dataloader-epoch{}-{}".format(epoch, count)) - _AdaptiveDataLoaderState.init_count[epoch] += 1 - - self.current_index = 0 # Index within the current dataloader loop. - self.end_index = 0 # End index of the current DataLoader loop. - self.last_position = {} # Epoch -> position of last completed loop. - self.current_local_bsz = 0 - self.accumulation_steps = 0 - - def save(self, fileobj): - pickle.dump((self.current_index, self.end_index, - self.last_position), fileobj) - - def load(self, fileobj): - self.current_index, self.end_index, self.last_position = \ - pickle.load(fileobj) diff --git a/adaptdl/adaptdl/torch/torch/data_test.py b/adaptdl/adaptdl/torch/torch/data_test.py deleted file mode 100644 index e4742263f..000000000 --- a/adaptdl/adaptdl/torch/torch/data_test.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -import math - -import pytest -import torch -import torchtext -from torch.utils.data import TensorDataset -from torchtext.data.utils import get_tokenizer - -from adaptdl.conftest import elastic_multiprocessing -from adaptdl.torch.data import (ElasticSampler, AdaptiveDataLoader, - current_dataloader) -from adaptdl.torch.iterator import AdaptiveBPTTIterator - - -@pytest.mark.parametrize("num_replicas", [1, 3, 5]) -@pytest.mark.parametrize("dataset_size", [9, 15, 25]) -def test_sampler_epoch(num_replicas, dataset_size, - epoch=0, index=0, shuffle=True): - dataset = TensorDataset(torch.rand(dataset_size)) - sampler = ElasticSampler(dataset, shuffle=shuffle) - sampler.num_replicas = num_replicas - sampler.set_epoch(epoch, index) - epoch_samples = [] - sample_counts = collections.Counter() - for rank in range(num_replicas): - sampler.rank = rank - epoch_samples.append(list(sampler)) - # Check indices are split evenly between replicas. - assert len(sampler) == \ - math.ceil((dataset_size - index % dataset_size) / num_replicas) - # Check the actual samples obey the length. - assert len(sampler) == len(epoch_samples[rank]) - # Check ordering is the same within the same epoch. - assert list(sampler) == epoch_samples[rank] - sample_counts.update(epoch_samples[rank]) - # Check all indices are present. - assert len(sample_counts) >= dataset_size - index % dataset_size - assert all(0 <= key < dataset_size for key in sample_counts) - # Check each index is counted roughly the same number of times. - assert max(sample_counts.values()) - min(sample_counts.values()) <= 1 - return epoch_samples - - -@pytest.mark.parametrize("num_replicas", [1, 3, 5]) -@pytest.mark.parametrize("dataset_size", [9, 15, 25]) -def test_sampler_shuffle(num_replicas, dataset_size): - epoch0_samples = test_sampler_epoch(num_replicas, dataset_size, epoch=0) - epoch1_samples = test_sampler_epoch(num_replicas, dataset_size, epoch=1) - assert epoch0_samples != epoch1_samples # Shuffle is on. - epoch0_samples = test_sampler_epoch(num_replicas, dataset_size, - epoch=0, shuffle=False) - epoch1_samples = test_sampler_epoch(num_replicas, dataset_size, - epoch=1, shuffle=False) - assert epoch0_samples == epoch1_samples # Shuffle is off. - - -@pytest.mark.parametrize("num_replicas", [1, 3, 5]) -@pytest.mark.parametrize("dataset_size", [9, 15, 25]) -def test_sampler_index(num_replicas, dataset_size): - index = dataset_size // 2 # Set index to halfway through the dataset. - epoch_samples = test_sampler_epoch(num_replicas, dataset_size, - index=index, shuffle=False) - samples = sum(epoch_samples, []) - # Check contains second half of dataset. - for idx in range(index, dataset_size): - assert idx in samples - - index = 2 * dataset_size # Test sampler wrap-around. - epoch_samples = test_sampler_epoch(num_replicas, dataset_size, - index=index, shuffle=False) - assert set(sum(epoch_samples, [])) == set(range(dataset_size)) - - -@elastic_multiprocessing -def test_dataloader_restarts(): - import adaptdl.checkpoint - import adaptdl.collective - from adaptdl.env import num_restarts, num_replicas - adaptdl.collective.initialize("0.0.0.0") - dataset_size = 100 - init_batch_size = 10 - dataset = TensorDataset(torch.rand(dataset_size)) - dataloader = AdaptiveDataLoader(dataset, batch_size=init_batch_size) - # Load data samples in the following order: - # 2 batches (20 samples) using 1 replica (local_bsz = 10, batch_size = 10) - # 5 batches (60 samples) using 4 replica (local_bsz = 3, batch_size = 12) - # 2 batches (20 samples) using 2 replica (local_bsz = 5, batch_size = 10) - assert current_dataloader() is None - for idx, batch in enumerate(dataloader): - if num_restarts() == 0 and idx == 2: - adaptdl.checkpoint.save_all_states() - return 4 # Restart with 4 replicas. - if num_restarts() == 1 and idx == 5: - adaptdl.checkpoint.save_all_states() - return 2 # Restart with 2 replicas. - assert current_dataloader() is dataloader._elastic - local_bsz = batch[0].size(0) - assert dataloader.current_local_bsz == local_bsz - assert local_bsz == math.ceil(init_batch_size / num_replicas()) - assert dataloader.current_batch_size == num_replicas() * local_bsz - # After the last 2 batches. - assert idx == 1 - - -@elastic_multiprocessing -def test_dataloader_break(): - import adaptdl.checkpoint - import adaptdl.collective - from adaptdl.env import num_restarts - if num_restarts() == 0: - return 2 - adaptdl.collective.initialize("0.0.0.0") - dataset = TensorDataset(torch.rand(100)) - dataloader = AdaptiveDataLoader(dataset, batch_size=10) - # Break in the middle of the first for-loop, and start another for-loop. - assert current_dataloader() is None - for idx, batch in enumerate(dataloader): - assert current_dataloader() is dataloader._elastic - if idx == 5: - break - assert current_dataloader() is None - for idx, batch in enumerate(dataloader): - assert current_dataloader() is dataloader._elastic - assert idx == 9 # Run 10 batches total. - - -@elastic_multiprocessing -def test_bptt_iterator(): - import adaptdl.checkpoint - import adaptdl.collective - from adaptdl.env import num_restarts - adaptdl.collective.initialize("0.0.0.0") - # Load the iterator with 500 words - # 1 batch (5x10) using 1 replica. Restart after one iteration. - # 9 batches (5x5) using 2 replicas to consume remaining batches. - TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"), - init_token='', - eos_token='') - fields = [('text', TEXT)] - examples = [torchtext.data.Example.fromlist([['The'] * 500], fields)] - dataset = torchtext.data.Dataset(examples, fields) - TEXT.build_vocab(dataset) - bptt_iter = AdaptiveBPTTIterator(dataset, batch_size=10, bptt_len=5) - for idx, batch in enumerate(bptt_iter): - if num_restarts() == 0 and idx == 1: - assert batch.text.shape == (5, 10) - adaptdl.checkpoint.save_all_states() - return 2 - if adaptdl.env.num_replicas() == 2: - assert batch.text.shape == (5, 5) or batch.text.shape == (4, 5) - if adaptdl.env.num_replicas() == 2: - assert idx == 8 diff --git a/adaptdl/adaptdl/torch/torch/epoch.py b/adaptdl/adaptdl/torch/torch/epoch.py deleted file mode 100644 index 08b58ede1..000000000 --- a/adaptdl/adaptdl/torch/torch/epoch.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module provides tools for the top-level loop over epochs during training. -AdaptDL expects the training program to be implemented as loop over several -epochs, each containing a series of loops over datasets (e.g. one loop over the -training set followed by one loop over the validation set). The program can be -interrupted between every iteration of any dataset loop, trigger a checkpoint -to be taken, and restarted using a different set of replicas. - -**Due to checkpoint-restarts, parts of the training program may be executed -multiple times (e.g. once after each restart)!** To avoid incorrect execution, -ensure that your code is idempotent_ in the following locations: - -1. Immediately before any epoch loop (using :func:`remaining_epochs_until`). -2. Immediately before any dataset loop (using - :class:`adaptdl.torch.data.AdaptiveDataLoader`). - -Your code may be non-idempotent in other locations. - -.. code-block:: python - - ### IDEMPOTENT CODE ONLY ### - - for epoch in remaining_epochs_until(30): - - ### IDEMPOTENT CODE ONLY ### - - for batch in train_loader: - # ... any code ... - - ### IDEMPOTENT CODE ONLY ### - - for batch in valid_loader: - # ... any code ... - - # ... any code ... - - # ... any code ... - - ### END PROGRAM ### - -For example, a common non-idempotent operation is learning-rate annealing: - -.. code-block:: python - - for epoch in remaining_epochs_until(30): - - lr_scheduler.step() # (A) WRONG! - - for batch in train_loader: - # ... - - lr_scheduler.step() # (B) WRONG! - - for batch in valid_loader: - # ... - - lr_scheduler.step() # (C) OK! - -Location (A) will be executed again after any checkpoint-restart during either -the training or validation loop, resulting in the learning rate being annealed -several times in one epoch! Similarly with location (B), if checkpoint-restart -happens during the validation loop. - -Location (C) results in the correct behavior, because (1) an epoch will not be -repeated once it has finished, and (2) no checkpoint-restarts can occur between -the learning rate annealing and the end of the epoch. - -.. _idempotent: https://stackoverflow.com/a/1077421 -""" - -import logging -import pickle - -import adaptdl.checkpoint - - -logging.basicConfig(level=logging.INFO) -LOG = logging.getLogger(__name__) -LOG.setLevel(logging.INFO) - - -def remaining_epochs_until(epoch): - """ - Iterate over epochs in a way that is consistent with checkpoint-restarts. - For example: - - .. code-block:: python - - for epoch in remaining_epochs_until(30): - print(current_epoch()) # Should print 0 through 29 - - for epoch in remaining_epochs_until(60): - print(current_epoch()) # Should print 30 through 59 - - If a checkpoint-restart happens during an epoch, all previous epochs will - be skipped after the program restarts. - - Arguments: - epoch (int): The epoch number to end at (exclusively). - - Raises: - RuntimeError: If invoked before a previous epoch loop has ended. - """ - if current_epoch() is not None: - raise RuntimeError("overlapping epoch loops detected") - if finished_epochs() < epoch: - LOG.info("starting at epoch %s", finished_epochs()) - else: - LOG.info("skipping all epochs up to %s", epoch) - while finished_epochs() < epoch: - _epoch_state().current_epoch = finished_epochs() - try: - yield current_epoch() - finally: - # Try to catch any exits from epoch loop, including breaks and - # Exceptions. See https://www.peterbe.com/plog/generatorexit. - _epoch_state().finished_epochs += 1 - _epoch_state().current_epoch = None - - -def current_epoch(): - """ - Get the current epoch while iterating with :func:`remaining_epochs_until`. - - Returns: - int or None: The current epoch number if called from within a - :func:`remaining_epochs_until` iteration, ``None`` otherwise. - """ - return _epoch_state().current_epoch - - -def finished_epochs(): - """ - Get the number of epochs finished using :func:`remaining_epochs_until`. - - Returns: - int: The number of finished epochs. Equal to :func:`current_epoch` - if called from within a :func:`remaining_epochs_until` iteration. - """ - return _epoch_state().finished_epochs - - -class _EpochState(adaptdl.checkpoint.State): - def __init__(self): - super().__init__(".adaptdl-epoch") - self.finished_epochs = 0 - self.current_epoch = None - - def save(self, fileobj): - pickle.dump(self.finished_epochs, fileobj) - - def load(self, fileobj): - self.finished_epochs = pickle.load(fileobj) - - -def _epoch_state(): - global _EPOCH_STATE - if _EPOCH_STATE is None: - _EPOCH_STATE = _EpochState() - adaptdl.checkpoint.load_state(_EPOCH_STATE) - return _EPOCH_STATE - - -_EPOCH_STATE = None diff --git a/adaptdl/adaptdl/torch/torch/epoch_test.py b/adaptdl/adaptdl/torch/torch/epoch_test.py deleted file mode 100644 index 464bb240c..000000000 --- a/adaptdl/adaptdl/torch/torch/epoch_test.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from adaptdl.conftest import elastic_multiprocessing - - -@elastic_multiprocessing -def test_epoch(): - import adaptdl.checkpoint - from adaptdl.env import num_restarts - from adaptdl.torch.epoch import (remaining_epochs_until, - current_epoch, finished_epochs) - total_epochs = 10 - restart_epoch = 5 - assert current_epoch() is None - if num_restarts() == 0: - assert finished_epochs() == 0 - expected_epochs = list(range(restart_epoch + 1)) - elif num_restarts() == 1: - assert finished_epochs() == restart_epoch - expected_epochs = list(range(restart_epoch, total_epochs)) - else: - assert False - for idx, epoch in enumerate(remaining_epochs_until(10)): - assert epoch == expected_epochs[idx] - assert current_epoch() == epoch - assert finished_epochs() == epoch - if num_restarts() == 0 and epoch == restart_epoch: - adaptdl.checkpoint.save_all_states() - return 5 # Restart with 5 replicas. diff --git a/adaptdl/adaptdl/torch/torch/gradient_noise_scale.py b/adaptdl/adaptdl/torch/torch/gradient_noise_scale.py deleted file mode 100644 index 2687644b9..000000000 --- a/adaptdl/adaptdl/torch/torch/gradient_noise_scale.py +++ /dev/null @@ -1,330 +0,0 @@ -import functools -import logging -import math -import numpy as np -import torch.distributed -import torch.optim - -from torch.autograd import Variable - -import adaptdl.utils - -__all__ = ["GradientNoiseScale"] - -logging.basicConfig(level=logging.INFO) -LOG = logging.getLogger(__name__) -LOG.setLevel(logging.INFO) - - -def _average_groups(grads1, grads2): - ret = [] - for group1, group2 in zip(grads1, grads2): - ret.append([]) - for g1, g2 in zip(group1, group2): - if g1 is None: - ret[-1].append(g2) - elif g2 is None: - ret[-1].append(g1) - else: - ret[-1].append((g1 + g2) / 2) - return ret - - -def _normsqr_groups(grads, pinvs): - ret = [] - for group, pinv_group in zip(grads, pinvs): - normsqr = [(g / pinv).pow(2).sum(dtype=torch.float64) - for g, pinv in zip(group, pinv_group) if g is not None] - ret.append(sum(normsqr).item() if normsqr else 0.0) - return np.array(ret) - - -class GradientNoiseScale(object): - """This class tracks gradient related stats and takes care of gradient - accumulation.""" - def __init__(self, adp, optimizer, - mp_scaler=None, - num_replicas=None, - accum_scale=None): - self._adp = adp - self._optimizer = optimizer - self._orig_optimizer_zero_grad = optimizer.zero_grad - self._should_zero_grad = True - self._mp_scaler = mp_scaler - self._local_sqr = None - self._num_replicas = (num_replicas if num_replicas is not None - else torch.distributed.get_world_size()) - self._accum_scale = accum_scale or self._num_replicas - self._prev_grads = None - - self.reset_accumulation() - - self._optimizer.state.setdefault("gns", { - "progress": 0.0, - "prev_scale": 0.0, - # Averages of n and v - "sqr_avg": np.ones(len(optimizer.param_groups)), - "var_avg": np.zeros(len(optimizer.param_groups)), - # Whether estimates are biased (using differenced estimator). - "biased": False, - }) - - for idx, param_group in enumerate(self._optimizer.param_groups): - for param in param_group["params"]: - param.register_hook( - functools.partial(self._backward_hook, idx, param)) - self._callback_queued = False - self._smoothing = 0.999 - - @property - def _state(self): - return self._optimizer.state["gns"] - - def reset_accumulation(self): - """reset accumulation calculations and gradients.""" - self._orig_optimizer_zero_grad() - self._local_sqr = None - self._accum_count = 0 - - @property - def should_zero_grad(self): - return self._should_zero_grad - - @property - def accum_scale(self): - return self._accum_scale - - @property - def accum_count(self): - return self._accum_count - - def set_accum_scale(self, accum_scale): - if not np.isclose(self._accum_scale, accum_scale): - self.reset_accumulation() - self._accum_scale = accum_scale - - @property - def raw_sqr_avg(self): - view = self._state["sqr_avg"].view() - view.flags.writeable = False - return view - - def sqr_avg(self): - """ - Current estimate of the squared l2-norm of the true gradient (sigma - squared). - - Returns (float): Estimate of squared l2-norm. - """ - return float(np.sum(np.maximum(self._state["sqr_avg"], 0.0))) - - @property - def raw_var_avg(self): - view = self._state["var_avg"].view() - view.flags.writeable = False - return view - - def var_avg(self): - """ - Current estimate of the trace of the covariance of the true gradient - (mu squared). - - Returns (float): Estimate of trace of the covariance. - """ - return float(np.sum(np.maximum(self._state["var_avg"], 1e-6))) - - def get_progress(self): - return self._state["progress"] - - def set_progress(self, progress): - self._state["progress"] = progress - - def gain(self, scale): - """ - Current estimate of the GradientNoiseScale gain ratio. - - Arguments: - scale (float): The total scale to estimate the gain ratio for. - - Returns (float): Estimate of gain ratio. - """ - var = self.var_avg() - norm = self.sqr_avg() - return (var + norm) / (var / scale + norm) - - def _update_avg(self, param_name, value, factor): - biased = self._state.get(param_name + "_biased", 0.0) - unbias = self._state.get(param_name + "_unbias", 0.0) - biased = factor * biased + (1.0 - factor) * value - unbias = factor * unbias + (1.0 - factor) - self._state[param_name + "_biased"] = biased - self._state[param_name + "_unbias"] = unbias - self._state[param_name] = biased / unbias - - def _reset_avg(self, param_name): - self._state.pop(param_name + "_biased", None) - self._state.pop(param_name + "_unbias", None) - - @adaptdl.utils.print_exc - def _backward_hook(self, idx, param, grad): - # This method should be invoked once for each parameter during the - # backward pass, before gradients are synchronized between replicas. - if self._local_sqr is None: - self._local_sqr = torch.zeros(len(self._optimizer.param_groups), - device=grad.device, - dtype=torch.float64) - - # Get the preconditioning matrix for the optimizer - preconditioner = self._calculate_preconditioner(idx, param) - - # Update the local gradient square sum - self._local_sqr[idx] += \ - (grad.detach() / preconditioner).pow(2).sum(dtype=torch.float64) - if not self._callback_queued: - Variable._execution_engine.queue_callback(self._queue_callback) - self._callback_queued = True - - @adaptdl.utils.print_exc - def _queue_callback(self): - # This method should be invoked after the entire backward pass. We want - # to make sure self._final_callback is invoked once, only after all - # gradients have been synchronized between each replica. However, the - # synchronization code in DistributedDataParallel is also done in a - # callback, which might not yet be executed. Therefore, we enqueue - # self._final_callback from this method, which should ensure it is - # invoked after the gradient synchronization callback. - self._callback_queued = False - self._accum_count += 1 - if self._adp.require_backward_grad_sync: - # Asynchronously sum the local squared-gradient statistics. The - # actual gradient averaging should also be happening at the same - # time, until self._final_callback is invoked. - if self._num_replicas > 1: - self._async_op = torch.distributed.all_reduce(self._local_sqr, - async_op=True) - Variable._execution_engine.queue_callback(self._final_callback) - self._should_zero_grad = True - else: - # Keep on accumulating gradients, should not zero grad. - self._should_zero_grad = False - - @adaptdl.utils.print_exc - def _final_callback(self): - # This method should be invoked once the gradients have been - # synchronized between all replicas and accumulation steps. - if self._num_replicas > 1: - self._async_op.wait() - grads = [] - if self._mp_scaler is not None: - mixed_precision_scale = self._mp_scaler.get_scale() - else: - mixed_precision_scale = 1.0 - for group in self._optimizer.param_groups: - grads.append([]) - for param in group["params"]: - if param.grad is None: - grads[-1].append(None) - continue - grad = param.grad.detach().float() - grads[-1].append( - grad / mixed_precision_scale / self._accum_count) - preconditioner = self._get_preconditioner() - - # Note: mixed precision can result in nan/inf gradients, - # which propogate into our norm and variance estimates. - # Mixed precision autoscaling skips the skip where - # there are nan/inf, so we also skip the update here - grads_normsqr = _normsqr_groups(grads, preconditioner) - if not np.all(np.isfinite(grads_normsqr)): - LOG.warning("GradientNoiseScale detected invalid gradient! " - "Skipping step.") - return - count = self._num_replicas * self._accum_count - scale = self._accum_scale * self._accum_count - if count > 1: - # Average local squared-norm samples. - local_sqr = self._local_sqr.cpu().numpy() / count - # Gradient is squared in local_sqr, so need to square the - # mixed precision scale as well - local_sqr = (local_sqr / mixed_precision_scale ** 2) - total_sqr = grads_normsqr - if self._state["biased"]: - self._reset_avg("sqr_avg") - self._reset_avg("var_avg") - self._state["biased"] = False - self._prev_grads = None - else: - # Single gradient datapoint, use difference estimation. - if self._prev_grads is not None: - local_sqr = (_normsqr_groups(self._prev_grads, preconditioner) - + grads_normsqr) / 2 - avg_grads = _average_groups(grads, self._prev_grads) - total_sqr = _normsqr_groups(avg_grads, preconditioner) - count = 2 - scale = 2 * self._accum_scale - self._state["biased"] = True - self._prev_grads = [[g.clone() if g is not None else None - for g in group] for group in grads] - if count > 1: - grad_sqr = (count * total_sqr - local_sqr) / (count - 1) - grad_var = (local_sqr - total_sqr) * scale / (count - 1) - theta = self._smoothing ** scale - self._update_avg('sqr_avg', grad_sqr, theta) - self._update_avg('var_avg', grad_var, theta) - - def _get_preconditioner(self): - out = [] - for idx, group in enumerate(self._optimizer.param_groups): - pinvs = [] - for param in group["params"]: - pinv = self._calculate_preconditioner(idx, param) - pinvs.append(pinv) - out.append(pinvs) - return out - - def _calculate_preconditioner(self, idx, param): - return torch.ones_like(param, memory_format=torch.preserve_format) - - -class AdamGradientNoiseScale(GradientNoiseScale): - def __init__(self, adp, optimizer, - mp_scaler=None, - num_replicas=None, - accum_scale=None): - self._adam_param_group = {'beta': [], 'eps': []} - super().__init__(adp, optimizer, mp_scaler, num_replicas, accum_scale) - for idx, param_group in enumerate(self._optimizer.param_groups): - self._adam_param_group['beta'].append(param_group['betas'][1]) - self._adam_param_group['eps'].append(param_group['eps']) - - def _calculate_preconditioner(self, idx, param): - state = self._optimizer.state[param] - if state.get('step', 0) < 5: - return torch.ones_like(param, memory_format=torch.preserve_format) - - exp_avg_sq = state["exp_avg_sq"].clone() # not sure if clone is needed - beta2 = self._adam_param_group['beta'][idx] - eps = self._adam_param_group['eps'][idx] - correction = 1 - beta2 ** state['step'] - pinv = (exp_avg_sq.sqrt() / math.sqrt(correction)).add_(eps) - return pinv - - def _reset_adam_state(self, step=0): - for group in self._optimizer.param_groups: - beta1, beta2 = group["betas"] - for param in group["params"]: - state = self._optimizer.state[param] - if state.get("step", 0) > 0: - state["exp_avg"].mul_( - (1 - beta1 ** step) / (1 - beta1 ** state["step"])) - state["exp_avg_sq"].mul_( - (1 - beta2 ** step) / (1 - beta2 ** state["step"])) - state["step"] = step - - def _final_callback(self): - scale = self._accum_scale * self._accum_count - if not np.isclose(scale, self._state["prev_scale"]): - self._reset_adam_state() - # reset Adam states when scale is changed - self._state["prev_scale"] = scale - return super()._final_callback() diff --git a/adaptdl/adaptdl/torch/torch/gradient_noise_scale_test.py b/adaptdl/adaptdl/torch/torch/gradient_noise_scale_test.py deleted file mode 100644 index 5f5112a36..000000000 --- a/adaptdl/adaptdl/torch/torch/gradient_noise_scale_test.py +++ /dev/null @@ -1,60 +0,0 @@ -import numpy as np -import pytest -import torch -import random - -from unittest.mock import Mock - -from adaptdl.torch.gradient_noise_scale import GradientNoiseScale - - -def test_object(): - params = [torch.tensor([[1., -1.], [2., 3.]], requires_grad=True), - torch.tensor([[2., 3.]], requires_grad=True)] - sgd = torch.optim.SGD(params, lr=0.1) - adp = Mock(require_backward_grad_sync=True) - obj = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) - assert(obj._accum_scale == 1.0) - obj._num_replicas = 8 - obj.set_accum_scale(3.0) - assert(obj.accum_scale == 3.0) - obj._num_replicas = 4 - obj.set_accum_scale(3.0) - assert(obj.accum_scale == 3.0) - assert(np.isclose(obj.gain(2.0), 1.0)) - obj._state['var_avg'] = 3.0 - obj._state['norm_avg'] = 1.0 - assert(np.isclose(obj.gain(3.0), 2.0)) - - -ATOL = 0.01 - - -def test_nan(): - def nan_objective(tensor): - if random.random() > 0.5: - target = float("Nan") - else: - target = 4.0 - return (tensor - target)**2 - - params_t = torch.Tensor([1.0]) - params = torch.autograd.Variable(params_t, requires_grad=True) - sgd = torch.optim.SGD([params], lr=0.1) - adp = Mock(require_backward_grad_sync=True) - gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) - adp.gns = gns - for i in range(100): - gns.reset_accumulation() - loss = nan_objective(params) - loss.backward() - if np.all(np.isfinite(loss.detach().numpy())): - sgd.step() - if params.allclose(torch.tensor([4.0]), atol=ATOL): - break - else: - pytest.fail(f"Did not converge: {params}") - if not (np.all(np.isfinite(gns.sqr_avg())) and - np.all(np.isfinite(gns.var_avg()))): - pytest.fail(f"non-finite adascale parameters:" - f"{gns.sqr_avg()}, {gns.var_avg()}") diff --git a/adaptdl/adaptdl/torch/torch/iterator.py b/adaptdl/adaptdl/torch/torch/iterator.py deleted file mode 100644 index 337292ea6..000000000 --- a/adaptdl/adaptdl/torch/torch/iterator.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -import logging - -from torchtext.data import BPTTIterator -from torchtext.data.dataset import Dataset -from torchtext.data.batch import Batch - -import adaptdl.checkpoint -import adaptdl.collective -import adaptdl.env -from adaptdl.torch.data import AdaptiveDataLoaderMixin - -logging.basicConfig(level=logging.INFO) -LOG = logging.getLogger(__name__) -LOG.setLevel(logging.INFO) - - -class AdaptiveBPTTIterator(BPTTIterator, AdaptiveDataLoaderMixin): - def __init__(self, dataset, batch_size, bptt_len, **kwargs): - max_batch_size = kwargs.pop("max_batch_size", None) - local_bsz_bounds = kwargs.pop("local_bsz_bounds", None) - - BPTTIterator.__init__(self, dataset=dataset, batch_size=batch_size, - bptt_len=bptt_len, **kwargs) - AdaptiveDataLoaderMixin.__init__(self, batch_size) - - self.num_replicas = adaptdl.env.num_replicas() - self.rank = adaptdl.env.replica_rank() - - if max_batch_size and local_bsz_bounds: - self._elastic.autoscale_batch_size(max_batch_size, - local_bsz_bounds) - - # The start index changes when there is a rescaling. We recompute a new - # start index based on how much we covered before the restart - def _recompute_start(self, prev_curr, prev_end, curr_end): - if prev_end == 0: - return prev_curr - return math.ceil(prev_curr * curr_end / prev_end) - - def __iter__(self): - with self._elastic.context(): - if self._elastic.skipdone(): - return - - self.batch_size = self._elastic._sync_local_bsz() - - text = self.dataset[0].text - TEXT = self.dataset.fields['text'] - TEXT.eos_token = None - text = text + ([TEXT.pad_token] * - int(math.ceil(len(text) / self.batch_size) * - self.batch_size - len(text))) - data = TEXT.numericalize( - [text], device=self.device) - data = data.view(self.batch_size, -1).t().contiguous() - dataset = Dataset(examples=self.dataset.examples, fields=[ - ('text', TEXT), ('target', TEXT)]) - end = data.size(0) # current length of dataset - - # Change in current batch size changes the dimensions of dataset - # which changes the starting position in the reshaped dataset. The - # local batch size is also a function of number of replicas, so - # when we rescale we need to recalculate where to start the - # iterations from for the new batch size. - self._elastic.current_index = \ - self._recompute_start(self._elastic.current_index, - self._elastic.end_index, end) - self._elastic.end_index = end - - # Every replica reads data strided by bptt_len - start = self._elastic.current_index + (self.bptt_len * self.rank) - step = self.bptt_len * self.num_replicas - - # The starting index of the highest rank - highest_start = self._elastic.current_index + \ - (self.bptt_len * (self.num_replicas - 1)) - - # Number of steps we will take on the highest rank. We limit - # iterations on all replicas by this number to prevent asymmetric - # reduce ops which would result in a deadlock. - min_steps_in_epoch = max(math.ceil((end - highest_start) / step), 0) # noqa: E501 - self.iterations = 0 - while True: - for i in range(start, end, step): - self.iterations += 1 - # Make sure that _elastic.profile is called equal number of - # times on all replicas - if self.iterations > min_steps_in_epoch: - break - with self._elastic.profile(self.training and i > 0): - seq_len = min(self.bptt_len, data.size(0) - i - 1) - assert seq_len > 0 - batch_text = data[i:i + seq_len] - batch_target = data[i + 1:i + 1 + seq_len] - if TEXT.batch_first: - batch_text = batch_text.t().contiguous() - batch_target = batch_target.t().contiguous() - yield Batch.fromvars( - dataset, self.batch_size, - text=batch_text, - target=batch_target) - self._elastic.current_index += step - - if not self.repeat: - break diff --git a/adaptdl/adaptdl/torch/torch/parallel.py b/adaptdl/adaptdl/torch/torch/parallel.py deleted file mode 100644 index 218ae2981..000000000 --- a/adaptdl/adaptdl/torch/torch/parallel.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import functools -import numpy as np -import time -import warnings -from typing import Optional - -import torch -import torch.cuda -import torch.distributed -from torch.autograd import Variable -from torch.nn.parallel import DistributedDataParallel - -import adaptdl.checkpoint -import adaptdl.env -import adaptdl.utils -from adaptdl.torch.data import current_dataloader -from adaptdl.torch.scaling_rules import AdaScale, AdamScale, ScalingRuleBase -from adaptdl.torch.gradient_noise_scale import GradientNoiseScale,\ - AdamGradientNoiseScale -from adaptdl.torch._metrics import profile_sync_time, update_grad_params,\ - update_progress - - -class AdaptiveDataParallel(DistributedDataParallel): - """ - This class extends PyTorch DistributedDataParallel with support for - adaptive batch sizes and checkpoint-restart elasticity. It automatically - saves the given model, optimizer, and (optionally) LR scheduler whenever a - checkpoint is triggered, and restores their states after restart. The - optimizer is automatically patched with the chosen scaling rule. - - Arguments: - model (torch.nn.Module): Model to be distributed. - optimizer (torch.optim.Optimizer): Optimizer used to update the given - model's parameters, will be patched using subclass of - :class:`adaptdl.torch.scaling_rules.ScalingRuleBase`. - scaling_rule (ScalingRuleBase): Scaling rule used to - patch the given optimizer, default to AdaScale. - lr_scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used - to anneal the learning rate for the given optimizer. - name (string): Unique name for each instance of this class, needed only - if multiple instances exist. - """ - def __init__(self, model, optimizer, lr_scheduler=None, mp_scaler=None, - scaling_rule: Optional[ScalingRuleBase] = None, - name="adaptdl-dataparallel", **kwargs): - super().__init__(model, **kwargs) - self._key = id(self) - # Register backward hooks on model parameters. Depends on these hooks - # being invoked before gradients are averaged. This is technically an - # internal behavior of DistributedDataParallel, but seems to be abused - # pretty widely so there should be little chance of it changing. - # https://discuss.pytorch.org/t/59291 - for param in model.parameters(): - param.register_hook(functools.partial(self._backward_hook, param)) - - # Setup for the scaling_rule, must be after registering backward hooks - # because some of them need to register their own backward hooks. - if not scaling_rule and (isinstance(optimizer, torch.optim.Adam) or - isinstance(optimizer, torch.optim.AdamW)): - self.scaling_rule = AdamScale() - else: - self.scaling_rule = scaling_rule or AdaScale() - - if isinstance(scaling_rule, AdamScale): - self.gns = AdamGradientNoiseScale(self, optimizer, - mp_scaler=mp_scaler) - else: - self.gns = GradientNoiseScale(self, optimizer, mp_scaler=mp_scaler) - self.scaling_rule.initialize(self, optimizer, patch_optimizer=True) - - self._state = _AdaptiveDataParallelState( - model, optimizer, lr_scheduler, mp_scaler, name) - adaptdl.checkpoint.load_state(self._state) - - self._sync_start = None - - def forward(self, *args, **kwargs): - # Do not do gradient synchronization during gradient accumulation. - dataloader = current_dataloader() - if dataloader is not None and dataloader.training: - self.require_backward_grad_sync = dataloader.is_optim_step() - accum_scale = (dataloader.current_local_bsz * - adaptdl.env.num_replicas() / dataloader.batch_size) - self.gns.set_accum_scale(accum_scale) - return super().forward(*args, **kwargs) - - @adaptdl.utils.print_exc - def _backward_hook(self, param, grad): - # This method should be invoked once for each parameter during the - # backward pass, before gradients are synchronized between replicas. - if grad.device.type.startswith("cuda"): - self._sync_start = torch.cuda.Event(enable_timing=True) - self._sync_start.record() - else: - self._sync_start = time.time() - self._final_callback_queued = False - Variable._execution_engine.queue_callback(self._queue_callback) - - @adaptdl.utils.print_exc - def _queue_callback(self): - # This method should be invoked after the entire backward pass. We want - # to make sure self._final_callback is invoked once, only after all - # gradients have been synchronized between each replica. However, the - # synchronization code in DistributedDataParallel is also done in a - # callback, which might not yet be executed. Therefore, we enqueue - # self._final_callback from this method, which should ensure it is - # invoked after the gradient synchronization callback. - if self._final_callback_queued: - return - self._final_callback_queued = True - Variable._execution_engine.queue_callback(self._final_callback) - - @adaptdl.utils.print_exc - def _final_callback(self): - # This method should be invoked once for each backward pass, after - # gradients have been synchronized between each replica. - self._final_callback_queued = False - # self._sync_start should mark the last time any local gradient - # from this module was produced. We assume the duration until now was - # primarily spent waiting for gradient synchronization. - # TODO: Depends on the internal behavior of DistributedDataParallel, - # which might break with future versions of PyTorch. Any better - # and well-supported way to measure the synchronization time? - if isinstance(self._sync_start, torch.cuda.Event): - sync_end = torch.cuda.Event(enable_timing=True) - sync_end.record() - sync_end.synchronize() - profile_sync_time(self._sync_start.elapsed_time(sync_end) / 1e3) - else: - profile_sync_time(time.time() - self._sync_start) - - dataloader = current_dataloader() - if dataloader is None: - # Don't allow backpropagation outside of a dataloader loop, because - # the batch size would be unknown. - raise RuntimeError("backpropagation outside AdaptiveDataLoader") - dataloader.train() - - scale = dataloader.current_batch_size / dataloader.batch_size - self._state.gain = self.gns.gain(scale) - self._state.lr_factor = \ - np.average(self.scaling_rule.scale_lr(scale)) - update_progress(self.gns.get_progress()) - if dataloader.max_batch_size and \ - dataloader.max_batch_size > dataloader.batch_size: - update_grad_params(self._key, self.gns.sqr_avg(), - self.gns.var_avg()) - self._sync_start = None - - def zero_grad(self, *args, **kwargs): - warnings.warn("zero_grad has no effect with AdaptiveDataParallel") - - @property - def gain(self): # TODO: should be tracked in the metrics module instead. - """ - Current estimate of the AdaScale gain (r_t) value. - """ - return self._state.gain - - def to_tensorboard(self, writer, global_step, tag_prefix=""): - """ - Output some useful metrics to TensorBoard. - - Arguments: - writer (torch.utils.tensorboard.SummaryWriter): ``SummaryWriter`` - object to output metrics to. - global_step (int): Global step value to record. - tag_prefix (str): Prefix added to each metric's tag. - """ - if tag_prefix and not tag_prefix.endswith("/"): - tag_prefix += "/" - writer.add_scalar(tag_prefix + "Gradient_Norm_Sqr", - self.gns.sqr_avg(), global_step) - writer.add_scalar(tag_prefix + "Gradient_Variance", - self.gns.var_avg(), global_step) - writer.add_scalar(tag_prefix + "Gain", - self._state.gain, global_step) - writer.add_scalar(tag_prefix + "Learning_Rate_Factor", - self._state.lr_factor, global_step) - - -class _AdaptiveDataParallelState(adaptdl.checkpoint.State): - def __init__(self, model, optimizer, lr_scheduler, mp_scaler, - name="adaptdl-dataparallel"): - super().__init__(name) - self.model = model - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.mp_scaler = mp_scaler - # TODO: Gain/goodput should be tracked in the metrics module instead. - self.gain = 1.0 - # lr_factor summary - self.lr_factor = 1.0 - - def save(self, fileobj): - state_dicts = [self.model.state_dict(), self.optimizer.state_dict()] - - if self.lr_scheduler is not None: - state_dicts.append(self.lr_scheduler.state_dict()) - else: - state_dicts.append(None) - - if self.mp_scaler is not None: - state_dicts.append(self.mp_scaler.state_dict()) - else: - state_dicts.append(None) - torch.save((state_dicts, self.gain, self.lr_factor), fileobj) - - def load(self, fileobj): - state_dicts, self.gain, self.lr_factor = torch.load(fileobj) - self.model.load_state_dict(state_dicts[0]) - self.optimizer.load_state_dict(state_dicts[1]) - if state_dicts[2] is not None: - self.lr_scheduler.load_state_dict(state_dicts[2]) - if state_dicts[3] is not None: - self.mp_scaler.load_state_dict(state_dicts[3]) diff --git a/adaptdl/adaptdl/torch/torch/parallel_test.py b/adaptdl/adaptdl/torch/torch/parallel_test.py deleted file mode 100644 index 3e67eb551..000000000 --- a/adaptdl/adaptdl/torch/torch/parallel_test.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import numpy as np -import torch - -from torch.utils.data import Dataset -import adaptdl.torch as adl - - -class LRIterableDataset(Dataset): - def __init__(self, size, true_values, noise): - input_values = np.random.uniform(-5.0, 5.0, size) - bias_input_values = np.stack([np.ones(size), input_values]) - target_values = ( - np.dot(true_values, bias_input_values) - + np.random.normal(0.0, noise, size=(size,))) - self._values = list(zip(input_values, target_values)) - self._len = size - - def __getitem__(self, index): - return self._values[index] - - def __len__(self): - return self._len - - -def test_single_replica_parallel(): - adl.init_process_group("gloo") - true_values = np.asarray([3.0, 4.0]) - dataset = LRIterableDataset(1000, true_values, 1.0) - dataloader = adl.AdaptiveDataLoader( - dataset, batch_size=32, shuffle=False, num_workers=1) - model = torch.nn.Linear(1, 1, bias=True) - params = [model.bias, model.weight] - sgd = torch.optim.SGD( - [{"params": [param]} for param in params], - lr=0.01) - schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, [50]) - model = adl.AdaptiveDataParallel(model, sgd, schedule) - loss = torch.nn.MSELoss() - for epoch in adl.remaining_epochs_until(100): - for inputs, targets in dataloader: - inputs = inputs.float() - targets = targets.float() - sgd.zero_grad() - output = model(torch.reshape(inputs, (-1, 1))) - targets = torch.reshape(targets, (-1, 1)) - loss_value = loss(output, targets) - loss_value.backward() - sgd.step() - schedule.step() - params = np.asarray([param.item() for param in params]) - assert(np.all(np.isclose(params, true_values, atol=0.1))), \ - (params, true_values) diff --git a/adaptdl/adaptdl/torch/torch/scaling_rules.py b/adaptdl/adaptdl/torch/torch/scaling_rules.py deleted file mode 100644 index ac0f2a1c0..000000000 --- a/adaptdl/adaptdl/torch/torch/scaling_rules.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import math -import numpy as np -import warnings - -from types import MethodType - -# from adaptdl.torch.data import current_dataloader - - -__all__ = ["ScalingRuleBase", "AdaScale", "LinearScale", "SqrtScale", - "LEGWScale"] - - -class ScalingRuleBase(object): - """ - Base class for scaling rules that has the ability to track gradient noise - scale calculations. Its subclasses can be used in combination with - ``adaptdl.torch.parallel.AdaptiveDataParallel`` and ``torch.optim.SGD``. - - .. code-block:: python - - optim = torch.optim.SGD(model, lr=0.001) - adascale = AdaScale() - model = AdaptiveDataParallel(model, optim, adascale) - - for epoch in ...: - for batch in ...: - optim.zero_grad() - loss = ... - loss.backward() - adascale.step() - """ - - _adaptlr = None - - def __init__(self): - # instance of AdaptiveDataParallel, needs to be set before any of the - # methods can be used - self.adp = None - self._optimizer = None - self._orig_optimizer_step = None - - def scale_lr(self, scale): - raise NotImplementedError - - def zero_grad(self, *args, **kwargs): - if self.adp.gns.should_zero_grad: - self.adp.gns.reset_accumulation(*args, **kwargs) - else: - warnings.warn("skipping zero_grad for accumulated gradient") - - def step(self, *args, **kwargs): - """ - Run one optimizer step. Essentially just invokes - ``optimizer.step(*args, **kwargs)`` with a scaled learning rate. - - Arguments: - args: Positional arguments passed to ``optimizer.step``. - kwargs: Keyword arguments passed to ``optimizer.step``. - """ - if not self.adp: - raise ValueError("AdaptiveDataParallel instance is not set!") - if not self.adp.require_backward_grad_sync: - return - scale = self.adp.gns.accum_scale * self.adp.gns.accum_count - initial_lr = [pg["lr"] for pg in self._optimizer.param_groups] - scaled_lr = np.multiply(self.scale_lr(scale), initial_lr) - ScalingRuleBase._adaptlr = scaled_lr - for lr, pg in zip(scaled_lr, self._optimizer.param_groups): - pg["lr"] = lr - self._orig_optimizer_step(*args, **kwargs) - for lr, pg in zip(initial_lr, self._optimizer.param_groups): - pg["lr"] = lr - self.adp.gns.set_progress(self.adp.gns.get_progress() - + self.adp.gns.gain(scale)) - - def _patch_optimizer(self): - """ - Monkey-patch the optimizer's step function with - :meth:`ScalingRuleBase.step`. - """ - @functools.wraps(self._optimizer.step) - def step_wrapper(optim, *args, **kwargs): - return self.step(*args, **kwargs) - - @functools.wraps(self._optimizer.zero_grad) - def zero_wrapper(optim, *args, **kwargs): - return self.zero_grad(*args, **kwargs) - self._optimizer.step = MethodType(step_wrapper, self._optimizer) - self._optimizer.zero_grad = MethodType(zero_wrapper, self._optimizer) - - def initialize(self, adp, optimizer, patch_optimizer=False): - self.adp = adp - self._optimizer = optimizer - self._orig_optimizer_step = optimizer.step - if patch_optimizer: - self._patch_optimizer() - - @staticmethod - def _get_adapt_lr_scale(): - return ScalingRuleBase._adaptlr - - -class AdaScale(ScalingRuleBase): - """ - Implements the AdaScale_ algorithm for scaling the learning rate for - distributed and large batch size training. - - .. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf - """ # noqa: E501 - - def scale_lr(self, scale): - """Calculate factors to be applied to lr for each parameter group.""" - var = self.adp.gns.raw_var_avg - sqr = self.adp.gns.raw_sqr_avg - var = np.maximum(var, 1e-6) - sqr = np.maximum(sqr, 0.0) - return (var + sqr) / (var / scale + sqr) - - -class AdamScale(AdaScale): - """ - Implements the variant of AdaScale_ that supports Adam, AdamW and RMSProp - """ - - def scale_lr(self, scale, power=0.5): - return np.power(super().scale_lr(scale=scale), power) - - -class LinearScale(ScalingRuleBase): - - def scale_lr(self, scale): - return scale - - -class SqrtScale(ScalingRuleBase): - - def scale_lr(self, scale): - return math.sqrt(scale) - - -class LEGWScale(ScalingRuleBase): - """ - Implements the LEGWScale algorithm for scaling the learning rate. - - Essentially, with LEGWScale, lr_factor is calculated based on - training progress as follows: - - when current_step < base_warmup_epoch * scale * steps_per_epoch: - `lr_factor = sqrt(scale) * progress_ratio` where - `progress_ratio = current_step / - (scale * base_warmup_epochs * steps_per_epoch)` - - when current_step >= base_warmup_epoch * scale * steps_per_epoch: - `lr_factor = sqrt(scale)` - - In order to adapt LEGWScale to AdaptDL, `progress_ratio` is - calculated differently as: - `progress / (scale * base_warmup_epochs * steps_per_epoch)` where - `progress` is the effective steps trained based on AdaptDL's - estimation. - - Argmuents: - base_warmup_epochs: Base warmup epochs - data_size: total number of samples in the dataset - - .. _LEGWScale: https://arxiv.org/pdf/1901.08256.pdf - """ - - def __init__(self, base_warmup_epochs, data_size): - super().__init__() - self._base_warmup_epochs = base_warmup_epochs - self._data_size = data_size - - def scale_lr(self, scale): - dataloader = current_dataloader() - # total training steps for warm up - total_steps = self._base_warmup_epochs * scale * \ - self._data_size / dataloader.batch_size - max_lr_multiplier = math.sqrt(scale) - # effective training steps taken - progress = self.adp.gns.get_progress() - if progress < total_steps: - lr_factor = max_lr_multiplier * (progress / total_steps) - else: - lr_factor = max_lr_multiplier - return lr_factor diff --git a/adaptdl/adaptdl/torch/torch/scaling_rules_test.py b/adaptdl/adaptdl/torch/torch/scaling_rules_test.py deleted file mode 100644 index 072adf666..000000000 --- a/adaptdl/adaptdl/torch/torch/scaling_rules_test.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright 2020 Petuum, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import numpy as np -import pytest -import torch - -from unittest.mock import Mock, patch - -from adaptdl.torch.gradient_noise_scale import GradientNoiseScale -from adaptdl.torch.scaling_rules import AdaScale, LinearScale,\ - LEGWScale, SqrtScale - - -def test_scaling_rules_1(): - """test AdaScale lr factors""" - adp = Mock(require_backward_grad_sync=True) - opm = Mock(param_groups=[1, 0, 2, -1]) - gns = Mock(raw_var_avg=np.asarray([1, 0, 0, 2]), - raw_sqr_avg=np.asarray([-1, 0, -1, 1])) - adp.gns = gns - adascale = AdaScale() - adascale.initialize(adp, opm) - input_scales = [0.5, 1, 2, 4, 10] - expected_ans = [[0.5, 0.5, 0.5, 0.6], [1., 1., 1., 1.], [2., 2., 2., 1.5], - [4., 4., 4., 2.], [10., 10., 10., 2.5]] - for scale, ans in zip(input_scales, expected_ans): - np.testing.assert_equal(adascale.scale_lr(scale), ans) - - -def test_scaling_rules_2(): - """test LinearScale lr factors""" - adp = Mock(require_backward_grad_sync=True) - opm = Mock(param_groups=[1, 0, 2, -1]) - gns = Mock(optimizer=opm) - adp.gns = gns - linearscale = LinearScale() - linearscale.initialize(adp, opm) - input_scales = [0.5, 1, 2, 4, 10] - expected_ans = [0.5, 1., 2., 4., 10.] - for scale, ans in zip(input_scales, expected_ans): - np.testing.assert_equal(linearscale.scale_lr(scale), ans) - - -def test_scaling_rules_3(): - """test SqrtScale lr factors""" - adp = Mock(require_backward_grad_sync=True) - opm = Mock(param_groups=[1, 0, 2, -1]) - gns = Mock(optimizer=opm) - adp.gns = gns - sqrtscale = SqrtScale() - sqrtscale.initialize(adp, opm) - input_scales = [1, 4, 9, 16, 25] - expected_ans = [1., 2., 3., 4., 5.] - for scale, ans in zip(input_scales, expected_ans): - np.testing.assert_equal(sqrtscale.scale_lr(scale), ans) - - -def test_scaling_rules_4(): - """test LEGWScale lr factors""" - with patch("adaptdl.torch.scaling_rules.current_dataloader", - return_value=Mock(batch_size=100)): - adp = Mock(require_backward_grad_sync=True) - opm = Mock(param_groups=[1, 0, 2, -1]) - gns = Mock(optimizer=opm, get_progress=Mock(return_value=5)) - adp.gns = gns - legwscale = LEGWScale(10, 1000) - legwscale.initialize(adp, opm) - input_scales = [1, 4, 9, 16, 25] - expected_ans = [1/20, 1/40, 1/60, 1/80, 1/100] - for scale, ans in zip(input_scales, expected_ans): - np.testing.assert_equal(legwscale.scale_lr(scale), ans) - with patch("adaptdl.torch.scaling_rules.current_dataloader", - return_value=Mock(batch_size=50)): - gns = Mock(optimizer=opm, get_progress=Mock(return_value=400)) - adp.gns = gns - input_scales = [1, 4, 9, 16, 25] - expected_ans = [1., 1., 2/3, 0.5, 0.4] - for scale, ans in zip(input_scales, expected_ans): - np.testing.assert_equal(legwscale.scale_lr(scale), ans) - gns = Mock(optimizer=opm, get_progress=Mock(return_value=400)) - adp.gns = gns - input_scales = [1, 4, 9, 16, 25] - expected_ans = [1., 2., 4/3, 1., 0.8] - for scale, ans in zip(input_scales, expected_ans): - np.testing.assert_equal(legwscale.scale_lr(scale), ans) - - -LR = 0.001 -STEP_SCHEDULE = [1000] -ATOL = 0.01 - - -def test_optimization_1(): - # See torch.test.test_optim - # Also see Rosenbrock/banana function - def rosenbrock(tensor): - x, y = tensor - return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2 - - params_t = torch.Tensor([1.0, 1.5]) - - params = torch.autograd.Variable(params_t, requires_grad=True) - sgd = torch.optim.SGD([params], lr=LR) - schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) - adp = Mock(require_backward_grad_sync=True) - gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) - adp.gns = gns - adascale = AdaScale() - adascale.initialize(adp, sgd, patch_optimizer=True) - for i in range(100000): - sgd.zero_grad() - loss = rosenbrock(params) - loss.backward() - sgd.step() - schedule.step() - if params.allclose(torch.tensor([1.0, 1.0]), atol=ATOL): - break - else: - pytest.fail(f"Did not converge: {params}") - - -def test_optimization_2(): - - def rosenbrock_noisy(tensor): - x, y = tensor - return (np.random.normal(1.0, 0.2) * (1 - x) ** 2 + - np.random.normal(1.0, 0.2) * 100 * (y - x ** 2) ** 2) - - params_t = torch.Tensor([1.0, 1.5]) - - params = torch.autograd.Variable(params_t, requires_grad=True) - sgd = torch.optim.SGD([params], lr=LR) - schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) - adp = Mock(require_backward_grad_sync=True) - gns = GradientNoiseScale(adp, sgd, accum_scale=2.0, num_replicas=1) - adp.gns = gns - adascale = AdaScale() - adascale.initialize(adp, sgd, patch_optimizer=True) - for i in range(100000): - sgd.zero_grad() - loss = sum([rosenbrock_noisy(params) for i in range(2)]) / 2.0 - loss.backward() - sgd.step() - schedule.step() - if params.allclose(torch.tensor([1.0, 1.0]), atol=ATOL): - break - else: - pytest.fail(f"Did not converge: {params}") - - -def test_optimization_3(): - def rosenbrock(x, y): - return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2 - - params_t = [ - {"params": [torch.autograd.Variable(torch.Tensor([1.0]), - requires_grad=True)]}, - {"params": [torch.autograd.Variable(torch.Tensor([1.5]), - requires_grad=True)]}] - - sgd = torch.optim.SGD(params_t, lr=LR) - schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) - adp = Mock(require_backward_grad_sync=True) - gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) - adp.gns = gns - adascale = AdaScale() - adascale.initialize(adp, sgd, patch_optimizer=True) - for i in range(100000): - sgd.zero_grad() - loss = rosenbrock(params_t[0]['params'][0], params_t[1]['params'][0]) - loss.backward() - sgd.step() - schedule.step() - if params_t[0]['params'][0].allclose(torch.tensor([1.0]), atol=ATOL) \ - and params_t[1]['params'][0].allclose(torch.tensor([1.0]), - atol=ATOL): - break - else: - pytest.fail(f"Did not converge: {params_t}") - - -def test_gradient_accumulation_optimization_1(): - - def rosenbrock(tensor): - x, y = tensor - return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2 - - params_t = torch.Tensor([1.0, 1.5]) - - params = torch.autograd.Variable(params_t, requires_grad=True) - sgd = torch.optim.SGD([params], lr=LR) - schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) - adp = Mock(require_backward_grad_sync=False) - gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) - adp.gns = gns - adascale = AdaScale() - adascale.initialize(adp, sgd, patch_optimizer=True) - for i in range(100000): - adp.require_backward_grad_sync = i % 2 == 1 - sgd.zero_grad() - loss = rosenbrock(params) - loss.backward() - sgd.step() - if adp.require_backward_grad_sync: - schedule.step() - if params.allclose(torch.tensor([1.0, 1.0]), atol=10 * ATOL): - break - else: - pytest.fail(f"Did not converge: {params}") - - -def test_gradient_accumulation_optimization_2(): - - def rosenbrock_noisy(tensor): - x, y = tensor - return (np.random.normal(1.0, 0.2) * (1 - x) ** 2 + - np.random.normal(1.0, 0.2) * 100 * (y - x ** 2) ** 2) - - params_t = torch.Tensor([1.0, 1.5]) - - params = torch.autograd.Variable(params_t, requires_grad=True) - sgd = torch.optim.SGD([params], lr=LR) - schedule = torch.optim.lr_scheduler.MultiStepLR(sgd, STEP_SCHEDULE) - adp = Mock(require_backward_grad_sync=False) - gns = GradientNoiseScale(adp, sgd, accum_scale=1.0, num_replicas=1) - adp.gns = gns - adascale = AdaScale() - adascale.initialize(adp, sgd, patch_optimizer=True) - for i in range(1000000): - adp.require_backward_grad_sync = i % 2 == 1 - sgd.zero_grad() - loss = rosenbrock_noisy(params) - loss.backward() - sgd.step() - if adp.require_backward_grad_sync: - schedule.step() - if params.allclose(torch.tensor([1.0, 1.0]), atol=ATOL): - break - else: - pytest.fail(f"Did not converge: {params}") From 4d8afa53423c72dd1b95aa3658426831b63a1ccc Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Fri, 8 Apr 2022 13:23:14 -0400 Subject: [PATCH 03/10] handle default value of preemptible flag (#116) --- sched/adaptdl_sched/allocator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sched/adaptdl_sched/allocator.py b/sched/adaptdl_sched/allocator.py index 4c8036939..d762d313d 100644 --- a/sched/adaptdl_sched/allocator.py +++ b/sched/adaptdl_sched/allocator.py @@ -62,7 +62,7 @@ async def _allocate_one_loop(self): # We only consider newly-added preemptible jobs # because this allocation may not be final. if (event["type"] == "ADDED" and - event["object"]["spec"]["preemptible"]): + event["object"]["spec"].get("preemptible", True)): async with self._lock: await self._allocate_one(event) From 16bc63ce654098f93698158515d6885f93d52704 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Tue, 19 Apr 2022 11:07:26 -0400 Subject: [PATCH 04/10] Support apiextensions.k8s.io/v1 and admissionregistration.k8s.io/v1 (#118) * first conversion to v1 * disable CR pruning * add adaptdl SA * add status to schema * comply to https://github.com/ray-project/ray/pull/21852 --- helm/adaptdl-sched/templates/adaptdl-crd.yaml | 85 ++++++++++--------- .../templates/validator-deployment.yaml | 1 + .../templates/validator-webhook.yaml | 6 +- ray/adaptdl_ray/tune/adaptdl_trial_test.py | 17 ---- 4 files changed, 50 insertions(+), 59 deletions(-) diff --git a/helm/adaptdl-sched/templates/adaptdl-crd.yaml b/helm/adaptdl-sched/templates/adaptdl-crd.yaml index 2262df52b..630559314 100644 --- a/helm/adaptdl-sched/templates/adaptdl-crd.yaml +++ b/helm/adaptdl-sched/templates/adaptdl-crd.yaml @@ -1,59 +1,66 @@ -apiVersion: apiextensions.k8s.io/v1beta1 +apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: name: adaptdljobs.adaptdl.petuum.com spec: group: adaptdl.petuum.com - versions: - - name: v1 - served: true - storage: true scope: Namespaced names: plural: adaptdljobs singular: adaptdljob kind: AdaptDLJob - additionalPrinterColumns: + shortNames: + - adljob + - adljobs + versions: + - name: v1 + served: true + storage: true + schema: + openAPIV3Schema: + type: object + required: ["spec"] + properties: + metadata: + type: object + properties: + # Name is used as label values which have a 63 character limit. + name: + type: string + maxLength: 63 + spec: + type: object + required: ["template"] + properties: + maxReplicas: + type: integer + minimum: 1 + minReplicas: + type: integer + minimum: 0 + preemptible: + type: boolean + template: + type: object + x-kubernetes-preserve-unknown-fields: true + status: + type: object + x-kubernetes-preserve-unknown-fields: true + subresources: + status: {} + additionalPrinterColumns: - name: Ready type: integer - JSONPath: .status.readyReplicas + jsonPath: .status.readyReplicas - name: Replicas type: string - JSONPath: .status.replicas + jsonPath: .status.replicas - name: Restarts type: integer - JSONPath: .status.group + jsonPath: .status.group - name: Status type: string - JSONPath: .status.phase + jsonPath: .status.phase - name: Age type: date - JSONPath: .metadata.creationTimestamp - subresources: - status: {} - validation: - openAPIV3Schema: - type: object - properties: - metadata: - type: object - properties: - # Name is used as label values which have a 63 character limit. - name: - type: string - maxLength: 63 - spec: - type: object - properties: - maxReplicas: - type: integer - minimum: 1 - minReplicas: - type: integer - minimum: 0 - preemptible: - type: boolean - template: - type: object - required: ["template"] - required: ["spec"] + jsonPath: .metadata.creationTimestamp diff --git a/helm/adaptdl-sched/templates/validator-deployment.yaml b/helm/adaptdl-sched/templates/validator-deployment.yaml index 3a74091c8..8e34dc819 100644 --- a/helm/adaptdl-sched/templates/validator-deployment.yaml +++ b/helm/adaptdl-sched/templates/validator-deployment.yaml @@ -17,6 +17,7 @@ spec: app: adaptdl-validator release: {{ .Release.Name }} spec: + serviceAccountName: adaptdl volumes: - name: tls secret: diff --git a/helm/adaptdl-sched/templates/validator-webhook.yaml b/helm/adaptdl-sched/templates/validator-webhook.yaml index 506c021ad..287f0538c 100644 --- a/helm/adaptdl-sched/templates/validator-webhook.yaml +++ b/helm/adaptdl-sched/templates/validator-webhook.yaml @@ -12,7 +12,7 @@ data: tls.crt: {{ $cert.Cert | b64enc }} tls.key: {{ $cert.Key | b64enc }} --- -apiVersion: admissionregistration.k8s.io/v1beta1 +apiVersion: admissionregistration.k8s.io/v1 kind: ValidatingWebhookConfiguration metadata: name: {{ .Release.Name }}-validator @@ -32,8 +32,8 @@ webhooks: rules: - operations: ["CREATE", "UPDATE"] apiGroups: ["adaptdl.petuum.com"] - apiVersions: ["v1beta1"] + apiVersions: ["v1"] resources: ["adaptdljobs"] admissionReviewVersions: - - v1beta1 + - v1 sideEffects: None diff --git a/ray/adaptdl_ray/tune/adaptdl_trial_test.py b/ray/adaptdl_ray/tune/adaptdl_trial_test.py index 234f6a258..6b3f250c7 100644 --- a/ray/adaptdl_ray/tune/adaptdl_trial_test.py +++ b/ray/adaptdl_ray/tune/adaptdl_trial_test.py @@ -20,9 +20,7 @@ from ray import tune from ray.tune.ray_trial_executor import RayTrialExecutor -from ray.tune.trial import Trial from ray.tune.suggest import BasicVariantGenerator -from .adaptdl_trial import AdaptDLTrial from .adaptdl_trainable import AdaptDLTrainableCreator, _train_simple @@ -37,19 +35,6 @@ def setUp(self): def tearDown(self): ray.shutdown() - def testTrialStatus(self): - ray.init(num_cpus=2) - trainable_cls = AdaptDLTrainableCreator(_train_simple, num_workers=2) - trial = AdaptDLTrial(trainable_cls.__name__, trial_id="0") - trial_executor = RayTrialExecutor() - assert trial.status == Trial.PENDING - trial_executor.start_trial(trial) - assert trial.status == Trial.RUNNING - trial_executor.stop_trial(trial) - assert trial.status == Trial.TERMINATED - trial_executor.stop_trial(trial, error=True) - assert trial.status == Trial.ERROR - def testExperimentTagTruncation(self): ray.init(num_cpus=2) trainable_cls = AdaptDLTrainableCreator(_train_simple, num_workers=1) @@ -72,7 +57,5 @@ def testExperimentTagTruncation(self): if not trial: break trial_executor.start_trial(trial) - assert trial.status == Trial.RUNNING assert len(os.path.basename(trial.logdir)) <= 200 trial_executor.stop_trial(trial) - assert trial.status == Trial.TERMINATED From c0bb45fc144e8f961c1e50f9341495fba357deb2 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Wed, 20 Apr 2022 01:14:51 -0400 Subject: [PATCH 05/10] fix adaptdl-ray release ver (#119) --- .github/workflows/release_pypi.yaml | 8 ++++---- ray/setup.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release_pypi.yaml b/.github/workflows/release_pypi.yaml index 947867429..cd3bc6301 100644 --- a/.github/workflows/release_pypi.yaml +++ b/.github/workflows/release_pypi.yaml @@ -35,7 +35,7 @@ jobs: cd adaptdl HOME=$(pwd) python setup.py sdist bdist_wheel ls -ltr dist/ - python -m twine upload dist/* + python -m twine upload --verbose dist/* - name: Build and push sched package env: TWINE_USERNAME: "__token__" @@ -45,7 +45,7 @@ jobs: cd sched HOME=$(pwd) python setup.py sdist bdist_wheel ls -ltr dist/ - python -m twine upload dist/* + python -m twine upload --verbose dist/* - name: Build and push ray package env: TWINE_USERNAME: "__token__" @@ -55,7 +55,7 @@ jobs: cd ray HOME=$(pwd) python setup.py sdist bdist_wheel ls -ltr dist/ - python -m twine upload dist/* + python -m twine upload --verbose dist/* - name: Build and push cli package env: TWINE_USERNAME: "__token__" @@ -65,4 +65,4 @@ jobs: cd cli HOME=$(pwd) python setup.py sdist bdist_wheel ls -ltr dist/ - python -m twine upload dist/* + python -m twine upload --verbose dist/* diff --git a/ray/setup.py b/ray/setup.py index eba2ba52f..7c7ec9e6a 100644 --- a/ray/setup.py +++ b/ray/setup.py @@ -32,7 +32,7 @@ def read_requirements(file_path): if __name__ == "__main__": setuptools.setup( name="adaptdl-ray", - version=os.getenv("ADAPTDL_RAY_VERSION", "0.0.0"), + version=os.getenv("ADAPTDL_VERSION", "0.0.0"), author="Petuum Inc. & The AdaptDL Authors", author_email="aurick.qiao@petuum.com", description="Dynamic-resource trainer and scheduler for deep learning", From 995f9104c486e44b779348f29ec60b4e9edcd8d7 Mon Sep 17 00:00:00 2001 From: Xuezhi-Liang Date: Mon, 30 May 2022 16:28:46 +0400 Subject: [PATCH 06/10] stage1_1.5 --- tutorial/mnist_step_5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorial/mnist_step_5.py b/tutorial/mnist_step_5.py index 0b7b27025..b862d2e89 100644 --- a/tutorial/mnist_step_5.py +++ b/tutorial/mnist_step_5.py @@ -118,6 +118,8 @@ def main(): transform=transform) dataset2 = datasets.MNIST('../data', train=False, transform=transform) + adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() + else "gloo") # Changed in step 1 train_loader = adaptdl.torch.AdaptiveDataLoader(dataset1, drop_last=True, **kwargs) # Changed in step 2 test_loader = adaptdl.torch.AdaptiveDataLoader(dataset2, **kwargs) # Changed in step 2 @@ -127,8 +129,6 @@ def main(): optimizer = optim.Adadelta(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() - else "gloo") # Changed in step 1 model = adaptdl.torch.AdaptiveDataParallel(model, optimizer, scheduler) # Changed in step 1 for epoch in adaptdl.torch.remaining_epochs_until(args.epochs): # Changed in step 4 From 6bd60b7da9a274cd015629bf3d8b0c1c5aee37e7 Mon Sep 17 00:00:00 2001 From: Xuezhi-Liang Date: Mon, 30 May 2022 16:33:40 +0400 Subject: [PATCH 07/10] stage1_1.5 --- adaptdl/adaptdl/torch/__init__.py | 4 + adaptdl/adaptdl/torch/context.py | 135 ++++++++++++------ adaptdl/adaptdl/torch/data.py | 129 +++++++++++------ adaptdl/adaptdl/torch/parallel.py | 6 +- adaptdl/adaptdl/torch/scaling_rules.py | 14 +- ...testcase_for_adaptdldataloader_refactor.py | 21 +-- 6 files changed, 205 insertions(+), 104 deletions(-) diff --git a/adaptdl/adaptdl/torch/__init__.py b/adaptdl/adaptdl/torch/__init__.py index c9832e600..07ece407a 100644 --- a/adaptdl/adaptdl/torch/__init__.py +++ b/adaptdl/adaptdl/torch/__init__.py @@ -29,6 +29,7 @@ import adaptdl.collective import adaptdl.env +import adaptdl.torch.data import semver from .epoch import current_epoch, finished_epochs, remaining_epochs_until from .data import current_dataloader, AdaptiveDataLoader, ElasticSampler @@ -119,6 +120,9 @@ def init_process_group(backend, rank, world_size) + # Initialize Context module. + adaptdl.torch.data.context_initialize(batch_size=32) + # Initialize torch.distributed. torch_port = adaptdl.collective.broadcast(portpicker.pick_unused_port()) init_method = "tcp://{}:{}?rank={}&world_size={}".format( diff --git a/adaptdl/adaptdl/torch/context.py b/adaptdl/adaptdl/torch/context.py index a130a592d..6ece877b1 100644 --- a/adaptdl/adaptdl/torch/context.py +++ b/adaptdl/adaptdl/torch/context.py @@ -18,81 +18,136 @@ import adaptdl.collective import adaptdl.env from adaptdl.torch._metrics import get_goodput_fn -import adaptdl.torch.data -from adaptdl.torch.scaling_rules import ScalingRuleBase +import adaptdl.torch.data as data +import numpy as np -class AdaptiveDLContext(object): +class Context(object): """ This class provides context tool to get AdaptDL-suggest parameters, such as batch_size, accum_steps and lr_scale. """ - def __init__(self, batch_size): - self._elastic = adaptdl.torch.data.AdaptiveDataLoaderHelper(batch_size) + def __init__(self, batch_size=32): # Autoscale batch size fields. self._speedup_threshold = 1.05 self.adapt_batch_size = None self.adapt_accum_steps = None self.adapt_lr_scale = None - def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, - gradient_accumulation=False): - self._elastic.autoscale_batch_size(max_batch_size, local_bsz_bounds, - gradient_accumulation) + self._max_batch_size = None + self._local_bsz_bounds = None + # Create and load state. + self._state = data._AdaptiveDataLoaderState() + adaptdl.checkpoint.load_state(self._state) + self.batch_size = batch_size + # self.state_batch_size = 1 + self._gradient_accumulation = False def get_batch_size(self): - _, self.adapt_batch_size, _ = self._sync_local_bsz() + self.adapt_batch_size, _ = self._get_local_bsz() return self.adapt_batch_size def get_accum_steps(self): - _, _, self.adapt_accum_steps = self._sync_local_bsz() + _, self.adapt_accum_steps = self._get_local_bsz() return self.adapt_accum_steps - def get_lr_scale(self): - self.adapt_lr_scale = ScalingRuleBase._get_adapt_lr_scale() - return float(self.adapt_lr_scale) + @staticmethod + def get_lr_scale(scale_lr, gns, optimizer): + scale = gns.accum_scale * gns.accum_count + initial_lr = [pg["lr"] for pg in optimizer.param_groups] + return scale, np.multiply(scale_lr(scale), initial_lr), initial_lr - def _sync_local_bsz(self): + def _get_local_bsz(self): goodput_fn = get_goodput_fn() - if self._elastic.max_batch_size is None or goodput_fn is None: + if self.max_batch_size is None or goodput_fn is None: # No autoscale batch size, just divide batch size evenly. - self._elastic._state.current_local_bsz = math.ceil( - self._elastic.batch_size / adaptdl.env.num_replicas()) - self._elastic._state.accumulation_steps = 0 - elif not self._elastic._state.current_local_bsz: + self._state.current_local_bsz = math.ceil( + self.batch_size / adaptdl.env.num_replicas()) + self._state.accumulation_steps = 0 + elif not self._state.current_local_bsz: # if init, use the batch size suggested _, atomic_bsz, accum_steps = goodput_fn.optimize( adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - max_batch_size=self._elastic._max_batch_size, - atomic_bsz_range=self._elastic._local_bsz_bounds, - accumulation=self._elastic._gradient_accumulation) - self._elastic._state.current_local_bsz = atomic_bsz - self._elastic._state.accumulation_steps = accum_steps + max_batch_size=self._max_batch_size, + atomic_bsz_range=self._local_bsz_bounds, + accumulation=self._gradient_accumulation) + self._state.current_local_bsz = atomic_bsz + self._state.accumulation_steps = accum_steps else: # if not first time, we check against the relative speedup suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize( adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - max_batch_size=self._elastic._max_batch_size, - atomic_bsz_range=self._elastic._local_bsz_bounds, - accumulation=self._elastic._gradient_accumulation) + max_batch_size=self._max_batch_size, + atomic_bsz_range=self._local_bsz_bounds, + accumulation=self._gradient_accumulation) # get current goodput current_goodput = goodput_fn( adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - self._elastic.current_local_bsz, self._elastic.accumulation_steps) + self.current_local_bsz, self.accumulation_steps) # use only if speedup is significant speedup = suggest_goodput / max(current_goodput, 1e-8) if speedup > self._speedup_threshold: - self._elastic._state.current_local_bsz = atomic_bsz - self._elastic._state.accumulation_steps = accum_steps - self._elastic._state.current_local_bsz, self._elastic._state.accumulation_steps = \ - adaptdl.collective.broadcast((self._elastic._state.current_local_bsz, - self._elastic._state.accumulation_steps)) - return self._elastic.current_local_bsz, self._elastic._state.current_local_bsz, self._elastic._state.accumulation_steps + self._state.current_local_bsz = atomic_bsz + self._state.accumulation_steps = accum_steps + return self._state.current_local_bsz, self._state.accumulation_steps + + @property + def max_batch_size(self): + """ + The maximum total batch size allowed for adaptive batch size. ``None`` + if adaptive batch size is disabled. + """ + return self._max_batch_size @property - def training(self): - return self._elastic.training + def local_bsz_bounds(self): + """ + The local batch size bounds on each replica. A pair of integers, + (min_local_bsz, max_local_bsz). + """ + return self._local_bsz_bounds + + @property + def current_local_bsz(self): + """ + The current logical local batch size used by the dataloader. + The batch size returned by the dataloader may be smaller if + gradient accumulation is used + """ + return self._state.current_local_bsz + + @property + def accumulation_steps(self): + """ + The number of batches returned by the dataloader before a + step is taken. + """ + return self._state.accumulation_steps + + def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, + gradient_accumulation=False): + """ + Enables adaptive batch size. Should be invoked once after the data + loader object is created. + + Arguments: + max_batch_size (int): Maximum total batch size allowed. + local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz), + the min and max local batch sizes allowed on each replica. + + Raises: + ValueError: If any of the provided batch size bounds are invalid. + """ + if not isinstance(max_batch_size, int) or \ + max_batch_size < self.batch_size: + raise ValueError("invalid max_batch_size") + if local_bsz_bounds is not None and ( + local_bsz_bounds[0] is not None and + local_bsz_bounds[0] > self.batch_size or + local_bsz_bounds[1] is not None and + local_bsz_bounds[1] < self.batch_size): + raise ValueError("invalid local_bsz_bounds") + self._max_batch_size = max_batch_size + self._local_bsz_bounds = local_bsz_bounds + self._gradient_accumulation = gradient_accumulation - def to_tensorboard(self, writer, global_step, tag_prefix=""): - self._elastic.to_tensorboard(writer, global_step, tag_prefix) - # to_tensorboard.__doc__ = adaptdl.torch.data.AdaptiveDataLoaderHelper.to_tensorboard.__doc__ \ No newline at end of file diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 8da7ce7a0..165571205 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -32,7 +32,6 @@ profile_step_start, profile_step_commit, set_batch_size, get_goodput_fn, get_progress) from adaptdl._signal import get_exit_flag -from adaptdl.torch.context import AdaptiveDLContext logging.basicConfig(level=logging.INFO) LOG = logging.getLogger(__name__) @@ -121,6 +120,24 @@ def current_dataloader(): return AdaptiveDataLoaderHelper._current +Context_obj = None +def context_initialize(batch_size): + """ + Initialize this module, must be invoked before calling any other functions. + This function will block until it has been invoked from all replicas. + + Arguments: + batch_size: batch_size of the context. + + Raises: + RuntimeError: If this module had already been initialized. + """ + global Context_obj + if Context_obj is not None: + raise RuntimeError("{} is already initialized".format(__name__)) + Context_obj = adaptdl.torch.context.Context(batch_size) + return Context_obj + class AdaptiveDataLoaderHelper(object): """ This class provides fine-grained control over adaptive training loops. It @@ -140,14 +157,15 @@ class AdaptiveDataLoaderHelper(object): _training = None # The AdaptiveDataLoader which loads training data. _current = None # The AdaptiveDataLoader which is currently iterating. - def __init__(self, batch_size=1): + def __init__(self, batch_size=32): + self._context = Context_obj # Autoscale batch size fields. self._max_batch_size = None self._local_bsz_bounds = None # Create and load state. - self._state = _AdaptiveDataLoaderState() - adaptdl.checkpoint.load_state(self._state) - self.batch_size = batch_size + self._state = self._context._state + # adaptdl.checkpoint.load_state(self._state) + self._context.batch_size = batch_size self.future_exit = None self._gradient_accumulation = False self._speedup_threshold = 1.05 @@ -199,7 +217,7 @@ def local_bsz_bounds(self): The local batch size bounds on each replica. A pair of integers, (min_local_bsz, max_local_bsz). """ - return self._local_bsz_bounds + return self._context._local_bsz_bounds @property def current_local_bsz(self): @@ -208,7 +226,7 @@ def current_local_bsz(self): The batch size returned by the dataloader may be smaller if gradient accumulation is used """ - return self._state.current_local_bsz + return self._context.get_batch_size() @property def accumulation_steps(self): @@ -216,7 +234,7 @@ def accumulation_steps(self): The number of batches returned by the dataloader before a step is taken. """ - return self._state.accumulation_steps + return self._context.get_accum_steps() def is_accum_step(self): """ @@ -237,36 +255,17 @@ def train(self): """ if AdaptiveDataLoaderHelper._training is None: AdaptiveDataLoaderHelper._training = self - set_batch_size(self.batch_size, self.max_batch_size, + set_batch_size(self._context.batch_size, self.max_batch_size, self.local_bsz_bounds, self._gradient_accumulation) - def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, - gradient_accumulation=False): - """ - Enables adaptive batch size. Should be invoked once after the data - loader object is created. - Arguments: - max_batch_size (int): Maximum total batch size allowed. - local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz), - the min and max local batch sizes allowed on each replica. - - Raises: - ValueError: If any of the provided batch size bounds are invalid. - """ - if not isinstance(max_batch_size, int) or \ - max_batch_size < self.batch_size: - raise ValueError("invalid max_batch_size") - if local_bsz_bounds is not None and ( - local_bsz_bounds[0] is not None and - local_bsz_bounds[0] > self.batch_size or - local_bsz_bounds[1] is not None and - local_bsz_bounds[1] < self.batch_size): - raise ValueError("invalid local_bsz_bounds") - self._max_batch_size = max_batch_size - self._local_bsz_bounds = local_bsz_bounds - self._gradient_accumulation = gradient_accumulation - self.train() + def _sync_local_bsz(self): + self._state.current_local_bsz, self._state.accumulation_steps = \ + self._context._get_local_bsz() + self._state.current_local_bsz, self._state.accumulation_steps = \ + adaptdl.collective.broadcast((self._state.current_local_bsz, + self._state.accumulation_steps)) + return self.current_local_bsz, self._state.current_local_bsz, self._state.accumulation_steps @property def training(self): @@ -319,7 +318,7 @@ def context(self): @property def current_batch_size(self): - return (self.current_local_bsz * (self.accumulation_steps + 1) * + return (self._context.get_batch_size() * (self._context.get_accum_steps() + 1) * adaptdl.env.num_replicas()) def skipdone(self): @@ -362,6 +361,54 @@ def to_tensorboard(self, writer, global_step, tag_prefix=""): self.accumulation_steps, global_step) +class AdaptiveDataLoaderMixin(object): + """ + This class provides elastic functionality to any custom DataLoader which + inherits it. It defines a member _elastic of type + :class:`AdaptiveDataLoaderHelper` which has useful methods and members to + implement restart-safe, elastic DataLoaders. It also exposes public methods + which can be used inside training loops directly from + :class:`AdaptiveDataLoader`. + """ + + def __init__(self, batch_size): + self._elastic = AdaptiveDataLoaderHelper(batch_size) + + def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, + gradient_accumulation=False): + self._elastic._context.autoscale_batch_size(max_batch_size, local_bsz_bounds, + gradient_accumulation) + self._elastic.train() + + @property + def current_local_bsz(self): + # if AdaptiveDataLoaderHelper._current is not self._elastic: + # return None + return self._elastic._context.current_local_bsz + + @property + def accumulation_steps(self): + """ + The number of batches returned by the dataloader before a + step is taken. + """ + return self._elastic._context.accumulation_steps + + @property + def training(self): + return self._elastic.training + + @property + def current_batch_size(self): + if AdaptiveDataLoaderHelper._current is not self._elastic: + return None + return self._elastic.current_batch_size + + def to_tensorboard(self, writer, global_step, tag_prefix=""): + self._elastic.to_tensorboard(writer, global_step, tag_prefix) + to_tensorboard.__doc__ = AdaptiveDataLoaderHelper.to_tensorboard.__doc__ + + def _worker_init_wrapper(worker_init_fn, num_workers): # Set globally-unique python and numpy seeds for each worker. @@ -379,7 +426,7 @@ def wrapper(worker_id): return wrapper -class AdaptiveDataLoader(DataLoader, AdaptiveDLContext): +class AdaptiveDataLoader(DataLoader, AdaptiveDataLoaderMixin): """ This class is a PyTorch DataLoader that also supports adaptive batch sizes and checkpoint-restart elasticity. Applications can typically use objects @@ -418,7 +465,7 @@ def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): kwargs["worker_init_fn"] = _worker_init_wrapper( kwargs.get("worker_init_fn"), kwargs.get("num_workers")) super().__init__(dataset, batch_size, shuffle=False, **kwargs) - AdaptiveDLContext.__init__(self, batch_size) + AdaptiveDataLoaderMixin.__init__(self, batch_size) def __iter__(self): """ @@ -443,19 +490,19 @@ def __iter__(self): while not done: self.sampler.set_epoch( epoch, index=self._elastic.current_index) - self.batch_sampler.batch_size = self.get_batch_size() + self.batch_sampler.batch_size = self._elastic._context.get_batch_size() for idx, batch in enumerate(super().__iter__()): with self._elastic.profile(self.training and idx >= 1): yield batch # Increment by the number of data samples processed self._elastic.current_index += \ num_replicas * self.batch_sampler.batch_size - if self._elastic.max_batch_size is not None and \ + if self._elastic._context.max_batch_size is not None and \ get_progress() >= len(self.dataset) * \ (epoch + 1) / self.batch_size: done = True break - if self._elastic.max_batch_size is None: + if self._elastic._context.max_batch_size is None: done = True self._elastic.current_index -= \ self._elastic.current_index % -len(self.dataset) diff --git a/adaptdl/adaptdl/torch/parallel.py b/adaptdl/adaptdl/torch/parallel.py index 218ae2981..99bff445f 100644 --- a/adaptdl/adaptdl/torch/parallel.py +++ b/adaptdl/adaptdl/torch/parallel.py @@ -96,7 +96,7 @@ def forward(self, *args, **kwargs): if dataloader is not None and dataloader.training: self.require_backward_grad_sync = dataloader.is_optim_step() accum_scale = (dataloader.current_local_bsz * - adaptdl.env.num_replicas() / dataloader.batch_size) + adaptdl.env.num_replicas() / dataloader._context.batch_size) self.gns.set_accum_scale(accum_scale) return super().forward(*args, **kwargs) @@ -152,13 +152,13 @@ def _final_callback(self): raise RuntimeError("backpropagation outside AdaptiveDataLoader") dataloader.train() - scale = dataloader.current_batch_size / dataloader.batch_size + scale = dataloader.current_batch_size / dataloader._context.batch_size self._state.gain = self.gns.gain(scale) self._state.lr_factor = \ np.average(self.scaling_rule.scale_lr(scale)) update_progress(self.gns.get_progress()) if dataloader.max_batch_size and \ - dataloader.max_batch_size > dataloader.batch_size: + dataloader.max_batch_size > dataloader._context.batch_size: update_grad_params(self._key, self.gns.sqr_avg(), self.gns.var_avg()) self._sync_start = None diff --git a/adaptdl/adaptdl/torch/scaling_rules.py b/adaptdl/adaptdl/torch/scaling_rules.py index ac0f2a1c0..1eeef1941 100644 --- a/adaptdl/adaptdl/torch/scaling_rules.py +++ b/adaptdl/adaptdl/torch/scaling_rules.py @@ -19,8 +19,9 @@ from types import MethodType -# from adaptdl.torch.data import current_dataloader - +from adaptdl.torch.data import current_dataloader +from adaptdl.torch.context import Context +from adaptdl.torch import data __all__ = ["ScalingRuleBase", "AdaScale", "LinearScale", "SqrtScale", "LEGWScale"] @@ -77,10 +78,7 @@ def step(self, *args, **kwargs): raise ValueError("AdaptiveDataParallel instance is not set!") if not self.adp.require_backward_grad_sync: return - scale = self.adp.gns.accum_scale * self.adp.gns.accum_count - initial_lr = [pg["lr"] for pg in self._optimizer.param_groups] - scaled_lr = np.multiply(self.scale_lr(scale), initial_lr) - ScalingRuleBase._adaptlr = scaled_lr + scale, scaled_lr, initial_lr = data.Context_obj.get_lr_scale(self.scale_lr, self.adp.gns, self._optimizer) for lr, pg in zip(scaled_lr, self._optimizer.param_groups): pg["lr"] = lr self._orig_optimizer_step(*args, **kwargs) @@ -111,10 +109,6 @@ def initialize(self, adp, optimizer, patch_optimizer=False): if patch_optimizer: self._patch_optimizer() - @staticmethod - def _get_adapt_lr_scale(): - return ScalingRuleBase._adaptlr - class AdaScale(ScalingRuleBase): """ diff --git a/tutorial/testcase_for_adaptdldataloader_refactor.py b/tutorial/testcase_for_adaptdldataloader_refactor.py index 038386404..d0356beb1 100644 --- a/tutorial/testcase_for_adaptdldataloader_refactor.py +++ b/tutorial/testcase_for_adaptdldataloader_refactor.py @@ -9,9 +9,8 @@ import adaptdl # Changed in step 1 import adaptdl.torch # Changed in step 1 -from adaptdl.torch.data import AdaptiveDLContext # For test AdaptiveDLContext only, users do not need to call this +from adaptdl.torch import data -from adaptdl.torch.scaling_rules import ScalingRuleBase class Net(nn.Module): def __init__(self): @@ -39,7 +38,7 @@ def forward(self, x): return output -def train(args, model, device, train_loader, optimizer, epoch, adacontext): # For test AdaptiveDLContext only, users do not need to call this +def train(args, model, device, train_loader, optimizer, epoch, adacontext): # For test Context only, users do not need to call this model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) @@ -57,7 +56,8 @@ def train(args, model, device, train_loader, optimizer, epoch, adacontext): # Fo epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item(), len(data), optimizer.param_groups[0]['lr'], - adacontext.get_batch_size(), adacontext.get_accum_steps(), adacontext.get_lr_scale(),# For test AdaptiveDLContext only, users do not need to call this + adacontext.get_batch_size(), adacontext.get_accum_steps(), + adacontext.get_lr_scale(model.scaling_rule.scale_lr, model.gns,optimizer)[1],# For test Context only, users do not need to call this )) if args.dry_run: break @@ -85,6 +85,7 @@ def tst(model, device, test_loader): def main(): # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') @@ -128,23 +129,23 @@ def main(): transform=transform) dataset2 = datasets.MNIST('../data', train=False, transform=transform) - adacontext = AdaptiveDLContext(args.batch_size) # For test AdaptiveDLContext only, users do not need to call this + + adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() + else "gloo") # Changed in step 1 + adacontext = data.Context_obj # For test Context only, users do not need to call this train_loader = adaptdl.torch.AdaptiveDataLoader(dataset1, drop_last=True, **kwargs) # Changed in step 2 test_loader = adaptdl.torch.AdaptiveDataLoader(dataset2, **kwargs) # Changed in step 2 - train_loader.autoscale_batch_size(1028, local_bsz_bounds=(32, 128)) # Changed in step 3, optional + train_loader.autoscale_batch_size(1028, local_bsz_bounds=(64, 128)) # Changed in step 3, optional model = Net().to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() - else "gloo") # Changed in step 1 model = adaptdl.torch.AdaptiveDataParallel(model, optimizer, scheduler) # Changed in step 1 - # adacontext = AdaptiveDLContext(args.batch_size) # For test AdaptiveDLContext only, users do not need to call this for epoch in adaptdl.torch.remaining_epochs_until(args.epochs): # Changed in step 4 - train(args, model, device, train_loader, optimizer, epoch, adacontext) # For test AdaptiveDLContext only, users do not need to call this + train(args, model, device, train_loader, optimizer, epoch, adacontext) # For test Context only, users do not need to call this tst(model, device, test_loader) scheduler.step() From 3d74699021735af8a171e2d77c17dd5e7a6803f9 Mon Sep 17 00:00:00 2001 From: Xuezhi-Liang Date: Mon, 30 May 2022 16:55:00 +0400 Subject: [PATCH 08/10] stage1_1.5 --- adaptdl/adaptdl/torch/__init__.py | 4 + adaptdl/adaptdl/torch/context.py | 153 +++++++++++++++++ adaptdl/adaptdl/torch/data.py | 116 +++++-------- adaptdl/adaptdl/torch/parallel.py | 6 +- adaptdl/adaptdl/torch/scaling_rules.py | 10 +- tutorial/mnist_step_5.py | 4 +- ...testcase_for_adaptdldataloader_refactor.py | 157 ++++++++++++++++++ 7 files changed, 365 insertions(+), 85 deletions(-) create mode 100644 adaptdl/adaptdl/torch/context.py create mode 100644 tutorial/testcase_for_adaptdldataloader_refactor.py diff --git a/adaptdl/adaptdl/torch/__init__.py b/adaptdl/adaptdl/torch/__init__.py index c9832e600..07ece407a 100644 --- a/adaptdl/adaptdl/torch/__init__.py +++ b/adaptdl/adaptdl/torch/__init__.py @@ -29,6 +29,7 @@ import adaptdl.collective import adaptdl.env +import adaptdl.torch.data import semver from .epoch import current_epoch, finished_epochs, remaining_epochs_until from .data import current_dataloader, AdaptiveDataLoader, ElasticSampler @@ -119,6 +120,9 @@ def init_process_group(backend, rank, world_size) + # Initialize Context module. + adaptdl.torch.data.context_initialize(batch_size=32) + # Initialize torch.distributed. torch_port = adaptdl.collective.broadcast(portpicker.pick_unused_port()) init_method = "tcp://{}:{}?rank={}&world_size={}".format( diff --git a/adaptdl/adaptdl/torch/context.py b/adaptdl/adaptdl/torch/context.py new file mode 100644 index 000000000..6ece877b1 --- /dev/null +++ b/adaptdl/adaptdl/torch/context.py @@ -0,0 +1,153 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import math +import adaptdl.checkpoint +import adaptdl.collective +import adaptdl.env +from adaptdl.torch._metrics import get_goodput_fn +import adaptdl.torch.data as data +import numpy as np + +class Context(object): + """ + This class provides context tool to get AdaptDL-suggest parameters, + such as batch_size, accum_steps and lr_scale. + """ + + def __init__(self, batch_size=32): + # Autoscale batch size fields. + self._speedup_threshold = 1.05 + self.adapt_batch_size = None + self.adapt_accum_steps = None + self.adapt_lr_scale = None + + self._max_batch_size = None + self._local_bsz_bounds = None + # Create and load state. + self._state = data._AdaptiveDataLoaderState() + adaptdl.checkpoint.load_state(self._state) + self.batch_size = batch_size + # self.state_batch_size = 1 + self._gradient_accumulation = False + + def get_batch_size(self): + self.adapt_batch_size, _ = self._get_local_bsz() + return self.adapt_batch_size + + def get_accum_steps(self): + _, self.adapt_accum_steps = self._get_local_bsz() + return self.adapt_accum_steps + + @staticmethod + def get_lr_scale(scale_lr, gns, optimizer): + scale = gns.accum_scale * gns.accum_count + initial_lr = [pg["lr"] for pg in optimizer.param_groups] + return scale, np.multiply(scale_lr(scale), initial_lr), initial_lr + + def _get_local_bsz(self): + goodput_fn = get_goodput_fn() + if self.max_batch_size is None or goodput_fn is None: + # No autoscale batch size, just divide batch size evenly. + self._state.current_local_bsz = math.ceil( + self.batch_size / adaptdl.env.num_replicas()) + self._state.accumulation_steps = 0 + elif not self._state.current_local_bsz: + # if init, use the batch size suggested + _, atomic_bsz, accum_steps = goodput_fn.optimize( + adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), + max_batch_size=self._max_batch_size, + atomic_bsz_range=self._local_bsz_bounds, + accumulation=self._gradient_accumulation) + self._state.current_local_bsz = atomic_bsz + self._state.accumulation_steps = accum_steps + else: + # if not first time, we check against the relative speedup + suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize( + adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), + max_batch_size=self._max_batch_size, + atomic_bsz_range=self._local_bsz_bounds, + accumulation=self._gradient_accumulation) + # get current goodput + current_goodput = goodput_fn( + adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), + self.current_local_bsz, self.accumulation_steps) + # use only if speedup is significant + speedup = suggest_goodput / max(current_goodput, 1e-8) + if speedup > self._speedup_threshold: + self._state.current_local_bsz = atomic_bsz + self._state.accumulation_steps = accum_steps + return self._state.current_local_bsz, self._state.accumulation_steps + + @property + def max_batch_size(self): + """ + The maximum total batch size allowed for adaptive batch size. ``None`` + if adaptive batch size is disabled. + """ + return self._max_batch_size + + @property + def local_bsz_bounds(self): + """ + The local batch size bounds on each replica. A pair of integers, + (min_local_bsz, max_local_bsz). + """ + return self._local_bsz_bounds + + @property + def current_local_bsz(self): + """ + The current logical local batch size used by the dataloader. + The batch size returned by the dataloader may be smaller if + gradient accumulation is used + """ + return self._state.current_local_bsz + + @property + def accumulation_steps(self): + """ + The number of batches returned by the dataloader before a + step is taken. + """ + return self._state.accumulation_steps + + def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, + gradient_accumulation=False): + """ + Enables adaptive batch size. Should be invoked once after the data + loader object is created. + + Arguments: + max_batch_size (int): Maximum total batch size allowed. + local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz), + the min and max local batch sizes allowed on each replica. + + Raises: + ValueError: If any of the provided batch size bounds are invalid. + """ + if not isinstance(max_batch_size, int) or \ + max_batch_size < self.batch_size: + raise ValueError("invalid max_batch_size") + if local_bsz_bounds is not None and ( + local_bsz_bounds[0] is not None and + local_bsz_bounds[0] > self.batch_size or + local_bsz_bounds[1] is not None and + local_bsz_bounds[1] < self.batch_size): + raise ValueError("invalid local_bsz_bounds") + self._max_batch_size = max_batch_size + self._local_bsz_bounds = local_bsz_bounds + self._gradient_accumulation = gradient_accumulation + diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 90a8767ac..165571205 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -120,6 +120,24 @@ def current_dataloader(): return AdaptiveDataLoaderHelper._current +Context_obj = None +def context_initialize(batch_size): + """ + Initialize this module, must be invoked before calling any other functions. + This function will block until it has been invoked from all replicas. + + Arguments: + batch_size: batch_size of the context. + + Raises: + RuntimeError: If this module had already been initialized. + """ + global Context_obj + if Context_obj is not None: + raise RuntimeError("{} is already initialized".format(__name__)) + Context_obj = adaptdl.torch.context.Context(batch_size) + return Context_obj + class AdaptiveDataLoaderHelper(object): """ This class provides fine-grained control over adaptive training loops. It @@ -139,14 +157,15 @@ class AdaptiveDataLoaderHelper(object): _training = None # The AdaptiveDataLoader which loads training data. _current = None # The AdaptiveDataLoader which is currently iterating. - def __init__(self, batch_size=1): + def __init__(self, batch_size=32): + self._context = Context_obj # Autoscale batch size fields. self._max_batch_size = None self._local_bsz_bounds = None # Create and load state. - self._state = _AdaptiveDataLoaderState() - adaptdl.checkpoint.load_state(self._state) - self.batch_size = batch_size + self._state = self._context._state + # adaptdl.checkpoint.load_state(self._state) + self._context.batch_size = batch_size self.future_exit = None self._gradient_accumulation = False self._speedup_threshold = 1.05 @@ -198,7 +217,7 @@ def local_bsz_bounds(self): The local batch size bounds on each replica. A pair of integers, (min_local_bsz, max_local_bsz). """ - return self._local_bsz_bounds + return self._context._local_bsz_bounds @property def current_local_bsz(self): @@ -207,7 +226,7 @@ def current_local_bsz(self): The batch size returned by the dataloader may be smaller if gradient accumulation is used """ - return self._state.current_local_bsz + return self._context.get_batch_size() @property def accumulation_steps(self): @@ -215,7 +234,7 @@ def accumulation_steps(self): The number of batches returned by the dataloader before a step is taken. """ - return self._state.accumulation_steps + return self._context.get_accum_steps() def is_accum_step(self): """ @@ -236,73 +255,17 @@ def train(self): """ if AdaptiveDataLoaderHelper._training is None: AdaptiveDataLoaderHelper._training = self - set_batch_size(self.batch_size, self.max_batch_size, + set_batch_size(self._context.batch_size, self.max_batch_size, self.local_bsz_bounds, self._gradient_accumulation) - def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, - gradient_accumulation=False): - """ - Enables adaptive batch size. Should be invoked once after the data - loader object is created. - - Arguments: - max_batch_size (int): Maximum total batch size allowed. - local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz), - the min and max local batch sizes allowed on each replica. - - Raises: - ValueError: If any of the provided batch size bounds are invalid. - """ - if not isinstance(max_batch_size, int) or \ - max_batch_size < self.batch_size: - raise ValueError("invalid max_batch_size") - if local_bsz_bounds is not None and ( - local_bsz_bounds[0] is not None and - local_bsz_bounds[0] > self.batch_size or - local_bsz_bounds[1] is not None and - local_bsz_bounds[1] < self.batch_size): - raise ValueError("invalid local_bsz_bounds") - self._max_batch_size = max_batch_size - self._local_bsz_bounds = local_bsz_bounds - self._gradient_accumulation = gradient_accumulation - self.train() def _sync_local_bsz(self): - goodput_fn = get_goodput_fn() - if self.max_batch_size is None or goodput_fn is None: - # No autoscale batch size, just divide batch size evenly. - self._state.current_local_bsz = math.ceil( - self.batch_size / adaptdl.env.num_replicas()) - self._state.accumulation_steps = 0 - elif not self._state.current_local_bsz: - # if init, use the batch size suggested - _, atomic_bsz, accum_steps = goodput_fn.optimize( - adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, - accumulation=self._gradient_accumulation) - self._state.current_local_bsz = atomic_bsz - self._state.accumulation_steps = accum_steps - else: - # if not first time, we check against the relative speedup - suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize( - adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, - accumulation=self._gradient_accumulation) - # get current goodput - current_goodput = goodput_fn( - adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - self.current_local_bsz, self.accumulation_steps) - # use only if speedup is significant - speedup = suggest_goodput / max(current_goodput, 1e-8) - if speedup > self._speedup_threshold: - self._state.current_local_bsz = atomic_bsz - self._state.accumulation_steps = accum_steps + self._state.current_local_bsz, self._state.accumulation_steps = \ + self._context._get_local_bsz() self._state.current_local_bsz, self._state.accumulation_steps = \ adaptdl.collective.broadcast((self._state.current_local_bsz, self._state.accumulation_steps)) - return self.current_local_bsz + return self.current_local_bsz, self._state.current_local_bsz, self._state.accumulation_steps @property def training(self): @@ -355,7 +318,7 @@ def context(self): @property def current_batch_size(self): - return (self.current_local_bsz * (self.accumulation_steps + 1) * + return (self._context.get_batch_size() * (self._context.get_accum_steps() + 1) * adaptdl.env.num_replicas()) def skipdone(self): @@ -413,14 +376,15 @@ def __init__(self, batch_size): def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, gradient_accumulation=False): - self._elastic.autoscale_batch_size(max_batch_size, local_bsz_bounds, + self._elastic._context.autoscale_batch_size(max_batch_size, local_bsz_bounds, gradient_accumulation) + self._elastic.train() @property def current_local_bsz(self): - if AdaptiveDataLoaderHelper._current is not self._elastic: - return None - return self._elastic.current_local_bsz + # if AdaptiveDataLoaderHelper._current is not self._elastic: + # return None + return self._elastic._context.current_local_bsz @property def accumulation_steps(self): @@ -428,7 +392,7 @@ def accumulation_steps(self): The number of batches returned by the dataloader before a step is taken. """ - return self._elastic.accumulation_steps + return self._elastic._context.accumulation_steps @property def training(self): @@ -526,19 +490,19 @@ def __iter__(self): while not done: self.sampler.set_epoch( epoch, index=self._elastic.current_index) - self.batch_sampler.batch_size = self._elastic._sync_local_bsz() + self.batch_sampler.batch_size = self._elastic._context.get_batch_size() for idx, batch in enumerate(super().__iter__()): with self._elastic.profile(self.training and idx >= 1): yield batch # Increment by the number of data samples processed self._elastic.current_index += \ num_replicas * self.batch_sampler.batch_size - if self._elastic.max_batch_size is not None and \ + if self._elastic._context.max_batch_size is not None and \ get_progress() >= len(self.dataset) * \ (epoch + 1) / self.batch_size: done = True break - if self._elastic.max_batch_size is None: + if self._elastic._context.max_batch_size is None: done = True self._elastic.current_index -= \ self._elastic.current_index % -len(self.dataset) diff --git a/adaptdl/adaptdl/torch/parallel.py b/adaptdl/adaptdl/torch/parallel.py index 218ae2981..99bff445f 100644 --- a/adaptdl/adaptdl/torch/parallel.py +++ b/adaptdl/adaptdl/torch/parallel.py @@ -96,7 +96,7 @@ def forward(self, *args, **kwargs): if dataloader is not None and dataloader.training: self.require_backward_grad_sync = dataloader.is_optim_step() accum_scale = (dataloader.current_local_bsz * - adaptdl.env.num_replicas() / dataloader.batch_size) + adaptdl.env.num_replicas() / dataloader._context.batch_size) self.gns.set_accum_scale(accum_scale) return super().forward(*args, **kwargs) @@ -152,13 +152,13 @@ def _final_callback(self): raise RuntimeError("backpropagation outside AdaptiveDataLoader") dataloader.train() - scale = dataloader.current_batch_size / dataloader.batch_size + scale = dataloader.current_batch_size / dataloader._context.batch_size self._state.gain = self.gns.gain(scale) self._state.lr_factor = \ np.average(self.scaling_rule.scale_lr(scale)) update_progress(self.gns.get_progress()) if dataloader.max_batch_size and \ - dataloader.max_batch_size > dataloader.batch_size: + dataloader.max_batch_size > dataloader._context.batch_size: update_grad_params(self._key, self.gns.sqr_avg(), self.gns.var_avg()) self._sync_start = None diff --git a/adaptdl/adaptdl/torch/scaling_rules.py b/adaptdl/adaptdl/torch/scaling_rules.py index a1300232f..1eeef1941 100644 --- a/adaptdl/adaptdl/torch/scaling_rules.py +++ b/adaptdl/adaptdl/torch/scaling_rules.py @@ -20,7 +20,8 @@ from types import MethodType from adaptdl.torch.data import current_dataloader - +from adaptdl.torch.context import Context +from adaptdl.torch import data __all__ = ["ScalingRuleBase", "AdaScale", "LinearScale", "SqrtScale", "LEGWScale"] @@ -45,6 +46,9 @@ class ScalingRuleBase(object): loss.backward() adascale.step() """ + + _adaptlr = None + def __init__(self): # instance of AdaptiveDataParallel, needs to be set before any of the # methods can be used @@ -74,9 +78,7 @@ def step(self, *args, **kwargs): raise ValueError("AdaptiveDataParallel instance is not set!") if not self.adp.require_backward_grad_sync: return - scale = self.adp.gns.accum_scale * self.adp.gns.accum_count - initial_lr = [pg["lr"] for pg in self._optimizer.param_groups] - scaled_lr = np.multiply(self.scale_lr(scale), initial_lr) + scale, scaled_lr, initial_lr = data.Context_obj.get_lr_scale(self.scale_lr, self.adp.gns, self._optimizer) for lr, pg in zip(scaled_lr, self._optimizer.param_groups): pg["lr"] = lr self._orig_optimizer_step(*args, **kwargs) diff --git a/tutorial/mnist_step_5.py b/tutorial/mnist_step_5.py index 0b7b27025..b862d2e89 100644 --- a/tutorial/mnist_step_5.py +++ b/tutorial/mnist_step_5.py @@ -118,6 +118,8 @@ def main(): transform=transform) dataset2 = datasets.MNIST('../data', train=False, transform=transform) + adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() + else "gloo") # Changed in step 1 train_loader = adaptdl.torch.AdaptiveDataLoader(dataset1, drop_last=True, **kwargs) # Changed in step 2 test_loader = adaptdl.torch.AdaptiveDataLoader(dataset2, **kwargs) # Changed in step 2 @@ -127,8 +129,6 @@ def main(): optimizer = optim.Adadelta(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() - else "gloo") # Changed in step 1 model = adaptdl.torch.AdaptiveDataParallel(model, optimizer, scheduler) # Changed in step 1 for epoch in adaptdl.torch.remaining_epochs_until(args.epochs): # Changed in step 4 diff --git a/tutorial/testcase_for_adaptdldataloader_refactor.py b/tutorial/testcase_for_adaptdldataloader_refactor.py new file mode 100644 index 000000000..d0356beb1 --- /dev/null +++ b/tutorial/testcase_for_adaptdldataloader_refactor.py @@ -0,0 +1,157 @@ +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + +import adaptdl # Changed in step 1 +import adaptdl.torch # Changed in step 1 +from adaptdl.torch import data + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout2d(0.25) + self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch, adacontext): # For test Context only, users do not need to call this + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' + '\treal_batch_size:{}\treal_lr:{}' + '\t ada_batch_size:{}\tada_accum:{}\tada_lr_scale:{}' + .format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item(), + len(data), optimizer.param_groups[0]['lr'], + adacontext.get_batch_size(), adacontext.get_accum_steps(), + adacontext.get_lr_scale(model.scaling_rule.scale_lr, model.gns,optimizer)[1],# For test Context only, users do not need to call this + )) + if args.dry_run: + break + + +def tst(model, device, test_loader): + model.eval() + stats = adaptdl.torch.Accumulator() # Changed in step 5 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + stats["test_loss"] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss # Changed in step 5 + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + stats["correct"] += pred.eq(target.view_as(pred)).sum().item() # Changed in step 5 + + with stats.synchronized(): # Changed in step 5 + test_loss = stats["test_loss"] / len(test_loader.dataset) # Changed in step 5 + correct = stats["correct"] # Changed in step 5 + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) # Changed in step 5 + + +def main(): + # Training settings + + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + kwargs = {'batch_size': args.batch_size} + if use_cuda: + kwargs.update({'num_workers': 1, + 'pin_memory': True, + 'shuffle': True}, + ) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + dataset1 = datasets.MNIST('../data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, + transform=transform) + + adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() + else "gloo") # Changed in step 1 + adacontext = data.Context_obj # For test Context only, users do not need to call this + train_loader = adaptdl.torch.AdaptiveDataLoader(dataset1, drop_last=True, **kwargs) # Changed in step 2 + test_loader = adaptdl.torch.AdaptiveDataLoader(dataset2, **kwargs) # Changed in step 2 + + train_loader.autoscale_batch_size(1028, local_bsz_bounds=(64, 128)) # Changed in step 3, optional + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + model = adaptdl.torch.AdaptiveDataParallel(model, optimizer, scheduler) # Changed in step 1 + + for epoch in adaptdl.torch.remaining_epochs_until(args.epochs): # Changed in step 4 + train(args, model, device, train_loader, optimizer, epoch, adacontext) # For test Context only, users do not need to call this + tst(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == '__main__': + main() From 8d7e96dc95a0226fe36b8c95cd9161b1cb94a776 Mon Sep 17 00:00:00 2001 From: Xuezhi-Liang Date: Mon, 30 May 2022 17:01:43 +0400 Subject: [PATCH 09/10] test --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index fe768883e..6494d64dd 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -# Set ADAPTDL_DEV_REPO to use an external docker registry. +### Set ADAPTDL_DEV_REPO to use an external docker registry. # Set ADAPTDL_DEV_REPO_CREDS to the name of registry secret. RELEASE_NAME = adaptdl LOCAL_PORT = 59283 From 64fb57b80a8861da97d6cfd6b740d785cf67cb52 Mon Sep 17 00:00:00 2001 From: Xuezhi-Liang Date: Mon, 30 May 2022 17:03:21 +0400 Subject: [PATCH 10/10] test --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 6494d64dd..fe768883e 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -### Set ADAPTDL_DEV_REPO to use an external docker registry. +# Set ADAPTDL_DEV_REPO to use an external docker registry. # Set ADAPTDL_DEV_REPO_CREDS to the name of registry secret. RELEASE_NAME = adaptdl LOCAL_PORT = 59283