From 043a3c9990a6225ccdbfcf42dc750adf91355b2a Mon Sep 17 00:00:00 2001 From: Debjyoti Mondal Date: Wed, 12 Jun 2024 19:18:25 +0530 Subject: [PATCH 1/2] Added resume ability. This PR allows users to restart training their model from some checkpoint. --- main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 52f45e0..7744fdc 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import numpy as np import torch import os +import pathlib import re import json import argparse @@ -266,7 +267,10 @@ def compute_metrics_rougel(eval_preds): ) if args.evaluate_dir is None: - trainer.train() + if list(pathlib.Path(save_dir).glob("checkpoint-*")) : + trainer.train(resume_from_checkpoint = True) + else : + train.train() trainer.save_model(save_dir) metrics = trainer.evaluate(eval_dataset = test_set, max_length=args.output_len) From d6cc3c4047ffd82ad02f0468bd4389ed508d4d33 Mon Sep 17 00:00:00 2001 From: Debjyoti Mondal Date: Tue, 10 Sep 2024 21:42:52 +0530 Subject: [PATCH 2/2] Update main.py added resume ability. --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 7744fdc..2febe3a 100644 --- a/main.py +++ b/main.py @@ -270,7 +270,7 @@ def compute_metrics_rougel(eval_preds): if list(pathlib.Path(save_dir).glob("checkpoint-*")) : trainer.train(resume_from_checkpoint = True) else : - train.train() + trainer.train() trainer.save_model(save_dir) metrics = trainer.evaluate(eval_dataset = test_set, max_length=args.output_len)