From 3fe7fce8e7708756e532bde68018e52959a61fdb Mon Sep 17 00:00:00 2001 From: valeriylo Date: Tue, 27 Jun 2023 15:46:30 +0300 Subject: [PATCH 1/2] Huggingface repo name changed; skip_special_tokens=True removed as it throws error for gtp-3 model --- .../transformations/distraction/sentence_additions.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/rutransform/transformations/transformations/distraction/sentence_additions.py b/rutransform/transformations/transformations/distraction/sentence_additions.py index 295b684..fd52e7a 100644 --- a/rutransform/transformations/transformations/distraction/sentence_additions.py +++ b/rutransform/transformations/transformations/distraction/sentence_additions.py @@ -175,8 +175,8 @@ def get_model_path(self) -> str: path to model in the HuggingFace library """ model_dict = { - "gpt2": "sberbank-ai/rugpt2_large", - "gpt3": "sberbank-ai/rugpt3large_based_on_gpt2", + "gpt2": "ai-forever/rugpt2_large", + "gpt3": "ai-forever/rugpt3large_based_on_gpt2", "mt5-base": "google/mt5-base", "mt5-small": "google/mt5-small", "mt5-large": "google/mt5-large", @@ -253,7 +253,6 @@ def generate( outputs = self.generator( sentence, max_length=self.args.max_length, - skip_special_tokens=True, num_return_sequences=1, num_beams=self.args.num_beams, early_stopping=self.args.early_stopping, From 8e18697019b542b7372b86d8de6a5329065171c2 Mon Sep 17 00:00:00 2001 From: valeriylo Date: Tue, 27 Jun 2023 15:48:16 +0300 Subject: [PATCH 2/2] Huggingface repo name changed --- rutransform/utils/args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rutransform/utils/args.py b/rutransform/utils/args.py index c83398f..e9a4967 100644 --- a/rutransform/utils/args.py +++ b/rutransform/utils/args.py @@ -53,8 +53,8 @@ class TransformArguments: generator: str = field( default="gpt3", metadata={ - "help": "generator model: 'gpt2' = sberbank-ai/rugpt2large, " - "'gpt3' = sberbank-ai/rugpt3small_based_on_gpt2, " + "help": "generator model: 'gpt2' = ai-forever/rugpt2large, " + "'gpt3' = ai-forever/rugpt3small_based_on_gpt2, " "'mt5-small' = google/mt5-small, 'mt5-base' = google/mt5-base, " "'mt5-large' = google/mt5-large" },