Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,45 @@

Nevise is a Persian spelling-checker developed by Dadmatech Co based on deep learning. Nevise is available in two versions. The second version has greater accuracy, the ability to correct errors based on spaces, and a better understanding of special characters like half space. These versions can be accessed via web services and as demos. We provide public access to the code and model checkpoint of the first version here.

## packages Installation
## Quick Start

Use the package manager [pip](https://pip.pypa.io/en/stable/) to install packages.
1. **Clone the Repository:**
```bash
git clone https://github.com/Dadmatech/Nevise.git
```

```bash
pip install -r requirements.txt
```
## Download model checkpoint and vocab and put them on "model" directory
2. **Install Dependencies:**
```bash
pip install -r Nevise/requirements.txt
```

3. **Download Model Checkpoint and Vocab:**
Create a directory for the model and the vocab:
```bash
mkdir Nevise/model
```

```bash
pip install gdown
mkdir model
cd model
gdown https://drive.google.com/uc?id=1Ki5WGR4yxftDEjROQLf9Br8KHef95k1F
gdown https://drive.google.com/uc?id=1nKeMdDnxIJpOv-OeFj00UnhoChuaY5Ns
```
## run
Download the model and the vocab:
```bash
gdown -O Nevise/model/model.pth.tar "https://drive.google.com/uc?id=1Ki5WGR4yxftDEjROQLf9Br8KHef95k1F"
gdown -O Nevise/model/vocab.pkl "https://drive.google.com/uc?id=1nKeMdDnxIJpOv-OeFj00UnhoChuaY5Ns"
```

In case of download failure, you can obtain the download links from your download manager and use the following commands:
```bash
wget -O Nevise/model/model.pth.tar "[DOWNLOAD URL OF model.pth.tar]"
wget -O Nevise/model/vocab.pkl "[DOWNLOAD URL OF vocab.pkl]"
```

5. **Run Spell-Checking:**
You can then use the model for spelling correction using the following command in which you must replace `"input.txt"` with the path to your input text and `"output.txt"` with the path to the file for writing the corrected text.

```bash
python Nevise/nevise.py --input-file "input.txt" --output-file "output.txt" --vocab-path "Nevise/model/vocab.pkl" --model-checkpoint-path "Nevise/model/model.pth.tar"
```

**NOTE:** For a more expressive output of the model including the pairs of wrong and corrected sentences along with the words that had an error and their corrected form, see `main.py`.

```bash
python main.py
```
# Demo

[Nevise(both versions)](https://dadmatech.ir/#/products/SpellChecker)
Expand Down
161 changes: 161 additions & 0 deletions nevise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import argparse
import os
from tqdm import tqdm
import re
import time
import torch
import utils
from helpers import load_vocab_dict
from helpers import batch_iter, labelize, bert_tokenize_for_valid_examples
from helpers import untokenize_without_unks, untokenize_without_unks2, get_model_nparams
from hazm import Normalizer
from models import SubwordBert
from utils import get_sentences_splitters

def model_inference(model, data, topk, DEVICE, BATCH_SIZE=16, vocab_=None):
"""
model: an instance of SubwordBert
data: list of tuples, with each tuple consisting of correct and incorrect
sentence string (would be split at whitespaces)
topk: how many of the topk softmax predictions are considered for metrics calculations
"""
if vocab_ is not None:
vocab = vocab_
print("###############################################")
inference_st_time = time.time()
_corr2corr, _corr2incorr, _incorr2corr, _incorr2incorr = 0, 0, 0, 0
_mistakes = []
VALID_BATCH_SIZE = BATCH_SIZE
valid_loss = 0.
print("data size: {}".format(len(data)))
data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False)
model.eval()
model.to(DEVICE)
results = []
line_index = 0
for batch_id, (batch_labels, batch_sentences) in tqdm(enumerate(data_iter)):
torch.cuda.empty_cache()
st_time = time.time()
# set batch data for bert
batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(batch_labels, batch_sentences)
if len(batch_labels_) == 0:
print("################")
print("Not predicting the following lines due to pre-processing mismatch: \n")
print([(a, b) for a, b in zip(batch_labels, batch_sentences)])
print("################")
continue
else:
batch_labels, batch_sentences = batch_labels_, batch_sentences_
batch_bert_inp = {k: v.to(DEVICE) for k, v in batch_bert_inp.items()}
# set batch data for others
batch_labels_ids, batch_lengths = labelize(batch_labels, vocab)
batch_lengths = batch_lengths.to(DEVICE)
batch_labels_ids = batch_labels_ids.to(DEVICE)

try:
with torch.no_grad():
"""
NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len)
"""
batch_loss, batch_predictions = model(batch_bert_inp, batch_bert_splits, targets=batch_labels_ids, topk=topk)
except RuntimeError:
print(f"batch_bert_inp:{len(batch_bert_inp.keys())},batch_labels_ids:{batch_labels_ids.shape}")
raise Exception("")
valid_loss += batch_loss
batch_lengths = batch_lengths.cpu().detach().numpy()
if topk == 1:
batch_predictions = untokenize_without_unks(batch_predictions, batch_lengths, vocab, batch_sentences)
else:
batch_predictions = untokenize_without_unks2(batch_predictions, batch_lengths, vocab, batch_sentences, topk=None)
batch_clean_sentences = [line for line in batch_labels]
batch_corrupt_sentences = [line for line in batch_sentences]
batch_predictions = [line for line in batch_predictions]

for i, (a, b, c) in enumerate(zip(batch_clean_sentences, batch_corrupt_sentences, batch_predictions)):
results.append({"id": line_index + i, "original": a, "noised": b, "predicted": c, "topk": [], "topk_prediction_probs": [], "topk_reranker_losses": []})
line_index += len(batch_clean_sentences)

'''
# update progress
progressBar(batch_id+1,
int(np.ceil(len(data) / VALID_BATCH_SIZE)),
["batch_time","batch_loss","avg_batch_loss","batch_acc","avg_batch_acc"],
[time.time()-st_time,batch_loss,valid_loss/(batch_id+1),None,None])
'''
print(f"\nEpoch {None} valid_loss: {valid_loss / (batch_id + 1)}")
print("total inference time for this data is: {:4f} secs".format(time.time() - inference_st_time))
print("###############################################")
return results


def load_model(vocab):
model = SubwordBert(3*len(vocab["chartoken2idx"]),vocab["token2idx"][ vocab["pad_token"] ],len(vocab["token_freq"]))
print(model)
print( get_model_nparams(model) )
return model


def load_pretrained(model, checkpoint_path, optimizer=None, device='cuda'):
if torch.cuda.is_available() and device != "cpu":
map_location = lambda storage, loc: storage.cuda()
else:
map_location = 'cpu'
print(f"Loading model params from checkpoint dir: {checkpoint_path}")
checkpoint_data = torch.load(checkpoint_path, map_location=map_location)
model.load_state_dict(checkpoint_data['model_state_dict'], strict=False)
if optimizer is not None:
optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
max_dev_acc, argmax_dev_acc = checkpoint_data["max_dev_acc"], checkpoint_data["argmax_dev_acc"]

if optimizer is not None:
return model, optimizer, max_dev_acc, argmax_dev_acc
return model


def load_pre_model(vocab_path, model_checkpoint_path):
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"loading vocab from {vocab_path}")
vocab = load_vocab_dict(vocab_path)
model = load_model(vocab)
model = load_pretrained(model, model_checkpoint_path)
return model, vocab, DEVICE


