Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ omit =
zetta_utils/alignment/*.py
zetta_utils/segmentation/*.py
zetta_utils/segmentation/**/*.py
zetta_utils/training/lightning/*.py
zetta_utils/training/lightning/train.py
zetta_utils/cloud_management/*.py
zetta_utils/cloud_management/**/*.py
zetta_utils/run/*.py
Expand Down
269 changes: 267 additions & 2 deletions tests/unit/training/lightning/test_default_trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,276 @@
# pylint: disable=redefined-outer-name
import os
import tempfile
from unittest.mock import MagicMock, PropertyMock, patch

import lightning.pytorch as pl
import pytest
import torch

from zetta_utils import training
from zetta_utils.training.lightning.trainers.default import (
ZettaDefaultTrainer,
jit_trace_export,
onnx_export,
)


def test_default_trainer():
result = training.lightning.trainers.ZettaDefaultTrainer(
result = ZettaDefaultTrainer(
experiment_name="unit_test",
experiment_version="x0",
)
assert isinstance(result, pl.Trainer)


class MockModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 5)

def forward(self, x):
return self.linear(x)


@pytest.fixture
def mock_model():
return MockModel()


@pytest.fixture
def mock_trace_input():
return torch.randn(1, 10)


@pytest.fixture
def mock_lightning_module():
"""Fixture providing a mock lightning module."""
mock_module = MagicMock()
mock_module._modules = {} # pylint: disable=protected-access
return mock_module


@pytest.fixture
def trainer_mocks(mocker):
"""Fixture providing common mocks for trainer tests."""
return {
"super_save": mocker.patch.object(pl.Trainer, "save_checkpoint"),
"jit_export": mocker.patch(
"zetta_utils.training.lightning.trainers.default.jit_trace_export"
),
"onnx_export": mocker.patch("zetta_utils.training.lightning.trainers.default.onnx_export"),
"is_global_zero": mocker.patch.object(
ZettaDefaultTrainer, "is_global_zero", new_callable=PropertyMock
),
"lightning_module": mocker.patch.object(
ZettaDefaultTrainer, "lightning_module", new_callable=PropertyMock
),
}


def test_save_checkpoint_calls_exports_when_enabled(
trainer_mocks, mock_model, mock_trace_input, mock_lightning_module
):
"""Test that save_checkpoint calls export functions when exports are enabled."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this test necessary? isn't it simpler just to add an assert line in the code that checks at least 1 export function is enabled?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's different. I am OK if the user disables both exports, so that assert would be too restricting, anyway.

Here I am just testing that the parameters got passed down the chain and didn't get lost while unpacking/modifying **kwargs somewhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's different. I am OK if the user disables both exports, so that assert would be too restricting, anyway.
Right, I misread the test.

But yea, still seems trivial? Is it really doing much more than a simple line coverage test? (i.e., call the save function & check that the saved file exists, etc.). Making sure the functions are called seems very marginally better.

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = ZettaDefaultTrainer(
experiment_name="unit_test",
experiment_version="x0",
enable_jit_export=True,
enable_onnx_export=True,
default_root_dir=tmp_dir,
)

trainer.trace_configuration = {
"test_model": {"model": mock_model, "trace_input": (mock_trace_input,)}
}

filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt")

trainer_mocks["is_global_zero"].return_value = True
trainer_mocks["lightning_module"].return_value = mock_lightning_module

trainer.save_checkpoint(filepath)

trainer_mocks["onnx_export"].assert_called_once_with(
mock_model, (mock_trace_input,), filepath, "test_model"
)
trainer_mocks["jit_export"].assert_called_once_with(
mock_model, (mock_trace_input,), filepath, "test_model"
)
trainer_mocks["super_save"].assert_called_once()


def test_save_checkpoint_skips_jit_when_disabled(
trainer_mocks, mock_model, mock_trace_input, mock_lightning_module
):
"""Test that save_checkpoint skips JIT export when disabled."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems not necessary. if jit.trace is disabled in the input param it's clear that it's disabled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test verifies that the export logic actually respects the set parameter.

Here is a Copy & Paste error introduced by future-Us doing refactoring:

if self.enable_onnx_export:
      jit_trace_export(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple line coverage test would have caught this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, onnx_export might get called from a different location (either correctly or incorrectly), too. My point is: 100% line coverage is nice, but easy to achieve by cheating ourselves. Ideally I would like to ensure the method is called from all the expected code paths (and only those). 100% code coverage ensures no dead code, no broken code. But it does not ensure correct results.

But also, something is wrong here. I should not have been able to achieve 100% code coverage without these tests in the first place. Especially the error handling paths can't have been possibly covered before...

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = ZettaDefaultTrainer(
experiment_name="unit_test",
experiment_version="x0",
enable_jit_export=False,
enable_onnx_export=True,
default_root_dir=tmp_dir,
)

trainer.trace_configuration = {
"test_model": {"model": mock_model, "trace_input": (mock_trace_input,)}
}

filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt")

trainer_mocks["is_global_zero"].return_value = True
trainer_mocks["lightning_module"].return_value = mock_lightning_module

trainer.save_checkpoint(filepath)

trainer_mocks["onnx_export"].assert_called_once_with(
mock_model, (mock_trace_input,), filepath, "test_model"
)
trainer_mocks["jit_export"].assert_not_called()


def test_save_checkpoint_skips_onnx_when_disabled(
trainer_mocks, mock_model, mock_trace_input, mock_lightning_module
):
"""Test that save_checkpoint skips ONNX export when disabled."""
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = ZettaDefaultTrainer(
experiment_name="unit_test",
experiment_version="x0",
enable_jit_export=True,
enable_onnx_export=False,
default_root_dir=tmp_dir,
)

trainer.trace_configuration = {
"test_model": {"model": mock_model, "trace_input": (mock_trace_input,)}
}

filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt")

trainer_mocks["is_global_zero"].return_value = True
trainer_mocks["lightning_module"].return_value = mock_lightning_module

trainer.save_checkpoint(filepath)

trainer_mocks["jit_export"].assert_called_once_with(
mock_model, (mock_trace_input,), filepath, "test_model"
)
trainer_mocks["onnx_export"].assert_not_called()


def test_save_checkpoint_skips_exports_non_global_zero(trainer_mocks):
"""Test that save_checkpoint skips exports when not global zero rank."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems useful, but would miss additional exports in the future

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = ZettaDefaultTrainer(
experiment_name="unit_test",
experiment_version="x0",
enable_jit_export=True,
enable_onnx_export=True,
default_root_dir=tmp_dir,
)

filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt")

trainer_mocks["is_global_zero"].return_value = False

trainer.save_checkpoint(filepath)

trainer_mocks["jit_export"].assert_not_called()
trainer_mocks["onnx_export"].assert_not_called()
trainer_mocks["super_save"].assert_called_once()


@patch("zetta_utils.training.lightning.trainers.default.logger")
def test_jit_trace_export_failure(mock_logger, mock_model, mock_trace_input):
with tempfile.TemporaryDirectory() as tmp_dir:
filepath = os.path.join(tmp_dir, "test_model")

with patch("torch.multiprocessing.get_context") as mock_ctx:
mock_ctx.side_effect = RuntimeError("Mock export failure")

jit_trace_export(mock_model, mock_trace_input, filepath, "test_model")

mock_logger.warning.assert_called_once()
assert "JIT trace export failed" in mock_logger.warning.call_args[0][0]


@patch("zetta_utils.training.lightning.trainers.default.logger")
def test_onnx_export_failure(mock_logger, mock_model, mock_trace_input):
with tempfile.TemporaryDirectory() as tmp_dir:
filepath = os.path.join(tmp_dir, "test_model")

with patch("torch.onnx.export") as mock_torch_onnx:
mock_torch_onnx.side_effect = RuntimeError("Mock ONNX export failure")

with patch("fsspec.open", create=True):
onnx_export(mock_model, mock_trace_input, filepath, "test_model")

mock_logger.warning.assert_called_once()
assert "ONNX export failed" in mock_logger.warning.call_args[0][0]


def test_export_functions_preserve_training_mode(mock_model, mock_trace_input):
with tempfile.TemporaryDirectory() as tmp_dir:
filepath = os.path.join(tmp_dir, "test_model")

# Test with training mode enabled
mock_model.train()
original_mode = mock_model.training
assert original_mode is True

# Test ONNX export preserves training mode
with patch("torch.onnx.export"):
with patch("fsspec.open", create=True):
onnx_export(mock_model, mock_trace_input, filepath, "test_model")

assert mock_model.training == original_mode

# Test JIT export preserves training mode
with patch("torch.multiprocessing.get_context") as mock_ctx:
mock_process = MagicMock()
mock_ctx.return_value.Process.return_value = mock_process

jit_trace_export(mock_model, mock_trace_input, filepath, "test_model")

assert mock_model.training == original_mode


def test_multiple_models_export(trainer_mocks, mock_trace_input, mock_lightning_module):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = ZettaDefaultTrainer(
experiment_name="unit_test",
experiment_version="x0",
enable_jit_export=True,
enable_onnx_export=True,
default_root_dir=tmp_dir,
)

model1 = MockModel()
model2 = MockModel()

trainer.trace_configuration = {
"model1": {"model": model1, "trace_input": (mock_trace_input,)},
"model2": {"model": model2, "trace_input": (mock_trace_input,)},
}

filepath = os.path.join(tmp_dir, "test_checkpoint.ckpt")

trainer_mocks["is_global_zero"].return_value = True
trainer_mocks["lightning_module"].return_value = mock_lightning_module

trainer.save_checkpoint(filepath)

# Verify both models were exported
assert trainer_mocks["jit_export"].call_count == 2
assert trainer_mocks["onnx_export"].call_count == 2

# Check that both models were called with correct names
jit_calls = [call[0] for call in trainer_mocks["jit_export"].call_args_list]
onnx_calls = [call[0] for call in trainer_mocks["onnx_export"].call_args_list]

assert any(call[0] is model1 and call[3] == "model1" for call in jit_calls)
assert any(call[0] is model2 and call[3] == "model2" for call in jit_calls)
assert any(call[0] is model1 and call[3] == "model1" for call in onnx_calls)
assert any(call[0] is model2 and call[3] == "model2" for call in onnx_calls)
3 changes: 2 additions & 1 deletion zetta_utils/api/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,8 @@
ZettaDefaultTrainer,
get_checkpointing_callbacks,
get_progress_bar_callbacks,
trace_and_save_model,
jit_trace_export,
onnx_export,
)
from zetta_utils.typing import (
ArithmeticOperand,
Expand Down
6 changes: 1 addition & 5 deletions zetta_utils/convnet/architecture/convblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@


def _get_size(data: torch.Tensor) -> Sequence[int]:
# In tracing mode, shapes obtained from tensor.shape are traced as tensors
if isinstance(data.shape[0], torch.Tensor): # type: ignore[unreachable] # pragma: no cover
size = list(map(lambda x: x.item(), data.shape)) # type: ignore[unreachable]
else:
size = data.shape
size = data.shape
return size


Expand Down
6 changes: 1 addition & 5 deletions zetta_utils/convnet/architecture/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:

for i, layer in enumerate(self.layers):
if i in skip_data_for:
# In tracing mode, shapes obtained from tensor.shape are traced as tensors
if isinstance(result.shape[0], torch.Tensor): # type: ignore # pragma: no cover
size = list(map(lambda x: x.item(), result.shape)) # type: ignore
else:
size = result.shape
size = result.shape
skip_data = crop_center(skip_data_for[i], size)
if self.unet_skip_mode == "sum":
result = result + skip_data
Expand Down
18 changes: 7 additions & 11 deletions zetta_utils/tensor_ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def crop(
@typechecked
def crop_center(
data: TensorTypeVar,
size: Sequence[int], # pylint: disable=redefined-outer-name
size: Union[torch.Size, Sequence[int]], # pylint: disable=redefined-outer-name
) -> TensorTypeVar:
"""
Crop a multidimensional tensor to the center.
Expand All @@ -659,16 +659,12 @@ def crop_center(
ndim = len(size)
slices = [slice(0, None) for _ in range(data.ndim - ndim)]
for insz, outsz in zip(data.shape[-ndim:], size):
if isinstance(insz, torch.Tensor): # pragma: no cover # only occurs for JIT
insz = insz.item()
assert insz >= outsz
lcrop = (insz - outsz) // 2
rcrop = (insz - outsz) - lcrop
if rcrop != 0:
slices.append(slice(lcrop, -rcrop))
else:
assert lcrop == 0
slices.append(slice(0, None))
if not torch.jit.is_tracing():
assert insz >= outsz
start_idx = (insz - outsz) // 2
end_idx = start_idx + outsz
slices.append(slice(start_idx, end_idx))

result = data[tuple(slices)]
return result

Expand Down
Loading
Loading