From 4e1f54b7d3c47cd8b6eaa81038545545f8584843 Mon Sep 17 00:00:00 2001 From: yhyu13 Date: Wed, 22 Nov 2023 16:11:23 +0800 Subject: [PATCH 1/2] Code change to use flash attn 2 --- toolbench/train/train.py | 3 ++- toolbench/train/train_lora.py | 11 ++++++----- toolbench/train/train_mem.py | 10 +++++----- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/toolbench/train/train.py b/toolbench/train/train.py index c89aa871..68094355 100644 --- a/toolbench/train/train.py +++ b/toolbench/train/train.py @@ -276,7 +276,8 @@ def train(): model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, - device_map=device_map + device_map=device_map, + use_flash_attention_2=True, ) model.config.use_cache = False trainer = Trainer( diff --git a/toolbench/train/train_lora.py b/toolbench/train/train_lora.py index ad72cfb3..251d185e 100644 --- a/toolbench/train/train_lora.py +++ b/toolbench/train/train_lora.py @@ -33,11 +33,11 @@ make_supervised_data_module, ) -from toolbench.train.llama_flash_attn_monkey_patch import ( - replace_llama_attn_with_flash_attn, -) +# from toolbench.train.llama_flash_attn_monkey_patch import ( +# replace_llama_attn_with_flash_attn, +# ) from toolbench.train.llama_condense_monkey_patch import replace_llama_with_condense -replace_llama_attn_with_flash_attn() +# replace_llama_attn_with_flash_attn() @dataclass @@ -110,7 +110,8 @@ def train(): model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, - device_map=device_map + device_map=device_map, + use_flash_attention_2=True, ) lora_config = LoraConfig( r=lora_args.lora_r, diff --git a/toolbench/train/train_mem.py b/toolbench/train/train_mem.py index fbf76648..20821865 100644 --- a/toolbench/train/train_mem.py +++ b/toolbench/train/train_mem.py @@ -1,11 +1,11 @@ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. -# Need to call this before importing transformers. -from toolbench.train.llama_flash_attn_monkey_patch import ( - replace_llama_attn_with_flash_attn, -) +# # Need to call this before importing transformers. +# from toolbench.train.llama_flash_attn_monkey_patch import ( +# replace_llama_attn_with_flash_attn, +# ) -replace_llama_attn_with_flash_attn() +# replace_llama_attn_with_flash_attn() from toolbench.train.train import train From f13f9c8f5ca9b1845eb822f77f8ccb9833273208 Mon Sep 17 00:00:00 2001 From: yhyu13 Date: Thu, 23 Nov 2023 11:46:21 +0800 Subject: [PATCH 2/2] Use QLoRA along side flash attn2 --- toolbench/train/train.py | 7 +++++++ toolbench/train/train_lora.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/toolbench/train/train.py b/toolbench/train/train.py index 68094355..5e28a51e 100644 --- a/toolbench/train/train.py +++ b/toolbench/train/train.py @@ -273,11 +273,18 @@ def train(): world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + bnb_config = transformers.BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, device_map=device_map, use_flash_attention_2=True, + quantization_config=bnb_config, ) model.config.use_cache = False trainer = Trainer( diff --git a/toolbench/train/train_lora.py b/toolbench/train/train_lora.py index 251d185e..96e683e4 100644 --- a/toolbench/train/train_lora.py +++ b/toolbench/train/train_lora.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch from dataclasses import dataclass, field import logging import pathlib @@ -107,11 +108,18 @@ def train(): world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + bnb_config = transformers.BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, device_map=device_map, use_flash_attention_2=True, + quantization_config=bnb_config, ) lora_config = LoraConfig( r=lora_args.lora_r,