From b2fbc89af78b21e1cde0e3f338ec06a9d0dd59e6 Mon Sep 17 00:00:00 2001 From: gml9812 Date: Wed, 17 Nov 2021 00:25:47 +0900 Subject: [PATCH] =?UTF-8?q?interaction=20=EA=B8=B0=EB=8A=A5=20=EC=B6=94?= =?UTF-8?q?=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crslab/data/dataloader/kgsf.py | 14 ++++++ crslab/system/kgsf.py | 84 +++++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/crslab/data/dataloader/kgsf.py b/crslab/data/dataloader/kgsf.py index 6bbcac4..820deb8 100644 --- a/crslab/data/dataloader/kgsf.py +++ b/crslab/data/dataloader/kgsf.py @@ -127,3 +127,17 @@ def conv_batchify(self, batch): def policy_batchify(self, *args, **kwargs): pass + + def conv_interact(self,data): + context_tokens = [truncate(merge_utt(data['context_tokens']), self.context_truncate, truncate_tail=False)] + context_entities = [truncate(data['context_entities'], self.entity_truncate, truncate_tail=False)] + context_words = [truncate(data['context_words'], self.word_truncate, truncate_tail=False)] + response = [add_start_end_token_idx(truncate(data['response'], self.response_truncate - 2), + start_token_idx=self.start_token_idx, + end_token_idx=self.end_token_idx)] + + return (padded_tensor(context_tokens, self.pad_token_idx, pad_tail=False), + padded_tensor(context_entities, self.pad_entity_idx, pad_tail=False), + padded_tensor(context_words, self.pad_word_idx, pad_tail=False), + padded_tensor(response, self.pad_token_idx)) + diff --git a/crslab/system/kgsf.py b/crslab/system/kgsf.py index 7f7b2a6..907ae5f 100644 --- a/crslab/system/kgsf.py +++ b/crslab/system/kgsf.py @@ -186,4 +186,86 @@ def fit(self): self.train_conversation() def interact(self): - pass + self.init_interact() + input_text = self.get_input("en") + while not self.finished: + #자연어로 된 input 처리해 tensor만듬 + KGSF_input = self.process_input(input_text, 'conv') + + #처리한 tensor 모델에 넣어 결과 받는다. + preds = self.model.forward(KGSF_input, 'conv', 'test').tolist()[0] + + #모델이 출력한 결과 다시 자연어로 번역 + p_str = ind2txt(preds, self.ind2tok, self.end_token_idx) + + #여기서 시간 많이, 왜? + token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id(p_str, 'conv') + + #차후 필요하면 수정 + self.update_context('conv', token_ids, entity_ids, movie_ids, word_ids) + + print(f"[Response]:\n{p_str}") + input_text = self.get_input("en") + + def init_interact(self): + self.finished = False + self.context = { + 'rec': {}, + 'conv': {} + } + for key in self.context: + self.context[key]['context_tokens'] = [] + self.context[key]['response'] = [] + self.context[key]['context_entities'] = [] + self.context[key]['context_words'] = [] + self.context[key]['context_items'] = [] + self.context[key]['items'] = [] + self.context[key]['entity_set'] = set() + self.context[key]['word_set'] = set() + + def process_input(self, input_text, stage): + token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id(input_text, stage) + + self.update_context(stage, token_ids, entity_ids, movie_ids, word_ids) + + data = {'role': 'Seeker', 'context_tokens': self.context[stage]['context_tokens'], + 'response': self.context[stage]['response'], + 'context_entities': self.context[stage]['context_entities'], + 'context_words': self.context[stage]['context_words'], + 'context_items': self.context[stage]['context_items'], + 'items': self.context[stage]['context_items']} + + dataloader = get_dataloader(self.opt, data, self.vocab) + + if stage == 'conv': + data = dataloader.conv_interact(data) + + data = [ele.to(self.device) if isinstance(ele, torch.Tensor) else ele for ele in data] + return data + + def convert_to_id(self, text, stage): + #token의 경우 text를 단어별로 분해한 것. ex: 'jack is having dinner' => ['jac','is','having','dinner'] + tokens = self.tokenize(text, 'nltk') + + #임시 + """ + if tokens[0] == '__start__': + del tokens[0] + """ + + #'i like the movie avengers' 입력하면 + #entities:['', ''] + #words:['juliet', 'saintlike', 'buy_presents_for_others', 'movie', 'avengers'] + + #token의 길이에 비례해 link 함수에서 엄청난 시간이 걸린다. 왜? + entities = self.link(tokens, self.side_data['entity_kg']['entity']) + words = self.link(tokens, self.side_data['word_kg']['entity']) + + token_ids = [self.vocab['tok2ind'].get(token, self.vocab['unk']) for token in tokens] + entity_ids = [self.vocab['entity2id'][entity] for entity in entities if + entity in self.vocab['entity2id']] + + movie_ids = [entity_id for entity_id in entity_ids if entity_id in self.item_ids] + word_ids = [self.vocab['word2id'][word] for word in words if word in self.vocab['word2id']] + + return token_ids, entity_ids, movie_ids, word_ids \ No newline at end of file