From 0ffacca6cef60d0e9e4664ded54e71a5100d7785 Mon Sep 17 00:00:00 2001 From: McTr0 <1334853459@qq.com> Date: Fri, 13 Mar 2026 22:07:02 +0800 Subject: [PATCH] fix(gemini): resolve embedding dimensions dynamically instead of hardcoding Previously ndims() was hardcoded to return 768 regardless of the actual model or user configuration. This caused dimension mismatches when using gemini-embedding-001 (native default: 3072) with downstream vector stores. Changes: - Add model_default_ndims() lookup (EMBEDDING_001=3072, EMBEDDING_004=768) - Change ndims field from Option to usize - Resolve dimensions in make(): user-specified > model default > 768 - ndims() now returns self.ndims directly (matching OpenAI pattern) BREAKING CHANGE: EmbeddingModel::new() and with_model() now take usize instead of Option for the ndims parameter. Closes #452 --- .../src/providers/gemini/embedding.rs | 79 +++++++++++++++++-- 1 file changed, 72 insertions(+), 7 deletions(-) diff --git a/rig/rig-core/src/providers/gemini/embedding.rs b/rig/rig-core/src/providers/gemini/embedding.rs index 86a8bae7b..4a6b242d9 100644 --- a/rig/rig-core/src/providers/gemini/embedding.rs +++ b/rig/rig-core/src/providers/gemini/embedding.rs @@ -12,20 +12,31 @@ use crate::{ wasm_compat::WasmCompatSend, }; -/// `embedding-001` embedding model +/// `gemini-embedding-001` embedding model (3072 dimensions by default) pub const EMBEDDING_001: &str = "gemini-embedding-001"; -/// `text-embedding-004` embedding model +/// `text-embedding-004` embedding model (768 dimensions by default) pub const EMBEDDING_004: &str = "text-embedding-004"; +/// Returns the default output dimensionality for known Gemini embedding models. +/// +/// See +fn model_default_ndims(model: &str) -> Option { + match model { + EMBEDDING_001 => Some(3072), + EMBEDDING_004 => Some(768), + _ => None, + } +} + #[derive(Clone)] pub struct EmbeddingModel { client: Client, model: String, - ndims: Option, + ndims: usize, } impl EmbeddingModel { - pub fn new(client: Client, model: impl Into, ndims: Option) -> Self { + pub fn new(client: Client, model: impl Into, ndims: usize) -> Self { Self { client, model: model.into(), @@ -33,7 +44,7 @@ impl EmbeddingModel { } } - pub fn with_model(client: Client, model: &str, ndims: Option) -> Self { + pub fn with_model(client: Client, model: &str, ndims: usize) -> Self { Self { client, model: model.to_string(), @@ -51,11 +62,13 @@ where const MAX_DOCUMENTS: usize = 1024; fn make(client: &Self::Client, model: impl Into, dims: Option) -> Self { - Self::new(client.clone(), model, dims) + let model = model.into(); + let ndims = dims.or_else(|| model_default_ndims(&model)).unwrap_or(768); + Self::new(client.clone(), model, ndims) } fn ndims(&self) -> usize { - 768 + self.ndims } /// @@ -238,3 +251,55 @@ mod gemini_api_types { pub values: Vec, } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_default_ndims_lookup() { + assert_eq!(model_default_ndims(EMBEDDING_001), Some(3072)); + assert_eq!(model_default_ndims(EMBEDDING_004), Some(768)); + assert_eq!(model_default_ndims("unknown-model"), None); + } + + #[test] + fn test_make_resolves_default_dims() { + let client = Client::new("test_key").unwrap(); + + // EMBEDDING_001 defaults to 3072 + let model = + ::make(&client, EMBEDDING_001, None); + assert_eq!(embeddings::EmbeddingModel::ndims(&model), 3072); + + // EMBEDDING_004 defaults to 768 + let model = + ::make(&client, EMBEDDING_004, None); + assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768); + + // Unknown model falls back to 768 + let model = ::make( + &client, + "some-future-model", + None, + ); + assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768); + } + + #[test] + fn test_make_respects_explicit_dims() { + let client = Client::new("test_key").unwrap(); + + let model = + ::make(&client, EMBEDDING_001, Some(256)); + assert_eq!(embeddings::EmbeddingModel::ndims(&model), 256); + } + + #[test] + fn test_new_uses_provided_ndims() { + let client = Client::new("test_key").unwrap(); + + let model = EmbeddingModel::new(client, EMBEDDING_001, 512); + assert_eq!(embeddings::EmbeddingModel::ndims(&model), 512); + } +}