@@ -16,6 +16,7 @@ use crate::column::Column;
1616use crate :: data_type:: DataType ;
1717use crate :: error:: ErrorMode ;
1818use crate :: expr:: { Expr , SchemaContext } ;
19+ use crate :: operator:: Operator ;
1920
2021#[ derive( Debug , Clone , PartialEq , Eq ) ]
2122pub enum Device {
@@ -40,15 +41,27 @@ async fn evaluate_gpu<B: BitBlock>(
4041 . collect ( ) ;
4142
4243 let schema_context = SchemaContext :: new ( ) . with_columns ( & column_map) . with_error_mode ( error_mode) ;
44+ match expr {
45+ Expr :: Aggregate { op, arg, distinct } => {
46+ evaluate_gpu_aggregate :: < B > ( * op, arg. as_ref ( ) , * distinct, & schema_context, columns)
47+ . await
48+ }
49+ _ => evaluate_gpu_row :: < B > ( expr, & schema_context, columns) . await ,
50+ }
51+ }
52+
53+ async fn evaluate_gpu_row < B : BitBlock > (
54+ expr : & Expr , schema_context : & SchemaContext , columns : & [ Column < B > ] ,
55+ ) -> Result < Vec < Column < B > > , Box < dyn Error > > {
4356 let mut code_block = CodeBlock :: default ( ) ;
44- expr. to_nvrtc :: < B > ( & schema_context, & mut code_block) ?;
57+ expr. to_nvrtc :: < B > ( schema_context, & mut code_block) ?;
4558
4659 let size = columns[ 0 ] . len ( ) ;
4760 let input_cols =
4861 columns. iter ( ) . map ( |col| col. to_gpu_column ( ) ) . collect :: < Result < Vec < _ > , _ > > ( ) ?;
4962
5063 let mut output_cols = Vec :: < gpu_column:: Column > :: new ( ) ;
51- let result_type = expr. infer_type ( & schema_context) ?;
64+ let result_type = expr. infer_type ( schema_context) ?;
5265
5366 let gpu_col = gpu_column:: Column :: new_uninitialized :: < B > (
5467 size * result_type. native_size ( ) ,
@@ -68,3 +81,80 @@ async fn evaluate_gpu<B: BitBlock>(
6881
6982 Ok ( result_cols)
7083}
84+
85+ async fn evaluate_gpu_aggregate < B : BitBlock > (
86+ op : Operator , arg : & Expr , distinct : bool , schema_context : & SchemaContext , columns : & [ Column < B > ] ,
87+ ) -> Result < Vec < Column < B > > , Box < dyn Error > > {
88+ if distinct {
89+ return Err ( "DISTINCT aggregates are not supported yet" . into ( ) ) ;
90+ }
91+
92+ let ( col_idx, col_type) = match arg {
93+ Expr :: Column ( col_name) => schema_context
94+ . lookup ( col_name)
95+ . copied ( )
96+ . ok_or_else ( || format ! ( "unknown column: {}" , col_name) ) ?,
97+ _ => return Err ( "Aggregate argument must be a column reference" . into ( ) ) ,
98+ } ;
99+
100+ let result_type =
101+ Expr :: Aggregate { op, arg : Box :: new ( arg. clone ( ) ) , distinct } . infer_type ( schema_context) ?;
102+ let code = build_aggregate_nvrtc_code :: < B > (
103+ op,
104+ col_idx,
105+ col_type,
106+ result_type,
107+ schema_context. error_mode ( ) == ErrorMode :: Ansi ,
108+ ) ?;
109+ let input_cols =
110+ columns. iter ( ) . map ( |col| col. to_gpu_column ( ) ) . collect :: < Result < Vec < _ > , _ > > ( ) ?;
111+
112+ let mut output_cols = Vec :: < gpu_column:: Column > :: new ( ) ;
113+ let size = 1usize ;
114+ let gpu_col = gpu_column:: Column :: new_uninitialized :: < B > (
115+ size * result_type. native_size ( ) ,
116+ size. div_ceil ( B :: BITS ) ,
117+ size,
118+ ) ?;
119+ output_cols. push ( gpu_col) ;
120+
121+ cuda_launcher:: launch_aggregate :: < B > ( & code, & input_cols, & output_cols) . await ?;
122+
123+ let result_cols = output_cols
124+ . into_iter ( )
125+ . map ( |col| -> Result < _ , Box < dyn Error > > {
126+ Column :: from_gpu_column ( & col, "r0" , result_type)
127+ } )
128+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
129+
130+ Ok ( result_cols)
131+ }
132+
133+ fn build_aggregate_nvrtc_code < B : BitBlock > (
134+ op : Operator , col_idx : u16 , col_type : DataType , result_type : DataType , ansi_error_mode : bool ,
135+ ) -> Result < String , Box < dyn Error > > {
136+ let input_kernel_type = col_type. kernel_type ( ) ;
137+ let output_kernel_type = result_type. kernel_type ( ) ;
138+ let bits_type = B :: C_TYPE ;
139+
140+ let code = match op {
141+ Operator :: Min => format ! (
142+ "\t aggregate_codegen::min<TypeKind::{input_kernel_type}, TypeKind::{output_kernel_type}, {bits_type}>(ctx, input, output, num_rows, {col_idx});\n "
143+ ) ,
144+ Operator :: Max => format ! (
145+ "\t aggregate_codegen::max<TypeKind::{input_kernel_type}, TypeKind::{output_kernel_type}, {bits_type}>(ctx, input, output, num_rows, {col_idx});\n "
146+ ) ,
147+ Operator :: Sum => format ! (
148+ "\t aggregate_codegen::sum<{ansi_error_mode}, TypeKind::{input_kernel_type}, TypeKind::{output_kernel_type}, {bits_type}>(ctx, input, output, num_rows, {col_idx});\n "
149+ ) ,
150+ Operator :: Avg => format ! (
151+ "\t aggregate_codegen::avg<{ansi_error_mode}, TypeKind::{input_kernel_type}, TypeKind::{output_kernel_type}, {bits_type}>(ctx, input, output, num_rows, {col_idx});\n "
152+ ) ,
153+ Operator :: Count => format ! (
154+ "\t aggregate_codegen::count<TypeKind::{input_kernel_type}, {bits_type}>(ctx, input, output, num_rows, {col_idx});\n "
155+ ) ,
156+ _ => return Err ( format ! ( "Unsupported aggregate operator: {:?}" , op) . into ( ) ) ,
157+ } ;
158+
159+ Ok ( code)
160+ }
0 commit comments