Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config/crs/kgsf/redial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ pretrain:
rec:
epoch: 9
batch_size: 128
test_print_every_batch: true
optimizer:
name: Adam
lr: !!float 1e-3
conv:
epoch: 90
batch_size: 128
test_print_every_batch: true
optimizer:
name: Adam
lr: !!float 1e-3
Expand Down
1 change: 1 addition & 0 deletions crslab/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
MODEL_PATH = os.path.join(DATA_PATH, 'model')
PRETRAIN_PATH = os.path.join(MODEL_PATH, 'pretrain')
EMBEDDING_PATH = os.path.join(DATA_PATH, 'embedding')
CSV_PATH = os.path.join(ROOT_PATH, 'test_result')
2 changes: 1 addition & 1 deletion crslab/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_dataset(opt, tokenize, restore, save) -> BaseDataset:
if dataset in dataset_register_table:
return dataset_register_table[dataset](opt, tokenize, restore, save)
else:
raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented')
raise NotImplementedError(f'The dataset [{dataset}] has not been implemented')


def get_dataloader(opt, dataset, vocab) -> BaseDataLoader:
Expand Down
35 changes: 29 additions & 6 deletions crslab/data/dataloader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseDataLoader(ABC):

"""

def __init__(self, opt, dataset):
def __init__(self, opt, dataset, vocab=None):
"""
Args:
opt (Config or dict): config for dataloader or the whole system.
Expand All @@ -32,17 +32,19 @@ def __init__(self, opt, dataset):
"""
self.opt = opt
self.dataset = dataset
self.vocab = vocab
self.scale = opt.get('scale', 1)
assert 0 < self.scale <= 1

def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None):
def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None, file=None):
"""Collate batch data for system to fit

Args:
batch_fn (func): function to collate data
batch_size (int):
shuffle (bool, optional): Defaults to True.
process_fn (func, optional): function to process dataset before batchify. Defaults to None.
file (file descriptor, optional): file descriptor where to print input data for each batch.

Yields:
tuple or dict of torch.Tensor: batch data for system to fit
Expand All @@ -63,9 +65,30 @@ def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None):
for start_idx in tqdm(range(batch_num)):
batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size]
batch = [dataset[idx] for idx in batch_idx]

if file:
for conv_dict in batch:
file.write('"')
for sentence_in_index in conv_dict['context_tokens']:
sentence = " ".join([self.vocab['ind2tok'][index] for index in sentence_in_index])
file.write(f'{sentence}\n')
file.write('"\t')

file.write('"')
entities = "\n".join(
[self.vocab['id2entity'][entity_index] for entity_index in conv_dict['context_entities']]
)
file.write(f'{entities}"\t')

file.write('"')
words = "\n".join(
[self.vocab['id2word'][word_index] for word_index in conv_dict['context_words']]
)
file.write(f'{words}"\t')

yield batch_fn(batch)

def get_conv_data(self, batch_size, shuffle=True):
def get_conv_data(self, batch_size, shuffle=True, file=None):
"""get_data wrapper for conversation.

You can implement your own process_fn in ``conv_process_fn``, batch_fn in ``conv_batchify``.
Expand All @@ -78,9 +101,9 @@ def get_conv_data(self, batch_size, shuffle=True):
tuple or dict of torch.Tensor: batch data for conversation.

"""
return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn)
return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn, file)

def get_rec_data(self, batch_size, shuffle=True):
def get_rec_data(self, batch_size, shuffle=True, file=None):
"""get_data wrapper for recommendation.

You can implement your own process_fn in ``rec_process_fn``, batch_fn in ``rec_batchify``.
Expand All @@ -93,7 +116,7 @@ def get_rec_data(self, batch_size, shuffle=True):
tuple or dict of torch.Tensor: batch data for recommendation.

"""
return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn)
return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn, file)

