diff --git a/rutransform/transformations/transformations/distraction/sentence_additions.py b/rutransform/transformations/transformations/distraction/sentence_additions.py index 295b684..fd52e7a 100644 --- a/rutransform/transformations/transformations/distraction/sentence_additions.py +++ b/rutransform/transformations/transformations/distraction/sentence_additions.py @@ -175,8 +175,8 @@ def get_model_path(self) -> str: path to model in the HuggingFace library """ model_dict = { - "gpt2": "sberbank-ai/rugpt2_large", - "gpt3": "sberbank-ai/rugpt3large_based_on_gpt2", + "gpt2": "ai-forever/rugpt2_large", + "gpt3": "ai-forever/rugpt3large_based_on_gpt2", "mt5-base": "google/mt5-base", "mt5-small": "google/mt5-small", "mt5-large": "google/mt5-large", @@ -253,7 +253,6 @@ def generate( outputs = self.generator( sentence, max_length=self.args.max_length, - skip_special_tokens=True, num_return_sequences=1, num_beams=self.args.num_beams, early_stopping=self.args.early_stopping, diff --git a/rutransform/utils/args.py b/rutransform/utils/args.py index c83398f..e9a4967 100644 --- a/rutransform/utils/args.py +++ b/rutransform/utils/args.py @@ -53,8 +53,8 @@ class TransformArguments: generator: str = field( default="gpt3", metadata={ - "help": "generator model: 'gpt2' = sberbank-ai/rugpt2large, " - "'gpt3' = sberbank-ai/rugpt3small_based_on_gpt2, " + "help": "generator model: 'gpt2' = ai-forever/rugpt2large, " + "'gpt3' = ai-forever/rugpt3small_based_on_gpt2, " "'mt5-small' = google/mt5-small, 'mt5-base' = google/mt5-base, " "'mt5-large' = google/mt5-large" },