Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions src/utils/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum

from dotenv import load_dotenv
from langchain.chat_models import ChatAnthropic, ChatOpenAI
from langchain.chat_models import ChatAnthropic, ChatOpenAI, ChatLiteLLM
from langchain.chat_models.base import BaseChatModel
from langchain.llms import OpenAI
from langchain.schema import BaseMessage
Expand All @@ -20,20 +20,7 @@ def get_chat_model(name: ChatModelName, **kwargs) -> BaseChatModel:
del kwargs["model_name"]
if "model" in kwargs:
del kwargs["model"]

if name == ChatModelName.TURBO:
return ChatOpenAI(model_name=name.value, **kwargs)
elif name == ChatModelName.GPT4:
return ChatOpenAI(model_name=name.value, **kwargs)
elif name == ChatModelName.CLAUDE:
return ChatAnthropic(model=name.value, **kwargs)
elif name == ChatModelName.CLAUDE_INSTANT:
return ChatAnthropic(model=name.value, **kwargs)
elif name == ChatModelName.WINDOW:
return ChatWindowAI(model_name=name.value, **kwargs)
else:
raise ValueError(f"Invalid model name: {name}")

return ChatLiteLLM(model_name=name.value, **kwargs)

class ChatModel:
"""Wrapper around the ChatModel class."""
Expand Down