diff --git a/.coveragerc b/.coveragerc index d6710d6e9..d89c0698f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -25,7 +25,7 @@ omit = zetta_utils/alignment/*.py zetta_utils/segmentation/*.py zetta_utils/segmentation/**/*.py - zetta_utils/training/lightning/*.py + zetta_utils/training/lightning/train.py zetta_utils/cloud_management/*.py zetta_utils/cloud_management/**/*.py zetta_utils/run/*.py diff --git a/tests/unit/training/lightning/test_default_trainer.py b/tests/unit/training/lightning/test_default_trainer.py index 8535c520e..5ec12e209 100644 --- a/tests/unit/training/lightning/test_default_trainer.py +++ b/tests/unit/training/lightning/test_default_trainer.py @@ -1,11 +1,276 @@ +# pylint: disable=redefined-outer-name +import os +import tempfile +from unittest.mock import MagicMock, PropertyMock, patch + import lightning.pytorch as pl +import pytest +import torch -from zetta_utils import training +from zetta_utils.training.lightning.trainers.default import ( + ZettaDefaultTrainer, + jit_trace_export, + onnx_export, +) def test_default_trainer(): - result = training.lightning.trainers.ZettaDefaultTrainer( + result = ZettaDefaultTrainer( experiment_name="unit_test", experiment_version="x0", ) assert isinstance(result, pl.Trainer) + + +class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + +@pytest.fixture +def mock_model(): + return MockModel() + + +@pytest.fixture +def mock_trace_input(): + return torch.randn(1, 10) + + +@pytest.fixture +def mock_lightning_module(): + """Fixture providing a mock lightning module.""" + mock_module = MagicMock() + mock_module._modules = {} # pylint: disable=protected-access + return mock_module + + +@pytest.fixture +def trainer_mocks(mocker): + """Fixture providing common mocks for trainer tests.""" + return { + "super_save": mocker.patch.object(pl.Trainer, "save_checkpoint"), + "jit_export": mocker.patch( + "zetta_utils.training.lightning.trainers.default.jit_trace_export" + ), + "onnx_export": mocker.patch("zetta_utils.training.lightning.trainers.default.onnx_export"), + "is_global_zero": mocker.patch.object( + ZettaDefaultTrainer, "is_global_zero", new_callable=PropertyMock + ), + "lightning_module": mocker.patch.object( + ZettaDefaultTrainer, "lightning_module", new_callable=PropertyMock + ), + } + + +def test_save_checkpoint_calls_exports_when_enabled( + trainer_mocks, mock_model, mock_trace_input, mock_lightning_module +): + """Test that save_checkpoint calls export functions when exports are enabled.""" + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = ZettaDefaultTrainer( + experiment_name="unit_test", + experiment_version="x0", + enable_jit_export=True, + enable_onnx_export=True, + default_root_dir=tmp_dir, + ) + + trainer.trace_configuration = { + "test_model": {"model": mock_model, "trace_input": (mock_trace_input,)} + } + + filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt") + + trainer_mocks["is_global_zero"].return_value = True + trainer_mocks["lightning_module"].return_value = mock_lightning_module + + trainer.save_checkpoint(filepath) + + trainer_mocks["onnx_export"].assert_called_once_with( + mock_model, (mock_trace_input,), filepath, "test_model" + ) + trainer_mocks["jit_export"].assert_called_once_with( + mock_model, (mock_trace_input,), filepath, "test_model" + ) + trainer_mocks["super_save"].assert_called_once() + + +def test_save_checkpoint_skips_jit_when_disabled( + trainer_mocks, mock_model, mock_trace_input, mock_lightning_module +): + """Test that save_checkpoint skips JIT export when disabled.""" + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = ZettaDefaultTrainer( + experiment_name="unit_test", + experiment_version="x0", + enable_jit_export=False, + enable_onnx_export=True, + default_root_dir=tmp_dir, + ) + + trainer.trace_configuration = { + "test_model": {"model": mock_model, "trace_input": (mock_trace_input,)} + } + + filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt") + + trainer_mocks["is_global_zero"].return_value = True + trainer_mocks["lightning_module"].return_value = mock_lightning_module + + trainer.save_checkpoint(filepath) + + trainer_mocks["onnx_export"].assert_called_once_with( + mock_model, (mock_trace_input,), filepath, "test_model" + ) + trainer_mocks["jit_export"].assert_not_called() + + +def test_save_checkpoint_skips_onnx_when_disabled( + trainer_mocks, mock_model, mock_trace_input, mock_lightning_module +): + """Test that save_checkpoint skips ONNX export when disabled.""" + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = ZettaDefaultTrainer( + experiment_name="unit_test", + experiment_version="x0", + enable_jit_export=True, + enable_onnx_export=False, + default_root_dir=tmp_dir, + ) + + trainer.trace_configuration = { + "test_model": {"model": mock_model, "trace_input": (mock_trace_input,)} + } + + filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt") + + trainer_mocks["is_global_zero"].return_value = True + trainer_mocks["lightning_module"].return_value = mock_lightning_module + + trainer.save_checkpoint(filepath) + + trainer_mocks["jit_export"].assert_called_once_with( + mock_model, (mock_trace_input,), filepath, "test_model" + ) + trainer_mocks["onnx_export"].assert_not_called() + + +def test_save_checkpoint_skips_exports_non_global_zero(trainer_mocks): + """Test that save_checkpoint skips exports when not global zero rank.""" + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = ZettaDefaultTrainer( + experiment_name="unit_test", + experiment_version="x0", + enable_jit_export=True, + enable_onnx_export=True, + default_root_dir=tmp_dir, + ) + + filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt") + + trainer_mocks["is_global_zero"].return_value = False + + trainer.save_checkpoint(filepath) + + trainer_mocks["jit_export"].assert_not_called() + trainer_mocks["onnx_export"].assert_not_called() + trainer_mocks["super_save"].assert_called_once() + + +@patch("zetta_utils.training.lightning.trainers.default.logger") +def test_jit_trace_export_failure(mock_logger, mock_model, mock_trace_input): + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = os.path.join(tmp_dir, "test_model") + + with patch("torch.multiprocessing.get_context") as mock_ctx: + mock_ctx.side_effect = RuntimeError("Mock export failure") + + jit_trace_export(mock_model, mock_trace_input, filepath, "test_model") + + mock_logger.warning.assert_called_once() + assert "JIT trace export failed" in mock_logger.warning.call_args[0][0] + + +@patch("zetta_utils.training.lightning.trainers.default.logger") +def test_onnx_export_failure(mock_logger, mock_model, mock_trace_input): + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = os.path.join(tmp_dir, "test_model") + + with patch("torch.onnx.export") as mock_torch_onnx: + mock_torch_onnx.side_effect = RuntimeError("Mock ONNX export failure") + + with patch("fsspec.open", create=True): + onnx_export(mock_model, mock_trace_input, filepath, "test_model") + + mock_logger.warning.assert_called_once() + assert "ONNX export failed" in mock_logger.warning.call_args[0][0] + + +def test_export_functions_preserve_training_mode(mock_model, mock_trace_input): + with tempfile.TemporaryDirectory() as tmp_dir: + filepath = os.path.join(tmp_dir, "test_model") + + # Test with training mode enabled + mock_model.train() + original_mode = mock_model.training + assert original_mode is True + + # Test ONNX export preserves training mode + with patch("torch.onnx.export"): + with patch("fsspec.open", create=True): + onnx_export(mock_model, mock_trace_input, filepath, "test_model") + + assert mock_model.training == original_mode + + # Test JIT export preserves training mode + with patch("torch.multiprocessing.get_context") as mock_ctx: + mock_process = MagicMock() + mock_ctx.return_value.Process.return_value = mock_process + + jit_trace_export(mock_model, mock_trace_input, filepath, "test_model") + + assert mock_model.training == original_mode + + +def test_multiple_models_export(trainer_mocks, mock_trace_input, mock_lightning_module): + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = ZettaDefaultTrainer( + experiment_name="unit_test", + experiment_version="x0", + enable_jit_export=True, + enable_onnx_export=True, + default_root_dir=tmp_dir, + ) + + model1 = MockModel() + model2 = MockModel() + + trainer.trace_configuration = { + "model1": {"model": model1, "trace_input": (mock_trace_input,)}, + "model2": {"model": model2, "trace_input": (mock_trace_input,)}, + } + + filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt") + + trainer_mocks["is_global_zero"].return_value = True + trainer_mocks["lightning_module"].return_value = mock_lightning_module + + trainer.save_checkpoint(filepath) + + # Verify both models were exported + assert trainer_mocks["jit_export"].call_count == 2 + assert trainer_mocks["onnx_export"].call_count == 2 + + # Check that both models were called with correct names + jit_calls = [call[0] for call in trainer_mocks["jit_export"].call_args_list] + onnx_calls = [call[0] for call in trainer_mocks["onnx_export"].call_args_list] + + assert any(call[0] is model1 and call[3] == "model1" for call in jit_calls) + assert any(call[0] is model2 and call[3] == "model2" for call in jit_calls) + assert any(call[0] is model1 and call[3] == "model1" for call in onnx_calls) + assert any(call[0] is model2 and call[3] == "model2" for call in onnx_calls) diff --git a/zetta_utils/api/v0.py b/zetta_utils/api/v0.py index 0c9a1d6ad..651675dfe 100644 --- a/zetta_utils/api/v0.py +++ b/zetta_utils/api/v0.py @@ -360,7 +360,8 @@ ZettaDefaultTrainer, get_checkpointing_callbacks, get_progress_bar_callbacks, - trace_and_save_model, + jit_trace_export, + onnx_export, ) from zetta_utils.typing import ( ArithmeticOperand, diff --git a/zetta_utils/convnet/architecture/convblock.py b/zetta_utils/convnet/architecture/convblock.py index 4c210814c..3bcbf7d06 100644 --- a/zetta_utils/convnet/architecture/convblock.py +++ b/zetta_utils/convnet/architecture/convblock.py @@ -18,11 +18,7 @@ def _get_size(data: torch.Tensor) -> Sequence[int]: - # In tracing mode, shapes obtained from tensor.shape are traced as tensors - if isinstance(data.shape[0], torch.Tensor): # type: ignore[unreachable] # pragma: no cover - size = list(map(lambda x: x.item(), data.shape)) # type: ignore[unreachable] - else: - size = data.shape + size = data.shape return size diff --git a/zetta_utils/convnet/architecture/unet.py b/zetta_utils/convnet/architecture/unet.py index fcf912ddd..f08a042d4 100644 --- a/zetta_utils/convnet/architecture/unet.py +++ b/zetta_utils/convnet/architecture/unet.py @@ -152,11 +152,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: for i, layer in enumerate(self.layers): if i in skip_data_for: - # In tracing mode, shapes obtained from tensor.shape are traced as tensors - if isinstance(result.shape[0], torch.Tensor): # type: ignore # pragma: no cover - size = list(map(lambda x: x.item(), result.shape)) # type: ignore - else: - size = result.shape + size = result.shape skip_data = crop_center(skip_data_for[i], size) if self.unet_skip_mode == "sum": result = result + skip_data diff --git a/zetta_utils/tensor_ops/common.py b/zetta_utils/tensor_ops/common.py index cede22ce9..85ff4d121 100644 --- a/zetta_utils/tensor_ops/common.py +++ b/zetta_utils/tensor_ops/common.py @@ -646,7 +646,7 @@ def crop( @typechecked def crop_center( data: TensorTypeVar, - size: Sequence[int], # pylint: disable=redefined-outer-name + size: Union[torch.Size, Sequence[int]], # pylint: disable=redefined-outer-name ) -> TensorTypeVar: """ Crop a multidimensional tensor to the center. @@ -659,16 +659,12 @@ def crop_center( ndim = len(size) slices = [slice(0, None) for _ in range(data.ndim - ndim)] for insz, outsz in zip(data.shape[-ndim:], size): - if isinstance(insz, torch.Tensor): # pragma: no cover # only occurs for JIT - insz = insz.item() - assert insz >= outsz - lcrop = (insz - outsz) // 2 - rcrop = (insz - outsz) - lcrop - if rcrop != 0: - slices.append(slice(lcrop, -rcrop)) - else: - assert lcrop == 0 - slices.append(slice(0, None)) + if not torch.jit.is_tracing(): + assert insz >= outsz + start_idx = (insz - outsz) // 2 + end_idx = start_idx + outsz + slices.append(slice(start_idx, end_idx)) + result = data[tuple(slices)] return result diff --git a/zetta_utils/training/lightning/trainers/default.py b/zetta_utils/training/lightning/trainers/default.py index fd204ad35..2e3355884 100644 --- a/zetta_utils/training/lightning/trainers/default.py +++ b/zetta_utils/training/lightning/trainers/default.py @@ -25,28 +25,47 @@ os.environ["MKL_THREADING_LAYER"] = "GNU" -""" -Separate function to work around the jit.trace memory leak -""" +def _jit_trace_export_in_subprocess(args_packed): # pragma: no cover + model, trace_input, filepath_jit = args_packed + model.eval() + with torch.inference_mode(): + trace = torch.jit.trace(model, trace_input) + with fsspec.open(filepath_jit, "wb") as f: + torch.jit.save(trace, f) + logger.info(f"JIT trace export: {filepath_jit}") + + +def jit_trace_export(model, trace_input, filepath, name): + # Separate process to avoid memory leak: https://github.com/pytorch/pytorch/issues/35600 + original_training_mode = model.training + filepath_jit = f"{filepath}.static-{torch.__version__}-{name}.jit" + try: + ctx = torch.multiprocessing.get_context("spawn") + p = ctx.Process( + target=_jit_trace_export_in_subprocess, + args=[(model, trace_input, filepath_jit)], + ) + p.start() + p.join() + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning(f"JIT trace export failed for {name}: {type(e).__name__}: {e}") + finally: + model.train(original_training_mode) -def trace_and_save_model( - args_packed, -): # pragma: no cover # pylint: disable=broad-except, used-before-assignment - model, trace_input, filepath, name = args_packed - trace = torch.jit.trace(model, trace_input) - filepath_jit = f"{filepath}.static-{torch.__version__}-{name}.jit" - with fsspec.open(filepath_jit, "wb") as f: - torch.jit.save(trace, f) +def onnx_export(model, trace_input, filepath, name): + original_training_mode = model.training + filepath_onnx = f"{filepath}.static-{torch.__version__}-{name}.onnx" try: - filepath_onnx = f"{filepath}.static-{torch.__version__}-{name}.onnx" - with fsspec.open(filepath_onnx, "wb") as f: - filesystem = f.fs - torch.onnx.export(model, trace_input, f, opset_version=ONNX_OPSET_VERSION) - return None - except Exception as e: - filesystem.delete(filepath_onnx) - return type(e).__name__, e.args[0] + model.eval() + with torch.inference_mode(): + with fsspec.open(filepath_onnx, "wb") as f: + torch.onnx.export(model, trace_input, f, opset_version=ONNX_OPSET_VERSION) + logger.info(f"ONNX export: {filepath_onnx}") + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning(f"ONNX export failed for {name}: {type(e).__name__}: {e}") + finally: + model.train(original_training_mode) @builder.register("ZettaDefaultTrainer") @@ -58,6 +77,8 @@ def __init__( *args, checkpointing_kwargs: Optional[dict] = None, progress_bar_kwargs: Optional[dict] = None, + enable_onnx_export: bool = True, + enable_jit_export: bool = False, **kwargs, ): assert "callbacks" not in kwargs @@ -100,6 +121,8 @@ def __init__( # to resume training with ckpt_path='last' when storing # checkpoints on GCP. self._ckpt_path = os.path.join(log_dir, "last.ckpt") + self.enable_onnx_export = enable_onnx_export + self.enable_jit_export = enable_jit_export def save_checkpoint( self, filepath, weights_only: bool = False, storage_options: Optional[Any] = None @@ -142,12 +165,12 @@ def save_checkpoint( for name, val in self.trace_configuration.items(): model = val["model"] trace_input = val["trace_input"] - ctx = torch.multiprocessing.get_context("spawn") - with ctx.Pool(processes=1) as pool: - # See https://github.com/pytorch/pytorch/issues/35600 - res = pool.map(trace_and_save_model, [(model, trace_input, filepath, name)])[0] - if res is not None: - logger.warning(f"Exception while saving the model as ONNX: {res[0]}: {res[1]}") + + if self.enable_onnx_export: + onnx_export(model, trace_input, filepath, name) + + if self.enable_jit_export: + jit_trace_export(model, trace_input, filepath, name) @typeguard.typechecked