def spell_checking_on_sents(model, vocab, device, normalizer, txt):
sents, splitters = get_sentences_splitters(txt)
sents = [utils.space_special_chars(s) for s in sents]
sents = list(filter(lambda txt: (txt != '' and txt != ' '), sents))
test_data = [(normalizer.normalize(t), normalizer.normalize(t)) for t in sents]
greedy_results = model_inference(model, test_data, topk=1, DEVICE=device, BATCH_SIZE=1, vocab_=vocab)

corrected_text = ''.join([line['predicted'] + splitter for line, splitter in zip(greedy_results, splitters + ['\n'])])

return corrected_text


def main():
parser = argparse.ArgumentParser(description="Spell checking script.")
parser.add_argument("--input-file", required=True, help="Path to the input text file for spell checking.")
parser.add_argument("--output-file", default=None, help="Path to the output file. If not provided, the result will be printed.")
parser.add_argument("--vocab-path", default="model/vocab.pkl", help="Path to the vocabulary file.")
parser.add_argument("--model-checkpoint-path", default="model/model.pth.tar", help="Path to the model checkpoint.")
args = parser.parse_args()

with open(args.input_file, 'r', encoding='utf-8') as input_file:
input_text = input_file.read()

normalizer = Normalizer()
vocab_path = os.path.join(args.vocab_path)
model_checkpoint_path = os.path.join(args.model_checkpoint_path)
model, vocab, device = load_pre_model(vocab_path=vocab_path, model_checkpoint_path=model_checkpoint_path)

corrected_text = spell_checking_on_sents(model, vocab, device, normalizer, input_text)
if args.output_file:
with open(args.output_file, 'w', encoding='utf-8') as output_file:
output_file.write(corrected_text)
print("Output successfully written to file!")
else:
print(corrected_text)

if __name__ == '__main__':
main()
12 changes: 6 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
clean-text==0.3.0
hazm==0.7.0
numpy<1.20
torch==1.7.1
tqdm==4.46.1
transformers==4.3.3
clean-text==0.6.0
hazm==0.10.0
numpy==1.24.3
torch
tqdm==4.66.2
transformers==4.38.2