diff --git a/src/utils/models.py b/src/utils/models.py index cec4059..861c61e 100644 --- a/src/utils/models.py +++ b/src/utils/models.py @@ -1,7 +1,7 @@ from enum import Enum from dotenv import load_dotenv -from langchain.chat_models import ChatAnthropic, ChatOpenAI +from langchain.chat_models import ChatAnthropic, ChatOpenAI, ChatLiteLLM from langchain.chat_models.base import BaseChatModel from langchain.llms import OpenAI from langchain.schema import BaseMessage @@ -20,20 +20,7 @@ def get_chat_model(name: ChatModelName, **kwargs) -> BaseChatModel: del kwargs["model_name"] if "model" in kwargs: del kwargs["model"] - - if name == ChatModelName.TURBO: - return ChatOpenAI(model_name=name.value, **kwargs) - elif name == ChatModelName.GPT4: - return ChatOpenAI(model_name=name.value, **kwargs) - elif name == ChatModelName.CLAUDE: - return ChatAnthropic(model=name.value, **kwargs) - elif name == ChatModelName.CLAUDE_INSTANT: - return ChatAnthropic(model=name.value, **kwargs) - elif name == ChatModelName.WINDOW: - return ChatWindowAI(model_name=name.value, **kwargs) - else: - raise ValueError(f"Invalid model name: {name}") - + return ChatLiteLLM(model_name=name.value, **kwargs) class ChatModel: """Wrapper around the ChatModel class."""