@@ -336,19 +336,16 @@ def test_progress_true(self, model1, model2, dataloader):
336336
337337 assert result .shape == (3 , 3 )
338338
339- def test_callback_called (self , model1 , model2 , dataloader ):
339+ def test_verbose_output (self , model1 , model2 , dataloader , capsys ):
340340 cka = CKA (model1 , model2 )
341- callback_calls = []
342341
343- def callback (batch_idx , total_batches , cka_matrix ):
344- callback_calls .append ((batch_idx , total_batches , cka_matrix .clone ()))
342+ result = cka (dataloader , verbose = True , progress = False )
345343
346- result = cka (dataloader , callback = callback )
347-
348- assert len (callback_calls ) == 4
349- assert callback_calls [0 ][0 ] == 0
350- assert callback_calls [0 ][1 ] == 4
351- assert callback_calls [- 1 ][2 ].shape == result .shape
344+ captured = capsys .readouterr ()
345+ assert "Batch 1/4" in captured .out
346+ assert "Batch 4/4" in captured .out
347+ assert "Mean CKA:" in captured .out
348+ assert result .shape [0 ] > 0
352349
353350 def test_two_dataloaders (self , model1 , model2 , dataloader ):
354351 x2 = torch .randn (32 , 10 )
0 commit comments