Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions crates/polars-io/src/arrow_predicate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use polars_core::datatypes::{TimeUnit, TimeZone};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

// This is meant to mimic `pyarrow.compute.Expression` API as closely as possible to make it
// easier to convert directly to pyarrow predicates applied at the python/scan level.
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ArrowPredicate {
Column(String),
Literal(LiteralValue),
// Binary ops
Comparison {
left: Box<ArrowPredicate>,
op: ComparisonOp,
right: Box<ArrowPredicate>,
},
// Logicals
And(Box<ArrowPredicate>, Box<ArrowPredicate>),
Or(Box<ArrowPredicate>, Box<ArrowPredicate>),
Xor(Box<ArrowPredicate>, Box<ArrowPredicate>),
Not(Box<ArrowPredicate>),
// Methods that turn into masks
IsNull(Box<ArrowPredicate>),
IsIn {
expr: Box<ArrowPredicate>,
values: Vec<LiteralValue>,
},
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ComparisonOp {
Eq,
NotEq,
Lt,
Lte,
Gt,
Gte,
}

#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum LiteralValue {
Null,
Int(i64),
Float(f64),
String(String),
Bool(bool),
Date(i32),
// Should mimic python datetime module semantics
Datetime {
value: i64,
time_unit: TimeUnit,
time_zone: Option<TimeZone>,
},
}
1 change: 1 addition & 0 deletions crates/polars-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#![allow(ambiguous_glob_reexports)]
extern crate core;

pub mod arrow_predicate;
#[cfg(feature = "avro")]
pub mod avro;
#[cfg(feature = "catalog")]
Expand Down
121 changes: 121 additions & 0 deletions crates/polars-mem-engine/src/arrow_predicate_pyo3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use polars_core::datatypes::TimeUnit;
use polars_io::arrow_predicate::{ArrowPredicate, ComparisonOp, LiteralValue};
use pyo3::IntoPyObjectExt;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};

pub fn arrow_predicate_to_pyobject<'py>(
py: Python<'py>,
pred: &ArrowPredicate,
) -> PyResult<Bound<'py, PyAny>> {
let pc = py.import("pyarrow.compute")?;
build_expr(py, &pc, pred)
}

// Conversion of ArrowPredicate enums to pyarrow expressions
fn build_expr<'py>(
py: Python<'py>,
pc: &Bound<'py, PyAny>,
p: &ArrowPredicate,
) -> PyResult<Bound<'py, PyAny>> {
match p {
ArrowPredicate::Column(name) => pc.call_method1("field", (name.as_str(),)),
ArrowPredicate::Literal(lv) => {
let val = literal_to_py(py, lv)?;
pc.call_method1("scalar", (val,))
},
ArrowPredicate::Comparison { left, op, right } => {
let l = build_expr(py, pc, left)?;
let r = build_expr(py, pc, right)?;
// This would stop working if pyarrow changed their overloaded operators, btw.
let method = match op {
ComparisonOp::Eq => "__eq__",
ComparisonOp::NotEq => "__ne__",
ComparisonOp::Lt => "__lt__",
ComparisonOp::Lte => "__le__",
ComparisonOp::Gt => "__gt__",
ComparisonOp::Gte => "__ge__",
};
l.call_method1(method, (r,))
},
ArrowPredicate::And(l, r) => {
let l = build_expr(py, pc, l)?;
let r = build_expr(py, pc, r)?;
l.call_method1("__and__", (r,))
},
ArrowPredicate::Or(l, r) => {
let l = build_expr(py, pc, l)?;
let r = build_expr(py, pc, r)?;
l.call_method1("__or__", (r,))
},
ArrowPredicate::Xor(l, r) => {
let l = build_expr(py, pc, l)?;
let r = build_expr(py, pc, r)?;
l.call_method1("__xor__", (r,))
},
ArrowPredicate::Not(inner) => {
let i = build_expr(py, pc, inner)?;
i.call_method0("__invert__")
},
ArrowPredicate::IsNull(inner) => {
let i = build_expr(py, pc, inner)?;
i.call_method0("is_null")
},
ArrowPredicate::IsIn { expr, values } => {
if values.is_empty() {
return pc.call_method1("scalar", (false,));
}
let e = build_expr(py, pc, expr)?;
let py_values: Vec<Py<PyAny>> = values
.iter()
.map(|v| literal_to_py(py, v).map(|b| b.unbind()))
.collect::<PyResult<_>>()?;
let list = PyList::new(py, &py_values)?;
e.call_method1("isin", (list,))
},
}
}

