From 6c9c308ec556df50be09dfef9c1d01409b55a67f Mon Sep 17 00:00:00 2001 From: krta2 Date: Sat, 13 Nov 2021 00:56:19 +0900 Subject: [PATCH 01/23] Fix typo --- crslab/data/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crslab/data/__init__.py b/crslab/data/__init__.py index ca7a27b..97450ec 100644 --- a/crslab/data/__init__.py +++ b/crslab/data/__init__.py @@ -82,7 +82,7 @@ def get_dataset(opt, tokenize, restore, save) -> BaseDataset: if dataset in dataset_register_table: return dataset_register_table[dataset](opt, tokenize, restore, save) else: - raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented') + raise NotImplementedError(f'The dataset [{dataset}] has not been implemented') def get_dataloader(opt, dataset, vocab) -> BaseDataLoader: From 38d0b40b6a8539b4a061b895d221dc72bf82411c Mon Sep 17 00:00:00 2001 From: krta2 Date: Sat, 13 Nov 2021 00:59:10 +0900 Subject: [PATCH 02/23] Update kgsf/redial config Add test_print_every_batch option --- config/crs/kgsf/redial.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index b6c1de0..71d64c3 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -31,6 +31,7 @@ pretrain: rec: epoch: 9 batch_size: 128 + test_print_every_batch: true optimizer: name: Adam lr: !!float 1e-3 From 2eef4472b9a9a8fb5da6bd768fdfa69d8e6919a7 Mon Sep 17 00:00:00 2001 From: krta2 Date: Sat, 13 Nov 2021 01:01:59 +0900 Subject: [PATCH 03/23] Add vocab for base dataloader init parameter --- crslab/data/dataloader/kgsf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crslab/data/dataloader/kgsf.py b/crslab/data/dataloader/kgsf.py index 6bbcac4..f69c9ba 100644 --- a/crslab/data/dataloader/kgsf.py +++ b/crslab/data/dataloader/kgsf.py @@ -50,7 +50,7 @@ def __init__(self, opt, dataset, vocab): vocab (dict): all kinds of useful size, idx and map between token and idx. """ - super().__init__(opt, dataset) + super().__init__(opt, dataset, vocab) self.n_entity = vocab['n_entity'] self.pad_token_idx = vocab['pad'] self.start_token_idx = vocab['start'] From 62fdd7200b15d6e56018d3a32eff927ef9802073 Mon Sep 17 00:00:00 2001 From: krta2 Date: Sat, 13 Nov 2021 01:03:51 +0900 Subject: [PATCH 04/23] Update BaseDataLoader - add print every batch feature --- crslab/data/dataloader/base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/crslab/data/dataloader/base.py b/crslab/data/dataloader/base.py index c98479a..ea6950c 100644 --- a/crslab/data/dataloader/base.py +++ b/crslab/data/dataloader/base.py @@ -23,7 +23,7 @@ class BaseDataLoader(ABC): """ - def __init__(self, opt, dataset): + def __init__(self, opt, dataset, vocab=None): """ Args: opt (Config or dict): config for dataloader or the whole system. @@ -32,6 +32,7 @@ def __init__(self, opt, dataset): """ self.opt = opt self.dataset = dataset + self.vocab = vocab self.scale = opt.get('scale', 1) assert 0 < self.scale <= 1 @@ -63,6 +64,11 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None): for start_idx in tqdm(range(batch_num)): batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size] batch = [dataset[idx] for idx in batch_idx] + if batch_fn == self.rec_batchify and self.opt['rec'].get('test_print_every_batch'): + for conv_dict in batch: + for sentence_in_index in conv_dict['context_tokens']: + sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) + logger.info(sentence) yield batch_fn(batch) def get_conv_data(self, batch_size, shuffle=True): From 4ec605554a657655b852d94a140d3ed95d97e9c3 Mon Sep 17 00:00:00 2001 From: krta2 Date: Sat, 13 Nov 2021 01:04:55 +0900 Subject: [PATCH 05/23] Update KGSFSystem - add print every batch in rec test --- crslab/system/kgsf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 7f7b2a6..11ca5bf 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -148,8 +148,14 @@ def train_recommender(self): with torch.no_grad(): self.evaluator.reset_metrics() for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): + if self.rec_optim_opt.get('test_print_every_batch'): + self.evaluator.reset_metrics() + # logger.info(batch) self.step(batch, stage='rec', mode='test') - self.evaluator.report(mode='test') + if self.rec_optim_opt.get('test_print_every_batch'): + self.evaluator.report(mode='test') + if self.rec_optim_opt.get('test_print_every_batch'): + self.evaluator.report(mode='test') def train_conversation(self): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': From 4eb6610a5f0c52dc412d612d188510eef2085735 Mon Sep 17 00:00:00 2001 From: krta2 Date: Sat, 13 Nov 2021 01:05:43 +0900 Subject: [PATCH 06/23] Fix error --- crslab/system/kgsf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 11ca5bf..5cebb3c 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -154,7 +154,7 @@ def train_recommender(self): self.step(batch, stage='rec', mode='test') if self.rec_optim_opt.get('test_print_every_batch'): self.evaluator.report(mode='test') - if self.rec_optim_opt.get('test_print_every_batch'): + if not self.rec_optim_opt.get('test_print_every_batch'): self.evaluator.report(mode='test') def train_conversation(self): From a0dc8a032470aec18c0e9f646b31d53605e1406c Mon Sep 17 00:00:00 2001 From: Woohyun Lee Date: Tue, 16 Nov 2021 15:28:39 +0900 Subject: [PATCH 07/23] Test with saved model (#3) * Update KGSFSystem - add test method * Update QuickStart - add test for restore_system option Co-authored-by: airotod --- crslab/quick_start/quick_start.py | 2 ++ crslab/system/kgsf.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 9181271..81d9174 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -69,6 +69,8 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r interact, debug, tensorboard) if interact: CRS.interact() + elif restore_system: + CRS.test() else: CRS.fit() if save_system: diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 5cebb3c..0a604eb 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -193,3 +193,29 @@ def fit(self): def interact(self): pass + + def test_recommendation(self): + logger.info('[Recommendation Test]') + self.init_optim(self.rec_optim_opt, self.model.parameters()) + with torch.no_grad(): + self.evaluator.reset_metrics() + for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): + self.step(batch, stage='rec', mode='test') + self.evaluator.report(mode='test') + + def test_conversation(self): + logger.info('[Conversation Test]') + if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': + self.model.freeze_parameters() + else: + self.model.module.freeze_parameters() + self.init_optim(self.conv_optim_opt, self.model.parameters()) + with torch.no_grad(): + self.evaluator.reset_metrics() + for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): + self.step(batch, stage='conv', mode='test') + self.evaluator.report(mode='test') + + def test(self): + self.test_recommendation() + self.test_conversation() \ No newline at end of file From a9e04845f83a352b30ff3bcc4633628306f52712 Mon Sep 17 00:00:00 2001 From: Woohyun Lee Date: Tue, 16 Nov 2021 15:29:06 +0900 Subject: [PATCH 08/23] Add print every batch conv (#4) * Update kgsf/redial config - add test_print_every_batch option for conv * Update BaseDataLoader - add print every batch feature of conv * Update KGSFSystem - add print every batch in conv test - add print conv response and prediction Co-authored-by: airotod --- config/crs/kgsf/redial.yaml | 1 + crslab/data/dataloader/base.py | 2 +- crslab/system/kgsf.py | 9 ++++++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/config/crs/kgsf/redial.yaml b/config/crs/kgsf/redial.yaml index 71d64c3..05682bd 100644 --- a/config/crs/kgsf/redial.yaml +++ b/config/crs/kgsf/redial.yaml @@ -38,6 +38,7 @@ rec: conv: epoch: 90 batch_size: 128 + test_print_every_batch: true optimizer: name: Adam lr: !!float 1e-3 diff --git a/crslab/data/dataloader/base.py b/crslab/data/dataloader/base.py index ea6950c..933555e 100644 --- a/crslab/data/dataloader/base.py +++ b/crslab/data/dataloader/base.py @@ -64,7 +64,7 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None): for start_idx in tqdm(range(batch_num)): batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size] batch = [dataset[idx] for idx in batch_idx] - if batch_fn == self.rec_batchify and self.opt['rec'].get('test_print_every_batch'): + if (batch_fn == self.rec_batchify and self.opt['rec'].get('test_print_every_batch')) or (batch_fn == self.conv_batchify and self.opt['conv'].get('test_print_every_batch')): for conv_dict in batch: for sentence_in_index in conv_dict['context_tokens']: sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 0a604eb..c5efaad 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -71,6 +71,7 @@ def conv_evaluate(self, prediction, response): for p, r in zip(prediction, response): p_str = ind2txt(p, self.ind2tok, self.end_token_idx) r_str = ind2txt(r, self.ind2tok, self.end_token_idx) + logger.info(f'\n prediction: {p_str}\n response: {r_str}') self.evaluator.gen_evaluate(p_str, [r_str]) def step(self, batch, stage, mode): @@ -183,8 +184,14 @@ def train_conversation(self): with torch.no_grad(): self.evaluator.reset_metrics() for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): + if self.conv_optim_opt.get('test_print_every_batch'): + self.evaluator.reset_metrics() + # logger.info(batch) self.step(batch, stage='conv', mode='test') - self.evaluator.report(mode='test') + if self.conv_optim_opt.get('test_print_every_batch'): + self.evaluator.report(mode='test') + if not self.conv_optim_opt.get('test_print_every_batch'): + self.evaluator.report(mode='test') def fit(self): self.pretrain() From 03e317ae8bfee1606507c95efc136e3dfd65c8b9 Mon Sep 17 00:00:00 2001 From: krta2 Date: Tue, 16 Nov 2021 15:59:42 +0900 Subject: [PATCH 09/23] Add --test argument --- crslab/quick_start/quick_start.py | 8 ++++++-- crslab/system/base.py | 3 +++ run_crslab.py | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 81d9174..2b0d4d4 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -14,7 +14,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, restore_system=False, - interact=False, debug=False, tensorboard=False): + interact=False, test=False, debug=False, tensorboard=False): """A fast running api, which includes the complete process of training and testing models on specified datasets. Args: @@ -26,6 +26,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r save_system (bool): whether to save system. Defaults to False. restore_system (bool): whether to restore system. Defaults to False. interact (bool): whether to interact with the system. Defaults to False. + test (bool): wheter to test with the saved system. Defaults to False. debug (bool): whether to debug the system. Defaults to False. .. _Github repo: @@ -69,7 +70,10 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r interact, debug, tensorboard) if interact: CRS.interact() - elif restore_system: + elif test: + if not restore_system: + print('Need to restore saved model by argument --restore_system or -rs for test.') + return CRS.test() else: CRS.fit() diff --git a/crslab/system/base.py b/crslab/system/base.py index 7ee41e0..9135103 100644 --- a/crslab/system/base.py +++ b/crslab/system/base.py @@ -361,3 +361,6 @@ def link(self, tokens, entities): if entity: linked_entities.append(entity[0]) return linked_entities + + def test(self): + raise NotImplementedError('Method test is not implemented.') diff --git a/run_crslab.py b/run_crslab.py index b0f7f73..e99a3f6 100644 --- a/run_crslab.py +++ b/run_crslab.py @@ -33,6 +33,8 @@ help='use valid dataset to debug your system') parser.add_argument('-i', '--interact', action='store_true', help='interact with your system instead of training') + parser.add_argument('-t', '--test', action='store_true', + help='test with your saved system instead of training. Note: argument --resotre_system or -rs required') parser.add_argument('-tb', '--tensorboard', action='store_true', help='enable tensorboard to monitor train performance') args, _ = parser.parse_known_args() @@ -41,4 +43,4 @@ from crslab.quick_start import run_crslab run_crslab(config, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact, - args.debug, args.tensorboard) + args.test, args.debug, args.tensorboard) From 549c7e78fc4718866da67815c3d7ba85c5bc0c89 Mon Sep 17 00:00:00 2001 From: krta2 Date: Tue, 16 Nov 2021 16:00:20 +0900 Subject: [PATCH 10/23] Update KGSFSystem - remove print every batch feature in train --- crslab/system/kgsf.py | 56 +++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index c5efaad..7585a52 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -149,14 +149,8 @@ def train_recommender(self): with torch.no_grad(): self.evaluator.reset_metrics() for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): - if self.rec_optim_opt.get('test_print_every_batch'): - self.evaluator.reset_metrics() - # logger.info(batch) self.step(batch, stage='rec', mode='test') - if self.rec_optim_opt.get('test_print_every_batch'): - self.evaluator.report(mode='test') - if not self.rec_optim_opt.get('test_print_every_batch'): - self.evaluator.report(mode='test') + self.evaluator.report(mode='test') def train_conversation(self): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': @@ -184,45 +178,51 @@ def train_conversation(self): with torch.no_grad(): self.evaluator.reset_metrics() for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): - if self.conv_optim_opt.get('test_print_every_batch'): - self.evaluator.reset_metrics() - # logger.info(batch) self.step(batch, stage='conv', mode='test') - if self.conv_optim_opt.get('test_print_every_batch'): - self.evaluator.report(mode='test') - if not self.conv_optim_opt.get('test_print_every_batch'): - self.evaluator.report(mode='test') - - def fit(self): - self.pretrain() - self.train_recommender() - self.train_conversation() - - def interact(self): - pass + self.evaluator.report(mode='test') def test_recommendation(self): - logger.info('[Recommendation Test]') self.init_optim(self.rec_optim_opt, self.model.parameters()) + + logger.info('[Recommendation Test]') with torch.no_grad(): self.evaluator.reset_metrics() for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): + if self.rec_optim_opt.get('test_print_every_batch'): + self.evaluator.reset_metrics() self.step(batch, stage='rec', mode='test') - self.evaluator.report(mode='test') + if self.rec_optim_opt.get('test_print_every_batch'): + self.evaluator.report(mode='test') + if not self.rec_optim_opt.get('test_print_every_batch'): + self.evaluator.report(mode='test') def test_conversation(self): - logger.info('[Conversation Test]') if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': self.model.freeze_parameters() else: self.model.module.freeze_parameters() self.init_optim(self.conv_optim_opt, self.model.parameters()) + + logger.info('[Conversation Test]') with torch.no_grad(): self.evaluator.reset_metrics() for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): + if self.conv_optim_opt.get('test_print_every_batch'): + self.evaluator.reset_metrics() self.step(batch, stage='conv', mode='test') - self.evaluator.report(mode='test') - + if self.conv_optim_opt.get('test_print_every_batch'): + self.evaluator.report(mode='test') + if not self.conv_optim_opt.get('test_print_every_batch'): + self.evaluator.report(mode='test') + + def fit(self): + self.pretrain() + self.train_recommender() + self.train_conversation() + def test(self): self.test_recommendation() - self.test_conversation() \ No newline at end of file + self.test_conversation() + + def interact(self): + pass From 3ba5126eacde0dbb42af43cc54eaabfc63a88377 Mon Sep 17 00:00:00 2001 From: krta2 Date: Tue, 16 Nov 2021 17:27:11 +0900 Subject: [PATCH 11/23] Update test feature - test option enable restore_system automatically --- crslab/quick_start/quick_start.py | 10 ++++++---- run_crslab.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/crslab/quick_start/quick_start.py b/crslab/quick_start/quick_start.py index 2b0d4d4..02ef3cf 100644 --- a/crslab/quick_start/quick_start.py +++ b/crslab/quick_start/quick_start.py @@ -26,7 +26,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r save_system (bool): whether to save system. Defaults to False. restore_system (bool): whether to restore system. Defaults to False. interact (bool): whether to interact with the system. Defaults to False. - test (bool): wheter to test with the saved system. Defaults to False. + test (bool): whether to test with the saved system. Defaults to False. debug (bool): whether to debug the system. Defaults to False. .. _Github repo: @@ -65,15 +65,17 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r train_dataloader[task] = get_dataloader(config, train_data, vocab[task]) valid_dataloader[task] = get_dataloader(config, valid_data, vocab[task]) test_dataloader[task] = get_dataloader(config, test_data, vocab[task]) + + # test need saved system + if test: + restore_system = True + # system CRS = get_system(config, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system, interact, debug, tensorboard) if interact: CRS.interact() elif test: - if not restore_system: - print('Need to restore saved model by argument --restore_system or -rs for test.') - return CRS.test() else: CRS.fit() diff --git a/run_crslab.py b/run_crslab.py index e99a3f6..d2e9ac9 100644 --- a/run_crslab.py +++ b/run_crslab.py @@ -34,7 +34,7 @@ parser.add_argument('-i', '--interact', action='store_true', help='interact with your system instead of training') parser.add_argument('-t', '--test', action='store_true', - help='test with your saved system instead of training. Note: argument --resotre_system or -rs required') + help='test with your saved system instead of training. Note: saved system is required') parser.add_argument('-tb', '--tensorboard', action='store_true', help='enable tensorboard to monitor train performance') args, _ = parser.parse_known_args() From d331558dc4b731c0e4c6efd369677fbe7b5745ee Mon Sep 17 00:00:00 2001 From: krta2 Date: Tue, 16 Nov 2021 18:05:03 +0900 Subject: [PATCH 12/23] Refactor if statement of test_print_every_batch --- crslab/system/kgsf.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 7585a52..d4f9bf7 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -186,14 +186,15 @@ def test_recommendation(self): logger.info('[Recommendation Test]') with torch.no_grad(): - self.evaluator.reset_metrics() - for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): - if self.rec_optim_opt.get('test_print_every_batch'): + if self.rec_optim_opt.get('test_print_every_batch'): + for batch in self.test_dataloader.get_rec_data(1, shuffle=False): self.evaluator.reset_metrics() - self.step(batch, stage='rec', mode='test') - if self.rec_optim_opt.get('test_print_every_batch'): + self.step(batch, stage='rec', mode='test') self.evaluator.report(mode='test') - if not self.rec_optim_opt.get('test_print_every_batch'): + else: + self.evaluator.reset_metrics() + for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): + self.step(batch, stage='rec', mode='test') self.evaluator.report(mode='test') def test_conversation(self): @@ -205,14 +206,15 @@ def test_conversation(self): logger.info('[Conversation Test]') with torch.no_grad(): - self.evaluator.reset_metrics() - for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): - if self.conv_optim_opt.get('test_print_every_batch'): + if self.conv_optim_opt.get('test_print_every_batch'): + for batch in self.test_dataloader.get_conv_data(1, shuffle=False): self.evaluator.reset_metrics() - self.step(batch, stage='conv', mode='test') - if self.conv_optim_opt.get('test_print_every_batch'): + self.step(batch, stage='conv', mode='test') self.evaluator.report(mode='test') - if not self.conv_optim_opt.get('test_print_every_batch'): + else: + self.evaluator.reset_metrics() + for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): + self.step(batch, stage='conv', mode='test') self.evaluator.report(mode='test') def fit(self): From 0f915635719baada295e0ec1b20ec7e0351a2f98 Mon Sep 17 00:00:00 2001 From: krta2 Date: Tue, 16 Nov 2021 18:21:26 +0900 Subject: [PATCH 13/23] Update KGSFSystem print conversation prediction/response only when test_print_every_batch --- crslab/system/kgsf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index d4f9bf7..6c2f53e 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -71,7 +71,8 @@ def conv_evaluate(self, prediction, response): for p, r in zip(prediction, response): p_str = ind2txt(p, self.ind2tok, self.end_token_idx) r_str = ind2txt(r, self.ind2tok, self.end_token_idx) - logger.info(f'\n prediction: {p_str}\n response: {r_str}') + if self.conv_optim_opt.get('test_print_every_batch'): + logger.info(f'\n prediction: {p_str}\n response: {r_str}') self.evaluator.gen_evaluate(p_str, [r_str]) def step(self, batch, stage, mode): From fb0fc55963002df07e0320cc0c15a71a084b3b6b Mon Sep 17 00:00:00 2001 From: Seunghee Han Date: Wed, 17 Nov 2021 12:32:20 +0900 Subject: [PATCH 14/23] =?UTF-8?q?=EA=B2=B0=EA=B3=BC=EB=A5=BC=20CSV=20?= =?UTF-8?q?=ED=8C=8C=EC=9D=BC=EB=A1=9C=20=EC=B6=9C=EB=A0=A5=20(#5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update Config - add path for saving csv * Update BaseDataLoader - add write sentence into csv file * Update StandardEvaluator - add write metric into csv file * Update KGSFSystem - add write test result - add write prediction and response for conv into csv file * Update Config - update path for saving csv * Update BaseDataLoader - not log when writing csv file * Update StandardEvaluator - not log when writing csv file * Update KGSFSystem - update csv path variable name - not log when writing csv file - remove unnecessary comment --- crslab/config/__init__.py | 1 + crslab/data/dataloader/base.py | 23 +++++++++++++++-------- crslab/evaluator/standard.py | 10 ++++++++-- crslab/system/kgsf.py | 33 ++++++++++++++++++++++++--------- 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/crslab/config/__init__.py b/crslab/config/__init__.py index f7556ef..740da5c 100644 --- a/crslab/config/__init__.py +++ b/crslab/config/__init__.py @@ -30,3 +30,4 @@ MODEL_PATH = os.path.join(DATA_PATH, 'model') PRETRAIN_PATH = os.path.join(MODEL_PATH, 'pretrain') EMBEDDING_PATH = os.path.join(DATA_PATH, 'embedding') +CSV_PATH = os.path.join(ROOT_PATH, 'test_result') diff --git a/crslab/data/dataloader/base.py b/crslab/data/dataloader/base.py index 933555e..7508019 100644 --- a/crslab/data/dataloader/base.py +++ b/crslab/data/dataloader/base.py @@ -36,7 +36,7 @@ def __init__(self, opt, dataset, vocab=None): self.scale = opt.get('scale', 1) assert 0 < self.scale <= 1 - def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None): + def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None, file=None): """Collate batch data for system to fit Args: @@ -66,12 +66,19 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None): batch = [dataset[idx] for idx in batch_idx] if (batch_fn == self.rec_batchify and self.opt['rec'].get('test_print_every_batch')) or (batch_fn == self.conv_batchify and self.opt['conv'].get('test_print_every_batch')): for conv_dict in batch: - for sentence_in_index in conv_dict['context_tokens']: - sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) - logger.info(sentence) + if file: + file.write('"') + for sentence_in_index in conv_dict['context_tokens']: + sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) + file.write(f'{sentence}\n') + file.write('"\t') + else: + for sentence_in_index in conv_dict['context_tokens']: + sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) + logger.info(sentence) yield batch_fn(batch) - def get_conv_data(self, batch_size, shuffle=True): + def get_conv_data(self, batch_size, shuffle=True, file=None): """get_data wrapper for conversation. You can implement your own process_fn in ``conv_process_fn``, batch_fn in ``conv_batchify``. @@ -84,9 +91,9 @@ def get_conv_data(self, batch_size, shuffle=True): tuple or dict of torch.Tensor: batch data for conversation. """ - return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn) + return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn, file) - def get_rec_data(self, batch_size, shuffle=True): + def get_rec_data(self, batch_size, shuffle=True, file=None): """get_data wrapper for recommendation. You can implement your own process_fn in ``rec_process_fn``, batch_fn in ``rec_batchify``. @@ -99,7 +106,7 @@ def get_rec_data(self, batch_size, shuffle=True): tuple or dict of torch.Tensor: batch data for recommendation. """ - return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn) + return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn, file) def get_policy_data(self, batch_size, shuffle=True): """get_data wrapper for policy. diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index 7341aba..181a1e6 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -87,15 +87,21 @@ def gen_evaluate(self, hyp, refs): self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs)) self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs)) - def report(self, epoch=-1, mode='test'): + def report(self, epoch=-1, mode='test', file=None): for k, v in self.dist_set.items(): self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt)) reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()] + if file: + for each_metric, value in self.rec_metrics.report().items(): + file.write(f'{str(value.value())}\t') + for each_metric, value in self.gen_metrics.report().items(): + file.write(f'{str(value.value())}\t') if self.tensorboard and mode != 'test': for idx, task_report in enumerate(reports): for each_metric, value in task_report.items(): self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch) - logger.info('\n' + nice_report(aggregate_unnamed_reports(reports))) + if not file: + logger.info('\n' + nice_report(aggregate_unnamed_reports(reports))) def reset_metrics(self): # rec diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 6c2f53e..dd69dc9 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -12,6 +12,7 @@ import torch from loguru import logger +from crslab.config import CSV_PATH from crslab.evaluator.metrics.base import AverageMetric from crslab.evaluator.metrics.gen import PPLMetric from crslab.system.base import BaseSystem @@ -55,6 +56,8 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc self.rec_batch_size = self.rec_optim_opt['batch_size'] self.conv_batch_size = self.conv_optim_opt['batch_size'] + self.csv_path = os.path.join(CSV_PATH, "kgsf", opt['dataset']) + def rec_evaluate(self, rec_predict, item_label): rec_predict = rec_predict.cpu() rec_predict = rec_predict[:, self.item_ids] @@ -65,17 +68,19 @@ def rec_evaluate(self, rec_predict, item_label): item = self.item_ids.index(item) self.evaluator.rec_evaluate(rec_rank, item) - def conv_evaluate(self, prediction, response): + def conv_evaluate(self, prediction, response, file=None): prediction = prediction.tolist() response = response.tolist() for p, r in zip(prediction, response): p_str = ind2txt(p, self.ind2tok, self.end_token_idx) r_str = ind2txt(r, self.ind2tok, self.end_token_idx) - if self.conv_optim_opt.get('test_print_every_batch'): + if file: + file.write(f'{p_str}\t{r_str}\t') + elif self.conv_optim_opt.get('test_print_every_batch'): logger.info(f'\n prediction: {p_str}\n response: {r_str}') self.evaluator.gen_evaluate(p_str, [r_str]) - def step(self, batch, stage, mode): + def step(self, batch, stage, mode, file=None): batch = [ele.to(self.device) for ele in batch] if stage == 'pretrain': info_loss = self.model.forward(batch, stage, mode) @@ -110,7 +115,7 @@ def step(self, batch, stage, mode): self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss)) else: pred = self.model.forward(batch, stage, mode) - self.conv_evaluate(pred, batch[-1]) + self.conv_evaluate(pred, batch[-1], file) else: raise @@ -186,12 +191,17 @@ def test_recommendation(self): self.init_optim(self.rec_optim_opt, self.model.parameters()) logger.info('[Recommendation Test]') + f = open(os.path.join(self.csv_path, 'rec.csv'), 'w', encoding='utf-8', newline='') + f.write('sentences\thit@1\tndcg@1\tmrr@1\thit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') + logger.info(f"[Write {os.path.join(self.csv_path, 'rec.csv')}]") with torch.no_grad(): if self.rec_optim_opt.get('test_print_every_batch'): - for batch in self.test_dataloader.get_rec_data(1, shuffle=False): + for batch in self.test_dataloader.get_rec_data(1, shuffle=False, file=f): self.evaluator.reset_metrics() self.step(batch, stage='rec', mode='test') - self.evaluator.report(mode='test') + self.evaluator.report(mode='test', file=f) + f.write('\n') + f.close() else: self.evaluator.reset_metrics() for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): @@ -206,12 +216,17 @@ def test_conversation(self): self.init_optim(self.conv_optim_opt, self.model.parameters()) logger.info('[Conversation Test]') + f = open(os.path.join(self.csv_path, 'conv.csv'), 'w', encoding='utf-8', newline='') + f.write('sentences\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\tgreedy\taverage\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') + logger.info(f"[Write {os.path.join(self.csv_path, 'conv.csv')}]") with torch.no_grad(): if self.conv_optim_opt.get('test_print_every_batch'): - for batch in self.test_dataloader.get_conv_data(1, shuffle=False): + for batch in self.test_dataloader.get_conv_data(1, shuffle=False, file=f): self.evaluator.reset_metrics() - self.step(batch, stage='conv', mode='test') - self.evaluator.report(mode='test') + self.step(batch, stage='conv', mode='test', file=f) + self.evaluator.report(mode='test', file=f) + f.write('\n') + f.close() else: self.evaluator.reset_metrics() for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): From 1b59d7b3d7d17110ef597bbdbc0686757e6ce01a Mon Sep 17 00:00:00 2001 From: krta2 Date: Thu, 18 Nov 2021 00:32:05 +0900 Subject: [PATCH 15/23] Refactor test_recommendation, test_conversation - Open file only when needed --- crslab/system/kgsf.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index dd69dc9..90667c0 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -191,11 +191,13 @@ def test_recommendation(self): self.init_optim(self.rec_optim_opt, self.model.parameters()) logger.info('[Recommendation Test]') - f = open(os.path.join(self.csv_path, 'rec.csv'), 'w', encoding='utf-8', newline='') - f.write('sentences\thit@1\tndcg@1\tmrr@1\thit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') - logger.info(f"[Write {os.path.join(self.csv_path, 'rec.csv')}]") with torch.no_grad(): if self.rec_optim_opt.get('test_print_every_batch'): + rec_test_result_file_name = 'rec.csv' + f = open(os.path.join(self.csv_path, rec_test_result_file_name), 'w', encoding='utf-8', newline='') + f.write('sentences\thit@1\tndcg@1\tmrr@1\thit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') + logger.info(f"[Write {os.path.join(self.csv_path, rec_test_result_file_name)}]") + for batch in self.test_dataloader.get_rec_data(1, shuffle=False, file=f): self.evaluator.reset_metrics() self.step(batch, stage='rec', mode='test') @@ -216,11 +218,14 @@ def test_conversation(self): self.init_optim(self.conv_optim_opt, self.model.parameters()) logger.info('[Conversation Test]') - f = open(os.path.join(self.csv_path, 'conv.csv'), 'w', encoding='utf-8', newline='') - f.write('sentences\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\tgreedy\taverage\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') - logger.info(f"[Write {os.path.join(self.csv_path, 'conv.csv')}]") with torch.no_grad(): if self.conv_optim_opt.get('test_print_every_batch'): + conv_test_result_file_name = 'rec.csv' + f = open(os.path.join(self.csv_path, conv_test_result_file_name), 'w', encoding='utf-8', newline='') + f.write('sentences\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\tgreedy\taverage' + '\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') + logger.info(f"[Write {os.path.join(self.csv_path, conv_test_result_file_name)}]") + for batch in self.test_dataloader.get_conv_data(1, shuffle=False, file=f): self.evaluator.reset_metrics() self.step(batch, stage='conv', mode='test', file=f) From 5042095db2dad069f1c0c1c4696514834d49dea8 Mon Sep 17 00:00:00 2001 From: krta2 Date: Thu, 18 Nov 2021 01:02:23 +0900 Subject: [PATCH 16/23] Update file printing behavior - print file only in test mode (--test) and test_print_every_batch flag on. --- crslab/data/dataloader/base.py | 22 +++++++++++----------- crslab/evaluator/standard.py | 12 +++++++----- crslab/system/kgsf.py | 2 -- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/crslab/data/dataloader/base.py b/crslab/data/dataloader/base.py index 7508019..cf0c1a4 100644 --- a/crslab/data/dataloader/base.py +++ b/crslab/data/dataloader/base.py @@ -44,6 +44,7 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None, file=Non batch_size (int): shuffle (bool, optional): Defaults to True. process_fn (func, optional): function to process dataset before batchify. Defaults to None. + file (file descriptor, optional): file descriptor where to print input data for each batch. Yields: tuple or dict of torch.Tensor: batch data for system to fit @@ -64,18 +65,17 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None, file=Non for start_idx in tqdm(range(batch_num)): batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size] batch = [dataset[idx] for idx in batch_idx] - if (batch_fn == self.rec_batchify and self.opt['rec'].get('test_print_every_batch')) or (batch_fn == self.conv_batchify and self.opt['conv'].get('test_print_every_batch')): + + if file: for conv_dict in batch: - if file: - file.write('"') - for sentence_in_index in conv_dict['context_tokens']: - sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) - file.write(f'{sentence}\n') - file.write('"\t') - else: - for sentence_in_index in conv_dict['context_tokens']: - sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) - logger.info(sentence) + file.write('"') + + for sentence_in_index in conv_dict['context_tokens']: + sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) + file.write(f'{sentence}\n') + + file.write('"\t') + yield batch_fn(batch) def get_conv_data(self, batch_size, shuffle=True, file=None): diff --git a/crslab/evaluator/standard.py b/crslab/evaluator/standard.py index 181a1e6..f7f43dc 100644 --- a/crslab/evaluator/standard.py +++ b/crslab/evaluator/standard.py @@ -91,16 +91,18 @@ def report(self, epoch=-1, mode='test', file=None): for k, v in self.dist_set.items(): self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt)) reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()] + + if self.tensorboard and mode != 'test': + for idx, task_report in enumerate(reports): + for each_metric, value in task_report.items(): + self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch) + if file: for each_metric, value in self.rec_metrics.report().items(): file.write(f'{str(value.value())}\t') for each_metric, value in self.gen_metrics.report().items(): file.write(f'{str(value.value())}\t') - if self.tensorboard and mode != 'test': - for idx, task_report in enumerate(reports): - for each_metric, value in task_report.items(): - self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch) - if not file: + else: logger.info('\n' + nice_report(aggregate_unnamed_reports(reports))) def reset_metrics(self): diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 90667c0..fe827f4 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -76,8 +76,6 @@ def conv_evaluate(self, prediction, response, file=None): r_str = ind2txt(r, self.ind2tok, self.end_token_idx) if file: file.write(f'{p_str}\t{r_str}\t') - elif self.conv_optim_opt.get('test_print_every_batch'): - logger.info(f'\n prediction: {p_str}\n response: {r_str}') self.evaluator.gen_evaluate(p_str, [r_str]) def step(self, batch, stage, mode, file=None): From db7957d70571b3ae85bdd4dee6c785b1357691f8 Mon Sep 17 00:00:00 2001 From: krta2 Date: Thu, 18 Nov 2021 01:48:57 +0900 Subject: [PATCH 17/23] Add input entities in batch print --- crslab/data/dataloader/base.py | 8 ++++++-- crslab/system/kgsf.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/crslab/data/dataloader/base.py b/crslab/data/dataloader/base.py index cf0c1a4..60913a9 100644 --- a/crslab/data/dataloader/base.py +++ b/crslab/data/dataloader/base.py @@ -69,13 +69,17 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None, file=Non if file: for conv_dict in batch: file.write('"') - for sentence_in_index in conv_dict['context_tokens']: sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) file.write(f'{sentence}\n') - file.write('"\t') + file.write('"') + entities = "\n".join( + [self.vocab['id2entity'][entity_index] for entity_index in conv_dict['context_entities']] + ) + file.write(f'{entities}"\t') + yield batch_fn(batch) def get_conv_data(self, batch_size, shuffle=True, file=None): diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index fe827f4..f37ce1d 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -193,7 +193,7 @@ def test_recommendation(self): if self.rec_optim_opt.get('test_print_every_batch'): rec_test_result_file_name = 'rec.csv' f = open(os.path.join(self.csv_path, rec_test_result_file_name), 'w', encoding='utf-8', newline='') - f.write('sentences\thit@1\tndcg@1\tmrr@1\thit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') + f.write('sentences\tentities\thit@1\tndcg@1\tmrr@1\thit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') logger.info(f"[Write {os.path.join(self.csv_path, rec_test_result_file_name)}]") for batch in self.test_dataloader.get_rec_data(1, shuffle=False, file=f): @@ -220,7 +220,7 @@ def test_conversation(self): if self.conv_optim_opt.get('test_print_every_batch'): conv_test_result_file_name = 'rec.csv' f = open(os.path.join(self.csv_path, conv_test_result_file_name), 'w', encoding='utf-8', newline='') - f.write('sentences\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\tgreedy\taverage' + f.write('sentences\tentities\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\tgreedy\taverage' '\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') logger.info(f"[Write {os.path.join(self.csv_path, conv_test_result_file_name)}]") From ac3827fe08dfefa8d379bcb69ca44a5ba432125b Mon Sep 17 00:00:00 2001 From: krta2 Date: Thu, 18 Nov 2021 02:11:31 +0900 Subject: [PATCH 18/23] Add role at front of sentence --- crslab/data/dataloader/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crslab/data/dataloader/base.py b/crslab/data/dataloader/base.py index 60913a9..869f10f 100644 --- a/crslab/data/dataloader/base.py +++ b/crslab/data/dataloader/base.py @@ -70,7 +70,9 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None, file=Non for conv_dict in batch: file.write('"') for sentence_in_index in conv_dict['context_tokens']: - sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) + sentence = f'{conv_dict["role"]}: ' + " ".join( + [self.vocab['ind2tok'][index] for index in sentence_in_index] + ) file.write(f'{sentence}\n') file.write('"\t') From 51a9769f938911145b3483e1d1dc828d930fcf75 Mon Sep 17 00:00:00 2001 From: krta2 Date: Thu, 18 Nov 2021 02:20:24 +0900 Subject: [PATCH 19/23] Update csv header term --- crslab/system/kgsf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index f37ce1d..4cd3b94 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -193,7 +193,8 @@ def test_recommendation(self): if self.rec_optim_opt.get('test_print_every_batch'): rec_test_result_file_name = 'rec.csv' f = open(os.path.join(self.csv_path, rec_test_result_file_name), 'w', encoding='utf-8', newline='') - f.write('sentences\tentities\thit@1\tndcg@1\tmrr@1\thit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') + f.write('input context\tentities\thit@1\tndcg@1\tmrr@1\t' + 'hit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') logger.info(f"[Write {os.path.join(self.csv_path, rec_test_result_file_name)}]") for batch in self.test_dataloader.get_rec_data(1, shuffle=False, file=f): @@ -220,8 +221,8 @@ def test_conversation(self): if self.conv_optim_opt.get('test_print_every_batch'): conv_test_result_file_name = 'rec.csv' f = open(os.path.join(self.csv_path, conv_test_result_file_name), 'w', encoding='utf-8', newline='') - f.write('sentences\tentities\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\tgreedy\taverage' - '\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') + f.write('input context\tentities\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\tgreedy\t' + 'average\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') logger.info(f"[Write {os.path.join(self.csv_path, conv_test_result_file_name)}]") for batch in self.test_dataloader.get_conv_data(1, shuffle=False, file=f): From 104b2651607a00bed75b98eef74b7a7d4b260d55 Mon Sep 17 00:00:00 2001 From: krta2 Date: Thu, 18 Nov 2021 11:11:07 +0900 Subject: [PATCH 20/23] Update KGSFSystem - automatically make file output dir when not exist --- crslab/system/kgsf.py | 64 +++++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 4cd3b94..c25aeba 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -191,18 +191,26 @@ def test_recommendation(self): logger.info('[Recommendation Test]') with torch.no_grad(): if self.rec_optim_opt.get('test_print_every_batch'): - rec_test_result_file_name = 'rec.csv' - f = open(os.path.join(self.csv_path, rec_test_result_file_name), 'w', encoding='utf-8', newline='') - f.write('input context\tentities\thit@1\tndcg@1\tmrr@1\t' - 'hit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') - logger.info(f"[Write {os.path.join(self.csv_path, rec_test_result_file_name)}]") - - for batch in self.test_dataloader.get_rec_data(1, shuffle=False, file=f): - self.evaluator.reset_metrics() - self.step(batch, stage='rec', mode='test') - self.evaluator.report(mode='test', file=f) - f.write('\n') - f.close() + rec_test_result_file_name = os.path.join(self.csv_path, 'rec.csv') + if not os.path.exists(os.path.dirname(rec_test_result_file_name)): + try: + os.makedirs(os.path.dirname(rec_test_result_file_name)) + except OSError as exc: # Guard against race condition + import errno + if exc.errno != errno.EEXIST: + raise + + with open(rec_test_result_file_name, 'w', encoding='utf-8', newline='') as f: + f.write('input context\tentities\thit@1\tndcg@1\tmrr@1\t' + 'hit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') + logger.info(f"[Write {rec_test_result_file_name}]") + + for batch in self.test_dataloader.get_rec_data(1, shuffle=False, file=f): + self.evaluator.reset_metrics() + self.step(batch, stage='rec', mode='test') + self.evaluator.report(mode='test', file=f) + f.write('\n') + f.close() else: self.evaluator.reset_metrics() for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): @@ -219,18 +227,26 @@ def test_conversation(self): logger.info('[Conversation Test]') with torch.no_grad(): if self.conv_optim_opt.get('test_print_every_batch'): - conv_test_result_file_name = 'rec.csv' - f = open(os.path.join(self.csv_path, conv_test_result_file_name), 'w', encoding='utf-8', newline='') - f.write('input context\tentities\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\tgreedy\t' - 'average\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') - logger.info(f"[Write {os.path.join(self.csv_path, conv_test_result_file_name)}]") - - for batch in self.test_dataloader.get_conv_data(1, shuffle=False, file=f): - self.evaluator.reset_metrics() - self.step(batch, stage='conv', mode='test', file=f) - self.evaluator.report(mode='test', file=f) - f.write('\n') - f.close() + conv_test_result_file_name = os.path.join(self.csv_path, 'conv.csv') + if not os.path.exists(os.path.dirname(conv_test_result_file_name)): + try: + os.makedirs(os.path.dirname(conv_test_result_file_name)) + except OSError as exc: # Guard against race condition + import errno + if exc.errno != errno.EEXIST: + raise + + with open(conv_test_result_file_name, 'w', encoding='utf-8', newline='') as f: + f.write('input context\tentities\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\t' + 'greedy\taverage\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') + logger.info(f"[Write {conv_test_result_file_name}]") + + for batch in self.test_dataloader.get_conv_data(1, shuffle=False, file=f): + self.evaluator.reset_metrics() + self.step(batch, stage='conv', mode='test', file=f) + self.evaluator.report(mode='test', file=f) + f.write('\n') + f.close() else: self.evaluator.reset_metrics() for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): From f947dc3acfedd0066f3bdb727831ec5c62fdf262 Mon Sep 17 00:00:00 2001 From: krta2 Date: Thu, 18 Nov 2021 11:14:42 +0900 Subject: [PATCH 21/23] Refactor dir making --- crslab/system/kgsf.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index c25aeba..71c78b4 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -192,14 +192,7 @@ def test_recommendation(self): with torch.no_grad(): if self.rec_optim_opt.get('test_print_every_batch'): rec_test_result_file_name = os.path.join(self.csv_path, 'rec.csv') - if not os.path.exists(os.path.dirname(rec_test_result_file_name)): - try: - os.makedirs(os.path.dirname(rec_test_result_file_name)) - except OSError as exc: # Guard against race condition - import errno - if exc.errno != errno.EEXIST: - raise - + os.makedirs(os.path.dirname(rec_test_result_file_name), exist_ok=True) with open(rec_test_result_file_name, 'w', encoding='utf-8', newline='') as f: f.write('input context\tentities\thit@1\tndcg@1\tmrr@1\t' 'hit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') @@ -228,14 +221,7 @@ def test_conversation(self): with torch.no_grad(): if self.conv_optim_opt.get('test_print_every_batch'): conv_test_result_file_name = os.path.join(self.csv_path, 'conv.csv') - if not os.path.exists(os.path.dirname(conv_test_result_file_name)): - try: - os.makedirs(os.path.dirname(conv_test_result_file_name)) - except OSError as exc: # Guard against race condition - import errno - if exc.errno != errno.EEXIST: - raise - + os.makedirs(os.path.dirname(conv_test_result_file_name), exist_ok=True) with open(conv_test_result_file_name, 'w', encoding='utf-8', newline='') as f: f.write('input context\tentities\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\t' 'greedy\taverage\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') From 64a8fa0ed90250a445bce2ebce0cae84ae33d80f Mon Sep 17 00:00:00 2001 From: krta2 Date: Thu, 18 Nov 2021 12:00:43 +0900 Subject: [PATCH 22/23] Fix printing useless role info --- crslab/data/dataloader/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crslab/data/dataloader/base.py b/crslab/data/dataloader/base.py index 869f10f..60913a9 100644 --- a/crslab/data/dataloader/base.py +++ b/crslab/data/dataloader/base.py @@ -70,9 +70,7 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None, file=Non for conv_dict in batch: file.write('"') for sentence_in_index in conv_dict['context_tokens']: - sentence = f'{conv_dict["role"]}: ' + " ".join( - [self.vocab['ind2tok'][index] for index in sentence_in_index] - ) + sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index]) file.write(f'{sentence}\n') file.write('"\t') From 7d85ae524f7f65432a922dd3b8074c747b1dd5b3 Mon Sep 17 00:00:00 2001 From: krta2 Date: Thu, 18 Nov 2021 12:45:17 +0900 Subject: [PATCH 23/23] Add words in output --- crslab/data/dataloader/base.py | 6 ++++++ crslab/data/dataset/redial/redial.py | 2 ++ crslab/system/kgsf.py | 4 ++-- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/crslab/data/dataloader/base.py b/crslab/data/dataloader/base.py index 60913a9..fa7147b 100644 --- a/crslab/data/dataloader/base.py +++ b/crslab/data/dataloader/base.py @@ -80,6 +80,12 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None, file=Non ) file.write(f'{entities}"\t') + file.write('"') + words = "\n".join( + [self.vocab['id2word'][word_index] for word_index in conv_dict['context_words']] + ) + file.write(f'{words}"\t') + yield batch_fn(batch) def get_conv_data(self, batch_size, shuffle=True, file=None): diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index cb6e47b..c84565b 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -83,6 +83,7 @@ def _load_data(self): 'entity2id': self.entity2id, 'id2entity': self.id2entity, 'word2id': self.word2id, + 'id2word': self.id2word, 'vocab_size': len(self.tok2ind), 'n_entity': self.n_entity, 'n_word': self.n_word, @@ -127,6 +128,7 @@ def _load_other_data(self): # conceptNet # {concept: concept_id} self.word2id = json.load(open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8')) + self.id2word = {idx: word for word, idx in self.word2id.items()} self.n_word = max(self.word2id.values()) + 1 # {relation\t concept \t concept} self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8') diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 71c78b4..5e8ed4e 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -194,7 +194,7 @@ def test_recommendation(self): rec_test_result_file_name = os.path.join(self.csv_path, 'rec.csv') os.makedirs(os.path.dirname(rec_test_result_file_name), exist_ok=True) with open(rec_test_result_file_name, 'w', encoding='utf-8', newline='') as f: - f.write('input context\tentities\thit@1\tndcg@1\tmrr@1\t' + f.write('input context\tentities\twords\thit@1\tndcg@1\tmrr@1\t' 'hit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n') logger.info(f"[Write {rec_test_result_file_name}]") @@ -223,7 +223,7 @@ def test_conversation(self): conv_test_result_file_name = os.path.join(self.csv_path, 'conv.csv') os.makedirs(os.path.dirname(conv_test_result_file_name), exist_ok=True) with open(conv_test_result_file_name, 'w', encoding='utf-8', newline='') as f: - f.write('input context\tentities\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\t' + f.write('input context\tentities\twords\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\t' 'greedy\taverage\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n') logger.info(f"[Write {conv_test_result_file_name}]")