From f516013db6f89607d53e36d0fe1624562331b4b3 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Mon, 24 Nov 2025 14:53:16 -0500 Subject: [PATCH 01/10] feat:Implement single column aggregations, no support for group by yet --- src/Backend/opti-sql-go/Expr/expr.go | 1 - .../opti-sql-go/operators/aggr/avgExec.go | 1 - .../operators/aggr/avgExec_test.go | 7 - .../opti-sql-go/operators/aggr/basicAggr.go | 5 - .../operators/aggr/basicAggr_test.go | 7 - .../opti-sql-go/operators/aggr/singleAggr.go | 266 +++++++++ .../operators/aggr/singleAggr_test.go | 529 ++++++++++++++++++ src/Backend/opti-sql-go/operators/aggr/sum.go | 1 - .../opti-sql-go/operators/aggr/sum_test.go | 7 - 9 files changed, 795 insertions(+), 29 deletions(-) delete mode 100644 src/Backend/opti-sql-go/operators/aggr/avgExec.go delete mode 100644 src/Backend/opti-sql-go/operators/aggr/avgExec_test.go delete mode 100644 src/Backend/opti-sql-go/operators/aggr/basicAggr.go delete mode 100644 src/Backend/opti-sql-go/operators/aggr/basicAggr_test.go create mode 100644 src/Backend/opti-sql-go/operators/aggr/singleAggr.go create mode 100644 src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go delete mode 100644 src/Backend/opti-sql-go/operators/aggr/sum.go delete mode 100644 src/Backend/opti-sql-go/operators/aggr/sum_test.go diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go index b3eed34..4ae10bb 100644 --- a/src/Backend/opti-sql-go/Expr/expr.go +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -602,7 +602,6 @@ func EvalCast(c *CastExpr, batch *operators.RecordBatch) (arrow.Array, error) { castOpts := compute.SafeCastOptions(c.TargetType) out, err := compute.CastArray(context.TODO(), arr, castOpts) if err != nil { - // This is a runtime cast error return nil, fmt.Errorf("cast error: cannot cast %s to %s: %w", arr.DataType(), c.TargetType, err) } diff --git a/src/Backend/opti-sql-go/operators/aggr/avgExec.go b/src/Backend/opti-sql-go/operators/aggr/avgExec.go deleted file mode 100644 index abd1ad5..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/avgExec.go +++ /dev/null @@ -1 +0,0 @@ -package aggr diff --git a/src/Backend/opti-sql-go/operators/aggr/avgExec_test.go b/src/Backend/opti-sql-go/operators/aggr/avgExec_test.go deleted file mode 100644 index 67671d0..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/avgExec_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package aggr - -import "testing" - -func TestAvgExec(t *testing.T) { - // Simple passing test -} diff --git a/src/Backend/opti-sql-go/operators/aggr/basicAggr.go b/src/Backend/opti-sql-go/operators/aggr/basicAggr.go deleted file mode 100644 index 0ffa1f3..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/basicAggr.go +++ /dev/null @@ -1,5 +0,0 @@ -package aggr - -// Min -//Max -//Count diff --git a/src/Backend/opti-sql-go/operators/aggr/basicAggr_test.go b/src/Backend/opti-sql-go/operators/aggr/basicAggr_test.go deleted file mode 100644 index 7a59206..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/basicAggr_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package aggr - -import "testing" - -func TestBasicAggr(t *testing.T) { - // Simple passing test -} diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go new file mode 100644 index 0000000..f59da08 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -0,0 +1,266 @@ +package aggr + +import ( + "context" + "errors" + "fmt" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/compute" +) + +// TODO: next steps are to deal with group by statments but that can be dealt with after basic aggr that return just 1 global value +var ( + ErrUnsupportedAggrFunc = func(aggr int) error { + return fmt.Errorf("%d is an unsupported aggregate function", aggr) + } + ErrInvalidAggrColumnType = func(value any) error { + return fmt.Errorf("%v of type %T cannot be cast to float64 so it is not a valid column type to aggragate on", value, value) + } +) + +type AggrFunc int + +const ( + Min AggrFunc = iota + Max + Count + Sum + Avg +) + +var ( + _ = (Accumulator)(&MinAggrAccumulator{}) + _ = (Accumulator)(&MaxAggrAccumulator{}) + _ = (Accumulator)(&CountAggrAccumulator{}) + _ = (Accumulator)(&SumAggrAccumulator{}) + _ = (Accumulator)(&AvgAggrAccumulator{}) + _ = (operators.Operator)(&AggrExec{}) +) + +// Min +//Max +//Count +// Sum +// Avg + +// for now just focus on single-column aggregation without group by +type AggregateFunctions struct { + AggrFunc AggrFunc // switch to deal with seperate aggregation functions + Child Expr.Expression // resolves to a column generally +} +type Accumulator interface { + Update(value float64) + Finalize() float64 +} + +func newMinAggr() Accumulator { + return &MinAggrAccumulator{} +} + +type MinAggrAccumulator struct { + minV float64 + firstValue bool +} + +func (m *MinAggrAccumulator) Update(value float64) { + if !m.firstValue { + m.minV = value + m.firstValue = true + return + } + m.minV = min(m.minV, value) + +} +func (m *MinAggrAccumulator) Finalize() float64 { return m.minV } +func newMaxAggr() Accumulator { + return &MaxAggrAccumulator{} +} + +type MaxAggrAccumulator struct { + maxV float64 + firstValue bool +} + +func (m *MaxAggrAccumulator) Update(value float64) { + if !m.firstValue { + m.maxV = value + m.firstValue = true + return + } + m.maxV = max(m.maxV, value) +} +func (m *MaxAggrAccumulator) Finalize() float64 { return m.maxV } + +func NewCountAggr() Accumulator { + return &CountAggrAccumulator{} +} + +type CountAggrAccumulator struct { + count float64 +} + +func (c *CountAggrAccumulator) Update(_ float64) { + c.count++ +} +func (c *CountAggrAccumulator) Finalize() float64 { return c.count } + +func NewSumAggr() Accumulator { + return &SumAggrAccumulator{} +} + +type SumAggrAccumulator struct { + summation float64 +} + +func (s *SumAggrAccumulator) Update(value float64) { + s.summation += value +} +func (s *SumAggrAccumulator) Finalize() float64 { return s.summation } +func newAvgAggr() Accumulator { + return &AvgAggrAccumulator{} +} + +type AvgAggrAccumulator struct { + values float64 + count float64 +} + +func (a *AvgAggrAccumulator) Update(value float64) { + a.values += value + a.count++ +} +func (a *AvgAggrAccumulator) Finalize() float64 { return float64(a.values / a.count) } + +// =================== +// Aggregator Operator +// =================== +type AggrExec struct { + child operators.Operator // child operator + schema *arrow.Schema // output schema + aggExpressions []AggregateFunctions // list of wanted aggregate expressions + accumulators []Accumulator // list of accumulators corresponding to aggExpressions, these will actually work to compute the aggregation + done bool // know when to return io.EOF +} + +func NewAggrExec(child operators.Operator, aggExprs []AggregateFunctions) (*AggrExec, error) { + accs := make([]Accumulator, len(aggExprs)) + fields := make([]arrow.Field, len(aggExprs)) + for i, agg := range aggExprs { + dt, err := Expr.ExprDataType(agg.Child, child.Schema()) + if err != nil || !validAggrType(dt) { + return nil, ErrInvalidAggrColumnType(dt) + } + var fieldName string + switch agg.AggrFunc { + case Min: + fieldName = fmt.Sprintf("min_%s", agg.Child.String()) + accs[i] = newMinAggr() + case Max: + fieldName = fmt.Sprintf("max_%s", agg.Child.String()) + accs[i] = newMaxAggr() + case Count: + fieldName = fmt.Sprintf("count_%s", agg.Child.String()) + accs[i] = NewCountAggr() + case Sum: + fieldName = fmt.Sprintf("sum_%s", agg.Child.String()) + accs[i] = NewSumAggr() + case Avg: + fieldName = fmt.Sprintf("avg_%s", agg.Child.String()) + accs[i] = newAvgAggr() + + default: + return nil, ErrUnsupportedAggrFunc(int(agg.AggrFunc)) + } + fields[i] = arrow.Field{ + Name: fieldName, + Type: arrow.PrimitiveTypes.Float64, + Nullable: true, + } + } + return &AggrExec{ + child: child, + schema: arrow.NewSchema(fields, nil), + aggExpressions: aggExprs, + accumulators: accs, + }, nil +} + +// check for io.EOF with flag +// read in all record batches +// for each batch, run Expr.Evaluate, to get the column you want for the expression (cast to float64) +// +// for each element of that column grab the values you want using the accumulator interface +// +// build output batch, for now its just 1 of everything straight forward +func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { + if a.done { + return nil, io.EOF + } + for { + childBatch, err := a.child.Next(n) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, err + } + for i, aggExpr := range a.aggExpressions { + agrArray, err := Expr.EvalExpression(aggExpr.Child, childBatch) + if err != nil { + return nil, err + } + agrArray, err = castArrayToFloat64(agrArray) + if err != nil { + return nil, err + } + valueArray := agrArray.(*array.Float64) + accumulator := a.accumulators[i] + for i := 0; i < valueArray.Len(); i++ { + accumulator.Update(valueArray.Value(i)) + } + + } + } + // build array with just the result of the column + resultColumns := make([]arrow.Array, len(a.accumulators)) + for i := range a.accumulators { + resultColumns[i] = operators.NewRecordBatchBuilder().GenFloatArray(a.accumulators[i].Finalize()) + } + return &operators.RecordBatch{ + Schema: a.schema, + Columns: resultColumns, + RowCount: uint64(len(a.aggExpressions)), + }, io.EOF + // this is a pipeline breaker so it will always consume all of the input which means this needs to return an io.EOF +} + +func (a *AggrExec) Schema() *arrow.Schema { + return a.schema +} +func (a *AggrExec) Close() error { + return a.child.Close() +} + +func validAggrType(dt arrow.DataType) bool { + switch dt.ID() { + case arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64, + arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64, arrow.FLOAT16, arrow.FLOAT32, arrow.FLOAT64: + return true + default: + return false + } +} + +func castArrayToFloat64(arr arrow.Array) (arrow.Array, error) { + outDatum, err := compute.CastArray(context.TODO(), arr, compute.NewCastOptions(&arrow.Float64Type{}, true)) + if err != nil { + return nil, err + } + + return outDatum, nil +} diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go new file mode 100644 index 0000000..36fe974 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go @@ -0,0 +1,529 @@ +package aggr + +import ( + "errors" + "fmt" + "io" + "math" + "opti-sql-go/Expr" + "opti-sql-go/operators/project" + "testing" + + "github.com/apache/arrow/go/v15/arrow/memory" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +func generateAggTestColumns() ([]string, []any) { + names := []string{ + "id", + "name", + "age", + "salary", + } + + columns := []any{ + // id: 1 to 25 + []int32{ + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, + }, + + // name: 25 people + []string{ + "Alice", "Bob", "Charlie", "David", "Eve", + "Frank", "Grace", "Hannah", "Ivy", "Jake", + "Karen", "Leo", "Mona", "Nate", "Olive", + "Paul", "Quinn", "Rita", "Sam", "Tina", + "Uma", "Victor", "Wendy", "Xavier", "Yara", + }, + + // age: 25 numeric values + []int32{ + 28, 34, 45, 22, 31, + 29, 40, 36, 50, 26, + 33, 41, 27, 38, 24, + 46, 30, 35, 43, 32, + 39, 48, 29, 37, 42, + }, + + // salary: 25 numeric values + []float64{ + 70000.0, 82000.5, 54000.0, 91000.0, 60000.0, + 75000.0, 66000.0, 88000.0, 45000.0, 99000.0, + 72000.0, 81000.0, 53000.0, 86000.0, 64000.0, + 93000.0, 68000.0, 76000.0, 89000.0, 71000.0, + 83000.0, 94000.0, 55000.0, 87000.0, 91500.0, + }, + } + + return names, columns +} +func aggProject() *project.InMemorySource { + names, cols := generateAggTestColumns() + p, _ := project.NewInMemoryProjectExec(names, cols) + return p +} + +func col(name string) Expr.Expression { + return Expr.NewColumnResolve(name) +} + +func TestNewAggrExec(t *testing.T) { + + // ----------------------------------------------------------------- + t.Run("valid_single_min", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}, + } + + exec, err := NewAggrExec(child, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if exec.Schema().NumFields() != 1 { + t.Fatalf("expected 1 schema field, got %d", exec.Schema().NumFields()) + } + + expectedName := "min_Column(age)" + if exec.Schema().Field(0).Name != expectedName { + t.Fatalf("expected name %s, got %s", + expectedName, exec.Schema().Field(0).Name) + } + }) + + // ----------------------------------------------------------------- + t.Run("multiple_aggregations_schema_names", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("id")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Avg, Child: col("age")}, + } + + exec, err := NewAggrExec(child, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + schema := exec.Schema() + + expected := []string{ + "min_Column(id)", + "max_Column(salary)", + "avg_Column(age)", + } + + for i, f := range schema.Fields() { + if f.Name != expected[i] { + t.Fatalf("expected field %s, got %s", expected[i], f.Name) + } + } + }) + + // ----------------------------------------------------------------- + t.Run("invalid_type_detection_string_column", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("name")}, // "name" is string → invalid + } + + _, err := NewAggrExec(child, agg) + if err == nil { + t.Fatalf("expected type error, got nil") + } + t.Logf("================\n invalid column err %v \n ============", err) + }) + + // ----------------------------------------------------------------- + t.Run("unsupported_aggregate_function", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: 9999, Child: col("age")}, + } + + _, err := NewAggrExec(child, agg) + if err == nil { + t.Fatalf("expected unsupported aggr error") + } + }) + + // ----------------------------------------------------------------- + t.Run("schema_type_float64_for_all_numeric_aggs", func(t *testing.T) { + child := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("id")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Sum, Child: col("age")}, + {AggrFunc: Avg, Child: col("salary")}, + {AggrFunc: Count, Child: col("age")}, + } + + exec, err := NewAggrExec(child, agg) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + for _, f := range exec.Schema().Fields() { + if f.Type.ID() != arrow.FLOAT64 { + t.Fatalf("expected float64 output type, got %s", f.Type) + } + } + if err := exec.Close(); err != nil { + t.Fatalf("unexpected close error: %v", err) + } + }) + + // ----------------------------------------------------------------- + t.Run("check_all_valid_numeric_types_pass", func(t *testing.T) { + + // all numeric arrow types accepted by validAggrType() + validTypes := []arrow.DataType{ + arrow.PrimitiveTypes.Uint8, + arrow.PrimitiveTypes.Uint16, + arrow.PrimitiveTypes.Uint32, + arrow.PrimitiveTypes.Uint64, + arrow.PrimitiveTypes.Int8, + arrow.PrimitiveTypes.Int16, + arrow.PrimitiveTypes.Int32, + arrow.PrimitiveTypes.Int64, + arrow.PrimitiveTypes.Float32, + arrow.PrimitiveTypes.Float64, + } + + fieldNames := make([]string, len(validTypes)) + colData := make([]any, len(validTypes)) + + for i, dt := range validTypes { + name := fmt.Sprintf("col_%d", i) + fieldNames[i] = name + + switch dt.ID() { + case arrow.UINT8: + colData[i] = []uint8{1} + case arrow.UINT16: + colData[i] = []uint16{1} + case arrow.UINT32: + colData[i] = []uint32{1} + case arrow.UINT64: + colData[i] = []uint64{1} + case arrow.INT8: + colData[i] = []int8{1} + case arrow.INT16: + colData[i] = []int16{1} + case arrow.INT32: + colData[i] = []int32{1} + case arrow.INT64: + colData[i] = []int64{1} + case arrow.FLOAT16: + // float16 stored as float32 in Go + colData[i] = []float32{1} + case arrow.FLOAT32: + colData[i] = []float32{1} + case arrow.FLOAT64: + colData[i] = []float64{1} + } + } + + src, _ := project.NewInMemoryProjectExec(fieldNames, colData) + + for i := range fieldNames { + agg := []AggregateFunctions{ + {AggrFunc: Sum, Child: col(fieldNames[i])}, + } + + _, err := NewAggrExec(src, agg) + if err != nil { + t.Fatalf("unexpected error for type %s: %v", validTypes[i], err) + } + } + }) +} + +func TestCastArrayToFloat64(t *testing.T) { + + alloc := memory.NewGoAllocator + + // -------------------------------------------------------- + t.Run("cast_int32_to_float64", func(t *testing.T) { + b := array.NewInt32Builder(alloc()) + b.AppendValues([]int32{1, 2, 3, 4}, nil) + arr := b.NewArray() + + out, err := castArrayToFloat64(arr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + farr, ok := out.(*array.Float64) + if !ok { + t.Fatalf("expected Float64 array, got %T", out) + } + + expected := []float64{1, 2, 3, 4} + for i := range expected { + if farr.Value(i) != expected[i] { + t.Fatalf("expected %v at %d, got %v", expected[i], i, farr.Value(i)) + } + } + }) + + // -------------------------------------------------------- + t.Run("cast_float32_to_float64", func(t *testing.T) { + b := array.NewFloat32Builder(alloc()) + b.AppendValues([]float32{10.5, 20.5, 30.5}, nil) + arr := b.NewArray() + + out, err := castArrayToFloat64(arr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + farr, ok := out.(*array.Float64) + if !ok { + t.Fatalf("expected Float64 array, got %T", out) + } + + expected := []float64{10.5, 20.5, 30.5} + for i := range expected { + if farr.Value(i) != expected[i] { + t.Fatalf("expected %v at %d, got %v", expected[i], i, farr.Value(i)) + } + } + }) + + // -------------------------------------------------------- + t.Run("invalid_string_cast", func(t *testing.T) { + b := array.NewStringBuilder(alloc()) + b.AppendValues([]string{"a", "b", "c"}, nil) + arr := b.NewArray() + + _, err := castArrayToFloat64(arr) + if err == nil { + t.Fatalf("expected error when casting string array to float64") + } + }) + + // -------------------------------------------------------- + t.Run("empty_array_cast", func(t *testing.T) { + b := array.NewInt32Builder(alloc()) + // no values appended + arr := b.NewArray() + + out, err := castArrayToFloat64(arr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + _, ok := out.(*array.Float64) + if !ok { + t.Fatalf("expected Float64 array for empty cast, got %T", out) + } + + if out.Len() != 0 { + t.Fatalf("expected empty array, got length %d", out.Len()) + } + }) + +} + +func TestAggregateExecNext(t *testing.T) { + t.Run("validating done case early", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("id")}} + aggrExec, err := NewAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + aggrExec.done = true + _, err = aggrExec.Next(10) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF error, got nil") + } + }) + t.Run("Aggr minimum value on age", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}} + aggrExec, err := NewAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resultBatch, err := aggrExec.Next(100) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF error, got nil") + } + t.Logf("record batch: %v\n", resultBatch) + if resultBatch.Columns[0].(*array.Float64).Value(0) != 22 { + t.Fatalf("expected minimum age 22, got %v", resultBatch.Columns[0].(*array.Float64).Value(0)) + } + + }) + t.Run("Aggr maximum salary", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + {AggrFunc: Max, Child: col("salary")}, + } + + aggrExec, err := NewAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, err := aggrExec.Next(100) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got %v", err) + } + + maxSalary := resultBatch.Columns[0].(*array.Float64).Value(0) + if maxSalary != 99000.0 && maxSalary != 94000.0 && maxSalary != 93000.0 { + // Real max is 99000 (Jake has 99000) + t.Fatalf("expected max salary 99000, got %v", maxSalary) + } + }) + t.Run("Aggr sum of id column", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("id")}, + } + + aggrExec, err := NewAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, err := aggrExec.Next(200) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got %v", err) + } + + sumIDs := resultBatch.Columns[0].(*array.Float64).Value(0) + expected := float64((25 * 26) / 2) // sum(1..25) = 325 + if sumIDs != expected { + t.Fatalf("expected sum 325, got %v", sumIDs) + } + }) + t.Run("Aggr count of age column", func(t *testing.T) { + proj := aggProject() + agg := []AggregateFunctions{ + {AggrFunc: Count, Child: col("age")}, + } + + aggrExec, err := NewAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, err := aggrExec.Next(300) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got %v", err) + } + + count := resultBatch.Columns[0].(*array.Float64).Value(0) + if count != 25 { + t.Fatalf("expected count 25, got %v", count) + } + }) + t.Run("Aggr average of salary (⚠ your AVG is wrong)", func(t *testing.T) { + proj := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Avg, Child: col("salary")}, + } + + aggrExec, err := NewAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, err := aggrExec.Next(500) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got %v", err) + } + + avg := resultBatch.Columns[0].(*array.Float64).Value(0) + expected := 75740.02 + + if math.Abs(avg-expected) > 0.001 { + t.Fatalf("expected avg %v, got %v", expected, avg) + } + + }) + t.Run("Multiple aggregators in a single request", func(t *testing.T) { + proj := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Count, Child: col("id")}, + } + + aggrExec, err := NewAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resultBatch, err := aggrExec.Next(1000) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got %v", err) + } + + minAge := resultBatch.Columns[0].(*array.Float64).Value(0) + maxSalary := resultBatch.Columns[1].(*array.Float64).Value(0) + countIDs := resultBatch.Columns[2].(*array.Float64).Value(0) + + if minAge != 22 { + t.Fatalf("expected min age 22, got %v", minAge) + } + if maxSalary != 99000.0 { + t.Fatalf("expected max salary 99000, got %v", maxSalary) + } + if countIDs != 25 { + t.Fatalf("expected count 25, got %v", countIDs) + } + }) + + // ========================================================== + t.Run("Schema correctness for multiple aggregates", func(t *testing.T) { + proj := aggProject() + + agg := []AggregateFunctions{ + {AggrFunc: Min, Child: col("id")}, + {AggrFunc: Sum, Child: col("age")}, + {AggrFunc: Count, Child: col("salary")}, + } + + aggrExec, err := NewAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + s := aggrExec.Schema() + + expectedNames := []string{ + "min_Column(id)", + "sum_Column(age)", + "count_Column(salary)", + } + + for i, f := range s.Fields() { + if f.Name != expectedNames[i] { + t.Fatalf("expected field %s, got %s", expectedNames[i], f.Name) + } + if f.Type.ID() != arrow.FLOAT64 { + t.Fatalf("expected float64 fields only") + } + } + }) +} diff --git a/src/Backend/opti-sql-go/operators/aggr/sum.go b/src/Backend/opti-sql-go/operators/aggr/sum.go deleted file mode 100644 index abd1ad5..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/sum.go +++ /dev/null @@ -1 +0,0 @@ -package aggr diff --git a/src/Backend/opti-sql-go/operators/aggr/sum_test.go b/src/Backend/opti-sql-go/operators/aggr/sum_test.go deleted file mode 100644 index 485b9bb..0000000 --- a/src/Backend/opti-sql-go/operators/aggr/sum_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package aggr - -import "testing" - -func TestSum(t *testing.T) { - // Simple passing test -} From d4b5538da5aae2a70359af964aa1e2db1309522f Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Mon, 24 Nov 2025 23:44:31 -0500 Subject: [PATCH 02/10] Feat: ground work for group bys, Consuming child record batching and producing result are next steps --- .../opti-sql-go/operators/aggr/groupBy.go | 120 ++++++ .../operators/aggr/groupBy_test.go | 400 +++++++++++++++++- .../opti-sql-go/operators/aggr/singleAggr.go | 30 +- .../operators/aggr/singleAggr_test.go | 28 +- .../opti-sql-go/operators/aggr/sort.go | 2 + 5 files changed, 560 insertions(+), 20 deletions(-) diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy.go b/src/Backend/opti-sql-go/operators/aggr/groupBy.go index abd1ad5..4a8d24b 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy.go @@ -1 +1,121 @@ package aggr + +import ( + "fmt" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + "strings" + + "github.com/apache/arrow/go/v17/arrow" +) + +/* +rules for group by: +1.Every non-aggregated column in SELECT must be in GROUP BY +2.You can group by multiple columns - creates groups for each unique combination +3.Use HAVING to filter groups (WHERE filters before grouping, HAVING filters after) +*/ +var ( + _ = (operators.Operator)(&GroupByExec{}) +) + +// place all unique elements of the group by column into a hash table, each element gets their own Accumulator instance +type GroupByExec struct { + child operators.Operator + schema *arrow.Schema + groupExpr []AggregateFunctions + groupByExpr []Expr.Expression // column names + + groups map[string][]Accumulator // maps group by key to its accumulator + keys map[string][]any // key → original values for output + done bool +} + +func NewGroupByExec(child operators.Operator, groupExpr []AggregateFunctions, groupBy []Expr.Expression) (*GroupByExec, error) { + s, err := buildGroupBySchema(child.Schema(), groupBy, groupExpr) + if err != nil { + return nil, err + } + + return &GroupByExec{ + child: child, + schema: s, + groupExpr: groupExpr, + groupByExpr: groupBy, + keys: make(map[string][]any), + groups: make(map[string][]Accumulator), + }, nil +} +func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { + if g.done { + return nil, io.EOF + } + return nil, nil +} +func (g *GroupByExec) Schema() *arrow.Schema { + return g.schema +} +func (g *GroupByExec) Close() error { + return g.child.Close() +} + +// handles validation and building of schema for group by +func buildGroupBySchema(childSchema *arrow.Schema, groupByExpr []Expr.Expression, aggrExprs []AggregateFunctions) (*arrow.Schema, error) { + + fields := make([]arrow.Field, 0, len(groupByExpr)+len(aggrExprs)) + + // 1. Add group-by columns + for _, expr := range groupByExpr { + dt, err := Expr.ExprDataType(expr, childSchema) + if err != nil { + return nil, fmt.Errorf("group-by expr %s has invalid type: %w", expr.String(), err) + } + + fields = append(fields, arrow.Field{ + Name: fmt.Sprintf("group_%s", expr.String()), + Type: dt, + Nullable: false, + }) + } + + // 2. Add aggregate columns + for _, agg := range aggrExprs { + + // All aggregates produce float64 in your design + fieldName := fmt.Sprintf("%s_%s", + strings.ToLower(aggrToString(int(agg.AggrFunc))), + agg.Child.String(), + ) + + fields = append(fields, arrow.Field{ + Name: fieldName, + Type: arrow.PrimitiveTypes.Float64, + Nullable: false, + }) + } + + return arrow.NewSchema(fields, nil), nil +} + +/* +TODO: use this in Next loop to skip boil plate creation code +func (g *GroupByExec) createAccumulators() []Accumulator { + accumulators := make([]Accumulator, len(g.groupExpr)) + for i, expr := range g.groupExpr { + switch expr.AggrFunc { + case Min: + accumulators[i] = newMinAggr() + case Max: + accumulators[i] = newMaxAggr() + case Count: + accumulators[i] = NewCountAggr() + case Sum: + accumulators[i] = NewSumAggr() + case Avg: + accumulators[i] = newAvgAggr() + } + } + return accumulators +} +*/ diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go index 3313b3e..57855d7 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go @@ -1,7 +1,401 @@ package aggr -import "testing" +import ( + "errors" + "fmt" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators/project" + "strings" + "testing" -func TestGroupBy(t *testing.T) { - // Simple passing test + "github.com/apache/arrow/go/v17/arrow" +) + +func generateGroupByTestColumns() ([]string, []any) { + names := []string{ + "id", + "name", + "department", + "region", + "seniority", + "salary", + "age", + } + + // 40 IDs + ids := make([]int32, 40) + for i := range ids { + ids[i] = int32(i + 1) + } + + // Names – 40 names + namesArr := []string{ + "Alice", "Bob", "Charlie", "David", "Eve", + "Frank", "Grace", "Hannah", "Ivy", "Jake", + "Karen", "Leo", "Mona", "Nate", "Olive", + "Paul", "Quinn", "Rita", "Sam", "Tina", + "Uma", "Victor", "Wendy", "Xavier", "Yara", + "Zane", "Becky", "Carlos", "Dora", "Elias", + "Fiona", "Gabe", "Helena", "Isaac", "Julia", + "Kevin", "Lara", "Miles", "Nora", "Owen", + } + + // Randomized but balanced departments (5 groups) + departments := []string{ + "Engineering", "HR", "Sales", "Engineering", "Finance", + "Support", "Sales", "Engineering", "Support", "Finance", + "HR", "Engineering", "Sales", "Support", "Finance", + "Engineering", "Sales", "HR", "Support", "Engineering", + "Finance", "Sales", "Engineering", "Support", "HR", + "Support", "Engineering", "Finance", "Sales", "HR", + "Engineering", "Support", "Finance", "Sales", "Engineering", + "HR", "Finance", "Support", "Engineering", "Sales", + } + + // Randomized but balanced regions (4 groups) + regions := []string{ + "North", "East", "South", "West", "South", + "North", "West", "East", "North", "South", + "West", "East", "North", "South", "West", + "North", "East", "West", "South", "North", + "East", "West", "South", "North", "East", + "South", "North", "West", "East", "South", + "West", "North", "East", "South", "West", + "North", "South", "East", "West", "North", + } + + // Randomized seniority (3 groups) + seniority := []string{ + "Junior", "Senior", "Mid", "Junior", "Mid", + "Senior", "Junior", "Mid", "Senior", "Junior", + "Mid", "Senior", "Junior", "Mid", "Senior", + "Junior", "Mid", "Senior", "Junior", "Mid", + "Senior", "Junior", "Mid", "Senior", "Junior", + "Mid", "Senior", "Junior", "Mid", "Senior", + "Junior", "Mid", "Senior", "Junior", "Mid", + "Senior", "Junior", "Mid", "Senior", "Junior", + } + + // Salaries (same as before) + salaries := []float64{ + 70000, 82000, 54000, 91000, 60000, + 75000, 66000, 88000, 45000, 99000, + 72000, 81000, 53000, 86000, 64000, + 93000, 68000, 76000, 89000, 71000, + 83000, 94000, 55000, 87000, 91500, + 72000, 69000, 58000, 84000, 79000, + 81000, 78000, 62000, 97000, 82000, + 95000, 76000, 88000, 91000, 64000, + } + + // Ages with some repetition + ages := []int32{ + 28, 34, 45, 22, 31, + 29, 40, 36, 50, 26, + 33, 41, 27, 38, 24, + 46, 30, 35, 43, 32, + 39, 48, 29, 37, 42, + 28, 34, 45, 22, 31, + 29, 40, 36, 50, 26, + 39, 48, 29, 37, 42, + } + + columns := []any{ + ids, + namesArr, + departments, + regions, + seniority, + salaries, + ages, + } + + return names, columns +} + +func groupByProject() *project.InMemorySource { + names, cols := generateGroupByTestColumns() + p, _ := project.NewInMemoryProjectExec(names, cols) + return p +} + +func TestGroupByInit(t *testing.T) { + p := groupByProject() + rc, _ := p.Next(12) + fmt.Printf("rc:%v \n", rc) +} + +func TestNewGroupByExecAndSchema(t *testing.T) { + // convenience builder + col := func(name string) Expr.Expression { + return Expr.NewColumnResolve(name) + } + + t.Run("single group-by single aggregate", func(t *testing.T) { + child := groupByProject() + + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + schema := gb.Schema() + if schema == nil { + t.Fatalf("schema should not be nil") + } + fmt.Println(schema) + + // group-by + 1 agg = 2 fields + if got, want := schema.NumFields(), 2; got != want { + t.Fatalf("expected %d fields, got %d", want, got) + } + + // group field + f0 := schema.Field(0) + expName := "group_" + groupBy[0].String() + if f0.Name != expName { + t.Fatalf("expected group field name %q, got %q", expName, f0.Name) + } + + // aggregate field + f1 := schema.Field(1) + properAggName := fmt.Sprintf("%s_%s", + strings.ToLower(aggrToString(int(aggs[0].AggrFunc))), + aggs[0].Child.String(), + ) + if f1.Name != properAggName { + t.Fatalf("expected agg field %q, got %q", properAggName, f1.Name) + } + + if gb.groups == nil { + t.Fatalf("groups map not initialized") + } + if gb.keys == nil { + t.Fatalf("keys map not initialized") + } + }) + + t.Run("multiple group-by and multiple aggregates", func(t *testing.T) { + child := groupByProject() + + groupBy := []Expr.Expression{col("region"), col("seniority")} + aggs := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + schema := gb.Schema() + fmt.Printf("schema: %v\n", schema) + wantFields := len(groupBy) + len(aggs) + if schema.NumFields() != wantFields { + t.Fatalf("expected %d fields, got %d", wantFields, schema.NumFields()) + } + + // group fields first + for i, gexpr := range groupBy { + f := schema.Field(i) + exp := "group_" + gexpr.String() + if f.Name != exp { + t.Fatalf("group field[%d] mismatch: want %q got %q", i, exp, f.Name) + } + } + + // aggregate fields next + offset := len(groupBy) + for j, agg := range aggs { + f := schema.Field(offset + j) + expAggName := fmt.Sprintf("%s_%s", + strings.ToLower(aggrToString(int(agg.AggrFunc))), + agg.Child.String(), + ) + if f.Name != expAggName { + t.Fatalf("agg field name mismatch: want %q got %q", expAggName, f.Name) + } + } + }) + + t.Run("invalid group-by column triggers error", func(t *testing.T) { + child := groupByProject() + + invalidGB := []Expr.Expression{col("not_a_col")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + // direct schema builder test + _, err := buildGroupBySchema(child.Schema(), invalidGB, aggs) + if err == nil { + t.Fatalf("expected error for invalid group-by expr") + } + + // NewGroupByExec should also fail + if _, err := NewGroupByExec(child, aggs, invalidGB); err == nil { + t.Fatalf("expected NewGroupByExec error for invalid group-by") + } + }) + + t.Run("no aggregates - schema should only contain group-by columns", func(t *testing.T) { + child := groupByProject() + + groupBy := []Expr.Expression{col("region")} + var aggs []AggregateFunctions + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + schema := gb.Schema() + + if schema.NumFields() != 1 { + t.Fatalf("expected 1 field, got %d", schema.NumFields()) + } + + f := schema.Field(0) + exp := "group_" + groupBy[0].String() + if f.Name != exp { + t.Fatalf("wrong group field name: want %q got %q", exp, f.Name) + } + }) + + t.Run("multiple aggregates produce float64 regardless of source type", func(t *testing.T) { + child := groupByProject() + + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Avg, Child: col("age")}, // int32 → float64 + {AggrFunc: Sum, Child: col("salary")}, // float64 → float64 + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + schema := gb.Schema() + + // group-by + 2 aggregates = 3 + if schema.NumFields() != 3 { + t.Fatalf("expected 3 fields, got %d", schema.NumFields()) + } + + for idx := 1; idx < 3; idx++ { + f := schema.Field(idx) + if f.Type.ID() != arrow.FLOAT64 { + t.Fatalf("expected field[%d] to be float64, got %v", idx, f.Type) + } + } + }) + + t.Run("schema names must match exact string() output of expressions", func(t *testing.T) { + child := groupByProject() + + gbExpr := []Expr.Expression{ + Expr.NewColumnResolve("seniority"), + Expr.NewColumnResolve("region"), + } + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: Expr.NewColumnResolve("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + schema := gb.Schema() + + expected0 := "group_" + gbExpr[0].String() // group_Column(seniority) + expected1 := "group_" + gbExpr[1].String() // group_Column(region) + + if schema.Field(0).Name != expected0 { + t.Fatalf("wrong field[0] name: want %q got %q", expected0, schema.Field(0).Name) + } + if schema.Field(1).Name != expected1 { + t.Fatalf("wrong field[1] name: want %q got %q", expected1, schema.Field(1).Name) + } + + // count column + expectedAgg := "count_" + aggs[0].Child.String() + if schema.Field(2).Name != expectedAgg { + t.Fatalf("wrong agg field name: want %q got %q", expectedAgg, schema.Field(2).Name) + } + }) + t.Run("basic close check", func(t *testing.T) { + child := groupByProject() + + gbExpr := []Expr.Expression{ + Expr.NewColumnResolve("seniority"), + Expr.NewColumnResolve("region"), + } + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: Expr.NewColumnResolve("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if gb.Close() != nil { + t.Fatalf("unexpected error on close") + } + + }) +} +func TestBasicOperatorCasesGroupBy(t *testing.T) { + + t.Run("basic close check", func(t *testing.T) { + child := groupByProject() + + gbExpr := []Expr.Expression{ + Expr.NewColumnResolve("seniority"), + Expr.NewColumnResolve("region"), + } + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: Expr.NewColumnResolve("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if gb.Close() != nil { + t.Fatalf("unexpected error on close") + } + + }) + t.Run("done case", func(t *testing.T) { + child := groupByProject() + + gbExpr := []Expr.Expression{ + Expr.NewColumnResolve("seniority"), + Expr.NewColumnResolve("region"), + } + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: Expr.NewColumnResolve("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + gb.done = true + _, err = gb.Next(100) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF but recieved %v", err) + } + + }) } diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go index f59da08..6ac7cad 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -49,6 +49,13 @@ var ( // Avg // for now just focus on single-column aggregation without group by +func NewAggregateFunctions(aggrFunc AggrFunc, child Expr.Expression) AggregateFunctions { + return AggregateFunctions{ + AggrFunc: aggrFunc, + Child: child, + } +} + type AggregateFunctions struct { AggrFunc AggrFunc // switch to deal with seperate aggregation functions Child Expr.Expression // resolves to a column generally @@ -147,7 +154,7 @@ type AggrExec struct { done bool // know when to return io.EOF } -func NewAggrExec(child operators.Operator, aggExprs []AggregateFunctions) (*AggrExec, error) { +func NewGlobalAggrExec(child operators.Operator, aggExprs []AggregateFunctions) (*AggrExec, error) { accs := make([]Accumulator, len(aggExprs)) fields := make([]arrow.Field, len(aggExprs)) for i, agg := range aggExprs { @@ -220,8 +227,8 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { } valueArray := agrArray.(*array.Float64) accumulator := a.accumulators[i] - for i := 0; i < valueArray.Len(); i++ { - accumulator.Update(valueArray.Value(i)) + for j := 0; j < valueArray.Len(); j++ { + accumulator.Update(valueArray.Value(j)) } } @@ -231,6 +238,7 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { for i := range a.accumulators { resultColumns[i] = operators.NewRecordBatchBuilder().GenFloatArray(a.accumulators[i].Finalize()) } + a.done = true return &operators.RecordBatch{ Schema: a.schema, Columns: resultColumns, @@ -264,3 +272,19 @@ func castArrayToFloat64(arr arrow.Array) (arrow.Array, error) { return outDatum, nil } +func aggrToString(t int) string { + switch AggrFunc(t) { + case Min: + return "MIN" + case Max: + return "MAX" + case Count: + return "COUNT" + case Sum: + return "SUM" + case Avg: + return "AVG" + default: + return "UNKNOWN_AGGREGATE_FUNCTION" + } +} diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go index 36fe974..ea89bac 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go @@ -82,7 +82,7 @@ func TestNewAggrExec(t *testing.T) { {AggrFunc: Min, Child: col("age")}, } - exec, err := NewAggrExec(child, agg) + exec, err := NewGlobalAggrExec(child, agg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -108,7 +108,7 @@ func TestNewAggrExec(t *testing.T) { {AggrFunc: Avg, Child: col("age")}, } - exec, err := NewAggrExec(child, agg) + exec, err := NewGlobalAggrExec(child, agg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -136,7 +136,7 @@ func TestNewAggrExec(t *testing.T) { {AggrFunc: Min, Child: col("name")}, // "name" is string → invalid } - _, err := NewAggrExec(child, agg) + _, err := NewGlobalAggrExec(child, agg) if err == nil { t.Fatalf("expected type error, got nil") } @@ -151,7 +151,7 @@ func TestNewAggrExec(t *testing.T) { {AggrFunc: 9999, Child: col("age")}, } - _, err := NewAggrExec(child, agg) + _, err := NewGlobalAggrExec(child, agg) if err == nil { t.Fatalf("expected unsupported aggr error") } @@ -169,7 +169,7 @@ func TestNewAggrExec(t *testing.T) { {AggrFunc: Count, Child: col("age")}, } - exec, err := NewAggrExec(child, agg) + exec, err := NewGlobalAggrExec(child, agg) if err != nil { t.Fatalf("unexpected: %v", err) } @@ -242,7 +242,7 @@ func TestNewAggrExec(t *testing.T) { {AggrFunc: Sum, Child: col(fieldNames[i])}, } - _, err := NewAggrExec(src, agg) + _, err := NewGlobalAggrExec(src, agg) if err != nil { t.Fatalf("unexpected error for type %s: %v", validTypes[i], err) } @@ -342,7 +342,7 @@ func TestAggregateExecNext(t *testing.T) { proj := aggProject() agg := []AggregateFunctions{ {AggrFunc: Min, Child: col("id")}} - aggrExec, err := NewAggrExec(proj, agg) + aggrExec, err := NewGlobalAggrExec(proj, agg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -356,7 +356,7 @@ func TestAggregateExecNext(t *testing.T) { proj := aggProject() agg := []AggregateFunctions{ {AggrFunc: Min, Child: col("age")}} - aggrExec, err := NewAggrExec(proj, agg) + aggrExec, err := NewGlobalAggrExec(proj, agg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -376,7 +376,7 @@ func TestAggregateExecNext(t *testing.T) { {AggrFunc: Max, Child: col("salary")}, } - aggrExec, err := NewAggrExec(proj, agg) + aggrExec, err := NewGlobalAggrExec(proj, agg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -398,7 +398,7 @@ func TestAggregateExecNext(t *testing.T) { {AggrFunc: Sum, Child: col("id")}, } - aggrExec, err := NewAggrExec(proj, agg) + aggrExec, err := NewGlobalAggrExec(proj, agg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -420,7 +420,7 @@ func TestAggregateExecNext(t *testing.T) { {AggrFunc: Count, Child: col("age")}, } - aggrExec, err := NewAggrExec(proj, agg) + aggrExec, err := NewGlobalAggrExec(proj, agg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -442,7 +442,7 @@ func TestAggregateExecNext(t *testing.T) { {AggrFunc: Avg, Child: col("salary")}, } - aggrExec, err := NewAggrExec(proj, agg) + aggrExec, err := NewGlobalAggrExec(proj, agg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -469,7 +469,7 @@ func TestAggregateExecNext(t *testing.T) { {AggrFunc: Count, Child: col("id")}, } - aggrExec, err := NewAggrExec(proj, agg) + aggrExec, err := NewGlobalAggrExec(proj, agg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -504,7 +504,7 @@ func TestAggregateExecNext(t *testing.T) { {AggrFunc: Count, Child: col("salary")}, } - aggrExec, err := NewAggrExec(proj, agg) + aggrExec, err := NewGlobalAggrExec(proj, agg) if err != nil { t.Fatalf("unexpected: %v", err) } diff --git a/src/Backend/opti-sql-go/operators/aggr/sort.go b/src/Backend/opti-sql-go/operators/aggr/sort.go index abd1ad5..d5a469b 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort.go @@ -1 +1,3 @@ package aggr + +// order by col asc, col 2 desc .... ect From be9c8131b91992cbee9cf4a68d7579027cbcc9e7 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Tue, 25 Nov 2025 15:37:02 -0500 Subject: [PATCH 03/10] feat: Implement agrregations with dynamic group by clause --- .../opti-sql-go/operators/aggr/groupBy.go | 364 ++++++++++++++++-- .../operators/aggr/groupBy_test.go | 292 +++++++++++++- .../opti-sql-go/operators/aggr/singleAggr.go | 58 +-- .../operators/aggr/singleAggr_test.go | 130 +++++-- .../opti-sql-go/operators/project/custom.go | 29 ++ src/Backend/opti-sql-go/operators/record.go | 113 ++++++ 6 files changed, 908 insertions(+), 78 deletions(-) diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy.go b/src/Backend/opti-sql-go/operators/aggr/groupBy.go index 4a8d24b..686ae3a 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy.go @@ -1,13 +1,16 @@ package aggr import ( + "errors" "fmt" "io" "opti-sql-go/Expr" "opti-sql-go/operators" "strings" + "github.com/apache/arrow/go/v15/arrow/memory" "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" ) /* @@ -27,8 +30,8 @@ type GroupByExec struct { groupExpr []AggregateFunctions groupByExpr []Expr.Expression // column names - groups map[string][]Accumulator // maps group by key to its accumulator - keys map[string][]any // key → original values for output + groups map[string][]accumulator // maps group by key to its accumulator + keys map[string][]string // key → original values for output done bool } @@ -43,16 +46,95 @@ func NewGroupByExec(child operators.Operator, groupExpr []AggregateFunctions, gr schema: s, groupExpr: groupExpr, groupByExpr: groupBy, - keys: make(map[string][]any), - groups: make(map[string][]Accumulator), + keys: make(map[string][]string), + groups: make(map[string][]accumulator), }, nil } + +/* +grab child rows +*/ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { if g.done { return nil, io.EOF } - return nil, nil + + for { + childBatch, err := g.child.Next(batchSize) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, err + } + + rowCount := int(childBatch.RowCount) + + // 1. evaluate all group-by expressions into arrays + groupArrays := make([]arrow.Array, len(g.groupByExpr)) + for i, expr := range g.groupByExpr { + arr, err := Expr.EvalExpression(expr, childBatch) + if err != nil { + return nil, err + } + groupArrays[i] = arr + } + + // 2. evaluate all aggregation child expressions + aggrArrays := make([]arrow.Array, len(g.groupExpr)) + for i, agg := range g.groupExpr { + arr, err := Expr.EvalExpression(agg.Child, childBatch) + if err != nil { + return nil, err + } + arr, err = castArrayToFloat64(arr) + if err != nil { + return nil, err + } + aggrArrays[i] = arr + } + + // 3. process rows + for row := 0; row < rowCount; row++ { + + // Build group key + keyParts := make([]string, len(groupArrays)) + for j, arr := range groupArrays { + if arr.IsNull(row) { + keyParts[j] = "NULL" + } else { + keyParts[j] = fmt.Sprintf("%v", getValue(arr, row)) + } + } + key := strings.Join(keyParts, "|") + fmt.Printf("key: %v\n", key) + // Allocate accumulator list if new group + if _, exists := g.groups[key]; !exists { + g.groups[key] = make([]accumulator, len(g.groupExpr)) + for i, agg := range g.groupExpr { + g.groups[key][i] = createAccumulator(agg.AggrFunc) + } + g.keys[key] = keyParts // store original values + } + + // UPDATE accumulators + for i, arr := range aggrArrays { + if arr.IsNull(row) { + continue + } + val := arr.(*array.Float64).Value(row) + g.groups[key][i].Update(val) + } + } + } + + // 4. Build output RecordBatch + batch := buildGroupByOutput(g) + + g.done = true + return batch, io.EOF } + func (g *GroupByExec) Schema() *arrow.Schema { return g.schema } @@ -81,8 +163,11 @@ func buildGroupBySchema(childSchema *arrow.Schema, groupByExpr []Expr.Expression // 2. Add aggregate columns for _, agg := range aggrExprs { - - // All aggregates produce float64 in your design + dt, err := Expr.ExprDataType(agg.Child, childSchema) + if err != nil || !validAggrType(dt) { + return nil, ErrInvalidAggrColumnType(dt) + } + // All aggregates produce float64 fieldName := fmt.Sprintf("%s_%s", strings.ToLower(aggrToString(int(agg.AggrFunc))), agg.Child.String(), @@ -98,24 +183,253 @@ func buildGroupBySchema(childSchema *arrow.Schema, groupByExpr []Expr.Expression return arrow.NewSchema(fields, nil), nil } -/* -TODO: use this in Next loop to skip boil plate creation code -func (g *GroupByExec) createAccumulators() []Accumulator { - accumulators := make([]Accumulator, len(g.groupExpr)) - for i, expr := range g.groupExpr { - switch expr.AggrFunc { - case Min: - accumulators[i] = newMinAggr() - case Max: - accumulators[i] = newMaxAggr() - case Count: - accumulators[i] = NewCountAggr() - case Sum: - accumulators[i] = NewSumAggr() - case Avg: - accumulators[i] = newAvgAggr() +func getValue(arr arrow.Array, row int) any { + switch col := arr.(type) { + case *array.Int32: + return col.Value(row) + case *array.Int64: + return col.Value(row) + case *array.Float32: + return col.Value(row) + case *array.Float64: + return col.Value(row) + case *array.String: + return col.Value(row) + case *array.Boolean: + return col.Value(row) + default: + // fallback – debug only + return fmt.Sprintf("%v", col) + } +} +func createAccumulator(fn AggrFunc) accumulator { + switch fn { + case Min: + return newMinAggr() + case Max: + return newMaxAggr() + case Sum: + return NewSumAggr() + case Count: + return NewCountAggr() + case Avg: + return newAvgAggr() + default: + panic(fmt.Sprintf("unsupported aggregate function: %v", fn)) + } +} + +func buildGroupByOutput(g *GroupByExec) *operators.RecordBatch { + alloc := memory.NewGoAllocator() + + rowCount := len(g.groups) + if rowCount == 0 { + // return empty batch (0 groups) + return &operators.RecordBatch{ + Schema: g.schema, + Columns: []arrow.Array{}, + RowCount: 0, + } + } + + // Prepare column builders + colBuilders := make([]arrow.Array, len(g.schema.Fields())) + + // Temporary storage for columns + groupCols := make([][]any, len(g.groupByExpr)) // group columns + aggrCols := make([][]float64, len(g.groupExpr)) // aggregate columns + + for i := range groupCols { + groupCols[i] = make([]any, 0, rowCount) + } + for i := range aggrCols { + aggrCols[i] = make([]float64, 0, rowCount) + } + + // Iterate groups in stable order + i := 0 + for key, accs := range g.groups { + // Add group-by (dimension) values + dims := g.keys[key] + for j, v := range dims { + groupCols[j] = append(groupCols[j], v) + } + + // Add aggregated values + for j, acc := range accs { + aggrCols[j] = append(aggrCols[j], acc.Finalize()) } + + i++ + } + + // Now build Arrow arrays in correct schema order + fieldIndex := 0 + + // Build group-by columns first + for j := range g.groupByExpr { + colBuilders[fieldIndex] = buildDynamicArray(alloc, g.schema.Field(fieldIndex).Type, groupCols[j]) + fieldIndex++ + } + + // Build aggregate columns + for j := range g.groupExpr { + colBuilders[fieldIndex] = buildFloatArray(alloc, aggrCols[j]) + fieldIndex++ + } + + return &operators.RecordBatch{ + Schema: g.schema, + Columns: colBuilders, + RowCount: uint64(rowCount), } - return accumulators } -*/ +func buildDynamicArray(mem memory.Allocator, dt arrow.DataType, values []any) arrow.Array { + switch dt.ID() { + + // =========================== + // STRING (UTF8) + // =========================== + case arrow.STRING: + sb := array.NewStringBuilder(mem) + for _, v := range values { + if v == nil { + sb.AppendNull() + } else { + sb.Append(fmt.Sprintf("%v", v)) + } + } + return sb.NewArray() + + // =========================== + // SIGNED INTEGERS + // =========================== + case arrow.INT8: + b := array.NewInt8Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(int8)) + } + } + return b.NewArray() + + case arrow.INT16: + b := array.NewInt16Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(int16)) + } + } + return b.NewArray() + + case arrow.INT32: + b := array.NewInt32Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(int32)) + } + } + return b.NewArray() + + case arrow.INT64: + b := array.NewInt64Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(int64)) + } + } + return b.NewArray() + + // =========================== + // UNSIGNED INTEGERS + // =========================== + case arrow.UINT8: + b := array.NewUint8Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(uint8)) + } + } + return b.NewArray() + + case arrow.UINT16: + b := array.NewUint16Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(uint16)) + } + } + return b.NewArray() + + case arrow.UINT32: + b := array.NewUint32Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(uint32)) + } + } + return b.NewArray() + + case arrow.UINT64: + b := array.NewUint64Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(uint64)) + } + } + return b.NewArray() + + // =========================== + // FLOATS + // =========================== + case arrow.FLOAT32: + b := array.NewFloat32Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(float32)) + } + } + return b.NewArray() + + case arrow.FLOAT64: + b := array.NewFloat64Builder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(v.(float64)) + } + } + return b.NewArray() + + // =========================== + // UNSUPPORTED TYPE + // =========================== + default: + panic(fmt.Sprintf("unsupported dynamic array type: %v", dt)) + } +} + +func buildFloatArray(mem memory.Allocator, values []float64) arrow.Array { + b := array.NewFloat64Builder(mem) + b.AppendValues(values, nil) + return b.NewArray() +} diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go index 57855d7..23803dc 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go @@ -9,7 +9,9 @@ import ( "strings" "testing" + "github.com/apache/arrow/go/v15/arrow/memory" "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" ) func generateGroupByTestColumns() ([]string, []any) { @@ -394,8 +396,296 @@ func TestBasicOperatorCasesGroupBy(t *testing.T) { gb.done = true _, err = gb.Next(100) if err == nil || !errors.Is(err, io.EOF) { - t.Fatalf("expected EOF but recieved %v", err) + t.Fatalf("expected EOF but received %v", err) } }) } +func TestGroupByNext_SingleColumnCount(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + + child := groupByProject() + + gbExpr := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + batch, err := gb.Next(1000) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF, got %v", err) + } + + if batch == nil || batch.RowCount == 0 { + t.Fatalf("expected non-empty grouped result") + } + + // Validate schema + if batch.Schema.NumFields() != 2 { + t.Fatalf("expected 2 fields, got %d", batch.Schema.NumFields()) + } + + // Validate that group keys exist and aggregates exist + if batch.Columns[0].Len() == 0 { + t.Fatalf("expected region groups") + } + + if batch.Columns[1].Len() == 0 { + t.Fatalf("expected aggregated counts") + } +} + +func TestGroupByNext_MultipleGroupBy_MultipleAggs(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + + child := groupByProject() + + gbExpr := []Expr.Expression{ + col("seniority"), + col("region"), + } + + aggs := []AggregateFunctions{ + {AggrFunc: Min, Child: col("age")}, + {AggrFunc: Max, Child: col("salary")}, + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatal(err) + } + + batch, err := gb.Next(50) + if err == nil || !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF, got %v", err) + } + + if batch.RowCount == 0 { + t.Fatalf("expected non-zero grouped rows") + } + + if batch.Schema.NumFields() != 5 { + t.Fatalf("expected 5 fields (2 group-by + 3 aggr), got %d", batch.Schema.NumFields()) + } +} + +func TestGroupByNext_MultipleNextCalls(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + + child := groupByProject() + + gbExpr := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatal(err) + } + + // First call returns batch + EOF + _, err = gb.Next(100) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF on first return, got %v", err) + } + + // Second call MUST return EOF immediately + _, err = gb.Next(100) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF on second call, got %v", err) + } +} + +func TestBuildGroupBySchema_AllBranches(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + + child := groupByProject() + + groupBy := []Expr.Expression{col("region"), col("seniority")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + {AggrFunc: Count, Child: col("id")}, + } + + schema, err := buildGroupBySchema(child.Schema(), groupBy, aggs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if schema.NumFields() != 4 { + t.Fatalf("expected 4 fields got %d", schema.NumFields()) + } + + // test group-by fields + if schema.Field(0).Type.ID() != arrow.STRING { + t.Fatalf("expected STRING for region") + } + + // aggregated fields always float64 + if schema.Field(2).Type.ID() != arrow.FLOAT64 { + t.Fatalf("expected FLOAT64 for aggregate field") + } +} + +func TestBuildGroupBySchema_InvalidColumn(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + child := groupByProject() + + _, err := buildGroupBySchema(child.Schema(), []Expr.Expression{col("doesnotexist")}, nil) + if err == nil { + t.Fatalf("expected error but got none") + } +} + +func TestBuildGroupBySchema_InvalidAggType(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + child := groupByProject() + + aggs := []AggregateFunctions{ + // Boolean type or unsupported type + {AggrFunc: Sum, Child: col("name")}, + } + + _, err := buildGroupBySchema(child.Schema(), nil, aggs) + if err == nil { + t.Fatalf("expected invalid agg type error") + } +} +func TestGetValue_AllTypes(t *testing.T) { + mem := memory.NewGoAllocator() + + // int32 + i32 := array.NewInt32Builder(mem) + i32.Append(42) + arr32 := i32.NewArray() + if getValue(arr32, 0).(int32) != 42 { + t.Fatal("failed int32 case") + } + + // int64 + i64 := array.NewInt64Builder(mem) + i64.Append(99) + arr64 := i64.NewArray() + if getValue(arr64, 0).(int64) != 99 { + t.Fatal("failed int64 case") + } + + // float32 + f32 := array.NewFloat32Builder(mem) + f32.Append(3.5) + arrf32 := f32.NewArray() + if getValue(arrf32, 0).(float32) != 3.5 { + t.Fatal("failed float32 case") + } + + // float64 + f64 := array.NewFloat64Builder(mem) + f64.Append(9.1) + arrf64 := f64.NewArray() + if getValue(arrf64, 0).(float64) != 9.1 { + t.Fatal("failed float64 case") + } + + // string + sb := array.NewStringBuilder(mem) + sb.Append("hello") + sarr := sb.NewArray() + if getValue(sarr, 0).(string) != "hello" { + t.Fatal("failed string case") + } + + // boolean + bb := array.NewBooleanBuilder(mem) + bb.Append(true) + barr := bb.NewArray() + if getValue(barr, 0).(bool) != true { + t.Fatal("failed boolean case") + } +} + +func TestBuildDynamicArray_AllPrimitiveTypes(t *testing.T) { + mem := memory.NewGoAllocator() + + tests := []struct { + dt arrow.DataType + val []any + }{ + {arrow.PrimitiveTypes.Int8, []any{int8(1), nil, int8(3)}}, + {arrow.PrimitiveTypes.Int16, []any{int16(2), int16(5)}}, + {arrow.PrimitiveTypes.Int32, []any{int32(10), nil, int32(12)}}, + {arrow.PrimitiveTypes.Int64, []any{int64(20), int64(40)}}, + + {arrow.PrimitiveTypes.Uint8, []any{uint8(7), nil}}, + {arrow.PrimitiveTypes.Uint16, []any{uint16(100)}}, + {arrow.PrimitiveTypes.Uint32, []any{uint32(2000)}}, + {arrow.PrimitiveTypes.Uint64, []any{uint64(99999)}}, + + {arrow.PrimitiveTypes.Float32, []any{float32(2.2), nil}}, + {arrow.PrimitiveTypes.Float64, []any{float64(9.9)}}, + + {arrow.BinaryTypes.String, []any{"a", "b", nil}}, + } + + for _, tc := range tests { + arr := buildDynamicArray(mem, tc.dt, tc.val) + if arr.Len() != len(tc.val) { + t.Fatalf("wrong length for type %v", tc.dt) + } + } +} + +func TestCreateAccumulator_AllCases(t *testing.T) { + funcs := []AggrFunc{Min, Max, Sum, Count, Avg} + + for _, fn := range funcs { + acc := createAccumulator(fn) + if acc == nil { + t.Fatalf("expected accumulator for fn=%v", fn) + } + } +} + +func TestCreateAccumulator_PanicOnInvalid(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatalf("expected panic for invalid function") + } + }() + + createAccumulator(AggrFunc(9999)) // invalid +} + +func TestBuildGroupByOutput_Basic(t *testing.T) { + col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } + child := groupByProject() + + gbExpr := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, gbExpr) + if err != nil { + t.Fatal(err) + } + + // invoke Next (fills accumulators) + _, _ = gb.Next(100) + + batch := buildGroupByOutput(gb) + + if batch.RowCount == 0 { + t.Fatalf("expected grouped rows") + } + + if len(batch.Columns) != 2 { + t.Fatalf("expected 2 columns, got %d", len(batch.Columns)) + } +} diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go index 6ac7cad..df9d3fa 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -13,16 +13,16 @@ import ( "github.com/apache/arrow/go/v17/arrow/compute" ) -// TODO: next steps are to deal with group by statments but that can be dealt with after basic aggr that return just 1 global value var ( ErrUnsupportedAggrFunc = func(aggr int) error { return fmt.Errorf("%d is an unsupported aggregate function", aggr) } ErrInvalidAggrColumnType = func(value any) error { - return fmt.Errorf("%v of type %T cannot be cast to float64 so it is not a valid column type to aggragate on", value, value) + return fmt.Errorf("%v of type %T cannot be cast to float64 so it is not a valid column type to aggregate on", value, value) } ) +// AggrFunc represents the type of aggregation function to be performed. type AggrFunc int const ( @@ -34,21 +34,14 @@ const ( ) var ( - _ = (Accumulator)(&MinAggrAccumulator{}) - _ = (Accumulator)(&MaxAggrAccumulator{}) - _ = (Accumulator)(&CountAggrAccumulator{}) - _ = (Accumulator)(&SumAggrAccumulator{}) - _ = (Accumulator)(&AvgAggrAccumulator{}) + _ = (accumulator)(&MinAggrAccumulator{}) + _ = (accumulator)(&MaxAggrAccumulator{}) + _ = (accumulator)(&CountAggrAccumulator{}) + _ = (accumulator)(&SumAggrAccumulator{}) + _ = (accumulator)(&AvgAggrAccumulator{}) _ = (operators.Operator)(&AggrExec{}) ) -// Min -//Max -//Count -// Sum -// Avg - -// for now just focus on single-column aggregation without group by func NewAggregateFunctions(aggrFunc AggrFunc, child Expr.Expression) AggregateFunctions { return AggregateFunctions{ AggrFunc: aggrFunc, @@ -57,15 +50,15 @@ func NewAggregateFunctions(aggrFunc AggrFunc, child Expr.Expression) AggregateFu } type AggregateFunctions struct { - AggrFunc AggrFunc // switch to deal with seperate aggregation functions + AggrFunc AggrFunc // switch to deal with separate aggregate functions Child Expr.Expression // resolves to a column generally } -type Accumulator interface { +type accumulator interface { Update(value float64) Finalize() float64 } -func newMinAggr() Accumulator { +func newMinAggr() accumulator { return &MinAggrAccumulator{} } @@ -84,7 +77,7 @@ func (m *MinAggrAccumulator) Update(value float64) { } func (m *MinAggrAccumulator) Finalize() float64 { return m.minV } -func newMaxAggr() Accumulator { +func newMaxAggr() accumulator { return &MaxAggrAccumulator{} } @@ -103,7 +96,7 @@ func (m *MaxAggrAccumulator) Update(value float64) { } func (m *MaxAggrAccumulator) Finalize() float64 { return m.maxV } -func NewCountAggr() Accumulator { +func NewCountAggr() accumulator { return &CountAggrAccumulator{} } @@ -116,7 +109,7 @@ func (c *CountAggrAccumulator) Update(_ float64) { } func (c *CountAggrAccumulator) Finalize() float64 { return c.count } -func NewSumAggr() Accumulator { +func NewSumAggr() accumulator { return &SumAggrAccumulator{} } @@ -128,34 +121,43 @@ func (s *SumAggrAccumulator) Update(value float64) { s.summation += value } func (s *SumAggrAccumulator) Finalize() float64 { return s.summation } -func newAvgAggr() Accumulator { +func newAvgAggr() accumulator { return &AvgAggrAccumulator{} } type AvgAggrAccumulator struct { + used bool values float64 count float64 } func (a *AvgAggrAccumulator) Update(value float64) { + a.used = true a.values += value a.count++ } -func (a *AvgAggrAccumulator) Finalize() float64 { return float64(a.values / a.count) } +func (a *AvgAggrAccumulator) Finalize() float64 { + // handles divide by zero + if !a.used { + return 0.0 + } + return a.values / a.count +} // =================== // Aggregator Operator // =================== +// handles global aggregations without group by type AggrExec struct { child operators.Operator // child operator schema *arrow.Schema // output schema aggExpressions []AggregateFunctions // list of wanted aggregate expressions - accumulators []Accumulator // list of accumulators corresponding to aggExpressions, these will actually work to compute the aggregation + accumulators []accumulator // list of accumulators corresponding to aggExpressions, these will actually work to compute the aggregation done bool // know when to return io.EOF } func NewGlobalAggrExec(child operators.Operator, aggExprs []AggregateFunctions) (*AggrExec, error) { - accs := make([]Accumulator, len(aggExprs)) + accs := make([]accumulator, len(aggExprs)) fields := make([]arrow.Field, len(aggExprs)) for i, agg := range aggExprs { dt, err := Expr.ExprDataType(agg.Child, child.Schema()) @@ -210,6 +212,7 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { } for { childBatch, err := a.child.Next(n) + fmt.Printf("child batch: %v\n", childBatch) if err != nil { if errors.Is(err, io.EOF) { break @@ -228,6 +231,9 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { valueArray := agrArray.(*array.Float64) accumulator := a.accumulators[i] for j := 0; j < valueArray.Len(); j++ { + if valueArray.IsNull(j) { + continue + } accumulator.Update(valueArray.Value(j)) } @@ -242,8 +248,8 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { return &operators.RecordBatch{ Schema: a.schema, Columns: resultColumns, - RowCount: uint64(len(a.aggExpressions)), - }, io.EOF + RowCount: 1, + }, nil // this is a pipeline breaker so it will always consume all of the input which means this needs to return an io.EOF } diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go index ea89bac..192630d 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go @@ -62,12 +62,81 @@ func generateAggTestColumns() ([]string, []any) { return names, columns } +func generateAggTestColumnsWithNulls(mem memory.Allocator) ([]string, []arrow.Array) { + names := []string{"id", "name", "age", "salary"} + + // ------------------------- + // id column (int32) + // ------------------------- + idB := array.NewInt32Builder(mem) + idVals := []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + idValid := []bool{ + true, true, false, true, true, + false, true, true, true, false, + } + idB.AppendValues(idVals, idValid) + idArr := idB.NewArray() + + // ------------------------- + // name column (string) + // ------------------------- + nameB := array.NewStringBuilder(mem) + nameVals := []string{ + "Alice", "Bob", "Charlie", "David", "Eve", + "Frank", "Grace", "Hannah", "Ivy", "Jake", + } + nameValid := []bool{ + true, true, true, false, true, + true, true, true, false, true, + } + nameB.AppendValues(nameVals, nameValid) + nameArr := nameB.NewArray() + + // ------------------------- + // age column (int32) + // ------------------------- + ageB := array.NewInt32Builder(mem) + ageVals := []int32{28, 34, 45, 22, 31, 29, 40, 36, 50, 26} + ageValid := []bool{ + true, false, true, true, true, + true, false, true, true, true, + } + ageB.AppendValues(ageVals, ageValid) + ageArr := ageB.NewArray() + + // ------------------------- + // salary column (float64) + // ------------------------- + salB := array.NewFloat64Builder(mem) + salVals := []float64{ + 70000, 82000, 54000, 91000, 60000, + 75000, 66000, 0, 45000, 99000, + } + + salaryValid := []bool{ + true, true, true, true, true, + true, true, false, true, true, + } + + salB.AppendValues(salVals, salaryValid) + salaryArr := salB.NewArray() + + return names, []arrow.Array{idArr, nameArr, ageArr, salaryArr} +} + func aggProject() *project.InMemorySource { names, cols := generateAggTestColumns() p, _ := project.NewInMemoryProjectExec(names, cols) return p } +// TODO: add test that check for null +func aggProjectNull() *project.InMemorySource { + names, arr := generateAggTestColumnsWithNulls(memory.NewGoAllocator()) + p, _ := project.NewInMemoryProjectExecFromArrays(names, arr) + return p +} + func col(name string) Expr.Expression { return Expr.NewColumnResolve(name) } @@ -360,10 +429,7 @@ func TestAggregateExecNext(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - resultBatch, err := aggrExec.Next(100) - if err == nil || !errors.Is(err, io.EOF) { - t.Fatalf("expected io.EOF error, got nil") - } + resultBatch, _ := aggrExec.Next(100) t.Logf("record batch: %v\n", resultBatch) if resultBatch.Columns[0].(*array.Float64).Value(0) != 22 { t.Fatalf("expected minimum age 22, got %v", resultBatch.Columns[0].(*array.Float64).Value(0)) @@ -381,10 +447,7 @@ func TestAggregateExecNext(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - resultBatch, err := aggrExec.Next(100) - if err == nil || !errors.Is(err, io.EOF) { - t.Fatalf("expected io.EOF, got %v", err) - } + resultBatch, _ := aggrExec.Next(100) maxSalary := resultBatch.Columns[0].(*array.Float64).Value(0) if maxSalary != 99000.0 && maxSalary != 94000.0 && maxSalary != 93000.0 { @@ -403,10 +466,7 @@ func TestAggregateExecNext(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - resultBatch, err := aggrExec.Next(200) - if err == nil || !errors.Is(err, io.EOF) { - t.Fatalf("expected io.EOF, got %v", err) - } + resultBatch, _ := aggrExec.Next(200) sumIDs := resultBatch.Columns[0].(*array.Float64).Value(0) expected := float64((25 * 26) / 2) // sum(1..25) = 325 @@ -417,7 +477,7 @@ func TestAggregateExecNext(t *testing.T) { t.Run("Aggr count of age column", func(t *testing.T) { proj := aggProject() agg := []AggregateFunctions{ - {AggrFunc: Count, Child: col("age")}, + NewAggregateFunctions(Count, col("age")), } aggrExec, err := NewGlobalAggrExec(proj, agg) @@ -425,17 +485,14 @@ func TestAggregateExecNext(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - resultBatch, err := aggrExec.Next(300) - if err == nil || !errors.Is(err, io.EOF) { - t.Fatalf("expected io.EOF, got %v", err) - } + resultBatch, _ := aggrExec.Next(300) count := resultBatch.Columns[0].(*array.Float64).Value(0) if count != 25 { t.Fatalf("expected count 25, got %v", count) } }) - t.Run("Aggr average of salary (⚠ your AVG is wrong)", func(t *testing.T) { + t.Run("Aggr average of salary ", func(t *testing.T) { proj := aggProject() agg := []AggregateFunctions{ @@ -447,10 +504,7 @@ func TestAggregateExecNext(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - resultBatch, err := aggrExec.Next(500) - if err == nil || !errors.Is(err, io.EOF) { - t.Fatalf("expected io.EOF, got %v", err) - } + resultBatch, _ := aggrExec.Next(500) avg := resultBatch.Columns[0].(*array.Float64).Value(0) expected := 75740.02 @@ -474,10 +528,7 @@ func TestAggregateExecNext(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - resultBatch, err := aggrExec.Next(1000) - if err == nil || !errors.Is(err, io.EOF) { - t.Fatalf("expected io.EOF, got %v", err) - } + resultBatch, _ := aggrExec.Next(1000) minAge := resultBatch.Columns[0].(*array.Float64).Value(0) maxSalary := resultBatch.Columns[1].(*array.Float64).Value(0) @@ -527,3 +578,30 @@ func TestAggregateExecNext(t *testing.T) { } }) } + +func TestAggregateExecNull(t *testing.T) { + + t.Run("Aggr count of age column", func(t *testing.T) { + proj := aggProjectNull() + agg := []AggregateFunctions{ + NewAggregateFunctions(Count, col("age")), + NewAggregateFunctions(Sum, col("id")), + } + + aggrExec, err := NewGlobalAggrExec(proj, agg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resultBatch, _ := aggrExec.Next(100) + t.Logf("rb:%v\n", resultBatch) + count := resultBatch.Columns[0].(*array.Float64).Value(0) + if count != 8 { + t.Fatalf("expected count 7, got %v", count) + } + sumIDs := resultBatch.Columns[1].(*array.Float64).Value(0) + expectedSum := float64(1 + 2 + 4 + 5 + 7 + 8 + 9) // only non-null ids + if sumIDs != expectedSum { + t.Fatalf("expected sum %v, got %v", expectedSum, sumIDs) + } + }) +} diff --git a/src/Backend/opti-sql-go/operators/project/custom.go b/src/Backend/opti-sql-go/operators/project/custom.go index e36fa0c..0816600 100644 --- a/src/Backend/opti-sql-go/operators/project/custom.go +++ b/src/Backend/opti-sql-go/operators/project/custom.go @@ -73,6 +73,35 @@ func (ms *InMemorySource) withFields(names ...string) error { ms.columns = cols return nil } +func NewInMemoryProjectExecFromArrays(names []string, arrays []arrow.Array) (*InMemorySource, error) { + if len(names) != len(arrays) { + return nil, operators.ErrInvalidSchema("number of column names and arrays do not match") + } + + fields := make([]arrow.Field, len(names)) + fieldToColIdx := make(map[string]int, len(names)) + + for i, arr := range arrays { + if arr == nil { + return nil, operators.ErrInvalidSchema(fmt.Sprintf("nil array for column %s", names[i])) + } + + fields[i] = arrow.Field{ + Name: names[i], + Type: arr.DataType(), + Nullable: true, // Arrow arrays may have null bitmaps + } + + fieldToColIdx[names[i]] = i + } + + return &InMemorySource{ + schema: arrow.NewSchema(fields, nil), + columns: arrays, + fieldToColIDx: fieldToColIdx, + }, nil +} + func (ms *InMemorySource) Next(n uint16) (*operators.RecordBatch, error) { if len(ms.columns) == 0 || ms.pos >= uint16(ms.columns[0].Len()) { return nil, io.EOF // EOF diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index 60f695b..d1f81a6 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -129,6 +129,7 @@ func (rb *RecordBatch) ColumnByName(name string) (arrow.Array, error) { } return rb.Columns[indices[0]], nil } + func (rbb *RecordBatchBuilder) GenIntArray(values ...int) arrow.Array { mem := memory.NewGoAllocator() builder := array.NewInt32Builder(mem) @@ -289,3 +290,115 @@ func (rbb *RecordBatchBuilder) GenLargeBinaryArray(values ...[]byte) arrow.Array } return builder.NewArray() } + +func (rb *RecordBatch) PrettyPrint() string { + if rb == nil { + return "" + } + + // ------------------------------- + // 1. Extract column names + // ------------------------------- + colNames := make([]string, len(rb.Schema.Fields())) + for i, f := range rb.Schema.Fields() { + colNames[i] = f.Name + } + + // ------------------------------- + // 2. Extract rows into [][]string + // ------------------------------- + rows := make([][]string, rb.RowCount) + for r := 0; r < int(rb.RowCount); r++ { + row := make([]string, len(rb.Columns)) + for c, arr := range rb.Columns { + row[c] = formatValue(arr, r) + } + rows[r] = row + } + + // ------------------------------- + // 3. Compute column widths + // ------------------------------- + colWidths := make([]int, len(colNames)) + for i, name := range colNames { + colWidths[i] = len(name) + } + for _, row := range rows { + for i, v := range row { + if len(v) > colWidths[i] { + colWidths[i] = len(v) + } + } + } + + // ------------------------------- + // 4. Build horizontal border line + // ------------------------------- + border := "+" + for _, w := range colWidths { + border += strings.Repeat("-", w+2) + "+" + } + + // ------------------------------- + // 5. Build the final output + // ------------------------------- + var b strings.Builder + + b.WriteString(border + "\n") + + // Header + b.WriteString("|") + for i, name := range colNames { + b.WriteString(" " + padRight(name, colWidths[i]) + " |") + } + b.WriteString("\n") + + b.WriteString(border + "\n") + + // Rows + for _, row := range rows { + b.WriteString("|") + for i, v := range row { + b.WriteString(" " + padRight(v, colWidths[i]) + " |") + } + b.WriteString("\n") + } + + b.WriteString(border) + + return b.String() +} + +// ------------------------------- +// Helper Functions +// ------------------------------- + +func padRight(s string, width int) string { + if len(s) >= width { + return s + } + return s + strings.Repeat(" ", width-len(s)) +} + +func formatValue(arr arrow.Array, row int) string { + if arr.IsNull(row) { + return "NULL" + } + + switch col := arr.(type) { + case *array.Int32: + return fmt.Sprintf("%d", col.Value(row)) + case *array.Int64: + return fmt.Sprintf("%d", col.Value(row)) + case *array.Float32: + return fmt.Sprintf("%g", col.Value(row)) + case *array.Float64: + return fmt.Sprintf("%g", col.Value(row)) + case *array.String: + return col.Value(row) + case *array.Boolean: + return fmt.Sprintf("%t", col.Value(row)) + default: + return "" + } +} From c1da3cb84d3a6b5eedad4d651d9afc14fa5e2054 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Tue, 25 Nov 2025 22:47:40 -0500 Subject: [PATCH 04/10] feat: Implement having operator --- src/Backend/opti-sql-go/Expr/expr.go | 1 + .../opti-sql-go/operators/aggr/groupBy.go | 4 +- .../operators/aggr/groupBy_test.go | 21 +- .../opti-sql-go/operators/aggr/having.go | 78 +++++++ .../opti-sql-go/operators/aggr/having_test.go | 213 ++++++++++++++++++ .../opti-sql-go/operators/aggr/singleAggr.go | 1 - .../opti-sql-go/operators/aggr/sort.go | 2 +- .../opti-sql-go/operators/filter/filter.go | 4 +- .../opti-sql-go/operators/test/t1_test.go | 1 + 9 files changed, 302 insertions(+), 23 deletions(-) create mode 100644 src/Backend/opti-sql-go/operators/aggr/having.go create mode 100644 src/Backend/opti-sql-go/operators/aggr/having_test.go create mode 100644 src/Backend/opti-sql-go/operators/test/t1_test.go diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go index 4ae10bb..e27d179 100644 --- a/src/Backend/opti-sql-go/Expr/expr.go +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -387,6 +387,7 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error } rightArr, err := EvalExpression(b.Right, batch) if err != nil { + fmt.Printf("right side evaluation failed with %v", err) return nil, err } opt := compute.ArithmeticOptions{} diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy.go b/src/Backend/opti-sql-go/operators/aggr/groupBy.go index 686ae3a..c958cac 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy.go @@ -107,7 +107,6 @@ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { } } key := strings.Join(keyParts, "|") - fmt.Printf("key: %v\n", key) // Allocate accumulator list if new group if _, exists := g.groups[key]; !exists { g.groups[key] = make([]accumulator, len(g.groupExpr)) @@ -132,7 +131,7 @@ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { batch := buildGroupByOutput(g) g.done = true - return batch, io.EOF + return batch, nil } func (g *GroupByExec) Schema() *arrow.Schema { @@ -224,7 +223,6 @@ func buildGroupByOutput(g *GroupByExec) *operators.RecordBatch { rowCount := len(g.groups) if rowCount == 0 { - // return empty batch (0 groups) return &operators.RecordBatch{ Schema: g.schema, Columns: []arrow.Array{}, diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go index 23803dc..0482870 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go @@ -124,8 +124,7 @@ func groupByProject() *project.InMemorySource { func TestGroupByInit(t *testing.T) { p := groupByProject() - rc, _ := p.Next(12) - fmt.Printf("rc:%v \n", rc) + _, _ = p.Next(12) } func TestNewGroupByExecAndSchema(t *testing.T) { @@ -416,10 +415,7 @@ func TestGroupByNext_SingleColumnCount(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - batch, err := gb.Next(1000) - if err == nil || !errors.Is(err, io.EOF) { - t.Fatalf("expected EOF, got %v", err) - } + batch, _ := gb.Next(1000) if batch == nil || batch.RowCount == 0 { t.Fatalf("expected non-empty grouped result") @@ -461,10 +457,7 @@ func TestGroupByNext_MultipleGroupBy_MultipleAggs(t *testing.T) { t.Fatal(err) } - batch, err := gb.Next(50) - if err == nil || !errors.Is(err, io.EOF) { - t.Fatalf("expected EOF, got %v", err) - } + batch, _ := gb.Next(50) if batch.RowCount == 0 { t.Fatalf("expected non-zero grouped rows") @@ -491,16 +484,12 @@ func TestGroupByNext_MultipleNextCalls(t *testing.T) { } // First call returns batch + EOF + _, _ = gb.Next(100) _, err = gb.Next(100) if !errors.Is(err, io.EOF) { - t.Fatalf("expected EOF on first return, got %v", err) + t.Fatalf("expected EOF on second return, got %v", err) } - // Second call MUST return EOF immediately - _, err = gb.Next(100) - if !errors.Is(err, io.EOF) { - t.Fatalf("expected EOF on second call, got %v", err) - } } func TestBuildGroupBySchema_AllBranches(t *testing.T) { diff --git a/src/Backend/opti-sql-go/operators/aggr/having.go b/src/Backend/opti-sql-go/operators/aggr/having.go new file mode 100644 index 0000000..72a5a91 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/aggr/having.go @@ -0,0 +1,78 @@ +package aggr + +import ( + "errors" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + "opti-sql-go/operators/filter" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +// carbon copy of filter.go with minor changes to fit having semantics +var ( + _ = (operators.Operator)(&HavingExec{}) +) + +type HavingClone = filter.FilterExec + +type HavingExec struct { + input operators.Operator + schema *arrow.Schema + + havingExpr Expr.Expression + done bool +} + +func NewHavingExec(input operators.Operator, havingFilter Expr.Expression) (*HavingExec, error) { + + return &HavingExec{ + input: input, + schema: input.Schema(), + havingExpr: havingFilter, + }, nil +} + +func (h *HavingExec) Next(n uint16) (*operators.RecordBatch, error) { + if h.done { + return nil, io.EOF + } + batch, err := h.input.Next(n) + if err != nil { + return nil, err + } + booleanMask, err := Expr.EvalExpression(h.havingExpr, batch) + if err != nil { + return nil, err + } + boolArr, ok := booleanMask.(*array.Boolean) // impossible for this to not be a boolean array,assuming validPredicates works as it should + if !ok { + return nil, errors.New("predicate did not evaluate to boolean array") + } + filteredCol := make([]arrow.Array, len(batch.Columns)) + for i, col := range batch.Columns { + filteredCol[i], err = filter.ApplyBooleanMask(col, boolArr) + if err != nil { + return nil, err + } + } + // release old columns + for _, c := range batch.Columns { + c.Release() + } + size := uint64(filteredCol[0].Len()) + + return &operators.RecordBatch{ + Schema: batch.Schema, + Columns: filteredCol, + RowCount: size, + }, nil +} +func (h *HavingExec) Schema() *arrow.Schema { + return h.schema +} +func (h *HavingExec) Close() error { + return h.input.Close() +} diff --git a/src/Backend/opti-sql-go/operators/aggr/having_test.go b/src/Backend/opti-sql-go/operators/aggr/having_test.go new file mode 100644 index 0000000..9321639 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/aggr/having_test.go @@ -0,0 +1,213 @@ +package aggr + +import ( + "errors" + "io" + "strings" + "testing" + + "opti-sql-go/Expr" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +func TestHavingExec_OnGroupBy(t *testing.T) { + + // ============================================================= + // 1) HAVING SUM(salary) > 600000 + // ============================================================= + t.Run("having_sum_salary_gt_600k", func(t *testing.T) { + + child := groupByProject() + + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected GroupBy error: %v", err) + } + + sumCol := "sum_Column(salary)" + + // SUM(salary) > 600000 + havingExpr := Expr.NewBinaryExpr( + Expr.NewColumnResolve(sumCol), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(600000)), + ) + + having, err := NewHavingExec(gb, havingExpr) + if err != nil { + t.Fatalf("unexpected HavingExec init error: %v", err) + } + + batch, err := having.Next(1024) + if err != nil { + t.Fatalf("unexpected error running Next: %v", err) + } + t.Logf("batch : %v\n", batch.PrettyPrint()) + sumValues := batch.Columns[1].(*array.Float64) + for i := 0; i < sumValues.Len(); i++ { + if sumValues.Value(i) <= 600000 { + t.Fatalf("expected sum(salary) > 600000, got %f", sumValues.Value(i)) + } + } + + }) + + // ============================================================= + // 2) HAVING COUNT(id) >= 10 + // ============================================================= + t.Run("having_count_id_ge_10", func(t *testing.T) { + + child := groupByProject() + + groupBy := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + } + + gb, err := NewGroupByExec(child, aggs, groupBy) + if err != nil { + t.Fatalf("unexpected GroupBy err: %v", err) + } + + countCol := "count_Column(id)" + + havingExpr := Expr.NewBinaryExpr( + Expr.NewColumnResolve(countCol), + Expr.GreaterThanOrEqual, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(10)), + ) + + having, err := NewHavingExec(gb, havingExpr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + batch, err := having.Next(200) + if err != nil { + t.Fatalf("unexpected Next error: %v", err) + } + + if batch.RowCount != 3 { // North, South, West ≥ 10 + t.Fatalf("expected 3 regions with >=10 rows, got %d", batch.RowCount) + } + }) + + // ============================================================= + // 3) HAVING filters all groups out + // ============================================================= + t.Run("having_filters_all", func(t *testing.T) { + + child := groupByProject() + + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, _ := NewGroupByExec(child, aggs, groupBy) + + sumCol := "sum_Column(salary)" + + // Impossible condition + havingExpr := Expr.NewBinaryExpr( + Expr.NewColumnResolve(sumCol), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(1_000_000_000)), + ) + + having, _ := NewHavingExec(gb, havingExpr) + + batch, err := having.Next(1024) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + if batch.RowCount != 0 { + t.Fatalf("expected all rows to be filtered out, got %d", batch.RowCount) + } + }) + + // ============================================================= + // 4) Non-boolean predicate → error + // ============================================================= + t.Run("having_non_boolean_predicate", func(t *testing.T) { + + child := groupByProject() + groupBy := []Expr.Expression{col("department")} + aggs := []AggregateFunctions{ + {AggrFunc: Sum, Child: col("salary")}, + } + + gb, _ := NewGroupByExec(child, aggs, groupBy) + + // invalid: resolves to float, not boolean + invalidExpr := Expr.NewColumnResolve("sum_Column(salary)") + + having, _ := NewHavingExec(gb, invalidExpr) + + _, err := having.Next(100) + if err == nil { + t.Fatalf("expected non-boolean error, got nil") + } + if !strings.Contains(err.Error(), "boolean") { + t.Fatalf("expected boolean error, got: %v", err) + } + }) + + // ============================================================= + // 5) done = true returns EOF + // ============================================================= + t.Run("done_returns_eof", func(t *testing.T) { + + child := groupByProject() + + groupBy := []Expr.Expression{col("region")} + aggs := []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + } + + gb, _ := NewGroupByExec(child, aggs, groupBy) + + countCol := "count_Column(id)" + + havingExpr := Expr.NewBinaryExpr( + Expr.NewColumnResolve(countCol), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(0)), + ) + + h, _ := NewHavingExec(gb, havingExpr) + h.done = true + + _, err := h.Next(10) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF, got: %v", err) + } + }) + + // ============================================================= + // 6) Close forwards to child.Close() + // ============================================================= + t.Run("close_propagates", func(t *testing.T) { + + child := groupByProject() + + gb, _ := NewGroupByExec(child, []AggregateFunctions{ + {AggrFunc: Count, Child: col("id")}, + }, []Expr.Expression{col("region")}) + + h, _ := NewHavingExec(gb, Expr.NewLiteralResolve(arrow.FixedWidthTypes.Boolean, true)) + + if err := h.Close(); err != nil { + t.Fatalf("Close returned error: %v", err) + } + t.Log(h.Schema()) + }) +} diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go index df9d3fa..3e1f4e6 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -212,7 +212,6 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { } for { childBatch, err := a.child.Next(n) - fmt.Printf("child batch: %v\n", childBatch) if err != nil { if errors.Is(err, io.EOF) { break diff --git a/src/Backend/opti-sql-go/operators/aggr/sort.go b/src/Backend/opti-sql-go/operators/aggr/sort.go index d5a469b..ed342a8 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort.go @@ -1,3 +1,3 @@ package aggr -// order by col asc, col 2 desc .... ect +// order by col asc, col 2 desc .... etc diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index ddd8c1b..645eeeb 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -55,7 +55,7 @@ func (f *FilterExec) Next(n uint16) (*operators.RecordBatch, error) { } filteredCol := make([]arrow.Array, len(batch.Columns)) for i, col := range batch.Columns { - filteredCol[i], err = applyBooleanMask(col, boolArr) + filteredCol[i], err = ApplyBooleanMask(col, boolArr) if err != nil { return nil, err } @@ -80,7 +80,7 @@ func (f *FilterExec) Close() error { return f.input.Close() } -func applyBooleanMask(col arrow.Array, mask *array.Boolean) (arrow.Array, error) { +func ApplyBooleanMask(col arrow.Array, mask *array.Boolean) (arrow.Array, error) { datum, err := compute.Filter( context.TODO(), compute.NewDatum(col), diff --git a/src/Backend/opti-sql-go/operators/test/t1_test.go b/src/Backend/opti-sql-go/operators/test/t1_test.go new file mode 100644 index 0000000..56e5404 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/test/t1_test.go @@ -0,0 +1 @@ +package test From 825732b448c67c9e2495de2e5e6f31bf603ca94a Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 26 Nov 2025 13:05:18 -0500 Subject: [PATCH 05/10] fixed PR comments --- .../opti-sql-go/operators/aggr/groupBy.go | 11 ++- .../operators/aggr/groupBy_test.go | 2 +- .../opti-sql-go/operators/aggr/having.go | 7 +- .../opti-sql-go/operators/aggr/singleAggr.go | 69 +++++++++---------- .../opti-sql-go/operators/test/t1_test.go | 2 + 5 files changed, 43 insertions(+), 48 deletions(-) diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy.go b/src/Backend/opti-sql-go/operators/aggr/groupBy.go index c958cac..7c57b28 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy.go @@ -8,9 +8,9 @@ import ( "opti-sql-go/operators" "strings" - "github.com/apache/arrow/go/v15/arrow/memory" "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" ) /* @@ -156,7 +156,7 @@ func buildGroupBySchema(childSchema *arrow.Schema, groupByExpr []Expr.Expression fields = append(fields, arrow.Field{ Name: fmt.Sprintf("group_%s", expr.String()), Type: dt, - Nullable: false, + Nullable: true, }) } @@ -208,9 +208,9 @@ func createAccumulator(fn AggrFunc) accumulator { case Max: return newMaxAggr() case Sum: - return NewSumAggr() + return newSumAggr() case Count: - return NewCountAggr() + return newCountAggr() case Avg: return newAvgAggr() default: @@ -244,8 +244,6 @@ func buildGroupByOutput(g *GroupByExec) *operators.RecordBatch { aggrCols[i] = make([]float64, 0, rowCount) } - // Iterate groups in stable order - i := 0 for key, accs := range g.groups { // Add group-by (dimension) values dims := g.keys[key] @@ -258,7 +256,6 @@ func buildGroupByOutput(g *GroupByExec) *operators.RecordBatch { aggrCols[j] = append(aggrCols[j], acc.Finalize()) } - i++ } // Now build Arrow arrays in correct schema order diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go index 0482870..10756f0 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go @@ -9,9 +9,9 @@ import ( "strings" "testing" - "github.com/apache/arrow/go/v15/arrow/memory" "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" ) func generateGroupByTestColumns() ([]string, []any) { diff --git a/src/Backend/opti-sql-go/operators/aggr/having.go b/src/Backend/opti-sql-go/operators/aggr/having.go index 72a5a91..3f47233 100644 --- a/src/Backend/opti-sql-go/operators/aggr/having.go +++ b/src/Backend/opti-sql-go/operators/aggr/having.go @@ -16,8 +16,6 @@ var ( _ = (operators.Operator)(&HavingExec{}) ) -type HavingClone = filter.FilterExec - type HavingExec struct { input operators.Operator schema *arrow.Schema @@ -41,6 +39,9 @@ func (h *HavingExec) Next(n uint16) (*operators.RecordBatch, error) { } batch, err := h.input.Next(n) if err != nil { + if errors.Is(err, io.EOF) { + h.done = true + } return nil, err } booleanMask, err := Expr.EvalExpression(h.havingExpr, batch) @@ -49,7 +50,7 @@ func (h *HavingExec) Next(n uint16) (*operators.RecordBatch, error) { } boolArr, ok := booleanMask.(*array.Boolean) // impossible for this to not be a boolean array,assuming validPredicates works as it should if !ok { - return nil, errors.New("predicate did not evaluate to boolean array") + return nil, errors.New("having predicate did not evaluate to boolean array") } filteredCol := make([]arrow.Array, len(batch.Columns)) for i, col := range batch.Columns { diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go index 3e1f4e6..9593ca3 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -34,11 +34,11 @@ const ( ) var ( - _ = (accumulator)(&MinAggrAccumulator{}) - _ = (accumulator)(&MaxAggrAccumulator{}) - _ = (accumulator)(&CountAggrAccumulator{}) - _ = (accumulator)(&SumAggrAccumulator{}) - _ = (accumulator)(&AvgAggrAccumulator{}) + _ = (accumulator)(&minAggrAccumulator{}) + _ = (accumulator)(&maxAggrAccumulator{}) + _ = (accumulator)(&countAggrAccumulator{}) + _ = (accumulator)(&sumAggrAccumulator{}) + _ = (accumulator)(&avgAggrAccumulator{}) _ = (operators.Operator)(&AggrExec{}) ) @@ -59,15 +59,15 @@ type accumulator interface { } func newMinAggr() accumulator { - return &MinAggrAccumulator{} + return &minAggrAccumulator{} } -type MinAggrAccumulator struct { +type minAggrAccumulator struct { minV float64 firstValue bool } -func (m *MinAggrAccumulator) Update(value float64) { +func (m *minAggrAccumulator) Update(value float64) { if !m.firstValue { m.minV = value m.firstValue = true @@ -76,17 +76,17 @@ func (m *MinAggrAccumulator) Update(value float64) { m.minV = min(m.minV, value) } -func (m *MinAggrAccumulator) Finalize() float64 { return m.minV } +func (m *minAggrAccumulator) Finalize() float64 { return m.minV } func newMaxAggr() accumulator { - return &MaxAggrAccumulator{} + return &maxAggrAccumulator{} } -type MaxAggrAccumulator struct { +type maxAggrAccumulator struct { maxV float64 firstValue bool } -func (m *MaxAggrAccumulator) Update(value float64) { +func (m *maxAggrAccumulator) Update(value float64) { if !m.firstValue { m.maxV = value m.firstValue = true @@ -94,49 +94,49 @@ func (m *MaxAggrAccumulator) Update(value float64) { } m.maxV = max(m.maxV, value) } -func (m *MaxAggrAccumulator) Finalize() float64 { return m.maxV } +func (m *maxAggrAccumulator) Finalize() float64 { return m.maxV } -func NewCountAggr() accumulator { - return &CountAggrAccumulator{} +func newCountAggr() accumulator { + return &countAggrAccumulator{} } -type CountAggrAccumulator struct { +type countAggrAccumulator struct { count float64 } -func (c *CountAggrAccumulator) Update(_ float64) { +func (c *countAggrAccumulator) Update(_ float64) { c.count++ } -func (c *CountAggrAccumulator) Finalize() float64 { return c.count } +func (c *countAggrAccumulator) Finalize() float64 { return c.count } -func NewSumAggr() accumulator { - return &SumAggrAccumulator{} +func newSumAggr() accumulator { + return &sumAggrAccumulator{} } -type SumAggrAccumulator struct { +type sumAggrAccumulator struct { summation float64 } -func (s *SumAggrAccumulator) Update(value float64) { +func (s *sumAggrAccumulator) Update(value float64) { s.summation += value } -func (s *SumAggrAccumulator) Finalize() float64 { return s.summation } +func (s *sumAggrAccumulator) Finalize() float64 { return s.summation } func newAvgAggr() accumulator { - return &AvgAggrAccumulator{} + return &avgAggrAccumulator{} } -type AvgAggrAccumulator struct { +type avgAggrAccumulator struct { used bool values float64 count float64 } -func (a *AvgAggrAccumulator) Update(value float64) { +func (a *avgAggrAccumulator) Update(value float64) { a.used = true a.values += value a.count++ } -func (a *AvgAggrAccumulator) Finalize() float64 { +func (a *avgAggrAccumulator) Finalize() float64 { // handles divide by zero if !a.used { return 0.0 @@ -174,10 +174,10 @@ func NewGlobalAggrExec(child operators.Operator, aggExprs []AggregateFunctions) accs[i] = newMaxAggr() case Count: fieldName = fmt.Sprintf("count_%s", agg.Child.String()) - accs[i] = NewCountAggr() + accs[i] = newCountAggr() case Sum: fieldName = fmt.Sprintf("sum_%s", agg.Child.String()) - accs[i] = NewSumAggr() + accs[i] = newSumAggr() case Avg: fieldName = fmt.Sprintf("avg_%s", agg.Child.String()) accs[i] = newAvgAggr() @@ -199,13 +199,9 @@ func NewGlobalAggrExec(child operators.Operator, aggExprs []AggregateFunctions) }, nil } -// check for io.EOF with flag -// read in all record batches -// for each batch, run Expr.Evaluate, to get the column you want for the expression (cast to float64) -// -// for each element of that column grab the values you want using the accumulator interface -// -// build output batch, for now its just 1 of everything straight forward +// Next consumes all batches from the child operator, evaluates the aggregate expressions, +// updates the accumulators for each value, and returns a single output batch containing +// the final aggregation results. It returns io.EOF after producing the result batch. func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { if a.done { return nil, io.EOF @@ -249,7 +245,6 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { Columns: resultColumns, RowCount: 1, }, nil - // this is a pipeline breaker so it will always consume all of the input which means this needs to return an io.EOF } func (a *AggrExec) Schema() *arrow.Schema { diff --git a/src/Backend/opti-sql-go/operators/test/t1_test.go b/src/Backend/opti-sql-go/operators/test/t1_test.go index 56e5404..a571421 100644 --- a/src/Backend/opti-sql-go/operators/test/t1_test.go +++ b/src/Backend/opti-sql-go/operators/test/t1_test.go @@ -1 +1,3 @@ package test + +// test for all operators together From 7352d28e00088c53310808d42462dd46b4a59ef5 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 26 Nov 2025 13:20:39 -0500 Subject: [PATCH 06/10] fix:removed array memory leaks --- src/Backend/opti-sql-go/Expr/expr.go | 3 --- src/Backend/opti-sql-go/operators/aggr/groupBy.go | 12 ++++++++++++ src/Backend/opti-sql-go/operators/aggr/having.go | 4 +--- src/Backend/opti-sql-go/operators/aggr/singleAggr.go | 1 + .../opti-sql-go/operators/aggr/singleAggr_test.go | 1 - src/Backend/opti-sql-go/operators/filter/filter.go | 5 ++--- src/Backend/opti-sql-go/operators/record.go | 7 +++++++ 7 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go index e27d179..f9d88de 100644 --- a/src/Backend/opti-sql-go/Expr/expr.go +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -387,7 +387,6 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error } rightArr, err := EvalExpression(b.Right, batch) if err != nil { - fmt.Printf("right side evaluation failed with %v", err) return nil, err } opt := compute.ArithmeticOptions{} @@ -496,7 +495,6 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error return unpackDatum(datum) case Like: if leftArr.DataType() != arrow.BinaryTypes.String || rightArr.DataType() != arrow.BinaryTypes.String { - // regEx runs only on strings return nil, errors.New("binary operator Like only works on arrays of strings") } var compiledRegEx = compileSqlRegEx(rightArr.ValueStr(0)) @@ -504,7 +502,6 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error leftStrArray := leftArr.(*array.String) for i := 0; i < leftStrArray.Len(); i++ { valid := validRegEx(leftStrArray.Value(i), compiledRegEx) - fmt.Printf("does %s match %s: %v\n", leftStrArray.Value(i), compiledRegEx, valid) filterBuilder.Append(valid) } return filterBuilder.NewArray(), nil diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy.go b/src/Backend/opti-sql-go/operators/aggr/groupBy.go index 7c57b28..5e65bfb 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy.go @@ -75,6 +75,8 @@ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { for i, expr := range g.groupByExpr { arr, err := Expr.EvalExpression(expr, childBatch) if err != nil { + operators.ReleaseArrays(groupArrays) + operators.ReleaseArrays(childBatch.Columns) return nil, err } groupArrays[i] = arr @@ -85,10 +87,16 @@ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { for i, agg := range g.groupExpr { arr, err := Expr.EvalExpression(agg.Child, childBatch) if err != nil { + operators.ReleaseArrays(aggrArrays) + operators.ReleaseArrays(groupArrays) + operators.ReleaseArrays(childBatch.Columns) return nil, err } arr, err = castArrayToFloat64(arr) if err != nil { + operators.ReleaseArrays(aggrArrays) + operators.ReleaseArrays(groupArrays) + operators.ReleaseArrays(childBatch.Columns) return nil, err } aggrArrays[i] = arr @@ -125,6 +133,10 @@ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { g.groups[key][i].Update(val) } } + // 4. release temp arrays + operators.ReleaseArrays(aggrArrays) + operators.ReleaseArrays(groupArrays) + operators.ReleaseArrays(childBatch.Columns) } // 4. Build output RecordBatch diff --git a/src/Backend/opti-sql-go/operators/aggr/having.go b/src/Backend/opti-sql-go/operators/aggr/having.go index 3f47233..a2aeb63 100644 --- a/src/Backend/opti-sql-go/operators/aggr/having.go +++ b/src/Backend/opti-sql-go/operators/aggr/having.go @@ -60,9 +60,7 @@ func (h *HavingExec) Next(n uint16) (*operators.RecordBatch, error) { } } // release old columns - for _, c := range batch.Columns { - c.Release() - } + operators.ReleaseArrays(batch.Columns) size := uint64(filteredCol[0].Len()) return &operators.RecordBatch{ diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go index 9593ca3..0d1db36 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -233,6 +233,7 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { } } + operators.ReleaseArrays(childBatch.Columns) } // build array with just the result of the column resultColumns := make([]arrow.Array, len(a.accumulators)) diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go index 192630d..9b5af24 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr_test.go @@ -130,7 +130,6 @@ func aggProject() *project.InMemorySource { return p } -// TODO: add test that check for null func aggProjectNull() *project.InMemorySource { names, arr := generateAggTestColumnsWithNulls(memory.NewGoAllocator()) p, _ := project.NewInMemoryProjectExecFromArrays(names, arr) diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index 645eeeb..e93a1c8 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -60,10 +60,9 @@ func (f *FilterExec) Next(n uint16) (*operators.RecordBatch, error) { return nil, err } } + booleanMask.Release() // release old columns - for _, c := range batch.Columns { - c.Release() - } + operators.ReleaseArrays(batch.Columns) size := uint64(filteredCol[0].Len()) return &operators.RecordBatch{ diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index d1f81a6..24c6da7 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -290,6 +290,13 @@ func (rbb *RecordBatchBuilder) GenLargeBinaryArray(values ...[]byte) arrow.Array } return builder.NewArray() } +func ReleaseArrays(a []arrow.Array) { + for _, col := range a { + if col != nil { + col.Release() + } + } +} func (rb *RecordBatch) PrettyPrint() string { if rb == nil { From 139c88c9899b758c44279739ec517ff021904e98 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 26 Nov 2025 13:22:24 -0500 Subject: [PATCH 07/10] fix:added naming convention for child input record batch --- src/Backend/opti-sql-go/operators/project/projectExec.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index 9d93d96..3df1fee 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -94,9 +94,7 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { outPutCols[i] = arr arr.Retain() } - for _, c := range childBatch.Columns { - c.Release() - } + operators.ReleaseArrays(childBatch.Columns) return &operators.RecordBatch{ Schema: &p.outputschema, Columns: outPutCols, From 1e48e9da1f1ff100fa95d6b1c5e74377ed47dde2 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 26 Nov 2025 13:26:15 -0500 Subject: [PATCH 08/10] closes #25 and closes #24 --- src/Backend/opti-sql-go/operators/filter/filter.go | 12 ++++++------ src/Backend/opti-sql-go/operators/record.go | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index e93a1c8..6c30c8f 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -41,11 +41,11 @@ func (f *FilterExec) Next(n uint16) (*operators.RecordBatch, error) { if f.done { return nil, io.EOF } - batch, err := f.input.Next(n) + childBatch, err := f.input.Next(n) if err != nil { return nil, err } - booleanMask, err := Expr.EvalExpression(f.predicate, batch) + booleanMask, err := Expr.EvalExpression(f.predicate, childBatch) if err != nil { return nil, err } @@ -53,8 +53,8 @@ func (f *FilterExec) Next(n uint16) (*operators.RecordBatch, error) { if !ok { return nil, errors.New("predicate did not evaluate to boolean array") } - filteredCol := make([]arrow.Array, len(batch.Columns)) - for i, col := range batch.Columns { + filteredCol := make([]arrow.Array, len(childBatch.Columns)) + for i, col := range childBatch.Columns { filteredCol[i], err = ApplyBooleanMask(col, boolArr) if err != nil { return nil, err @@ -62,11 +62,11 @@ func (f *FilterExec) Next(n uint16) (*operators.RecordBatch, error) { } booleanMask.Release() // release old columns - operators.ReleaseArrays(batch.Columns) + operators.ReleaseArrays(childBatch.Columns) size := uint64(filteredCol[0].Len()) return &operators.RecordBatch{ - Schema: batch.Schema, + Schema: childBatch.Schema, Columns: filteredCol, RowCount: size, }, nil diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index 24c6da7..6678ef4 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -24,7 +24,7 @@ type Operator interface { type RecordBatch struct { Schema *arrow.Schema Columns []arrow.Array - RowCount uint64 // TODO: update to actually use this, in all operators + RowCount uint64 // } type SchemaBuilder struct { From 413f2062b3b7f9e1b4cc2efd244edd9e2e3da1fe Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Thu, 27 Nov 2025 13:48:18 -0500 Subject: [PATCH 09/10] feat:implement basic sort operator --- .../opti-sql-go/operators/aggr/groupBy.go | 8 +- .../opti-sql-go/operators/aggr/having.go | 12 +- .../opti-sql-go/operators/aggr/singleAggr.go | 8 +- .../opti-sql-go/operators/aggr/sort.go | 338 ++++++++++++++++++ .../opti-sql-go/operators/aggr/sort_test.go | 126 ++++++- .../opti-sql-go/operators/project/parquet.go | 8 +- .../operators/project/projectExec.go | 8 +- 7 files changed, 483 insertions(+), 25 deletions(-) diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy.go b/src/Backend/opti-sql-go/operators/aggr/groupBy.go index 5e65bfb..962a450 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy.go @@ -25,7 +25,7 @@ var ( // place all unique elements of the group by column into a hash table, each element gets their own Accumulator instance type GroupByExec struct { - child operators.Operator + input operators.Operator schema *arrow.Schema groupExpr []AggregateFunctions groupByExpr []Expr.Expression // column names @@ -42,7 +42,7 @@ func NewGroupByExec(child operators.Operator, groupExpr []AggregateFunctions, gr } return &GroupByExec{ - child: child, + input: child, schema: s, groupExpr: groupExpr, groupByExpr: groupBy, @@ -60,7 +60,7 @@ func (g *GroupByExec) Next(batchSize uint16) (*operators.RecordBatch, error) { } for { - childBatch, err := g.child.Next(batchSize) + childBatch, err := g.input.Next(batchSize) if err != nil { if errors.Is(err, io.EOF) { break @@ -150,7 +150,7 @@ func (g *GroupByExec) Schema() *arrow.Schema { return g.schema } func (g *GroupByExec) Close() error { - return g.child.Close() + return g.input.Close() } // handles validation and building of schema for group by diff --git a/src/Backend/opti-sql-go/operators/aggr/having.go b/src/Backend/opti-sql-go/operators/aggr/having.go index a2aeb63..a2a559f 100644 --- a/src/Backend/opti-sql-go/operators/aggr/having.go +++ b/src/Backend/opti-sql-go/operators/aggr/having.go @@ -37,14 +37,14 @@ func (h *HavingExec) Next(n uint16) (*operators.RecordBatch, error) { if h.done { return nil, io.EOF } - batch, err := h.input.Next(n) + childBatch, err := h.input.Next(n) if err != nil { if errors.Is(err, io.EOF) { h.done = true } return nil, err } - booleanMask, err := Expr.EvalExpression(h.havingExpr, batch) + booleanMask, err := Expr.EvalExpression(h.havingExpr, childBatch) if err != nil { return nil, err } @@ -52,19 +52,19 @@ func (h *HavingExec) Next(n uint16) (*operators.RecordBatch, error) { if !ok { return nil, errors.New("having predicate did not evaluate to boolean array") } - filteredCol := make([]arrow.Array, len(batch.Columns)) - for i, col := range batch.Columns { + filteredCol := make([]arrow.Array, len(childBatch.Columns)) + for i, col := range childBatch.Columns { filteredCol[i], err = filter.ApplyBooleanMask(col, boolArr) if err != nil { return nil, err } } // release old columns - operators.ReleaseArrays(batch.Columns) + operators.ReleaseArrays(childBatch.Columns) size := uint64(filteredCol[0].Len()) return &operators.RecordBatch{ - Schema: batch.Schema, + Schema: childBatch.Schema, Columns: filteredCol, RowCount: size, }, nil diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go index 0d1db36..1fcccdd 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -149,7 +149,7 @@ func (a *avgAggrAccumulator) Finalize() float64 { // =================== // handles global aggregations without group by type AggrExec struct { - child operators.Operator // child operator + input operators.Operator // child operator schema *arrow.Schema // output schema aggExpressions []AggregateFunctions // list of wanted aggregate expressions accumulators []accumulator // list of accumulators corresponding to aggExpressions, these will actually work to compute the aggregation @@ -192,7 +192,7 @@ func NewGlobalAggrExec(child operators.Operator, aggExprs []AggregateFunctions) } } return &AggrExec{ - child: child, + input: child, schema: arrow.NewSchema(fields, nil), aggExpressions: aggExprs, accumulators: accs, @@ -207,7 +207,7 @@ func (a *AggrExec) Next(n uint16) (*operators.RecordBatch, error) { return nil, io.EOF } for { - childBatch, err := a.child.Next(n) + childBatch, err := a.input.Next(n) if err != nil { if errors.Is(err, io.EOF) { break @@ -252,7 +252,7 @@ func (a *AggrExec) Schema() *arrow.Schema { return a.schema } func (a *AggrExec) Close() error { - return a.child.Close() + return a.input.Close() } func validAggrType(dt arrow.DataType) bool { diff --git a/src/Backend/opti-sql-go/operators/aggr/sort.go b/src/Backend/opti-sql-go/operators/aggr/sort.go index ed342a8..11ab431 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort.go @@ -1,3 +1,341 @@ package aggr +import ( + "context" + "errors" + "fmt" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + "sort" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/compute" + "github.com/apache/arrow/go/v17/arrow/memory" +) + // order by col asc, col 2 desc .... etc +var ( + _ = (operators.Operator)(&SortExec{}) + _ = (operators.Operator)(&TopKSortExec{}) +) + +type SortKey struct { + Expr Expr.Expression + Ascending bool // by default false -- DESC (highest values first -> smaller values) + NullFirst bool // by default false -- nulls last +} + +func NewSortKey(expr Expr.Expression, options ...bool) *SortKey { + var asc, nullF bool + switch len(options) { + case 2: + asc = options[0] + nullF = options[1] + case 1: + asc = options[0] + } + return &SortKey{ + Expr: expr, + Ascending: asc, + NullFirst: nullF, + } +} +func CombineSortKeys(sk ...*SortKey) []SortKey { + var res []SortKey + for _, s := range sk { + res = append(res, *s) + } + return res +} + +type SortExec struct { + child operators.Operator + schema *arrow.Schema + done bool + sortKeys []SortKey // resolves to columns +} + +func NewSortExec(child operators.Operator, sortKeys []SortKey) (*SortExec, error) { + fmt.Printf("sorts Keys %v\n", sortKeys) + return &SortExec{ + child: child, + schema: child.Schema(), + sortKeys: sortKeys, + }, nil +} + +// for now read everything into memory and sort -- next steps will be to do external merge +func (s *SortExec) Next(n uint16) (*operators.RecordBatch, error) { + if s.done { + return nil, io.EOF + } + allColumns := make([]arrow.Array, len(s.schema.Fields())) // concated columns + mem := memory.NewGoAllocator() + fmt.Printf("all columns init %v\n", allColumns) + var count uint64 + for { + childBatch, err := s.child.Next(n) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, err + } + for i := range childBatch.Columns { + if allColumns[i] == nil { + allColumns[i] = childBatch.Columns[i] + continue + } + largerArray, err := concatarr(allColumns[i], childBatch.Columns[i], mem) + if err != nil { + return nil, err + } + allColumns[i] = largerArray + } + } + if len(allColumns) > 0 { + count = uint64(allColumns[0].Len()) + } + idx := sortBatches(&operators.RecordBatch{ + Schema: s.schema, + Columns: allColumns, + RowCount: count, + }, s.sortKeys) + // now update all mappings + for i := range len(allColumns) { + tmpDatum, err := compute.Take(context.TODO(), *compute.DefaultTakeOptions(), compute.NewDatum(allColumns[i]), compute.NewDatum(toDatumFormat(idx, mem))) + if err != nil { + return nil, err + } + array, ok := tmpDatum.(*compute.ArrayDatum) + if !ok { + return nil, fmt.Errorf("non datum was returned from take") + } + allColumns[i] = array.MakeArray() + } + // TOOD: break this uo into N chunks + return &operators.RecordBatch{ + Schema: s.schema, + Columns: allColumns, + RowCount: count, + }, nil +} +func (s *SortExec) Schema() *arrow.Schema { + return s.schema +} +func (s *SortExec) Close() error { + return s.child.Close() +} + +/* +only sort and keep the top k elements in memory +*/ +type TopKSortExec struct { + child operators.Operator + schema *arrow.Schema + done bool + sortKeys []SortKey // resolves to columns + k uint16 // top k +} + +func NewTopKSortExec(child operators.Operator, sortKeys []SortKey, k uint16) (*TopKSortExec, error) { + fmt.Printf("sort keys %v\n", sortKeys) + return &TopKSortExec{ + child: child, + schema: child.Schema(), + sortKeys: sortKeys, + k: k, + }, nil +} + +// for now read everything into memory and sort -- next steps will be to do external merge +func (t *TopKSortExec) Next(n uint16) (*operators.RecordBatch, error) { + if t.done { + return nil, io.EOF + } + return nil, nil +} +func (t *TopKSortExec) Schema() *arrow.Schema { + return t.schema +} +func (t *TopKSortExec) Close() error { + return t.child.Close() +} + +/* +shared functions +*/ +func sortBatches(fullRC *operators.RecordBatch, sortKeys []SortKey) []uint64 { + keyColumns := make([]arrow.Array, len(sortKeys)) + for i, sk := range sortKeys { + arr, err := Expr.EvalExpression(sk.Expr, fullRC) + if err != nil { + panic(fmt.Sprintf("sort batches: failed to eval sort expression: %v", err)) + } + keyColumns[i] = arr + } + fmt.Printf("columns\n") + for i, k := range keyColumns { + fmt.Printf("%d:%v\n", i, k) + } + idVector := make([]uint64, fullRC.RowCount) + for i := 0; uint64(i) < fullRC.RowCount; i++ { + idVector[i] = uint64(i) + } + sortIndexVector(idVector, keyColumns, sortKeys) + fmt.Printf("old Id Vec:%v\n", idVector) + fmt.Printf("new ID vec: %v\n", idVector) + return idVector +} +func toRC() []arrow.Array { + return nil +} + +func concatarr(a arrow.Array, b arrow.Array, mem memory.Allocator) (arrow.Array, error) { + return array.Concatenate([]arrow.Array{a, b}, mem) + +} + +// sortIndexVector sorts idVec based on keyColumns + sortKeys. +// keyColumns[i] corresponds to sortKeys[i]. +func sortIndexVector(idVec []uint64, keyColumns []arrow.Array, sortKeys []SortKey) { + sort.Slice(idVec, func(a, b int) bool { + i := idVec[a] + j := idVec[b] + + // lexicographic: go through each sort key + for k, col := range keyColumns { + sk := sortKeys[k] + cmp := compareArrowValues(col, i, j) + + if cmp == 0 { + continue // equal → move to next key + } + + if sk.Ascending { + return cmp < 0 + } else { + return cmp > 0 + } + } + + // completely equal for all keys + return false + }) +} + +func compareArrowValues(col arrow.Array, i, j uint64) int { + // Handle nulls (treat as lowest value for now) + if col.IsNull(int(i)) && col.IsNull(int(j)) { + return 0 + } + if col.IsNull(int(i)) { + return -1 + } + if col.IsNull(int(j)) { + return 1 + } + + switch arr := col.(type) { + + case *array.String: + vi := arr.Value(int(i)) + vj := arr.Value(int(j)) + switch { + case vi < vj: + return -1 + case vi > vj: + return 1 + default: + return 0 + } + + case *array.Int8: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Int16: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Int32: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Int64: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Uint8: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Uint16: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Uint32: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Uint64: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareNumeric(vi, vj) + + case *array.Float32: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareFloat(vi, vj) + + case *array.Float64: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + return compareFloat(vi, vj) + + case *array.Boolean: + vi, vj := arr.Value(int(i)), arr.Value(int(j)) + if vi == vj { + return 0 + } + if !vi && vj { + return -1 + } + return 1 + + default: + panic("unsupported Arrow type in compareArrowValues") + } +} + +func compareNumeric[T int64 | int32 | int16 | int8 | uint64 | uint32 | uint16 | uint8](a, b T) int { + switch { + case a < b: + return -1 + case a > b: + return 1 + default: + return 0 + } +} + +func compareFloat[T float32 | float64](a, b T) int { + switch { + case a < b: + return -1 + case a > b: + return 1 + default: + return 0 + } +} +func toDatumFormat(v []uint64, mem memory.Allocator) compute.Datum { + // turn to array first + b := array.NewUint64Builder(mem) + defer b.Release() + for _, val := range v { + b.Append(val) + } + arr := b.NewArray() + defer arr.Release() + return compute.NewDatum(arr) +} diff --git a/src/Backend/opti-sql-go/operators/aggr/sort_test.go b/src/Backend/opti-sql-go/operators/aggr/sort_test.go index b919b31..9ae02ab 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort_test.go @@ -1,7 +1,129 @@ package aggr -import "testing" +import ( + "context" + "fmt" + "io" + "opti-sql-go/Expr" + "testing" -func TestSort(t *testing.T) { + "github.com/apache/arrow/go/v17/arrow/compute" +) + +func TestSortInit(t *testing.T) { // Simple passing test + t.Run("sort Exec init", func(t *testing.T) { + proj := aggProject() + sortExec, err := NewSortExec(proj, nil) + if err != nil { + t.Fatal(err) + } + if !sortExec.Schema().Equal(proj.Schema()) { + t.Fatalf("expected schema %v, got %v", proj.Schema(), sortExec.schema) + } + sortExec.done = true + _, err = sortExec.Next(100) + if err != io.EOF { + t.Fatalf("expected io.EOF error on done sortExec but got %v", err) + } + if sortExec.Close() != nil { + t.Fatalf("expected nil error on close but got %v", sortExec.Close()) + } + + }) + t.Run("tok k sort exec init", func(t *testing.T) { + proj := aggProject() + topKVal := 5 + topK, err := NewTopKSortExec(proj, nil, uint16(topKVal)) + if err != nil { + t.Fatal(err) + } + if !topK.Schema().Equal(proj.Schema()) { + t.Fatalf("expected schema %v, got %v", proj.Schema(), topK.schema) + } + if topK.k != 5 { + t.Fatalf("expected %v for top k but got %v", topKVal, topK.k) + } + topK.done = true + _, err = topK.Next(100) + if err != io.EOF { + t.Fatalf("expected io.EOF error on done topK but got %v", err) + } + if topK.Close() != nil { + t.Fatalf("expected nil error on close but got %v", topK.Close()) + } + + }) +} + +func TestBasicSortExpr(t *testing.T) { + t.Run("Sort", func(t *testing.T) { + proj := aggProject() + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + ageExpr := Expr.NewColumnResolve("age") + ageSK := NewSortKey(ageExpr, false) + _, err := NewSortExec(proj, CombineSortKeys(nameSK, ageSK)) + if err != nil { + t.Fatalf("unexpected error from NewSortExec : %v\n", err) + } + //t.Logf("%v\n", sortExec) + }) + t.Run("Basic Next operation", func(t *testing.T) { + proj := aggProject() + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + ageExpr := Expr.NewColumnResolve("age") + ageSK := NewSortKey(ageExpr, false) + sortExec, err := NewSortExec(proj, CombineSortKeys(ageSK, nameSK)) + if err != nil { + t.Fatalf("unexpected error from NewSortExec : %v\n", err) + } + sortedBatch, err := sortExec.Next(10) + if err != nil { + t.Fatalf("unexpected error from sortExec Next : %v\n", err) + } + fmt.Println(sortedBatch.PrettyPrint()) + + }) +} +func TestBasicTopKSortExpr(t *testing.T) { + t.Run("TopK Sort", func(t *testing.T) { + proj := aggProject() + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + ageExpr := Expr.NewColumnResolve("age") + ageSK := NewSortKey(ageExpr, false) + sortExec, err := NewTopKSortExec(proj, CombineSortKeys(nameSK, ageSK), 5) + if err != nil { + t.Fatalf("unexpected error from NewTopKSortExec : %v\n", err) + } + t.Logf("%v\n", sortExec) + + }) +} + +func TestOne(t *testing.T) { + v := compute.GetExecCtx(context.Background()) + names := v.Registry.GetFunctionNames() + for i, name := range names { + fmt.Printf("%d: %v\n", i, name) + } + /* + mem := memory.NewGoAllocator() + floatB := array.NewFloat64Builder(mem) + floatB.AppendValues([]float64{10.5, 20.3, 30.1, 40.7, 50.2}, []bool{true, true, true, true, true}) + pos := array.NewInt32Builder(mem) + pos.AppendValues([]int32{1, 3, 4}, []bool{true, true, true}) + + dat, err := compute.Take(context.TODO(), *compute.DefaultTakeOptions(), compute.NewDatum(floatB.NewArray()), compute.NewDatum(pos.NewArray())) + if err != nil { + t.Fatalf("Take failed: %v", err) + } + array, ok := dat.(*compute.ArrayDatum) + if !ok { + t.Logf("expected an array to be returned but got something else %T\n", dat) + } + t.Logf("data: %v\n", array.MakeArray()) + */ } diff --git a/src/Backend/opti-sql-go/operators/project/parquet.go b/src/Backend/opti-sql-go/operators/project/parquet.go index 94b6e1d..42d5c14 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/parquet.go @@ -22,12 +22,10 @@ var ( ) type ParquetSource struct { - // existing fields schema *arrow.Schema projectionPushDown []string // columns to project up reader pqarrow.RecordReader - // for internal reading - done bool // if set to true always return io.EOF + done bool // if set to true always return io.EOF } func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { @@ -45,7 +43,7 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { arrowReader, err := pqarrow.NewFileReader( filerReader, - pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, // TODO: Read in from config for this stuff + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, allocator, ) if err != nil { @@ -84,7 +82,7 @@ func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string) (*Parq arrowReader, err := pqarrow.NewFileReader( filerReader, - pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, // TODO: Read in from config for this stuff + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, allocator, ) if err != nil { diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index 3df1fee..033a58c 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -20,7 +20,7 @@ var ( ) type ProjectExec struct { - child operators.Operator + input operators.Operator outputschema arrow.Schema expr []Expr.Expression done bool @@ -60,7 +60,7 @@ func NewProjectExec(input operators.Operator, exprs []Expr.Expression) (*Project outputschema := arrow.NewSchema(fields, nil) // return new exec return &ProjectExec{ - child: input, + input: input, outputschema: *outputschema, expr: exprs, }, nil @@ -73,7 +73,7 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { return nil, io.EOF } - childBatch, err := p.child.Next(n) + childBatch, err := p.input.Next(n) if err != nil { return nil, err } @@ -102,7 +102,7 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { }, nil } func (p *ProjectExec) Close() error { - return p.child.Close() + return p.input.Close() } func (p *ProjectExec) Schema() *arrow.Schema { return &p.outputschema From eb30dec9a8f31ddbddb968fcfcaa93300be2dfaa Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Thu, 27 Nov 2025 15:01:03 -0500 Subject: [PATCH 10/10] feat:Full-Sort operator returns results in batches| TODO:Top K sort operator --- src/Backend/opti-sql-go/go.mod | 3 +- src/Backend/opti-sql-go/go.sum | 2 + .../opti-sql-go/operators/aggr/sort.go | 153 +++--- .../opti-sql-go/operators/aggr/sort_test.go | 490 +++++++++++++++++- 4 files changed, 585 insertions(+), 63 deletions(-) diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index c9ee239..5b872b6 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -1,6 +1,6 @@ module opti-sql-go -go 1.23 +go 1.24.0 require ( github.com/apache/arrow/go/v15 v15.0.2 @@ -28,6 +28,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.13 // indirect github.com/aws/smithy-go v1.23.2 // indirect github.com/go-ini/ini v1.67.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect diff --git a/src/Backend/opti-sql-go/go.sum b/src/Backend/opti-sql-go/go.sum index 9c4220d..7c4ee5c 100644 --- a/src/Backend/opti-sql-go/go.sum +++ b/src/Backend/opti-sql-go/go.sum @@ -37,6 +37,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= diff --git a/src/Backend/opti-sql-go/operators/aggr/sort.go b/src/Backend/opti-sql-go/operators/aggr/sort.go index 11ab431..60d0cb5 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math" "opti-sql-go/Expr" "opti-sql-go/operators" "sort" @@ -53,8 +54,13 @@ func CombineSortKeys(sk ...*SortKey) []SortKey { type SortExec struct { child operators.Operator schema *arrow.Schema - done bool sortKeys []SortKey // resolves to columns + // internal book keeping + totalColumns []arrow.Array + consumedOffset uint64 + totalRows uint64 + consumed bool // did we finish reading all of the child record batches? + done bool // have we already produced all the sorted record batches? } func NewSortExec(child operators.Operator, sortKeys []SortKey) (*SortExec, error) { @@ -67,59 +73,80 @@ func NewSortExec(child operators.Operator, sortKeys []SortKey) (*SortExec, error } // for now read everything into memory and sort -- next steps will be to do external merge + +// n is the number of records we will return,sortExec will read in 2^16-1 column entries from its child, this is more efficient that trusting the caller to pass in a reasonable +// n so that we avoid small/frequent IO operations func (s *SortExec) Next(n uint16) (*operators.RecordBatch, error) { if s.done { return nil, io.EOF } - allColumns := make([]arrow.Array, len(s.schema.Fields())) // concated columns - mem := memory.NewGoAllocator() - fmt.Printf("all columns init %v\n", allColumns) - var count uint64 - for { - childBatch, err := s.child.Next(n) - if err != nil { - if errors.Is(err, io.EOF) { - break + if !s.consumed { + allColumns := make([]arrow.Array, len(s.schema.Fields())) // concated columns + mem := memory.NewGoAllocator() + var count uint64 + for { + childBatch, err := s.child.Next(math.MaxUint16) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, err + } + for i := range childBatch.Columns { + if allColumns[i] == nil { + allColumns[i] = childBatch.Columns[i] + continue + } + largerArray, err := array.Concatenate([]arrow.Array{allColumns[i], childBatch.Columns[i]}, mem) + if err != nil { + return nil, err + } + allColumns[i] = largerArray } + } + s.consumed = true + if len(allColumns) > 0 { + count = uint64(allColumns[0].Len()) + } + idx, err := sortBatches(&operators.RecordBatch{ + Schema: s.schema, + Columns: allColumns, + RowCount: count, + }, s.sortKeys) + if err != nil { return nil, err } - for i := range childBatch.Columns { - if allColumns[i] == nil { - allColumns[i] = childBatch.Columns[i] - continue - } - largerArray, err := concatarr(allColumns[i], childBatch.Columns[i], mem) + // now update all mappings + for i := range len(allColumns) { + arr, err := compute.TakeArray(context.TODO(), allColumns[i], idxToArrowArray(idx, mem)) if err != nil { return nil, err } - allColumns[i] = largerArray + allColumns[i] = arr } + s.totalColumns = allColumns + s.totalRows = count } - if len(allColumns) > 0 { - count = uint64(allColumns[0].Len()) + var readSize uint64 + remaining := s.totalRows - s.consumedOffset + if remaining < uint64(n) { + // if n is more than we have left just read up to remaining + readSize = uint64(remaining) + s.done = true + } else { + // remaining > n or remaining = n then just read n and return + readSize = uint64(n) } - idx := sortBatches(&operators.RecordBatch{ - Schema: s.schema, - Columns: allColumns, - RowCount: count, - }, s.sortKeys) - // now update all mappings - for i := range len(allColumns) { - tmpDatum, err := compute.Take(context.TODO(), *compute.DefaultTakeOptions(), compute.NewDatum(allColumns[i]), compute.NewDatum(toDatumFormat(idx, mem))) - if err != nil { - return nil, err - } - array, ok := tmpDatum.(*compute.ArrayDatum) - if !ok { - return nil, fmt.Errorf("non datum was returned from take") - } - allColumns[i] = array.MakeArray() + mem := memory.NewGoAllocator() + sortedColumns, err := s.consumeSortedBatch(readSize, mem) + if err != nil { + return nil, err } - // TOOD: break this uo into N chunks + return &operators.RecordBatch{ Schema: s.schema, - Columns: allColumns, - RowCount: count, + Columns: sortedColumns, + RowCount: readSize, }, nil } func (s *SortExec) Schema() *arrow.Schema { @@ -128,6 +155,22 @@ func (s *SortExec) Schema() *arrow.Schema { func (s *SortExec) Close() error { return s.child.Close() } +func (s *SortExec) consumeSortedBatch(readsize uint64, mem memory.Allocator) ([]arrow.Array, error) { + ctx := context.TODO() + resultColumns := make([]arrow.Array, len(s.schema.Fields())) + offsetArray := genoffsetTakeIdx(s.consumedOffset, readsize, mem) + for i := range s.totalColumns { + sortArr := s.totalColumns[i] + arr, err := compute.TakeArray(ctx, sortArr, offsetArray) + if err != nil { + return nil, err + } + resultColumns[i] = arr + + } + s.consumedOffset += readsize + return resultColumns, nil +} /* only sort and keep the top k elements in memory @@ -167,35 +210,21 @@ func (t *TopKSortExec) Close() error { /* shared functions */ -func sortBatches(fullRC *operators.RecordBatch, sortKeys []SortKey) []uint64 { +func sortBatches(fullRC *operators.RecordBatch, sortKeys []SortKey) ([]uint64, error) { keyColumns := make([]arrow.Array, len(sortKeys)) for i, sk := range sortKeys { arr, err := Expr.EvalExpression(sk.Expr, fullRC) if err != nil { - panic(fmt.Sprintf("sort batches: failed to eval sort expression: %v", err)) + return nil, fmt.Errorf("sort batches: failed to eval sort expression: %v", err) } keyColumns[i] = arr } - fmt.Printf("columns\n") - for i, k := range keyColumns { - fmt.Printf("%d:%v\n", i, k) - } idVector := make([]uint64, fullRC.RowCount) for i := 0; uint64(i) < fullRC.RowCount; i++ { idVector[i] = uint64(i) } sortIndexVector(idVector, keyColumns, sortKeys) - fmt.Printf("old Id Vec:%v\n", idVector) - fmt.Printf("new ID vec: %v\n", idVector) - return idVector -} -func toRC() []arrow.Array { - return nil -} - -func concatarr(a arrow.Array, b arrow.Array, mem memory.Allocator) (arrow.Array, error) { - return array.Concatenate([]arrow.Array{a, b}, mem) - + return idVector, nil } // sortIndexVector sorts idVec based on keyColumns + sortKeys. @@ -328,14 +357,20 @@ func compareFloat[T float32 | float64](a, b T) int { return 0 } } -func toDatumFormat(v []uint64, mem memory.Allocator) compute.Datum { +func idxToArrowArray(v []uint64, mem memory.Allocator) arrow.Array { // turn to array first b := array.NewUint64Builder(mem) - defer b.Release() for _, val := range v { b.Append(val) } arr := b.NewArray() - defer arr.Release() - return compute.NewDatum(arr) + return arr +} +func genoffsetTakeIdx(offset, size uint64, mem memory.Allocator) arrow.Array { + b := array.NewUint64Builder(mem) + for i := range size { + b.Append(offset + i) + } + arr := b.NewArray() + return arr } diff --git a/src/Backend/opti-sql-go/operators/aggr/sort_test.go b/src/Backend/opti-sql-go/operators/aggr/sort_test.go index 9ae02ab..95754c8 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort_test.go @@ -2,14 +2,63 @@ package aggr import ( "context" + "errors" "fmt" "io" "opti-sql-go/Expr" + "opti-sql-go/operators" + "opti-sql-go/operators/project" "testing" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" "github.com/apache/arrow/go/v17/arrow/compute" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/go-jose/go-jose/v4/testutils/require" ) +func buildAggTestRecordBatch(t *testing.T) *operators.RecordBatch { + names, cols := generateAggTestColumns() + mem := memory.NewGoAllocator() + + arrowCols := make([]arrow.Array, len(cols)) + fields := make([]arrow.Field, len(cols)) + + for i, col := range cols { + switch v := col.(type) { + + case []int32: + b := array.NewInt32Builder(mem) + defer b.Release() + b.AppendValues(v, nil) + arrowCols[i] = b.NewArray() + + case []string: + b := array.NewStringBuilder(mem) + defer b.Release() + b.AppendValues(v, nil) + arrowCols[i] = b.NewArray() + + case []float64: + b := array.NewFloat64Builder(mem) + defer b.Release() + b.AppendValues(v, nil) + arrowCols[i] = b.NewArray() + + default: + t.Fatalf("unsupported type in generateAggTestColumns") + } + + fields[i] = arrow.Field{Name: names[i], Type: arrowCols[i].DataType()} + } + + return &operators.RecordBatch{ + Schema: arrow.NewSchema(fields, nil), + Columns: arrowCols, + RowCount: uint64(len(cols[0].([]int32))), + } +} + func TestSortInit(t *testing.T) { // Simple passing test t.Run("sort Exec init", func(t *testing.T) { @@ -30,6 +79,14 @@ func TestSortInit(t *testing.T) { t.Fatalf("expected nil error on close but got %v", sortExec.Close()) } + }) + t.Run("SortKey options", func(t *testing.T) { + proj := aggProject() + _, err := NewSortExec(proj, []SortKey{*NewSortKey(col("-"), false, false)}) + if err != nil { + t.Fatal(err) + } + }) t.Run("tok k sort exec init", func(t *testing.T) { proj := aggProject() @@ -79,14 +136,441 @@ func TestBasicSortExpr(t *testing.T) { if err != nil { t.Fatalf("unexpected error from NewSortExec : %v\n", err) } - sortedBatch, err := sortExec.Next(10) + for { + sortedBatch, err := sortExec.Next(5) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Fatalf("unexpected error from sortExec Next : %v\n", err) + } + fmt.Println(sortedBatch.PrettyPrint()) + } + }) +} +func TestFullSortOverNetwork(t *testing.T) { + t.Run("Full Sort of large file", func(t *testing.T) { + const fileName = "country_full.csv" + nr, err := project.NewStreamReader(fileName) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + pj, err := project.NewProjectCSVLeaf(nr.Stream()) if err != nil { - t.Fatalf("unexpected error from sortExec Next : %v\n", err) + t.Fatalf("failed to create csv project source from s3 object: %v", err) + } + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + sortExec, err := NewSortExec(pj, CombineSortKeys(nameSK)) + if err != nil { + t.Fatalf("unexpected error %v\n", err) + } + rc, err := sortExec.Next(10) + if err != nil { + t.Fatalf("unexpected error %v\n", err) + } + fmt.Println(rc.PrettyPrint()) + + }) + +} + +func TestFullSortExec_Next(t *testing.T) { + t.Parallel() + + t.Run("sort_age_DESC", func(t *testing.T) { + proj := aggProject() + + ageExpr := Expr.NewColumnResolve("age") + ageSK := NewSortKey(ageExpr, false) // DESC + + sortExec, err := NewSortExec(proj, CombineSortKeys(ageSK)) + require.NoError(t, err) + + batch, err := sortExec.Next(5) + require.NoError(t, err) + require.Equal(t, uint64(5), batch.RowCount) + + ages := batch.Columns[2].(*array.Int32) + got := []int32{ + ages.Value(0), + ages.Value(1), + ages.Value(2), + ages.Value(3), + ages.Value(4), + } + + expected := []int32{50, 48, 46, 45, 43} + for i, v := range expected { + if got[i] != v { + t.Fatalf("expected %v at index %d, but got %v", v, i, got[i]) + } + } + }) + + t.Run("sort_name_ASC", func(t *testing.T) { + proj := aggProject() + + nameExpr := Expr.NewColumnResolve("name") + nameSK := NewSortKey(nameExpr, true) + + sortExec, err := NewSortExec(proj, CombineSortKeys(nameSK)) + require.NoError(t, err) + + batch, err := sortExec.Next(3) + require.NoError(t, err) + + names := batch.Columns[1].(*array.String) + got := []string{ + names.Value(0), + names.Value(1), + names.Value(2), + } + + expected := []string{"Alice", "Bob", "Charlie"} + for i, v := range expected { + if got[i] != v { + t.Fatalf("expected %v at index %d, but got %v", v, i, got[i]) + } + } + }) +} + +// ----------------------------------------------------------------------------- +// TEST 2: sortIndexVector() +// ----------------------------------------------------------------------------- + +func TestSortIndexVector(t *testing.T) { + t.Parallel() + + mem := memory.NewGoAllocator() + + t.Run("single_key_int", func(t *testing.T) { + b := array.NewInt32Builder(mem) + b.AppendValues([]int32{30, 10, 20}, nil) + arr := b.NewArray() + defer arr.Release() + + keys := []arrow.Array{arr} + idVec := []uint64{0, 1, 2} + + sks := []SortKey{ + {Expr: nil, Ascending: true}, + } + + sortIndexVector(idVec, keys, sks) + + expected := []uint64{1, 2, 0} + for i, v := range expected { + if idVec[i] != v { + t.Fatalf("expected %v at index %d, but got %v", v, i, idVec[i]) + } + } + }) + + t.Run("single_key_string", func(t *testing.T) { + b := array.NewStringBuilder(mem) + b.AppendValues([]string{"Charlie", "Alice", "Bob"}, nil) + arr := b.NewArray() + defer arr.Release() + + keys := []arrow.Array{arr} + idVec := []uint64{0, 1, 2} + + sks := []SortKey{{Ascending: true}} + + sortIndexVector(idVec, keys, sks) + + expected := []uint64{1, 2, 0} + for i, v := range expected { + if idVec[i] != v { + t.Fatalf("expected %v at index %d, but got %v", v, i, idVec[i]) + } } - fmt.Println(sortedBatch.PrettyPrint()) + }) +} + +// ----------------------------------------------------------------------------- +// TEST 3: compareArrowValues() +// ----------------------------------------------------------------------------- + +func TestCompareArrowValues(t *testing.T) { + t.Parallel() + + mem := memory.NewGoAllocator() + t.Run("int", func(t *testing.T) { + b := array.NewInt32Builder(mem) + b.AppendValues([]int32{10, 20}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + require.Equal(t, 0, compareArrowValues(arr, 0, 0)) }) + + t.Run("uint", func(t *testing.T) { + b := array.NewUint32Builder(mem) + b.AppendValues([]uint32{5, 7}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + }) + + t.Run("float", func(t *testing.T) { + b := array.NewFloat64Builder(mem) + b.AppendValues([]float64{1.5, 1.7}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + }) + + t.Run("string", func(t *testing.T) { + b := array.NewStringBuilder(mem) + b.AppendValues([]string{"a", "b"}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + }) + + t.Run("bool", func(t *testing.T) { + b := array.NewBooleanBuilder(mem) + b.AppendValues([]bool{false, true}, nil) + arr := b.NewArray() + defer arr.Release() + + require.Equal(t, -1, compareArrowValues(arr, 0, 1)) + require.Equal(t, 1, compareArrowValues(arr, 1, 0)) + }) +} +func TestCompareArrowValues_AllTypes(t *testing.T) { + mem := memory.NewGoAllocator() + + // helper to assert cmp result + assert := func(name string, got, want int) { + if got != want { + t.Fatalf("%s: expected %d, got %d", name, want, got) + } + } + + // ---- STRING ---- + strB := array.NewStringBuilder(mem) + strB.Append("apple") + strB.Append("banana") + strArr := strB.NewArray().(*array.String) + + assert("string lt", compareArrowValues(strArr, 0, 1), -1) + assert("string gt", compareArrowValues(strArr, 1, 0), 1) + assert("string eq", compareArrowValues(strArr, 0, 0), 0) + + strArr.Release() + strB.Release() + + // ---- INT TYPES ---- + int8Arr := buildInt8(mem, []int8{1, 3}) + assert("int8 lt", compareArrowValues(int8Arr, 0, 1), -1) + assert("int8 gt", compareArrowValues(int8Arr, 1, 0), 1) + assert("int8 eq", compareArrowValues(int8Arr, 0, 0), 0) + int8Arr.Release() + + int16Arr := buildInt16(mem, []int16{5, 2}) + assert("int16 gt", compareArrowValues(int16Arr, 0, 1), 1) + int16Arr.Release() + + int32Arr := buildInt32(mem, []int32{10, 10}) + assert("int32 eq", compareArrowValues(int32Arr, 0, 1), 0) + int32Arr.Release() + + int64Arr := buildInt64(mem, []int64{-5, 7}) + assert("int64 lt", compareArrowValues(int64Arr, 0, 1), -1) + int64Arr.Release() + + // ---- UINT TYPES ---- + u8Arr := buildUint8(mem, []uint8{9, 3}) + assert("uint8 gt", compareArrowValues(u8Arr, 0, 1), 1) + u8Arr.Release() + + u16Arr := buildUint16(mem, []uint16{3, 3}) + assert("uint16 eq", compareArrowValues(u16Arr, 0, 1), 0) + u16Arr.Release() + + u32Arr := buildUint32(mem, []uint32{3, 10}) + assert("uint32 lt", compareArrowValues(u32Arr, 0, 1), -1) + u32Arr.Release() + + u64Arr := buildUint64(mem, []uint64{100, 2}) + assert("uint64 gt", compareArrowValues(u64Arr, 0, 1), 1) + u64Arr.Release() + + // ---- FLOAT TYPES ---- + f32Arr := buildFloat32(mem, []float32{1.5, 1.5}) + assert("float32 eq", compareArrowValues(f32Arr, 0, 1), 0) + f32Arr.Release() + + f64Arr := buildFloat64(mem, []float64{-1.0, 2.3}) + assert("float64 lt", compareArrowValues(f64Arr, 0, 1), -1) + f64Arr.Release() + + // ---- BOOLEAN ---- + boolArr := buildBool(mem, []bool{false, true}) + assert("bool lt", compareArrowValues(boolArr, 0, 1), -1) + assert("bool gt", compareArrowValues(boolArr, 1, 0), 1) + assert("bool eq", compareArrowValues(boolArr, 1, 1), 0) + boolArr.Release() + + // ---- NULL CASES ---- + nullB := array.NewInt32Builder(mem) + nullB.AppendNull() + nullB.Append(10) + nullArr := nullB.NewArray().(*array.Int32) + + assert("null < value", compareArrowValues(nullArr, 0, 1), -1) + assert("value > null", compareArrowValues(nullArr, 1, 0), 1) + assert("null == null", compareArrowValues(nullArr, 0, 0), 0) + + nullArr.Release() + nullB.Release() + + // ---- UNSUPPORTED TYPE PANIC ---- + // Build a fixed-size binary array to trigger panic + fsb := array.NewFixedSizeBinaryBuilder(mem, &arrow.FixedSizeBinaryType{ByteWidth: 2}) + fsb.Append([]byte{1, 2}) + fsb.Append([]byte{3, 4}) + fsArr := fsb.NewArray() + + didPanic := false + func() { + defer func() { + if recover() != nil { + didPanic = true + } + }() + _ = compareArrowValues(fsArr, 0, 1) + }() + if !didPanic { + t.Fatalf("expected panic for unsupported Arrow type") + } + + fsArr.Release() + fsb.Release() +} +func buildInt8(mem memory.Allocator, vals []int8) *array.Int8 { + b := array.NewInt8Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Int8) + b.Release() + return arr +} + +func buildInt16(mem memory.Allocator, vals []int16) *array.Int16 { + b := array.NewInt16Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Int16) + b.Release() + return arr +} + +func buildInt32(mem memory.Allocator, vals []int32) *array.Int32 { + b := array.NewInt32Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Int32) + b.Release() + return arr +} + +func buildInt64(mem memory.Allocator, vals []int64) *array.Int64 { + b := array.NewInt64Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Int64) + b.Release() + return arr +} + +func buildUint8(mem memory.Allocator, vals []uint8) *array.Uint8 { + b := array.NewUint8Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Uint8) + b.Release() + return arr } + +func buildUint16(mem memory.Allocator, vals []uint16) *array.Uint16 { + b := array.NewUint16Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Uint16) + b.Release() + return arr +} + +func buildUint32(mem memory.Allocator, vals []uint32) *array.Uint32 { + b := array.NewUint32Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Uint32) + b.Release() + return arr +} + +func buildUint64(mem memory.Allocator, vals []uint64) *array.Uint64 { + b := array.NewUint64Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Uint64) + b.Release() + return arr +} + +func buildFloat32(mem memory.Allocator, vals []float32) *array.Float32 { + b := array.NewFloat32Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Float32) + b.Release() + return arr +} + +func buildFloat64(mem memory.Allocator, vals []float64) *array.Float64 { + b := array.NewFloat64Builder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Float64) + b.Release() + return arr +} + +func buildBool(mem memory.Allocator, vals []bool) *array.Boolean { + b := array.NewBooleanBuilder(mem) + for _, v := range vals { + b.Append(v) + } + arr := b.NewArray().(*array.Boolean) + b.Release() + return arr +} + func TestBasicTopKSortExpr(t *testing.T) { t.Run("TopK Sort", func(t *testing.T) { proj := aggProject()