From 7772a2bb88c2366718f62bb475aa553617dcfeac Mon Sep 17 00:00:00 2001 From: Jannes Stubbemann Date: Sat, 21 Mar 2026 22:53:39 +0100 Subject: [PATCH 1/2] feat(embedding): add litellm as embedding provider Adds LiteLLM as a new embedding provider, bringing embedding parity with the VLM layer which already supports litellm. This enables users to route embedding requests through OpenRouter, Ollama, vLLM, and any other OpenAI-compatible endpoint via litellm's unified interface. Closes #847 Co-Authored-By: Claude Opus 4.6 (1M context) --- openviking/models/embedder/__init__.py | 8 + .../models/embedder/litellm_embedders.py | 203 +++++++++++++ .../utils/config/embedding_config.py | 30 +- tests/unit/test_litellm_embedder.py | 276 ++++++++++++++++++ 4 files changed, 511 insertions(+), 6 deletions(-) create mode 100644 openviking/models/embedder/litellm_embedders.py create mode 100644 tests/unit/test_litellm_embedder.py diff --git a/openviking/models/embedder/__init__.py b/openviking/models/embedder/__init__.py index 52fe8ce7..f69a74e3 100644 --- a/openviking/models/embedder/__init__.py +++ b/openviking/models/embedder/__init__.py @@ -14,6 +14,7 @@ - Jina AI: Dense only - Voyage AI: Dense only - Google Gemini: Dense only +- LiteLLM: Dense only (bridges to OpenRouter, Ollama, vLLM, and many others) """ from openviking.models.embedder.base import ( @@ -30,6 +31,11 @@ except ImportError: GeminiDenseEmbedder = None # google-genai not installed from openviking.models.embedder.jina_embedders import JinaDenseEmbedder + +try: + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder +except ImportError: + LiteLLMDenseEmbedder = None # litellm not installed from openviking.models.embedder.minimax_embedders import MinimaxDenseEmbedder from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder from openviking.models.embedder.vikingdb_embedders import ( @@ -56,6 +62,8 @@ "GeminiDenseEmbedder", # Jina AI implementations "JinaDenseEmbedder", + # LiteLLM implementations + "LiteLLMDenseEmbedder", # MiniMax implementations "MinimaxDenseEmbedder", # OpenAI implementations diff --git a/openviking/models/embedder/litellm_embedders.py b/openviking/models/embedder/litellm_embedders.py new file mode 100644 index 00000000..7cce60da --- /dev/null +++ b/openviking/models/embedder/litellm_embedders.py @@ -0,0 +1,203 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""LiteLLM Embedder Implementation + +Uses litellm to provide a unified embedding interface across many providers +(OpenRouter, Ollama, vLLM, and any OpenAI-compatible endpoint). +""" + +import logging +import os + +os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") + +from typing import Any, Dict, List, Optional + +import litellm + +from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult +from openviking.telemetry import get_current_telemetry + +logger = logging.getLogger(__name__) + + +class LiteLLMDenseEmbedder(DenseEmbedderBase): + """LiteLLM Dense Embedder Implementation + + Routes embedding requests through litellm, supporting dozens of providers + via a unified interface. Model names use litellm's provider/model format + (e.g., "openai/text-embedding-3-small", "ollama/nomic-embed-text"). + + Example: + >>> # OpenRouter embeddings + >>> embedder = LiteLLMDenseEmbedder( + ... model_name="openai/text-embedding-3-small", + ... api_key="sk-or-...", + ... api_base="https://openrouter.ai/api/v1", + ... dimension=1536, + ... ) + >>> result = embedder.embed("Hello world") + + >>> # Local Ollama embeddings + >>> embedder = LiteLLMDenseEmbedder( + ... model_name="ollama/nomic-embed-text", + ... api_base="http://localhost:11434", + ... ) + >>> result = embedder.embed("Hello world") + """ + + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + dimension: Optional[int] = None, + query_param: Optional[str] = None, + document_param: Optional[str] = None, + extra_headers: Optional[Dict[str, str]] = None, + config: Optional[Dict[str, Any]] = None, + ): + """Initialize LiteLLM Dense Embedder + + Args: + model_name: Model name in litellm format (e.g., "openai/text-embedding-3-small"). + api_key: API key for the provider. Falls back to provider-specific env vars. + api_base: Custom API base URL (e.g., "https://openrouter.ai/api/v1"). + dimension: Embedding vector dimension. If None, auto-detected via a probe call. + query_param: Parameter value for query-side embeddings (non-symmetric mode). + document_param: Parameter value for document-side embeddings (non-symmetric mode). + extra_headers: Extra HTTP headers for API requests. + config: Additional configuration dict. + """ + super().__init__(model_name, config) + + self.api_key = api_key + self.api_base = api_base + self.dimension = dimension + self.query_param = query_param + self.document_param = document_param + self.extra_headers = extra_headers + + self._dimension = dimension + if self._dimension is None: + self._dimension = self._detect_dimension() + + def _detect_dimension(self) -> int: + """Detect dimension by making a probe embedding call.""" + try: + result = self.embed("test") + return len(result.dense_vector) if result.dense_vector else 1536 + except Exception: + return 1536 + + def _build_kwargs(self, is_query: bool = False) -> Dict[str, Any]: + """Build kwargs dict for litellm.embedding() call.""" + kwargs: Dict[str, Any] = {"model": self.model_name} + + if self.api_key: + kwargs["api_key"] = self.api_key + if self.api_base: + kwargs["api_base"] = self.api_base + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + if self.dimension: + kwargs["dimensions"] = self.dimension + + # Non-symmetric embedding support + active_param = None + if is_query and self.query_param is not None: + active_param = self.query_param + elif not is_query and self.document_param is not None: + active_param = self.document_param + + if active_param: + if "=" in active_param: + # Parse key=value format (e.g., "input_type=query,task=search") + extra_body = {} + for part in active_param.split(","): + part = part.strip() + if "=" in part: + key, value = part.split("=", 1) + extra_body[key.strip()] = value.strip() + if extra_body: + kwargs["extra_body"] = extra_body + else: + kwargs["input_type"] = active_param + + return kwargs + + def _update_telemetry_token_usage(self, response) -> None: + """Update telemetry with token usage from response.""" + usage = getattr(response, "usage", None) + if not usage: + return + + def _usage_value(key: str, default: int = 0) -> int: + if isinstance(usage, dict): + return int(usage.get(key, default) or default) + return int(getattr(usage, key, default) or default) + + prompt_tokens = _usage_value("prompt_tokens", 0) + total_tokens = _usage_value("total_tokens", prompt_tokens) + output_tokens = max(total_tokens - prompt_tokens, 0) + get_current_telemetry().add_token_usage_by_source( + "embedding", + prompt_tokens, + output_tokens, + ) + + def embed(self, text: str, is_query: bool = False) -> EmbedResult: + """Perform dense embedding on text via litellm. + + Args: + text: Input text + is_query: Flag to indicate if this is a query embedding + + Returns: + EmbedResult: Result containing dense_vector + + Raises: + RuntimeError: When embedding call fails + """ + try: + kwargs = self._build_kwargs(is_query=is_query) + kwargs["input"] = [text] + response = litellm.embedding(**kwargs) + self._update_telemetry_token_usage(response) + vector = response.data[0]["embedding"] + return EmbedResult(dense_vector=vector) + except Exception as e: + raise RuntimeError(f"LiteLLM embedding failed: {e}") from e + + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: + """Batch embedding via litellm. + + Args: + texts: List of texts + is_query: Flag to indicate if these are query embeddings + + Returns: + List[EmbedResult]: List of embedding results + + Raises: + RuntimeError: When embedding call fails + """ + if not texts: + return [] + + try: + kwargs = self._build_kwargs(is_query=is_query) + kwargs["input"] = texts + response = litellm.embedding(**kwargs) + self._update_telemetry_token_usage(response) + return [EmbedResult(dense_vector=item["embedding"]) for item in response.data] + except Exception as e: + raise RuntimeError(f"LiteLLM batch embedding failed: {e}") from e + + def get_dimension(self) -> int: + """Get embedding dimension. + + Returns: + int: Vector dimension + """ + return self._dimension diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 14435438..9ae01d00 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -37,9 +37,9 @@ class EmbeddingModelConfig(BaseModel): provider: Optional[str] = Field( default="volcengine", description=( - "Provider type: 'openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage'. " - "For OpenRouter or other OpenAI-compatible providers, use 'openai' with " - "api_base and extra_headers." + "Provider type: 'openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'litellm'. " + "For OpenRouter or other OpenAI-compatible providers, use 'litellm' with " + "api_base and api_key, or 'openai' with api_base and extra_headers." ), ) backend: Optional[str] = Field( @@ -103,10 +103,11 @@ def validate_config(self): "gemini", "voyage", "minimax", + "litellm", ]: raise ValueError( f"Invalid embedding provider: '{self.provider}'. Must be one of: " - "'openai', 'azure', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'minimax'" + "'openai', 'azure', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'minimax', 'litellm'" ) # Provider-specific validation @@ -178,6 +179,10 @@ def validate_config(self): if not self.api_key: raise ValueError("MiniMax provider requires 'api_key' to be set") + elif self.provider == "litellm": + # litellm handles auth via env vars or explicit api_key; no strict requirement + pass + return self def get_effective_dimension(self) -> int: @@ -203,7 +208,7 @@ def get_effective_dimension(self) -> int: class EmbeddingConfig(BaseModel): """ - Embedding configuration, supports OpenAI, VolcEngine, VikingDB, Jina, Gemini, or Voyage APIs. + Embedding configuration, supports OpenAI, VolcEngine, VikingDB, Jina, Gemini, Voyage, or LiteLLM APIs. Structure: - dense: Configuration for dense embedder @@ -241,7 +246,7 @@ def _create_embedder( """Factory method to create embedder instance based on provider and type. Args: - provider: Provider type ('openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage') + provider: Provider type ('openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'litellm') embedder_type: Embedder type ('dense', 'sparse', 'hybrid') config: EmbeddingModelConfig instance @@ -254,6 +259,7 @@ def _create_embedder( from openviking.models.embedder import ( GeminiDenseEmbedder, JinaDenseEmbedder, + LiteLLMDenseEmbedder, MinimaxDenseEmbedder, OpenAIDenseEmbedder, VikingDBDenseEmbedder, @@ -414,6 +420,18 @@ def _create_embedder( **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), }, ), + ("litellm", "dense"): ( + LiteLLMDenseEmbedder, + lambda cfg: { + "model_name": cfg.model, + "api_key": cfg.api_key, + "api_base": cfg.api_base, + "dimension": cfg.dimension, + **({"query_param": cfg.query_param} if cfg.query_param else {}), + **({"document_param": cfg.document_param} if cfg.document_param else {}), + **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), + }, + ), } key = (provider, embedder_type) diff --git a/tests/unit/test_litellm_embedder.py b/tests/unit/test_litellm_embedder.py new file mode 100644 index 00000000..5fd4a95e --- /dev/null +++ b/tests/unit/test_litellm_embedder.py @@ -0,0 +1,276 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for LiteLLM Embedder and factory integration.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from openviking_cli.utils.config.embedding_config import EmbeddingConfig, EmbeddingModelConfig + + +def _mock_litellm_response(vectors=None, usage=None): + """Create a mock litellm embedding response.""" + if vectors is None: + vectors = [[0.1] * 1536] + response = MagicMock() + response.data = [{"embedding": v} for v in vectors] + response.usage = usage + return response + + +class TestLiteLLMDenseEmbedder: + """Test cases for LiteLLMDenseEmbedder.""" + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_embed_basic(self, mock_litellm): + """Basic embedding should return a dense vector.""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + dimension=1536, + ) + result = embedder.embed("Hello world") + + assert result.dense_vector is not None + assert len(result.dense_vector) == 1536 + mock_litellm.embedding.assert_called() + call_kwargs = mock_litellm.embedding.call_args[1] + assert call_kwargs["model"] == "openai/text-embedding-3-small" + assert call_kwargs["api_key"] == "test-key" + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_embed_with_api_base(self, mock_litellm): + """api_base should be forwarded to litellm.""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + api_base="https://openrouter.ai/api/v1", + dimension=1536, + ) + embedder.embed("Hello") + + call_kwargs = mock_litellm.embedding.call_args[1] + assert call_kwargs["api_base"] == "https://openrouter.ai/api/v1" + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_embed_batch(self, mock_litellm): + """Batch embedding should return multiple results.""" + vectors = [[0.1] * 1536, [0.2] * 1536] + mock_litellm.embedding.return_value = _mock_litellm_response(vectors) + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + dimension=1536, + ) + results = embedder.embed_batch(["Hello", "World"]) + + assert len(results) == 2 + assert results[0].dense_vector[0] == 0.1 + assert results[1].dense_vector[0] == 0.2 + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_embed_batch_empty(self, mock_litellm): + """Empty batch should return empty list without API call.""" + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + dimension=1536, + ) + results = embedder.embed_batch([]) + + assert results == [] + mock_litellm.embedding.assert_not_called() + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_embed_non_symmetric_query(self, mock_litellm): + """Query param should be forwarded as input_type.""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + dimension=1536, + query_param="query", + ) + embedder.embed("search query", is_query=True) + + call_kwargs = mock_litellm.embedding.call_args[1] + assert call_kwargs["input_type"] == "query" + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_embed_non_symmetric_document(self, mock_litellm): + """Document param should be forwarded as input_type.""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + dimension=1536, + document_param="passage", + ) + embedder.embed("document text", is_query=False) + + call_kwargs = mock_litellm.embedding.call_args[1] + assert call_kwargs["input_type"] == "passage" + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_embed_no_extra_body_when_symmetric(self, mock_litellm): + """No input_type or extra_body when symmetric mode.""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + dimension=1536, + ) + embedder.embed("Hello world") + + call_kwargs = mock_litellm.embedding.call_args[1] + assert "input_type" not in call_kwargs + assert "extra_body" not in call_kwargs + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_embed_key_value_param(self, mock_litellm): + """Key=value format params should be sent as extra_body.""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + dimension=1536, + query_param="input_type=query,task=search", + ) + embedder.embed("query text", is_query=True) + + call_kwargs = mock_litellm.embedding.call_args[1] + assert call_kwargs["extra_body"] == {"input_type": "query", "task": "search"} + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_get_dimension(self, mock_litellm): + """get_dimension should return the configured dimension.""" + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + dimension=1024, + ) + assert embedder.get_dimension() == 1024 + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_no_api_key_allowed(self, mock_litellm): + """litellm allows no api_key (uses env vars).""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + embedder = LiteLLMDenseEmbedder( + model_name="ollama/nomic-embed-text", + dimension=768, + ) + embedder.embed("test") + + call_kwargs = mock_litellm.embedding.call_args[1] + assert "api_key" not in call_kwargs + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_extra_headers_forwarded(self, mock_litellm): + """Extra headers should be forwarded to litellm.""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + headers = {"HTTP-Referer": "https://mysite.com", "X-Title": "MyApp"} + embedder = LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + dimension=1536, + extra_headers=headers, + ) + embedder.embed("test") + + call_kwargs = mock_litellm.embedding.call_args[1] + assert call_kwargs["extra_headers"] == headers + + +class TestLiteLLMEmbeddingFactory: + """Test the factory creates LiteLLMDenseEmbedder correctly.""" + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_factory_creates_litellm_embedder(self, mock_litellm): + """EmbeddingConfig factory should create LiteLLMDenseEmbedder for provider='litellm'.""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + cfg = EmbeddingModelConfig( + provider="litellm", + model="openai/text-embedding-3-small", + api_key="test-key", + api_base="https://openrouter.ai/api/v1", + dimension=1536, + ) + embedder = EmbeddingConfig(dense=cfg)._create_embedder("litellm", "dense", cfg) + + assert isinstance(embedder, LiteLLMDenseEmbedder) + assert embedder.model_name == "openai/text-embedding-3-small" + assert embedder.api_key == "test-key" + assert embedder.api_base == "https://openrouter.ai/api/v1" + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_factory_forwards_query_document_params(self, mock_litellm): + """Factory should forward query_param and document_param.""" + mock_litellm.embedding.return_value = _mock_litellm_response() + + cfg = EmbeddingModelConfig( + provider="litellm", + model="openai/text-embedding-3-small", + api_key="test-key", + dimension=1536, + query_param="query", + document_param="passage", + ) + embedder = EmbeddingConfig(dense=cfg)._create_embedder("litellm", "dense", cfg) + + assert embedder.query_param == "query" + assert embedder.document_param == "passage" + + def test_config_validation_accepts_litellm(self): + """EmbeddingModelConfig should accept 'litellm' as a valid provider.""" + cfg = EmbeddingModelConfig( + provider="litellm", + model="openai/text-embedding-3-small", + dimension=1536, + ) + assert cfg.provider == "litellm" + + def test_config_validation_litellm_no_api_key_ok(self): + """litellm provider should not require api_key.""" + cfg = EmbeddingModelConfig( + provider="litellm", + model="ollama/nomic-embed-text", + dimension=768, + ) + assert cfg.api_key is None From 457f47e3e6f43994419dbed816c691adf4cb02d3 Mon Sep 17 00:00:00 2001 From: Jannes Stubbemann Date: Sun, 22 Mar 2026 09:16:33 +0100 Subject: [PATCH 2/2] fix: address review feedback on litellm embedding provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Move os.environ.setdefault from module-level to __init__ to avoid mutating the process environment on import 2. Require dimension as mandatory — removes the probe API call during construction that caused surprise billable requests and silent fallbacks 3. Add None guard for LiteLLMDenseEmbedder in factory to give a clear error when litellm is not installed Co-Authored-By: Claude Opus 4.6 (1M context) --- .../models/embedder/litellm_embedders.py | 23 ++++++-------- .../utils/config/embedding_config.py | 11 ++++++- tests/unit/test_litellm_embedder.py | 30 +++++++++++++++++++ 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/openviking/models/embedder/litellm_embedders.py b/openviking/models/embedder/litellm_embedders.py index 7cce60da..4f10f99c 100644 --- a/openviking/models/embedder/litellm_embedders.py +++ b/openviking/models/embedder/litellm_embedders.py @@ -8,9 +8,6 @@ import logging import os - -os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") - from typing import Any, Dict, List, Optional import litellm @@ -42,6 +39,7 @@ class LiteLLMDenseEmbedder(DenseEmbedderBase): >>> embedder = LiteLLMDenseEmbedder( ... model_name="ollama/nomic-embed-text", ... api_base="http://localhost:11434", + ... dimension=768, ... ) >>> result = embedder.embed("Hello world") """ @@ -63,7 +61,7 @@ def __init__( model_name: Model name in litellm format (e.g., "openai/text-embedding-3-small"). api_key: API key for the provider. Falls back to provider-specific env vars. api_base: Custom API base URL (e.g., "https://openrouter.ai/api/v1"). - dimension: Embedding vector dimension. If None, auto-detected via a probe call. + dimension: Embedding vector dimension (required). query_param: Parameter value for query-side embeddings (non-symmetric mode). document_param: Parameter value for document-side embeddings (non-symmetric mode). extra_headers: Extra HTTP headers for API requests. @@ -71,6 +69,8 @@ def __init__( """ super().__init__(model_name, config) + os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") + self.api_key = api_key self.api_base = api_base self.dimension = dimension @@ -78,17 +78,12 @@ def __init__( self.document_param = document_param self.extra_headers = extra_headers + if dimension is None: + raise ValueError( + "LiteLLM embedding provider requires 'dimension' to be set explicitly. " + "Check your embedding model's documentation for the correct dimension." + ) self._dimension = dimension - if self._dimension is None: - self._dimension = self._detect_dimension() - - def _detect_dimension(self) -> int: - """Detect dimension by making a probe embedding call.""" - try: - result = self.embed("test") - return len(result.dense_vector) if result.dense_vector else 1536 - except Exception: - return 1536 def _build_kwargs(self, is_query: bool = False) -> Dict[str, Any]: """Build kwargs dict for litellm.embedding() call.""" diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 9ae01d00..9039bbe6 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -181,7 +181,11 @@ def validate_config(self): elif self.provider == "litellm": # litellm handles auth via env vars or explicit api_key; no strict requirement - pass + if not self.dimension: + raise ValueError( + "LiteLLM provider requires 'dimension' to be set explicitly. " + "Check your embedding model's documentation for the correct dimension." + ) return self @@ -271,6 +275,11 @@ def _create_embedder( VoyageDenseEmbedder, ) + if provider == "litellm" and LiteLLMDenseEmbedder is None: + raise ValueError( + "LiteLLM is not installed. Install it with: pip install litellm" + ) + # Factory registry: (provider, type) -> (embedder_class, param_builder) factory_registry = { ("openai", "dense"): ( diff --git a/tests/unit/test_litellm_embedder.py b/tests/unit/test_litellm_embedder.py index 5fd4a95e..a664e1a2 100644 --- a/tests/unit/test_litellm_embedder.py +++ b/tests/unit/test_litellm_embedder.py @@ -274,3 +274,33 @@ def test_config_validation_litellm_no_api_key_ok(self): dimension=768, ) assert cfg.api_key is None + + def test_config_validation_litellm_requires_dimension(self): + """litellm provider should require dimension to be set.""" + with pytest.raises(ValueError, match="dimension"): + EmbeddingModelConfig( + provider="litellm", + model="openai/text-embedding-3-small", + ) + + @patch("openviking.models.embedder.litellm_embedders.litellm") + def test_dimension_required_in_embedder(self, mock_litellm): + """LiteLLMDenseEmbedder should raise ValueError when dimension is None.""" + from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder + + with pytest.raises(ValueError, match="dimension"): + LiteLLMDenseEmbedder( + model_name="openai/text-embedding-3-small", + api_key="test-key", + ) + + def test_factory_raises_when_litellm_not_installed(self): + """Factory should raise clear error when litellm is not installed.""" + cfg = EmbeddingModelConfig( + provider="litellm", + model="openai/text-embedding-3-small", + dimension=1536, + ) + with patch("openviking.models.embedder.LiteLLMDenseEmbedder", None): + with pytest.raises(ValueError, match="not installed"): + EmbeddingConfig(dense=cfg)._create_embedder("litellm", "dense", cfg)