Skip to content

Commit 76a1973

Browse files
committed
feat[array]: refactor zip compute kernel to lazy zip via ScalarFnArray
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 493dbb5 commit 76a1973

File tree

12 files changed

+356
-199
lines changed

12 files changed

+356
-199
lines changed

vortex-array/src/arrays/chunked/compute/rules.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::arrays::ConstantArray;
1313
use crate::arrays::ConstantVTable;
1414
use crate::arrays::ScalarFnArray;
1515
use crate::compute::FillNullReduceAdaptor;
16+
use crate::compute::ZipReduceAdaptor;
1617
use crate::optimizer::ArrayOptimizer;
1718
use crate::optimizer::rules::ArrayParentReduceRule;
1819
use crate::optimizer::rules::ParentRuleSet;
@@ -21,6 +22,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet<ChunkedVTable> = ParentRuleSet::new
2122
ParentRuleSet::lift(&ChunkedUnaryScalarFnPushDownRule),
2223
ParentRuleSet::lift(&ChunkedConstantScalarFnPushDownRule),
2324
ParentRuleSet::lift(&FillNullReduceAdaptor(ChunkedVTable)),
25+
ParentRuleSet::lift(&ZipReduceAdaptor(ChunkedVTable)),
2426
]);
2527

2628
/// Push down any unary scalar function through chunked arrays.

vortex-array/src/arrays/chunked/compute/zip.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,14 @@ use crate::Array;
88
use crate::ArrayRef;
99
use crate::arrays::ChunkedArray;
1010
use crate::arrays::ChunkedVTable;
11-
use crate::compute::ZipKernel;
12-
use crate::compute::ZipKernelAdapter;
11+
use crate::compute::ZipReduce;
1312
use crate::compute::zip;
14-
use crate::register_kernel;
1513

