From 152eac8fb1498244bda62d8a2bcb8100b6ab8a79 Mon Sep 17 00:00:00 2001 From: JoshuaTang <1240604020@qq.com> Date: Mon, 5 Jan 2026 20:45:17 -0800 Subject: [PATCH 1/2] feat: support vector literals --- rust/lance-graph/src/ast.rs | 3 + .../src/datafusion_planner/expression.rs | 19 +++ rust/lance-graph/src/parser.rs | 127 ++++++++++++++++- rust/lance-graph/src/semantic.rs | 12 ++ rust/lance-graph/src/simple_executor/expr.rs | 1 + rust/lance-graph/tests/test_vector_search.rs | 128 ++++++++++++++++++ 6 files changed, 288 insertions(+), 2 deletions(-) diff --git a/rust/lance-graph/src/ast.rs b/rust/lance-graph/src/ast.rs index 20234a7..94c1438 100644 --- a/rust/lance-graph/src/ast.rs +++ b/rust/lance-graph/src/ast.rs @@ -308,6 +308,9 @@ pub enum ValueExpression { }, /// Parameter reference for query parameters (e.g., $query_vector) Parameter(String), + /// Vector literal: [0.1, 0.2, 0.3] + /// Represents an inline vector for similarity search + VectorLiteral(Vec), } /// Arithmetic operators diff --git a/rust/lance-graph/src/datafusion_planner/expression.rs b/rust/lance-graph/src/datafusion_planner/expression.rs index a8b5d06..586155b 100644 --- a/rust/lance-graph/src/datafusion_planner/expression.rs +++ b/rust/lance-graph/src/datafusion_planner/expression.rs @@ -241,6 +241,25 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { vec![left_expr, right_expr], )) } + VE::VectorLiteral(values) => { + // Convert Vec to DataFusion scalar FixedSizeList + use arrow::array::{Float32Array, FixedSizeListArray}; + use arrow::datatypes::{DataType, Field}; + use datafusion::scalar::ScalarValue; + use std::sync::Arc; + + let dim = values.len() as i32; + let field = Arc::new(Field::new("item", DataType::Float32, true)); + let float_array = Arc::new(Float32Array::from(values.clone())); + + let list_array = FixedSizeListArray::try_new(field.clone(), dim, float_array, None) + .expect("Failed to create FixedSizeListArray for vector literal"); + + let scalar = ScalarValue::try_from_array(&list_array, 0) + .expect("Failed to create scalar from array"); + + lit(scalar) + } VE::Parameter(name) => { // TODO: Implement proper parameter resolution // Parameters ($param) should be resolved to literal values from the query's diff --git a/rust/lance-graph/src/parser.rs b/rust/lance-graph/src/parser.rs index 63c0963..140cceb 100644 --- a/rust/lance-graph/src/parser.rs +++ b/rust/lance-graph/src/parser.rs @@ -11,8 +11,8 @@ use crate::error::{GraphError, Result}; use nom::{ branch::alt, bytes::complete::{tag, tag_no_case, take_while1}, - character::complete::{char, multispace0, multispace1}, - combinator::{map, opt, peek, recognize}, + character::complete::{char, digit0, digit1, multispace0, multispace1, one_of}, + combinator::{map, map_res, opt, peek, recognize}, multi::{many0, separated_list0, separated_list1}, sequence::{delimited, pair, preceded, tuple}, IResult, @@ -438,6 +438,7 @@ fn comparison_operator(input: &str) -> IResult<&str, ComparisonOperator> { // Parse a basic value expression (without vector functions to avoid circular dependency) fn basic_value_expression(input: &str) -> IResult<&str, ValueExpression> { alt(( + parse_vector_literal, // Try vector literal first [0.1, 0.2] parse_parameter, // Try $parameter function_call, // Regular function calls map(property_value, ValueExpression::Literal), // Try literals BEFORE property references @@ -603,6 +604,47 @@ fn value_expression_list(input: &str) -> IResult<&str, Vec> { )(input) } +// Parse a float32 literal for vectors +fn float32_literal(input: &str) -> IResult<&str, f32> { + map_res( + recognize(tuple(( + opt(char('-')), + alt(( + // Scientific notation: 1e-3, 2.5e2 + recognize(tuple(( + digit1, + opt(tuple((char('.'), digit0))), + one_of("eE"), + opt(one_of("+-")), + digit1, + ))), + // Regular float: 1.23 or integer: 123 + recognize(tuple(( + digit1, + opt(tuple((char('.'), digit0))), + ))), + )), + ))), + |s: &str| s.parse::(), + )(input) +} + +// Parse vector literal: [0.1, 0.2, 0.3] +fn parse_vector_literal(input: &str) -> IResult<&str, ValueExpression> { + let (input, _) = char('[')(input)?; + let (input, _) = multispace0(input)?; + + let (input, values) = separated_list1( + tuple((multispace0, char(','), multispace0)), + float32_literal, + )(input)?; + + let (input, _) = multispace0(input)?; + let (input, _) = char(']')(input)?; + + Ok((input, ValueExpression::VectorLiteral(values))) +} + // Parse a property reference: variable.property fn property_reference(input: &str) -> IResult<&str, PropertyRef> { let (input, variable) = identifier(input)?; @@ -1597,4 +1639,85 @@ mod tests { _ => panic!("Expected AND expression"), } } + + #[test] + fn test_parse_vector_literal() { + let result = parse_vector_literal("[0.1, 0.2, 0.3]"); + assert!(result.is_ok()); + let (_, expr) = result.unwrap(); + match expr { + ValueExpression::VectorLiteral(vec) => { + assert_eq!(vec.len(), 3); + assert_eq!(vec[0], 0.1); + assert_eq!(vec[1], 0.2); + assert_eq!(vec[2], 0.3); + } + _ => panic!("Expected VectorLiteral"), + } + } + + #[test] + fn test_parse_vector_literal_with_negative_values() { + let result = parse_vector_literal("[-0.1, 0.2, -0.3]"); + assert!(result.is_ok()); + let (_, expr) = result.unwrap(); + match expr { + ValueExpression::VectorLiteral(vec) => { + assert_eq!(vec.len(), 3); + assert_eq!(vec[0], -0.1); + assert_eq!(vec[2], -0.3); + } + _ => panic!("Expected VectorLiteral"), + } + } + + #[test] + fn test_parse_vector_literal_scientific_notation() { + let result = parse_vector_literal("[1e-3, 2.5e2, -3e-1]"); + assert!(result.is_ok()); + let (_, expr) = result.unwrap(); + match expr { + ValueExpression::VectorLiteral(vec) => { + assert_eq!(vec.len(), 3); + assert!((vec[0] - 0.001).abs() < 1e-6); + assert!((vec[1] - 250.0).abs() < 1e-6); + assert!((vec[2] - (-0.3)).abs() < 1e-6); + } + _ => panic!("Expected VectorLiteral"), + } + } + + #[test] + fn test_vector_distance_with_literal() { + let query = "MATCH (p:Person) WHERE vector_distance(p.embedding, [0.1, 0.2], l2) < 0.5 RETURN p"; + let result = parse_cypher_query(query); + assert!(result.is_ok()); + + let ast = result.unwrap(); + let where_clause = ast.where_clause.expect("Expected WHERE clause"); + + match where_clause.expression { + BooleanExpression::Comparison { left, operator, .. } => { + match left { + ValueExpression::VectorDistance { left, right, metric } => { + // Left should be property reference + assert!(matches!(*left, ValueExpression::Property(_))); + // Right should be vector literal + match *right { + ValueExpression::VectorLiteral(vec) => { + assert_eq!(vec.len(), 2); + assert_eq!(vec[0], 0.1); + assert_eq!(vec[1], 0.2); + } + _ => panic!("Expected VectorLiteral"), + } + assert_eq!(metric, DistanceMetric::L2); + } + _ => panic!("Expected VectorDistance"), + } + assert_eq!(operator, ComparisonOperator::LessThan); + } + _ => panic!("Expected comparison"), + } + } } diff --git a/rust/lance-graph/src/semantic.rs b/rust/lance-graph/src/semantic.rs index 6f50a15..a34dace 100644 --- a/rust/lance-graph/src/semantic.rs +++ b/rust/lance-graph/src/semantic.rs @@ -378,6 +378,18 @@ impl SemanticAnalyzer { }); } } + ValueExpression::VectorLiteral(values) => { + // Validate non-empty + if values.is_empty() { + return Err(GraphError::PlanError { + message: "Vector literal cannot be empty".to_string(), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + + // Note: Very large vectors (>4096 dimensions) may impact performance + // but we don't enforce a hard limit here + } ValueExpression::Parameter(_) => { // Parameters are always valid (resolved at runtime) } diff --git a/rust/lance-graph/src/simple_executor/expr.rs b/rust/lance-graph/src/simple_executor/expr.rs index 7f1042c..ec1fb69 100644 --- a/rust/lance-graph/src/simple_executor/expr.rs +++ b/rust/lance-graph/src/simple_executor/expr.rs @@ -172,5 +172,6 @@ pub(crate) fn to_df_value_expr_simple( VE::VectorDistance { .. } => lit(0.0f32), VE::VectorSimilarity { .. } => lit(1.0f32), VE::Parameter(_) => lit(0), + VE::VectorLiteral(_) => lit(0.0f32), } } diff --git a/rust/lance-graph/tests/test_vector_search.rs b/rust/lance-graph/tests/test_vector_search.rs index af594fe..6975269 100644 --- a/rust/lance-graph/tests/test_vector_search.rs +++ b/rust/lance-graph/tests/test_vector_search.rs @@ -154,6 +154,36 @@ async fn test_vector_distance_l2() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_vector_distance_l2_with_literal() -> Result<()> { + let (config, datasets) = create_person_graph_with_embeddings(); + + // Same test as above but using vector literal instead of cross product + // Find people similar to [1, 0, 0] (Alice's embedding) + let query = CypherQuery::new( + "MATCH (p:Person) \ + WHERE vector_distance(p.embedding, [1.0, 0.0, 0.0], l2) < 0.2 \ + RETURN p.name ORDER BY p.name", + )? + .with_config(config); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await?; + + // Should return Alice (exact match) and Bob (very similar) + assert_eq!(result.num_rows(), 2); + let names = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.value(0), "Alice"); + assert_eq!(names.value(1), "Bob"); + + Ok(()) +} + #[tokio::test] async fn test_vector_distance_cosine() -> Result<()> { let (config, datasets) = create_person_graph_with_embeddings(); @@ -261,6 +291,37 @@ async fn test_vector_distance_order_by() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_vector_distance_order_by_with_literal() -> Result<()> { + let (config, datasets) = create_person_graph_with_embeddings(); + + // Same as above but using vector literal - order by distance to [1,0,0] (Alice's vector) + let query = CypherQuery::new( + "MATCH (p:Person) \ + RETURN p.name \ + ORDER BY vector_distance(p.embedding, [1.0, 0.0, 0.0], cosine) ASC \ + LIMIT 3", + )? + .with_config(config); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await?; + + // Should return Alice (closest), Bob (second), Eve (third) + assert_eq!(result.num_rows(), 3); + let names = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.value(0), "Alice"); + assert_eq!(names.value(1), "Bob"); + assert_eq!(names.value(2), "Eve"); + + Ok(()) +} + #[tokio::test] async fn test_vector_similarity_order_by() -> Result<()> { let (config, datasets) = create_person_graph_with_embeddings(); @@ -350,6 +411,36 @@ async fn test_hybrid_query_property_and_vector() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_hybrid_query_with_vector_literal() -> Result<()> { + let (config, datasets) = create_person_graph_with_embeddings(); + + // Combine property filter with vector literal search + let query = CypherQuery::new( + "MATCH (p:Person) \ + WHERE p.age > 25 \ + AND vector_distance(p.embedding, [1.0, 0.0, 0.0], l2) < 0.3 \ + RETURN p.name ORDER BY p.name", + )? + .with_config(config); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await?; + + // Should return only Alice (age 30 > 25, close to [1,0,0]) + // Bob is age 25, not > 25 + assert_eq!(result.num_rows(), 1); + let names = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.value(0), "Alice"); + + Ok(()) +} + #[tokio::test] async fn test_vector_distance_dot_product() -> Result<()> { let (config, datasets) = create_person_graph_with_embeddings(); @@ -621,3 +712,40 @@ async fn test_vector_similarity_self_comparison() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_vector_literal_in_return_clause() -> Result<()> { + let (config, datasets) = create_person_graph_with_embeddings(); + + // Use vector literal in RETURN to compute distances + let query = CypherQuery::new( + "MATCH (p:Person) \ + RETURN p.name, vector_distance(p.embedding, [0.5, 0.5, 0.0], l2) AS dist \ + ORDER BY dist ASC \ + LIMIT 1", + )? + .with_config(config); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await?; + + // Should return Eve (closest to [0.5, 0.5, 0.0]) + assert_eq!(result.num_rows(), 1); + let names = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.value(0), "Eve"); + + // Distance should be 0 (exact match) + let distances = result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(distances.value(0) < 0.001); + + Ok(()) +} From 9b450016bd65d46e8948d7e7ae3a69ceec67f23b Mon Sep 17 00:00:00 2001 From: JoshuaTang <1240604020@qq.com> Date: Mon, 5 Jan 2026 20:48:38 -0800 Subject: [PATCH 2/2] format code --- .../src/datafusion_planner/expression.rs | 2 +- rust/lance-graph/src/parser.rs | 20 ++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/rust/lance-graph/src/datafusion_planner/expression.rs b/rust/lance-graph/src/datafusion_planner/expression.rs index 586155b..3c8d981 100644 --- a/rust/lance-graph/src/datafusion_planner/expression.rs +++ b/rust/lance-graph/src/datafusion_planner/expression.rs @@ -243,7 +243,7 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { } VE::VectorLiteral(values) => { // Convert Vec to DataFusion scalar FixedSizeList - use arrow::array::{Float32Array, FixedSizeListArray}; + use arrow::array::{FixedSizeListArray, Float32Array}; use arrow::datatypes::{DataType, Field}; use datafusion::scalar::ScalarValue; use std::sync::Arc; diff --git a/rust/lance-graph/src/parser.rs b/rust/lance-graph/src/parser.rs index 140cceb..7427437 100644 --- a/rust/lance-graph/src/parser.rs +++ b/rust/lance-graph/src/parser.rs @@ -438,9 +438,9 @@ fn comparison_operator(input: &str) -> IResult<&str, ComparisonOperator> { // Parse a basic value expression (without vector functions to avoid circular dependency) fn basic_value_expression(input: &str) -> IResult<&str, ValueExpression> { alt(( - parse_vector_literal, // Try vector literal first [0.1, 0.2] - parse_parameter, // Try $parameter - function_call, // Regular function calls + parse_vector_literal, // Try vector literal first [0.1, 0.2] + parse_parameter, // Try $parameter + function_call, // Regular function calls map(property_value, ValueExpression::Literal), // Try literals BEFORE property references map(property_reference, ValueExpression::Property), map(identifier, |id| ValueExpression::Variable(id.to_string())), @@ -619,10 +619,7 @@ fn float32_literal(input: &str) -> IResult<&str, f32> { digit1, ))), // Regular float: 1.23 or integer: 123 - recognize(tuple(( - digit1, - opt(tuple((char('.'), digit0))), - ))), + recognize(tuple((digit1, opt(tuple((char('.'), digit0)))))), )), ))), |s: &str| s.parse::(), @@ -1689,7 +1686,8 @@ mod tests { #[test] fn test_vector_distance_with_literal() { - let query = "MATCH (p:Person) WHERE vector_distance(p.embedding, [0.1, 0.2], l2) < 0.5 RETURN p"; + let query = + "MATCH (p:Person) WHERE vector_distance(p.embedding, [0.1, 0.2], l2) < 0.5 RETURN p"; let result = parse_cypher_query(query); assert!(result.is_ok()); @@ -1699,7 +1697,11 @@ mod tests { match where_clause.expression { BooleanExpression::Comparison { left, operator, .. } => { match left { - ValueExpression::VectorDistance { left, right, metric } => { + ValueExpression::VectorDistance { + left, + right, + metric, + } => { // Left should be property reference assert!(matches!(*left, ValueExpression::Property(_))); // Right should be vector literal