diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..86f98d1 --- /dev/null +++ b/.flake8 @@ -0,0 +1,7 @@ +[flake8] +max-line-length = 99 +# E203: black conflict +# E701: black conflict +# F821: lot of issues regarding type annotations +# F722: syntax error in forward annotations (jaxtyping, etc.) +extend-ignore = E203,E701,F821,F722 \ No newline at end of file diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..24fde43 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,46 @@ +name: Lint + +on: + pull_request: + types: [opened, synchronize, reopened] + +# To cancel a currently running workflow from the same PR, branch or tag when a new workflow is triggered +# https://stackoverflow.com/a/72408109 +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install linters + run: pip install autopep8 flake8 + + - name: Check formatting with autopep8 + run: autopep8 --diff --recursive --exit-code beast tests + # Reads config from [tool.autopep8] in pyproject.toml + + - name: Lint with flake8 (critical errors only) + run: flake8 beast tests --select=E9,F63,F7,F82 + # Reads config from .flake8 file + + - name: Show fix instructions if formatting needed + if: failure() + run: | + echo "" + echo "Linting failed!" + echo "" + echo "To fix formatting issues locally, run:" + echo " autopep8 --in-place --recursive beast tests" + echo "" + echo "To check for flake8 errors locally, run:" + echo " flake8 beast tests --select=E9,F63,F7,F82" + echo "" \ No newline at end of file diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 813cb31..ca499bb 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -4,6 +4,10 @@ on: release: types: [published] +permissions: + id-token: write # OIDC for trusted publishing + contents: read + jobs: deploy: runs-on: ubuntu-latest @@ -64,7 +68,7 @@ jobs: python -m zipfile -l dist/*.whl | head -20 - name: Publish to PyPI - run: poetry publish --username __token__ --password ${{ secrets.PYPI_API_TOKEN }} + uses: pypa/gh-action-pypi-publish@release/v1 - name: Verify publication run: | diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index c6c8019..93c0c31 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -1,5 +1,5 @@ # See documentation in scripts/buildbot/README -name: SSH and Execute Build Script +name: Run tests on: pull_request_target: diff --git a/beast/data/samplers.py b/beast/data/samplers.py index bda7064..a2187f9 100644 --- a/beast/data/samplers.py +++ b/beast/data/samplers.py @@ -107,7 +107,7 @@ def __init__(self, dataset, batch_size, idx_offset=1, shuffle=True, seed=42): # Calculate samples per replica self.samples_per_replica = self.num_samples // self.num_replicas self.total_samples = self.samples_per_replica * self.num_replicas - + # Calculate batches for this replica self.num_batches = self.samples_per_replica // self.batch_size @@ -142,11 +142,11 @@ def __iter__(self): indices_per_replica = len(anchor_indices_) // self.num_replicas start_idx = self.rank * indices_per_replica end_idx = start_idx + indices_per_replica - + # Handle remainder indices for the last replica if self.rank == self.num_replicas - 1: end_idx = len(anchor_indices_) - + # Get this replica's subset of anchor indices for this epoch anchor_indices = anchor_indices_[start_idx:end_idx] self.anchor_indices = anchor_indices # for testing and debugging @@ -173,15 +173,15 @@ def __iter__(self): # Find valid positive indices valid_positives = [ - p for p in self.pos_indices[i] + p for p in self.pos_indices[i] if p in self.dataset_indices and p not in used ] - + if not valid_positives: used.add(i) # Mark this anchor as used even if no valid positives idx_cursor += 1 continue - + # Choose random positive i_p = np.random.choice(valid_positives) diff --git a/beast/train.py b/beast/train.py index 8258d4a..ee851c9 100644 --- a/beast/train.py +++ b/beast/train.py @@ -165,7 +165,7 @@ def train(config: dict, model, output_dir: str | Path): def get_callbacks( checkpointing: bool = True, lr_monitor: bool = True, - ckpt_every_n_epochs: int | None= None, + ckpt_every_n_epochs: int | None = None, ) -> list: callbacks = [] diff --git a/pyproject.toml b/pyproject.toml index 94d6eb1..26b3cf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [project] name = "beast-backbones" -version = "1.1.2" # Update the version according to your source +version = "1.1.3" # Update the version according to your source description = "Behavioral analysis via self-supervised pretraining of transformers" license = "MIT" readme = "README.md" @@ -60,6 +60,7 @@ beast = "beast.cli.main:main" [project.optional-dependencies] dev = [ + "autopep8", "flake8-pyproject", "isort", "pytest", diff --git a/tests/data/test_samplers.py b/tests/data/test_samplers.py index 17df7a3..a972134 100644 --- a/tests/data/test_samplers.py +++ b/tests/data/test_samplers.py @@ -200,14 +200,14 @@ def test_distributed_anchor_distribution(self): with patch('torch.distributed.get_rank', return_value=rank): sampler = ContrastBatchSampler(dataset, batch_size=4, seed=42) samplers.append(sampler) - + # Test that all samplers have access to the same full set of anchor indices # (before per-epoch distribution) all_anchors_rank0 = set(samplers[0].all_anchor_indices) all_anchors_rank1 = set(samplers[1].all_anchor_indices) assert all_anchors_rank0 == all_anchors_rank1, \ "All ranks should start with same anchor indices" - + # Test anchor distribution during iteration (first epoch) epoch_1_batches = [] for sampler in samplers: @@ -216,7 +216,7 @@ def test_distributed_anchor_distribution(self): for batch in sampler: epoch_indices.extend(batch[::2]) # just take anchor indices epoch_1_batches.append(set(epoch_indices)) - + # Verify no overlap between ranks in epoch 1 anchors_rank0_epoch1 = epoch_1_batches[0] anchors_rank1_epoch1 = epoch_1_batches[1] @@ -224,12 +224,12 @@ def test_distributed_anchor_distribution(self): print(anchors_rank1_epoch1) assert anchors_rank0_epoch1.isdisjoint(anchors_rank1_epoch1), \ "Ranks should have non-overlapping anchors in the same epoch" - + # Verify roughly equal distribution per epoch total_epoch_anchors = len(anchors_rank0_epoch1) + len(anchors_rank1_epoch1) assert abs(len(anchors_rank0_epoch1) - len(anchors_rank1_epoch1)) <= 2, \ "Anchors should be roughly evenly distributed per epoch" - + # Test that distribution changes across epochs epoch_2_batches = [] for sampler in samplers: @@ -238,16 +238,16 @@ def test_distributed_anchor_distribution(self): for batch in sampler: epoch_indices.extend(batch[::2]) # just take anchor indices epoch_2_batches.append(set(epoch_indices)) - + anchors_rank0_epoch2 = epoch_2_batches[0] anchors_rank1_epoch2 = epoch_2_batches[1] - + # Verify that each rank gets different data across epochs assert anchors_rank0_epoch1 != anchors_rank0_epoch2, \ "Rank 0 should get different anchor indices across epochs" assert anchors_rank1_epoch1 != anchors_rank1_epoch2, \ "Rank 1 should get different anchor indices across epochs" - + # Verify no overlap within each epoch assert anchors_rank0_epoch2.isdisjoint(anchors_rank1_epoch2), \ "Ranks should have non-overlapping anchors in epoch 2" @@ -271,11 +271,11 @@ def create_rank_samplers(seed): sampler = ContrastBatchSampler(dataset, batch_size=4, seed=seed) samplers.append(sampler.all_anchor_indices.copy()) return samplers - + # Test determinism: same seed should give same results anchors_run1 = create_rank_samplers(seed=42) anchors_run2 = create_rank_samplers(seed=42) - + assert anchors_run1[0] == anchors_run2[0], "Rank 0 should be deterministic with same seed" assert anchors_run1[1] == anchors_run2[1], "Rank 1 should be deterministic with same seed" @@ -291,7 +291,7 @@ def test_anchor_redistribution_across_epochs(self): # Create sampler for single GPU with patch('torch.distributed.is_initialized', return_value=False): sampler = ContrastBatchSampler(dataset, batch_size=4, seed=42) - + # Collect indices from multiple epochs epoch_data = [] for epoch in range(3): @@ -299,27 +299,27 @@ def test_anchor_redistribution_across_epochs(self): for batch in sampler: epoch_indices.extend(batch) epoch_data.append(set(epoch_indices)) - + # Verify that different epochs use different anchor orderings # (though they may overlap since it's the same dataset) epoch1_indices, epoch2_indices, epoch3_indices = epoch_data - + # Convert to lists to check ordering epoch1_list = [] epoch2_list = [] epoch3_list = [] - + # Reset sampler and collect ordered indices sampler.epoch = 0 for batch in sampler: epoch1_list.extend(batch) - + for batch in sampler: epoch2_list.extend(batch) - + for batch in sampler: epoch3_list.extend(batch) - + # Verify that the ordering is different across epochs assert epoch1_list != epoch2_list, "Epoch 1 and 2 should have different anchor orderings" assert epoch2_list != epoch3_list, "Epoch 2 and 3 should have different anchor orderings" @@ -340,21 +340,21 @@ def test_reproducible_epoch_distribution(self): with patch('torch.distributed.is_initialized', return_value=False): sampler = ContrastBatchSampler(dataset, batch_size=4, seed=42) samplers.append(sampler) - + # Collect indices from first epoch for both samplers epoch1_indices_sampler1 = [] epoch1_indices_sampler2 = [] - + for batch in samplers[0]: epoch1_indices_sampler1.extend(batch[::2]) # just take anchor indices - + for batch in samplers[1]: epoch1_indices_sampler2.extend(batch[::2]) # just take anchor indices - + # Verify same seed produces same results assert epoch1_indices_sampler1 == epoch1_indices_sampler2, \ "Same seed should produce identical anchor distribution" - + def test_epoch_based_shuffling(self): """Test that different epochs produce different orderings within each rank""" n_frames = 52 @@ -387,13 +387,13 @@ def test_positive_indices_validity(self): subdataset = Mock() # Create a realistic scenario with consecutive frames subdataset.image_list = [f"video1/frame_{i:03d}.png" for i in range(15)] + \ - [f"video2/frame_{i:03d}.png" for i in range(15)] + [f"video2/frame_{i:03d}.png" for i in range(15)] dataset.indices = list(range(30)) dataset.dataset = subdataset with patch('torch.distributed.is_initialized', return_value=False): sampler = ContrastBatchSampler(dataset, batch_size=4, idx_offset=1) - + # Check that pos_indices relationships are valid for anchor_idx, pos_list in sampler.pos_indices.items(): for pos_idx in pos_list: @@ -414,16 +414,16 @@ def test_batch_anchor_positive_relationships(self): with patch('torch.distributed.is_initialized', return_value=False): sampler = ContrastBatchSampler(dataset, batch_size=4, idx_offset=1) - + batches = list(sampler) - + for batch in batches: # Process pairs (assuming even indices are anchors, odd are positives) for i in range(0, len(batch), 2): if i + 1 < len(batch): anchor_idx = batch[i] pos_idx = batch[i + 1] - + # Verify this is a valid anchor-positive relationship assert anchor_idx in sampler.pos_indices, f"Anchor {anchor_idx} should have positives" assert pos_idx in sampler.pos_indices[anchor_idx], \