From b15d85a9724a1e20c774dc59d752e589353cd3f5 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Tue, 13 Sep 2022 17:01:16 -0700 Subject: [PATCH 01/11] ia3 adaptors --- tango/integrations/transformers/ia3.py | 208 ++++++++++++++++++++ tests/integrations/transformers/ia3_test.py | 41 ++++ 2 files changed, 249 insertions(+) create mode 100644 tango/integrations/transformers/ia3.py create mode 100644 tests/integrations/transformers/ia3_test.py diff --git a/tango/integrations/transformers/ia3.py b/tango/integrations/transformers/ia3.py new file mode 100644 index 000000000..9d02c7b69 --- /dev/null +++ b/tango/integrations/transformers/ia3.py @@ -0,0 +1,208 @@ +from dataclasses import dataclass +from typing import Optional +import re +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.modeling_utils import Conv1D +from transformers import PreTrainedModel + +@dataclass +class WithIA3Config: + attention_modules: str = None + fused_qkv_layers: Optional[str] = None + k_layers: Optional[str] = None + v_layers: Optional[str] = None + mlp_modules: str = None + mlp_layers: str = None + ia3_param_names: str = None + +GPT_J_IA3_CONFIG = WithIA3Config( + attention_modules = ".*attn", + k_layers="k_proj", + v_layers="v_proj", + mlp_modules = ".*mlp", + mlp_layers = "fc_in", + ia3_param_names = "ia3" + ) + +GPT_2_IA3_CONFIG = WithIA3Config( + attention_modules = ".*attn", + fused_qkv_layers = "c_attn", + mlp_modules = ".*mlp", + mlp_layers = "c_fc", + ia3_param_names = "ia3" + ) + +OPT_IA3_CONFIG = WithIA3Config( + attention_modules = ".*self_attn", + k_layers="k_proj", + v_layers="v_proj", + mlp_modules = ".*layers\.\d*", + mlp_layers = "fc1", + ia3_param_names = "ia3" + ) + +BLOOM_IA3_CONFIG = WithIA3Config( + attention_modules = ".*self_attention", + fused_qkv_layers = "query_key_value", + mlp_modules = ".*mlp", + mlp_layers = "dense_h_to_4h", + ia3_param_names = "ia3" + ) + +MODEL_NAME_TO_CONFIG = { + 'sshleifer/tiny-gpt2': GPT_2_IA3_CONFIG, + 'gpt2': GPT_2_IA3_CONFIG, + 'gpt2-medium': GPT_2_IA3_CONFIG, + 'gpt2-large': GPT_2_IA3_CONFIG, + 'gpt2-xl': GPT_2_IA3_CONFIG, + 'bigscience/bloom-560m': BLOOM_IA3_CONFIG, + 'bigscience/bloom-1b1': BLOOM_IA3_CONFIG, + 'bigscience/bloom-1b7': BLOOM_IA3_CONFIG, + 'bigscience/bloom-3b': BLOOM_IA3_CONFIG, + 'bigscience/bloom-7b1': BLOOM_IA3_CONFIG, + 'bigscience/bloom': BLOOM_IA3_CONFIG, + 'facebook/opt-125m': OPT_IA3_CONFIG, + 'facebook/opt-350m': OPT_IA3_CONFIG, + 'facebook/opt-1.3b': OPT_IA3_CONFIG, + 'facebook/opt-2.7b': OPT_IA3_CONFIG, + 'facebook/opt-6.7b': OPT_IA3_CONFIG, + 'facebook/opt-13b': OPT_IA3_CONFIG, + 'facebook/opt-30b': OPT_IA3_CONFIG, + 'facebook/opt-66b': OPT_IA3_CONFIG, + 'EleutherAI/gpt-j-6B': GPT_J_IA3_CONFIG, +} + +class LinearWithIA3(nn.Module): + def __init__(self, linear_layer, ia3_param_names, unfuse_size: int = None): + super().__init__() + + self.in_features = linear_layer.in_features + self.out_features = linear_layer.out_features + self.unfuse_size = unfuse_size + + self.weight = linear_layer.weight + self.bias = linear_layer.bias + + self.ia3_param_names = ia3_param_names + + # if (q,k,v) are stacked into one layer + if unfuse_size is not None: + assert linear_layer.out_features == unfuse_size * 3 + # IA3 only operates on k and v (not q), thus the "* 2" + setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1))) + else: + setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1))) + + def forward(self, x): + x = F.linear(x, self.weight, self.bias) + + ia3_params = getattr(self, self.ia3_param_names) + + if ia3_params.requires_grad: + if self.unfuse_size is not None: + # non_q means k and v + q, non_q = x[:, :, :self.unfuse_size], x[:, :, self.unfuse_size:] + ia3_params = getattr(self, self.ia3_param_names) + non_q = non_q * ia3_params.flatten() + x = torch.cat([q, non_q], dim=2) + else: + x = x * ia3_params.flatten() + + return x + +class Conv1DWithIA3(nn.Module): + def __init__(self, conv1d_layer, ia3_param_names, unfuse_size: int = None): + super().__init__() + + # nf: number of output features; nx: number of input features + self.nf = conv1d_layer.nf + self.unfuse_size = unfuse_size + + self.weight = conv1d_layer.weight + self.bias = conv1d_layer.bias + + self.ia3_param_names = ia3_param_names + + # in c_att parameters, (q,k,v) linear layers are stacked into one Conv1D layer + if unfuse_size is not None: + assert conv1d_layer.nf == unfuse_size * 3 + # but IA3 only operates on k and v (not q), thus the "* 2" + setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1))) + else: + setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.nf, 1))) + + def forward(self, x): + # copied and pasted from the original Conv1D implemnetation + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) # ... * self.nf + + ia3_params = getattr(self, self.ia3_param_names) + + if ia3_params.requires_grad: + if self.unfuse_size is not None: + # non_q means k and v + q, non_q = x[:, :, :self.unfuse_size], x[:, :, self.unfuse_size:] + ia3_params = getattr(self, self.ia3_param_names) + non_q = non_q * ia3_params.flatten() + x = torch.cat([q, non_q], dim=2) + else: + x = x * ia3_params.flatten() + + return x + +def modify_with_ia3(transformer: PreTrainedModel, config: WithIA3Config, *, only_ia3_requires_grad: bool = True): + """ + A function to add ia3 adaptors to the given transformer. Code modified from + [t-few](https://github.com/r-three/t-few/blob/217cfa3b73aa66a07594826e4ebbbc516b331461/src/models/lora.py) and Qinyuan Ye + + :param model: + A :class:`~transformers.PreTrainedModel` to modify. + :param config: + A :class:`~tango.integrations.transformers.ia3.WithIA3Config` that specifies the layers to modify. + :param only_ia3_requires_grad: + A `bool`, `True` if `requires_grad` should only be set on ia3 paramenters in the output model. + + Examples + -------- + + You can use this as a :class:`~tango.integrations.torch.Model` constructor from a config/params + like this: + + .. testcode:: + + from transformers import AutoModelForCausalLM, AutoTokenizer + from tango.integrations.transformers.ia3 import modify_with_ia3, GPT_2_IA3_CONFIG + + model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") + model = modify_with_ia3(model, GPT_2_IA3_CONFIG) + """ + for m_name, module in dict(transformer.named_modules()).items(): + if re.fullmatch(config.attention_modules, m_name) or re.fullmatch(config.mlp_modules, m_name): + attn_layers = [regex for regex in (config.fused_qkv_layers, config.k_layers, config.v_layers) if regex is not None] + layers_to_change = "|".join(attn_layers) \ + if re.fullmatch(config.attention_modules, m_name) \ + else config.mlp_layers + for c_name, layer in dict(module.named_children()).items(): + if re.fullmatch(layers_to_change, c_name): + assert isinstance(layer, Conv1D) or isinstance(layer, nn.Linear), f"This code only supports Conv1D and nn.Linear" + adaptor_class = Conv1DWithIA3 if isinstance(layer, Conv1D) else LinearWithIA3 + new_module = adaptor_class( + layer, + config.ia3_param_names, + unfuse_size=transformer.config.hidden_size \ + if config.fused_qkv_layers and re.fullmatch(config.fused_qkv_layers, c_name) \ + else None + ) + setattr(module, c_name, new_module) + + if only_ia3_requires_grad: + transformer.requires_grad_(False) + for p_name, v in dict(transformer.named_parameters()).items(): + if re.fullmatch('.*' + config.ia3_param_names + '.*', p_name): + v.requires_grad_(True) + + return transformer \ No newline at end of file diff --git a/tests/integrations/transformers/ia3_test.py b/tests/integrations/transformers/ia3_test.py new file mode 100644 index 000000000..cfe72a528 --- /dev/null +++ b/tests/integrations/transformers/ia3_test.py @@ -0,0 +1,41 @@ +from tango.integrations.transformers.ia3 import modify_with_ia3, GPT_2_IA3_CONFIG +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +import re + + +def test_ia3(): + + config = GPT_2_IA3_CONFIG + model_name = "sshleifer/tiny-gpt2" + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + input_seq = tokenizer( + ["A tiny test on a tiny model."], + return_tensors="pt" + ) + + model = AutoModelForCausalLM.from_pretrained(model_name) + + with torch.no_grad(): + old_outputs = model( + input_ids=input_seq.input_ids, + attention_mask=input_seq.attention_mask, + labels=input_seq.input_ids, + ) + + model = modify_with_ia3(model, config) + + with torch.no_grad(): + new_outputs = model( + input_ids=input_seq.input_ids, + attention_mask=input_seq.attention_mask, + labels=input_seq.input_ids, + ) + + logits_diff = torch.abs(old_outputs.logits - new_outputs.logits).mean() + assert logits_diff < 1e-10 + + loss_diff = torch.abs(old_outputs.loss - new_outputs.loss) + assert loss_diff < 1e-10 \ No newline at end of file From 90642ec56eb0590eea9d2acc1085c4f25c867721 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Tue, 13 Sep 2022 18:43:51 -0700 Subject: [PATCH 02/11] docs and style fixes --- docs/source/api/integrations/transformers.rst | 2 + tango/integrations/transformers/ia3.py | 194 +++++++++++------- tests/integrations/transformers/ia3_test.py | 15 +- 3 files changed, 123 insertions(+), 88 deletions(-) diff --git a/docs/source/api/integrations/transformers.rst b/docs/source/api/integrations/transformers.rst index 59a003326..376b89143 100644 --- a/docs/source/api/integrations/transformers.rst +++ b/docs/source/api/integrations/transformers.rst @@ -17,3 +17,5 @@ Reference .. autoclass:: tango.integrations.transformers.RunGenerationDataset :members: + +.. autofunction:: tango.integrations.transformers.ia3.modify_with_ia3 \ No newline at end of file diff --git a/tango/integrations/transformers/ia3.py b/tango/integrations/transformers/ia3.py index 9d02c7b69..bbcbf3d1b 100644 --- a/tango/integrations/transformers/ia3.py +++ b/tango/integrations/transformers/ia3.py @@ -1,80 +1,83 @@ +import re from dataclasses import dataclass from typing import Optional -import re + import torch import torch.nn as nn import torch.nn.functional as F - -from transformers.modeling_utils import Conv1D from transformers import PreTrainedModel +from transformers.modeling_utils import Conv1D + @dataclass class WithIA3Config: - attention_modules: str = None + ia3_param_names: str + attention_modules: str + mlp_modules: str + mlp_layers: str fused_qkv_layers: Optional[str] = None k_layers: Optional[str] = None v_layers: Optional[str] = None - mlp_modules: str = None - mlp_layers: str = None - ia3_param_names: str = None + GPT_J_IA3_CONFIG = WithIA3Config( - attention_modules = ".*attn", - k_layers="k_proj", - v_layers="v_proj", - mlp_modules = ".*mlp", - mlp_layers = "fc_in", - ia3_param_names = "ia3" - ) + attention_modules=".*attn", + k_layers="k_proj", + v_layers="v_proj", + mlp_modules=".*mlp", + mlp_layers="fc_in", + ia3_param_names="ia3", +) GPT_2_IA3_CONFIG = WithIA3Config( - attention_modules = ".*attn", - fused_qkv_layers = "c_attn", - mlp_modules = ".*mlp", - mlp_layers = "c_fc", - ia3_param_names = "ia3" - ) + attention_modules=".*attn", + fused_qkv_layers="c_attn", + mlp_modules=".*mlp", + mlp_layers="c_fc", + ia3_param_names="ia3", +) OPT_IA3_CONFIG = WithIA3Config( - attention_modules = ".*self_attn", - k_layers="k_proj", - v_layers="v_proj", - mlp_modules = ".*layers\.\d*", - mlp_layers = "fc1", - ia3_param_names = "ia3" - ) + attention_modules=".*self_attn", + k_layers="k_proj", + v_layers="v_proj", + mlp_modules=r".*layers\.\d*", + mlp_layers="fc1", + ia3_param_names="ia3", +) BLOOM_IA3_CONFIG = WithIA3Config( - attention_modules = ".*self_attention", - fused_qkv_layers = "query_key_value", - mlp_modules = ".*mlp", - mlp_layers = "dense_h_to_4h", - ia3_param_names = "ia3" - ) + attention_modules=".*self_attention", + fused_qkv_layers="query_key_value", + mlp_modules=".*mlp", + mlp_layers="dense_h_to_4h", + ia3_param_names="ia3", +) MODEL_NAME_TO_CONFIG = { - 'sshleifer/tiny-gpt2': GPT_2_IA3_CONFIG, - 'gpt2': GPT_2_IA3_CONFIG, - 'gpt2-medium': GPT_2_IA3_CONFIG, - 'gpt2-large': GPT_2_IA3_CONFIG, - 'gpt2-xl': GPT_2_IA3_CONFIG, - 'bigscience/bloom-560m': BLOOM_IA3_CONFIG, - 'bigscience/bloom-1b1': BLOOM_IA3_CONFIG, - 'bigscience/bloom-1b7': BLOOM_IA3_CONFIG, - 'bigscience/bloom-3b': BLOOM_IA3_CONFIG, - 'bigscience/bloom-7b1': BLOOM_IA3_CONFIG, - 'bigscience/bloom': BLOOM_IA3_CONFIG, - 'facebook/opt-125m': OPT_IA3_CONFIG, - 'facebook/opt-350m': OPT_IA3_CONFIG, - 'facebook/opt-1.3b': OPT_IA3_CONFIG, - 'facebook/opt-2.7b': OPT_IA3_CONFIG, - 'facebook/opt-6.7b': OPT_IA3_CONFIG, - 'facebook/opt-13b': OPT_IA3_CONFIG, - 'facebook/opt-30b': OPT_IA3_CONFIG, - 'facebook/opt-66b': OPT_IA3_CONFIG, - 'EleutherAI/gpt-j-6B': GPT_J_IA3_CONFIG, + "sshleifer/tiny-gpt2": GPT_2_IA3_CONFIG, + "gpt2": GPT_2_IA3_CONFIG, + "gpt2-medium": GPT_2_IA3_CONFIG, + "gpt2-large": GPT_2_IA3_CONFIG, + "gpt2-xl": GPT_2_IA3_CONFIG, + "bigscience/bloom-560m": BLOOM_IA3_CONFIG, + "bigscience/bloom-1b1": BLOOM_IA3_CONFIG, + "bigscience/bloom-1b7": BLOOM_IA3_CONFIG, + "bigscience/bloom-3b": BLOOM_IA3_CONFIG, + "bigscience/bloom-7b1": BLOOM_IA3_CONFIG, + "bigscience/bloom": BLOOM_IA3_CONFIG, + "facebook/opt-125m": OPT_IA3_CONFIG, + "facebook/opt-350m": OPT_IA3_CONFIG, + "facebook/opt-1.3b": OPT_IA3_CONFIG, + "facebook/opt-2.7b": OPT_IA3_CONFIG, + "facebook/opt-6.7b": OPT_IA3_CONFIG, + "facebook/opt-13b": OPT_IA3_CONFIG, + "facebook/opt-30b": OPT_IA3_CONFIG, + "facebook/opt-66b": OPT_IA3_CONFIG, + "EleutherAI/gpt-j-6B": GPT_J_IA3_CONFIG, } + class LinearWithIA3(nn.Module): def __init__(self, linear_layer, ia3_param_names, unfuse_size: int = None): super().__init__() @@ -104,7 +107,7 @@ def forward(self, x): if ia3_params.requires_grad: if self.unfuse_size is not None: # non_q means k and v - q, non_q = x[:, :, :self.unfuse_size], x[:, :, self.unfuse_size:] + q, non_q = x[:, :, : self.unfuse_size], x[:, :, self.unfuse_size :] ia3_params = getattr(self, self.ia3_param_names) non_q = non_q * ia3_params.flatten() x = torch.cat([q, non_q], dim=2) @@ -113,6 +116,7 @@ def forward(self, x): return x + class Conv1DWithIA3(nn.Module): def __init__(self, conv1d_layer, ia3_param_names, unfuse_size: int = None): super().__init__() @@ -138,14 +142,14 @@ def forward(self, x): # copied and pasted from the original Conv1D implemnetation size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(size_out) # ... * self.nf + x = x.view(size_out) # ... * self.nf ia3_params = getattr(self, self.ia3_param_names) if ia3_params.requires_grad: if self.unfuse_size is not None: # non_q means k and v - q, non_q = x[:, :, :self.unfuse_size], x[:, :, self.unfuse_size:] + q, non_q = x[:, :, : self.unfuse_size], x[:, :, self.unfuse_size :] ia3_params = getattr(self, self.ia3_param_names) non_q = non_q * ia3_params.flatten() x = torch.cat([q, non_q], dim=2) @@ -154,10 +158,15 @@ def forward(self, x): return x -def modify_with_ia3(transformer: PreTrainedModel, config: WithIA3Config, *, only_ia3_requires_grad: bool = True): + +def modify_with_ia3( + transformer: PreTrainedModel, config: WithIA3Config, *, only_ia3_requires_grad: bool = True +) -> PreTrainedModel: """ - A function to add ia3 adaptors to the given transformer. Code modified from - [t-few](https://github.com/r-three/t-few/blob/217cfa3b73aa66a07594826e4ebbbc516b331461/src/models/lora.py) and Qinyuan Ye + A function to add ia3 adaptors to the given transformer. Code modified from + `t-few `_ + and Qinyuan Ye + :param model: A :class:`~transformers.PreTrainedModel` to modify. @@ -169,8 +178,7 @@ def modify_with_ia3(transformer: PreTrainedModel, config: WithIA3Config, *, only Examples -------- - You can use this as a :class:`~tango.integrations.torch.Model` constructor from a config/params - like this: + You can use the provided configurations: .. testcode:: @@ -179,30 +187,58 @@ def modify_with_ia3(transformer: PreTrainedModel, config: WithIA3Config, *, only model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") model = modify_with_ia3(model, GPT_2_IA3_CONFIG) + + Or you can write your own configuration with regex matching the layers to modify and their parents: + + .. testcode:: + + from transformers import AutoModelForCausalLM, AutoTokenizer + from tango.integrations.transformers.ia3 import modify_with_ia3 + + my_config = WithIA3Config( + attention_modules=".*attn", + fused_qkv_layers="c_attn", + mlp_modules=".*mlp", + mlp_layers="c_fc", + ia3_param_names="ia3", + ) + + model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") + model = modify_with_ia3(model, my_config) """ - for m_name, module in dict(transformer.named_modules()).items(): - if re.fullmatch(config.attention_modules, m_name) or re.fullmatch(config.mlp_modules, m_name): - attn_layers = [regex for regex in (config.fused_qkv_layers, config.k_layers, config.v_layers) if regex is not None] - layers_to_change = "|".join(attn_layers) \ - if re.fullmatch(config.attention_modules, m_name) \ + for m_name, module in dict(transformer.named_modules()).items(): # type: ignore + if re.fullmatch(config.attention_modules, m_name) or re.fullmatch( + config.mlp_modules, m_name + ): + attn_layers = [ + regex + for regex in (config.fused_qkv_layers, config.k_layers, config.v_layers) + if regex is not None + ] + layers_to_change = ( + "|".join(attn_layers) + if re.fullmatch(config.attention_modules, m_name) else config.mlp_layers + ) for c_name, layer in dict(module.named_children()).items(): if re.fullmatch(layers_to_change, c_name): - assert isinstance(layer, Conv1D) or isinstance(layer, nn.Linear), f"This code only supports Conv1D and nn.Linear" + assert isinstance(layer, Conv1D) or isinstance( + layer, nn.Linear + ), "This code only supports Conv1D and nn.Linear" adaptor_class = Conv1DWithIA3 if isinstance(layer, Conv1D) else LinearWithIA3 new_module = adaptor_class( - layer, - config.ia3_param_names, - unfuse_size=transformer.config.hidden_size \ - if config.fused_qkv_layers and re.fullmatch(config.fused_qkv_layers, c_name) \ - else None - ) + layer, + config.ia3_param_names, + unfuse_size=transformer.config.hidden_size # type: ignore + if config.fused_qkv_layers and re.fullmatch(config.fused_qkv_layers, c_name) + else None, + ) setattr(module, c_name, new_module) - + if only_ia3_requires_grad: - transformer.requires_grad_(False) - for p_name, v in dict(transformer.named_parameters()).items(): - if re.fullmatch('.*' + config.ia3_param_names + '.*', p_name): + transformer.requires_grad_(False) # type: ignore + for p_name, v in dict(transformer.named_parameters()).items(): # type: ignore + if re.fullmatch(".*" + config.ia3_param_names + ".*", p_name): v.requires_grad_(True) - - return transformer \ No newline at end of file + + return transformer diff --git a/tests/integrations/transformers/ia3_test.py b/tests/integrations/transformers/ia3_test.py index cfe72a528..a8de10934 100644 --- a/tests/integrations/transformers/ia3_test.py +++ b/tests/integrations/transformers/ia3_test.py @@ -1,7 +1,7 @@ -from tango.integrations.transformers.ia3 import modify_with_ia3, GPT_2_IA3_CONFIG -from transformers import AutoModelForCausalLM, AutoTokenizer import torch -import re +from transformers import AutoModelForCausalLM, AutoTokenizer + +from tango.integrations.transformers.ia3 import GPT_2_IA3_CONFIG, modify_with_ia3 def test_ia3(): @@ -11,10 +11,7 @@ def test_ia3(): tokenizer = AutoTokenizer.from_pretrained(model_name) - input_seq = tokenizer( - ["A tiny test on a tiny model."], - return_tensors="pt" - ) + input_seq = tokenizer(["A tiny test on a tiny model."], return_tensors="pt") model = AutoModelForCausalLM.from_pretrained(model_name) @@ -24,7 +21,7 @@ def test_ia3(): attention_mask=input_seq.attention_mask, labels=input_seq.input_ids, ) - + model = modify_with_ia3(model, config) with torch.no_grad(): @@ -38,4 +35,4 @@ def test_ia3(): assert logits_diff < 1e-10 loss_diff = torch.abs(old_outputs.loss - new_outputs.loss) - assert loss_diff < 1e-10 \ No newline at end of file + assert loss_diff < 1e-10 From 2d9ecc1db28b60faeda3df8a37a856f09cf6c889 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Tue, 13 Sep 2022 18:48:26 -0700 Subject: [PATCH 03/11] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7917c345..9cf4f818f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Adds a function to modify a Hugging Face transformer with IA3 adaptors + ### Fixed - Made `BeakerExecutor` more robust to connection, timeout, SSL, and other recoverable HTTP errors. From 34e2ab10f2335c0d0e1a3b89115440b48be3bc90 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Wed, 14 Sep 2022 22:48:12 -0700 Subject: [PATCH 04/11] docs and refactoring --- tango/integrations/transformers/ia3.py | 119 ++++++++++++++++--------- 1 file changed, 76 insertions(+), 43 deletions(-) diff --git a/tango/integrations/transformers/ia3.py b/tango/integrations/transformers/ia3.py index bbcbf3d1b..ca6ad9100 100644 --- a/tango/integrations/transformers/ia3.py +++ b/tango/integrations/transformers/ia3.py @@ -11,6 +11,27 @@ @dataclass class WithIA3Config: + """ + A class for configuring which layers to modify with IA3 adaptors. + + + :param ia3_param_names: + A string used as the name for all ia3 parameters + :param attention_modules: + A regex that matches all attention modules which are parents to the keys and value layers to modify. + :param mlp_modules: + A regex that matches all modules that are parents to the feed forward layer to modify. + :param mlp_layers: + A regex that matches the feed forward layer in the modules specified by `mlp_modles`. + :param fused_qkv_layers: + A regex that matches the combined query, key, and value layer in the modules specified + by `attention_modules`. + :param k_layers: + A regex that matches the key layer in the modules specified by `attention_modules`. + :param v_layers: + A regex that matches the value layer in the modules specified by `attention_modules`. + """ + ia3_param_names: str attention_modules: str mlp_modules: str @@ -78,30 +99,19 @@ class WithIA3Config: } -class LinearWithIA3(nn.Module): - def __init__(self, linear_layer, ia3_param_names, unfuse_size: int = None): +class WithIA3(nn.Module): + def __init__(self, ia3_param_names: str, unfuse_size: int = None): super().__init__() - - self.in_features = linear_layer.in_features - self.out_features = linear_layer.out_features - self.unfuse_size = unfuse_size - - self.weight = linear_layer.weight - self.bias = linear_layer.bias - self.ia3_param_names = ia3_param_names # if (q,k,v) are stacked into one layer if unfuse_size is not None: - assert linear_layer.out_features == unfuse_size * 3 # IA3 only operates on k and v (not q), thus the "* 2" setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1))) else: - setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1))) - - def forward(self, x): - x = F.linear(x, self.weight, self.bias) + setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1))) # type: ignore + def scale_by_ia3(self, x): ia3_params = getattr(self, self.ia3_param_names) if ia3_params.requires_grad: @@ -117,46 +127,69 @@ def forward(self, x): return x -class Conv1DWithIA3(nn.Module): - def __init__(self, conv1d_layer, ia3_param_names, unfuse_size: int = None): - super().__init__() +class LinearWithIA3(WithIA3): + def __init__(self, linear_layer: nn.Linear, ia3_param_names: str, unfuse_size: int = None): + """ + A replacement for :class:`~torch.nn.Linear` modified with an IA3 adaptor + + + :param linear_layer: + A :class:`~torch.nn.Linear` layer to adapt. + :param ia3_param_names: + A `str` to use as the name of ia3 parameters. + :param unfuse_size: + An `int` indicating hidden dimension of the query, key, and value vectors. + To be used only when the layer to modify is a fused projection of query, + key, and value vectors in an attention mechanism. + """ + assert unfuse_size is None or (linear_layer.out_features == unfuse_size * 3) + self.in_features = linear_layer.in_features + self.out_features = linear_layer.out_features + self.unfuse_size = unfuse_size + + super().__init__(ia3_param_names, unfuse_size) + + self.weight = linear_layer.weight + self.bias = linear_layer.bias + + def forward(self, x): + x = F.linear(x, self.weight, self.bias) + return self.scale_by_ia3(x) + + +class Conv1DWithIA3(WithIA3): + def __init__(self, conv1d_layer: Conv1D, ia3_param_names: str, unfuse_size: int = None): + """ + A replacement for :class:`~transformers.modeling_utils.Conv1D` modified with an IA3 adaptor + + + :param conv1d_layer: + A :class:`~transformers.modeling_utils.Conv1D` layer to adapt. + :param ia3_param_names: + A `str` to use as the name of ia3 parameters. + :param unfuse_size: + An `int` indicating hidden dimension of the query, key, and value vectors. + To be used only when the layer to modify is a fused projection of query, + key, and value vectors in an attention mechanism. + """ + assert unfuse_size is None or (conv1d_layer.nf == unfuse_size * 3) # nf: number of output features; nx: number of input features - self.nf = conv1d_layer.nf + self.out_features = conv1d_layer.nf self.unfuse_size = unfuse_size + super().__init__(ia3_param_names, unfuse_size) + self.weight = conv1d_layer.weight self.bias = conv1d_layer.bias - self.ia3_param_names = ia3_param_names - - # in c_att parameters, (q,k,v) linear layers are stacked into one Conv1D layer - if unfuse_size is not None: - assert conv1d_layer.nf == unfuse_size * 3 - # but IA3 only operates on k and v (not q), thus the "* 2" - setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1))) - else: - setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.nf, 1))) - def forward(self, x): # copied and pasted from the original Conv1D implemnetation - size_out = x.size()[:-1] + (self.nf,) + size_out = x.size()[:-1] + (self.out_features,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = x.view(size_out) # ... * self.nf - ia3_params = getattr(self, self.ia3_param_names) - - if ia3_params.requires_grad: - if self.unfuse_size is not None: - # non_q means k and v - q, non_q = x[:, :, : self.unfuse_size], x[:, :, self.unfuse_size :] - ia3_params = getattr(self, self.ia3_param_names) - non_q = non_q * ia3_params.flatten() - x = torch.cat([q, non_q], dim=2) - else: - x = x * ia3_params.flatten() - - return x + return self.scale_by_ia3(x) def modify_with_ia3( From b825a4177ca239e9d1f1a20caf950c01694c6bb5 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Thu, 15 Sep 2022 08:55:02 -0700 Subject: [PATCH 05/11] style fix --- tango/integrations/transformers/ia3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tango/integrations/transformers/ia3.py b/tango/integrations/transformers/ia3.py index ca6ad9100..728eb50c6 100644 --- a/tango/integrations/transformers/ia3.py +++ b/tango/integrations/transformers/ia3.py @@ -109,7 +109,7 @@ def __init__(self, ia3_param_names: str, unfuse_size: int = None): # IA3 only operates on k and v (not q), thus the "* 2" setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1))) else: - setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1))) # type: ignore + setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1))) # type: ignore def scale_by_ia3(self, x): ia3_params = getattr(self, self.ia3_param_names) From c1e27f1c84d6813044f7f22cf8f0a58dbf0611f3 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 16 Sep 2022 17:25:16 -0700 Subject: [PATCH 06/11] style fixes --- tango/integrations/transformers/ia3.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tango/integrations/transformers/ia3.py b/tango/integrations/transformers/ia3.py index 728eb50c6..3f86fb808 100644 --- a/tango/integrations/transformers/ia3.py +++ b/tango/integrations/transformers/ia3.py @@ -193,7 +193,10 @@ def forward(self, x): def modify_with_ia3( - transformer: PreTrainedModel, config: WithIA3Config, *, only_ia3_requires_grad: bool = True + transformer: PreTrainedModel, + *, + config: WithIA3Config = None, + only_ia3_requires_grad: bool = True, ) -> PreTrainedModel: """ A function to add ia3 adaptors to the given transformer. Code modified from @@ -239,6 +242,13 @@ def modify_with_ia3( model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") model = modify_with_ia3(model, my_config) """ + if config is None: + model_name = transformer.config._name_or_path # type: ignore + assert ( + model_name in MODEL_NAME_TO_CONFIG + ), f"{model_name} does not have an pre made configuration; please make your own." + config = MODEL_NAME_TO_CONFIG[model_name] + for m_name, module in dict(transformer.named_modules()).items(): # type: ignore if re.fullmatch(config.attention_modules, m_name) or re.fullmatch( config.mlp_modules, m_name From 22fb7f81ca8ffaa983a7ec9c954d0fd4627bdab7 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 16 Sep 2022 17:28:01 -0700 Subject: [PATCH 07/11] more style fixes --- tango/integrations/transformers/ia3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tango/integrations/transformers/ia3.py b/tango/integrations/transformers/ia3.py index 3f86fb808..8148f54e4 100644 --- a/tango/integrations/transformers/ia3.py +++ b/tango/integrations/transformers/ia3.py @@ -243,7 +243,7 @@ def modify_with_ia3( model = modify_with_ia3(model, my_config) """ if config is None: - model_name = transformer.config._name_or_path # type: ignore + model_name = transformer.config._name_or_path # type: ignore assert ( model_name in MODEL_NAME_TO_CONFIG ), f"{model_name} does not have an pre made configuration; please make your own." From 75b24d92348cdfc63eb0fbe38f33f621c91d0735 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 16 Sep 2022 17:36:43 -0700 Subject: [PATCH 08/11] update test --- tests/integrations/transformers/ia3_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integrations/transformers/ia3_test.py b/tests/integrations/transformers/ia3_test.py index a8de10934..2617df269 100644 --- a/tests/integrations/transformers/ia3_test.py +++ b/tests/integrations/transformers/ia3_test.py @@ -22,7 +22,7 @@ def test_ia3(): labels=input_seq.input_ids, ) - model = modify_with_ia3(model, config) + model = modify_with_ia3(model, config=config) with torch.no_grad(): new_outputs = model( From 6be3b6f20ca77985afaa95ea15ed34f0713c90ac Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 16 Sep 2022 17:44:08 -0700 Subject: [PATCH 09/11] more fixes --- tango/integrations/transformers/ia3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tango/integrations/transformers/ia3.py b/tango/integrations/transformers/ia3.py index 8148f54e4..3e542da54 100644 --- a/tango/integrations/transformers/ia3.py +++ b/tango/integrations/transformers/ia3.py @@ -222,7 +222,7 @@ def modify_with_ia3( from tango.integrations.transformers.ia3 import modify_with_ia3, GPT_2_IA3_CONFIG model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") - model = modify_with_ia3(model, GPT_2_IA3_CONFIG) + model = modify_with_ia3(model, config=GPT_2_IA3_CONFIG) Or you can write your own configuration with regex matching the layers to modify and their parents: @@ -240,7 +240,7 @@ def modify_with_ia3( ) model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") - model = modify_with_ia3(model, my_config) + model = modify_with_ia3(model, config=my_config) """ if config is None: model_name = transformer.config._name_or_path # type: ignore From 7c24fdc7889fa3536993179f14bfa684c6949bf5 Mon Sep 17 00:00:00 2001 From: Ian Magnusson <40903802+jagnusson@users.noreply.github.com> Date: Fri, 16 Sep 2022 18:14:56 -0700 Subject: [PATCH 10/11] Apply suggestions from code review Co-authored-by: Dirk Groeneveld --- tango/integrations/transformers/ia3.py | 2 +- tests/integrations/transformers/ia3_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tango/integrations/transformers/ia3.py b/tango/integrations/transformers/ia3.py index 3e542da54..e5013805d 100644 --- a/tango/integrations/transformers/ia3.py +++ b/tango/integrations/transformers/ia3.py @@ -246,7 +246,7 @@ def modify_with_ia3( model_name = transformer.config._name_or_path # type: ignore assert ( model_name in MODEL_NAME_TO_CONFIG - ), f"{model_name} does not have an pre made configuration; please make your own." + ), f"{model_name} does not have a pre made configuration; please make your own." config = MODEL_NAME_TO_CONFIG[model_name] for m_name, module in dict(transformer.named_modules()).items(): # type: ignore diff --git a/tests/integrations/transformers/ia3_test.py b/tests/integrations/transformers/ia3_test.py index 2617df269..766097395 100644 --- a/tests/integrations/transformers/ia3_test.py +++ b/tests/integrations/transformers/ia3_test.py @@ -15,7 +15,7 @@ def test_ia3(): model = AutoModelForCausalLM.from_pretrained(model_name) - with torch.no_grad(): + with torch.inference_mode(): old_outputs = model( input_ids=input_seq.input_ids, attention_mask=input_seq.attention_mask, @@ -24,7 +24,7 @@ def test_ia3(): model = modify_with_ia3(model, config=config) - with torch.no_grad(): + with torch.inference_mode(): new_outputs = model( input_ids=input_seq.input_ids, attention_mask=input_seq.attention_mask, From 1fe816ebffe740c7e47d3cbba68235a7ad3187d0 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 19 Sep 2022 10:45:28 -0700 Subject: [PATCH 11/11] Set model to eval mode --- tests/integrations/transformers/ia3_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integrations/transformers/ia3_test.py b/tests/integrations/transformers/ia3_test.py index 766097395..fa2df2aa8 100644 --- a/tests/integrations/transformers/ia3_test.py +++ b/tests/integrations/transformers/ia3_test.py @@ -13,7 +13,7 @@ def test_ia3(): input_seq = tokenizer(["A tiny test on a tiny model."], return_tensors="pt") - model = AutoModelForCausalLM.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name).eval() with torch.inference_mode(): old_outputs = model(