Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions openviking/models/embedder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Jina AI: Dense only
- Voyage AI: Dense only
- Google Gemini: Dense only
- LiteLLM: Dense only (bridges to OpenRouter, Ollama, vLLM, and many others)
"""

from openviking.models.embedder.base import (
Expand All @@ -30,6 +31,11 @@
except ImportError:
GeminiDenseEmbedder = None # google-genai not installed
from openviking.models.embedder.jina_embedders import JinaDenseEmbedder

try:
from openviking.models.embedder.litellm_embedders import LiteLLMDenseEmbedder
except ImportError:
LiteLLMDenseEmbedder = None # litellm not installed
from openviking.models.embedder.minimax_embedders import MinimaxDenseEmbedder
from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder
from openviking.models.embedder.vikingdb_embedders import (
Expand All @@ -56,6 +62,8 @@
"GeminiDenseEmbedder",
# Jina AI implementations
"JinaDenseEmbedder",
# LiteLLM implementations
"LiteLLMDenseEmbedder",
# MiniMax implementations
"MinimaxDenseEmbedder",
# OpenAI implementations
Expand Down
198 changes: 198 additions & 0 deletions openviking/models/embedder/litellm_embedders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: Apache-2.0
"""LiteLLM Embedder Implementation

Uses litellm to provide a unified embedding interface across many providers
(OpenRouter, Ollama, vLLM, and any OpenAI-compatible endpoint).
"""

import logging
import os
from typing import Any, Dict, List, Optional

import litellm

from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult
from openviking.telemetry import get_current_telemetry

logger = logging.getLogger(__name__)


class LiteLLMDenseEmbedder(DenseEmbedderBase):
"""LiteLLM Dense Embedder Implementation

Routes embedding requests through litellm, supporting dozens of providers
via a unified interface. Model names use litellm's provider/model format
(e.g., "openai/text-embedding-3-small", "ollama/nomic-embed-text").

Example:
>>> # OpenRouter embeddings
>>> embedder = LiteLLMDenseEmbedder(
... model_name="openai/text-embedding-3-small",
... api_key="sk-or-...",
... api_base="https://openrouter.ai/api/v1",
... dimension=1536,
... )
>>> result = embedder.embed("Hello world")

>>> # Local Ollama embeddings
>>> embedder = LiteLLMDenseEmbedder(
... model_name="ollama/nomic-embed-text",
... api_base="http://localhost:11434",
... dimension=768,
... )
>>> result = embedder.embed("Hello world")
"""

def __init__(
self,
model_name: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
dimension: Optional[int] = None,
query_param: Optional[str] = None,
document_param: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
config: Optional[Dict[str, Any]] = None,
):
"""Initialize LiteLLM Dense Embedder

Args:
model_name: Model name in litellm format (e.g., "openai/text-embedding-3-small").
api_key: API key for the provider. Falls back to provider-specific env vars.
api_base: Custom API base URL (e.g., "https://openrouter.ai/api/v1").
dimension: Embedding vector dimension (required).
query_param: Parameter value for query-side embeddings (non-symmetric mode).
document_param: Parameter value for document-side embeddings (non-symmetric mode).
extra_headers: Extra HTTP headers for API requests.
config: Additional configuration dict.
"""
super().__init__(model_name, config)

os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True")

self.api_key = api_key
self.api_base = api_base
self.dimension = dimension
self.query_param = query_param
self.document_param = document_param
self.extra_headers = extra_headers

if dimension is None:
raise ValueError(
"LiteLLM embedding provider requires 'dimension' to be set explicitly. "
"Check your embedding model's documentation for the correct dimension."
)
self._dimension = dimension

def _build_kwargs(self, is_query: bool = False) -> Dict[str, Any]:
"""Build kwargs dict for litellm.embedding() call."""
kwargs: Dict[str, Any] = {"model": self.model_name}

if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
if self.dimension:
kwargs["dimensions"] = self.dimension

# Non-symmetric embedding support
active_param = None
if is_query and self.query_param is not None:
active_param = self.query_param
elif not is_query and self.document_param is not None:
active_param = self.document_param

if active_param:
if "=" in active_param:
# Parse key=value format (e.g., "input_type=query,task=search")
extra_body = {}
for part in active_param.split(","):
part = part.strip()
if "=" in part:
key, value = part.split("=", 1)
extra_body[key.strip()] = value.strip()
if extra_body:
kwargs["extra_body"] = extra_body
else:
kwargs["input_type"] = active_param

return kwargs

def _update_telemetry_token_usage(self, response) -> None:
"""Update telemetry with token usage from response."""
usage = getattr(response, "usage", None)
if not usage:
return

def _usage_value(key: str, default: int = 0) -> int:
if isinstance(usage, dict):
return int(usage.get(key, default) or default)
return int(getattr(usage, key, default) or default)

prompt_tokens = _usage_value("prompt_tokens", 0)
total_tokens = _usage_value("total_tokens", prompt_tokens)
output_tokens = max(total_tokens - prompt_tokens, 0)
get_current_telemetry().add_token_usage_by_source(
"embedding",
prompt_tokens,
output_tokens,
)

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
"""Perform dense embedding on text via litellm.

