Hello, thank you for sharing the great code.
In the contrast_matrix function (src/train_plus_utils.py), there seems to be a mistake around line ~218.
matrix_02 = torch.zeros((245, 4))
for i in range(4):
matrix_02[i * 49:(i + 1) * 49, i] = 0.5
for i in range(4):
if i != 3:
matrix_20[196 + i * 12:196 + (i + 1) * 13, i] = 0.5 # <- possibly should be matrix_02
else:
matrix_20[196 + i * 12:, i] = 0.5 # <- possibly should be matrix_02
At this point, the code is modifying matrix_20 again, but based on the logic, I think it is supposed to continue assigning values to matrix_02.