Skip to content

Commit 34509f2

Browse files
authored
Optimize L2Norm for ConstantArray (#7495)
## Summary Optimizes L2Norm for ConstantArray, and makes sure we dont hit that tracing info case. ## Testing Some basic tests. Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 7e0af3c commit 34509f2

File tree

1 file changed

+111
-4
lines changed

1 file changed

+111
-4
lines changed

vortex-tensor/src/scalar_fns/l2_norm.rs

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ use prost::Message;
1010
use vortex_array::ArrayRef;
1111
use vortex_array::ExecutionCtx;
1212
use vortex_array::IntoArray;
13+
use vortex_array::arrays::Constant;
14+
use vortex_array::arrays::ConstantArray;
1315
use vortex_array::arrays::ExtensionArray;
1416
use vortex_array::arrays::PrimitiveArray;
1517
use vortex_array::arrays::ScalarFnArray;
@@ -26,6 +28,7 @@ use vortex_array::dtype::Nullability;
2628
use vortex_array::dtype::proto::dtype as pb;
2729
use vortex_array::expr::Expression;
2830
use vortex_array::match_each_float_ptype;
31+
use vortex_array::scalar::Scalar;
2932
use vortex_array::scalar_fn::Arity;
3033
use vortex_array::scalar_fn::ChildName;
3134
use vortex_array::scalar_fn::EmptyOptions;
@@ -131,6 +134,8 @@ impl ScalarFnVTable for L2Norm {
131134
let tensor_flat_size = tensor_match.list_size();
132135
let element_ptype = tensor_match.element_ptype();
133136

137+
let norm_dtype = DType::Primitive(element_ptype, ext.nullability());
138+
134139
// L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored
135140
// norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics
136141
// instead of forcing a decode-and-recompute path here.
@@ -139,14 +144,37 @@ impl ScalarFnVTable for L2Norm {
139144
.nth_child(1)
140145
.vortex_expect("L2Denom must have at 2 children");
141146

142-
vortex_ensure_eq!(
143-
norms.dtype(),
144-
&DType::Primitive(element_ptype, input_ref.dtype().nullability())
145-
);
147+
vortex_ensure_eq!(norms.dtype(), &norm_dtype);
146148

147149
return Ok(norms);
148150
}
149151

152+
// Optimize for the constant array case.
153+
if let Some(array) = input_ref.as_opt::<Constant>() {
154+
let scalar = array.scalar().as_extension().to_storage_scalar();
155+
156+
let Some(elements) = scalar.as_list().elements() else {
157+
return Ok(ConstantArray::new(Scalar::null(norm_dtype), row_count).into_array());
158+
};
159+
160+
let norm_scalar = match_each_float_ptype!(element_ptype, |T| {
161+
let values: Vec<T> = elements
162+
.iter()
163+
.map(|s| {
164+
s.as_primitive()
165+
.as_::<T>()
166+
.vortex_expect("element was somehow not the correct float")
167+
})
168+
.collect();
169+
let norm = l2_norm_row::<T>(&values);
170+
171+
Scalar::try_new(norm_dtype, Some(norm.into()))
172+
})?;
173+
174+
let norms = ConstantArray::new(norm_scalar, row_count).into_array();
175+
return Ok(norms);
176+
}
177+
150178
let input: ExtensionArray = input_ref.execute(ctx)?;
151179
let validity = input.as_ref().validity()?;
152180

@@ -244,10 +272,18 @@ mod tests {
244272
use vortex_array::ArrayRef;
245273
use vortex_array::IntoArray;
246274
use vortex_array::VortexSessionExecute;
275+
use vortex_array::arrays::Constant;
276+
use vortex_array::arrays::ConstantArray;
247277
use vortex_array::arrays::MaskedArray;
248278
use vortex_array::arrays::PrimitiveArray;
249279
use vortex_array::arrays::ScalarFnArray;
250280
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin;
281+
use vortex_array::dtype::DType;
282+
use vortex_array::dtype::Nullability;
283+
use vortex_array::dtype::PType;
284+
use vortex_array::dtype::extension::ExtDType;
285+
use vortex_array::extension::EmptyMetadata;
286+
use vortex_array::scalar::Scalar;
251287
use vortex_array::validity::Validity;
252288
use vortex_error::VortexResult;
253289

@@ -256,6 +292,7 @@ mod tests {
256292
use crate::utils::test_helpers::assert_close;
257293
use crate::utils::test_helpers::tensor_array;
258294
use crate::utils::test_helpers::vector_array;
295+
use crate::vector::Vector;
259296

260297
/// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec<f64>`.
261298
fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
@@ -326,6 +363,76 @@ mod tests {
326363
Ok(())
327364
}
328365

366+
/// Builds a [`ConstantArray`] whose scalar is a [`Vector`] extension scalar wrapping a
367+
/// fixed-size list of `elements`, broadcast to `len` rows.
368+
fn constant_vector_extension_array(elements: &[f64], len: usize) -> ArrayRef {
369+
let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
370+
let children: Vec<Scalar> = elements
371+
.iter()
372+
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
373+
.collect();
374+
let storage_scalar =
375+
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
376+
let ext_scalar = Scalar::extension::<Vector>(EmptyMetadata, storage_scalar);
377+
ConstantArray::new(ext_scalar, len).into_array()
378+
}
379+
380+
/// A constant input whose scalar is a non-null tensor should short-circuit to a
381+
/// [`ConstantArray`] output whose scalar is the precomputed norm. Uses [`execute_until`] so
382+
/// execution stops at the [`Constant`] encoding instead of canonicalizing into a
383+
/// [`PrimitiveArray`].
384+
#[test]
385+
fn constant_non_null_input_yields_constant_output() -> VortexResult<()> {
386+
let input = constant_vector_extension_array(&[3.0, 4.0], 4);
387+
388+
let scalar_fn = L2Norm::new().erased();
389+
let result = ScalarFnArray::try_new(scalar_fn, vec![input], 4)?.into_array();
390+
let mut ctx = SESSION.create_execution_ctx();
391+
let output = result.execute_until::<Constant>(&mut ctx)?;
392+
393+
let constant = output
394+
.as_opt::<Constant>()
395+
.expect("L2Norm over a constant input must produce a constant output");
396+
assert_eq!(constant.len(), 4);
397+
let norm = constant
398+
.scalar()
399+
.as_primitive()
400+
.as_::<f64>()
401+
.expect("norm scalar must be a non-null primitive");
402+
assert_close(&[norm], &[5.0]);
403+
Ok(())
404+
}
405+
406+
/// A constant input whose scalar is null should short-circuit to a null [`ConstantArray`] of
407+
/// the correct primitive dtype and length.
408+
#[test]
409+
fn constant_null_input_yields_null_constant_output() -> VortexResult<()> {
410+
let storage_dtype = DType::FixedSizeList(
411+
DType::Primitive(PType::F64, Nullability::NonNullable).into(),
412+
2,
413+
Nullability::Nullable,
414+
);
415+
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, storage_dtype)?.erased();
416+
let null_scalar = Scalar::null(DType::Extension(ext_dtype));
417+
let input = ConstantArray::new(null_scalar, 3).into_array();
418+
419+
let scalar_fn = L2Norm::new().erased();
420+
let result = ScalarFnArray::try_new(scalar_fn, vec![input], 3)?.into_array();
421+
let mut ctx = SESSION.create_execution_ctx();
422+
let output = result.execute_until::<Constant>(&mut ctx)?;
423+
424+
let constant = output
425+
.as_opt::<Constant>()
426+
.expect("null constant input must produce a constant output");
427+
assert_eq!(constant.len(), 3);
428+
assert!(constant.scalar().is_null());
429+
assert_eq!(
430+
constant.dtype(),
431+
&DType::Primitive(PType::F64, Nullability::Nullable)
432+
);
433+
Ok(())
434+
}
435+
329436
#[rstest]
330437
#[case::fixed_shape_tensor(tensor_array(&[3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)]
331438
#[case::vector(vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)]

0 commit comments

Comments
 (0)