if write:
self.id_bank[cur_step] = [
hidden_states[: self.id_length],
hidden_states[self.id_length :],
]
else:
encoder_hidden_states = torch.cat(
(
self.id_bank[cur_step][0].to(self.device),
hidden_states[:1],
self.id_bank[cur_step][1].to(self.device),
hidden_states[1:],
)
)
if write:
self.id_bank[cur_step] = hidden_states
else:
encoder_hidden_states = torch.cat(
(
self.id_bank[cur_step].to(self.device),
hidden_states,
)
)
I'd like to know why the implementation looks like this
Why not code like this?