-
Notifications
You must be signed in to change notification settings - Fork 204
Open
Description
Hi. Maybe it's me understanding it incorrectly. In code line 178~180 from run_pplm.py, where a window mask for choosing only a recent past of the hidden states to update is constructed:
window_mask = torch.cat(
(ones_mask, torch.zeros(zeros_key_val_shape)),
dim=-2
Should we actually concatenate in the order of (zeros; ones) instead since we aim to mask out the recent latents rather than the very beginning?
Any response to this would be greatly appreciated!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels