@@ -10,6 +10,8 @@ use prost::Message;
1010use vortex_array:: ArrayRef ;
1111use vortex_array:: ExecutionCtx ;
1212use vortex_array:: IntoArray ;
13+ use vortex_array:: arrays:: Constant ;
14+ use vortex_array:: arrays:: ConstantArray ;
1315use vortex_array:: arrays:: ExtensionArray ;
1416use vortex_array:: arrays:: PrimitiveArray ;
1517use vortex_array:: arrays:: ScalarFnArray ;
@@ -26,6 +28,7 @@ use vortex_array::dtype::Nullability;
2628use vortex_array:: dtype:: proto:: dtype as pb;
2729use vortex_array:: expr:: Expression ;
2830use vortex_array:: match_each_float_ptype;
31+ use vortex_array:: scalar:: Scalar ;
2932use vortex_array:: scalar_fn:: Arity ;
3033use vortex_array:: scalar_fn:: ChildName ;
3134use vortex_array:: scalar_fn:: EmptyOptions ;
@@ -131,6 +134,8 @@ impl ScalarFnVTable for L2Norm {
131134 let tensor_flat_size = tensor_match. list_size ( ) ;
132135 let element_ptype = tensor_match. element_ptype ( ) ;
133136
137+ let norm_dtype = DType :: Primitive ( element_ptype, ext. nullability ( ) ) ;
138+
134139 // L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored
135140 // norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics
136141 // instead of forcing a decode-and-recompute path here.
@@ -139,14 +144,37 @@ impl ScalarFnVTable for L2Norm {
139144 . nth_child ( 1 )
140145 . vortex_expect ( "L2Denom must have at 2 children" ) ;
141146
142- vortex_ensure_eq ! (
143- norms. dtype( ) ,
144- & DType :: Primitive ( element_ptype, input_ref. dtype( ) . nullability( ) )
145- ) ;
147+ vortex_ensure_eq ! ( norms. dtype( ) , & norm_dtype) ;
146148
147149 return Ok ( norms) ;
148150 }
149151
152+ // Optimize for the constant array case.
153+ if let Some ( array) = input_ref. as_opt :: < Constant > ( ) {
154+ let scalar = array. scalar ( ) . as_extension ( ) . to_storage_scalar ( ) ;
155+
156+ let Some ( elements) = scalar. as_list ( ) . elements ( ) else {
157+ return Ok ( ConstantArray :: new ( Scalar :: null ( norm_dtype) , row_count) . into_array ( ) ) ;
158+ } ;
159+
160+ let norm_scalar = match_each_float_ptype ! ( element_ptype, |T | {
161+ let values: Vec <T > = elements
162+ . iter( )
163+ . map( |s| {
164+ s. as_primitive( )
165+ . as_:: <T >( )
166+ . vortex_expect( "element was somehow not the correct float" )
167+ } )
168+ . collect( ) ;
169+ let norm = l2_norm_row:: <T >( & values) ;
170+
171+ Scalar :: try_new( norm_dtype, Some ( norm. into( ) ) )
172+ } ) ?;
173+
174+ let norms = ConstantArray :: new ( norm_scalar, row_count) . into_array ( ) ;
175+ return Ok ( norms) ;
176+ }
177+
150178 let input: ExtensionArray = input_ref. execute ( ctx) ?;
151179 let validity = input. as_ref ( ) . validity ( ) ?;
152180
@@ -244,10 +272,18 @@ mod tests {
244272 use vortex_array:: ArrayRef ;
245273 use vortex_array:: IntoArray ;
246274 use vortex_array:: VortexSessionExecute ;
275+ use vortex_array:: arrays:: Constant ;
276+ use vortex_array:: arrays:: ConstantArray ;
247277 use vortex_array:: arrays:: MaskedArray ;
248278 use vortex_array:: arrays:: PrimitiveArray ;
249279 use vortex_array:: arrays:: ScalarFnArray ;
250280 use vortex_array:: arrays:: scalar_fn:: plugin:: ScalarFnArrayPlugin ;
281+ use vortex_array:: dtype:: DType ;
282+ use vortex_array:: dtype:: Nullability ;
283+ use vortex_array:: dtype:: PType ;
284+ use vortex_array:: dtype:: extension:: ExtDType ;
285+ use vortex_array:: extension:: EmptyMetadata ;
286+ use vortex_array:: scalar:: Scalar ;
251287 use vortex_array:: validity:: Validity ;
252288 use vortex_error:: VortexResult ;
253289
@@ -256,6 +292,7 @@ mod tests {
256292 use crate :: utils:: test_helpers:: assert_close;
257293 use crate :: utils:: test_helpers:: tensor_array;
258294 use crate :: utils:: test_helpers:: vector_array;
295+ use crate :: vector:: Vector ;
259296
260297 /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec<f64>`.
261298 fn eval_l2_norm ( input : ArrayRef , len : usize ) -> VortexResult < Vec < f64 > > {
@@ -326,6 +363,76 @@ mod tests {
326363 Ok ( ( ) )
327364 }
328365
366+ /// Builds a [`ConstantArray`] whose scalar is a [`Vector`] extension scalar wrapping a
367+ /// fixed-size list of `elements`, broadcast to `len` rows.
368+ fn constant_vector_extension_array ( elements : & [ f64 ] , len : usize ) -> ArrayRef {
369+ let element_dtype = DType :: Primitive ( PType :: F64 , Nullability :: NonNullable ) ;
370+ let children: Vec < Scalar > = elements
371+ . iter ( )
372+ . map ( |& v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
373+ . collect ( ) ;
374+ let storage_scalar =
375+ Scalar :: fixed_size_list ( element_dtype, children, Nullability :: NonNullable ) ;
376+ let ext_scalar = Scalar :: extension :: < Vector > ( EmptyMetadata , storage_scalar) ;
377+ ConstantArray :: new ( ext_scalar, len) . into_array ( )
378+ }
379+
380+ /// A constant input whose scalar is a non-null tensor should short-circuit to a
381+ /// [`ConstantArray`] output whose scalar is the precomputed norm. Uses [`execute_until`] so
382+ /// execution stops at the [`Constant`] encoding instead of canonicalizing into a
383+ /// [`PrimitiveArray`].
384+ #[ test]
385+ fn constant_non_null_input_yields_constant_output ( ) -> VortexResult < ( ) > {
386+ let input = constant_vector_extension_array ( & [ 3.0 , 4.0 ] , 4 ) ;
387+
388+ let scalar_fn = L2Norm :: new ( ) . erased ( ) ;
389+ let result = ScalarFnArray :: try_new ( scalar_fn, vec ! [ input] , 4 ) ?. into_array ( ) ;
390+ let mut ctx = SESSION . create_execution_ctx ( ) ;
391+ let output = result. execute_until :: < Constant > ( & mut ctx) ?;
392+
393+ let constant = output
394+ . as_opt :: < Constant > ( )
395+ . expect ( "L2Norm over a constant input must produce a constant output" ) ;
396+ assert_eq ! ( constant. len( ) , 4 ) ;
397+ let norm = constant
398+ . scalar ( )
399+ . as_primitive ( )
400+ . as_ :: < f64 > ( )
401+ . expect ( "norm scalar must be a non-null primitive" ) ;
402+ assert_close ( & [ norm] , & [ 5.0 ] ) ;
403+ Ok ( ( ) )
404+ }
405+
406+ /// A constant input whose scalar is null should short-circuit to a null [`ConstantArray`] of
407+ /// the correct primitive dtype and length.
408+ #[ test]
409+ fn constant_null_input_yields_null_constant_output ( ) -> VortexResult < ( ) > {
410+ let storage_dtype = DType :: FixedSizeList (
411+ DType :: Primitive ( PType :: F64 , Nullability :: NonNullable ) . into ( ) ,
412+ 2 ,
413+ Nullability :: Nullable ,
414+ ) ;
415+ let ext_dtype = ExtDType :: < Vector > :: try_new ( EmptyMetadata , storage_dtype) ?. erased ( ) ;
416+ let null_scalar = Scalar :: null ( DType :: Extension ( ext_dtype) ) ;
417+ let input = ConstantArray :: new ( null_scalar, 3 ) . into_array ( ) ;
418+
419+ let scalar_fn = L2Norm :: new ( ) . erased ( ) ;
420+ let result = ScalarFnArray :: try_new ( scalar_fn, vec ! [ input] , 3 ) ?. into_array ( ) ;
421+ let mut ctx = SESSION . create_execution_ctx ( ) ;
422+ let output = result. execute_until :: < Constant > ( & mut ctx) ?;
423+
424+ let constant = output
425+ . as_opt :: < Constant > ( )
426+ . expect ( "null constant input must produce a constant output" ) ;
427+ assert_eq ! ( constant. len( ) , 3 ) ;
428+ assert ! ( constant. scalar( ) . is_null( ) ) ;
429+ assert_eq ! (
430+ constant. dtype( ) ,
431+ & DType :: Primitive ( PType :: F64 , Nullability :: Nullable )
432+ ) ;
433+ Ok ( ( ) )
434+ }
435+
329436 #[ rstest]
330437 #[ case:: fixed_shape_tensor( tensor_array( & [ 3 ] , & [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ) . unwrap( ) , 2 ) ]
331438 #[ case:: vector( vector_array( 3 , & [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ) . unwrap( ) , 2 ) ]
0 commit comments