From 035aa7e577f98f0590b9b8eaaddada2e51ec06cc Mon Sep 17 00:00:00 2001 From: Error <65428515+ErrorBot1122@users.noreply.github.com> Date: Tue, 1 Nov 2022 14:22:09 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20Changed=20GPU=20Requirements?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added compatibility for non-Cuda devices --- rpunct/punctuate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rpunct/punctuate.py b/rpunct/punctuate.py index c0143c08..17d43b35 100644 --- a/rpunct/punctuate.py +++ b/rpunct/punctuate.py @@ -7,6 +7,7 @@ import logging from langdetect import detect from simpletransformers.ner import NERModel +from torch.cuda import is_available class RestorePuncts: @@ -15,7 +16,7 @@ def __init__(self, wrds_per_pred=250): self.overlap_wrds = 30 self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U'] self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels, - args={"silent": True, "max_seq_length": 512}) + args={"silent": True, "max_seq_length": 512}, has_cuda=is_available()) def punctuate(self, text: str, lang:str=''): """