diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index b6c1de0..05682bd 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -31,12 +31,14 @@ pretrain: rec: epoch: 9 batch_size: 128 + test_print_every_batch: true optimizer: name: Adam lr: !!float 1e-3 conv: epoch: 90 batch_size: 128 + test_print_every_batch: true optimizer: name: Adam lr: !!float 1e-3 diff --git a/crslab/config/__init__.py b/crslab/config/__init__.py index f7556ef..740da5c 100644 --- a/crslab/config/__init__.py +++ b/crslab/config/__init__.py @@ -30,3 +30,4 @@ MODEL_PATH = os.path.join(DATA_PATH, 'model') PRETRAIN_PATH = os.path.join(MODEL_PATH, 'pretrain') EMBEDDING_PATH = os.path.join(DATA_PATH, 'embedding') +CSV_PATH = os.path.join(ROOT_PATH, 'test_result') diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index ca7a27b..97450ec 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -82,7 +82,7 @@ def get_dataset(opt, tokenize, restore, save) -> BaseDataset: if dataset in dataset_register_table: return dataset_register_table[dataset](opt, tokenize, restore, save) else: - raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented') + raise NotImplementedError(f'The dataset [{dataset}] has not been implemented') def get_dataloader(opt, dataset, vocab) -> BaseDataLoader: diff --git a/crslab/data/dataloader/base.py b/crslab/data/dataloader/base.py index c98479a..fa7147b 100644 --- a/crslab/data/dataloader/base.py +++ b/crslab/data/dataloader/base.py @@ -23,7 +23,7 @@ class BaseDataLoader(ABC): """ - def __init__(self, opt, dataset): + def __init__(self, opt, dataset, vocab=None): """ Args: opt (Config or dict): config for dataloader or the whole system. @@ -32,10 +32,11 @@ def __init__(self, opt, dataset): """ self.opt = opt self.dataset = dataset + self.vocab = vocab self.scale = opt.get('scale', 1) assert 0 < self.scale <= 1 - def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None): + def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None, file=None): """Collate batch data for system to fit Args: @@ -43,6 +44,7 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None): batch_size (int): shuffle (bool, optional): Defaults to True. process_fn (func, optional): function to process dataset before batchify. Defaults to None. + file (file descriptor, optional): file descriptor where to print input data for each batch. Yields: tuple or dict of torch.Tensor: batch data for system to fit @@ -63,9 +65,30 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None): for start_idx in tqdm(range(batch_num)): batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size] batch = [dataset[idx] for idx in batch_idx] + + if file: + for conv_dict in batch: + file.write('"') + for sentence_in_index in conv_dict['context_tokens']: + sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) + file.write(f'{sentence}\n') + file.write('"\t') + + file.write('"') + entities = "\n".join( + [self.vocab['id2entity'][entity_index] for entity_index in conv_dict['context_entities']] + ) + file.write(f'{entities}"\t') + + file.write('"') + words = "\n".join( + [self.vocab['id2word'][word_index] for word_index in conv_dict['context_words']] + ) + file.write(f'{words}"\t') + yield batch_fn(batch) - def get_conv_data(self, batch_size, shuffle=True): + def get_conv_data(self, batch_size, shuffle=True, file=None): """get_data wrapper for conversation. You can implement your own process_fn in ``conv_process_fn``, batch_fn in ``conv_batchify``. @@ -78,9 +101,9 @@ def get_conv_data(self, batch_size, shuffle=True): tuple or dict of torch.Tensor: batch data for conversation. """ - return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn) + return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn, file) - def get_rec_data(self, batch_size, shuffle=True): + def get_rec_data(self, batch_size, shuffle=True, file=None): """get_data wrapper for recommendation. You can implement your own process_fn in ``rec_process_fn``, batch_fn in ``rec_batchify``. @@ -93,7 +116,7 @@ def get_rec_data(self, batch_size, shuffle=True): tuple or dict of torch.Tensor: batch data for recommendation. """ - return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn) + return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn, file) def get_policy_data(self, batch_size, shuffle=True): """get_data wrapper for policy. diff --git a/crslab/data/dataloader/kgsf.py b/crslab/data/dataloader/kgsf.py index 6bbcac4..f69c9ba 100644 --- a/crslab/data/dataloader/kgsf.py +++ b/crslab/data/dataloader/kgsf.py @@ -50,7 +50,7 @@ def __init__(self, opt, dataset, vocab): vocab (dict): all kinds of useful size, idx and map between token and idx. """ - super().__init__(opt, dataset) + super().__init__(opt, dataset, vocab) self.n_entity = vocab['n_entity'] self.pad_token_idx = vocab['pad'] self.start_token_idx = vocab['start'] diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index cb6e47b..c84565b 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -83,6 +83,7 @@ def _load_data(self): 'entity2id': self.entity2id, 'id2entity': self.id2entity, 'word2id': self.word2id, + 'id2word': self.id2word, 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, @@ -127,6 +128,7 @@ def _load_other_data(self): # conceptNet # {concept: concept_id} self.word2id = json.load(open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8')) + self.id2word = {idx: word for word, idx in self.word2id.items()} self.n_word = max(self.word2id.values()) + 1 # {relation\t concept \t concept} self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8') diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index 7341aba..f7f43dc 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -87,15 +87,23 @@ def gen_evaluate(self, hyp, refs): self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) - def report(self, epoch=-1, mode='test'): + def report(self, epoch=-1, mode='test', file=None): for k, v in self.dist_set.items(): self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt)) reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()] + if self.tensorboard and mode != 'test': for idx, task_report in enumerate(reports): for each_metric, value in task_report.items(): self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch) - logger.info('\n' + nice_report(aggregate_unnamed_reports(reports))) + + if file: + for each_metric, value in self.rec_metrics.report().items(): + file.write(f'{str(value.value())}\t') + for each_metric, value in self.gen_metrics.report().items(): + file.write(f'{str(value.value())}\t') + else: + logger.info('\n' + nice_report(aggregate_unnamed_reports(reports))) def reset_metrics(self): # rec diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 9181271..02ef3cf 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -14,7 +14,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, restore_system=False, - interact=False, debug=False, tensorboard=False): + interact=False, test=False, debug=False, tensorboard=False): """A fast running api, which includes the complete process of training and testing models on specified datasets. Args: @@ -26,6 +26,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r save_system (bool): whether to save system. Defaults to False. restore_system (bool): whether to restore system. Defaults to False. interact (bool): whether to interact with the system. Defaults to False. + test (bool): whether to test with the saved system. Defaults to False. debug (bool): whether to debug the system. Defaults to False. .. _Github repo: @@ -64,11 +65,18 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r train_dataloader[task] = get_dataloader(config, train_data, vocab[task]) valid_dataloader[task] = get_dataloader(config, valid_data, vocab[task]) test_dataloader[task] = get_dataloader(config, test_data, vocab[task]) + + # test need saved system + if test: + restore_system = True + # system CRS = get_system(config, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system, interact, debug, tensorboard) if interact: CRS.interact() + elif test: + CRS.test() else: CRS.fit() if save_system: diff --git a/crslab/system/base.py b/crslab/system/base.py index 7ee41e0..9135103 100644 --- a/crslab/system/base.py +++ b/crslab/system/base.py @@ -361,3 +361,6 @@ def link(self, tokens, entities): if entity: linked_entities.append(entity[0]) return linked_entities + + def test(self): + raise NotImplementedError('Method test is not implemented.') diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 7f7b2a6..5e8ed4e 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -12,6 +12,7 @@ import torch from loguru import logger +from crslab.config import CSV_PATH from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem @@ -55,6 +56,8 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.rec_batch_size = self.rec_optim_opt['batch_size'] self.conv_batch_size = self.conv_optim_opt['batch_size'] + self.csv_path = os.path.join(CSV_PATH, "kgsf", opt['dataset']) + def rec_evaluate(self, rec_predict, item_label): rec_predict = rec_predict.cpu() rec_predict = rec_predict[:, self.item_ids] @@ -65,15 +68,17 @@ def rec_evaluate(self, rec_predict, item_label): item = self.item_ids.index(item) self.evaluator.rec_evaluate(rec_rank, item) - def conv_evaluate(self, prediction, response): + def conv_evaluate(self, prediction, response, file=None): prediction = prediction.tolist() response = response.tolist() for p, r in zip(prediction, response): p_str = ind2txt(p, self.ind2tok, self.end_token_idx) r_str = ind2txt(r, self.ind2tok, self.end_token_idx) + if file: + file.write(f'{p_str}\t{r_str}\t') self.evaluator.gen_evaluate(p_str, [r_str]) - def step(self, batch, stage, mode): + def step(self, batch, stage, mode, file=None): batch = [ele.to(self.device) for ele in batch] if stage == 'pretrain': info_loss = self.model.forward(batch, stage, mode) @@ -108,7 +113,7 @@ def step(self, batch, stage, mode): self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss)) else: pred = self.model.forward(batch, stage, mode) - self.conv_evaluate(pred, batch[-1]) + self.conv_evaluate(pred, batch[-1], file) else: raise @@ -180,10 +185,68 @@ def train_conversation(self): self.step(batch, stage='conv', mode='test') self.evaluator.report(mode='test') + def test_recommendation(self): + self.init_optim(self.rec_optim_opt, self.model.parameters()) + + logger.info('[Recommendation Test]') + with torch.no_grad(): + if self.rec_optim_opt.get('test_print_every_batch'): + rec_test_result_file_name = os.path.join(self.csv_path, 'rec.csv') + os.makedirs(os.path.dirname(rec_test_result_file_name), exist_ok=True) + with open(rec_test_result_file_name, 'w', encoding='utf-8', newline='') as f: + f.write('input context\tentities\twords\thit@1\tndcg@1\tmrr@1\t' + 'hit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') + logger.info(f"[Write {rec_test_result_file_name}]") + + for batch in self.test_dataloader.get_rec_data(1, shuffle=False, file=f): + self.evaluator.reset_metrics() + self.step(batch, stage='rec', mode='test') + self.evaluator.report(mode='test', file=f) + f.write('\n') + f.close() + else: + self.evaluator.reset_metrics() + for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): + self.step(batch, stage='rec', mode='test') + self.evaluator.report(mode='test') + + def test_conversation(self): + if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': + self.model.freeze_parameters() + else: + self.model.module.freeze_parameters() + self.init_optim(self.conv_optim_opt, self.model.parameters()) + + logger.info('[Conversation Test]') + with torch.no_grad(): + if self.conv_optim_opt.get('test_print_every_batch'): + conv_test_result_file_name = os.path.join(self.csv_path, 'conv.csv') + os.makedirs(os.path.dirname(conv_test_result_file_name), exist_ok=True) + with open(conv_test_result_file_name, 'w', encoding='utf-8', newline='') as f: + f.write('input context\tentities\twords\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\t' + 'greedy\taverage\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') + logger.info(f"[Write {conv_test_result_file_name}]") + + for batch in self.test_dataloader.get_conv_data(1, shuffle=False, file=f): + self.evaluator.reset_metrics() + self.step(batch, stage='conv', mode='test', file=f) + self.evaluator.report(mode='test', file=f) + f.write('\n') + f.close() + else: + self.evaluator.reset_metrics() + for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): + self.step(batch, stage='conv', mode='test') + self.evaluator.report(mode='test') + def fit(self): self.pretrain() self.train_recommender() self.train_conversation() + def test(self): + self.test_recommendation() + self.test_conversation() + def interact(self): pass diff --git a/run_crslab.py b/run_crslab.py index b0f7f73..d2e9ac9 100644 --- a/run_crslab.py +++ b/run_crslab.py @@ -33,6 +33,8 @@ help='use valid dataset to debug your system') parser.add_argument('-i', '--interact', action='store_true', help='interact with your system instead of training') + parser.add_argument('-t', '--test', action='store_true', + help='test with your saved system instead of training. Note: saved system is required') parser.add_argument('-tb', '--tensorboard', action='store_true', help='enable tensorboard to monitor train performance') args, _ = parser.parse_known_args() @@ -41,4 +43,4 @@ from crslab.quick_start import run_crslab run_crslab(config, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact, - args.debug, args.tensorboard) + args.test, args.debug, args.tensorboard)