diff --git a/crates/polars-io/src/arrow_predicate.rs b/crates/polars-io/src/arrow_predicate.rs new file mode 100644 index 000000000000..9c1c6f442b30 --- /dev/null +++ b/crates/polars-io/src/arrow_predicate.rs @@ -0,0 +1,57 @@ +use polars_core::datatypes::{TimeUnit, TimeZone}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +// This is meant to mimic `pyarrow.compute.Expression` API as closely as possible to make it +// easier to convert directly to pyarrow predicates applied at the python/scan level. +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ArrowPredicate { + Column(String), + Literal(LiteralValue), + // Binary ops + Comparison { + left: Box, + op: ComparisonOp, + right: Box, + }, + // Logicals + And(Box, Box), + Or(Box, Box), + Xor(Box, Box), + Not(Box), + // Methods that turn into masks + IsNull(Box), + IsIn { + expr: Box, + values: Vec, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ComparisonOp { + Eq, + NotEq, + Lt, + Lte, + Gt, + Gte, +} + +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum LiteralValue { + Null, + Int(i64), + Float(f64), + String(String), + Bool(bool), + Date(i32), + // Should mimic python datetime module semantics + Datetime { + value: i64, + time_unit: TimeUnit, + time_zone: Option, + }, +} diff --git a/crates/polars-io/src/lib.rs b/crates/polars-io/src/lib.rs index d239f5ebb5a1..49c1a862d75c 100644 --- a/crates/polars-io/src/lib.rs +++ b/crates/polars-io/src/lib.rs @@ -7,6 +7,7 @@ #![allow(ambiguous_glob_reexports)] extern crate core; +pub mod arrow_predicate; #[cfg(feature = "avro")] pub mod avro; #[cfg(feature = "catalog")] diff --git a/crates/polars-mem-engine/src/arrow_predicate_pyo3.rs b/crates/polars-mem-engine/src/arrow_predicate_pyo3.rs new file mode 100644 index 000000000000..dcf6a1f349fe --- /dev/null +++ b/crates/polars-mem-engine/src/arrow_predicate_pyo3.rs @@ -0,0 +1,121 @@ +use polars_core::datatypes::TimeUnit; +use polars_io::arrow_predicate::{ArrowPredicate, ComparisonOp, LiteralValue}; +use pyo3::IntoPyObjectExt; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; + +pub fn arrow_predicate_to_pyobject<'py>( + py: Python<'py>, + pred: &ArrowPredicate, +) -> PyResult> { + let pc = py.import("pyarrow.compute")?; + build_expr(py, &pc, pred) +} + +// Conversion of ArrowPredicate enums to pyarrow expressions +fn build_expr<'py>( + py: Python<'py>, + pc: &Bound<'py, PyAny>, + p: &ArrowPredicate, +) -> PyResult> { + match p { + ArrowPredicate::Column(name) => pc.call_method1("field", (name.as_str(),)), + ArrowPredicate::Literal(lv) => { + let val = literal_to_py(py, lv)?; + pc.call_method1("scalar", (val,)) + }, + ArrowPredicate::Comparison { left, op, right } => { + let l = build_expr(py, pc, left)?; + let r = build_expr(py, pc, right)?; + // This would stop working if pyarrow changed their overloaded operators, btw. + let method = match op { + ComparisonOp::Eq => "__eq__", + ComparisonOp::NotEq => "__ne__", + ComparisonOp::Lt => "__lt__", + ComparisonOp::Lte => "__le__", + ComparisonOp::Gt => "__gt__", + ComparisonOp::Gte => "__ge__", + }; + l.call_method1(method, (r,)) + }, + ArrowPredicate::And(l, r) => { + let l = build_expr(py, pc, l)?; + let r = build_expr(py, pc, r)?; + l.call_method1("__and__", (r,)) + }, + ArrowPredicate::Or(l, r) => { + let l = build_expr(py, pc, l)?; + let r = build_expr(py, pc, r)?; + l.call_method1("__or__", (r,)) + }, + ArrowPredicate::Xor(l, r) => { + let l = build_expr(py, pc, l)?; + let r = build_expr(py, pc, r)?; + l.call_method1("__xor__", (r,)) + }, + ArrowPredicate::Not(inner) => { + let i = build_expr(py, pc, inner)?; + i.call_method0("__invert__") + }, + ArrowPredicate::IsNull(inner) => { + let i = build_expr(py, pc, inner)?; + i.call_method0("is_null") + }, + ArrowPredicate::IsIn { expr, values } => { + if values.is_empty() { + return pc.call_method1("scalar", (false,)); + } + let e = build_expr(py, pc, expr)?; + let py_values: Vec> = values + .iter() + .map(|v| literal_to_py(py, v).map(|b| b.unbind())) + .collect::>()?; + let list = PyList::new(py, &py_values)?; + e.call_method1("isin", (list,)) + }, + } +} + +fn literal_to_py<'py>(py: Python<'py>, lv: &LiteralValue) -> PyResult> { + match lv { + LiteralValue::Null => Ok(py.None().into_bound(py)), + LiteralValue::Int(v) => Ok(v.into_pyobject(py)?.into_any()), + LiteralValue::Float(v) => Ok(v.into_pyobject(py)?.into_any()), + LiteralValue::String(s) => Ok(s.into_pyobject(py)?.into_any()), + LiteralValue::Bool(v) => v.into_bound_py_any(py), + LiteralValue::Date(days) => { + let dt_mod = py.import("datetime")?; + let epoch = dt_mod.getattr("date")?.call1((1970i32, 1i32, 1i32))?; + let delta = dt_mod.getattr("timedelta")?.call1((*days as i64,))?; + epoch.call_method1("__add__", (delta,)) + }, + LiteralValue::Datetime { + value, + time_unit, + time_zone, + } => { + // This conversion does not feel idiomatic but it's probably the best way to do this + // within the confines of the python standard lib. + let dt_mod = py.import("datetime")?; + let epoch = dt_mod.getattr("datetime")?.call1((1970i32, 1i32, 1i32))?; + let micros: i64 = match time_unit { + TimeUnit::Nanoseconds => value / 1000, + TimeUnit::Microseconds => *value, + TimeUnit::Milliseconds => value * 1000, + }; + let kwargs = PyDict::new(py); + kwargs.set_item("microseconds", micros)?; + let delta = dt_mod.getattr("timedelta")?.call((), Some(&kwargs))?; + let naive = epoch.call_method1("__add__", (delta,))?; + if let Some(tz) = time_zone { + let zi = py.import("zoneinfo")?; + let tz_obj = zi.getattr("ZoneInfo")?.call1((tz.to_string(),))?; + let kw = PyDict::new(py); + kw.set_item("tzinfo", tz_obj)?; + naive.call_method("replace", (), Some(&kw)) + } else { + Ok(naive) + } + }, + } +} diff --git a/crates/polars-mem-engine/src/executors/scan/python_scan.rs b/crates/polars-mem-engine/src/executors/scan/python_scan.rs index 38a52228c0f4..42411ec8ed10 100644 --- a/crates/polars-mem-engine/src/executors/scan/python_scan.rs +++ b/crates/polars-mem-engine/src/executors/scan/python_scan.rs @@ -71,7 +71,11 @@ impl Executor for PythonScanExec { let mut could_serialize_predicate = true; let predicate = match &self.options.predicate { - PythonPredicate::PyArrow(s) => s.into_bound_py_any(py).unwrap(), + PythonPredicate::PyArrow(pred) => { + crate::arrow_predicate_pyo3::arrow_predicate_to_pyobject(py, pred).map_err( + |e| polars_err!(ComputeError: "failed to build pyarrow predicate: {}", e), + )? + }, PythonPredicate::None => None::<()>.into_bound_py_any(py).unwrap(), PythonPredicate::Polars(_) => { assert!(self.predicate.is_some(), "should be set"); diff --git a/crates/polars-mem-engine/src/lib.rs b/crates/polars-mem-engine/src/lib.rs index 19ce74cbb234..d400733821a3 100644 --- a/crates/polars-mem-engine/src/lib.rs +++ b/crates/polars-mem-engine/src/lib.rs @@ -2,6 +2,8 @@ feature = "allow_unused", allow(unused, dead_code, irrefutable_let_patterns) )] // Maybe be caused by some feature +#[cfg(feature = "python")] +pub mod arrow_predicate_pyo3; mod executors; mod planner; mod prelude; diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index b6d4734c80f4..43a7c6d94a07 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -196,19 +196,18 @@ pub fn python_scan_predicate( // Convert to a pyarrow eval string. if matches!(options.python_source, PythonScanSource::Pyarrow) { use polars_core::config::verbose_print_sensitive; + use polars_io::arrow_predicate::ArrowPredicate; use polars_plan::plans::MintermIter; // If there is a `head`, that comes before the filter and we post-apply - // the predicate in the engine. No need to transpile to `pyarrow.predicate` + // the predicate in the engine. let residual_predicate_expr_ir = if options.n_rows.is_none() { // Split into AND-minterms and convert each independently. let mut residual_predicate_nodes: Vec = vec![]; - let parts: Vec = MintermIter::new(e.node(), expr_arena) + let parts: Vec = MintermIter::new(e.node(), expr_arena) .filter_map(|node| { - let result = polars_plan::plans::python::pyarrow::predicate_to_pa( - node, - expr_arena, - Default::default(), + let result = polars_plan::plans::python::pyarrow::predicate_to_arrow_pred( + node, expr_arena, ); if result.is_none() { residual_predicate_nodes.push(node); @@ -220,11 +219,17 @@ pub fn python_scan_predicate( let predicate_pa = match parts.len() { 0 => None, 1 => Some(parts.into_iter().next().unwrap()), - _ => Some(format!("({})", parts.join(" & "))), + _ => { + let combined = parts + .into_iter() + .reduce(|acc, p| ArrowPredicate::And(Box::new(acc), Box::new(p))) + .unwrap(); + Some(combined) + }, }; - if let Some(eval_str) = predicate_pa { - options.predicate = PythonPredicate::PyArrow(eval_str); + if let Some(pred) = predicate_pa { + options.predicate = PythonPredicate::PyArrow(pred); residual_predicate_nodes .into_iter() @@ -247,8 +252,8 @@ pub fn python_scan_predicate( verbose_print_sensitive(|| { let predicate_pa_verbose_msg = match &options.predicate { - PythonPredicate::PyArrow(p) => p, - _ => "", + PythonPredicate::PyArrow(p) => format!("{:?}", p), + _ => "".to_string(), }; format!( diff --git a/crates/polars-plan/src/plans/ir/dot.rs b/crates/polars-plan/src/plans/ir/dot.rs index 603dcfcaf69e..cc0efbd4f64e 100644 --- a/crates/polars-plan/src/plans/ir/dot.rs +++ b/crates/polars-plan/src/plans/ir/dot.rs @@ -139,7 +139,7 @@ impl<'a> IRDotDisplay<'a> { PythonScan { options } => { let predicate = match &options.predicate { PythonPredicate::Polars(e) => format!("{}", self.display_expr(e)), - PythonPredicate::PyArrow(s) => s.clone(), + PythonPredicate::PyArrow(p) => format!("{:?}", p), PythonPredicate::None => "none".to_string(), }; let with_columns = NumColumns(options.with_columns.as_ref().map(|s| s.as_ref())); diff --git a/crates/polars-plan/src/plans/python/pyarrow.rs b/crates/polars-plan/src/plans/python/pyarrow.rs index 103dbc793014..19c0d2f9d013 100644 --- a/crates/polars-plan/src/plans/python/pyarrow.rs +++ b/crates/polars-plan/src/plans/python/pyarrow.rs @@ -3,6 +3,7 @@ use std::fmt::Write; use polars_core::datatypes::AnyValue; use polars_core::prelude::{DataType, TimeUnit, TimeZone}; use polars_core::series::Series; +use polars_io::arrow_predicate::{ArrowPredicate, ComparisonOp, LiteralValue as ArrowLiteralValue}; use polars_utils::pl_str::PlSmallStr; use crate::prelude::*; @@ -279,3 +280,222 @@ pub fn predicate_to_pa( _ => None, } } + +fn comparison_op(op: &Operator) -> Option { + Some(match op { + Operator::Eq => ComparisonOp::Eq, + Operator::NotEq => ComparisonOp::NotEq, + Operator::Lt => ComparisonOp::Lt, + Operator::LtEq => ComparisonOp::Lte, + Operator::Gt => ComparisonOp::Gt, + Operator::GtEq => ComparisonOp::Gte, + _ => return None, + }) +} + +fn anyvalue_to_arrow_literal(av: AnyValue<'_>) -> Option { + let dtype = av.dtype(); + match av.as_borrowed() { + AnyValue::Null => Some(ArrowLiteralValue::Null), + AnyValue::Boolean(v) => Some(ArrowLiteralValue::Bool(v)), + AnyValue::String(s) => { + let s = sanitize(s)?; + Some(ArrowLiteralValue::String(s.to_string())) + }, + #[cfg(feature = "dtype-date")] + AnyValue::Date(v) => Some(ArrowLiteralValue::Date(v)), + #[cfg(feature = "dtype-datetime")] + AnyValue::Datetime(v, tu, tz) => Some(ArrowLiteralValue::Datetime { + value: v, + time_unit: tu, + time_zone: tz.cloned(), + }), + AnyValue::Binary(_) | AnyValue::List(_) => None, + #[cfg(feature = "dtype-array")] + AnyValue::Array(_, _) => None, + #[cfg(feature = "dtype-struct")] + AnyValue::Struct(_, _, _) => None, + av => { + if dtype.is_float() { + av.extract::().map(ArrowLiteralValue::Float) + } else if dtype.is_integer() { + av.extract::().map(ArrowLiteralValue::Int) + } else { + None + } + }, + } +} + +fn series_to_arrow_literal_list(s: &Series) -> Option> { + let mut out = Vec::with_capacity(s.len()); + for av in s.iter() { + out.push(anyvalue_to_arrow_literal(av)?); + } + Some(out) +} + +pub fn predicate_to_arrow_pred( + predicate: Node, + expr_arena: &Arena, +) -> Option { + match expr_arena.get(predicate) { + AExpr::BinaryExpr { left, right, op } => { + if !op.is_comparison_or_bitwise() { + return None; + } + let left = predicate_to_arrow_pred(*left, expr_arena)?; + let right = predicate_to_arrow_pred(*right, expr_arena)?; + if let Some(cmp) = comparison_op(op) { + Some(ArrowPredicate::Comparison { + left: Box::new(left), + op: cmp, + right: Box::new(right), + }) + } else { + match op { + Operator::And | Operator::LogicalAnd => { + Some(ArrowPredicate::And(Box::new(left), Box::new(right))) + }, + Operator::Or | Operator::LogicalOr => { + Some(ArrowPredicate::Or(Box::new(left), Box::new(right))) + }, + Operator::Xor => Some(ArrowPredicate::Xor(Box::new(left), Box::new(right))), + _ => None, + } + } + }, + AExpr::Column(name) => { + let name = sanitize(name)?; + Some(ArrowPredicate::Column(name.to_string())) + }, + AExpr::Literal(LiteralValue::Series(_)) => None, + AExpr::Literal(lv) => { + let av = lv.to_any_value()?; + let lit = anyvalue_to_arrow_literal(av)?; + Some(ArrowPredicate::Literal(lit)) + }, + #[cfg(feature = "is_in")] + AExpr::Function { + function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { nulls_equal }), + input, + .. + } => { + let col = predicate_to_arrow_pred(input.first()?.node(), expr_arena)?; + let rhs_node = input.get(1)?.node(); + + let values = if let AExpr::Literal(lv) = expr_arena.get(rhs_node) + && lv.get_datatype().is_list() + { + use polars_core::prelude::ExplodeOptions; + + let mut haystack_series = if let LiteralValue::Series(s) = lv + && s.dtype().is_list() + && s.len() == 1 + { + if s.null_count() == 0 { + s.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: false, + }) + .ok()? + } else { + Series::full_null(PlSmallStr::EMPTY, 0, &DataType::Null) + } + } else if let Some(AnyValue::List(s)) = lv.to_any_value() { + s + } else if lv.is_null() { + Series::full_null(PlSmallStr::EMPTY, 0, &DataType::Null) + } else { + return None; + }; + + let converted_len = haystack_series.len() + - if *nulls_equal { + 0 + } else { + haystack_series.null_count() + }; + + if converted_len > LIST_ITEM_LIMIT { + return None; + } + + if !*nulls_equal { + haystack_series = haystack_series.drop_nulls(); + } + + series_to_arrow_literal_list(&haystack_series)? + } else { + return None; + }; + + // We can simplify this to a false mask if the list is empty. + if values.is_empty() { + Some(ArrowPredicate::Literal(ArrowLiteralValue::Bool(false))) + } else { + Some(ArrowPredicate::IsIn { + expr: Box::new(col), + values, + }) + } + }, + #[cfg(feature = "is_between")] + AExpr::Function { + function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }), + input, + .. + } => { + if !matches!(expr_arena.get(input.first()?.node()), AExpr::Column(_)) { + return None; + } + let col = predicate_to_arrow_pred(input.first()?.node(), expr_arena)?; + let left_cmp_op = match closed { + ClosedInterval::None | ClosedInterval::Right => ComparisonOp::Gt, + ClosedInterval::Both | ClosedInterval::Left => ComparisonOp::Gte, + }; + let right_cmp_op = match closed { + ClosedInterval::None | ClosedInterval::Left => ComparisonOp::Lt, + ClosedInterval::Both | ClosedInterval::Right => ComparisonOp::Lte, + }; + + let lower = predicate_to_arrow_pred(input.get(1)?.node(), expr_arena)?; + let upper = predicate_to_arrow_pred(input.get(2)?.node(), expr_arena)?; + + let lower_cmp = ArrowPredicate::Comparison { + left: Box::new(col.clone()), + op: left_cmp_op, + right: Box::new(lower), + }; + let upper_cmp = ArrowPredicate::Comparison { + left: Box::new(col), + op: right_cmp_op, + right: Box::new(upper), + }; + Some(ArrowPredicate::And( + Box::new(lower_cmp), + Box::new(upper_cmp), + )) + }, + AExpr::Function { + function, input, .. + } => { + let input = input.first().unwrap().node(); + let input = predicate_to_arrow_pred(input, expr_arena)?; + + match function { + IRFunctionExpr::Boolean(IRBooleanFunction::Not) => { + Some(ArrowPredicate::Not(Box::new(input))) + }, + IRFunctionExpr::Boolean(IRBooleanFunction::IsNull) => { + Some(ArrowPredicate::IsNull(Box::new(input))) + }, + IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull) => Some(ArrowPredicate::Not( + Box::new(ArrowPredicate::IsNull(Box::new(input))), + )), + _ => None, + } + }, + _ => None, + } +} diff --git a/crates/polars-plan/src/plans/python/source.rs b/crates/polars-plan/src/plans/python/source.rs index 243f8d1272b9..988f8acdec2b 100644 --- a/crates/polars-plan/src/plans/python/source.rs +++ b/crates/polars-plan/src/plans/python/source.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use polars_core::schema::SchemaRef; +use polars_io::arrow_predicate::ArrowPredicate; use polars_utils::python_function::PythonFunction; #[cfg(feature = "ir_serde")] use serde::{Deserialize, Serialize}; @@ -8,7 +9,7 @@ use serde::{Deserialize, Serialize}; use crate::dsl::python_dsl::PythonScanSource; use crate::plans::{ExprIR, PlSmallStr}; -#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[derive(Clone, PartialEq, Debug, Default)] #[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub struct PythonOptions { /// A function that returns a Python Generator. @@ -32,12 +33,10 @@ pub struct PythonOptions { pub is_pure: bool, } -#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[derive(Clone, PartialEq, Debug, Default)] #[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub enum PythonPredicate { - // A pyarrow predicate python expression - // can be evaluated with python.eval - PyArrow(String), + PyArrow(ArrowPredicate), Polars(ExprIR), #[default] None, diff --git a/crates/polars-plan/src/plans/visitor/hash.rs b/crates/polars-plan/src/plans/visitor/hash.rs index 8362ad96f22b..0cf51cfd8640 100644 --- a/crates/polars-plan/src/plans/visitor/hash.rs +++ b/crates/polars-plan/src/plans/visitor/hash.rs @@ -63,7 +63,7 @@ fn hash_python_predicate( std::mem::discriminant(pred).hash(state); match pred { PythonPredicate::None => {}, - PythonPredicate::PyArrow(s) => s.hash(state), + PythonPredicate::PyArrow(p) => format!("{:?}", p).hash(state), PythonPredicate::Polars(e) => e.traverse_and_hash(expr_arena, state), } } diff --git a/crates/polars-python/src/lazyframe/visitor/nodes.rs b/crates/polars-python/src/lazyframe/visitor/nodes.rs index 416c7c81a797..7dbf7a4c1819 100644 --- a/crates/polars-python/src/lazyframe/visitor/nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/nodes.rs @@ -378,7 +378,9 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult> { python_src, match &options.predicate { PythonPredicate::None => py.None(), - PythonPredicate::PyArrow(s) => ("pyarrow", s).into_py_any(py)?, + PythonPredicate::PyArrow(p) => { + ("pyarrow", format!("{:?}", p)).into_py_any(py)? + }, PythonPredicate::Polars(e) => ("polars", e.node().0).into_py_any(py)?, }, options diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 49840b67928b..b999df03451f 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -1414,7 +1414,10 @@ fn to_graph_rec<'a>( let mut could_serialize_predicate = true; let predicate = match &options.predicate { - PythonPredicate::PyArrow(s) => s.into_bound_py_any(py).unwrap(), + PythonPredicate::PyArrow(pred) => { + polars_mem_engine::arrow_predicate_pyo3::arrow_predicate_to_pyobject(py, pred) + .map_err(|e| polars_err!(ComputeError: "failed to build pyarrow predicate: {}", e))? + }, PythonPredicate::None => None::<()>.into_bound_py_any(py).unwrap(), PythonPredicate::Polars(_) => { assert!(pl_predicate.is_some(), "should be set"); diff --git a/py-polars/src/polars/io/pyarrow_dataset/anonymous_scan.py b/py-polars/src/polars/io/pyarrow_dataset/anonymous_scan.py index 3f98e465c9fc..d5db4b007e93 100644 --- a/py-polars/src/polars/io/pyarrow_dataset/anonymous_scan.py +++ b/py-polars/src/polars/io/pyarrow_dataset/anonymous_scan.py @@ -4,12 +4,12 @@ from typing import TYPE_CHECKING, Any import polars as pl -from polars._dependencies import pyarrow as pa if TYPE_CHECKING: from collections.abc import Iterator from polars import DataFrame, LazyFrame + from polars._dependencies import pyarrow as pa def _scan_pyarrow_dataset( @@ -51,7 +51,7 @@ def _scan_pyarrow_dataset( def _scan_pyarrow_dataset_impl( ds: pa.dataset.Dataset, with_columns: list[str] | None, - predicate: str | bytes | None, + predicate: pa.compute.Expression | None, n_rows: int | None, batch_size: int | None = None, *, @@ -68,7 +68,7 @@ def _scan_pyarrow_dataset_impl( with_columns Columns that are projected. predicate - pyarrow expression string (when `allow_pyarrow_filter=True`) or + pyarrow expression (when `allow_pyarrow_filter=True`) or serialized Polars predicate bytes (when `allow_pyarrow_filter=False`). n_rows: Materialize only `n` rows from the arrow dataset. @@ -80,12 +80,6 @@ def _scan_pyarrow_dataset_impl( user_batch_size User-specified `batch_size` (takes precedence over Rust-provided `batch_size`). - Warnings - -------- - Don't use this if you accept untrusted user inputs. Predicates will be evaluated - with python 'eval'. There is sanitation in place, but it is a possible attack - vector. - Returns ------- tuple[Iterator[DataFrame], bool] @@ -98,30 +92,8 @@ def _scan_pyarrow_dataset_impl( filter_ = None if allow_pyarrow_filter and predicate is not None: - from polars._utils.convert import ( - to_py_date, - to_py_datetime, - to_py_time, - to_py_timedelta, - ) - from polars.datatypes import Date, Datetime, Duration - - v = eval( - predicate, - { - "pa": pa, - "Date": Date, - "Datetime": Datetime, - "Duration": Duration, - "to_py_date": to_py_date, - "to_py_datetime": to_py_datetime, - "to_py_time": to_py_time, - "to_py_timedelta": to_py_timedelta, - }, - ) - if n_rows is None: - filter_ = v + filter_ = predicate common_params: dict[str, Any] = {"columns": with_columns, "filter": filter_} batch_size = user_batch_size if user_batch_size is not None else batch_size diff --git a/py-polars/tests/unit/io/test_pyarrow_dataset.py b/py-polars/tests/unit/io/test_pyarrow_dataset.py index 7e1d1ef11f38..0fe633869a8c 100644 --- a/py-polars/tests/unit/io/test_pyarrow_dataset.py +++ b/py-polars/tests/unit/io/test_pyarrow_dataset.py @@ -237,11 +237,13 @@ def test_pyarrow_dataset_partial_predicate_pushdown( capture = capfd.readouterr().err # Verify: partial predicate was pushed to pyarrow - assert "(pa.compute.field('a') > 1)" in capture - assert ( + binop_pred = 'converted pyarrow predicate: Comparison { left: Column("a"), op: Gt, right: Literal(Int(1)) }' + assert binop_pred in capture + + resid_pred = ( 'residual predicate: Some([([(col("a").cast(Float64)) * (col("b"))]) > (25.0)])' - in capture ) + assert resid_pred in capture # Verify: correctness expected = ( df.lazy().filter((pl.col("a") > 1) & (pl.col("a") * pl.col("b") > 25)).collect() @@ -266,8 +268,11 @@ def test_pyarrow_dataset_is_in_predicate_pushdown( result = q.collect() capture = capfd.readouterr().err - assert "(pa.compute.field('id')).isin([1,3])" in capture - assert "residual predicate: None" in capture + isin_pred = 'predicate node: col("id").is_in([[1, 3]])' + assert isin_pred in capture + + resid_pred = "residual predicate: None" + assert resid_pred in capture assert_frame_equal(result, expected) assert_frame_equal(df.filter(pred), expected) @@ -282,8 +287,11 @@ def test_pyarrow_dataset_is_in_predicate_pushdown( plmonkeypatch.setenv("POLARS_VERBOSE_SENSITIVE", "0") - assert "(pa.compute.field('id')).isin([1,2,3])" in capture - assert "residual predicate: None" in capture + isin_pred = 'predicate node: col("id").is_in([[1, 2, 3]])' + assert isin_pred in capture + + resid_pred = "residual predicate: None" + assert resid_pred in capture assert_frame_equal(result, expected) assert_frame_equal(df.filter(pred), expected) @@ -306,7 +314,10 @@ def test_pyarrow_dataset_is_in_predicate_pushdown_nulls_equality( result = q.collect() capture = capfd.readouterr().err - assert "(pa.compute.field('id')).isin([1,3])" in capture + isin_pred = 'converted pyarrow predicate: IsIn { expr: Column("id"), values: [Int(1), Int(3)] }' + assert isin_pred in capture + + resid_pred = "residual predicate: None" assert "residual predicate: None" in capture assert_frame_equal(result, expected) @@ -320,7 +331,7 @@ def test_pyarrow_dataset_is_in_predicate_pushdown_nulls_equality( result = q.collect() capture = capfd.readouterr().err - assert "(pa.compute.field('id')).isin([1,None,3])" in capture + assert 'IsIn { expr: Column("id"), values: [Int(1), Null, Int(3)] }' in capture assert "residual predicate: None" in capture assert_frame_equal(result, expected) @@ -334,7 +345,7 @@ def test_pyarrow_dataset_is_in_predicate_pushdown_nulls_equality( result = q.collect() capture = capfd.readouterr().err - assert "converted pyarrow predicate: pa.compute.scalar(False)" in capture + assert "converted pyarrow predicate: Literal(Bool(false))" in capture assert "residual predicate: None" in capture assert_frame_equal(q.collect(), expected) @@ -350,7 +361,7 @@ def test_pyarrow_dataset_is_in_predicate_pushdown_nulls_equality( plmonkeypatch.setenv("POLARS_VERBOSE_SENSITIVE", "0") - assert "converted pyarrow predicate: pa.compute.scalar(False)" in capture + assert "converted pyarrow predicate: Literal(Bool(false))" in capture assert "residual predicate: None" in capture assert_frame_equal(q.collect(), expected) @@ -403,7 +414,7 @@ def test_pyarrow_dataset_predicate_verbose_log( assert ( "[SENSITIVE]: python_scan_predicate: " 'predicate node: [(col("a")) < (3)], ' - "converted pyarrow predicate: (pa.compute.field('a') < 3), " + 'converted pyarrow predicate: Comparison { left: Column("a"), op: Lt, right: Literal(Int(3)) }, ' "residual predicate: None" ) in capture @@ -511,6 +522,8 @@ def test_scan_pyarrow_dataset_filter_slice_order() -> None: pl.DataFrame({"index": 1, "year": 2026, "month": 0}), ) + import pyarrow.compute as pc + import polars.io.pyarrow_dataset.anonymous_scan # Test post-filter in engine: this tests the correct result. @@ -520,12 +533,14 @@ def test_scan_pyarrow_dataset_filter_slice_order() -> None: pl.DataFrame({"index": 1, "year": 2026, "month": 0}), ) + year_eq_2026 = pc.field("year") == 2026 + # Test post-filter in engine: this tests that the filter is not applied pyarrow. assert_frame_equal( polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl( dataset, n_rows=2, - predicate="pa.compute.field('year') == 2026", + predicate=year_eq_2026, with_columns=None, )[0].__next__(), pl.DataFrame({"index": [0, 1], "year": [2025, 2026], "month": [0, 0]}), @@ -536,7 +551,7 @@ def test_scan_pyarrow_dataset_filter_slice_order() -> None: polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl( dataset, n_rows=0, - predicate="pa.compute.field('year') == 2026", + predicate=year_eq_2026, with_columns=None, )[0].__next__(), pl.DataFrame(schema={"index": pl.Int64, "year": pl.Int64, "month": pl.Int64}), @@ -559,12 +574,109 @@ def test_scan_pyarrow_dataset_filter_slice_order() -> None: assert not polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl( dataset, n_rows=0, - predicate="pa.compute.field('year') == 2026", + predicate=year_eq_2026, with_columns=None, allow_pyarrow_filter=False, )[1] +@pytest.mark.write_disk +def test_arrow_predicate_conversions(tmp_path: Path) -> None: + """Test that various arrow predicates are correctly converted and pushed down.""" + # Create test data with various data types + df = pl.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "value": [10, 20, 30, 40, 50], + "name": ["a", "b", "c", "d", "e"], + "is_active": [True, False, True, False, True], + } + ) + + file_path = tmp_path / "test_predicates.ipc" + df.write_ipc(file_path) + + # Test simple equality comparison + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("id") == 2), + n_expected=1, + check_predicate_pushdown=True, + ) + + # Test greater than comparison + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("value") > 25), + n_expected=3, + check_predicate_pushdown=True, + ) + + # Test less than or equal comparison + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("value") <= 20), + n_expected=2, + check_predicate_pushdown=True, + ) + + # Test boolean column filter + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("is_active")), + n_expected=3, + check_predicate_pushdown=True, + ) + + # Test NOT filter + helper_dataset_test( + file_path, + lambda lf: lf.filter(~pl.col("is_active")), + n_expected=2, + check_predicate_pushdown=True, + ) + + # Test AND logic + helper_dataset_test( + file_path, + lambda lf: lf.filter((pl.col("id") > 2) & (pl.col("value") < 45)), + n_expected=2, + check_predicate_pushdown=True, + ) + + # Test OR logic + helper_dataset_test( + file_path, + lambda lf: lf.filter((pl.col("id") == 1) | (pl.col("id") == 5)), + n_expected=2, + check_predicate_pushdown=True, + ) + + # Test is_null + df_with_nulls = pl.DataFrame( + { + "id": [1, 2, None, 4], + "value": [10, None, 30, 40], + } + ) + file_path_nulls = tmp_path / "test_nulls.ipc" + df_with_nulls.write_ipc(file_path_nulls) + + helper_dataset_test( + file_path_nulls, + lambda lf: lf.filter(pl.col("id").is_null()), + n_expected=1, + check_predicate_pushdown=True, + ) + + helper_dataset_test( + file_path_nulls, + lambda lf: lf.filter(pl.col("id").is_not_null()), + n_expected=3, + check_predicate_pushdown=True, + ) + + def test_pyarrow_dataset_streaming_source() -> None: df = pl.DataFrame({"item": ["foo", "bar", "baz"], "price": [10.0, 20.0, 30.0]}) dataset = pl.scan_pyarrow_dataset(