Skip to content

Using rampup_batch_size breaks the learning rate schedule #2609

@ezKoya

Description

@ezKoya

In ConfigContainer there is _calculate_scheduler_steps which computes samples/steps for warmup/training/decay. Problem is that this doesn't account for batch size warmup. Steps are computed as iters x global_batch_size. Because of this warmup lasts longer in terms of iters and decay starts and ends later. For warmup and cosine scheduler it's not a big issue because you can just adjust those to last less or stop sooner. For WSD scheduler though, it's a big problem because start of the decay is computed from the end. Here is an example training setup:

  • train_iters: 100
  • global_batch_size: 16
  • rampup_batch_size: [8, 8, 160]
  • lr_decay_style: WSD
  • lr_wsd_decay_style: linear
  • lr_warmup_iters: 10
  • lr_wsd_decay_iters: 20

The resulting lr schedule looks like this:
Image

As you can see warmup lasts 20 iters instead of 10 because during those 20 iters the batch size is half of global_batch_size and decay starts 10 iters later and doesn't get a chance to decay completely and stops half way. Increasing train_iters to 110 would just shift the decay to the right again because for WSD start of the decay is computed from the end.
Here is a quick fix that copies Megatron LM batch size warmup logic in order to compute correct steps/samples.

    def _get_consumed_samples_for_iterations(self, iterations: int) -> int:
        """Calculates consumed samples after a given number of iterations."""
        if iterations <= 0:
            return 0

        global_batch_size = self.train.global_batch_size

        rampup_batch_size = self.train.rampup_batch_size
        if rampup_batch_size is None:
            return iterations * global_batch_size

        # Same behavior as in Megatron LM
        # megatron/core/num_microbatches_calculator.py::RampupBatchsizeNumMicroBatchesCalculator
        assert len(rampup_batch_size) == 3, (
            'expected the following '
            'format: --rampup-batch-size <start batch size> '
            '<batch size incerement> <ramp-up samples>'
        )
        start_batch_size = int(rampup_batch_size[0])
        batch_size_increment = int(rampup_batch_size[1])
        rampup_samples = int(rampup_batch_size[2])

        diff_batch_size = global_batch_size - start_batch_size
        assert diff_batch_size >= 0, (
            'expected global batch size to be greater than or equal to start batch size, '
            f'got {global_batch_size} and {start_batch_size}'
        )
        assert diff_batch_size % batch_size_increment == 0, (
            'expected '
            f'global batch size interval ({diff_batch_size}) to be divisible by global batch '
            f'size increment ({batch_size_increment})'
        )

        num_increments = diff_batch_size // batch_size_increment
        rampup_samples_per_increment = rampup_samples / num_increments

        consumed_samples = 0
        for _ in range(iterations):
            if consumed_samples > rampup_samples:
                current_batch_size = global_batch_size
            else:
                steps = int(consumed_samples / rampup_samples_per_increment)
                current_batch_size = start_batch_size + steps * batch_size_increment
                assert current_batch_size <= global_batch_size

            consumed_samples += current_batch_size

        return consumed_samples

    def _calculate_scheduler_steps(self) -> None:
        """Calculate scheduler steps for both iteration-based and sample-based training."""
        is_sample_based = self.train.train_samples is not None

        if is_sample_based:
            if self.scheduler.lr_decay_samples is None:
                self.scheduler.lr_decay_samples = self.train.train_samples
            self.scheduler.lr_decay_steps = self.scheduler.lr_decay_samples
            self.scheduler.wd_incr_steps = self.train.train_samples

            if self.scheduler.lr_wsd_decay_samples is not None:
                self.scheduler.wsd_decay_steps = self.scheduler.lr_wsd_decay_samples

            # Warmup calculation for sample-based training
            if self.scheduler.lr_warmup_fraction is not None:
                self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps
            else:
                self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_samples
        else:
            # Iteration-based training
            if self.scheduler.lr_decay_iters is None:
                self.scheduler.lr_decay_iters = self.train.train_iters
            self.scheduler.lr_decay_steps = self._get_consumed_samples_for_iterations(self.scheduler.lr_decay_iters)
            self.scheduler.wd_incr_steps = self._get_consumed_samples_for_iterations(self.train.train_iters)

            if self.scheduler.lr_wsd_decay_iters is not None:
                wsd_start_iters = max(self.scheduler.lr_decay_iters - self.scheduler.lr_wsd_decay_iters, 0)
                self.scheduler.wsd_decay_steps = (
                    self.scheduler.lr_decay_steps - self._get_consumed_samples_for_iterations(wsd_start_iters)
                )

            if self.scheduler.lr_warmup_fraction is not None:
                self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps
            else:
                self.scheduler.lr_warmup_steps = self._get_consumed_samples_for_iterations(self.scheduler.lr_warmup_iters)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions