Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions adaptdl/adaptdl/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import adaptdl.collective
import adaptdl.env
import adaptdl.torch.data
import semver
from .epoch import current_epoch, finished_epochs, remaining_epochs_until
from .data import current_dataloader, AdaptiveDataLoader, ElasticSampler
Expand Down Expand Up @@ -119,6 +120,9 @@ def init_process_group(backend,
rank,
world_size)

# Initialize Context module.
adaptdl.torch.data.context_initialize()

# Initialize torch.distributed.
torch_port = adaptdl.collective.broadcast(portpicker.pick_unused_port())
init_method = "tcp://{}:{}?rank={}&world_size={}".format(
Expand Down
153 changes: 153 additions & 0 deletions adaptdl/adaptdl/torch/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import math
import adaptdl.checkpoint
import adaptdl.collective
import adaptdl.env
from adaptdl.torch._metrics import get_goodput_fn
import adaptdl.torch.data as data
import numpy as np

class Context(object):
"""
This class provides context tool to get AdaptDL-suggest parameters,
such as batch_size, accum_steps and lr_scale.
"""

def __init__(self, batch_size=32):
# Autoscale batch size fields.
self._speedup_threshold = 1.05
self.adapt_batch_size = None
self.adapt_accum_steps = None
self.adapt_lr_scale = None

self._max_batch_size = None
self._local_bsz_bounds = None
# Create and load state.
self._state = data._AdaptiveDataLoaderState()
adaptdl.checkpoint.load_state(self._state)
self.batch_size = batch_size
# self.state_batch_size = 1
self._gradient_accumulation = False

def get_batch_size(self):
self.adapt_batch_size, _ = self._get_local_bsz()
return self.adapt_batch_size

def get_accum_steps(self):
_, self.adapt_accum_steps = self._get_local_bsz()
return self.adapt_accum_steps

@staticmethod
def get_lr_scale(scale_lr, gns, optimizer):
scale = gns.accum_scale * gns.accum_count
initial_lr = [pg["lr"] for pg in optimizer.param_groups]
return scale, np.multiply(scale_lr(scale), initial_lr), initial_lr

def _get_local_bsz(self):
goodput_fn = get_goodput_fn()
if self.max_batch_size is None or goodput_fn is None:
# No autoscale batch size, just divide batch size evenly.
self._state.current_local_bsz = math.ceil(
self.batch_size / adaptdl.env.num_replicas())
self._state.accumulation_steps = 0
elif not self._state.current_local_bsz:
# if init, use the batch size suggested
_, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
accumulation=self._gradient_accumulation)
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps
else:
# if not first time, we check against the relative speedup
suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
accumulation=self._gradient_accumulation)
# get current goodput
current_goodput = goodput_fn(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
self.current_local_bsz, self.accumulation_steps)
# use only if speedup is significant
speedup = suggest_goodput / max(current_goodput, 1e-8)
if speedup > self._speedup_threshold:
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps
return self._state.current_local_bsz, self._state.accumulation_steps

@property
def max_batch_size(self):
"""
The maximum total batch size allowed for adaptive batch size. ``None``
if adaptive batch size is disabled.
"""
return self._max_batch_size

@property
def local_bsz_bounds(self):
"""
The local batch size bounds on each replica. A pair of integers,
(min_local_bsz, max_local_bsz).
"""
return self._local_bsz_bounds

@property
def current_local_bsz(self):
"""
The current logical local batch size used by the dataloader.
The batch size returned by the dataloader may be smaller if
gradient accumulation is used
"""
return self._state.current_local_bsz

@property
def accumulation_steps(self):
"""
The number of batches returned by the dataloader before a
step is taken.
"""
return self._state.accumulation_steps

