Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
/delta/
/model/
/.idea/
/ferret.egg-info/
/serve_images/
1 change: 1 addition & 0 deletions ferret/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
TIMEOUT = 300

LOGDIR = "."

Expand Down
4 changes: 3 additions & 1 deletion ferret/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from ferret.model import *
from ferret.model.device_selector import get_device
from ferret.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
DEFAULT_REGION_FEA_TOKEN = "<region_fea>"

Expand Down Expand Up @@ -128,7 +129,8 @@ def load_from_hf(repo_id, filename, subfolder=None):
else:
vision_tower.load_model()

vision_tower.to(device='cuda', dtype=torch.float16)
device = get_device()
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
Expand Down
13 changes: 13 additions & 0 deletions ferret/model/device_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch

def get_device():
"""
Determines the best available device (CUDA, MPS, or CPU)
and returns it for later use with PyTorch.
"""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
3 changes: 2 additions & 1 deletion ferret/model/ferret_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,8 @@ def initialize_vision_tokenizer(self, model_args, tokenizer, add_region_feature=
p.requires_grad = False

if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
device = get_device()
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location=device)
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
if add_region_feature:
num_new_tokens = num_new_tokens - num_region_fea_tokens
Expand Down
6 changes: 3 additions & 3 deletions ferret/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import requests
import uvicorn

from ferret.constants import CONTROLLER_HEART_BEAT_EXPIRATION
from ferret.constants import CONTROLLER_HEART_BEAT_EXPIRATION, TIMEOUT
from ferret.utils import build_logger, server_error_msg


Expand Down Expand Up @@ -87,7 +87,7 @@ def register_worker(self, worker_name: str, check_heart_beat: bool,

def get_worker_status(self, worker_name: str):
try:
r = requests.post(worker_name + "/worker_get_status", timeout=5)
r = requests.post(worker_name + "/worker_get_status", timeout=TIMEOUT)
except requests.exceptions.RequestException as e:
logger.error(f"Get status fails: {worker_name}, {e}")
return None
Expand Down Expand Up @@ -202,7 +202,7 @@ def worker_api_generate_stream(self, params):

try:
response = requests.post(worker_addr + "/worker_generate_stream",
json=params, stream=True, timeout=5)
json=params, stream=True, timeout=TIMEOUT)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"
Expand Down
4 changes: 2 additions & 2 deletions ferret/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ferret.conversation import (default_conversation, conv_templates,
SeparatorStyle)
from ferret.constants import LOGDIR
from ferret.constants import LOGDIR, TIMEOUT
from ferret.utils import (build_logger, server_error_msg,
violates_moderation, moderation_msg)
import hashlib
Expand Down Expand Up @@ -443,7 +443,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
try:
# Stream output
response = requests.post(worker_addr + "/worker_generate_stream",
headers=headers, json=pload, stream=True, timeout=10)
headers=headers, json=pload, stream=True, timeout=TIMEOUT)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
Expand Down
17 changes: 11 additions & 6 deletions ferret/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import uvicorn
from functools import partial

from ferret.constants import WORKER_HEART_BEAT_INTERVAL
from ferret.constants import WORKER_HEART_BEAT_INTERVAL, TIMEOUT
from ferret.utils import (build_logger, server_error_msg,
pretty_print_semaphore)
from ferret.model.device_selector import get_device
from ferret.model.builder import load_pretrained_model
from ferret.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
from ferret.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
Expand Down Expand Up @@ -80,6 +81,10 @@ def __init__(self, controller_addr, worker_addr,
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, model_base, self.model_name, load_8bit, load_4bit)
self.is_multimodal = 'llava' in self.model_name.lower() or 'ferret' in self.model_name.lower()
self.device = get_device()

logger.info(f"Device is {self.device}")
logger.info(f"Model device is {self.model.device}.")

if not no_register:
self.register_to_controller()
Expand Down Expand Up @@ -110,7 +115,7 @@ def send_heart_beat(self):
try:
ret = requests.post(url, json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length()}, timeout=5)
"queue_length": self.get_queue_length()}, timeout=TIMEOUT)
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
Expand Down Expand Up @@ -177,7 +182,7 @@ def generate_stream(self, params):

if region_masks is not None:
assert self.add_region_feature
region_masks = [[torch.Tensor(region_mask_i).cuda().half() for region_mask_i in region_masks]]
region_masks = [[torch.Tensor(region_mask_i).to(self.device).half() for region_mask_i in region_masks]]
image_args["region_masks"] = region_masks
logger.info("Add region_masks to image_args.")
else:
Expand Down Expand Up @@ -211,15 +216,15 @@ def generate_stream(self, params):
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids]).cuda(),
torch.as_tensor([input_ids]).to(self.device),
use_cache=True,
**image_args)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device="cuda")
out = model(input_ids=torch.as_tensor([[token]], device="cuda"),
1, past_key_values[0][0].shape[-2] + 1, device=self.device)
out = model(input_ids=torch.as_tensor([[token]], device=self.device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values,
Expand Down
4 changes: 2 additions & 2 deletions ferret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import requests

from ferret.constants import LOGDIR
from ferret.constants import LOGDIR, TIMEOUT

server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
Expand Down Expand Up @@ -110,7 +110,7 @@ def violates_moderation(text):
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
ret = requests.post(url, headers=headers, data=data, timeout=TIMEOUT)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"deepspeed==0.9.5",
"peft==0.4.0",
"transformers @ git+https://github.com/huggingface/transformers.git@cae78c46",
"accelerate==0.21.0",
"accelerate==0.25.0",
"bitsandbytes==0.41.0",
"scikit-learn==1.2.2",
"sentencepiece==0.1.99",
Expand Down