fn literal_to_py<'py>(py: Python<'py>, lv: &LiteralValue) -> PyResult<Bound<'py, PyAny>> {
match lv {
LiteralValue::Null => Ok(py.None().into_bound(py)),
LiteralValue::Int(v) => Ok(v.into_pyobject(py)?.into_any()),
LiteralValue::Float(v) => Ok(v.into_pyobject(py)?.into_any()),
LiteralValue::String(s) => Ok(s.into_pyobject(py)?.into_any()),
LiteralValue::Bool(v) => v.into_bound_py_any(py),
LiteralValue::Date(days) => {
let dt_mod = py.import("datetime")?;
let epoch = dt_mod.getattr("date")?.call1((1970i32, 1i32, 1i32))?;
let delta = dt_mod.getattr("timedelta")?.call1((*days as i64,))?;
epoch.call_method1("__add__", (delta,))
},
LiteralValue::Datetime {
value,
time_unit,
time_zone,
} => {
// This conversion does not feel idiomatic but it's probably the best way to do this
// within the confines of the python standard lib.
let dt_mod = py.import("datetime")?;
let epoch = dt_mod.getattr("datetime")?.call1((1970i32, 1i32, 1i32))?;
let micros: i64 = match time_unit {
TimeUnit::Nanoseconds => value / 1000,
TimeUnit::Microseconds => *value,
TimeUnit::Milliseconds => value * 1000,
};
let kwargs = PyDict::new(py);
kwargs.set_item("microseconds", micros)?;
let delta = dt_mod.getattr("timedelta")?.call((), Some(&kwargs))?;
let naive = epoch.call_method1("__add__", (delta,))?;
if let Some(tz) = time_zone {
let zi = py.import("zoneinfo")?;
let tz_obj = zi.getattr("ZoneInfo")?.call1((tz.to_string(),))?;
let kw = PyDict::new(py);
kw.set_item("tzinfo", tz_obj)?;
naive.call_method("replace", (), Some(&kw))
} else {
Ok(naive)
}
},
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ impl Executor for PythonScanExec {
let mut could_serialize_predicate = true;

let predicate = match &self.options.predicate {
PythonPredicate::PyArrow(s) => s.into_bound_py_any(py).unwrap(),
PythonPredicate::PyArrow(pred) => {
crate::arrow_predicate_pyo3::arrow_predicate_to_pyobject(py, pred).map_err(
|e| polars_err!(ComputeError: "failed to build pyarrow predicate: {}", e),
)?
},
PythonPredicate::None => None::<()>.into_bound_py_any(py).unwrap(),
PythonPredicate::Polars(_) => {
assert!(self.predicate.is_some(), "should be set");
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-mem-engine/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
feature = "allow_unused",
allow(unused, dead_code, irrefutable_let_patterns)
)] // Maybe be caused by some feature
#[cfg(feature = "python")]
pub mod arrow_predicate_pyo3;
mod executors;
mod planner;
mod prelude;
Expand Down
27 changes: 16 additions & 11 deletions crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,19 +196,18 @@ pub fn python_scan_predicate(
// Convert to a pyarrow eval string.
if matches!(options.python_source, PythonScanSource::Pyarrow) {
use polars_core::config::verbose_print_sensitive;
use polars_io::arrow_predicate::ArrowPredicate;
use polars_plan::plans::MintermIter;

// If there is a `head`, that comes before the filter and we post-apply
// the predicate in the engine. No need to transpile to `pyarrow.predicate`
// the predicate in the engine.
let residual_predicate_expr_ir = if options.n_rows.is_none() {
// Split into AND-minterms and convert each independently.
let mut residual_predicate_nodes: Vec<Node> = vec![];
let parts: Vec<String> = MintermIter::new(e.node(), expr_arena)
let parts: Vec<ArrowPredicate> = MintermIter::new(e.node(), expr_arena)
.filter_map(|node| {
let result = polars_plan::plans::python::pyarrow::predicate_to_pa(
node,
expr_arena,
Default::default(),
let result = polars_plan::plans::python::pyarrow::predicate_to_arrow_pred(
node, expr_arena,
);
if result.is_none() {
residual_predicate_nodes.push(node);
Expand All @@ -220,11 +219,17 @@ pub fn python_scan_predicate(
let predicate_pa = match parts.len() {
0 => None,
1 => Some(parts.into_iter().next().unwrap()),
_ => Some(format!("({})", parts.join(" & "))),
_ => {
let combined = parts
.into_iter()
.reduce(|acc, p| ArrowPredicate::And(Box::new(acc), Box::new(p)))
.unwrap();
Some(combined)
},
};

if let Some(eval_str) = predicate_pa {
options.predicate = PythonPredicate::PyArrow(eval_str);
if let Some(pred) = predicate_pa {
options.predicate = PythonPredicate::PyArrow(pred);

residual_predicate_nodes
.into_iter()
Expand All @@ -247,8 +252,8 @@ pub fn python_scan_predicate(

verbose_print_sensitive(|| {
let predicate_pa_verbose_msg = match &options.predicate {
PythonPredicate::PyArrow(p) => p,
_ => "<conversion failed>",
PythonPredicate::PyArrow(p) => format!("{:?}", p),
_ => "<conversion failed>".to_string(),
};

format!(
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/ir/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl<'a> IRDotDisplay<'a> {
PythonScan { options } => {
let predicate = match &options.predicate {
PythonPredicate::Polars(e) => format!("{}", self.display_expr(e)),
PythonPredicate::PyArrow(s) => s.clone(),
PythonPredicate::PyArrow(p) => format!("{:?}", p),
PythonPredicate::None => "none".to_string(),
};
let with_columns = NumColumns(options.with_columns.as_ref().map(|s| s.as_ref()));
Expand Down
Loading
Loading