From faf9955593f6f323c86ea96539103cb648b4627a Mon Sep 17 00:00:00 2001 From: Nick Doiron Date: Mon, 16 Jan 2023 19:27:34 -0700 Subject: [PATCH 1/2] time control tests --- .gitignore | 3 +- README.md | 5 +- decoder_ring/all-demo.py | 113 +++++ decoder_ring/basic-demo.py | 65 --- decoder_ring/decoders.py | 25 + decoder_ring/timecontrol/decode.py | 728 +++++++++++++++++++++++++++++ decoder_ring/timecontrol/encode.py | 625 +++++++++++++++++++++++++ scratch.md | 9 +- 8 files changed, 1503 insertions(+), 70 deletions(-) create mode 100644 decoder_ring/all-demo.py delete mode 100644 decoder_ring/basic-demo.py create mode 100644 decoder_ring/timecontrol/decode.py create mode 100644 decoder_ring/timecontrol/encode.py diff --git a/.gitignore b/.gitignore index 4c3cdee..9dfa17e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ __pycache__ build/ dist/ -*.egg-info \ No newline at end of file +*.egg-info +lightning_logs/ diff --git a/README.md b/README.md index 617b737..f781c80 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,11 @@ I would like to expand on the documentation in all of the decoder options, links - ContrastiveSearch (params: random_seed, penalty_alpha, top_k) - GreedyDecoder - RandomSampling (params: random_seed) +- TimeControl (params: trained_encoder, random_seed) - TypicalDecoder (params: random_seed, typical_p) +Modules: TimeControl requires: Datasets, PyTorch Lightning, TQDM + ### Writer Examples (text input and output) ```python @@ -73,4 +76,4 @@ typical_output_2 = decoder3.generate_text( ## License -Apache license for compatibility with the Transformers library \ No newline at end of file +Apache license for compatibility with the Transformers library diff --git a/decoder_ring/all-demo.py b/decoder_ring/all-demo.py new file mode 100644 index 0000000..931b182 --- /dev/null +++ b/decoder_ring/all-demo.py @@ -0,0 +1,113 @@ +import os + +from transformers import AutoModelForCausalLM, AutoTokenizer +import pytorch_lightning as pl # for TimeControl + +from decoders import ( + BasicWriter, + BeamSearch, + ContrastiveSearch, + GreedyDecoder, + RandomSampling, + TimeControl, + TypicalDecoder, +) +from timecontrol.decode import GPT2TimeLMHeadModel +from timecontrol.encode import BrownianBridgeSystem + +## Writer test +basic = BasicWriter("gpt2", RandomSampling) +writer_output = basic.write_text( + prompt="Hello, my name is", + max_length=20, +) +print(writer_output) + +## HuggingFace tokenized string +# model = AutoModelForCausalLM.from_pretrained("gpt2") +tokenizer = AutoTokenizer.from_pretrained("gpt2") +content = tokenizer.encode("Hello, my name is", return_tensors="pt") + +## Decoder tests +# decoder1 = GreedyDecoder(model) +# greedy_output = decoder1.generate_text( +# prompt=content, +# max_length=20, +# ) +# txtop = tokenizer.decode(greedy_output[0], skip_special_tokens=True) +# print(txtop) + +# without random_seed, or set_random_seed, this will call logging.warn +# decoder2 = RandomSampling(model, random_seed=603) +# sampling_output = decoder2.generate_text( +# prompt=content, +# max_length=20, +# ) +# txtop2 = tokenizer.decode(sampling_output[0], skip_special_tokens=True) +# print(txtop2) +# +# decoder3 = TypicalDecoder(model, random_seed=603, typical_p=0.4) +# typical_output = decoder3.generate_text( +# prompt=content, +# max_length=20, +# ) +# txtop3 = tokenizer.decode(typical_output[0], skip_special_tokens=True) +# print(txtop3) +# +# decoder4 = ContrastiveSearch(model, random_seed=603, penalty_alpha=0.4, top_k=4) +# contrastive_output = decoder4.generate_text( +# prompt=content, +# max_length=20, +# ) +# txtop4 = tokenizer.decode(contrastive_output[0], skip_special_tokens=True) +# print(txtop4) +# +# decoder5 = BeamSearch(model, early_stopping=True, num_beams=3) +# beam_output = decoder5.generate_text( +# prompt=content, +# max_length=20, +# ) +# txtop5 = tokenizer.decode(beam_output[0], skip_special_tokens=True) +# print(txtop5) + +tc_model = GPT2TimeLMHeadModel.from_pretrained("gpt2") +trainer = pl.Trainer( + # gpus=1, + max_epochs=1, + min_epochs=1, +) +trainer.fit( + BrownianBridgeSystem( + # params via https://github.com/rosewang2008/language_modeling_via_stochastic_processes/blob/main/language_modeling_via_stochastic_processes/config/encoder/brownian.yaml + { + "data_params": { + # "k": 5, + "name": "recipe", + "path": "~/Downloads/recipe/", + }, + "loss_params": { + "name": "simclr", + }, + "model_params": { + "eps": 1e-6, + "hidden_size": 128, + "latent_dim": 32, + "name": "gpt2", + }, + "optim_params": { + "batch_size": 32, + "learning_rate": 0.0001, + "momentum": 0.9, + }, + } + ) +) +trainer.save("./tcencoder") + +decoder6 = TimeControl(tc_model, encoder=trainer.model, random_seed=22) +tc_output = decoder6.generate_text( + prompt=content, + max_length=20, +) +txtop6 = tokenizer.decode(tc_output[0], skip_special_tokens=True) +print(txtop6) diff --git a/decoder_ring/basic-demo.py b/decoder_ring/basic-demo.py deleted file mode 100644 index aa8a7b8..0000000 --- a/decoder_ring/basic-demo.py +++ /dev/null @@ -1,65 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer - -from decoders import ( - BasicWriter, - BeamSearch, - ContrastiveSearch, - GreedyDecoder, - RandomSampling, - TypicalDecoder, -) - -## Writer test -basic = BasicWriter("gpt2", RandomSampling) -writer_output = basic.write_text( - prompt="Hello, my name is", - max_length=20, -) -print(writer_output) - -## HuggingFace tokenized string -model = AutoModelForCausalLM.from_pretrained("gpt2") -tokenizer = AutoTokenizer.from_pretrained("gpt2") -content = tokenizer.encode("Hello, my name is", return_tensors="pt") - -## Decoder tests -decoder1 = GreedyDecoder(model) -greedy_output = decoder1.generate_text( - prompt=content, - max_length=20, -) -txtop = tokenizer.decode(greedy_output[0], skip_special_tokens=True) -print(txtop) - -# without random_seed, or set_random_seed, this will call logging.warn -decoder2 = RandomSampling(model, random_seed=603) -sampling_output = decoder2.generate_text( - prompt=content, - max_length=20, -) -txtop2 = tokenizer.decode(sampling_output[0], skip_special_tokens=True) -print(txtop2) - -decoder3 = TypicalDecoder(model, random_seed=603, typical_p=0.4) -typical_output = decoder3.generate_text( - prompt=content, - max_length=20, -) -txtop3 = tokenizer.decode(typical_output[0], skip_special_tokens=True) -print(txtop3) - -decoder4 = ContrastiveSearch(model, random_seed=603, penalty_alpha=0.4, top_k=4) -contrastive_output = decoder4.generate_text( - prompt=content, - max_length=20, -) -txtop4 = tokenizer.decode(contrastive_output[0], skip_special_tokens=True) -print(txtop4) - -decoder5 = BeamSearch(model, early_stopping=True, num_beams=3) -beam_output = decoder5.generate_text( - prompt=content, - max_length=20, -) -txtop5 = tokenizer.decode(beam_output[0], skip_special_tokens=True) -print(txtop5) diff --git a/decoder_ring/decoders.py b/decoder_ring/decoders.py index 17c0553..6cf293a 100644 --- a/decoder_ring/decoders.py +++ b/decoder_ring/decoders.py @@ -2,7 +2,9 @@ from random import randint from typing import Optional, Type, Union +from timecontrol.decode import GPT2TimeLMHeadModel from torch import Tensor, LongTensor +from torch.nn import Module from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationMixin, set_seed GenerationOutput = LongTensor @@ -248,6 +250,29 @@ def generate_text( ) +class TimeControl(RandomSampling): + def __init__( + self, + model: GPT2TimeLMHeadModel, + trained_encoder: Module, + random_seed: Optional[int] = None, + ) -> None: + super().__init__(model, random_seed=random_seed) + self.encoder = trained_encoder + + def generate_text( + self, + prompt: Optional[Tensor], + max_length: int = 100, + ) -> GenerationOutput: + self.validate_params() + return self.model.generate( + prompt, + do_sample=True, + max_length=max_length, + ) + + AnyDecoderMagicClass = Union[ Type[BeamSearch], Type[BeamSearchWithSampling], diff --git a/decoder_ring/timecontrol/decode.py b/decoder_ring/timecontrol/decode.py new file mode 100644 index 0000000..8317403 --- /dev/null +++ b/decoder_ring/timecontrol/decode.py @@ -0,0 +1,728 @@ +# Code copied from https://github.com/rosewang2008/language_modeling_via_stochastic_processes +# Language modeling via stochastic processes (ICLR Oral 2022) +# https://arxiv.org/abs/2203.11370 +""" +@misc{https://doi.org/10.48550/arxiv.2203.11370, + doi = {10.48550/ARXIV.2203.11370}, + url = {https://arxiv.org/abs/2203.11370}, + author = {Wang, Rose E and Durmus, Esin and Goodman, Noah and Hashimoto, Tatsunori}, + keywords = {Computation and Language (cs.CL), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Language modeling via stochastic processes}, + publisher = {arXiv}, + year = {2022}, + copyright = {Creative Commons Attribution 4.0 International} +} +""" + +from typing import Tuple + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss +from transformers import GPT2PreTrainedModel, GPT2Tokenizer +from transformers.models.gpt2.modeling_gpt2 import GPT2Block + + +class GPT2TimeModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config, model_name): + super().__init__(config) + self.embed_dim = config.hidden_size + + if not hasattr(config, "use_contrastive_embeddings"): + config.use_contrastive_embeddings = False + + if hasattr(config, "cl_latent_dim") and config.cl_latent_dim is not None: + self.cl2e = nn.Linear(config.cl_latent_dim, self.embed_dim) + + self.cl_tokenizer = GPT2Tokenizer.from_pretrained(model_name) + self.cl_tokenizer.pad_token = self.cl_tokenizer.eos_token + self.cl_end_token = self.cl_tokenizer.eos_token_id + + try: + MAX_NUM_SECTIONS = config.max_num_sections + except: + # Error is hit for older models for toy wikisection setup. + MAX_NUM_SECTIONS = 4 + # NOTE: batch_size 1 + self.section_onehot = torch.FloatTensor(1, MAX_NUM_SECTIONS) + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.section2e = nn.Embedding(MAX_NUM_SECTIONS, self.embed_dim) + # num sections + 1 null embedding. + self.sectionNull2e = nn.Embedding(MAX_NUM_SECTIONS + 1, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList( + [GPT2Block(config) for _ in range(config.num_hidden_layers)] + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + self._config = config + + # For doing eval generation + self._transition_cl = False + self._cur_cl_idx = 0 + self._has_reset = False + # For doing eval generation + self._transition_section = False + self._cur_section_idx = 0 + + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = ( + "cpu" + if "cpu" in self.device_map.keys() + else "cuda:" + str(min(self.device_map.keys())) + ) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + self.section2e = self.section2e.to(self.first_device) + + if self._config.use_contrastive_embeddings: + self.cl2e = self.cl2e.to(self.first_device) + # self.cl_model = self.cl_model.to(self.first_device) + self.cl_tokenizer = self.cl_tokenizer.to(self.first_device) + + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + self.section2e = self.section2e.to("cpu") + if self._config.use_contrastive_embeddings: + self.cl2e = self.cl2e.to("cpu") + # self.cl_model = self.cl_model.to("cpu") + self.cl_tokenizer = self.cl_tokenizer.to("cpu") + + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def cl_tokenize_text(self, text): + output = self.cl_tokenizer( + text, + padding=True, + return_tensors="pt", + ) + input_ids = output["input_ids"] # .squeeze(0) + attention_mask = output["attention_mask"] # .squeeze(0) + eos_input_ids = torch.tensor([[self.cl_end_token] * input_ids.shape[0]]) + eos_attention = torch.tensor([[0] * input_ids.shape[0]]) + input_ids = torch.cat((input_ids, eos_input_ids.T), dim=1) + attention_mask = torch.cat((attention_mask, eos_attention.T), dim=1) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return input_ids.to(device), attention_mask.to(device) + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + def _get_cl_embeddings(self, raw_text, cl_feats, seq_cl_feats, input_ids, seq_len): + # NOTE assuming batch size 1 + generated_by_raw_text = False + + if seq_cl_feats is not None and cl_feats is not None: # Used in evaluation + if input_ids.shape[0] > 1 and ( + not isinstance(self._transition_cl, list) + or not isinstance(self._cur_cl_idx, list) + ): # beam + self._transition_cl = [False] * input_ids.shape[0] + self._cur_cl_idx = [0] * input_ids.shape[0] + cl_feats = cl_feats.expand(input_ids.shape[0], cl_feats.shape[-1]) + self._last_beam_cl_feats = cl_feats + + else: # non beam + if ( + "wikisection" in self.config.dataset_name + or "wikihow" in self.config.dataset_name + or "stories" in self.config.dataset_name + ): + if input_ids[0][0] == self.special_tokens[-1]: # " . " token + try: + num_feats = seq_cl_feats.shape[0] - 1 + except: + num_feats = len(seq_cl_feats) - 1 + self._cur_cl_idx = min(self._cur_cl_idx + 1, num_feats) + cl_feats = seq_cl_feats[self._cur_cl_idx] + elif "taskmaster" in self.config.dataset_name: + if input_ids[0][0] in self.special_tokens: + if not self._has_reset: # don't iterate past + self._has_reset = True + else: + self._cur_cl_idx = min( + self._cur_cl_idx + 1, seq_cl_feats.shape[0] - 1 + ) + cl_feats = seq_cl_feats[self._cur_cl_idx] + else: + self._cur_cl_idx = min( + self._cur_cl_idx + 1, seq_cl_feats.shape[0] - 1 + ) + cl_feats = seq_cl_feats[self._cur_cl_idx] + + if cl_feats is None and seq_cl_feats is not None: + cl_feats = seq_cl_feats + + cl_embeds = self.cl2e(cl_feats) + + if generated_by_raw_text: + cl_embeds = cl_embeds.unsqueeze(0) + else: + if input_ids.shape[0] == 1: + cl_embeds = cl_embeds.expand(1, seq_len, 768) + else: # beam + cl_embeds = cl_embeds.unsqueeze(1) + + return cl_embeds + + def _get_section_ids(self, input_ids, section_ids, seq_section_ids): + # desired shape: section ids = [batch_size, 1] + seq_len = input_ids.shape[1] + if seq_section_ids is not None: # Used in evaluation + if input_ids.shape[0] > 1 and ( + not isinstance(self._transition_section, list) + or not isinstance(self._cur_section_idx, list) + ): # beam + self._transition_section = [False] * input_ids.shape[0] + self._cur_section_idx = [0] * input_ids.shape[0] + section_ids = section_ids.expand( + input_ids.shape[0], section_ids.shape[-1] + ) + self._last_section_ids = section_ids + + if input_ids.shape[0] > 1: + # TODO off by one - need to start replacing on the second 764 mention + section_ids = torch.clone(self._last_section_ids) + for seq_idx, beam_seq in enumerate(input_ids): + if beam_seq[-1] == 764: # eos + self._transition_section[seq_idx] = True + elif self._transition_section[seq_idx]: # last id was eos + if ( + self._cur_section_idx[seq_idx] + 1 + < seq_section_ids.shape[0] + ): + self._cur_section_idx[seq_idx] += 1 + section_ids[seq_idx] = seq_section_ids[ + self._cur_section_idx[seq_idx] + ] + self._last_beam_section_ids = torch.clone(section_ids) + + else: # non beam + section_ids = seq_section_ids[self._cur_section_idx] + if len(seq_section_ids) != self._cur_section_idx + 1: + self._cur_section_idx += 1 + section_ids = section_ids.expand(1, seq_len) + + elif section_ids is None: + section_ids = torch.zeros((1, seq_len)) + end_idx = seq_len + _break = False + for section_num, section_token in enumerate(self.special_tokens[::-1]): + if section_token == self.special_tokens[0]: # first section + start_idx = 0 + else: + start_idx = (input_ids == section_token).nonzero(as_tuple=True)[-1] + if not start_idx.shape[0]: # empty, could not be found. + start_idx = 0 + _break = True + section_ids[:, start_idx:end_idx] = 3 - section_num + if _break: + break + end_idx = start_idx + section_ids = section_ids.to(self.device).long() + else: # have section_ids (1, 1) + section_ids = section_ids.expand(-1, seq_len) # (1, seq_len) + + return section_ids + + def forward( + self, + input_ids=None, + raw_text=None, + cl_feats=None, + seq_cl_feats=None, + seq_section_ids=None, + section_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + # fulldoc & wikisection + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + + hidden_states = inputs_embeds + position_embeds + + # Use embeddings from CL training. + if self._config.use_contrastive_embeddings: + if self._config.use_section_ids: + raise ValueError( + "contrastive embeddings should not be used at same time as section ids" + ) + cl_embeds = self._get_cl_embeddings( + raw_text=raw_text, + cl_feats=cl_feats, + seq_cl_feats=seq_cl_feats, + input_ids=input_ids, + seq_len=inputs_embeds.shape[1], + ) + hidden_states = hidden_states + cl_embeds + + # Do section embeddings + if self._config.use_section_ids: + if self._config.use_contrastive_embeddings: + raise ValueError( + "contrastive embeddings should not be used at same time as section ids" + ) + # section ids = [batch_size, 1] + section_ids = self._get_section_ids( + input_ids=input_ids, + section_ids=section_ids, + seq_section_ids=seq_section_ids, + ) + if ( + hasattr(self._config, "use_section_null") + and self._config.use_section_null + ): + section_embeds = self.sectionNull2e(section_ids) + else: + section_embeds = self.section2e(section_ids) + hidden_states = hidden_states + section_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple( + past_state.to(hidden_states.device) for past_state in layer_past + ) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + ( + outputs[2 if use_cache else 1], + ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + outputs[3 if use_cache else 2], + ) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class GPT2TimeLMHeadModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"attn.masked_bias", + r"attn.bias", + r"lm_head.weight", + ] + + def __init__(self, config, model_name="gpt2"): + super().__init__(config) + self.transformer = GPT2TimeModel(config, model_name) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + result = { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + if "section_ids" in kwargs: + result["section_ids"] = kwargs["section_ids"] + if "raw_text" in kwargs: + result["raw_text"] = kwargs["raw_text"] + if "cl_feats" in kwargs: + result["cl_feats"] = kwargs["cl_feats"] + if "seq_cl_feats" in kwargs: + result["seq_cl_feats"] = kwargs["seq_cl_feats"] + if "seq_section_ids" in kwargs: + result["seq_section_ids"] = kwargs["seq_section_ids"] + + return result + + def forward( + self, + input_ids=None, + raw_text=None, + seq_cl_feats=None, + seq_section_ids=None, + cl_feats=None, + section_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to + ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.transformer( + input_ids, + raw_text=raw_text, + cl_feats=cl_feats, + seq_cl_feats=seq_cl_feats, + seq_section_ids=seq_section_ids, + section_ids=section_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) + for layer_past in past + ) diff --git a/decoder_ring/timecontrol/encode.py b/decoder_ring/timecontrol/encode.py new file mode 100644 index 0000000..5253df4 --- /dev/null +++ b/decoder_ring/timecontrol/encode.py @@ -0,0 +1,625 @@ +# Code copied from https://github.com/rosewang2008/language_modeling_via_stochastic_processes +# Language modeling via stochastic processes (ICLR Oral 2022) +# https://arxiv.org/abs/2203.11370 +""" +@misc{https://doi.org/10.48550/arxiv.2203.11370, + doi = {10.48550/ARXIV.2203.11370}, + url = {https://arxiv.org/abs/2203.11370}, + author = {Wang, Rose E and Durmus, Esin and Goodman, Noah and Hashimoto, Tatsunori}, + keywords = {Computation and Language (cs.CL), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Language modeling via stochastic processes}, + publisher = {arXiv}, + year = {2022}, + copyright = {Creative Commons Attribution 4.0 International} +} +""" + +import os, pickle, random + +from tqdm import tqdm + +import datasets +import pytorch_lightning as pl +import torch +from torch import nn +import torch.utils.data as data +from torch.utils.data import DataLoader +from transformers import GPT2Model, GPT2Tokenizer, BertTokenizer + +# torch.backends.cudnn.benchmark = True + + +def weights_init(m): + if isinstance(m, nn.Linear): + torch.nn.init.zeros_(m.bias) + m.bias.requires_grad = False + + +def create_dataloader(dataset, config, shuffle=True): + loader = DataLoader( + dataset, + batch_size=config["optim_params"]["batch_size"], + shuffle=shuffle, + pin_memory=True, + drop_last=shuffle, + # num_workers=config.experiment_params.data_loader_workers, + ) + return loader + + +class GPT2OUEncoder(nn.Module): + def __init__(self, hidden_dim, latent_dim, model_name, finetune_gpt2=False): + super(GPT2OUEncoder, self).__init__() + self.hidden_dim = hidden_dim + self.latent_dim = latent_dim + self.finetune = finetune_gpt2 + self.model_name = model_name + self._init_model() + + def _init_model(self): + self.model = GPT2Model.from_pretrained(self.model_name) + self.model = self.model.eval() + # turn off all the gradients + for param in self.model.parameters(): + param.requires_grad = self.finetune + self.mlp = nn.Linear(self.model.wte.embedding_dim, self.hidden_dim) + self.feature_extractor = ( + self.create_feature_extractor() + ) # data_dim -> hidden_dim + self.log_q = self.create_log_q() + self.C_eta = nn.Linear(1, 1) + + ## NEW AUG 19, turn off bias training. + self.mlp.apply(weights_init) + self.feature_extractor.apply(weights_init) + self.log_q.apply(weights_init) + self.C_eta.apply(weights_init) + + def create_feature_extractor(self): + return nn.Sequential( + *[ + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.latent_dim), + ] + ) + + def create_log_q(self): + return nn.Sequential( + *[ + nn.Linear(self.latent_dim, self.latent_dim), + nn.Linear(self.latent_dim, self.latent_dim), + nn.Linear(self.latent_dim, 1), + ] + ) + + def get_gpt2_embeddings(self, input_ids, attention_mask): + gpt_emb = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] + # Index into the last hidden state of the sentence (last non-EOS token) + gpt_emb = self.compute_masked_means(gpt_emb, attention_mask) + return gpt_emb + + def get_log_q(self, x): + return self.log_q(x) + + def set_to_train(self): + pass + + def compute_masked_means(self, outputs, masks): + # we don't want to include padding tokens + # outputs : B x T x D + # masks : B x T + dim = outputs.size(2) + masks_dim = masks.unsqueeze(2).repeat(1, 1, dim) + # masked_outputs : B x T x D + masked_outputs = outputs * masks_dim # makes the masked entries 0 + # masked_outputs: B x D / B x 1 => B x D + partition = torch.sum(masks, dim=1, keepdim=True) + masked_outputs = torch.sum(masked_outputs, dim=1) / partition + return masked_outputs + + def projection(self, gpt_emb): + z = self.mlp(gpt_emb) # 32, 100 + z = self.feature_extractor(z) + return z + + def forward(self, input_ids, attention_mask): + gpt_emb = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] + # Index into the last hidden state of the sentence (last non-EOS token) + gpt_emb = self.compute_masked_means(gpt_emb, attention_mask) + # Albert lang embedding -> feature embedding space + return self.projection(gpt_emb) + + +class RecipeNLGData(data.Dataset): + """WikiSection data""" + + def __init__( + self, + train, + all_dataset, + config, + tokenizer_name="GPT2", + filepath=None, + seed=1, + ): + """ """ + super().__init__() + self.train = train + self.all_dataset = all_dataset + self.config = config + + if self.train: + self.start_idx, self.end_idx = 0, 1_000 + else: + self.start_idx, self.end_idx = 500_000, 500_100 + self.seed = seed + self.tokenizer_name = tokenizer_name + self._set_tokenizer() + + self._process_data() + print("Done loading dataset.") + + print("Example: ", self.processed_data[0]["sentence"]) + print("Example: ", self.processed_data[10]["sentence"]) + + def _process_data(self): + self.processed_data = [] + for doc_id in tqdm(range(self.start_idx, self.end_idx)): + doc = self.all_dataset[doc_id] + doc_info = [] + sentence_counter = 0 + # Put all the document sentences together. + title = [self.section_ids[0] + " " + doc["title"] + " . "] + ingredients = [ + self.section_ids[1] + " " + (", ".join(doc["ner"]) + " . ").capitalize() + ] + directions = [d[:-1] + " . " for d in doc["directions"]] + directions[0] = self.section_ids[2] + " " + directions[0] + gpt2_text = title + ingredients + directions + gpt2_text = [s for s in gpt2_text if s] + all_sentences = gpt2_text + # gpt2_text = "".join(gpt2_text) + # all_sentences = title + ingredients + directions + if not all( + [len(self.tokenizer(s)["input_ids"]) < 1024 for s in all_sentences] + ): + continue + for sentence in all_sentences: + if not sentence: + continue + sentence_info = { + "sentence": sentence, + "sentence_id": sentence_counter, + "doc_id": doc_id, + } + doc_info.append(sentence_info) + sentence_counter += 1 + + # Track total number of sentences in a document + for info in doc_info: + info["total_doc_sentences"] = sentence_counter + + self.processed_data += doc_info + + def _set_tokenizer(self): + if self.tokenizer_name == "GPT2": + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + self.tokenizer.pad_token = self.tokenizer.eos_token + self.end_token = self.tokenizer.eos_token_id + self.max_length = 1024 + # elif self.tokenizer_name == "BERT": + # self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + # self.max_length = 512 + else: + raise ValueError("Dont recognize name {}".format(self.tokenizer_name)) + + self.section_ids = ["[ TITLE ]", "[ INGREDIENTS ]", "[ DIRECTIONS ]"] + self.section_names = self.section_ids + self.cl_eos_str = " . " + self.tokenizer.add_tokens(self.section_ids + [self.cl_eos_str]) + self.special_tokens = [ + _[0] for _ in self.tokenizer(self.section_ids)["input_ids"] + ] + self.cl_eos_id = self.tokenizer(self.cl_eos_str)["input_ids"][0] + print("CL EOS ID", self.cl_eos_id) + + def tokenize_caption(self, caption, device): + if self.tokenizer_name == "GPT2": + output = self.tokenizer( + caption, + padding=True, + return_tensors="pt", + ) + input_ids = output["input_ids"].squeeze(0) + attention_mask = output["attention_mask"].squeeze(0) + eos_input_ids = torch.tensor([[self.end_token] * input_ids.shape[0]]) + eos_attention = torch.tensor([[0] * input_ids.shape[0]]) + input_ids = torch.cat((input_ids, eos_input_ids.T), dim=1) + attention_mask = torch.cat((attention_mask, eos_attention.T), dim=1) + elif self.tokenizer_name == "BERT": + # Prepend [CLS] so I can use the first embedding + output = self.tokenizer( + caption, + padding=True, + return_tensors="pt", + ) + input_ids = output["input_ids"].squeeze(0) + attention_mask = output["attention_mask"].squeeze(0) + + return input_ids.to(device), attention_mask.to(device) + + def __len__(self): + return len(self.processed_data) - 1 + + +class RecipeDiscourse(RecipeNLGData): + def __init__( + self, + train, + all_dataset, + config, + tokenizer_name="GPT2", + seed=1, + ): + """ """ + super(RecipeDiscourse, self).__init__( + train=train, + all_dataset=all_dataset, + config=config, + tokenizer_name=tokenizer_name, + seed=seed, + ) + + def __getitem__(self, index): + label = random.randint(0, 1) # either in- or out-of-order + + if label: # in-order + if ( + self.processed_data[index]["doc_id"] + != self.processed_data[index + 1]["doc_id"] + ): + index -= 1 + y_t = self.processed_data[index]["sentence"] + y_tp1 = self.processed_data[index + 1]["sentence"] + else: + y_t = self.processed_data[index]["sentence"] + random_idx = random.randint( + 0, len(self.processed_data) - 1 + ) # either in- or out-of-order + y_tp1 = self.processed_data[random_idx]["sentence"] + + if self.one_hot_labels: + labels = torch.zeros(2) + labels[label] = 1.0 + label = labels + + result = {"y_t": y_t, "y_tp1": y_tp1, "label": label, "idx": index} + return result + + +class RecipeTriplet(RecipeNLGData): + def __init__( + self, + train, + all_dataset, + config, + tokenizer_name="GPT2", + seed=1, + ): + """ """ + super(RecipeTriplet, self).__init__( + train=train, + all_dataset=all_dataset, + config=config, + tokenizer_name=tokenizer_name, + seed=seed, + ) + + def __getitem__(self, index): + utterance = self.processed_data[index] + sentence_num = utterance["sentence_id"] + + # Check if index is start of a seq. If so -> +2 + if sentence_num == 0: + index += 2 + if sentence_num == 1: + index += 1 + + # Update + utterance = self.processed_data[index] + sentence_num = utterance["sentence_id"] + + # TRIAL 2: Sample all random points, t, t', t'' + T = sentence_num + # t is a random point in between + nums = list(range(T)) + t1 = random.choice(nums) + nums.remove(t1) + t2 = random.choice(nums) + if t2 < t1: + t = t2 + t2 = t1 + t1 = t + + assert t1 < t2 and t2 < T + y_0 = self.processed_data[index - T + t1]["sentence"] + y_t = self.processed_data[index - T + t2]["sentence"] + y_T = self.processed_data[index]["sentence"] + + t_ = t1 + t = t2 + + total_doc = utterance["total_doc_sentences"] + result = { + "y_0": y_0, + "y_t": y_t, + "y_T": y_T, + "t_": t_, + "t": t, + "T": T, + "total_t": total_doc, + } + return result + + +NAME2DATASET = { + # 'wikisection': wikisection.WikisectionTPK, + "recipe": RecipeTriplet, + # 'wikihow': wikihow.WikihowTPK, + # 'roc_stories': roc_stories.ROCStoriesTPK, + # 'tm2': tm2.TM2TPK, + # 'tickettalk': tickettalk.TicketTalkTPK, +} + + +class BrownianBridgeLoss(object): + """Everything is a brownian bridge... + p(z_t | mu_0, mu_T) = \mathcal{N}(mu_0 * t/T + mu_T * (1-t/T), I t*(T-t)/T) + normalization constant: -1/(2 * t*(T-t)/T) + """ + + def __init__( + self, + z_0, + z_t, + z_T, + t_, + t, + T, + alpha, + var, + log_q_y_T, + loss_type, + eps, + max_seq_len, + C_eta=None, + label=None, + ): + super().__init__() + self.log_q_y_T = log_q_y_T + self.z_0 = z_0 + self.z_t = z_t + self.z_T = z_T + self.t_ = t_ + self.t = t + self.T = T + self.alpha = alpha + self.var = var + NAME2LOSS = { + "simclr": self.simclr_loss, + } + self.loss_f = NAME2LOSS[loss_type] + self.eps = eps + self.max_seq_len = max_seq_len + self.sigmoid = nn.Sigmoid() + self.label = label + + if C_eta is None: + C_eta = 0.0 + self.C_eta = C_eta + self.end_pin_val = 1.0 + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def _log_p(self, z_0, z_t, z_T, t_0, t_1, t_2): + T = t_2 - t_0 + t = t_1 - t_0 + + alpha = (t / (T + self.eps)).view(-1, 1) + delta = z_0 * (1 - alpha) + z_T * (alpha) - z_t + var = t * (T - t) / (T + self.eps) + log_p = ( + -1 / (2 * var + self.eps) * (delta * delta).sum(-1) + self.C_eta + ) # (512,) + if len(log_p.shape) > 1: # (1, bsz) + log_p = log_p.squeeze(0) + return log_p + + def _logit(self, z_0, z_T, z_t, t_, t, T): + """ + Calculating log p(z_tp1, z_t) = -|| h(z_{t+dt}) - h(z_t)(1-dt)||^2_2 + """ + log_p = self._log_p(z_0=z_0, z_t=z_t, z_T=z_T, t_0=t_, t_1=t, t_2=T) + log_p = log_p.unsqueeze(-1) + log_q = self.log_q_y_T + logit = log_p # - log_q + return logit # should be (bsz, 1) + + def reg_loss(self): + loss = 0.0 + mse_loss_f = nn.MSELoss() + # start reg + start_idxs = torch.where((self.t_) == 0)[0] + if start_idxs.nelement(): + vals = self.z_0[start_idxs, :] + start_reg = mse_loss_f(vals, torch.zeros(vals.shape, device=self.device)) + loss += start_reg + # end reg + end_idxs = torch.where((self.T) == self.max_seq_len - 1)[0] + if end_idxs.nelement(): + vals = torch.abs(self.z_T[end_idxs, :]) + end_reg = mse_loss_f( + vals, torch.ones(vals.shape, device=self.device) * self.end_pin_val + ) + loss += end_reg + return loss + + def simclr_loss(self): + """ + log p = -1/(2*eta) \| x' - x - \mu(x) \|^2_2 + C_{\eta} + logit = log p - log q + """ + loss = 0.0 + # Positive pair + pos_logit = self._logit( + z_0=self.z_0, z_T=self.z_T, z_t=self.z_t, t_=self.t_, t=self.t, T=self.T + ) + pos_probs = torch.exp(pos_logit) # (bsz,1) + for idx in range(self.z_T.shape[0]): + # Negative pair: logits over all possible contrasts + # Nominal contrast for random triplet - contrast from in between + neg_i_logit = self._logit( + z_0=self.z_0, + z_T=self.z_T, + z_t=self.z_t[idx], + t_=self.t_, + t=self.t[idx], + T=self.T, + ) + neg_i_probs = torch.exp(neg_i_logit) # (bsz,1) + loss_i = -(pos_logit[idx] - torch.log(neg_i_probs.sum() + self.eps)) + loss += loss_i + + loss = loss / self.z_T.shape[0] + # Regularization for pinning start and end of bridge + reg_loss = self.reg_loss() + loss += reg_loss + return loss + + def get_loss(self): + return self.loss_f() + + +class BrownianBridgeSystem(pl.LightningModule): + def __init__(self, config): + super().__init__() + self.config = config + self.model_name = config["model_params"]["name"] + self._set_dataset() + self._set_language_encoder() + + def configure_optimizers(self): + optimizer = torch.optim.SGD( + self.parameters(), + lr=self.config["optim_params"]["learning_rate"], + momentum=self.config["optim_params"]["momentum"], + ) + return [optimizer], [] + + def train_dataloader(self): + return create_dataloader(self.train_dataset, self.config) + + def test_dataloader(self): + return create_dataloader(self.test_dataset, self.config, shuffle=False) + + def _set_dataset(self): + dname = self.config["data_params"]["name"] + if "recipe" == dname: + self.data_dir = self.config["data_params"][ + "path" + ] # constants.PATH2RECIPENLG + self.all_dataset = datasets.load_dataset( + "recipe_nlg", data_dir=self.data_dir + )["train"] + elif "wikihow" == dname: + self.data_name = constants.PATH2WIKIHOW + with open(self.data_name, "rb") as f: + self.all_dataset = pickle.load(f) + else: + self.all_dataset = None + + dataset = NAME2DATASET[dname] + self.train_dataset = dataset( + train=True, + # seed=self.config['data_params']['data_seed'], + all_dataset=self.all_dataset, + config=self.config, + ) + self.test_dataset = dataset( + train=False, + # seed=self.config['data_params']['data_seed'], + all_dataset=self.all_dataset, + config=self.config, + ) + + def set_to_train(self): + pass + + def _set_language_encoder(self): + self.model = GPT2OUEncoder( + hidden_dim=self.config["model_params"]["hidden_size"], + latent_dim=self.config["model_params"]["latent_dim"], + model_name=self.model_name, + finetune_gpt2=False, + ) + + self.model.model.resize_token_embeddings(len(self.train_dataset.tokenizer)) + for p in self.model.model.parameters(): + p.requires_grad = False + + def forward(self, input_ids, attention_mask): + feats = self.model.forward(input_ids=input_ids, attention_mask=attention_mask) + return feats + + def get_feats(self, obs): + input_ids_i, attention_mask_i = self.train_dataset.tokenize_caption( + obs, device=self.device + ) + input_ids_i = input_ids_i[:, : self.train_dataset.max_length] + attention_mask_i = attention_mask_i[:, : self.train_dataset.max_length] + feats_i = self.forward(input_ids=input_ids_i, attention_mask=attention_mask_i) + return feats_i + + def get_losses_for_batch(self, batch, batch_idx): + torch.cuda.empty_cache() + if "y_0" in batch: + obs_0 = batch["y_0"] + else: + obs_0 = "" + obs_t = batch["y_t"] + obs_T = batch["y_T"] + t_s = batch["t_"].float() + ts = batch["t"].float() + Ts = batch["T"].float() + feats_0 = self.get_feats(obs_0) + feats_t = self.get_feats(obs_t) + feats_T = self.get_feats(obs_T) + log_q_y_tp1 = self.model.get_log_q(feats_t) + loss_fn = BrownianBridgeLoss( + z_0=feats_0, + z_t=feats_t, + z_T=feats_T, + t_=t_s, + t=ts, + T=Ts, + alpha=0, + var=0, + log_q_y_T=log_q_y_tp1, + loss_type=self.config["loss_params"]["name"], + eps=self.config["model_params"]["eps"], + max_seq_len=batch["total_t"].float(), + ) + loss = loss_fn.get_loss() + return loss + + def training_step(self, batch, batch_idx): + loss = self.get_losses_for_batch(batch, batch_idx) + self.log("train_loss", loss, prog_bar=True, on_step=True) + return loss + + def test_step(self, batch, i): + loss = self.get_losses_for_batch(batch=batch, batch_idx=i) + self.log("test_loss", loss.cpu().detach().numpy(), prog_bar=True, on_step=True) + return loss diff --git a/scratch.md b/scratch.md index b18e55c..168fdc7 100644 --- a/scratch.md +++ b/scratch.md @@ -2,11 +2,14 @@ Interesting params in https://github.com/mapmeld/transformers/blob/main/src/tran is_constraint_gen_mode -Unsupported: +Not yet supported: -- https://arxiv.org/abs/2203.11370 time control -- https://github.com/rosewang2008/language_modeling_via_stochastic_processes - https://github.com/XiangLi1999/ContrastiveDecoding +- https://arxiv.org/abs/2104.05336 + +Partially supported: + +- time control https://arxiv.org/abs/2203.11370 == https://github.com/rosewang2008/language_modeling_via_stochastic_processes Issues: From 4584d3dc69870a6a903bd1fe7794ec41ed76c957 Mon Sep 17 00:00:00 2001 From: Nick Doiron Date: Mon, 16 Jan 2023 20:40:03 -0700 Subject: [PATCH 2/2] a little cleaner, but encoder model does not save very much data, so loader is not picking up the right stuff --- .gitignore | 1 + decoder_ring/all-demo.py | 213 ++++++++++++++++++++--------- decoder_ring/timecontrol/encode.py | 7 + 3 files changed, 154 insertions(+), 67 deletions(-) diff --git a/.gitignore b/.gitignore index 9dfa17e..6961e49 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ build/ dist/ *.egg-info lightning_logs/ +*.pt diff --git a/decoder_ring/all-demo.py b/decoder_ring/all-demo.py index 931b182..aa3072e 100644 --- a/decoder_ring/all-demo.py +++ b/decoder_ring/all-demo.py @@ -1,5 +1,7 @@ import os +import torch +import transformers from transformers import AutoModelForCausalLM, AutoTokenizer import pytorch_lightning as pl # for TimeControl @@ -13,7 +15,7 @@ TypicalDecoder, ) from timecontrol.decode import GPT2TimeLMHeadModel -from timecontrol.encode import BrownianBridgeSystem +from timecontrol.encode import BrownianBridgeSystem, GPT2OUEncoder ## Writer test basic = BasicWriter("gpt2", RandomSampling) @@ -24,87 +26,164 @@ print(writer_output) ## HuggingFace tokenized string -# model = AutoModelForCausalLM.from_pretrained("gpt2") +model = AutoModelForCausalLM.from_pretrained("gpt2") tokenizer = AutoTokenizer.from_pretrained("gpt2") content = tokenizer.encode("Hello, my name is", return_tensors="pt") ## Decoder tests -# decoder1 = GreedyDecoder(model) -# greedy_output = decoder1.generate_text( -# prompt=content, -# max_length=20, -# ) -# txtop = tokenizer.decode(greedy_output[0], skip_special_tokens=True) -# print(txtop) +decoder1 = GreedyDecoder(model) +greedy_output = decoder1.generate_text( + prompt=content, + max_length=20, +) +txtop = tokenizer.decode(greedy_output[0], skip_special_tokens=True) +print(txtop) # without random_seed, or set_random_seed, this will call logging.warn -# decoder2 = RandomSampling(model, random_seed=603) -# sampling_output = decoder2.generate_text( -# prompt=content, -# max_length=20, -# ) -# txtop2 = tokenizer.decode(sampling_output[0], skip_special_tokens=True) -# print(txtop2) -# -# decoder3 = TypicalDecoder(model, random_seed=603, typical_p=0.4) -# typical_output = decoder3.generate_text( -# prompt=content, -# max_length=20, -# ) -# txtop3 = tokenizer.decode(typical_output[0], skip_special_tokens=True) -# print(txtop3) -# -# decoder4 = ContrastiveSearch(model, random_seed=603, penalty_alpha=0.4, top_k=4) -# contrastive_output = decoder4.generate_text( -# prompt=content, -# max_length=20, -# ) -# txtop4 = tokenizer.decode(contrastive_output[0], skip_special_tokens=True) -# print(txtop4) -# -# decoder5 = BeamSearch(model, early_stopping=True, num_beams=3) -# beam_output = decoder5.generate_text( -# prompt=content, -# max_length=20, -# ) -# txtop5 = tokenizer.decode(beam_output[0], skip_special_tokens=True) -# print(txtop5) +decoder2 = RandomSampling(model, random_seed=603) +sampling_output = decoder2.generate_text( + prompt=content, + max_length=20, +) +txtop2 = tokenizer.decode(sampling_output[0], skip_special_tokens=True) +print(txtop2) + +decoder3 = TypicalDecoder(model, random_seed=603, typical_p=0.4) +typical_output = decoder3.generate_text( + prompt=content, + max_length=20, +) +txtop3 = tokenizer.decode(typical_output[0], skip_special_tokens=True) +print(txtop3) + +decoder4 = ContrastiveSearch(model, random_seed=603, penalty_alpha=0.4, top_k=4) +contrastive_output = decoder4.generate_text( + prompt=content, + max_length=20, +) +txtop4 = tokenizer.decode(contrastive_output[0], skip_special_tokens=True) +print(txtop4) +decoder5 = BeamSearch(model, early_stopping=True, num_beams=3) +beam_output = decoder5.generate_text( + prompt=content, + max_length=20, +) +txtop5 = tokenizer.decode(beam_output[0], skip_special_tokens=True) +print(txtop5) + +# based on https://github.com/rosewang2008/language_modeling_via_stochastic_processes/blob/main/language_modeling_via_stochastic_processes/scripts/train_encoder.py tc_model = GPT2TimeLMHeadModel.from_pretrained("gpt2") trainer = pl.Trainer( # gpus=1, max_epochs=1, min_epochs=1, ) -trainer.fit( - BrownianBridgeSystem( - # params via https://github.com/rosewang2008/language_modeling_via_stochastic_processes/blob/main/language_modeling_via_stochastic_processes/config/encoder/brownian.yaml - { - "data_params": { - # "k": 5, - "name": "recipe", - "path": "~/Downloads/recipe/", - }, - "loss_params": { - "name": "simclr", - }, - "model_params": { - "eps": 1e-6, - "hidden_size": 128, - "latent_dim": 32, - "name": "gpt2", - }, - "optim_params": { - "batch_size": 32, - "learning_rate": 0.0001, - "momentum": 0.9, - }, - } - ) +sys = BrownianBridgeSystem( + # params via https://github.com/rosewang2008/language_modeling_via_stochastic_processes/blob/main/language_modeling_via_stochastic_processes/config/encoder/brownian.yaml + { + "data_params": { + # "k": 5, + "name": "recipe", + "path": "~/Downloads/recipe/", + }, + "loss_params": { + "name": "simclr", + }, + "model_params": { + "eps": 1e-6, + "hidden_size": 128, + "latent_dim": 32, + "name": "gpt2", + }, + "optim_params": { + "batch_size": 32, + "learning_rate": 0.0001, + "momentum": 0.9, + }, + } ) -trainer.save("./tcencoder") +if os.path.isfile("./mlp.pt"): + HIDDEN_DIM = 128 + + def load_cl_model(filepath, latent_dim, base_model, use_section_ids, token_size): + model = GPT2OUEncoder( + model_name="gpt2", + hidden_dim=HIDDEN_DIM, + latent_dim=latent_dim, + finetune_gpt2=False, + ) + if use_section_ids: + model.model.resize_token_embeddings(token_size) + + transformers.__spec__ = "gpt2" # Avoid bug + state_dict = torch.load(filepath) + new_dict = {} + for k, v in state_dict["state_dict"].items(): + if any([i in k for i in ["model.model.g_ar", "model.model.W_k"]]): + new_dict[k[6:]] = v + elif any([i in k for i in ["model.g_ar", "model.W_k", "time_model"]]): + continue + elif "model." in k: + new_dict[k[6:]] = v + else: + new_dict[k] = v -decoder6 = TimeControl(tc_model, encoder=trainer.model, random_seed=22) + if any(["g_ar" in k for k in new_dict.keys()]): + model.g_ar = nn.GRU( + input_size=latent_dim, + hidden_size=2400, # default number in infoNCE for langauge + num_layers=3, + batch_first=True, + ) + model.W_k = nn.Linear(2400, latent_dim) + elif any(["time_model" in k for k in state_dict["state_dict"].keys()]): + model.fc_mu = nn.Linear(latent_dim, latent_dim) + model.fc_var = nn.Linear(latent_dim, latent_dim) + + model.load_state_dict(new_dict) + for p in model.parameters(): + p.requires_grad = False + model.eval() + return model + + def get_checkpoint( + dataset_name, + latent_dim, + base_model="gpt2", + sec_id=False, + token_size=None, + filepath=None, + ): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = load_cl_model( + filepath, + latent_dim, + base_model, + use_section_ids=sec_id, + token_size=token_size, + ) + model.to(device) + model = model.eval() + return model + + CL_MODEL = get_checkpoint( + dataset_name="recipe", + latent_dim=32, + sec_id=True, + token_size=len(tokenizer), + base_model="gpt2", + filepath="./feature_extractor.pt", + ) + # CL_MODEL.to(args.device) +else: + trainer.fit(sys) + sys.save(directory="./") + +# https://github.com/rosewang2008/language_modeling_via_stochastic_processes/blob/main/language_modeling_via_stochastic_processes/transformers/examples/pytorch/text-generation/run_decoding_from_embeddings.py +tc_model = GPT2TimeLMHeadModel.from_pretrained("gpt2") +tc_model._config.use_contrastive_embeddings = True +decoder6 = TimeControl(tc_model, sys.model, random_seed=22) tc_output = decoder6.generate_text( prompt=content, max_length=20, diff --git a/decoder_ring/timecontrol/encode.py b/decoder_ring/timecontrol/encode.py index 5253df4..7dd3dd0 100644 --- a/decoder_ring/timecontrol/encode.py +++ b/decoder_ring/timecontrol/encode.py @@ -623,3 +623,10 @@ def test_step(self, batch, i): loss = self.get_losses_for_batch(batch=batch, batch_idx=i) self.log("test_loss", loss.cpu().detach().numpy(), prog_bar=True, on_step=True) return loss + + def save(self, directory): + torch.save(self.model.mlp.state_dict(), os.path.join(directory, "mlp.pt")) + torch.save( + self.model.feature_extractor.state_dict(), + os.path.join(directory, "feature_extractor.pt"), + )