Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[flake8]
max-line-length = 99
# E203: black conflict
# E701: black conflict
# F821: lot of issues regarding type annotations
# F722: syntax error in forward annotations (jaxtyping, etc.)
extend-ignore = E203,E701,F821,F722
46 changes: 46 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Lint

on:
pull_request:
types: [opened, synchronize, reopened]

# To cancel a currently running workflow from the same PR, branch or tag when a new workflow is triggered
# https://stackoverflow.com/a/72408109
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install linters
run: pip install autopep8 flake8

- name: Check formatting with autopep8
run: autopep8 --diff --recursive --exit-code beast tests
# Reads config from [tool.autopep8] in pyproject.toml

- name: Lint with flake8 (critical errors only)
run: flake8 beast tests --select=E9,F63,F7,F82
# Reads config from .flake8 file

- name: Show fix instructions if formatting needed
if: failure()
run: |
echo ""
echo "Linting failed!"
echo ""
echo "To fix formatting issues locally, run:"
echo " autopep8 --in-place --recursive beast tests"
echo ""
echo "To check for flake8 errors locally, run:"
echo " flake8 beast tests --select=E9,F63,F7,F82"
echo ""
6 changes: 5 additions & 1 deletion .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ on:
release:
types: [published]

permissions:
id-token: write # OIDC for trusted publishing
contents: read

