From 9b7c73a21720266b105f3aad51dc171a60edafba Mon Sep 17 00:00:00 2001 From: Michael Fox Date: Thu, 13 Jan 2022 11:14:51 -0600 Subject: [PATCH 1/4] add use_cuda parameter --- rpunct/punctuate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rpunct/punctuate.py b/rpunct/punctuate.py index c0143c08..6b335f6b 100644 --- a/rpunct/punctuate.py +++ b/rpunct/punctuate.py @@ -10,11 +10,11 @@ class RestorePuncts: - def __init__(self, wrds_per_pred=250): + def __init__(self, wrds_per_pred=250, use_cuda=False): self.wrds_per_pred = wrds_per_pred self.overlap_wrds = 30 self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U'] - self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels, + self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512}) def punctuate(self, text: str, lang:str=''): From 0731029daf0a13cf70f7441dac4fcba254a846e5 Mon Sep 17 00:00:00 2001 From: Michael Fox Date: Fri, 14 Jan 2022 09:16:31 -0600 Subject: [PATCH 2/4] add silent True/False parameter --- rpunct/punctuate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rpunct/punctuate.py b/rpunct/punctuate.py index 6b335f6b..b83bb48d 100644 --- a/rpunct/punctuate.py +++ b/rpunct/punctuate.py @@ -10,14 +10,14 @@ class RestorePuncts: - def __init__(self, wrds_per_pred=250, use_cuda=False): + def __init__(self, wrds_per_pred=250, use_cuda=False, silent=False): self.wrds_per_pred = wrds_per_pred self.overlap_wrds = 30 self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U'] self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels, use_cuda=use_cuda, - args={"silent": True, "max_seq_length": 512}) + args={"silent": silent, "max_seq_length": 512}) - def punctuate(self, text: str, lang:str=''): + def punctuate(self, text: str, lang: str = ''): """ Performs punctuation restoration on arbitrarily large text. Detects if input is not English, if non-English was detected terminates predictions. From 4ba590c29c9a7402c63aee611eebb4ac800e40c1 Mon Sep 17 00:00:00 2001 From: Michael Fox Date: Fri, 14 Jan 2022 10:35:45 -0600 Subject: [PATCH 3/4] add diagnostic prints --- rpunct/punctuate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rpunct/punctuate.py b/rpunct/punctuate.py index b83bb48d..057b23e3 100644 --- a/rpunct/punctuate.py +++ b/rpunct/punctuate.py @@ -38,8 +38,10 @@ def punctuate(self, text: str, lang: str = ''): splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds) # predict slices # full_preds_lst contains tuple of labels and logits + print(f'predicting {len(splits)} slices') full_preds_lst = [self.predict(i['text']) for i in splits] # extract predictions, and discard logits + print(f'combining predictions') preds_lst = [i[0][0] for i in full_preds_lst] # join text slices combined_preds = self.combine_results(text, preds_lst) From f36df08bf411f01e663672db35d908c5137d22d7 Mon Sep 17 00:00:00 2001 From: Michael Fox Date: Thu, 3 Feb 2022 12:13:32 -0600 Subject: [PATCH 4/4] add local model capabilities --- rpunct/punctuate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rpunct/punctuate.py b/rpunct/punctuate.py index 057b23e3..e5f61eb8 100644 --- a/rpunct/punctuate.py +++ b/rpunct/punctuate.py @@ -10,11 +10,11 @@ class RestorePuncts: - def __init__(self, wrds_per_pred=250, use_cuda=False, silent=False): + def __init__(self, model='felflare/bert-restore-punctuation', wrds_per_pred=250, use_cuda=False, silent=False): self.wrds_per_pred = wrds_per_pred self.overlap_wrds = 30 self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U'] - self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels, use_cuda=use_cuda, + self.model = NERModel("bert", f"{model}", labels=self.valid_labels, use_cuda=use_cuda, args={"silent": silent, "max_seq_length": 512}) def punctuate(self, text: str, lang: str = ''):