From 0d92b5070c05f6303edec0c92a1b2532d65d5ae8 Mon Sep 17 00:00:00 2001 From: jeanjerome Date: Mon, 8 Jan 2024 15:46:00 +0100 Subject: [PATCH] refacto(silicon): Refactor Ferret LLM for Apple Silicon --- .gitignore | 5 +++++ ferret/constants.py | 1 + ferret/model/builder.py | 4 +++- ferret/model/device_selector.py | 13 +++++++++++++ ferret/model/ferret_arch.py | 3 ++- ferret/serve/controller.py | 6 +++--- ferret/serve/gradio_web_server.py | 4 ++-- ferret/serve/model_worker.py | 17 +++++++++++------ ferret/utils.py | 4 ++-- pyproject.toml | 2 +- 10 files changed, 43 insertions(+), 16 deletions(-) create mode 100644 .gitignore create mode 100644 ferret/model/device_selector.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5e87969 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/delta/ +/model/ +/.idea/ +/ferret.egg-info/ +/serve_images/ diff --git a/ferret/constants.py b/ferret/constants.py index be8cf02..281d7df 100644 --- a/ferret/constants.py +++ b/ferret/constants.py @@ -1,5 +1,6 @@ CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 +TIMEOUT = 300 LOGDIR = "." diff --git a/ferret/model/builder.py b/ferret/model/builder.py index f901f53..11a21d1 100644 --- a/ferret/model/builder.py +++ b/ferret/model/builder.py @@ -18,6 +18,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig import torch from ferret.model import * +from ferret.model.device_selector import get_device from ferret.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN DEFAULT_REGION_FEA_TOKEN = "" @@ -128,7 +129,8 @@ def load_from_hf(repo_id, filename, subfolder=None): else: vision_tower.load_model() - vision_tower.to(device='cuda', dtype=torch.float16) + device = get_device() + vision_tower.to(device=device, dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): diff --git a/ferret/model/device_selector.py b/ferret/model/device_selector.py new file mode 100644 index 0000000..a74cf6f --- /dev/null +++ b/ferret/model/device_selector.py @@ -0,0 +1,13 @@ +import torch + +def get_device(): + """ + Determines the best available device (CUDA, MPS, or CPU) + and returns it for later use with PyTorch. + """ + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") diff --git a/ferret/model/ferret_arch.py b/ferret/model/ferret_arch.py index a5bbc5f..c519ae5 100644 --- a/ferret/model/ferret_arch.py +++ b/ferret/model/ferret_arch.py @@ -659,7 +659,8 @@ def initialize_vision_tokenizer(self, model_args, tokenizer, add_region_feature= p.requires_grad = False if model_args.pretrain_mm_mlp_adapter: - mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + device = get_device() + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location=device) embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] if add_region_feature: num_new_tokens = num_new_tokens - num_region_fea_tokens diff --git a/ferret/serve/controller.py b/ferret/serve/controller.py index e01cc3a..3792417 100644 --- a/ferret/serve/controller.py +++ b/ferret/serve/controller.py @@ -18,7 +18,7 @@ import requests import uvicorn -from ferret.constants import CONTROLLER_HEART_BEAT_EXPIRATION +from ferret.constants import CONTROLLER_HEART_BEAT_EXPIRATION, TIMEOUT from ferret.utils import build_logger, server_error_msg @@ -87,7 +87,7 @@ def register_worker(self, worker_name: str, check_heart_beat: bool, def get_worker_status(self, worker_name: str): try: - r = requests.post(worker_name + "/worker_get_status", timeout=5) + r = requests.post(worker_name + "/worker_get_status", timeout=TIMEOUT) except requests.exceptions.RequestException as e: logger.error(f"Get status fails: {worker_name}, {e}") return None @@ -202,7 +202,7 @@ def worker_api_generate_stream(self, params): try: response = requests.post(worker_addr + "/worker_generate_stream", - json=params, stream=True, timeout=5) + json=params, stream=True, timeout=TIMEOUT) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: yield chunk + b"\0" diff --git a/ferret/serve/gradio_web_server.py b/ferret/serve/gradio_web_server.py index 67576c8..3ad85ff 100644 --- a/ferret/serve/gradio_web_server.py +++ b/ferret/serve/gradio_web_server.py @@ -14,7 +14,7 @@ from ferret.conversation import (default_conversation, conv_templates, SeparatorStyle) -from ferret.constants import LOGDIR +from ferret.constants import LOGDIR, TIMEOUT from ferret.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg) import hashlib @@ -443,7 +443,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in try: # Stream output response = requests.post(worker_addr + "/worker_generate_stream", - headers=headers, json=pload, stream=True, timeout=10) + headers=headers, json=pload, stream=True, timeout=TIMEOUT) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) diff --git a/ferret/serve/model_worker.py b/ferret/serve/model_worker.py index 7557890..9c872fe 100644 --- a/ferret/serve/model_worker.py +++ b/ferret/serve/model_worker.py @@ -21,9 +21,10 @@ import uvicorn from functools import partial -from ferret.constants import WORKER_HEART_BEAT_INTERVAL +from ferret.constants import WORKER_HEART_BEAT_INTERVAL, TIMEOUT from ferret.utils import (build_logger, server_error_msg, pretty_print_semaphore) +from ferret.model.device_selector import get_device from ferret.model.builder import load_pretrained_model from ferret.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria from ferret.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN @@ -80,6 +81,10 @@ def __init__(self, controller_addr, worker_addr, self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( model_path, model_base, self.model_name, load_8bit, load_4bit) self.is_multimodal = 'llava' in self.model_name.lower() or 'ferret' in self.model_name.lower() + self.device = get_device() + + logger.info(f"Device is {self.device}") + logger.info(f"Model device is {self.model.device}.") if not no_register: self.register_to_controller() @@ -110,7 +115,7 @@ def send_heart_beat(self): try: ret = requests.post(url, json={ "worker_name": self.worker_addr, - "queue_length": self.get_queue_length()}, timeout=5) + "queue_length": self.get_queue_length()}, timeout=TIMEOUT) exist = ret.json()["exist"] break except requests.exceptions.RequestException as e: @@ -177,7 +182,7 @@ def generate_stream(self, params): if region_masks is not None: assert self.add_region_feature - region_masks = [[torch.Tensor(region_mask_i).cuda().half() for region_mask_i in region_masks]] + region_masks = [[torch.Tensor(region_mask_i).to(self.device).half() for region_mask_i in region_masks]] image_args["region_masks"] = region_masks logger.info("Add region_masks to image_args.") else: @@ -211,15 +216,15 @@ def generate_stream(self, params): for i in range(max_new_tokens): if i == 0: out = model( - torch.as_tensor([input_ids]).cuda(), + torch.as_tensor([input_ids]).to(self.device), use_cache=True, **image_args) logits = out.logits past_key_values = out.past_key_values else: attention_mask = torch.ones( - 1, past_key_values[0][0].shape[-2] + 1, device="cuda") - out = model(input_ids=torch.as_tensor([[token]], device="cuda"), + 1, past_key_values[0][0].shape[-2] + 1, device=self.device) + out = model(input_ids=torch.as_tensor([[token]], device=self.device), use_cache=True, attention_mask=attention_mask, past_key_values=past_key_values, diff --git a/ferret/utils.py b/ferret/utils.py index a681120..4cd3981 100644 --- a/ferret/utils.py +++ b/ferret/utils.py @@ -6,7 +6,7 @@ import requests -from ferret.constants import LOGDIR +from ferret.constants import LOGDIR, TIMEOUT server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." @@ -110,7 +110,7 @@ def violates_moderation(text): data = "{" + '"input": ' + f'"{text}"' + "}" data = data.encode("utf-8") try: - ret = requests.post(url, headers=headers, data=data, timeout=5) + ret = requests.post(url, headers=headers, data=data, timeout=TIMEOUT) flagged = ret.json()["results"][0]["flagged"] except requests.exceptions.RequestException as e: flagged = False diff --git a/pyproject.toml b/pyproject.toml index aa285b7..ff6e804 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "deepspeed==0.9.5", "peft==0.4.0", "transformers @ git+https://github.com/huggingface/transformers.git@cae78c46", - "accelerate==0.21.0", + "accelerate==0.25.0", "bitsandbytes==0.41.0", "scikit-learn==1.2.2", "sentencepiece==0.1.99",