Why comment out the self.bank.clear() inside the if MODE == "read": block in mutual_self_attention.py, and then perform the clear operation within the if accelerator.sync_gradients: block in train_stage_1.py?
Won’t this lead to the bank accumulating multiple states in the interim? For instance, when the gradient accumulation steps are set to 3:
During the second forward pass, the bank would use the first hidden state.
During the third forward pass, the bank would use the first two hidden states.