def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None,
gradient_accumulation=False):
"""
Enables adaptive batch size. Should be invoked once after the data
loader object is created.

Arguments:
max_batch_size (int): Maximum total batch size allowed.
local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz),
the min and max local batch sizes allowed on each replica.

Raises:
ValueError: If any of the provided batch size bounds are invalid.
"""
if not isinstance(max_batch_size, int) or \
max_batch_size < self.batch_size:
raise ValueError("invalid max_batch_size")
if local_bsz_bounds is not None and (
local_bsz_bounds[0] is not None and
local_bsz_bounds[0] > self.batch_size or
local_bsz_bounds[1] is not None and
local_bsz_bounds[1] < self.batch_size):
raise ValueError("invalid local_bsz_bounds")
self._max_batch_size = max_batch_size
self._local_bsz_bounds = local_bsz_bounds
self._gradient_accumulation = gradient_accumulation

118 changes: 41 additions & 77 deletions adaptdl/adaptdl/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,24 @@ def current_dataloader():
return AdaptiveDataLoaderHelper._current


Context_obj = None
def context_initialize():
"""
Initialize this module, must be invoked before calling any other functions.
This function will block until it has been invoked from all replicas.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How's this enforced?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How's this enforced?

Dear Omkar, Thank you for asking. As we were making the Context global, the Context was firstly initialized in init_process_group as Context_obj following Aurick's suggestion. All the subsequent processes in terms of Context will be using the Context_obj instead. So the initialize process was enforced in the very beginning.


Arguments:
batch_size: batch_size of the context.

Raises:
RuntimeError: If this module had already been initialized.
"""
global Context_obj
if Context_obj is not None:
raise RuntimeError("{} is already initialized".format(__name__))
Context_obj = adaptdl.torch.context.Context()
return Context_obj

class AdaptiveDataLoaderHelper(object):
"""
This class provides fine-grained control over adaptive training loops. It
Expand All @@ -139,14 +157,15 @@ class AdaptiveDataLoaderHelper(object):
_training = None # The AdaptiveDataLoader which loads training data.
_current = None # The AdaptiveDataLoader which is currently iterating.

def __init__(self, batch_size=1):
def __init__(self, batch_size=32):
self._context = Context_obj
# Autoscale batch size fields.
self._max_batch_size = None
self._local_bsz_bounds = None
# Create and load state.
self._state = _AdaptiveDataLoaderState()
adaptdl.checkpoint.load_state(self._state)
self.batch_size = batch_size
self._state = self._context._state
# adaptdl.checkpoint.load_state(self._state)
self._context.batch_size = batch_size
self.future_exit = None
self._gradient_accumulation = False
self._speedup_threshold = 1.05
Expand Down Expand Up @@ -198,7 +217,7 @@ def local_bsz_bounds(self):
The local batch size bounds on each replica. A pair of integers,
(min_local_bsz, max_local_bsz).
"""
return self._local_bsz_bounds
return self._context._local_bsz_bounds

@property
def current_local_bsz(self):
Expand All @@ -207,15 +226,15 @@ def current_local_bsz(self):
The batch size returned by the dataloader may be smaller if
gradient accumulation is used
"""
return self._state.current_local_bsz
return self._context.get_batch_size()

@property
def accumulation_steps(self):
"""
The number of batches returned by the dataloader before a
step is taken.
"""
return self._state.accumulation_steps
return self._context.get_accum_steps()

def is_accum_step(self):
"""
Expand All @@ -236,73 +255,17 @@ def train(self):
"""
if AdaptiveDataLoaderHelper._training is None:
AdaptiveDataLoaderHelper._training = self
set_batch_size(self.batch_size, self.max_batch_size,
set_batch_size(self._context.batch_size, self.max_batch_size,
self.local_bsz_bounds, self._gradient_accumulation)

def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None,
gradient_accumulation=False):
"""
Enables adaptive batch size. Should be invoked once after the data
loader object is created.

Arguments:
max_batch_size (int): Maximum total batch size allowed.
local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz),
the min and max local batch sizes allowed on each replica.

