diff --git a/classy_vision/hooks/progress_bar_hook.py b/classy_vision/hooks/progress_bar_hook.py index f2539604ec..f92b5b664b 100644 --- a/classy_vision/hooks/progress_bar_hook.py +++ b/classy_vision/hooks/progress_bar_hook.py @@ -4,7 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Optional from classy_vision.generic.distributed_util import is_primary from classy_vision.hooks import register_hook @@ -12,49 +12,71 @@ try: - import progressbar + import tqdm - progressbar_available = True + tqdm_available = True except ImportError: - progressbar_available = False + tqdm_available = False @register_hook("progress_bar") class ProgressBarHook(ClassyHook): """ - Displays a progress bar to show progress in processing batches. - """ + Displays progress bars to show progress in processing batches. + + The permanent main progress bar tracks the overall progress in the main task. + The nested progress bar tracks the progress in the current phase. - on_start = ClassyHook._noop - on_end = ClassyHook._noop + This hook assumes that the task passed as argument contains the + following fields (e.g. ``classy_vision.tasks.ClassificationTask``): + + - ``phases``: a list of train and test phases + - ``last_batch``: to access the last labels + """ def __init__(self) -> None: """The constructor method of ProgressBarHook.""" super().__init__() - self.progress_bar: Optional[progressbar.ProgressBar] = None - self.bar_size: int = 0 - self.batches: int = 0 + self.progress_bar: Optional[tqdm.tqdm] = None + self.phase_bar: Optional[tqdm.tqdm] = None - def on_phase_start(self, task) -> None: + def on_start(self, task) -> None: """Create and display a progress bar with 0 progress.""" - if not progressbar_available: - raise RuntimeError( - "progressbar module not installed, cannot use ProgressBarHook" + if not tqdm_available: + raise RuntimeError("tqdm module not installed, cannot use ProgressBarHook") + if is_primary(): + # Compute the total number of images processed + total_images = 0 + for phase in task.phases: + phase_type = "train" if phase["train"] else "test" + total_images += len(task.datasets[phase_type]) + # Create the main task progress bar + self.progress_bar = tqdm.tqdm( + total=total_images, desc="task", unit="images" ) + def on_phase_start(self, task) -> None: if is_primary(): - self.bar_size = task.num_batches_per_phase - self.batches = 0 - self.progress_bar = progressbar.ProgressBar(self.bar_size) - self.progress_bar.start() + phase_images = len(task.datasets[task.phase_type]) + self.phase_bar = tqdm.tqdm( + total=phase_images, desc=task.phase_type, unit="images", leave=False + ) def on_step(self, task) -> None: """Update the progress bar with the batch size.""" - if task.train and is_primary() and self.progress_bar is not None: - self.batches += 1 - self.progress_bar.update(min(self.batches, self.bar_size)) + if is_primary(): + batch_size = task.last_batch.output.size(0) + if self.progress_bar is not None: + self.progress_bar.update(batch_size) + if self.phase_bar is not None: + self.phase_bar.update(batch_size) def on_phase_end(self, task) -> None: """Clear the progress bar at the end of the phase.""" + if is_primary() and self.phase_bar is not None: + self.phase_bar.close() + + def on_end(self, task) -> None: + """Clear the progress bar at the end of the task.""" if is_primary() and self.progress_bar is not None: - self.progress_bar.finish() + self.progress_bar.close() diff --git a/setup.cfg b/setup.cfg index d504f55be3..fcf52cca17 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,4 +6,4 @@ line_length=88 multi_line_output=3 use_parentheses=True lines_after_imports=2 -known_third_party=classy_vision,fvcore,numpy,parameterized,PIL,progressbar,torch,torchelastic,torchvision,visdom +known_third_party=classy_vision,fvcore,numpy,parameterized,PIL,torch,torchelastic,torchvision,tqdm,visdom diff --git a/test/manual/hooks_progress_bar_hook_test.py b/test/manual/hooks_progress_bar_hook_test.py index 78bb95a79c..47703e2f57 100644 --- a/test/manual/hooks_progress_bar_hook_test.py +++ b/test/manual/hooks_progress_bar_hook_test.py @@ -4,13 +4,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import unittest import unittest.mock as mock from test.generic.config_utils import get_test_classy_task from test.generic.hook_test_utils import HookTestBase -import progressbar +import torch +import tqdm from classy_vision.hooks import ProgressBarHook +from classy_vision.tasks.classification_task import LastBatchInfo class TestProgressBarHook(HookTestBase): @@ -23,16 +24,16 @@ def test_constructors(self) -> None: config=config, hook_type=ProgressBarHook, hook_registry_name="progress_bar" ) - @mock.patch("classy_vision.hooks.progress_bar_hook.progressbar") + @mock.patch("classy_vision.hooks.progress_bar_hook.tqdm") @mock.patch("classy_vision.hooks.progress_bar_hook.is_primary") def test_progress_bar( - self, mock_is_primary: mock.MagicMock, mock_progressbar_pkg: mock.MagicMock + self, mock_is_primary: mock.MagicMock, mock_tqdm_pkg: mock.MagicMock ) -> None: """ Tests that the progress bar is created, updated and destroyed correctly. """ - mock_progress_bar = mock.create_autospec(progressbar.ProgressBar, instance=True) - mock_progressbar_pkg.ProgressBar.return_value = mock_progress_bar + mock_progress_bar = mock.create_autospec(tqdm.tqdm, instance=True) + mock_tqdm_pkg.tqdm.return_value = mock_progress_bar mock_is_primary.return_value = True @@ -49,29 +50,27 @@ def test_progress_bar( # progressbar.ProgressBar should be init-ed with num_batches progress_bar_hook.on_phase_start(task) - mock_progressbar_pkg.ProgressBar.assert_called_once_with(num_batches) - mock_progress_bar.start.assert_called_once_with() - mock_progress_bar.start.reset_mock() - mock_progressbar_pkg.ProgressBar.reset_mock() + phase_images = len(task.datasets[task.phase_type]) + mock_tqdm_pkg.tqdm.assert_called_once_with( + total=phase_images, desc=task.phase_type, unit="images", leave=False + ) + mock_tqdm_pkg.tqdm.reset_mock() # on_step should update the progress bar correctly for i in range(num_batches): + # Fake a batch + batch_size = 32 + task.last_batch = LastBatchInfo( + loss=torch.empty(batch_size), + output=torch.empty(batch_size), + target=torch.empty(batch_size), + sample={}, + step_data={}, + ) progress_bar_hook.on_step(task) - mock_progress_bar.update.assert_called_once_with(i + 1) - mock_progress_bar.update.reset_mock() - - # check that even if on_step is called again, the progress bar is - # only updated with num_batches - for _ in range(num_batches): - progress_bar_hook.on_step(task) - mock_progress_bar.update.assert_called_once_with(num_batches) + mock_progress_bar.update.assert_called_once_with(batch_size) mock_progress_bar.update.reset_mock() - # finish should be called on the progress bar - progress_bar_hook.on_phase_end(task) - mock_progress_bar.finish.assert_called_once_with() - mock_progress_bar.finish.reset_mock() - # check that even if the progress bar isn't created, the code doesn't # crash progress_bar_hook = ProgressBarHook() @@ -82,7 +81,7 @@ def test_progress_bar( self.fail( "Received Exception when on_phase_start() isn't called: {}".format(e) ) - mock_progressbar_pkg.ProgressBar.assert_not_called() + mock_tqdm_pkg.ProgressBar.assert_not_called() # check that a progress bar is not created if is_primary() returns False mock_is_primary.return_value = False @@ -94,4 +93,4 @@ def test_progress_bar( except Exception as e: self.fail("Received Exception when is_primary() is False: {}".format(e)) self.assertIsNone(progress_bar_hook.progress_bar) - mock_progressbar_pkg.ProgressBar.assert_not_called() + mock_tqdm_pkg.ProgressBar.assert_not_called()