Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ by our SIGDIAL 2018 paper: [Zero-Shot Dialog Generation with Cross-Domain Latent
See paper for details. The source code and data used for the paper can be found at [here](https://github.com/snakeztc/NeuralDialog-ZSDG).

## Prerequisites
- Python 2.7
- Python 3.6
- Numpy
- NLTK
- progressbar
- progressbar2


## Usage
Expand Down
8 changes: 4 additions & 4 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ for idx in 0 1 2 3 4 5 6 7 8 9; do


# # combine OTGY(slot values) file
if [ -f "../1500_data_fixed_0/r_w_b-OTGY.json"]; then
cp ../1500_data_fixed_0/r_w_b-OTGY.json ./
if [ -f "../1500_data_fixed_0/r_w_b-OTGY.json" ]; then
cp -f ../1500_data_fixed_0/r_w_b-OTGY.json ./
else
python3 ../combine_domain.py -data_path ./ -target_domain ''
fi

if [ -f "../1500_data_fixed_0/r_w_b_${data_size}m-OTGY.json"]; then
cp ../1500_data_fixed_0/r_w_b_${data_size}m-OTGY.json ./r_w_b_${data_size}m-OTGY.json
if [ -f "../1500_data_fixed_0/r_w_b_${data_size}m-OTGY.json" ]; then
cp -f ../1500_data_fixed_0/r_w_b_${data_size}m-OTGY.json ./r_w_b_${data_size}m-OTGY.json
else
python3 ../combine_domain.py -data_path ./
fi
Expand Down
8 changes: 7 additions & 1 deletion simdial/agent/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import OrderedDict
import numpy as np
import copy
import sys


class BeliefSlot(object):
Expand Down Expand Up @@ -59,7 +60,12 @@ def add_grounding(self, confirm_conf, disconfirm_conf, turn_id, target_value=Non
def get_maxconf_value(self):
if len(self.value_map) == 0:
return None
max_s, max_v = max([(s, v) for v, s in self.value_map.items()])
# Substitute 'None' with 0, as 'None' in Python 2 is smaller than any int
value_map = copy.deepcopy(self.value_map)
value_map = {k if k is not None else -1: v for k, v in value_map.items()}
_, max_v = max([(s, v) for v, s in value_map.items()])
if max_v == -1:
return None
return max_v

def max_conf(self):
Expand Down
12 changes: 6 additions & 6 deletions simdial/agent/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def reset_goal(self, sys_goals):

def __init__(self, domain, complexity):
super(User, self).__init__(domain, complexity)
self.goal_cnt = np.random.choice(complexity.multi_goals.keys(), p=complexity.multi_goals.values())
self.goal_cnt = np.random.choice(np.asarray(list(complexity.multi_goals.keys())), p=list(complexity.multi_goals.values()))
self.goal_ptr = 0
self.usr_constrains, self.sys_goals = self._sample_goal()
self.state = self.DialogState(self.sys_goals)
Expand Down Expand Up @@ -122,7 +122,7 @@ def _increment_goal(self):
else:
self.goal_ptr += 1
_, self.sys_goals = self._sample_goal()
change_key = np.random.choice(self.usr_constrains.keys())
change_key = np.random.choice(np.asarray(list(self.usr_constrains.keys())))
change_slot = self.domain.get_usr_slot(change_key)
old_value = self.usr_constrains[change_key]
old_value = -1 if old_value is None else old_value
Expand Down Expand Up @@ -163,8 +163,8 @@ def policy(self):
if slot_val == self.usr_constrains[slot_type] or self.usr_constrains[slot_type] is None:
return None
else:
strategy = np.random.choice(self.complexity.reject_style.keys(),
p=self.complexity.reject_style.values())
strategy = np.random.choice(np.asarray(list(self.complexity.reject_style.keys())),
p=list(self.complexity.reject_style.values()))
if strategy == "reject":
return Action(UserAct.DISCONFIRM, (slot_type, slot_val))
elif strategy == "reject+inform":
Expand Down Expand Up @@ -237,8 +237,8 @@ def policy(self):

elif self.domain.is_usr_slot(slot_type):
if len(self.domain.usr_slots) > 1:
num_informs = np.random.choice(self.complexity.multi_slots.keys(),
p=self.complexity.multi_slots.values(),
num_informs = np.random.choice(np.asarray(list(self.complexity.multi_slots.keys())),
p=list(self.complexity.multi_slots.values()),
replace=False)
if num_informs > 1:
candidates = [k for k, v in self.usr_constrains.items() if k != slot_type and v is not None]
Expand Down
2 changes: 1 addition & 1 deletion simdial/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def transmit(self, actions):
elif a.act == UserAct.INFORM:
if np.random.rand() > conf:
slot, value = a.parameters[0]
choices = range(self.dim_map[slot]) + [None]
choices = list(range(self.dim_map[slot])) + [None]
a.parameters[0] = (slot, np.random.choice(choices))

noisy_actions.append(a)
Expand Down
2 changes: 1 addition & 1 deletion simdial/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def sample_unique_row(self):
:return: a unique row in the searchable table
"""
unique_rows = np.unique(self.table, axis=0)
idxes = range(len(unique_rows))
idxes = list(range(len(unique_rows)))
np.random.shuffle(idxes)
return unique_rows[idxes[0]]

Expand Down
6 changes: 3 additions & 3 deletions simdial/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def pprint(dialogs, in_json, domain_spec, output_file=None):
:param dialogs: a list of dialogs generated
:param output_file: None if print to STDOUT. Otherwise write the file in the path
"""
f = sys.stdout if output_file is None else open(output_file, "wb")
f = sys.stdout if output_file is None else open(output_file, "w")

if in_json:
# combo = {'dialogs': dialogs, 'meta': domain_spec.to_dict()}
Expand Down Expand Up @@ -68,7 +68,7 @@ def print_db(database, in_json, domain_spec, output_file=None):

:param database: a database class generated in database.py
"""
f = sys.stdout if output_file is None else open(output_file, "wb")
f = sys.stdout if output_file is None else open(output_file, "w")

if in_json:
combo = []
Expand All @@ -93,7 +93,7 @@ def print_db(database, in_json, domain_spec, output_file=None):

@staticmethod
def print_OTGY(domain_spec, in_json, output_file=None):
f = sys.stdout if output_file is None else open(output_file, "wb")
f = sys.stdout if output_file is None else open(output_file, "w")

if in_json:
info_dict = {}
Expand Down