From bec1b3b71931746a48bb9db49259b2e2d669ee5e Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Sat, 15 Nov 2025 02:30:26 -0500 Subject: [PATCH 01/19] feat(ProjectExecCSV) Implement csv reader that transforms csv file into recordBatches --- .gitignore | 6 +- CONTRIBUTING.md | 6 + Makefile | 1 - src/Backend/opti-sql-go/go.mod | 8 + src/Backend/opti-sql-go/go.sum | 14 + .../operators/project/source/csv.go | 212 ++++ .../operators/project/source/csv_test.go | 970 +++++++++++++++++- src/Backend/opti-sql-go/operators/record.go | 3 + ...ealth_and_Social_Media_Balance_Dataset.csv | 501 +++++++++ 9 files changed, 1716 insertions(+), 5 deletions(-) create mode 100644 src/Backend/test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv diff --git a/.gitignore b/.gitignore index 18bb522..b8fbd54 100644 --- a/.gitignore +++ b/.gitignore @@ -102,4 +102,8 @@ src/Backend/test_data/json # Allow s3_source directory !src/Backend/test_data/s3_source/ -!src/Backend/test_data/s3_source/** \ No newline at end of file +!src/Backend/test_data/s3_source/** + +# Allow a specific CSV dataset that we want tracked despite the general csv ignores +!src/Backend/test_data/csv/ +!src/Backend/test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7e5fd6c..cdc3be3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,6 +19,12 @@ We use a Makefile to simplify common development tasks. All commands should be r ```bash make go-test-coverage ``` +- Run test with html coverage +```bash + +go tool cover -html=coverage.out +``` + ### Rust Tests - Run all tests diff --git a/Makefile b/Makefile index 0afcce1..045191c 100644 --- a/Makefile +++ b/Makefile @@ -29,7 +29,6 @@ go-test-coverage: @echo "Running Go tests with coverage..." cd src/Backend/opti-sql-go && go test -v -coverprofile=coverage.out ./... cd src/Backend/opti-sql-go && go tool cover -func=coverage.out - go-run: @echo "Running Go application..." cd src/Backend/opti-sql-go && go run main.go diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index 49182e3..426538f 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -3,6 +3,7 @@ module opti-sql-go go 1.24.0 require ( + github.com/apache/arrow/go/v15 v15.0.2 github.com/apache/arrow/go/v17 v17.0.0 google.golang.org/grpc v1.63.2 google.golang.org/protobuf v1.34.2 @@ -10,9 +11,16 @@ require ( ) require ( + github.com/andybalholm/brotli v1.1.0 // indirect + github.com/apache/thrift v0.20.0 // indirect github.com/goccy/go-json v0.10.3 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect + github.com/klauspost/asmfmt v1.3.2 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect + github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect + github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/mod v0.18.0 // indirect diff --git a/src/Backend/opti-sql-go/go.sum b/src/Backend/opti-sql-go/go.sum index 8839c2d..d8d9111 100644 --- a/src/Backend/opti-sql-go/go.sum +++ b/src/Backend/opti-sql-go/go.sum @@ -1,17 +1,31 @@ +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/apache/arrow/go/v15 v15.0.2 h1:60IliRbiyTWCWjERBCkO1W4Qun9svcYoZrSLcyOsMLE= +github.com/apache/arrow/go/v15 v15.0.2/go.mod h1:DGXsR3ajT524njufqf95822i+KTh+yea1jass9YXgjA= github.com/apache/arrow/go/v17 v17.0.0 h1:RRR2bdqKcdbss9Gxy2NS/hK8i4LDMh23L6BbkN5+F54= github.com/apache/arrow/go/v17 v17.0.0/go.mod h1:jR7QHkODl15PfYyjM2nU+yTLScZ/qfj7OSUZmJ8putc= +github.com/apache/thrift v0.20.0 h1:631+KvYbsBZxmuJjYwhezVsrfc/TbqtZV4QcxOX1fOI= +github.com/apache/thrift v0.20.0/go.mod h1:hOk1BQqcp2OLzGsyVXdfMk7YFlMxK3aoEVhjD06QhB8= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI= github.com/google/flatbuffers v24.3.25+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= +github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/src/Backend/opti-sql-go/operators/project/source/csv.go b/src/Backend/opti-sql-go/operators/project/source/csv.go index d150341..5780f97 100644 --- a/src/Backend/opti-sql-go/operators/project/source/csv.go +++ b/src/Backend/opti-sql-go/operators/project/source/csv.go @@ -1 +1,213 @@ package source + +import ( + "encoding/csv" + "fmt" + "io" + "opti-sql-go/operators" + "strconv" + "strings" + + "github.com/apache/arrow/go/v15/arrow/memory" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +type ProjectCSVLeaf struct { + r *csv.Reader + schema *arrow.Schema // columns to project as well as types to cast to + colPosition map[string]int + firstDataRow []string + done bool // if this is set in Next, we have reached EOF +} + +// assume everything is on disk for now +func NewProjectCSVLeaf(source io.Reader) (*ProjectCSVLeaf, error) { + r := csv.NewReader(source) + proj := &ProjectCSVLeaf{ + r: r, + colPosition: make(map[string]int), + } + var err error + // construct the schema from the header + proj.schema, err = proj.parseHeader() + return proj, err +} + +func (pcsv *ProjectCSVLeaf) Next(n uint64) (*operators.RecordBatch, error) { + if pcsv.done { + return nil, io.EOF + } + + // 1. Create builders + builders := pcsv.initBuilders() + + rowsRead := uint64(0) + + // Process stored first row (from parseHeader) --- + if pcsv.firstDataRow != nil && rowsRead < n { + if err := pcsv.processRow(pcsv.firstDataRow, builders); err != nil { + return nil, err + } + pcsv.firstDataRow = nil // consume it once + rowsRead++ + } + + // Stream remaining rows from CSV reader --- + for rowsRead < n { + row, err := pcsv.r.Read() + if err == io.EOF { + if rowsRead == 0 { + pcsv.done = true + return nil, io.EOF + } + break + } + if err != nil { + return nil, err + } + + // append to builders + if err := pcsv.processRow(row, builders); err != nil { + return nil, err + } + + rowsRead++ + } + + // Freeze into Arrow arrays + columns := pcsv.finalizeBuilders(builders) + + return &operators.RecordBatch{ + Schema: pcsv.schema, + Columns: columns, + }, nil +} + +func (pcsv *ProjectCSVLeaf) initBuilders() []array.Builder { + fields := pcsv.schema.Fields() + builders := make([]array.Builder, len(fields)) + + for i, f := range fields { + builders[i] = array.NewBuilder(memory.DefaultAllocator, f.Type) + } + + return builders +} +func (pcsv *ProjectCSVLeaf) processRow( + content []string, + builders []array.Builder, +) error { + fields := pcsv.schema.Fields() + + for i, f := range fields { + colIdx := pcsv.colPosition[f.Name] + cell := content[colIdx] + + switch b := builders[i].(type) { + + case *array.Int64Builder: + if cell == "" || cell == "NULL" { + b.AppendNull() + } else { + v, _ := strconv.ParseInt(cell, 10, 64) + b.Append(v) + } + + case *array.Float64Builder: + if cell == "" || cell == "NULL" { + b.AppendNull() + } else { + v, _ := strconv.ParseFloat(cell, 64) + b.Append(v) + } + + case *array.StringBuilder: + if cell == "" || cell == "NULL" { + b.AppendNull() + } else { + b.Append(cell) + } + + case *array.BooleanBuilder: + if cell == "" || cell == "NULL" { + b.AppendNull() + } else { + b.Append(cell == "true") + } + + default: + return fmt.Errorf("unsupported Arrow type: %s", f.Type) + } + } + + return nil +} +func (pcsv *ProjectCSVLeaf) finalizeBuilders(builders []array.Builder) []arrow.Array { + columns := make([]arrow.Array, len(builders)) + + for i, b := range builders { + columns[i] = b.NewArray() + b.Release() + } + + return columns +} + +// first call to csv.Reader +func (pscv *ProjectCSVLeaf) parseHeader() (*arrow.Schema, error) { + header, err := pscv.r.Read() + if err != nil { + return nil, err + } + firstDataRow, err := pscv.r.Read() + if err != nil { + return nil, err + } + pscv.firstDataRow = firstDataRow + newFields := make([]arrow.Field, 0, len(header)) + for i, colName := range header { + sampleValue := firstDataRow[i] + newFields = append(newFields, arrow.Field{ + Name: colName, + Type: parseDataType(sampleValue), + Nullable: true, + }) + pscv.colPosition[colName] = i + } + return arrow.NewSchema(newFields, nil), nil +} +func parseDataType(sample string) arrow.DataType { + sample = strings.TrimSpace(sample) + + // Nulls or empty fields → treat as nullable string in inference + if sample == "" || strings.EqualFold(sample, "NULL") { + return arrow.BinaryTypes.String + } + + // Boolean + if sample == "true" || sample == "false" { + return arrow.FixedWidthTypes.Boolean + } + + // Try int + if _, err := strconv.Atoi(sample); err == nil { + return arrow.PrimitiveTypes.Int64 + } + + // Try float + if _, err := strconv.ParseFloat(sample, 64); err == nil { + return arrow.PrimitiveTypes.Float64 + } + + // Fallback to string + return arrow.BinaryTypes.String +} + +/* +Integers (int8, int16, int32, int64) - whole numbers like 42, -100 +Floating point (float32, float64) - decimal numbers like 3.14, -0.5 +Booleans - true/false values (often represented as "true"/"false", "1"/"0", or "yes"/"no") +Strings (text) - any text like "hello", "John Doe" +Nulls +*/ diff --git a/src/Backend/opti-sql-go/operators/project/source/csv_test.go b/src/Backend/opti-sql-go/operators/project/source/csv_test.go index c00d2dd..3c87e32 100644 --- a/src/Backend/opti-sql-go/operators/project/source/csv_test.go +++ b/src/Backend/opti-sql-go/operators/project/source/csv_test.go @@ -1,7 +1,971 @@ package source -import "testing" +import ( + "fmt" + "io" + "os" + "strings" + "testing" -func TestCsv(t *testing.T) { - // Simple passing test + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +const csvFilePath = "../../../../test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv" + +//const csvFilePathLarger = "../../../../test_data/csv/stats.csv" + +func getTestFile() *os.File { + v, err := os.Open(csvFilePath) + if err != nil { + panic(err) + } + return v +} + +/* + func getTestFile2() *os.File { + v, err := os.Open(csvFilePathLarger) + if err != nil { + panic(err) + } + return v + } +*/ +func TestCsvInit(t *testing.T) { + v := getTestFile() + defer func() { + if err := v.Close(); err != nil { + t.Fatalf("failed to close: %v", err) + } + }() + p, err := NewProjectCSVLeaf(v) + if err != nil { + t.Errorf("Failed to create ProjectCSVLeaf: %v", err) + } + fmt.Printf("schema -> %v\n", p.schema) + fmt.Printf("columns Mapping -> %v\n", p.colPosition) +} +func TestProjectComponents(t *testing.T) { + v := getTestFile() + defer func() { + if err := v.Close(); err != nil { + t.Fatalf("failed to close: %v", err) + } + }() + p, err := NewProjectCSVLeaf(v) + if err != nil { + t.Errorf("Failed to create ProjectCSVLeaf: %v", err) + } + if p.schema == nil { + t.Errorf("Schema is nil") + } + if len(p.colPosition) == 0 { + t.Errorf("Column position mapping is empty") + } +} +func TestCsvNext(t *testing.T) { + v := getTestFile() + defer func() { + if err := v.Close(); err != nil { + t.Fatalf("failed to close: %v", err) + } + }() + + csvLeaf, err := NewProjectCSVLeaf(v) + if err != nil { + t.Errorf("Failed to create ProjectCSVLeaf: %v", err) + } + rBatch, err := csvLeaf.Next(10) + if err != nil { + t.Errorf("Failed to read next batch from CSV: %v", err) + } + fmt.Printf("Batch: %v\n", rBatch) +} + +// TestParseDataType tests every branch of the parseDataType function +func TestParseDataType(t *testing.T) { + tests := []struct { + name string + input string + expected arrow.DataType + }{ + // Empty and NULL cases + { + name: "Empty string", + input: "", + expected: arrow.BinaryTypes.String, + }, + { + name: "NULL uppercase", + input: "NULL", + expected: arrow.BinaryTypes.String, + }, + { + name: "NULL lowercase", + input: "null", + expected: arrow.BinaryTypes.String, + }, + { + name: "NULL mixed case", + input: "NuLl", + expected: arrow.BinaryTypes.String, + }, + { + name: "Empty string with whitespace", + input: " ", + expected: arrow.BinaryTypes.String, + }, + { + name: "NULL with whitespace", + input: " NULL ", + expected: arrow.BinaryTypes.String, + }, + + // Boolean cases + { + name: "Boolean true", + input: "true", + expected: arrow.FixedWidthTypes.Boolean, + }, + { + name: "Boolean false", + input: "false", + expected: arrow.FixedWidthTypes.Boolean, + }, + { + name: "Boolean true with whitespace", + input: " true ", + expected: arrow.FixedWidthTypes.Boolean, + }, + { + name: "Boolean false with whitespace", + input: " false ", + expected: arrow.FixedWidthTypes.Boolean, + }, + + // Integer cases + { + name: "Positive integer", + input: "123", + expected: arrow.PrimitiveTypes.Int64, + }, + { + name: "Negative integer", + input: "-456", + expected: arrow.PrimitiveTypes.Int64, + }, + { + name: "Zero", + input: "0", + expected: arrow.PrimitiveTypes.Int64, + }, + { + name: "Integer with whitespace", + input: " 789 ", + expected: arrow.PrimitiveTypes.Int64, + }, + + // Float cases + { + name: "Positive float", + input: "3.14", + expected: arrow.PrimitiveTypes.Float64, + }, + { + name: "Negative float", + input: "-2.71", + expected: arrow.PrimitiveTypes.Float64, + }, + { + name: "Float with leading zero", + input: "0.5", + expected: arrow.PrimitiveTypes.Float64, + }, + { + name: "Float with trailing zero", + input: "1.0", + expected: arrow.PrimitiveTypes.Float64, + }, + { + name: "Float with whitespace", + input: " 9.99 ", + expected: arrow.PrimitiveTypes.Float64, + }, + { + name: "Scientific notation", + input: "1.23e10", + expected: arrow.PrimitiveTypes.Float64, + }, + + // String fallback cases + { + name: "Regular string", + input: "hello", + expected: arrow.BinaryTypes.String, + }, + { + name: "String with spaces", + input: "hello world", + expected: arrow.BinaryTypes.String, + }, + { + name: "String with numbers", + input: "abc123", + expected: arrow.BinaryTypes.String, + }, + { + name: "Boolean-like but not exact", + input: "True", + expected: arrow.BinaryTypes.String, + }, + { + name: "Boolean-like but not exact 2", + input: "FALSE", + expected: arrow.BinaryTypes.String, + }, + { + name: "Invalid number", + input: "12.34.56", + expected: arrow.BinaryTypes.String, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseDataType(tt.input) + if result != tt.expected { + t.Errorf("parseDataType(%q) = %v, expected %v", tt.input, result, tt.expected) + } + }) + } +} + +// TestParseHeader tests the parseHeader function +func TestParseHeader(t *testing.T) { + t.Run("Valid header with all data types", func(t *testing.T) { + csvData := `id,name,age,salary,active +123,John,30,50000.50,true` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + // Check schema was created + if proj.schema == nil { + t.Fatal("Schema is nil") + } + + // Check correct number of fields + fields := proj.schema.Fields() + if len(fields) != 5 { + t.Errorf("Expected 5 fields, got %d", len(fields)) + } + + // Check field names and types + expectedFields := map[string]arrow.DataType{ + "id": arrow.PrimitiveTypes.Int64, + "name": arrow.BinaryTypes.String, + "age": arrow.PrimitiveTypes.Int64, + "salary": arrow.PrimitiveTypes.Float64, + "active": arrow.FixedWidthTypes.Boolean, + } + + for _, field := range fields { + expectedType, exists := expectedFields[field.Name] + if !exists { + t.Errorf("Unexpected field name: %s", field.Name) + continue + } + if field.Type != expectedType { + t.Errorf("Field %s: expected type %v, got %v", field.Name, expectedType, field.Type) + } + if !field.Nullable { + t.Errorf("Field %s: expected nullable=true, got false", field.Name) + } + } + + // Check column position mapping + if len(proj.colPosition) != 5 { + t.Errorf("Expected 5 column positions, got %d", len(proj.colPosition)) + } + + expectedPositions := map[string]int{ + "id": 0, + "name": 1, + "age": 2, + "salary": 3, + "active": 4, + } + + for name, expectedPos := range expectedPositions { + actualPos, exists := proj.colPosition[name] + if !exists { + t.Errorf("Column position for %s not found", name) + continue + } + if actualPos != expectedPos { + t.Errorf("Column %s: expected position %d, got %d", name, expectedPos, actualPos) + } + } + }) + + t.Run("Header with NULL values", func(t *testing.T) { + csvData := `col1,col2,col3 +NULL,,value` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + fields := proj.schema.Fields() + // All should be inferred as string + for _, field := range fields { + if field.Type != arrow.BinaryTypes.String { + t.Errorf("Field %s: expected String type for NULL/empty value, got %v", field.Name, field.Type) + } + } + }) + + t.Run("Empty file - header only", func(t *testing.T) { + csvData := `col1,col2` + reader := strings.NewReader(csvData) + _, err := NewProjectCSVLeaf(reader) + if err == nil { + t.Error("Expected error for CSV with header but no data rows") + } + }) + + t.Run("Completely empty file", func(t *testing.T) { + csvData := `` + reader := strings.NewReader(csvData) + _, err := NewProjectCSVLeaf(reader) + if err == nil { + t.Error("Expected error for completely empty CSV") + } + }) +} + +// TestNewProjectCSVLeaf tests the constructor +func TestNewProjectCSVLeaf(t *testing.T) { + t.Run("Valid CSV initialization", func(t *testing.T) { + csvData := `name,value +test,123` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + if proj == nil { + t.Fatal("ProjectCSVLeaf is nil") + } + if proj.r == nil { + t.Error("CSV reader is nil") + } + if proj.schema == nil { + t.Error("Schema is nil") + } + if proj.colPosition == nil { + t.Error("Column position map is nil") + } + if proj.done { + t.Error("done flag should be false initially") + } + }) + + t.Run("Error during header parsing", func(t *testing.T) { + csvData := `only_header` + reader := strings.NewReader(csvData) + _, err := NewProjectCSVLeaf(reader) + if err == nil { + t.Error("Expected error when no data rows present") + } + }) +} + +// TestNextFunction tests the Next function comprehensively +func TestNextFunction(t *testing.T) { + t.Run("Read single batch with all data types", func(t *testing.T) { + csvData := `id,name,score,active +1,Alice,95.5,true +2,Bob,87.3,false +3,Charlie,92.1,true` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + batch, err := proj.Next(10) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + if batch == nil { + t.Fatal("Batch is nil") + } + + // Check schema + if batch.Schema == nil { + t.Fatal("Batch schema is nil") + } + + // Check columns + if len(batch.Columns) != 4 { + t.Fatalf("Expected 4 columns, got %d", len(batch.Columns)) + } + + // Verify each column has 3 rows + for i, col := range batch.Columns { + if col.Len() != 3 { + t.Errorf("Column %d: expected 3 rows, got %d", i, col.Len()) + } + } + fmt.Printf("col0: %v\n", batch.Columns[0]) + // Check Int64 column (id) + idCol, ok := batch.Columns[0].(*array.Int64) + if !ok { + t.Errorf("Column 0 (id): expected *array.Int64, got %T", batch.Columns[0]) + } else { + if idCol.Value(0) != 1 || idCol.Value(1) != 2 || idCol.Value(2) != 3 { + t.Errorf("ID column values incorrect: got [%d, %d, %d]", idCol.Value(0), idCol.Value(1), idCol.Value(2)) + } + } + + // Check String column (name) + nameCol, ok := batch.Columns[1].(*array.String) + if !ok { + t.Errorf("Column 1 (name): expected *array.String, got %T", batch.Columns[1]) + } else { + if nameCol.Value(0) != "Alice" || nameCol.Value(1) != "Bob" || nameCol.Value(2) != "Charlie" { + t.Errorf("Name column values incorrect") + } + } + + // Check Float64 column (score) + scoreCol, ok := batch.Columns[2].(*array.Float64) + if !ok { + t.Errorf("Column 2 (score): expected *array.Float64, got %T", batch.Columns[2]) + } else { + if scoreCol.Value(0) != 95.5 || scoreCol.Value(1) != 87.3 || scoreCol.Value(2) != 92.1 { + t.Errorf("Score column values incorrect") + } + } + + // Check Boolean column (active) + activeCol, ok := batch.Columns[3].(*array.Boolean) + if !ok { + t.Errorf("Column 3 (active): expected *array.Boolean, got %T", batch.Columns[3]) + } else { + if !activeCol.Value(0) || activeCol.Value(1) || !activeCol.Value(2) { + t.Errorf("Active column values incorrect") + } + } + }) + + t.Run("Read with NULL values - Int64", func(t *testing.T) { + csvData := `id,value +1,100 +,200 +3,` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + batch, err := proj.Next(10) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + // Check id column for NULLs + idCol, ok := batch.Columns[0].(*array.Int64) + if !ok { + t.Fatalf("Column 0: expected *array.Int64, got %T", batch.Columns[0]) + } + + if !idCol.IsNull(1) { + t.Error("Expected NULL at index 1 in id column") + } + if idCol.IsNull(0) || idCol.IsNull(2) { + t.Error("Unexpected NULL in id column") + } + + // Check value column for NULLs + valueCol, ok := batch.Columns[1].(*array.Int64) + if !ok { + t.Fatalf("Column 1: expected *array.Int64, got %T", batch.Columns[1]) + } + + if !valueCol.IsNull(2) { + t.Error("Expected NULL at index 2 in value column") + } + }) + + t.Run("Read with NULL values - Float64", func(t *testing.T) { + csvData := `price +99.99 +NULL +` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + batch, err := proj.Next(10) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + priceCol, ok := batch.Columns[0].(*array.Float64) + if !ok { + t.Fatalf("Expected *array.Float64, got %T", batch.Columns[0]) + } + + if !priceCol.IsNull(1) || !priceCol.IsNull(2) { + t.Error("Expected NULL values in price column") + } + }) + + t.Run("Read with NULL values - String", func(t *testing.T) { + csvData := `name +Alice +NULL +` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + batch, err := proj.Next(10) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + nameCol, ok := batch.Columns[0].(*array.String) + if !ok { + t.Fatalf("Expected *array.String, got %T", batch.Columns[0]) + } + + if !nameCol.IsNull(1) || !nameCol.IsNull(2) { + t.Error("Expected NULL values in name column") + } + }) + + t.Run("Read with NULL values - Boolean", func(t *testing.T) { + csvData := `flag +true +NULL +false +` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + batch, err := proj.Next(10) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + flagCol, ok := batch.Columns[0].(*array.Boolean) + if !ok { + t.Fatalf("Expected *array.Boolean, got %T", batch.Columns[0]) + } + fmt.Printf("flagCol : %v\n", flagCol) + + if !flagCol.IsNull(1) { + t.Error("Expected NULL values in flag column") + } + }) + + t.Run("Read multiple batches", func(t *testing.T) { + csvData := `id +1 +2 +3 +4 +5 +6` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + // First batch of 2 + batch1, err := proj.Next(2) + if err != nil { + t.Fatalf("First Next failed: %v", err) + } + if batch1.Columns[0].Len() != 2 { + t.Errorf("First batch: expected 2 rows, got %d", batch1.Columns[0].Len()) + } + + // Second batch of 3 + batch2, err := proj.Next(3) + if err != nil { + t.Fatalf("Second Next failed: %v", err) + } + if batch2.Columns[0].Len() != 3 { + t.Errorf("Second batch: expected 3 rows, got %d", batch2.Columns[0].Len()) + } + + // Third batch - should get remaining 1 row + batch3, err := proj.Next(10) + if err != nil { + t.Fatalf("Third Next failed: %v", err) + } + if batch3.Columns[0].Len() != 1 { + t.Errorf("Third batch: expected 1 row, got %d", batch3.Columns[0].Len()) + } + + // Fourth batch - should return EOF + _, err = proj.Next(10) + if err != io.EOF { + t.Errorf("Expected EOF, got: %v", err) + } + }) + + t.Run("Read exact batch size", func(t *testing.T) { + csvData := `num +10 +20 +30` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + batch, err := proj.Next(3) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + if batch.Columns[0].Len() != 3 { + t.Errorf("Expected 3 rows, got %d", batch.Columns[0].Len()) + } + + // Next call should return EOF + _, err = proj.Next(1) + if err != io.EOF { + t.Errorf("Expected EOF after reading all data, got: %v", err) + } + }) + + t.Run("EOF on first Next call - empty data", func(t *testing.T) { + csvData := `col1 +val1` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + // Read the only row + _, err = proj.Next(10) + if err != nil { + t.Fatalf("First Next failed: %v", err) + } + + // Second call when no data remains and rowsRead == 0 + _, err = proj.Next(10) + if err != io.EOF { + t.Errorf("Expected EOF when no data left, got: %v", err) + } + + // Verify done flag is set + if !proj.done { + t.Error("Expected done flag to be true after EOF") + } + }) + + t.Run("Subsequent calls after done is set", func(t *testing.T) { + csvData := `val +1` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + // Read all data + _, _ = proj.Next(10) + + // Hit EOF and set done + _, err = proj.Next(10) + if err != io.EOF { + t.Fatalf("Expected EOF, got: %v", err) + } + + // Call again - should immediately return EOF due to done flag + _, err = proj.Next(10) + if err != io.EOF { + t.Errorf("Expected EOF on subsequent call when done=true, got: %v", err) + } + }) + + t.Run("Batch size of 1", func(t *testing.T) { + csvData := `x +a +b +c` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + // Read one row at a time + for i := 0; i < 3; i++ { + batch, err := proj.Next(1) + if err != nil { + t.Fatalf("Next call %d failed: %v", i+1, err) + } + if batch.Columns[0].Len() != 1 { + t.Errorf("Batch %d: expected 1 row, got %d", i+1, batch.Columns[0].Len()) + } + } + }) + + t.Run("Large batch size with fewer rows", func(t *testing.T) { + csvData := `num +1 +2` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + batch, err := proj.Next(1000) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + if batch.Columns[0].Len() != 2 { + t.Errorf("Expected 2 rows, got %d", batch.Columns[0].Len()) + } + }) + + t.Run("EOF mid-batch breaks correctly", func(t *testing.T) { + csvData := `id +1 +2 +3` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + // Request 10 rows, but only 3 exist + batch, err := proj.Next(10) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + // Should get 3 rows (break on EOF, not error) + if batch.Columns[0].Len() != 3 { + t.Errorf("Expected 3 rows when hitting EOF mid-batch, got %d", batch.Columns[0].Len()) + } + }) + + t.Run("Boolean false value handling", func(t *testing.T) { + csvData := `active +false +true +false` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + batch, err := proj.Next(10) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + boolCol, ok := batch.Columns[0].(*array.Boolean) + if !ok { + t.Fatalf("Expected *array.Boolean, got %T", batch.Columns[0]) + } + + // Verify false values are correctly stored + if boolCol.Value(0) != false { + t.Error("Expected false at index 0") + } + if boolCol.Value(1) != true { + t.Error("Expected true at index 1") + } + if boolCol.Value(2) != false { + t.Error("Expected false at index 2") + } + }) + + t.Run("Column ordering matches schema", func(t *testing.T) { + csvData := `z,y,x +1,2,3 +4,5,6` + reader := strings.NewReader(csvData) + proj, err := NewProjectCSVLeaf(reader) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + batch, err := proj.Next(10) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + // Verify schema field order + fields := batch.Schema.Fields() + if fields[0].Name != "z" || fields[1].Name != "y" || fields[2].Name != "x" { + t.Error("Schema field order doesn't match CSV header order") + } + + // Verify data is in correct columns + zCol := batch.Columns[0].(*array.Int64) + yCol := batch.Columns[1].(*array.Int64) + xCol := batch.Columns[2].(*array.Int64) + + if zCol.Value(0) != 1 || yCol.Value(0) != 2 || xCol.Value(0) != 3 { + t.Error("First row data not in correct column order") + } + if zCol.Value(1) != 4 || yCol.Value(1) != 5 || xCol.Value(1) != 6 { + t.Error("Second row data not in correct column order") + } + }) +} + +// TestIntegrationWithRealFile tests with the actual test file +func TestIntegrationWithRealFile(t *testing.T) { + t.Run("Real file - multiple batches", func(t *testing.T) { + v := getTestFile() + defer func() { + if err := v.Close(); err != nil { + t.Fatalf("failed to close: %v", err) + } + }() + + proj, err := NewProjectCSVLeaf(v) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + totalRows := 0 + batchCount := 0 + + for { + batch, err := proj.Next(10) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Next failed on batch %d: %v", batchCount+1, err) + } + + batchCount++ + if len(batch.Columns) > 0 { + totalRows += batch.Columns[0].Len() + } + + // Verify all columns have same length + expectedLen := batch.Columns[0].Len() + for i, col := range batch.Columns { + if col.Len() != expectedLen { + t.Errorf("Batch %d, Column %d: length mismatch, expected %d, got %d", + batchCount, i, expectedLen, col.Len()) + } + } + } + + if batchCount == 0 { + t.Error("Expected at least one batch from real file") + } + if totalRows == 0 { + t.Error("Expected at least one row from real file") + } + + t.Logf("Read %d batches with total of %d rows", batchCount, totalRows) + }) + + t.Run("Real file - schema validation", func(t *testing.T) { + v := getTestFile() + defer func() { + if err := v.Close(); err != nil { + t.Fatalf("failed to close: %v", err) + } + }() + + proj, err := NewProjectCSVLeaf(v) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + + if proj.schema == nil { + t.Fatal("Schema is nil") + } + + fields := proj.schema.Fields() + if len(fields) == 0 { + t.Error("Schema has no fields") + } + + // Verify all fields are nullable + for _, field := range fields { + if !field.Nullable { + t.Errorf("Field %s is not nullable", field.Name) + } + } + + // Verify colPosition map matches schema + if len(proj.colPosition) != len(fields) { + t.Errorf("Column position map size (%d) doesn't match schema field count (%d)", + len(proj.colPosition), len(fields)) + } + + for i, field := range fields { + pos, exists := proj.colPosition[field.Name] + if !exists { + t.Errorf("Field %s not found in column position map", field.Name) + } + if pos != i { + t.Errorf("Field %s: expected position %d, got %d", field.Name, i, pos) + } + } + }) +} + +/* +func TestLargercsvFile(t *testing.T) { + f1 := getTestFile2() + + project, err := NewProjectCSVLeaf(f1) + if err != nil { + t.Fatalf("NewProjectCSVLeaf failed: %v", err) + } + defer func() { + if err := f1.Close(); err != nil { + t.Fatalf("failed to close: %v", err) + } + }() + for { + rc, err := project.Next(1024 * 8) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Next failed: %v", err) + } + fmt.Printf("rc : %v\n", rc.Columns) + } } +*/ diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index 70a8a2e..54ccf60 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -15,6 +15,9 @@ var ( } ) +type Operator interface { + Next(uint16) (*RecordBatch, error) +} type RecordBatch struct { Schema *arrow.Schema Columns []arrow.Array diff --git a/src/Backend/test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv b/src/Backend/test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv new file mode 100644 index 0000000..7457ea2 --- /dev/null +++ b/src/Backend/test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv @@ -0,0 +1,501 @@ +User_ID,Age,Gender,Daily_Screen_Time(hrs),Sleep_Quality(1-10),Stress_Level(1-10),Days_Without_Social_Media,Exercise_Frequency(week),Social_Media_Platform,Happiness_Index(1-10) +U001,44,Male,3.1,7.0,6.0,2.0,5.0,Facebook,10.0 +U002,30,Other,5.1,7.0,8.0,5.0,3.0,LinkedIn,10.0 +U003,23,Other,7.4,6.0,7.0,1.0,3.0,YouTube,6.0 +U004,36,Female,5.7,7.0,8.0,1.0,1.0,TikTok,8.0 +U005,34,Female,7.0,4.0,7.0,5.0,1.0,X (Twitter),8.0 +U006,38,Male,6.6,5.0,7.0,4.0,3.0,LinkedIn,8.0 +U007,26,Female,7.8,4.0,8.0,2.0,0.0,TikTok,7.0 +U008,26,Female,7.4,5.0,6.0,1.0,4.0,Instagram,7.0 +U009,39,Male,4.7,7.0,7.0,6.0,1.0,YouTube,9.0 +U010,39,Female,6.6,6.0,8.0,0.0,2.0,Facebook,7.0 +U011,18,Female,2.8,7.0,6.0,2.0,0.0,Instagram,7.0 +U012,37,Other,5.4,5.0,7.0,3.0,2.0,Instagram,9.0 +U013,17,Female,7.0,7.0,10.0,7.0,1.0,YouTube,8.0 +U014,39,Female,5.7,5.0,7.0,4.0,0.0,Facebook,8.0 +U015,45,Male,6.3,7.0,7.0,4.0,3.0,X (Twitter),9.0 +U016,17,Female,5.1,7.0,6.0,2.0,5.0,LinkedIn,10.0 +U017,36,Female,7.5,5.0,8.0,4.0,4.0,Facebook,7.0 +U018,48,Male,5.4,6.0,4.0,3.0,4.0,TikTok,10.0 +U019,27,Male,4.7,6.0,6.0,0.0,2.0,Instagram,9.0 +U020,37,Male,7.1,6.0,6.0,5.0,4.0,TikTok,10.0 +U021,40,Female,3.0,8.0,5.0,5.0,4.0,X (Twitter),10.0 +U022,42,Other,6.4,7.0,6.0,5.0,3.0,TikTok,10.0 +U023,43,Male,3.6,8.0,5.0,3.0,1.0,Facebook,9.0 +U024,31,Female,7.8,5.0,8.0,1.0,2.0,TikTok,7.0 +U025,30,Male,5.0,6.0,6.0,0.0,5.0,X (Twitter),8.0 +U026,18,Other,5.7,6.0,7.0,3.0,2.0,X (Twitter),8.0 +U027,22,Female,6.2,5.0,8.0,5.0,0.0,X (Twitter),6.0 +U028,36,Female,2.0,8.0,4.0,5.0,0.0,Facebook,10.0 +U029,24,Male,4.3,6.0,6.0,5.0,2.0,X (Twitter),10.0 +U030,33,Male,7.4,6.0,10.0,3.0,4.0,Instagram,8.0 +U031,19,Male,7.1,5.0,8.0,5.0,2.0,X (Twitter),9.0 +U032,40,Male,6.0,6.0,7.0,5.0,4.0,TikTok,7.0 +U033,29,Female,3.1,8.0,4.0,2.0,2.0,LinkedIn,10.0 +U034,24,Female,6.6,6.0,7.0,4.0,0.0,TikTok,7.0 +U035,41,Other,6.8,8.0,7.0,2.0,2.0,LinkedIn,9.0 +U036,17,Other,5.6,6.0,7.0,4.0,3.0,TikTok,7.0 +U037,35,Female,7.7,4.0,8.0,0.0,3.0,Facebook,8.0 +U038,43,Female,4.3,8.0,4.0,1.0,1.0,TikTok,10.0 +U039,22,Other,4.6,4.0,6.0,5.0,2.0,Facebook,9.0 +U040,23,Female,1.0,9.0,5.0,5.0,7.0,Facebook,10.0 +U041,29,Male,7.2,5.0,7.0,4.0,1.0,X (Twitter),10.0 +U042,32,Male,6.1,5.0,6.0,3.0,3.0,Instagram,7.0 +U043,19,Male,2.5,9.0,3.0,0.0,0.0,LinkedIn,10.0 +U044,17,Other,2.3,9.0,3.0,3.0,4.0,X (Twitter),10.0 +U045,21,Female,3.9,6.0,6.0,4.0,1.0,X (Twitter),8.0 +U046,19,Male,6.9,5.0,7.0,2.0,3.0,Facebook,9.0 +U047,44,Female,5.0,4.0,7.0,2.0,1.0,Facebook,8.0 +U048,33,Female,4.1,8.0,7.0,6.0,3.0,LinkedIn,9.0 +U049,41,Male,3.7,8.0,6.0,3.0,1.0,TikTok,9.0 +U050,49,Male,3.6,7.0,5.0,4.0,1.0,TikTok,9.0 +U051,25,Female,2.0,9.0,4.0,5.0,3.0,X (Twitter),10.0 +U052,29,Female,3.3,7.0,5.0,1.0,2.0,YouTube,10.0 +U053,46,Female,3.6,7.0,5.0,3.0,3.0,YouTube,10.0 +U054,30,Female,6.7,6.0,8.0,6.0,2.0,TikTok,7.0 +U055,23,Female,5.2,6.0,6.0,7.0,3.0,YouTube,8.0 +U056,29,Female,4.0,7.0,3.0,3.0,3.0,TikTok,10.0 +U057,38,Female,9.7,3.0,9.0,3.0,2.0,Facebook,4.0 +U058,36,Female,4.3,7.0,6.0,4.0,1.0,X (Twitter),9.0 +U059,31,Male,7.6,5.0,10.0,3.0,2.0,Instagram,6.0 +U060,33,Male,8.4,4.0,8.0,6.0,4.0,X (Twitter),7.0 +U061,39,Male,2.7,9.0,5.0,1.0,3.0,TikTok,10.0 +U062,41,Female,5.6,6.0,7.0,2.0,1.0,TikTok,9.0 +U063,40,Female,4.6,7.0,7.0,1.0,3.0,X (Twitter),10.0 +U064,44,Female,4.4,6.0,7.0,1.0,3.0,LinkedIn,9.0 +U065,30,Male,6.5,4.0,9.0,4.0,3.0,TikTok,5.0 +U066,16,Male,4.4,8.0,4.0,5.0,1.0,YouTube,10.0 +U067,40,Male,4.6,7.0,6.0,2.0,1.0,YouTube,10.0 +U068,22,Male,5.2,9.0,7.0,5.0,1.0,TikTok,10.0 +U069,24,Female,4.8,6.0,7.0,5.0,2.0,TikTok,7.0 +U070,39,Male,3.4,8.0,6.0,7.0,3.0,LinkedIn,10.0 +U071,16,Male,6.9,5.0,8.0,3.0,3.0,YouTube,5.0 +U072,23,Female,7.4,6.0,7.0,2.0,3.0,TikTok,8.0 +U073,39,Male,6.1,5.0,5.0,1.0,5.0,TikTok,10.0 +U074,26,Female,4.8,6.0,6.0,1.0,4.0,LinkedIn,8.0 +U075,32,Female,4.2,8.0,8.0,5.0,2.0,LinkedIn,8.0 +U076,23,Male,4.2,7.0,5.0,0.0,1.0,X (Twitter),10.0 +U077,48,Female,9.7,4.0,8.0,5.0,1.0,LinkedIn,5.0 +U078,20,Female,7.2,6.0,9.0,1.0,2.0,Instagram,6.0 +U079,43,Other,4.0,7.0,7.0,2.0,4.0,YouTube,9.0 +U080,22,Male,5.8,6.0,8.0,3.0,2.0,X (Twitter),7.0 +U081,24,Female,4.8,6.0,7.0,2.0,2.0,Instagram,8.0 +U082,23,Female,6.6,6.0,8.0,5.0,1.0,TikTok,8.0 +U083,27,Male,6.5,7.0,9.0,7.0,2.0,TikTok,8.0 +U084,49,Male,7.5,4.0,7.0,2.0,4.0,X (Twitter),8.0 +U085,48,Other,4.0,7.0,5.0,2.0,5.0,LinkedIn,10.0 +U086,38,Female,5.7,5.0,6.0,3.0,2.0,TikTok,9.0 +U087,39,Male,2.4,7.0,4.0,4.0,2.0,Facebook,10.0 +U088,37,Female,3.2,8.0,7.0,3.0,1.0,X (Twitter),10.0 +U089,42,Male,7.1,3.0,8.0,6.0,5.0,Instagram,6.0 +U090,16,Male,3.8,10.0,6.0,5.0,0.0,X (Twitter),10.0 +U091,29,Male,4.0,6.0,4.0,1.0,4.0,YouTube,9.0 +U092,18,Male,5.5,6.0,6.0,4.0,2.0,LinkedIn,9.0 +U093,16,Male,6.8,7.0,6.0,6.0,2.0,LinkedIn,9.0 +U094,20,Male,3.6,7.0,7.0,2.0,1.0,X (Twitter),8.0 +U095,41,Female,1.0,7.0,5.0,2.0,6.0,TikTok,10.0 +U096,29,Male,3.9,7.0,6.0,2.0,3.0,TikTok,10.0 +U097,42,Male,7.7,4.0,7.0,2.0,3.0,Facebook,8.0 +U098,24,Female,6.1,4.0,9.0,0.0,4.0,Facebook,6.0 +U099,30,Female,7.9,5.0,8.0,0.0,2.0,LinkedIn,8.0 +U100,30,Female,3.5,8.0,4.0,3.0,3.0,Facebook,10.0 +U101,41,Male,2.5,9.0,4.0,6.0,3.0,Instagram,10.0 +U102,28,Male,6.3,5.0,7.0,2.0,2.0,Facebook,7.0 +U103,47,Male,5.9,4.0,7.0,2.0,4.0,Instagram,7.0 +U104,47,Male,6.0,6.0,7.0,6.0,2.0,Facebook,9.0 +U105,19,Male,7.2,5.0,6.0,1.0,0.0,LinkedIn,9.0 +U106,45,Male,6.3,7.0,8.0,0.0,2.0,LinkedIn,6.0 +U107,38,Male,6.2,7.0,7.0,1.0,3.0,Instagram,9.0 +U108,30,Male,6.8,7.0,7.0,2.0,1.0,X (Twitter),10.0 +U109,44,Male,6.4,6.0,8.0,0.0,1.0,LinkedIn,7.0 +U110,28,Male,7.0,6.0,6.0,3.0,2.0,X (Twitter),8.0 +U111,47,Male,7.3,4.0,6.0,5.0,4.0,TikTok,8.0 +U112,22,Female,8.3,5.0,10.0,0.0,2.0,Instagram,6.0 +U113,37,Female,4.3,6.0,8.0,3.0,5.0,LinkedIn,9.0 +U114,43,Male,6.4,6.0,7.0,4.0,2.0,YouTube,7.0 +U115,17,Female,4.8,7.0,6.0,3.0,3.0,TikTok,10.0 +U116,21,Female,4.0,9.0,5.0,3.0,2.0,X (Twitter),10.0 +U117,43,Male,4.3,7.0,6.0,3.0,4.0,LinkedIn,9.0 +U118,43,Female,6.4,6.0,7.0,1.0,1.0,Instagram,9.0 +U119,35,Female,6.3,6.0,7.0,6.0,4.0,Facebook,9.0 +U120,45,Male,6.2,5.0,8.0,5.0,3.0,YouTube,8.0 +U121,26,Male,5.1,6.0,5.0,5.0,2.0,X (Twitter),9.0 +U122,43,Female,5.5,8.0,6.0,5.0,2.0,TikTok,10.0 +U123,40,Female,7.0,4.0,6.0,0.0,5.0,Instagram,8.0 +U124,48,Male,2.8,7.0,5.0,5.0,3.0,YouTube,10.0 +U125,16,Male,5.5,7.0,7.0,4.0,1.0,Instagram,9.0 +U126,42,Male,5.7,7.0,7.0,0.0,1.0,LinkedIn,9.0 +U127,28,Female,6.9,5.0,8.0,3.0,3.0,Facebook,8.0 +U128,18,Female,3.1,7.0,6.0,4.0,3.0,TikTok,9.0 +U129,21,Male,8.2,5.0,7.0,6.0,4.0,Instagram,7.0 +U130,23,Male,5.9,5.0,5.0,4.0,0.0,X (Twitter),9.0 +U131,42,Female,6.0,5.0,8.0,1.0,1.0,LinkedIn,7.0 +U132,24,Male,4.8,8.0,4.0,6.0,4.0,LinkedIn,10.0 +U133,48,Male,6.5,6.0,8.0,3.0,3.0,TikTok,8.0 +U134,39,Female,5.6,8.0,6.0,2.0,3.0,LinkedIn,9.0 +U135,30,Female,7.7,4.0,10.0,2.0,1.0,X (Twitter),8.0 +U136,47,Female,2.9,8.0,5.0,0.0,1.0,Instagram,10.0 +U137,47,Female,5.8,6.0,7.0,3.0,4.0,TikTok,10.0 +U138,39,Male,6.2,5.0,8.0,4.0,4.0,LinkedIn,7.0 +U139,27,Male,5.2,7.0,5.0,2.0,1.0,LinkedIn,10.0 +U140,17,Female,4.8,8.0,5.0,4.0,1.0,X (Twitter),10.0 +U141,18,Male,4.8,6.0,5.0,1.0,2.0,Facebook,10.0 +U142,32,Male,5.8,6.0,6.0,3.0,1.0,Facebook,10.0 +U143,17,Male,2.2,10.0,5.0,4.0,1.0,LinkedIn,10.0 +U144,17,Male,7.2,7.0,8.0,3.0,1.0,Instagram,8.0 +U145,43,Male,4.6,8.0,7.0,0.0,2.0,TikTok,9.0 +U146,38,Male,5.7,7.0,5.0,2.0,1.0,LinkedIn,10.0 +U147,47,Male,6.2,5.0,7.0,6.0,1.0,YouTube,9.0 +U148,48,Male,4.6,7.0,5.0,2.0,3.0,TikTok,10.0 +U149,16,Female,5.3,7.0,6.0,3.0,2.0,Facebook,10.0 +U150,34,Female,1.8,9.0,7.0,3.0,3.0,LinkedIn,10.0 +U151,17,Female,6.8,5.0,8.0,6.0,0.0,X (Twitter),6.0 +U152,41,Female,7.4,4.0,9.0,6.0,3.0,YouTube,5.0 +U153,47,Female,2.4,7.0,7.0,5.0,6.0,LinkedIn,9.0 +U154,21,Male,4.8,5.0,7.0,1.0,2.0,Instagram,8.0 +U155,47,Female,3.3,8.0,5.0,4.0,4.0,YouTube,10.0 +U156,19,Female,5.2,7.0,5.0,4.0,4.0,LinkedIn,10.0 +U157,26,Female,6.7,5.0,7.0,2.0,4.0,Instagram,8.0 +U158,32,Male,7.3,5.0,7.0,1.0,3.0,YouTube,8.0 +U159,39,Male,7.5,5.0,6.0,6.0,3.0,Instagram,7.0 +U160,20,Male,7.9,5.0,9.0,2.0,4.0,Instagram,5.0 +U161,49,Male,5.1,5.0,7.0,6.0,3.0,YouTube,8.0 +U162,21,Female,3.4,9.0,7.0,5.0,1.0,YouTube,8.0 +U163,37,Female,5.3,7.0,6.0,4.0,0.0,Instagram,10.0 +U164,26,Male,6.6,6.0,8.0,4.0,0.0,Facebook,8.0 +U165,31,Other,3.3,9.0,5.0,0.0,2.0,YouTube,10.0 +U166,48,Female,8.5,4.0,10.0,3.0,3.0,LinkedIn,8.0 +U167,24,Male,7.1,6.0,8.0,6.0,1.0,TikTok,7.0 +U168,21,Male,5.1,6.0,5.0,4.0,4.0,X (Twitter),10.0 +U169,31,Male,4.2,8.0,6.0,3.0,2.0,Instagram,10.0 +U170,44,Male,4.3,8.0,6.0,6.0,2.0,Instagram,10.0 +U171,18,Male,5.6,5.0,7.0,3.0,1.0,X (Twitter),6.0 +U172,35,Male,5.1,7.0,7.0,2.0,0.0,Facebook,7.0 +U173,34,Male,7.2,6.0,6.0,3.0,3.0,X (Twitter),8.0 +U174,41,Male,4.9,7.0,5.0,5.0,5.0,YouTube,10.0 +U175,18,Female,6.4,6.0,8.0,0.0,2.0,X (Twitter),8.0 +U176,34,Male,5.4,6.0,6.0,4.0,0.0,Instagram,10.0 +U177,35,Female,6.0,7.0,7.0,5.0,4.0,X (Twitter),8.0 +U178,47,Male,5.1,7.0,4.0,2.0,0.0,TikTok,9.0 +U179,22,Other,6.8,5.0,10.0,5.0,1.0,LinkedIn,7.0 +U180,48,Male,6.8,4.0,7.0,4.0,2.0,YouTube,7.0 +U181,33,Male,6.3,6.0,8.0,7.0,3.0,X (Twitter),6.0 +U182,16,Other,6.6,6.0,7.0,1.0,1.0,Instagram,8.0 +U183,26,Female,3.8,10.0,6.0,5.0,2.0,LinkedIn,10.0 +U184,43,Female,6.1,7.0,9.0,1.0,4.0,YouTube,7.0 +U185,40,Male,7.7,3.0,9.0,0.0,4.0,Instagram,5.0 +U186,38,Male,6.4,6.0,7.0,0.0,4.0,Facebook,8.0 +U187,46,Female,3.6,9.0,5.0,4.0,5.0,TikTok,10.0 +U188,45,Female,5.4,7.0,8.0,3.0,4.0,Facebook,7.0 +U189,22,Female,6.2,6.0,9.0,2.0,2.0,TikTok,6.0 +U190,31,Female,4.1,8.0,6.0,0.0,4.0,X (Twitter),10.0 +U191,41,Male,3.1,7.0,7.0,2.0,3.0,YouTube,9.0 +U192,17,Female,4.9,7.0,6.0,3.0,3.0,X (Twitter),10.0 +U193,16,Male,4.2,7.0,7.0,5.0,3.0,TikTok,8.0 +U194,27,Male,5.5,6.0,8.0,3.0,1.0,TikTok,7.0 +U195,20,Male,2.6,9.0,4.0,4.0,1.0,TikTok,10.0 +U196,47,Female,6.5,5.0,7.0,2.0,2.0,LinkedIn,8.0 +U197,24,Male,4.4,8.0,7.0,6.0,3.0,Instagram,10.0 +U198,34,Male,7.3,5.0,7.0,1.0,1.0,YouTube,6.0 +U199,31,Female,5.0,7.0,5.0,2.0,3.0,X (Twitter),10.0 +U200,18,Male,5.5,5.0,6.0,2.0,6.0,X (Twitter),9.0 +U201,35,Female,5.2,7.0,6.0,4.0,0.0,Facebook,10.0 +U202,39,Male,4.3,9.0,7.0,1.0,3.0,TikTok,9.0 +U203,48,Female,10.0,3.0,10.0,3.0,2.0,TikTok,4.0 +U204,39,Female,4.6,6.0,7.0,3.0,3.0,TikTok,6.0 +U205,26,Male,7.3,5.0,8.0,2.0,2.0,X (Twitter),7.0 +U206,23,Male,3.5,9.0,5.0,4.0,2.0,Facebook,10.0 +U207,35,Male,6.2,4.0,8.0,1.0,3.0,Instagram,7.0 +U208,40,Male,6.4,5.0,8.0,1.0,2.0,LinkedIn,7.0 +U209,40,Other,5.5,6.0,7.0,3.0,0.0,YouTube,9.0 +U210,44,Male,4.5,6.0,6.0,0.0,2.0,Instagram,9.0 +U211,33,Female,5.0,6.0,6.0,0.0,0.0,X (Twitter),8.0 +U212,33,Male,4.8,5.0,6.0,4.0,3.0,LinkedIn,9.0 +U213,17,Female,9.8,4.0,9.0,5.0,0.0,LinkedIn,6.0 +U214,31,Female,5.5,7.0,7.0,2.0,4.0,LinkedIn,9.0 +U215,48,Female,6.4,4.0,6.0,2.0,3.0,TikTok,7.0 +U216,19,Female,8.7,6.0,7.0,5.0,3.0,YouTube,10.0 +U217,48,Male,3.9,9.0,5.0,1.0,0.0,X (Twitter),10.0 +U218,29,Male,5.6,6.0,6.0,0.0,2.0,X (Twitter),8.0 +U219,36,Female,5.0,8.0,6.0,4.0,4.0,LinkedIn,10.0 +U220,35,Female,5.3,6.0,6.0,3.0,2.0,X (Twitter),9.0 +U221,23,Other,5.2,7.0,6.0,5.0,1.0,YouTube,9.0 +U222,22,Male,4.2,6.0,6.0,0.0,2.0,YouTube,8.0 +U223,18,Female,5.9,8.0,6.0,0.0,2.0,Facebook,8.0 +U224,32,Female,3.6,7.0,5.0,3.0,1.0,Facebook,10.0 +U225,48,Male,7.3,5.0,9.0,4.0,4.0,Instagram,5.0 +U226,27,Male,6.6,5.0,7.0,5.0,2.0,YouTube,7.0 +U227,37,Female,4.6,9.0,6.0,3.0,2.0,YouTube,8.0 +U228,37,Female,4.7,9.0,6.0,3.0,1.0,YouTube,10.0 +U229,45,Female,6.0,8.0,7.0,2.0,3.0,Facebook,9.0 +U230,23,Male,4.1,8.0,5.0,6.0,3.0,LinkedIn,10.0 +U231,42,Female,8.8,4.0,8.0,1.0,1.0,LinkedIn,6.0 +U232,42,Male,6.1,7.0,6.0,4.0,0.0,X (Twitter),8.0 +U233,49,Male,6.2,6.0,8.0,1.0,3.0,LinkedIn,8.0 +U234,36,Male,6.7,7.0,7.0,4.0,2.0,TikTok,8.0 +U235,45,Female,4.5,8.0,6.0,1.0,3.0,Facebook,10.0 +U236,48,Female,7.5,4.0,9.0,4.0,5.0,YouTube,5.0 +U237,43,Female,5.7,7.0,7.0,3.0,2.0,LinkedIn,9.0 +U238,48,Male,7.2,6.0,8.0,6.0,2.0,Facebook,9.0 +U239,20,Female,5.1,4.0,5.0,2.0,1.0,LinkedIn,9.0 +U240,34,Male,7.9,6.0,8.0,4.0,2.0,Instagram,7.0 +U241,19,Female,6.2,5.0,7.0,1.0,3.0,YouTube,8.0 +U242,32,Male,9.4,3.0,9.0,2.0,3.0,Instagram,5.0 +U243,43,Male,4.8,6.0,5.0,3.0,5.0,YouTube,10.0 +U244,45,Male,5.9,6.0,8.0,0.0,3.0,Facebook,7.0 +U245,44,Male,3.6,8.0,5.0,0.0,2.0,X (Twitter),10.0 +U246,21,Female,4.5,6.0,6.0,4.0,1.0,LinkedIn,9.0 +U247,39,Female,6.6,6.0,7.0,3.0,3.0,LinkedIn,7.0 +U248,44,Male,7.5,6.0,9.0,2.0,1.0,TikTok,6.0 +U249,46,Female,10.8,5.0,10.0,2.0,3.0,Instagram,4.0 +U250,48,Male,5.3,5.0,5.0,2.0,3.0,LinkedIn,8.0 +U251,36,Female,3.8,7.0,5.0,3.0,2.0,Facebook,9.0 +U252,47,Female,6.9,6.0,8.0,4.0,2.0,X (Twitter),8.0 +U253,38,Female,3.5,9.0,4.0,0.0,0.0,Facebook,10.0 +U254,48,Male,6.5,5.0,8.0,6.0,4.0,Instagram,8.0 +U255,18,Female,4.6,8.0,7.0,6.0,3.0,TikTok,9.0 +U256,33,Female,5.9,4.0,7.0,0.0,3.0,LinkedIn,5.0 +U257,40,Male,7.5,4.0,8.0,5.0,0.0,Facebook,7.0 +U258,46,Male,5.3,7.0,8.0,5.0,5.0,YouTube,9.0 +U259,18,Female,7.2,4.0,7.0,3.0,2.0,X (Twitter),6.0 +U260,39,Female,7.1,4.0,8.0,0.0,2.0,Facebook,7.0 +U261,47,Female,2.1,8.0,4.0,5.0,1.0,X (Twitter),10.0 +U262,37,Female,3.9,8.0,6.0,3.0,0.0,LinkedIn,10.0 +U263,38,Female,2.7,6.0,3.0,4.0,4.0,YouTube,10.0 +U264,17,Male,9.1,4.0,9.0,7.0,4.0,TikTok,4.0 +U265,42,Male,5.3,7.0,7.0,4.0,2.0,TikTok,8.0 +U266,17,Male,8.7,4.0,8.0,1.0,1.0,YouTube,7.0 +U267,41,Female,2.6,8.0,4.0,2.0,2.0,X (Twitter),10.0 +U268,32,Male,4.5,7.0,6.0,1.0,2.0,LinkedIn,8.0 +U269,48,Female,6.9,5.0,9.0,4.0,1.0,Facebook,6.0 +U270,24,Female,5.8,8.0,6.0,1.0,2.0,YouTube,10.0 +U271,44,Female,6.5,5.0,6.0,2.0,3.0,Instagram,8.0 +U272,41,Male,3.2,8.0,6.0,0.0,4.0,X (Twitter),10.0 +U273,40,Female,8.0,5.0,8.0,3.0,3.0,TikTok,6.0 +U274,39,Male,6.9,5.0,7.0,3.0,0.0,YouTube,9.0 +U275,28,Female,6.2,3.0,8.0,0.0,3.0,TikTok,5.0 +U276,22,Male,4.3,7.0,5.0,4.0,2.0,X (Twitter),10.0 +U277,35,Male,6.1,7.0,8.0,3.0,5.0,TikTok,9.0 +U278,16,Female,5.0,7.0,6.0,1.0,5.0,YouTube,10.0 +U279,23,Male,2.5,9.0,5.0,7.0,5.0,X (Twitter),10.0 +U280,31,Male,5.4,5.0,8.0,5.0,2.0,Instagram,7.0 +U281,29,Female,5.2,7.0,6.0,1.0,5.0,YouTube,10.0 +U282,27,Female,7.4,4.0,9.0,5.0,3.0,Facebook,5.0 +U283,38,Female,1.5,9.0,4.0,5.0,5.0,LinkedIn,10.0 +U284,30,Female,5.9,6.0,9.0,3.0,2.0,YouTube,7.0 +U285,43,Female,3.4,7.0,7.0,4.0,2.0,X (Twitter),10.0 +U286,49,Male,6.1,5.0,6.0,4.0,3.0,LinkedIn,9.0 +U287,17,Male,6.2,6.0,8.0,0.0,3.0,YouTube,7.0 +U288,47,Male,5.3,6.0,6.0,3.0,4.0,Instagram,9.0 +U289,38,Female,7.0,4.0,8.0,3.0,2.0,Instagram,6.0 +U290,37,Female,3.3,7.0,5.0,6.0,4.0,Facebook,9.0 +U291,40,Male,7.5,5.0,7.0,4.0,4.0,TikTok,9.0 +U292,37,Male,4.2,9.0,7.0,2.0,2.0,LinkedIn,8.0 +U293,37,Other,2.7,8.0,5.0,3.0,3.0,X (Twitter),10.0 +U294,21,Male,5.0,7.0,6.0,5.0,2.0,X (Twitter),10.0 +U295,30,Female,1.7,9.0,4.0,3.0,6.0,TikTok,10.0 +U296,48,Female,6.0,5.0,7.0,2.0,0.0,Instagram,7.0 +U297,23,Female,5.0,5.0,7.0,2.0,4.0,Instagram,8.0 +U298,20,Male,7.7,4.0,8.0,5.0,0.0,YouTube,5.0 +U299,19,Female,6.7,6.0,5.0,4.0,2.0,Facebook,9.0 +U300,21,Female,3.0,8.0,5.0,3.0,2.0,LinkedIn,10.0 +U301,47,Female,4.0,8.0,4.0,3.0,5.0,Facebook,10.0 +U302,45,Female,6.4,6.0,7.0,2.0,1.0,Facebook,7.0 +U303,31,Female,7.0,5.0,6.0,3.0,4.0,YouTube,8.0 +U304,28,Female,7.7,4.0,7.0,2.0,1.0,Facebook,7.0 +U305,45,Male,7.7,6.0,9.0,2.0,5.0,TikTok,6.0 +U306,34,Female,5.3,6.0,6.0,6.0,2.0,YouTube,10.0 +U307,32,Female,6.2,5.0,7.0,1.0,3.0,TikTok,7.0 +U308,34,Male,7.8,5.0,9.0,2.0,2.0,YouTube,6.0 +U309,43,Female,5.0,5.0,6.0,0.0,4.0,Facebook,8.0 +U310,41,Female,6.2,8.0,9.0,0.0,4.0,TikTok,7.0 +U311,41,Female,4.9,7.0,6.0,3.0,3.0,TikTok,9.0 +U312,38,Male,4.5,7.0,6.0,2.0,1.0,TikTok,9.0 +U313,24,Male,5.6,5.0,8.0,7.0,1.0,YouTube,7.0 +U314,27,Other,4.3,7.0,7.0,2.0,1.0,X (Twitter),8.0 +U315,16,Female,5.3,6.0,6.0,6.0,0.0,Facebook,9.0 +U316,16,Male,3.6,7.0,5.0,5.0,4.0,TikTok,10.0 +U317,49,Female,5.2,7.0,8.0,3.0,2.0,LinkedIn,8.0 +U318,47,Female,6.3,5.0,8.0,7.0,2.0,X (Twitter),8.0 +U319,40,Female,4.4,8.0,7.0,4.0,5.0,LinkedIn,8.0 +U320,16,Female,5.5,9.0,5.0,6.0,2.0,Facebook,10.0 +U321,31,Male,5.8,6.0,8.0,6.0,1.0,Facebook,7.0 +U322,20,Male,6.3,6.0,6.0,4.0,2.0,TikTok,8.0 +U323,37,Male,5.3,7.0,7.0,5.0,1.0,TikTok,8.0 +U324,44,Female,7.5,7.0,8.0,4.0,3.0,YouTube,9.0 +U325,18,Male,1.7,9.0,3.0,1.0,2.0,X (Twitter),10.0 +U326,27,Male,10.8,2.0,9.0,3.0,2.0,X (Twitter),5.0 +U327,41,Male,3.9,6.0,4.0,5.0,2.0,TikTok,9.0 +U328,31,Female,6.2,6.0,7.0,6.0,2.0,Facebook,8.0 +U329,37,Female,3.8,7.0,4.0,2.0,3.0,TikTok,10.0 +U330,44,Male,6.7,8.0,6.0,4.0,2.0,Facebook,10.0 +U331,29,Male,6.0,8.0,7.0,3.0,2.0,Facebook,7.0 +U332,43,Female,4.3,6.0,3.0,3.0,2.0,TikTok,10.0 +U333,20,Male,7.4,4.0,9.0,4.0,2.0,Instagram,6.0 +U334,45,Male,5.2,7.0,4.0,3.0,3.0,X (Twitter),10.0 +U335,20,Male,6.5,7.0,7.0,4.0,1.0,LinkedIn,7.0 +U336,27,Male,3.8,8.0,6.0,6.0,4.0,Facebook,10.0 +U337,31,Male,5.4,6.0,7.0,5.0,3.0,Facebook,8.0 +U338,41,Female,9.1,4.0,9.0,4.0,2.0,Instagram,7.0 +U339,41,Male,7.2,5.0,10.0,7.0,3.0,LinkedIn,7.0 +U340,36,Female,6.1,7.0,7.0,5.0,2.0,YouTube,9.0 +U341,48,Male,4.7,6.0,7.0,3.0,0.0,Instagram,8.0 +U342,45,Female,5.6,6.0,6.0,1.0,3.0,LinkedIn,7.0 +U343,38,Female,5.9,7.0,7.0,1.0,1.0,Instagram,8.0 +U344,25,Male,6.3,5.0,9.0,5.0,4.0,Facebook,8.0 +U345,20,Female,3.6,6.0,5.0,2.0,2.0,X (Twitter),10.0 +U346,49,Male,6.7,6.0,8.0,1.0,4.0,Instagram,8.0 +U347,46,Other,3.5,7.0,6.0,1.0,2.0,YouTube,10.0 +U348,25,Male,8.4,5.0,8.0,2.0,3.0,TikTok,8.0 +U349,34,Male,4.1,8.0,4.0,5.0,4.0,LinkedIn,10.0 +U350,47,Male,6.0,6.0,7.0,3.0,2.0,Instagram,8.0 +U351,16,Other,7.3,5.0,9.0,0.0,2.0,Instagram,8.0 +U352,20,Female,5.0,6.0,5.0,3.0,3.0,YouTube,9.0 +U353,19,Female,6.1,6.0,8.0,3.0,3.0,X (Twitter),9.0 +U354,31,Male,8.1,5.0,9.0,2.0,2.0,Instagram,6.0 +U355,39,Female,6.5,5.0,5.0,5.0,0.0,Instagram,10.0 +U356,31,Male,6.1,6.0,6.0,2.0,2.0,X (Twitter),8.0 +U357,17,Female,2.5,9.0,4.0,5.0,4.0,YouTube,10.0 +U358,43,Male,3.8,8.0,5.0,4.0,1.0,LinkedIn,10.0 +U359,47,Male,7.4,6.0,7.0,5.0,2.0,LinkedIn,7.0 +U360,42,Male,6.5,5.0,7.0,1.0,2.0,Instagram,8.0 +U361,35,Female,6.8,4.0,9.0,3.0,0.0,X (Twitter),7.0 +U362,39,Female,6.6,6.0,7.0,3.0,1.0,Instagram,7.0 +U363,27,Male,7.8,5.0,8.0,0.0,4.0,TikTok,6.0 +U364,48,Female,5.0,8.0,6.0,5.0,0.0,LinkedIn,10.0 +U365,48,Male,5.7,7.0,7.0,1.0,3.0,Facebook,9.0 +U366,27,Female,3.6,8.0,5.0,2.0,1.0,TikTok,10.0 +U367,18,Male,3.2,9.0,5.0,1.0,3.0,Instagram,10.0 +U368,16,Male,2.5,8.0,4.0,4.0,4.0,YouTube,10.0 +U369,48,Female,3.9,6.0,5.0,0.0,2.0,LinkedIn,10.0 +U370,25,Male,8.3,5.0,8.0,4.0,2.0,Facebook,7.0 +U371,44,Male,2.7,9.0,5.0,5.0,1.0,TikTok,10.0 +U372,28,Female,5.0,7.0,5.0,4.0,0.0,LinkedIn,9.0 +U373,27,Other,3.2,9.0,6.0,4.0,5.0,X (Twitter),10.0 +U374,46,Male,4.4,7.0,7.0,5.0,1.0,YouTube,8.0 +U375,17,Male,5.8,6.0,7.0,1.0,5.0,TikTok,9.0 +U376,38,Female,4.2,8.0,5.0,6.0,1.0,X (Twitter),10.0 +U377,32,Male,7.5,4.0,7.0,2.0,3.0,Instagram,6.0 +U378,41,Male,8.3,4.0,10.0,3.0,1.0,X (Twitter),6.0 +U379,23,Male,1.5,9.0,3.0,2.0,3.0,TikTok,10.0 +U380,44,Female,4.3,8.0,7.0,4.0,2.0,TikTok,10.0 +U381,41,Female,5.6,6.0,7.0,2.0,2.0,Instagram,9.0 +U382,25,Male,7.5,4.0,8.0,3.0,3.0,X (Twitter),7.0 +U383,41,Male,4.7,6.0,6.0,1.0,2.0,X (Twitter),9.0 +U384,49,Male,5.5,7.0,5.0,5.0,0.0,Instagram,10.0 +U385,22,Female,3.4,8.0,6.0,4.0,4.0,X (Twitter),10.0 +U386,19,Male,2.7,8.0,5.0,2.0,1.0,LinkedIn,9.0 +U387,26,Male,8.1,5.0,7.0,2.0,3.0,Instagram,7.0 +U388,44,Female,2.2,9.0,4.0,4.0,4.0,YouTube,10.0 +U389,40,Male,9.7,4.0,10.0,2.0,2.0,TikTok,6.0 +U390,36,Female,6.8,5.0,8.0,3.0,4.0,YouTube,8.0 +U391,25,Male,6.7,7.0,7.0,5.0,3.0,LinkedIn,8.0 +U392,24,Male,6.5,7.0,6.0,1.0,1.0,LinkedIn,7.0 +U393,39,Male,5.0,5.0,6.0,4.0,5.0,Facebook,8.0 +U394,33,Female,6.2,6.0,7.0,4.0,1.0,Facebook,9.0 +U395,47,Female,5.2,7.0,7.0,3.0,2.0,Facebook,9.0 +U396,39,Female,6.5,5.0,7.0,2.0,2.0,Instagram,8.0 +U397,38,Male,5.4,8.0,9.0,5.0,3.0,X (Twitter),8.0 +U398,47,Female,3.8,8.0,4.0,2.0,4.0,LinkedIn,10.0 +U399,27,Male,3.9,7.0,5.0,0.0,4.0,YouTube,10.0 +U400,28,Male,3.7,8.0,5.0,4.0,3.0,YouTube,10.0 +U401,38,Female,7.2,5.0,8.0,1.0,2.0,YouTube,9.0 +U402,40,Male,4.4,7.0,5.0,4.0,0.0,YouTube,10.0 +U403,45,Male,2.1,10.0,5.0,6.0,1.0,Facebook,10.0 +U404,32,Male,6.2,6.0,8.0,4.0,5.0,YouTube,7.0 +U405,35,Male,8.2,4.0,9.0,5.0,0.0,X (Twitter),6.0 +U406,40,Male,6.9,3.0,8.0,7.0,1.0,Facebook,6.0 +U407,37,Female,4.4,7.0,6.0,5.0,4.0,YouTube,10.0 +U408,28,Male,2.6,10.0,3.0,4.0,3.0,Instagram,10.0 +U409,34,Female,5.3,6.0,6.0,6.0,4.0,TikTok,8.0 +U410,27,Female,5.2,4.0,7.0,7.0,3.0,TikTok,9.0 +U411,34,Male,3.8,7.0,6.0,7.0,5.0,TikTok,10.0 +U412,27,Male,6.5,7.0,7.0,2.0,3.0,Facebook,10.0 +U413,24,Male,8.8,5.0,9.0,0.0,3.0,YouTube,6.0 +U414,22,Female,5.9,5.0,6.0,2.0,1.0,LinkedIn,9.0 +U415,43,Male,5.8,6.0,7.0,4.0,3.0,LinkedIn,9.0 +U416,29,Female,3.8,7.0,6.0,3.0,2.0,LinkedIn,8.0 +U417,46,Male,8.0,4.0,8.0,3.0,2.0,Facebook,9.0 +U418,34,Female,5.4,4.0,5.0,4.0,4.0,X (Twitter),10.0 +U419,31,Male,4.5,8.0,5.0,0.0,4.0,Instagram,10.0 +U420,20,Male,5.6,8.0,8.0,3.0,5.0,Instagram,8.0 +U421,27,Other,6.3,4.0,9.0,5.0,5.0,TikTok,6.0 +U422,40,Female,5.8,7.0,8.0,4.0,4.0,TikTok,8.0 +U423,36,Male,7.7,5.0,8.0,4.0,0.0,LinkedIn,8.0 +U424,38,Female,6.2,7.0,6.0,6.0,5.0,LinkedIn,10.0 +U425,31,Female,5.0,7.0,6.0,3.0,4.0,TikTok,10.0 +U426,29,Female,6.2,6.0,9.0,2.0,1.0,YouTube,8.0 +U427,46,Female,6.4,6.0,8.0,5.0,4.0,X (Twitter),8.0 +U428,20,Male,3.2,7.0,5.0,5.0,1.0,X (Twitter),9.0 +U429,38,Male,7.6,6.0,9.0,4.0,3.0,X (Twitter),8.0 +U430,44,Male,5.0,6.0,5.0,3.0,1.0,X (Twitter),10.0 +U431,26,Female,5.0,6.0,6.0,3.0,1.0,TikTok,10.0 +U432,33,Female,6.9,6.0,7.0,0.0,4.0,Instagram,7.0 +U433,27,Male,6.8,5.0,8.0,5.0,4.0,TikTok,7.0 +U434,24,Male,4.2,8.0,7.0,5.0,3.0,X (Twitter),7.0 +U435,25,Female,6.9,7.0,7.0,0.0,2.0,TikTok,9.0 +U436,32,Other,4.5,8.0,5.0,5.0,0.0,X (Twitter),10.0 +U437,22,Female,8.8,2.0,9.0,4.0,2.0,Facebook,7.0 +U438,28,Male,4.0,7.0,5.0,1.0,3.0,LinkedIn,9.0 +U439,24,Male,5.9,7.0,6.0,1.0,2.0,Instagram,8.0 +U440,42,Male,4.1,6.0,6.0,4.0,2.0,Facebook,9.0 +U441,17,Female,9.5,4.0,9.0,2.0,2.0,YouTube,4.0 +U442,20,Male,3.3,8.0,4.0,7.0,2.0,TikTok,10.0 +U443,44,Female,6.4,5.0,6.0,5.0,3.0,YouTube,6.0 +U444,34,Male,4.2,8.0,4.0,4.0,2.0,TikTok,10.0 +U445,23,Female,4.9,7.0,7.0,4.0,5.0,LinkedIn,7.0 +U446,16,Female,6.3,6.0,7.0,6.0,0.0,Facebook,8.0 +U447,37,Female,7.7,4.0,8.0,5.0,2.0,LinkedIn,8.0 +U448,32,Male,8.4,5.0,9.0,4.0,1.0,X (Twitter),7.0 +U449,22,Male,2.7,6.0,6.0,1.0,3.0,LinkedIn,8.0 +U450,40,Male,6.0,4.0,6.0,9.0,3.0,LinkedIn,8.0 +U451,19,Female,8.1,5.0,9.0,1.0,5.0,X (Twitter),7.0 +U452,21,Female,2.5,9.0,3.0,2.0,5.0,TikTok,10.0 +U453,46,Female,6.1,6.0,8.0,0.0,1.0,X (Twitter),6.0 +U454,34,Female,5.5,6.0,5.0,4.0,3.0,Facebook,10.0 +U455,42,Female,6.2,6.0,6.0,3.0,3.0,LinkedIn,6.0 +U456,25,Male,4.1,6.0,5.0,5.0,1.0,X (Twitter),9.0 +U457,41,Female,2.9,8.0,3.0,0.0,3.0,LinkedIn,10.0 +U458,34,Male,2.1,9.0,4.0,5.0,2.0,Facebook,10.0 +U459,18,Male,7.3,5.0,8.0,1.0,1.0,YouTube,7.0 +U460,28,Female,7.5,6.0,8.0,3.0,2.0,Facebook,6.0 +U461,43,Male,6.0,6.0,5.0,5.0,1.0,Instagram,9.0 +U462,35,Male,4.5,7.0,8.0,4.0,4.0,Instagram,8.0 +U463,43,Female,2.6,9.0,5.0,4.0,5.0,Instagram,10.0 +U464,23,Male,6.5,5.0,10.0,5.0,3.0,LinkedIn,5.0 +U465,16,Female,3.5,8.0,5.0,4.0,3.0,TikTok,10.0 +U466,18,Male,4.2,7.0,7.0,2.0,3.0,TikTok,10.0 +U467,28,Male,4.4,7.0,7.0,3.0,4.0,TikTok,8.0 +U468,43,Male,5.1,6.0,6.0,3.0,4.0,LinkedIn,9.0 +U469,40,Male,4.1,6.0,7.0,1.0,2.0,Facebook,10.0 +U470,48,Female,5.6,6.0,6.0,5.0,3.0,TikTok,8.0 +U471,21,Male,4.4,5.0,6.0,4.0,3.0,YouTube,7.0 +U472,47,Female,6.8,5.0,4.0,3.0,1.0,Instagram,10.0 +U473,36,Female,7.4,3.0,8.0,5.0,4.0,Facebook,5.0 +U474,31,Female,5.7,7.0,5.0,3.0,5.0,Instagram,10.0 +U475,36,Female,6.3,7.0,7.0,3.0,4.0,TikTok,8.0 +U476,26,Female,9.3,3.0,8.0,2.0,1.0,Facebook,6.0 +U477,34,Male,3.4,8.0,5.0,4.0,0.0,X (Twitter),10.0 +U478,35,Male,3.3,9.0,6.0,4.0,4.0,YouTube,10.0 +U479,33,Male,6.8,6.0,8.0,2.0,1.0,TikTok,9.0 +U480,29,Female,9.0,4.0,10.0,3.0,3.0,TikTok,4.0 +U481,30,Male,5.0,7.0,8.0,2.0,5.0,YouTube,7.0 +U482,46,Female,8.6,4.0,8.0,2.0,2.0,YouTube,4.0 +U483,16,Female,5.7,6.0,6.0,0.0,2.0,X (Twitter),8.0 +U484,18,Male,4.3,7.0,7.0,3.0,3.0,TikTok,10.0 +U485,31,Female,3.7,7.0,7.0,0.0,1.0,LinkedIn,9.0 +U486,38,Male,6.6,6.0,8.0,2.0,1.0,TikTok,7.0 +U487,26,Male,3.9,8.0,5.0,2.0,3.0,Instagram,10.0 +U488,27,Female,5.6,7.0,7.0,5.0,2.0,Facebook,8.0 +U489,25,Female,8.4,5.0,9.0,1.0,3.0,Facebook,7.0 +U490,47,Female,4.9,7.0,7.0,1.0,7.0,Instagram,10.0 +U491,31,Male,7.9,5.0,8.0,2.0,3.0,Facebook,6.0 +U492,23,Male,3.3,10.0,4.0,2.0,1.0,YouTube,10.0 +U493,27,Female,4.5,5.0,7.0,6.0,5.0,Facebook,9.0 +U494,39,Female,3.0,7.0,2.0,1.0,0.0,Facebook,10.0 +U495,43,Female,5.6,8.0,6.0,2.0,0.0,Instagram,10.0 +U496,23,Male,6.9,5.0,7.0,4.0,2.0,X (Twitter),10.0 +U497,43,Female,5.6,7.0,6.0,5.0,2.0,Facebook,9.0 +U498,41,Male,7.7,5.0,7.0,2.0,2.0,LinkedIn,8.0 +U499,23,Male,4.2,9.0,7.0,0.0,2.0,Facebook,9.0 +U500,43,Female,5.9,5.0,8.0,3.0,3.0,X (Twitter),7.0 From 334c0c5761af939006fc9b2d6d57f0d0484c45fb Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Sat, 15 Nov 2025 15:39:49 -0500 Subject: [PATCH 02/19] feat:InmemoryProject Operator.this provide a base for easy mock testing of source nodes --- CONTRIBUTING.md | 8 +- .../operators/project/projectExec.go | 37 + .../operators/project/source/custom.go | 223 +++++ .../operators/project/source/custom_test.go | 914 +++++++++++++++++- .../operators/project/source/json.go | 3 + 5 files changed, 1179 insertions(+), 6 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cdc3be3..f705081 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,10 +20,10 @@ We use a Makefile to simplify common development tasks. All commands should be r make go-test-coverage ``` - Run test with html coverage -```bash - -go tool cover -html=coverage.out -``` + ```bash + go tool cover -func=coverage.out + go tool cover -html=coverage.out + ``` ### Rust Tests diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index 7dac6f1..259c9eb 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -1 +1,38 @@ package project + +import ( + "errors" + + "github.com/apache/arrow/go/v17/arrow" +) + +// handle keeping only the request columsn but make sure the schema and columns are also aligned +// returns error if a column doesnt exist +func ProjectSchemaFilterDown(schema *arrow.Schema, cols []arrow.Array, keepCols ...string) (*arrow.Schema, []arrow.Array, error) { + if len(keepCols) == 0 { + return arrow.NewSchema([]arrow.Field{}, nil), nil, errors.New("no columns passed in") + } + + // Build map: columnName -> original index + fieldIndex := make(map[string]int) + for i, f := range schema.Fields() { + fieldIndex[f.Name] = i + } + + newFields := make([]arrow.Field, 0, len(keepCols)) + newCols := make([]arrow.Array, 0, len(keepCols)) + + // Preserve order from keepCols, not schema order + for _, name := range keepCols { + idx, exists := fieldIndex[name] + if !exists { + return arrow.NewSchema([]arrow.Field{}, nil), []arrow.Array{}, errors.New("invalid column passed in to be pruned") + } + + newFields = append(newFields, schema.Field(idx)) + newCols = append(newCols, cols[idx]) + } + + newSchema := arrow.NewSchema(newFields, nil) + return newSchema, newCols, nil +} diff --git a/src/Backend/opti-sql-go/operators/project/source/custom.go b/src/Backend/opti-sql-go/operators/project/source/custom.go index 9a323c3..22b75b5 100644 --- a/src/Backend/opti-sql-go/operators/project/source/custom.go +++ b/src/Backend/opti-sql-go/operators/project/source/custom.go @@ -1,4 +1,227 @@ package source +import ( + "fmt" + "io" + "opti-sql-go/operators" + "opti-sql-go/operators/project" + + "github.com/apache/arrow/go/v15/arrow/memory" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + // in memory format just for the ease of testing // same as other sources, we can use structs/slices here + +// thankfully we already covered most of this in record.go +// add a couple utility functions for go types and this should be good to go +var ( + ErrInvalidInMemoryDataType = func(Type any) error { + return fmt.Errorf("%T is not a supported in memory dataType for InMemoryProjectExec", Type) + } +) + +type InMemoryProjectExec struct { + schema *arrow.Schema + columns []arrow.Array + pos uint64 + fieldToColIDx map[string]int +} + +func NewInMemoryProjectExec(names []string, columns []any) (*InMemoryProjectExec, error) { + if len(names) != len(columns) { + return nil, operators.ErrInvalidSchema("number of column names and columns do not match") + } + fields := make([]arrow.Field, 0, len(names)) + arrays := make([]arrow.Array, 0, len(names)) + fieldToColIDx := make(map[string]int) + // parse schema from each of the columns + for i, col := range columns { + if !supportedType(col) { + return nil, operators.ErrInvalidSchema(fmt.Sprintf("unsupported column type for column %s", names[i])) + } + field, arr, err := unpackColumm(names[i], col) + if err != nil { + return nil, ErrInvalidInMemoryDataType(col) + } + fields = append(fields, field) + arrays = append(arrays, arr) + fieldToColIDx[field.Name] = i + } + return &InMemoryProjectExec{ + schema: arrow.NewSchema(fields, nil), + columns: arrays, + fieldToColIDx: fieldToColIDx, + }, nil +} +func (ime *InMemoryProjectExec) withFields(names ...string) error { + + newSchema, cols, err := project.ProjectSchemaFilterDown(ime.schema, ime.columns, names...) + if err != nil { + return err + } + newMap := make(map[string]int) + for i, f := range newSchema.Fields() { + newMap[f.Name] = i + fmt.Printf("%s:%d", f.Name, i) + } + ime.schema = newSchema + ime.fieldToColIDx = newMap + ime.columns = cols + return nil +} +func (ime *InMemoryProjectExec) Next(n uint64) (*operators.RecordBatch, error) { + if ime.pos >= uint64(ime.columns[0].Len()) { + return nil, io.EOF // EOF + } + var currRows uint64 = 0 + outPutCols := make([]arrow.Array, len(ime.schema.Fields())) + + for i, field := range ime.schema.Fields() { + col := ime.columns[ime.fieldToColIDx[field.Name]] + colLen := uint64(col.Len()) + remaining := colLen - ime.pos + toRead := n + if remaining < n { + toRead = remaining + } + slice := array.NewSlice(col, int64(ime.pos), int64(ime.pos+toRead)) + outPutCols[i] = slice + currRows = toRead + } + ime.pos += currRows + + return &operators.RecordBatch{ + Schema: ime.schema, + Columns: outPutCols, + }, nil +} +func unpackColumm(name string, col any) (arrow.Field, arrow.Array, error) { + // need to not only build the array; but also need the schema + var field arrow.Field + field.Name = name + field.Nullable = true // default to nullable for now + switch colType := col.(type) { + case []int: + field.Type = arrow.PrimitiveTypes.Int64 + data := colType + b := array.NewInt64Builder(memory.DefaultAllocator) + defer b.Release() + for _, v := range data { + b.Append(int64(v)) + } + return field, b.NewArray(), nil + case []int8: + // build int8 array + field.Type = arrow.PrimitiveTypes.Int8 + data := colType + b := array.NewInt8Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + // build int8 array + case []int16: + field.Type = arrow.PrimitiveTypes.Int16 + data := colType + b := array.NewInt16Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + case []int32: + field.Type = arrow.PrimitiveTypes.Int32 + data := colType + b := array.NewInt32Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + // build int32 array + case []int64: + field.Type = arrow.PrimitiveTypes.Int64 + data := colType + b := array.NewInt64Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + case []uint: + field.Type = arrow.PrimitiveTypes.Uint64 + data := colType + b := array.NewUint64Builder(memory.DefaultAllocator) + defer b.Release() + for _, v := range data { + b.Append(uint64(v)) + } + return field, b.NewArray(), nil + case []uint8: + field.Type = arrow.PrimitiveTypes.Uint8 + data := colType + b := array.NewUint8Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + case []uint16: + field.Type = arrow.PrimitiveTypes.Uint16 + data := colType + b := array.NewUint16Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + case []uint32: + field.Type = arrow.PrimitiveTypes.Uint32 + data := colType + b := array.NewUint32Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + case []uint64: + field.Type = arrow.PrimitiveTypes.Uint64 + data := colType + b := array.NewUint64Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + case []float32: + field.Type = arrow.PrimitiveTypes.Float32 + data := colType + b := array.NewFloat32Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + case []float64: + field.Type = arrow.PrimitiveTypes.Float64 + data := colType + b := array.NewFloat64Builder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + case []string: + field.Type = arrow.BinaryTypes.String + data := colType + b := array.NewStringBuilder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + // build string array + case []bool: + field.Type = arrow.FixedWidthTypes.Boolean + data := colType + b := array.NewBooleanBuilder(memory.DefaultAllocator) + defer b.Release() + b.AppendValues(data, nil) + return field, b.NewArray(), nil + + } + return arrow.Field{}, nil, fmt.Errorf("unsupported column type for column %s", name) +} +func supportedType(col any) bool { + switch col.(type) { + case []int, []int8, []int16, []int32, []int64, + []uint, []uint8, []uint16, []uint32, []uint64, + []float32, []float64, + []string, + []bool: + return true + default: + return false + } +} diff --git a/src/Backend/opti-sql-go/operators/project/source/custom_test.go b/src/Backend/opti-sql-go/operators/project/source/custom_test.go index 53df34f..57a19ff 100644 --- a/src/Backend/opti-sql-go/operators/project/source/custom_test.go +++ b/src/Backend/opti-sql-go/operators/project/source/custom_test.go @@ -1,7 +1,917 @@ package source -import "testing" +import ( + "fmt" + "io" + "testing" -func TestCustom(t *testing.T) { + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +// generateTestColumns returns 8 column names and matching columns, +// each column containing ~10 entries for testing purposes. +func generateTestColumns() ([]string, []any) { + names := []string{ + "id", + "name", + "age", + "salary", + "is_active", + "department", + "rating", + "years_experience", + } + + columns := []any{ + []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + []string{ + "Alice", "Bob", "Charlie", "David", "Eve", + "Frank", "Grace", "Hannah", "Ivy", "Jake", + }, + []int32{28, 34, 45, 22, 31, 29, 40, 36, 50, 26}, + []float64{ + 70000.0, 82000.5, 54000.0, 91000.0, 60000.0, + 75000.0, 66000.0, 88000.0, 45000.0, 99000.0, + }, + []bool{true, false, true, true, false, false, true, true, false, true}, + []string{ + "Engineering", "HR", "Engineering", "Sales", "Finance", + "Sales", "Support", "Engineering", "HR", "Finance", + }, + []float32{4.5, 3.8, 4.2, 2.9, 5.0, 4.3, 3.7, 4.9, 4.1, 3.5}, + []int32{1, 5, 10, 2, 7, 3, 6, 12, 4, 8}, + } + + return names, columns +} + +func TestInMemoryBatchInit(t *testing.T) { // Simple passing test + names := []string{"id", "name", "age", "salary", "is_active"} + columns := []any{ + []int32{1, 2, 3, 4, 5}, + []string{"Alice", "Bob", "Charlie", "David", "Eve"}, + []int32{30, 25, 35, 28, 40}, + []float64{70000.0, 50000.0, 80000.0, 60000.0, 90000.0}, + []bool{true, false, true, true, false}, + } + projC, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Errorf("Failed to create InMemoryProjectExec: %v", err) + } + if projC.schema == nil { + t.Error("Schema is nil") + } + if projC.columns == nil { + t.Error("Columns are nil") + } + if projC.schema.NumFields() != len(names) { + t.Errorf("Schema field count mismatch: got %d, want %d", projC.schema.NumFields(), len(names)) + } + if len(projC.columns) != len(columns) { + t.Errorf("Columns count mismatch: got %d, want %d", len(projC.columns), len(columns)) + } + if len(projC.columns) != projC.schema.NumFields() { + t.Errorf("Columns and schema field count mismatch: got %d and %d", len(projC.columns), projC.schema.NumFields()) + } + fmt.Printf("schema: %v\n", projC.schema) +} + +// ==================== COMPREHENSIVE TESTS FOR 100% CODE COVERAGE ==================== + +// TestSupportedType tests every branch of the supportedType function +func TestSupportedType(t *testing.T) { + tests := []struct { + name string + input any + expected bool + }{ + // Supported integer types + {"[]int", []int{1, 2, 3}, true}, + {"[]int8", []int8{1, 2, 3}, true}, + {"[]int16", []int16{1, 2, 3}, true}, + {"[]int32", []int32{1, 2, 3}, true}, + {"[]int64", []int64{1, 2, 3}, true}, + + // Supported unsigned integer types + {"[]uint", []uint{1, 2, 3}, true}, + {"[]uint8", []uint8{1, 2, 3}, true}, + {"[]uint16", []uint16{1, 2, 3}, true}, + {"[]uint32", []uint32{1, 2, 3}, true}, + {"[]uint64", []uint64{1, 2, 3}, true}, + + // Supported float types + {"[]float32", []float32{1.1, 2.2, 3.3}, true}, + {"[]float64", []float64{1.1, 2.2, 3.3}, true}, + + // Supported string type + {"[]string", []string{"a", "b", "c"}, true}, + + // Supported boolean type + {"[]bool", []bool{true, false, true}, true}, + + // Unsupported types + //{"[]byte", []byte{1, 2, 3}, false}, alias for uint8 + //{"[]rune", []rune{'a', 'b', 'c'}, false}, alias for int32 + {"[]interface{}", []interface{}{1, "a", true}, false}, + {"map[string]int", map[string]int{"a": 1}, false}, + {"string", "not a slice", false}, + {"int", 123, false}, + {"struct", struct{ x int }{x: 1}, false}, + {"nil", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := supportedType(tt.input) + if result != tt.expected { + t.Errorf("supportedType(%v) = %v, expected %v", tt.name, result, tt.expected) + } + }) + } +} + +// TestUnpackColumn tests every branch of the unpackColumm function +func TestUnpackColumn(t *testing.T) { + t.Run("[]int type", func(t *testing.T) { + field, arr, err := unpackColumm("test_int", []int{1, 2, 3, 4, 5}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Name != "test_int" { + t.Errorf("Expected field name 'test_int', got '%s'", field.Name) + } + if field.Type != arrow.PrimitiveTypes.Int64 { + t.Errorf("Expected Int64 type, got %v", field.Type) + } + if !field.Nullable { + t.Error("Expected field to be nullable") + } + int64Arr, ok := arr.(*array.Int64) + if !ok { + t.Fatalf("Expected *array.Int64, got %T", arr) + } + if int64Arr.Len() != 5 { + t.Errorf("Expected 5 elements, got %d", int64Arr.Len()) + } + for i := 0; i < 5; i++ { + if int64Arr.Value(i) != int64(i+1) { + t.Errorf("Element %d: expected %d, got %d", i, i+1, int64Arr.Value(i)) + } + } + }) + + t.Run("[]int8 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_int8", []int8{-1, 0, 1, 127}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Int8 { + t.Errorf("Expected Int8 type, got %v", field.Type) + } + int8Arr, ok := arr.(*array.Int8) + if !ok { + t.Fatalf("Expected *array.Int8, got %T", arr) + } + if int8Arr.Len() != 4 { + t.Errorf("Expected 4 elements, got %d", int8Arr.Len()) + } + }) + + t.Run("[]int16 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_int16", []int16{-100, 0, 100, 32767}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Int16 { + t.Errorf("Expected Int16 type, got %v", field.Type) + } + int16Arr, ok := arr.(*array.Int16) + if !ok { + t.Fatalf("Expected *array.Int16, got %T", arr) + } + if int16Arr.Len() != 4 { + t.Errorf("Expected 4 elements, got %d", int16Arr.Len()) + } + }) + + t.Run("[]int32 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_int32", []int32{-1000, 0, 1000, 2147483647}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Int32 { + t.Errorf("Expected Int32 type, got %v", field.Type) + } + int32Arr, ok := arr.(*array.Int32) + if !ok { + t.Fatalf("Expected *array.Int32, got %T", arr) + } + if int32Arr.Len() != 4 { + t.Errorf("Expected 4 elements, got %d", int32Arr.Len()) + } + }) + + t.Run("[]int64 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_int64", []int64{-9223372036854775808, 0, 9223372036854775807}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Int64 { + t.Errorf("Expected Int64 type, got %v", field.Type) + } + int64Arr, ok := arr.(*array.Int64) + if !ok { + t.Fatalf("Expected *array.Int64, got %T", arr) + } + if int64Arr.Len() != 3 { + t.Errorf("Expected 3 elements, got %d", int64Arr.Len()) + } + }) + + t.Run("[]uint type", func(t *testing.T) { + field, arr, err := unpackColumm("test_uint", []uint{0, 1, 100, 1000}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Uint64 { + t.Errorf("Expected Uint64 type, got %v", field.Type) + } + uint64Arr, ok := arr.(*array.Uint64) + if !ok { + t.Fatalf("Expected *array.Uint64, got %T", arr) + } + if uint64Arr.Len() != 4 { + t.Errorf("Expected 4 elements, got %d", uint64Arr.Len()) + } + expected := []uint64{0, 1, 100, 1000} + for i, exp := range expected { + if uint64Arr.Value(i) != exp { + t.Errorf("Element %d: expected %d, got %d", i, exp, uint64Arr.Value(i)) + } + } + }) + + t.Run("[]uint8 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_uint8", []uint8{0, 1, 255}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Uint8 { + t.Errorf("Expected Uint8 type, got %v", field.Type) + } + uint8Arr, ok := arr.(*array.Uint8) + if !ok { + t.Fatalf("Expected *array.Uint8, got %T", arr) + } + if uint8Arr.Len() != 3 { + t.Errorf("Expected 3 elements, got %d", uint8Arr.Len()) + } + }) + + t.Run("[]uint16 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_uint16", []uint16{0, 100, 65535}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Uint16 { + t.Errorf("Expected Uint16 type, got %v", field.Type) + } + uint16Arr, ok := arr.(*array.Uint16) + if !ok { + t.Fatalf("Expected *array.Uint16, got %T", arr) + } + if uint16Arr.Len() != 3 { + t.Errorf("Expected 3 elements, got %d", uint16Arr.Len()) + } + }) + + t.Run("[]uint32 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_uint32", []uint32{0, 1000, 4294967295}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Uint32 { + t.Errorf("Expected Uint32 type, got %v", field.Type) + } + uint32Arr, ok := arr.(*array.Uint32) + if !ok { + t.Fatalf("Expected *array.Uint32, got %T", arr) + } + if uint32Arr.Len() != 3 { + t.Errorf("Expected 3 elements, got %d", uint32Arr.Len()) + } + }) + + t.Run("[]uint64 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_uint64", []uint64{0, 1000, 18446744073709551615}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Uint64 { + t.Errorf("Expected Uint64 type, got %v", field.Type) + } + uint64Arr, ok := arr.(*array.Uint64) + if !ok { + t.Fatalf("Expected *array.Uint64, got %T", arr) + } + if uint64Arr.Len() != 3 { + t.Errorf("Expected 3 elements, got %d", uint64Arr.Len()) + } + }) + + t.Run("[]float32 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_float32", []float32{-1.5, 0.0, 1.5, 3.14159}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Float32 { + t.Errorf("Expected Float32 type, got %v", field.Type) + } + float32Arr, ok := arr.(*array.Float32) + if !ok { + t.Fatalf("Expected *array.Float32, got %T", arr) + } + if float32Arr.Len() != 4 { + t.Errorf("Expected 4 elements, got %d", float32Arr.Len()) + } + }) + + t.Run("[]float64 type", func(t *testing.T) { + field, arr, err := unpackColumm("test_float64", []float64{-2.718281828, 0.0, 3.141592653589793}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.PrimitiveTypes.Float64 { + t.Errorf("Expected Float64 type, got %v", field.Type) + } + float64Arr, ok := arr.(*array.Float64) + if !ok { + t.Fatalf("Expected *array.Float64, got %T", arr) + } + if float64Arr.Len() != 3 { + t.Errorf("Expected 3 elements, got %d", float64Arr.Len()) + } + }) + + t.Run("[]string type", func(t *testing.T) { + field, arr, err := unpackColumm("test_string", []string{"hello", "world", "test", ""}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.BinaryTypes.String { + t.Errorf("Expected String type, got %v", field.Type) + } + stringArr, ok := arr.(*array.String) + if !ok { + t.Fatalf("Expected *array.String, got %T", arr) + } + if stringArr.Len() != 4 { + t.Errorf("Expected 4 elements, got %d", stringArr.Len()) + } + expected := []string{"hello", "world", "test", ""} + for i, exp := range expected { + if stringArr.Value(i) != exp { + t.Errorf("Element %d: expected '%s', got '%s'", i, exp, stringArr.Value(i)) + } + } + }) + + t.Run("[]bool type", func(t *testing.T) { + field, arr, err := unpackColumm("test_bool", []bool{true, false, true, false, true}) + if err != nil { + t.Fatalf("unpackColumm failed: %v", err) + } + if field.Type != arrow.FixedWidthTypes.Boolean { + t.Errorf("Expected Boolean type, got %v", field.Type) + } + boolArr, ok := arr.(*array.Boolean) + if !ok { + t.Fatalf("Expected *array.Boolean, got %T", arr) + } + if boolArr.Len() != 5 { + t.Errorf("Expected 5 elements, got %d", boolArr.Len()) + } + expected := []bool{true, false, true, false, true} + for i, exp := range expected { + if boolArr.Value(i) != exp { + t.Errorf("Element %d: expected %v, got %v", i, exp, boolArr.Value(i)) + } + } + }) + + t.Run("Unsupported type - default case", func(t *testing.T) { + _, _, err := unpackColumm("test_unsupported", []byte{1, 2, 3}) + if err != nil { + t.Error("unexpected error for unsupported type") + } + + }) + + t.Run("Empty slices", func(t *testing.T) { + field, arr, err := unpackColumm("empty_int", []int{}) + if err != nil { + t.Fatalf("unpackColumm failed for empty slice: %v", err) + } + if arr.Len() != 0 { + t.Errorf("Expected 0 elements for empty slice, got %d", arr.Len()) + } + if field.Name != "empty_int" { + t.Errorf("Expected field name 'empty_int', got '%s'", field.Name) + } + }) +} + +// TestNewInMemoryProjectExec tests the constructor comprehensively +func TestNewInMemoryProjectExec(t *testing.T) { + t.Run("Valid construction with all types", func(t *testing.T) { + names := []string{ + "col_int", "col_int8", "col_int16", "col_int32", "col_int64", + "col_uint", "col_uint8", "col_uint16", "col_uint32", "col_uint64", + "col_float32", "col_float64", "col_string", "col_bool", + } + columns := []any{ + []int{1, 2}, + []int8{1, 2}, + []int16{1, 2}, + []int32{1, 2}, + []int64{1, 2}, + []uint{1, 2}, + []uint8{1, 2}, + []uint16{1, 2}, + []uint32{1, 2}, + []uint64{1, 2}, + []float32{1.1, 2.2}, + []float64{1.1, 2.2}, + []string{"a", "b"}, + []bool{true, false}, + } + + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + if proj == nil { + t.Fatal("InMemoryProjectExec is nil") + } + if proj.schema == nil { + t.Fatal("Schema is nil") + } + if proj.columns == nil { + t.Fatal("Columns are nil") + } + if proj.schema.NumFields() != len(names) { + t.Errorf("Expected %d fields, got %d", len(names), proj.schema.NumFields()) + } + if len(proj.columns) != len(columns) { + t.Errorf("Expected %d columns, got %d", len(columns), len(proj.columns)) + } + + // Verify each field name matches + fields := proj.schema.Fields() + for i, expectedName := range names { + if fields[i].Name != expectedName { + t.Errorf("Field %d: expected name '%s', got '%s'", i, expectedName, fields[i].Name) + } + if !fields[i].Nullable { + t.Errorf("Field %d (%s): expected nullable=true", i, expectedName) + } + } + + // Verify each column has correct length + for i, col := range proj.columns { + if col.Len() != 2 { + t.Errorf("Column %d: expected length 2, got %d", i, col.Len()) + } + } + }) + + t.Run("Mismatched names and columns count", func(t *testing.T) { + names := []string{"col1", "col2"} + columns := []any{[]int{1, 2, 3}} + + _, err := NewInMemoryProjectExec(names, columns) + if err == nil { + t.Error("Expected error for mismatched names and columns, got nil") + } + }) + + t.Run("Unsupported type - supportedType returns false", func(t *testing.T) { + names := []string{"col1"} + columns := []any{[]byte{1, 2, 3}} // byte slice is not supported + + _, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Error("unexpected error for unsupported type in NewInMemoryProjectExec") + } + }) + + t.Run("Single column", func(t *testing.T) { + names := []string{"only_col"} + columns := []any{[]int{10, 20, 30}} + + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + if proj.schema.NumFields() != 1 { + t.Errorf("Expected 1 field, got %d", proj.schema.NumFields()) + } + if len(proj.columns) != 1 { + t.Errorf("Expected 1 column, got %d", len(proj.columns)) + } + if proj.columns[0].Len() != 3 { + t.Errorf("Expected column length 3, got %d", proj.columns[0].Len()) + } + }) + + t.Run("Empty columns", func(t *testing.T) { + names := []string{} + columns := []any{} + + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed for empty input: %v", err) + } + + if proj.schema.NumFields() != 0 { + t.Errorf("Expected 0 fields, got %d", proj.schema.NumFields()) + } + if len(proj.columns) != 0 { + t.Errorf("Expected 0 columns, got %d", len(proj.columns)) + } + }) + + t.Run("Columns with different lengths - valid construction", func(t *testing.T) { + // Note: The function doesn't validate that all columns have the same length + // This is valid construction even though columns have different lengths + names := []string{"col1", "col2"} + columns := []any{ + []int{1, 2, 3}, + []string{"a", "b"}, + } + + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + if proj.columns[0].Len() != 3 { + t.Errorf("Column 0: expected length 3, got %d", proj.columns[0].Len()) + } + if proj.columns[1].Len() != 2 { + t.Errorf("Column 1: expected length 2, got %d", proj.columns[1].Len()) + } + }) + + t.Run("Complex field names", func(t *testing.T) { + names := []string{"Column_1", "column-2", "Column.3", "column 4"} + columns := []any{ + []int{1}, + []int{2}, + []int{3}, + []int{4}, + } + + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + fields := proj.schema.Fields() + for i, expectedName := range names { + if fields[i].Name != expectedName { + t.Errorf("Field %d: expected name '%s', got '%s'", i, expectedName, fields[i].Name) + } + } + }) +} + +// TestErrInvalidInMemoryDataType tests the error constructor +func TestErrInvalidInMemoryDataType(t *testing.T) { + testType := []byte{1, 2, 3} + err := ErrInvalidInMemoryDataType(testType) + + if err == nil { + t.Fatal("ErrInvalidInMemoryDataType returned nil") + } + + expectedMsg := "[]uint8 is not a supported in memory dataType for InMemoryProjectExec" + if err.Error() != expectedMsg { + t.Errorf("Expected error message '%s', got '%s'", expectedMsg, err.Error()) + } + + // Test with different type + testType2 := map[string]int{"key": 1} + err2 := ErrInvalidInMemoryDataType(testType2) + expectedMsg2 := "map[string]int is not a supported in memory dataType for InMemoryProjectExec" + if err2.Error() != expectedMsg2 { + t.Errorf("Expected error message '%s', got '%s'", expectedMsg2, err2.Error()) + } +} + +// TestSchemaFieldTypes verifies the correct Arrow types are assigned +func TestSchemaFieldTypes(t *testing.T) { + names := []string{ + "int", "int8", "int16", "int32", "int64", + "uint", "uint8", "uint16", "uint32", "uint64", + "float32", "float64", "string", "bool", + } + columns := []any{ + []int{1}, []int8{1}, []int16{1}, []int32{1}, []int64{1}, + []uint{1}, []uint8{1}, []uint16{1}, []uint32{1}, []uint64{1}, + []float32{1.0}, []float64{1.0}, []string{"a"}, []bool{true}, + } + + expectedTypes := []arrow.DataType{ + arrow.PrimitiveTypes.Int64, // []int -> Int64 + arrow.PrimitiveTypes.Int8, // []int8 -> Int8 + arrow.PrimitiveTypes.Int16, // []int16 -> Int16 + arrow.PrimitiveTypes.Int32, // []int32 -> Int32 + arrow.PrimitiveTypes.Int64, // []int64 -> Int64 + arrow.PrimitiveTypes.Uint64, // []uint -> Uint64 + arrow.PrimitiveTypes.Uint8, // []uint8 -> Uint8 + arrow.PrimitiveTypes.Uint16, // []uint16 -> Uint16 + arrow.PrimitiveTypes.Uint32, // []uint32 -> Uint32 + arrow.PrimitiveTypes.Uint64, // []uint64 -> Uint64 + arrow.PrimitiveTypes.Float32, // []float32 -> Float32 + arrow.PrimitiveTypes.Float64, // []float64 -> Float64 + arrow.BinaryTypes.String, // []string -> String + arrow.FixedWidthTypes.Boolean, // []bool -> Boolean + } + + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + fields := proj.schema.Fields() + for i, expectedType := range expectedTypes { + if fields[i].Type != expectedType { + t.Errorf("Field %d (%s): expected type %v, got %v", + i, names[i], expectedType, fields[i].Type) + } + } +} + +func TestPrunceSchema(t *testing.T) { + names, columns := generateTestColumns() + + t.Run("Select subset of fields", func(t *testing.T) { + // Create a fresh instance for this test + testProj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + // Original schema should have 8 fields + originalFieldCount := testProj.schema.NumFields() + if originalFieldCount != 8 { + t.Errorf("Expected 8 fields in original schema, got %d", originalFieldCount) + } + + // Select only a subset of fields + selectedFields := []string{"id", "name", "salary"} + err = testProj.withFields(selectedFields...) + if err != nil { + t.Error("unexpected error when pruning columns") + } + + // After pruning, schema should have only 3 fields + prunedFieldCount := testProj.schema.NumFields() + if prunedFieldCount != 3 { + t.Errorf("Expected 3 fields after pruning, got %d", prunedFieldCount) + } + + // Verify the field names match + fields := testProj.schema.Fields() + for i, expectedName := range selectedFields { + if fields[i].Name != expectedName { + t.Errorf("Field %d: expected name '%s', got '%s'", i, expectedName, fields[i].Name) + } + } + + // Verify field order is preserved + if fields[0].Name != "id" || fields[1].Name != "name" || fields[2].Name != "salary" { + t.Error("Field order not preserved after pruning") + } + }) + + t.Run("Select single field", func(t *testing.T) { + // Create a fresh instance for this test + testProj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + // Select only one field + err = testProj.withFields("department") + if err != nil { + t.Error("unexpected error when pruning columns") + } + + // After pruning, schema should have only 1 field + prunedFieldCount := testProj.schema.NumFields() + if prunedFieldCount != 1 { + t.Errorf("Expected 1 field after pruning, got %d", prunedFieldCount) + } + + // Verify the field name + fields := testProj.schema.Fields() + if fields[0].Name != "department" { + t.Errorf("Expected field name 'department', got '%s'", fields[0].Name) + } + + // Verify the field type is preserved (should be String since department is []string) + if fields[0].Type != arrow.BinaryTypes.String { + t.Errorf("Expected String type, got %v", fields[0].Type) + } + }) +} + +// TestNext tests the Next function with projection and iteration +func TestNext(t *testing.T) { + t.Run("Read all data in single batch", func(t *testing.T) { + names, columns := generateTestColumns() + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + // Read all 10 rows in one batch + batch, err := proj.Next(100) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + if batch == nil { + t.Fatal("Expected batch, got nil") + } + + // Verify we got all 10 rows + if len(batch.Columns) != 8 { + t.Errorf("Expected 8 columns, got %d", len(batch.Columns)) + } + if batch.Columns[0].Len() != 10 { + t.Errorf("Expected 10 rows, got %d", batch.Columns[0].Len()) + } + + // Next call should return EOF + _, err = proj.Next(1) + if err != io.EOF { + t.Errorf("Expected EOF after reading all data, got: %v", err) + } + }) + + t.Run("Read with projection and iterate to EOF", func(t *testing.T) { + names, columns := generateTestColumns() + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + // Project to only 3 columns + err = proj.withFields("id", "name", "salary") + if err != nil { + t.Error("unexpected error when pruning columns") + } + totalRowsRead := 0 + batchCount := 0 + + // Iterate until EOF + for { + batch, err := proj.Next(3) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Next failed on batch %d: %v", batchCount+1, err) + } + + batchCount++ + totalRowsRead += batch.Columns[0].Len() + + // Verify projected schema has only 3 fields + if len(batch.Columns) != 3 { + t.Errorf("Batch %d: expected 3 columns after projection, got %d", batchCount, len(batch.Columns)) + } + + // Verify field names + fields := batch.Schema.Fields() + expectedNames := []string{"id", "name", "salary"} + for i, expectedName := range expectedNames { + if fields[i].Name != expectedName { + t.Errorf("Batch %d, Field %d: expected '%s', got '%s'", batchCount, i, expectedName, fields[i].Name) + } + } + } + + // Verify we read all 10 rows total + if totalRowsRead != 10 { + t.Errorf("Expected to read 10 total rows, got %d", totalRowsRead) + } + + // Verify we got 4 batches (3+3+3+1) + if batchCount != 4 { + t.Errorf("Expected 4 batches, got %d", batchCount) + } + }) + + t.Run("Multiple Next calls with small batch size", func(t *testing.T) { + names, columns := generateTestColumns() + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + // Project to 2 columns + err = proj.withFields("age", "is_active") + if err != nil { + t.Error("unexpected error when pruning columns") + } + + // Read 2 rows at a time + batch1, err := proj.Next(2) + if err != nil { + t.Fatalf("First Next failed: %v", err) + } + if batch1.Columns[0].Len() != 2 { + t.Errorf("First batch: expected 2 rows, got %d", batch1.Columns[0].Len()) + } + + batch2, err := proj.Next(2) + if err != nil { + t.Fatalf("Second Next failed: %v", err) + } + if batch2.Columns[0].Len() != 2 { + t.Errorf("Second batch: expected 2 rows, got %d", batch2.Columns[0].Len()) + } + + // Continue reading until EOF + rowsRemaining := 0 + for { + batch, err := proj.Next(2) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Next failed: %v", err) + } + rowsRemaining += batch.Columns[0].Len() + } + + // We read 4 rows in first two batches, so 6 should remain + if rowsRemaining != 6 { + t.Errorf("Expected 6 remaining rows, got %d", rowsRemaining) + } + }) + + t.Run("Single field projection with iteration", func(t *testing.T) { + names, columns := generateTestColumns() + proj, err := NewInMemoryProjectExec(names, columns) + if err != nil { + t.Fatalf("NewInMemoryProjectExec failed: %v", err) + } + + // Project to just the department column + err = proj.withFields("department") + if err != nil { + t.Error("unexpected error when pruning columns") + } + fmt.Printf("updated: %s\n", proj.schema) + fmt.Printf("new Mapping: %v\n", proj.fieldToColIDx) + fmt.Printf("new columns: %v\n", proj.columns) + + totalRows := 0 + for { + batch, err := proj.Next(5) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Next failed: %v", err) + } + fmt.Printf("Batche schema: %v\n", batch.Schema) + fmt.Printf("Batch data: %v\n", batch.Columns) + + // Verify only 1 column + if len(batch.Columns) != 1 { + t.Errorf("Expected 1 column, got %d", len(batch.Columns)) + } + + // Verify it's a string array + if _, ok := batch.Columns[0].(*array.String); !ok { + t.Errorf("Expected *array.String, got %T", batch.Columns[0]) + } + + totalRows += batch.Columns[0].Len() + } + + if totalRows != 10 { + t.Errorf("Expected 10 total rows, got %d", totalRows) + } + }) } diff --git a/src/Backend/opti-sql-go/operators/project/source/json.go b/src/Backend/opti-sql-go/operators/project/source/json.go index d150341..1c0d24f 100644 --- a/src/Backend/opti-sql-go/operators/project/source/json.go +++ b/src/Backend/opti-sql-go/operators/project/source/json.go @@ -1 +1,4 @@ package source + +// take entire json file and rewrite it as parquet file and then use +// ProjectParquet source to read it From 7159b86b3b4910a1b4cd5bb78c0c106a160a56e0 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Sat, 15 Nov 2025 16:11:24 -0500 Subject: [PATCH 03/19] removed json as datasource --- CONTRIBUTING.md | 2 +- src/Backend/opti-sql-go/config/config.go | 12 ++++++++++++ .../operators/project/source/custom_test.go | 13 ++++++++++--- .../opti-sql-go/operators/project/source/json.go | 4 ---- .../operators/project/source/json_test.go | 7 ------- 5 files changed, 23 insertions(+), 15 deletions(-) delete mode 100644 src/Backend/opti-sql-go/operators/project/source/json.go delete mode 100644 src/Backend/opti-sql-go/operators/project/source/json_test.go diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f705081..d4f60d4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -21,7 +21,7 @@ We use a Makefile to simplify common development tasks. All commands should be r ``` - Run test with html coverage ```bash - go tool cover -func=coverage.out + go test ./... -coverprofile=coverage.out go tool cover -html=coverage.out ``` diff --git a/src/Backend/opti-sql-go/config/config.go b/src/Backend/opti-sql-go/config/config.go index d5943e6..5b976c9 100644 --- a/src/Backend/opti-sql-go/config/config.go +++ b/src/Backend/opti-sql-go/config/config.go @@ -32,6 +32,8 @@ type batchConfig struct { EnableParallelRead bool `yaml:"enable_parallel_read"` MaxMemoryBeforeSpill uint64 `yaml:"max_memory_before_spill"` MaxFileSizeMB int `yaml:"max_file_size_mb"` // max size of a single file + ShouldDowndload bool `yaml:"should_download"` + MaxDownloadSizeMB int `yaml:"max_download_size_mb"` // max size to download from external sources like S3 } type queryConfig struct { // should results be cached, server side? if so how long @@ -64,6 +66,10 @@ var configInstance *Config = &Config{ EnableParallelRead: true, MaxMemoryBeforeSpill: uint64(gigaByte) * 2, // 2GB MaxFileSizeMB: 500, // 500MB + // should we download files from external sources like S3 + // if so whats the max size to download, if its greater than dont download the file locally + ShouldDowndload: true, + MaxDownloadSizeMB: 10, // 10MB }, Query: queryConfig{ EnableCache: true, @@ -138,6 +144,12 @@ func mergeConfig(dst *Config, src map[string]interface{}) { if v, ok := batch["max_file_size_mb"].(int); ok { dst.Batch.MaxFileSizeMB = v } + if v, ok := batch["should_download"].(bool); ok { + dst.Batch.ShouldDowndload = v + } + if v, ok := batch["max_download_size_mb"].(int); ok { + dst.Batch.MaxDownloadSizeMB = v + } } // ============================= diff --git a/src/Backend/opti-sql-go/operators/project/source/custom_test.go b/src/Backend/opti-sql-go/operators/project/source/custom_test.go index 57a19ff..1af0bf5 100644 --- a/src/Backend/opti-sql-go/operators/project/source/custom_test.go +++ b/src/Backend/opti-sql-go/operators/project/source/custom_test.go @@ -499,13 +499,20 @@ func TestNewInMemoryProjectExec(t *testing.T) { }) t.Run("Unsupported type - supportedType returns false", func(t *testing.T) { + // Custom struct type is not supported + type CustomStruct struct { + ID int + Name string + } + names := []string{"col1"} - columns := []any{[]byte{1, 2, 3}} // byte slice is not supported + columns := []any{[]CustomStruct{{1, "test"}, {2, "data"}}} _, err := NewInMemoryProjectExec(names, columns) - if err != nil { - t.Error("unexpected error for unsupported type in NewInMemoryProjectExec") + if err == nil { + t.Error("Expected error for unsupported type, got nil") } + }) t.Run("Single column", func(t *testing.T) { diff --git a/src/Backend/opti-sql-go/operators/project/source/json.go b/src/Backend/opti-sql-go/operators/project/source/json.go deleted file mode 100644 index 1c0d24f..0000000 --- a/src/Backend/opti-sql-go/operators/project/source/json.go +++ /dev/null @@ -1,4 +0,0 @@ -package source - -// take entire json file and rewrite it as parquet file and then use -// ProjectParquet source to read it diff --git a/src/Backend/opti-sql-go/operators/project/source/json_test.go b/src/Backend/opti-sql-go/operators/project/source/json_test.go deleted file mode 100644 index 482bb43..0000000 --- a/src/Backend/opti-sql-go/operators/project/source/json_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package source - -import "testing" - -func TestJson(t *testing.T) { - // Simple passing test -} From a9de3dbb04ae26bf0d19d2b3d51082a10046cc66 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Sun, 16 Nov 2025 01:33:20 -0500 Subject: [PATCH 04/19] feature:ParquetSource operator is implemented --- .gitignore | 5 +- src/Backend/opti-sql-go/config/config.go | 5 +- src/Backend/opti-sql-go/go.mod | 2 + src/Backend/opti-sql-go/go.sum | 2 + .../opti-sql-go/operators/filter/filter.go | 25 +- .../operators/project/source/csv.go | 55 +- .../operators/project/source/csv_test.go | 24 + .../operators/project/source/custom.go | 35 +- .../operators/project/source/parquet.go | 383 +++++++++ .../operators/project/source/parquet_test.go | 724 +++++++++++++++++- .../source/{s3_test.go => source_test.go} | 1 + src/Backend/opti-sql-go/operators/record.go | 7 +- .../test_data/parquet/capitals_clean.parquet | Bin 0 -> 12340 bytes 13 files changed, 1212 insertions(+), 56 deletions(-) rename src/Backend/opti-sql-go/operators/project/source/{s3_test.go => source_test.go} (55%) create mode 100644 src/Backend/test_data/parquet/capitals_clean.parquet diff --git a/.gitignore b/.gitignore index b8fbd54..3970ccd 100644 --- a/.gitignore +++ b/.gitignore @@ -106,4 +106,7 @@ src/Backend/test_data/json # Allow a specific CSV dataset that we want tracked despite the general csv ignores !src/Backend/test_data/csv/ -!src/Backend/test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv \ No newline at end of file +!src/Backend/test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv +# allow parquet file +!src/Backend/test_data/parquet/ +!src/Backend/test_data/parquet/capitals_clean.parquet \ No newline at end of file diff --git a/src/Backend/opti-sql-go/config/config.go b/src/Backend/opti-sql-go/config/config.go index 5b976c9..ba02bcf 100644 --- a/src/Backend/opti-sql-go/config/config.go +++ b/src/Backend/opti-sql-go/config/config.go @@ -32,8 +32,9 @@ type batchConfig struct { EnableParallelRead bool `yaml:"enable_parallel_read"` MaxMemoryBeforeSpill uint64 `yaml:"max_memory_before_spill"` MaxFileSizeMB int `yaml:"max_file_size_mb"` // max size of a single file - ShouldDowndload bool `yaml:"should_download"` - MaxDownloadSizeMB int `yaml:"max_download_size_mb"` // max size to download from external sources like S3 + // TODO: add test for these two fileds, just add to existing test + ShouldDowndload bool `yaml:"should_download"` + MaxDownloadSizeMB int `yaml:"max_download_size_mb"` // max size to download from external sources like S3 } type queryConfig struct { // should results be cached, server side? if so how long diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index 426538f..5012de3 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -11,6 +11,7 @@ require ( ) require ( + github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/apache/thrift v0.20.0 // indirect github.com/goccy/go-json v0.10.3 // indirect @@ -21,6 +22,7 @@ require ( github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect + github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/mod v0.18.0 // indirect diff --git a/src/Backend/opti-sql-go/go.sum b/src/Backend/opti-sql-go/go.sum index d8d9111..afb9b78 100644 --- a/src/Backend/opti-sql-go/go.sum +++ b/src/Backend/opti-sql-go/go.sum @@ -1,3 +1,5 @@ +github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= +github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/apache/arrow/go/v15 v15.0.2 h1:60IliRbiyTWCWjERBCkO1W4Qun9svcYoZrSLcyOsMLE= diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index 32326a3..0fc5e5c 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -1,7 +1,24 @@ package filter -// handle Bitwise operations here as well +import ( + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) -// OR -// AND -// NOT +// FilterExpr takes in a field and column and yeildss a function that takes in an index and returns a bool indicating whether the row at that index satisfies the filter condition. +type FilterExpr func(filed arrow.Field, col arrow.Array) func(i int) bool + +// example +func ExampleFilterExpr(field arrow.Field, col arrow.Array) func(i int) bool { + { + if field.Name == "age" && col.DataType().ID() == arrow.INT32 { + return func(i int) bool { + val := col.(*array.Int32).Value(i) + return val > 30 + } + } + return func(i int) bool { + return true + } + } +} diff --git a/src/Backend/opti-sql-go/operators/project/source/csv.go b/src/Backend/opti-sql-go/operators/project/source/csv.go index 5780f97..b5c48df 100644 --- a/src/Backend/opti-sql-go/operators/project/source/csv.go +++ b/src/Backend/opti-sql-go/operators/project/source/csv.go @@ -13,7 +13,9 @@ import ( "github.com/apache/arrow/go/v17/arrow/array" ) -type ProjectCSVLeaf struct { +// TODO: change the leaf stuff to be called scans instead + +type CSVSource struct { r *csv.Reader schema *arrow.Schema // columns to project as well as types to cast to colPosition map[string]int @@ -22,9 +24,9 @@ type ProjectCSVLeaf struct { } // assume everything is on disk for now -func NewProjectCSVLeaf(source io.Reader) (*ProjectCSVLeaf, error) { +func NewProjectCSVLeaf(source io.Reader) (*CSVSource, error) { r := csv.NewReader(source) - proj := &ProjectCSVLeaf{ + proj := &CSVSource{ r: r, colPosition: make(map[string]int), } @@ -34,31 +36,32 @@ func NewProjectCSVLeaf(source io.Reader) (*ProjectCSVLeaf, error) { return proj, err } -func (pcsv *ProjectCSVLeaf) Next(n uint64) (*operators.RecordBatch, error) { - if pcsv.done { +func (csvS *CSVSource) Next(n uint64) (*operators.RecordBatch, error) { + if csvS.done { return nil, io.EOF } // 1. Create builders - builders := pcsv.initBuilders() + builders := csvS.initBuilders() rowsRead := uint64(0) // Process stored first row (from parseHeader) --- - if pcsv.firstDataRow != nil && rowsRead < n { - if err := pcsv.processRow(pcsv.firstDataRow, builders); err != nil { + if csvS.firstDataRow != nil && rowsRead < n { + fmt.Printf("First row: %v\n", csvS.firstDataRow) + if err := csvS.processRow(csvS.firstDataRow, builders); err != nil { return nil, err } - pcsv.firstDataRow = nil // consume it once + csvS.firstDataRow = nil // consume it once rowsRead++ } // Stream remaining rows from CSV reader --- for rowsRead < n { - row, err := pcsv.r.Read() + row, err := csvS.r.Read() if err == io.EOF { if rowsRead == 0 { - pcsv.done = true + csvS.done = true return nil, io.EOF } break @@ -68,7 +71,7 @@ func (pcsv *ProjectCSVLeaf) Next(n uint64) (*operators.RecordBatch, error) { } // append to builders - if err := pcsv.processRow(row, builders); err != nil { + if err := csvS.processRow(row, builders); err != nil { return nil, err } @@ -76,16 +79,16 @@ func (pcsv *ProjectCSVLeaf) Next(n uint64) (*operators.RecordBatch, error) { } // Freeze into Arrow arrays - columns := pcsv.finalizeBuilders(builders) + columns := csvS.finalizeBuilders(builders) return &operators.RecordBatch{ - Schema: pcsv.schema, + Schema: csvS.schema, Columns: columns, }, nil } -func (pcsv *ProjectCSVLeaf) initBuilders() []array.Builder { - fields := pcsv.schema.Fields() +func (csvS *CSVSource) initBuilders() []array.Builder { + fields := csvS.schema.Fields() builders := make([]array.Builder, len(fields)) for i, f := range fields { @@ -94,14 +97,14 @@ func (pcsv *ProjectCSVLeaf) initBuilders() []array.Builder { return builders } -func (pcsv *ProjectCSVLeaf) processRow( +func (csvS *CSVSource) processRow( content []string, builders []array.Builder, ) error { - fields := pcsv.schema.Fields() - + fields := csvS.schema.Fields() + fmt.Printf("content : %v\n", content) for i, f := range fields { - colIdx := pcsv.colPosition[f.Name] + colIdx := csvS.colPosition[f.Name] cell := content[colIdx] switch b := builders[i].(type) { @@ -143,7 +146,7 @@ func (pcsv *ProjectCSVLeaf) processRow( return nil } -func (pcsv *ProjectCSVLeaf) finalizeBuilders(builders []array.Builder) []arrow.Array { +func (csvS *CSVSource) finalizeBuilders(builders []array.Builder) []arrow.Array { columns := make([]arrow.Array, len(builders)) for i, b := range builders { @@ -155,16 +158,16 @@ func (pcsv *ProjectCSVLeaf) finalizeBuilders(builders []array.Builder) []arrow.A } // first call to csv.Reader -func (pscv *ProjectCSVLeaf) parseHeader() (*arrow.Schema, error) { - header, err := pscv.r.Read() +func (csvS *CSVSource) parseHeader() (*arrow.Schema, error) { + header, err := csvS.r.Read() if err != nil { return nil, err } - firstDataRow, err := pscv.r.Read() + firstDataRow, err := csvS.r.Read() if err != nil { return nil, err } - pscv.firstDataRow = firstDataRow + csvS.firstDataRow = firstDataRow newFields := make([]arrow.Field, 0, len(header)) for i, colName := range header { sampleValue := firstDataRow[i] @@ -173,7 +176,7 @@ func (pscv *ProjectCSVLeaf) parseHeader() (*arrow.Schema, error) { Type: parseDataType(sampleValue), Nullable: true, }) - pscv.colPosition[colName] = i + csvS.colPosition[colName] = i } return arrow.NewSchema(newFields, nil), nil } diff --git a/src/Backend/opti-sql-go/operators/project/source/csv_test.go b/src/Backend/opti-sql-go/operators/project/source/csv_test.go index 3c87e32..3d08625 100644 --- a/src/Backend/opti-sql-go/operators/project/source/csv_test.go +++ b/src/Backend/opti-sql-go/operators/project/source/csv_test.go @@ -9,6 +9,7 @@ import ( "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" ) const csvFilePath = "../../../../test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv" @@ -943,6 +944,29 @@ func TestIntegrationWithRealFile(t *testing.T) { } }) } +func TestProccessFirstLine(t *testing.T) { + v := getTestFile() + p, err := NewProjectCSVLeaf(v) + if err != nil { + t.Errorf("Failed to create ProjectCSVLeaf: %v", err) + } + defer func() { + if err := v.Close(); err != nil { + t.Fatalf("failed to close: %v", err) + } + }() + var builders []array.Builder + for range len(p.schema.Fields()) { + builder := array.NewBuilder(memory.DefaultAllocator, &arrow.Date64Type{}) + defer builder.Release() + builders = append(builders, builder) + } + err = p.processRow([]string{"1", "alice", "95.5", "true"}, builders) + if err == nil { + t.Errorf("Expected error for empty row, got nil") + } + +} /* func TestLargercsvFile(t *testing.T) { diff --git a/src/Backend/opti-sql-go/operators/project/source/custom.go b/src/Backend/opti-sql-go/operators/project/source/custom.go index 22b75b5..f7fa96e 100644 --- a/src/Backend/opti-sql-go/operators/project/source/custom.go +++ b/src/Backend/opti-sql-go/operators/project/source/custom.go @@ -22,14 +22,14 @@ var ( } ) -type InMemoryProjectExec struct { +type InMemorySource struct { schema *arrow.Schema columns []arrow.Array pos uint64 fieldToColIDx map[string]int } -func NewInMemoryProjectExec(names []string, columns []any) (*InMemoryProjectExec, error) { +func NewInMemoryProjectExec(names []string, columns []any) (*InMemorySource, error) { if len(names) != len(columns) { return nil, operators.ErrInvalidSchema("number of column names and columns do not match") } @@ -49,51 +49,50 @@ func NewInMemoryProjectExec(names []string, columns []any) (*InMemoryProjectExec arrays = append(arrays, arr) fieldToColIDx[field.Name] = i } - return &InMemoryProjectExec{ + return &InMemorySource{ schema: arrow.NewSchema(fields, nil), columns: arrays, fieldToColIDx: fieldToColIDx, }, nil } -func (ime *InMemoryProjectExec) withFields(names ...string) error { +func (ms *InMemorySource) withFields(names ...string) error { - newSchema, cols, err := project.ProjectSchemaFilterDown(ime.schema, ime.columns, names...) + newSchema, cols, err := project.ProjectSchemaFilterDown(ms.schema, ms.columns, names...) if err != nil { return err } newMap := make(map[string]int) for i, f := range newSchema.Fields() { newMap[f.Name] = i - fmt.Printf("%s:%d", f.Name, i) } - ime.schema = newSchema - ime.fieldToColIDx = newMap - ime.columns = cols + ms.schema = newSchema + ms.fieldToColIDx = newMap + ms.columns = cols return nil } -func (ime *InMemoryProjectExec) Next(n uint64) (*operators.RecordBatch, error) { - if ime.pos >= uint64(ime.columns[0].Len()) { +func (ms *InMemorySource) Next(n uint64) (*operators.RecordBatch, error) { + if ms.pos >= uint64(ms.columns[0].Len()) { return nil, io.EOF // EOF } var currRows uint64 = 0 - outPutCols := make([]arrow.Array, len(ime.schema.Fields())) + outPutCols := make([]arrow.Array, len(ms.schema.Fields())) - for i, field := range ime.schema.Fields() { - col := ime.columns[ime.fieldToColIDx[field.Name]] + for i, field := range ms.schema.Fields() { + col := ms.columns[ms.fieldToColIDx[field.Name]] colLen := uint64(col.Len()) - remaining := colLen - ime.pos + remaining := colLen - ms.pos toRead := n if remaining < n { toRead = remaining } - slice := array.NewSlice(col, int64(ime.pos), int64(ime.pos+toRead)) + slice := array.NewSlice(col, int64(ms.pos), int64(ms.pos+toRead)) outPutCols[i] = slice currRows = toRead } - ime.pos += currRows + ms.pos += currRows return &operators.RecordBatch{ - Schema: ime.schema, + Schema: ms.schema, Columns: outPutCols, }, nil } diff --git a/src/Backend/opti-sql-go/operators/project/source/parquet.go b/src/Backend/opti-sql-go/operators/project/source/parquet.go index d150341..8486823 100644 --- a/src/Backend/opti-sql-go/operators/project/source/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/source/parquet.go @@ -1 +1,384 @@ package source + +import ( + "context" + "errors" + "fmt" + "io" + "opti-sql-go/operators" + "opti-sql-go/operators/filter" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/apache/arrow/go/v17/parquet" + "github.com/apache/arrow/go/v17/parquet/file" + "github.com/apache/arrow/go/v17/parquet/pqarrow" +) + +type ParquetSource struct { + // existing fields + Schema *arrow.Schema + projectionPushDown []string // columns to project up + predicatePushDown []filter.FilterExpr // simple predicate push down for now + reader pqarrow.RecordReader + // for internal reading + done bool // if set to true always return io.EOF +} + +// source, columns you want to be push up the tree, any filters +func NewParquetSource(r parquet.ReaderAtSeeker, columns []string, filters []filter.FilterExpr) (*ParquetSource, error) { + if len(columns) == 0 { + return nil, errors.New("no columns were provided for projection push down") + } + allocator := memory.NewGoAllocator() + filerReader, err := file.NewParquetReader(r) + if err != nil { + return nil, err + } + + defer func() { + if err := filerReader.Close(); err != nil { + fmt.Printf("warning: failed to close parquet reader: %v\n", err) + + } + }() + + arrowReader, err := pqarrow.NewFileReader( + filerReader, + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: 5}, // TODO: Read in from config for this stuff + allocator, + ) + if err != nil { + return nil, err + } + var wantedColumnsIDX []int + s, _ := arrowReader.Schema() + for _, col := range columns { + idx_array := s.FieldIndices(col) + if len(idx_array) == 0 { + return nil, errors.New("unknown column passed in to be project push down") + } + wantedColumnsIDX = append(wantedColumnsIDX, idx_array...) + } + + rdr, err := arrowReader.GetRecordReader(context.TODO(), wantedColumnsIDX, nil) + if err != nil { + return nil, err + } + + return &ParquetSource{ + Schema: rdr.Schema(), + projectionPushDown: columns, + predicatePushDown: filters, + reader: rdr, + }, nil +} + +// This should be 1 +func (ps *ParquetSource) Next(n uint16) (*operators.RecordBatch, error) { + if ps.reader == nil || ps.done || !ps.reader.Next() { + return nil, io.EOF + } + columns := make([]arrow.Array, len(ps.Schema.Fields())) + curRow := 0 + for curRow < int(n) && ps.reader.Next() { + err := ps.reader.Err() + if err != nil { + return nil, err + } + record := ps.reader.Record() + numCols := int(record.NumCols()) + numRows := int(record.NumRows()) + + fmt.Printf("numCols=%d numRows=%d columns=%v\n", + numCols, numRows, record.Columns(), + ) + for colIdx := 0; colIdx < numCols; colIdx++ { + + batchCol := record.Column(colIdx) + existing := columns[colIdx] + fmt.Printf("columns:%v\n", columns) + fmt.Printf("existing:%v\n", existing) + fmt.Printf("batchCol:%v\n", batchCol) + // First time seeing this column → just assign it + if existing == nil { + batchCol.Retain() + columns[colIdx] = batchCol + continue + } + + // Otherwise combine existing + new batch column + combined := CombineArray(existing, batchCol) + + // Replace + columns[colIdx] = combined + + // VERY IMPORTANT: + // Release the old existing array to avoid leaks + existing.Release() + } + record.Release() + + curRow += numRows + } + return &operators.RecordBatch{ + Schema: ps.Schema, // Remove the pointer as ps.Schema is already of type arrow.Schema + Columns: columns, + RowCount: uint64(curRow), + }, nil +} +func (ps *ParquetSource) Close() error { + ps.reader.Release() + ps.reader = nil + return nil +} + +// append arr2 to arr1 so (arr1 + arr2) = arr1-arr2 +func CombineArray(a1, a2 arrow.Array) arrow.Array { + if a1 == nil { + return a2 + } + if a2 == nil { + return a1 + } + + mem := memory.NewGoAllocator() + dt := a1.DataType() + + switch dt.ID() { + + // -------------------- INT TYPES -------------------- + case arrow.INT8: + b := array.NewInt8Builder(mem) + appendInt8(b, a1.(*array.Int8)) + appendInt8(b, a2.(*array.Int8)) + return b.NewArray() + + case arrow.INT16: + b := array.NewInt16Builder(mem) + appendInt16(b, a1.(*array.Int16)) + appendInt16(b, a2.(*array.Int16)) + return b.NewArray() + + case arrow.INT32: + b := array.NewInt32Builder(mem) + appendInt32(b, a1.(*array.Int32)) + appendInt32(b, a2.(*array.Int32)) + return b.NewArray() + + case arrow.INT64: + b := array.NewInt64Builder(mem) + appendInt64(b, a1.(*array.Int64)) + appendInt64(b, a2.(*array.Int64)) + return b.NewArray() + + // -------------------- UINT TYPES -------------------- + case arrow.UINT8: + b := array.NewUint8Builder(mem) + appendUint8(b, a1.(*array.Uint8)) + appendUint8(b, a2.(*array.Uint8)) + return b.NewArray() + + case arrow.UINT16: + b := array.NewUint16Builder(mem) + appendUint16(b, a1.(*array.Uint16)) + appendUint16(b, a2.(*array.Uint16)) + return b.NewArray() + + case arrow.UINT32: + b := array.NewUint32Builder(mem) + appendUint32(b, a1.(*array.Uint32)) + appendUint32(b, a2.(*array.Uint32)) + return b.NewArray() + + case arrow.UINT64: + b := array.NewUint64Builder(mem) + appendUint64(b, a1.(*array.Uint64)) + appendUint64(b, a2.(*array.Uint64)) + return b.NewArray() + + // -------------------- FLOAT TYPES -------------------- + case arrow.FLOAT32: + b := array.NewFloat32Builder(mem) + appendFloat32(b, a1.(*array.Float32)) + appendFloat32(b, a2.(*array.Float32)) + return b.NewArray() + + case arrow.FLOAT64: + b := array.NewFloat64Builder(mem) + appendFloat64(b, a1.(*array.Float64)) + appendFloat64(b, a2.(*array.Float64)) + return b.NewArray() + + // -------------------- BOOLEAN -------------------- + case arrow.BOOL: + b := array.NewBooleanBuilder(mem) + appendBool(b, a1.(*array.Boolean)) + appendBool(b, a2.(*array.Boolean)) + return b.NewArray() + + // -------------------- STRING TYPES -------------------- + case arrow.STRING: + b := array.NewStringBuilder(mem) + appendString(b, a1.(*array.String)) + appendString(b, a2.(*array.String)) + return b.NewArray() + + case arrow.LARGE_STRING: + b := array.NewLargeStringBuilder(mem) + appendLargeString(b, a1.(*array.LargeString)) + appendLargeString(b, a2.(*array.LargeString)) + return b.NewArray() + + // -------------------- BINARY TYPES -------------------- + case arrow.BINARY: + b := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + appendBinary(b, a1.(*array.Binary)) + appendBinary(b, a2.(*array.Binary)) + return b.NewArray() + + default: + panic(fmt.Sprintf("unsupported datatype in CombineArray: %v", dt)) + } +} + +func appendInt8(b *array.Int8Builder, c *array.Int8) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendInt16(b *array.Int16Builder, c *array.Int16) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendInt32(b *array.Int32Builder, c *array.Int32) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendInt64(b *array.Int64Builder, c *array.Int64) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendUint8(b *array.Uint8Builder, c *array.Uint8) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendUint16(b *array.Uint16Builder, c *array.Uint16) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendUint32(b *array.Uint32Builder, c *array.Uint32) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendUint64(b *array.Uint64Builder, c *array.Uint64) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendFloat32(b *array.Float32Builder, c *array.Float32) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendFloat64(b *array.Float64Builder, c *array.Float64) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendBool(b *array.BooleanBuilder, c *array.Boolean) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendString(b *array.StringBuilder, c *array.String) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendLargeString(b *array.LargeStringBuilder, c *array.LargeString) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} + +func appendBinary(b *array.BinaryBuilder, c *array.Binary) { + for i := 0; i < c.Len(); i++ { + if c.IsNull(i) { + b.AppendNull() + continue + } + b.Append(c.Value(i)) + } +} diff --git a/src/Backend/opti-sql-go/operators/project/source/parquet_test.go b/src/Backend/opti-sql-go/operators/project/source/parquet_test.go index f677b80..1b5daa1 100644 --- a/src/Backend/opti-sql-go/operators/project/source/parquet_test.go +++ b/src/Backend/opti-sql-go/operators/project/source/parquet_test.go @@ -1,7 +1,725 @@ package source -import "testing" +import ( + "fmt" + "io" + "os" + "testing" -func TestParquet(t *testing.T) { - // Simple passing test + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" +) + +const ParquetTestDatafile = "../../../../test_data/parquet/capitals_clean.parquet" + +func getTestParuqetFile() *os.File { + file, err := os.Open(ParquetTestDatafile) + if err != nil { + panic(err) + } + return file +} + +/* +schema: + + fields: 5 + - country: type=utf8, nullable + metadata: ["PARQUET:field_id": "-1"] + - country_alpha2: type=utf8, nullable + metadata: ["PARQUET:field_id": "-1"] + - capital: type=utf8, nullable + metadata: ["PARQUET:field_id": "-1"] + - lat: type=float64, nullable + metadata: ["PARQUET:field_id": "-1"] + - lon: type=float64, nullable +*/ +// TODO: more to their own files later down the line +func existIn(str string, arr []string) bool { + for _, a := range arr { + if a == str { + return true + } + } + return false +} +func sameStringSlice(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} +func TestParquetInit(t *testing.T) { + t.Run("Test No names pass in", func(t *testing.T) { + f := getTestParuqetFile() + + _, err := NewParquetSource(f, []string{}, nil) + if err == nil { + t.Errorf("Expected error when no columns are passed in, but got nil") + } + }) + + t.Run("Test invalid names are passed in", func(t *testing.T) { + f := getTestParuqetFile() + _, err := NewParquetSource(f, []string{"non_existent_column"}, nil) + if err == nil { + t.Errorf("Expected error when invalid column names are passed in, but got nil") + } + }) + + t.Run("Test correct schema is returned", func(t *testing.T) { + f := getTestParuqetFile() + columns := []string{"country", "capital", "lat"} + source, err := NewParquetSource(f, columns, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + schema := source.Schema + if len(schema.Fields()) != len(columns) { + t.Errorf("Expected schema to have %d fields, got %d", len(columns), len(schema.Fields())) + } + for _, field := range schema.Fields() { + if !existIn(field.Name, columns) { + t.Errorf("Field %s not found in expected columns %v", field.Name, columns) + } + } + + }) + + t.Run("Test input columns and filters were passed back out", func(t *testing.T) { + f := getTestParuqetFile() + columns := []string{"country", "capital", "lat"} + source, err := NewParquetSource(f, columns, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if len(source.projectionPushDown) != len(columns) { + t.Errorf("Expected projectionPushDown to have %d columns, got %d", len(columns), len(source.projectionPushDown)) + } + if !sameStringSlice(source.projectionPushDown, columns) || source.predicatePushDown != nil { + t.Errorf("Expected projectionPushDown to be %v and predicatePushDown to be nil, got %v and %v", columns, source.projectionPushDown, source.predicatePushDown) + } + }) + + t.Run("Check reader isnt null", func(t *testing.T) { + + f := getTestParuqetFile() + columns := []string{"country", "capital", "lat"} + source, err := NewParquetSource(f, columns, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if source.reader == nil { + t.Errorf("Expected reader to be initialized, but got nil") + } + + }) + +} +func TestParquetClose(t *testing.T) { + f := getTestParuqetFile() + columns := []string{"country", "capital", "lat"} + source, err := NewParquetSource(f, columns, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + err = source.Close() + if err != nil { + t.Errorf("Unexpected error on Close: %v", err) + } + if source.reader != nil { + t.Errorf("Expected reader to be nil after Close, but it is not") + } + _, err = source.Next(1) + if err != io.EOF { + t.Error("expected reader to return io.EOF") + } + +} +func TestRunToEnd(t *testing.T) { + f := getTestParuqetFile() + columns := []string{"country", "capital", "lat"} + source, err := NewParquetSource(f, columns, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + for { + rc, err := source.Next(1024 * 8) + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("Unexpected error on Next: %v", err) + } + t.Log("RecordBatch: ", rc) + } +} + +func TestParquetRead(t *testing.T) { + f := getTestParuqetFile() + columns := []string{"country", "capital", "lat"} + source, err := NewParquetSource(f, columns, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + // batchSize := uint16(10) + rc, err := source.Next(uint16(15)) + if err != nil { + t.Fatalf("Unexpected error on Next: %v", err) + } + if rc == nil { + t.Fatalf("Expected RecordBatch, got nil") + } + if len(rc.Columns) != len(columns) { + t.Errorf("Expected %d columns, got %d", len(columns), len(rc.Columns)) + } + if rc.Schema.NumFields() != len(columns) { + t.Errorf("Expected schema to have %d fields, got %d", len(columns), rc.Schema.NumFields()) + } + fmt.Printf("columns:%v\n", rc.Columns) + fmt.Printf("count:%d\n", rc.RowCount) +} + +// CombineArray tests: cover primitive, uint, float, bool, string, binary and nil-handling +func TestCombineArray_Cases(t *testing.T) { + mem := memory.NewGoAllocator() + + t.Run("INT8", func(t *testing.T) { + ib1 := array.NewInt8Builder(mem) + ib1.Append(1) + ib1.AppendNull() + a1 := ib1.NewArray().(*array.Int8) + ib2 := array.NewInt8Builder(mem) + ib2.Append(2) + ib2.Append(3) + a2 := ib2.NewArray().(*array.Int8) + comb := CombineArray(a1, a2).(*array.Int8) + if comb.Len() != a1.Len()+a2.Len() { + t.Fatalf("int8 combined length wrong") + } + if comb.Value(0) != 1 || !comb.IsNull(1) || comb.Value(2) != 2 { + t.Fatalf("int8 values unexpected") + } + a1.Release() + a2.Release() + comb.Release() + }) + + t.Run("INT16", func(t *testing.T) { + i16b1 := array.NewInt16Builder(mem) + i16b1.Append(10) + i16b1.Append(20) + ia1 := i16b1.NewArray().(*array.Int16) + i16b2 := array.NewInt16Builder(mem) + i16b2.Append(30) + ia2 := i16b2.NewArray().(*array.Int16) + i16c := CombineArray(ia1, ia2).(*array.Int16) + if i16c.Len() != ia1.Len()+ia2.Len() { + t.Fatalf("int16 combined length") + } + ia1.Release() + ia2.Release() + i16c.Release() + }) + + t.Run("INT32", func(t *testing.T) { + i32b1 := array.NewInt32Builder(mem) + i32b1.Append(1) + ia32_1 := i32b1.NewArray().(*array.Int32) + i32b2 := array.NewInt32Builder(mem) + i32b2.Append(2) + ia32_2 := i32b2.NewArray().(*array.Int32) + i32c := CombineArray(ia32_1, ia32_2).(*array.Int32) + if i32c.Len() != 2 { + t.Fatalf("int32 combined length") + } + ia32_1.Release() + ia32_2.Release() + i32c.Release() + }) + + t.Run("INT64", func(t *testing.T) { + i64b1 := array.NewInt64Builder(mem) + i64b1.Append(100) + ia64_1 := i64b1.NewArray().(*array.Int64) + i64b2 := array.NewInt64Builder(mem) + i64b2.Append(200) + ia64_2 := i64b2.NewArray().(*array.Int64) + i64c := CombineArray(ia64_1, ia64_2).(*array.Int64) + if i64c.Len() != 2 { + t.Fatalf("int64 combined length") + } + ia64_1.Release() + ia64_2.Release() + i64c.Release() + }) + + t.Run("UINT8", func(t *testing.T) { + u8b1 := array.NewUint8Builder(mem) + u8b1.Append(8) + ua8_1 := u8b1.NewArray().(*array.Uint8) + u8b2 := array.NewUint8Builder(mem) + u8b2.Append(9) + ua8_2 := u8b2.NewArray().(*array.Uint8) + u8c := CombineArray(ua8_1, ua8_2).(*array.Uint8) + if u8c.Len() != 2 { + t.Fatalf("uint8 combined length") + } + ua8_1.Release() + ua8_2.Release() + u8c.Release() + }) + + t.Run("UINT16", func(t *testing.T) { + u16b1 := array.NewUint16Builder(mem) + u16b1.Append(16) + ua16_1 := u16b1.NewArray().(*array.Uint16) + u16b2 := array.NewUint16Builder(mem) + u16b2.Append(32) + ua16_2 := u16b2.NewArray().(*array.Uint16) + u16c := CombineArray(ua16_1, ua16_2).(*array.Uint16) + if u16c.Len() != 2 { + t.Fatalf("uint16 combined length") + } + ua16_1.Release() + ua16_2.Release() + u16c.Release() + }) + + t.Run("UINT32", func(t *testing.T) { + u32b1 := array.NewUint32Builder(mem) + u32b1.Append(1000) + ua32_1 := u32b1.NewArray().(*array.Uint32) + u32b2 := array.NewUint32Builder(mem) + u32b2.Append(2000) + ua32_2 := u32b2.NewArray().(*array.Uint32) + u32c := CombineArray(ua32_1, ua32_2).(*array.Uint32) + if u32c.Len() != 2 { + t.Fatalf("uint32 combined length") + } + ua32_1.Release() + ua32_2.Release() + u32c.Release() + }) + + t.Run("UINT64", func(t *testing.T) { + u64b1 := array.NewUint64Builder(mem) + u64b1.Append(10000) + ua64_1 := u64b1.NewArray().(*array.Uint64) + u64b2 := array.NewUint64Builder(mem) + u64b2.Append(20000) + ua64_2 := u64b2.NewArray().(*array.Uint64) + u64c := CombineArray(ua64_1, ua64_2).(*array.Uint64) + if u64c.Len() != 2 { + t.Fatalf("uint64 combined length") + } + ua64_1.Release() + ua64_2.Release() + u64c.Release() + }) + + t.Run("FLOAT32", func(t *testing.T) { + f32b1 := array.NewFloat32Builder(mem) + f32b1.Append(1.25) + fa32_1 := f32b1.NewArray().(*array.Float32) + f32b2 := array.NewFloat32Builder(mem) + f32b2.Append(2.5) + fa32_2 := f32b2.NewArray().(*array.Float32) + f32c := CombineArray(fa32_1, fa32_2).(*array.Float32) + if f32c.Len() != 2 { + t.Fatalf("float32 combined length") + } + fa32_1.Release() + fa32_2.Release() + f32c.Release() + }) + + t.Run("FLOAT64", func(t *testing.T) { + f64b1 := array.NewFloat64Builder(mem) + f64b1.Append(3.14) + fa64_1 := f64b1.NewArray().(*array.Float64) + f64b2 := array.NewFloat64Builder(mem) + f64b2.Append(6.28) + fa64_2 := f64b2.NewArray().(*array.Float64) + f64c := CombineArray(fa64_1, fa64_2).(*array.Float64) + if f64c.Len() != 2 { + t.Fatalf("float64 combined length") + } + fa64_1.Release() + fa64_2.Release() + f64c.Release() + }) + + t.Run("BOOL", func(t *testing.T) { + bb1 := array.NewBooleanBuilder(mem) + bb1.Append(true) + bb1.AppendNull() + ba1 := bb1.NewArray().(*array.Boolean) + bb2 := array.NewBooleanBuilder(mem) + bb2.Append(false) + ba2 := bb2.NewArray().(*array.Boolean) + bc := CombineArray(ba1, ba2).(*array.Boolean) + if bc.Len() != ba1.Len()+ba2.Len() { + t.Fatalf("bool combined length") + } + ba1.Release() + ba2.Release() + bc.Release() + }) + + t.Run("STRING", func(t *testing.T) { + sb1 := array.NewStringBuilder(mem) + sb1.Append("one") + sb1.AppendNull() + sa1 := sb1.NewArray().(*array.String) + sb2 := array.NewStringBuilder(mem) + sb2.Append("two") + sa2 := sb2.NewArray().(*array.String) + sc := CombineArray(sa1, sa2).(*array.String) + if sc.Len() != sa1.Len()+sa2.Len() { + t.Fatalf("string combined length") + } + sa1.Release() + sa2.Release() + sc.Release() + }) + + t.Run("LARGE_STRING", func(t *testing.T) { + lsb1 := array.NewLargeStringBuilder(mem) + lsb1.Append("big1") + la1 := lsb1.NewArray().(*array.LargeString) + lsb2 := array.NewLargeStringBuilder(mem) + lsb2.Append("big2") + la2 := lsb2.NewArray().(*array.LargeString) + lc := CombineArray(la1, la2).(*array.LargeString) + if lc.Len() != la1.Len()+la2.Len() { + t.Fatalf("large string combined length") + } + la1.Release() + la2.Release() + lc.Release() + }) + + t.Run("BINARY", func(t *testing.T) { + bbld := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + bbld.Append([]byte("a")) + baBb1 := bbld.NewArray().(*array.Binary) + bbld2 := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + bbld2.Append([]byte("b")) + baBb2 := bbld2.NewArray().(*array.Binary) + bcbin := CombineArray(baBb1, baBb2).(*array.Binary) + if bcbin.Len() != baBb1.Len()+baBb2.Len() { + t.Fatalf("binary combined length") + } + baBb1.Release() + baBb2.Release() + bcbin.Release() + }) + + t.Run("NIL_A1", func(t *testing.T) { + // build a small binary array to pass as second + bbld := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + bbld.Append([]byte("z")) + sec := bbld.NewArray().(*array.Binary) + got := CombineArray(nil, sec) + if got == nil { + t.Fatalf("expected non-nil when a1 is nil") + } + if got != sec { // CombineArray will return sec directly when a1 is nil + got.Release() + } + sec.Release() + }) + + t.Run("NIL_A2", func(t *testing.T) { + bbld := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + bbld.Append([]byte("y")) + first := bbld.NewArray().(*array.Binary) + got := CombineArray(first, nil) + if got == nil { + t.Fatalf("expected non-nil when a2 is nil") + } + if got != first { // CombineArray will return first directly when a2 is nil + got.Release() + } + first.Release() + }) +} + +// includes null values so append* helpers take the AppendNull branch. +func TestCombineArray_PerTypeNulls(t *testing.T) { + mem := memory.NewGoAllocator() + + t.Run("AppendUint16_nulls", func(t *testing.T) { + b1 := array.NewUint16Builder(mem) + b1.Append(11) + b1.AppendNull() + b1.Append(13) + a1 := b1.NewArray().(*array.Uint16) + + b2 := array.NewUint16Builder(mem) + b2.AppendNull() + b2.Append(15) + a2 := b2.NewArray().(*array.Uint16) + + out := CombineArray(a1, a2).(*array.Uint16) + if out.Len() != 5 { + t.Fatalf("uint16 expected len 5 got %d", out.Len()) + } + if !out.IsNull(1) || !out.IsNull(3) { + t.Fatalf("uint16 nulls not preserved") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendInt16_nulls", func(t *testing.T) { + b1 := array.NewInt16Builder(mem) + b1.Append(21) + b1.AppendNull() + a1 := b1.NewArray().(*array.Int16) + b2 := array.NewInt16Builder(mem) + b2.AppendNull() + b2.Append(23) + a2 := b2.NewArray().(*array.Int16) + out := CombineArray(a1, a2).(*array.Int16) + if out.Len() != 4 { + t.Fatalf("int16 expected len 4 got %d", out.Len()) + } + if !out.IsNull(1) || !out.IsNull(2) { + t.Fatalf("int16 nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendInt32_nulls", func(t *testing.T) { + b1 := array.NewInt32Builder(mem) + b1.Append(31) + b1.AppendNull() + a1 := b1.NewArray().(*array.Int32) + b2 := array.NewInt32Builder(mem) + b2.AppendNull() + b2.Append(33) + a2 := b2.NewArray().(*array.Int32) + out := CombineArray(a1, a2).(*array.Int32) + if !out.IsNull(1) || !out.IsNull(2) { + t.Fatalf("int32 nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + t.Run("AppendUint32_nulls", func(t *testing.T) { + b1 := array.NewUint32Builder(mem) + b1.AppendNull() + b1.Append(22) + a1 := b1.NewArray().(*array.Uint32) + b2 := array.NewUint32Builder(mem) + b2.Append(23) + b2.AppendNull() + a2 := b2.NewArray().(*array.Uint32) + out := CombineArray(a1, a2).(*array.Uint32) + if !out.IsNull(0) || !out.IsNull(3) { + t.Fatalf("uint32 nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendInt64_nulls", func(t *testing.T) { + b1 := array.NewInt64Builder(mem) + b1.AppendNull() + b1.Append(41) + a1 := b1.NewArray().(*array.Int64) + b2 := array.NewInt64Builder(mem) + b2.Append(42) + b2.AppendNull() + a2 := b2.NewArray().(*array.Int64) + out := CombineArray(a1, a2).(*array.Int64) + if !out.IsNull(0) || !out.IsNull(3) { + t.Fatalf("int64 nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendUint64_nulls", func(t *testing.T) { + b1 := array.NewUint64Builder(mem) + b1.AppendNull() + b1.Append(41) + a1 := b1.NewArray().(*array.Uint64) + b2 := array.NewUint64Builder(mem) + b2.Append(42) + b2.AppendNull() + a2 := b2.NewArray().(*array.Uint64) + out := CombineArray(a1, a2).(*array.Uint64) + if !out.IsNull(0) || !out.IsNull(3) { + t.Fatalf("Uint64 nulls not present") + } + a1.Release() + a2.Release() + out.Release() + + }) + + t.Run("AppendUint8_nulls", func(t *testing.T) { + b1 := array.NewUint8Builder(mem) + b1.AppendNull() + b1.Append(2) + a1 := b1.NewArray().(*array.Uint8) + b2 := array.NewUint8Builder(mem) + b2.Append(3) + b2.AppendNull() + a2 := b2.NewArray().(*array.Uint8) + out := CombineArray(a1, a2).(*array.Uint8) + if !out.IsNull(0) || !out.IsNull(3) { + t.Fatalf("uint8 nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendFloat32_nulls", func(t *testing.T) { + b1 := array.NewFloat32Builder(mem) + b1.Append(1.5) + b1.AppendNull() + a1 := b1.NewArray().(*array.Float32) + b2 := array.NewFloat32Builder(mem) + b2.AppendNull() + b2.Append(2.5) + a2 := b2.NewArray().(*array.Float32) + out := CombineArray(a1, a2).(*array.Float32) + if !out.IsNull(1) || !out.IsNull(2) { + t.Fatalf("float32 nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendFloat64_nulls", func(t *testing.T) { + b1 := array.NewFloat64Builder(mem) + b1.AppendNull() + b1.Append(3.14) + a1 := b1.NewArray().(*array.Float64) + b2 := array.NewFloat64Builder(mem) + b2.Append(4.14) + b2.AppendNull() + a2 := b2.NewArray().(*array.Float64) + out := CombineArray(a1, a2).(*array.Float64) + if !out.IsNull(0) || !out.IsNull(3) { + t.Fatalf("float64 nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendBool_nulls", func(t *testing.T) { + b1 := array.NewBooleanBuilder(mem) + b1.Append(true) + b1.AppendNull() + a1 := b1.NewArray().(*array.Boolean) + b2 := array.NewBooleanBuilder(mem) + b2.AppendNull() + b2.Append(false) + a2 := b2.NewArray().(*array.Boolean) + out := CombineArray(a1, a2).(*array.Boolean) + if !out.IsNull(1) || !out.IsNull(2) { + t.Fatalf("bool nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendString_nulls", func(t *testing.T) { + b1 := array.NewStringBuilder(mem) + b1.Append("s1") + b1.AppendNull() + a1 := b1.NewArray().(*array.String) + b2 := array.NewStringBuilder(mem) + b2.AppendNull() + b2.Append("s2") + a2 := b2.NewArray().(*array.String) + out := CombineArray(a1, a2).(*array.String) + if !out.IsNull(1) || !out.IsNull(2) { + t.Fatalf("string nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendLargeString_nulls", func(t *testing.T) { + b1 := array.NewLargeStringBuilder(mem) + b1.AppendNull() + b1.Append("L1") + a1 := b1.NewArray().(*array.LargeString) + b2 := array.NewLargeStringBuilder(mem) + b2.Append("L2") + b2.AppendNull() + a2 := b2.NewArray().(*array.LargeString) + out := CombineArray(a1, a2).(*array.LargeString) + if !out.IsNull(0) || !out.IsNull(3) { + t.Fatalf("large string nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + + t.Run("AppendBinary_nulls", func(t *testing.T) { + b1 := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + b1.AppendNull() + b1.Append([]byte("bb1")) + a1 := b1.NewArray().(*array.Binary) + b2 := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + b2.Append([]byte("bb2")) + b2.AppendNull() + a2 := b2.NewArray().(*array.Binary) + out := CombineArray(a1, a2).(*array.Binary) + if !out.IsNull(0) || !out.IsNull(3) { + t.Fatalf("binary nulls not present") + } + a1.Release() + a2.Release() + out.Release() + }) + +} +func TestCombineArray_UnsupportedType(t *testing.T) { + mem := memory.NewGoAllocator() + + // Build a FixedSizeBinary array (NOT supported in your switch) + dt := &arrow.FixedSizeBinaryType{ByteWidth: 4} + b := array.NewFixedSizeBinaryBuilder(mem, dt) + b.Append([]byte{0, 1, 2, 3}) + b.Append([]byte{4, 5, 6, 7}) + arr := b.NewArray() + b.Release() + + defer arr.Release() + + // Expect panic + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic for unsupported datatype") + } + }() + + // Call CombineArray with unsupported type + _ = CombineArray(arr, arr) } diff --git a/src/Backend/opti-sql-go/operators/project/source/s3_test.go b/src/Backend/opti-sql-go/operators/project/source/source_test.go similarity index 55% rename from src/Backend/opti-sql-go/operators/project/source/s3_test.go rename to src/Backend/opti-sql-go/operators/project/source/source_test.go index f62698f..40d2170 100644 --- a/src/Backend/opti-sql-go/operators/project/source/s3_test.go +++ b/src/Backend/opti-sql-go/operators/project/source/source_test.go @@ -2,6 +2,7 @@ package source import "testing" +// test s3 as a source first then run test for other source files here func TestS3(t *testing.T) { // Simple passing test } diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index 54ccf60..5b1b814 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -17,10 +17,13 @@ var ( type Operator interface { Next(uint16) (*RecordBatch, error) + // Call Operator.Close() after Next retruns an io.EOF to clean up resources + Close() error } type RecordBatch struct { - Schema *arrow.Schema - Columns []arrow.Array + Schema *arrow.Schema + Columns []arrow.Array + RowCount uint64 // TODO: update to actaully use this, in all operators } type SchemaBuilder struct { diff --git a/src/Backend/test_data/parquet/capitals_clean.parquet b/src/Backend/test_data/parquet/capitals_clean.parquet new file mode 100644 index 0000000000000000000000000000000000000000..5593be6b9ebcccdb562292e78c72846896104436 GIT binary patch literal 12340 zcmd^mdwdgB_IL&|B`ElL=;q9*Y$fQQ;_e!^l)eoC%58jJcrMh6TSE#9Cs6Q3%mo zoZ++%qY$BC0g^PTMZeZ!5Myxxvx8cRK^g|95*A|*E%-6huD#P}!(jn49Fb9M!(52P zK{^)J07XQH4hQ)d@tDCzvse(*2I@kL7*XDt9-4^jZGeL#hpEHjppzyW&oZ419!^y1 z_!!4%rJM>{=%GV=C&OVfXyD_qF3gc%8EgXXVOY|s7etO`NV7>~Ls($2jwsI&g$@{P zF2^yR{Ec3O5kYdC(T*d2KFE+hxkd{-PcGNnF^7ZX1!Tuj+T277LD~pJ+iA`+(qX3B z1SWtvb3Gju{VWqu6DCa<2Z=xF?O}$ciNyvx&+-v}Zn@FUM|c6YxgyWb^Brb)j0K*? zEVImZ9)w}8X9CuhM!Ue{IIy!0xJiC!v?rp17#mCkbUkz+oGuod0kwmU@Bsma1u{nD z@NMRJn1(gD5P8n%pt%SZI*5CWj&{b+i*fSrW(OZ(I3RK;;TojNIBdx^Is+mwTp-@o zJ41P0aHN8_{;-DY8Se)l_=Np~CVGuvL z06F7fhL19M5X>vb0#6sFsm2g#)94@#B+@~%FAhbFc+2Er0s_y`vD}fma;BXj)@jNa z&f>ECpf49N7jXJ?F6`&&@p>1{3!&Uwj4oP;U@k!%)VV-oou?{Xj30C>j)fFJ<`QT+ zK;tn)D%nI}4ZKoX(r5(A=?G@6_h`=$uV_AqqK5y16imN$U>sc;EoBHX2}`4 zIA#-;`*Yl#ypRAP;8>i!5!g+4TNK670o5^ zNZ3Xd!-ax;#3E33Lk$h06o?I~7&Y5e#R!Zam`p4*RPiyslP6R4Vi#uO)Imd4LI@># zh(_WCT{Q-kGD8vX%)MP-O~Y}{eb87<$M|@dCpKxS8UMn)#8MrU32Re9H3NDSr@2_1 z2Ik+auV&(5ap6TGoK-D$)1dsK5F!SfJUEC$I2On~p!Z-FcR7PZsyY#KFMu94$aI20 z*XW?i5D}Hf@?)MJcH+T8ER0#UZn3seCxX5ZHxM_gJmg?CBa(kKc<62>z!TdI9-a;I zoz66x*-vdXIxk04K3UvbGM~Q*@8XRH# zn^r&x$ClWnuc4!uwR~i(p}WkjG`6fHdQCN;@B$8r))~4QCPZs(gq9fxRaj?QF)zf! z<|pj9adxL=<3B-O#3H8Js7hfCZ^^2z<9>ZRZKh{ zGoJ)D(49<-xJ8>`sulwbHmg*Fh8UX|E&!I99FNs^_9q^*4Al8>gm*5^@_~^82{9bM z?g*T9G02!H0dM(I<+FUN!h3$B^MRdq4>tSgL?BGFEWKwyj*p?i8JMeSo?Fsw_JOA2 zC~%GR63-eu7z=UQy4>Jn`A*#7ex1;1z(`yFpz^&MQ1e{fcFLyl@sX1|$Z{PRqVS<7 z$#bOMsmU0k^MMnwp3?ZDm|;ymfiYKOu4B*04YyeY6R??%foa(gBugxtOg>N~;5cnr zVDNQ;P2tE^T^AGYxwF(-qDz2@I$4(G+x-}CZr~%kH(BSIyu$7+T!4v^v-J&_>B37M z)HUFC#(fjvhYHtYIYH`-VL%K_C4yNEU^~Gt;RvRmA%>`2xPg2xr-1=WoeyUyPE&8| z8u$=T+zAC28ah&IY5;E4?s|IC^xB9I+jE#Z_ zE;q5zqHFB&)46J5p0NpYVA(|c4R|1q?h(OJl12kV$ALM-d%6}nLX!`xTCHTB>`_~p zh#&jA=yX;3v;EVz_s^4K0DXE&dWva^Do1mv+c@xwU;w7QL!3L-AIjeTgr-DA^y{xC zQwC_XsRMO`1`pAv4IO4kA3nmEF>;h?^q6bRmdvqP*6eXPx#K6~uEv?tiz3%#X^RYh=q#+n#+B?_? z$474vVsWvvt2@zi5IU8OoT;|iNS&$$+MQ&IHcM4W>c}b3F<)(?$f0BdRNK_H8mM-wZRIf3O{A!* z$oVQexfo!I+TKJMp^^pbcB}2xOFkmCU! z1(53%YNwsB0R<~nK2lJ->Hu90yq=C%V&f%)=Aj{#(< zT-q#hFVuPg8dNT;0U#Ud+kpyW9~@S@eMGOy4WzmOy&INvlUo5;fL4H+K%xMg*r=}b z5;LK)9d`UWK!K_TpbBW8V>Ls$SXE^mtg42-YQR)&{Tk|v0lEP;0?Y&ul64Q%J+K!K z(CC4EcmSCP*73l+WV?2XY7X%z)Q>^iuc7=LU?ad?0NVi$1MZ;!Hh>cVV*pkHOaZtH zU_Q_zul=#A#;O6~SE#%&-U}P^0w=s6DQ^q05_Yy2UTOfI1K68{I@pvQpdJR+17Y>h zUzH^ObyX8Ex7j)y1knvJ6Cg`fuQkHc2Pl0o!v|aNfhc`Iy3cw$l&q>1$gHyVgSsqL z&`em}3(DiOdR4xpi0pvdmlW_`(44L6hDzd`s=?}jmn_gZ8^~-c)O-8jW>tfBCdj=3 zkTz(?ThD+hj|X^EMFC;5d>et3M)GC=R@G!J0>ZNZYKmK&^7&}>B%jS!wK#6sB9w5@i8;S2y{$dq=gpfP6$78y|u7zJ9r`-lnp!N zB7X46qz4>5&2d;wlxc8>NX!F&44wcJ59l0W+(Cx3YyuMu(Q&?uBSz>w3>WJl*3Psu z@q`&N1{)+5F`A7LbLEsP!bi28Szd@o`1~3OG>JbO8p2E@8YTmVYHar69%7Qo2jNsZ z4V^l(9Hz^Wyq37z3=#D1Ei8kF5rc?x1%er(2*zxqL%0i*ODlYQ8-wMfvsQ#~GRTV| zp1fY~V&Zg&dLW)1F{E?O&&Eh?A|7a(VdWtLMX&^FalMm@rj+*Oj89&jNv*M;tmKy z1tvxbSrFBOqA(Dy5!V?S{19fwtdk&H7FfnA>8LNug4SapYv`d7^|`mX(jid@lq-0h*eU z)j)?J01aU_O71n;AhwHwz7jLC>ghyBJMP@QrNeU6EC=Tfnhi5V2joW`5Ii~y^$j%K z2w2liHXI5=NR?uFPA_NC95K~g388ZgV!;UUx()(jns_~@1tM9{TVCuCoEKz~g&-TX z>LCp4NLY94sfZk~do)xmUN_zHXO-XikeP}=3sYy<2w+(DCn!r$saxghU-P|buOb+OTmAjXMqkY)_XqShKAQxV|0!JV%c zm?+O$Kh-zkpx9#>VrXJmjuCUeHvgPX-KVBFVuy|cnK8$W)eH-fWS5mxSBaM28>+=f zR1~NZjhsyt>uqp|JFG@S6Q~It&kY00c2VFhYs@vgh&uvdK5qT3jVuQ9f#qz2_=C|-!|vN@VwJ8Q z`*FTj7Zu~Y7z@ealfc2SB}e7Ji!Dom95y82AXx)uIw0Sq(64LC;lAQ~qV?xA$~rtNnHR?V zkaLqOP4bNn+zEBkGY!@4a17cbHyGfSg24i5gwz=9C#a`rac9|Bo{qH1x$>QR^gfyw zS&KnaDf;uL>w;kQopsZaS5CZqqMdLfVP2g^?8vEtZN_;q0?ub$i2>9Mt{0r`CSRNn zbcA_!$uu394UU*(a6n%yz%b^^dDWhS2FRtNaD|qo5o6(#?^>QTIC!|Vz{KsZ9@TX4 z9Yngi2Cp*_sXDOM?b;XB4u&}lQMK1HHm64H5D)IapaJ+*##w3r{{xwy_AAJa;2M%} z=qV1&ft3e&hC?9wII+Z7E^@d{jBm(;JCg`x_d~DK=&BM-r)9YbG|CJ<&cA0MSWxh= z+~7)$z8S~D44rMrGdm{QCt8NXr39_(WP6CBPzy<_xj56vg zFvl~NWg03-Z$LAv!Qyma*_~Iw2;pC(PbrEsF={VS!oY<8_UPy7jP!>xE~ie%bQcB_ z!^~%f6geoA`{RWH8Jis_wOjkb{D2KTvmvnY40585-G5A(zQrQV`@%7nxTX*_3}5%` z?rKWfdb(ViI>LcA<=N;vw7Do!mv?N@21;r_ziIR4krin5n=h?8JJEsezYjj8)6q-w zy`5{ePet`Fjal;7!ttmu(f{=Y=PS_@cfI!Od8``&_f;xXi2hpq+1ayaGtjh)yMmj$ zXCUs@FV3F5xj-`1ZpnV=3nx12`17LtjnkyMna|(f-8)HoY|6u%H*X$>QYWN0G_*L; z`}4##+h5K>bl(1y>5+lb;oPZ_N5+mr2k*aj&Gz>Pqq9#U`n+kHsa@3yt@lUvYda zn&fIM9(uYQ;WK%?f6ttXwz^LC7GT|0ETdNA5_=yBNIM6}d> zLEp3$>bzJw# zZMxMLb~+0j6DyH3GP|*{@td+Er+iR4bB}4}&qAcx!8w(^{H+jyq_(;e)$aWlhV}aj zAz1)(YiCI>+837ho1KByfBbf7N)(}mKgE9c<0_ZJ`sONheEF28vfrpsM0UDJ67P?9 zb$vHMGWmK&|8}BQVc9mlBC^U{RJ7qJ{`_kSwTyW2E!FiVBtG%!`tQ<<&`Uc%FA#H! zrO74Q6DwOME8`ZqloNE^iNf96Qoj7gg>cnA?>mJ~^!e_&o|*kA)P3>W^|hauN=bTm zPF6&=YKr7udu#Tt+fVNNHV1U+`59<^+GDkUTknuw-uYP~&6|TN` zCKB&4s+$H+mXfNn$c~=se00^VKRHnRpl8(S3!XkvO;zZ82cbO29nvb*81&Vk-zQ4G zvY_azkDMv^CR0kT>nfEdcq*wE&)d;8aDZMKlOqi&9-Nj|n6HR_bS3(W@xYjMt#-6} zWBY3kcPUy?zHIhq!_uVqz4hwPip!Ow^kEg6wLJ2>AE}>qCM`mBp`vn{JhUQXq{JH%OLxx)6b1U5B%ed>#qG}$n(lOr}h?4Lec(gS69~v zG_t*V=7h~2v~1mj;pT)39r*CBSVdPKon>PE2#mZ=xSrYRO) za--d#{woAKy7Q@n!@h}8Qh4>an+`wjK;P%RdE1U1g$i9NE4ocOxXzhU(r3bQQ4;^# z6D3=SayX6@N>{mBFlrQg=tV}&CAAz5SC_wRh0Xte;C$C#>3kRT#+m+WE9+40gQ1UF z-kOZQX&*OI+pR%fsp@I++KXjc{e^zyc=YkH-uJFKQzK3La_aKG6nLbhU7*U*iN=pd zR4uBL1ojVK3|a42?AW{lWZKs2{`B(-#Xh**h zG~4q#^_)dhCAP8cSo+E$MV~SzD&?j+=}A*-aK@5)>Db}IhR#Qv=(FrX&65`c(&)N( zf=04YntWtUchiMQ=+u-9?Zw^(DQS1MB2@I^fePtp9-3AC=)@Nvo`}Y17loW-bJ6B& z`!6{EYp3Ebo~ekWJI$kfCfJeF4xnBlIQr`H1xVR0AlTj zj@*B0^y8;96?;gRqvUycZ9Ga}R{6QQ_`uF&DA1NJF?7v47$T&k>pt$5PSfvR6a277 zT30nF{_Z19(qF$G9o)!HkS4ay+jyp6lJs7|a|>&Rd8E3pqNaW6_0p7n4WDEjYLr?I zZx|F$ovi4?mnBltZh3sld4H}Iy-S>L|L7SvDl3J6Xks~fb?!Op_S6jN@53^Q4<4k@ zwo}7@JMp_5RCn{O)}6+5>4lu(?5jmSY0+LwW@aZwQ>9MudlVWqf5(m;@6{=K5wj09Y)39?(SRnV}zoebGcjC7DaF8*DJ^A;x$r|zbW}> z@GAxPjb~|;4CO9NlCH_FyZ61F)#%~QiyhDIpCbh;pW1!o?MorbXhiRt^v(I-wJH(i znIa|foHI&_p5RUU+4oUjn4zz99SE-*o zS2aq>5b&G|ZK)j${2h%Tpght&i zM~PT%^=Q`}L-W7?(|q*C4|o4%+DATV)mJZ+?GO4Dzx-CUWM6gmQ^UyX6u)TmDWT`= z>rwyf7v&%8SAyP}_1BU4u1sm6=BE{}?RKNP9&du!bRxR+CN~{1E1t5a%**=86VW%eJ@ui zB@KPNSn-3~tB^S3U0}d?DQWc2x1di3%=!9POKPPVXRjeXFcl)(-uM1t?w_yF{f_|J z*8kMouRT$YlJV8l>1f5mW15}w(2|Yk@ai(OX5J5_kw*qeyz7DHVc+?XdCg<{zFbS8 zUiHRJv>hp@c6F6x`Ag)Mm1#xD{*!vdgoC3M?`^g9DlH-~;N; zoIhGx(B|`B9(tnDr-pJ^HA{X+{XvS&zHIS4VogJqLmP8SB*4x1+mGT zg|`2GePL?{(pH=k3O#7tvv6(#ZD{{Q`}30vjMdWcgv##5pQQ^rqaqe+u@al8=u|&w%Dqiqs8S&Gx81y$!97% z8&rm_>VGd0VbolrS@yrbZ2j-|zZqSZJ{Z#PKL!nY#x!^xhsI0qY#PGlAvo{uI zZ_JVxzVwaGzhCLnhLk1l8bAE+W3!hxWs=)Mf&tlC2rOay0?p=x5{^ zQ=;;hJ2B%P^2SWKIHkMi1$b7BaIwtk=4-FZGiNHV^OAk^KeZ-bZ_Jc`>j1-HXi~j0 zVOASMv%z^+=}^$kizlMVDR7$>;DfY0g@a4w#N`g&-%bbOa(je^?~&mC7cg)l%ife{ z{sqQ9<^7k4uju~|<5y0|RS^H&@t?&AO7}}_eU1Nwsf~r7Po{6_q7!MqZ~i+S=URF{h{`>~C&_^(m&MsI<_})hB45TVAWHrMaY| zWll|7u!8OCTcgk)VF5=ytE^E{=$~UNnKQY*B;Zok@E29y(Ar#6p!65AL037fQ6H9R z^m9!ye+AH4ado=p6qUyUu5j4HB=JXZb5pFfyi{yyE@b8T^4^2ZB?SSFZEI=n?!qo- zB8i8yg@G60Kw)VCJ;$bu<(fp?T+`VKbXD6a25`p%73`IJ>1>Ha*_O$5N`G@xK}&Oe zIOuZ9ycUDa-LW>{H=U@w5`Q@0VqsiYa!zZ6W&N)DwkvtuGRGE$J;(mv($yUf+W!^L zF85tco)N#RG#qra^vUa&@@$QCcec!_E2>e}kGFy>TV3S^EwWCzV7#llwAwD~?9X}x zYl@2ef2m`a`>rOlNNdbrR9;Zi-Y1h^%Iu%$y_$X%24#7-Os-@>?>Ykv&?)Pi3x2U2 zDVKFyJ_}uRb9o{VDV?mysg(_IHBo<2aq=8^8ahOey|~-WCC}8)e2X+C0!3`6UtZJh zPE@zMMNgu#R6e?37bvQtPY>ak+vM|JBR}lPlivz8${|0JHic^Rf_;(AL7gYr4@O0P zNj{SgxzE+z=?^)YWm5!icB-NsYSWnlb Date: Sun, 16 Nov 2025 19:02:54 -0500 Subject: [PATCH 05/19] feat: Implement s3 file reader & file downloader --- src/Backend/opti-sql-go/config/config.go | 22 +- src/Backend/opti-sql-go/go.mod | 18 ++ src/Backend/opti-sql-go/go.sum | 40 +++ .../operators/project/source/csv.go | 6 +- .../operators/project/source/parquet.go | 38 ++- .../operators/project/source/parquet_test.go | 16 +- .../operators/project/source/s3.go | 111 +++++++ .../operators/project/source/source_test.go | 299 +++++++++++++++++- 8 files changed, 535 insertions(+), 15 deletions(-) diff --git a/src/Backend/opti-sql-go/config/config.go b/src/Backend/opti-sql-go/config/config.go index ba02bcf..7f70612 100644 --- a/src/Backend/opti-sql-go/config/config.go +++ b/src/Backend/opti-sql-go/config/config.go @@ -16,10 +16,11 @@ var ( ) type Config struct { - Server serverConfig `yaml:"server"` - Batch batchConfig `yaml:"batch"` - Query queryConfig `yaml:"query"` - Metrics metricsConfig `yaml:"metrics"` + Server serverConfig `yaml:"server"` + Batch batchConfig `yaml:"batch"` + Query queryConfig `yaml:"query"` + Metrics metricsConfig `yaml:"metrics"` + Secretes secretesConfig // do not read these from yaml } type serverConfig struct { Port int `yaml:"port"` @@ -54,6 +55,12 @@ type metricsConfig struct { // memory usage over time EnableMemoryStats bool `yaml:"enable_memory_stats"` } +type secretesConfig struct { + AccessKey string `yaml:"access_key"` + SecretKey string `yaml:"secret_key"` + EndpointURL string `yaml:"endpoint_url"` + BucketName string `yaml:"bucket_name"` +} var configInstance *Config = &Config{ Server: serverConfig{ @@ -86,6 +93,13 @@ var configInstance *Config = &Config{ EnableQueryStats: true, EnableMemoryStats: true, }, + // TODO: remove hardcoded secretes before production. we are just testing for now + Secretes: secretesConfig{ + AccessKey: "DO8013ZT6VDHJ2EM94RN", + SecretKey: "kPvQSMt6naiwe/FhDnzXpYmVE5yzJUsIR0/OJpsUNzo", + EndpointURL: "atl1.digitaloceanspaces.com", + BucketName: "test-bucket-pull-down", + }, } func GetConfig() *Config { diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index 5012de3..184caaa 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -5,6 +5,10 @@ go 1.24.0 require ( github.com/apache/arrow/go/v15 v15.0.2 github.com/apache/arrow/go/v17 v17.0.0 + github.com/aws/aws-sdk-go v1.55.8 + github.com/aws/aws-sdk-go-v2 v1.39.6 + github.com/aws/aws-sdk-go-v2/service/s3 v1.90.2 + github.com/joho/godotenv v1.5.1 google.golang.org/grpc v1.63.2 google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v3 v3.0.1 @@ -14,16 +18,30 @@ require ( github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/apache/thrift v0.20.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.13 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.13 // indirect + github.com/aws/smithy-go v1.23.2 // indirect + github.com/go-ini/ini v1.67.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect + github.com/minio/minio-go v6.0.14+incompatible // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect + golang.org/x/crypto v0.24.0 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/mod v0.18.0 // indirect golang.org/x/net v0.26.0 // indirect diff --git a/src/Backend/opti-sql-go/go.sum b/src/Backend/opti-sql-go/go.sum index afb9b78..9c4220d 100644 --- a/src/Backend/opti-sql-go/go.sum +++ b/src/Backend/opti-sql-go/go.sum @@ -8,8 +8,35 @@ github.com/apache/arrow/go/v17 v17.0.0 h1:RRR2bdqKcdbss9Gxy2NS/hK8i4LDMh23L6BbkN github.com/apache/arrow/go/v17 v17.0.0/go.mod h1:jR7QHkODl15PfYyjM2nU+yTLScZ/qfj7OSUZmJ8putc= github.com/apache/thrift v0.20.0 h1:631+KvYbsBZxmuJjYwhezVsrfc/TbqtZV4QcxOX1fOI= github.com/apache/thrift v0.20.0/go.mod h1:hOk1BQqcp2OLzGsyVXdfMk7YFlMxK3aoEVhjD06QhB8= +github.com/aws/aws-sdk-go v1.55.8 h1:JRmEUbU52aJQZ2AjX4q4Wu7t4uZjOu71uyNmaWlUkJQ= +github.com/aws/aws-sdk-go v1.55.8/go.mod h1:ZkViS9AqA6otK+JBBNH2++sx1sgxrPKcSzPPvQkUtXk= +github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk= +github.com/aws/aws-sdk-go-v2 v1.39.6/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.3 h1:DHctwEM8P8iTXFxC/QK0MRjwEpWQeM9yzidCRjldUz0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.3/go.mod h1:xdCzcZEtnSTKVDOmUZs4l/j3pSV6rpo1WXl5ugNsL8Y= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 h1:a+8/MLcWlIxo1lF9xaGt3J/u3yOZx+CdSveSNwjhD40= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13/go.mod h1:oGnKwIYZ4XttyU2JWxFrwvhF6YKiK/9/wmE3v3Iu9K8= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 h1:HBSI2kDkMdWz4ZM7FjwE7e/pWDEZ+nR95x8Ztet1ooY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13/go.mod h1:YE94ZoDArI7awZqJzBAZ3PDD2zSfuP7w6P2knOzIn8M= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.13 h1:eg/WYAa12vqTphzIdWMzqYRVKKnCboVPRlvaybNCqPA= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.13/go.mod h1:/FDdxWhz1486obGrKKC1HONd7krpk38LBt+dutLcN9k= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 h1:x2Ibm/Af8Fi+BH+Hsn9TXGdT+hKbDd5XOTZxTMxDk7o= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3/go.mod h1:IW1jwyrQgMdhisceG8fQLmQIydcT/jWY21rFhzgaKwo= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.4 h1:NvMjwvv8hpGUILarKw7Z4Q0w1H9anXKsesMxtw++MA4= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.4/go.mod h1:455WPHSwaGj2waRSpQp7TsnpOnBfw8iDfPfbwl7KPJE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 h1:kDqdFvMY4AtKoACfzIGD8A0+hbT41KTKF//gq7jITfM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13/go.mod h1:lmKuogqSU3HzQCwZ9ZtcqOc5XGMqtDK7OIc2+DxiUEg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.13 h1:zhBJXdhWIFZ1acfDYIhu4+LCzdUS2Vbcum7D01dXlHQ= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.13/go.mod h1:JaaOeCE368qn2Hzi3sEzY6FgAZVCIYcC2nwbro2QCh8= +github.com/aws/aws-sdk-go-v2/service/s3 v1.90.2 h1:DhdbtDl4FdNlj31+xiRXANxEE+eC7n8JQz+/ilwQ8Uc= +github.com/aws/aws-sdk-go-v2/service/s3 v1.90.2/go.mod h1:+wArOOrcHUevqdto9k1tKOF5++YTe9JEcPSc9Tx2ZSw= +github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= +github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= @@ -18,6 +45,11 @@ github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81A github.com/google/flatbuffers v24.3.25+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= @@ -28,16 +60,23 @@ github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpsp github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= +github.com/minio/minio-go v6.0.14+incompatible h1:fnV+GD28LeqdN6vT2XdGKW8Qe/IfjJDswNVuni6km9o= +github.com/minio/minio-go v6.0.14+incompatible/go.mod h1:7guKYtitv8dktvNUGrhzmNlA5wrAABTQXCoesZdFQO8= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= @@ -64,5 +103,6 @@ google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDom google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/Backend/opti-sql-go/operators/project/source/csv.go b/src/Backend/opti-sql-go/operators/project/source/csv.go index b5c48df..6975f3b 100644 --- a/src/Backend/opti-sql-go/operators/project/source/csv.go +++ b/src/Backend/opti-sql-go/operators/project/source/csv.go @@ -86,6 +86,11 @@ func (csvS *CSVSource) Next(n uint64) (*operators.RecordBatch, error) { Columns: columns, }, nil } +func (csvS *CSVSource) Close() error { + csvS.r = nil + csvS.done = true + return nil +} func (csvS *CSVSource) initBuilders() []array.Builder { fields := csvS.schema.Fields() @@ -102,7 +107,6 @@ func (csvS *CSVSource) processRow( builders []array.Builder, ) error { fields := csvS.schema.Fields() - fmt.Printf("content : %v\n", content) for i, f := range fields { colIdx := csvS.colPosition[f.Name] cell := content[colIdx] diff --git a/src/Backend/opti-sql-go/operators/project/source/parquet.go b/src/Backend/opti-sql-go/operators/project/source/parquet.go index 8486823..b2e848f 100644 --- a/src/Backend/opti-sql-go/operators/project/source/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/source/parquet.go @@ -26,8 +26,44 @@ type ParquetSource struct { done bool // if set to true always return io.EOF } +func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { + allocator := memory.NewGoAllocator() + filerReader, err := file.NewParquetReader(r) + if err != nil { + return nil, err + } + + defer func() { + if err := filerReader.Close(); err != nil { + fmt.Printf("warning: failed to close parquet reader: %v\n", err) + + } + }() + + arrowReader, err := pqarrow.NewFileReader( + filerReader, + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: 5}, // TODO: Read in from config for this stuff + allocator, + ) + if err != nil { + return nil, err + } + rdr, err := arrowReader.GetRecordReader(context.TODO(), nil, nil) + if err != nil { + return nil, err + } + + return &ParquetSource{ + Schema: rdr.Schema(), + projectionPushDown: []string{}, + predicatePushDown: nil, + reader: rdr, + }, nil + +} + // source, columns you want to be push up the tree, any filters -func NewParquetSource(r parquet.ReaderAtSeeker, columns []string, filters []filter.FilterExpr) (*ParquetSource, error) { +func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string, filters []filter.FilterExpr) (*ParquetSource, error) { if len(columns) == 0 { return nil, errors.New("no columns were provided for projection push down") } diff --git a/src/Backend/opti-sql-go/operators/project/source/parquet_test.go b/src/Backend/opti-sql-go/operators/project/source/parquet_test.go index 1b5daa1..15ad8c5 100644 --- a/src/Backend/opti-sql-go/operators/project/source/parquet_test.go +++ b/src/Backend/opti-sql-go/operators/project/source/parquet_test.go @@ -59,7 +59,7 @@ func TestParquetInit(t *testing.T) { t.Run("Test No names pass in", func(t *testing.T) { f := getTestParuqetFile() - _, err := NewParquetSource(f, []string{}, nil) + _, err := NewParquetSourcePushDown(f, []string{}, nil) if err == nil { t.Errorf("Expected error when no columns are passed in, but got nil") } @@ -67,7 +67,7 @@ func TestParquetInit(t *testing.T) { t.Run("Test invalid names are passed in", func(t *testing.T) { f := getTestParuqetFile() - _, err := NewParquetSource(f, []string{"non_existent_column"}, nil) + _, err := NewParquetSourcePushDown(f, []string{"non_existent_column"}, nil) if err == nil { t.Errorf("Expected error when invalid column names are passed in, but got nil") } @@ -76,7 +76,7 @@ func TestParquetInit(t *testing.T) { t.Run("Test correct schema is returned", func(t *testing.T) { f := getTestParuqetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSource(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -95,7 +95,7 @@ func TestParquetInit(t *testing.T) { t.Run("Test input columns and filters were passed back out", func(t *testing.T) { f := getTestParuqetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSource(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -111,7 +111,7 @@ func TestParquetInit(t *testing.T) { f := getTestParuqetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSource(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -125,7 +125,7 @@ func TestParquetInit(t *testing.T) { func TestParquetClose(t *testing.T) { f := getTestParuqetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSource(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -145,7 +145,7 @@ func TestParquetClose(t *testing.T) { func TestRunToEnd(t *testing.T) { f := getTestParuqetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSource(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -164,7 +164,7 @@ func TestRunToEnd(t *testing.T) { func TestParquetRead(t *testing.T) { f := getTestParuqetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSource(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/src/Backend/opti-sql-go/operators/project/source/s3.go b/src/Backend/opti-sql-go/operators/project/source/s3.go index d150341..73e5379 100644 --- a/src/Backend/opti-sql-go/operators/project/source/s3.go +++ b/src/Backend/opti-sql-go/operators/project/source/s3.go @@ -1 +1,112 @@ package source + +import ( + "fmt" + "io" + "opti-sql-go/config" + "os" + "time" + + "github.com/minio/minio-go" +) + +var secretes = config.GetConfig().Secretes + +type mime string + +var ( + MimeCSV mime = "csv" + MimeParquet mime = "parquet" +) + +type NetworkResource struct { + client *minio.Client + bucket string + key string + + // raw streaming object for CSV + stream *minio.Object + // for clean up-testing + fileName string +} + +func NewStreamReader(fileName string) (*NetworkResource, error) { + accessKey := secretes.AccessKey + secretKey := secretes.SecretKey + endpoint := secretes.EndpointURL + bucket := secretes.BucketName + useSSL := true + + client, err := minio.New(endpoint, accessKey, secretKey, useSSL) + if err != nil { + return nil, err + } + + obj, err := client.GetObject(bucket, fileName, minio.GetObjectOptions{}) + if err != nil { + return nil, err + } + + return &NetworkResource{ + client: client, + bucket: bucket, + key: fileName, + fileName: fileName, + stream: obj, // CSV reads this directly + }, nil +} + +func (n *NetworkResource) Stream() io.Reader { + return n.stream +} + +// S3ReaderAt implements io.ReaderAt for Parquet readers +func (n *NetworkResource) ReadAt(p []byte, off int64) (int, error) { + opts := minio.GetObjectOptions{} + _ = opts.SetRange(off, off+int64(len(p))-1) + + obj, err := n.client.GetObject(n.bucket, n.key, opts) + if err != nil { + return 0, err + } + return io.ReadFull(obj, p) +} + +func (n *NetworkResource) Seek(offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + return offset, nil + case io.SeekEnd: + // Need to return total object size + info, err := n.client.StatObject(n.bucket, n.key, minio.StatObjectOptions{}) + if err != nil { + return 0, fmt.Errorf("failed to stat object: %w", err) + } + return info.Size, nil + default: + return 0, fmt.Errorf("unsupported seek mode for S3: %d", whence) + } +} +func (n *NetworkResource) DownloadLocally() (*os.File, error) { + f, err := os.Create(fmt.Sprintf("%s-%d", n.key, time.Now().UnixNano())) + if err != nil { + return nil, err + } + + // Read entire stream once + content, err := io.ReadAll(n.stream) + if err != nil { + return nil, err + } + + if _, err := f.Write(content); err != nil { + return nil, err + } + + // Rewind so CSV readers can start from beginning + if _, err := f.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + return f, nil +} diff --git a/src/Backend/opti-sql-go/operators/project/source/source_test.go b/src/Backend/opti-sql-go/operators/project/source/source_test.go index 40d2170..05cba84 100644 --- a/src/Backend/opti-sql-go/operators/project/source/source_test.go +++ b/src/Backend/opti-sql-go/operators/project/source/source_test.go @@ -1,8 +1,305 @@ package source -import "testing" +import ( + "fmt" + "io" + "os" + "strings" + "testing" +) + +const ( + s3CSVFile = "country_full.csv" + s3ParquetFile = "userdata.parquet" + s3TxtFile = "example.txt" +) // test s3 as a source first then run test for other source files here func TestS3(t *testing.T) { // Simple passing test + _, err := NewStreamReader(s3CSVFile) + if err != nil { + t.Fatalf("failed to create s3 stream reader: %v", err) + } +} + +// test for +// (1) reading files from network (s3) should provide exact same abstraction as a local file +func TestS3BasicRead(t *testing.T) { + t.Run("csv read", func(t *testing.T) { + nr, err := NewStreamReader(s3CSVFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + firstKB := make([]byte, 1024) + n, err := nr.stream.Read(firstKB) + if err != nil { + t.Fatalf("failed to read from s3 object: %v", err) + } + if n != 1024 { + t.Fatalf("expected to read 1024 bytes, but read %d bytes", n) + } + fmt.Printf("returned contents %s\n", firstKB[:n]) + + }) + t.Run("parquet read", func(t *testing.T) { + nr, err := NewStreamReader(s3ParquetFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + firstKB := make([]byte, 1024) + n, err := nr.stream.Read(firstKB) + if err != nil { + t.Fatalf("failed to read from s3 object: %v", err) + } + if n != 1024 { + t.Fatalf("expected to read 1024 bytes, but read %d bytes", n) + } + fmt.Printf("returned contents %v\n", firstKB[:n]) + + }) + t.Run("txt read", func(t *testing.T) { + nr, err := NewStreamReader(s3TxtFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + firstKB := make([]byte, 1024) + n, err := nr.stream.Read(firstKB) + if err != nil { + t.Fatalf("failed to read from s3 object: %v", err) + } + if n != 1024 { + t.Fatalf("expected to read 1024 bytes, but read %d bytes", n) + } + fmt.Printf("returned contents %s\n", firstKB[:n]) + + }) +} + +// (2) download entire file before reading +func TestS3Download(t *testing.T) { + t.Run("Download CSV locally", func(t *testing.T) { + nr, err := NewStreamReader(s3CSVFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + newFile, err := nr.DownloadLocally() + if err != nil { + t.Fatalf("failed to download file locally %v", err) + } + defer func() { + _ = newFile.Close() + if err := os.Remove(newFile.Name()); err != nil { + t.Fatalf("error closing file %v", newFile.Name()) + } + }() + // validate stats about file + info, err := newFile.Stat() + if err != nil { + t.Fatalf("failed to get file stats %v", err) + } + if info.IsDir() { + t.Fatalf("expected regular file, found directory: %s", info.Name()) + } + + if info.Size() < 100 { + t.Fatalf("file is too small (%d bytes), expected >= 100 bytes", info.Size()) + } + + if !strings.HasPrefix(info.Name(), nr.fileName) { + t.Fatalf("filename mismatch: got %s, expected prefix %s", info.Name(), nr.fileName) + } + + }) + t.Run("Download parquet locally", func(t *testing.T) { + nr, err := NewStreamReader(s3ParquetFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + newFile, err := nr.DownloadLocally() + if err != nil { + t.Fatalf("failed to download file locally %v", err) + } + defer func() { + _ = newFile.Close() + if err := os.Remove(newFile.Name()); err != nil { + t.Fatalf("error closing file %v", newFile.Name()) + } + }() + // validate stats about file + info, err := newFile.Stat() + if err != nil { + t.Fatalf("failed to get file stats %v", err) + } + if info.IsDir() { + t.Fatalf("expected regular file, found directory: %s", info.Name()) + } + + if info.Size() < 100 { + t.Fatalf("file is too small (%d bytes), expected >= 100 bytes", info.Size()) + } + + if !strings.HasPrefix(info.Name(), nr.fileName) { + t.Fatalf("filename mismatch: got %s, expected prefix %s", info.Name(), nr.fileName) + } + + }) + t.Run("Download txt locally", func(t *testing.T) { + nr, err := NewStreamReader(s3TxtFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + newFile, err := nr.DownloadLocally() + if err != nil { + t.Fatalf("failed to download file locally %v", err) + } + defer func() { + _ = newFile.Close() + if err := os.Remove(newFile.Name()); err != nil { + t.Fatalf("error closing file %v", newFile.Name()) + } + }() + // validate stats about file + info, err := newFile.Stat() + if err != nil { + t.Fatalf("failed to get file stats %v", err) + } + if info.IsDir() { + t.Fatalf("expected regular file, found directory: %s", info.Name()) + } + + if info.Size() < 100 { + t.Fatalf("file is too small (%d bytes), expected >= 100 bytes", info.Size()) + } + + if !strings.HasPrefix(info.Name(), nr.fileName) { + t.Fatalf("filename mismatch: got %s, expected prefix %s", info.Name(), nr.fileName) + } + + }) +} + +// (3) add s3 variant of existing operaor sources (csv,parquet) and write minimal test here to check they work the same +func TestS3ForSource(t *testing.T) { + t.Run("csv from s3 source", func(t *testing.T) { + nr, err := NewStreamReader(s3CSVFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + pj, err := NewProjectCSVLeaf(nr.stream) + if err != nil { + t.Fatalf("failed to create csv project source from s3 object: %v", err) + } + rc, err := pj.Next(5) + if err != nil { + t.Fatalf("failed to read record batch from s3 csv source: %v", err) + } + fmt.Printf("returned record batch from s3 csv source: %v\n", rc) + + }) + t.Run("parquet from s3 source", func(t *testing.T) { + nr, err := NewStreamReader(s3ParquetFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + pj, err := NewParquetSource(nr) + if err != nil { + t.Fatalf("failed to create parquet project source from s3 object: %v", err) + } + rc, err := pj.Next(5) + if err != nil { + t.Fatalf("failed to read record batch from s3 csv source: %v", err) + } + fmt.Printf("returned record batch from s3 csv source: %v\n", rc) + + }) + t.Run("csv from s3 source then downloaded", func(t *testing.T) { + nr, err := NewStreamReader(s3CSVFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + f, err := nr.DownloadLocally() + if err != nil { + t.Fatalf("failed to download s3 object locally: %v", err) + } + defer func() { + fmt.Println("deleting downloaded file...") + _ = f.Close() + if err := os.Remove(f.Name()); err != nil { + t.Fatalf("error closing file %v", f.Name()) + } + }() + pj, err := NewProjectCSVLeaf(f) + if err != nil { + t.Fatalf("failed to create csv project source from s3 object: %v", err) + } + rc, err := pj.Next(5) + if err != nil { + t.Fatalf("failed to read record batch from s3 csv source: %v", err) + } + err = pj.Close() + if err != nil { + t.Fatalf("failed to close csv project source: %v", err) + } + fmt.Printf("returned record batch from s3 csv source: %v\n", rc) + + }) + t.Run("parquet from s3 source then downloaded", func(t *testing.T) { + nr, err := NewStreamReader(s3ParquetFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + f, err := nr.DownloadLocally() + if err != nil { + t.Fatalf("failed to download s3 object locally: %v", err) + } + defer func() { + fmt.Println("deleting downloaded file...") + _ = f.Close() + if err := os.Remove(f.Name()); err != nil { + t.Fatalf("error closing file %v", f.Name()) + } + }() + pj, err := NewParquetSource(f) + if err != nil { + t.Fatalf("failed to create csv project source from s3 object: %v", err) + } + rc, err := pj.Next(5) + if err != nil { + t.Fatalf("failed to read record batch from s3 csv source: %v", err) + } + fmt.Printf("returned record batch from s3 csv source: %v\n", rc) + + }) +} + +func TestS3Source(t *testing.T) { + nr, err := NewStreamReader(s3CSVFile) + if err != nil { + t.Fatalf("failed to create s3 object: %v", err) + } + t.Run("test SeekStart", func(t *testing.T) { + _, err := nr.Seek(0, io.SeekStart) + if err != nil { + t.Fatalf("failed to seek to start of s3 object: %v", err) + } + }) + t.Run("invalidSeek ", func(t *testing.T) { + _, err := nr.Seek(4, 4) + if err == nil { + t.Fatalf("expected error when seeking with invalid whence, but got none") + } + }) + t.Run("test stream read", func(t *testing.T) { + stream := nr.Stream() + buf := make([]byte, 512) + n, err := stream.Read(buf) + if err != nil { + t.Fatalf("failed to read from s3 object stream: %v", err) + } + if n == 0 { + t.Fatalf("expected to read some bytes from s3 object stream, but read 0 bytes") + } + fmt.Printf("read %d bytes from s3 object stream: %s\n", n, string(buf[:n])) + }) } From 6ee517899581924a63be4a2f2b369394ff8e1024 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Mon, 17 Nov 2025 01:59:54 -0500 Subject: [PATCH 06/19] feat:Implement Project operator | tested on top source operators --- .../operators/project/{source => }/csv.go | 13 +- .../project/{source => }/csv_test.go | 4 +- .../operators/project/{source => }/custom.go | 28 +- .../project/{source => }/custom_test.go | 2 +- .../operators/project/{source => }/parquet.go | 19 +- .../project/{source => }/parquet_test.go | 6 +- .../operators/project/projectExec.go | 97 ++++++- .../operators/project/projectExec_test.go | 273 +++++++++++++++++- .../operators/project/{source => }/s3.go | 2 +- .../project/{source => }/source_test.go | 2 +- src/Backend/opti-sql-go/operators/record.go | 1 + 11 files changed, 416 insertions(+), 31 deletions(-) rename src/Backend/opti-sql-go/operators/project/{source => }/csv.go (95%) rename src/Backend/opti-sql-go/operators/project/{source => }/csv_test.go (99%) rename src/Backend/opti-sql-go/operators/project/{source => }/custom.go (92%) rename src/Backend/opti-sql-go/operators/project/{source => }/custom_test.go (99%) rename src/Backend/opti-sql-go/operators/project/{source => }/parquet.go (96%) rename src/Backend/opti-sql-go/operators/project/{source => }/parquet_test.go (99%) rename src/Backend/opti-sql-go/operators/project/{source => }/s3.go (99%) rename src/Backend/opti-sql-go/operators/project/{source => }/source_test.go (99%) diff --git a/src/Backend/opti-sql-go/operators/project/source/csv.go b/src/Backend/opti-sql-go/operators/project/csv.go similarity index 95% rename from src/Backend/opti-sql-go/operators/project/source/csv.go rename to src/Backend/opti-sql-go/operators/project/csv.go index 6975f3b..b5fc714 100644 --- a/src/Backend/opti-sql-go/operators/project/source/csv.go +++ b/src/Backend/opti-sql-go/operators/project/csv.go @@ -1,4 +1,4 @@ -package source +package project import ( "encoding/csv" @@ -13,6 +13,10 @@ import ( "github.com/apache/arrow/go/v17/arrow/array" ) +var ( + _ = (operators.Operator)(&CSVSource{}) +) + // TODO: change the leaf stuff to be called scans instead type CSVSource struct { @@ -36,7 +40,7 @@ func NewProjectCSVLeaf(source io.Reader) (*CSVSource, error) { return proj, err } -func (csvS *CSVSource) Next(n uint64) (*operators.RecordBatch, error) { +func (csvS *CSVSource) Next(n uint16) (*operators.RecordBatch, error) { if csvS.done { return nil, io.EOF } @@ -44,7 +48,7 @@ func (csvS *CSVSource) Next(n uint64) (*operators.RecordBatch, error) { // 1. Create builders builders := csvS.initBuilders() - rowsRead := uint64(0) + rowsRead := uint16(0) // Process stored first row (from parseHeader) --- if csvS.firstDataRow != nil && rowsRead < n { @@ -92,6 +96,9 @@ func (csvS *CSVSource) Close() error { return nil } +func (csvS *CSVSource) Schema() *arrow.Schema { + return csvS.schema +} func (csvS *CSVSource) initBuilders() []array.Builder { fields := csvS.schema.Fields() builders := make([]array.Builder, len(fields)) diff --git a/src/Backend/opti-sql-go/operators/project/source/csv_test.go b/src/Backend/opti-sql-go/operators/project/csv_test.go similarity index 99% rename from src/Backend/opti-sql-go/operators/project/source/csv_test.go rename to src/Backend/opti-sql-go/operators/project/csv_test.go index 3d08625..07a38d5 100644 --- a/src/Backend/opti-sql-go/operators/project/source/csv_test.go +++ b/src/Backend/opti-sql-go/operators/project/csv_test.go @@ -1,4 +1,4 @@ -package source +package project import ( "fmt" @@ -12,7 +12,7 @@ import ( "github.com/apache/arrow/go/v17/arrow/memory" ) -const csvFilePath = "../../../../test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv" +const csvFilePath = "../../../test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv" //const csvFilePathLarger = "../../../../test_data/csv/stats.csv" diff --git a/src/Backend/opti-sql-go/operators/project/source/custom.go b/src/Backend/opti-sql-go/operators/project/custom.go similarity index 92% rename from src/Backend/opti-sql-go/operators/project/source/custom.go rename to src/Backend/opti-sql-go/operators/project/custom.go index f7fa96e..28b56ea 100644 --- a/src/Backend/opti-sql-go/operators/project/source/custom.go +++ b/src/Backend/opti-sql-go/operators/project/custom.go @@ -1,16 +1,19 @@ -package source +package project import ( "fmt" "io" "opti-sql-go/operators" - "opti-sql-go/operators/project" "github.com/apache/arrow/go/v15/arrow/memory" "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" ) +var ( + _ = (operators.Operator)(&InMemorySource{}) +) + // in memory format just for the ease of testing // same as other sources, we can use structs/slices here @@ -25,7 +28,7 @@ var ( type InMemorySource struct { schema *arrow.Schema columns []arrow.Array - pos uint64 + pos uint16 fieldToColIDx map[string]int } @@ -57,7 +60,7 @@ func NewInMemoryProjectExec(names []string, columns []any) (*InMemorySource, err } func (ms *InMemorySource) withFields(names ...string) error { - newSchema, cols, err := project.ProjectSchemaFilterDown(ms.schema, ms.columns, names...) + newSchema, cols, err := ProjectSchemaFilterDown(ms.schema, ms.columns, names...) if err != nil { return err } @@ -70,16 +73,16 @@ func (ms *InMemorySource) withFields(names ...string) error { ms.columns = cols return nil } -func (ms *InMemorySource) Next(n uint64) (*operators.RecordBatch, error) { - if ms.pos >= uint64(ms.columns[0].Len()) { +func (ms *InMemorySource) Next(n uint16) (*operators.RecordBatch, error) { + if ms.pos >= uint16(ms.columns[0].Len()) { return nil, io.EOF // EOF } - var currRows uint64 = 0 + var currRows uint16 = 0 outPutCols := make([]arrow.Array, len(ms.schema.Fields())) for i, field := range ms.schema.Fields() { col := ms.columns[ms.fieldToColIDx[field.Name]] - colLen := uint64(col.Len()) + colLen := uint16(col.Len()) remaining := colLen - ms.pos toRead := n if remaining < n { @@ -96,6 +99,15 @@ func (ms *InMemorySource) Next(n uint64) (*operators.RecordBatch, error) { Columns: outPutCols, }, nil } +func (ms *InMemorySource) Close() error { + for _, c := range ms.columns { + c.Release() + } + return nil +} +func (ms *InMemorySource) Schema() *arrow.Schema { + return ms.schema +} func unpackColumm(name string, col any) (arrow.Field, arrow.Array, error) { // need to not only build the array; but also need the schema var field arrow.Field diff --git a/src/Backend/opti-sql-go/operators/project/source/custom_test.go b/src/Backend/opti-sql-go/operators/project/custom_test.go similarity index 99% rename from src/Backend/opti-sql-go/operators/project/source/custom_test.go rename to src/Backend/opti-sql-go/operators/project/custom_test.go index 1af0bf5..84e9db2 100644 --- a/src/Backend/opti-sql-go/operators/project/source/custom_test.go +++ b/src/Backend/opti-sql-go/operators/project/custom_test.go @@ -1,4 +1,4 @@ -package source +package project import ( "fmt" diff --git a/src/Backend/opti-sql-go/operators/project/source/parquet.go b/src/Backend/opti-sql-go/operators/project/parquet.go similarity index 96% rename from src/Backend/opti-sql-go/operators/project/source/parquet.go rename to src/Backend/opti-sql-go/operators/project/parquet.go index b2e848f..c9e9940 100644 --- a/src/Backend/opti-sql-go/operators/project/source/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/parquet.go @@ -1,4 +1,4 @@ -package source +package project import ( "context" @@ -16,9 +16,13 @@ import ( "github.com/apache/arrow/go/v17/parquet/pqarrow" ) +var ( + _ = (operators.Operator)(&ParquetSource{}) +) + type ParquetSource struct { // existing fields - Schema *arrow.Schema + schema *arrow.Schema projectionPushDown []string // columns to project up predicatePushDown []filter.FilterExpr // simple predicate push down for now reader pqarrow.RecordReader @@ -54,7 +58,7 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { } return &ParquetSource{ - Schema: rdr.Schema(), + schema: rdr.Schema(), projectionPushDown: []string{}, predicatePushDown: nil, reader: rdr, @@ -104,7 +108,7 @@ func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string, filter } return &ParquetSource{ - Schema: rdr.Schema(), + schema: rdr.Schema(), projectionPushDown: columns, predicatePushDown: filters, reader: rdr, @@ -116,7 +120,7 @@ func (ps *ParquetSource) Next(n uint16) (*operators.RecordBatch, error) { if ps.reader == nil || ps.done || !ps.reader.Next() { return nil, io.EOF } - columns := make([]arrow.Array, len(ps.Schema.Fields())) + columns := make([]arrow.Array, len(ps.schema.Fields())) curRow := 0 for curRow < int(n) && ps.reader.Next() { err := ps.reader.Err() @@ -159,7 +163,7 @@ func (ps *ParquetSource) Next(n uint16) (*operators.RecordBatch, error) { curRow += numRows } return &operators.RecordBatch{ - Schema: ps.Schema, // Remove the pointer as ps.Schema is already of type arrow.Schema + Schema: ps.schema, // Remove the pointer as ps.Schema is already of type arrow.Schema Columns: columns, RowCount: uint64(curRow), }, nil @@ -169,6 +173,9 @@ func (ps *ParquetSource) Close() error { ps.reader = nil return nil } +func (ps *ParquetSource) Schema() *arrow.Schema { + return ps.schema +} // append arr2 to arr1 so (arr1 + arr2) = arr1-arr2 func CombineArray(a1, a2 arrow.Array) arrow.Array { diff --git a/src/Backend/opti-sql-go/operators/project/source/parquet_test.go b/src/Backend/opti-sql-go/operators/project/parquet_test.go similarity index 99% rename from src/Backend/opti-sql-go/operators/project/source/parquet_test.go rename to src/Backend/opti-sql-go/operators/project/parquet_test.go index 15ad8c5..3345c36 100644 --- a/src/Backend/opti-sql-go/operators/project/source/parquet_test.go +++ b/src/Backend/opti-sql-go/operators/project/parquet_test.go @@ -1,4 +1,4 @@ -package source +package project import ( "fmt" @@ -11,7 +11,7 @@ import ( "github.com/apache/arrow/go/v17/arrow/memory" ) -const ParquetTestDatafile = "../../../../test_data/parquet/capitals_clean.parquet" +const ParquetTestDatafile = "../../../test_data/parquet/capitals_clean.parquet" func getTestParuqetFile() *os.File { file, err := os.Open(ParquetTestDatafile) @@ -80,7 +80,7 @@ func TestParquetInit(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - schema := source.Schema + schema := source.Schema() if len(schema.Fields()) != len(columns) { t.Errorf("Expected schema to have %d fields, got %d", len(columns), len(schema.Fields())) } diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index 259c9eb..f98c443 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -2,19 +2,87 @@ package project import ( "errors" + "io" + "opti-sql-go/operators" "github.com/apache/arrow/go/v17/arrow" ) +var ( + _ = (operators.Operator)(&ProjectExec{}) +) + +var ( + ErrProjectColumnNotFound = errors.New("project: column not found") + ErrEmptyColumnsToProject = errors.New("project: no columns to project") +) + +type ProjectExec struct { + child operators.Operator + outputschema arrow.Schema + columnsToKeep []string + // if done always return eof + done bool +} + +// columns to keep and existing scehma +// TODO: double check that calling operator.Schema here is fine, this means make sure all operators have their schema in order as soons as possible +func NewProjectExec(projectColumns []string, input operators.Operator) (*ProjectExec, error) { + newSchema, err := prunedSchema(input.Schema(), projectColumns) + if err != nil { + return nil, err + } + // return new exec + return &ProjectExec{ + child: input, + outputschema: *newSchema, + columnsToKeep: projectColumns, + }, nil +} + +// pretty simple, read from child operator and prune columns +// pass through error && handles EOF alike +func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { + if p.done { + return nil, io.EOF + } + + rc, err := p.child.Next(n) + if err != nil { + return nil, err + } + _, orderCols, err := ProjectSchemaFilterDown(rc.Schema, rc.Columns, p.columnsToKeep...) + if err != nil { + return nil, err + } + for _, c := range rc.Columns { + c.Release() + } + if rc.RowCount == 0 { + p.done = true + } + return &operators.RecordBatch{ + Schema: &p.outputschema, + Columns: orderCols, + RowCount: rc.RowCount, + }, nil +} +func (p *ProjectExec) Close() error { + return nil +} +func (p *ProjectExec) Schema() *arrow.Schema { + return &p.outputschema +} + // handle keeping only the request columsn but make sure the schema and columns are also aligned // returns error if a column doesnt exist func ProjectSchemaFilterDown(schema *arrow.Schema, cols []arrow.Array, keepCols ...string) (*arrow.Schema, []arrow.Array, error) { if len(keepCols) == 0 { - return arrow.NewSchema([]arrow.Field{}, nil), nil, errors.New("no columns passed in") + return arrow.NewSchema([]arrow.Field{}, nil), nil, ErrEmptyColumnsToProject } // Build map: columnName -> original index - fieldIndex := make(map[string]int) + fieldIndex := make(map[string]int) // age -> 0 for i, f := range schema.Fields() { fieldIndex[f.Name] = i } @@ -26,13 +94,34 @@ func ProjectSchemaFilterDown(schema *arrow.Schema, cols []arrow.Array, keepCols for _, name := range keepCols { idx, exists := fieldIndex[name] if !exists { - return arrow.NewSchema([]arrow.Field{}, nil), []arrow.Array{}, errors.New("invalid column passed in to be pruned") + return arrow.NewSchema([]arrow.Field{}, nil), []arrow.Array{}, ErrProjectColumnNotFound } newFields = append(newFields, schema.Field(idx)) - newCols = append(newCols, cols[idx]) + col := cols[idx] + col.Retain() + newCols = append(newCols, col) } newSchema := arrow.NewSchema(newFields, nil) return newSchema, newCols, nil } + +// passing in a coulumn to keep that doesnt exist returns error +// passing in no columns returns and error +func prunedSchema(schema *arrow.Schema, keepCols []string) (*arrow.Schema, error) { + if len(keepCols) == 0 { + return arrow.NewSchema([]arrow.Field{}, nil), ErrEmptyColumnsToProject + } + newFields := make([]arrow.Field, 0) + for _, colName := range keepCols { + idx := schema.FieldIndices(colName) + if len(idx) == 0 { + return nil, ErrProjectColumnNotFound + } + // append the field + newFields = append(newFields, schema.Field(idx[0])) + } + newSchema := arrow.NewSchema(newFields, nil) + return newSchema, nil +} diff --git a/src/Backend/opti-sql-go/operators/project/projectExec_test.go b/src/Backend/opti-sql-go/operators/project/projectExec_test.go index 404af84..b7981bb 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec_test.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec_test.go @@ -1,7 +1,276 @@ package project -import "testing" +import ( + "errors" + "fmt" + "io" + "testing" -func TestProjectExec(t *testing.T) { + "github.com/apache/arrow/go/v17/arrow" +) + +func TestProjectExecInit(t *testing.T) { // Simple passing test } + +func TestProjectPrune(t *testing.T) { + fields := []arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int64}, + {Name: "name", Type: arrow.BinaryTypes.String}, + {Name: "age", Type: arrow.PrimitiveTypes.Int64}, + {Name: "country", Type: arrow.BinaryTypes.String}, + {Name: "email", Type: arrow.BinaryTypes.String}, + {Name: "signup_date", Type: arrow.FixedWidthTypes.Date32}, + } + schema := arrow.NewSchema(fields, nil) + t.Run("validate prune 1", func(t *testing.T) { + keepCols := []string{"id", "name", "email"} + newSchema, err := prunedSchema(schema, keepCols) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if newSchema.NumFields() != len(keepCols) { + t.Fatalf("expected %d fields, got %d", len(keepCols), newSchema.NumFields()) + } + for i, field := range newSchema.Fields() { + if field.Name != keepCols[i] { + t.Fatalf("expected field %s, got %s", keepCols[i], field.Name) + } + } + fmt.Printf("%s\n", newSchema) + }) + t.Run("validate prune 2", func(t *testing.T) { + keeptCols := []string{"age", "country", "signup_date"} + newSchema, err := prunedSchema(schema, keeptCols) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if newSchema.NumFields() != len(keeptCols) { + t.Fatalf("expected %d fields, got %d", len(keeptCols), newSchema.NumFields()) + } + for i, field := range newSchema.Fields() { + if field.Name != keeptCols[i] { + t.Fatalf("expected field %s, got %s", keeptCols[i], field.Name) + } + } + fmt.Printf("%s\n", newSchema) + + }) + t.Run("prune non-existant column", func(t *testing.T) { + keepCols := []string{"id", "non_existing_column"} + _, err := prunedSchema(schema, keepCols) + if err == nil { + t.Fatalf("expected error for non-existing column, got nil") + } + if !errors.Is(err, ErrProjectColumnNotFound) { + t.Fatalf("expected ErrProjectColumnNotFound, got %v", err) + } + + }) + t.Run("Prune empty input keepcols", func(t *testing.T) { + keepCols := []string{} + _, err := prunedSchema(schema, keepCols) + if err == nil { + t.Fatalf("expected error for empty keepcols, got nil") + } + if !errors.Is(err, ErrEmptyColumnsToProject) { + t.Fatalf("expected ErrEmptyColumnsToProject, got %v", err) + } + }) + +} +func TestProjectExec(t *testing.T) { + names, col := generateTestColumns() + memorySource, err := NewInMemoryProjectExec(names, col) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + fmt.Printf("original schema %v\n", memorySource.Schema()) + projectExec, err := NewProjectExec([]string{"id", "name", "age"}, memorySource) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } + rc, err := projectExec.Next(3) + if err != nil { + t.Fatalf("failed to get next record batch: %v", err) + } + fmt.Printf("rc:%v\n", rc) + +} + +// NewProjectExec, pruned schema errors and iteration behavior. +func TestProjectExec_Subtests(t *testing.T) { + names, cols := generateTestColumns() + + t.Run("ValidProjection", func(t *testing.T) { + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + projCols := []string{"id", "name", "age"} + projExec, err := NewProjectExec(projCols, memSrc) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } + rb, err := projExec.Next(4) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + if rb == nil { + t.Fatalf("expected a record batch, got nil") + } + if len(rb.Columns) != len(projCols) { + t.Fatalf("expected %d columns, got %d", len(projCols), len(rb.Columns)) + } + for _, c := range rb.Columns { + c.Release() + } + }) + + t.Run("EmptyColumns", func(t *testing.T) { + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + _, err = NewProjectExec([]string{}, memSrc) + if err == nil { + t.Fatalf("expected error for empty project columns, got nil") + } + if !errors.Is(err, ErrEmptyColumnsToProject) { + t.Fatalf("expected ErrEmptyColumnsToProject, got %v", err) + } + }) + + t.Run("NonExistentColumn", func(t *testing.T) { + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + _, err = NewProjectExec([]string{"id", "nope"}, memSrc) + if err == nil { + t.Fatalf("expected error for non-existent column, got nil") + } + if !errors.Is(err, ErrProjectColumnNotFound) { + t.Fatalf("expected ErrProjectColumnNotFound, got %v", err) + } + }) + + t.Run("SchemaMatch", func(t *testing.T) { + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + projCols := []string{"id", "name"} + projExec, err := NewProjectExec(projCols, memSrc) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } + execSchema := projExec.Schema() + pruned, err := prunedSchema(memSrc.Schema(), projCols) + if err != nil { + t.Fatalf("prunedSchema failed: %v", err) + } + if !execSchema.Equal(pruned) { + t.Fatalf("expected exec schema %v, got %v", pruned, execSchema) + } + _ = projExec + }) + + t.Run("IterateUntilEOF", func(t *testing.T) { + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + projExec, err := NewProjectExec([]string{"id", "name"}, memSrc) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } + total := 0 + batches := 0 + for { + rb, err := projExec.Next(3) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Fatalf("Next returned unexpected error: %v", err) + } + if rb == nil { + t.Fatalf("expected record batch, got nil") + } + total += int(rb.Columns[0].Len()) + batches++ + for _, c := range rb.Columns { + c.Release() + } + } + if batches == 0 { + t.Fatalf("expected at least 1 batch, got 0") + } + }) + + t.Run("SingleColumnProjection", func(t *testing.T) { + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + projExec, err := NewProjectExec([]string{"department"}, memSrc) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } + total := 0 + for { + rb, err := projExec.Next(5) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Fatalf("Next returned unexpected error: %v", err) + } + if len(rb.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(rb.Columns)) + } + total += int(rb.Columns[0].Len()) + for _, c := range rb.Columns { + c.Release() + } + } + }) + t.Run("Check Close", func(t *testing.T) { + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + projExec, err := NewProjectExec([]string{"department"}, memSrc) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } + err = projExec.Close() + if err != nil { + t.Fatalf("expected no error on Close, got %v", err) + } + + }) + t.Run("Empty ProjectFilter", func(t *testing.T) { + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + _, _, err = ProjectSchemaFilterDown(memSrc.Schema(), memSrc.columns, []string{}...) + if err == nil { + t.Fatalf("expected error for empty project filter, got nil") + } + if !errors.Is(err, ErrEmptyColumnsToProject) { + t.Fatalf("expected ErrEmptyColumnsToProject, got %v", err) + } + _, _, err = ProjectSchemaFilterDown(memSrc.Schema(), memSrc.columns, []string{"This column doesnt exist"}...) + if err == nil { + t.Fatalf("expected error for non-existent column in project filter, got nil") + } + if !errors.Is(err, ErrProjectColumnNotFound) { + t.Fatalf("expected ErrProjectColumnNotFound, got %v", err) + } + + }) + +} diff --git a/src/Backend/opti-sql-go/operators/project/source/s3.go b/src/Backend/opti-sql-go/operators/project/s3.go similarity index 99% rename from src/Backend/opti-sql-go/operators/project/source/s3.go rename to src/Backend/opti-sql-go/operators/project/s3.go index 73e5379..cb47a62 100644 --- a/src/Backend/opti-sql-go/operators/project/source/s3.go +++ b/src/Backend/opti-sql-go/operators/project/s3.go @@ -1,4 +1,4 @@ -package source +package project import ( "fmt" diff --git a/src/Backend/opti-sql-go/operators/project/source/source_test.go b/src/Backend/opti-sql-go/operators/project/source_test.go similarity index 99% rename from src/Backend/opti-sql-go/operators/project/source/source_test.go rename to src/Backend/opti-sql-go/operators/project/source_test.go index 05cba84..debdbf4 100644 --- a/src/Backend/opti-sql-go/operators/project/source/source_test.go +++ b/src/Backend/opti-sql-go/operators/project/source_test.go @@ -1,4 +1,4 @@ -package source +package project import ( "fmt" diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index 5b1b814..2059d3a 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -17,6 +17,7 @@ var ( type Operator interface { Next(uint16) (*RecordBatch, error) + Schema() *arrow.Schema // Call Operator.Close() after Next retruns an io.EOF to clean up resources Close() error } From e7e9d82928a8876e425729a2e26e87c49ad05794 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Mon, 17 Nov 2025 02:17:19 -0500 Subject: [PATCH 07/19] chore: add next steps. Looking at Expr and Filter operators --- src/Backend/opti-sql-go/main.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Backend/opti-sql-go/main.go b/src/Backend/opti-sql-go/main.go index 82e1eb8..f277de6 100644 --- a/src/Backend/opti-sql-go/main.go +++ b/src/Backend/opti-sql-go/main.go @@ -6,6 +6,8 @@ import ( "os" ) +// TODO: in the project operators make sure the record batches account for the RowCount field properly. + func main() { if len(os.Args) > 1 { if err := config.Decode(os.Args[1]); err != nil { From f97e73b437378f5710edcb48e3f81c5280bf5f67 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Mon, 17 Nov 2025 02:17:45 -0500 Subject: [PATCH 08/19] chore: add next steps. Looking at Expr and Filter operators --- src/Backend/opti-sql-go/operators/Expr/expr.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/Backend/opti-sql-go/operators/Expr/expr.go b/src/Backend/opti-sql-go/operators/Expr/expr.go index 25c8d3c..5334beb 100644 --- a/src/Backend/opti-sql-go/operators/Expr/expr.go +++ b/src/Backend/opti-sql-go/operators/Expr/expr.go @@ -4,3 +4,21 @@ package Expr // for example Column + Literal // Column - Column // Literal / Literal + +//1. Arithmetic Expressions +// SELECT salary * 1.2, price + tax, -(discount) +//2.Alias Expressions +//SELECT name AS employee_name, age AS employee_age +//SELECT salary * 1.2 AS new_salary +//3.String Expressions +//first_name || ' ' || last_name +//UPPER(name) +//LOWER(email) +//SUBSTRING(name, 1, 3) +//4. Function calls +//ABS(x) +//ROUND(salary, 2) +//LENGTH(name) +//COALESCE(a, b) +//5. Constants +//SELECT 1, 'hello', 3.14 From 86b1e3ccda3aadff366f6e691c79fc85c694735b Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Mon, 17 Nov 2025 10:51:48 -0500 Subject: [PATCH 09/19] fix:revised PR --- src/Backend/opti-sql-go/config/config.go | 9 ++-- .../opti-sql-go/operators/filter/filter.go | 2 +- .../opti-sql-go/operators/project/csv.go | 26 ++++++---- .../opti-sql-go/operators/project/csv_test.go | 48 ++----------------- .../opti-sql-go/operators/project/custom.go | 16 +++---- .../operators/project/custom_test.go | 47 +++++++++--------- .../opti-sql-go/operators/project/parquet.go | 7 --- .../operators/project/parquet_test.go | 23 +++++---- .../operators/project/projectExec.go | 10 ++-- .../operators/project/projectExec_test.go | 9 ++-- .../opti-sql-go/operators/project/s3.go | 7 --- .../operators/project/source_test.go | 21 ++++---- src/Backend/opti-sql-go/operators/record.go | 4 +- .../opti-sql-go/operators/serialize_test.go | 8 ++-- 14 files changed, 91 insertions(+), 146 deletions(-) diff --git a/src/Backend/opti-sql-go/config/config.go b/src/Backend/opti-sql-go/config/config.go index 7f70612..627136b 100644 --- a/src/Backend/opti-sql-go/config/config.go +++ b/src/Backend/opti-sql-go/config/config.go @@ -33,9 +33,8 @@ type batchConfig struct { EnableParallelRead bool `yaml:"enable_parallel_read"` MaxMemoryBeforeSpill uint64 `yaml:"max_memory_before_spill"` MaxFileSizeMB int `yaml:"max_file_size_mb"` // max size of a single file - // TODO: add test for these two fileds, just add to existing test - ShouldDowndload bool `yaml:"should_download"` - MaxDownloadSizeMB int `yaml:"max_download_size_mb"` // max size to download from external sources like S3 + ShouldDownload bool `yaml:"should_download"` + MaxDownloadSizeMB int `yaml:"max_download_size_mb"` // max size to download from external sources like S3 } type queryConfig struct { // should results be cached, server side? if so how long @@ -76,7 +75,7 @@ var configInstance *Config = &Config{ MaxFileSizeMB: 500, // 500MB // should we download files from external sources like S3 // if so whats the max size to download, if its greater than dont download the file locally - ShouldDowndload: true, + ShouldDownload: true, MaxDownloadSizeMB: 10, // 10MB }, Query: queryConfig{ @@ -160,7 +159,7 @@ func mergeConfig(dst *Config, src map[string]interface{}) { dst.Batch.MaxFileSizeMB = v } if v, ok := batch["should_download"].(bool); ok { - dst.Batch.ShouldDowndload = v + dst.Batch.ShouldDownload = v } if v, ok := batch["max_download_size_mb"].(int); ok { dst.Batch.MaxDownloadSizeMB = v diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index 0fc5e5c..4195cdd 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -5,7 +5,7 @@ import ( "github.com/apache/arrow/go/v17/arrow/array" ) -// FilterExpr takes in a field and column and yeildss a function that takes in an index and returns a bool indicating whether the row at that index satisfies the filter condition. +// FilterExpr takes in a field and column and yeilds a function that takes in an index and returns a bool indicating whether the row at that index satisfies the filter condition. type FilterExpr func(filed arrow.Field, col arrow.Array) func(i int) bool // example diff --git a/src/Backend/opti-sql-go/operators/project/csv.go b/src/Backend/opti-sql-go/operators/project/csv.go index b5fc714..1a021c6 100644 --- a/src/Backend/opti-sql-go/operators/project/csv.go +++ b/src/Backend/opti-sql-go/operators/project/csv.go @@ -17,8 +17,6 @@ var ( _ = (operators.Operator)(&CSVSource{}) ) -// TODO: change the leaf stuff to be called scans instead - type CSVSource struct { r *csv.Reader schema *arrow.Schema // columns to project as well as types to cast to @@ -52,7 +50,6 @@ func (csvS *CSVSource) Next(n uint16) (*operators.RecordBatch, error) { // Process stored first row (from parseHeader) --- if csvS.firstDataRow != nil && rowsRead < n { - fmt.Printf("First row: %v\n", csvS.firstDataRow) if err := csvS.processRow(csvS.firstDataRow, builders); err != nil { return nil, err } @@ -86,8 +83,9 @@ func (csvS *CSVSource) Next(n uint16) (*operators.RecordBatch, error) { columns := csvS.finalizeBuilders(builders) return &operators.RecordBatch{ - Schema: csvS.schema, - Columns: columns, + Schema: csvS.schema, + Columns: columns, + RowCount: uint64(rowsRead), }, nil } func (csvS *CSVSource) Close() error { @@ -124,16 +122,26 @@ func (csvS *CSVSource) processRow( if cell == "" || cell == "NULL" { b.AppendNull() } else { - v, _ := strconv.ParseInt(cell, 10, 64) - b.Append(v) + v, err := strconv.ParseInt(cell, 10, 64) + if err != nil { + fmt.Printf("failed to parse cell: %v with error: %v\n", cell, err) + b.AppendNull() + } else { + b.Append(v) + } } case *array.Float64Builder: if cell == "" || cell == "NULL" { b.AppendNull() } else { - v, _ := strconv.ParseFloat(cell, 64) - b.Append(v) + v, err := strconv.ParseFloat(cell, 64) + if err != nil { + fmt.Printf("failed to parse cell: %v with error: %v\n", cell, err) + b.AppendNull() + } else { + b.Append(v) + } } case *array.StringBuilder: diff --git a/src/Backend/opti-sql-go/operators/project/csv_test.go b/src/Backend/opti-sql-go/operators/project/csv_test.go index 07a38d5..8a119b5 100644 --- a/src/Backend/opti-sql-go/operators/project/csv_test.go +++ b/src/Backend/opti-sql-go/operators/project/csv_test.go @@ -1,7 +1,6 @@ package project import ( - "fmt" "io" "os" "strings" @@ -14,8 +13,6 @@ import ( const csvFilePath = "../../../test_data/csv/Mental_Health_and_Social_Media_Balance_Dataset.csv" -//const csvFilePathLarger = "../../../../test_data/csv/stats.csv" - func getTestFile() *os.File { v, err := os.Open(csvFilePath) if err != nil { @@ -24,15 +21,6 @@ func getTestFile() *os.File { return v } -/* - func getTestFile2() *os.File { - v, err := os.Open(csvFilePathLarger) - if err != nil { - panic(err) - } - return v - } -*/ func TestCsvInit(t *testing.T) { v := getTestFile() defer func() { @@ -44,8 +32,8 @@ func TestCsvInit(t *testing.T) { if err != nil { t.Errorf("Failed to create ProjectCSVLeaf: %v", err) } - fmt.Printf("schema -> %v\n", p.schema) - fmt.Printf("columns Mapping -> %v\n", p.colPosition) + t.Logf("schema -> %v\n", p.schema) + t.Logf("columns Mapping -> %v\n", p.colPosition) } func TestProjectComponents(t *testing.T) { v := getTestFile() @@ -81,7 +69,7 @@ func TestCsvNext(t *testing.T) { if err != nil { t.Errorf("Failed to read next batch from CSV: %v", err) } - fmt.Printf("Batch: %v\n", rBatch) + t.Logf("Batch: %v\n", rBatch) } // TestParseDataType tests every branch of the parseDataType function @@ -425,7 +413,7 @@ func TestNextFunction(t *testing.T) { t.Errorf("Column %d: expected 3 rows, got %d", i, col.Len()) } } - fmt.Printf("col0: %v\n", batch.Columns[0]) + t.Logf("col0: %v\n", batch.Columns[0]) // Check Int64 column (id) idCol, ok := batch.Columns[0].(*array.Int64) if !ok { @@ -580,7 +568,7 @@ false if !ok { t.Fatalf("Expected *array.Boolean, got %T", batch.Columns[0]) } - fmt.Printf("flagCol : %v\n", flagCol) + t.Logf("flagCol : %v\n", flagCol) if !flagCol.IsNull(1) { t.Error("Expected NULL values in flag column") @@ -967,29 +955,3 @@ func TestProccessFirstLine(t *testing.T) { } } - -/* -func TestLargercsvFile(t *testing.T) { - f1 := getTestFile2() - - project, err := NewProjectCSVLeaf(f1) - if err != nil { - t.Fatalf("NewProjectCSVLeaf failed: %v", err) - } - defer func() { - if err := f1.Close(); err != nil { - t.Fatalf("failed to close: %v", err) - } - }() - for { - rc, err := project.Next(1024 * 8) - if err == io.EOF { - break - } - if err != nil { - t.Fatalf("Next failed: %v", err) - } - fmt.Printf("rc : %v\n", rc.Columns) - } -} -*/ diff --git a/src/Backend/opti-sql-go/operators/project/custom.go b/src/Backend/opti-sql-go/operators/project/custom.go index 28b56ea..38f2a94 100644 --- a/src/Backend/opti-sql-go/operators/project/custom.go +++ b/src/Backend/opti-sql-go/operators/project/custom.go @@ -44,7 +44,7 @@ func NewInMemoryProjectExec(names []string, columns []any) (*InMemorySource, err if !supportedType(col) { return nil, operators.ErrInvalidSchema(fmt.Sprintf("unsupported column type for column %s", names[i])) } - field, arr, err := unpackColumm(names[i], col) + field, arr, err := unpackColumn(names[i], col) if err != nil { return nil, ErrInvalidInMemoryDataType(col) } @@ -74,7 +74,7 @@ func (ms *InMemorySource) withFields(names ...string) error { return nil } func (ms *InMemorySource) Next(n uint16) (*operators.RecordBatch, error) { - if ms.pos >= uint16(ms.columns[0].Len()) { + if len(ms.columns) == 0 || ms.pos >= uint16(ms.columns[0].Len()) { return nil, io.EOF // EOF } var currRows uint16 = 0 @@ -95,8 +95,9 @@ func (ms *InMemorySource) Next(n uint16) (*operators.RecordBatch, error) { ms.pos += currRows return &operators.RecordBatch{ - Schema: ms.schema, - Columns: outPutCols, + Schema: ms.schema, + Columns: outPutCols, + RowCount: uint64(currRows), }, nil } func (ms *InMemorySource) Close() error { @@ -108,11 +109,11 @@ func (ms *InMemorySource) Close() error { func (ms *InMemorySource) Schema() *arrow.Schema { return ms.schema } -func unpackColumm(name string, col any) (arrow.Field, arrow.Array, error) { +func unpackColumn(name string, col any) (arrow.Field, arrow.Array, error) { // need to not only build the array; but also need the schema var field arrow.Field field.Name = name - field.Nullable = true // default to nullable for now + field.Nullable = true // default to nullable switch colType := col.(type) { case []int: field.Type = arrow.PrimitiveTypes.Int64 @@ -124,14 +125,12 @@ func unpackColumm(name string, col any) (arrow.Field, arrow.Array, error) { } return field, b.NewArray(), nil case []int8: - // build int8 array field.Type = arrow.PrimitiveTypes.Int8 data := colType b := array.NewInt8Builder(memory.DefaultAllocator) defer b.Release() b.AppendValues(data, nil) return field, b.NewArray(), nil - // build int8 array case []int16: field.Type = arrow.PrimitiveTypes.Int16 data := colType @@ -212,7 +211,6 @@ func unpackColumm(name string, col any) (arrow.Field, arrow.Array, error) { defer b.Release() b.AppendValues(data, nil) return field, b.NewArray(), nil - // build string array case []bool: field.Type = arrow.FixedWidthTypes.Boolean data := colType diff --git a/src/Backend/opti-sql-go/operators/project/custom_test.go b/src/Backend/opti-sql-go/operators/project/custom_test.go index 84e9db2..ef08946 100644 --- a/src/Backend/opti-sql-go/operators/project/custom_test.go +++ b/src/Backend/opti-sql-go/operators/project/custom_test.go @@ -1,7 +1,6 @@ package project import ( - "fmt" "io" "testing" @@ -75,7 +74,7 @@ func TestInMemoryBatchInit(t *testing.T) { if len(projC.columns) != projC.schema.NumFields() { t.Errorf("Columns and schema field count mismatch: got %d and %d", len(projC.columns), projC.schema.NumFields()) } - fmt.Printf("schema: %v\n", projC.schema) + t.Logf("schema: %v\n", projC.schema) } // ==================== COMPREHENSIVE TESTS FOR 100% CODE COVERAGE ==================== @@ -135,7 +134,7 @@ func TestSupportedType(t *testing.T) { // TestUnpackColumn tests every branch of the unpackColumm function func TestUnpackColumn(t *testing.T) { t.Run("[]int type", func(t *testing.T) { - field, arr, err := unpackColumm("test_int", []int{1, 2, 3, 4, 5}) + field, arr, err := unpackColumn("test_int", []int{1, 2, 3, 4, 5}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -163,7 +162,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]int8 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_int8", []int8{-1, 0, 1, 127}) + field, arr, err := unpackColumn("test_int8", []int8{-1, 0, 1, 127}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -180,7 +179,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]int16 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_int16", []int16{-100, 0, 100, 32767}) + field, arr, err := unpackColumn("test_int16", []int16{-100, 0, 100, 32767}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -197,7 +196,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]int32 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_int32", []int32{-1000, 0, 1000, 2147483647}) + field, arr, err := unpackColumn("test_int32", []int32{-1000, 0, 1000, 2147483647}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -214,7 +213,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]int64 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_int64", []int64{-9223372036854775808, 0, 9223372036854775807}) + field, arr, err := unpackColumn("test_int64", []int64{-9223372036854775808, 0, 9223372036854775807}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -231,7 +230,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]uint type", func(t *testing.T) { - field, arr, err := unpackColumm("test_uint", []uint{0, 1, 100, 1000}) + field, arr, err := unpackColumn("test_uint", []uint{0, 1, 100, 1000}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -254,7 +253,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]uint8 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_uint8", []uint8{0, 1, 255}) + field, arr, err := unpackColumn("test_uint8", []uint8{0, 1, 255}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -271,7 +270,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]uint16 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_uint16", []uint16{0, 100, 65535}) + field, arr, err := unpackColumn("test_uint16", []uint16{0, 100, 65535}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -288,7 +287,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]uint32 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_uint32", []uint32{0, 1000, 4294967295}) + field, arr, err := unpackColumn("test_uint32", []uint32{0, 1000, 4294967295}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -305,7 +304,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]uint64 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_uint64", []uint64{0, 1000, 18446744073709551615}) + field, arr, err := unpackColumn("test_uint64", []uint64{0, 1000, 18446744073709551615}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -322,7 +321,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]float32 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_float32", []float32{-1.5, 0.0, 1.5, 3.14159}) + field, arr, err := unpackColumn("test_float32", []float32{-1.5, 0.0, 1.5, 3.14159}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -339,7 +338,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]float64 type", func(t *testing.T) { - field, arr, err := unpackColumm("test_float64", []float64{-2.718281828, 0.0, 3.141592653589793}) + field, arr, err := unpackColumn("test_float64", []float64{-2.718281828, 0.0, 3.141592653589793}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -356,7 +355,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]string type", func(t *testing.T) { - field, arr, err := unpackColumm("test_string", []string{"hello", "world", "test", ""}) + field, arr, err := unpackColumn("test_string", []string{"hello", "world", "test", ""}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -379,7 +378,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("[]bool type", func(t *testing.T) { - field, arr, err := unpackColumm("test_bool", []bool{true, false, true, false, true}) + field, arr, err := unpackColumn("test_bool", []bool{true, false, true, false, true}) if err != nil { t.Fatalf("unpackColumm failed: %v", err) } @@ -402,7 +401,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("Unsupported type - default case", func(t *testing.T) { - _, _, err := unpackColumm("test_unsupported", []byte{1, 2, 3}) + _, _, err := unpackColumn("test_unsupported", []byte{1, 2, 3}) if err != nil { t.Error("unexpected error for unsupported type") } @@ -410,7 +409,7 @@ func TestUnpackColumn(t *testing.T) { }) t.Run("Empty slices", func(t *testing.T) { - field, arr, err := unpackColumm("empty_int", []int{}) + field, arr, err := unpackColumn("empty_int", []int{}) if err != nil { t.Fatalf("unpackColumm failed for empty slice: %v", err) } @@ -664,7 +663,7 @@ func TestSchemaFieldTypes(t *testing.T) { } } -func TestPrunceSchema(t *testing.T) { +func TestPruneSchema(t *testing.T) { names, columns := generateTestColumns() t.Run("Select subset of fields", func(t *testing.T) { @@ -888,9 +887,9 @@ func TestNext(t *testing.T) { if err != nil { t.Error("unexpected error when pruning columns") } - fmt.Printf("updated: %s\n", proj.schema) - fmt.Printf("new Mapping: %v\n", proj.fieldToColIDx) - fmt.Printf("new columns: %v\n", proj.columns) + t.Logf("updated: %s\n", proj.schema) + t.Logf("new Mapping: %v\n", proj.fieldToColIDx) + t.Logf("new columns: %v\n", proj.columns) totalRows := 0 for { @@ -901,8 +900,8 @@ func TestNext(t *testing.T) { if err != nil { t.Fatalf("Next failed: %v", err) } - fmt.Printf("Batche schema: %v\n", batch.Schema) - fmt.Printf("Batch data: %v\n", batch.Columns) + t.Logf("Batche schema: %v\n", batch.Schema) + t.Logf("Batch data: %v\n", batch.Columns) // Verify only 1 column if len(batch.Columns) != 1 { diff --git a/src/Backend/opti-sql-go/operators/project/parquet.go b/src/Backend/opti-sql-go/operators/project/parquet.go index c9e9940..50b04a4 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/parquet.go @@ -40,7 +40,6 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { defer func() { if err := filerReader.Close(); err != nil { fmt.Printf("warning: failed to close parquet reader: %v\n", err) - } }() @@ -131,16 +130,10 @@ func (ps *ParquetSource) Next(n uint16) (*operators.RecordBatch, error) { numCols := int(record.NumCols()) numRows := int(record.NumRows()) - fmt.Printf("numCols=%d numRows=%d columns=%v\n", - numCols, numRows, record.Columns(), - ) for colIdx := 0; colIdx < numCols; colIdx++ { batchCol := record.Column(colIdx) existing := columns[colIdx] - fmt.Printf("columns:%v\n", columns) - fmt.Printf("existing:%v\n", existing) - fmt.Printf("batchCol:%v\n", batchCol) // First time seeing this column → just assign it if existing == nil { batchCol.Retain() diff --git a/src/Backend/opti-sql-go/operators/project/parquet_test.go b/src/Backend/opti-sql-go/operators/project/parquet_test.go index 3345c36..c383a07 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet_test.go +++ b/src/Backend/opti-sql-go/operators/project/parquet_test.go @@ -1,7 +1,6 @@ package project import ( - "fmt" "io" "os" "testing" @@ -13,7 +12,7 @@ import ( const ParquetTestDatafile = "../../../test_data/parquet/capitals_clean.parquet" -func getTestParuqetFile() *os.File { +func getTestParquetFile() *os.File { file, err := os.Open(ParquetTestDatafile) if err != nil { panic(err) @@ -57,7 +56,7 @@ func sameStringSlice(a, b []string) bool { } func TestParquetInit(t *testing.T) { t.Run("Test No names pass in", func(t *testing.T) { - f := getTestParuqetFile() + f := getTestParquetFile() _, err := NewParquetSourcePushDown(f, []string{}, nil) if err == nil { @@ -66,7 +65,7 @@ func TestParquetInit(t *testing.T) { }) t.Run("Test invalid names are passed in", func(t *testing.T) { - f := getTestParuqetFile() + f := getTestParquetFile() _, err := NewParquetSourcePushDown(f, []string{"non_existent_column"}, nil) if err == nil { t.Errorf("Expected error when invalid column names are passed in, but got nil") @@ -74,7 +73,7 @@ func TestParquetInit(t *testing.T) { }) t.Run("Test correct schema is returned", func(t *testing.T) { - f := getTestParuqetFile() + f := getTestParquetFile() columns := []string{"country", "capital", "lat"} source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { @@ -93,7 +92,7 @@ func TestParquetInit(t *testing.T) { }) t.Run("Test input columns and filters were passed back out", func(t *testing.T) { - f := getTestParuqetFile() + f := getTestParquetFile() columns := []string{"country", "capital", "lat"} source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { @@ -109,7 +108,7 @@ func TestParquetInit(t *testing.T) { t.Run("Check reader isnt null", func(t *testing.T) { - f := getTestParuqetFile() + f := getTestParquetFile() columns := []string{"country", "capital", "lat"} source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { @@ -123,7 +122,7 @@ func TestParquetInit(t *testing.T) { } func TestParquetClose(t *testing.T) { - f := getTestParuqetFile() + f := getTestParquetFile() columns := []string{"country", "capital", "lat"} source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { @@ -143,7 +142,7 @@ func TestParquetClose(t *testing.T) { } func TestRunToEnd(t *testing.T) { - f := getTestParuqetFile() + f := getTestParquetFile() columns := []string{"country", "capital", "lat"} source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { @@ -162,7 +161,7 @@ func TestRunToEnd(t *testing.T) { } func TestParquetRead(t *testing.T) { - f := getTestParuqetFile() + f := getTestParquetFile() columns := []string{"country", "capital", "lat"} source, err := NewParquetSourcePushDown(f, columns, nil) if err != nil { @@ -182,8 +181,8 @@ func TestParquetRead(t *testing.T) { if rc.Schema.NumFields() != len(columns) { t.Errorf("Expected schema to have %d fields, got %d", len(columns), rc.Schema.NumFields()) } - fmt.Printf("columns:%v\n", rc.Columns) - fmt.Printf("count:%d\n", rc.RowCount) + t.Logf("columns:%v\n", rc.Columns) + t.Logf("count:%d\n", rc.RowCount) } // CombineArray tests: cover primitive, uint, float, bool, string, binary and nil-handling diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index f98c443..f633416 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -21,12 +21,10 @@ type ProjectExec struct { child operators.Operator outputschema arrow.Schema columnsToKeep []string - // if done always return eof - done bool + done bool } -// columns to keep and existing scehma -// TODO: double check that calling operator.Schema here is fine, this means make sure all operators have their schema in order as soons as possible +// columns to keep and existing schema func NewProjectExec(projectColumns []string, input operators.Operator) (*ProjectExec, error) { newSchema, err := prunedSchema(input.Schema(), projectColumns) if err != nil { @@ -74,7 +72,7 @@ func (p *ProjectExec) Schema() *arrow.Schema { return &p.outputschema } -// handle keeping only the request columsn but make sure the schema and columns are also aligned +// handle keeping only the request columns but make sure the schema and columns are also aligned // returns error if a column doesnt exist func ProjectSchemaFilterDown(schema *arrow.Schema, cols []arrow.Array, keepCols ...string) (*arrow.Schema, []arrow.Array, error) { if len(keepCols) == 0 { @@ -107,8 +105,6 @@ func ProjectSchemaFilterDown(schema *arrow.Schema, cols []arrow.Array, keepCols return newSchema, newCols, nil } -// passing in a coulumn to keep that doesnt exist returns error -// passing in no columns returns and error func prunedSchema(schema *arrow.Schema, keepCols []string) (*arrow.Schema, error) { if len(keepCols) == 0 { return arrow.NewSchema([]arrow.Field{}, nil), ErrEmptyColumnsToProject diff --git a/src/Backend/opti-sql-go/operators/project/projectExec_test.go b/src/Backend/opti-sql-go/operators/project/projectExec_test.go index b7981bb..04a0ecd 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec_test.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec_test.go @@ -2,7 +2,6 @@ package project import ( "errors" - "fmt" "io" "testing" @@ -37,7 +36,7 @@ func TestProjectPrune(t *testing.T) { t.Fatalf("expected field %s, got %s", keepCols[i], field.Name) } } - fmt.Printf("%s\n", newSchema) + t.Logf("%s\n", newSchema) }) t.Run("validate prune 2", func(t *testing.T) { keeptCols := []string{"age", "country", "signup_date"} @@ -53,7 +52,7 @@ func TestProjectPrune(t *testing.T) { t.Fatalf("expected field %s, got %s", keeptCols[i], field.Name) } } - fmt.Printf("%s\n", newSchema) + t.Logf("%s\n", newSchema) }) t.Run("prune non-existant column", func(t *testing.T) { @@ -85,7 +84,7 @@ func TestProjectExec(t *testing.T) { if err != nil { t.Fatalf("failed to create in memory source: %v", err) } - fmt.Printf("original schema %v\n", memorySource.Schema()) + t.Logf("original schema %v\n", memorySource.Schema()) projectExec, err := NewProjectExec([]string{"id", "name", "age"}, memorySource) if err != nil { t.Fatalf("failed to create project exec: %v", err) @@ -94,7 +93,7 @@ func TestProjectExec(t *testing.T) { if err != nil { t.Fatalf("failed to get next record batch: %v", err) } - fmt.Printf("rc:%v\n", rc) + t.Logf("rc:%v\n", rc) } diff --git a/src/Backend/opti-sql-go/operators/project/s3.go b/src/Backend/opti-sql-go/operators/project/s3.go index cb47a62..b418503 100644 --- a/src/Backend/opti-sql-go/operators/project/s3.go +++ b/src/Backend/opti-sql-go/operators/project/s3.go @@ -12,13 +12,6 @@ import ( var secretes = config.GetConfig().Secretes -type mime string - -var ( - MimeCSV mime = "csv" - MimeParquet mime = "parquet" -) - type NetworkResource struct { client *minio.Client bucket string diff --git a/src/Backend/opti-sql-go/operators/project/source_test.go b/src/Backend/opti-sql-go/operators/project/source_test.go index debdbf4..facce88 100644 --- a/src/Backend/opti-sql-go/operators/project/source_test.go +++ b/src/Backend/opti-sql-go/operators/project/source_test.go @@ -1,7 +1,6 @@ package project import ( - "fmt" "io" "os" "strings" @@ -39,7 +38,7 @@ func TestS3BasicRead(t *testing.T) { if n != 1024 { t.Fatalf("expected to read 1024 bytes, but read %d bytes", n) } - fmt.Printf("returned contents %s\n", firstKB[:n]) + t.Logf("returned contents %s\n", firstKB[:n]) }) t.Run("parquet read", func(t *testing.T) { @@ -55,7 +54,7 @@ func TestS3BasicRead(t *testing.T) { if n != 1024 { t.Fatalf("expected to read 1024 bytes, but read %d bytes", n) } - fmt.Printf("returned contents %v\n", firstKB[:n]) + t.Logf("returned contents %v\n", firstKB[:n]) }) t.Run("txt read", func(t *testing.T) { @@ -71,7 +70,7 @@ func TestS3BasicRead(t *testing.T) { if n != 1024 { t.Fatalf("expected to read 1024 bytes, but read %d bytes", n) } - fmt.Printf("returned contents %s\n", firstKB[:n]) + t.Logf("returned contents %s\n", firstKB[:n]) }) } @@ -194,7 +193,7 @@ func TestS3ForSource(t *testing.T) { if err != nil { t.Fatalf("failed to read record batch from s3 csv source: %v", err) } - fmt.Printf("returned record batch from s3 csv source: %v\n", rc) + t.Logf("returned record batch from s3 csv source: %v\n", rc) }) t.Run("parquet from s3 source", func(t *testing.T) { @@ -210,7 +209,7 @@ func TestS3ForSource(t *testing.T) { if err != nil { t.Fatalf("failed to read record batch from s3 csv source: %v", err) } - fmt.Printf("returned record batch from s3 csv source: %v\n", rc) + t.Logf("returned record batch from s3 csv source: %v\n", rc) }) t.Run("csv from s3 source then downloaded", func(t *testing.T) { @@ -223,7 +222,7 @@ func TestS3ForSource(t *testing.T) { t.Fatalf("failed to download s3 object locally: %v", err) } defer func() { - fmt.Println("deleting downloaded file...") + t.Log("deleting downloaded file...") _ = f.Close() if err := os.Remove(f.Name()); err != nil { t.Fatalf("error closing file %v", f.Name()) @@ -241,7 +240,7 @@ func TestS3ForSource(t *testing.T) { if err != nil { t.Fatalf("failed to close csv project source: %v", err) } - fmt.Printf("returned record batch from s3 csv source: %v\n", rc) + t.Logf("returned record batch from s3 csv source: %v\n", rc) }) t.Run("parquet from s3 source then downloaded", func(t *testing.T) { @@ -254,7 +253,7 @@ func TestS3ForSource(t *testing.T) { t.Fatalf("failed to download s3 object locally: %v", err) } defer func() { - fmt.Println("deleting downloaded file...") + t.Log("deleting downloaded file...") _ = f.Close() if err := os.Remove(f.Name()); err != nil { t.Fatalf("error closing file %v", f.Name()) @@ -268,7 +267,7 @@ func TestS3ForSource(t *testing.T) { if err != nil { t.Fatalf("failed to read record batch from s3 csv source: %v", err) } - fmt.Printf("returned record batch from s3 csv source: %v\n", rc) + t.Logf("returned record batch from s3 csv source: %v\n", rc) }) } @@ -300,6 +299,6 @@ func TestS3Source(t *testing.T) { if n == 0 { t.Fatalf("expected to read some bytes from s3 object stream, but read 0 bytes") } - fmt.Printf("read %d bytes from s3 object stream: %s\n", n, string(buf[:n])) + t.Logf("read %d bytes from s3 object stream: %s\n", n, string(buf[:n])) }) } diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index 2059d3a..ba2c37d 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -18,13 +18,13 @@ var ( type Operator interface { Next(uint16) (*RecordBatch, error) Schema() *arrow.Schema - // Call Operator.Close() after Next retruns an io.EOF to clean up resources + // Call Operator.Close() after Next returns an io.EOF to clean up resources Close() error } type RecordBatch struct { Schema *arrow.Schema Columns []arrow.Array - RowCount uint64 // TODO: update to actaully use this, in all operators + RowCount uint64 // TODO: update to actually use this, in all operators } type SchemaBuilder struct { diff --git a/src/Backend/opti-sql-go/operators/serialize_test.go b/src/Backend/opti-sql-go/operators/serialize_test.go index b8b3cac..f7d19e3 100644 --- a/src/Backend/opti-sql-go/operators/serialize_test.go +++ b/src/Backend/opti-sql-go/operators/serialize_test.go @@ -136,7 +136,7 @@ func TestSerializerInit(t *testing.T) { // TestSchemaOnlySerialization tests standalone schema serialization/deserialization func TestSchemaOnlySerialization(t *testing.T) { recordBatch := generateDummyRecordBatch1() - fmt.Printf("original schema before serialization: %v\n", recordBatch.Schema) + t.Logf("original schema before serialization: %v\n", recordBatch.Schema) ss, err := NewSerializer(recordBatch.Schema) if err != nil { @@ -148,7 +148,7 @@ func TestSchemaOnlySerialization(t *testing.T) { if err != nil { t.Fatalf("Schema serialization failed: %v", err) } - fmt.Printf("serialized schema bytes length: %d\n", len(serializedSchema)) + t.Logf("serialized schema bytes length: %d\n", len(serializedSchema)) // Deserialize schema deserializedSchema, err := ss.schemaFromDisk(bytes.NewBuffer(serializedSchema)) @@ -160,7 +160,7 @@ func TestSchemaOnlySerialization(t *testing.T) { if !deserializedSchema.Equal(recordBatch.Schema) { t.Fatal("Deserialized schema does not match the original schema") } - fmt.Printf("schema after serialization & deserialization: %v\n", deserializedSchema) + t.Logf("schema after serialization & deserialization: %v\n", deserializedSchema) // Validate field properties for i := 0; i < recordBatch.Schema.NumFields(); i++ { @@ -745,7 +745,7 @@ func TestSerializationWithDifferentTypes(t *testing.T) { func TestNullSchemaSerialize(t *testing.T) { rb := generateNullableRecordBatch() for i := range rb.Schema.Fields() { - fmt.Printf("is nullable? : %v\n", rb.Schema.Field(i).Nullable) + t.Logf("is nullable? : %v\n", rb.Schema.Field(i).Nullable) } serializer, err := NewSerializer(rb.Schema) if err != nil { From ee7d4a2a77090c889d72b7e4304a3ada905c926d Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Tue, 18 Nov 2025 00:45:35 -0500 Subject: [PATCH 10/19] define Expr Interface & Expr Operators --- .../opti-sql-go/operators/Expr/expr.go | 125 +++++++++++++++--- .../opti-sql-go/operators/Expr/info.go | 0 src/Backend/opti-sql-rs/Cargo.toml | 2 + src/Backend/opti-sql-rs/src/main.rs | 65 ++++++++- 4 files changed, 166 insertions(+), 26 deletions(-) create mode 100644 src/Backend/opti-sql-go/operators/Expr/info.go diff --git a/src/Backend/opti-sql-go/operators/Expr/expr.go b/src/Backend/opti-sql-go/operators/Expr/expr.go index 5334beb..5a91a0b 100644 --- a/src/Backend/opti-sql-go/operators/Expr/expr.go +++ b/src/Backend/opti-sql-go/operators/Expr/expr.go @@ -1,24 +1,105 @@ package Expr -// evaluate expressions -// for example Column + Literal -// Column - Column -// Literal / Literal - -//1. Arithmetic Expressions -// SELECT salary * 1.2, price + tax, -(discount) -//2.Alias Expressions -//SELECT name AS employee_name, age AS employee_age -//SELECT salary * 1.2 AS new_salary -//3.String Expressions -//first_name || ' ' || last_name -//UPPER(name) -//LOWER(email) -//SUBSTRING(name, 1, 3) -//4. Function calls -//ABS(x) -//ROUND(salary, 2) -//LENGTH(name) -//COALESCE(a, b) -//5. Constants -//SELECT 1, 'hello', 3.14 +import "github.com/apache/arrow/go/v17/arrow" + +type binaryOperator int + +const ( + // arithmetic + addition binaryOperator = 1 + subtraction binaryOperator = 2 + multiplication binaryOperator = 3 + division binaryOperator = 4 + modulous binaryOperator = 5 + // comparison + equal binaryOperator = 6 + notEqual binaryOperator = 7 + lessThan binaryOperator = 8 + lessThanOrEqual binaryOperator = 9 + greaterThan binaryOperator = 10 + greaterThanOrEqual binaryOperator = 11 + // logical + and binaryOperator = 12 + or binaryOperator = 13 + not binaryOperator = 14 +) + +type supportedFunctions int + +const ( + upper supportedFunctions = 1 + lower supportedFunctions = 2 + abs supportedFunctions = 3 + round supportedFunctions = 4 +) + +type aggFunctions = int + +const ( + Sum aggFunctions = 1 + Count aggFunctions = 2 + Avg aggFunctions = 3 + Min aggFunctions = 4 + Max aggFunctions = 5 +) + +/* +Eval(expr): + + match expr: + Literal(x) -> return x + Column(name) -> return array of that column + BinaryExpr(left > right) -> eval left, eval right, apply operator + ScalarFunction(upper(name)) -> evaluate function + Alias(expr, name) -> just a name wrapper +*/ +type Expr interface { + ExprNode() // empty method, only for the sake of polymophism +} + +/* +Alias | sql: select col1 as new_name from table_source +updates the column name in the output schema. +*/ +type Alias struct { + expr []Expr + columnName string + name string +} + +// return batch.Columns[fieldIndex["age"]] +// resolves the arrow array corresponding to name passed in +// sql: select age +type ColumnResolve struct { + name string +} + +// Evaluates to a column of length = batch-size, filled with this literal. +// sql: select 1 +type LiteralResolve struct { + Type arrow.DataType + value any +} + +type Operator struct { + o binaryOperator +} +type BinaryExpr struct { + left []Expr + op Operator + right []Expr +} + +type ScalarFunction struct { + function supportedFunctions + input []Expr // resolve to something you can procees IE, literal/coloumn Resolve +} +type AggregateFunction struct { + function aggFunctions + args []Expr +} + +type CastExpr struct { + expr []Expr // can be a Literal or Column (check for datatype then) + targetType arrow.DataType +} diff --git a/src/Backend/opti-sql-go/operators/Expr/info.go b/src/Backend/opti-sql-go/operators/Expr/info.go new file mode 100644 index 0000000..e69de29 diff --git a/src/Backend/opti-sql-rs/Cargo.toml b/src/Backend/opti-sql-rs/Cargo.toml index 5e0c609..ffbedf7 100644 --- a/src/Backend/opti-sql-rs/Cargo.toml +++ b/src/Backend/opti-sql-rs/Cargo.toml @@ -5,3 +5,5 @@ edition = "2021" [dependencies] datafusion = "50.3.0" +datafusion-substrait = "50.3.0" +tokio = { version = "1.28.2", features = ["full"] } diff --git a/src/Backend/opti-sql-rs/src/main.rs b/src/Backend/opti-sql-rs/src/main.rs index 31841df..aa02e30 100644 --- a/src/Backend/opti-sql-rs/src/main.rs +++ b/src/Backend/opti-sql-rs/src/main.rs @@ -1,8 +1,65 @@ //use datafusion::arrow::record_batch::RecordBatch; +use datafusion::{arrow::{array::RecordBatch, util::pretty::print_batches}, functions::crypto::basic, prelude::*}; +use datafusion_substrait::*; mod project; #[allow(dead_code)] -fn main() { - println!("Hello, world!"); - project::project_exec::project_execute(); - project::source::csv::read_csv(); +#[tokio::main] +async fn main() { + let mut ctx = SessionContext::new(); + + ctx.register_csv( + "example", + "example.csv", + CsvReadOptions::new() + ).await.unwrap(); + // basic projections + //basic_project("Basic column projection",&mut ctx,"select name,salary from example").await; + //basic_project("reorder and duplicate projection",&mut ctx,"select salary,name,salary as s1 from example").await; // this is an error, expression names must be unique/ must use alias to get around error + // Literal projections + //basic_project("select int",&mut ctx,"select 1").await; + //basic_project("select string",&mut ctx,"select 'hello'").await; + //basic_project("reorder float",&mut ctx,"select 3.14").await; + // Literal + column + //basic_project("column plus literal",&mut ctx,"select salary + 10 from example").await; + //basic_project("column times literal",&mut ctx,"select salary * 2.4 from example").await; + // fully literal + //basic_project("two literals",&mut ctx,"select 5 + 10").await; + // column by column + // basic_project("column plus column",&mut ctx,"select salary + age from example").await; + //basic_project("Column-literal + nested arithmetic",&mut ctx,"SELECT (salary - age) * 1.08 FROM example;").await; + // alias operators + //basic_project("alias_1",&mut ctx,"SELECT age + 5 AS new_age FROM example").await; + //basic_project("alias_2 constant",&mut ctx,"SELECT 1 as greatest_number").await; + //String Expressions + //basic_project("select upper()",&mut ctx,"SELECT upper('richard')").await; + //basic_project("select lower()",&mut ctx,"SELECT lower(name)").await; + //basic_project("select substring()",&mut ctx,"SELECT substring(name,1,3)").await; + // mixed expressions + //basic_project("mixed expressions",&mut ctx,"SELECT upper(name) AS upper_name, salary * 1.1 AS increased_salary FROM example").await; + // function calls + let (l,r) = basic_project("function call with Abs()",&mut ctx,"SELECT ABS(age) FROM example").await; + //println!("Logical Plan:\n{}", l); + print_batches(&r).unwrap();ß + //basic_project("function call Round()",&mut ctx,"SELECT Round(age) FROM example").await; + //basic_project("function call Length()",&mut ctx,"SELECT LENGTH(name) FROM example").await; + +} + +pub async fn basic_project(name : &str,ctx : &mut SessionContext,sql : &str) -> (String,Vec) { + println!("Running project: {}",name); + let df1 = ctx.sql(sql) + .await + .unwrap(); + + + let logical_plan = df1.logical_plan().clone(); + let substrait_plan = logical_plan::producer::to_substrait_plan(&logical_plan, &ctx.state()).unwrap(); + print!("Substrait Plan :\n{:?}", substrait_plan); + + + let display = format!("{}",logical_plan.display_indent()); + + // Running will create the physical plan automatically + return (display,df1.collect().await.unwrap()); + } From 0ee601a863cfc926eb07a57eb03186dce70de9bc Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 19 Nov 2025 15:31:24 -0500 Subject: [PATCH 11/19] feat:defined and intergrated Expression interface to query execution enginee --- src/Backend/opti-sql-go/Expr/expr.go | 601 +++++++++ src/Backend/opti-sql-go/Expr/expr_test.go | 1194 +++++++++++++++++ src/Backend/opti-sql-go/Expr/info.go | 25 + .../opti-sql-go/operators/Expr/expr.go | 105 -- .../opti-sql-go/operators/Expr/expr_test.go | 7 - .../opti-sql-go/operators/Expr/info.go | 0 .../opti-sql-go/operators/project/csv.go | 2 - .../opti-sql-go/operators/project/custom.go | 33 + .../operators/project/projectExec.go | 100 +- .../operators/project/projectExecExpr_test.go | 843 ++++++++++++ .../operators/project/projectExec_test.go | 439 +++--- 11 files changed, 2968 insertions(+), 381 deletions(-) create mode 100644 src/Backend/opti-sql-go/Expr/expr.go create mode 100644 src/Backend/opti-sql-go/Expr/expr_test.go create mode 100644 src/Backend/opti-sql-go/Expr/info.go delete mode 100644 src/Backend/opti-sql-go/operators/Expr/expr.go delete mode 100644 src/Backend/opti-sql-go/operators/Expr/expr_test.go delete mode 100644 src/Backend/opti-sql-go/operators/Expr/info.go create mode 100644 src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go new file mode 100644 index 0000000..e278093 --- /dev/null +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -0,0 +1,601 @@ +package Expr + +import ( + "context" + "fmt" + "opti-sql-go/operators" + "strings" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/compute" + "github.com/apache/arrow/go/v17/arrow/memory" +) + +var ( + ErrUnsupportedExpression = func(info string) error { + return fmt.Errorf("unsupported expression passed to EvalExpression: %s", info) + } +) + +type binaryOperator int + +const ( + // arithmetic + Addition binaryOperator = 1 + Subtraction binaryOperator = 2 + Multiplication binaryOperator = 3 + Division binaryOperator = 4 + // comparison + Equal binaryOperator = 6 + NotEqual binaryOperator = 7 + LessThan binaryOperator = 8 + LessThanOrEqual binaryOperator = 9 + GreaterThan binaryOperator = 10 + GreaterThanOrEqual binaryOperator = 11 + // logical + And binaryOperator = 12 + Or binaryOperator = 13 + Not binaryOperator = 14 +) + +type supportedFunctions int + +const ( + Upper supportedFunctions = 1 + Lower supportedFunctions = 2 + Abs supportedFunctions = 3 + Round supportedFunctions = 4 +) + +type aggFunctions = int + +const ( + Sum aggFunctions = 1 + Count aggFunctions = 2 + Avg aggFunctions = 3 + Min aggFunctions = 4 + Max aggFunctions = 5 +) + +var ( + _ = (Expression)(&Alias{}) + _ = (Expression)(&ColumnResolve{}) + _ = (Expression)(&LiteralResolve{}) + _ = (Expression)(&BinaryExpr{}) + _ = (Expression)(&ScalarFunction{}) + _ = (Expression)(&CastExpr{}) +) + +// TODO: create nice wrapper functions for creating expressions +/* +Eval(expr): + + match expr: + Literal(x) -> return x + Column(name) -> return array of that column + BinaryExpr(left > right) -> eval left, eval right, apply operator + ScalarFunction(upper(name)) -> evaluate function + Alias(expr, name) -> just a name wrapper +*/ +type Expression interface { + //ExprNode(expr Expr, batch *operators.RecordBatch) (arrow.Array, error) + // empty method, only for the sake of polymorphism + ExprNode() + fmt.Stringer +} + +func EvalExpression(expr Expression, batch *operators.RecordBatch) (arrow.Array, error) { + switch e := expr.(type) { + case *Alias: + return EvalAlias(e, batch) + case *ColumnResolve: + return EvalColumn(e, batch) + case *LiteralResolve: + return EvalLiteral(e, batch) + case *BinaryExpr: + return EvalBinary(e, batch) + case *ScalarFunction: + return EvalScalarFunction(e, batch) + case *CastExpr: + return EvalCast(e, batch) + default: + return nil, ErrUnsupportedExpression(expr.String()) + } +} + +func ExprDataType(e Expression, inputSchema *arrow.Schema) arrow.DataType { + switch ex := e.(type) { + + case *LiteralResolve: + return ex.Type + + case *ColumnResolve: + idx := inputSchema.FieldIndices(ex.Name) + if len(idx) == 0 { + panic(fmt.Sprintf("exprDataType: unknown column %q", ex.Name)) + } + return inputSchema.Field(idx[0]).Type + case *Alias: + // alias does NOT change type + return ExprDataType(ex.Expr, inputSchema) + + case *CastExpr: + return ex.TargetType + + case *BinaryExpr: + leftType := ExprDataType(ex.Left, inputSchema) + rightType := ExprDataType(ex.Right, inputSchema) + return inferBinaryType(leftType, ex.Op, rightType) + + case *ScalarFunction: + argType := ExprDataType(ex.Arguments, inputSchema) + return inferScalarFunctionType(ex.Function, argType) + + default: + panic(fmt.Sprintf("unsupported expr type %T", ex)) + } +} +func NewExpressions(exprs ...Expression) []Expression { + return exprs +} + +/* +Alias | sql: select col1 as new_name from table_source +updates the column name in the output schema. +*/ +type Alias struct { + Expr Expression + Name string +} + +func NewAlias(expr Expression, name string) *Alias { + return &Alias{ + Expr: expr, + Name: name, + } +} + +func EvalAlias(a *Alias, batch *operators.RecordBatch) (arrow.Array, error) { + return EvalExpression(a.Expr, batch) +} +func (a *Alias) ExprNode() {} +func (a *Alias) String() string { + return fmt.Sprintf("Alias(%s AS %s)", a.Expr, a.Name) + +} + +// resolves the arrow array corresponding to name passed in +// sql: select age +type ColumnResolve struct { + Name string +} + +func NewColumnResolve(name string) *ColumnResolve { + return &ColumnResolve{Name: name} +} + +func EvalColumn(c *ColumnResolve, batch *operators.RecordBatch) (arrow.Array, error) { + // schema and columns are always aligned + for i, f := range batch.Schema.Fields() { + if f.Name == c.Name { + col := batch.Columns[i] + col.Retain() + return col, nil + } + } + return nil, fmt.Errorf("column %s not found", c.Name) +} +func (c *ColumnResolve) ExprNode() {} +func (c *ColumnResolve) String() string { + return fmt.Sprintf("Column(%s)", c.Name) +} + +// Evaluates to a column of length = batch-size, filled with this literal. +// sql: select 1 +type LiteralResolve struct { + Type arrow.DataType + // dont forget to cast the value. so string("hello") not just "hello" + Value any +} + +func NewLiteralResolve(Type arrow.DataType, Value any) *LiteralResolve { + return &LiteralResolve{Type: Type, Value: Value} +} +func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, error) { + n := int(batch.RowCount) + + switch l.Type.ID() { + + // ------------------------------ + // BOOL + // ------------------------------ + case arrow.BOOL: + val := l.Value.(bool) + b := array.NewBooleanBuilder(memory.DefaultAllocator) + defer b.Release() + + for i := 0; i < n; i++ { + b.Append(val) + } + return b.NewArray(), nil + + // ------------------------------ + // INT / UINT (8/16/32/64) + // ------------------------------ + case arrow.INT8: + v := int8(l.Value.(int8)) + + b := array.NewInt8Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + case arrow.UINT8: + v := l.Value.(uint8) + + b := array.NewUint8Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + case arrow.INT16: + v := int16(l.Value.(int16)) + b := array.NewInt16Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + case arrow.UINT16: + v := uint16(l.Value.(uint16)) + b := array.NewUint16Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + case arrow.INT32: + v := int32(l.Value.(int32)) + b := array.NewInt32Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + case arrow.UINT32: + v := uint32(l.Value.(uint32)) + b := array.NewUint32Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + // correct jump + case arrow.INT64: + v := int64(l.Value.(int64)) + b := array.NewInt64Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + case arrow.UINT64: + v := uint64(l.Value.(uint64)) + b := array.NewUint64Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + // ------------------------------ + // FLOATS + // ------------------------------ + case arrow.FLOAT32: + v := float32(l.Value.(float32)) + b := array.NewFloat32Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + case arrow.FLOAT64: + v := float64(l.Value.(float64)) + b := array.NewFloat64Builder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + // ------------------------------ + // STRING + // ------------------------------ + case arrow.STRING: + v := l.Value.(string) + b := array.NewStringBuilder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + // ------------------------------ + // BINARY + // ------------------------------ + case arrow.BINARY: + v := l.Value.([]byte) + b := array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary) + defer b.Release() + for i := 0; i < n; i++ { + b.Append(v) + } + return b.NewArray(), nil + + default: + return nil, fmt.Errorf("literal type %s not supported", l.Type) + } +} + +func (l *LiteralResolve) ExprNode() {} +func (l *LiteralResolve) String() string { + return fmt.Sprintf("Literal(%v)", l.Value) +} + +type BinaryExpr struct { + Left Expression + Op binaryOperator + Right Expression +} + +func NewBinaryExpr(left Expression, op binaryOperator, right Expression) *BinaryExpr { + return &BinaryExpr{ + Left: left, + Op: op, + Right: right, + } +} + +func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error) { + leftArr, err := EvalExpression(b.Left, batch) + if err != nil { + return nil, err + } + rightArr, err := EvalExpression(b.Right, batch) + if err != nil { + return nil, err + } + opt := compute.ArithmeticOptions{} + switch b.Op { + // arithmetic + case Addition: + datum, err := compute.Add(context.TODO(), opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) + case Subtraction: + datum, err := compute.Subtract(context.TODO(), opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) + + case Multiplication: + datum, err := compute.Multiply(context.TODO(), opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) + case Division: + datum, err := compute.Divide(context.TODO(), opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) + + // comparisions TODO: + case Equal: + case NotEqual: + case LessThan: + case LessThanOrEqual: + case GreaterThan: + case GreaterThanOrEqual: + // logical + case And: + case Or: + case Not: + } + return nil, fmt.Errorf("binary operator %d not supported", b.Op) +} +func (b *BinaryExpr) ExprNode() {} +func (b *BinaryExpr) String() string { + return fmt.Sprintf("BinaryExpr(%s %d %s)", b.Left, b.Op, b.Right) +} +func unpackDatum(d compute.Datum) (arrow.Array, error) { + array, ok := d.(*compute.ArrayDatum) + if !ok { + return nil, fmt.Errorf("datum %v is not of type array", d) + } + return array.MakeArray(), nil +} +func inferBinaryType(left arrow.DataType, op binaryOperator, right arrow.DataType) arrow.DataType { + switch op { + + case Addition, Subtraction, Multiplication, Division: + // numeric → numeric promotion rules + return numericPromotion(left, right) + + case Equal, NotEqual, LessThan, LessThanOrEqual, GreaterThan, GreaterThanOrEqual: + return arrow.FixedWidthTypes.Boolean + + case And, Or: + return arrow.FixedWidthTypes.Boolean + + default: + panic(fmt.Sprintf("inferBinaryType: unsupported operator %v", op)) + } +} +func numericPromotion(a, b arrow.DataType) arrow.DataType { + // simplest version: return float64 for any mixed numeric types. + // expand later when needed. + return arrow.PrimitiveTypes.Float64 +} + +type ScalarFunction struct { + Function supportedFunctions + Arguments Expression // resolve to something you can process IE, literal/coloumn Resolve +} + +func NewScalarFunction(function supportedFunctions, Argument Expression) *ScalarFunction { + return &ScalarFunction{ + Function: function, + Arguments: Argument, + } +} + +func EvalScalarFunction(s *ScalarFunction, batch *operators.RecordBatch) (arrow.Array, error) { + switch s.Function { + case Upper: + arr, err := EvalExpression(s.Arguments, batch) + if err != nil { + return nil, err + } + return upperImpl(arr) + + case Lower: + arr, err := EvalExpression(s.Arguments, batch) + if err != nil { + return nil, err + } + return lowerImpl(arr) + case Abs: + arr, err := EvalExpression(s.Arguments, batch) + if err != nil { + return nil, err + } + datum, err := compute.AbsoluteValue(context.TODO(), compute.ArithmeticOptions{}, compute.NewDatum(arr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) + case Round: + arr, err := EvalExpression(s.Arguments, batch) + if err != nil { + return nil, err + } + datum, err := compute.Round(context.TODO(), compute.DefaultRoundOptions, compute.NewDatum(arr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) + } + return nil, fmt.Errorf("unsupported scalar function %v", s.Function) +} +func (s *ScalarFunction) ExprNode() {} +func (s *ScalarFunction) String() string { + return fmt.Sprintf("ScalarFunction(%d, %v)", s.Function, s.Arguments) +} +func upperImpl(arr arrow.Array) (arrow.Array, error) { + strArr, ok := arr.(*array.String) + if !ok { + return nil, fmt.Errorf("upper function only supports string arrays, got %s", arr.DataType()) + } + b := array.NewStringBuilder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < strArr.Len(); i++ { + if strArr.IsNull(i) { + b.AppendNull() + } else { + b.Append(strings.ToUpper(strArr.Value(i))) + } + } + return b.NewArray(), nil +} +func lowerImpl(arr arrow.Array) (arrow.Array, error) { + { + strArr, ok := arr.(*array.String) + if !ok { + return nil, fmt.Errorf("lower function only supports string arrays, got %s", arr.DataType()) + } + b := array.NewStringBuilder(memory.DefaultAllocator) + defer b.Release() + for i := 0; i < strArr.Len(); i++ { + if strArr.IsNull(i) { + b.AppendNull() + } else { + b.Append(strings.ToLower(strArr.Value(i))) + } + } + return b.NewArray(), nil + } +} +func inferScalarFunctionType(fn supportedFunctions, argType arrow.DataType) arrow.DataType { + switch fn { + + case Upper, Lower: + if argType.ID() != arrow.STRING { + panic("upper/lower only support string types") + } + return arrow.BinaryTypes.String + + case Abs, Round: + return argType // numeric-in numeric-out + + default: + panic(fmt.Sprintf("unknown scalar function %v", fn)) + } +} + +// not a priority at all +type AggregateFunction struct { + Function aggFunctions + Args Expression +} + +// If cast succeeds → return the casted value +// If cast fails → throw a runtime error +type CastExpr struct { + Expr Expression // can be a Literal or Column (check for datatype when you resolve) + TargetType arrow.DataType +} + +func NewCastExpr(expr Expression, targetType arrow.DataType) *CastExpr { + return &CastExpr{ + Expr: expr, + TargetType: targetType, + } +} + +func EvalCast(c *CastExpr, batch *operators.RecordBatch) (arrow.Array, error) { + arr, err := EvalExpression(c.Expr, batch) + if err != nil { + return nil, err + } + + // Use Arrow compute kernel to cast + castOpts := compute.SafeCastOptions(c.TargetType) + out, err := compute.CastArray(context.TODO(), arr, castOpts) + if err != nil { + // This is a runtime cast error + return nil, fmt.Errorf("cast error: cannot cast %s to %s: %w", + arr.DataType(), c.TargetType, err) + } + + return out, nil +} + +func (c *CastExpr) ExprNode() {} +func (c *CastExpr) String() string { + return fmt.Sprintf("Cast(%s AS %s)", c.Expr, c.TargetType) +} diff --git a/src/Backend/opti-sql-go/Expr/expr_test.go b/src/Backend/opti-sql-go/Expr/expr_test.go new file mode 100644 index 0000000..6fe9115 --- /dev/null +++ b/src/Backend/opti-sql-go/Expr/expr_test.go @@ -0,0 +1,1194 @@ +package Expr + +import ( + "log" + "opti-sql-go/operators" + "testing" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/compute" + "github.com/apache/arrow/go/v17/arrow/memory" +) + +// (removed helper that constructed a record builder - not needed in Expr tests) + +// Do not change this +// generateTestColumns returns a RecordBatch containing the first 4 rows of the +// commonly used test data. Returning a RecordBatch allows Expr tests to avoid +// depending on the `project` package (no import cycle). +func generateTestColumns() *operators.RecordBatch { + mem := memory.DefaultAllocator + + // id + idB := array.NewInt32Builder(mem) + defer idB.Release() + idB.AppendValues([]int32{1, 2, 3, 4}, nil) + idArr := idB.NewArray() + + // name + nameB := array.NewStringBuilder(mem) + defer nameB.Release() + nameB.AppendValues([]string{"Alice", "Bob", "Charlie", "David"}, nil) + nameArr := nameB.NewArray() + + // age + ageB := array.NewInt32Builder(mem) + defer ageB.Release() + ageB.AppendValues([]int32{28, 34, 45, 22}, nil) + ageArr := ageB.NewArray() + + // salary + salB := array.NewFloat64Builder(mem) + defer salB.Release() + salB.AppendValues([]float64{70000.0, 82000.5, 54000.0, 91000.0}, nil) + salArr := salB.NewArray() + + // is_active + actB := array.NewBooleanBuilder(mem) + defer actB.Release() + actB.AppendValues([]bool{true, false, true, true}, nil) + actArr := actB.NewArray() + + fields := []arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "name", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "age", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "salary", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "is_active", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + } + schema := arrow.NewSchema(fields, nil) + + cols := []arrow.Array{idArr, nameArr, ageArr, salArr, actArr} + return &operators.RecordBatch{Schema: schema, Columns: cols, RowCount: 4} +} +func GenerateColumsNull() arrow.Array { + b := array.NewStringBuilder(memory.DefaultAllocator) + defer b.Release() + b.AppendNull() + b.Append("first-value") + b.AppendNull() + b.Append("second-value") + + return b.NewArray() +} + +func TestAliasExpr(t *testing.T) { + rc := generateTestColumns() // 4 + // alias should return the underlying expression with a new name + // the name swap out occures in project so for now just validate evaluation is as expected + t.Run("Alias on Name", func(t *testing.T) { + a := Alias{ + Expr: &ColumnResolve{Name: "name"}, + Name: "employee_name", + } + _ = a.String() + arr, err := EvalExpression(&a, rc) + if err != nil { + t.Fatalf("failed to evaluate alias expression: %v", err) + } + expectedArr := []string{"Alice", "Bob", "Charlie", "David"} + if len(expectedArr) != arr.Len() { + t.Fatalf("expected length %d, got %d", len(expectedArr), arr.Len()) + } + for i := 0; i < arr.Len(); i++ { + expected := expectedArr[i] + actual := arr.(*array.String).Value(i) + if expected != actual { + t.Fatalf("at index %d, expected %s, got %s", i, expected, actual) + } + } + + }) + t.Run("Alias on Salary", func(t *testing.T) { + a := Alias{ + Expr: &ColumnResolve{Name: "salary"}, + Name: "employee_salary", + } + _ = a.String() + arr, err := EvalExpression(&a, rc) + if err != nil { + t.Fatalf("failed to evaluate alias expression: %v", err) + } + expectedArr := []float64{70000.0, 82000.5, 54000.0, 91000.0} + if len(expectedArr) != arr.Len() { + t.Fatalf("expected length %d, got %d", len(expectedArr), arr.Len()) + } + for i := 0; i < arr.Len(); i++ { + expected := expectedArr[i] + actual := arr.(*array.Float64).Value(i) + if expected != actual { + t.Fatalf("at index %d, expected %f, got %f", i, expected, actual) + } + } + + }) + t.Run("Alias on is_active", func(t *testing.T) { + a := Alias{ + Expr: &ColumnResolve{Name: "is_active"}, + Name: "active_status", + } + _ = a.String() + arr, err := EvalExpression(&a, rc) + if err != nil { + t.Fatalf("failed to evaluate alias expression: %v", err) + } + expectedArr := []bool{true, false, true, true} + if len(expectedArr) != arr.Len() { + t.Fatalf("expected length %d, got %d", len(expectedArr), arr.Len()) + } + for i := 0; i < arr.Len(); i++ { + expected := expectedArr[i] + actual := arr.(*array.Boolean).Value(i) + if expected != actual { + t.Fatalf("at index %d, expected %v, got %v", i, expected, actual) + } + } + }) + // interface validation + a := Alias{Name: "New_Name"} + a.ExprNode() + t.Logf("%s", a.String()) +} + +func TestColumnResolve(t *testing.T) { + rc := generateTestColumns() // + t.Run("ColumnResolve on age", func(t *testing.T) { + cr := ColumnResolve{Name: "age"} + arr, err := EvalExpression(&cr, rc) + if err != nil { + t.Fatalf("failed to evaluate column resolve expression: %v", err) + } + expectedArr := []int32{28, 34, 45, 22} + if len(expectedArr) != arr.Len() { + t.Fatalf("expected length %d, got %d", len(expectedArr), arr.Len()) + } + for i := 0; i < arr.Len(); i++ { + expected := expectedArr[i] + actual := arr.(*array.Int32).Value(i) + if expected != actual { + t.Fatalf("at index %d, expected %d, got %d", i, expected, actual) + } + } + }) + t.Run("ColumnResolve on ID", func(t *testing.T) { + cr := ColumnResolve{Name: "id"} + arr, err := EvalExpression(&cr, rc) + if err != nil { + t.Fatalf("failed to evaluate column resolve expression: %v", err) + } + expectedArr := []int32{1, 2, 3, 4} + if len(expectedArr) != arr.Len() { + t.Fatalf("expected length %d, got %d", len(expectedArr), arr.Len()) + } + for i := 0; i < arr.Len(); i++ { + expected := expectedArr[i] + actual := arr.(*array.Int32).Value(i) + if expected != actual { + t.Fatalf("at index %d, expected %d, got %d", i, expected, actual) + } + } + }) + t.Run("ColumnResolve on non-existant column", func(t *testing.T) { + cr := ColumnResolve{Name: "doesnt Exist"} + _, err := EvalExpression(&cr, rc) + if err == nil { + t.Fatalf("expected error for non existant column") + } + }) + // interface Validation + cr := ColumnResolve{Name: "--"} + cr.ExprNode() + t.Logf("%s\n", cr.String()) +} + +func TestLiteralResolve(t *testing.T) { + t.Run("EvalLiteral", func(t *testing.T) { + rc := generateTestColumns() // + + // ------------------------- + // BOOLEAN + // ------------------------- + t.Run("BOOL", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.FixedWidthTypes.Boolean, + Value: true, + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + b := arr.(*array.Boolean) + for i := 0; i < arr.Len(); i++ { + if !b.Value(i) { + t.Fatalf("expected true at index %d", i) + } + } + }) + // ------------------------- + // Int8 + // ------------------------- + t.Run("INT8", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Int8, + Value: int8(-5), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := arr.(*array.Int8) + if out.Len() != 4 { + t.Fatalf("expected 4, got %d", out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != -5 { + t.Fatalf("expected -5 at %d, got %d", i, out.Value(i)) + } + } + }) + + // ------------------------- + // Uint8 + // ------------------------- + t.Run("UINT8", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Uint8, + Value: uint8(7), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := arr.(*array.Uint8) + if out.Len() != 4 { + t.Fatalf("expected 4, got %d", out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 7 { + t.Fatalf("expected 7 at %d, got %d", i, out.Value(i)) + } + } + }) + // ------------------------- + // int16 + // ------------------------- + + t.Run("INT16", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Int16, + Value: int16(1234), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := arr.(*array.Int16) + if out.Len() != 4 { + t.Fatalf("expected 4, got %d", out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 1234 { + t.Fatalf("expected 1234 at %d, got %d", i, out.Value(i)) + } + } + }) + // ------------------------- + // Uint16 + // ------------------------- + t.Run("UINT16", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Uint16, + Value: uint16(60000), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := arr.(*array.Uint16) + if out.Len() != 4 { + t.Fatalf("expected 4, got %d", out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 60000 { + t.Fatalf("expected 60000 at %d, got %d", i, out.Value(i)) + } + } + }) + + // ------------------------- + // INT32 + // ------------------------- + t.Run("INT32", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Int32, + Value: int32(99), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + col := arr.(*array.Int32) + for i := 0; i < arr.Len(); i++ { + if col.Value(i) != 99 { + t.Fatalf("expected 99, got %d", col.Value(i)) + } + } + }) + // ------------------------- + // UINT32 + // ------------------------- + t.Run("UINT32", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Uint32, + Value: uint32(4000000000), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := arr.(*array.Uint32) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 4000000000 { + t.Fatalf("expected 4000000000 at %d, got %d", i, out.Value(i)) + } + } + }) + + // ------------------------- + // INT64 + // ------------------------- + t.Run("INT64", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Int64, + Value: int64(123456), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + col := arr.(*array.Int64) + for i := 0; i < arr.Len(); i++ { + if col.Value(i) != 123456 { + t.Fatalf("expected 123456, got %d", col.Value(i)) + } + } + }) + // ------------------------- + // UINT64 + // ------------------------- + t.Run("UINT64", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Uint64, + Value: uint64(9999999999), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := arr.(*array.Uint64) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 9999999999 { + t.Fatalf("expected 9999999999 at %d, got %d", i, out.Value(i)) + } + } + }) + // ------------------------- + // FLOAT32 + // ------------------------- + + t.Run("FLOAT32", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Float32, + Value: float32(3.14), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := arr.(*array.Float32) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != float32(3.14) { + t.Fatalf("expected 3.14 at %d, got %f", i, out.Value(i)) + } + } + }) + + // ------------------------- + // FLOAT64 + // ------------------------- + t.Run("FLOAT64", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Float64, + Value: float64(3.14), + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + col := arr.(*array.Float64) + for i := 0; i < arr.Len(); i++ { + if col.Value(i) != 3.14 { + t.Fatalf("expected 3.14, got %f", col.Value(i)) + } + } + }) + + // ------------------------- + // STRING + // ------------------------- + t.Run("STRING", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.BinaryTypes.String, + Value: "hello", + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + col := arr.(*array.String) + for i := 0; i < arr.Len(); i++ { + if col.Value(i) != "hello" { + t.Fatalf("expected 'hello', got '%s'", col.Value(i)) + } + } + }) + + // ------------------------- + // BINARY + // ------------------------- + t.Run("BINARY", func(t *testing.T) { + bval := []byte{1, 2, 3} + lit := &LiteralResolve{ + Type: arrow.BinaryTypes.Binary, + Value: bval, + } + arr, err := EvalLiteral(lit, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + col := arr.(*array.Binary) + for i := 0; i < arr.Len(); i++ { + v := col.Value(i) + if string(v) != string(bval) { + t.Fatalf("expected %v, got %v", bval, v) + } + } + }) + + // ------------------------- + // ERROR CASE + // ------------------------- + t.Run("UnsupportedType", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.FixedWidthTypes.Duration_s, // something you did NOT implement + Value: int64(10), + } + _, err := EvalLiteral(lit, rc) + if err == nil { + t.Fatalf("expected error for unsupported type, got nil") + } + }) + + // ------------------------- + // Validate .String() and .ExprNode() + // ------------------------- + t.Run("Interface methods", func(t *testing.T) { + lit := &LiteralResolve{ + Type: arrow.PrimitiveTypes.Int32, + Value: int32(1), + } + + // Just call, like your other tests + t.Logf("%s\n", lit.String()) + lit.ExprNode() + }) + }) + +} + +func TestBinaryExpr(t *testing.T) { + t.Run("BinaryExpr Arithmetic", func(t *testing.T) { + rc := generateTestColumns() //4 + + makeLit := func(v int32) Expression { + return &LiteralResolve{ + Type: arrow.PrimitiveTypes.Int32, + Value: v, + } + } + + t.Run("Addition", func(t *testing.T) { + b := &BinaryExpr{ + Left: makeLit(10), + Op: Addition, + Right: makeLit(5), + } + + arr, err := EvalBinary(b, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Int32) + for i := 0; i < out.Len(); i++ { + got := out.Value(i) + if got != 15 { + t.Fatalf("expected 15, got %d (index %d)", got, i) + } + } + }) + + t.Run("Subtraction", func(t *testing.T) { + b := &BinaryExpr{ + Left: makeLit(20), + Op: Subtraction, + Right: makeLit(3), + } + + arr, err := EvalExpression(b, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Int32) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 17 { + t.Fatalf("expected 17, got %d", out.Value(i)) + } + } + }) + + t.Run("Multiplication", func(t *testing.T) { + b := &BinaryExpr{ + Left: makeLit(7), + Op: Multiplication, + Right: makeLit(6), + } + + arr, err := EvalExpression(b, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Int32) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 42 { + t.Fatalf("expected 42, got %d", out.Value(i)) + } + } + }) + + t.Run("Division", func(t *testing.T) { + b := &BinaryExpr{ + Left: makeLit(20), + Op: Division, + Right: makeLit(4), + } + + arr, err := EvalExpression(b, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Int32) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 5 { + t.Fatalf("expected 5, got %d", out.Value(i)) + } + } + }) + + t.Run("UnsupportedOperator", func(t *testing.T) { + b := &BinaryExpr{ + Left: makeLit(1), + Op: 9999, // unsupported + Right: makeLit(1), + } + + _, err := EvalExpression(b, rc) + if err == nil { + t.Fatalf("expected error for unsupported operator, got nil") + } + }) + t.Run("Invalid datum", func(t *testing.T) { + datum := compute.NewDatum(4) + t.Logf("datum:%v\n", datum) + _, err := unpackDatum(datum) + if err == nil { + t.Fatalf("expected error for invalid datum, got nil") + } + }) + + // -- interface / string validation --------------------------------------- + + be := &BinaryExpr{ + Left: makeLit(1), + Op: Addition, + Right: makeLit(2), + } + log.Printf("%s", be.String()) + be.ExprNode() + + t.Logf("BinaryExpr String(): %s\n", be.String()) + }) + +} + +func TestScalarFunctions(t *testing.T) { + t.Run("ScalarFunction", func(t *testing.T) { + + rc := generateTestColumns() //4 + + // Utility literal helper for ints + makeInt := func(v int32) Expression { + return &LiteralResolve{ + Type: arrow.PrimitiveTypes.Int32, + Value: v, + } + } + + // Utility literal helper for strings + makeStr := func(v string) Expression { + return &LiteralResolve{ + Type: arrow.BinaryTypes.String, + Value: v, + } + } + + // ------------------------- + // UPPER + // ------------------------- + t.Run("Upper", func(t *testing.T) { + sf := &ScalarFunction{ + Function: Upper, + Arguments: makeStr("hello"), + } + + arr, err := EvalExpression(sf, rc) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + out := arr.(*array.String) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != "HELLO" { + t.Fatalf("expected HELLO, got %s", out.Value(i)) + } + } + }) + + // ------------------------- + // LOWER + // ------------------------- + t.Run("Lower", func(t *testing.T) { + sf := &ScalarFunction{ + Function: Lower, + Arguments: makeStr("HeLLo"), + } + + arr, err := EvalExpression(sf, rc) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + out := arr.(*array.String) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != "hello" { + t.Fatalf("expected hello, got %s", out.Value(i)) + } + } + }) + + // ------------------------- + // ABS + // ------------------------- + t.Run("Abs", func(t *testing.T) { + sf := &ScalarFunction{ + Function: Abs, + Arguments: makeInt(-9), + } + + arr, err := EvalExpression(sf, rc) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + out := arr.(*array.Int32) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 9 { + t.Fatalf("expected 9, got %d", out.Value(i)) + } + } + }) + + // ------------------------- + // ROUND + // ------------------------- + t.Run("Round", func(t *testing.T) { + sf := &ScalarFunction{ + Function: Round, + Arguments: &LiteralResolve{ + Type: arrow.PrimitiveTypes.Float64, + Value: 3.67, + }, + } + + arr, err := EvalExpression(sf, rc) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + out := arr.(*array.Float64) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 4 { + t.Fatalf("expected 4, got %f", out.Value(i)) + } + } + }) + + // ------------------------- + // TYPE ERROR for UPPER + // ------------------------- + t.Run("UpperTypeError", func(t *testing.T) { + sf := &ScalarFunction{ + Function: Upper, + Arguments: makeInt(99), // not a string + } + + _, err := EvalExpression(sf, rc) + if err == nil { + t.Fatalf("expected type error, got nil") + } + }) + t.Run("LowerTypeError", func(t *testing.T) { + sf := &ScalarFunction{ + Function: Lower, + Arguments: makeInt(99), // not a string + } + + _, err := EvalExpression(sf, rc) + if err == nil { + t.Fatalf("expected type error, got nil") + } + }) + t.Run("Upper with Nulls", func(t *testing.T) { + nullArr := GenerateColumsNull() + schema := arrow.NewSchema([]arrow.Field{ + {Name: "col1", Type: arrow.BinaryTypes.String, Nullable: true}, + }, nil) + defer nullArr.Release() + cr := ColumnResolve{Name: "col1"} + sf := &ScalarFunction{ + Function: Upper, + Arguments: &cr, + } + a, err := EvalExpression(sf, &operators.RecordBatch{Schema: schema, Columns: []arrow.Array{nullArr}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + t.Logf("array:%v", a) + + }) + t.Run("Lower with Nulls", func(t *testing.T) { + nullArr := GenerateColumsNull() + schema := arrow.NewSchema([]arrow.Field{ + {Name: "col1", Type: arrow.BinaryTypes.String, Nullable: true}, + }, nil) + defer nullArr.Release() + cr := ColumnResolve{Name: "col1"} + sf := &ScalarFunction{ + Function: Lower, + Arguments: &cr, + } + a, err := EvalExpression(sf, &operators.RecordBatch{Schema: schema, Columns: []arrow.Array{nullArr}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + t.Logf("array:%v", a) + + }) + // ------------------------- + // Unsupported function + // ------------------------- + t.Run("UnsupportedFunction", func(t *testing.T) { + sf := &ScalarFunction{ + Function: 999, + Arguments: makeStr("hi"), + } + + _, err := EvalExpression(sf, rc) + if err == nil { + t.Fatalf("expected unsupported function error, got nil") + } + }) + + // ---------------------------------------- + // INTERFACE VALIDATION + // ---------------------------------------- + s := &ScalarFunction{ + Function: Upper, + Arguments: makeStr("ok"), + } + + // Should not panic + s.ExprNode() + + // Print the string representation + t.Logf("%s\n", s.String()) + }) + +} + +func TestCastResolve(t *testing.T) { + t.Run("CastExpr", func(t *testing.T) { + + rc := generateTestColumns() //4 + + // ---- Helpers ----- + makeLitInt := func(v int32) Expression { + return &LiteralResolve{ + Type: arrow.PrimitiveTypes.Int32, + Value: v, + } + } + + makeLitStr := func(v string) Expression { + return &LiteralResolve{ + Type: arrow.BinaryTypes.String, + Value: v, + } + } + + // ---------------------------------------- + // 1. CAST Literal Int32 -> Float64 + // ---------------------------------------- + t.Run("Literal_Int32_to_Float64", func(t *testing.T) { + ce := &CastExpr{ + Expr: makeLitInt(42), + TargetType: arrow.PrimitiveTypes.Float64, + } + + arr, err := EvalExpression(ce, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Float64) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 42.0 { + t.Fatalf("expected 42.0, got %f", out.Value(i)) + } + } + }) + + // ---------------------------------------- + // 2. CAST Literal String -> Int32 (invalid) + // ---------------------------------------- + t.Run("Literal_String_to_Int32_error", func(t *testing.T) { + ce := &CastExpr{ + Expr: makeLitStr("hello"), + TargetType: arrow.PrimitiveTypes.Int32, + } + + _, err := EvalCast(ce, rc) + if err == nil { + t.Fatalf("expected cast error, got nil") + } + t.Logf("Cast string->int32 error : %v\n", err) + }) + + // ---------------------------------------- + // 3. CAST Column (age int32) -> Float64 + // ---------------------------------------- + t.Run("Column_age_to_Float64", func(t *testing.T) { + col := &ColumnResolve{Name: "age"} + + ce := &CastExpr{ + Expr: col, + TargetType: arrow.PrimitiveTypes.Float64, + } + + arr, err := EvalExpression(ce, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + orig := rc.Columns[rc.Schema.FieldIndices("age")[0]].(*array.Int32) + out := arr.(*array.Float64) + + if orig.Len() != out.Len() { + t.Fatalf("length mismatch: %d vs %d", orig.Len(), out.Len()) + } + + for i := 0; i < orig.Len(); i++ { + if float64(orig.Value(i)) != out.Value(i) { + t.Fatalf("at %d expected %f, got %f", + i, float64(orig.Value(i)), out.Value(i)) + } + } + }) + + // ---------------------------------------- + // 4. Invalid target type (like trying cast to LargeBinary) + // ---------------------------------------- + t.Run("InvalidTargetType", func(t *testing.T) { + ce := &CastExpr{ + Expr: makeLitInt(5), + TargetType: arrow.BinaryTypes.LargeBinary, + } + + _, err := EvalExpression(ce, rc) + if err == nil { + t.Fatalf("expected error for invalid cast, got nil") + } + }) + + // ---------------------------------------- + // 5. Interface + String check + // ---------------------------------------- + ce := &CastExpr{ + Expr: makeLitInt(1), + TargetType: arrow.PrimitiveTypes.Float64, + } + + // no panic = success + ce.ExprNode() + + t.Logf("%s\n", ce.String()) + }) + +} + +// InvariantExpr is a dummy expression that always returns an error when evaluated +type InvariantExpr struct{} + +func (ie *InvariantExpr) ExprNode() {} +func (ie *InvariantExpr) String() string { + return "InvariantExpr" +} +func TestInvariantExpr(t *testing.T) { + t.Run("InvariantExpr", func(t *testing.T) { + _, err := EvalExpression(&InvariantExpr{}, nil) + if err == nil { + t.Fatalf("expected error for invariant expr eval, got nil") + } + }) +} + +// Tests for ExprDataType and the helper type-inference functions. +func TestExprDataType(t *testing.T) { + // simple schema used for column resolution tests + schema := arrow.NewSchema([]arrow.Field{{Name: "age", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "name", Type: arrow.BinaryTypes.String, Nullable: true}}, nil) + + t.Run("Literal_ReturnsType", func(t *testing.T) { + lit := &LiteralResolve{Type: arrow.PrimitiveTypes.Int32} + got := ExprDataType(lit, nil) + if got.ID() != arrow.INT32 { + t.Fatalf("expected INT32, got %s", got) + } + }) + + t.Run("Column_ReturnsSchemaType", func(t *testing.T) { + col := &ColumnResolve{Name: "age"} + got := ExprDataType(col, schema) + if got.ID() != arrow.INT32 { + t.Fatalf("expected INT32, got %s", got) + } + }) + + t.Run("Column_Unknown_Panics", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic for unknown column, got none") + } + }() + _ = ExprDataType(&ColumnResolve{Name: "missing"}, schema) + }) + + t.Run("Alias_PreservesType", func(t *testing.T) { + a := &Alias{Expr: &LiteralResolve{Type: arrow.PrimitiveTypes.Float64}, Name: "f"} + got := ExprDataType(a, schema) + if got.ID() != arrow.FLOAT64 { + t.Fatalf("expected FLOAT64, got %s", got) + } + }) + + t.Run("Cast_ReturnsTargetType", func(t *testing.T) { + c := &CastExpr{Expr: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}, TargetType: arrow.PrimitiveTypes.Float64} + got := ExprDataType(c, schema) + if got.ID() != arrow.FLOAT64 { + t.Fatalf("expected FLOAT64, got %s", got) + } + }) + + t.Run("Binary_Arithmetic_PromotesToFloat64", func(t *testing.T) { + be := &BinaryExpr{Left: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}, Op: Addition, Right: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}} + got := ExprDataType(be, schema) + if got.ID() != arrow.FLOAT64 { + t.Fatalf("expected FLOAT64 from numericPromotion, got %s", got) + } + }) + + t.Run("Binary_Comparison_ReturnsBoolean", func(t *testing.T) { + be := &BinaryExpr{Left: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}, Op: Equal, Right: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}} + got := ExprDataType(be, schema) + if got.ID() != arrow.BOOL { + t.Fatalf("expected BOOL from comparison, got %s", got) + } + }) + + t.Run("ScalarFunction_Upper_String", func(t *testing.T) { + sf := &ScalarFunction{Function: Upper, Arguments: &LiteralResolve{Type: arrow.BinaryTypes.String}} + got := ExprDataType(sf, schema) + if got.ID() != arrow.STRING { + t.Fatalf("expected STRING from Upper, got %s", got) + } + }) + + t.Run("UnsupportedExpr_Panics", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic for unsupported expr type, got none") + } + }() + // InvariantExpr is a test-only expr that should hit the default case + _ = ExprDataType(&InvariantExpr{}, schema) + }) +} + +func TestInferBinaryType(t *testing.T) { + t.Run("Arithmetic_ReturnsNumericPromotion", func(t *testing.T) { + got := inferBinaryType(arrow.PrimitiveTypes.Int32, Addition, arrow.PrimitiveTypes.Int32) + if got.ID() != arrow.FLOAT64 { + t.Fatalf("expected FLOAT64 from numericPromotion, got %s", got) + } + }) + + t.Run("Comparison_ReturnsBoolean", func(t *testing.T) { + got := inferBinaryType(arrow.PrimitiveTypes.Int32, Equal, arrow.PrimitiveTypes.Int32) + if got.ID() != arrow.BOOL { + t.Fatalf("expected BOOL for comparison, got %s", got) + } + }) + + t.Run("Logical_ReturnsBoolean", func(t *testing.T) { + got := inferBinaryType(arrow.FixedWidthTypes.Boolean, And, arrow.FixedWidthTypes.Boolean) + if got.ID() != arrow.BOOL { + t.Fatalf("expected BOOL for logical op, got %s", got) + } + }) + + t.Run("UnsupportedOperator_Panics", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic for unsupported operator, got none") + } + }() + // Use 'Not' which is not handled in inferBinaryType switch and should panic + _ = inferBinaryType(arrow.PrimitiveTypes.Int32, Not, arrow.PrimitiveTypes.Int32) + }) +} + +func TestNumericPromotion(t *testing.T) { + t.Run("Int32_Int32_ToFloat64", func(t *testing.T) { + got := numericPromotion(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int32) + if got.ID() != arrow.FLOAT64 { + t.Fatalf("expected FLOAT64 from numericPromotion, got %s", got) + } + }) +} + +func TestInferScalarFunctionType(t *testing.T) { + t.Run("Upper_Lower_String", func(t *testing.T) { + got := inferScalarFunctionType(Upper, arrow.BinaryTypes.String) + if got.ID() != arrow.STRING { + t.Fatalf("expected STRING for Upper/Lower, got %s", got) + } + }) + + t.Run("Upper_NonString_Panics", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic for Upper on non-string, got none") + } + }() + _ = inferScalarFunctionType(Upper, arrow.PrimitiveTypes.Int32) + }) + + t.Run("Abs_Round_ReturnSameType", func(t *testing.T) { + got := inferScalarFunctionType(Abs, arrow.PrimitiveTypes.Int32) + if got.ID() != arrow.INT32 { + t.Fatalf("expected same input type for Abs/Round, got %s", got) + } + }) + + t.Run("UnknownFunction_Panics", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic for unknown function, got none") + } + }() + _ = inferScalarFunctionType(supportedFunctions(9999), arrow.PrimitiveTypes.Int32) + }) +} + +// test constructor methods for expressions +func TestExprInitMethods(t *testing.T) { + t.Run("New Alias", func(t *testing.T) { + literal := NewLiteralResolve(arrow.BinaryTypes.String, string("the golfer")) + a := NewAlias(literal, "nickname") + if a == nil { + t.Fatalf("failed to create Alias expression") + } + }) + t.Run("New ColumnResolve", func(t *testing.T) { + cr := NewColumnResolve("age") + if cr == nil { + t.Fatalf("failed to create ColumnResolve expression") + } + }) + t.Run("New LiteralResolve", func(t *testing.T) { + lit := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(42)) + if lit == nil { + t.Fatalf("failed to create LiteralResolve expression") + } + }) + t.Run("New BinaryExpr", func(t *testing.T) { + left := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(10)) + right := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(5)) + be := NewBinaryExpr(left, Addition, right) + if be == nil { + t.Fatalf("failed to create BinaryExpr expression") + } + }) + t.Run("New ScalarFunc", func(t *testing.T) { + arg := NewLiteralResolve(arrow.BinaryTypes.String, string("hello")) + sf := NewScalarFunction(Upper, arg) + if sf == nil { + t.Fatalf("failed to create ScalarFunction expression") + } + }) + t.Run("New CastExpr", func(t *testing.T) { + expr := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(100)) + ce := NewCastExpr(expr, arrow.PrimitiveTypes.Float64) + if ce == nil { + t.Fatalf("failed to create CastExpr expression") + } + }) + t.Run("New Expressions", func(t *testing.T) { + literal := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(7)) + cr := NewColumnResolve("score") + left := NewBinaryExpr(literal, Multiplication, cr) + sf := NewScalarFunction(Abs, left) + ce := NewCastExpr(sf, arrow.PrimitiveTypes.Float64) + if ce == nil { + t.Fatalf("failed to create complex expression") + } + exprs := NewExpressions(literal, cr, left, sf, ce) + if len(exprs) != 5 { + t.Fatalf("expected 5 expressions, got %d", len(exprs)) + } + }) +} diff --git a/src/Backend/opti-sql-go/Expr/info.go b/src/Backend/opti-sql-go/Expr/info.go new file mode 100644 index 0000000..eda4866 --- /dev/null +++ b/src/Backend/opti-sql-go/Expr/info.go @@ -0,0 +1,25 @@ +package Expr + +// evaluate expressions +// for example Column + Literal +// Column - Column +// Literal / Literal + +//1. Arithmetic Expressions +// SELECT salary * 1.2, price + tax, -(discount) +//2.Alias Expressions +//SELECT name AS employee_name, age AS employee_age +//SELECT salary * 1.2 AS new_salary +//3.String Expressions +//first_name || ' ' || last_name +// 3.1 String functions +//UPPER(name) +//LOWER(email) +//SUBSTRING(name, 1, 3) +//4. Function calls +//ABS(x) +//ROUND(salary, 2) +//LENGTH(name) +//COALESCE(a, b) +//5. Constants +//SELECT 1, 'hello', 3.14 diff --git a/src/Backend/opti-sql-go/operators/Expr/expr.go b/src/Backend/opti-sql-go/operators/Expr/expr.go deleted file mode 100644 index 5a91a0b..0000000 --- a/src/Backend/opti-sql-go/operators/Expr/expr.go +++ /dev/null @@ -1,105 +0,0 @@ -package Expr - -import "github.com/apache/arrow/go/v17/arrow" - -type binaryOperator int - -const ( - // arithmetic - addition binaryOperator = 1 - subtraction binaryOperator = 2 - multiplication binaryOperator = 3 - division binaryOperator = 4 - modulous binaryOperator = 5 - // comparison - equal binaryOperator = 6 - notEqual binaryOperator = 7 - lessThan binaryOperator = 8 - lessThanOrEqual binaryOperator = 9 - greaterThan binaryOperator = 10 - greaterThanOrEqual binaryOperator = 11 - // logical - and binaryOperator = 12 - or binaryOperator = 13 - not binaryOperator = 14 -) - -type supportedFunctions int - -const ( - upper supportedFunctions = 1 - lower supportedFunctions = 2 - abs supportedFunctions = 3 - round supportedFunctions = 4 -) - -type aggFunctions = int - -const ( - Sum aggFunctions = 1 - Count aggFunctions = 2 - Avg aggFunctions = 3 - Min aggFunctions = 4 - Max aggFunctions = 5 -) - -/* -Eval(expr): - - match expr: - Literal(x) -> return x - Column(name) -> return array of that column - BinaryExpr(left > right) -> eval left, eval right, apply operator - ScalarFunction(upper(name)) -> evaluate function - Alias(expr, name) -> just a name wrapper -*/ -type Expr interface { - ExprNode() // empty method, only for the sake of polymophism -} - -/* -Alias | sql: select col1 as new_name from table_source -updates the column name in the output schema. -*/ -type Alias struct { - expr []Expr - columnName string - name string -} - -// return batch.Columns[fieldIndex["age"]] -// resolves the arrow array corresponding to name passed in -// sql: select age -type ColumnResolve struct { - name string -} - -// Evaluates to a column of length = batch-size, filled with this literal. -// sql: select 1 -type LiteralResolve struct { - Type arrow.DataType - value any -} - -type Operator struct { - o binaryOperator -} -type BinaryExpr struct { - left []Expr - op Operator - right []Expr -} - -type ScalarFunction struct { - function supportedFunctions - input []Expr // resolve to something you can procees IE, literal/coloumn Resolve -} -type AggregateFunction struct { - function aggFunctions - args []Expr -} - -type CastExpr struct { - expr []Expr // can be a Literal or Column (check for datatype then) - targetType arrow.DataType -} diff --git a/src/Backend/opti-sql-go/operators/Expr/expr_test.go b/src/Backend/opti-sql-go/operators/Expr/expr_test.go deleted file mode 100644 index 0f383f6..0000000 --- a/src/Backend/opti-sql-go/operators/Expr/expr_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package Expr - -import "testing" - -func TestExpr(t *testing.T) { - // Simple passing test -} diff --git a/src/Backend/opti-sql-go/operators/Expr/info.go b/src/Backend/opti-sql-go/operators/Expr/info.go deleted file mode 100644 index e69de29..0000000 diff --git a/src/Backend/opti-sql-go/operators/project/csv.go b/src/Backend/opti-sql-go/operators/project/csv.go index 1a021c6..7f57686 100644 --- a/src/Backend/opti-sql-go/operators/project/csv.go +++ b/src/Backend/opti-sql-go/operators/project/csv.go @@ -124,7 +124,6 @@ func (csvS *CSVSource) processRow( } else { v, err := strconv.ParseInt(cell, 10, 64) if err != nil { - fmt.Printf("failed to parse cell: %v with error: %v\n", cell, err) b.AppendNull() } else { b.Append(v) @@ -137,7 +136,6 @@ func (csvS *CSVSource) processRow( } else { v, err := strconv.ParseFloat(cell, 64) if err != nil { - fmt.Printf("failed to parse cell: %v with error: %v\n", cell, err) b.AppendNull() } else { b.Append(v) diff --git a/src/Backend/opti-sql-go/operators/project/custom.go b/src/Backend/opti-sql-go/operators/project/custom.go index 38f2a94..e36fa0c 100644 --- a/src/Backend/opti-sql-go/operators/project/custom.go +++ b/src/Backend/opti-sql-go/operators/project/custom.go @@ -234,3 +234,36 @@ func supportedType(col any) bool { return false } } + +// handle keeping only the request columns but make sure the schema and columns are also aligned +// returns error if a column doesnt exist +func ProjectSchemaFilterDown(schema *arrow.Schema, cols []arrow.Array, keepCols ...string) (*arrow.Schema, []arrow.Array, error) { + if len(keepCols) == 0 { + return arrow.NewSchema([]arrow.Field{}, nil), nil, ErrEmptyColumnsToProject + } + + // Build map: columnName -> original index + fieldIndex := make(map[string]int) // age -> 0 + for i, f := range schema.Fields() { + fieldIndex[f.Name] = i + } + + newFields := make([]arrow.Field, 0, len(keepCols)) + newCols := make([]arrow.Array, 0, len(keepCols)) + + // Preserve order from keepCols, not schema order + for _, name := range keepCols { + idx, exists := fieldIndex[name] + if !exists { + return arrow.NewSchema([]arrow.Field{}, nil), []arrow.Array{}, ErrProjectColumnNotFound + } + + newFields = append(newFields, schema.Field(idx)) + col := cols[idx] + col.Retain() + newCols = append(newCols, col) + } + + newSchema := arrow.NewSchema(newFields, nil) + return newSchema, newCols, nil +} diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index f633416..5dcd3bf 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -2,7 +2,9 @@ package project import ( "errors" + "fmt" "io" + "opti-sql-go/Expr" "opti-sql-go/operators" "github.com/apache/arrow/go/v17/arrow" @@ -18,23 +20,40 @@ var ( ) type ProjectExec struct { - child operators.Operator - outputschema arrow.Schema - columnsToKeep []string - done bool + child operators.Operator + outputschema arrow.Schema + expr []Expr.Expression + done bool } // columns to keep and existing schema -func NewProjectExec(projectColumns []string, input operators.Operator) (*ProjectExec, error) { - newSchema, err := prunedSchema(input.Schema(), projectColumns) - if err != nil { - return nil, err +func NewProjectExec(input operators.Operator, exprs []Expr.Expression) (*ProjectExec, error) { + fields := make([]arrow.Field, len(exprs)) + for i, e := range exprs { + switch ex := e.(type) { + case *Expr.Alias: + fields[i] = arrow.Field{ + Name: ex.Name, + Type: Expr.ExprDataType(ex.Expr, input.Schema()), + Nullable: true, + } + default: + name := fmt.Sprintf("col_%d", i) + Type := Expr.ExprDataType(e, input.Schema()) + fields[i] = arrow.Field{ + Name: name, + Type: Type, + Nullable: true, + } + } } + + outputschema := arrow.NewSchema(fields, nil) // return new exec return &ProjectExec{ - child: input, - outputschema: *newSchema, - columnsToKeep: projectColumns, + child: input, + outputschema: *outputschema, + expr: exprs, }, nil } @@ -45,24 +64,33 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { return nil, io.EOF } - rc, err := p.child.Next(n) + childBatch, err := p.child.Next(n) if err != nil { return nil, err } - _, orderCols, err := ProjectSchemaFilterDown(rc.Schema, rc.Columns, p.columnsToKeep...) - if err != nil { - return nil, err + if childBatch.RowCount == 0 { + p.done = true + return &operators.RecordBatch{ + Schema: &p.outputschema, + Columns: []arrow.Array{}, + RowCount: 0, + }, nil } - for _, c := range rc.Columns { - c.Release() + outPutCols := make([]arrow.Array, len(p.expr)) + for i, e := range p.expr { + arr, err := Expr.EvalExpression(e, childBatch) + if err != nil { + return nil, fmt.Errorf("project eval expression failed for expr %d: %w", i, err) + } + outPutCols[i] = arr } - if rc.RowCount == 0 { - p.done = true + for _, c := range childBatch.Columns { + c.Release() } return &operators.RecordBatch{ Schema: &p.outputschema, - Columns: orderCols, - RowCount: rc.RowCount, + Columns: outPutCols, + RowCount: childBatch.RowCount, }, nil } func (p *ProjectExec) Close() error { @@ -74,36 +102,6 @@ func (p *ProjectExec) Schema() *arrow.Schema { // handle keeping only the request columns but make sure the schema and columns are also aligned // returns error if a column doesnt exist -func ProjectSchemaFilterDown(schema *arrow.Schema, cols []arrow.Array, keepCols ...string) (*arrow.Schema, []arrow.Array, error) { - if len(keepCols) == 0 { - return arrow.NewSchema([]arrow.Field{}, nil), nil, ErrEmptyColumnsToProject - } - - // Build map: columnName -> original index - fieldIndex := make(map[string]int) // age -> 0 - for i, f := range schema.Fields() { - fieldIndex[f.Name] = i - } - - newFields := make([]arrow.Field, 0, len(keepCols)) - newCols := make([]arrow.Array, 0, len(keepCols)) - - // Preserve order from keepCols, not schema order - for _, name := range keepCols { - idx, exists := fieldIndex[name] - if !exists { - return arrow.NewSchema([]arrow.Field{}, nil), []arrow.Array{}, ErrProjectColumnNotFound - } - - newFields = append(newFields, schema.Field(idx)) - col := cols[idx] - col.Retain() - newCols = append(newCols, col) - } - - newSchema := arrow.NewSchema(newFields, nil) - return newSchema, newCols, nil -} func prunedSchema(schema *arrow.Schema, keepCols []string) (*arrow.Schema, error) { if len(keepCols) == 0 { diff --git a/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go b/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go new file mode 100644 index 0000000..68eeb61 --- /dev/null +++ b/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go @@ -0,0 +1,843 @@ +package project + +import ( + "math" + "opti-sql-go/Expr" + "strings" + "testing" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +func generateData() ([]string, []any) { + names := []string{"id", "name", "age", "active", "score"} + values := []any{ + []int{1, 2, 3, 4, 5, 6}, + []string{"Ainsley Coffey", "Kody Frazier", "Octavia Truong", "Ayan Gonzalez", "Abigail Castro", "Clay McDaniel"}, + []int8{10, 12, 35, 76, 42, 63}, + []bool{false, true, false, true, true, true}, + []float32{98.6, 75.4, 88.1, 92.3, 79.5, 85.0}, + } + return names, values + +} + +/* +project: column +project: column1, column2, column3 +sql : select id,age,name from table +*/ +func TestProjectExec_Column_sql(t *testing.T) { + names, cols := generateData() + t.Run("select id", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + exprs := []Expr.Expression{ + &Expr.ColumnResolve{Name: "id"}, + } + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(6) + if err != nil { + t.Fatalf("err:%v\n", err) + } + t.Logf("rc:%v\n", rc) + column := rc.Columns[0] + if column.DataType() != arrow.PrimitiveTypes.Int64 { + t.Fatalf("expected int64, got %s", column.DataType().Name()) + } + idCol, ok := column.(*array.Int64) + if !ok { + t.Fatalf("expected Int64 array, got %T", column) + } + expectedValues := []int64{1, 2, 3, 4, 5, 6} + for i, v := range expectedValues { + if idCol.Value(i) != v { + t.Fatalf("at index %d: expected %d, got %d", i, v, idCol.Value(i)) + } + } + + }) + t.Run("select name", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + exprs := []Expr.Expression{ + &Expr.ColumnResolve{Name: "name"}, + } + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(6) + if err != nil { + t.Fatalf("err:%v\n", err) + } + t.Logf("rc:%v\n", rc) + column := rc.Columns[0] + if column.DataType() != arrow.BinaryTypes.String { + t.Fatalf("expected string, got %s", column.DataType().Name()) + } + nameCol, ok := column.(*array.String) + if !ok { + t.Fatalf("expected String array, got %T", column) + } + expectedValues := []string{"Ainsley Coffey", "Kody Frazier", "Octavia Truong", "Ayan Gonzalez", "Abigail Castro", "Clay McDaniel"} + for i, v := range expectedValues { + if nameCol.Value(i) != v { + t.Fatalf("at index %d: expected %s, got %s", i, v, nameCol.Value(i)) + } + } + }) + t.Run("select age", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + exprs := []Expr.Expression{ + &Expr.ColumnResolve{Name: "age"}, + } + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(6) + if err != nil { + t.Fatalf("err:%v\n", err) + } + t.Logf("rc:%v\n", rc) + column := rc.Columns[0] + if column.DataType() != arrow.PrimitiveTypes.Int8 { + t.Fatalf("expected int8, got %s", column.DataType().Name()) + } + ageCol, ok := column.(*array.Int8) + if !ok { + t.Fatalf("expected Int8 array, got %T", column) + } + expectedValues := []int8{10, 12, 35, 76, 42, 63} + for i, v := range expectedValues { + if ageCol.Value(i) != v { + t.Fatalf("at index %d: expected %d, got %d", i, v, ageCol.Value(i)) + } + } + }) + +} + +// these test with no base table +// select 1 from table +func TestProjectExec_Literal_sql(t *testing.T) { + names, cols := generateData() + t.Run("select 1", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + exprs := []Expr.Expression{ + &Expr.LiteralResolve{Type: arrow.PrimitiveTypes.Int64, Value: int64(4)}, + } + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(1) + if err != nil { + t.Fatalf("err:%v\n", err) + } + column := rc.Columns[0] + if column.DataType() != arrow.PrimitiveTypes.Int64 { + t.Fatalf("expected int64, got %s", column.DataType().Name()) + } + if column.ValueStr(0) != "4" { + t.Fatalf("expected 4, got %s", column.ValueStr(0)) + } + + }) + t.Run("select 'hello'", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + exprs := []Expr.Expression{ + &Expr.LiteralResolve{Type: arrow.BinaryTypes.String, Value: string("hello")}, + } + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(1) + if err != nil { + t.Fatalf("err:%v\n", err) + } + column := rc.Columns[0] + if column.DataType() != arrow.BinaryTypes.String { + t.Fatalf("expected column types to be of type string but recieved %s\n", column.DataType()) + } + if column.ValueStr(0) != "hello" { + t.Fatalf("expected hello, got %s", column.ValueStr(0)) + } + t.Logf("rc:%v\n", rc) + + }) + t.Run("select 3.14", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + exprs := []Expr.Expression{ + &Expr.LiteralResolve{Type: arrow.PrimitiveTypes.Float32, Value: float32(3.14)}, + } + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(2) + if err != nil { + t.Fatalf("err:%v\n", err) + } + column := rc.Columns[0] + if column.DataType() != arrow.PrimitiveTypes.Float32 { + t.Fatalf("expected column types to be of type string but recieved %s\n", column.DataType()) + } + floatArr, ok := column.(*array.Float32) + if !ok { + t.Fatalf("expected Float32 array, got %T", column) + } + for i := 0; i < floatArr.Len(); i++ { + if floatArr.Value(i) != float32(3.14) { + t.Fatalf("at index %d: expected %f, got %f", i, float32(3.14), floatArr.Value(i)) + } + } + t.Logf("rc:%v\n", rc) + + }) + +} + +/* +Project: Literal |Operator| Literal +*/ +func TestProjectExec_Literal_Literal(t *testing.T) { + names, cols := generateData() + t.Run("Age plus constant", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + exprs := []Expr.Expression{ + &Expr.BinaryExpr{Left: &Expr.ColumnResolve{Name: "age"}, Op: Expr.Addition, Right: &Expr.LiteralResolve{Type: arrow.PrimitiveTypes.Int64, Value: int64(10)}}, + } + for _, e := range exprs { + t.Log(e.String()) + } + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(4) + if err != nil { + t.Fatalf("unexpected error: %v\n", err) + } + ageCol, ok := rc.Columns[0].(*array.Int64) + if !ok { + t.Fatalf("expected column to be of type Int64 but got %T\n", rc.Columns[0]) + } + expected := []int64{20, 22, 45, 86} + if ageCol.Len() != len(expected) { + t.Fatalf("mismatch in expected column length, recieved column of len %d", ageCol.Len()) + } + for i := 0; i < len(expected); i++ { + if ageCol.Value(i) != expected[i] { + t.Fatalf("expected %d at position %d, but recieved %d", expected[i], i, ageCol.Value(i)) + } + } + }) + + t.Run("id multiplied by constant", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + &Expr.BinaryExpr{ + Left: &Expr.ColumnResolve{Name: "id"}, + Op: Expr.Multiplication, + Right: &Expr.LiteralResolve{Type: arrow.PrimitiveTypes.Int64, Value: int64(3)}, + }, + } + + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(4) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + idCol, ok := rc.Columns[0].(*array.Int64) + if !ok { + t.Fatalf("expected Int64 column, got %T", rc.Columns[0]) + } + + expected := []int64{3, 6, 9, 12} + + if idCol.Len() != len(expected) { + t.Fatalf("expected %d rows, got %d", len(expected), idCol.Len()) + } + + for i := 0; i < len(expected); i++ { + if idCol.Value(i) != expected[i] { + t.Fatalf("at index %d: expected %d, got %d", i, expected[i], idCol.Value(i)) + } + } + }) + // column |operator| nestedLiteralExpr + t.Run("select score - (5+4)", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + nested := &Expr.BinaryExpr{ + Left: &Expr.LiteralResolve{Type: arrow.PrimitiveTypes.Int64, Value: int64(5)}, + Op: Expr.Addition, + Right: &Expr.LiteralResolve{Type: arrow.PrimitiveTypes.Int64, Value: int64(4)}, + } + + exprs := []Expr.Expression{ + &Expr.BinaryExpr{ + Left: &Expr.ColumnResolve{Name: "score"}, + Op: Expr.Subtraction, + Right: nested, + }, + } + + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(6) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + scoreCol, ok := rc.Columns[0].(*array.Float32) + if !ok { + t.Fatalf("expected Float32 column, got %T", rc.Columns[0]) + } + + expected := []float32{ + 98.6 - 9, + 75.4 - 9, + 88.1 - 9, + 92.3 - 9, + 79.5 - 9, + 85.0 - 9, + } + + if scoreCol.Len() != len(expected) { + t.Fatalf("expected %d rows, got %d", len(expected), scoreCol.Len()) + } + + for i := 0; i < len(expected); i++ { + if diff := scoreCol.Value(i) - expected[i]; diff > 1e-5 || diff < -1e-5 { + t.Fatalf("expected %f at index %d, got %f", expected[i], i, scoreCol.Value(i)) + } + } + + }) + t.Run("select age / (2*3)", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + mult := &Expr.BinaryExpr{ + Left: &Expr.LiteralResolve{Type: arrow.PrimitiveTypes.Int64, Value: int64(2)}, + Op: Expr.Multiplication, + Right: &Expr.LiteralResolve{Type: arrow.PrimitiveTypes.Int64, Value: int64(3)}, + } + + exprs := []Expr.Expression{ + &Expr.BinaryExpr{ + Left: &Expr.ColumnResolve{Name: "age"}, + Op: Expr.Division, + Right: mult, + }, + } + + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(4) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ageCol, ok := rc.Columns[0].(*array.Int64) + if !ok { + t.Fatalf("expected Int64 column, got %T", rc.Columns[0]) + } + + // doesnt cast by default + // age / 6 + expected := []int64{10 / 6, 12 / 6, 35 / 6, 76 / 6} + + if ageCol.Len() != len(expected) { + t.Fatalf("expected %d rows, got %d", len(expected), ageCol.Len()) + } + + for i := 0; i < len(expected); i++ { + if ageCol.Value(i) != expected[i] { + t.Fatalf("expected %d at index %d, got %d", expected[i], i, ageCol.Value(i)) + } + } + }) +} + +/* +Project: cast literal +project: cast column +*/ +func TestProjectExec_CastLiteral_Column(t *testing.T) { + names, cols := generateData() + t.Run("select age as float32", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + Expr.NewCastExpr(Expr.NewColumnResolve("age"), arrow.PrimitiveTypes.Float32), + } + + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(6) + if err != nil { + t.Fatalf("error: %v", err) + } + + col, ok := rc.Columns[0].(*array.Float32) + if !ok { + t.Fatalf("expected Float32 column, got %T", rc.Columns[0]) + } + + expected := []float32{10, 12, 35, 76, 42, 63} + for i := 0; i < len(expected); i++ { + if col.Value(i) != expected[i] { + t.Fatalf("expected %f at %d, got %f", expected[i], i, col.Value(i)) + } + } + }) + t.Run("select age as int16", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + Expr.NewCastExpr(Expr.NewColumnResolve("age"), arrow.PrimitiveTypes.Int16), + } + + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(6) + if err != nil { + t.Fatalf("error: %v", err) + } + + col, ok := rc.Columns[0].(*array.Int16) + if !ok { + t.Fatalf("expected Int16 column, got %T", rc.Columns[0]) + } + + expected := []int16{10, 12, 35, 76, 42, 63} + for i := 0; i < len(expected); i++ { + if col.Value(i) != expected[i] { + t.Fatalf("expected %d at %d, got %d", expected[i], i, col.Value(i)) + } + } + }) + // should fail + t.Run("select name as int32", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + Expr.NewCastExpr(Expr.NewColumnResolve("name"), arrow.PrimitiveTypes.Int32), + } + + proj, _ := NewProjectExec(memSrc, exprs) + _, err := proj.Next(6) + + if err == nil { + t.Fatalf("expected cast error but got nil") + } + }) + t.Run("select 4 as float64", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + Expr.NewCastExpr( + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int64, int64(4)), + arrow.PrimitiveTypes.Float64, + ), + } + + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(6) + if err != nil { + t.Fatalf("error: %v", err) + } + + col, ok := rc.Columns[0].(*array.Float64) + if !ok { + t.Fatalf("expected Float64 column, got %T", rc.Columns[0]) + } + + for i := 0; i < col.Len(); i++ { + if col.Value(i) != 4.0 { + t.Fatalf("expected 4.0 at %d, got %f", i, col.Value(i)) + } + } + + }) + // should be a no op + t.Run("select 'richard' as string", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + Expr.NewCastExpr( + &Expr.LiteralResolve{Type: arrow.BinaryTypes.String, Value: "richard"}, + arrow.BinaryTypes.String, + ), + } + + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(3) + if err != nil { + t.Fatalf("error: %v", err) + } + + col, ok := rc.Columns[0].(*array.String) + if !ok { + t.Fatalf("expected String column, got %T", rc.Columns[0]) + } + + for i := 0; i < col.Len(); i++ { + if col.Value(i) != "richard" { + t.Fatalf("expected 'richard' at %d, got %s", i, col.Value(i)) + } + } + }) +} + +/* +Column Name | (Operator) | type(value) +Value is applied to every element of column +*/ +func TestProjectExec_Column_Literal(t *testing.T) { + names, cols := generateData() + t.Run("age + 10", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := Expr.NewExpressions( + Expr.NewBinaryExpr( + Expr.NewColumnResolve("age"), + Expr.Addition, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int8, int8(10)), + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(6) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + out := rc.Columns[0].(*array.Int8) + expected := []int8{20, 22, 45, 86, 52, 73} + for i := 0; i < len(expected); i++ { + if out.Value(i) != expected[i] { + t.Fatalf("expected %d got %d at %d", expected[i], out.Value(i), i) + } + } + }) + t.Run("score - 5.0", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := Expr.NewExpressions( + Expr.NewBinaryExpr( + Expr.NewColumnResolve("score"), + Expr.Subtraction, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float32, float32(5.0)), + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + rc, err := proj.Next(6) + if err != nil { + t.Fatalf("err: %v", err) + } + + out := rc.Columns[0].(*array.Float32) + expected := []float32{93.6, 70.4, 83.1, 87.3, 74.5, 80.0} + for i := range expected { + if math.Abs(float64(out.Value(i)-expected[i])) > 0.0001 { + t.Fatalf("expected %f got %f", expected[i], out.Value(i)) + } + } + }) + + t.Run("id * 2", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + exprs := Expr.NewExpressions( + Expr.NewBinaryExpr( + Expr.NewColumnResolve("id"), + Expr.Multiplication, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int64, int64(2)), + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + rc, _ := proj.Next(6) + out := rc.Columns[0].(*array.Int64) + + expected := []int64{2, 4, 6, 8, 10, 12} + for i := range expected { + if out.Value(i) != expected[i] { + t.Fatalf("expected %d got %d", expected[i], out.Value(i)) + } + } + }) + + t.Run("score / 2", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + exprs := Expr.NewExpressions( + Expr.NewBinaryExpr( + Expr.NewColumnResolve("score"), + Expr.Division, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float32, float32(2)), + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + rc, _ := proj.Next(6) + out := rc.Columns[0].(*array.Float32) + + expected := []float32{49.3, 37.7, 44.05, 46.15, 39.75, 42.5} + for i := range expected { + if math.Abs(float64(out.Value(i)-expected[i])) > 0.0001 { + t.Fatalf("expected %f got %f", expected[i], out.Value(i)) + } + } + }) +} + +/* +Alias(column |operator| literal) +*/ +func TestProjectExec_AliasExpr(t *testing.T) { + names, cols := generateData() + + t.Run("alias column id → identifier", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := Expr.NewExpressions( + Expr.NewAlias( + Expr.NewColumnResolve("id"), + "identifier", + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + schema := proj.Schema() + + if schema.Field(0).Name != "identifier" { + t.Fatalf("expected alias name identifier, got %s", schema.Field(0).Name) + } + }) + + t.Run("alias expression (age + 10) → boosted_age", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := Expr.NewExpressions( + Expr.NewAlias( + Expr.NewBinaryExpr( + Expr.NewColumnResolve("age"), + Expr.Addition, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int8, int8(10)), + ), + "boosted_age", + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + rc, _ := proj.Next(3) + + if proj.Schema().Field(0).Name != "boosted_age" { + t.Fatalf("alias not applied") + } + + out := rc.Columns[0].(*array.Int8) + expected := []int8{20, 22, 45} + + for i := range expected { + if out.Value(i) != expected[i] { + t.Fatalf("expected %d got %d", expected[i], out.Value(i)) + } + } + }) + + t.Run("alias literal → constant_value", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := Expr.NewExpressions( + Expr.NewAlias( + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(7)), + "constant_value", + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + rc, _ := proj.Next(3) + + if proj.Schema().Field(0).Name != "constant_value" { + t.Fatalf("alias not applied") + } + + out := rc.Columns[0].(*array.Int32) + for i := 0; i < out.Len(); i++ { + if out.Value(i) != 7 { + t.Fatalf("expected literal 7, got %d", out.Value(i)) + } + } + }) + + t.Run("alias nested expr (score - (2+3)) → final_score", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + inner := Expr.NewBinaryExpr( + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(2)), + Expr.Addition, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(3)), + ) + + exprs := Expr.NewExpressions( + Expr.NewAlias( + Expr.NewBinaryExpr( + Expr.NewColumnResolve("score"), + Expr.Subtraction, + inner, + ), + "final_score", + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + rc, _ := proj.Next(3) + + if proj.Schema().Field(0).Name != "final_score" { + t.Fatalf("alias not applied") + } + + out := rc.Columns[0].(*array.Float32) + expected := []float32{ + 98.6 - 5, + 75.4 - 5, + 88.1 - 5, + } + + for i := range expected { + if math.Abs(float64(out.Value(i)-expected[i])) > 0.0001 { + t.Fatalf("expected %f got %f", expected[i], out.Value(i)) + } + } + }) +} + +/* +function(column/literal) +function(column |operator| literal) +function(columh/literal |operator| literal/column) +*/ +func TestProjectExec_FunctionExpr(t *testing.T) { + names, cols := generateData() // id, name, age, active, score + + // ------------------------------------------------------------ + // 1. UPPER(column) + // ------------------------------------------------------------ + t.Run("UPPER(name)", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := Expr.NewExpressions( + Expr.NewScalarFunction( + Expr.Upper, + Expr.NewColumnResolve("name"), + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + rb, err := proj.Next(3) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := rb.Columns[0].(*array.String) + expected := []string{ + "AINSLEY COFFEY", + "KODY FRAZIER", + "OCTAVIA TRUONG", + } + + for i := range expected { + if out.Value(i) != expected[i] { + t.Fatalf("expected %s got %s", expected[i], out.Value(i)) + } + } + }) + + t.Run("LOWER('MonKey_x')", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + expr := Expr.NewLiteralResolve(arrow.BinaryTypes.String, string("MoNKey_X")) + + exprs := Expr.NewExpressions( + Expr.NewScalarFunction( + Expr.Lower, + expr, + ), + ) + + proj, _ := NewProjectExec(memSrc, exprs) + rb, _ := proj.Next(2) + + out := rb.Columns[0].(*array.String) + t.Logf("columns: %v\n", out) + expected := []string{ + strings.ToLower("monkey_x"), + strings.ToLower("monkey_x"), + } + + for i := range expected { + if out.Value(i) != expected[i] { + t.Fatalf("expected %s got %s", expected[i], out.Value(i)) + } + } + }) + + // ------------------------------------------------------------ + // 3. ABS(column |operator| literal) + // ABS(score - 100.0) + // ------------------------------------------------------------ + t.Run("ABS(score - 100)", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + expr := Expr.NewScalarFunction( + Expr.Abs, + Expr.NewBinaryExpr( + Expr.NewColumnResolve("score"), + Expr.Subtraction, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float32, float32(100)), + ), + ) + + proj, _ := NewProjectExec(memSrc, Expr.NewExpressions(expr)) + rb, _ := proj.Next(3) + + out := rb.Columns[0].(*array.Float32) + + expected := []float32{ + float32(math.Abs(98.6 - 100)), + float32(math.Abs(75.4 - 100)), + float32(math.Abs(88.1 - 100)), + } + + for i := range expected { + if math.Abs(float64(out.Value(i)-expected[i])) > 0.0001 { + t.Fatalf("expected %f got %f", expected[i], out.Value(i)) + } + } + }) + + // ------------------------------------------------------------ + // 4. ROUND(literal |operator| column) + // ROUND(2.5 * score) + // ------------------------------------------------------------ + t.Run("ROUND(2.5 * score)", func(t *testing.T) { + memSrc, _ := NewInMemoryProjectExec(names, cols) + + expr := Expr.NewScalarFunction( + Expr.Round, + Expr.NewBinaryExpr( + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(2.5)), + Expr.Multiplication, + Expr.NewColumnResolve("score"), + ), + ) + + proj, _ := NewProjectExec(memSrc, Expr.NewExpressions(expr)) + rb, _ := proj.Next(3) + + out := rb.Columns[0].(*array.Float64) + expected := []float64{ + math.Round(2.5 * 98.6), + math.Round(2.5 * 75.4), + math.Round(2.5 * 88.1), + } + + for i := range expected { + if math.Abs(out.Value(i)-expected[i]) > 1 { + t.Fatalf("expected %f got %f", expected[i], out.Value(i)) + } + } + }) +} + +/* +complex expr +ex: alias(function(column |operator| literal) |operator| literal) +TODO: not the most imporatnt thing right now since we know basic expression are fine +*/ +func TestProjectExec_ComplexExpr(t *testing.T) {} diff --git a/src/Backend/opti-sql-go/operators/project/projectExec_test.go b/src/Backend/opti-sql-go/operators/project/projectExec_test.go index 04a0ecd..3415d47 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec_test.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec_test.go @@ -3,273 +3,280 @@ package project import ( "errors" "io" + "opti-sql-go/Expr" "testing" "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" ) -func TestProjectExecInit(t *testing.T) { - // Simple passing test +func TestProjectExec_Init(t *testing.T) { + names, cols := generateTestColumns() + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + + exprs := []Expr.Expression{ + &Expr.ColumnResolve{Name: "id"}, + &Expr.ColumnResolve{Name: "name"}, + } + + proj, err := NewProjectExec(memSrc, exprs) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } + + schema := proj.Schema() + if schema.NumFields() != len(exprs) { + t.Fatalf("expected %d fields, got %d", len(exprs), schema.NumFields()) + } } -func TestProjectPrune(t *testing.T) { - fields := []arrow.Field{ - {Name: "id", Type: arrow.PrimitiveTypes.Int64}, - {Name: "name", Type: arrow.BinaryTypes.String}, - {Name: "age", Type: arrow.PrimitiveTypes.Int64}, - {Name: "country", Type: arrow.BinaryTypes.String}, - {Name: "email", Type: arrow.BinaryTypes.String}, - {Name: "signup_date", Type: arrow.FixedWidthTypes.Date32}, +func TestProjectExec_BasicColumns(t *testing.T) { + names, cols := generateTestColumns() + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + &Expr.ColumnResolve{Name: "id"}, + &Expr.ColumnResolve{Name: "name"}, } - schema := arrow.NewSchema(fields, nil) - t.Run("validate prune 1", func(t *testing.T) { - keepCols := []string{"id", "name", "email"} - newSchema, err := prunedSchema(schema, keepCols) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if newSchema.NumFields() != len(keepCols) { - t.Fatalf("expected %d fields, got %d", len(keepCols), newSchema.NumFields()) - } - for i, field := range newSchema.Fields() { - if field.Name != keepCols[i] { - t.Fatalf("expected field %s, got %s", keepCols[i], field.Name) - } - } - t.Logf("%s\n", newSchema) - }) - t.Run("validate prune 2", func(t *testing.T) { - keeptCols := []string{"age", "country", "signup_date"} - newSchema, err := prunedSchema(schema, keeptCols) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if newSchema.NumFields() != len(keeptCols) { - t.Fatalf("expected %d fields, got %d", len(keeptCols), newSchema.NumFields()) - } - for i, field := range newSchema.Fields() { - if field.Name != keeptCols[i] { - t.Fatalf("expected field %s, got %s", keeptCols[i], field.Name) - } - } - t.Logf("%s\n", newSchema) - }) - t.Run("prune non-existant column", func(t *testing.T) { - keepCols := []string{"id", "non_existing_column"} - _, err := prunedSchema(schema, keepCols) - if err == nil { - t.Fatalf("expected error for non-existing column, got nil") - } - if !errors.Is(err, ErrProjectColumnNotFound) { - t.Fatalf("expected ErrProjectColumnNotFound, got %v", err) - } + projExec, err := NewProjectExec(memSrc, exprs) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } - }) - t.Run("Prune empty input keepcols", func(t *testing.T) { - keepCols := []string{} - _, err := prunedSchema(schema, keepCols) - if err == nil { - t.Fatalf("expected error for empty keepcols, got nil") - } - if !errors.Is(err, ErrEmptyColumnsToProject) { - t.Fatalf("expected ErrEmptyColumnsToProject, got %v", err) - } - }) + rb, err := projExec.Next(3) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + if len(rb.Columns) != 2 { + t.Fatalf("expected 2 columns, got %d", len(rb.Columns)) + } + for _, c := range rb.Columns { + c.Release() + } + t.Logf("record batch: %+v", rb) } -func TestProjectExec(t *testing.T) { - names, col := generateTestColumns() - memorySource, err := NewInMemoryProjectExec(names, col) - if err != nil { - t.Fatalf("failed to create in memory source: %v", err) + +func TestProjectExec_Alias(t *testing.T) { + names, cols := generateTestColumns() + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + &Expr.Alias{ + Expr: &Expr.ColumnResolve{Name: "id"}, + Name: "identifier", + }, + &Expr.Alias{ + Expr: &Expr.ColumnResolve{Name: "name"}, + Name: "full_name", + }, } - t.Logf("original schema %v\n", memorySource.Schema()) - projectExec, err := NewProjectExec([]string{"id", "name", "age"}, memorySource) + + projExec, err := NewProjectExec(memSrc, exprs) if err != nil { t.Fatalf("failed to create project exec: %v", err) } - rc, err := projectExec.Next(3) + + schema := projExec.Schema() + + if schema.Field(0).Name != "identifier" { + t.Fatalf("expected identifier, got %s", schema.Field(0).Name) + } + if schema.Field(1).Name != "full_name" { + t.Fatalf("expected full_name, got %s", schema.Field(1).Name) + } + t.Logf("schema %v\n", schema) +} + +func TestProjectExec_Literal(t *testing.T) { + names, cols := generateTestColumns() + memSrc, _ := NewInMemoryProjectExec(names, cols) + + lit := &Expr.LiteralResolve{ + Type: arrow.PrimitiveTypes.Int64, + Value: int64(99), + } + projExec, err := NewProjectExec(memSrc, []Expr.Expression{lit}) if err != nil { - t.Fatalf("failed to get next record batch: %v", err) + t.Fatalf("failed to init project exec: %v", err) } - t.Logf("rc:%v\n", rc) + rb, err := projExec.Next(5) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + col := rb.Columns[0].(*array.Int64) + for i := 0; i < col.Len(); i++ { + if col.Value(i) != 99 { + t.Fatalf("expected 99, got %d", col.Value(i)) + } + } + + for _, c := range rb.Columns { + c.Release() + } } -// NewProjectExec, pruned schema errors and iteration behavior. -func TestProjectExec_Subtests(t *testing.T) { +func TestProjectExec_BinaryAdd(t *testing.T) { names, cols := generateTestColumns() + memSrc, _ := NewInMemoryProjectExec(names, cols) - t.Run("ValidProjection", func(t *testing.T) { - memSrc, err := NewInMemoryProjectExec(names, cols) - if err != nil { - t.Fatalf("failed to create in memory source: %v", err) - } - projCols := []string{"id", "name", "age"} - projExec, err := NewProjectExec(projCols, memSrc) - if err != nil { - t.Fatalf("failed to create project exec: %v", err) + expr := &Expr.BinaryExpr{ + Left: &Expr.ColumnResolve{Name: "age"}, + Op: Expr.Addition, + Right: &Expr.LiteralResolve{Type: arrow.PrimitiveTypes.Int64, Value: int64(1)}, + } + + projExec, err := NewProjectExec(memSrc, []Expr.Expression{expr}) + if err != nil { + t.Fatalf("failed: %v", err) + } + + rb, err := projExec.Next(3) + if err != nil { + t.Fatalf("Next failed: %v", err) + } + + if len(rb.Columns) != 1 { + t.Fatalf("expected 1 column, got %d", len(rb.Columns)) + } + + for _, c := range rb.Columns { + c.Release() + } + t.Logf("column: %+v", rb.Columns[0]) +} + +func TestProjectExec_IterateEOF(t *testing.T) { + names, cols := generateTestColumns() + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + &Expr.ColumnResolve{Name: "id"}, + } + + projExec, _ := NewProjectExec(memSrc, exprs) + + count := 0 + for { + rb, err := projExec.Next(2) + if errors.Is(err, io.EOF) { + break } - rb, err := projExec.Next(4) if err != nil { - t.Fatalf("Next failed: %v", err) - } - if rb == nil { - t.Fatalf("expected a record batch, got nil") - } - if len(rb.Columns) != len(projCols) { - t.Fatalf("expected %d columns, got %d", len(projCols), len(rb.Columns)) + t.Fatalf("unexpected err: %v", err) } + count += rb.Columns[0].Len() for _, c := range rb.Columns { c.Release() } - }) + } - t.Run("EmptyColumns", func(t *testing.T) { - memSrc, err := NewInMemoryProjectExec(names, cols) - if err != nil { - t.Fatalf("failed to create in memory source: %v", err) - } - _, err = NewProjectExec([]string{}, memSrc) + if count == 0 { + t.Fatalf("expected some rows, got 0") + } +} + +func TestProjectExec_Close(t *testing.T) { + names, cols := generateTestColumns() + memSrc, _ := NewInMemoryProjectExec(names, cols) + + exprs := []Expr.Expression{ + &Expr.ColumnResolve{Name: "id"}, + } + + projExec, _ := NewProjectExec(memSrc, exprs) + if err := projExec.Close(); err != nil { + t.Fatalf("close returned error: %v", err) + } +} +func TestEOFBehavior(t *testing.T) { + names, cols := []string{"name"}, []any{ + []string{"richard"}, + } + + memSrc, err := NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to create in memory source: %v", err) + } + proj, err := NewProjectExec(memSrc, []Expr.Expression{&Expr.ColumnResolve{Name: "name"}}) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } + rc, err := proj.Next(1) + if err != nil { + t.Fatalf("unexpected error on first Next: %v", err) + } + nameCol, ok := rc.Columns[0].(*array.String) + if !ok { + t.Fatalf("expected String array, got %T", rc.Columns[0]) + } + if nameCol.Value(0) != "richard" { + t.Fatalf("expected 'richard', got '%s'", nameCol.Value(0)) + } + + _, err = proj.Next(10) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF error on second Next, got %v", err) + } +} + +func TestPruneSchema_cvg(t *testing.T) { + // build a sample schema + fields := []arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "name", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "age", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + } + schema := arrow.NewSchema(fields, nil) + + t.Run("EmptyKeepCols_ReturnsErrAndEmptySchema", func(t *testing.T) { + s, err := prunedSchema(schema, []string{}) if err == nil { - t.Fatalf("expected error for empty project columns, got nil") + t.Fatalf("expected ErrEmptyColumnsToProject, got nil") } if !errors.Is(err, ErrEmptyColumnsToProject) { t.Fatalf("expected ErrEmptyColumnsToProject, got %v", err) } - }) - - t.Run("NonExistentColumn", func(t *testing.T) { - memSrc, err := NewInMemoryProjectExec(names, cols) - if err != nil { - t.Fatalf("failed to create in memory source: %v", err) - } - _, err = NewProjectExec([]string{"id", "nope"}, memSrc) - if err == nil { - t.Fatalf("expected error for non-existent column, got nil") - } - if !errors.Is(err, ErrProjectColumnNotFound) { - t.Fatalf("expected ErrProjectColumnNotFound, got %v", err) - } - }) - - t.Run("SchemaMatch", func(t *testing.T) { - memSrc, err := NewInMemoryProjectExec(names, cols) - if err != nil { - t.Fatalf("failed to create in memory source: %v", err) - } - projCols := []string{"id", "name"} - projExec, err := NewProjectExec(projCols, memSrc) - if err != nil { - t.Fatalf("failed to create project exec: %v", err) + if s == nil { + t.Fatalf("expected non-nil schema even on empty columns") } - execSchema := projExec.Schema() - pruned, err := prunedSchema(memSrc.Schema(), projCols) - if err != nil { - t.Fatalf("prunedSchema failed: %v", err) - } - if !execSchema.Equal(pruned) { - t.Fatalf("expected exec schema %v, got %v", pruned, execSchema) + if s.NumFields() != 0 { + t.Fatalf("expected 0 fields, got %d", s.NumFields()) } - _ = projExec }) - t.Run("IterateUntilEOF", func(t *testing.T) { - memSrc, err := NewInMemoryProjectExec(names, cols) + t.Run("ValidKeepCols_PreservesOrderAndTypes", func(t *testing.T) { + keep := []string{"name", "id"} + s, err := prunedSchema(schema, keep) if err != nil { - t.Fatalf("failed to create in memory source: %v", err) - } - projExec, err := NewProjectExec([]string{"id", "name"}, memSrc) - if err != nil { - t.Fatalf("failed to create project exec: %v", err) - } - total := 0 - batches := 0 - for { - rb, err := projExec.Next(3) - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Fatalf("Next returned unexpected error: %v", err) - } - if rb == nil { - t.Fatalf("expected record batch, got nil") - } - total += int(rb.Columns[0].Len()) - batches++ - for _, c := range rb.Columns { - c.Release() - } + t.Fatalf("unexpected error: %v", err) } - if batches == 0 { - t.Fatalf("expected at least 1 batch, got 0") + if s.NumFields() != 2 { + t.Fatalf("expected 2 fields, got %d", s.NumFields()) } - }) - - t.Run("SingleColumnProjection", func(t *testing.T) { - memSrc, err := NewInMemoryProjectExec(names, cols) - if err != nil { - t.Fatalf("failed to create in memory source: %v", err) - } - projExec, err := NewProjectExec([]string{"department"}, memSrc) - if err != nil { - t.Fatalf("failed to create project exec: %v", err) + if s.Field(0).Name != "name" || s.Field(0).Type.ID() != arrow.STRING { + t.Fatalf("field 0 mismatch, got %v", s.Field(0)) } - total := 0 - for { - rb, err := projExec.Next(5) - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Fatalf("Next returned unexpected error: %v", err) - } - if len(rb.Columns) != 1 { - t.Fatalf("expected 1 column, got %d", len(rb.Columns)) - } - total += int(rb.Columns[0].Len()) - for _, c := range rb.Columns { - c.Release() - } + if s.Field(1).Name != "id" || s.Field(1).Type.ID() != arrow.INT32 { + t.Fatalf("field 1 mismatch, got %v", s.Field(1)) } }) - t.Run("Check Close", func(t *testing.T) { - memSrc, err := NewInMemoryProjectExec(names, cols) - if err != nil { - t.Fatalf("failed to create in memory source: %v", err) - } - projExec, err := NewProjectExec([]string{"department"}, memSrc) - if err != nil { - t.Fatalf("failed to create project exec: %v", err) - } - err = projExec.Close() - if err != nil { - t.Fatalf("expected no error on Close, got %v", err) - } - }) - t.Run("Empty ProjectFilter", func(t *testing.T) { - memSrc, err := NewInMemoryProjectExec(names, cols) - if err != nil { - t.Fatalf("failed to create in memory source: %v", err) - } - _, _, err = ProjectSchemaFilterDown(memSrc.Schema(), memSrc.columns, []string{}...) - if err == nil { - t.Fatalf("expected error for empty project filter, got nil") - } - if !errors.Is(err, ErrEmptyColumnsToProject) { - t.Fatalf("expected ErrEmptyColumnsToProject, got %v", err) - } - _, _, err = ProjectSchemaFilterDown(memSrc.Schema(), memSrc.columns, []string{"This column doesnt exist"}...) + t.Run("MissingColumn_ReturnsErrProjectColumnNotFound", func(t *testing.T) { + _, err := prunedSchema(schema, []string{"missing_col"}) if err == nil { - t.Fatalf("expected error for non-existent column in project filter, got nil") + t.Fatalf("expected ErrProjectColumnNotFound, got nil") } if !errors.Is(err, ErrProjectColumnNotFound) { t.Fatalf("expected ErrProjectColumnNotFound, got %v", err) } - }) } From 19d436c88f9b531e639415ec3c0fbf3e62cd8361 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 19 Nov 2025 15:34:36 -0500 Subject: [PATCH 12/19] fixed linting/pre-push rules --- src/Backend/opti-sql-rs/src/main.rs | 46 +++++++++++-------- .../opti-sql-rs/src/project/project_exec.rs | 1 + .../opti-sql-rs/src/project/source/csv.rs | 1 + 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/Backend/opti-sql-rs/src/main.rs b/src/Backend/opti-sql-rs/src/main.rs index aa02e30..32f0023 100644 --- a/src/Backend/opti-sql-rs/src/main.rs +++ b/src/Backend/opti-sql-rs/src/main.rs @@ -1,5 +1,8 @@ //use datafusion::arrow::record_batch::RecordBatch; -use datafusion::{arrow::{array::RecordBatch, util::pretty::print_batches}, functions::crypto::basic, prelude::*}; +use datafusion::{ + arrow::{array::RecordBatch, util::pretty::print_batches}, + prelude::*, +}; use datafusion_substrait::*; mod project; #[allow(dead_code)] @@ -7,11 +10,9 @@ mod project; async fn main() { let mut ctx = SessionContext::new(); - ctx.register_csv( - "example", - "example.csv", - CsvReadOptions::new() - ).await.unwrap(); + ctx.register_csv("example", "example.csv", CsvReadOptions::new()) + .await + .unwrap(); // basic projections //basic_project("Basic column projection",&mut ctx,"select name,salary from example").await; //basic_project("reorder and duplicate projection",&mut ctx,"select salary,name,salary as s1 from example").await; // this is an error, expression names must be unique/ must use alias to get around error @@ -37,29 +38,34 @@ async fn main() { // mixed expressions //basic_project("mixed expressions",&mut ctx,"SELECT upper(name) AS upper_name, salary * 1.1 AS increased_salary FROM example").await; // function calls - let (l,r) = basic_project("function call with Abs()",&mut ctx,"SELECT ABS(age) FROM example").await; + let (_l, r) = basic_project( + "function call with Abs()", + &mut ctx, + "SELECT ABS(age) FROM example", + ) + .await; //println!("Logical Plan:\n{}", l); - print_batches(&r).unwrap();ß + print_batches(&r).unwrap(); + //basic_project("function call Round()",&mut ctx,"SELECT Round(age) FROM example").await; //basic_project("function call Length()",&mut ctx,"SELECT LENGTH(name) FROM example").await; - } -pub async fn basic_project(name : &str,ctx : &mut SessionContext,sql : &str) -> (String,Vec) { - println!("Running project: {}",name); - let df1 = ctx.sql(sql) - .await - .unwrap(); - +pub async fn basic_project( + name: &str, + ctx: &mut SessionContext, + sql: &str, +) -> (String, Vec) { + println!("Running project: {}", name); + let df1 = ctx.sql(sql).await.unwrap(); let logical_plan = df1.logical_plan().clone(); - let substrait_plan = logical_plan::producer::to_substrait_plan(&logical_plan, &ctx.state()).unwrap(); + let substrait_plan = + logical_plan::producer::to_substrait_plan(&logical_plan, &ctx.state()).unwrap(); print!("Substrait Plan :\n{:?}", substrait_plan); + let display = format!("{}", logical_plan.display_indent()); - let display = format!("{}",logical_plan.display_indent()); - // Running will create the physical plan automatically - return (display,df1.collect().await.unwrap()); - + (display, df1.collect().await.unwrap()) } diff --git a/src/Backend/opti-sql-rs/src/project/project_exec.rs b/src/Backend/opti-sql-rs/src/project/project_exec.rs index 2447902..618b4b7 100644 --- a/src/Backend/opti-sql-rs/src/project/project_exec.rs +++ b/src/Backend/opti-sql-rs/src/project/project_exec.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] pub fn project_execute() { println!("Reading CSV..."); } diff --git a/src/Backend/opti-sql-rs/src/project/source/csv.rs b/src/Backend/opti-sql-rs/src/project/source/csv.rs index 021581e..57abaa9 100644 --- a/src/Backend/opti-sql-rs/src/project/source/csv.rs +++ b/src/Backend/opti-sql-rs/src/project/source/csv.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] pub fn read_csv() { println!("Reading CSV..."); } From 3f6de414284eee7f538c464a0f8b1386366219c2 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 19 Nov 2025 17:14:17 -0500 Subject: [PATCH 13/19] resolved PR comments --- src/Backend/opti-sql-go/Expr/expr.go | 35 ++++++++++--------- src/Backend/opti-sql-go/Expr/expr_test.go | 6 ++++ .../operators/project/projectExec.go | 3 ++ .../operators/project/projectExecExpr_test.go | 10 +++--- .../operators/project/projectExec_test.go | 20 +++++++++++ 5 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go index e278093..a7c9daf 100644 --- a/src/Backend/opti-sql-go/Expr/expr.go +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -67,7 +67,6 @@ var ( _ = (Expression)(&CastExpr{}) ) -// TODO: create nice wrapper functions for creating expressions /* Eval(expr): @@ -224,7 +223,7 @@ func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, // INT / UINT (8/16/32/64) // ------------------------------ case arrow.INT8: - v := int8(l.Value.(int8)) + v := l.Value.(int8) b := array.NewInt8Builder(memory.DefaultAllocator) defer b.Release() @@ -244,7 +243,7 @@ func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, return b.NewArray(), nil case arrow.INT16: - v := int16(l.Value.(int16)) + v := l.Value.(int16) b := array.NewInt16Builder(memory.DefaultAllocator) defer b.Release() for i := 0; i < n; i++ { @@ -253,7 +252,7 @@ func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, return b.NewArray(), nil case arrow.UINT16: - v := uint16(l.Value.(uint16)) + v := l.Value.(uint16) b := array.NewUint16Builder(memory.DefaultAllocator) defer b.Release() for i := 0; i < n; i++ { @@ -262,7 +261,7 @@ func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, return b.NewArray(), nil case arrow.INT32: - v := int32(l.Value.(int32)) + v := l.Value.(int32) b := array.NewInt32Builder(memory.DefaultAllocator) defer b.Release() for i := 0; i < n; i++ { @@ -271,16 +270,15 @@ func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, return b.NewArray(), nil case arrow.UINT32: - v := uint32(l.Value.(uint32)) + v := l.Value.(uint32) b := array.NewUint32Builder(memory.DefaultAllocator) defer b.Release() for i := 0; i < n; i++ { b.Append(v) } return b.NewArray(), nil - // correct jump case arrow.INT64: - v := int64(l.Value.(int64)) + v := l.Value.(int64) b := array.NewInt64Builder(memory.DefaultAllocator) defer b.Release() for i := 0; i < n; i++ { @@ -289,7 +287,7 @@ func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, return b.NewArray(), nil case arrow.UINT64: - v := uint64(l.Value.(uint64)) + v := l.Value.(uint64) b := array.NewUint64Builder(memory.DefaultAllocator) defer b.Release() for i := 0; i < n; i++ { @@ -301,7 +299,7 @@ func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, // FLOATS // ------------------------------ case arrow.FLOAT32: - v := float32(l.Value.(float32)) + v := l.Value.(float32) b := array.NewFloat32Builder(memory.DefaultAllocator) defer b.Release() for i := 0; i < n; i++ { @@ -310,7 +308,7 @@ func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, return b.NewArray(), nil case arrow.FLOAT64: - v := float64(l.Value.(float64)) + v := l.Value.(float64) b := array.NewFloat64Builder(memory.DefaultAllocator) defer b.Release() for i := 0; i < n; i++ { @@ -406,15 +404,24 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error // comparisions TODO: case Equal: + return nil, fmt.Errorf("operator Equal (%d) not yet implemented", b.Op) case NotEqual: + return nil, fmt.Errorf("operator NotEqual (%d) not yet implemented", b.Op) case LessThan: + return nil, fmt.Errorf("operator LessThan (%d) not yet implemented", b.Op) case LessThanOrEqual: + return nil, fmt.Errorf("operator LessThanOrEqual (%d) not yet implemented", b.Op) case GreaterThan: + return nil, fmt.Errorf("operator GreaterThan (%d) not yet implemented", b.Op) case GreaterThanOrEqual: + return nil, fmt.Errorf("operator GreaterThanOrEqual (%d) not yet implemented", b.Op) // logical case And: + return nil, fmt.Errorf("operator And (%d) not yet implemented", b.Op) case Or: + return nil, fmt.Errorf("operator Or (%d) not yet implemented", b.Op) case Not: + return nil, fmt.Errorf("operator Not (%d) not yet implemented", b.Op) } return nil, fmt.Errorf("binary operator %d not supported", b.Op) } @@ -557,12 +564,6 @@ func inferScalarFunctionType(fn supportedFunctions, argType arrow.DataType) arro } } -// not a priority at all -type AggregateFunction struct { - Function aggFunctions - Args Expression -} - // If cast succeeds → return the casted value // If cast fails → throw a runtime error type CastExpr struct { diff --git a/src/Backend/opti-sql-go/Expr/expr_test.go b/src/Backend/opti-sql-go/Expr/expr_test.go index 6fe9115..f5555d7 100644 --- a/src/Backend/opti-sql-go/Expr/expr_test.go +++ b/src/Backend/opti-sql-go/Expr/expr_test.go @@ -1192,3 +1192,9 @@ func TestExprInitMethods(t *testing.T) { } }) } + +func TestUnimplemntedOperators(t *testing.T) { + for i := Equal; i <= Or; i++ { + NewBinaryExpr(NewLiteralResolve(arrow.PrimitiveTypes.Int16, int16(10)), 3, NewLiteralResolve(arrow.PrimitiveTypes.Int16, int16(5))) + } +} diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index 5dcd3bf..e85c73e 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -47,6 +47,8 @@ func NewProjectExec(input operators.Operator, exprs []Expr.Expression) (*Project } } } + // Use a generic column naming pattern ("col_%d") when an expression doesn't have an explicit alias. + // This ensures every projected column has a name in the output schema. outputschema := arrow.NewSchema(fields, nil) // return new exec @@ -83,6 +85,7 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { return nil, fmt.Errorf("project eval expression failed for expr %d: %w", i, err) } outPutCols[i] = arr + arr.Retain() } for _, c := range childBatch.Columns { c.Release() diff --git a/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go b/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go index 68eeb61..25b1b94 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go +++ b/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go @@ -147,7 +147,7 @@ func TestProjectExec_Literal_sql(t *testing.T) { } column := rc.Columns[0] if column.DataType() != arrow.BinaryTypes.String { - t.Fatalf("expected column types to be of type string but recieved %s\n", column.DataType()) + t.Fatalf("expected column types to be of type string but received %s\n", column.DataType()) } if column.ValueStr(0) != "hello" { t.Fatalf("expected hello, got %s", column.ValueStr(0)) @@ -167,7 +167,7 @@ func TestProjectExec_Literal_sql(t *testing.T) { } column := rc.Columns[0] if column.DataType() != arrow.PrimitiveTypes.Float32 { - t.Fatalf("expected column types to be of type string but recieved %s\n", column.DataType()) + t.Fatalf("expected column types to be of type string but received %s\n", column.DataType()) } floatArr, ok := column.(*array.Float32) if !ok { @@ -208,7 +208,7 @@ func TestProjectExec_Literal_Literal(t *testing.T) { } expected := []int64{20, 22, 45, 86} if ageCol.Len() != len(expected) { - t.Fatalf("mismatch in expected column length, recieved column of len %d", ageCol.Len()) + t.Fatalf("mismatch in expected column length, received column of len %d", ageCol.Len()) } for i := 0; i < len(expected); i++ { if ageCol.Value(i) != expected[i] { @@ -700,7 +700,7 @@ func TestProjectExec_AliasExpr(t *testing.T) { /* function(column/literal) function(column |operator| literal) -function(columh/literal |operator| literal/column) +function(column/literal |operator| literal/column) */ func TestProjectExec_FunctionExpr(t *testing.T) { names, cols := generateData() // id, name, age, active, score @@ -838,6 +838,6 @@ func TestProjectExec_FunctionExpr(t *testing.T) { /* complex expr ex: alias(function(column |operator| literal) |operator| literal) -TODO: not the most imporatnt thing right now since we know basic expression are fine +TODO: not the most important thing right now since we know basic expression are fine */ func TestProjectExec_ComplexExpr(t *testing.T) {} diff --git a/src/Backend/opti-sql-go/operators/project/projectExec_test.go b/src/Backend/opti-sql-go/operators/project/projectExec_test.go index 3415d47..b180b2c 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec_test.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec_test.go @@ -153,6 +153,26 @@ func TestProjectExec_BinaryAdd(t *testing.T) { t.Logf("column: %+v", rb.Columns[0]) } +// TODO: once your implement the other operators this test will fail +func TestUnimplemntedOperators(t *testing.T) { + names, cols := generateTestColumns() + memSrc, _ := NewInMemoryProjectExec(names, cols) + for i := Expr.Equal; i <= Expr.Or; i++ { + br := Expr.NewBinaryExpr(Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int16, int16(10)), i, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int16, int16(5))) + + proj, err := NewProjectExec(memSrc, []Expr.Expression{br}) + if err != nil { + t.Fatalf("failed to create project exec: %v", err) + } + _, err = proj.Next(1) + if err == nil { + t.Fatalf("expected error for unimplemented operator %d, got nil", i) + + } + t.Logf("error: %v", err) + } +} + func TestProjectExec_IterateEOF(t *testing.T) { names, cols := generateTestColumns() memSrc, _ := NewInMemoryProjectExec(names, cols) From 36622452e234c882ca40bcdf8198ce4b1ae5097a Mon Sep 17 00:00:00 2001 From: Marco Ferreira Date: Wed, 19 Nov 2025 20:19:59 -0500 Subject: [PATCH 14/19] feat(http-server) Added http server, tests, updated readme, added workflow --- .env | 10 ++ .env.example | 10 ++ .github/workflows/frontend-test.yml | 35 ++++ Makefile | 51 +++++- README.md | 16 +- src/Backend/opti-sql-go/go.mod | 2 +- src/FrontEnd/.dockerignore | 38 +++++ src/FrontEnd/.env | 10 ++ src/FrontEnd/Dockerfile | 17 ++ src/FrontEnd/README.md | 186 +++++++++++++++++++++ src/FrontEnd/app/__init__.py | 0 src/FrontEnd/app/api/__init__.py | 0 src/FrontEnd/app/api/v1/__init__.py | 0 src/FrontEnd/app/api/v1/routes/__init__.py | 0 src/FrontEnd/app/api/v1/routes/health.py | 17 ++ src/FrontEnd/app/api/v1/routes/query.py | 79 +++++++++ src/FrontEnd/app/core/__init__.py | 0 src/FrontEnd/app/core/config.py | 27 +++ src/FrontEnd/app/core/logging.py | 44 +++++ src/FrontEnd/app/main.py | 51 ++++++ src/FrontEnd/app/models/__init__.py | 0 src/FrontEnd/app/models/schemas.py | 32 ++++ src/FrontEnd/docker-compose.yml | 24 +++ src/FrontEnd/generate_swagger.py | 12 ++ src/FrontEnd/main.cpp | 6 - src/FrontEnd/pytest.ini | 9 + src/FrontEnd/requirements.txt | 10 ++ src/FrontEnd/swagger.yml | 155 +++++++++++++++++ src/FrontEnd/tests/README.md | 119 +++++++++++++ src/FrontEnd/tests/__init__.py | 0 src/FrontEnd/tests/conftest.py | 26 +++ src/FrontEnd/tests/test_health.py | 28 ++++ src/FrontEnd/tests/test_integration.py | 116 +++++++++++++ src/FrontEnd/tests/test_query.py | 124 ++++++++++++++ 34 files changed, 1243 insertions(+), 11 deletions(-) create mode 100644 .env create mode 100644 .env.example create mode 100644 .github/workflows/frontend-test.yml create mode 100644 src/FrontEnd/.dockerignore create mode 100644 src/FrontEnd/.env create mode 100644 src/FrontEnd/Dockerfile create mode 100644 src/FrontEnd/README.md create mode 100644 src/FrontEnd/app/__init__.py create mode 100644 src/FrontEnd/app/api/__init__.py create mode 100644 src/FrontEnd/app/api/v1/__init__.py create mode 100644 src/FrontEnd/app/api/v1/routes/__init__.py create mode 100644 src/FrontEnd/app/api/v1/routes/health.py create mode 100644 src/FrontEnd/app/api/v1/routes/query.py create mode 100644 src/FrontEnd/app/core/__init__.py create mode 100644 src/FrontEnd/app/core/config.py create mode 100644 src/FrontEnd/app/core/logging.py create mode 100644 src/FrontEnd/app/main.py create mode 100644 src/FrontEnd/app/models/__init__.py create mode 100644 src/FrontEnd/app/models/schemas.py create mode 100644 src/FrontEnd/docker-compose.yml create mode 100644 src/FrontEnd/generate_swagger.py delete mode 100644 src/FrontEnd/main.cpp create mode 100644 src/FrontEnd/pytest.ini create mode 100644 src/FrontEnd/requirements.txt create mode 100644 src/FrontEnd/swagger.yml create mode 100644 src/FrontEnd/tests/README.md create mode 100644 src/FrontEnd/tests/__init__.py create mode 100644 src/FrontEnd/tests/conftest.py create mode 100644 src/FrontEnd/tests/test_health.py create mode 100644 src/FrontEnd/tests/test_integration.py create mode 100644 src/FrontEnd/tests/test_query.py diff --git a/.env b/.env new file mode 100644 index 0000000..0070cc1 --- /dev/null +++ b/.env @@ -0,0 +1,10 @@ +# Server Configuration +PORT=8000 +HOST=0.0.0.0 + +# Logging Configuration +# Options: prod, info, debug +LOGGING_MODE=info + +# Testing Configuration +TEST_SERVER_URL=http://localhost:8000 diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..0070cc1 --- /dev/null +++ b/.env.example @@ -0,0 +1,10 @@ +# Server Configuration +PORT=8000 +HOST=0.0.0.0 + +# Logging Configuration +# Options: prod, info, debug +LOGGING_MODE=info + +# Testing Configuration +TEST_SERVER_URL=http://localhost:8000 diff --git a/.github/workflows/frontend-test.yml b/.github/workflows/frontend-test.yml new file mode 100644 index 0000000..9208ef6 --- /dev/null +++ b/.github/workflows/frontend-test.yml @@ -0,0 +1,35 @@ +name: Frontend Tests + +on: + push: + branches: [ main, pre-release ] + paths: + - 'src/FrontEnd/**' + - '.github/workflows/frontend-test.yml' + pull_request: + branches: [ main, pre-release ] + paths: + - 'src/FrontEnd/**' + - '.github/workflows/frontend-test.yml' + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r src/FrontEnd/requirements.txt + + - name: Run tests + run: | + cd src/FrontEnd + pytest -v -m "not integration" diff --git a/Makefile b/Makefile index 0afcce1..0c24562 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help go-test rust-test go-run rust-run go-lint rust-lint go-fmt rust-fmt test-all lint-all fmt-all pre-push +.PHONY: help go-test rust-test go-run rust-run go-lint rust-lint go-fmt rust-fmt frontend-test frontend-run frontend-docker-build frontend-docker-run frontend-docker-down frontend-setup test-all lint-all fmt-all pre-push # Default target help: @@ -11,7 +11,13 @@ help: @echo " make rust-lint - Run Rust linter and formatter check" @echo " make go-fmt - Format Go code" @echo " make rust-fmt - Format Rust code" - @echo " make test-all - Run all tests (Go + Rust)" + @echo " make frontend-test - Run Python/Frontend tests" + @echo " make frontend-run - Run Frontend server (without Docker)" + @echo " make frontend-setup - Setup Python virtual environment and install dependencies" + @echo " make frontend-docker-build - Build Frontend Docker image" + @echo " make frontend-docker-run - Run Frontend using Docker Compose" + @echo " make frontend-docker-down - Stop Frontend Docker containers" + @echo " make test-all - Run all tests (Go + Rust + Frontend)" @echo " make lint-all - Run all linters (Go + Rust)" @echo " make fmt-all - Format all code (Go + Rust)" @echo " make pre-push - Run fmt, lint, and test (use before pushing)" @@ -71,8 +77,47 @@ rust-fmt-check: @echo "Checking Rust formatting..." cd src/Backend/opti-sql-rs && cargo fmt --check +# Frontend targets +frontend-setup: + @echo "Setting up Python virtual environment..." + rm -rf src/FrontEnd/venv + cd src/FrontEnd && python3.12 -m venv --without-pip venv + @echo "Installing pip..." + cd src/FrontEnd && . venv/bin/activate && curl -sS https://bootstrap.pypa.io/get-pip.py | python + @echo "Installing dependencies..." + cd src/FrontEnd && . venv/bin/activate && pip install --upgrade pip && pip install -r requirements.txt + @echo "Frontend setup completed! Activate with: cd src/FrontEnd && source venv/bin/activate" + +frontend-test: frontend-setup + @echo "Running Frontend/Python tests..." + cd src/FrontEnd && . venv/bin/activate && pytest -m "not integration" + +frontend-run: + @echo "Running Frontend server..." + cd src/FrontEnd && . venv/bin/activate && python -m uvicorn app.main:app --reload --host 0.0.0.0 --port 8005 + +frontend-docker-build: + @echo "Building Frontend Docker image..." + @if [ ! -f src/FrontEnd/.env ]; then \ + echo "Creating .env file from root .env..."; \ + cp .env src/FrontEnd/.env; \ + fi + cd src/FrontEnd && docker compose build + +frontend-docker-run: + @echo "Running Frontend with Docker Compose..." + @if [ ! -f src/FrontEnd/.env ]; then \ + echo "Creating .env file from root .env..."; \ + cp .env src/FrontEnd/.env; \ + fi + cd src/FrontEnd && docker compose up -d + +frontend-docker-down: + @echo "Stopping Frontend Docker containers..." + cd src/FrontEnd && docker compose down + # Combined targets -test-all: go-test rust-test +test-all: go-test rust-test frontend-test @echo "All tests completed!" lint-all: go-lint rust-lint diff --git a/README.md b/README.md index 061a643..52dedf9 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ A high-performance, in-memory query execution engine. ![Go Tests](https://github.com/Rich-T-kid/OptiSQL/actions/workflows/go-test.yml/badge.svg) ![Rust Tests](https://github.com/Rich-T-kid/OptiSQL/actions/workflows/rust-test.yml/badge.svg) +![Frontend Tests](https://github.com/Rich-T-kid/OptiSQL/actions/workflows/frontend-test.yml/badge.svg) + ## Overview @@ -20,6 +22,8 @@ OptiSQL is a custom in-memory query execution engine. The backend (physical exec - Go 1.24+ - Rust 1.70+ - C++23 +- Python 3.11+ +- Docker 29+ - Make - git @@ -36,6 +40,13 @@ make go-run # Build and run Rust backend make rust-run +# Frontend setup and run +make frontend-setup # Create venv and install dependencies +make frontend-run # Run locally without Docker +# OR with Docker +make frontend-docker-build +make frontend-docker-run + # Run all tests make test-all @@ -58,7 +69,10 @@ OptiSQL/ │ │ └── opti-sql-rs/ # Rust implementation (Go clone for learning) │ │ ├── src/project/ # Core project logic │ │ └── src/ # Query processing modules -│ └── FrontEnd/ # C++ frontend (in development) +│ └── FrontEnd/ # Python/FastAPI HTTP server (C++ query processing in progress) +│ ├── app/ # API endpoints and logic +│ ├── tests/ # Frontend tests +│ └── Dockerfile # Docker configuration ├── .github/workflows/ # CI/CD pipelines ├── Makefile # Development commands └── CONTRIBUTING.md # Contribution guidelines diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index 49182e3..665dc8c 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -1,6 +1,6 @@ module opti-sql-go -go 1.24.0 +go 1.23 require ( github.com/apache/arrow/go/v17 v17.0.0 diff --git a/src/FrontEnd/.dockerignore b/src/FrontEnd/.dockerignore new file mode 100644 index 0000000..99162ba --- /dev/null +++ b/src/FrontEnd/.dockerignore @@ -0,0 +1,38 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +env/ +venv/ +ENV/ +*.egg-info/ +dist/ +build/ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# Git +.git/ +.gitignore + +# Documentation +*.md +tests/ + +# Other +.DS_Store +*.log +swagger.yml +generate_swagger.py +.env.example diff --git a/src/FrontEnd/.env b/src/FrontEnd/.env new file mode 100644 index 0000000..0070cc1 --- /dev/null +++ b/src/FrontEnd/.env @@ -0,0 +1,10 @@ +# Server Configuration +PORT=8000 +HOST=0.0.0.0 + +# Logging Configuration +# Options: prod, info, debug +LOGGING_MODE=info + +# Testing Configuration +TEST_SERVER_URL=http://localhost:8000 diff --git a/src/FrontEnd/Dockerfile b/src/FrontEnd/Dockerfile new file mode 100644 index 0000000..282a56c --- /dev/null +++ b/src/FrontEnd/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . +COPY ../../.env + +# Expose port (will be read from config.yml) +EXPOSE 8000 + +# Run the application +CMD ["python", "-m", "app.main"] diff --git a/src/FrontEnd/README.md b/src/FrontEnd/README.md new file mode 100644 index 0000000..2131f8e --- /dev/null +++ b/src/FrontEnd/README.md @@ -0,0 +1,186 @@ +# OptiSQL Frontend API + +FastAPI server for SQL query processing and optimization. + +## Features + +- **Health Check Endpoint**: `/api/v1/health` +- **SQL Query Processing**: `/api/v1/query` + - Supports file uploads (CSV, JSON, Parquet, Excel) + - Supports file URIs (HTTP/HTTPS) + - Configurable logging levels + +## Project Structure + +``` +FrontEnd/ +├── app/ +│ ├── main.py # FastAPI application +│ ├── api/v1/routes/ # API endpoints +│ ├── core/ # Config and logging +│ └── models/ # Pydantic schemas +├── tests/ # Test suite +├── config.yml # Configuration +├── requirements.txt # Python dependencies +├── Dockerfile # Docker image +├── docker-compose.yml # Docker Compose config +└── pytest.ini # Pytest configuration +``` + +## Configuration + +Create a `.env` file (or copy from `.env.example`): + +```bash +cp .env.example .env +``` + +Edit `.env` to configure: + +```env +# Server Configuration +PORT=8000 +HOST=0.0.0.0 + +# Logging Configuration +# Options: prod, info, debug +LOGGING_MODE=info +``` + +## Running Locally + +### Install Dependencies + +```bash +pip install -r requirements.txt +``` + +### Run the Server + +```bash +python -m app.main +``` + +The server will start on the port specified in `config.yml` (default: 8000). + +### Access API Documentation + +- Swagger UI: http://localhost:8000/docs +- ReDoc: http://localhost:8000/redoc + +## Running with Docker + +### Build and Run + +```bash +docker compose up --build +``` + +### Run in Background + +```bash +docker compose up -d +``` + +### View Logs + +```bash +docker compose logs -f +``` + +### Stop the Service + +```bash +docker compose down +``` + +Note: Use `docker compose` (without hyphen) for Docker Compose V2. + +## Testing + + +### Run Unit Tests Only (No Server Required) + +```bash +pytest -m "not integration" +``` + +### Run Integration Tests (Requires Running Server) + +```bash +# Start server first +docker compose up -d + +# Run integration tests +pytest -m integration +``` + +### Run All Tests + +```bash +pytest +``` + +### Run with Verbose Output + +```bash +pytest -v +``` + +### Run Specific Tests + +```bash +pytest tests/test_health.py +pytest tests/test_query.py +pytest tests/test_integration.py +``` + +### Test Against Different Server + +```bash +TEST_SERVER_URL=http://localhost:9000 pytest -m integration +``` + +## API Endpoints + +### Health Check + +```bash +curl http://localhost:8000/api/v1/health +``` + +### SQL Query Processing (File Upload) + +```bash +curl -X POST http://localhost:8000/api/v1/query \ + -F "sql_query=SELECT * FROM data" \ + -F "file=@data.csv" +``` + +### SQL Query Processing (File URI) + +```bash +curl -X POST http://localhost:8000/api/v1/query \ + -F "sql_query=SELECT * FROM data" \ + -F "file_uri=https://example.com/data.csv" +``` + +## Logging Modes + +- **prod**: WARNING level, minimal output +- **info**: INFO level, standard logging +- **debug**: DEBUG level, detailed logs with file/line numbers + +## Development + +### Hot Reload + +When running with Docker Compose, the app directory is mounted as a volume, enabling hot-reload during development. + +### Generate Swagger YAML + +```bash +python generate_swagger.py +``` + +This generates `swagger.yml` with the OpenAPI specification. diff --git a/src/FrontEnd/app/__init__.py b/src/FrontEnd/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/FrontEnd/app/api/__init__.py b/src/FrontEnd/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/FrontEnd/app/api/v1/__init__.py b/src/FrontEnd/app/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/FrontEnd/app/api/v1/routes/__init__.py b/src/FrontEnd/app/api/v1/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/FrontEnd/app/api/v1/routes/health.py b/src/FrontEnd/app/api/v1/routes/health.py new file mode 100644 index 0000000..c46ba37 --- /dev/null +++ b/src/FrontEnd/app/api/v1/routes/health.py @@ -0,0 +1,17 @@ +from fastapi import APIRouter +from app.models.schemas import HealthResponse +import logging + +router = APIRouter() +logger = logging.getLogger(__name__) + +@router.get("/health", response_model=HealthResponse, tags=["Health"]) +async def health_check(): + """ + Health check endpoint to verify the service is running. + + Returns: + HealthResponse: Current health status of the service + """ + logger.info("Health check requested") + return HealthResponse(status="healthy", version="0.1.0") diff --git a/src/FrontEnd/app/api/v1/routes/query.py b/src/FrontEnd/app/api/v1/routes/query.py new file mode 100644 index 0000000..9914824 --- /dev/null +++ b/src/FrontEnd/app/api/v1/routes/query.py @@ -0,0 +1,79 @@ +from fastapi import APIRouter, UploadFile, File, Form, HTTPException +from app.models.schemas import SQLQueryResponse +import logging +import time +from typing import Optional + +router = APIRouter() +logger = logging.getLogger(__name__) + +@router.post("/query", response_model=SQLQueryResponse, tags=["Query Processing"]) +async def process_sql_query( + sql_query: str = Form(..., description="SQL query to process"), + file_uri: Optional[str] = Form(None, description="URI to a remote file (optional if file is provided)"), + file: Optional[UploadFile] = File(default=None, description="Data file to process (optional if file_uri is provided)") +): + """ + Process a SQL query against an uploaded file or a file from a URI. + + Args: + sql_query: SQL query string to execute + file: Uploaded file (CSV, JSON, Parquet, etc.) - optional + file_uri: URI to a remote file (e.g., s3://bucket/file.csv, https://example.com/data.csv) - optional + + Returns: + SQLQueryResponse: Query processing results + + Note: + Either 'file' or 'file_uri' must be provided, but not both. + """ + start_time = time.time() + + logger.info(f"Processing SQL query: {sql_query[:100]}...") + + # Normalize inputs: treat empty strings as None + if file_uri is not None and file_uri.strip() == "": + file_uri = None + + # Check if file is actually empty (no filename means no file uploaded) + if file is not None and (not file.filename or file.filename == ""): + file = None + + # Validate input: must have either file or file_uri + if file is None and file_uri is None: + raise HTTPException( + status_code=400, + detail="Either 'file' (uploaded file) or 'file_uri' (file URI) must be provided" + ) + + if file is not None and file_uri is not None: + raise HTTPException( + status_code=400, + detail="Cannot provide both 'file' and 'file_uri'. Please provide only one." + ) + + try: + + # TODO: Implement actual SQL query processing logic + + result = { + "query_length": len(sql_query), + "message": "Query processing not yet implemented" + } + + execution_time = (time.time() - start_time) * 1000 + + logger.info(f"Query processed successfully in {execution_time:.2f}ms") + + return SQLQueryResponse( + status="success", + query=sql_query, + result=result, + execution_time_ms=execution_time + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error processing query: {str(e)}") + raise HTTPException(status_code=500, detail=f"Query processing failed: {str(e)}") \ No newline at end of file diff --git a/src/FrontEnd/app/core/__init__.py b/src/FrontEnd/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/FrontEnd/app/core/config.py b/src/FrontEnd/app/core/config.py new file mode 100644 index 0000000..e59f6b4 --- /dev/null +++ b/src/FrontEnd/app/core/config.py @@ -0,0 +1,27 @@ +import os +from typing import Literal +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +class Config: + def __init__(self): + pass + + @property + def port(self) -> int: + return int(os.getenv('PORT', '8000')) + + @property + def host(self) -> str: + return os.getenv('HOST', '0.0.0.0') + + @property + def logging_mode(self) -> Literal['prod', 'info', 'debug']: + mode = os.getenv('LOGGING_MODE', 'info').lower() + if mode not in ['prod', 'info', 'debug']: + return 'info' + return mode + +config = Config() diff --git a/src/FrontEnd/app/core/logging.py b/src/FrontEnd/app/core/logging.py new file mode 100644 index 0000000..1a0671f --- /dev/null +++ b/src/FrontEnd/app/core/logging.py @@ -0,0 +1,44 @@ +import logging +import sys +from typing import Literal + +def setup_logging(mode: Literal['prod', 'info', 'debug']) -> None: + """ + Configure logging based on the mode specified in config.yml + + Args: + mode: Logging mode - 'prod', 'info', or 'debug' + """ + # Define logging levels + level_map = { + 'prod': logging.WARNING, + 'info': logging.INFO, + 'debug': logging.DEBUG + } + + log_level = level_map.get(mode, logging.INFO) + + # Configure format based on mode + if mode == 'debug': + log_format = '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s' + elif mode == 'info': + log_format = '%(asctime)s - %(levelname)s - %(message)s' + else: # prod + log_format = '%(asctime)s - %(levelname)s - %(message)s' + + # Configure root logger + logging.basicConfig( + level=log_level, + format=log_format, + handlers=[ + logging.StreamHandler(sys.stdout) + ] + ) + + # Set uvicorn logger levels + logging.getLogger("uvicorn").setLevel(log_level) + logging.getLogger("uvicorn.access").setLevel(log_level) + logging.getLogger("uvicorn.error").setLevel(log_level) + + logger = logging.getLogger(__name__) + logger.info(f"Logging configured with mode: {mode} (level: {logging.getLevelName(log_level)})") diff --git a/src/FrontEnd/app/main.py b/src/FrontEnd/app/main.py new file mode 100644 index 0000000..f8fa204 --- /dev/null +++ b/src/FrontEnd/app/main.py @@ -0,0 +1,51 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from app.core.config import config +from app.core.logging import setup_logging +from app.api.v1.routes import health, query +import logging + +# Setup logging +setup_logging(config.logging_mode) +logger = logging.getLogger(__name__) + +# Create FastAPI app +app = FastAPI( + title="OptiSQL Frontend API", + description="FastAPI server for SQL query processing and optimization", + version="0.1.0", + docs_url="/docs", + redoc_url="/redoc" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Include routers +app.include_router(health.router, prefix="/api/v1") +app.include_router(query.router, prefix="/api/v1") + +@app.on_event("startup") +async def startup_event(): + logger.info("Starting OptiSQL Frontend API") + logger.info(f"Server configuration - Host: {config.host}, Port: {config.port}") + logger.info(f"Logging mode: {config.logging_mode}") + +@app.on_event("shutdown") +async def shutdown_event(): + logger.info("Shutting down OptiSQL Frontend API") + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app.main:app", + host=config.host, + port=config.port, + reload=True if config.logging_mode == "debug" else False + ) diff --git a/src/FrontEnd/app/models/__init__.py b/src/FrontEnd/app/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/FrontEnd/app/models/schemas.py b/src/FrontEnd/app/models/schemas.py new file mode 100644 index 0000000..ec626cb --- /dev/null +++ b/src/FrontEnd/app/models/schemas.py @@ -0,0 +1,32 @@ +from pydantic import BaseModel, Field +from typing import Optional, Any + +class HealthResponse(BaseModel): + status: str = Field(..., description="Health status of the service") + version: str = Field(default="0.1.0", description="API version") + +class SQLQueryRequest(BaseModel): + sql_query: str = Field(..., description="SQL query to process") + + class Config: + json_schema_extra = { + "example": { + "sql_query": "SELECT * FROM users WHERE age > 25" + } + } + +class SQLQueryResponse(BaseModel): + status: str = Field(..., description="Processing status") + query: str = Field(..., description="Original SQL query") + result: Optional[Any] = Field(None, description="Query processing result") + execution_time_ms: Optional[float] = Field(None, description="Execution time in milliseconds") + + class Config: + json_schema_extra = { + "example": { + "status": "success", + "query": "SELECT * FROM users WHERE age > 25", + "result": {"rows": 42}, + "execution_time_ms": 123.45 + } + } diff --git a/src/FrontEnd/docker-compose.yml b/src/FrontEnd/docker-compose.yml new file mode 100644 index 0000000..2afaa98 --- /dev/null +++ b/src/FrontEnd/docker-compose.yml @@ -0,0 +1,24 @@ +services: + fastapi: + build: + context: . + dockerfile: Dockerfile + container_name: optisql-frontend + ports: + - "${PORT:-8000}:${PORT:-8000}" + volumes: + # Mount the app directory + # TODO: need to add cpp library + - ./app:/app/app + - ./.env:/app/.env + env_file: + - .env + environment: + - PYTHONUNBUFFERED=1 + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:${PORT:-8000}/api/v1/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s diff --git a/src/FrontEnd/generate_swagger.py b/src/FrontEnd/generate_swagger.py new file mode 100644 index 0000000..7e0725c --- /dev/null +++ b/src/FrontEnd/generate_swagger.py @@ -0,0 +1,12 @@ +import yaml +import json +from app.main import app + +# Generate OpenAPI schema +openapi_schema = app.openapi() + +# Write to YAML file +with open('swagger.yml', 'w') as f: + yaml.dump(openapi_schema, f, default_flow_style=False, sort_keys=False) + +print("Swagger YAML file generated: swagger.yml") diff --git a/src/FrontEnd/main.cpp b/src/FrontEnd/main.cpp deleted file mode 100644 index fa5c04d..0000000 --- a/src/FrontEnd/main.cpp +++ /dev/null @@ -1,6 +0,0 @@ -#include - -int main() { - std::cout << "Hello World" << std::endl; - return 0; -} diff --git a/src/FrontEnd/pytest.ini b/src/FrontEnd/pytest.ini new file mode 100644 index 0000000..7193e46 --- /dev/null +++ b/src/FrontEnd/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_functions = test_* +addopts = -v --tb=short + +markers = + integration: marks tests as integration tests (requires running server) + unit: marks tests as unit tests (no server required) diff --git a/src/FrontEnd/requirements.txt b/src/FrontEnd/requirements.txt new file mode 100644 index 0000000..c3d97d2 --- /dev/null +++ b/src/FrontEnd/requirements.txt @@ -0,0 +1,10 @@ +fastapi==0.109.0 +uvicorn[standard]==0.27.0 +pydantic==2.5.3 +python-multipart==0.0.6 +python-dotenv==1.0.0 + +# Testing +pytest==7.4.4 +pytest-asyncio==0.23.3 +httpx==0.26.0 diff --git a/src/FrontEnd/swagger.yml b/src/FrontEnd/swagger.yml new file mode 100644 index 0000000..5a5163a --- /dev/null +++ b/src/FrontEnd/swagger.yml @@ -0,0 +1,155 @@ +openapi: 3.1.0 +info: + title: OptiSQL Frontend API + description: FastAPI server for SQL query processing and optimization + version: 0.1.0 +paths: + /api/v1/health: + get: + tags: + - Health + summary: Health Check + description: "Health check endpoint to verify the service is running.\n\nReturns:\n\ + \ HealthResponse: Current health status of the service" + operationId: health_check_api_v1_health_get + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/HealthResponse' + /api/v1/query: + post: + tags: + - Query Processing + summary: Process Sql Query + description: "Process a SQL query against an uploaded file or a file from a\ + \ URI.\n\nArgs:\n sql_query: SQL query string to execute\n file: Uploaded\ + \ file (CSV, JSON, Parquet, etc.) - optional\n file_uri: URI to a remote\ + \ file (e.g., s3://bucket/file.csv, https://example.com/data.csv) - optional\n\ + \nReturns:\n SQLQueryResponse: Query processing results\n\nNote:\n Either\ + \ 'file' or 'file_uri' must be provided, but not both." + operationId: process_sql_query_api_v1_query_post + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_process_sql_query_api_v1_query_post' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/SQLQueryResponse' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' +components: + schemas: + Body_process_sql_query_api_v1_query_post: + properties: + sql_query: + type: string + title: Sql Query + description: SQL query to process + file_uri: + anyOf: + - type: string + - type: 'null' + title: File Uri + description: URI to a remote file (optional if file is provided) + file: + anyOf: + - type: string + format: binary + - type: 'null' + title: File + description: Data file to process (optional if file_uri is provided) + type: object + required: + - sql_query + title: Body_process_sql_query_api_v1_query_post + HTTPValidationError: + properties: + detail: + items: + $ref: '#/components/schemas/ValidationError' + type: array + title: Detail + type: object + title: HTTPValidationError + HealthResponse: + properties: + status: + type: string + title: Status + description: Health status of the service + version: + type: string + title: Version + description: API version + default: 0.1.0 + type: object + required: + - status + title: HealthResponse + SQLQueryResponse: + properties: + status: + type: string + title: Status + description: Processing status + query: + type: string + title: Query + description: Original SQL query + result: + anyOf: + - {} + - type: 'null' + title: Result + description: Query processing result + execution_time_ms: + anyOf: + - type: number + - type: 'null' + title: Execution Time Ms + description: Execution time in milliseconds + type: object + required: + - status + - query + title: SQLQueryResponse + example: + execution_time_ms: 123.45 + query: SELECT * FROM users WHERE age > 25 + result: + rows: 42 + status: success + ValidationError: + properties: + loc: + items: + anyOf: + - type: string + - type: integer + type: array + title: Location + msg: + type: string + title: Message + type: + type: string + title: Error Type + type: object + required: + - loc + - msg + - type + title: ValidationError diff --git a/src/FrontEnd/tests/README.md b/src/FrontEnd/tests/README.md new file mode 100644 index 0000000..f0d94b0 --- /dev/null +++ b/src/FrontEnd/tests/README.md @@ -0,0 +1,119 @@ +# Tests + +This directory contains tests for the OptiSQL FastAPI server. + +## Test Types + +### Unit Tests +Unit tests use FastAPI's `TestClient` and don't require a running server. They are fast and can be run anywhere. + +- `test_health.py` - Health endpoint tests +- `test_query.py` - Query endpoint tests + +### Integration Tests +Integration tests connect to a real running server. They test the full stack including networking, Docker, etc. + +- `test_integration.py` - Tests against a running server + +## Running Tests + +### Run all tests (unit only, no server required) +```bash +pytest -m "not integration" +``` + +### Run only unit tests +```bash +pytest tests/test_health.py tests/test_query.py +``` + +### Run integration tests (requires running server) + +First, start the server: +```bash +# Using Docker +docker compose up -d + +# Or locally +python -m app.main +``` + +Then run integration tests: +```bash +pytest -m integration +``` + +### Run all tests (unit + integration) +```bash +# Make sure server is running first! +docker compose up -d + +# Run all tests +pytest +``` + +### Run specific test file +```bash +pytest tests/test_health.py +pytest tests/test_query.py +pytest tests/test_integration.py +``` + +### Run with verbose output +```bash +pytest -v +``` + +### Run with coverage +```bash +pytest --cov=app --cov-report=html +``` + +## Environment Variables + +### `TEST_SERVER_URL` +Set this to test against a different server URL (default: `http://localhost:8000`) + +```bash +TEST_SERVER_URL=http://localhost:9000 pytest -m integration +``` + +## Complete Test Workflow + +```bash +# 1. Run unit tests (no server needed) +pytest -m "not integration" + +# 2. Start the server +docker compose up -d + +# 3. Wait for server to be ready +sleep 3 + +# 4. Run integration tests +pytest -m integration + +# 5. Run all tests together +pytest + +# 6. Stop the server +docker compose down +``` + +## Test Structure + +- `conftest.py` - Shared fixtures for all tests +- `test_health.py` - Unit tests for the health endpoint +- `test_query.py` - Unit tests for the SQL query processing endpoint +- `test_integration.py` - Integration tests against running server + +## Test Fixtures + +### `client` +A TestClient instance for making requests to the API without running the server. + +### `sample_csv_file` +A sample CSV file fixture for testing file uploads. + +### `sample_json_file` +A sample JSON file fixture for testing file uploads. diff --git a/src/FrontEnd/tests/__init__.py b/src/FrontEnd/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/FrontEnd/tests/conftest.py b/src/FrontEnd/tests/conftest.py new file mode 100644 index 0000000..b52365c --- /dev/null +++ b/src/FrontEnd/tests/conftest.py @@ -0,0 +1,26 @@ +import pytest +from fastapi.testclient import TestClient +from app.main import app + + +@pytest.fixture +def client(): + """ + Create a test client for the FastAPI application. + This client can be used to make requests to the API without running the server. + """ + return TestClient(app) + + +@pytest.fixture +def sample_csv_file(): + """Create a sample CSV file content for testing.""" + content = b"name,age,city\nJohn,30,NYC\nJane,25,LA\nBob,35,Chicago" + return ("test_data.csv", content, "text/csv") + + +@pytest.fixture +def sample_json_file(): + """Create a sample JSON file content for testing.""" + content = b'[{"name":"John","age":30,"city":"NYC"},{"name":"Jane","age":25,"city":"LA"}]' + return ("test_data.json", content, "application/json") diff --git a/src/FrontEnd/tests/test_health.py b/src/FrontEnd/tests/test_health.py new file mode 100644 index 0000000..bda593d --- /dev/null +++ b/src/FrontEnd/tests/test_health.py @@ -0,0 +1,28 @@ +import pytest +from fastapi.testclient import TestClient + + +def test_health_endpoint(client: TestClient): + """Test that the health endpoint returns a successful response.""" + response = client.get("/api/v1/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "version" in data + + +def test_health_endpoint_structure(client: TestClient): + """Test that the health endpoint returns the correct structure.""" + response = client.get("/api/v1/health") + + assert response.status_code == 200 + data = response.json() + + # Check required fields + assert "status" in data + assert "version" in data + + # Check data types + assert isinstance(data["status"], str) + assert isinstance(data["version"], str) diff --git a/src/FrontEnd/tests/test_integration.py b/src/FrontEnd/tests/test_integration.py new file mode 100644 index 0000000..9c8ddaa --- /dev/null +++ b/src/FrontEnd/tests/test_integration.py @@ -0,0 +1,116 @@ +import pytest +import httpx +import os +import io + + +# Get server URL from environment or use default +SERVER_URL = os.getenv("TEST_SERVER_URL", "http://localhost:8000") + + +@pytest.mark.integration +def test_server_is_running(): + """Test that the server is accessible.""" + try: + response = httpx.get(f"{SERVER_URL}/api/v1/health", timeout=5.0) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + except httpx.ConnectError: + pytest.fail(f"Could not connect to server at {SERVER_URL}. Is it running?") + + +@pytest.mark.integration +def test_health_endpoint_live(): + """Test the health endpoint on a running server.""" + response = httpx.get(f"{SERVER_URL}/api/v1/health", timeout=5.0) + + assert response.status_code == 200 + data = response.json() + + assert data["status"] == "healthy" + assert "version" in data + assert isinstance(data["version"], str) + + +@pytest.mark.integration +def test_query_endpoint_with_file_live(): + """Test the query endpoint with file upload on a running server.""" + # Create a test CSV file + csv_content = b"name,age,city\nJohn,30,NYC\nJane,25,LA" + + files = {"file": ("test.csv", io.BytesIO(csv_content), "text/csv")} + data = {"sql_query": "SELECT * FROM data"} + + response = httpx.post( + f"{SERVER_URL}/api/v1/query", + files=files, + data=data, + timeout=10.0 + ) + + assert response.status_code == 200 + result = response.json() + + assert result["status"] == "success" + assert result["query"] == "SELECT * FROM data" + assert "execution_time_ms" in result + + +@pytest.mark.integration +def test_query_endpoint_with_uri_live(): + """Test the query endpoint with file URI on a running server.""" + data = { + "sql_query": "SELECT * FROM data", + "file_uri": "https://example.com/data.csv" + } + + response = httpx.post( + f"{SERVER_URL}/api/v1/query", + data=data, + timeout=10.0 + ) + + assert response.status_code == 200 + result = response.json() + + assert result["status"] == "success" + assert result["query"] == "SELECT * FROM data" + + +@pytest.mark.integration +def test_query_endpoint_validation_live(): + """Test that validation works on a running server.""" + # Missing both file and file_uri + response = httpx.post( + f"{SERVER_URL}/api/v1/query", + data={"sql_query": "SELECT * FROM data"}, + timeout=10.0 + ) + + assert response.status_code == 400 + error = response.json() + assert "detail" in error + + +@pytest.mark.integration +def test_openapi_docs_accessible(): + """Test that the OpenAPI/Swagger documentation is accessible.""" + response = httpx.get(f"{SERVER_URL}/docs", timeout=5.0) + assert response.status_code == 200 + + response = httpx.get(f"{SERVER_URL}/redoc", timeout=5.0) + assert response.status_code == 200 + + +@pytest.mark.integration +def test_openapi_schema_available(): + """Test that the OpenAPI schema is available.""" + response = httpx.get(f"{SERVER_URL}/openapi.json", timeout=5.0) + assert response.status_code == 200 + + schema = response.json() + assert "openapi" in schema + assert "paths" in schema + assert "/api/v1/health" in schema["paths"] + assert "/api/v1/query" in schema["paths"] diff --git a/src/FrontEnd/tests/test_query.py b/src/FrontEnd/tests/test_query.py new file mode 100644 index 0000000..cf99cda --- /dev/null +++ b/src/FrontEnd/tests/test_query.py @@ -0,0 +1,124 @@ +import pytest +from fastapi.testclient import TestClient +import io + + +def test_query_endpoint_with_file(client: TestClient, sample_csv_file): + """Test the query endpoint with a file upload.""" + filename, content, content_type = sample_csv_file + + response = client.post( + "/api/v1/query", + data={"sql_query": "SELECT * FROM data"}, + files={"file": (filename, io.BytesIO(content), content_type)} + ) + + assert response.status_code == 200 + data = response.json() + + assert data["status"] == "success" + assert data["query"] == "SELECT * FROM data" + assert "result" in data + assert "execution_time_ms" in data + + +def test_query_endpoint_with_file_uri(client: TestClient): + """Test the query endpoint with a file URI.""" + response = client.post( + "/api/v1/query", + data={ + "sql_query": "SELECT * FROM data", + "file_uri": "https://example.com/data.csv" + } + ) + + assert response.status_code == 200 + data = response.json() + + assert data["status"] == "success" + assert data["query"] == "SELECT * FROM data" + assert "result" in data + assert "execution_time_ms" in data + + +def test_query_endpoint_missing_both_file_and_uri(client: TestClient): + """Test that the endpoint rejects requests without file or file_uri.""" + response = client.post( + "/api/v1/query", + data={"sql_query": "SELECT * FROM data"} + ) + + assert response.status_code == 400 + data = response.json() + assert "detail" in data + assert "file" in data["detail"].lower() or "uri" in data["detail"].lower() + + +def test_query_endpoint_with_both_file_and_uri(client: TestClient, sample_csv_file): + """Test that the endpoint rejects requests with both file and file_uri.""" + filename, content, content_type = sample_csv_file + + response = client.post( + "/api/v1/query", + data={ + "sql_query": "SELECT * FROM data", + "file_uri": "https://example.com/data.csv" + }, + files={"file": (filename, io.BytesIO(content), content_type)} + ) + + assert response.status_code == 400 + data = response.json() + assert "detail" in data + + +def test_query_endpoint_missing_sql_query(client: TestClient, sample_csv_file): + """Test that the endpoint requires sql_query parameter.""" + filename, content, content_type = sample_csv_file + + response = client.post( + "/api/v1/query", + files={"file": (filename, io.BytesIO(content), content_type)} + ) + + assert response.status_code == 422 # Validation error + + +def test_query_endpoint_response_structure(client: TestClient, sample_csv_file): + """Test that the query endpoint returns the correct response structure.""" + filename, content, content_type = sample_csv_file + + response = client.post( + "/api/v1/query", + data={"sql_query": "SELECT * FROM data WHERE age > 25"}, + files={"file": (filename, io.BytesIO(content), content_type)} + ) + + assert response.status_code == 200 + data = response.json() + + # Check required fields + assert "status" in data + assert "query" in data + assert "result" in data + assert "execution_time_ms" in data + + # Check data types + assert isinstance(data["status"], str) + assert isinstance(data["query"], str) + assert isinstance(data["execution_time_ms"], (int, float)) + + +def test_query_endpoint_with_json_file(client: TestClient, sample_json_file): + """Test the query endpoint with a JSON file upload.""" + filename, content, content_type = sample_json_file + + response = client.post( + "/api/v1/query", + data={"sql_query": "SELECT * FROM data"}, + files={"file": (filename, io.BytesIO(content), content_type)} + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" From 2ff516ea6bb4703e4cded3a57f5cbcc833247d4a Mon Sep 17 00:00:00 2001 From: Marco Ferreira Date: Wed, 19 Nov 2025 20:23:02 -0500 Subject: [PATCH 15/19] docs(Contributing) added frontend test information to Contributing.md --- CONTRIBUTING.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7e5fd6c..61a39a1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -26,7 +26,13 @@ We use a Makefile to simplify common development tasks. All commands should be r make rust-test ``` -### Run All Tests (Go + Rust) +### Frontend Tests +- Run all tests + ```bash + make frontend-test + ``` + +### Run All Tests (Go + Rust + Frontend) - Run tests for both backends ```bash make test-all From a32b36645521aed4c0282e6e0ef8c18951dfaa93 Mon Sep 17 00:00:00 2001 From: Marco Ferreira Date: Wed, 19 Nov 2025 20:32:21 -0500 Subject: [PATCH 16/19] docs(Contributing) added example information to Contributing.md --- CONTRIBUTING.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 61a39a1..0e212de 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -116,6 +116,11 @@ make go-run make rust-run ``` +### Build and Run Frontend +```bash +make frontend-run +``` + ### Run All Tests ```bash make test-all @@ -125,6 +130,7 @@ Or run individually: ```bash make go-test make rust-test +make frontend-test ``` ### Run Linters From 0d2369c649a6c297216201e5f6e72e00607f48ed Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 19 Nov 2025 21:54:57 -0500 Subject: [PATCH 17/19] feat:Implement limitExec operator & include test --- .../opti-sql-go/operators/filter/filter.go | 21 -- .../opti-sql-go/operators/filter/limit.go | 72 +++++++ .../operators/filter/limit_test.go | 204 +++++++++++++++++- 3 files changed, 274 insertions(+), 23 deletions(-) diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index 4195cdd..745349e 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -1,24 +1,3 @@ package filter -import ( - "github.com/apache/arrow/go/v17/arrow" - "github.com/apache/arrow/go/v17/arrow/array" -) - // FilterExpr takes in a field and column and yeilds a function that takes in an index and returns a bool indicating whether the row at that index satisfies the filter condition. -type FilterExpr func(filed arrow.Field, col arrow.Array) func(i int) bool - -// example -func ExampleFilterExpr(field arrow.Field, col arrow.Array) func(i int) bool { - { - if field.Name == "age" && col.DataType().ID() == arrow.INT32 { - return func(i int) bool { - val := col.(*array.Int32).Value(i) - return val > 30 - } - } - return func(i int) bool { - return true - } - } -} diff --git a/src/Backend/opti-sql-go/operators/filter/limit.go b/src/Backend/opti-sql-go/operators/filter/limit.go index 4a28b11..68bbb41 100644 --- a/src/Backend/opti-sql-go/operators/filter/limit.go +++ b/src/Backend/opti-sql-go/operators/filter/limit.go @@ -1 +1,73 @@ package filter + +import ( + "io" + "opti-sql-go/operators" + + "github.com/apache/arrow/go/v17/arrow" +) + +var ( + _ = (operators.Operator)(&LimitExec{}) +) + +// TODO: (1) Implement Limter Exec Operator | pretty straightforward +type LimitExec struct { + input operators.Operator + schema *arrow.Schema + remaining uint16 + done bool +} + +func NewLimitExec(input operators.Operator, count uint16) (*LimitExec, error) { + return &LimitExec{ + input: input, + schema: input.Schema(), + remaining: count, + }, nil +} + +func (l *LimitExec) Next(n uint16) (*operators.RecordBatch, error) { + if n == 0 { + return &operators.RecordBatch{ + Schema: l.schema, + Columns: []arrow.Array{}, + RowCount: 0, + }, nil + } + if l.remaining == 0 { + return nil, io.EOF + } + var childN uint16 + switch { + case n < l.remaining: + // We can fulfill the request fully + childN = n + l.remaining -= n + + case n == l.remaining: + // Exact request - done afterwards + childN = n + l.remaining = 0 + l.done = true + + case n > l.remaining: + // Only have l.remaining left + childN = l.remaining + l.remaining = 0 + l.done = true + } + childBatch, err := l.input.Next(childN) + if err != nil { + return nil, err + } + return childBatch, nil +} +func (l *LimitExec) Schema() *arrow.Schema { + return l.schema +} + +// nothing to close +func (l *LimitExec) Close() error { + return nil +} diff --git a/src/Backend/opti-sql-go/operators/filter/limit_test.go b/src/Backend/opti-sql-go/operators/filter/limit_test.go index dc4683b..480faeb 100644 --- a/src/Backend/opti-sql-go/operators/filter/limit_test.go +++ b/src/Backend/opti-sql-go/operators/filter/limit_test.go @@ -1,7 +1,207 @@ package filter -import "testing" +import ( + "errors" + "io" + "opti-sql-go/operators/project" + "testing" +) -func TestLimit(t *testing.T) { +func generateTestColumns() ([]string, []any) { + names := []string{ + "id", + "name", + "age", + "salary", + "is_active", + "department", + "rating", + "years_experience", + } + + columns := []any{ + []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + []string{ + "Alice", "Bob", "Charlie", "David", "Eve", + "Frank", "Grace", "Hannah", "Ivy", "Jake", + }, + []int32{28, 34, 45, 22, 31, 29, 40, 36, 50, 26}, + []float64{ + 70000.0, 82000.5, 54000.0, 91000.0, 60000.0, + 75000.0, 66000.0, 88000.0, 45000.0, 99000.0, + }, + []bool{true, false, true, true, false, false, true, true, false, true}, + []string{ + "Engineering", "HR", "Engineering", "Sales", "Finance", + "Sales", "Support", "Engineering", "HR", "Finance", + }, + []float32{4.5, 3.8, 4.2, 2.9, 5.0, 4.3, 3.7, 4.9, 4.1, 3.5}, + []int32{1, 5, 10, 2, 7, 3, 6, 12, 4, 8}, + } + + return names, columns +} +func basicProject() *project.InMemorySource { + names, col := generateTestColumns() + v, _ := project.NewInMemoryProjectExec(names, col) + return v +} +func TestLimitInit(t *testing.T) { // Simple passing test + trialProject := basicProject() + _, err := NewLimitExec(trialProject, 4) + if err != nil { + t.Fatalf("error creating LimitExec :%v", err) + } +} + +func TestLimitExec_InitAndSchema(t *testing.T) { + t.Run("Init OK", func(t *testing.T) { + proj := basicProject() + lim, err := NewLimitExec(proj, 5) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if lim.Schema() == nil { + t.Fatalf("expected non-nil schema") + } + }) + + t.Run("Init Zero Limit", func(t *testing.T) { + proj := basicProject() + lim, err := NewLimitExec(proj, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + _, err = lim.Next(3) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF for zero limit, got %v", err) + } + }) + + t.Run("Init DoesNotModifyUnderlyingSchema", func(t *testing.T) { + proj := basicProject() + origSchema := proj.Schema() + lim, _ := NewLimitExec(proj, 10) + + if !lim.Schema().Equal(origSchema) { + t.Fatalf("schema mismatch: expected %v got %v", origSchema, lim.Schema()) + } + }) +} + +func TestLimitExec_NextBehavior(t *testing.T) { + names, cols := generateTestColumns() + memSrc, err := project.NewInMemoryProjectExec(names, cols) + if err != nil { + t.Fatalf("failed to init memory source: %v", err) + } + + t.Run("n < remaining", func(t *testing.T) { + lim, _ := NewLimitExec(memSrc, 5) + rb, err := lim.Next(3) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rb.RowCount != 3 { + t.Fatalf("expected 3 rows, got %d", rb.RowCount) + } + }) + + t.Run("n == remaining", func(t *testing.T) { + memSrc2, _ := project.NewInMemoryProjectExec(names, cols) + lim, _ := NewLimitExec(memSrc2, 4) + rb, err := lim.Next(4) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rb.RowCount != 4 { + t.Fatalf("expected 4 rows, got %d", rb.RowCount) + } + _, err = lim.Next(2) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF after exact match, got %v", err) + } + }) + + t.Run("n > remaining", func(t *testing.T) { + memSrc3, _ := project.NewInMemoryProjectExec(names, cols) + lim, _ := NewLimitExec(memSrc3, 3) + rb, err := lim.Next(10) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + if rb.RowCount != 3 { + t.Fatalf("expected 3 rows, got %d", rb.RowCount) + } + _, err = lim.Next(10) + if !errors.Is(err, io.EOF) { + t.Fatalf("was expecting io.EOF but recieved %v", err) + } + }) +} +func TestLimitExec_IterationUntilEOF(t *testing.T) { + names, cols := generateTestColumns() + memSrc, _ := project.NewInMemoryProjectExec(names, cols) + + t.Run("ConsumeInMultipleBatches", func(t *testing.T) { + lim, _ := NewLimitExec(memSrc, 7) + + total := 0 + for { + rb, err := lim.Next(3) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Fatalf("unexpected error: %v", err) + } + + total += int(rb.RowCount) + + for _, c := range rb.Columns { + c.Release() + } + } + if total != 7 { + t.Fatalf("expected 7 total rows, got %d", total) + } + lim.Close() + }) + + t.Run("RequestZeroDoesNotChangeLimit", func(t *testing.T) { + memSrc2, _ := project.NewInMemoryProjectExec(names, cols) + lim, _ := NewLimitExec(memSrc2, 5) + + rb, err := lim.Next(0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rb.RowCount != 0 { + t.Fatalf("expected zero rowcount, got %d", rb.RowCount) + } + + rb2, err := lim.Next(2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rb2.RowCount != 2 { + t.Fatalf("expected 2 rows, got %d", rb2.RowCount) + } + lim.Close() + }) + + t.Run("AfterEOFAlwaysEOF", func(t *testing.T) { + memSrc3, _ := project.NewInMemoryProjectExec(names, cols) + lim, _ := NewLimitExec(memSrc3, 2) + + lim.Next(3) // exhaust + + _, err := lim.Next(1) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF, got %v", err) + } + lim.Close() + }) } From 0f8a49ffcc12de00154dd18c516672fa7de951fb Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Thu, 20 Nov 2025 11:38:18 -0500 Subject: [PATCH 18/19] feat: Implement filter operator with support for expressions --- CONTRIBUTING.md | 2 +- src/Backend/opti-sql-go/Expr/expr.go | 110 ++++-- src/Backend/opti-sql-go/Expr/expr_test.go | 318 +++++++++++++++-- .../opti-sql-go/operators/filter/filter.go | 134 ++++++++ .../operators/filter/filter_test.go | 320 +++++++++++++++++- .../operators/filter/limit_test.go | 12 +- .../opti-sql-go/operators/project/parquet.go | 41 ++- .../operators/project/parquet_test.go | 20 +- .../operators/project/projectExec.go | 11 +- .../operators/project/projectExecExpr_test.go | 2 +- .../operators/project/projectExec_test.go | 20 -- src/Backend/opti-sql-go/operators/record.go | 7 + 12 files changed, 893 insertions(+), 104 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d4f60d4..559fd67 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -21,7 +21,7 @@ We use a Makefile to simplify common development tasks. All commands should be r ``` - Run test with html coverage ```bash - go test ./... -coverprofile=coverage.out + go test -count=1 ./... -coverprofile=coverage.out go tool cover -html=coverage.out ``` diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go index a7c9daf..0829819 100644 --- a/src/Backend/opti-sql-go/Expr/expr.go +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -16,6 +16,9 @@ var ( ErrUnsupportedExpression = func(info string) error { return fmt.Errorf("unsupported expression passed to EvalExpression: %s", info) } + ErrCantCompareDifferentTypes = func(leftType, rightType arrow.DataType) error { + return fmt.Errorf("cannot compare different data types: %s and %s", leftType, rightType) + } ) type binaryOperator int @@ -36,7 +39,6 @@ const ( // logical And binaryOperator = 12 Or binaryOperator = 13 - Not binaryOperator = 14 ) type supportedFunctions int @@ -103,36 +105,45 @@ func EvalExpression(expr Expression, batch *operators.RecordBatch) (arrow.Array, } } -func ExprDataType(e Expression, inputSchema *arrow.Schema) arrow.DataType { +func ExprDataType(e Expression, inputSchema *arrow.Schema) (arrow.DataType, error) { switch ex := e.(type) { case *LiteralResolve: - return ex.Type + return ex.Type, nil case *ColumnResolve: idx := inputSchema.FieldIndices(ex.Name) if len(idx) == 0 { - panic(fmt.Sprintf("exprDataType: unknown column %q", ex.Name)) + return nil, fmt.Errorf("exprDataType: unknown column %q", ex.Name) } - return inputSchema.Field(idx[0]).Type + return inputSchema.Field(idx[0]).Type, nil case *Alias: // alias does NOT change type return ExprDataType(ex.Expr, inputSchema) case *CastExpr: - return ex.TargetType + return ex.TargetType, nil case *BinaryExpr: - leftType := ExprDataType(ex.Left, inputSchema) - rightType := ExprDataType(ex.Right, inputSchema) - return inferBinaryType(leftType, ex.Op, rightType) + leftType, err := ExprDataType(ex.Left, inputSchema) + if err != nil { + return nil, err + } + rightType, err := ExprDataType(ex.Right, inputSchema) + if err != nil { + return nil, err + } + return inferBinaryType(leftType, ex.Op, rightType), nil case *ScalarFunction: - argType := ExprDataType(ex.Arguments, inputSchema) - return inferScalarFunctionType(ex.Function, argType) + argType, err := ExprDataType(ex.Arguments, inputSchema) + if err != nil { + return nil, err + } + return inferScalarFunctionType(ex.Function, argType), nil default: - panic(fmt.Sprintf("unsupported expr type %T", ex)) + return nil, ErrUnsupportedExpression(ex.String()) } } func NewExpressions(exprs ...Expression) []Expression { @@ -403,25 +414,80 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error return unpackDatum(datum) // comparisions TODO: + // These return a boolean array case Equal: - return nil, fmt.Errorf("operator Equal (%d) not yet implemented", b.Op) + if leftArr.DataType() != rightArr.DataType() { + return nil, ErrCantCompareDifferentTypes(leftArr.DataType(), rightArr.DataType()) + } + datum, err := compute.CallFunction(context.Background(), "equal", compute.DefaultFilterOptions(), compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) case NotEqual: - return nil, fmt.Errorf("operator NotEqual (%d) not yet implemented", b.Op) + if leftArr.DataType() != rightArr.DataType() { + return nil, ErrCantCompareDifferentTypes(leftArr.DataType(), rightArr.DataType()) + } + datum, err := compute.CallFunction(context.Background(), "not_equal", compute.DefaultFilterOptions(), compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) case LessThan: - return nil, fmt.Errorf("operator LessThan (%d) not yet implemented", b.Op) + if leftArr.DataType() != rightArr.DataType() { + return nil, ErrCantCompareDifferentTypes(leftArr.DataType(), rightArr.DataType()) + } + datum, err := compute.CallFunction(context.Background(), "less", compute.DefaultFilterOptions(), compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) case LessThanOrEqual: - return nil, fmt.Errorf("operator LessThanOrEqual (%d) not yet implemented", b.Op) + if leftArr.DataType() != rightArr.DataType() { + return nil, ErrCantCompareDifferentTypes(leftArr.DataType(), rightArr.DataType()) + } + datum, err := compute.CallFunction(context.Background(), "less_equal", compute.DefaultFilterOptions(), compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) case GreaterThan: - return nil, fmt.Errorf("operator GreaterThan (%d) not yet implemented", b.Op) + if leftArr.DataType() != rightArr.DataType() { + return nil, ErrCantCompareDifferentTypes(leftArr.DataType(), rightArr.DataType()) + } + datum, err := compute.CallFunction(context.Background(), "greater", compute.DefaultFilterOptions(), compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) case GreaterThanOrEqual: - return nil, fmt.Errorf("operator GreaterThanOrEqual (%d) not yet implemented", b.Op) + if leftArr.DataType() != rightArr.DataType() { + return nil, ErrCantCompareDifferentTypes(leftArr.DataType(), rightArr.DataType()) + } + datum, err := compute.CallFunction(context.Background(), "greater_equal", compute.DefaultFilterOptions(), compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) // logical case And: - return nil, fmt.Errorf("operator And (%d) not yet implemented", b.Op) + if leftArr.DataType() != rightArr.DataType() { + return nil, ErrCantCompareDifferentTypes(leftArr.DataType(), rightArr.DataType()) + } + datum, err := compute.CallFunction(context.Background(), "and", compute.DefaultFilterOptions(), compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) case Or: - return nil, fmt.Errorf("operator Or (%d) not yet implemented", b.Op) - case Not: - return nil, fmt.Errorf("operator Not (%d) not yet implemented", b.Op) + if leftArr.DataType() != rightArr.DataType() { + return nil, ErrCantCompareDifferentTypes(leftArr.DataType(), rightArr.DataType()) + } + datum, err := compute.CallFunction(context.Background(), "or", compute.DefaultFilterOptions(), compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + if err != nil { + return nil, err + } + return unpackDatum(datum) } return nil, fmt.Errorf("binary operator %d not supported", b.Op) } diff --git a/src/Backend/opti-sql-go/Expr/expr_test.go b/src/Backend/opti-sql-go/Expr/expr_test.go index f5555d7..54aaa24 100644 --- a/src/Backend/opti-sql-go/Expr/expr_test.go +++ b/src/Backend/opti-sql-go/Expr/expr_test.go @@ -984,7 +984,7 @@ func TestExprDataType(t *testing.T) { t.Run("Literal_ReturnsType", func(t *testing.T) { lit := &LiteralResolve{Type: arrow.PrimitiveTypes.Int32} - got := ExprDataType(lit, nil) + got, _ := ExprDataType(lit, nil) if got.ID() != arrow.INT32 { t.Fatalf("expected INT32, got %s", got) } @@ -992,24 +992,22 @@ func TestExprDataType(t *testing.T) { t.Run("Column_ReturnsSchemaType", func(t *testing.T) { col := &ColumnResolve{Name: "age"} - got := ExprDataType(col, schema) + got, _ := ExprDataType(col, schema) if got.ID() != arrow.INT32 { t.Fatalf("expected INT32, got %s", got) } }) t.Run("Column_Unknown_Panics", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Fatalf("expected panic for unknown column, got none") - } - }() - _ = ExprDataType(&ColumnResolve{Name: "missing"}, schema) + _, err := ExprDataType(&ColumnResolve{Name: "missing"}, schema) + if err == nil { + t.Fatalf("expected error for unknown column, got none") + } }) t.Run("Alias_PreservesType", func(t *testing.T) { a := &Alias{Expr: &LiteralResolve{Type: arrow.PrimitiveTypes.Float64}, Name: "f"} - got := ExprDataType(a, schema) + got, _ := ExprDataType(a, schema) if got.ID() != arrow.FLOAT64 { t.Fatalf("expected FLOAT64, got %s", got) } @@ -1017,7 +1015,7 @@ func TestExprDataType(t *testing.T) { t.Run("Cast_ReturnsTargetType", func(t *testing.T) { c := &CastExpr{Expr: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}, TargetType: arrow.PrimitiveTypes.Float64} - got := ExprDataType(c, schema) + got, _ := ExprDataType(c, schema) if got.ID() != arrow.FLOAT64 { t.Fatalf("expected FLOAT64, got %s", got) } @@ -1025,7 +1023,7 @@ func TestExprDataType(t *testing.T) { t.Run("Binary_Arithmetic_PromotesToFloat64", func(t *testing.T) { be := &BinaryExpr{Left: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}, Op: Addition, Right: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}} - got := ExprDataType(be, schema) + got, _ := ExprDataType(be, schema) if got.ID() != arrow.FLOAT64 { t.Fatalf("expected FLOAT64 from numericPromotion, got %s", got) } @@ -1033,7 +1031,7 @@ func TestExprDataType(t *testing.T) { t.Run("Binary_Comparison_ReturnsBoolean", func(t *testing.T) { be := &BinaryExpr{Left: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}, Op: Equal, Right: &LiteralResolve{Type: arrow.PrimitiveTypes.Int32}} - got := ExprDataType(be, schema) + got, _ := ExprDataType(be, schema) if got.ID() != arrow.BOOL { t.Fatalf("expected BOOL from comparison, got %s", got) } @@ -1041,20 +1039,18 @@ func TestExprDataType(t *testing.T) { t.Run("ScalarFunction_Upper_String", func(t *testing.T) { sf := &ScalarFunction{Function: Upper, Arguments: &LiteralResolve{Type: arrow.BinaryTypes.String}} - got := ExprDataType(sf, schema) + got, _ := ExprDataType(sf, schema) if got.ID() != arrow.STRING { t.Fatalf("expected STRING from Upper, got %s", got) } }) t.Run("UnsupportedExpr_Panics", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Fatalf("expected panic for unsupported expr type, got none") - } - }() // InvariantExpr is a test-only expr that should hit the default case - _ = ExprDataType(&InvariantExpr{}, schema) + _, err := ExprDataType(&InvariantExpr{}, schema) + if err == nil { + t.Fatalf("expected error for unsupported expr type, got none") + } }) } @@ -1080,15 +1076,6 @@ func TestInferBinaryType(t *testing.T) { } }) - t.Run("UnsupportedOperator_Panics", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Fatalf("expected panic for unsupported operator, got none") - } - }() - // Use 'Not' which is not handled in inferBinaryType switch and should panic - _ = inferBinaryType(arrow.PrimitiveTypes.Int32, Not, arrow.PrimitiveTypes.Int32) - }) } func TestNumericPromotion(t *testing.T) { @@ -1192,9 +1179,276 @@ func TestExprInitMethods(t *testing.T) { } }) } +func TestFilterBinaryExpr(t *testing.T) { + t.Run("age == 22", func(t *testing.T) { + rc := generateTestColumns() //4 + literal := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(22)) + col := NewColumnResolve("age") + be := NewBinaryExpr(col, Equal, literal) + arr, err := EvalExpression(be, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := arr.(*array.Boolean) + t.Logf("out:%v\n", out) + expected := []bool{false, false, false, true} + if len(expected) != out.Len() { + t.Fatalf("length mismatch: expected %d, got %d", len(expected), out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != expected[i] { + t.Fatalf("at index %d: expected %v, got %v", i, expected[i], out.Value(i)) + } + } + }) + t.Run("age != 22", func(t *testing.T) { + rc := generateTestColumns() + literal := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(22)) + col := NewColumnResolve("age") + be := NewBinaryExpr(col, NotEqual, literal) -func TestUnimplemntedOperators(t *testing.T) { - for i := Equal; i <= Or; i++ { - NewBinaryExpr(NewLiteralResolve(arrow.PrimitiveTypes.Int16, int16(10)), 3, NewLiteralResolve(arrow.PrimitiveTypes.Int16, int16(5))) - } + arr, err := EvalExpression(be, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Boolean) + expected := []bool{true, true, true, false} + + if out.Len() != len(expected) { + t.Fatalf("length mismatch: expected %d, got %d", len(expected), out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != expected[i] { + t.Fatalf("at index %d: expected %v, got %v", i, expected[i], out.Value(i)) + } + } + }) + t.Run("age < 34", func(t *testing.T) { + rc := generateTestColumns() + literal := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(34)) + col := NewColumnResolve("age") + be := NewBinaryExpr(col, LessThan, literal) + + arr, err := EvalExpression(be, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Boolean) + expected := []bool{true, false, false, true} + + if out.Len() != len(expected) { + t.Fatalf("length mismatch: expected %d, got %d", len(expected), out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != expected[i] { + t.Fatalf("index %d expected %v got %v", i, expected[i], out.Value(i)) + } + } + }) + t.Run("age <= 34", func(t *testing.T) { + rc := generateTestColumns() + literal := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(34)) + col := NewColumnResolve("age") + be := NewBinaryExpr(col, LessThanOrEqual, literal) + + arr, err := EvalExpression(be, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Boolean) + expected := []bool{true, true, false, true} + + if out.Len() != len(expected) { + t.Fatalf("length mismatch: expected %d, got %d", len(expected), out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != expected[i] { + t.Fatalf("index %d expected %v got %v", i, expected[i], out.Value(i)) + } + } + }) + t.Run("age > 30", func(t *testing.T) { + rc := generateTestColumns() + literal := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(30)) + col := NewColumnResolve("age") + be := NewBinaryExpr(col, GreaterThan, literal) + + arr, err := EvalExpression(be, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Boolean) + expected := []bool{false, true, true, false} + + if out.Len() != len(expected) { + t.Fatalf("length mismatch: expected %d, got %d", len(expected), out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != expected[i] { + t.Fatalf("index %d expected %v got %v", i, expected[i], out.Value(i)) + } + } + }) + t.Run("age >= 34", func(t *testing.T) { + rc := generateTestColumns() + literal := NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(34)) + col := NewColumnResolve("age") + be := NewBinaryExpr(col, GreaterThanOrEqual, literal) + + arr, err := EvalExpression(be, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Boolean) + expected := []bool{false, true, true, false} + + if out.Len() != len(expected) { + t.Fatalf("length mismatch: expected %d, got %d", len(expected), out.Len()) + } + for i := 0; i < out.Len(); i++ { + if out.Value(i) != expected[i] { + t.Fatalf("index %d expected %v got %v", i, expected[i], out.Value(i)) + } + } + }) + t.Run("logical AND: (age > 30) AND is_active", func(t *testing.T) { + rc := generateTestColumns() + + left := NewBinaryExpr( + NewColumnResolve("age"), + GreaterThan, + NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(30)), + ) + + right := NewBinaryExpr( + NewColumnResolve("is_active"), + Equal, + NewLiteralResolve(arrow.FixedWidthTypes.Boolean, true), + ) + + be := NewBinaryExpr(left, And, right) + arr, err := EvalExpression(be, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Boolean) + expected := []bool{false, false, true, false} + + for i := range expected { + if out.Value(i) != expected[i] { + t.Fatalf("index %d: expected %v got %v", i, expected[i], out.Value(i)) + } + } + }) + t.Run("logical OR: (age < 30) OR is_active", func(t *testing.T) { + rc := generateTestColumns() + + left := NewBinaryExpr( + NewColumnResolve("age"), + LessThan, + NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(30)), + ) + + right := NewBinaryExpr( + NewColumnResolve("is_active"), + Equal, + NewLiteralResolve(arrow.FixedWidthTypes.Boolean, true), + ) + + be := NewBinaryExpr(left, Or, right) + arr, err := EvalExpression(be, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + out := arr.(*array.Boolean) + expected := []bool{true, false, true, true} + + for i := range expected { + if out.Value(i) != expected[i] { + t.Fatalf("index %d: expected %v got %v", i, expected[i], out.Value(i)) + } + } + }) + +} + +func TestFilterBinaryExpr_InvalidTypes(t *testing.T) { + rc := generateTestColumns() // 4 rows + + // LEFT = age (int32) + left := NewColumnResolve("age") + + // RIGHT = name (string) → mismatched type + right := NewColumnResolve("name") + + t.Run("invalid Equal", func(t *testing.T) { + be := NewBinaryExpr(left, Equal, right) + _, err := EvalExpression(be, rc) + if err == nil { + t.Fatalf("expected error for mismatched datatypes (Equal), got nil") + } + }) + + t.Run("invalid NotEqual", func(t *testing.T) { + be := NewBinaryExpr(left, NotEqual, right) + _, err := EvalExpression(be, rc) + if err == nil { + t.Fatalf("expected error for mismatched datatypes (NotEqual), got nil") + } + }) + + t.Run("invalid LessThan", func(t *testing.T) { + be := NewBinaryExpr(left, LessThan, right) + _, err := EvalExpression(be, rc) + if err == nil { + t.Fatalf("expected error for mismatched datatypes (LessThan), got nil") + } + }) + + t.Run("invalid LessThanOrEqual", func(t *testing.T) { + be := NewBinaryExpr(left, LessThanOrEqual, right) + _, err := EvalExpression(be, rc) + if err == nil { + t.Fatalf("expected error for mismatched datatypes (LessThanOrEqual), got nil") + } + }) + + t.Run("invalid GreaterThan", func(t *testing.T) { + be := NewBinaryExpr(left, GreaterThan, right) + _, err := EvalExpression(be, rc) + if err == nil { + t.Fatalf("expected error for mismatched datatypes (GreaterThan), got nil") + } + }) + + t.Run("invalid GreaterThanOrEqual", func(t *testing.T) { + be := NewBinaryExpr(left, GreaterThanOrEqual, right) + _, err := EvalExpression(be, rc) + if err == nil { + t.Fatalf("expected error for mismatched datatypes (GreaterThanOrEqual), got nil") + } + }) + + t.Run("invalid AND", func(t *testing.T) { + be := NewBinaryExpr(left, And, right) + _, err := EvalExpression(be, rc) + if err == nil { + t.Fatalf("expected error for mismatched datatypes (AND), got nil") + } + }) + + t.Run("invalid OR", func(t *testing.T) { + be := NewBinaryExpr(left, Or, right) + _, err := EvalExpression(be, rc) + if err == nil { + t.Fatalf("expected error for mismatched datatypes (OR), got nil") + } + }) } diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index 745349e..8d839a3 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -1,3 +1,137 @@ package filter +import ( + "context" + "errors" + "io" + "opti-sql-go/Expr" + "opti-sql-go/operators" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/compute" +) + +var ( + _ = (operators.Operator)(&FilterExec{}) +) + // FilterExpr takes in a field and column and yeilds a function that takes in an index and returns a bool indicating whether the row at that index satisfies the filter condition. +type FilterExec struct { + input operators.Operator + schema *arrow.Schema + predicate Expr.Expression + done bool +} + +func NewFilterExec(input operators.Operator, pred Expr.Expression) (*FilterExec, error) { + if !validPredicates(pred, input.Schema()) { + return nil, errors.New("predicates passed to FilterExec are invalid") + } + return &FilterExec{ + input: input, + predicate: pred, + schema: input.Schema(), + }, nil +} +func (f *FilterExec) Next(n uint16) (*operators.RecordBatch, error) { + if n == 0 { + return nil, errors.New("must pass in wanted batch size > 0") + } + if f.done { + return nil, io.EOF + } + batch, err := f.input.Next(n) + if err != nil { + return nil, err + } + booleanMask, err := Expr.EvalExpression(f.predicate, batch) + if err != nil { + return nil, err + } + boolArr := booleanMask.(*array.Boolean) // impossible for this to not be a boolean array,assuming validPredicates works as it should + filteredCol := make([]arrow.Array, len(batch.Columns)) + for i, col := range batch.Columns { + filteredCol[i], err = applyBooleanMask(col, boolArr) + if err != nil { + return nil, err + } + } + // release old columns + for _, c := range batch.Columns { + c.Release() + } + size := uint64(filteredCol[0].Len()) + + return &operators.RecordBatch{ + Schema: batch.Schema, + Columns: filteredCol, + RowCount: size, + }, nil +} +func (f *FilterExec) Schema() *arrow.Schema { + return f.schema +} + +// TODO: check if this pattern is good +func (f *FilterExec) Close() error { + return f.input.Close() +} + +func applyBooleanMask(col arrow.Array, mask *array.Boolean) (arrow.Array, error) { + datum, err := compute.Filter( + context.TODO(), + compute.NewDatum(col), + compute.NewDatum(mask), + *compute.DefaultFilterOptions(), + ) + if err != nil { + return nil, err + } + + arr := datum.(*compute.ArrayDatum).MakeArray() + return arr, nil +} +func validPredicates(pred Expr.Expression, schema *arrow.Schema) bool { + switch p := pred.(type) { + case *Expr.ColumnResolve: + idx := schema.FieldIndices(p.Name) + if len(idx) == 0 { + return false + } + return true + + case *Expr.BinaryExpr: + // Check valid operator + // these return boolean arrays + switch p.Op { + case Expr.Equal, Expr.NotEqual, + Expr.GreaterThan, Expr.GreaterThanOrEqual, + Expr.LessThan, Expr.LessThanOrEqual, + Expr.And, Expr.Or: + // supported + default: + return false + } + dt1, err := Expr.ExprDataType(p.Left, schema) + if err != nil { + return false + } + dt2, err := Expr.ExprDataType(p.Right, schema) + if err != nil { + return false + } + if !arrow.TypeEqual(dt1, dt2) { + return false + } + // recursively validate children + return validPredicates(p.Left, schema) && + validPredicates(p.Right, schema) + + case *Expr.LiteralResolve: + return true + + default: + return false + } +} diff --git a/src/Backend/opti-sql-go/operators/filter/filter_test.go b/src/Backend/opti-sql-go/operators/filter/filter_test.go index c71325a..698f2b6 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter_test.go +++ b/src/Backend/opti-sql-go/operators/filter/filter_test.go @@ -1,7 +1,321 @@ package filter -import "testing" +import ( + "errors" + "io" + "opti-sql-go/Expr" + "testing" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +func TestFilterInit(t *testing.T) { + proj := basicProject() + predicate := Expr.NewBinaryExpr(Expr.NewColumnResolve("age"), Expr.GreaterThan, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, 30)) + _, err := NewFilterExec(proj, predicate) + if err != nil { + t.Fatalf("failed to create filter exec: %v", err) + } +} +func TestFilterInit_1(t *testing.T) { + t.Run("simple greater-than predicate", func(t *testing.T) { + proj := basicProject() + predicate := Expr.NewBinaryExpr( + Expr.NewColumnResolve("age"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(30)), + ) + _, err := NewFilterExec(proj, predicate) + if err != nil { + t.Fatalf("failed to create filter exec: %v", err) + } + }) + + t.Run("equals predicate", func(t *testing.T) { + proj := basicProject() + predicate := Expr.NewBinaryExpr( + Expr.NewColumnResolve("is_active"), + Expr.Equal, + Expr.NewLiteralResolve(arrow.FixedWidthTypes.Boolean, true), + ) + _, err := NewFilterExec(proj, predicate) + if err != nil { + t.Fatalf("failed to create filter exec: %v", err) + } + }) + + t.Run("invalid column name", func(t *testing.T) { + proj := basicProject() + predicate := Expr.NewBinaryExpr( + Expr.NewColumnResolve("does_not_exist"), + Expr.Equal, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(1)), + ) + _, err := NewFilterExec(proj, predicate) + if err == nil { + t.Fatalf("expected error for missing column, got nil") + } + }) + + t.Run("nil predicate should fail", func(t *testing.T) { + proj := basicProject() + _, err := NewFilterExec(proj, nil) + if err == nil { + t.Fatalf("expected error for nil predicate") + } + }) +} + +func TestFilterExec_BasicPredicates(t *testing.T) { + + t.Run("age > 30 returns correct rows", func(t *testing.T) { + proj := basicProject() + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("age"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(30)), + ) + + f, _ := NewFilterExec(proj, pred) + + rb, err := f.Next(10) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // --- use ColumnByName + correct type assertion --- + raw, _ := rb.ColumnByName("age") + ageCol, ok := raw.(*array.Int32) + if !ok { + t.Fatalf("expected Int32 column, got %T", raw) + } + + expected := []int32{34, 45, 31, 40, 36, 50} + + for i := range expected { + if ageCol.Value(i) != expected[i] { + t.Fatalf("index %d expected %d got %d", i, expected[i], ageCol.Value(i)) + } + } + }) + + t.Run("is_active == true", func(t *testing.T) { + proj := basicProject() + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("is_active"), + Expr.Equal, + Expr.NewLiteralResolve(arrow.FixedWidthTypes.Boolean, true), + ) + + f, _ := NewFilterExec(proj, pred) + + rb, err := f.Next(20) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + raw, _ := rb.ColumnByName("is_active") + boolCol, ok := raw.(*array.Boolean) + if !ok { + t.Fatalf("expected Boolean column, got %T", raw) + } + + for i := 0; i < boolCol.Len(); i++ { + if !boolCol.Value(i) { + t.Fatalf("expected all rows to be true") + } + } + }) + + t.Run("salary < 60000", func(t *testing.T) { + proj := basicProject() + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("salary"), + Expr.LessThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, float64(60000)), + ) + + f, _ := NewFilterExec(proj, pred) + + rb, err := f.Next(10) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + raw, _ := rb.ColumnByName("salary") + salCol, ok := raw.(*array.Float64) + if !ok { + t.Fatalf("expected Float64 column, got %T", raw) + } + + expected := []float64{54000.0, 45000.0} + + for i := range expected { + if salCol.Value(i) != expected[i] { + t.Fatalf("expected %v got %v", expected[i], salCol.Value(i)) + } + } + }) + + t.Run("department == 'Engineering'", func(t *testing.T) { + proj := basicProject() + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("department"), + Expr.Equal, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, "Engineering"), + ) + + f, _ := NewFilterExec(proj, pred) + + rb, err := f.Next(10) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + raw, _ := rb.ColumnByName("department") + deptCol, ok := raw.(*array.String) + if !ok { + t.Fatalf("expected String column, got %T", raw) + } + + for i := 0; i < deptCol.Len(); i++ { + if deptCol.Value(i) != "Engineering" { + t.Fatalf("expected Engineering got %s", deptCol.Value(i)) + } + } + }) +} + +func TestFilterExec_EdgeCases(t *testing.T) { + + t.Run("Next(0) returns empty batch", func(t *testing.T) { + proj := basicProject() + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("age"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(20)), + ) + + f, _ := NewFilterExec(proj, pred) + + _, err := f.Next(0) + if err == nil { + t.Fatalf("exepected error but got %v", err) + } + }) + + t.Run("EOF after consuming all rows", func(t *testing.T) { + proj := basicProject() + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("age"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(0)), + ) + + f, _ := NewFilterExec(proj, pred) + + _, _ = f.Next(50) + _, err := f.Next(10) + + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF, got %v", err) + } + }) + + t.Run("predicate that always returns false", func(t *testing.T) { + proj := basicProject() + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("age"), + Expr.Equal, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(-1)), + ) + + f, _ := NewFilterExec(proj, pred) + + rb, err := f.Next(20) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + + if rb.RowCount != 0 { + t.Fatalf("expected 0 rows, got %d", rb.RowCount) + } + }) + + t.Run("incompatible predicate types → error", func(t *testing.T) { + proj := basicProject() + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("age"), // int32 + Expr.Equal, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, "bad"), // string + ) + + _, err := NewFilterExec(proj, pred) + if err == nil { + t.Fatalf("expected type error for invalid predicate") + } + }) +} + +func TestFilterExecVariantCase(t *testing.T) { + t.Run("filter done", func(t *testing.T) { + proj := basicProject() + predicate := Expr.NewBinaryExpr(Expr.NewColumnResolve("age"), Expr.GreaterThan, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(30))) + f, _ := NewFilterExec(proj, predicate) + _, err := f.Next(1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + f.done = true + _, err = f.Next(1) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF error, got %v", err) + } + + }) + t.Run("filter schema ", func(t *testing.T) { + proj := basicProject() + predicate := Expr.NewBinaryExpr(Expr.NewColumnResolve("age"), Expr.GreaterThan, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(30))) + f, _ := NewFilterExec(proj, predicate) + t.Logf("%s", f.Schema()) + if !f.schema.Equal(proj.Schema()) { + t.Fatalf("expected schema to match input schema") + } + + }) + t.Run("filter close ", func(t *testing.T) { + proj := basicProject() + predicate := Expr.NewBinaryExpr(Expr.NewColumnResolve("age"), Expr.GreaterThan, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(30))) + f, _ := NewFilterExec(proj, predicate) + if f.Close() != nil { + t.Fatalf("expected nil error on close") + } + }) + t.Run("filter unsupported binary operator ", func(t *testing.T) { + proj := basicProject() + predicate := Expr.NewBinaryExpr(Expr.NewColumnResolve("age"), Expr.Addition, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int32, int32(30))) + _, err := NewFilterExec(proj, predicate) + if err == nil { + t.Fatalf("expected error for unsupported binary operator") + } + }) + t.Run("filter empty column resolve ", func(t *testing.T) { + proj := basicProject() + predicate := Expr.NewColumnResolve("doesnt-exist") + _, err := NewFilterExec(proj, predicate) + if err == nil { + t.Fatalf("expected error for empty column resolve") + } + + }) -func TestFilter(t *testing.T) { - // Simple passing test } diff --git a/src/Backend/opti-sql-go/operators/filter/limit_test.go b/src/Backend/opti-sql-go/operators/filter/limit_test.go index 480faeb..337d423 100644 --- a/src/Backend/opti-sql-go/operators/filter/limit_test.go +++ b/src/Backend/opti-sql-go/operators/filter/limit_test.go @@ -167,7 +167,9 @@ func TestLimitExec_IterationUntilEOF(t *testing.T) { if total != 7 { t.Fatalf("expected 7 total rows, got %d", total) } - lim.Close() + if err := lim.Close(); err != nil { + t.Fatalf("unexpected error on close: %v", err) + } }) t.Run("RequestZeroDoesNotChangeLimit", func(t *testing.T) { @@ -189,19 +191,21 @@ func TestLimitExec_IterationUntilEOF(t *testing.T) { if rb2.RowCount != 2 { t.Fatalf("expected 2 rows, got %d", rb2.RowCount) } - lim.Close() + if err := lim.Close(); err != nil { + t.Fatalf("unexpected error on close: %v", err) + } }) t.Run("AfterEOFAlwaysEOF", func(t *testing.T) { memSrc3, _ := project.NewInMemoryProjectExec(names, cols) lim, _ := NewLimitExec(memSrc3, 2) - lim.Next(3) // exhaust + _, _ = lim.Next(3) // exhaust _, err := lim.Next(1) if !errors.Is(err, io.EOF) { t.Fatalf("expected EOF, got %v", err) } - lim.Close() + _ = lim.Close() }) } diff --git a/src/Backend/opti-sql-go/operators/project/parquet.go b/src/Backend/opti-sql-go/operators/project/parquet.go index 50b04a4..ec70c3e 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/parquet.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "opti-sql-go/operators" - "opti-sql-go/operators/filter" "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" @@ -23,8 +22,7 @@ var ( type ParquetSource struct { // existing fields schema *arrow.Schema - projectionPushDown []string // columns to project up - predicatePushDown []filter.FilterExpr // simple predicate push down for now + projectionPushDown []string // columns to project up reader pqarrow.RecordReader // for internal reading done bool // if set to true always return io.EOF @@ -45,7 +43,7 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { arrowReader, err := pqarrow.NewFileReader( filerReader, - pqarrow.ArrowReadProperties{Parallel: true, BatchSize: 5}, // TODO: Read in from config for this stuff + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: 1}, // TODO: Read in from config for this stuff allocator, ) if err != nil { @@ -59,14 +57,13 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { return &ParquetSource{ schema: rdr.Schema(), projectionPushDown: []string{}, - predicatePushDown: nil, reader: rdr, }, nil } // source, columns you want to be push up the tree, any filters -func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string, filters []filter.FilterExpr) (*ParquetSource, error) { +func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string) (*ParquetSource, error) { if len(columns) == 0 { return nil, errors.New("no columns were provided for projection push down") } @@ -85,7 +82,7 @@ func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string, filter arrowReader, err := pqarrow.NewFileReader( filerReader, - pqarrow.ArrowReadProperties{Parallel: true, BatchSize: 5}, // TODO: Read in from config for this stuff + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: 1}, // TODO: Read in from config for this stuff allocator, ) if err != nil { @@ -109,12 +106,11 @@ func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string, filter return &ParquetSource{ schema: rdr.Schema(), projectionPushDown: columns, - predicatePushDown: filters, reader: rdr, }, nil } -// This should be 1 +// double check that this return exactly n rows in a column. func (ps *ParquetSource) Next(n uint16) (*operators.RecordBatch, error) { if ps.reader == nil || ps.done || !ps.reader.Next() { return nil, io.EOF @@ -273,6 +269,33 @@ func CombineArray(a1, a2 arrow.Array) arrow.Array { appendBinary(b, a1.(*array.Binary)) appendBinary(b, a2.(*array.Binary)) return b.NewArray() + // -------------------- TIMESTAMP[TZ] -------------------- + case arrow.TIMESTAMP: + tsType := dt.(*arrow.TimestampType) // keeps unit + timezone + + b := array.NewTimestampBuilder(mem, tsType) + arr1 := a1.(*array.Timestamp) + arr2 := a2.(*array.Timestamp) + + // append arr1 + for i := 0; i < arr1.Len(); i++ { + if arr1.IsNull(i) { + b.AppendNull() + } else { + b.Append(arr1.Value(i)) + } + } + + // append arr2 + for i := 0; i < arr2.Len(); i++ { + if arr2.IsNull(i) { + b.AppendNull() + } else { + b.Append(arr2.Value(i)) + } + } + + return b.NewArray() default: panic(fmt.Sprintf("unsupported datatype in CombineArray: %v", dt)) diff --git a/src/Backend/opti-sql-go/operators/project/parquet_test.go b/src/Backend/opti-sql-go/operators/project/parquet_test.go index c383a07..ff28535 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet_test.go +++ b/src/Backend/opti-sql-go/operators/project/parquet_test.go @@ -58,7 +58,7 @@ func TestParquetInit(t *testing.T) { t.Run("Test No names pass in", func(t *testing.T) { f := getTestParquetFile() - _, err := NewParquetSourcePushDown(f, []string{}, nil) + _, err := NewParquetSourcePushDown(f, []string{}) if err == nil { t.Errorf("Expected error when no columns are passed in, but got nil") } @@ -66,7 +66,7 @@ func TestParquetInit(t *testing.T) { t.Run("Test invalid names are passed in", func(t *testing.T) { f := getTestParquetFile() - _, err := NewParquetSourcePushDown(f, []string{"non_existent_column"}, nil) + _, err := NewParquetSourcePushDown(f, []string{"non_existent_column"}) if err == nil { t.Errorf("Expected error when invalid column names are passed in, but got nil") } @@ -75,7 +75,7 @@ func TestParquetInit(t *testing.T) { t.Run("Test correct schema is returned", func(t *testing.T) { f := getTestParquetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSourcePushDown(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -94,15 +94,15 @@ func TestParquetInit(t *testing.T) { t.Run("Test input columns and filters were passed back out", func(t *testing.T) { f := getTestParquetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSourcePushDown(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns) if err != nil { t.Fatalf("Unexpected error: %v", err) } if len(source.projectionPushDown) != len(columns) { t.Errorf("Expected projectionPushDown to have %d columns, got %d", len(columns), len(source.projectionPushDown)) } - if !sameStringSlice(source.projectionPushDown, columns) || source.predicatePushDown != nil { - t.Errorf("Expected projectionPushDown to be %v and predicatePushDown to be nil, got %v and %v", columns, source.projectionPushDown, source.predicatePushDown) + if !sameStringSlice(source.projectionPushDown, columns) { + t.Errorf("Expected projectionPushDown to be %v and predicatePushDown to be nil, got %v and ", columns, source.projectionPushDown) } }) @@ -110,7 +110,7 @@ func TestParquetInit(t *testing.T) { f := getTestParquetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSourcePushDown(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -124,7 +124,7 @@ func TestParquetInit(t *testing.T) { func TestParquetClose(t *testing.T) { f := getTestParquetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSourcePushDown(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -144,7 +144,7 @@ func TestParquetClose(t *testing.T) { func TestRunToEnd(t *testing.T) { f := getTestParquetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSourcePushDown(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -163,7 +163,7 @@ func TestRunToEnd(t *testing.T) { func TestParquetRead(t *testing.T) { f := getTestParquetFile() columns := []string{"country", "capital", "lat"} - source, err := NewParquetSourcePushDown(f, columns, nil) + source, err := NewParquetSourcePushDown(f, columns) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index e85c73e..5181d58 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -32,14 +32,21 @@ func NewProjectExec(input operators.Operator, exprs []Expr.Expression) (*Project for i, e := range exprs { switch ex := e.(type) { case *Expr.Alias: + tp, err := Expr.ExprDataType(ex.Expr, input.Schema()) + if err != nil { + return nil, fmt.Errorf("project exec: failed to get expression data type for expr %d: %w", i, err) + } fields[i] = arrow.Field{ Name: ex.Name, - Type: Expr.ExprDataType(ex.Expr, input.Schema()), + Type: tp, Nullable: true, } default: name := fmt.Sprintf("col_%d", i) - Type := Expr.ExprDataType(e, input.Schema()) + Type, err := Expr.ExprDataType(e, input.Schema()) + if err != nil { + return nil, fmt.Errorf("project exec: failed to get expression data type for expr %d: %w", i, err) + } fields[i] = arrow.Field{ Name: name, Type: Type, diff --git a/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go b/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go index 25b1b94..354db56 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go +++ b/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go @@ -212,7 +212,7 @@ func TestProjectExec_Literal_Literal(t *testing.T) { } for i := 0; i < len(expected); i++ { if ageCol.Value(i) != expected[i] { - t.Fatalf("expected %d at position %d, but recieved %d", expected[i], i, ageCol.Value(i)) + t.Fatalf("expected %d at position %d, but received %d", expected[i], i, ageCol.Value(i)) } } }) diff --git a/src/Backend/opti-sql-go/operators/project/projectExec_test.go b/src/Backend/opti-sql-go/operators/project/projectExec_test.go index b180b2c..3415d47 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec_test.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec_test.go @@ -153,26 +153,6 @@ func TestProjectExec_BinaryAdd(t *testing.T) { t.Logf("column: %+v", rb.Columns[0]) } -// TODO: once your implement the other operators this test will fail -func TestUnimplemntedOperators(t *testing.T) { - names, cols := generateTestColumns() - memSrc, _ := NewInMemoryProjectExec(names, cols) - for i := Expr.Equal; i <= Expr.Or; i++ { - br := Expr.NewBinaryExpr(Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int16, int16(10)), i, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int16, int16(5))) - - proj, err := NewProjectExec(memSrc, []Expr.Expression{br}) - if err != nil { - t.Fatalf("failed to create project exec: %v", err) - } - _, err = proj.Next(1) - if err == nil { - t.Fatalf("expected error for unimplemented operator %d, got nil", i) - - } - t.Logf("error: %v", err) - } -} - func TestProjectExec_IterateEOF(t *testing.T) { names, cols := generateTestColumns() memSrc, _ := NewInMemoryProjectExec(names, cols) diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index ba2c37d..60f695b 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -122,6 +122,13 @@ func (rb *RecordBatch) DeepEqual(other *RecordBatch) bool { } return true } +func (rb *RecordBatch) ColumnByName(name string) (arrow.Array, error) { + indices := rb.Schema.FieldIndices(name) + if len(indices) == 0 { + return nil, fmt.Errorf("column with name '%s' not found in schema", name) + } + return rb.Columns[indices[0]], nil +} func (rbb *RecordBatchBuilder) GenIntArray(values ...int) arrow.Array { mem := memory.NewGoAllocator() builder := array.NewInt32Builder(mem) From 1907dd37cc76e678c1d28fbb332d37cdb6f6285a Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Fri, 21 Nov 2025 01:54:53 -0500 Subject: [PATCH 19/19] feat:Implement support for Sql-compatible regEx --- src/Backend/opti-sql-go/Expr/expr.go | 157 ++++++--- src/Backend/opti-sql-go/Expr/expr_test.go | 221 ++++++++++++ src/Backend/opti-sql-go/Expr/info.go | 4 + .../opti-sql-go/operators/filter/filter.go | 8 +- .../operators/filter/filter_test.go | 2 +- .../opti-sql-go/operators/filter/limit.go | 4 +- .../operators/filter/limit_test.go | 329 +++++++++++++++++- .../opti-sql-go/operators/filter/wildCard.go | 1 - .../operators/filter/wildCard_test.go | 7 - .../opti-sql-go/operators/project/parquet.go | 11 +- .../operators/project/projectExec.go | 2 +- 11 files changed, 677 insertions(+), 69 deletions(-) delete mode 100644 src/Backend/opti-sql-go/operators/filter/wildCard.go delete mode 100644 src/Backend/opti-sql-go/operators/filter/wildCard_test.go diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go index 0829819..b3eed34 100644 --- a/src/Backend/opti-sql-go/Expr/expr.go +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -1,9 +1,12 @@ package Expr import ( + "bytes" "context" + "errors" "fmt" "opti-sql-go/operators" + "regexp" "strings" "github.com/apache/arrow/go/v17/arrow" @@ -39,6 +42,8 @@ const ( // logical And binaryOperator = 12 Or binaryOperator = 13 + // RegEx expressions + Like binaryOperator = 14 // where column_name like "patte%n_with_wi%dcard_" ) type supportedFunctions int @@ -488,6 +493,21 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error return nil, err } return unpackDatum(datum) + case Like: + if leftArr.DataType() != arrow.BinaryTypes.String || rightArr.DataType() != arrow.BinaryTypes.String { + // regEx runs only on strings + return nil, errors.New("binary operator Like only works on arrays of strings") + } + var compiledRegEx = compileSqlRegEx(rightArr.ValueStr(0)) + filterBuilder := array.NewBooleanBuilder(memory.NewGoAllocator()) + leftStrArray := leftArr.(*array.String) + for i := 0; i < leftStrArray.Len(); i++ { + valid := validRegEx(leftStrArray.Value(i), compiledRegEx) + fmt.Printf("does %s match %s: %v\n", leftStrArray.Value(i), compiledRegEx, valid) + filterBuilder.Append(valid) + } + return filterBuilder.NewArray(), nil + } return nil, fmt.Errorf("binary operator %d not supported", b.Op) } @@ -502,28 +522,6 @@ func unpackDatum(d compute.Datum) (arrow.Array, error) { } return array.MakeArray(), nil } -func inferBinaryType(left arrow.DataType, op binaryOperator, right arrow.DataType) arrow.DataType { - switch op { - - case Addition, Subtraction, Multiplication, Division: - // numeric → numeric promotion rules - return numericPromotion(left, right) - - case Equal, NotEqual, LessThan, LessThanOrEqual, GreaterThan, GreaterThanOrEqual: - return arrow.FixedWidthTypes.Boolean - - case And, Or: - return arrow.FixedWidthTypes.Boolean - - default: - panic(fmt.Sprintf("inferBinaryType: unsupported operator %v", op)) - } -} -func numericPromotion(a, b arrow.DataType) arrow.DataType { - // simplest version: return float64 for any mixed numeric types. - // expand later when needed. - return arrow.PrimitiveTypes.Float64 -} type ScalarFunction struct { Function supportedFunctions @@ -579,6 +577,44 @@ func (s *ScalarFunction) ExprNode() {} func (s *ScalarFunction) String() string { return fmt.Sprintf("ScalarFunction(%d, %v)", s.Function, s.Arguments) } + +// If cast succeeds → return the casted value +// If cast fails → throw a runtime error +type CastExpr struct { + Expr Expression // can be a Literal or Column (check for datatype when you resolve) + TargetType arrow.DataType +} + +func NewCastExpr(expr Expression, targetType arrow.DataType) *CastExpr { + return &CastExpr{ + Expr: expr, + TargetType: targetType, + } +} + +func EvalCast(c *CastExpr, batch *operators.RecordBatch) (arrow.Array, error) { + arr, err := EvalExpression(c.Expr, batch) + if err != nil { + return nil, err + } + + // Use Arrow compute kernel to cast + castOpts := compute.SafeCastOptions(c.TargetType) + out, err := compute.CastArray(context.TODO(), arr, castOpts) + if err != nil { + // This is a runtime cast error + return nil, fmt.Errorf("cast error: cannot cast %s to %s: %w", + arr.DataType(), c.TargetType, err) + } + + return out, nil +} + +func (c *CastExpr) ExprNode() {} +func (c *CastExpr) String() string { + return fmt.Sprintf("Cast(%s AS %s)", c.Expr, c.TargetType) +} + func upperImpl(arr arrow.Array) (arrow.Array, error) { strArr, ok := arr.(*array.String) if !ok { @@ -630,39 +666,66 @@ func inferScalarFunctionType(fn supportedFunctions, argType arrow.DataType) arro } } -// If cast succeeds → return the casted value -// If cast fails → throw a runtime error -type CastExpr struct { - Expr Expression // can be a Literal or Column (check for datatype when you resolve) - TargetType arrow.DataType -} +func inferBinaryType(left arrow.DataType, op binaryOperator, right arrow.DataType) arrow.DataType { + switch op { -func NewCastExpr(expr Expression, targetType arrow.DataType) *CastExpr { - return &CastExpr{ - Expr: expr, - TargetType: targetType, + case Addition, Subtraction, Multiplication, Division: + // numeric → numeric promotion rules + return numericPromotion(left, right) + + case Equal, NotEqual, LessThan, LessThanOrEqual, GreaterThan, GreaterThanOrEqual: + return arrow.FixedWidthTypes.Boolean + + case And, Or: + return arrow.FixedWidthTypes.Boolean + + default: + panic(fmt.Sprintf("inferBinaryType: unsupported operator %v", op)) } } +func numericPromotion(a, b arrow.DataType) arrow.DataType { + // simplest version: return float64 for any mixed numeric types. + return arrow.PrimitiveTypes.Float64 +} -func EvalCast(c *CastExpr, batch *operators.RecordBatch) (arrow.Array, error) { - arr, err := EvalExpression(c.Expr, batch) - if err != nil { - return nil, err +func compileSqlRegEx(s string) string { + var buf bytes.Buffer + + // Track anchoring rules + startsWithWildcard := len(s) > 0 && s[0] == '%' + endsWithWildcard := len(s) > 0 && s[len(s)-1] == '%' + + // Build body + for i := 0; i < len(s); i++ { + switch s[i] { + case '_': + buf.WriteString(".") + case '%': + buf.WriteString(".*") + default: + // Escape regex meta chars + if strings.ContainsRune(`.^$|()[]*+?{}`, rune(s[i])) { + buf.WriteByte('\\') + } + buf.WriteByte(s[i]) + } } - // Use Arrow compute kernel to cast - castOpts := compute.SafeCastOptions(c.TargetType) - out, err := compute.CastArray(context.TODO(), arr, castOpts) - if err != nil { - // This is a runtime cast error - return nil, fmt.Errorf("cast error: cannot cast %s to %s: %w", - arr.DataType(), c.TargetType, err) + regex := buf.String() + + // Apply anchoring + if !startsWithWildcard { + regex = "^" + regex + } + if !endsWithWildcard { + regex = regex + "$" } - return out, nil + return regex } -func (c *CastExpr) ExprNode() {} -func (c *CastExpr) String() string { - return fmt.Sprintf("Cast(%s AS %s)", c.Expr, c.TargetType) +func validRegEx(columnValue, regExExpr string) bool { + ok, _ := regexp.MatchString(regExExpr, columnValue) + return ok + } diff --git a/src/Backend/opti-sql-go/Expr/expr_test.go b/src/Backend/opti-sql-go/Expr/expr_test.go index 54aaa24..487e8f2 100644 --- a/src/Backend/opti-sql-go/Expr/expr_test.go +++ b/src/Backend/opti-sql-go/Expr/expr_test.go @@ -1452,3 +1452,224 @@ func TestFilterBinaryExpr_InvalidTypes(t *testing.T) { } }) } + +func TestCompileRegEx(t *testing.T) { + t.Run("starts with abc", func(t *testing.T) { + sqlString := "abc%" + expectedRegEx := "^abc.*" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) + + t.Run("only % (matches anything)", func(t *testing.T) { + sqlString := "%" + expectedRegEx := ".*" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) + + t.Run("starts with foo", func(t *testing.T) { + sqlString := "foo%" + expectedRegEx := "^foo.*" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) + + t.Run("ends with xyz", func(t *testing.T) { + sqlString := "%xyz" + expectedRegEx := ".*xyz$" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) + + t.Run("contains dog", func(t *testing.T) { + sqlString := "%dog%" + expectedRegEx := ".*dog.*" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) + + t.Run("exactly 3 chars", func(t *testing.T) { + sqlString := "___" + expectedRegEx := "^...$" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) + + t.Run("a_z pattern", func(t *testing.T) { + sqlString := "a_z" + expectedRegEx := "^a.z$" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) + + t.Run("error-__", func(t *testing.T) { + sqlString := "error-__" + expectedRegEx := "^error-..$" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) + + t.Run("3 chars then log", func(t *testing.T) { + sqlString := "___log" + expectedRegEx := "^...log$" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) + + t.Run("file-%.txt", func(t *testing.T) { + sqlString := "%file-%.txt" + expectedRegEx := ".*file-.*\\.txt$" + res := compileSqlRegEx(sqlString) + if res != expectedRegEx { + t.Fatalf("expected %v, received %s", expectedRegEx, res) + } + }) +} + +func TestLikeOperatorSQL(t *testing.T) { + t.Run("name starts with a", func(t *testing.T) { + rc := generateTestColumns() + sqlStatment := "A%" + whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, string(sqlStatment))) + boolMask, err := EvalExpression(whereStatment, rc) + if err != nil { + t.Fatalf("unexpected error from EvalExpression") + } + mask, ok := boolMask.(*array.Boolean) + if !ok { + t.Fatalf("expected array type to be of type boolean but got %T, error:%v", mask, err) + } + expectedMask := []bool{true, false, false, false} + if mask.Len() != len(expectedMask) { + t.Fatalf("expected boolean array len to be %d but got %d", len(expectedMask), mask.Len()) + } + for i := 0; i < mask.Len(); i++ { + if mask.Value(i) != expectedMask[i] { + t.Fatalf("expected mask[%d] to be %v but got %v", i, expectedMask[i], mask.Value(i)) + } + } + }) + t.Run("name contains li", func(t *testing.T) { + rc := generateTestColumns() + sqlStatment := "%li%" + whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, string(sqlStatment))) + + boolMask, err := EvalExpression(whereStatment, rc) + if err != nil { + t.Fatalf("unexpected error from EvalExpression") + } + + mask, ok := boolMask.(*array.Boolean) + if !ok { + t.Fatalf("expected array type to be boolean, got %T, error:%v", mask, err) + } + + expectedMask := []bool{true, false, true, false} // Alice, Charlie + + if mask.Len() != len(expectedMask) { + t.Fatalf("expected mask len %d, got %d", len(expectedMask), mask.Len()) + } + for i := 0; i < mask.Len(); i++ { + if mask.Value(i) != expectedMask[i] { + t.Fatalf("expected mask[%d]=%v but got %v", i, expectedMask[i], mask.Value(i)) + } + } + }) + t.Run("name ends with d", func(t *testing.T) { + rc := generateTestColumns() + sqlStatment := "%d" + whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, string(sqlStatment))) + + boolMask, err := EvalExpression(whereStatment, rc) + if err != nil { + t.Fatalf("unexpected error from EvalExpression") + } + + mask, ok := boolMask.(*array.Boolean) + if !ok { + t.Fatalf("expected array type boolean, got %T, error:%v", mask, err) + } + + expectedMask := []bool{false, false, false, true} // only David ends with d + + if mask.Len() != len(expectedMask) { + t.Fatalf("expected mask len %d, got %d", len(expectedMask), mask.Len()) + } + for i := 0; i < mask.Len(); i++ { + if mask.Value(i) != expectedMask[i] { + t.Fatalf("expected mask[%d]=%v but got %v", i, expectedMask[i], mask.Value(i)) + } + } + }) + t.Run("name is exactly 5 letters", func(t *testing.T) { + rc := generateTestColumns() + sqlStatment := "_____" + whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, string(sqlStatment))) + + boolMask, err := EvalExpression(whereStatment, rc) + if err != nil { + t.Fatalf("unexpected error from EvalExpression") + } + + mask, ok := boolMask.(*array.Boolean) + if !ok { + t.Fatalf("expected boolean array got %T, error:%v", mask, err) + } + + expectedMask := []bool{true, false, false, true} // Alice (5), David (5) + + if mask.Len() != len(expectedMask) { + t.Fatalf("expected mask len %d, got %d", len(expectedMask), mask.Len()) + } + for i := 0; i < mask.Len(); i++ { + if mask.Value(i) != expectedMask[i] { + t.Fatalf("expected mask[%d]=%v but got %v", i, expectedMask[i], mask.Value(i)) + } + } + }) + t.Run("name starts with Ch", func(t *testing.T) { + rc := generateTestColumns() + sqlStatment := "Ch%" + whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, string(sqlStatment))) + + boolMask, err := EvalExpression(whereStatment, rc) + if err != nil { + t.Fatalf("unexpected error from EvalExpression") + } + + mask, ok := boolMask.(*array.Boolean) + if !ok { + t.Fatalf("expected boolean array got %T, error:%v", mask, err) + } + + expectedMask := []bool{false, false, true, false} // only Charlie starts with Ch + + if mask.Len() != len(expectedMask) { + t.Fatalf("expected mask len %d, got %d", len(expectedMask), mask.Len()) + } + for i := 0; i < mask.Len(); i++ { + if mask.Value(i) != expectedMask[i] { + t.Fatalf("expected mask[%d]=%v but got %v", i, expectedMask[i], mask.Value(i)) + } + } + }) +} diff --git a/src/Backend/opti-sql-go/Expr/info.go b/src/Backend/opti-sql-go/Expr/info.go index eda4866..df83f8d 100644 --- a/src/Backend/opti-sql-go/Expr/info.go +++ b/src/Backend/opti-sql-go/Expr/info.go @@ -23,3 +23,7 @@ package Expr //COALESCE(a, b) //5. Constants //SELECT 1, 'hello', 3.14 + +// ======================= +// These are all implemented +// ======================= diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index 8d839a3..ddd8c1b 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -16,7 +16,7 @@ var ( _ = (operators.Operator)(&FilterExec{}) ) -// FilterExpr takes in a field and column and yeilds a function that takes in an index and returns a bool indicating whether the row at that index satisfies the filter condition. +// FilterExec is an operator that filters input records according to a predicate expression. type FilterExec struct { input operators.Operator schema *arrow.Schema @@ -49,7 +49,10 @@ func (f *FilterExec) Next(n uint16) (*operators.RecordBatch, error) { if err != nil { return nil, err } - boolArr := booleanMask.(*array.Boolean) // impossible for this to not be a boolean array,assuming validPredicates works as it should + boolArr, ok := booleanMask.(*array.Boolean) // impossible for this to not be a boolean array,assuming validPredicates works as it should + if !ok { + return nil, errors.New("predicate did not evaluate to boolean array") + } filteredCol := make([]arrow.Array, len(batch.Columns)) for i, col := range batch.Columns { filteredCol[i], err = applyBooleanMask(col, boolArr) @@ -73,7 +76,6 @@ func (f *FilterExec) Schema() *arrow.Schema { return f.schema } -// TODO: check if this pattern is good func (f *FilterExec) Close() error { return f.input.Close() } diff --git a/src/Backend/opti-sql-go/operators/filter/filter_test.go b/src/Backend/opti-sql-go/operators/filter/filter_test.go index 698f2b6..8e531c9 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter_test.go +++ b/src/Backend/opti-sql-go/operators/filter/filter_test.go @@ -206,7 +206,7 @@ func TestFilterExec_EdgeCases(t *testing.T) { _, err := f.Next(0) if err == nil { - t.Fatalf("exepected error but got %v", err) + t.Fatalf("expected error but got %v", err) } }) diff --git a/src/Backend/opti-sql-go/operators/filter/limit.go b/src/Backend/opti-sql-go/operators/filter/limit.go index 68bbb41..e4c93a5 100644 --- a/src/Backend/opti-sql-go/operators/filter/limit.go +++ b/src/Backend/opti-sql-go/operators/filter/limit.go @@ -11,7 +11,6 @@ var ( _ = (operators.Operator)(&LimitExec{}) ) -// TODO: (1) Implement Limter Exec Operator | pretty straightforward type LimitExec struct { input operators.Operator schema *arrow.Schema @@ -67,7 +66,6 @@ func (l *LimitExec) Schema() *arrow.Schema { return l.schema } -// nothing to close func (l *LimitExec) Close() error { - return nil + return l.input.Close() } diff --git a/src/Backend/opti-sql-go/operators/filter/limit_test.go b/src/Backend/opti-sql-go/operators/filter/limit_test.go index 337d423..64cd006 100644 --- a/src/Backend/opti-sql-go/operators/filter/limit_test.go +++ b/src/Backend/opti-sql-go/operators/filter/limit_test.go @@ -3,8 +3,12 @@ package filter import ( "errors" "io" + "opti-sql-go/Expr" "opti-sql-go/operators/project" "testing" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" ) func generateTestColumns() ([]string, []any) { @@ -46,6 +50,43 @@ func basicProject() *project.InMemorySource { v, _ := project.NewInMemoryProjectExec(names, col) return v } +func maskAny(t *testing.T, src *project.InMemorySource, expr Expr.Expression, expected []bool) { + t.Helper() + + // 1. Pull the record batch from the project source + batch, err := src.Next(10) + if err != nil { + t.Fatalf("failed to fetch record batch: %v", err) + } + if batch == nil { + t.Fatalf("expected non-nil record batch from project source") + } + + // 2. Evaluate expression against the batch + out, err := Expr.EvalExpression(expr, batch) + if err != nil { + t.Fatalf("EvalExpression error: %v", err) + } + + // 3. Extract boolean mask + mask, ok := out.(*array.Boolean) + if !ok { + t.Fatalf("expected output to be *array.Boolean, got %T", out) + } + + // 4. Validate length matches + if mask.Len() != len(expected) { + t.Fatalf("expected mask length %d, got %d", len(expected), mask.Len()) + } + + // 5. Validate each element + for i := 0; i < mask.Len(); i++ { + if mask.Value(i) != expected[i] { + t.Fatalf("mask[%d]: expected %v, got %v", i, expected[i], mask.Value(i)) + } + } +} + func TestLimitInit(t *testing.T) { // Simple passing test trialProject := basicProject() @@ -137,7 +178,7 @@ func TestLimitExec_NextBehavior(t *testing.T) { } _, err = lim.Next(10) if !errors.Is(err, io.EOF) { - t.Fatalf("was expecting io.EOF but recieved %v", err) + t.Fatalf("was expecting io.EOF but received %v", err) } }) } @@ -209,3 +250,289 @@ func TestLimitExec_IterationUntilEOF(t *testing.T) { _ = lim.Close() }) } + +/* +============================================== +// Wild Card Test +============================================== +*/ +func TestLikePercentWildcards(t *testing.T) { + + t.Run("name starts with A (A%)", func(t *testing.T) { + src := basicProject() + sql := "A%" + + expected := []bool{ + true, false, false, false, false, + false, false, false, false, false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) + + t.Run("name ends with e (%e)", func(t *testing.T) { + src := basicProject() + sql := "%e" + + expected := []bool{ + true, // Alice + false, + true, // Charlie + false, + true, // Eve + false, + true, // Grace + false, + false, + true, // Jake + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) + + t.Run("name contains 'an' (%an%)", func(t *testing.T) { + src := basicProject() + sql := "%an%" + + expected := []bool{ + false, false, false, false, false, + true, // Frank + false, + true, // Hannah + false, + false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) + + t.Run("wildcard only (%) matches all rows", func(t *testing.T) { + src := basicProject() + sql := "%" + + expected := []bool{ + true, true, true, true, true, + true, true, true, true, true, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) +} + +func TestLikeSingleUnderscore(t *testing.T) { + + t.Run("name is exactly 5 characters (_____)", func(t *testing.T) { + src := basicProject() + sql := "_____" + + // Alice, David, Grace + expected := []bool{ + true, // Alice (5) + false, + false, + true, // David (5) + false, + true, // Frank (5) + true, // Grace (5) + false, + false, + false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) + + t.Run("name starts with H and length is 6 (H_____)", func(t *testing.T) { + src := basicProject() + sql := "H_____" + + // Hannah is 6 letters + expected := []bool{ + false, false, false, false, false, + false, false, + true, // Hannah + false, false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) + + t.Run("fourth letter is r (___r%)", func(t *testing.T) { + src := basicProject() + sql := "___r%" + + // Charlie → C h a r … + expected := []bool{ + false, false, + true, // Charlie + false, false, + false, false, false, false, false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) +} + +func TestLikeMixedWildcards(t *testing.T) { + + t.Run("starts with C and exactly 7 chars (C______)", func(t *testing.T) { + src := basicProject() + sql := "C______" + + // Charlie (7 letters) + expected := []bool{ + false, false, + true, // Charlie + false, false, + false, false, false, false, false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) + + t.Run("ends with ake (_ake)", func(t *testing.T) { + src := basicProject() + sql := "_ake" + + // Jake → J a k e, matches _ake + expected := []bool{ + false, false, false, false, false, + false, false, false, false, + true, // Jake + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) + + t.Run("starts with H and contains ah (H%ah%)", func(t *testing.T) { + src := basicProject() + sql := "H%ah%" + + // Hannah contains "ah" twice + expected := []bool{ + false, false, false, false, false, + false, false, + true, // Hannah + false, false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) +} + +func TestLikeEdgeCases(t *testing.T) { + + t.Run("empty pattern matches nothing", func(t *testing.T) { + src := basicProject() + sql := "" + + expected := []bool{ + false, false, false, false, false, + false, false, false, false, false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) + + t.Run("no names end with zz (%zz)", func(t *testing.T) { + src := basicProject() + sql := "%zz" + + expected := []bool{ + false, false, false, false, false, + false, false, false, false, false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) + + t.Run("single underscore (_) matches 1-char names only", func(t *testing.T) { + src := basicProject() + sql := "_" + + expected := []bool{ + false, false, false, false, false, + false, false, false, false, false, + } + + expr := Expr.NewBinaryExpr( + Expr.NewColumnResolve("name"), + Expr.Like, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, sql), + ) + + maskAny(t, src, expr, expected) + }) +} diff --git a/src/Backend/opti-sql-go/operators/filter/wildCard.go b/src/Backend/opti-sql-go/operators/filter/wildCard.go deleted file mode 100644 index 4a28b11..0000000 --- a/src/Backend/opti-sql-go/operators/filter/wildCard.go +++ /dev/null @@ -1 +0,0 @@ -package filter diff --git a/src/Backend/opti-sql-go/operators/filter/wildCard_test.go b/src/Backend/opti-sql-go/operators/filter/wildCard_test.go deleted file mode 100644 index 8a341a0..0000000 --- a/src/Backend/opti-sql-go/operators/filter/wildCard_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package filter - -import "testing" - -func TestWildCard(t *testing.T) { - // Simple passing test -} diff --git a/src/Backend/opti-sql-go/operators/project/parquet.go b/src/Backend/opti-sql-go/operators/project/parquet.go index ec70c3e..94b6e1d 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/parquet.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "opti-sql-go/config" "opti-sql-go/operators" "github.com/apache/arrow/go/v17/arrow" @@ -16,7 +17,8 @@ import ( ) var ( - _ = (operators.Operator)(&ParquetSource{}) + _ = (operators.Operator)(&ParquetSource{}) + Config = config.GetConfig() ) type ParquetSource struct { @@ -43,7 +45,7 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { arrowReader, err := pqarrow.NewFileReader( filerReader, - pqarrow.ArrowReadProperties{Parallel: true, BatchSize: 1}, // TODO: Read in from config for this stuff + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, // TODO: Read in from config for this stuff allocator, ) if err != nil { @@ -82,7 +84,7 @@ func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string) (*Parq arrowReader, err := pqarrow.NewFileReader( filerReader, - pqarrow.ArrowReadProperties{Parallel: true, BatchSize: 1}, // TODO: Read in from config for this stuff + pqarrow.ArrowReadProperties{Parallel: true, BatchSize: int64(Config.Batch.Size)}, // TODO: Read in from config for this stuff allocator, ) if err != nil { @@ -143,7 +145,6 @@ func (ps *ParquetSource) Next(n uint16) (*operators.RecordBatch, error) { // Replace columns[colIdx] = combined - // VERY IMPORTANT: // Release the old existing array to avoid leaks existing.Release() } @@ -152,7 +153,7 @@ func (ps *ParquetSource) Next(n uint16) (*operators.RecordBatch, error) { curRow += numRows } return &operators.RecordBatch{ - Schema: ps.schema, // Remove the pointer as ps.Schema is already of type arrow.Schema + Schema: ps.schema, Columns: columns, RowCount: uint64(curRow), }, nil diff --git a/src/Backend/opti-sql-go/operators/project/projectExec.go b/src/Backend/opti-sql-go/operators/project/projectExec.go index 5181d58..9d93d96 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExec.go +++ b/src/Backend/opti-sql-go/operators/project/projectExec.go @@ -104,7 +104,7 @@ func (p *ProjectExec) Next(n uint16) (*operators.RecordBatch, error) { }, nil } func (p *ProjectExec) Close() error { - return nil + return p.child.Close() } func (p *ProjectExec) Schema() *arrow.Schema { return &p.outputschema