diff --git a/ferret/model/builder.py b/ferret/model/builder.py index f901f53..9877e29 100644 --- a/ferret/model/builder.py +++ b/ferret/model/builder.py @@ -15,98 +15,156 @@ import shutil import pdb -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + AutoConfig, + BitsAndBytesConfig, +) import torch from ferret.model import * -from ferret.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from ferret.constants import ( + DEFAULT_IMAGE_PATCH_TOKEN, + DEFAULT_IM_START_TOKEN, + DEFAULT_IM_END_TOKEN, +) +from ferret.model.utils import DEVICE + DEFAULT_REGION_FEA_TOKEN = "" -def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"): + +def load_pretrained_model( + model_path, + model_base, + model_name, + load_8bit=False, + load_4bit=False, + device_map="auto", +): kwargs = {"device_map": device_map} if load_8bit: - kwargs['load_in_8bit'] = True + kwargs["load_in_8bit"] = True elif load_4bit: - kwargs['load_in_4bit'] = True - kwargs['quantization_config'] = BitsAndBytesConfig( + kwargs["load_in_4bit"] = True + kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type='nf4' + bnb_4bit_quant_type="nf4", ) else: - kwargs['torch_dtype'] = torch.float16 + kwargs["torch_dtype"] = torch.float16 - if 'llava' in model_name.lower() or 'ferret' in model_name.lower(): + if "llava" in model_name.lower() or "ferret" in model_name.lower(): # Load LLaVA/FERRET model - if 'lora' in model_name.lower() and model_base is not None: + if "lora" in model_name.lower() and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - print('Loading LLaVA/FERRET from base model...') - model = FERRETLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + print("Loading LLaVA/FERRET from base model...") + model = FERRETLlamaForCausalLM.from_pretrained( + model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs + ) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: - model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) - model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + model.lm_head.weight = torch.nn.Parameter( + torch.empty( + token_num, tokem_dim, device=model.device, dtype=model.dtype + ) + ) + model.model.embed_tokens.weight = torch.nn.Parameter( + torch.empty( + token_num, tokem_dim, device=model.device, dtype=model.dtype + ) + ) - print('Loading additional LLaVA/FERRET weights...') - if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): - non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + print("Loading additional LLaVA/FERRET weights...") + if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): + non_lora_trainables = torch.load( + os.path.join(model_path, "non_lora_trainables.bin"), + map_location="cpu", + ) else: # this is probably from HF Hub from huggingface_hub import hf_hub_download + def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( - repo_id=repo_id, - filename=filename, - subfolder=subfolder) - return torch.load(cache_file, map_location='cpu') - non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') - non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} - if any(k.startswith('model.model.') for k in non_lora_trainables): - non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + repo_id=repo_id, filename=filename, subfolder=subfolder + ) + return torch.load(cache_file, map_location="cpu") + + non_lora_trainables = load_from_hf( + model_path, "non_lora_trainables.bin" + ) + non_lora_trainables = { + (k[11:] if k.startswith("base_model.") else k): v + for k, v in non_lora_trainables.items() + } + if any(k.startswith("model.model.") for k in non_lora_trainables): + non_lora_trainables = { + (k[6:] if k.startswith("model.") else k): v + for k, v in non_lora_trainables.items() + } model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel - print('Loading LoRA weights...') + + print("Loading LoRA weights...") model = PeftModel.from_pretrained(model, model_path) - print('Merging LoRA weights...') + print("Merging LoRA weights...") model = model.merge_and_unload() - print('Model is loaded...') + print("Model is loaded...") elif model_base is not None: # this may be mm projector only - print('Loading LLaVA/FERRET from base model...') + print("Loading LLaVA/FERRET from base model...") tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) - model = FERRETLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + model = FERRETLlamaForCausalLM.from_pretrained( + model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs + ) - mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') - mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} + mm_projector_weights = torch.load( + os.path.join(model_path, "mm_projector.bin"), map_location="cpu" + ) + mm_projector_weights = { + k: v.to(torch.float16) for k, v in mm_projector_weights.items() + } model.load_state_dict(mm_projector_weights, strict=False) else: tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - model = FERRETLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + model = FERRETLlamaForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **kwargs + ) else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) - model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") + model = AutoModelForCausalLM.from_pretrained( + model_base, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + device_map="auto", + ) print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() - print('Convert to FP16...') + print("Convert to FP16...") model.to(torch.float16) else: use_fast = False tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **kwargs + ) image_processor = None - if 'llava' in model_name.lower() or 'ferret' in model_name.lower(): + if "llava" in model_name.lower() or "ferret" in model_name.lower(): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) mm_im_region_fea_token = getattr(model.config, "im_region_fea_token", None) @@ -115,20 +173,22 @@ def load_from_hf(repo_id, filename, subfolder=None): if mm_im_region_fea_token is not None: tokenizer.add_tokens([DEFAULT_REGION_FEA_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() - vision_tower_path = os.path.join(model_path, 'vision_tower') + vision_tower_path = os.path.join(model_path, "vision_tower") if not vision_tower.is_loaded or os.path.exists(vision_tower_path): if os.path.exists(vision_tower_path): - print(f'Start Loading vision tower from {vision_tower_path}') + print(f"Start Loading vision tower from {vision_tower_path}") vision_tower.load_model(vision_tower_path=vision_tower_path) - print(f'Finish Loading vision tower from {vision_tower_path}') + print(f"Finish Loading vision tower from {vision_tower_path}") else: vision_tower.load_model() - vision_tower.to(device='cuda', dtype=torch.float16) + vision_tower.to(device=DEVICE, dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): diff --git a/ferret/model/utils.py b/ferret/model/utils.py index bbdf3b2..13a1c07 100644 --- a/ferret/model/utils.py +++ b/ferret/model/utils.py @@ -1,18 +1,32 @@ +import torch +from typing import Literal from transformers import AutoConfig +DEVICE: Literal["cpu", "cuda", "mps"] = None +if torch.cuda.is_available(): + DEVICE = "cuda" +elif torch.backends.mps.is_available(): + DEVICE = "mps" +else: + DEVICE = "cpu" + def auto_upgrade(config): cfg = AutoConfig.from_pretrained(config) - if 'llava' in config and 'llava' not in cfg.model_type: - assert cfg.model_type == 'llama' - print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") - print("You must upgrade the checkpoint to the new code base (this can be done automatically).") + if "llava" in config and "llava" not in cfg.model_type: + assert cfg.model_type == "llama" + print( + "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." + ) + print( + "You must upgrade the checkpoint to the new code base (this can be done automatically)." + ) confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") if confirm.lower() in ["y", "yes"]: print("Upgrading checkpoint...") assert len(cfg.architectures) == 1 setattr(cfg.__class__, "model_type", "llava") - cfg.architectures[0] = 'FERRETLlamaForCausalLM' + cfg.architectures[0] = "FERRETLlamaForCausalLM" cfg.save_pretrained(config) print("Checkpoint upgraded.") else: diff --git a/ferret/serve/model_worker.py b/ferret/serve/model_worker.py index 7557890..637789b 100644 --- a/ferret/serve/model_worker.py +++ b/ferret/serve/model_worker.py @@ -29,6 +29,7 @@ from ferret.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # from transformers import TextIteratorStreamer from threading import Thread +from ferret.model.utils import DEVICE GB = 1 << 30 @@ -177,7 +178,7 @@ def generate_stream(self, params): if region_masks is not None: assert self.add_region_feature - region_masks = [[torch.Tensor(region_mask_i).cuda().half() for region_mask_i in region_masks]] + region_masks = [[torch.Tensor(region_mask_i).to(DEVICE).half() for region_mask_i in region_masks]] image_args["region_masks"] = region_masks logger.info("Add region_masks to image_args.") else: @@ -211,15 +212,15 @@ def generate_stream(self, params): for i in range(max_new_tokens): if i == 0: out = model( - torch.as_tensor([input_ids]).cuda(), + torch.as_tensor([input_ids]).to(DEVICE), use_cache=True, **image_args) logits = out.logits past_key_values = out.past_key_values else: attention_mask = torch.ones( - 1, past_key_values[0][0].shape[-2] + 1, device="cuda") - out = model(input_ids=torch.as_tensor([[token]], device="cuda"), + 1, past_key_values[0][0].shape[-2] + 1, device=DEVICE) + out = model(input_ids=torch.as_tensor([[token]], device=DEVICE), use_cache=True, attention_mask=attention_mask, past_key_values=past_key_values,