diff --git a/Cargo.lock b/Cargo.lock index bf9a66447..8156837d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1453,6 +1453,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "backon" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cffb0e931875b666fc4fcb20fee52e9bbd1ef836fd9e9e04ec21555f9f85f7ef" +dependencies = [ + "fastrand", +] + [[package]] name = "base16ct" version = "0.2.0" @@ -9935,6 +9944,30 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "redis" +version = "0.27.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d8f99a4090c89cc489a94833c901ead69bfbf3877b4867d5482e321ee875bc" +dependencies = [ + "arc-swap", + "async-trait", + "backon", + "bytes", + "combine", + "futures", + "futures-util", + "itertools 0.13.0", + "itoa", + "num-bigint", + "percent-encoding", + "pin-project-lite", + "ryu", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "redis" version = "1.0.3" @@ -10401,7 +10434,7 @@ dependencies = [ "pin-project-lite", "quick-xml", "rayon", - "redis", + "redis 1.0.3", "reqwest 0.13.2", "reqwest-middleware", "reqwest-retry", @@ -10605,6 +10638,22 @@ dependencies = [ "uuid 1.20.0", ] +[[package]] +name = "rig-redis" +version = "0.1.0" +dependencies = [ + "anyhow", + "httpmock", + "redis 0.27.6", + "rig-core 0.31.0", + "serde", + "serde_json", + "testcontainers", + "tokio", + "tracing-subscriber", + "uuid 1.20.0", +] + [[package]] name = "rig-s3vectors" version = "0.1.20" diff --git a/Cargo.toml b/Cargo.toml index 658ee8156..309c60b76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ qdrant-client = { version = "1.14.0", default-features = false, features = [ quick-xml = "0.38.0" quote = "1.0.40" rayon = "1.10.0" +redis = { version = "0.27", default-features = false } reqwest = { version = "0.13", default-features = false } url = "2.5" rusqlite = "0.32" diff --git a/rig-integrations/rig-redis/Cargo.toml b/rig-integrations/rig-redis/Cargo.toml new file mode 100644 index 000000000..2c1f84fe1 --- /dev/null +++ b/rig-integrations/rig-redis/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "rig-redis" +version = "0.1.0" +edition = { workspace = true } +license = "MIT" +readme = "README.md" +description = "Redis vector store implementation for the rig framework" +repository = "https://github.com/0xPlaygrounds/rig" + +[lints] +workspace = true + +[dependencies] +rig-core = { path = "../../rig/rig-core", version = "0.31.0", default-features = false } +redis = { workspace = true, features = ["tokio-comp", "connection-manager"] } +serde = { workspace = true } +serde_json = { workspace = true } +uuid = { workspace = true, features = ["v4"] } + +[dev-dependencies] +rig-core = { path = "../../rig/rig-core", version = "0.31.0", features = ["derive"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +anyhow = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +testcontainers = { workspace = true } +httpmock = { workspace = true } + +[[example]] +name = "vector_search_redis" +required-features = ["rig-core/derive"] diff --git a/rig-integrations/rig-redis/README.md b/rig-integrations/rig-redis/README.md new file mode 100644 index 000000000..cabef2c91 --- /dev/null +++ b/rig-integrations/rig-redis/README.md @@ -0,0 +1,136 @@ +# Rig-Redis + +Vector store index integration for [Redis](https://redis.io/) using RediSearch vector similarity search. This integration supports dense vector retrieval using Rig's embedding providers and leverages Redis's FT.SEARCH command with KNN queries for efficient similarity search. + +## Features + +- Vector similarity search using Redis's RediSearch module +- Support for KNN (k-nearest neighbors) queries +- Metadata filtering with Redis query syntax +- Document insertion with automatic embedding storage +- Compatible with Redis 7.2+ or Redis Stack + +## Prerequisites + +You need a Redis instance with RediSearch module enabled. This can be: +- [Redis Stack](https://redis.io/docs/stack/) +- Redis 7.2+ with RediSearch module loaded +- Redis Cloud with RediSearch enabled + +## Creating a Vector Index + +Before using the vector store, you need to create a RediSearch index with a vector field. Here's an example using redis-cli: + +```bash +FT.CREATE word_idx + ON HASH + PREFIX 1 doc: + SCHEMA + document TEXT + embedded_text TEXT + embedding VECTOR FLAT 6 + TYPE FLOAT32 + DIM 1536 + DISTANCE_METRIC COSINE +``` + +Replace `1536` with your embedding model's dimensionality. + +## Usage Example + +```rust +use rig::providers::openai; +use rig::vector_store::{InsertDocuments, VectorStoreIndex}; +use rig_redis::RedisVectorStore; + +// Create embedding model +let openai_client = openai::Client::from_env(); +let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_3_SMALL); + +// Create Redis client +let redis_client = redis::Client::open("redis://127.0.0.1:6379")?; + +// Create vector store +let vector_store = RedisVectorStore::new( + model, + redis_client, + "word_idx".to_string(), // index name + "embedding".to_string(), // vector field name +); + +// Insert documents +vector_store.insert_documents(documents).await?; + +// Search +let results = vector_store + .top_n::( + VectorSearchRequest::builder() + .query("your search query") + .samples(5) + .build()? + ) + .await?; +``` + +You can find complete examples [here](https://github.com/0xPlaygrounds/rig/tree/main/rig-integrations/rig-redis/examples). + +## Distance Metrics + +Redis supports three distance metrics: +- **COSINE** - Cosine similarity (default, recommended) +- **L2** - Euclidean distance +- **IP** - Inner product + +Choose the metric that matches your embedding model when creating the index. + +## Limitations + +- Requires pre-created RediSearch index +- Vector dimensionality must match the index definition +- Embeddings are stored as FLOAT32 (converted from FLOAT64) + +## Testing + +### Prerequisites + +Integration tests require Docker to be running, as they use testcontainers to spin up a Redis Stack instance. + +### Running Tests + +```bash +# Run all tests (unit + integration) +cargo test + +# Run only unit tests +cargo test --lib + +# Run only integration tests +cargo test --test integration_tests + +# Or use the Makefile +make test # All tests +make test-unit # Unit tests only +make test-integration # Integration tests only +``` + +### Manual Testing with Local Redis + +You can start a local Redis Stack instance for manual testing: + +```bash +# Start Redis Stack +make redis-local +# or +docker run -d --name redis-stack -p 6379:6379 redis/redis-stack:latest + +# Create a test index +redis-cli FT.CREATE word_idx ON HASH SCHEMA document TEXT embedded_text TEXT embedding VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE + +# Run the example +make run-example +# or +cargo run --example vector_search_redis + +# Stop Redis Stack +make redis-stop +``` diff --git a/rig-integrations/rig-redis/examples/vector_search_redis.rs b/rig-integrations/rig-redis/examples/vector_search_redis.rs new file mode 100644 index 000000000..aaa244658 --- /dev/null +++ b/rig-integrations/rig-redis/examples/vector_search_redis.rs @@ -0,0 +1,82 @@ +use rig::client::ProviderClient; +use rig::vector_store::InsertDocuments; +use rig::vector_store::request::VectorSearchRequest; +use rig::{ + Embed, client::EmbeddingsClient, embeddings::EmbeddingsBuilder, vector_store::VectorStoreIndex, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Embed, Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Default)] +struct WordDefinition { + word: String, + #[serde(skip)] + #[embed] + definition: String, +} + +impl std::fmt::Display for WordDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.word) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create OpenAI client + let openai_client = rig::providers::openai::Client::from_env(); + let model = openai_client.embedding_model(rig::providers::openai::TEXT_EMBEDDING_3_SMALL); + + let redis_url = + std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string()); + let redis_client = redis::Client::open(redis_url)?; + + let vector_store = rig_redis::RedisVectorStore::new( + model.clone(), + redis_client, + "word_idx".to_string(), + "embedding".to_string(), + ); + + // Create test documents with embeddings + let words = vec![ + WordDefinition { + word: "flurbo".to_string(), + definition: "1. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + }, + WordDefinition { + word: "glarb-glarb".to_string(), + definition: "1. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + }, + WordDefinition { + word: "linglingdong".to_string(), + definition: "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + } + ]; + + let documents = EmbeddingsBuilder::new(model.clone()) + .documents(words) + .unwrap() + .build() + .await + .expect("Failed to create embeddings"); + + vector_store.insert_documents(documents).await?; + + // Query vector store + let query = "What does \"glarb-glarb\" mean?"; + + let req = VectorSearchRequest::builder() + .query(query) + .samples(2) + .build() + .expect("VectorSearchRequest should not fail to build here"); + + let results = vector_store.top_n::(req).await?; + + println!("#{} results for query: {}", results.len(), query); + for (score, _id, doc) in results.iter() { + println!("Result score {score} for word: {doc}"); + } + + Ok(()) +} diff --git a/rig-integrations/rig-redis/src/filter.rs b/rig-integrations/rig-redis/src/filter.rs new file mode 100644 index 000000000..efc055336 --- /dev/null +++ b/rig-integrations/rig-redis/src/filter.rs @@ -0,0 +1,183 @@ +use rig::vector_store::request::{Filter as CoreFilter, FilterError, SearchFilter}; +use serde::{Deserialize, Serialize}; + +/// Redis filter value type#[derive(Debug, Clone, PartialEq)] +pub enum RedisValue { + Number(f64), + String(String), + Bool(bool), +} + +impl RedisValue { + fn to_redis_expr(&self) -> String { + match self { + RedisValue::Number(n) => n.to_string(), + RedisValue::String(s) => format!("{{{}}}", s), + RedisValue::Bool(b) => { + if *b { + "1".to_string() + } else { + "0".to_string() + } + } + } + } +} + +impl From for RedisValue { + fn from(value: i64) -> Self { + Self::Number(value as f64) + } +} + +impl From for RedisValue { + fn from(value: u64) -> Self { + Self::Number(value as f64) + } +} + +impl From for RedisValue { + fn from(value: f64) -> Self { + Self::Number(value) + } +} + +impl From for RedisValue { + fn from(value: bool) -> Self { + Self::Bool(value) + } +} + +impl From for RedisValue { + fn from(value: String) -> Self { + Self::String(value) + } +} + +impl TryFrom for RedisValue { + type Error = FilterError; + + fn try_from(value: serde_json::Value) -> Result { + match value { + serde_json::Value::Bool(b) => Ok(RedisValue::Bool(b)), + serde_json::Value::Number(n) => { + let num = n.as_f64().ok_or_else(|| FilterError::Expected { + expected: "Valid 64-bit float".into(), + got: "Invalid 64-bit float".into(), + })?; + Ok(RedisValue::Number(num)) + } + serde_json::Value::String(s) => Ok(RedisValue::String(s)), + serde_json::Value::Null + | serde_json::Value::Array(_) + | serde_json::Value::Object(_) => Err(FilterError::TypeError( + "Redis filter does not currently support null values, arrays or objects".into(), + )), + } + } +} + +/// Redis filter for FT.SEARCH queries +/// +/// Redis uses a query syntax like: `@field:[min max]` for numeric ranges, +/// `@field:{value}` for tags, etc. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Filter(String); + +impl SearchFilter for Filter { + type Value = RedisValue; + + fn eq(key: impl AsRef, value: Self::Value) -> Self { + Self(format!("@{}:{}", key.as_ref(), value.to_redis_expr())) + } + + fn gt(key: impl AsRef, value: Self::Value) -> Self { + match value { + RedisValue::Number(n) => Self(format!("@{}:[({} +inf]", key.as_ref(), n)), + _ => Self(format!("@{}:{}", key.as_ref(), value.to_redis_expr())), + } + } + + fn lt(key: impl AsRef, value: Self::Value) -> Self { + match value { + RedisValue::Number(n) => Self(format!("@{}:[-inf ({}]", key.as_ref(), n)), + _ => Self(format!("@{}:{}", key.as_ref(), value.to_redis_expr())), + } + } + + fn and(self, rhs: Self) -> Self { + Self(format!("({} {})", self.0, rhs.0)) + } + + fn or(self, rhs: Self) -> Self { + Self(format!("({} | {})", self.0, rhs.0)) + } +} + +impl Filter { + #[allow(clippy::should_implement_trait)] + pub fn not(self) -> Self { + Self(format!("-{}", self.0)) + } + + /// Greater than or equal + pub fn gte(key: String, value: ::Value) -> Self { + match value { + RedisValue::Number(n) => Self(format!("@{}:[{} +inf]", key, n)), + _ => Self(format!("@{}:{}", key, value.to_redis_expr())), + } + } + + /// Less than or equal + pub fn lte(key: String, value: ::Value) -> Self { + match value { + RedisValue::Number(n) => Self(format!("@{}:[-inf {}]", key, n)), + _ => Self(format!("@{}:{}", key, value.to_redis_expr())), + } + } + + /// Range filter (inclusive) + pub fn range(key: String, min: f64, max: f64) -> Self { + Self(format!("@{}:[{} {}]", key, min, max)) + } + + /// Range filter (exclusive) + pub fn range_exclusive(key: String, min: f64, max: f64) -> Self { + Self(format!("@{}:[({} ({}]", key, min, max)) + } + + /// Tag filter for multiple values (OR) + pub fn tag_in(key: String, values: Vec) -> Self { + let tags = values + .into_iter() + .map(|v| format!("{{{}}}", v)) + .collect::>() + .join("|"); + Self(format!("@{}:{}", key, tags)) + } + + /// Text search in field + pub fn text_contains(key: String, text: String) -> Self { + Self(format!("@{}:{}", key, text)) + } + + pub fn into_inner(self) -> String { + self.0 + } +} + +impl TryFrom> for Filter { + type Error = FilterError; + + fn try_from(value: CoreFilter) -> Result { + let filter = match value { + CoreFilter::Eq(k, val) => Filter::eq(k, val.try_into()?), + CoreFilter::Gt(k, val) => Filter::gt(k, val.try_into()?), + CoreFilter::Lt(k, val) => Filter::lt(k, val.try_into()?), + CoreFilter::And(l, r) => Self::try_from(*l)?.and(Self::try_from(*r)?), + CoreFilter::Or(l, r) => Self::try_from(*l)?.or(Self::try_from(*r)?), + }; + + Ok(filter) + } +} diff --git a/rig-integrations/rig-redis/src/lib.rs b/rig-integrations/rig-redis/src/lib.rs new file mode 100644 index 000000000..ed1165c74 --- /dev/null +++ b/rig-integrations/rig-redis/src/lib.rs @@ -0,0 +1,359 @@ +pub mod filter; + +pub use filter::Filter; +use redis::{AsyncCommands, Client}; +use rig::{ + Embed, OneOrMany, + embeddings::{Embedding, EmbeddingModel}, + vector_store::{ + InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn, + request::{Filter as CoreFilter, VectorSearchRequest}, + }, + wasm_compat::WasmBoxedFuture, +}; +use serde::{Deserialize, Serialize}; + +/// Redis vector store implementation using RediSearch vector similarity search. +/// +/// This implementation uses Redis's FT.SEARCH command with KNN vector queries +/// for similarity search operations. +pub struct RedisVectorStore +where + M: EmbeddingModel, +{ + /// Model used to generate embeddings for queries + model: M, + /// Redis client + client: Client, + /// Name of the RediSearch index + index_name: String, + /// Name of the vector field in the index + vector_field: String, + /// Optional key prefix for document keys + key_prefix: Option, +} + +impl RedisVectorStore +where + M: EmbeddingModel, +{ + /// Creates a new Redis vector store instance. + /// + /// # Arguments + /// * `model` - Embedding model for query vectorization + /// * `client` - Redis client instance + /// * `index_name` - Name of the RediSearch index to query + /// * `vector_field` - Name of the vector field in the index (default: "embedding") + pub fn new(model: M, client: Client, index_name: String, vector_field: String) -> Self { + Self { + model, + client, + index_name, + vector_field, + key_prefix: None, + } + } + + /// Sets a key prefix for document keys + pub fn with_key_prefix(mut self, prefix: String) -> Self { + self.key_prefix = Some(prefix); + self + } + + /// Converts embedding vector to bytes for Redis + fn embedding_to_bytes(embedding: &[f64]) -> Vec { + embedding + .iter() + .flat_map(|&x| (x as f32).to_le_bytes()) + .collect() + } + + /// Extracts string value from Redis value + fn extract_string(value: &redis::Value) -> Option { + match value { + redis::Value::BulkString(bytes) => Some(String::from_utf8_lossy(bytes).to_string()), + redis::Value::SimpleString(s) => Some(s.clone()), + _ => None, + } + } + + /// Extracts score from Redis value + fn extract_score(value: &redis::Value) -> f64 { + match value { + redis::Value::BulkString(bytes) => { + String::from_utf8_lossy(bytes).parse().unwrap_or(0.0) + } + redis::Value::SimpleString(s) => s.parse().unwrap_or(0.0), + _ => 0.0, + } + } + + /// Parses FT.SEARCH response into results with documents + fn parse_search_response( + response: redis::Value, + ) -> Result, VectorStoreError> + where + T: for<'a> Deserialize<'a>, + { + Self::parse_response_generic(response, true).and_then(|items| { + items + .into_iter() + .map(|(score, id, doc_json)| { + let doc = serde_json::from_str::(&doc_json)?; + Ok((score, id, doc)) + }) + .collect() + }) + } + + /// Parses FT.SEARCH response for IDs only + fn parse_search_response_ids( + response: redis::Value, + ) -> Result, VectorStoreError> { + Self::parse_response_generic(response, false).map(|items| { + items + .into_iter() + .map(|(score, id, _)| (score, id)) + .collect() + }) + } + + /// Generic response parser for both full and ID-only results + fn parse_response_generic( + response: redis::Value, + include_document: bool, + ) -> Result, VectorStoreError> { + match response { + redis::Value::Array(ref items) if !items.is_empty() => { + let count = match &items[0] { + redis::Value::Int(n) => *n as usize, + _ => { + return Err(VectorStoreError::DatastoreError( + "Invalid response format: expected count as first element".into(), + )); + } + }; + + if count == 0 { + return Ok(Vec::new()); + } + + let mut results = Vec::new(); + + for chunk in items[1..].chunks(2) { + if chunk.len() != 2 { + continue; + } + + let id = match Self::extract_string(&chunk[0]) { + Some(id) => id, + None => continue, + }; + + if let redis::Value::Array(fields) = &chunk[1] { + let mut score = 0.0; + let mut document_json = String::new(); + + for field_chunk in fields.chunks(2) { + if field_chunk.len() != 2 { + continue; + } + + let field_name = match Self::extract_string(&field_chunk[0]) { + Some(name) => name, + None => continue, + }; + + if field_name == "__vector_score" { + score = Self::extract_score(&field_chunk[1]); + if !include_document { + break; + } + } else if include_document + && field_name == "document" + && let Some(json) = Self::extract_string(&field_chunk[1]) + { + document_json = json; + } + } + + results.push((score, id, document_json)); + } + } + + Ok(results) + } + _ => Err(VectorStoreError::DatastoreError( + "Invalid FT.SEARCH response format".into(), + )), + } + } + + /// Builds and executes FT.SEARCH command, optionally including document field + async fn execute_search( + &self, + vector_bytes: Vec, + req: &VectorSearchRequest, + include_document: bool, + ) -> Result { + let mut con = self + .client + .get_multiplexed_async_connection() + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + let filter_str = req + .filter() + .as_ref() + .map(|f| f.clone().into_inner()) + .unwrap_or_else(|| "*".to_string()); + + let knn_query = format!( + "{}=>[KNN {} @{} $vec AS __vector_score]", + filter_str, + req.samples(), + self.vector_field + ); + + let mut cmd = redis::cmd("FT.SEARCH"); + cmd.arg(&self.index_name) + .arg(&knn_query) + .arg("PARAMS") + .arg(2) + .arg("vec") + .arg(vector_bytes) + .arg("SORTBY") + .arg("__vector_score") + .arg("RETURN"); + + if include_document { + cmd.arg(2).arg("__vector_score").arg("document"); + } else { + cmd.arg(1).arg("__vector_score"); + } + + cmd.arg("DIALECT").arg(2); + + if req.threshold().is_some() { + cmd.arg("LIMIT").arg(0).arg(req.samples()); + } + + cmd.query_async(&mut con) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e))) + } +} + +impl InsertDocuments for RedisVectorStore +where + Model: EmbeddingModel + Send + Sync, +{ + async fn insert_documents( + &self, + documents: Vec<(Doc, OneOrMany)>, + ) -> Result<(), VectorStoreError> { + let mut con = self + .client + .get_multiplexed_async_connection() + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + for (document, embeddings) in documents { + let json_document = serde_json::to_string(&document)?; + + for embedding in embeddings.into_iter() { + let id = if let Some(ref prefix) = self.key_prefix { + format!("{}{}", prefix, uuid::Uuid::new_v4()) + } else { + uuid::Uuid::new_v4().to_string() + }; + let embedding_bytes = Self::embedding_to_bytes(&embedding.vec); + + con.hset_multiple::<_, _, _, ()>( + &id, + &[ + ("document", json_document.as_bytes()), + ("embedded_text", embedding.document.as_bytes()), + (&self.vector_field, &embedding_bytes), + ], + ) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + } + } + + Ok(()) + } +} + +impl VectorStoreIndex for RedisVectorStore +where + M: EmbeddingModel + Send + Sync, +{ + type Filter = Filter; + + async fn top_n Deserialize<'a> + Send>( + &self, + req: VectorSearchRequest, + ) -> Result, VectorStoreError> { + let embedding = self.model.embed_text(req.query()).await?; + let vector_bytes = Self::embedding_to_bytes(&embedding.vec); + + let response = self.execute_search(vector_bytes, &req, true).await?; + + let mut results = Self::parse_search_response(response)?; + + if let Some(threshold) = req.threshold() { + results.retain(|(score, _, _)| *score >= threshold); + } + + Ok(results) + } + + async fn top_n_ids( + &self, + req: VectorSearchRequest, + ) -> Result, VectorStoreError> { + let embedding = self.model.embed_text(req.query()).await?; + let vector_bytes = Self::embedding_to_bytes(&embedding.vec); + + let response = self.execute_search(vector_bytes, &req, false).await?; + + let mut results = Self::parse_search_response_ids(response)?; + + if let Some(threshold) = req.threshold() { + results.retain(|(score, _)| *score >= threshold); + } + + Ok(results) + } +} + +impl VectorStoreIndexDyn for RedisVectorStore +where + M: EmbeddingModel + Sync + Send, +{ + fn top_n<'a>( + &'a self, + req: VectorSearchRequest>, + ) -> WasmBoxedFuture<'a, TopNResults> { + Box::pin(async move { + let req = req.try_map_filter(Filter::try_from)?; + let results = ::top_n::(self, req).await?; + + Ok(results) + }) + } + + fn top_n_ids<'a>( + &'a self, + req: VectorSearchRequest>, + ) -> WasmBoxedFuture<'a, Result, VectorStoreError>> { + Box::pin(async move { + let req = req.try_map_filter(Filter::try_from)?; + let results = ::top_n_ids(self, req).await?; + + Ok(results) + }) + } +} diff --git a/rig-integrations/rig-redis/tests/filter_tests.rs b/rig-integrations/rig-redis/tests/filter_tests.rs new file mode 100644 index 000000000..9d260f5f6 --- /dev/null +++ b/rig-integrations/rig-redis/tests/filter_tests.rs @@ -0,0 +1,146 @@ +use rig::vector_store::request::{Filter as CoreFilter, SearchFilter}; +use rig_redis::filter::{Filter, RedisValue}; + +#[test] +fn test_filter_eq_string() { + let filter = Filter::eq("category", RedisValue::String("electronics".to_string())); + assert_eq!(filter.into_inner(), "@category:{electronics}"); +} + +#[test] +fn test_filter_eq_number() { + let filter = Filter::eq("price", RedisValue::Number(99.99)); + assert_eq!(filter.into_inner(), "@price:99.99"); +} + +#[test] +fn test_filter_gt() { + let filter = Filter::gt("price", RedisValue::Number(50.0)); + assert_eq!(filter.into_inner(), "@price:[(50 +inf]"); +} + +#[test] +fn test_filter_lt() { + let filter = Filter::lt("price", RedisValue::Number(100.0)); + assert_eq!(filter.into_inner(), "@price:[-inf (100]"); +} + +#[test] +fn test_filter_gte() { + let filter = Filter::gte("price".to_string(), RedisValue::Number(50.0)); + assert_eq!(filter.into_inner(), "@price:[50 +inf]"); +} + +#[test] +fn test_filter_lte() { + let filter = Filter::lte("price".to_string(), RedisValue::Number(100.0)); + assert_eq!(filter.into_inner(), "@price:[-inf 100]"); +} + +#[test] +fn test_filter_range() { + let filter = Filter::range("price".to_string(), 50.0, 100.0); + assert_eq!(filter.into_inner(), "@price:[50 100]"); +} + +#[test] +fn test_filter_range_exclusive() { + let filter = Filter::range_exclusive("price".to_string(), 50.0, 100.0); + assert_eq!(filter.into_inner(), "@price:[(50 (100]"); +} + +#[test] +fn test_filter_and() { + let filter1 = Filter::eq("category", RedisValue::String("electronics".to_string())); + let filter2 = Filter::gt("price", RedisValue::Number(50.0)); + let combined = filter1.and(filter2); + assert_eq!( + combined.into_inner(), + "(@category:{electronics} @price:[(50 +inf])" + ); +} + +#[test] +fn test_filter_or() { + let filter1 = Filter::eq("category", RedisValue::String("electronics".to_string())); + let filter2 = Filter::eq("category", RedisValue::String("books".to_string())); + let combined = filter1.or(filter2); + assert_eq!( + combined.into_inner(), + "(@category:{electronics} | @category:{books})" + ); +} + +#[test] +fn test_filter_not() { + let filter = Filter::eq("category", RedisValue::String("electronics".to_string())); + let negated = filter.not(); + assert_eq!(negated.into_inner(), "-@category:{electronics}"); +} + +#[test] +fn test_filter_tag_in() { + let filter = Filter::tag_in( + "tags".to_string(), + vec!["new".to_string(), "sale".to_string()], + ); + assert_eq!(filter.into_inner(), "@tags:{new}|{sale}"); +} + +#[test] +fn test_filter_text_contains() { + let filter = Filter::text_contains("description".to_string(), "laptop".to_string()); + assert_eq!(filter.into_inner(), "@description:laptop"); +} + +#[test] +fn test_complex_filter() { + let category_filter = Filter::eq("category", RedisValue::String("electronics".to_string())); + let price_min = Filter::gte("price".to_string(), RedisValue::Number(50.0)); + let price_max = Filter::lte("price".to_string(), RedisValue::Number(200.0)); + + let combined = category_filter.and(price_min).and(price_max); + + assert_eq!( + combined.into_inner(), + "((@category:{electronics} @price:[50 +inf]) @price:[-inf 200])" + ); +} + +#[test] +fn test_core_filter_conversion() { + let core_filter: CoreFilter = + CoreFilter::eq("category", serde_json::json!("electronics")); + + let redis_filter = Filter::try_from(core_filter).unwrap(); + assert_eq!(redis_filter.into_inner(), "@category:{electronics}"); +} + +#[test] +fn test_core_filter_gt_conversion() { + let core_filter: CoreFilter = + CoreFilter::gt("price", serde_json::json!(50.0)); + + let redis_filter = Filter::try_from(core_filter).unwrap(); + assert_eq!(redis_filter.into_inner(), "@price:[(50 +inf]"); +} + +#[test] +fn test_core_filter_and_conversion() { + let filter1: CoreFilter = + CoreFilter::eq("category", serde_json::json!("electronics")); + let filter2: CoreFilter = CoreFilter::gt("price", serde_json::json!(50.0)); + let combined = CoreFilter::and(filter1, filter2); + + let redis_filter = Filter::try_from(combined).unwrap(); + assert_eq!( + redis_filter.into_inner(), + "(@category:{electronics} @price:[(50 +inf])" + ); +} + +#[test] +fn test_redis_value_bool() { + let filter = Filter::eq("in_stock", RedisValue::Bool(true)); + assert_eq!(filter.into_inner(), "@in_stock:1"); +} diff --git a/rig-integrations/rig-redis/tests/integration_tests.rs b/rig-integrations/rig-redis/tests/integration_tests.rs new file mode 100644 index 000000000..aee67e8d1 --- /dev/null +++ b/rig-integrations/rig-redis/tests/integration_tests.rs @@ -0,0 +1,624 @@ +use rig::client::EmbeddingsClient; +use rig::{ + Embed, + embeddings::EmbeddingsBuilder, + providers::openai, + vector_store::{InsertDocuments, VectorStoreIndex, request::VectorSearchRequest}, +}; +use rig_redis::RedisVectorStore; +use serde_json::json; +use testcontainers::{ + GenericImage, + core::{IntoContainerPort, WaitFor}, + runners::AsyncRunner, +}; +use tokio::time::{Duration, sleep}; + +const REDIS_PORT: u16 = 6379; +const VECTOR_FIELD: &str = "embedding"; + +#[derive(Embed, Clone, serde::Deserialize, serde::Serialize, Debug, PartialEq)] +struct Word { + id: String, + #[embed] + definition: String, +} + +/// Check if Redis is already running on localhost:6379 +async fn is_redis_running() -> bool { + match redis::Client::open("redis://127.0.0.1:6379") { + Ok(client) => client.get_multiplexed_async_connection().await.is_ok(), + Err(_) => false, + } +} + +/// Get Redis connection info - either from existing instance or new container +async fn get_redis_connection() -> ( + String, + u16, + Option>, +) { + if is_redis_running().await { + println!("Using existing Redis instance on localhost:6379"); + ("127.0.0.1".to_string(), REDIS_PORT, None) + } else { + println!("Starting new Redis Stack container"); + let container = GenericImage::new("redis/redis-stack", "latest") + .with_exposed_port(REDIS_PORT.tcp()) + .with_wait_for(WaitFor::Duration { + length: std::time::Duration::from_secs(3), + }) + .start() + .await + .expect("Failed to start Redis Stack container"); + + let port = container.get_host_port_ipv4(REDIS_PORT).await.unwrap(); + let host = container.get_host().await.unwrap().to_string(); + + (host, port, Some(container)) + } +} + +async fn setup_redis_index( + client: &redis::Client, + index_name: &str, + dimensions: usize, +) -> Result<(), Box> { + let mut con = client.get_multiplexed_async_connection().await?; + + // Drop existing index if it exists (DD flag deletes associated documents) + let _: Result = redis::cmd("FT.DROPINDEX") + .arg(index_name) + .arg("DD") + .query_async(&mut con) + .await; + + // Create vector index with PREFIX to associate documents with this index + let prefix = format!("{index_name}:"); + let _: String = redis::cmd("FT.CREATE") + .arg(index_name) + .arg("ON") + .arg("HASH") + .arg("PREFIX") + .arg(1) + .arg(&prefix) + .arg("SCHEMA") + .arg("document") + .arg("TEXT") + .arg("embedded_text") + .arg("TEXT") + .arg(VECTOR_FIELD) + .arg("VECTOR") + .arg("FLAT") + .arg(6) + .arg("TYPE") + .arg("FLOAT32") + .arg("DIM") + .arg(dimensions) + .arg("DISTANCE_METRIC") + .arg("COSINE") + .query_async(&mut con) + .await?; + + // Wait for index to be ready + sleep(Duration::from_millis(1000)).await; + + Ok(()) +} + +async fn cleanup_redis_index( + client: &redis::Client, + index_name: &str, +) -> Result<(), Box> { + let mut con = client.get_multiplexed_async_connection().await?; + + // Drop index and associated documents + let _: Result = redis::cmd("FT.DROPINDEX") + .arg(index_name) + .arg("DD") + .query_async(&mut con) + .await; + + // Wait for cleanup to complete + sleep(Duration::from_millis(100)).await; + + Ok(()) +} + +#[tokio::test] +async fn test_vector_search_basic() { + let (host, port, _container) = get_redis_connection().await; + let index_name = "test_vector_search_basic"; + + let server = httpmock::MockServer::start(); + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .header("Authorization", "Bearer TEST") + .json_body(json!({ + "input": [ + "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets", + "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.", + "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans." + ], + "model": "text-embedding-ada-002", + })); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "object": "list", + "data": [ + {"object": "embedding", "embedding": (0..1536).map(|i| if i < 512 { 1.0 } else { 0.0 }).collect::>(), "index": 0}, + {"object": "embedding", "embedding": (0..1536).map(|i| if (512..1024).contains(&i) { 1.0 } else { 0.0 }).collect::>(), "index": 1}, + {"object": "embedding", "embedding": (0..1536).map(|i| if i >= 1024 { 1.0 } else { 0.0 }).collect::>(), "index": 2} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 8, "total_tokens": 8} + })); + }); + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .header("Authorization", "Bearer TEST") + .json_body(json!({ + "input": ["What is a linglingdong?"], + "model": "text-embedding-ada-002", + })); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "object": "list", + "data": [ + {"object": "embedding", "embedding": (0..1536).map(|i| if i >= 1024 { 1.0 } else { 0.0 }).collect::>(), "index": 0} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 8, "total_tokens": 8} + })); + }); + + let openai_client = openai::Client::builder() + .api_key("TEST") + .base_url(server.base_url()) + .build() + .unwrap(); + + let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); + + let redis_url = format!("redis://{host}:{port}"); + let redis_client = redis::Client::open(redis_url).unwrap(); + + setup_redis_index(&redis_client, index_name, 1536) + .await + .unwrap(); + + let vector_store = RedisVectorStore::new( + model.clone(), + redis_client.clone(), + index_name.to_string(), + VECTOR_FIELD.to_string(), + ) + .with_key_prefix(format!("{}:", index_name)); + + let words = vec![ + Word { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + }, + Word { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + }, + Word { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + } + ]; + + let documents = EmbeddingsBuilder::new(model.clone()) + .documents(words) + .unwrap() + .build() + .await + .unwrap(); + + vector_store.insert_documents(documents).await.unwrap(); + + sleep(Duration::from_millis(500)).await; + + let req = VectorSearchRequest::builder() + .query("What is a linglingdong?") + .samples(1) + .build() + .unwrap(); + + let results = vector_store.top_n::(req).await.unwrap(); + + assert_eq!(results.len(), 1); + let (score, _, doc) = &results[0]; + // Redis returns cosine distance (0 = identical, higher = more different) + // So we just check it's a valid number + assert!(score.is_finite()); + assert!(doc.definition.contains("linglingdong")); + + cleanup_redis_index(&redis_client, index_name) + .await + .unwrap(); +} + +#[tokio::test] +async fn test_top_n_ids() { + let (host, port, _container) = get_redis_connection().await; + let index_name = "test_top_n_ids"; + + let server = httpmock::MockServer::start(); + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .json_body(json!({ + "input": [ + "First test document", + "Second test document" + ], + "model": "text-embedding-ada-002", + })); + then.status(200).json_body(json!({ + "object": "list", + "data": [ + {"object": "embedding", "embedding": vec![0.5; 1536], "index": 0}, + {"object": "embedding", "embedding": vec![0.6; 1536], "index": 1} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 4, "total_tokens": 4} + })); + }); + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .json_body(json!({ + "input": ["test query"], + "model": "text-embedding-ada-002", + })); + then.status(200).json_body(json!({ + "object": "list", + "data": [ + {"object": "embedding", "embedding": vec![0.55; 1536], "index": 0} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 2, "total_tokens": 2} + })); + }); + + let openai_client = openai::Client::builder() + .api_key("TEST") + .base_url(server.base_url()) + .build() + .unwrap(); + + let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); + + let redis_url = format!("redis://{host}:{port}"); + let redis_client = redis::Client::open(redis_url).unwrap(); + + setup_redis_index(&redis_client, index_name, 1536) + .await + .unwrap(); + + let vector_store = RedisVectorStore::new( + model.clone(), + redis_client.clone(), + index_name.to_string(), + VECTOR_FIELD.to_string(), + ) + .with_key_prefix(format!("{}:", index_name)); + + let words = vec![ + Word { + id: "test1".to_string(), + definition: "First test document".to_string(), + }, + Word { + id: "test2".to_string(), + definition: "Second test document".to_string(), + }, + ]; + + let documents = EmbeddingsBuilder::new(model.clone()) + .documents(words) + .unwrap() + .build() + .await + .unwrap(); + + vector_store.insert_documents(documents).await.unwrap(); + + sleep(Duration::from_millis(500)).await; + + let req = VectorSearchRequest::builder() + .query("test query") + .samples(2) + .build() + .unwrap(); + + let results = vector_store.top_n_ids(req).await.unwrap(); + + assert_eq!(results.len(), 2); + // Redis returns cosine distance, so scores can be 0 or positive + assert!(results[0].0.is_finite()); + assert!(!results[0].1.is_empty()); + + cleanup_redis_index(&redis_client, index_name) + .await + .unwrap(); +} + +#[tokio::test] +async fn test_threshold_filtering() { + let (host, port, _container) = get_redis_connection().await; + let index_name = "test_threshold_filtering"; + + let server = httpmock::MockServer::start(); + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .json_body(json!({ + "input": [ + "Document with low similarity", + "Document with high similarity" + ], + "model": "text-embedding-ada-002", + })); + then.status(200).json_body(json!({ + "object": "list", + "data": [ + {"object": "embedding", "embedding": vec![0.1; 1536], "index": 0}, + {"object": "embedding", "embedding": vec![0.9; 1536], "index": 1} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 4, "total_tokens": 4} + })); + }); + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .json_body(json!({ + "input": ["test query"], + "model": "text-embedding-ada-002", + })); + then.status(200).json_body(json!({ + "object": "list", + "data": [ + {"object": "embedding", "embedding": vec![0.85; 1536], "index": 0} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 2, "total_tokens": 2} + })); + }); + + let openai_client = openai::Client::builder() + .api_key("TEST") + .base_url(server.base_url()) + .build() + .unwrap(); + + let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); + + let redis_url = format!("redis://{host}:{port}"); + let redis_client = redis::Client::open(redis_url).unwrap(); + + setup_redis_index(&redis_client, index_name, 1536) + .await + .unwrap(); + + let vector_store = RedisVectorStore::new( + model.clone(), + redis_client.clone(), + index_name.to_string(), + VECTOR_FIELD.to_string(), + ) + .with_key_prefix(format!("{}:", index_name)); + + let words = vec![ + Word { + id: "low_score".to_string(), + definition: "Document with low similarity".to_string(), + }, + Word { + id: "high_score".to_string(), + definition: "Document with high similarity".to_string(), + }, + ]; + + let documents = EmbeddingsBuilder::new(model.clone()) + .documents(words) + .unwrap() + .build() + .await + .unwrap(); + + vector_store.insert_documents(documents).await.unwrap(); + + sleep(Duration::from_millis(500)).await; + + let req = VectorSearchRequest::builder() + .query("test query") + .samples(10) + .threshold(0.5) + .build() + .unwrap(); + + let results = vector_store.top_n::(req).await.unwrap(); + + for (score, _, _) in &results { + assert!(score >= &0.5, "All results should meet threshold"); + } + + cleanup_redis_index(&redis_client, index_name) + .await + .unwrap(); +} + +#[tokio::test] +async fn test_insert_multiple_embeddings() { + let (host, port, _container) = get_redis_connection().await; + let index_name = "test_insert_multiple_embeddings"; + + let server = httpmock::MockServer::start(); + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .json_body(json!({ + "input": [ + "First batch document", + "Second batch document", + "Third batch document" + ], + "model": "text-embedding-ada-002", + })); + then.status(200).json_body(json!({ + "object": "list", + "data": [ + {"object": "embedding", "embedding": vec![0.1; 1536], "index": 0}, + {"object": "embedding", "embedding": vec![0.2; 1536], "index": 1}, + {"object": "embedding", "embedding": vec![0.3; 1536], "index": 2} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 6, "total_tokens": 6} + })); + }); + + let openai_client = openai::Client::builder() + .api_key("TEST") + .base_url(server.base_url()) + .build() + .unwrap(); + + let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); + + let redis_url = format!("redis://{host}:{port}"); + let redis_client = redis::Client::open(redis_url).unwrap(); + + setup_redis_index(&redis_client, index_name, 1536) + .await + .unwrap(); + + let vector_store = RedisVectorStore::new( + model.clone(), + redis_client.clone(), + index_name.to_string(), + VECTOR_FIELD.to_string(), + ) + .with_key_prefix(format!("{}:", index_name)); + + let words = vec![ + Word { + id: "batch1".to_string(), + definition: "First batch document".to_string(), + }, + Word { + id: "batch2".to_string(), + definition: "Second batch document".to_string(), + }, + Word { + id: "batch3".to_string(), + definition: "Third batch document".to_string(), + }, + ]; + + let documents = EmbeddingsBuilder::new(model.clone()) + .documents(words) + .unwrap() + .build() + .await + .unwrap(); + + vector_store.insert_documents(documents).await.unwrap(); + + sleep(Duration::from_millis(500)).await; + + // Verify documents were inserted + let mut con = redis_client + .get_multiplexed_async_connection() + .await + .unwrap(); + let keys: Vec = redis::cmd("KEYS") + .arg("*") + .query_async(&mut con) + .await + .unwrap(); + + // Should have at least 3 documents (one per embedding) + assert!(keys.len() >= 3, "Should have inserted at least 3 documents"); + + cleanup_redis_index(&redis_client, index_name) + .await + .unwrap(); +} + +#[tokio::test] +async fn test_empty_results() { + let (host, port, _container) = get_redis_connection().await; + let index_name = "test_empty_results"; + + let server = httpmock::MockServer::start(); + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .json_body(json!({ + "input": ["query with no results"], + "model": "text-embedding-ada-002", + })); + then.status(200).json_body(json!({ + "object": "list", + "data": [ + {"object": "embedding", "embedding": vec![0.5; 1536], "index": 0} + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 2, "total_tokens": 2} + })); + }); + + let openai_client = openai::Client::builder() + .api_key("TEST") + .base_url(server.base_url()) + .build() + .unwrap(); + + let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); + + let redis_url = format!("redis://{host}:{port}"); + let redis_client = redis::Client::open(redis_url).unwrap(); + + setup_redis_index(&redis_client, index_name, 1536) + .await + .unwrap(); + + let vector_store = RedisVectorStore::new( + model.clone(), + redis_client.clone(), + index_name.to_string(), + VECTOR_FIELD.to_string(), + ) + .with_key_prefix(format!("{}:", index_name)); + + let req = VectorSearchRequest::builder() + .query("query with no results") + .samples(5) + .build() + .unwrap(); + + let results = vector_store.top_n::(req).await.unwrap(); + + assert_eq!(results.len(), 0); + + cleanup_redis_index(&redis_client, index_name) + .await + .unwrap(); +}