Args:
text: Input text
is_query: Flag to indicate if this is a query embedding

Returns:
EmbedResult: Result containing dense_vector

Raises:
RuntimeError: When embedding call fails
"""
try:
kwargs = self._build_kwargs(is_query=is_query)
kwargs["input"] = [text]
response = litellm.embedding(**kwargs)
self._update_telemetry_token_usage(response)
vector = response.data[0]["embedding"]
return EmbedResult(dense_vector=vector)
except Exception as e:
raise RuntimeError(f"LiteLLM embedding failed: {e}") from e

def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
"""Batch embedding via litellm.

Args:
texts: List of texts
is_query: Flag to indicate if these are query embeddings

Returns:
List[EmbedResult]: List of embedding results

Raises:
RuntimeError: When embedding call fails
"""
if not texts:
return []

try:
kwargs = self._build_kwargs(is_query=is_query)
kwargs["input"] = texts
response = litellm.embedding(**kwargs)
self._update_telemetry_token_usage(response)
return [EmbedResult(dense_vector=item["embedding"]) for item in response.data]
except Exception as e:
raise RuntimeError(f"LiteLLM batch embedding failed: {e}") from e

def get_dimension(self) -> int:
"""Get embedding dimension.

Returns:
int: Vector dimension
"""
return self._dimension
39 changes: 33 additions & 6 deletions openviking_cli/utils/config/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class EmbeddingModelConfig(BaseModel):
provider: Optional[str] = Field(
default="volcengine",
description=(
"Provider type: 'openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage'. "
"For OpenRouter or other OpenAI-compatible providers, use 'openai' with "
"api_base and extra_headers."
"Provider type: 'openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'litellm'. "
"For OpenRouter or other OpenAI-compatible providers, use 'litellm' with "
"api_base and api_key, or 'openai' with api_base and extra_headers."
),
)
backend: Optional[str] = Field(
Expand Down Expand Up @@ -103,10 +103,11 @@ def validate_config(self):
"gemini",
"voyage",
"minimax",
"litellm",
]:
raise ValueError(
f"Invalid embedding provider: '{self.provider}'. Must be one of: "
"'openai', 'azure', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'minimax'"
"'openai', 'azure', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'minimax', 'litellm'"
)

# Provider-specific validation
Expand Down Expand Up @@ -178,6 +179,14 @@ def validate_config(self):
if not self.api_key:
raise ValueError("MiniMax provider requires 'api_key' to be set")

elif self.provider == "litellm":
# litellm handles auth via env vars or explicit api_key; no strict requirement
if not self.dimension:
raise ValueError(
"LiteLLM provider requires 'dimension' to be set explicitly. "
"Check your embedding model's documentation for the correct dimension."
)

return self

def get_effective_dimension(self) -> int:
Expand All @@ -203,7 +212,7 @@ def get_effective_dimension(self) -> int:

class EmbeddingConfig(BaseModel):
"""
Embedding configuration, supports OpenAI, VolcEngine, VikingDB, Jina, Gemini, or Voyage APIs.
Embedding configuration, supports OpenAI, VolcEngine, VikingDB, Jina, Gemini, Voyage, or LiteLLM APIs.

Structure:
- dense: Configuration for dense embedder
Expand Down Expand Up @@ -241,7 +250,7 @@ def _create_embedder(
"""Factory method to create embedder instance based on provider and type.

Args:
provider: Provider type ('openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage')
provider: Provider type ('openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'litellm')
embedder_type: Embedder type ('dense', 'sparse', 'hybrid')
config: EmbeddingModelConfig instance

Expand All @@ -254,6 +263,7 @@ def _create_embedder(
from openviking.models.embedder import (
GeminiDenseEmbedder,
JinaDenseEmbedder,
LiteLLMDenseEmbedder,
MinimaxDenseEmbedder,
OpenAIDenseEmbedder,
VikingDBDenseEmbedder,
Expand All @@ -265,6 +275,11 @@ def _create_embedder(
VoyageDenseEmbedder,
)

if provider == "litellm" and LiteLLMDenseEmbedder is None:
raise ValueError(
"LiteLLM is not installed. Install it with: pip install litellm"
)

# Factory registry: (provider, type) -> (embedder_class, param_builder)
factory_registry = {
("openai", "dense"): (
Expand Down Expand Up @@ -414,6 +429,18 @@ def _create_embedder(
**({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}),
},
),
("litellm", "dense"): (
LiteLLMDenseEmbedder,
lambda cfg: {
"model_name": cfg.model,
"api_key": cfg.api_key,
"api_base": cfg.api_base,
"dimension": cfg.dimension,
**({"query_param": cfg.query_param} if cfg.query_param else {}),
**({"document_param": cfg.document_param} if cfg.document_param else {}),
**({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}),
},
),
}

key = (provider, embedder_type)
Expand Down
Loading
Loading