Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion libcity/executor/contra_mlm_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,74 @@ def train(self, train_dataloader, eval_dataloader, test_dataloader=None):
eval_loss = []
eval_acc = []
lr_list = []
start_epoch = 0


exp_id = self.config["exp_id"]
# Start Checking
last_best_epoch_id = -1
if exp_id is not None:
self._logger.info('Check if there exist trained model')
"""
首先检查是否存在已有的模型文件
有的话加载一个最好的模型,并eval验证
"""
exp_id = self.config["exp_id"]
directory = self.cache_dir
fileNameList = []
pattern_template = r'^{}_{}_epoch(\d+)\.tar$'
epoches_check = []
pattern = pattern_template.format(self.config['model'], self.config['dataset'])
for _, _, files in os.walk(directory):
for fileName in files:
fileNameList.append(fileName)
for fileName in fileNameList:
match = re.match(pattern, fileName)
if match:
epoches_check.append(int(match.group(1)))

if len(epoches_check) == 0:
self._logger.info("There exists no trained model")
else:
self._logger.info("Trained model with epoches:{}".format(epoches_check))
# eval这些模型
self._logger.info("Start evaluating trained models")
trained_eval_time = []
trained_eval_loss = []
trained_eval_acc = []
trained_min_val_loss = float("inf")

for epoch_idx in epoches_check:
self.load_model_with_epoch(epoch_idx)
t2 = time.time()
trained_eval_avg_loss, trained_eval_avg_acc = self._valid_epoch(eval_dataloader, epoch_idx,
mode='Eval')
end_time = time.time()
trained_eval_time.append(end_time - t2)
trained_eval_loss.append(trained_eval_avg_loss)
trained_eval_acc.append(trained_eval_avg_acc)

if trained_eval_avg_loss < trained_min_val_loss:
self._logger.info("Trained Models : decrease from {:.4f} to {:.4f},epoch from {} to {}". \
format(trained_min_val_loss, trained_eval_avg_loss, last_best_epoch_id,epoch_idx))
trained_min_val_loss = trained_eval_avg_loss
last_best_epoch_id = epoch_idx
else:
self._logger.warning("Trained Models : last best epoch is {}".format(last_best_epoch_id))
best_epoch = last_best_epoch_id
min_val_loss = trained_min_val_loss
eval_time.append(trained_eval_time)
eval_loss.append(trained_eval_loss)
eval_acc.append(trained_eval_acc)
max_epoch = max(epoches_check)
start_epoch = max_epoch + 1
self.load_model_with_epoch(best_epoch)


num_batches = len(train_dataloader)
self._logger.info("Num_batches: train={}, eval={}".format(num_batches, len(eval_dataloader)))

for epoch_idx in range(self.epochs):
for epoch_idx in range(start_epoch, self.epochs):
start_time = time.time()
train_avg_loss, train_avg_acc = self._train_epoch(train_dataloader, epoch_idx)
t1 = time.time()
Expand Down