jobs:
deploy:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -64,7 +68,7 @@ jobs:
python -m zipfile -l dist/*.whl | head -20

- name: Publish to PyPI
run: poetry publish --username __token__ --password ${{ secrets.PYPI_API_TOKEN }}
uses: pypa/gh-action-pypi-publish@release/v1

- name: Verify publication
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# See documentation in scripts/buildbot/README
name: SSH and Execute Build Script
name: Run tests

on:
pull_request_target:
Expand Down
12 changes: 6 additions & 6 deletions beast/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, dataset, batch_size, idx_offset=1, shuffle=True, seed=42):
# Calculate samples per replica
self.samples_per_replica = self.num_samples // self.num_replicas
self.total_samples = self.samples_per_replica * self.num_replicas

# Calculate batches for this replica
self.num_batches = self.samples_per_replica // self.batch_size

Expand Down Expand Up @@ -142,11 +142,11 @@ def __iter__(self):
indices_per_replica = len(anchor_indices_) // self.num_replicas
start_idx = self.rank * indices_per_replica
end_idx = start_idx + indices_per_replica

# Handle remainder indices for the last replica
if self.rank == self.num_replicas - 1:
end_idx = len(anchor_indices_)

# Get this replica's subset of anchor indices for this epoch
anchor_indices = anchor_indices_[start_idx:end_idx]
self.anchor_indices = anchor_indices # for testing and debugging
Expand All @@ -173,15 +173,15 @@ def __iter__(self):

# Find valid positive indices
valid_positives = [
p for p in self.pos_indices[i]
p for p in self.pos_indices[i]
if p in self.dataset_indices and p not in used
]

if not valid_positives:
used.add(i) # Mark this anchor as used even if no valid positives
idx_cursor += 1
continue

# Choose random positive
i_p = np.random.choice(valid_positives)

Expand Down
2 changes: 1 addition & 1 deletion beast/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def train(config: dict, model, output_dir: str | Path):
def get_callbacks(
checkpointing: bool = True,
lr_monitor: bool = True,
ckpt_every_n_epochs: int | None= None,
ckpt_every_n_epochs: int | None = None,
) -> list:

callbacks = []
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[project]
name = "beast-backbones"
version = "1.1.2" # Update the version according to your source
version = "1.1.3" # Update the version according to your source
description = "Behavioral analysis via self-supervised pretraining of transformers"
license = "MIT"
readme = "README.md"
Expand Down Expand Up @@ -60,6 +60,7 @@ beast = "beast.cli.main:main"

[project.optional-dependencies]
dev = [
"autopep8",
"flake8-pyproject",
"isort",
"pytest",
Expand Down
54 changes: 27 additions & 27 deletions tests/data/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ def test_distributed_anchor_distribution(self):
with patch('torch.distributed.get_rank', return_value=rank):
sampler = ContrastBatchSampler(dataset, batch_size=4, seed=42)
samplers.append(sampler)

# Test that all samplers have access to the same full set of anchor indices
# (before per-epoch distribution)
all_anchors_rank0 = set(samplers[0].all_anchor_indices)
all_anchors_rank1 = set(samplers[1].all_anchor_indices)
assert all_anchors_rank0 == all_anchors_rank1, \
"All ranks should start with same anchor indices"

# Test anchor distribution during iteration (first epoch)
epoch_1_batches = []
for sampler in samplers:
Expand All @@ -216,20 +216,20 @@ def test_distributed_anchor_distribution(self):
for batch in sampler:
epoch_indices.extend(batch[::2]) # just take anchor indices
epoch_1_batches.append(set(epoch_indices))

# Verify no overlap between ranks in epoch 1
anchors_rank0_epoch1 = epoch_1_batches[0]
anchors_rank1_epoch1 = epoch_1_batches[1]
print(anchors_rank0_epoch1)
print(anchors_rank1_epoch1)
assert anchors_rank0_epoch1.isdisjoint(anchors_rank1_epoch1), \
"Ranks should have non-overlapping anchors in the same epoch"

# Verify roughly equal distribution per epoch
total_epoch_anchors = len(anchors_rank0_epoch1) + len(anchors_rank1_epoch1)
assert abs(len(anchors_rank0_epoch1) - len(anchors_rank1_epoch1)) <= 2, \
"Anchors should be roughly evenly distributed per epoch"

# Test that distribution changes across epochs
epoch_2_batches = []
for sampler in samplers:
Expand All @@ -238,16 +238,16 @@ def test_distributed_anchor_distribution(self):
for batch in sampler:
epoch_indices.extend(batch[::2]) # just take anchor indices
epoch_2_batches.append(set(epoch_indices))

anchors_rank0_epoch2 = epoch_2_batches[0]
anchors_rank1_epoch2 = epoch_2_batches[1]

# Verify that each rank gets different data across epochs
assert anchors_rank0_epoch1 != anchors_rank0_epoch2, \
"Rank 0 should get different anchor indices across epochs"
assert anchors_rank1_epoch1 != anchors_rank1_epoch2, \
"Rank 1 should get different anchor indices across epochs"

# Verify no overlap within each epoch
assert anchors_rank0_epoch2.isdisjoint(anchors_rank1_epoch2), \
"Ranks should have non-overlapping anchors in epoch 2"
Expand All @@ -271,11 +271,11 @@ def create_rank_samplers(seed):
sampler = ContrastBatchSampler(dataset, batch_size=4, seed=seed)
samplers.append(sampler.all_anchor_indices.copy())
return samplers

# Test determinism: same seed should give same results
anchors_run1 = create_rank_samplers(seed=42)
anchors_run2 = create_rank_samplers(seed=42)

assert anchors_run1[0] == anchors_run2[0], "Rank 0 should be deterministic with same seed"
assert anchors_run1[1] == anchors_run2[1], "Rank 1 should be deterministic with same seed"

Expand All @@ -291,35 +291,35 @@ def test_anchor_redistribution_across_epochs(self):
# Create sampler for single GPU
with patch('torch.distributed.is_initialized', return_value=False):
sampler = ContrastBatchSampler(dataset, batch_size=4, seed=42)

# Collect indices from multiple epochs
epoch_data = []
for epoch in range(3):
epoch_indices = []
for batch in sampler:
epoch_indices.extend(batch)
epoch_data.append(set(epoch_indices))

# Verify that different epochs use different anchor orderings
# (though they may overlap since it's the same dataset)
epoch1_indices, epoch2_indices, epoch3_indices = epoch_data

# Convert to lists to check ordering
epoch1_list = []
epoch2_list = []
epoch3_list = []

# Reset sampler and collect ordered indices
sampler.epoch = 0
for batch in sampler:
epoch1_list.extend(batch)

for batch in sampler:
epoch2_list.extend(batch)

for batch in sampler:
epoch3_list.extend(batch)

# Verify that the ordering is different across epochs
assert epoch1_list != epoch2_list, "Epoch 1 and 2 should have different anchor orderings"
assert epoch2_list != epoch3_list, "Epoch 2 and 3 should have different anchor orderings"
Expand All @@ -340,21 +340,21 @@ def test_reproducible_epoch_distribution(self):
with patch('torch.distributed.is_initialized', return_value=False):
sampler = ContrastBatchSampler(dataset, batch_size=4, seed=42)
samplers.append(sampler)

# Collect indices from first epoch for both samplers
epoch1_indices_sampler1 = []
epoch1_indices_sampler2 = []

for batch in samplers[0]:
epoch1_indices_sampler1.extend(batch[::2]) # just take anchor indices

for batch in samplers[1]:
epoch1_indices_sampler2.extend(batch[::2]) # just take anchor indices

# Verify same seed produces same results
assert epoch1_indices_sampler1 == epoch1_indices_sampler2, \
"Same seed should produce identical anchor distribution"

def test_epoch_based_shuffling(self):
"""Test that different epochs produce different orderings within each rank"""
n_frames = 52
Expand Down Expand Up @@ -387,13 +387,13 @@ def test_positive_indices_validity(self):
subdataset = Mock()
# Create a realistic scenario with consecutive frames
subdataset.image_list = [f"video1/frame_{i:03d}.png" for i in range(15)] + \
[f"video2/frame_{i:03d}.png" for i in range(15)]
[f"video2/frame_{i:03d}.png" for i in range(15)]
dataset.indices = list(range(30))
dataset.dataset = subdataset

with patch('torch.distributed.is_initialized', return_value=False):
sampler = ContrastBatchSampler(dataset, batch_size=4, idx_offset=1)

# Check that pos_indices relationships are valid
for anchor_idx, pos_list in sampler.pos_indices.items():
for pos_idx in pos_list:
Expand All @@ -414,16 +414,16 @@ def test_batch_anchor_positive_relationships(self):

with patch('torch.distributed.is_initialized', return_value=False):
sampler = ContrastBatchSampler(dataset, batch_size=4, idx_offset=1)

batches = list(sampler)

for batch in batches:
# Process pairs (assuming even indices are anchors, odd are positives)
for i in range(0, len(batch), 2):
if i + 1 < len(batch):
anchor_idx = batch[i]
pos_idx = batch[i + 1]

# Verify this is a valid anchor-positive relationship
assert anchor_idx in sampler.pos_indices, f"Anchor {anchor_idx} should have positives"
assert pos_idx in sampler.pos_indices[anchor_idx], \
Expand Down
Loading