-
Notifications
You must be signed in to change notification settings - Fork 687
Add Sentence Transformers for embeddings #142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,4 +20,7 @@ providers: | |
| api_key_env: SNOWFLAKE_PAT | ||
| api_endpoint_env: SNOWFLAKE_ACCOUNT_URL | ||
| api_version_env: "2024-10-01" | ||
| model: snowflake-arctic-embed-m-v1.5 | ||
| model: snowflake-arctic-embed-m-v1.5 | ||
|
|
||
| sentence-transformers: | ||
| model: all-MiniLM-L6-v2 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe the name should be: "sentence-transformers/all-MiniLM-L6-v2" instead of just "all-MiniLM-L6-v2" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont think so, it seems to work fine like this ""all-MiniLM-L6-v2"" |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # Copyright (c) 2025 Microsoft Corporation. | ||
| # Licensed under the MIT License | ||
|
|
||
| """ | ||
| SentenceTransformer-based local embedding implementation. | ||
|
|
||
| WARNING: This code is under development and may undergo changes in future releases. | ||
| Backwards compatibility is not guaranteed at this time. | ||
| """ | ||
|
|
||
| import threading | ||
| from typing import List, Optional | ||
| import asyncio | ||
|
|
||
| from sentence_transformers import SentenceTransformer | ||
|
|
||
| from config.config import CONFIG | ||
| from utils.logging_config_helper import get_configured_logger, LogLevel | ||
|
|
||
| logger = get_configured_logger("sentence_transformer_embedding") | ||
|
|
||
| # Thread-safe singleton initialization | ||
| _model_lock = threading.Lock() | ||
| embedding_model = None | ||
|
|
||
| def get_model_name() -> str: | ||
| """ | ||
| Retrieve the embedding model name from configuration or default. | ||
| """ | ||
| provider_config = CONFIG.get_embedding_provider("sentence_transformers") | ||
| if provider_config and provider_config.model: | ||
| return provider_config.model | ||
| return "all-MiniLM-L6-v2" # Default lightweight model | ||
|
|
||
| def get_embedding_model(model_override: Optional[str] = None) -> SentenceTransformer: | ||
| """ | ||
| Load and return a singleton SentenceTransformer model. | ||
| """ | ||
| global embedding_model | ||
| with _model_lock: | ||
| if embedding_model is None: | ||
| # Use override model if provided, otherwise use configured model | ||
| model_name = model_override or get_model_name() | ||
| try: | ||
| embedding_model = SentenceTransformer(model_name) | ||
| logger.info(f"Loaded SentenceTransformer model: {model_name}") | ||
| except Exception as e: | ||
| logger.exception(f"Failed to load SentenceTransformer model: {model_name}") | ||
| raise | ||
| return embedding_model | ||
|
|
||
| async def get_sentence_transformer_embedding( | ||
| text: str, | ||
| model: Optional[str] = None, | ||
|
||
| timeout: float = 30.0 | ||
| ) -> List[float]: | ||
| """ | ||
| Generate a single embedding using SentenceTransformer. | ||
|
|
||
| Args: | ||
| text: The input text to embed. | ||
| model: Optional model name to override config. | ||
| timeout: Unused, for compatibility. | ||
|
|
||
| Returns: | ||
| Embedding vector as list of floats. | ||
| """ | ||
| try: | ||
| model_instance = get_embedding_model(model) | ||
|
|
||
| # Run the blocking encode operation in a thread pool | ||
| loop = asyncio.get_running_loop() | ||
| embedding = await loop.run_in_executor( | ||
| None, | ||
| lambda: model_instance.encode(text.replace("\n", " "), convert_to_numpy=True).tolist() | ||
| ) | ||
|
|
||
| logger.debug(f"Generated embedding (dim={len(embedding)})") | ||
| return embedding | ||
| except Exception as e: | ||
| logger.exception("Error generating SentenceTransformer embedding") | ||
| raise | ||
|
|
||
| async def get_sentence_transformer_batch_embeddings( | ||
| texts: List[str], | ||
| model: Optional[str] = None, | ||
| timeout: float = 60.0 | ||
| ) -> List[List[float]]: | ||
| """ | ||
| Generate batch embeddings using SentenceTransformer. | ||
|
|
||
| Args: | ||
| texts: List of input texts. | ||
| model: Optional model name to override config. | ||
| timeout: Unused, for compatibility. | ||
|
|
||
| Returns: | ||
| List of embedding vectors. | ||
| """ | ||
| try: | ||
| model_instance = get_embedding_model() | ||
| cleaned_texts = [t.replace("\n", " ") for t in texts] | ||
|
|
||
| # Run the blocking encode operation in a thread pool | ||
| loop = asyncio.get_running_loop() | ||
| embeddings = await loop.run_in_executor( | ||
| None, | ||
| lambda: model_instance.encode(cleaned_texts, convert_to_numpy=True).tolist() | ||
| ) | ||
|
|
||
| logger.debug(f"Generated {len(embeddings)} embeddings (dim={len(embeddings[0])})") | ||
| return embeddings | ||
| except Exception as e: | ||
| logger.exception("Error generating batch embeddings with SentenceTransformer") | ||
| raise | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,3 +22,4 @@ feedparser>=6.0.1 | |
| httpx>=0.28.1 | ||
| huggingface_hub>=0.31.0 | ||
| seaborn>=0.13.0 | ||
| sentence-transformers>=4.1.0 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # Sentence Transformers Embedding Framework | ||
|
|
||
| The `sentence_transformers` framework provides a unified interface for working with embedding and reranker models. It is used by the `db_load` tool to compute vector embeddings when inserting documents into the database. | ||
|
|
||
| We use the `all-MiniLM-L6-v2` model as the default embedding model, which offers a strong balance between speed and embedding quality for general-purpose use. The resulting vectors are 384-dimensional. | ||
|
|
||
| A wide range of models is supported through the framework. See the full list at [sentence-transformers on Hugging Face](https://huggingface.co/sentence-transformers). | ||
|
|
||
| **Note**: Embedding vector size is defined by the model. If you change models or providers and encounter a vector size mismatch error, you may need to delete your existing embeddings database and regenerate it using the `db_load` tool. | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Config key
sentence-transformersdoes not match the lookup inget_model_name()which usessentence_transformer; unify the provider identifier to ensure overrides work.