diff --git a/AGENTS.md b/AGENTS.md index d6d795f..df0ba88 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -19,6 +19,7 @@ Miso is a query engine over semi-structured (JSON) logs that processes KQL (Kust - Be mindful of allocations in hot paths - Prefer structured logging - Provide helpful error messages +- Use #[test_case] when writing tests, and use snake_case for naming the tests ## Running / Testing diff --git a/miso-optimizations/src/dynamic_filter.rs b/miso-optimizations/src/dynamic_filter.rs index 3f91004..739e3be 100644 --- a/miso-optimizations/src/dynamic_filter.rs +++ b/miso-optimizations/src/dynamic_filter.rs @@ -19,7 +19,10 @@ use miso_common::watch::Watch; use miso_workflow::{WorkflowStep, scan::Scan}; +use miso_workflow_types::expr::Expr; +use miso_workflow_types::field::Field; use miso_workflow_types::join::JoinType; +use miso_workflow_types::project::ProjectField; use crate::pattern; @@ -39,7 +42,7 @@ impl DynamicFilter { impl Optimization for DynamicFilter { fn pattern(&self) -> Pattern { - pattern!(Scan [Count Limit TopN Summarize Sort Filter]*? Join) + pattern!(Scan [Count Limit TopN Summarize Sort Filter Project Extend Rename]*? Join) } fn apply(&self, steps: &[WorkflowStep], _groups: &[Group]) -> OptimizationResult { @@ -66,14 +69,11 @@ impl Optimization for DynamicFilter { &[] }; - let left_dcount = calculate_max_distinct_count( - left_join_field.to_string(), - left_scan, - left_steps_after_scan, - ) - .unwrap_or(self.max_distinct_values); + let left_dcount = + calculate_max_distinct_count(left_join_field.clone(), left_scan, left_steps_after_scan) + .unwrap_or(self.max_distinct_values); let right_dcount = calculate_max_distinct_count( - right_join_field.to_string(), + right_join_field.clone(), right_scan, right_steps_after_scan, ) @@ -137,8 +137,26 @@ impl Optimization for DynamicFilter { } } +fn resolve_fields( + fields: &mut [Field], + project_fields: &[ProjectField], + is_extend: bool, +) -> Option<()> { + for f in fields.iter_mut() { + if let Some(pf) = project_fields.iter().find(|pf| pf.to == *f) { + match &pf.from { + Expr::Field(source) => *f = source.clone(), + _ => return None, + } + } else if !is_extend { + return None; + } + } + Some(()) +} + fn calculate_max_distinct_count( - join_field: String, + join_field: Field, scan: &Scan, steps_after_scan: &[WorkflowStep], ) -> Option { @@ -160,19 +178,33 @@ fn calculate_max_distinct_count( return None; } prev_dcount = dcount.take(); - fields = summarize.by.iter().map(|bf| bf.name.to_string()).collect(); + fields = summarize.by.iter().map(|bf| bf.name.clone()).collect(); + } + + WorkflowStep::Project(pf) => { + resolve_fields(&mut fields, pf, false)?; + } + WorkflowStep::Extend(pf) => { + resolve_fields(&mut fields, pf, true)?; + } + + WorkflowStep::Rename(renames) => { + for f in fields.iter_mut() { + if let Some((from, _)) = renames.iter().find(|(_, to)| to == f) { + *f = from.clone(); + } + } } WorkflowStep::Sort(..) | WorkflowStep::Filter(..) => {} - // Unsupported (need to think about project & extend): _ => return None, } } let dcounts: Vec<_> = fields .iter() - .flat_map(|f| scan.get_field_stats(f)?.distinct_count) + .flat_map(|f| scan.get_field_stats(&f.to_string())?.distinct_count) .collect(); if dcounts.len() == fields.len() { @@ -192,7 +224,6 @@ mod tests { use async_trait::async_trait; use color_eyre::Result; - use std::collections::BTreeMap; use hashbrown::HashMap; use miso_connectors::{ @@ -200,7 +231,7 @@ mod tests { stats::{CollectionStats, ConnectorStats, FieldStats}, }; use miso_workflow::{Workflow, WorkflowStep, scan::Scan}; - use miso_workflow_types::{expr::Expr, join::Join, summarize::Summarize}; + use miso_workflow_types::{join::Join, value::Value}; use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use test_case::test_case; @@ -208,7 +239,7 @@ mod tests { use super::{ DynamicFilter, JoinType, Optimization, OptimizationResult, calculate_max_distinct_count, }; - use crate::test_utils::{field, summarize_by}; + use crate::test_utils::{field, literal_project, rename_project, sort_asc, summarize_by}; #[derive(Debug, Clone, Serialize, Deserialize, Default)] struct TestHandle; @@ -284,65 +315,96 @@ mod tests { } fn calc(scan: &Scan, steps: &[WorkflowStep]) -> Option { - calculate_max_distinct_count("id".to_string(), scan, steps) + calculate_max_distinct_count(field("id"), scan, steps) + } + + #[test_case(vec![("id", 100)], &[] => Some(100); "field_stats_only")] + #[test_case(vec![], &[] => None; "no_stats_returns_none")] + #[test_case(vec![("other", 50)], &[] => None; "unrelated_field_stats")] + fn calc_stats_only(stats: Vec<(&str, u64)>, steps: &[WorkflowStep]) -> Option { + calc(&scan(stats), steps) + } + + #[test_case(vec![("id", 100)], 10 => Some(10); "limit_below_stats")] + #[test_case(vec![("id", 5)], 100 => Some(5); "stats_below_limit")] + #[test_case(vec![], 25 => Some(25); "limit_without_stats")] + fn calc_limit(stats: Vec<(&str, u64)>, limit: u64) -> Option { + calc(&scan(stats), &[WorkflowStep::Limit(limit)]) + } + + #[test_case(vec![("id", 100)], 10 => Some(10); "topn_below_stats")] + #[test_case(vec![("id", 5)], 100 => Some(5); "stats_below_topn")] + #[test_case(vec![], 25 => Some(25); "topn_without_stats")] + fn calc_topn(stats: Vec<(&str, u64)>, limit: u64) -> Option { + calc( + &scan(stats), + &[WorkflowStep::TopN(vec![sort_asc(field("id"))], limit)], + ) + } + + #[test_case(vec![("id", 100)] => Some(1); "count_returns_one")] + #[test_case(vec![] => Some(1); "count_without_stats")] + fn calc_count(stats: Vec<(&str, u64)>) -> Option { + calc(&scan(stats), &[WorkflowStep::Count]) + } + + #[test_case(&["a", "b"], vec![("a", 5), ("b", 7)] => Some(35); "multiplies_group_by_dcounts")] + #[test_case(&[], vec![("id", 100)] => Some(1); "empty_group_by_returns_one")] + #[test_case(&["a", "b"], vec![("a", 0), ("b", 100)] => Some(0); "zero_dcount_propagates")] + #[test_case(&["a", "b"], vec![("a", u64::MAX), ("b", 2)] => None; "overflow_returns_none")] + #[test_case(&["a", "b"], vec![("a", 10)] => None; "incomplete_stats_returns_none")] + fn calc_summarize(by: &[&str], stats: Vec<(&str, u64)>) -> Option { + calc(&scan(stats), &[summarize_by(by)]) } #[test] - fn limit_takes_minimum_of_limit_and_stats() { + fn summarize_with_limit_takes_minimum() { + let s = scan(vec![("cat", 100)]); assert_eq!( - calc(&scan(vec![("id", 100)]), &[WorkflowStep::Limit(10)]), + calc(&s, &[summarize_by(&["cat"]), WorkflowStep::Limit(10)]), Some(10) ); assert_eq!( - calc(&scan(vec![("id", 5)]), &[WorkflowStep::Limit(100)]), - Some(5) + calc(&s, &[WorkflowStep::Limit(10), summarize_by(&["cat"])]), + Some(10) ); - assert_eq!(calc(&scan(vec![]), &[WorkflowStep::Limit(25)]), Some(25)); } #[test] - fn summarize_multiplies_group_by_distinct_counts() { + fn summarize_falls_back_to_limit_when_stats_incomplete() { assert_eq!( calc( - &scan(vec![("a", 5), ("b", 7)]), - &[summarize_by(&["a", "b"])] + &scan(vec![("a", 10)]), + &[summarize_by(&["a", "b"]), WorkflowStep::Limit(50)] ), - Some(35) - ); - } - - #[test] - fn summarize_with_empty_group_by_returns_one() { - assert_eq!( - calc(&scan(vec![("id", 100)]), &[summarize_by(&[])]), - Some(1) + Some(50) ); } #[test] - fn summarize_with_limit_takes_minimum() { - let s = scan(vec![("category", 100)]); - assert_eq!( - calc(&s, &[summarize_by(&["category"]), WorkflowStep::Limit(10)]), - Some(10) - ); + fn overflow_with_limit_falls_back_to_limit() { assert_eq!( - calc(&s, &[WorkflowStep::Limit(10), summarize_by(&["category"])]), - Some(10) + calc( + &scan(vec![("a", u64::MAX), ("b", 2)]), + &[summarize_by(&["a", "b"]), WorkflowStep::Limit(100)] + ), + Some(100) ); } #[test] - fn summarize_falls_back_to_limit_when_stats_incomplete() { - let s = scan(vec![("a", 10)]); + fn count_after_summarize_returns_one() { assert_eq!( - calc(&s, &[summarize_by(&["a", "b"]), WorkflowStep::Limit(50)]), - Some(50) + calc( + &scan(vec![("cat", 100)]), + &[summarize_by(&["cat"]), WorkflowStep::Count] + ), + Some(1) ); } #[test] - fn two_summarizes_without_limit_uses_outer() { + fn two_summarizes_uses_outer() { let s = scan(vec![("a", 10), ("b", 20)]); assert_eq!( calc(&s, &[summarize_by(&["a"]), summarize_by(&["b"])]), @@ -351,7 +413,7 @@ mod tests { } #[test] - fn two_summarizes_with_limit_after_bails_out() { + fn two_summarizes_with_limit_after_second_bails_out() { let s = scan(vec![("a", 10), ("b", 20)]); assert_eq!( calc( @@ -382,71 +444,131 @@ mod tests { ); } - #[test] - fn count_after_summarize_returns_one() { - let s = scan(vec![("category", 100)]); - assert_eq!( - calc(&s, &[summarize_by(&["category"]), WorkflowStep::Count]), - Some(1) - ); + #[test_case( + vec![("original_id", 50)], + vec![WorkflowStep::Project(vec![rename_project("id", "original_id")])] + => Some(50); "project_rename_resolves_to_source_stats" + )] + #[test_case( + vec![("id", 100)], + vec![WorkflowStep::Project(vec![literal_project("id", Value::Int(1))])] + => None; "project_with_non_field_expr_bails" + )] + #[test_case( + vec![("id", 100)], + vec![WorkflowStep::Project(vec![rename_project("other", "id")])] + => None; "project_drops_join_field" + )] + #[test_case( + vec![("id", 100)], + vec![WorkflowStep::Project(vec![])] + => None; "empty_project_drops_join_field" + )] + #[test_case( + vec![("id", 100)], + vec![WorkflowStep::Extend(vec![literal_project("extra", Value::Int(42))])] + => Some(100); "extend_passthrough_when_field_not_in_extend_list" + )] + #[test_case( + vec![("raw", 30)], + vec![WorkflowStep::Extend(vec![rename_project("id", "raw")])] + => Some(30); "extend_rename_resolves_to_source_stats" + )] + #[test_case( + vec![("id", 100)], + vec![WorkflowStep::Extend(vec![literal_project("id", Value::Int(1))])] + => None; "extend_overwrites_join_field_with_literal_bails" + )] + fn calc_project_extend(stats: Vec<(&str, u64)>, steps: Vec) -> Option { + calc(&scan(stats), &steps) } - #[test] - fn overflow_in_group_by_multiplication() { - let s = scan(vec![("a", u64::MAX), ("b", 2)]); - assert_eq!(calc(&s, &[summarize_by(&["a", "b"])]), None); - assert_eq!( - calc(&s, &[summarize_by(&["a", "b"]), WorkflowStep::Limit(100)]), - Some(100) - ); + #[test_case( + vec![("original_id", 50)], + vec![WorkflowStep::Rename(vec![(field("original_id"), field("id"))])] + => Some(50); "rename_resolves_to_source_stats" + )] + #[test_case( + vec![("id", 100)], + vec![WorkflowStep::Rename(vec![(field("other"), field("renamed"))])] + => Some(100); "rename_passthrough_when_field_not_renamed" + )] + #[test_case( + vec![("raw", 30)], + vec![ + WorkflowStep::Rename(vec![(field("raw"), field("mid"))]), + WorkflowStep::Rename(vec![(field("mid"), field("id"))]), + ] + => Some(30); "chained_renames_resolve_through" + )] + fn calc_rename(stats: Vec<(&str, u64)>, steps: Vec) -> Option { + calc(&scan(stats), &steps) } #[test] - fn unsupported_steps_return_none() { - let s = scan(vec![("id", 100)]); - assert_eq!(calc(&s, &[WorkflowStep::Project(vec![])]), None); - assert_eq!(calc(&s, &[WorkflowStep::Extend(vec![])]), None); - assert_eq!(calc(&s, &[WorkflowStep::MuxLimit(10)]), None); + fn rename_then_summarize_resolves_through() { assert_eq!( calc( - &s, - &[WorkflowStep::MuxSummarize(Summarize { - aggs: BTreeMap::new(), - by: vec![] - })] + &scan(vec![("raw_key", 25)]), + &[ + WorkflowStep::Rename(vec![(field("raw_key"), field("mapped"))]), + summarize_by(&["mapped"]), + ] ), - None + Some(25) ); } #[test] - fn unsupported_step_in_middle_returns_none() { - let s = scan(vec![("id", 100)]); + fn extend_then_summarize_resolves_through() { assert_eq!( calc( - &s, + &scan(vec![("raw_key", 25)]), &[ - WorkflowStep::Filter(Expr::Literal(true.into())), - WorkflowStep::Project(vec![]), - WorkflowStep::Limit(10), + WorkflowStep::Extend(vec![rename_project("mapped", "raw_key")]), + summarize_by(&["mapped"]), ] ), - None + Some(25) ); } + #[test_case( + vec![WorkflowStep::Sort(vec![sort_asc(field("id"))])] + => Some(50); "sort_is_passthrough" + )] + #[test_case( + vec![WorkflowStep::Filter(miso_workflow_types::expr::Expr::Literal(true.into()))] + => Some(50); "filter_is_passthrough" + )] + #[test_case( + vec![ + WorkflowStep::Filter(miso_workflow_types::expr::Expr::Literal(true.into())), + WorkflowStep::Sort(vec![sort_asc(field("id"))]), + WorkflowStep::Limit(10), + ] + => Some(10); "mixed_passthrough_steps_with_limit" + )] + fn calc_passthrough(steps: Vec) -> Option { + calc(&scan(vec![("id", 50)]), &steps) + } + #[test] - fn zero_distinct_count_results_in_zero() { + fn unsupported_step_returns_none() { assert_eq!( - calc( - &scan(vec![("a", 0), ("b", 100)]), - &[summarize_by(&["a", "b"])] - ), - Some(0) + calc(&scan(vec![("id", 100)]), &[WorkflowStep::MuxLimit(10)]), + None ); } - fn apply_opt(left: u64, right: u64, join_type: JoinType) -> (bool, bool) { + struct ApplyResult { + changed: bool, + add_not: bool, + left_has_rx: bool, + right_has_rx: bool, + } + + fn apply_opt(left: u64, right: u64, join_type: JoinType) -> ApplyResult { let join_step = WorkflowStep::Join( Join { on: (field("id"), field("id")), @@ -459,29 +581,120 @@ mod tests { match DynamicFilter::new(100).apply(&steps, &[]) { OptimizationResult::Changed(s) => { + let WorkflowStep::Scan(left_scan) = &s[0] else { + unreachable!() + }; let WorkflowStep::Join(_, w) = &s[1] else { unreachable!() }; - let WorkflowStep::Scan(s) = &w.steps[0] else { + let WorkflowStep::Scan(right_scan) = &w.steps[0] else { unreachable!() }; - (true, s.add_not_to_dynamic_filter) + ApplyResult { + changed: true, + add_not: right_scan.add_not_to_dynamic_filter, + left_has_rx: left_scan.dynamic_filter_rx.is_some(), + right_has_rx: right_scan.dynamic_filter_rx.is_some(), + } } - _ => (false, false), + _ => ApplyResult { + changed: false, + add_not: false, + left_has_rx: false, + right_has_rx: false, + }, } } - #[test_case(10, 20, JoinType::Inner => (true, false); "inner_both_small")] - #[test_case(10, 200, JoinType::Inner => (true, false); "inner_left_small")] - #[test_case(200, 10, JoinType::Inner => (true, false); "inner_right_small")] - #[test_case(200, 200, JoinType::Inner => (false, false); "inner_both_large")] - #[test_case(10, 200, JoinType::Left => (true, false); "left_preferred_small")] - #[test_case(200, 10, JoinType::Left => (true, true); "left_opposite_small_adds_not")] - #[test_case(200, 200, JoinType::Left => (false, false); "left_both_large")] - #[test_case(200, 10, JoinType::Right => (true, false); "right_preferred_small")] - #[test_case(10, 200, JoinType::Right => (true, true); "right_opposite_small_adds_not")] - #[test_case(200, 200, JoinType::Right => (false, false); "right_both_large")] - fn dynamic_filter_join_types(left: u64, right: u64, jt: JoinType) -> (bool, bool) { - apply_opt(left, right, jt) + #[test_case(10, 20, JoinType::Inner; "inner_both_small_picks_left_as_producer")] + #[test_case(10, 200, JoinType::Inner; "inner_left_small")] + fn inner_left_produces(left: u64, right: u64, jt: JoinType) { + let r = apply_opt(left, right, jt); + assert!(r.changed); + assert!(!r.add_not); + assert!(!r.left_has_rx, "left is producer, should not have rx"); + assert!(r.right_has_rx, "right is consumer, should have rx"); + } + + #[test] + fn inner_right_small_produces_from_right() { + let r = apply_opt(200, 10, JoinType::Inner); + assert!(r.changed); + assert!(!r.add_not); + assert!(r.left_has_rx, "left is consumer, should have rx"); + assert!(!r.right_has_rx, "right is producer, should not have rx"); + } + + #[test_case(200, 200, JoinType::Inner; "inner_both_large")] + #[test_case(200, 200, JoinType::Left; "left_both_large")] + #[test_case(200, 200, JoinType::Right; "right_both_large")] + fn both_large_unchanged(left: u64, right: u64, jt: JoinType) { + let r = apply_opt(left, right, jt); + assert!(!r.changed); + } + + #[test] + fn left_join_preferred_side_small() { + let r = apply_opt(10, 200, JoinType::Left); + assert!(r.changed); + assert!(!r.add_not); + assert!(!r.left_has_rx, "left is producer"); + assert!(r.right_has_rx, "right is consumer"); + } + + #[test] + fn left_join_opposite_side_small_adds_not() { + let r = apply_opt(200, 10, JoinType::Left); + assert!(r.changed); + assert!(r.add_not); + } + + #[test] + fn right_join_preferred_side_small() { + let r = apply_opt(200, 10, JoinType::Right); + assert!(r.changed); + assert!(!r.add_not); + assert!(r.left_has_rx, "left is consumer"); + assert!(!r.right_has_rx, "right is producer"); + } + + #[test] + fn right_join_opposite_side_small_adds_not() { + let r = apply_opt(10, 200, JoinType::Right); + assert!(r.changed); + assert!(r.add_not); + } + + #[test] + fn outer_join_small_side_adds_not() { + let r = apply_opt(10, 200, JoinType::Outer); + assert!(r.changed); + assert!(r.add_not); + } + + #[test] + fn apply_with_intermediate_steps_between_scan_and_join() { + let right_workflow = Workflow::new(vec![WorkflowStep::Scan(scan(vec![("id", 200)]))]); + let steps = vec![ + WorkflowStep::Scan(scan(vec![("id", 10)])), + WorkflowStep::Limit(5), + WorkflowStep::Join( + Join { + on: (field("id"), field("id")), + type_: JoinType::Inner, + partitions: 1, + }, + right_workflow, + ), + ]; + + let result = DynamicFilter::new(100).apply(&steps, &[]); + match result { + OptimizationResult::Changed(s) => { + assert_eq!(s.len(), 3); + assert!(matches!(&s[1], WorkflowStep::Limit(5))); + } + _ => panic!("expected Changed"), + } } } diff --git a/miso-workflow/src/tests.rs b/miso-workflow/src/tests.rs index 0fd8c94..1033a43 100644 --- a/miso-workflow/src/tests.rs +++ b/miso-workflow/src/tests.rs @@ -243,8 +243,8 @@ async fn check_multi_connectors( expected: &str, should_cancel: bool, apply_filter_tx: Option>, - run_only_once: bool, workflow_limits: WorkflowLimits, + max_dynamic_filter_distinct_values: Option, ) -> Result<()> { let expected_logs = { let mut v: Vec<_> = serde_json::from_str::>(expected) @@ -291,7 +291,10 @@ async fn check_multi_connectors( let steps = to_workflow_steps(&connectors, &views, parse(query).expect("parse query")) .expect("workflow steps to compile"); - let optimizer = Optimizer::default(); + let optimizer = match max_dynamic_filter_distinct_values { + Some(max) => Optimizer::with_dynamic_filtering(max), + None => Optimizer::default(), + }; let steps_cloned = steps.clone(); let optimized_steps = spawn_blocking(move || optimizer.optimize(steps_cloned)).await?; @@ -306,6 +309,7 @@ async fn check_multi_connectors( .and_then(|test_run_str| test_run_str.parse().ok()) .unwrap_or(1); + let run_only_once = apply_filter_tx.is_some(); if run_only_once || test_runs == 1 { assert_workflows( no_optimizations_workflow, @@ -339,8 +343,8 @@ async fn check_multi_collection( expect: &str, cancel: Option, apply_filter_tx: Option>, - run_only_once: Option, workflow_limits: Option, + max_dynamic_filter_distinct_values: Option, ) -> Result<()> { check_multi_connectors( query, @@ -349,8 +353,8 @@ async fn check_multi_collection( expect, cancel.unwrap_or(false), apply_filter_tx, - run_only_once.unwrap_or(false), workflow_limits.unwrap_or_default(), + max_dynamic_filter_distinct_values, ) .await } @@ -1160,6 +1164,42 @@ async fn project_summarize_bin() -> Result<()> { .await } +fn assert_dynamic_filter( + rx: std::sync::mpsc::Receiver, + expected_values: Vec, + expect_not: bool, +) { + let ast = rx.recv().expect("recv() apply dynamic filter"); + + let (field_box, mut actual_vec) = match (expect_not, ast) { + (true, Expr::Not(inner)) => match *inner { + Expr::In(f, v) => (f, v), + other => panic!("Expected Expr::Not(Expr::In(...)), but inner was: {other:?}"), + }, + (false, Expr::In(f, v)) => (f, v), + (true, other) => panic!("Expected Expr::Not(...), but got: {other:?}"), + (false, other) => panic!("Expected Expr::In(...), but got: {other:?}"), + }; + + assert_eq!(field_box, Box::new(Expr::Field(field_unwrap!("id")))); + + let cmp = |a: &Expr, b: &Expr| -> Ordering { + match (a, b) { + (Expr::Literal(a), Expr::Literal(b)) => a.cmp(b), + _ => panic!("Unexpected Expr variants: {a:?} vs {b:?}"), + } + }; + let mut expected_vec: Vec = expected_values.into_iter().map(Expr::Literal).collect(); + actual_vec.sort_by(cmp); + expected_vec.sort_by(cmp); + assert_eq!(actual_vec, expected_vec); + + assert!(matches!( + rx.try_recv(), + Err(std::sync::mpsc::TryRecvError::Empty | std::sync::mpsc::TryRecvError::Disconnected) + )); +} + #[tokio::test] #[test_case(1)] #[test_case(10)] @@ -1183,46 +1223,11 @@ async fn join_inner(partitions: usize) -> Result<()> { ]"# ) .apply_filter_tx(tx) - .run_only_once(true) .call() .await .context("check multi collection")?; - let ast = rx.recv().context("recv() apply dynamic filter")?; - match ast { - Expr::In(id_box, mut actual_vec) => { - assert_eq!(id_box, Box::new(Expr::Field(field_unwrap!("id")))); - - let mut expected_vec = vec![ - Expr::Literal(1.into()), - Expr::Literal(2.into()), - Expr::Literal(3.into()), - ]; - - let compare_filter_asts = |a: &Expr, b: &Expr| -> Ordering { - match (a, b) { - (Expr::Literal(val_a), Expr::Literal(val_b)) => val_a.cmp(val_b), - _ => { - panic!("Unexpected Expr variants in Vec during comparison: {a:?} vs {b:?}") - } - } - }; - - actual_vec.sort_by(compare_filter_asts); - expected_vec.sort_by(compare_filter_asts); - - assert_eq!(actual_vec, expected_vec); - } - _ => { - panic!("Expected Expr::In variant, but got: {ast:?}"); - } - } - - assert!(matches!( - rx.try_recv(), - Err(std::sync::mpsc::TryRecvError::Empty | std::sync::mpsc::TryRecvError::Disconnected) - )); - + assert_dynamic_filter(rx, vec![1.into(), 2.into(), 3.into()], false); Ok(()) } @@ -1256,6 +1261,8 @@ async fn join_outer(partitions: usize) -> Result<()> { #[test_case(1)] #[test_case(10)] async fn join_left(partitions: usize) -> Result<()> { + let (tx, rx) = std::sync::mpsc::channel(); + check_multi_collection() .query( &format!(r#"test.left | join kind=left hint.partitions={partitions} (test.right) on id"#) @@ -1271,14 +1278,21 @@ async fn join_left(partitions: usize) -> Result<()> { {"id": 3, "value": "three"} ]"# ) + .apply_filter_tx(tx) .call() .await + .context("check multi collection")?; + + assert_dynamic_filter(rx, vec![1.into(), 2.into(), 3.into()], false); + Ok(()) } #[tokio::test] #[test_case(1)] #[test_case(10)] async fn join_right(partitions: usize) -> Result<()> { + let (tx, rx) = std::sync::mpsc::channel(); + check_multi_collection() .query( &format!(r#"test.left | join kind=right hint.partitions={partitions} (test.right) on id"#) @@ -1294,8 +1308,75 @@ async fn join_right(partitions: usize) -> Result<()> { {"id": 4, "value": "FOUR"} ]"# ) + .apply_filter_tx(tx) + .call() + .await + .context("check multi collection")?; + + assert_dynamic_filter(rx, vec![1.into(), 2.into(), 4.into()], false); + Ok(()) +} + +#[tokio::test] +#[test_case(1)] +#[test_case(10)] +async fn join_left_dynamic_filter_not(partitions: usize) -> Result<()> { + let (tx, rx) = std::sync::mpsc::channel(); + + check_multi_collection() + .query( + &format!(r#"test.left | join kind=left hint.partitions={partitions} (test.right) on id"#) + ) + .input(btreemap!{ + "left" => r#"[{"id": 1, "value": "one"}, {"id": 2, "value": "two"}, {"id": 3, "value": "three"}]"#, + "right" => r#"[{"id": 1, "value": "ONE"}, {"id": 2, "value": "TWO"}]"#, + }) + .expect( + r#"[ + {"id": 1, "value": "one"}, + {"id": 2, "value": "two"}, + {"id": 3, "value": "three"} + ]"# + ) + .apply_filter_tx(tx) + .max_dynamic_filter_distinct_values(3) .call() .await + .context("check multi collection")?; + + assert_dynamic_filter(rx, vec![1.into(), 2.into()], true); + Ok(()) +} + +#[tokio::test] +#[test_case(1)] +#[test_case(10)] +async fn join_right_dynamic_filter_not(partitions: usize) -> Result<()> { + let (tx, rx) = std::sync::mpsc::channel(); + + check_multi_collection() + .query( + &format!(r#"test.left | join kind=right hint.partitions={partitions} (test.right) on id"#) + ) + .input(btreemap!{ + "left" => r#"[{"id": 1, "value": "one"}, {"id": 2, "value": "TWO"}]"#, + "right" => r#"[{"id": 1, "value": "ONE"}, {"id": 2, "value": "two"}, {"id": 3, "value": "three"}]"#, + }) + .expect( + r#"[ + {"id": 1, "value": "ONE"}, + {"id": 2, "value": "two"}, + {"id": 3, "value": "three"} + ]"# + ) + .apply_filter_tx(tx) + .max_dynamic_filter_distinct_values(3) + .call() + .await + .context("check multi collection")?; + + assert_dynamic_filter(rx, vec![1.into(), 2.into()], true); + Ok(()) } #[tokio::test]