From a2c2c22a80758689d1130e817d3f15d533d0a04c Mon Sep 17 00:00:00 2001 From: Darko Lukic Date: Tue, 12 Aug 2025 22:06:08 +0200 Subject: [PATCH] Add linter tests --- .github/workflows/pytest.yml | 1 + alignit/visualize.py | 6 +----- tests/test_alignnet.py | 22 +++++++++++++++++++--- tests/test_dataset_loader.py | 6 ++---- tests/test_lint.py | 18 ++++++++++++++++++ tests/test_tfs.py | 2 +- 6 files changed, 42 insertions(+), 13 deletions(-) create mode 100644 tests/test_lint.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index e6a80ed..03de18c 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -30,6 +30,7 @@ jobs: run: | python -m pip install --upgrade pip pip install . + pip install black - name: Run tests run: | diff --git a/alignit/visualize.py b/alignit/visualize.py index 26ee81f..86532e0 100644 --- a/alignit/visualize.py +++ b/alignit/visualize.py @@ -25,11 +25,7 @@ def get_data(index): outputs=[gr.Image(type="pil", label="Image"), gr.Text(label="Label")], title="Dataset Image Viewer", live=True, - ).launch( - share=cfg.share, - server_name=cfg.server_name, - server_port=cfg.server_port - ) + ).launch(share=cfg.share, server_name=cfg.server_name, server_port=cfg.server_port) if __name__ == "__main__": diff --git a/tests/test_alignnet.py b/tests/test_alignnet.py index aa77cb6..77bcddb 100644 --- a/tests/test_alignnet.py +++ b/tests/test_alignnet.py @@ -4,7 +4,12 @@ def test_alignnet_forward_shapes_cpu(): - model = AlignNet(backbone_name="resnet18", backbone_weights=None, use_vector_input=False, output_dim=7) + model = AlignNet( + backbone_name="resnet18", + backbone_weights=None, + use_vector_input=False, + output_dim=7, + ) model.eval() x = torch.randn(2, 3, 3, 64, 64) # B=2, N=3 views with torch.no_grad(): @@ -13,7 +18,12 @@ def test_alignnet_forward_shapes_cpu(): def test_alignnet_with_vector_input(): - model = AlignNet(backbone_name="resnet18", backbone_weights=None, use_vector_input=True, output_dim=7) + model = AlignNet( + backbone_name="resnet18", + backbone_weights=None, + use_vector_input=True, + output_dim=7, + ) model.eval() x = torch.randn(1, 2, 3, 64, 64) vecs = [torch.randn(5)] @@ -23,12 +33,18 @@ def test_alignnet_with_vector_input(): def test_alignnet_performance(): - model = AlignNet(backbone_name="efficientnet_b0", backbone_weights=None, use_vector_input=True, output_dim=7) + model = AlignNet( + backbone_name="efficientnet_b0", + backbone_weights=None, + use_vector_input=True, + output_dim=7, + ) model.eval() x = torch.randn(1, 3, 3, 224, 224) # B=1, N=3 views vecs = [torch.randn(5)] with torch.no_grad(): import time + start_time = time.time() for _ in range(10): y = model(x, vecs) diff --git a/tests/test_dataset_loader.py b/tests/test_dataset_loader.py index 90a6b5a..e534a9d 100644 --- a/tests/test_dataset_loader.py +++ b/tests/test_dataset_loader.py @@ -1,7 +1,3 @@ -import os - -import pytest - from alignit.utils.dataset import load_dataset @@ -19,6 +15,7 @@ def fake_load_from_disk(p): # Patch inside module import alignit.utils.dataset as ds + ds.load_from_disk = fake_load_from_disk d = load_dataset(str(tmp_path)) @@ -35,6 +32,7 @@ def fake_hf_load_dataset(name): return Dummy(name) import alignit.utils.dataset as ds + ds.hf_load_dataset = fake_hf_load_dataset d = load_dataset("my-dataset/name") diff --git a/tests/test_lint.py b/tests/test_lint.py new file mode 100644 index 0000000..b0bea21 --- /dev/null +++ b/tests/test_lint.py @@ -0,0 +1,18 @@ +import subprocess +import sys +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parent.parent + + +def run(cmd: list[str]): + result = subprocess.run(cmd, cwd=PROJECT_ROOT, capture_output=True, text=True) + if result.returncode != 0: + print(result.stdout) + print(result.stderr, file=sys.stderr) + return result.returncode == 0 + + +def test_black(): + result = run([sys.executable, "-m", "black", "--check", "--diff", "--verbose", "."]) + assert result, "Black formatting issues found" diff --git a/tests/test_tfs.py b/tests/test_tfs.py index 805c5c6..65bc87e 100644 --- a/tests/test_tfs.py +++ b/tests/test_tfs.py @@ -21,5 +21,5 @@ def test_get_pose_str_format(): pose = np.eye(4) s = get_pose_str(pose) assert isinstance(s, str) - parts = s.split(',') + parts = s.split(",") assert len(parts) == 6