16-
// Push down the zip call to the chunks. Without this kernel
14+
// Push down the zip call to the chunks. Without this rule
1715
// the default implementation canonicalises the chunked array
1816
// then zips once.
19-
impl ZipKernel for ChunkedVTable {
17+
impl ZipReduce for ChunkedVTable {
2018
fn zip(
21-
&self,
2219
if_true: &ChunkedArray,
2320
if_false: &dyn Array,
2421
mask: &Mask,
@@ -72,8 +69,6 @@ impl ZipKernel for ChunkedVTable {
7269
}
7370
}
7471

75-
register_kernel!(ZipKernelAdapter(ChunkedVTable).lift());
76-
7772
#[cfg(test)]
7873
mod tests {
7974
use vortex_buffer::buffer;

vortex-array/src/arrays/struct_/compute/zip.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,20 @@ use vortex_mask::Mask;
1010

1111
use crate::Array;
1212
use crate::ArrayRef;
13+
use crate::ExecutionCtx;
1314
use crate::arrays::StructArray;
1415
use crate::arrays::StructVTable;
1516
use crate::compute::ZipKernel;
16-
use crate::compute::ZipKernelAdapter;
1717
use crate::compute::zip;
18-
use crate::register_kernel;
1918
use crate::validity::Validity;
2019
use crate::vtable::ValidityHelper;
2120

2221
impl ZipKernel for StructVTable {
2322
fn zip(
24-
&self,
2523
if_true: &StructArray,
2624
if_false: &dyn Array,
2725
mask: &Mask,
26+
_ctx: &mut ExecutionCtx,
2827
) -> VortexResult<Option<ArrayRef>> {
2928
let Some(if_false) = if_false.as_opt::<StructVTable>() else {
3029
return Ok(None);
@@ -66,8 +65,6 @@ impl ZipKernel for StructVTable {
6665
}
6766
}
6867

69-
register_kernel!(ZipKernelAdapter(StructVTable).lift());
70-
7168
#[cfg(test)]
7269
mod tests {
7370
use vortex_dtype::FieldNames;

vortex-array/src/arrays/struct_/vtable/kernel.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
use crate::arrays::StructVTable;
55
use crate::arrays::TakeExecuteAdaptor;
6+
use crate::compute::ZipExecuteAdaptor;
67
use crate::kernel::ParentKernelSet;
78

8-
pub(super) const PARENT_KERNELS: ParentKernelSet<StructVTable> =
9-
ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(StructVTable))]);
9+
pub(super) const PARENT_KERNELS: ParentKernelSet<StructVTable> = ParentKernelSet::new(&[
10+
ParentKernelSet::lift(&TakeExecuteAdaptor(StructVTable)),
11+
ParentKernelSet::lift(&ZipExecuteAdaptor(StructVTable)),
12+
]);

vortex-array/src/arrays/varbinview/compute/zip.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,21 @@ use vortex_vector::binaryview::BinaryView;
1212

1313
use crate::Array;
1414
use crate::ArrayRef;
15+
use crate::ExecutionCtx;
1516
use crate::arrays::VarBinViewArray;
1617
use crate::arrays::VarBinViewVTable;
1718
use crate::builders::DeduplicatedBuffers;
1819
use crate::builders::LazyBitBufferBuilder;
1920
use crate::compute::ZipKernel;
20-
use crate::compute::ZipKernelAdapter;
21-
use crate::register_kernel;
2221

2322
// A dedicated VarBinView zip kernel that builds the result directly by adjusting views and validity,
2423
// instead of routing through the generic builder (which would redo buffer lookups per mask slice).
2524
impl ZipKernel for VarBinViewVTable {
2625
fn zip(
27-
&self,
2826
if_true: &VarBinViewArray,
2927
if_false: &dyn Array,
3028
mask: &Mask,
29+
_ctx: &mut ExecutionCtx,
3130
) -> VortexResult<Option<ArrayRef>> {
3231
let Some(if_false) = if_false.as_opt::<VarBinViewVTable>() else {
3332
return Ok(None);
@@ -37,7 +36,6 @@ impl ZipKernel for VarBinViewVTable {
3736
vortex_bail!("input arrays to zip must have the same dtype");
3837
}
3938

40-
// compute fn already asserts if_true.len() == if_false.len()
4139
let len = if_true.len();
4240
let dtype = if_true
4341
.dtype()
@@ -205,8 +203,6 @@ fn push_view(
205203
validity_builder.append_non_null();
206204
}
207205

208-
register_kernel!(ZipKernelAdapter(VarBinViewVTable).lift());
209-
210206
#[cfg(test)]
211207
mod tests {
212208
use vortex_dtype::DType;

vortex-array/src/arrays/varbinview/vtable/kernel.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
use crate::arrays::TakeExecuteAdaptor;
55
use crate::arrays::VarBinViewVTable;
6+
use crate::compute::ZipExecuteAdaptor;
67
use crate::kernel::ParentKernelSet;
78

8-
pub(super) const PARENT_KERNELS: ParentKernelSet<VarBinViewVTable> =
9-
ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(VarBinViewVTable))]);
9+
pub(super) const PARENT_KERNELS: ParentKernelSet<VarBinViewVTable> = ParentKernelSet::new(&[
10+
ParentKernelSet::lift(&TakeExecuteAdaptor(VarBinViewVTable)),
11+
ParentKernelSet::lift(&ZipExecuteAdaptor(VarBinViewVTable)),
12+
]);

vortex-array/src/builtins.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ use vortex_dtype::DType;
1313
use vortex_dtype::FieldName;
1414
use vortex_error::VortexResult;
1515
use vortex_scalar::Scalar;
16+
use vortex_session::VortexSession;
1617

1718
use crate::Array;
1819
use crate::ArrayRef;
20+
use crate::ExecutionCtx;
1921
use crate::IntoArray;
2022
use crate::arrays::ConstantArray;
2123
use crate::arrays::ScalarFnArrayExt;
@@ -28,6 +30,7 @@ use crate::expr::IsNull;
2830
use crate::expr::Mask;
2931
use crate::expr::Not;
3032
use crate::expr::VTableExt;
33+
use crate::expr::Zip;
3134
use crate::optimizer::ArrayOptimizer;
3235

3336
/// A collection of built-in scalar functions that can be applied to expressions or arrays.
@@ -51,6 +54,9 @@ pub trait ExprBuiltins: Sized {
5154

5255
/// Boolean negation.
5356
fn not(&self) -> VortexResult<Expression>;
57+
58+
/// Conditional selection: `result[i] = if mask[i] then self[i] else if_false[i]`.
59+
fn zip(&self, if_false: Expression, mask: Expression) -> VortexResult<Expression>;
5460
}
5561

5662
impl ExprBuiltins for Expression {
@@ -77,6 +83,10 @@ impl ExprBuiltins for Expression {
7783
fn not(&self) -> VortexResult<Expression> {
7884
Not.try_new_expr(EmptyOptions, [self.clone()])
7985
}
86+
87+
fn zip(&self, if_false: Expression, mask: Expression) -> VortexResult<Expression> {
88+
Zip.try_new_expr(EmptyOptions, [self.clone(), if_false, mask])
89+
}
8090
}
8191

8292
pub trait ArrayBuiltins: Sized {
@@ -99,6 +109,9 @@ pub trait ArrayBuiltins: Sized {
99109

100110
/// Boolean negation.
101111
fn not(&self) -> VortexResult<ArrayRef>;
112+
113+
/// Conditional selection: `result[i] = if mask[i] then self[i] else if_false[i]`.
114+
fn zip(&self, if_false: ArrayRef, mask: ArrayRef) -> VortexResult<ArrayRef>;
102115
}
103116

104117
impl ArrayBuiltins for ArrayRef {
@@ -141,4 +154,11 @@ impl ArrayBuiltins for ArrayRef {
141154
Not.try_new_array(self.len(), EmptyOptions, [self.clone()])?
142155
.optimize()
143156
}
157+
158+
fn zip(&self, if_false: ArrayRef, mask: ArrayRef) -> VortexResult<ArrayRef> {
159+
let scalar_fn =
160+
Zip.try_new_array(self.len(), EmptyOptions, [self.clone(), if_false, mask])?;
161+
let mut ctx = ExecutionCtx::new(VortexSession::empty());
162+
scalar_fn.execute::<ArrayRef>(&mut ctx)
163+
}
144164
}

vortex-array/src/compute/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ pub use crate::expr::FillNullExecuteAdaptor;
4949
pub use crate::expr::FillNullKernel;
5050
pub use crate::expr::FillNullReduce;
5151
pub use crate::expr::FillNullReduceAdaptor;
52+
pub use crate::expr::ZipExecuteAdaptor;
53+
pub use crate::expr::ZipKernel;
54+
pub use crate::expr::ZipReduce;
55+
pub use crate::expr::ZipReduceAdaptor;
5256

5357
#[cfg(feature = "arbitrary")]
5458
mod arbitrary;
@@ -99,7 +103,6 @@ pub fn warm_up_vtables() {
99103
nan_count::warm_up_vtable();
100104
numeric::warm_up_vtable();
101105
sum::warm_up_vtable();
102-
zip::warm_up_vtable();
103106
}
104107

105108
impl ComputeFn {

0 commit comments

Comments
 (0)