-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_utils.py
More file actions
116 lines (99 loc) · 4.14 KB
/
llm_utils.py
File metadata and controls
116 lines (99 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
import torch
from pathlib import Path
from typing import Union, Any
from vllm import LLM, SamplingParams
from litellm import completion
from dotenv import load_dotenv
import os
def load_config(config_path: str) -> dict:
"""Load configuration from JSON file."""
try:
with open(config_path, 'r') as f:
config = json.load(f)
print(f"Loaded config from {config_path}: {config}")
return config
except json.JSONDecodeError as e:
print(f"Error parsing JSON from {config_path}: {str(e)}")
raise
except Exception as e:
print(f"Error loading config from {config_path}: {str(e)}")
raise
def get_model_config(model_name: str) -> dict:
"""Get model-specific configuration from JSON files."""
base_config_dir = Path(__file__).parent / "model_configs"
local_config_dir = Path(__file__).parent / "model_configs_local"
base_config = load_config(base_config_dir / "base.json")
base_config["tensor_parallel_size"] = torch.cuda.device_count()
model_config = {}
model_name_lower = model_name.lower()
# Check both directories for matching config files
for config_path in list(local_config_dir.glob('*.json')) + list(base_config_dir.glob('*.json')):
if config_path.stem.lower() in model_name_lower:
try:
model_config = load_config(config_path)
print(f"Using configuration from {config_path}")
if local_config_dir in config_path.parents:
break
except Exception as e:
print(f"Error loading config from {config_path}: {str(e)}")
continue
if not model_config:
print("Using default configuration")
model_config = {}
# Merge configs with model config taking precedence
return {**base_config, **model_config}
class LiteLLMWrapper:
"""Wrapper class for LiteLLM to provide a similar interface to vLLM."""
def __init__(self, model_name: str):
self.model = model_name
self.default_sampling_params = {
'max_completion_tokens': 1024,
}
def get_default_sampling_params(self) -> dict:
return self.default_sampling_params.copy()
def chat(self, messages: list, sampling_params: dict, **kwargs) -> Any:
response = completion(
model=self.model,
messages=messages,
**sampling_params
)
return [response]
def load_model(model_name: str, use_api: bool = False) -> Union[LLM, LiteLLMWrapper]:
"""Initialize vLLM engine or LiteLLM client with the specified model.
Args:
model_name: Name of the model to load
use_api: Whether to use API-based models via LiteLLM
"""
if use_api:
print(f"Initializing LiteLLM client for model: {model_name}")
load_dotenv() # Load environment variables from .env file
return LiteLLMWrapper(model_name)
print(f"Loading local model with vLLM: {model_name}")
config = get_model_config(model_name)
# Extract HF-specific config overrides
hf_config = config.pop('model_kwargs', {}).get('config', {})
if 'config_overrides' in config:
hf_config.update(config.pop('config_overrides'))
# Initialize LLM with engine parameters and HF overrides
model = LLM(
model=model_name,
generation_config="auto",
hf_overrides=hf_config if hf_config else None,
**config
)
return model
def get_sampling_params(engine: Union[LLM, LiteLLMWrapper], max_tokens: int, temperature: float) -> Union[SamplingParams, dict]:
"""Get sampling parameters for either vLLM or LiteLLM."""
if isinstance(engine, LLM):
sampling_params = engine.get_default_sampling_params()
sampling_params.max_tokens = max_tokens
if temperature is not None:
sampling_params.temperature = temperature
return sampling_params
else: # LiteLLMWrapper
params = engine.get_default_sampling_params()
params['max_completion_tokens'] = max_tokens
if temperature is not None:
params['temperature'] = temperature
return params