diff --git a/README.md b/README.md index 84d2468..6d20098 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,10 @@ by our SIGDIAL 2018 paper: [Zero-Shot Dialog Generation with Cross-Domain Latent See paper for details. The source code and data used for the paper can be found at [here](https://github.com/snakeztc/NeuralDialog-ZSDG). ## Prerequisites - - Python 2.7 + - Python 3.6 - Numpy - NLTK - - progressbar + - progressbar2 ## Usage diff --git a/run.sh b/run.sh index bb7768d..f8be446 100755 --- a/run.sh +++ b/run.sh @@ -27,14 +27,14 @@ for idx in 0 1 2 3 4 5 6 7 8 9; do # # combine OTGY(slot values) file - if [ -f "../1500_data_fixed_0/r_w_b-OTGY.json"]; then - cp ../1500_data_fixed_0/r_w_b-OTGY.json ./ + if [ -f "../1500_data_fixed_0/r_w_b-OTGY.json" ]; then + cp -f ../1500_data_fixed_0/r_w_b-OTGY.json ./ else python3 ../combine_domain.py -data_path ./ -target_domain '' fi - if [ -f "../1500_data_fixed_0/r_w_b_${data_size}m-OTGY.json"]; then - cp ../1500_data_fixed_0/r_w_b_${data_size}m-OTGY.json ./r_w_b_${data_size}m-OTGY.json + if [ -f "../1500_data_fixed_0/r_w_b_${data_size}m-OTGY.json" ]; then + cp -f ../1500_data_fixed_0/r_w_b_${data_size}m-OTGY.json ./r_w_b_${data_size}m-OTGY.json else python3 ../combine_domain.py -data_path ./ fi diff --git a/simdial/agent/system.py b/simdial/agent/system.py index efbe6b8..c6f3da1 100644 --- a/simdial/agent/system.py +++ b/simdial/agent/system.py @@ -6,6 +6,7 @@ from collections import OrderedDict import numpy as np import copy +import sys class BeliefSlot(object): @@ -59,7 +60,12 @@ def add_grounding(self, confirm_conf, disconfirm_conf, turn_id, target_value=Non def get_maxconf_value(self): if len(self.value_map) == 0: return None - max_s, max_v = max([(s, v) for v, s in self.value_map.items()]) + # Substitute 'None' with 0, as 'None' in Python 2 is smaller than any int + value_map = copy.deepcopy(self.value_map) + value_map = {k if k is not None else -1: v for k, v in value_map.items()} + _, max_v = max([(s, v) for v, s in value_map.items()]) + if max_v == -1: + return None return max_v def max_conf(self): diff --git a/simdial/agent/user.py b/simdial/agent/user.py index 2415101..35ad758 100644 --- a/simdial/agent/user.py +++ b/simdial/agent/user.py @@ -73,7 +73,7 @@ def reset_goal(self, sys_goals): def __init__(self, domain, complexity): super(User, self).__init__(domain, complexity) - self.goal_cnt = np.random.choice(complexity.multi_goals.keys(), p=complexity.multi_goals.values()) + self.goal_cnt = np.random.choice(np.asarray(list(complexity.multi_goals.keys())), p=list(complexity.multi_goals.values())) self.goal_ptr = 0 self.usr_constrains, self.sys_goals = self._sample_goal() self.state = self.DialogState(self.sys_goals) @@ -122,7 +122,7 @@ def _increment_goal(self): else: self.goal_ptr += 1 _, self.sys_goals = self._sample_goal() - change_key = np.random.choice(self.usr_constrains.keys()) + change_key = np.random.choice(np.asarray(list(self.usr_constrains.keys()))) change_slot = self.domain.get_usr_slot(change_key) old_value = self.usr_constrains[change_key] old_value = -1 if old_value is None else old_value @@ -163,8 +163,8 @@ def policy(self): if slot_val == self.usr_constrains[slot_type] or self.usr_constrains[slot_type] is None: return None else: - strategy = np.random.choice(self.complexity.reject_style.keys(), - p=self.complexity.reject_style.values()) + strategy = np.random.choice(np.asarray(list(self.complexity.reject_style.keys())), + p=list(self.complexity.reject_style.values())) if strategy == "reject": return Action(UserAct.DISCONFIRM, (slot_type, slot_val)) elif strategy == "reject+inform": @@ -237,8 +237,8 @@ def policy(self): elif self.domain.is_usr_slot(slot_type): if len(self.domain.usr_slots) > 1: - num_informs = np.random.choice(self.complexity.multi_slots.keys(), - p=self.complexity.multi_slots.values(), + num_informs = np.random.choice(np.asarray(list(self.complexity.multi_slots.keys())), + p=list(self.complexity.multi_slots.values()), replace=False) if num_informs > 1: candidates = [k for k, v in self.usr_constrains.items() if k != slot_type and v is not None] diff --git a/simdial/channel.py b/simdial/channel.py index 741966d..674fe2b 100644 --- a/simdial/channel.py +++ b/simdial/channel.py @@ -45,7 +45,7 @@ def transmit(self, actions): elif a.act == UserAct.INFORM: if np.random.rand() > conf: slot, value = a.parameters[0] - choices = range(self.dim_map[slot]) + [None] + choices = list(range(self.dim_map[slot])) + [None] a.parameters[0] = (slot, np.random.choice(choices)) noisy_actions.append(a) diff --git a/simdial/database.py b/simdial/database.py index e0959a0..a56f543 100644 --- a/simdial/database.py +++ b/simdial/database.py @@ -70,7 +70,7 @@ def sample_unique_row(self): :return: a unique row in the searchable table """ unique_rows = np.unique(self.table, axis=0) - idxes = range(len(unique_rows)) + idxes = list(range(len(unique_rows))) np.random.shuffle(idxes) return unique_rows[idxes[0]] diff --git a/simdial/generator.py b/simdial/generator.py index cecc944..ee41bbc 100644 --- a/simdial/generator.py +++ b/simdial/generator.py @@ -39,7 +39,7 @@ def pprint(dialogs, in_json, domain_spec, output_file=None): :param dialogs: a list of dialogs generated :param output_file: None if print to STDOUT. Otherwise write the file in the path """ - f = sys.stdout if output_file is None else open(output_file, "wb") + f = sys.stdout if output_file is None else open(output_file, "w") if in_json: # combo = {'dialogs': dialogs, 'meta': domain_spec.to_dict()} @@ -68,7 +68,7 @@ def print_db(database, in_json, domain_spec, output_file=None): :param database: a database class generated in database.py """ - f = sys.stdout if output_file is None else open(output_file, "wb") + f = sys.stdout if output_file is None else open(output_file, "w") if in_json: combo = [] @@ -93,7 +93,7 @@ def print_db(database, in_json, domain_spec, output_file=None): @staticmethod def print_OTGY(domain_spec, in_json, output_file=None): - f = sys.stdout if output_file is None else open(output_file, "wb") + f = sys.stdout if output_file is None else open(output_file, "w") if in_json: info_dict = {}