Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/polars-plan/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ pub mod dsl;
pub mod frame;
pub mod plans;
pub mod prelude;
pub mod traversal;
pub mod utils;
92 changes: 73 additions & 19 deletions crates/polars-plan/src/plans/aexpr/projection_height.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use std::marker::PhantomData;
use std::ops::ControlFlow;

use polars_utils::arena::{Arena, Node};
use polars_utils::collection::{Collection, CollectionWrap};
use polars_utils::scratch_vec::ScratchVec;

use crate::dsl::WindowMapping;
use crate::plans::{AExpr, aexpr_postvisit_traversal};
use crate::plans::{AExpr, aexpr_tree_traversal};
use crate::traversal::visitor::{NodeVisitor, SubtreeVisit};

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum ExprProjectionHeight {
Expand Down Expand Up @@ -39,27 +44,75 @@ pub fn aexpr_projection_height_rec(
ae_node: Node,
mut expr_arena: &Arena<AExpr>,
stack: &mut ScratchVec<Node>,
inputs_stack: &mut ScratchVec<ExprProjectionHeight>,
edges_stack: &mut ScratchVec<ExprProjectionHeight>,
) -> ExprProjectionHeight {
aexpr_postvisit_traversal(
let mut visitor = ExprHeightVisitor::default();

aexpr_tree_traversal(
ae_node,
&mut expr_arena,
stack.get(),
inputs_stack.get(),
&mut |ae_node, input_heights, expr_arena| {
aexpr_projection_height(expr_arena.get(ae_node), input_heights)
},
edges_stack.get(),
&mut visitor,
)
.continue_value()
.unwrap()
}

#[derive(Default)]
pub struct ExprHeightVisitor<'a>(PhantomData<&'a ()>);

impl<'a> NodeVisitor for ExprHeightVisitor<'a> {
type Key = Node;
type Edge = ExprProjectionHeight;
type Storage = &'a Arena<AExpr>;
type BreakValue = ();

fn default_edge(&mut self) -> Self::Edge {
ExprProjectionHeight::Unknown
}

fn pre_visit(
&mut self,
key: Self::Key,
storage: &mut Self::Storage,
edges: &mut dyn crate::traversal::edge_provider::NodeEdgesProvider<Self::Edge>,
) -> ControlFlow<Self::BreakValue, SubtreeVisit> {
ControlFlow::Continue(
if let Some(height) = aexpr_projection_height(storage.get(key), None) {
edges.outputs()[0] = height;
SubtreeVisit::Skip
} else {
SubtreeVisit::Visit
},
)
}

fn post_visit(
&mut self,
key: Self::Key,
storage: &mut Self::Storage,
edges: &mut dyn crate::traversal::edge_provider::NodeEdgesProvider<Self::Edge>,
) -> ControlFlow<Self::BreakValue> {
edges.outputs()[0] =
aexpr_projection_height(storage.get(key), Some(&mut *edges.inputs())).unwrap();
ControlFlow::Continue(())
}
}

/// # Returns
/// Returns `None` if the output height is dependent on input heights and input heights were not
/// provided.
pub fn aexpr_projection_height(
aexpr: &AExpr,
input_heights: &[ExprProjectionHeight],
) -> ExprProjectionHeight {
input_heights: Option<&mut dyn Collection<ExprProjectionHeight>>,
) -> Option<ExprProjectionHeight> {
use AExpr::*;
use ExprProjectionHeight as H;

match aexpr {
let input_heights = input_heights.map(CollectionWrap::<ExprProjectionHeight, _>::new);

Some(match aexpr {
Column(_) => H::Column,

Element => H::Column,
Expand All @@ -73,37 +126,37 @@ pub fn aexpr_projection_height(
}
},

Eval { .. } => input_heights[0],
Eval { .. } => input_heights?[0],
#[cfg(feature = "dtype-struct")]
StructEval { .. } => input_heights[0],
StructEval { .. } => input_heights?[0],

Filter { .. } | Slice { .. } | Explode { .. } => H::Unknown,

Agg(_) | AnonymousAgg { .. } => H::Scalar,
Len => H::Scalar,

BinaryExpr { .. } => {
let [l, r] = input_heights.try_into().unwrap();
let [l, r] = input_heights?.try_into().unwrap();
l.zip_with(r)
},
Ternary { .. } => {
let [pred, truthy, falsy] = input_heights.try_into().unwrap();
let [pred, truthy, falsy] = input_heights?.try_into().unwrap();
pred.zip_with(truthy).zip_with(falsy)
},

Cast { .. } | Sort { .. } => {
let [h] = input_heights.try_into().unwrap();
let [h] = input_heights?.try_into().unwrap();
h
},

SortBy { .. } => H::zipped_projection_height(input_heights.iter().copied()),
SortBy { .. } => H::zipped_projection_height(input_heights?.iter().copied()),

Gather { returns_scalar, .. } => {
if *returns_scalar {
// This is `get()` from the API
H::Scalar
} else {
let indices_height = input_heights[1];
let indices_height = input_heights?[1];

match indices_height {
H::Column => H::Column,
Expand All @@ -116,20 +169,21 @@ pub fn aexpr_projection_height(
if options.flags.returns_scalar() {
H::Scalar
} else if options.flags.is_elementwise() || options.flags.is_length_preserving() {
H::zipped_projection_height(input_heights.iter().copied())
H::zipped_projection_height(input_heights?.iter().copied())
} else {
H::Unknown
}
},

#[cfg(feature = "dynamic_group_by")]
Rolling { .. } => H::Column,

Over { mapping, .. } => {
if matches!(mapping, WindowMapping::Explode) {
H::Unknown
} else {
H::Column
}
},
}
})
}
58 changes: 27 additions & 31 deletions crates/polars-plan/src/plans/aexpr/traverse.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::ops::ControlFlow;

use super::*;
use crate::traversal::tree_traversal::{GetNodeInputs, tree_traversal};
use crate::traversal::visitor::NodeVisitor;

impl AExpr {
/// Push the inputs of this node to the given container, in field declaration order.
Expand Down Expand Up @@ -603,41 +607,33 @@ impl NodeInputs {
}
}

#[recursive::recursive]
pub fn aexpr_postvisit_traversal<F, State, ArenaT>(
ae_node: Node,
pub fn aexpr_tree_traversal<ArenaT, Edge, BreakValue>(
root_ae_node: Node,
expr_arena: &mut ArenaT,
stack: &mut Vec<Node>,
inputs_stack: &mut Vec<State>,
visit: &mut F,
) -> State
visit_stack: &mut Vec<Node>,
edges: &mut Vec<Edge>,
visitor: &mut dyn NodeVisitor<Key = Node, Storage = ArenaT, Edge = Edge, BreakValue = BreakValue>,
) -> ControlFlow<BreakValue, Edge>
where
F: FnMut(Node, &mut [State], &mut ArenaT) -> State,
State: Default,
ArenaT: AsRef<Arena<AExpr>>,
ArenaT: GetNodeInputs<Node>,
{
let ae = expr_arena.as_ref().get(ae_node);

let base_stack_len = stack.len();
let base_inputs_stack_len = inputs_stack.len();
ae.inputs(stack);
let num_inputs = stack.len() - base_stack_len;
tree_traversal(root_ae_node, expr_arena, visit_stack, edges, visitor)
}

for i in base_stack_len..stack.len() {
let h = aexpr_postvisit_traversal(stack[i], expr_arena, stack, inputs_stack, visit);
inputs_stack.push(h);
impl GetNodeInputs<Node> for Arena<AExpr> {
fn push_inputs_for_key<C>(&self, key: Node, container: &mut C)
where
C: Extend<Node>,
{
self.get(key).inputs(container);
}
}

assert_eq!(stack.len(), base_stack_len + num_inputs);
stack.truncate(base_stack_len);

assert_eq!(inputs_stack.len(), base_inputs_stack_len + num_inputs);
let state = visit(
ae_node,
&mut inputs_stack[base_inputs_stack_len..],
expr_arena,
);
inputs_stack.truncate(base_inputs_stack_len);

state
impl GetNodeInputs<Node> for &Arena<AExpr> {
fn push_inputs_for_key<C>(&self, key: Node, container: &mut C)
where
C: Extend<Node>,
{
self.get(key).inputs(container);
}
}
11 changes: 11 additions & 0 deletions crates/polars-plan/src/traversal/edge_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use polars_utils::collection::{Collection, CollectionWrap};

pub trait NodeEdgesProvider<Edge> {
fn inputs<'a>(&'a mut self) -> CollectionWrap<Edge, &'a mut dyn Collection<Edge>>
where
Edge: 'a;

fn outputs<'a>(&'a mut self) -> CollectionWrap<Edge, &'a mut dyn Collection<Edge>>
where
Edge: 'a;
}
3 changes: 3 additions & 0 deletions crates/polars-plan/src/traversal/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod edge_provider;
pub mod tree_traversal;
pub mod visitor;
Loading
Loading