Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rust/lance-graph/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ pub enum ValueExpression {
},
/// Parameter reference for query parameters (e.g., $query_vector)
Parameter(String),
/// Vector literal: [0.1, 0.2, 0.3]
/// Represents an inline vector for similarity search
VectorLiteral(Vec<f32>),
}

/// Arithmetic operators
Expand Down
19 changes: 19 additions & 0 deletions rust/lance-graph/src/datafusion_planner/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,25 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr {
vec![left_expr, right_expr],
))
}
VE::VectorLiteral(values) => {
// Convert Vec<f32> to DataFusion scalar FixedSizeList
use arrow::array::{FixedSizeListArray, Float32Array};
use arrow::datatypes::{DataType, Field};
use datafusion::scalar::ScalarValue;
use std::sync::Arc;

let dim = values.len() as i32;
let field = Arc::new(Field::new("item", DataType::Float32, true));
let float_array = Arc::new(Float32Array::from(values.clone()));

let list_array = FixedSizeListArray::try_new(field.clone(), dim, float_array, None)
.expect("Failed to create FixedSizeListArray for vector literal");

let scalar = ScalarValue::try_from_array(&list_array, 0)
.expect("Failed to create scalar from array");

lit(scalar)
}
VE::Parameter(name) => {
// TODO: Implement proper parameter resolution
// Parameters ($param) should be resolved to literal values from the query's
Expand Down
133 changes: 129 additions & 4 deletions rust/lance-graph/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use crate::error::{GraphError, Result};
use nom::{
branch::alt,
bytes::complete::{tag, tag_no_case, take_while1},
character::complete::{char, multispace0, multispace1},
combinator::{map, opt, peek, recognize},
character::complete::{char, digit0, digit1, multispace0, multispace1, one_of},
combinator::{map, map_res, opt, peek, recognize},
multi::{many0, separated_list0, separated_list1},
sequence::{delimited, pair, preceded, tuple},
IResult,
Expand Down Expand Up @@ -438,8 +438,9 @@ fn comparison_operator(input: &str) -> IResult<&str, ComparisonOperator> {
// Parse a basic value expression (without vector functions to avoid circular dependency)
fn basic_value_expression(input: &str) -> IResult<&str, ValueExpression> {
alt((
parse_parameter, // Try $parameter
function_call, // Regular function calls
parse_vector_literal, // Try vector literal first [0.1, 0.2]
parse_parameter, // Try $parameter
function_call, // Regular function calls
map(property_value, ValueExpression::Literal), // Try literals BEFORE property references
map(property_reference, ValueExpression::Property),
map(identifier, |id| ValueExpression::Variable(id.to_string())),
Expand Down Expand Up @@ -603,6 +604,44 @@ fn value_expression_list(input: &str) -> IResult<&str, Vec<ValueExpression>> {
)(input)
}

// Parse a float32 literal for vectors
fn float32_literal(input: &str) -> IResult<&str, f32> {
map_res(
recognize(tuple((
opt(char('-')),
alt((
// Scientific notation: 1e-3, 2.5e2
recognize(tuple((
digit1,
opt(tuple((char('.'), digit0))),
one_of("eE"),
opt(one_of("+-")),
digit1,
))),
// Regular float: 1.23 or integer: 123
recognize(tuple((digit1, opt(tuple((char('.'), digit0)))))),
)),
))),
|s: &str| s.parse::<f32>(),
)(input)
}

// Parse vector literal: [0.1, 0.2, 0.3]
fn parse_vector_literal(input: &str) -> IResult<&str, ValueExpression> {
let (input, _) = char('[')(input)?;
let (input, _) = multispace0(input)?;

let (input, values) = separated_list1(
tuple((multispace0, char(','), multispace0)),
float32_literal,
)(input)?;

let (input, _) = multispace0(input)?;
let (input, _) = char(']')(input)?;

Ok((input, ValueExpression::VectorLiteral(values)))
}

// Parse a property reference: variable.property
fn property_reference(input: &str) -> IResult<&str, PropertyRef> {
let (input, variable) = identifier(input)?;
Expand Down Expand Up @@ -1597,4 +1636,90 @@ mod tests {
_ => panic!("Expected AND expression"),
}
}

#[test]
fn test_parse_vector_literal() {
let result = parse_vector_literal("[0.1, 0.2, 0.3]");
assert!(result.is_ok());
let (_, expr) = result.unwrap();
match expr {
ValueExpression::VectorLiteral(vec) => {
assert_eq!(vec.len(), 3);
assert_eq!(vec[0], 0.1);
assert_eq!(vec[1], 0.2);
assert_eq!(vec[2], 0.3);
}
_ => panic!("Expected VectorLiteral"),
}
}

#[test]
fn test_parse_vector_literal_with_negative_values() {
let result = parse_vector_literal("[-0.1, 0.2, -0.3]");
assert!(result.is_ok());
let (_, expr) = result.unwrap();
match expr {
ValueExpression::VectorLiteral(vec) => {
assert_eq!(vec.len(), 3);
assert_eq!(vec[0], -0.1);
assert_eq!(vec[2], -0.3);
}
_ => panic!("Expected VectorLiteral"),
}
}

#[test]
fn test_parse_vector_literal_scientific_notation() {
let result = parse_vector_literal("[1e-3, 2.5e2, -3e-1]");
assert!(result.is_ok());
let (_, expr) = result.unwrap();
match expr {
ValueExpression::VectorLiteral(vec) => {
assert_eq!(vec.len(), 3);
assert!((vec[0] - 0.001).abs() < 1e-6);
assert!((vec[1] - 250.0).abs() < 1e-6);
assert!((vec[2] - (-0.3)).abs() < 1e-6);
}
_ => panic!("Expected VectorLiteral"),
}
}

#[test]
fn test_vector_distance_with_literal() {
let query =
"MATCH (p:Person) WHERE vector_distance(p.embedding, [0.1, 0.2], l2) < 0.5 RETURN p";
let result = parse_cypher_query(query);
assert!(result.is_ok());

let ast = result.unwrap();
let where_clause = ast.where_clause.expect("Expected WHERE clause");

match where_clause.expression {
BooleanExpression::Comparison { left, operator, .. } => {
match left {
ValueExpression::VectorDistance {
left,
right,
metric,
} => {
// Left should be property reference
assert!(matches!(*left, ValueExpression::Property(_)));
// Right should be vector literal
match *right {
ValueExpression::VectorLiteral(vec) => {
assert_eq!(vec.len(), 2);
assert_eq!(vec[0], 0.1);
assert_eq!(vec[1], 0.2);
}
_ => panic!("Expected VectorLiteral"),
}
assert_eq!(metric, DistanceMetric::L2);
}
_ => panic!("Expected VectorDistance"),
}
assert_eq!(operator, ComparisonOperator::LessThan);
}
_ => panic!("Expected comparison"),
}
}
}
12 changes: 12 additions & 0 deletions rust/lance-graph/src/semantic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,18 @@ impl SemanticAnalyzer {
});
}
}
ValueExpression::VectorLiteral(values) => {
// Validate non-empty
if values.is_empty() {
return Err(GraphError::PlanError {
message: "Vector literal cannot be empty".to_string(),
location: snafu::Location::new(file!(), line!(), column!()),
});
}

// Note: Very large vectors (>4096 dimensions) may impact performance
// but we don't enforce a hard limit here
}
ValueExpression::Parameter(_) => {
// Parameters are always valid (resolved at runtime)
}
Expand Down
1 change: 1 addition & 0 deletions rust/lance-graph/src/simple_executor/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,6 @@ pub(crate) fn to_df_value_expr_simple(
VE::VectorDistance { .. } => lit(0.0f32),
VE::VectorSimilarity { .. } => lit(1.0f32),
VE::Parameter(_) => lit(0),
VE::VectorLiteral(_) => lit(0.0f32),
}
}
128 changes: 128 additions & 0 deletions rust/lance-graph/tests/test_vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,36 @@ async fn test_vector_distance_l2() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_vector_distance_l2_with_literal() -> Result<()> {
let (config, datasets) = create_person_graph_with_embeddings();

// Same test as above but using vector literal instead of cross product
// Find people similar to [1, 0, 0] (Alice's embedding)
let query = CypherQuery::new(
"MATCH (p:Person) \
WHERE vector_distance(p.embedding, [1.0, 0.0, 0.0], l2) < 0.2 \
RETURN p.name ORDER BY p.name",
)?
.with_config(config);

let result = query
.execute(datasets, Some(ExecutionStrategy::DataFusion))
.await?;

// Should return Alice (exact match) and Bob (very similar)
assert_eq!(result.num_rows(), 2);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "Alice");
assert_eq!(names.value(1), "Bob");

