diff --git a/src/rewrite/normal_form.rs b/src/rewrite/normal_form.rs index 9cce9d8..b5be586 100644 --- a/src/rewrite/normal_form.rs +++ b/src/rewrite/normal_form.rs @@ -656,6 +656,27 @@ impl Predicate { fn insert_binary_expr(&mut self, left: &Expr, op: Operator, right: &Expr) -> Result<()> { match (left, op, right) { (Expr::Column(c), op, Expr::Literal(v, _)) => { + // Normalize boolean expressions to canonical form: + // col = false -> NOT col + // col != true -> NOT col + // col = true -> col + // col != false -> col + // This ensures semantic equivalence matching (e.g., "active = false" matches "NOT active") + if let ScalarValue::Boolean(Some(b)) = v { + match (op, b) { + (Operator::Eq, false) | (Operator::NotEq, true) => { + self.residuals + .insert(Expr::Not(Box::new(Expr::Column(c.clone())))); + return Ok(()); + } + (Operator::Eq, true) | (Operator::NotEq, false) => { + self.residuals.insert(Expr::Column(c.clone())); + return Ok(()); + } + _ => {} + } + } + if let Err(e) = self.add_range(c, &op, v) { // Add a range can fail in some cases, so just fallthrough log::debug!("failed to add range filter: {e}"); @@ -1368,4 +1389,191 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_boolean_expression_normalization() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = SessionContext::new(); + + // Create table with boolean column + ctx.sql( + "CREATE TABLE bool_test ( + id INT, + active BOOLEAN, + name VARCHAR + )", + ) + .await? + .collect() + .await?; + + ctx.sql("INSERT INTO bool_test VALUES (1, true, 'a'), (2, false, 'b')") + .await? + .collect() + .await?; + + // MV: uses "active = false" + let mv_plan = ctx + .sql("SELECT * FROM bool_test WHERE active = false") + .await? + .into_optimized_plan()?; + let mv_normal_form = SpjNormalForm::new(&mv_plan)?; + + ctx.sql("CREATE TABLE mv AS SELECT * FROM bool_test WHERE active = false") + .await? + .collect() + .await?; + + // Query: uses "NOT active" (semantically equivalent to "active = false") + let query_plan = ctx + .sql("SELECT id, name FROM bool_test WHERE NOT active") + .await? + .into_optimized_plan()?; + let query_normal_form = SpjNormalForm::new(&query_plan)?; + + let table_ref = TableReference::bare("mv"); + let rewritten = query_normal_form.rewrite_from( + &mv_normal_form, + table_ref.clone(), + provider_as_source(ctx.table_provider(table_ref).await?), + )?; + + assert!( + rewritten.is_some(), + "Expected MV with 'active = false' to match query with 'NOT active'" + ); + + // Also test the reverse: MV with "NOT active", query with "active = false" + let mv_plan2 = ctx + .sql("SELECT * FROM bool_test WHERE NOT active") + .await? + .into_optimized_plan()?; + let mv_normal_form2 = SpjNormalForm::new(&mv_plan2)?; + + ctx.sql("CREATE TABLE mv2 AS SELECT * FROM bool_test WHERE NOT active") + .await? + .collect() + .await?; + + let query_plan2 = ctx + .sql("SELECT id FROM bool_test WHERE active = false") + .await? + .into_optimized_plan()?; + let query_normal_form2 = SpjNormalForm::new(&query_plan2)?; + + let table_ref2 = TableReference::bare("mv2"); + let rewritten2 = query_normal_form2.rewrite_from( + &mv_normal_form2, + table_ref2.clone(), + provider_as_source(ctx.table_provider(table_ref2).await?), + )?; + + assert!( + rewritten2.is_some(), + "Expected MV with 'NOT active' to match query with 'active = false'" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_boolean_column_normalization() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = SessionContext::new(); + + ctx.sql( + "CREATE TABLE bool_test ( + id INT, + active BOOLEAN, + name VARCHAR + )", + ) + .await? + .collect() + .await?; + + // Test: MV with "active = false" should match query with "NOT active" + let mv_plan = ctx + .sql("SELECT * FROM bool_test WHERE active = false") + .await? + .into_optimized_plan()?; + let mv_normal_form = SpjNormalForm::new(&mv_plan)?; + + ctx.sql("CREATE TABLE mv AS SELECT * FROM bool_test WHERE active = false") + .await? + .collect() + .await?; + + let query_plan = ctx + .sql("SELECT id, name FROM bool_test WHERE NOT active") + .await? + .into_optimized_plan()?; + let query_normal_form = SpjNormalForm::new(&query_plan)?; + + let table_ref = TableReference::bare("mv"); + let rewritten = query_normal_form.rewrite_from( + &mv_normal_form, + table_ref.clone(), + provider_as_source(ctx.table_provider(table_ref).await?), + )?; + + // Should successfully rewrite + assert!( + rewritten.is_some(), + "Expected MV with 'active = false' to match query with 'NOT active'" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_boolean_true_normalization() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let ctx = SessionContext::new(); + + ctx.sql( + "CREATE TABLE bool_test2 ( + id INT, + enabled BOOLEAN + )", + ) + .await? + .collect() + .await?; + + // Test: MV with "enabled = true" should match query with just "enabled" + let mv_plan = ctx + .sql("SELECT * FROM bool_test2 WHERE enabled = true") + .await? + .into_optimized_plan()?; + let mv_normal_form = SpjNormalForm::new(&mv_plan)?; + + ctx.sql("CREATE TABLE mv2 AS SELECT * FROM bool_test2 WHERE enabled = true") + .await? + .collect() + .await?; + + let query_plan = ctx + .sql("SELECT id FROM bool_test2 WHERE enabled") + .await? + .into_optimized_plan()?; + let query_normal_form = SpjNormalForm::new(&query_plan)?; + + let table_ref = TableReference::bare("mv2"); + let rewritten = query_normal_form.rewrite_from( + &mv_normal_form, + table_ref.clone(), + provider_as_source(ctx.table_provider(table_ref).await?), + )?; + + assert!( + rewritten.is_some(), + "Expected MV with 'enabled = true' to match query with 'enabled'" + ); + + Ok(()) + } }