diff --git a/adaptdl/adaptdl/torch/__init__.py b/adaptdl/adaptdl/torch/__init__.py index c9832e600..50f430acd 100644 --- a/adaptdl/adaptdl/torch/__init__.py +++ b/adaptdl/adaptdl/torch/__init__.py @@ -29,6 +29,7 @@ import adaptdl.collective import adaptdl.env +import adaptdl.torch.data import semver from .epoch import current_epoch, finished_epochs, remaining_epochs_until from .data import current_dataloader, AdaptiveDataLoader, ElasticSampler @@ -119,6 +120,9 @@ def init_process_group(backend, rank, world_size) + # Initialize Context module. + adaptdl.torch.data.context_initialize() + # Initialize torch.distributed. torch_port = adaptdl.collective.broadcast(portpicker.pick_unused_port()) init_method = "tcp://{}:{}?rank={}&world_size={}".format( diff --git a/adaptdl/adaptdl/torch/context.py b/adaptdl/adaptdl/torch/context.py new file mode 100644 index 000000000..6ece877b1 --- /dev/null +++ b/adaptdl/adaptdl/torch/context.py @@ -0,0 +1,153 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import math +import adaptdl.checkpoint +import adaptdl.collective +import adaptdl.env +from adaptdl.torch._metrics import get_goodput_fn +import adaptdl.torch.data as data +import numpy as np + +class Context(object): + """ + This class provides context tool to get AdaptDL-suggest parameters, + such as batch_size, accum_steps and lr_scale. + """ + + def __init__(self, batch_size=32): + # Autoscale batch size fields. + self._speedup_threshold = 1.05 + self.adapt_batch_size = None + self.adapt_accum_steps = None + self.adapt_lr_scale = None + + self._max_batch_size = None + self._local_bsz_bounds = None + # Create and load state. + self._state = data._AdaptiveDataLoaderState() + adaptdl.checkpoint.load_state(self._state) + self.batch_size = batch_size + # self.state_batch_size = 1 + self._gradient_accumulation = False + + def get_batch_size(self): + self.adapt_batch_size, _ = self._get_local_bsz() + return self.adapt_batch_size + + def get_accum_steps(self): + _, self.adapt_accum_steps = self._get_local_bsz() + return self.adapt_accum_steps + + @staticmethod + def get_lr_scale(scale_lr, gns, optimizer): + scale = gns.accum_scale * gns.accum_count + initial_lr = [pg["lr"] for pg in optimizer.param_groups] + return scale, np.multiply(scale_lr(scale), initial_lr), initial_lr + + def _get_local_bsz(self): + goodput_fn = get_goodput_fn() + if self.max_batch_size is None or goodput_fn is None: + # No autoscale batch size, just divide batch size evenly. + self._state.current_local_bsz = math.ceil( + self.batch_size / adaptdl.env.num_replicas()) + self._state.accumulation_steps = 0 + elif not self._state.current_local_bsz: + # if init, use the batch size suggested + _, atomic_bsz, accum_steps = goodput_fn.optimize( + adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), + max_batch_size=self._max_batch_size, + atomic_bsz_range=self._local_bsz_bounds, + accumulation=self._gradient_accumulation) + self._state.current_local_bsz = atomic_bsz + self._state.accumulation_steps = accum_steps + else: + # if not first time, we check against the relative speedup + suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize( + adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), + max_batch_size=self._max_batch_size, + atomic_bsz_range=self._local_bsz_bounds, + accumulation=self._gradient_accumulation) + # get current goodput + current_goodput = goodput_fn( + adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), + self.current_local_bsz, self.accumulation_steps) + # use only if speedup is significant + speedup = suggest_goodput / max(current_goodput, 1e-8) + if speedup > self._speedup_threshold: + self._state.current_local_bsz = atomic_bsz + self._state.accumulation_steps = accum_steps + return self._state.current_local_bsz, self._state.accumulation_steps + + @property + def max_batch_size(self): + """ + The maximum total batch size allowed for adaptive batch size. ``None`` + if adaptive batch size is disabled. + """ + return self._max_batch_size + + @property + def local_bsz_bounds(self): + """ + The local batch size bounds on each replica. A pair of integers, + (min_local_bsz, max_local_bsz). + """ + return self._local_bsz_bounds + + @property + def current_local_bsz(self): + """ + The current logical local batch size used by the dataloader. + The batch size returned by the dataloader may be smaller if + gradient accumulation is used + """ + return self._state.current_local_bsz + + @property + def accumulation_steps(self): + """ + The number of batches returned by the dataloader before a + step is taken. + """ + return self._state.accumulation_steps + + def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, + gradient_accumulation=False): + """ + Enables adaptive batch size. Should be invoked once after the data + loader object is created. + + Arguments: + max_batch_size (int): Maximum total batch size allowed. + local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz), + the min and max local batch sizes allowed on each replica. + + Raises: + ValueError: If any of the provided batch size bounds are invalid. + """ + if not isinstance(max_batch_size, int) or \ + max_batch_size < self.batch_size: + raise ValueError("invalid max_batch_size") + if local_bsz_bounds is not None and ( + local_bsz_bounds[0] is not None and + local_bsz_bounds[0] > self.batch_size or + local_bsz_bounds[1] is not None and + local_bsz_bounds[1] < self.batch_size): + raise ValueError("invalid local_bsz_bounds") + self._max_batch_size = max_batch_size + self._local_bsz_bounds = local_bsz_bounds + self._gradient_accumulation = gradient_accumulation + diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 90a8767ac..8427ab776 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -120,6 +120,24 @@ def current_dataloader(): return AdaptiveDataLoaderHelper._current +Context_obj = None +def context_initialize(): + """ + Initialize this module, must be invoked before calling any other functions. + This function will block until it has been invoked from all replicas. + + Arguments: + batch_size: batch_size of the context. + + Raises: + RuntimeError: If this module had already been initialized. + """ + global Context_obj + if Context_obj is not None: + raise RuntimeError("{} is already initialized".format(__name__)) + Context_obj = adaptdl.torch.context.Context() + return Context_obj + class AdaptiveDataLoaderHelper(object): """ This class provides fine-grained control over adaptive training loops. It @@ -139,14 +157,15 @@ class AdaptiveDataLoaderHelper(object): _training = None # The AdaptiveDataLoader which loads training data. _current = None # The AdaptiveDataLoader which is currently iterating. - def __init__(self, batch_size=1): + def __init__(self, batch_size=32): + self._context = Context_obj # Autoscale batch size fields. self._max_batch_size = None self._local_bsz_bounds = None # Create and load state. - self._state = _AdaptiveDataLoaderState() - adaptdl.checkpoint.load_state(self._state) - self.batch_size = batch_size + self._state = self._context._state + # adaptdl.checkpoint.load_state(self._state) + self._context.batch_size = batch_size self.future_exit = None self._gradient_accumulation = False self._speedup_threshold = 1.05 @@ -198,7 +217,7 @@ def local_bsz_bounds(self): The local batch size bounds on each replica. A pair of integers, (min_local_bsz, max_local_bsz). """ - return self._local_bsz_bounds + return self._context._local_bsz_bounds @property def current_local_bsz(self): @@ -207,7 +226,7 @@ def current_local_bsz(self): The batch size returned by the dataloader may be smaller if gradient accumulation is used """ - return self._state.current_local_bsz + return self._context.get_batch_size() @property def accumulation_steps(self): @@ -215,7 +234,7 @@ def accumulation_steps(self): The number of batches returned by the dataloader before a step is taken. """ - return self._state.accumulation_steps + return self._context.get_accum_steps() def is_accum_step(self): """ @@ -236,73 +255,17 @@ def train(self): """ if AdaptiveDataLoaderHelper._training is None: AdaptiveDataLoaderHelper._training = self - set_batch_size(self.batch_size, self.max_batch_size, + set_batch_size(self._context.batch_size, self.max_batch_size, self.local_bsz_bounds, self._gradient_accumulation) - def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, - gradient_accumulation=False): - """ - Enables adaptive batch size. Should be invoked once after the data - loader object is created. - - Arguments: - max_batch_size (int): Maximum total batch size allowed. - local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz), - the min and max local batch sizes allowed on each replica. - - Raises: - ValueError: If any of the provided batch size bounds are invalid. - """ - if not isinstance(max_batch_size, int) or \ - max_batch_size < self.batch_size: - raise ValueError("invalid max_batch_size") - if local_bsz_bounds is not None and ( - local_bsz_bounds[0] is not None and - local_bsz_bounds[0] > self.batch_size or - local_bsz_bounds[1] is not None and - local_bsz_bounds[1] < self.batch_size): - raise ValueError("invalid local_bsz_bounds") - self._max_batch_size = max_batch_size - self._local_bsz_bounds = local_bsz_bounds - self._gradient_accumulation = gradient_accumulation - self.train() def _sync_local_bsz(self): - goodput_fn = get_goodput_fn() - if self.max_batch_size is None or goodput_fn is None: - # No autoscale batch size, just divide batch size evenly. - self._state.current_local_bsz = math.ceil( - self.batch_size / adaptdl.env.num_replicas()) - self._state.accumulation_steps = 0 - elif not self._state.current_local_bsz: - # if init, use the batch size suggested - _, atomic_bsz, accum_steps = goodput_fn.optimize( - adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, - accumulation=self._gradient_accumulation) - self._state.current_local_bsz = atomic_bsz - self._state.accumulation_steps = accum_steps - else: - # if not first time, we check against the relative speedup - suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize( - adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, - accumulation=self._gradient_accumulation) - # get current goodput - current_goodput = goodput_fn( - adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), - self.current_local_bsz, self.accumulation_steps) - # use only if speedup is significant - speedup = suggest_goodput / max(current_goodput, 1e-8) - if speedup > self._speedup_threshold: - self._state.current_local_bsz = atomic_bsz - self._state.accumulation_steps = accum_steps + self._state.current_local_bsz, self._state.accumulation_steps = \ + self._context._get_local_bsz() self._state.current_local_bsz, self._state.accumulation_steps = \ adaptdl.collective.broadcast((self._state.current_local_bsz, self._state.accumulation_steps)) - return self.current_local_bsz + return self.current_local_bsz, self._state.current_local_bsz, self._state.accumulation_steps @property def training(self): @@ -355,8 +318,8 @@ def context(self): @property def current_batch_size(self): - return (self.current_local_bsz * (self.accumulation_steps + 1) * - adaptdl.env.num_replicas()) + return (self._context.current_local_bsz * (self.accumulation_steps + 1) * + adaptdl.env.num_replicas()) def skipdone(self): """ @@ -413,14 +376,15 @@ def __init__(self, batch_size): def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, gradient_accumulation=False): - self._elastic.autoscale_batch_size(max_batch_size, local_bsz_bounds, + self._elastic._context.autoscale_batch_size(max_batch_size, local_bsz_bounds, gradient_accumulation) + self._elastic.train() @property def current_local_bsz(self): - if AdaptiveDataLoaderHelper._current is not self._elastic: - return None - return self._elastic.current_local_bsz + # if AdaptiveDataLoaderHelper._current is not self._elastic: + # return None + return self._elastic._context.current_local_bsz @property def accumulation_steps(self): @@ -428,7 +392,7 @@ def accumulation_steps(self): The number of batches returned by the dataloader before a step is taken. """ - return self._elastic.accumulation_steps + return self._elastic._context.accumulation_steps @property def training(self): @@ -526,19 +490,19 @@ def __iter__(self): while not done: self.sampler.set_epoch( epoch, index=self._elastic.current_index) - self.batch_sampler.batch_size = self._elastic._sync_local_bsz() + self.batch_sampler.batch_size, _, _ = self._elastic._sync_local_bsz() for idx, batch in enumerate(super().__iter__()): with self._elastic.profile(self.training and idx >= 1): yield batch # Increment by the number of data samples processed self._elastic.current_index += \ num_replicas * self.batch_sampler.batch_size - if self._elastic.max_batch_size is not None and \ + if self._elastic._context.max_batch_size is not None and \ get_progress() >= len(self.dataset) * \ (epoch + 1) / self.batch_size: done = True break - if self._elastic.max_batch_size is None: + if self._elastic._context.max_batch_size is None: done = True self._elastic.current_index -= \ self._elastic.current_index % -len(self.dataset) diff --git a/adaptdl/adaptdl/torch/parallel.py b/adaptdl/adaptdl/torch/parallel.py index 218ae2981..99bff445f 100644 --- a/adaptdl/adaptdl/torch/parallel.py +++ b/adaptdl/adaptdl/torch/parallel.py @@ -96,7 +96,7 @@ def forward(self, *args, **kwargs): if dataloader is not None and dataloader.training: self.require_backward_grad_sync = dataloader.is_optim_step() accum_scale = (dataloader.current_local_bsz * - adaptdl.env.num_replicas() / dataloader.batch_size) + adaptdl.env.num_replicas() / dataloader._context.batch_size) self.gns.set_accum_scale(accum_scale) return super().forward(*args, **kwargs) @@ -152,13 +152,13 @@ def _final_callback(self): raise RuntimeError("backpropagation outside AdaptiveDataLoader") dataloader.train() - scale = dataloader.current_batch_size / dataloader.batch_size + scale = dataloader.current_batch_size / dataloader._context.batch_size self._state.gain = self.gns.gain(scale) self._state.lr_factor = \ np.average(self.scaling_rule.scale_lr(scale)) update_progress(self.gns.get_progress()) if dataloader.max_batch_size and \ - dataloader.max_batch_size > dataloader.batch_size: + dataloader.max_batch_size > dataloader._context.batch_size: update_grad_params(self._key, self.gns.sqr_avg(), self.gns.var_avg()) self._sync_start = None diff --git a/adaptdl/adaptdl/torch/scaling_rules.py b/adaptdl/adaptdl/torch/scaling_rules.py index a1300232f..1eeef1941 100644 --- a/adaptdl/adaptdl/torch/scaling_rules.py +++ b/adaptdl/adaptdl/torch/scaling_rules.py @@ -20,7 +20,8 @@ from types import MethodType from adaptdl.torch.data import current_dataloader - +from adaptdl.torch.context import Context +from adaptdl.torch import data __all__ = ["ScalingRuleBase", "AdaScale", "LinearScale", "SqrtScale", "LEGWScale"] @@ -45,6 +46,9 @@ class ScalingRuleBase(object): loss.backward() adascale.step() """ + + _adaptlr = None + def __init__(self): # instance of AdaptiveDataParallel, needs to be set before any of the # methods can be used @@ -74,9 +78,7 @@ def step(self, *args, **kwargs): raise ValueError("AdaptiveDataParallel instance is not set!") if not self.adp.require_backward_grad_sync: return - scale = self.adp.gns.accum_scale * self.adp.gns.accum_count - initial_lr = [pg["lr"] for pg in self._optimizer.param_groups] - scaled_lr = np.multiply(self.scale_lr(scale), initial_lr) + scale, scaled_lr, initial_lr = data.Context_obj.get_lr_scale(self.scale_lr, self.adp.gns, self._optimizer) for lr, pg in zip(scaled_lr, self._optimizer.param_groups): pg["lr"] = lr self._orig_optimizer_step(*args, **kwargs) diff --git a/tutorial/mnist_step_5.py b/tutorial/mnist_step_5.py index 0b7b27025..b862d2e89 100644 --- a/tutorial/mnist_step_5.py +++ b/tutorial/mnist_step_5.py @@ -118,6 +118,8 @@ def main(): transform=transform) dataset2 = datasets.MNIST('../data', train=False, transform=transform) + adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() + else "gloo") # Changed in step 1 train_loader = adaptdl.torch.AdaptiveDataLoader(dataset1, drop_last=True, **kwargs) # Changed in step 2 test_loader = adaptdl.torch.AdaptiveDataLoader(dataset2, **kwargs) # Changed in step 2 @@ -127,8 +129,6 @@ def main(): optimizer = optim.Adadelta(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() - else "gloo") # Changed in step 1 model = adaptdl.torch.AdaptiveDataParallel(model, optimizer, scheduler) # Changed in step 1 for epoch in adaptdl.torch.remaining_epochs_until(args.epochs): # Changed in step 4 diff --git a/tutorial/testcase_for_adaptdldataloader_refactor.py b/tutorial/testcase_for_adaptdldataloader_refactor.py new file mode 100644 index 000000000..d0356beb1 --- /dev/null +++ b/tutorial/testcase_for_adaptdldataloader_refactor.py @@ -0,0 +1,157 @@ +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + +import adaptdl # Changed in step 1 +import adaptdl.torch # Changed in step 1 +from adaptdl.torch import data + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout2d(0.25) + self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch, adacontext): # For test Context only, users do not need to call this + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' + '\treal_batch_size:{}\treal_lr:{}' + '\t ada_batch_size:{}\tada_accum:{}\tada_lr_scale:{}' + .format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item(), + len(data), optimizer.param_groups[0]['lr'], + adacontext.get_batch_size(), adacontext.get_accum_steps(), + adacontext.get_lr_scale(model.scaling_rule.scale_lr, model.gns,optimizer)[1],# For test Context only, users do not need to call this + )) + if args.dry_run: + break + + +def tst(model, device, test_loader): + model.eval() + stats = adaptdl.torch.Accumulator() # Changed in step 5 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + stats["test_loss"] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss # Changed in step 5 + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + stats["correct"] += pred.eq(target.view_as(pred)).sum().item() # Changed in step 5 + + with stats.synchronized(): # Changed in step 5 + test_loss = stats["test_loss"] / len(test_loader.dataset) # Changed in step 5 + correct = stats["correct"] # Changed in step 5 + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) # Changed in step 5 + + +def main(): + # Training settings + + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + kwargs = {'batch_size': args.batch_size} + if use_cuda: + kwargs.update({'num_workers': 1, + 'pin_memory': True, + 'shuffle': True}, + ) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + dataset1 = datasets.MNIST('../data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, + transform=transform) + + adaptdl.torch.init_process_group("nccl" if torch.cuda.is_available() + else "gloo") # Changed in step 1 + adacontext = data.Context_obj # For test Context only, users do not need to call this + train_loader = adaptdl.torch.AdaptiveDataLoader(dataset1, drop_last=True, **kwargs) # Changed in step 2 + test_loader = adaptdl.torch.AdaptiveDataLoader(dataset2, **kwargs) # Changed in step 2 + + train_loader.autoscale_batch_size(1028, local_bsz_bounds=(64, 128)) # Changed in step 3, optional + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + model = adaptdl.torch.AdaptiveDataParallel(model, optimizer, scheduler) # Changed in step 1 + + for epoch in adaptdl.torch.remaining_epochs_until(args.epochs): # Changed in step 4 + train(args, model, device, train_loader, optimizer, epoch, adacontext) # For test Context only, users do not need to call this + tst(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == '__main__': + main()