diff --git a/trainer.py b/trainer.py index 7ca58c85..2e98f03e 100644 --- a/trainer.py +++ b/trainer.py @@ -248,7 +248,7 @@ def edit_step(self, batch, training: bool): info_dict["acc/post"] = post_loc_dict["acc"].item() info_dict["nll/pre"] = pre_loc_dict["nll"].item() info_dict["nll/post"] = post_loc_dict["nll"].item() - info_dict["n_tokens/pre"] = post_loc_dict["n_tokens"] + info_dict["n_tokens/pre"] = pre_loc_dict["n_tokens"] info_dict["n_tokens/post"] = post_loc_dict["n_tokens"] info_dict["time/edit"] = edit_time