Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions src/rewrite/normal_form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,27 @@ impl Predicate {
fn insert_binary_expr(&mut self, left: &Expr, op: Operator, right: &Expr) -> Result<()> {
match (left, op, right) {
(Expr::Column(c), op, Expr::Literal(v, _)) => {
// Normalize boolean expressions to canonical form:
// col = false -> NOT col
// col != true -> NOT col
// col = true -> col
// col != false -> col
// This ensures semantic equivalence matching (e.g., "active = false" matches "NOT active")
if let ScalarValue::Boolean(Some(b)) = v {
match (op, b) {
(Operator::Eq, false) | (Operator::NotEq, true) => {
self.residuals
.insert(Expr::Not(Box::new(Expr::Column(c.clone()))));
return Ok(());
}
(Operator::Eq, true) | (Operator::NotEq, false) => {
self.residuals.insert(Expr::Column(c.clone()));
return Ok(());
}
_ => {}
}
}

if let Err(e) = self.add_range(c, &op, v) {
// Add a range can fail in some cases, so just fallthrough
log::debug!("failed to add range filter: {e}");
Expand Down Expand Up @@ -1368,4 +1389,191 @@ mod test {

Ok(())
}

#[tokio::test]
async fn test_boolean_expression_normalization() -> Result<()> {
let _ = env_logger::builder().is_test(true).try_init();

let ctx = SessionContext::new();

// Create table with boolean column
ctx.sql(
"CREATE TABLE bool_test (
id INT,
active BOOLEAN,
name VARCHAR
)",
)
.await?
.collect()
.await?;

ctx.sql("INSERT INTO bool_test VALUES (1, true, 'a'), (2, false, 'b')")
.await?
.collect()
.await?;

// MV: uses "active = false"
let mv_plan = ctx
.sql("SELECT * FROM bool_test WHERE active = false")
.await?
.into_optimized_plan()?;
let mv_normal_form = SpjNormalForm::new(&mv_plan)?;

ctx.sql("CREATE TABLE mv AS SELECT * FROM bool_test WHERE active = false")
.await?
.collect()
.await?;

// Query: uses "NOT active" (semantically equivalent to "active = false")
let query_plan = ctx
.sql("SELECT id, name FROM bool_test WHERE NOT active")
.await?
.into_optimized_plan()?;
let query_normal_form = SpjNormalForm::new(&query_plan)?;

let table_ref = TableReference::bare("mv");
let rewritten = query_normal_form.rewrite_from(
&mv_normal_form,
table_ref.clone(),
provider_as_source(ctx.table_provider(table_ref).await?),
)?;

assert!(
rewritten.is_some(),
"Expected MV with 'active = false' to match query with 'NOT active'"
);

// Also test the reverse: MV with "NOT active", query with "active = false"
let mv_plan2 = ctx
.sql("SELECT * FROM bool_test WHERE NOT active")
.await?
.into_optimized_plan()?;
let mv_normal_form2 = SpjNormalForm::new(&mv_plan2)?;

ctx.sql("CREATE TABLE mv2 AS SELECT * FROM bool_test WHERE NOT active")
.await?
.collect()
.await?;

let query_plan2 = ctx
.sql("SELECT id FROM bool_test WHERE active = false")
.await?
.into_optimized_plan()?;
let query_normal_form2 = SpjNormalForm::new(&query_plan2)?;

let table_ref2 = TableReference::bare("mv2");
let rewritten2 = query_normal_form2.rewrite_from(
&mv_normal_form2,
table_ref2.clone(),
provider_as_source(ctx.table_provider(table_ref2).await?),
)?;

assert!(
rewritten2.is_some(),
"Expected MV with 'NOT active' to match query with 'active = false'"
);

Ok(())
}

#[tokio::test]
async fn test_boolean_column_normalization() -> Result<()> {
let _ = env_logger::builder().is_test(true).try_init();

let ctx = SessionContext::new();

ctx.sql(
"CREATE TABLE bool_test (
id INT,
active BOOLEAN,
name VARCHAR
)",
)
.await?
.collect()
.await?;

// Test: MV with "active = false" should match query with "NOT active"
let mv_plan = ctx
.sql("SELECT * FROM bool_test WHERE active = false")
.await?
.into_optimized_plan()?;
let mv_normal_form = SpjNormalForm::new(&mv_plan)?;

ctx.sql("CREATE TABLE mv AS SELECT * FROM bool_test WHERE active = false")
.await?
.collect()
.await?;

let query_plan = ctx
.sql("SELECT id, name FROM bool_test WHERE NOT active")
.await?
.into_optimized_plan()?;
let query_normal_form = SpjNormalForm::new(&query_plan)?;

let table_ref = TableReference::bare("mv");
let rewritten = query_normal_form.rewrite_from(
&mv_normal_form,
table_ref.clone(),
provider_as_source(ctx.table_provider(table_ref).await?),
)?;

// Should successfully rewrite
assert!(
rewritten.is_some(),
"Expected MV with 'active = false' to match query with 'NOT active'"
);

Ok(())
}

#[tokio::test]
async fn test_boolean_true_normalization() -> Result<()> {
let _ = env_logger::builder().is_test(true).try_init();

let ctx = SessionContext::new();

ctx.sql(
"CREATE TABLE bool_test2 (
id INT,
enabled BOOLEAN
)",
)
.await?
.collect()
.await?;

// Test: MV with "enabled = true" should match query with just "enabled"
let mv_plan = ctx
.sql("SELECT * FROM bool_test2 WHERE enabled = true")
.await?
.into_optimized_plan()?;
let mv_normal_form = SpjNormalForm::new(&mv_plan)?;

ctx.sql("CREATE TABLE mv2 AS SELECT * FROM bool_test2 WHERE enabled = true")
.await?
.collect()
.await?;

let query_plan = ctx
.sql("SELECT id FROM bool_test2 WHERE enabled")
.await?
.into_optimized_plan()?;
let query_normal_form = SpjNormalForm::new(&query_plan)?;

let table_ref = TableReference::bare("mv2");
let rewritten = query_normal_form.rewrite_from(
&mv_normal_form,
table_ref.clone(),
provider_as_source(ctx.table_provider(table_ref).await?),
)?;

assert!(
rewritten.is_some(),
"Expected MV with 'enabled = true' to match query with 'enabled'"
);

Ok(())
}
}