diff --git a/main.py b/main.py index 63b4a7b..01121ad 100644 --- a/main.py +++ b/main.py @@ -10,17 +10,14 @@ from src.core.services.embedding import EmbeddingService from src.core.services.db_service import DatabaseService from src.config.settings import settings -from openai import AsyncOpenAI +from src.config.model_provider import ModelProvider from src.utils.logging import logger async def process_documents(base_dir: str): """Process markdown documents to embeddings""" - openai_client = AsyncOpenAI( - api_key=settings.OPENAI_API_KEY, - base_url=settings.OPENAI_API_BASE - ) + model_provider = ModelProvider() db_service = DatabaseService() - embedding_service = EmbeddingService(openai_client) + embedding_service = EmbeddingService(model_provider.embed_model) processor = DocumentProcessor(db_service, embedding_service) await processor.process_directory(base_dir) @@ -50,13 +47,10 @@ async def check_updates(raw_dir: str, markdown_dir: str): Returns: Added, modified, removed files """ - openai_client = AsyncOpenAI( - api_key=settings.OPENAI_API_KEY, - base_url=settings.OPENAI_API_BASE - ) + model_provider = ModelProvider() db_service = DatabaseService() - embedding_service = EmbeddingService(openai_client) + embedding_service = EmbeddingService(model_provider.embed_model) document_processor = DocumentProcessor(db_service, embedding_service) markdown_converter = MarkdownConverter() diff --git a/src/api/routes/chat.py b/src/api/routes/chat.py index 9e48469..52aeacc 100644 --- a/src/api/routes/chat.py +++ b/src/api/routes/chat.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse -from openai import AsyncOpenAI +from src.config.model_provider import ModelProvider from src.core.services.db_service import DatabaseService from src.api.models.chat import ChatRequest, ChatResponse from src.api.dependencies.auth import verify_token @@ -13,14 +13,11 @@ # Create dependency for services async def get_services(): - openai_client = AsyncOpenAI( - api_key=settings.OPENAI_API_KEY, - base_url=settings.OPENAI_API_BASE - ) + model_provider = ModelProvider() db_service = DatabaseService() - embedding_service = EmbeddingService(openai_client) - chat_service = ChatService(openai_client, db_service, embedding_service) + embedding_service = EmbeddingService(model_provider.embed_model) + chat_service = ChatService(model_provider.llm, db_service, embedding_service) return chat_service diff --git a/src/config/model_provider.py b/src/config/model_provider.py new file mode 100644 index 0000000..5f3d05d --- /dev/null +++ b/src/config/model_provider.py @@ -0,0 +1,219 @@ +from typing import Dict +from src.config.settings import settings +from src.utils.logging import logger + +class ModelProvider: + def __init__(self): + model_provider = settings.MODEL_PROVIDER + logger.info(f"Model provider: {model_provider}") + logger.info(f"Embedding model: {settings.EMBEDDING_MODEL}") + logger.info(f"LLM: {settings.MODEL}") + match model_provider: + case "openai": + self.init_openai() + case "groq": + self.init_groq() + case "ollama": + self.init_ollama() + case "anthropic": + self.init_anthropic() + case "gemini": + self.init_gemini() + case "mistral": + self.init_mistral() + case "azure-openai": + self.init_azure_openai() + case "huggingface": + self.init_huggingface() + case _: + raise ValueError(f"Invalid model provider: {model_provider}") + + def init_ollama(self): + try: + from llama_index.embeddings.ollama import OllamaEmbedding + from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama + except ImportError: + raise ImportError( + "Ollama support is not installed. Please install it with `poetry add llama-index-llms-ollama` and `poetry add llama-index-embeddings-ollama`" + ) + + base_url = settings.OLLAMA_BASE_URL or "http://127.0.0.1:11434" + request_timeout = float( + settings.OLLAMA_REQUEST_TIMEOUT or DEFAULT_REQUEST_TIMEOUT + ) + self.embed_model = OllamaEmbedding( + base_url=base_url, + model_name=settings.EMBEDDING_MODEL, + ) + self.llm = Ollama( + base_url=base_url, model=settings.MODEL, request_timeout=request_timeout + ) + + def init_openai(self): + from llama_index.core.constants import DEFAULT_TEMPERATURE + from llama_index.embeddings.openai import OpenAIEmbedding + from llama_index.llms.openai import OpenAI + + max_tokens = settings.LLM_MAX_TOKENS + model_name = settings.MODEL or "gpt-4o" + self.llm = OpenAI( + model=model_name, + temperature=float(settings.LLM_TEMPERATURE or DEFAULT_TEMPERATURE), + max_tokens=int(max_tokens) if max_tokens is not None else None, + ) + + dimensions = settings.EMBEDDING_DIM + self.embed_model = OpenAIEmbedding( + model=settings.EMBEDDING_MODEL or "text-embedding-3-small", + dimensions=int(dimensions) if dimensions is not None else None, + ) + + def init_azure_openai(self): + from llama_index.core.constants import DEFAULT_TEMPERATURE + + try: + from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding + from llama_index.llms.azure_openai import AzureOpenAI + except ImportError: + raise ImportError( + "Azure OpenAI support is not installed. Please install it with `poetry add llama-index-llms-azure-openai` and `poetry add llama-index-embeddings-azure-openai`" + ) + + llm_deployment = settings.AZURE_OPENAI_LLM_DEPLOYMENT + embedding_deployment = settings.AZURE_OPENAI_EMBEDDING_DEPLOYMENT + max_tokens = settings.LLM_MAX_TOKENS + temperature = settings.LLM_TEMPERATURE or DEFAULT_TEMPERATURE + dimensions = settings.EMBEDDING_DIM + + azure_config = { + "api_key": settings.AZURE_OPENAI_API_KEY, + "azure_endpoint": settings.AZURE_OPENAI_ENDPOINT, + "api_version": settings.AZURE_OPENAI_API_VERSION or settings.OPENAI_API_VERSION, + } + + self.llm = AzureOpenAI( + model=settings.MODEL, + max_tokens=int(max_tokens) if max_tokens is not None else None, + temperature=float(temperature), + deployment_name=llm_deployment, + **azure_config, + ) + + self.embed_model = AzureOpenAIEmbedding( + model=settings.EMBEDDING_MODEL, + dimensions=int(dimensions) if dimensions is not None else None, + deployment_name=embedding_deployment, + **azure_config, + ) + + def init_fastembed(self): + try: + from llama_index.embeddings.fastembed import FastEmbedEmbedding + except ImportError: + raise ImportError( + "FastEmbed support is not installed. Please install it with `poetry add llama-index-embeddings-fastembed`" + ) + + embed_model_map: Dict[str, str] = { + # Small and multilingual + "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", + # Large and multilingual + "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + } + + embedding_model = settings.EMBEDDING_MODEL + if embedding_model is None: + raise ValueError("EMBEDDING_MODEL environment variable is not set") + + # This will download the model automatically if it is not already downloaded + self.embed_model = FastEmbedEmbedding( + model_name=embed_model_map[embedding_model] + ) + + def init_huggingface_embedding(self): + try: + from llama_index.embeddings.huggingface import HuggingFaceEmbedding + except ImportError: + raise ImportError( + "Hugging Face support is not installed. Please install it with `poetry add llama-index-embeddings-huggingface`" + ) + + embedding_model = settings.EMBEDDING_MODEL or "all-MiniLM-L6-v2" + backend = settings.EMBEDDING_BACKEND or "onnx" # "torch", "onnx", or "openvino" + trust_remote_code = ( + (settings.EMBEDDING_TRUST_REMOTE_CODE or "false").lower() == "true" + ) + + self.embed_model = HuggingFaceEmbedding( + model_name=embedding_model, + trust_remote_code=trust_remote_code, + backend=backend, + ) + + def init_huggingface(self): + try: + from llama_index.llms.huggingface import HuggingFaceLLM + except ImportError: + raise ImportError( + "Hugging Face support is not installed. Please install it with `poetry add llama-index-llms-huggingface` and `poetry add llama-index-embeddings-huggingface`" + ) + + self.llm = HuggingFaceLLM( + model_name=settings.MODEL, + tokenizer_name=settings.MODEL, + ) + self.init_huggingface_embedding() + + def init_groq(self): + try: + from llama_index.llms.groq import Groq + except ImportError: + raise ImportError( + "Groq support is not installed. Please install it with `poetry add llama-index-llms-groq`" + ) + + self.llm = Groq(model=settings.MODEL) + # Groq does not provide embeddings, so we use FastEmbed instead + self.init_fastembed() + + def init_anthropic(self): + try: + from llama_index.llms.anthropic import Anthropic + except ImportError: + raise ImportError( + "Anthropic support is not installed. Please install it with `poetry add llama-index-llms-anthropic`" + ) + + model_map: Dict[str, str] = { + "claude-3-opus": "claude-3-opus-20240229", + "claude-3-sonnet": "claude-3-sonnet-20240229", + "claude-3-haiku": "claude-3-haiku-20240307", + "claude-2.1": "claude-2.1", + "claude-instant-1.2": "claude-instant-1.2", + } + + self.llm = Anthropic(model=model_map[settings.MODEL]) + # Anthropic does not provide embeddings, so we use FastEmbed instead + self.init_fastembed() + + def init_gemini(self): + try: + from llama_index.embeddings.gemini import GeminiEmbedding + from llama_index.llms.gemini import Gemini + except ImportError: + raise ImportError( + "Gemini support is not installed. Please install it with `poetry add llama-index-llms-gemini` and `poetry add llama-index-embeddings-gemini`" + ) + + model_name = f"models/{settings.MODEL}" + embed_model_name = f"models/{settings.EMBEDDING_MODEL}" + + self.llm = Gemini(model=model_name) + self.embed_model = GeminiEmbedding(model_name=embed_model_name) + + def init_mistral(self): + from llama_index.embeddings.mistralai import MistralAIEmbedding + from llama_index.llms.mistralai import MistralAI + + self.llm = MistralAI(model=settings.MODEL) + self.embed_model = MistralAIEmbedding(model_name=settings.EMBEDDING_MODEL) \ No newline at end of file diff --git a/src/config/settings.py b/src/config/settings.py index d1419f3..6d44d21 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -7,11 +7,26 @@ class Settings(BaseSettings): API_VERSION: str = "0.0.1" API_TITLE: str = "Odoo Expert API" API_DESCRIPTION: str = "API for querying Odoo documentation with RAG-powered responses" - + + # The provider for the AI models to use. + MODEL_PROVIDER: str = "openai" + + # Ollama Settings + OLLAMA_BASE_URL: str + OLLAMA_REQUEST_TIMEOUT: int + + # The name of LLM model to use. + MODEL: str = "gpt-4o" + + # Name of the embedding model to use. + EMBEDDING_MODEL: str = "text-embedding-3-small" + + # Dimension of the embedding model to use. + EMBEDDING_DIM: int = 1024 + # OpenAI Settings - OPENAI_API_KEY: str - OPENAI_API_BASE: str - LLM_MODEL: str = "gpt-4o" + OPENAI_API_KEY: str = "" + OPENAI_API_BASE: str = "" # PostgreSQL Settings POSTGRES_USER: str = "postgres" diff --git a/src/core/services/chat_service.py b/src/core/services/chat_service.py index b0ca780..619ee0d 100644 --- a/src/core/services/chat_service.py +++ b/src/core/services/chat_service.py @@ -1,5 +1,6 @@ from typing import List, Dict, Optional, Tuple -from openai import AsyncOpenAI +from llama_index.llms.ollama.base import Ollama +from llama_index.core.llms import ChatMessage, MessageRole from src.core.services.embedding import EmbeddingService from src.core.services.db_service import DatabaseService from src.config.settings import settings @@ -8,11 +9,11 @@ class ChatService: def __init__( self, - openai_client: AsyncOpenAI, + llm: Ollama, db_service: DatabaseService, embedding_service: EmbeddingService ): - self.openai_client = openai_client + self.llm = llm self.db_service = db_service self.embedding_service = embedding_service @@ -63,37 +64,32 @@ async def generate_response( ): """Generate AI response based on query and context.""" try: - messages = [ - { - "role": "system", - "content": settings.SYSTEM_PROMPT - } - ] + messages = [ChatMessage( + role=MessageRole.SYSTEM, + content=settings.SYSTEM_PROMPT, + )] if conversation_history: history_text = "\n".join([ f"User: {msg['user']}\nAssistant: {msg['assistant']}" for msg in conversation_history[-3:] ]) - messages.append({ - "role": "user", - "content": f"Previous conversation:\n{history_text}" - }) - - messages.append({ - "role": "user", - "content": f"Question: {query}\n\nRelevant documentation:\n{context}" - }) - - response = await self.openai_client.chat.completions.create( - model=settings.LLM_MODEL, - messages=messages, - stream=stream - ) - + messages.append(ChatMessage( + role=MessageRole.USER, + content=f"Previous conversation:\n{history_text}", + )) + + messages.append(ChatMessage( + role=MessageRole.USER, + content=f"Question: {query}\n\nRelevant documentation:\n{context}", + )) + if stream: - return response - return response.choices[0].message.content + response_stream = await self.llm.astream_chat(messages) + return response_stream + else: + response = self.llm.chat(messages) + return response.message.content except Exception as e: logger.error(f"Error generating response: {e}") diff --git a/src/core/services/embedding.py b/src/core/services/embedding.py index f982c3e..2e6a141 100644 --- a/src/core/services/embedding.py +++ b/src/core/services/embedding.py @@ -1,10 +1,10 @@ from typing import List -from openai import AsyncOpenAI +from llama_index.embeddings.ollama import OllamaEmbedding from src.utils.logging import logger class EmbeddingService: - def __init__(self, client: AsyncOpenAI): - self.client = client + def __init__(self, embed_model: OllamaEmbedding): + self.embed_model = embed_model async def get_embedding(self, text: str) -> List[float]: try: @@ -12,11 +12,7 @@ async def get_embedding(self, text: str) -> List[float]: if len(text) > 8000: text = text[:8000] + "..." - response = await self.client.embeddings.create( - model="text-embedding-3-small", - input=text - ) - return response.data[0].embedding + return await self.embed_model.aget_general_text_embedding(text) except Exception as e: logger.error(f"Error getting embedding: {e}") raise \ No newline at end of file diff --git a/src/ui/streamlit_app.py b/src/ui/streamlit_app.py index c76ca47..c6bb2d5 100644 --- a/src/ui/streamlit_app.py +++ b/src/ui/streamlit_app.py @@ -1,5 +1,6 @@ import sys from pathlib import Path +from typing import Dict # Add project root to Python path project_root = Path(__file__).parent.parent.parent @@ -11,24 +12,22 @@ from src.core.services.chat_service import ChatService from src.core.services.embedding import EmbeddingService from src.config.settings import settings +from src.config.model_provider import ModelProvider from src.utils.logging import logger -from openai import AsyncOpenAI from src.core.services.db_service import DatabaseService class StreamlitUI: def __init__(self): - self.openai_client = AsyncOpenAI( - api_key=settings.OPENAI_API_KEY, - base_url=settings.OPENAI_API_BASE - ) + model_provider = ModelProvider() + self.db_service = DatabaseService() - self.embedding_service = EmbeddingService(self.openai_client) + self.embedding_service = EmbeddingService(model_provider.embed_model) self.chat_service = ChatService( - self.openai_client, + model_provider.llm, self.db_service, self.embedding_service ) - + async def cleanup(self): """Cleanup resources.""" if hasattr(self, 'db_service'): @@ -87,14 +86,11 @@ async def process_query(self, query: str, version: int): conversation_history=st.session_state.conversation_history, stream=True ) - + async for chunk in response: - # Add more robust error checking - if chunk and hasattr(chunk, 'choices') and chunk.choices: - delta = chunk.choices[0].delta - if hasattr(delta, 'content') and delta.content: - full_response += delta.content - response_placeholder.markdown(full_response) + full_response += chunk.delta + response_placeholder.markdown(chunk.message.content) + response_placeholder.markdown(full_response) if full_response: # Add to conversation history only if we got a valid response