-
Notifications
You must be signed in to change notification settings - Fork 199
Description
In ConfigContainer there is _calculate_scheduler_steps which computes samples/steps for warmup/training/decay. Problem is that this doesn't account for batch size warmup. Steps are computed as iters x global_batch_size. Because of this warmup lasts longer in terms of iters and decay starts and ends later. For warmup and cosine scheduler it's not a big issue because you can just adjust those to last less or stop sooner. For WSD scheduler though, it's a big problem because start of the decay is computed from the end. Here is an example training setup:
train_iters: 100global_batch_size: 16rampup_batch_size: [8, 8, 160]lr_decay_style: WSDlr_wsd_decay_style: linearlr_warmup_iters: 10lr_wsd_decay_iters: 20
The resulting lr schedule looks like this:

As you can see warmup lasts 20 iters instead of 10 because during those 20 iters the batch size is half of global_batch_size and decay starts 10 iters later and doesn't get a chance to decay completely and stops half way. Increasing train_iters to 110 would just shift the decay to the right again because for WSD start of the decay is computed from the end.
Here is a quick fix that copies Megatron LM batch size warmup logic in order to compute correct steps/samples.
def _get_consumed_samples_for_iterations(self, iterations: int) -> int:
"""Calculates consumed samples after a given number of iterations."""
if iterations <= 0:
return 0
global_batch_size = self.train.global_batch_size
rampup_batch_size = self.train.rampup_batch_size
if rampup_batch_size is None:
return iterations * global_batch_size
# Same behavior as in Megatron LM
# megatron/core/num_microbatches_calculator.py::RampupBatchsizeNumMicroBatchesCalculator
assert len(rampup_batch_size) == 3, (
'expected the following '
'format: --rampup-batch-size <start batch size> '
'<batch size incerement> <ramp-up samples>'
)
start_batch_size = int(rampup_batch_size[0])
batch_size_increment = int(rampup_batch_size[1])
rampup_samples = int(rampup_batch_size[2])
diff_batch_size = global_batch_size - start_batch_size
assert diff_batch_size >= 0, (
'expected global batch size to be greater than or equal to start batch size, '
f'got {global_batch_size} and {start_batch_size}'
)
assert diff_batch_size % batch_size_increment == 0, (
'expected '
f'global batch size interval ({diff_batch_size}) to be divisible by global batch '
f'size increment ({batch_size_increment})'
)
num_increments = diff_batch_size // batch_size_increment
rampup_samples_per_increment = rampup_samples / num_increments
consumed_samples = 0
for _ in range(iterations):
if consumed_samples > rampup_samples:
current_batch_size = global_batch_size
else:
steps = int(consumed_samples / rampup_samples_per_increment)
current_batch_size = start_batch_size + steps * batch_size_increment
assert current_batch_size <= global_batch_size
consumed_samples += current_batch_size
return consumed_samples
def _calculate_scheduler_steps(self) -> None:
"""Calculate scheduler steps for both iteration-based and sample-based training."""
is_sample_based = self.train.train_samples is not None
if is_sample_based:
if self.scheduler.lr_decay_samples is None:
self.scheduler.lr_decay_samples = self.train.train_samples
self.scheduler.lr_decay_steps = self.scheduler.lr_decay_samples
self.scheduler.wd_incr_steps = self.train.train_samples
if self.scheduler.lr_wsd_decay_samples is not None:
self.scheduler.wsd_decay_steps = self.scheduler.lr_wsd_decay_samples
# Warmup calculation for sample-based training
if self.scheduler.lr_warmup_fraction is not None:
self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps
else:
self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_samples
else:
# Iteration-based training
if self.scheduler.lr_decay_iters is None:
self.scheduler.lr_decay_iters = self.train.train_iters
self.scheduler.lr_decay_steps = self._get_consumed_samples_for_iterations(self.scheduler.lr_decay_iters)
self.scheduler.wd_incr_steps = self._get_consumed_samples_for_iterations(self.train.train_iters)
if self.scheduler.lr_wsd_decay_iters is not None:
wsd_start_iters = max(self.scheduler.lr_decay_iters - self.scheduler.lr_wsd_decay_iters, 0)
self.scheduler.wsd_decay_steps = (
self.scheduler.lr_decay_steps - self._get_consumed_samples_for_iterations(wsd_start_iters)
)
if self.scheduler.lr_warmup_fraction is not None:
self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps
else:
self.scheduler.lr_warmup_steps = self._get_consumed_samples_for_iterations(self.scheduler.lr_warmup_iters)