From 83d80c5a19326a516b4ab54339c27f8cac349693 Mon Sep 17 00:00:00 2001 From: Raul Puri Date: Thu, 14 Jun 2018 14:31:35 -0700 Subject: [PATCH 1/4] Update lazy_loader.py --- data_utils/lazy_loader.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/data_utils/lazy_loader.py b/data_utils/lazy_loader.py index 0049636..3fdc876 100644 --- a/data_utils/lazy_loader.py +++ b/data_utils/lazy_loader.py @@ -29,9 +29,13 @@ def make_lazy(path, strs, data_type='data'): datapath = os.path.join(lazypath, data_type) lenpath = os.path.join(lazypath, data_type+'.len.pkl') if not torch.distributed._initialized or torch.distributed.get_rank() == 0: - with open(datapath, 'w') as f: - f.write(''.join(strs)) - str_ends = list(accumulate(map(len, strs))) + with open(datapath, 'wb') as f: + str_ends = [] + str_cnt = 0 + for s in strs: + f.write(s.encode('utf-8')) + str_cnt += len(s) + str_ends.append(str_cnt) pkl.dump(str_ends, open(lenpath, 'wb')) else: while not os.path.exists(lenpath): @@ -53,7 +57,7 @@ def __init__(self, path, data_type='data', mem_map=False): lazypath = get_lazy_path(path) datapath = os.path.join(lazypath, data_type) #get file where array entries are concatenated into one big string - self._file = open(datapath, 'r') + self._file = open(datapath, 'rb') self.file = self._file #memory map file if necessary self.mem_map = mem_map @@ -61,6 +65,7 @@ def __init__(self, path, data_type='data', mem_map=False): self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ) lenpath = os.path.join(lazypath, data_type+'.len.pkl') self.ends = pkl.load(open(lenpath, 'rb')) + self.read_lock = Lock() def __getitem__(self, index): """read file and splice strings based on string ending array `ends` """ @@ -88,6 +93,7 @@ def file_read(self, start=0, end=None): """read specified portion of file""" #TODO: Solve race condition #Seek to start of file read + self.read_lock.acquire() self.file.seek(start) ##### Getting context-switched here #read to end of file if no end point provided @@ -96,8 +102,10 @@ def file_read(self, start=0, end=None): #else read amount needed to reach end point else: rtn = self.file.read(end-start) + self.read_lock.release() #TODO: @raulp figure out mem map byte string bug #if mem map'd need to decode byte string to string + rtn = rtn.decode('utf-8') if self.mem_map: rtn = rtn.decode('unicode_escape') return rtn From 41ea6800fcc63f188192da76edae13efecd5bcbd Mon Sep 17 00:00:00 2001 From: Raul Puri Date: Fri, 15 Jun 2018 07:04:01 -0700 Subject: [PATCH 2/4] added back import statement --- data_utils/lazy_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/data_utils/lazy_loader.py b/data_utils/lazy_loader.py index 3fdc876..6a67bec 100644 --- a/data_utils/lazy_loader.py +++ b/data_utils/lazy_loader.py @@ -3,6 +3,7 @@ import pickle as pkl import time from itertools import accumulate +from threading import Lock import torch From fe970f1c9e3e2a9971655451a5fb1771377ddda3 Mon Sep 17 00:00:00 2001 From: Raul Puri Date: Fri, 15 Jun 2018 11:14:17 -0700 Subject: [PATCH 3/4] change lazy loader read decoding --- data_utils/lazy_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data_utils/lazy_loader.py b/data_utils/lazy_loader.py index 6a67bec..110159c 100644 --- a/data_utils/lazy_loader.py +++ b/data_utils/lazy_loader.py @@ -106,7 +106,8 @@ def file_read(self, start=0, end=None): self.read_lock.release() #TODO: @raulp figure out mem map byte string bug #if mem map'd need to decode byte string to string - rtn = rtn.decode('utf-8') + #rtn = rtn.decode('utf-8') + rtn = str(rtn) if self.mem_map: rtn = rtn.decode('unicode_escape') return rtn From dbeb18746d898c8b152dd3cfeaaf88f9b9ac765f Mon Sep 17 00:00:00 2001 From: Khabbab Nazar Date: Wed, 20 Jun 2018 15:04:51 +0530 Subject: [PATCH 4/4] Real-time word by word heatmap generation of user input. --- generate.py | 83 +++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 22 deletions(-) diff --git a/generate.py b/generate.py index 94c1d2d..44363f2 100644 --- a/generate.py +++ b/generate.py @@ -24,6 +24,7 @@ import seaborn as sns sns.set_style({'font.family': 'monospace'}) +import sys, termios, tty, cv2 parser = argparse.ArgumentParser(description='PyTorch Sentiment Discovery Generation/Visualization') @@ -43,11 +44,11 @@ parser.add_argument('--tied', action='store_true', help='tie the word embedding and softmax weights') parser.add_argument('--load_model', type=str, default='model.pt', - help='model checkpoint to use') + help='model checkpoint to use') #use imdb_clf.pt model provided in the readme. parser.add_argument('--save', type=str, default='generated.txt', help='output file for generated text') parser.add_argument('--gen_length', type=int, default='1000', - help='number of tokens to generate') + help='number of tokens to generate') #use --gen_length -1 parser.add_argument('--seed', type=int, default=-1, help='random seed') parser.add_argument('--temperature', type=float, default=1.0, @@ -63,8 +64,9 @@ help='generates heatmap of main neuron activation [not working yet]') parser.add_argument('--overwrite', type=float, default=None, help='Overwrite value of neuron s.t. generated text reads as a +1/-1 classification') -parser.add_argument('--text', default='', - help='warm up generation with specified text first') +#dont need --text arg. +#parser.add_argument('--text', default='', +# help='warm up generation with specified text first') args = parser.parse_args() args.data_size = 256 @@ -109,7 +111,7 @@ def get_neuron_and_polarity(sd, neuron): return neuron, 1 if neuron is None: val, neuron = torch.max(torch.abs(weight[0].float()), 0) - neuron = neuron[0] + neuron = neuron.item() val = weight[0][neuron] if val >= 0: polarity = 1 @@ -196,13 +198,24 @@ def make_heatmap(text, values, save=None, polarity=1): plt.figure(figsize=(cell_width*n_limit, cell_height*num_rows)) hmap=sns.heatmap(values, annot=text, mask=mask, fmt='', vmin=-1, vmax=1, cmap='RdYlGn', xticklabels=False, yticklabels=False, cbar=False) - plt.tight_layout() + #plt.tight_layout() if save is not None: plt.savefig(save) # clear plot for next graph since we returned `hmap` plt.clf() + plt.close() return hmap +#return each character entered by the user +def getchar(): + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(sys.stdin.fileno()) + ch = sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return ch neuron, polarity = get_neuron_and_polarity(sd, args.neuron) neuron = neuron if args.visualize or args.overwrite is not None else None @@ -224,19 +237,45 @@ def make_heatmap(text, values, save=None, polarity=1): outchrs = [] outvals = [] -#with open(args.save, 'w') as outf: -with torch.no_grad(): - if args.text != '': - chrs, vals = process_text(args.text, model, input, args.temperature, neuron, mask, args.overwrite, polarity) - outchrs += chrs - outvals += vals - chrs, vals = generate(args.gen_length, model, input, args.temperature, neuron, mask, args.overwrite, polarity) - outchrs += chrs - outvals += vals -outstr = ''.join(outchrs) -print(outstr) -with open(args.save, 'w') as f: - f.write(outstr) - -if args.visualize: - make_heatmap(outchrs, outvals, os.path.splitext(args.save)[0]+'.png', polarity) +input_chars = [] +text = "" +print("Enter Text:") + +#In this loop, word by word user input is processed for heatmap generation. +#To exit from this loop, press esc. +while True: + sys.stdout.flush() + c = getchar() + + if (c == "\x1b"): + print("\n") + exit(0) + elif (c == "\x7f" and len(input_chars) > 0): + input_chars.pop() + sys.stdout.write("\b \b") + continue + elif (c=='\r'): + print() + continue + print(c,end='') + input_chars.append(c) + text = ''.join(input_chars) + + if (c == " " or c == "." or c == "!" or c == "@" or c == "#" or c == "$" or c == "%" or c == "&" or c == "*" or c == "?"): + #with open(args.save, 'w') as outf: + with torch.no_grad(): + if text != '': + chrs, vals = process_text(text, model, input, args.temperature, neuron, mask, args.overwrite, polarity) + outchrs += chrs + outvals += vals + del input_chars[:] + chrs, vals = generate(args.gen_length, model, input, args.temperature, neuron, mask, args.overwrite, polarity) + outchrs += chrs + outvals += vals + if args.visualize: + make_heatmap(outchrs, outvals, os.path.splitext(args.save)[0]+'.png', polarity) + output_img = cv2.imread(os.path.splitext(args.save)[0]+'.png') + cv2.imshow("output",output_img) + cv2.waitKey(1) + if 0xFF == ord('q'): + sys.exit()