Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions crslab/data/dataset/redial/redial.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def _load_raw_data(self):

def _load_vocab(self):
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))

# Add role tokens
# last_index = len(self.tok2ind)
# self.role_seeker_token_idx = last_index
# self.role_recommender_token_idx = last_index + 1
# self.tok2ind["__Seeker__"] = self.role_seeker_token_idx
# self.tok2ind["__Recommender__"] = self.role_recommender_token_idx

self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}

logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
Expand Down Expand Up @@ -185,6 +193,15 @@ def _augment_and_add(self, raw_conv_dict):
entity_set, word_set = set(), set()
for i, conv in enumerate(raw_conv_dict):
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["movie"], conv["word"]

# Add role token in front of text_tokens
self.role_seeker_token_idx = 2459 # "seeker"
self.role_recommender_token_idx = 1755 # "recommender"
if conv['role'] == 'Seeker':
text_tokens.insert(0, self.role_seeker_token_idx)
else:
text_tokens.insert(0, self.role_recommender_token_idx)

if len(context_tokens) > 0:
conv_dict = {
"role": conv['role'],
Expand Down