diff --git a/libcity/executor/contra_mlm_executor.py b/libcity/executor/contra_mlm_executor.py index 1fd1c75..9d2839c 100644 --- a/libcity/executor/contra_mlm_executor.py +++ b/libcity/executor/contra_mlm_executor.py @@ -39,11 +39,74 @@ def train(self, train_dataloader, eval_dataloader, test_dataloader=None): eval_loss = [] eval_acc = [] lr_list = [] + start_epoch = 0 + + + exp_id = self.config["exp_id"] + # Start Checking + last_best_epoch_id = -1 + if exp_id is not None: + self._logger.info('Check if there exist trained model') + """ + 首先检查是否存在已有的模型文件 + 有的话加载一个最好的模型,并eval验证 + """ + exp_id = self.config["exp_id"] + directory = self.cache_dir + fileNameList = [] + pattern_template = r'^{}_{}_epoch(\d+)\.tar$' + epoches_check = [] + pattern = pattern_template.format(self.config['model'], self.config['dataset']) + for _, _, files in os.walk(directory): + for fileName in files: + fileNameList.append(fileName) + for fileName in fileNameList: + match = re.match(pattern, fileName) + if match: + epoches_check.append(int(match.group(1))) + + if len(epoches_check) == 0: + self._logger.info("There exists no trained model") + else: + self._logger.info("Trained model with epoches:{}".format(epoches_check)) + # eval这些模型 + self._logger.info("Start evaluating trained models") + trained_eval_time = [] + trained_eval_loss = [] + trained_eval_acc = [] + trained_min_val_loss = float("inf") + + for epoch_idx in epoches_check: + self.load_model_with_epoch(epoch_idx) + t2 = time.time() + trained_eval_avg_loss, trained_eval_avg_acc = self._valid_epoch(eval_dataloader, epoch_idx, + mode='Eval') + end_time = time.time() + trained_eval_time.append(end_time - t2) + trained_eval_loss.append(trained_eval_avg_loss) + trained_eval_acc.append(trained_eval_avg_acc) + + if trained_eval_avg_loss < trained_min_val_loss: + self._logger.info("Trained Models : decrease from {:.4f} to {:.4f},epoch from {} to {}". \ + format(trained_min_val_loss, trained_eval_avg_loss, last_best_epoch_id,epoch_idx)) + trained_min_val_loss = trained_eval_avg_loss + last_best_epoch_id = epoch_idx + else: + self._logger.warning("Trained Models : last best epoch is {}".format(last_best_epoch_id)) + best_epoch = last_best_epoch_id + min_val_loss = trained_min_val_loss + eval_time.append(trained_eval_time) + eval_loss.append(trained_eval_loss) + eval_acc.append(trained_eval_acc) + max_epoch = max(epoches_check) + start_epoch = max_epoch + 1 + self.load_model_with_epoch(best_epoch) + num_batches = len(train_dataloader) self._logger.info("Num_batches: train={}, eval={}".format(num_batches, len(eval_dataloader))) - for epoch_idx in range(self.epochs): + for epoch_idx in range(start_epoch, self.epochs): start_time = time.time() train_avg_loss, train_avg_acc = self._train_epoch(train_dataloader, epoch_idx) t1 = time.time()