diff --git a/README.md b/README.md index fe60bfe..201dbc1 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,14 @@ response = client.embeddings.create( ) embedding = response.data[0].embedding + +# Optional: Specify embedding dimensions (requires Ollama v0.11.11+) +# Example with a model that supports dynamic dimensions +response = client.embeddings.create( + input="The quick brown fox jumps over the lazy dog", + model="qwen3-embedding", # Use a model with MRL support + dimensions=512 # Custom dimension size +) ``` **Response (OpenAI format):** @@ -93,6 +101,11 @@ embedding = response.data[0].embedding } ``` +**Note on dimensions parameter:** The `dimensions` parameter is supported in Ollama v0.11.11 and later. However, model support varies: +- **Models with MRL (Matryoshka Representation Learning)** like Qwen3 Embedding support dynamic dimensions +- **Fixed-dimension models** like `nomic-embed-text` (768) and `all-minilm` (384) will use their default dimensions regardless of this parameter +- Check your model's documentation to see if it supports custom dimensions + The server waits up to 30 seconds for the result. If processing takes longer, it returns a task ID for polling. #### Chat Completions @@ -165,6 +178,7 @@ All endpoints require `Authorization: Bearer your-secret-token` header. |-------|------|---------|-------------| | `input` | string | required | Text to embed | | `model` | string | `nomic-embed-text` | Model name (optional) | +| `dimensions` | integer | null | Embedding dimension size (optional, must be positive). Requires Ollama v0.11.11+. Support varies by model. | `POST /v1/chat/completions` diff --git a/server/main.py b/server/main.py index 09a085b..05f3db7 100644 --- a/server/main.py +++ b/server/main.py @@ -164,6 +164,8 @@ async def openai_embeddings(request: OpenAIEmbeddingRequest, token: str = Depend {"id": "task-id"} - poll GET /tasks/{id} for result """ payload = {"text": request.input, "model": request.model} + if request.dimensions is not None: + payload["dimensions"] = request.dimensions task_id = await database.create_task("embedding", payload) max_wait = 30 # Fixed wait time for embeddings deadline = time.time() + max_wait diff --git a/server/models.py b/server/models.py index 0615062..8636a99 100644 --- a/server/models.py +++ b/server/models.py @@ -19,6 +19,7 @@ class OpenAIEmbeddingResponse(BaseModel): class OpenAIEmbeddingRequest(BaseModel): input: str # text to embed model: str # Remove default value + dimensions: Optional[int] = Field(default=None, gt=0) # Optional embedding dimension size # Chat completion models diff --git a/worker/embedder.py b/worker/embedder.py index 87f0400..9341d67 100644 --- a/worker/embedder.py +++ b/worker/embedder.py @@ -2,14 +2,18 @@ from config import OLLAMA_URL, EMBEDDING_MODEL -def get_embedding(text: str, model: str = None) -> list[float]: +def get_embedding(text: str, model: str = None, dimensions: int = None) -> list[float]: """Get embedding from Ollama API.""" + payload = { + "model": model or EMBEDDING_MODEL, + "prompt": text, + } + if dimensions is not None: + payload["dimensions"] = dimensions + response = requests.post( f"{OLLAMA_URL}/api/embeddings", - json={ - "model": model or EMBEDDING_MODEL, - "prompt": text, - }, + json=payload, timeout=60, ) response.raise_for_status() diff --git a/worker/worker.py b/worker/worker.py index 8371761..4138245 100644 --- a/worker/worker.py +++ b/worker/worker.py @@ -50,9 +50,10 @@ def process_embedding_task(task_id: str, payload: dict): if not isinstance(text, str): raise ValueError("Embedding task payload must include a 'text' field of type string.") model = payload.get("model") + dimensions = payload.get("dimensions") print(f"[embedding] {task_id}: {text[:50]}...") - embedding = get_embedding(text, model) + embedding = get_embedding(text, model, dimensions) complete_task(task_id, embedding) print(f"[embedding] {task_id} completed (dim: {len(embedding)})")