diff --git a/toolbench/train/train.py b/toolbench/train/train.py index c89aa871..5e28a51e 100644 --- a/toolbench/train/train.py +++ b/toolbench/train/train.py @@ -273,10 +273,18 @@ def train(): world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + bnb_config = transformers.BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, - device_map=device_map + device_map=device_map, + use_flash_attention_2=True, + quantization_config=bnb_config, ) model.config.use_cache = False trainer = Trainer( diff --git a/toolbench/train/train_lora.py b/toolbench/train/train_lora.py index ad72cfb3..96e683e4 100644 --- a/toolbench/train/train_lora.py +++ b/toolbench/train/train_lora.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch from dataclasses import dataclass, field import logging import pathlib @@ -33,11 +34,11 @@ make_supervised_data_module, ) -from toolbench.train.llama_flash_attn_monkey_patch import ( - replace_llama_attn_with_flash_attn, -) +# from toolbench.train.llama_flash_attn_monkey_patch import ( +# replace_llama_attn_with_flash_attn, +# ) from toolbench.train.llama_condense_monkey_patch import replace_llama_with_condense -replace_llama_attn_with_flash_attn() +# replace_llama_attn_with_flash_attn() @dataclass @@ -107,10 +108,18 @@ def train(): world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + bnb_config = transformers.BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, - device_map=device_map + device_map=device_map, + use_flash_attention_2=True, + quantization_config=bnb_config, ) lora_config = LoraConfig( r=lora_args.lora_r, diff --git a/toolbench/train/train_mem.py b/toolbench/train/train_mem.py index fbf76648..20821865 100644 --- a/toolbench/train/train_mem.py +++ b/toolbench/train/train_mem.py @@ -1,11 +1,11 @@ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. -# Need to call this before importing transformers. -from toolbench.train.llama_flash_attn_monkey_patch import ( - replace_llama_attn_with_flash_attn, -) +# # Need to call this before importing transformers. +# from toolbench.train.llama_flash_attn_monkey_patch import ( +# replace_llama_attn_with_flash_attn, +# ) -replace_llama_attn_with_flash_attn() +# replace_llama_attn_with_flash_attn() from toolbench.train.train import train