def get_policy_data(self, batch_size, shuffle=True):
"""get_data wrapper for policy.
Expand Down
2 changes: 1 addition & 1 deletion crslab/data/dataloader/kgsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, opt, dataset, vocab):
vocab (dict): all kinds of useful size, idx and map between token and idx.

"""
super().__init__(opt, dataset)
super().__init__(opt, dataset, vocab)
self.n_entity = vocab['n_entity']
self.pad_token_idx = vocab['pad']
self.start_token_idx = vocab['start']
Expand Down
2 changes: 2 additions & 0 deletions crslab/data/dataset/redial/redial.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _load_data(self):
'entity2id': self.entity2id,
'id2entity': self.id2entity,
'word2id': self.word2id,
'id2word': self.id2word,
'vocab_size': len(self.tok2ind),
'n_entity': self.n_entity,
'n_word': self.n_word,
Expand Down Expand Up @@ -127,6 +128,7 @@ def _load_other_data(self):
# conceptNet
# {concept: concept_id}
self.word2id = json.load(open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8'))
self.id2word = {idx: word for word, idx in self.word2id.items()}
self.n_word = max(self.word2id.values()) + 1
# {relation\t concept \t concept}
self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8')
Expand Down
12 changes: 10 additions & 2 deletions crslab/evaluator/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,23 @@ def gen_evaluate(self, hyp, refs):
self.gen_metrics.add('average', EmbeddingAverage.compute(hyp_emb, ref_embs))
self.gen_metrics.add('extreme', VectorExtrema.compute(hyp_emb, ref_embs))

def report(self, epoch=-1, mode='test'):
def report(self, epoch=-1, mode='test', file=None):
for k, v in self.dist_set.items():
self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt))
reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()]

if self.tensorboard and mode != 'test':
for idx, task_report in enumerate(reports):
for each_metric, value in task_report.items():
self.writer.add_scalars(f'{self.reports_name[idx]}/{each_metric}', {mode: value.value()}, epoch)
logger.info('\n' + nice_report(aggregate_unnamed_reports(reports)))

if file:
for each_metric, value in self.rec_metrics.report().items():
file.write(f'{str(value.value())}\t')
for each_metric, value in self.gen_metrics.report().items():
file.write(f'{str(value.value())}\t')
else:
logger.info('\n' + nice_report(aggregate_unnamed_reports(reports)))

