Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions server/services/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from services.google_ai_client import GoogleAIClient
from services.deepseek_client import DeepSeekClient
from services.dashscope_client import DashScopeClient
from services.runpod_client import RunPodClient

from utilities.constants.LLM_enums import LLMType, ModelType
from utilities.constants.response_messages import ERROR_UNSUPPORTED_CLIENT_TYPE
Expand All @@ -26,5 +27,7 @@ def get_client(type: LLMType, model: Optional[ModelType] = None, temperature: Op
return DeepSeekClient(model=model, temperature=temperature, max_tokens=max_tokens)
elif type == LLMType.DASHSCOPE:
return DashScopeClient(model=model, temperature=temperature, max_tokens=max_tokens)
elif type == LLMType.RUNPOD:
return RunPodClient(model=model, temperature=temperature, max_tokens=max_tokens)
else:
raise ValueError(ERROR_UNSUPPORTED_CLIENT_TYPE)
52 changes: 52 additions & 0 deletions server/services/runpod_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from services.base_client import Client
from utilities.constants.LLM_enums import LLMType, ModelType, RUNPOD_ENPOINTS
from openai import OpenAI
from typing import Optional
from utilities.config import RUNPOD_API_KEY
from utilities.utility_functions import format_chat
from utilities.constants.response_messages import (
ERROR_API_FAILURE
)
import re

class RunPodClient(Client):
def __init__(self, model: Optional[ModelType], max_tokens: Optional[int], temperature: Optional[float]):
RUNPOD_ENDPOINT_ID = RUNPOD_ENPOINTS.get(model)
client = OpenAI(api_key=RUNPOD_API_KEY, base_url=f'https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1')
super().__init__(model=model, temperature=temperature, max_tokens=max_tokens, client=client)

def execute_prompt(self, prompt: str) -> str:
try:
response = self.client.chat.completions.create(
model=self.model.value,
messages=[{"role": "user", "content": prompt}],
max_tokens=self.max_tokens,
temperature=self.temperature,
stream=False
)
content = response.choices[0].message.content
pattern = r"```sql\n(.*?)```"
match = re.search(pattern, content, re.DOTALL)
return match.group(1).strip() if match else content

except Exception as e:
raise RuntimeError(ERROR_API_FAILURE.format(llm_type=LLMType.RUNPOD.value, error=str(e)))

def execute_chat(self, chat):

chat=format_chat(chat, {'system':'system', 'user':'user', 'model':'assistant', 'content':'content'})
try:
response = self.client.chat.completions.create(
model=self.model.value,
messages=chat,
max_tokens=self.max_tokens,
temperature=self.temperature,
stream = False
)
content = response.choices[0].message.content
pattern = r"```sql\n(.*?)```"
match = re.search(pattern, content, re.DOTALL)
return match.group(1).strip() if match else content

except Exception as e:
raise RuntimeError(ERROR_API_FAILURE.format(llm_type=LLMType.RUNPOD.value, error=str(e)))
3 changes: 3 additions & 0 deletions server/utilities/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
GOOGLE_AI_API_KEY = os.getenv("GOOGLE_AI_API_KEY")
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
RUNPOD_API_KEY = os.getenv("RUNPOD_API_KEY")
ALL_GOOGLE_KEYS = os.getenv('ALL_GOOGLE_API_KEYS')
ALL_GOOGLE_KEYS = ALL_GOOGLE_KEYS.split(" ")

Expand All @@ -32,6 +33,8 @@
raise RuntimeError(ERROR_API_KEY_MISSING.format(api_key="DEEPSEEK_API_KEY"))
if not DASHSCOPE_API_KEY:
raise RuntimeError(ERROR_API_KEY_MISSING.format(api_key="DASHSCOPE_API_KEY"))
if not RUNPOD_API_KEY:
raise RuntimeError(ERROR_API_KEY_MISSING.format(api_key="RUNPOD_API_KEY"))

class ChromadbClient:
CHROMADB_CLIENT=chromadb.PersistentClient() # Default path is "./chroma"
15 changes: 15 additions & 0 deletions server/utilities/constants/LLM_enums.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from enum import Enum
import os
from dotenv import load_dotenv

load_dotenv()

class LLMType(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
GOOGLE_AI = "google_ai"
DEEPSEEK = "deepseek"
DASHSCOPE = "dashscope"
RUNPOD = "runpod"

class ModelType(Enum):
# OpenAI Models
Expand Down Expand Up @@ -55,6 +60,8 @@ class ModelType(Enum):
DASHSCOPE_QWEN1_5_14B_CHAT = 'qwen1.5-14b-chat'
DASHSCOPE_QWEN1_5_7B_CHAT = 'qwen1.5-7b-chat'

RUNPOD_LLAMA_70B = 'meta-llama/Llama-3.3-70B-Instruct'


VALID_LLM_MODELS = {
LLMType.OPENAI: [
Expand Down Expand Up @@ -100,9 +107,17 @@ class ModelType(Enum):
ModelType.DASHSCOPE_QWEN1_5_32B_CHAT,
ModelType.DASHSCOPE_QWEN1_5_14B_CHAT,
ModelType.DASHSCOPE_QWEN1_5_7B_CHAT
],
LLMType.RUNPOD: [
ModelType.RUNPOD_LLAMA_70B
]
}

RUNPOD_ENPOINTS = {
ModelType.RUNPOD_LLAMA_70B: os.getenv("RUNPOD_LLAMA_70B_ENDPOINT")
}


MODEL_COST = {
# Pricing per 1K tokens
ModelType.OPENAI_GPT3_5_TURBO_0125: {
Expand Down