Skip to content

Commit fb042ec

Browse files
authored
Merge pull request #17 from ryusudol/dev
Dev
2 parents 1859398 + 545b27f commit fb042ec

6 files changed

Lines changed: 502 additions & 14 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: CI
22

33
on:
44
push:
5-
branches: [main, dev]
5+
branches: [main]
66
paths:
77
- "cka/**"
88
- "tests/**"

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,5 @@ cka_checkpoint.pt
5757
.claude/
5858
CLAUDE.md
5959
uv.lock
60+
61+
*.pdf

cka/cka.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def compare(
242242
dataloader: DataLoader,
243243
dataloader2: DataLoader | None = None,
244244
progress: bool = True,
245-
callback: Callable[[int, int, torch.Tensor], None] | None = None,
245+
verbose: bool = False,
246246
) -> torch.Tensor:
247247
if not self._hook_handles:
248248
raise RuntimeError("Hooks not registered.")
@@ -296,9 +296,11 @@ def compare(
296296

297297
self._accumulate_hsic(hsic_xy, hsic_xx, hsic_yy)
298298

299-
if callback is not None:
299+
if verbose:
300300
current_cka = self._compute_cka_matrix(hsic_xy, hsic_xx, hsic_yy)
301-
callback(batch_idx, total_batches, current_cka)
301+
print(
302+
f"Batch {batch_idx + 1}/{total_batches} - Mean CKA: {current_cka.mean().item():.4f}"
303+
)
302304

303305
return self._compute_cka_matrix(hsic_xy, hsic_xx, hsic_yy)
304306

File renamed without changes.

examples/memory_efficiency_comparison.ipynb

Lines changed: 487 additions & 0 deletions
Large diffs are not rendered by default.

tests/test_cka_extended.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -336,19 +336,16 @@ def test_progress_true(self, model1, model2, dataloader):
336336

337337
assert result.shape == (3, 3)
338338

339-
def test_callback_called(self, model1, model2, dataloader):
339+
def test_verbose_output(self, model1, model2, dataloader, capsys):
340340
cka = CKA(model1, model2)
341-
callback_calls = []
342341

343-
def callback(batch_idx, total_batches, cka_matrix):
344-
callback_calls.append((batch_idx, total_batches, cka_matrix.clone()))
342+
result = cka(dataloader, verbose=True, progress=False)
345343

346-
result = cka(dataloader, callback=callback)
347-
348-
assert len(callback_calls) == 4
349-
assert callback_calls[0][0] == 0
350-
assert callback_calls[0][1] == 4
351-
assert callback_calls[-1][2].shape == result.shape
344+
captured = capsys.readouterr()
345+
assert "Batch 1/4" in captured.out
346+
assert "Batch 4/4" in captured.out
347+
assert "Mean CKA:" in captured.out
348+
assert result.shape[0] > 0
352349

353350
def test_two_dataloaders(self, model1, model2, dataloader):
354351
x2 = torch.randn(32, 10)

0 commit comments

Comments
 (0)