Ok(())
}

#[tokio::test]
async fn test_vector_distance_cosine() -> Result<()> {
let (config, datasets) = create_person_graph_with_embeddings();
Expand Down Expand Up @@ -261,6 +291,37 @@ async fn test_vector_distance_order_by() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_vector_distance_order_by_with_literal() -> Result<()> {
let (config, datasets) = create_person_graph_with_embeddings();

// Same as above but using vector literal - order by distance to [1,0,0] (Alice's vector)
let query = CypherQuery::new(
"MATCH (p:Person) \
RETURN p.name \
ORDER BY vector_distance(p.embedding, [1.0, 0.0, 0.0], cosine) ASC \
LIMIT 3",
)?
.with_config(config);

let result = query
.execute(datasets, Some(ExecutionStrategy::DataFusion))
.await?;

// Should return Alice (closest), Bob (second), Eve (third)
assert_eq!(result.num_rows(), 3);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "Alice");
assert_eq!(names.value(1), "Bob");
assert_eq!(names.value(2), "Eve");

Ok(())
}

#[tokio::test]
async fn test_vector_similarity_order_by() -> Result<()> {
let (config, datasets) = create_person_graph_with_embeddings();
Expand Down Expand Up @@ -350,6 +411,36 @@ async fn test_hybrid_query_property_and_vector() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_hybrid_query_with_vector_literal() -> Result<()> {
let (config, datasets) = create_person_graph_with_embeddings();

// Combine property filter with vector literal search
let query = CypherQuery::new(
"MATCH (p:Person) \
WHERE p.age > 25 \
AND vector_distance(p.embedding, [1.0, 0.0, 0.0], l2) < 0.3 \
RETURN p.name ORDER BY p.name",
)?
.with_config(config);

let result = query
.execute(datasets, Some(ExecutionStrategy::DataFusion))
.await?;

// Should return only Alice (age 30 > 25, close to [1,0,0])
// Bob is age 25, not > 25
assert_eq!(result.num_rows(), 1);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "Alice");

Ok(())
}

#[tokio::test]
async fn test_vector_distance_dot_product() -> Result<()> {
let (config, datasets) = create_person_graph_with_embeddings();
Expand Down Expand Up @@ -621,3 +712,40 @@ async fn test_vector_similarity_self_comparison() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_vector_literal_in_return_clause() -> Result<()> {
let (config, datasets) = create_person_graph_with_embeddings();

// Use vector literal in RETURN to compute distances
let query = CypherQuery::new(
"MATCH (p:Person) \
RETURN p.name, vector_distance(p.embedding, [0.5, 0.5, 0.0], l2) AS dist \
ORDER BY dist ASC \
LIMIT 1",
)?
.with_config(config);

let result = query
.execute(datasets, Some(ExecutionStrategy::DataFusion))
.await?;

// Should return Eve (closest to [0.5, 0.5, 0.0])
assert_eq!(result.num_rows(), 1);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "Eve");

// Distance should be 0 (exact match)
let distances = result
.column(1)
.as_any()
.downcast_ref::<Float32Array>()
.unwrap();
assert!(distances.value(0) < 0.001);

Ok(())
}
Loading