Raises:
ValueError: If any of the provided batch size bounds are invalid.
"""
if not isinstance(max_batch_size, int) or \
max_batch_size < self.batch_size:
raise ValueError("invalid max_batch_size")
if local_bsz_bounds is not None and (
local_bsz_bounds[0] is not None and
local_bsz_bounds[0] > self.batch_size or
local_bsz_bounds[1] is not None and
local_bsz_bounds[1] < self.batch_size):
raise ValueError("invalid local_bsz_bounds")
self._max_batch_size = max_batch_size
self._local_bsz_bounds = local_bsz_bounds
self._gradient_accumulation = gradient_accumulation
self.train()

def _sync_local_bsz(self):
goodput_fn = get_goodput_fn()
if self.max_batch_size is None or goodput_fn is None:
# No autoscale batch size, just divide batch size evenly.
self._state.current_local_bsz = math.ceil(
self.batch_size / adaptdl.env.num_replicas())
self._state.accumulation_steps = 0
elif not self._state.current_local_bsz:
# if init, use the batch size suggested
_, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
accumulation=self._gradient_accumulation)
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps
else:
# if not first time, we check against the relative speedup
suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
accumulation=self._gradient_accumulation)
# get current goodput
current_goodput = goodput_fn(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
self.current_local_bsz, self.accumulation_steps)
# use only if speedup is significant
speedup = suggest_goodput / max(current_goodput, 1e-8)
if speedup > self._speedup_threshold:
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps
self._state.current_local_bsz, self._state.accumulation_steps = \
self._context._get_local_bsz()
self._state.current_local_bsz, self._state.accumulation_steps = \
adaptdl.collective.broadcast((self._state.current_local_bsz,
self._state.accumulation_steps))
return self.current_local_bsz
return self.current_local_bsz, self._state.current_local_bsz, self._state.accumulation_steps

@property
def training(self):
Expand Down Expand Up @@ -355,8 +318,8 @@ def context(self):

@property
def current_batch_size(self):
return (self.current_local_bsz * (self.accumulation_steps + 1) *
adaptdl.env.num_replicas())
return (self._context.current_local_bsz * (self.accumulation_steps + 1) *
adaptdl.env.num_replicas())

def skipdone(self):
"""
Expand Down Expand Up @@ -413,22 +376,23 @@ def __init__(self, batch_size):

def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None,
gradient_accumulation=False):
self._elastic.autoscale_batch_size(max_batch_size, local_bsz_bounds,
self._elastic._context.autoscale_batch_size(max_batch_size, local_bsz_bounds,
gradient_accumulation)
self._elastic.train()

@property
def current_local_bsz(self):
if AdaptiveDataLoaderHelper._current is not self._elastic:
return None
return self._elastic.current_local_bsz
# if AdaptiveDataLoaderHelper._current is not self._elastic:
# return None
return self._elastic._context.current_local_bsz

@property
def accumulation_steps(self):
"""
The number of batches returned by the dataloader before a
step is taken.
"""
return self._elastic.accumulation_steps
return self._elastic._context.accumulation_steps

@property
def training(self):
Expand Down Expand Up @@ -526,19 +490,19 @@ def __iter__(self):
while not done:
self.sampler.set_epoch(
epoch, index=self._elastic.current_index)
self.batch_sampler.batch_size = self._elastic._sync_local_bsz()
self.batch_sampler.batch_size, _, _ = self._elastic._sync_local_bsz()
for idx, batch in enumerate(super().__iter__()):
with self._elastic.profile(self.training and idx >= 1):
yield batch
# Increment by the number of data samples processed
self._elastic.current_index += \
num_replicas * self.batch_sampler.batch_size
if self._elastic.max_batch_size is not None and \
if self._elastic._context.max_batch_size is not None and \
get_progress() >= len(self.dataset) * \
(epoch + 1) / self.batch_size:
done = True
break
if self._elastic.max_batch_size is None:
if self._elastic._context.max_batch_size is None:
done = True
self._elastic.current_index -= \
self._elastic.current_index % -len(self.dataset)
Expand Down
Loading