diff --git a/crates/polars-plan/src/lib.rs b/crates/polars-plan/src/lib.rs index 1b962a643314..de23da47e699 100644 --- a/crates/polars-plan/src/lib.rs +++ b/crates/polars-plan/src/lib.rs @@ -16,4 +16,5 @@ pub mod dsl; pub mod frame; pub mod plans; pub mod prelude; +pub mod traversal; pub mod utils; diff --git a/crates/polars-plan/src/plans/aexpr/projection_height.rs b/crates/polars-plan/src/plans/aexpr/projection_height.rs index b2d23cb83350..aaf8a73bdae6 100644 --- a/crates/polars-plan/src/plans/aexpr/projection_height.rs +++ b/crates/polars-plan/src/plans/aexpr/projection_height.rs @@ -1,8 +1,13 @@ +use std::marker::PhantomData; +use std::ops::ControlFlow; + use polars_utils::arena::{Arena, Node}; +use polars_utils::collection::{Collection, CollectionWrap}; use polars_utils::scratch_vec::ScratchVec; use crate::dsl::WindowMapping; -use crate::plans::{AExpr, aexpr_postvisit_traversal}; +use crate::plans::{AExpr, aexpr_tree_traversal}; +use crate::traversal::visitor::{NodeVisitor, SubtreeVisit}; #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] pub enum ExprProjectionHeight { @@ -39,27 +44,75 @@ pub fn aexpr_projection_height_rec( ae_node: Node, mut expr_arena: &Arena, stack: &mut ScratchVec, - inputs_stack: &mut ScratchVec, + edges_stack: &mut ScratchVec, ) -> ExprProjectionHeight { - aexpr_postvisit_traversal( + let mut visitor = ExprHeightVisitor::default(); + + aexpr_tree_traversal( ae_node, &mut expr_arena, stack.get(), - inputs_stack.get(), - &mut |ae_node, input_heights, expr_arena| { - aexpr_projection_height(expr_arena.get(ae_node), input_heights) - }, + edges_stack.get(), + &mut visitor, ) + .continue_value() + .unwrap() } +#[derive(Default)] +pub struct ExprHeightVisitor<'a>(PhantomData<&'a ()>); + +impl<'a> NodeVisitor for ExprHeightVisitor<'a> { + type Key = Node; + type Edge = ExprProjectionHeight; + type Storage = &'a Arena; + type BreakValue = (); + + fn default_edge(&mut self) -> Self::Edge { + ExprProjectionHeight::Unknown + } + + fn pre_visit( + &mut self, + key: Self::Key, + storage: &mut Self::Storage, + edges: &mut dyn crate::traversal::edge_provider::NodeEdgesProvider, + ) -> ControlFlow { + ControlFlow::Continue( + if let Some(height) = aexpr_projection_height(storage.get(key), None) { + edges.outputs()[0] = height; + SubtreeVisit::Skip + } else { + SubtreeVisit::Visit + }, + ) + } + + fn post_visit( + &mut self, + key: Self::Key, + storage: &mut Self::Storage, + edges: &mut dyn crate::traversal::edge_provider::NodeEdgesProvider, + ) -> ControlFlow { + edges.outputs()[0] = + aexpr_projection_height(storage.get(key), Some(&mut *edges.inputs())).unwrap(); + ControlFlow::Continue(()) + } +} + +/// # Returns +/// Returns `None` if the output height is dependent on input heights and input heights were not +/// provided. pub fn aexpr_projection_height( aexpr: &AExpr, - input_heights: &[ExprProjectionHeight], -) -> ExprProjectionHeight { + input_heights: Option<&mut dyn Collection>, +) -> Option { use AExpr::*; use ExprProjectionHeight as H; - match aexpr { + let input_heights = input_heights.map(CollectionWrap::::new); + + Some(match aexpr { Column(_) => H::Column, Element => H::Column, @@ -73,9 +126,9 @@ pub fn aexpr_projection_height( } }, - Eval { .. } => input_heights[0], + Eval { .. } => input_heights?[0], #[cfg(feature = "dtype-struct")] - StructEval { .. } => input_heights[0], + StructEval { .. } => input_heights?[0], Filter { .. } | Slice { .. } | Explode { .. } => H::Unknown, @@ -83,27 +136,27 @@ pub fn aexpr_projection_height( Len => H::Scalar, BinaryExpr { .. } => { - let [l, r] = input_heights.try_into().unwrap(); + let [l, r] = input_heights?.try_into().unwrap(); l.zip_with(r) }, Ternary { .. } => { - let [pred, truthy, falsy] = input_heights.try_into().unwrap(); + let [pred, truthy, falsy] = input_heights?.try_into().unwrap(); pred.zip_with(truthy).zip_with(falsy) }, Cast { .. } | Sort { .. } => { - let [h] = input_heights.try_into().unwrap(); + let [h] = input_heights?.try_into().unwrap(); h }, - SortBy { .. } => H::zipped_projection_height(input_heights.iter().copied()), + SortBy { .. } => H::zipped_projection_height(input_heights?.iter().copied()), Gather { returns_scalar, .. } => { if *returns_scalar { // This is `get()` from the API H::Scalar } else { - let indices_height = input_heights[1]; + let indices_height = input_heights?[1]; match indices_height { H::Column => H::Column, @@ -116,7 +169,7 @@ pub fn aexpr_projection_height( if options.flags.returns_scalar() { H::Scalar } else if options.flags.is_elementwise() || options.flags.is_length_preserving() { - H::zipped_projection_height(input_heights.iter().copied()) + H::zipped_projection_height(input_heights?.iter().copied()) } else { H::Unknown } @@ -124,6 +177,7 @@ pub fn aexpr_projection_height( #[cfg(feature = "dynamic_group_by")] Rolling { .. } => H::Column, + Over { mapping, .. } => { if matches!(mapping, WindowMapping::Explode) { H::Unknown @@ -131,5 +185,5 @@ pub fn aexpr_projection_height( H::Column } }, - } + }) } diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index 25c38f474650..676043f59722 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -1,4 +1,8 @@ +use std::ops::ControlFlow; + use super::*; +use crate::traversal::tree_traversal::{GetNodeInputs, tree_traversal}; +use crate::traversal::visitor::NodeVisitor; impl AExpr { /// Push the inputs of this node to the given container, in field declaration order. @@ -603,41 +607,33 @@ impl NodeInputs { } } -#[recursive::recursive] -pub fn aexpr_postvisit_traversal( - ae_node: Node, +pub fn aexpr_tree_traversal( + root_ae_node: Node, expr_arena: &mut ArenaT, - stack: &mut Vec, - inputs_stack: &mut Vec, - visit: &mut F, -) -> State + visit_stack: &mut Vec, + edges: &mut Vec, + visitor: &mut dyn NodeVisitor, +) -> ControlFlow where - F: FnMut(Node, &mut [State], &mut ArenaT) -> State, - State: Default, - ArenaT: AsRef>, + ArenaT: GetNodeInputs, { - let ae = expr_arena.as_ref().get(ae_node); - - let base_stack_len = stack.len(); - let base_inputs_stack_len = inputs_stack.len(); - ae.inputs(stack); - let num_inputs = stack.len() - base_stack_len; + tree_traversal(root_ae_node, expr_arena, visit_stack, edges, visitor) +} - for i in base_stack_len..stack.len() { - let h = aexpr_postvisit_traversal(stack[i], expr_arena, stack, inputs_stack, visit); - inputs_stack.push(h); +impl GetNodeInputs for Arena { + fn push_inputs_for_key(&self, key: Node, container: &mut C) + where + C: Extend, + { + self.get(key).inputs(container); } +} - assert_eq!(stack.len(), base_stack_len + num_inputs); - stack.truncate(base_stack_len); - - assert_eq!(inputs_stack.len(), base_inputs_stack_len + num_inputs); - let state = visit( - ae_node, - &mut inputs_stack[base_inputs_stack_len..], - expr_arena, - ); - inputs_stack.truncate(base_inputs_stack_len); - - state +impl GetNodeInputs for &Arena { + fn push_inputs_for_key(&self, key: Node, container: &mut C) + where + C: Extend, + { + self.get(key).inputs(container); + } } diff --git a/crates/polars-plan/src/traversal/edge_provider.rs b/crates/polars-plan/src/traversal/edge_provider.rs new file mode 100644 index 000000000000..811fa92defa6 --- /dev/null +++ b/crates/polars-plan/src/traversal/edge_provider.rs @@ -0,0 +1,11 @@ +use polars_utils::collection::{Collection, CollectionWrap}; + +pub trait NodeEdgesProvider { + fn inputs<'a>(&'a mut self) -> CollectionWrap> + where + Edge: 'a; + + fn outputs<'a>(&'a mut self) -> CollectionWrap> + where + Edge: 'a; +} diff --git a/crates/polars-plan/src/traversal/mod.rs b/crates/polars-plan/src/traversal/mod.rs new file mode 100644 index 000000000000..b910df134b15 --- /dev/null +++ b/crates/polars-plan/src/traversal/mod.rs @@ -0,0 +1,3 @@ +pub mod edge_provider; +pub mod tree_traversal; +pub mod visitor; diff --git a/crates/polars-plan/src/traversal/tree_traversal.rs b/crates/polars-plan/src/traversal/tree_traversal.rs new file mode 100644 index 000000000000..5437abc4c168 --- /dev/null +++ b/crates/polars-plan/src/traversal/tree_traversal.rs @@ -0,0 +1,187 @@ +use std::marker::PhantomData; +use std::ops::{ControlFlow, Range}; + +use polars_utils::collection::{Collection, CollectionWrap}; + +use crate::traversal::edge_provider::NodeEdgesProvider; +use crate::traversal::visitor::{NodeVisitor, SubtreeVisit}; + +pub trait GetNodeInputs { + fn push_inputs_for_key(&self, key: Key, container: &mut C) + where + C: Extend; + + fn num_inputs(&self, key: Key) -> usize { + struct Counter(usize, PhantomData); + + impl Extend for Counter { + fn extend>(&mut self, iter: I) { + iter.into_iter().for_each(|_| self.0 += 1); + } + } + + let mut c = Counter::(0, PhantomData); + self.push_inputs_for_key(key, &mut c); + + c.0 + } +} + +pub fn tree_traversal( + root_key: Key, + storage: &mut Storage, + visit_stack: &mut Vec, + edges: &mut Vec, + visitor: &mut dyn NodeVisitor, +) -> ControlFlow +where + Key: Clone, + Storage: GetNodeInputs, +{ + let root_edge_idx = edges.len(); + edges.push(visitor.default_edge()); + + tree_traversal_impl::( + root_key, + root_edge_idx, + storage, + visit_stack, + edges, + visitor, + )?; + + assert_eq!(edges.len(), root_edge_idx + 1); + ControlFlow::Continue(edges.pop().unwrap()) +} + +#[recursive::recursive] +pub fn tree_traversal_impl( + current_key: Key, + current_key_out_edge_idx: usize, + storage: &mut Storage, + visit_stack: &mut Vec, + edges: &mut Vec, + visitor: &mut dyn NodeVisitor, +) -> ControlFlow +where + Key: Clone, + Storage: GetNodeInputs, +{ + let base_visit_stack_len = visit_stack.len(); + let base_edges_len = edges.len(); + + storage.push_inputs_for_key(current_key.clone(), visit_stack); + + let num_inputs = visit_stack.len() - base_visit_stack_len; + + edges.extend((0..num_inputs).map(|_| visitor.default_edge())); + + match visitor.pre_visit( + current_key.clone(), + storage, + &mut SliceEdgeProvider { + edges, + input_range: base_edges_len..base_edges_len + num_inputs, + output_idx: current_key_out_edge_idx, + }, + )? { + SubtreeVisit::Visit => { + for i in 0..num_inputs { + tree_traversal_impl( + visit_stack[base_visit_stack_len + i].clone(), + base_edges_len + i, + storage, + visit_stack, + edges, + visitor, + )?; + } + }, + SubtreeVisit::Skip => {}, + } + + assert_eq!(visit_stack.len(), base_visit_stack_len + num_inputs); + visit_stack.truncate(base_visit_stack_len); + + assert_eq!(edges.len(), base_edges_len + num_inputs); + + let control_flow = visitor.post_visit( + current_key, + storage, + &mut SliceEdgeProvider { + edges, + input_range: base_edges_len..base_edges_len + num_inputs, + output_idx: current_key_out_edge_idx, + }, + ); + + edges.truncate(base_edges_len); + + control_flow +} + +struct SliceEdgeProvider<'a, Edge> { + edges: &'a mut [Edge], + input_range: Range, + output_idx: usize, +} + +impl<'provider, Edge> NodeEdgesProvider for SliceEdgeProvider<'provider, Edge> { + fn inputs<'a>(&'a mut self) -> CollectionWrap> + where + Edge: 'a, + { + CollectionWrap::new(unsafe { + std::mem::transmute::< + &'a mut SliceEdgeProvider<'provider, Edge>, + &'a mut Inputs>, + >(self) + }) + } + + fn outputs<'a>(&'a mut self) -> CollectionWrap> + where + Edge: 'a, + { + CollectionWrap::new(unsafe { + std::mem::transmute::< + &'a mut SliceEdgeProvider<'provider, Edge>, + &'a mut Outputs>, + >(self) + }) + } +} + +#[repr(transparent)] +struct Inputs(T); + +impl<'a, Edge> Collection for Inputs> { + fn len(&self) -> usize { + self.0.input_range.len() + } + + fn get(&self, idx: usize) -> Option<&Edge> { + (idx < self.0.input_range.len()).then(|| &self.0.edges[self.0.input_range.start + idx]) + } + + fn get_mut(&mut self, idx: usize) -> Option<&mut Edge> { + (idx < self.0.input_range.len()).then(|| &mut self.0.edges[self.0.input_range.start + idx]) + } +} + +#[repr(transparent)] +struct Outputs(T); + +impl<'a, Edge> Collection for Outputs> { + fn len(&self) -> usize { + 1 + } + + fn get(&self, idx: usize) -> Option<&Edge> { + (idx == 0).then(|| &self.0.edges[self.0.output_idx]) + } + + fn get_mut(&mut self, idx: usize) -> Option<&mut Edge> { + (idx == 0).then(|| &mut self.0.edges[self.0.output_idx]) + } +} diff --git a/crates/polars-plan/src/traversal/visitor.rs b/crates/polars-plan/src/traversal/visitor.rs new file mode 100644 index 000000000000..1daabda68a6d --- /dev/null +++ b/crates/polars-plan/src/traversal/visitor.rs @@ -0,0 +1,31 @@ +use std::ops::ControlFlow; + +use crate::traversal::edge_provider::NodeEdgesProvider; + +pub enum SubtreeVisit { + Visit, + Skip, +} + +pub trait NodeVisitor { + type Key; + type Storage; + type Edge; + type BreakValue; + + fn default_edge(&mut self) -> Self::Edge; + + fn pre_visit( + &mut self, + key: Self::Key, + storage: &mut Self::Storage, + edges: &mut dyn NodeEdgesProvider, + ) -> ControlFlow; + + fn post_visit( + &mut self, + key: Self::Key, + storage: &mut Self::Storage, + edges: &mut dyn NodeEdgesProvider, + ) -> ControlFlow; +} diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index 34b1e9bc9410..8ebb0b8bd059 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -1,3 +1,4 @@ +use std::ops::ControlFlow; use std::sync::Arc; use polars_core::chunked_array::cast::CastOptions; @@ -14,10 +15,13 @@ use polars_ops::frame::{JoinArgs, JoinType}; use polars_ops::series::{RLE_LENGTH_COLUMN_NAME, RLE_VALUE_COLUMN_NAME}; use polars_plan::plans::AExpr; use polars_plan::plans::expr_ir::{ExprIR, OutputName}; +use polars_plan::plans::projection_height::{ExprHeightVisitor, ExprProjectionHeight}; use polars_plan::prelude::*; +use polars_plan::traversal::visitor::{NodeVisitor, SubtreeVisit}; use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; use polars_utils::pl_str::PlSmallStr; +use polars_utils::scratch_vec::ScratchVec; use polars_utils::{unique_column_name, unitvec}; use slotmap::SlotMap; @@ -53,6 +57,8 @@ struct LowerExprContext<'a> { expr_arena: &'a mut Arena, phys_sm: &'a mut SlotMap, cache: &'a mut ExprCache, + node_scratch: &'a mut ScratchVec, + ae_height_scratch: &'a mut ScratchVec, } impl<'a> From> for StreamingLowerIRContext<'a> { @@ -303,124 +309,75 @@ fn build_input_independent_node_with_ctx( ))) } -#[recursive::recursive] -pub fn is_length_preserving_rec( - expr_key: ExprNodeKey, - arena: &Arena, - cache: &mut PlHashMap, -) -> bool { - if let Some(ret) = cache.get(&expr_key) { - return *ret; +fn is_length_preserving_ctx(expr_key: ExprNodeKey, ctx: &mut LowerExprContext) -> bool { + struct CachedLengthPreservingVisitor<'a> { + height_resolver: &'a mut ExprHeightVisitor<'a>, + cache: &'a mut PlHashMap, } - let ret = match arena.get(expr_key) { - // Handled separately in `Eval`. - AExpr::Element => unreachable!(), - // Mapped to `Column` in `StructEval`. - AExpr::StructField(_) => unreachable!(), + impl<'a> NodeVisitor for CachedLengthPreservingVisitor<'a> { + type Edge = ExprProjectionHeight; + type Key = Node; + type Storage = &'a Arena; + type BreakValue = (); - AExpr::Gather { .. } - | AExpr::Explode { .. } - | AExpr::Filter { .. } - | AExpr::Agg(_) - | AExpr::Slice { .. } - | AExpr::Len - | AExpr::Literal(_) => false, + fn default_edge(&mut self) -> Self::Edge { + self.height_resolver.default_edge() + } - AExpr::Column(_) => true, + fn pre_visit( + &mut self, + key: Self::Key, + storage: &mut Self::Storage, + edges: &mut dyn polars_plan::traversal::edge_provider::NodeEdgesProvider, + ) -> ControlFlow { + use ControlFlow as CF; + + if let Some(length_preserving) = self.cache.get(&key) { + edges.outputs()[0] = if *length_preserving { + ExprProjectionHeight::Column + } else { + ExprProjectionHeight::Unknown + }; - AExpr::Cast { - expr: inner, - dtype: _, - options: _, - } - | AExpr::Sort { - expr: inner, - options: _, + CF::Continue(SubtreeVisit::Skip) + } else { + self.height_resolver.pre_visit(key, storage, edges) + } } - | AExpr::SortBy { - expr: inner, - by: _, - sort_options: _, - } => is_length_preserving_rec(*inner, arena, cache), - AExpr::BinaryExpr { left, op: _, right } => { - // As long as at least one input is length-preserving the other side - // should either broadcast or have the same length. - is_length_preserving_rec(*left, arena, cache) - || is_length_preserving_rec(*right, arena, cache) - }, - AExpr::Ternary { - predicate, - truthy, - falsy, - } => { - is_length_preserving_rec(*predicate, arena, cache) - || is_length_preserving_rec(*truthy, arena, cache) - || is_length_preserving_rec(*falsy, arena, cache) - }, - AExpr::AnonymousAgg { .. } => false, - AExpr::AnonymousFunction { - input, - function: _, - options, - fmt_str: _, - } - | AExpr::Function { - input, - function: _, - options, - } => { - // TODO: actually inspect the functions? This is overly conservative. - options.is_length_preserving() - && input - .iter() - .all(|expr| is_length_preserving_rec(expr.node(), arena, cache)) - }, - AExpr::Eval { - expr, - evaluation: _, - variant: _, - } => is_length_preserving_rec(*expr, arena, cache), - #[cfg(feature = "dynamic_group_by")] - AExpr::Rolling { - function: _, - index_column: _, - period: _, - offset: _, - closed_window: _, - } => true, - AExpr::StructEval { - expr, - evaluation: _, - } => is_length_preserving_rec(*expr, arena, cache), - AExpr::Over { - function: _, // Actually shouldn't matter for window functions. - partition_by: _, - order_by: _, - mapping, - } => !matches!(mapping, WindowMapping::Explode), - }; + fn post_visit( + &mut self, + key: Self::Key, + storage: &mut Self::Storage, + edges: &mut dyn polars_plan::traversal::edge_provider::NodeEdgesProvider, + ) -> ControlFlow { + let control_flow = self.height_resolver.post_visit(key, storage, edges); - cache.insert(expr_key, ret); - ret -} + let length_preserving = matches!(edges.outputs()[0], ExprProjectionHeight::Column); -#[expect(dead_code)] -pub fn is_length_preserving( - expr_key: ExprNodeKey, - expr_arena: &Arena, - cache: &mut ExprCache, -) -> bool { - is_length_preserving_rec(expr_key, expr_arena, &mut cache.is_length_preserving) -} + self.cache.insert(key, length_preserving); -fn is_length_preserving_ctx(expr_key: ExprNodeKey, ctx: &mut LowerExprContext) -> bool { - is_length_preserving_rec( + control_flow + } + } + + let mut visitor = CachedLengthPreservingVisitor { + cache: &mut ctx.cache.is_length_preserving, + height_resolver: &mut Default::default(), + }; + + let height = aexpr_tree_traversal( expr_key, - ctx.expr_arena, - &mut ctx.cache.is_length_preserving, + &mut &*ctx.expr_arena, + ctx.node_scratch.get(), + ctx.ae_height_scratch.get(), + &mut visitor, ) + .continue_value() + .unwrap(); + + matches!(height, ExprProjectionHeight::Column) } fn build_fallback_node_with_ctx( @@ -2745,6 +2702,8 @@ pub fn lower_exprs( cache: expr_cache, prepare_visualization: ctx.prepare_visualization, sortedness: ctx.sortedness, + node_scratch: &mut Default::default(), + ae_height_scratch: &mut Default::default(), }; let node_exprs = exprs.iter().map(|e| e.node()).collect_vec(); let (transformed_input, transformed_exprs) = @@ -2773,6 +2732,8 @@ pub fn build_select_stream( cache: expr_cache, prepare_visualization: ctx.prepare_visualization, sortedness: ctx.sortedness, + node_scratch: &mut Default::default(), + ae_height_scratch: &mut Default::default(), }; build_select_stream_with_ctx(input, exprs, &mut ctx) } @@ -2852,6 +2813,8 @@ pub fn build_length_preserving_select_stream( cache: expr_cache, prepare_visualization: ctx.prepare_visualization, sortedness: ctx.sortedness, + node_scratch: &mut Default::default(), + ae_height_scratch: &mut Default::default(), }; let already_length_preserving = exprs .iter() diff --git a/crates/polars-utils/src/aliases.rs b/crates/polars-utils/src/aliases.rs index b5b5d82053cb..89cde41b5c71 100644 --- a/crates/polars-utils/src/aliases.rs +++ b/crates/polars-utils/src/aliases.rs @@ -10,6 +10,30 @@ pub type PlHashSet = hashbrown::HashSet; pub type PlIndexMap = indexmap::IndexMap; pub type PlIndexSet = indexmap::IndexSet; +/// HashMap container with a getter that clears the HashMap. +#[derive(Default)] +pub struct ScratchHashMap(PlHashMap); + +impl ScratchHashMap { + /// Clear the HashMap and return a mutable reference to it. + pub fn get(&mut self) -> &mut PlHashMap { + self.0.clear(); + &mut self.0 + } +} + +/// HashSet container with a getter that clears the HashSet. +#[derive(Default)] +pub struct ScratchHashSet(PlHashSet); + +impl ScratchHashSet { + /// Clear the HashSet and return a mutable reference to it. + pub fn get(&mut self) -> &mut PlHashSet { + self.0.clear(); + &mut self.0 + } +} + pub trait SeedableFromU64SeedExt { fn seed_from_u64(seed: u64) -> Self; } diff --git a/crates/polars-utils/src/arena.rs b/crates/polars-utils/src/arena.rs index 0c37df93b26d..c9878709d94e 100644 --- a/crates/polars-utils/src/arena.rs +++ b/crates/polars-utils/src/arena.rs @@ -46,18 +46,6 @@ impl Default for Arena { } } -impl AsRef> for &Arena { - fn as_ref(&self) -> &Arena { - self - } -} - -impl AsRef> for &mut Arena { - fn as_ref(&self) -> &Arena { - self - } -} - /// Simple Arena implementation /// Allocates memory and stores item in a Vec. Only deallocates when being dropped itself. impl Arena { diff --git a/crates/polars-utils/src/collection.rs b/crates/polars-utils/src/collection.rs new file mode 100644 index 000000000000..d653fcaa9da1 --- /dev/null +++ b/crates/polars-utils/src/collection.rs @@ -0,0 +1,164 @@ +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut, Index, IndexMut}; + +pub trait Collection { + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize; + fn get(&self, idx: usize) -> Option<&T>; + fn get_mut(&mut self, idx: usize) -> Option<&mut T>; +} + +/// Wrapper that implements indexing. +pub struct CollectionWrap> { + inner: C, + phantom: PhantomData, +} + +impl> CollectionWrap { + pub fn new(inner: C) -> Self { + Self { + inner, + phantom: PhantomData, + } + } + + pub fn iter(&self) -> CollectionIter<'_, T, C> { + CollectionIter { + idx: 0, + collection: &self.inner, + phantom: PhantomData, + } + } + + pub fn for_each_mut(&mut self, mut f: F) + where + F: for<'b> FnMut(&'b mut T), + { + (0..self.len()).for_each(move |i| f(self.get_mut(i).unwrap())) + } + + pub fn map_mut<'a, B, F>(&'a mut self, mut f: F) -> impl Iterator + where + F: for<'b> FnMut(&'b mut T) -> B + 'a, + { + (0..self.len()).map(move |i| f(self.get_mut(i).unwrap())) + } + + pub fn into_inner(self) -> C { + self.inner + } +} + +impl> Deref for CollectionWrap { + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl> DerefMut for CollectionWrap { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl> Index for CollectionWrap { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + self.get(index).unwrap() + } +} + +impl> IndexMut for CollectionWrap { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + self.get_mut(index).unwrap() + } +} + +impl, const N: usize> TryFrom> for [T; N] { + type Error = (); + + fn try_from(value: CollectionWrap) -> Result { + if value.len() != N { + return Err(()); + } + + Ok(std::array::from_fn(|i| value.get(i).unwrap().clone())) + } +} + +impl> From for CollectionWrap { + fn from(value: C) -> Self { + Self { + inner: value, + phantom: PhantomData, + } + } +} + +impl Collection for &mut dyn Collection { + fn len(&self) -> usize { + (**self).len() + } + + fn get(&self, idx: usize) -> Option<&T> { + (**self).get(idx) + } + + fn get_mut(&mut self, idx: usize) -> Option<&mut T> { + (**self).get_mut(idx) + } +} + +pub struct CollectionIter<'a, T: 'a, C: Collection> { + idx: usize, + collection: &'a C, + phantom: PhantomData, +} + +impl<'a, T, C: Collection> Iterator for CollectionIter<'a, T, C> { + type Item = &'a T; + + fn next(&mut self) -> Option { + let item = self.collection.get(self.idx); + + if item.is_some() { + self.idx += 1; + } + + item + } +} + +impl Collection for [T] { + fn len(&self) -> usize { + <[T]>::len(self) + } + + fn get(&self, idx: usize) -> Option<&T> { + <[T]>::get(self, idx) + } + + fn get_mut(&mut self, idx: usize) -> Option<&mut T> { + <[T]>::get_mut(self, idx) + } +} + +impl Collection for &mut [T] { + fn len(&self) -> usize { + <[T]>::len(self) + } + + fn get(&self, idx: usize) -> Option<&T> { + <[T]>::get(self, idx) + } + + fn get_mut(&mut self, idx: usize) -> Option<&mut T> { + <[T]>::get_mut(self, idx) + } +} diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index 552315d0c1d9..8606b02cd68e 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -92,5 +92,6 @@ pub mod kahan_sum; pub use either; pub use idx_vec::UnitVec; pub mod chunked_bytes_cursor; +pub mod collection; pub mod concat_vec; pub mod scratch_vec;