diff --git a/rpunct/punctuate.py b/rpunct/punctuate.py index c0143c08..17d43b35 100644 --- a/rpunct/punctuate.py +++ b/rpunct/punctuate.py @@ -7,6 +7,7 @@ import logging from langdetect import detect from simpletransformers.ner import NERModel +from torch.cuda import is_available class RestorePuncts: @@ -15,7 +16,7 @@ def __init__(self, wrds_per_pred=250): self.overlap_wrds = 30 self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U'] self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels, - args={"silent": True, "max_seq_length": 512}) + args={"silent": True, "max_seq_length": 512}, has_cuda=is_available()) def punctuate(self, text: str, lang:str=''): """