diff --git a/crslab/data/dataset/redial/redial.py b/crslab/data/dataset/redial/redial.py index c84565b..9089a95 100644 --- a/crslab/data/dataset/redial/redial.py +++ b/crslab/data/dataset/redial/redial.py @@ -108,6 +108,14 @@ def _load_raw_data(self): def _load_vocab(self): self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + + # Add role tokens + # last_index = len(self.tok2ind) + # self.role_seeker_token_idx = last_index + # self.role_recommender_token_idx = last_index + 1 + # self.tok2ind["__Seeker__"] = self.role_seeker_token_idx + # self.tok2ind["__Recommender__"] = self.role_recommender_token_idx + self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") @@ -185,6 +193,15 @@ def _augment_and_add(self, raw_conv_dict): entity_set, word_set = set(), set() for i, conv in enumerate(raw_conv_dict): text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["movie"], conv["word"] + + # Add role token in front of text_tokens + self.role_seeker_token_idx = 2459 # "seeker" + self.role_recommender_token_idx = 1755 # "recommender" + if conv['role'] == 'Seeker': + text_tokens.insert(0, self.role_seeker_token_idx) + else: + text_tokens.insert(0, self.role_recommender_token_idx) + if len(context_tokens) > 0: conv_dict = { "role": conv['role'],