Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,79 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional
from typing import Optional

from classy_vision.generic.distributed_util import is_primary
from classy_vision.hooks import register_hook
from classy_vision.hooks.classy_hook import ClassyHook


try:
import progressbar
import tqdm

progressbar_available = True
tqdm_available = True
except ImportError:
progressbar_available = False
tqdm_available = False


@register_hook("progress_bar")
class ProgressBarHook(ClassyHook):
"""
Displays a progress bar to show progress in processing batches.
"""
Displays progress bars to show progress in processing batches.

The permanent main progress bar tracks the overall progress in the main task.
The nested progress bar tracks the progress in the current phase.

on_start = ClassyHook._noop
on_end = ClassyHook._noop
This hook assumes that the task passed as argument contains the
following fields (e.g. ``classy_vision.tasks.ClassificationTask``):

- ``phases``: a list of train and test phases
- ``last_batch``: to access the last labels
"""

def __init__(self) -> None:
"""The constructor method of ProgressBarHook."""
super().__init__()
self.progress_bar: Optional[progressbar.ProgressBar] = None
self.bar_size: int = 0
self.batches: int = 0
self.progress_bar: Optional[tqdm.tqdm] = None
self.phase_bar: Optional[tqdm.tqdm] = None

def on_phase_start(self, task) -> None:
def on_start(self, task) -> None:
"""Create and display a progress bar with 0 progress."""
if not progressbar_available:
raise RuntimeError(
"progressbar module not installed, cannot use ProgressBarHook"
if not tqdm_available:
raise RuntimeError("tqdm module not installed, cannot use ProgressBarHook")
if is_primary():
# Compute the total number of images processed
total_images = 0
for phase in task.phases:
phase_type = "train" if phase["train"] else "test"
total_images += len(task.datasets[phase_type])
# Create the main task progress bar
self.progress_bar = tqdm.tqdm(
total=total_images, desc="task", unit="images"
)

def on_phase_start(self, task) -> None:
if is_primary():
self.bar_size = task.num_batches_per_phase
self.batches = 0
self.progress_bar = progressbar.ProgressBar(self.bar_size)
self.progress_bar.start()
phase_images = len(task.datasets[task.phase_type])
self.phase_bar = tqdm.tqdm(
total=phase_images, desc=task.phase_type, unit="images", leave=False
)

def on_step(self, task) -> None:
"""Update the progress bar with the batch size."""
if task.train and is_primary() and self.progress_bar is not None:
self.batches += 1
self.progress_bar.update(min(self.batches, self.bar_size))
if is_primary():
batch_size = task.last_batch.output.size(0)
if self.progress_bar is not None:
self.progress_bar.update(batch_size)
if self.phase_bar is not None:
self.phase_bar.update(batch_size)

def on_phase_end(self, task) -> None:
"""Clear the progress bar at the end of the phase."""
if is_primary() and self.phase_bar is not None:
self.phase_bar.close()

def on_end(self, task) -> None:
"""Clear the progress bar at the end of the task."""
if is_primary() and self.progress_bar is not None:
self.progress_bar.finish()
self.progress_bar.close()
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ line_length=88
multi_line_output=3
use_parentheses=True
lines_after_imports=2
known_third_party=classy_vision,fvcore,numpy,parameterized,PIL,progressbar,torch,torchelastic,torchvision,visdom
known_third_party=classy_vision,fvcore,numpy,parameterized,PIL,torch,torchelastic,torchvision,tqdm,visdom
49 changes: 24 additions & 25 deletions test/manual/hooks_progress_bar_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import unittest.mock as mock
from test.generic.config_utils import get_test_classy_task
from test.generic.hook_test_utils import HookTestBase

import progressbar
import torch
import tqdm
from classy_vision.hooks import ProgressBarHook
from classy_vision.tasks.classification_task import LastBatchInfo


class TestProgressBarHook(HookTestBase):
Expand All @@ -23,16 +24,16 @@ def test_constructors(self) -> None:
config=config, hook_type=ProgressBarHook, hook_registry_name="progress_bar"
)

@mock.patch("classy_vision.hooks.progress_bar_hook.progressbar")
@mock.patch("classy_vision.hooks.progress_bar_hook.tqdm")
@mock.patch("classy_vision.hooks.progress_bar_hook.is_primary")
def test_progress_bar(
self, mock_is_primary: mock.MagicMock, mock_progressbar_pkg: mock.MagicMock
self, mock_is_primary: mock.MagicMock, mock_tqdm_pkg: mock.MagicMock
) -> None:
"""
Tests that the progress bar is created, updated and destroyed correctly.
"""
mock_progress_bar = mock.create_autospec(progressbar.ProgressBar, instance=True)
mock_progressbar_pkg.ProgressBar.return_value = mock_progress_bar
mock_progress_bar = mock.create_autospec(tqdm.tqdm, instance=True)
mock_tqdm_pkg.tqdm.return_value = mock_progress_bar

mock_is_primary.return_value = True

Expand All @@ -49,29 +50,27 @@ def test_progress_bar(

# progressbar.ProgressBar should be init-ed with num_batches
progress_bar_hook.on_phase_start(task)
mock_progressbar_pkg.ProgressBar.assert_called_once_with(num_batches)
mock_progress_bar.start.assert_called_once_with()
mock_progress_bar.start.reset_mock()
mock_progressbar_pkg.ProgressBar.reset_mock()
phase_images = len(task.datasets[task.phase_type])
mock_tqdm_pkg.tqdm.assert_called_once_with(
total=phase_images, desc=task.phase_type, unit="images", leave=False
)
mock_tqdm_pkg.tqdm.reset_mock()

# on_step should update the progress bar correctly
for i in range(num_batches):
# Fake a batch
batch_size = 32
task.last_batch = LastBatchInfo(
loss=torch.empty(batch_size),
output=torch.empty(batch_size),
target=torch.empty(batch_size),
sample={},
step_data={},
)
progress_bar_hook.on_step(task)
mock_progress_bar.update.assert_called_once_with(i + 1)
mock_progress_bar.update.reset_mock()

# check that even if on_step is called again, the progress bar is
# only updated with num_batches
for _ in range(num_batches):
progress_bar_hook.on_step(task)
mock_progress_bar.update.assert_called_once_with(num_batches)
mock_progress_bar.update.assert_called_once_with(batch_size)
mock_progress_bar.update.reset_mock()

# finish should be called on the progress bar
progress_bar_hook.on_phase_end(task)
mock_progress_bar.finish.assert_called_once_with()
mock_progress_bar.finish.reset_mock()

# check that even if the progress bar isn't created, the code doesn't
# crash
progress_bar_hook = ProgressBarHook()
Expand All @@ -82,7 +81,7 @@ def test_progress_bar(
self.fail(
"Received Exception when on_phase_start() isn't called: {}".format(e)
)
mock_progressbar_pkg.ProgressBar.assert_not_called()
mock_tqdm_pkg.ProgressBar.assert_not_called()

# check that a progress bar is not created if is_primary() returns False
mock_is_primary.return_value = False
Expand All @@ -94,4 +93,4 @@ def test_progress_bar(
except Exception as e:
self.fail("Received Exception when is_primary() is False: {}".format(e))
self.assertIsNone(progress_bar_hook.progress_bar)
mock_progressbar_pkg.ProgressBar.assert_not_called()
mock_tqdm_pkg.ProgressBar.assert_not_called()