def reset_metrics(self):
# rec
Expand Down
10 changes: 9 additions & 1 deletion crslab/quick_start/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def run_crslab(config, save_data=False, restore_data=False, save_system=False, restore_system=False,
interact=False, debug=False, tensorboard=False):
interact=False, test=False, debug=False, tensorboard=False):
"""A fast running api, which includes the complete process of training and testing models on specified datasets.

Args:
Expand All @@ -26,6 +26,7 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r
save_system (bool): whether to save system. Defaults to False.
restore_system (bool): whether to restore system. Defaults to False.
interact (bool): whether to interact with the system. Defaults to False.
test (bool): whether to test with the saved system. Defaults to False.
debug (bool): whether to debug the system. Defaults to False.

.. _Github repo:
Expand Down Expand Up @@ -64,11 +65,18 @@ def run_crslab(config, save_data=False, restore_data=False, save_system=False, r
train_dataloader[task] = get_dataloader(config, train_data, vocab[task])
valid_dataloader[task] = get_dataloader(config, valid_data, vocab[task])
test_dataloader[task] = get_dataloader(config, test_data, vocab[task])

# test need saved system
if test:
restore_system = True

# system
CRS = get_system(config, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system,
interact, debug, tensorboard)
if interact:
CRS.interact()
elif test:
CRS.test()
else:
CRS.fit()
if save_system:
Expand Down
3 changes: 3 additions & 0 deletions crslab/system/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,6 @@ def link(self, tokens, entities):
if entity:
linked_entities.append(entity[0])
return linked_entities

def test(self):
raise NotImplementedError('Method test is not implemented.')
69 changes: 66 additions & 3 deletions crslab/system/kgsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from loguru import logger

from crslab.config import CSV_PATH
from crslab.evaluator.metrics.base import AverageMetric
from crslab.evaluator.metrics.gen import PPLMetric
from crslab.system.base import BaseSystem
Expand Down Expand Up @@ -55,6 +56,8 @@ def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, voc
self.rec_batch_size = self.rec_optim_opt['batch_size']
self.conv_batch_size = self.conv_optim_opt['batch_size']

self.csv_path = os.path.join(CSV_PATH, "kgsf", opt['dataset'])

def rec_evaluate(self, rec_predict, item_label):
rec_predict = rec_predict.cpu()
rec_predict = rec_predict[:, self.item_ids]
Expand All @@ -65,15 +68,17 @@ def rec_evaluate(self, rec_predict, item_label):
item = self.item_ids.index(item)
self.evaluator.rec_evaluate(rec_rank, item)

def conv_evaluate(self, prediction, response):
def conv_evaluate(self, prediction, response, file=None):
prediction = prediction.tolist()
response = response.tolist()
for p, r in zip(prediction, response):
p_str = ind2txt(p, self.ind2tok, self.end_token_idx)
r_str = ind2txt(r, self.ind2tok, self.end_token_idx)
if file:
file.write(f'{p_str}\t{r_str}\t')
self.evaluator.gen_evaluate(p_str, [r_str])

def step(self, batch, stage, mode):
def step(self, batch, stage, mode, file=None):
batch = [ele.to(self.device) for ele in batch]
if stage == 'pretrain':
info_loss = self.model.forward(batch, stage, mode)
Expand Down Expand Up @@ -108,7 +113,7 @@ def step(self, batch, stage, mode):
self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss))
else:
pred = self.model.forward(batch, stage, mode)
self.conv_evaluate(pred, batch[-1])
self.conv_evaluate(pred, batch[-1], file)
else:
raise

Expand Down Expand Up @@ -180,10 +185,68 @@ def train_conversation(self):
self.step(batch, stage='conv', mode='test')
self.evaluator.report(mode='test')

def test_recommendation(self):
self.init_optim(self.rec_optim_opt, self.model.parameters())

logger.info('[Recommendation Test]')
with torch.no_grad():
if self.rec_optim_opt.get('test_print_every_batch'):
rec_test_result_file_name = os.path.join(self.csv_path, 'rec.csv')
os.makedirs(os.path.dirname(rec_test_result_file_name), exist_ok=True)
with open(rec_test_result_file_name, 'w', encoding='utf-8', newline='') as f:
f.write('input context\tentities\twords\thit@1\tndcg@1\tmrr@1\t'
'hit@10\tndcg@10\tmrr@10\thit@50\tndcg@50\tmrr@50\n')
logger.info(f"[Write {rec_test_result_file_name}]")

for batch in self.test_dataloader.get_rec_data(1, shuffle=False, file=f):
self.evaluator.reset_metrics()
self.step(batch, stage='rec', mode='test')
self.evaluator.report(mode='test', file=f)
f.write('\n')
f.close()
else:
self.evaluator.reset_metrics()
for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):
self.step(batch, stage='rec', mode='test')
self.evaluator.report(mode='test')

def test_conversation(self):
if os.environ["CUDA_VISIBLE_DEVICES"] == '-1':
self.model.freeze_parameters()
else:
self.model.module.freeze_parameters()
self.init_optim(self.conv_optim_opt, self.model.parameters())

logger.info('[Conversation Test]')
with torch.no_grad():
if self.conv_optim_opt.get('test_print_every_batch'):
conv_test_result_file_name = os.path.join(self.csv_path, 'conv.csv')
os.makedirs(os.path.dirname(conv_test_result_file_name), exist_ok=True)
with open(conv_test_result_file_name, 'w', encoding='utf-8', newline='') as f:
f.write('input context\tentities\twords\tprediction\tresponse\tf1\tbleu@1\tbleu@2\tbleu@3\tbleu@4\t'
'greedy\taverage\textreme\tdist@1\tdist@2\tdist@3\tdist@4\n')
logger.info(f"[Write {conv_test_result_file_name}]")

for batch in self.test_dataloader.get_conv_data(1, shuffle=False, file=f):
self.evaluator.reset_metrics()
self.step(batch, stage='conv', mode='test', file=f)
self.evaluator.report(mode='test', file=f)
f.write('\n')
f.close()
else:
self.evaluator.reset_metrics()
for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):
self.step(batch, stage='conv', mode='test')
self.evaluator.report(mode='test')

def fit(self):
self.pretrain()
self.train_recommender()
self.train_conversation()

def test(self):
self.test_recommendation()
self.test_conversation()

def interact(self):
pass
4 changes: 3 additions & 1 deletion run_crslab.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
help='use valid dataset to debug your system')
parser.add_argument('-i', '--interact', action='store_true',
help='interact with your system instead of training')
parser.add_argument('-t', '--test', action='store_true',
help='test with your saved system instead of training. Note: saved system is required')
parser.add_argument('-tb', '--tensorboard', action='store_true',
help='enable tensorboard to monitor train performance')
args, _ = parser.parse_known_args()
Expand All @@ -41,4 +43,4 @@
from crslab.quick_start import run_crslab

run_crslab(config, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact,
args.debug, args.tensorboard)
args.test, args.debug, args.tensorboard)