From 27989f5c8883f20eb729234780961e1a7442ea87 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Wed, 12 Nov 2025 12:45:29 -0500 Subject: [PATCH 01/10] feat(GRPC-spec): define grpc contract between Query parser and query execution enginee --- .gitignore | 9 +++++ README.md | 7 +--- src/Backend/test_data/s3_source/source.json | 17 ++++++++ src/Contract/operation.proto | 44 +++++++++++++++++++++ 4 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 src/Backend/test_data/s3_source/source.json create mode 100644 src/Contract/operation.proto diff --git a/.gitignore b/.gitignore index 5e5e5b7..18bb522 100644 --- a/.gitignore +++ b/.gitignore @@ -94,3 +94,12 @@ flycheck0/ # Cache directories .cache/ node_modules/ + +# dont push large files to git +src/Backend/test_data/parquet +src/Backend/test_data/csv +src/Backend/test_data/json + +# Allow s3_source directory +!src/Backend/test_data/s3_source/ +!src/Backend/test_data/s3_source/** \ No newline at end of file diff --git a/README.md b/README.md index 4702126..215a8c6 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ OptiSQL is a custom in-memory query execution engine. The backend (physical exec ### Prerequisites - Go 1.24+ - Rust 1.70+ -- C++ (marco update this) +- C++ 23.0 - Make - git @@ -72,7 +72,6 @@ Initial development is done in **Go** (`opti-sql-go`), which serves as the prima - `/operators` - SQL operator implementations (filter, join, aggregation, project) - `/physical-optimizer` - Query plan parsing and optimization - `/substrait` - Substrait plan integration -- `/project` - [Add description] ## Branching Model @@ -130,6 +129,4 @@ Want to contribute? Check out [CONTRIBUTING.md](CONTRIBUTING.md) for detailed gu - Build and run instructions ## License - -This project is licensed under the terms specified in [LICENSE.txt](LICENSE.txt). - +This project is licensed under the terms specified in [LICENSE.txt](LICENSE.txt). \ No newline at end of file diff --git a/src/Backend/test_data/s3_source/source.json b/src/Backend/test_data/s3_source/source.json new file mode 100644 index 0000000..0cb53bb --- /dev/null +++ b/src/Backend/test_data/s3_source/source.json @@ -0,0 +1,17 @@ +{ + "meta_data":"names of s3 files", + "csv_files":[ + "s3://my-bucket/data/file1.csv", + "s3://my-bucket/data/file2.csv", + "s3://my-bucket/data/file3.csv" + ], + "json_files":[ + "s3://my-bucket/data/file1.json", + "s3://my-bucket/data/file2.json" + ], + "parquet_files":[ + "s3://my-bucket/data/file1.parquet", + "s3://my-bucket/data/file2.parquet" + ] + +} \ No newline at end of file diff --git a/src/Contract/operation.proto b/src/Contract/operation.proto new file mode 100644 index 0000000..6c2fdd5 --- /dev/null +++ b/src/Contract/operation.proto @@ -0,0 +1,44 @@ +syntax = "proto3"; + +package contract; + +// The service definition. +service SSOperation { + rpc ExecuteQuery(QueryExecutionRequest) returns (QueryExecutionResponse); +} + +// The request message containing the operation details. +message QueryExecutionRequest { + bytes substraight_logical = 1; //SS logical plan + string sql_statment = 2; // original sql statment + string id = 3; // unique id for this client + SourceType source = 4; // (s3 link| base64 data) +} + +// The response message containing the result. +message QueryExecutionResponse { + string s3_result_link = 1; // s3 link to the result data + ErrorDetails error_type = 2; // error type if any +} + +message SourceType{ + string s3_source = 1; // s3 link to the source data + string mime = 2; +} + +enum returnTypes{ + Success = 0; + ParseError = 1; + ExecutionError = 2; + SourceError = 3; + UploadError = 4; + OutOfMemory = 5; + UnknownError = 6; +} +message ErrorDetails{ + returnTypes error_type = 1; + string message = 2; +} + +// Flow: Upload source file to s3 GRPC call write results to s3 +//(client) -> sql statement + source data -> Query parser -> (SS,sql,id,s3 link) -> PerformOperation() -> execution engine -> (proccess data) -> s3 result link -> Query parser -> (s3 result link) -> (client) \ No newline at end of file From 6e472feb2027ff2030e43eb2d1391bb22fd9783f Mon Sep 17 00:00:00 2001 From: RIchard Baah <137434454+Rich-T-kid@users.noreply.github.com> Date: Wed, 12 Nov 2025 14:28:45 -0500 Subject: [PATCH 02/10] Update README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 215a8c6..d3d3986 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ OptiSQL is a custom in-memory query execution engine. The backend (physical exec ### Prerequisites - Go 1.24+ - Rust 1.70+ -- C++ 23.0 +- C++23 - Make - git From 8f7a122ea3f89b155f76603af979fea34d5136b7 Mon Sep 17 00:00:00 2001 From: RIchard Baah <137434454+Rich-T-kid@users.noreply.github.com> Date: Wed, 12 Nov 2025 14:28:53 -0500 Subject: [PATCH 03/10] Update src/Contract/operation.proto Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/Contract/operation.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Contract/operation.proto b/src/Contract/operation.proto index 6c2fdd5..b941ae7 100644 --- a/src/Contract/operation.proto +++ b/src/Contract/operation.proto @@ -9,7 +9,7 @@ service SSOperation { // The request message containing the operation details. message QueryExecutionRequest { - bytes substraight_logical = 1; //SS logical plan + bytes substrait_logical = 1; //SS logical plan string sql_statment = 2; // original sql statment string id = 3; // unique id for this client SourceType source = 4; // (s3 link| base64 data) From 5b384d9362e584de1bbac10289e3429478417311 Mon Sep 17 00:00:00 2001 From: RIchard Baah <137434454+Rich-T-kid@users.noreply.github.com> Date: Wed, 12 Nov 2025 14:29:00 -0500 Subject: [PATCH 04/10] Update src/Contract/operation.proto Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/Contract/operation.proto | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Contract/operation.proto b/src/Contract/operation.proto index b941ae7..874a8c9 100644 --- a/src/Contract/operation.proto +++ b/src/Contract/operation.proto @@ -27,13 +27,13 @@ message SourceType{ } enum returnTypes{ - Success = 0; - ParseError = 1; - ExecutionError = 2; - SourceError = 3; - UploadError = 4; - OutOfMemory = 5; - UnknownError = 6; + SUCCESS = 0; + PARSE_ERROR = 1; + EXECUTION_ERROR = 2; + SOURCE_ERROR = 3; + UPLOAD_ERROR = 4; + OUT_OF_MEMORY = 5; + UNKNOWN_ERROR = 6; } message ErrorDetails{ returnTypes error_type = 1; From 398b246ae0eda70e58ef47e4de6a042de01667f5 Mon Sep 17 00:00:00 2001 From: RIchard Baah <137434454+Rich-T-kid@users.noreply.github.com> Date: Wed, 12 Nov 2025 14:29:13 -0500 Subject: [PATCH 05/10] Update src/Contract/operation.proto Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/Contract/operation.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Contract/operation.proto b/src/Contract/operation.proto index 874a8c9..563a4f0 100644 --- a/src/Contract/operation.proto +++ b/src/Contract/operation.proto @@ -10,7 +10,7 @@ service SSOperation { // The request message containing the operation details. message QueryExecutionRequest { bytes substrait_logical = 1; //SS logical plan - string sql_statment = 2; // original sql statment + string sql_statement = 2; // original sql statement string id = 3; // unique id for this client SourceType source = 4; // (s3 link| base64 data) } From 050c21159d839f27bda371aa97406904b35268c7 Mon Sep 17 00:00:00 2001 From: RIchard Baah <137434454+Rich-T-kid@users.noreply.github.com> Date: Wed, 12 Nov 2025 14:29:24 -0500 Subject: [PATCH 06/10] Update src/Contract/operation.proto Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/Contract/operation.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Contract/operation.proto b/src/Contract/operation.proto index 563a4f0..040a0b1 100644 --- a/src/Contract/operation.proto +++ b/src/Contract/operation.proto @@ -41,4 +41,4 @@ message ErrorDetails{ } // Flow: Upload source file to s3 GRPC call write results to s3 -//(client) -> sql statement + source data -> Query parser -> (SS,sql,id,s3 link) -> PerformOperation() -> execution engine -> (proccess data) -> s3 result link -> Query parser -> (s3 result link) -> (client) \ No newline at end of file +//(client) -> sql statement + source data -> Query parser -> (SS,sql,id,s3 link) -> PerformOperation() -> execution engine -> (process data) -> s3 result link -> Query parser -> (s3 result link) -> (client) \ No newline at end of file From 8ee0e823569372e65aeca61d8134e2d187691fb3 Mon Sep 17 00:00:00 2001 From: RIchard Baah <137434454+Rich-T-kid@users.noreply.github.com> Date: Wed, 12 Nov 2025 14:29:37 -0500 Subject: [PATCH 07/10] Update README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d3d3986..061a643 100644 --- a/README.md +++ b/README.md @@ -129,4 +129,4 @@ Want to contribute? Check out [CONTRIBUTING.md](CONTRIBUTING.md) for detailed gu - Build and run instructions ## License -This project is licensed under the terms specified in [LICENSE.txt](LICENSE.txt). \ No newline at end of file +This project is licensed under the terms specified in [LICENSE.txt](LICENSE.txt). \ No newline at end of file From e5ea405453fdaf1b5bbaaaa8344be837d3fb269c Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Fri, 14 Nov 2025 05:38:14 -0500 Subject: [PATCH 08/10] feat(RecordType):Implement Record serialization/deserilization format (supporting test included) --- src/Backend/opti-sql-go/coverage.out | 2 - src/Backend/opti-sql-go/go.mod | 15 + src/Backend/opti-sql-go/go.sum | 41 + src/Backend/opti-sql-go/main.go | 3 - src/Backend/opti-sql-go/operators/record.go | 276 ++++++- .../opti-sql-go/operators/record_test.go | 775 +++++++++++++++++- .../opti-sql-go/operators/serialize.go | 412 +++++++++- .../opti-sql-go/operators/serialize_test.go | 757 ++++++++++++++++- src/Backend/test_data/s3_source/source.json | 1 - src/Contract/operation.proto | 18 +- 10 files changed, 2277 insertions(+), 23 deletions(-) delete mode 100644 src/Backend/opti-sql-go/coverage.out create mode 100644 src/Backend/opti-sql-go/go.sum diff --git a/src/Backend/opti-sql-go/coverage.out b/src/Backend/opti-sql-go/coverage.out deleted file mode 100644 index b5e11ef..0000000 --- a/src/Backend/opti-sql-go/coverage.out +++ /dev/null @@ -1,2 +0,0 @@ -mode: set -opti-sql-go/main.go:5.13,7.2 1 0 diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index afe0dde..f4c6898 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -1,3 +1,18 @@ module opti-sql-go go 1.24.0 + +require github.com/apache/arrow/go/v17 v17.0.0 + +require ( + github.com/goccy/go-json v0.10.3 // indirect + github.com/google/flatbuffers v24.3.25+incompatible // indirect + github.com/klauspost/cpuid/v2 v2.2.8 // indirect + github.com/zeebo/xxh3 v1.0.2 // indirect + golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect + golang.org/x/mod v0.18.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.21.0 // indirect + golang.org/x/tools v0.22.0 // indirect + golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect +) diff --git a/src/Backend/opti-sql-go/go.sum b/src/Backend/opti-sql-go/go.sum new file mode 100644 index 0000000..1e7a8ee --- /dev/null +++ b/src/Backend/opti-sql-go/go.sum @@ -0,0 +1,41 @@ +github.com/apache/arrow/go/v17 v17.0.0 h1:RRR2bdqKcdbss9Gxy2NS/hK8i4LDMh23L6BbkN5+F54= +github.com/apache/arrow/go/v17 v17.0.0/go.mod h1:jR7QHkODl15PfYyjM2nU+yTLScZ/qfj7OSUZmJ8putc= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI= +github.com/google/flatbuffers v24.3.25+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= +github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= +github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +gonum.org/v1/gonum v0.15.0 h1:2lYxjRbTYyxkJxlhC+LvJIx3SsANPdRybu1tGj9/OrQ= +gonum.org/v1/gonum v0.15.0/go.mod h1:xzZVBJBtS+Mz4q0Yl2LJTk+OxOg4jiXZ7qBoM0uISGo= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/Backend/opti-sql-go/main.go b/src/Backend/opti-sql-go/main.go index 91e7378..da29a2c 100644 --- a/src/Backend/opti-sql-go/main.go +++ b/src/Backend/opti-sql-go/main.go @@ -1,7 +1,4 @@ package main -import "fmt" - func main() { - fmt.Println("Hello World") } diff --git a/src/Backend/opti-sql-go/operators/record.go b/src/Backend/opti-sql-go/operators/record.go index 755491d..70a8a2e 100644 --- a/src/Backend/opti-sql-go/operators/record.go +++ b/src/Backend/opti-sql-go/operators/record.go @@ -1,3 +1,277 @@ package operators -// This is what everything is going to be working off of +import ( + "fmt" + "strings" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" +) + +var ( + ErrInvalidSchema = func(info string) error { + return fmt.Errorf("invalid schema was provided. context: %s", info) + } +) + +type RecordBatch struct { + Schema *arrow.Schema + Columns []arrow.Array +} + +type SchemaBuilder struct { + fields []arrow.Field +} + +type RecordBatchBuilder struct { + SchemaBuilder *SchemaBuilder +} + +func NewRecordBatchBuilder() *RecordBatchBuilder { + return &RecordBatchBuilder{ + SchemaBuilder: &SchemaBuilder{ + fields: make([]arrow.Field, 0, 10), + }, + } +} + +func (sb *SchemaBuilder) WithField(name string, dtype arrow.DataType, nullable bool) *SchemaBuilder { + sb.fields = append(sb.fields, arrow.Field{ + Name: name, + Type: dtype, + Nullable: nullable, + }) + return sb +} +func (sb *SchemaBuilder) WithoutField(names ...string) *SchemaBuilder { + nameSet := make(map[string]struct{}, len(names)) + for _, n := range names { + nameSet[n] = struct{}{} + } + + newFields := make([]arrow.Field, 0, len(sb.fields)) + for _, field := range sb.fields { + _, found := nameSet[field.Name] + if !found { + newFields = append(newFields, field) + } + } + sb.fields = newFields + return sb + +} + +func (sb *SchemaBuilder) Build() *arrow.Schema { + return arrow.NewSchema(sb.fields, nil) +} +func (rbb *RecordBatchBuilder) Schema() *arrow.Schema { + return arrow.NewSchema(rbb.SchemaBuilder.fields, nil) +} + +// schema is always right in case of type mismatches +func (rbb *RecordBatchBuilder) validate(schema *arrow.Schema, columns []arrow.Array) error { + if len(schema.Fields()) != len(columns) { + return ErrInvalidSchema("schema fields and column count do not match") + } + // make sure that the array data types line up with whats expected of the schema + // Ensure array data types align with schema expectations. + var errors []string + for i := 0; i < len(columns); i++ { + field := schema.Field(i) + colType := columns[i].DataType() + + if !arrow.TypeEqual(colType, field.Type) { + errors = append(errors, + fmt.Sprintf("Type mismatch at position %d: column '%s' has type '%s', but schema expects '%s'.", + i, field.Name, colType, field.Type)) + } + } + if len(errors) > 0 { + return ErrInvalidSchema(strings.Join(errors, " ")) + } + return nil +} +func (rbb *RecordBatchBuilder) NewRecordBatch(schema *arrow.Schema, columns []arrow.Array) (*RecordBatch, error) { + if err := rbb.validate(schema, columns); err != nil { + return nil, err + } + return &RecordBatch{ + Schema: schema, + Columns: columns, + }, nil +} +func (rb *RecordBatch) DeepEqual(other *RecordBatch) bool { + if !rb.Schema.Equal(other.Schema) { + return false + } + if len(rb.Columns) != len(other.Columns) { + return false + } + for i := 0; i < len(rb.Columns); i++ { + if !array.Equal(rb.Columns[i], other.Columns[i]) { + return false + } + } + return true +} +func (rbb *RecordBatchBuilder) GenIntArray(values ...int) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewInt32Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(int32(v)) + } + return builder.NewArray() +} + +func (rbb *RecordBatchBuilder) GenFloatArray(values ...float64) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewFloat64Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +func (rbb *RecordBatchBuilder) GenStringArray(values ...string) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewStringBuilder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +func (rbb *RecordBatchBuilder) GenBoolArray(values ...bool) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewBooleanBuilder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenInt8Array generates an Int8 array +func (rbb *RecordBatchBuilder) GenInt8Array(values ...int8) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewInt8Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenInt16Array generates an Int16 array +func (rbb *RecordBatchBuilder) GenInt16Array(values ...int16) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewInt16Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenInt64Array generates an Int64 array +func (rbb *RecordBatchBuilder) GenInt64Array(values ...int64) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewInt64Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenUint8Array generates a Uint8 array +func (rbb *RecordBatchBuilder) GenUint8Array(values ...uint8) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewUint8Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenUint16Array generates a Uint16 array +func (rbb *RecordBatchBuilder) GenUint16Array(values ...uint16) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewUint16Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenUint32Array generates a Uint32 array +func (rbb *RecordBatchBuilder) GenUint32Array(values ...uint32) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewUint32Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenUint64Array generates a Uint64 array +func (rbb *RecordBatchBuilder) GenUint64Array(values ...uint64) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewUint64Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenFloat32Array generates a Float32 array +func (rbb *RecordBatchBuilder) GenFloat32Array(values ...float32) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewFloat32Builder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenBinaryArray generates a Binary array +func (rbb *RecordBatchBuilder) GenBinaryArray(values ...[]byte) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenLargeStringArray generates a LargeString array +func (rbb *RecordBatchBuilder) GenLargeStringArray(values ...string) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewLargeStringBuilder(mem) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} + +// GenLargeBinaryArray generates a LargeBinary array +func (rbb *RecordBatchBuilder) GenLargeBinaryArray(values ...[]byte) arrow.Array { + mem := memory.NewGoAllocator() + builder := array.NewBinaryBuilder(mem, arrow.BinaryTypes.LargeBinary) + defer builder.Release() + for _, v := range values { + builder.Append(v) + } + return builder.NewArray() +} diff --git a/src/Backend/opti-sql-go/operators/record_test.go b/src/Backend/opti-sql-go/operators/record_test.go index 5f1e095..5fe24f1 100644 --- a/src/Backend/opti-sql-go/operators/record_test.go +++ b/src/Backend/opti-sql-go/operators/record_test.go @@ -1,7 +1,776 @@ package operators -import "testing" +import ( + "testing" -func TestRecord(t *testing.T) { - // Simple passing test + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +// Test 1: SchemaBuilder.WithField and WithoutField +func TestSchemaBuilderWithField(t *testing.T) { + sb := &SchemaBuilder{ + fields: make([]arrow.Field, 0, 10), + } + + // Add fields + sb.WithField("age", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("salary", arrow.PrimitiveTypes.Float64, true) + + // Verify fields were added + if len(sb.fields) != 3 { + t.Errorf("Expected 3 fields, got %d", len(sb.fields)) + } + + // Check field names + expectedNames := []string{"age", "name", "salary"} + for i, expected := range expectedNames { + if sb.fields[i].Name != expected { + t.Errorf("Field %d: expected name '%s', got '%s'", i, expected, sb.fields[i].Name) + } + } + + // Check field types + if !arrow.TypeEqual(sb.fields[0].Type, arrow.PrimitiveTypes.Int32) { + t.Errorf("Field 'age': expected Int32 type, got %s", sb.fields[0].Type) + } + if !arrow.TypeEqual(sb.fields[1].Type, arrow.BinaryTypes.String) { + t.Errorf("Field 'name': expected String type, got %s", sb.fields[1].Type) + } + if !arrow.TypeEqual(sb.fields[2].Type, arrow.PrimitiveTypes.Float64) { + t.Errorf("Field 'salary': expected Float64 type, got %s", sb.fields[2].Type) + } + + // Check nullable flags + if sb.fields[0].Nullable != false { + t.Errorf("Field 'age': expected nullable=false, got %v", sb.fields[0].Nullable) + } + if sb.fields[2].Nullable != true { + t.Errorf("Field 'salary': expected nullable=true, got %v", sb.fields[2].Nullable) + } +} + +func TestSchemaBuilderWithoutField(t *testing.T) { + sb := &SchemaBuilder{ + fields: make([]arrow.Field, 0, 10), + } + + // Add fields + sb.WithField("age", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("salary", arrow.PrimitiveTypes.Float64, true). + WithField("active", arrow.FixedWidthTypes.Boolean, false) + + // Verify 4 fields + if len(sb.fields) != 4 { + t.Errorf("Expected 4 fields, got %d", len(sb.fields)) + } + + // Remove fields + sb.WithoutField("name", "active") + + // Verify only 2 fields remain + if len(sb.fields) != 2 { + t.Errorf("Expected 2 fields after removal, got %d", len(sb.fields)) + } + + // Verify remaining fields are age and salary + if sb.fields[0].Name != "age" { + t.Errorf("Expected first field to be 'age', got '%s'", sb.fields[0].Name) + } + if sb.fields[1].Name != "salary" { + t.Errorf("Expected second field to be 'salary', got '%s'", sb.fields[1].Name) + } +} + +// Test 1.5: SchemaBuilder.Build with WithField and WithoutField +func TestSchemaBuilderBuildWithFields(t *testing.T) { + sb := &SchemaBuilder{ + fields: make([]arrow.Field, 0, 10), + } + + // Add fields and build + sb.WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("email", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, true) + + schema := sb.Build() + + // Validate schema has 3 fields + if schema.NumFields() != 3 { + t.Errorf("Expected schema with 3 fields, got %d", schema.NumFields()) + } + + // Validate field names + field0 := schema.Field(0) + if field0.Name != "id" { + t.Errorf("Expected field 0 name 'id', got '%s'", field0.Name) + } + + field1 := schema.Field(1) + if field1.Name != "email" { + t.Errorf("Expected field 1 name 'email', got '%s'", field1.Name) + } + + field2 := schema.Field(2) + if field2.Name != "score" { + t.Errorf("Expected field 2 name 'score', got '%s'", field2.Name) + } + + // Validate types + if !arrow.TypeEqual(field0.Type, arrow.PrimitiveTypes.Int32) { + t.Errorf("Expected field 'id' type Int32, got %s", field0.Type) + } + if !arrow.TypeEqual(field1.Type, arrow.BinaryTypes.String) { + t.Errorf("Expected field 'email' type String, got %s", field1.Type) + } + if !arrow.TypeEqual(field2.Type, arrow.PrimitiveTypes.Float64) { + t.Errorf("Expected field 'score' type Float64, got %s", field2.Type) + } +} + +func TestSchemaBuilderBuildWithFieldsRemoved(t *testing.T) { + sb := &SchemaBuilder{ + fields: make([]arrow.Field, 0, 10), + } + + // Add fields, remove some, then build + sb.WithField("a", arrow.PrimitiveTypes.Int32, false). + WithField("b", arrow.BinaryTypes.String, false). + WithField("c", arrow.PrimitiveTypes.Float64, false). + WithField("d", arrow.FixedWidthTypes.Boolean, false). + WithoutField("b", "d") + schema := sb.Build() + + // Validate schema has only 2 fields (a and c) + if schema.NumFields() != 2 { + t.Errorf("Expected schema with 2 fields after removal, got %d", schema.NumFields()) + } + + // Validate remaining field names + if schema.Field(0).Name != "a" { + t.Errorf("Expected field 0 name 'a', got '%s'", schema.Field(0).Name) + } + if schema.Field(1).Name != "c" { + t.Errorf("Expected field 1 name 'c', got '%s'", schema.Field(1).Name) + } +} + +// Test 2: GenDataTypeArray functions +func TestGenIntArray(t *testing.T) { + rbb := NewRecordBatchBuilder() + + arr := rbb.GenIntArray(10, 20, 30, 40) + defer arr.Release() + + // Check length + if arr.Len() != 4 { + t.Errorf("Expected array length 4, got %d", arr.Len()) + } + + // Check type + if !arrow.TypeEqual(arr.DataType(), arrow.PrimitiveTypes.Int32) { + t.Errorf("Expected Int32 type, got %s", arr.DataType()) + } + + // Check values + int32Arr := arr.(*array.Int32) + expectedValues := []int32{10, 20, 30, 40} + for i, expected := range expectedValues { + if int32Arr.Value(i) != expected { + t.Errorf("Index %d: expected %d, got %d", i, expected, int32Arr.Value(i)) + } + } +} + +func TestGenFloatArray(t *testing.T) { + rbb := NewRecordBatchBuilder() + + arr := rbb.GenFloatArray(1.5, 2.7, 3.14, 9.99) + defer arr.Release() + + // Check length + if arr.Len() != 4 { + t.Errorf("Expected array length 4, got %d", arr.Len()) + } + + // Check type + if !arrow.TypeEqual(arr.DataType(), arrow.PrimitiveTypes.Float64) { + t.Errorf("Expected Float64 type, got %s", arr.DataType()) + } + + // Check values + float64Arr := arr.(*array.Float64) + expectedValues := []float64{1.5, 2.7, 3.14, 9.99} + for i, expected := range expectedValues { + if float64Arr.Value(i) != expected { + t.Errorf("Index %d: expected %f, got %f", i, expected, float64Arr.Value(i)) + } + } +} + +func TestGenStringArray(t *testing.T) { + rbb := NewRecordBatchBuilder() + + arr := rbb.GenStringArray("Alice", "Bob", "Charlie") + defer arr.Release() + + // Check length + if arr.Len() != 3 { + t.Errorf("Expected array length 3, got %d", arr.Len()) + } + + // Check type + if !arrow.TypeEqual(arr.DataType(), arrow.BinaryTypes.String) { + t.Errorf("Expected String type, got %s", arr.DataType()) + } + + // Check values + stringArr := arr.(*array.String) + expectedValues := []string{"Alice", "Bob", "Charlie"} + for i, expected := range expectedValues { + if stringArr.Value(i) != expected { + t.Errorf("Index %d: expected '%s', got '%s'", i, expected, stringArr.Value(i)) + } + } +} + +func TestGenBoolArray(t *testing.T) { + rbb := NewRecordBatchBuilder() + + arr := rbb.GenBoolArray(true, false, true, true) + defer arr.Release() + + // Check length + if arr.Len() != 4 { + t.Errorf("Expected array length 4, got %d", arr.Len()) + } + + // Check type + if !arrow.TypeEqual(arr.DataType(), arrow.FixedWidthTypes.Boolean) { + t.Errorf("Expected Boolean type, got %s", arr.DataType()) + } + + // Check values + boolArr := arr.(*array.Boolean) + expectedValues := []bool{true, false, true, true} + for i, expected := range expectedValues { + if boolArr.Value(i) != expected { + t.Errorf("Index %d: expected %v, got %v", i, expected, boolArr.Value(i)) + } + } +} + +// Test 3: Validate function +func TestValidateIncorrectColumnTypes(t *testing.T) { + rbb := NewRecordBatchBuilder() + + // Create schema expecting Int32 and String + schema := rbb.SchemaBuilder. + WithField("age", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + Build() + + // Create columns with wrong types (Float64 instead of Int32) + wrongCol := rbb.GenFloatArray(1.5, 2.5, 3.5) + defer wrongCol.Release() + nameCol := rbb.GenStringArray("Alice", "Bob", "Charlie") + defer nameCol.Release() + + // Validate should fail + err := rbb.validate(schema, []arrow.Array{wrongCol, nameCol}) + if err == nil { + t.Error("Expected validation error for incorrect column type, got nil") + } +} + +func TestValidateMismatchedColumnCount(t *testing.T) { + rbb := NewRecordBatchBuilder() + + // Create schema with 2 fields + schema := rbb.SchemaBuilder. + WithField("age", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + Build() + + // Create only 1 column + ageCol := rbb.GenIntArray(25, 30, 22) + defer ageCol.Release() + + // Validate should fail + err := rbb.validate(schema, []arrow.Array{ageCol}) + if err == nil { + t.Error("Expected validation error for column count mismatch, got nil") + } +} + +func TestValidateCorrectSchemaAndColumns(t *testing.T) { + rbb := NewRecordBatchBuilder() + + // Create schema + schema := rbb.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + + // Create matching columns + idCol := rbb.GenIntArray(1, 2, 3) + defer idCol.Release() + scoreCol := rbb.GenFloatArray(95.5, 87.3, 92.1) + defer scoreCol.Release() + + // Validate should pass + err := rbb.validate(schema, []arrow.Array{idCol, scoreCol}) + if err != nil { + t.Errorf("Expected validation to pass, got error: %v", err) + } +} + +// Test 4: NewRecordBatch function +func TestNewRecordBatchSuccess(t *testing.T) { + rbb := NewRecordBatchBuilder() + + // Create schema + schema := rbb.SchemaBuilder. + WithField("age", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + Build() + + // Create matching columns + ageCol := rbb.GenIntArray(25, 30, 22) + defer ageCol.Release() + nameCol := rbb.GenStringArray("Alice", "Bob", "Charlie") + defer nameCol.Release() + + // Create RecordBatch + rb, err := rbb.NewRecordBatch(schema, []arrow.Array{ageCol, nameCol}) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + // Validate RecordBatch + if rb == nil { + t.Fatal("Expected non-nil RecordBatch") + } + if rb.Schema.NumFields() != 2 { + t.Errorf("Expected schema with 2 fields, got %d", rb.Schema.NumFields()) + } + if len(rb.Columns) != 2 { + t.Errorf("Expected 2 columns, got %d", len(rb.Columns)) + } +} + +func TestNewRecordBatchMisalignedSchema(t *testing.T) { + rbb := NewRecordBatchBuilder() + + // Create schema with 3 fields + schema := rbb.SchemaBuilder. + WithField("a", arrow.PrimitiveTypes.Int32, false). + WithField("b", arrow.BinaryTypes.String, false). + WithField("c", arrow.PrimitiveTypes.Float64, false). + Build() + + // Create only 2 columns + col1 := rbb.GenIntArray(1, 2, 3) + defer col1.Release() + col2 := rbb.GenStringArray("x", "y", "z") + defer col2.Release() + + // Should fail + _, err := rbb.NewRecordBatch(schema, []arrow.Array{col1, col2}) + if err == nil { + t.Error("Expected error for misaligned schema and columns, got nil") + } +} + +func TestNewRecordBatchIncorrectDataTypes(t *testing.T) { + rbb := NewRecordBatchBuilder() + + // Create schema expecting Int32 and Float64 + schema := rbb.SchemaBuilder. + WithField("count", arrow.PrimitiveTypes.Int32, false). + WithField("value", arrow.PrimitiveTypes.Float64, false). + Build() + + // Create columns with wrong types (both String) + col1 := rbb.GenStringArray("1", "2", "3") + defer col1.Release() + col2 := rbb.GenStringArray("4.5", "5.5", "6.5") + defer col2.Release() + + // Should fail + _, err := rbb.NewRecordBatch(schema, []arrow.Array{col1, col2}) + if err == nil { + t.Error("Expected error for incorrect data types, got nil") + } +} + +// Test 5: Integration test - Full workflow +func TestRecordBatchBuilderIntegration(t *testing.T) { + // Create builder + rbb := NewRecordBatchBuilder() + + // Build schema with multiple fields + rbb.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, true). + WithField("active", arrow.FixedWidthTypes.Boolean, false). + WithField("temp", arrow.BinaryTypes.String, false). // Will be removed + WithoutField("temp") // Remove it + + schema := rbb.Schema() + + // Verify schema has 4 fields (temp was removed) + if schema.NumFields() != 4 { + t.Fatalf("Expected 4 fields in schema, got %d", schema.NumFields()) + } + + // Generate data arrays + idCol := rbb.GenIntArray(1, 2, 3, 4, 5) + defer idCol.Release() + + nameCol := rbb.GenStringArray("Alice", "Bob", "Charlie", "David", "Eve") + defer nameCol.Release() + + scoreCol := rbb.GenFloatArray(95.5, 87.3, 92.1, 88.0, 91.5) + defer scoreCol.Release() + + activeCol := rbb.GenBoolArray(true, false, true, true, false) + defer activeCol.Release() + + // Verify array lengths + if idCol.Len() != 5 { + t.Errorf("Expected idCol length 5, got %d", idCol.Len()) + } + + // Create RecordBatch + rb, err := rbb.NewRecordBatch(schema, []arrow.Array{idCol, nameCol, scoreCol, activeCol}) + if err != nil { + t.Fatalf("Failed to create RecordBatch: %v", err) + } + + // Validate RecordBatch structure + if rb.Schema.NumFields() != 4 { + t.Errorf("Expected RecordBatch schema with 4 fields, got %d", rb.Schema.NumFields()) + } + if len(rb.Columns) != 4 { + t.Errorf("Expected RecordBatch with 4 columns, got %d", len(rb.Columns)) + } + + // Validate field names in order + expectedFieldNames := []string{"id", "name", "score", "active"} + for i, expectedName := range expectedFieldNames { + actualName := rb.Schema.Field(i).Name + if actualName != expectedName { + t.Errorf("Field %d: expected name '%s', got '%s'", i, expectedName, actualName) + } + } + + // Validate column data types + if !arrow.TypeEqual(rb.Columns[0].DataType(), arrow.PrimitiveTypes.Int32) { + t.Errorf("Column 0: expected Int32, got %s", rb.Columns[0].DataType()) + } + if !arrow.TypeEqual(rb.Columns[1].DataType(), arrow.BinaryTypes.String) { + t.Errorf("Column 1: expected String, got %s", rb.Columns[1].DataType()) + } + if !arrow.TypeEqual(rb.Columns[2].DataType(), arrow.PrimitiveTypes.Float64) { + t.Errorf("Column 2: expected Float64, got %s", rb.Columns[2].DataType()) + } + if !arrow.TypeEqual(rb.Columns[3].DataType(), arrow.FixedWidthTypes.Boolean) { + t.Errorf("Column 3: expected Boolean, got %s", rb.Columns[3].DataType()) + } + + // Validate some actual values + idArr := rb.Columns[0].(*array.Int32) + if idArr.Value(0) != 1 || idArr.Value(4) != 5 { + t.Errorf("ID column values incorrect") + } + + nameArr := rb.Columns[1].(*array.String) + if nameArr.Value(0) != "Alice" || nameArr.Value(2) != "Charlie" { + t.Errorf("Name column values incorrect") + } + + scoreArr := rb.Columns[2].(*array.Float64) + if scoreArr.Value(0) != 95.5 || scoreArr.Value(1) != 87.3 { + t.Errorf("Score column values incorrect") + } + + activeArr := rb.Columns[3].(*array.Boolean) + if activeArr.Value(0) != true || activeArr.Value(1) != false { + t.Errorf("Active column values incorrect") + } + + // Test error case: try to create RecordBatch with mismatched columns + wrongCol := rbb.GenFloatArray(1.1, 2.2, 3.3, 4.4, 5.5) + defer wrongCol.Release() + + _, err = rbb.NewRecordBatch(schema, []arrow.Array{wrongCol, nameCol, scoreCol, activeCol}) + if err == nil { + t.Error("Expected error when creating RecordBatch with wrong column type, got nil") + } +} + +// TestRecordBatchDeepEqual tests every branch of the DeepEqual method +func TestRecordBatchDeepEqual(t *testing.T) { + rbb := NewRecordBatchBuilder() + + // Create a base schema and RecordBatch for testing + schema1 := rbb.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + + idCol := rbb.GenIntArray(1, 2, 3) + nameCol := rbb.GenStringArray("Alice", "Bob", "Charlie") + scoreCol := rbb.GenFloatArray(95.5, 87.3, 92.1) + + rb1, err := rbb.NewRecordBatch(schema1, []arrow.Array{idCol, nameCol, scoreCol}) + if err != nil { + t.Fatalf("Failed to create base RecordBatch: %v", err) + } + + // Mini test 1: Schema inequality - should hit "if !rb.Schema.Equal(other.Schema) { return false }" + t.Run("MiniTest_SchemaNotEqual", func(t *testing.T) { + builderA := NewRecordBatchBuilder() + diffSchema := builderA.SchemaBuilder. + WithField("different", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + colA1 := builderA.GenIntArray(1, 2, 3) + colA2 := builderA.GenStringArray("Alice", "Bob", "Charlie") + colA3 := builderA.GenFloatArray(95.5, 87.3, 92.1) + rbA, _ := builderA.NewRecordBatch(diffSchema, []arrow.Array{colA1, colA2, colA3}) + if rb1.DeepEqual(rbA) { + t.Error("DeepEqual should return false when schemas differ") + } + }) + + // Mini test 2: Column count inequality - should hit "if len(rb.Columns) != len(other.Columns) { return false }" + t.Run("MiniTest_ColumnCountNotEqual", func(t *testing.T) { + builderB := NewRecordBatchBuilder() + schemaB := builderB.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + Build() + colB1 := builderB.GenIntArray(1, 2, 3) + colB2 := builderB.GenStringArray("Alice", "Bob", "Charlie") + rbB, _ := builderB.NewRecordBatch(schemaB, []arrow.Array{colB1, colB2}) + if rb1.DeepEqual(rbB) { + t.Error("DeepEqual should return false when column counts differ") + } + }) + + // Mini test 3: Array inequality in loop - should hit "if !array.Equal(rb.Columns[i], other.Columns[i]) { return false }" + t.Run("MiniTest_ArrayNotEqual", func(t *testing.T) { + builderC := NewRecordBatchBuilder() + schemaC := builderC.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + colC1 := builderC.GenIntArray(999, 888, 777) // Different data + colC2 := builderC.GenStringArray("Alice", "Bob", "Charlie") + colC3 := builderC.GenFloatArray(95.5, 87.3, 92.1) + rbC, _ := builderC.NewRecordBatch(schemaC, []arrow.Array{colC1, colC2, colC3}) + if rb1.DeepEqual(rbC) { + t.Error("DeepEqual should return false when array data differs") + } + }) + + // Mini test 4: All conditions pass - should hit final "return true" + t.Run("MiniTest_AllEqual", func(t *testing.T) { + builderD := NewRecordBatchBuilder() + schemaD := builderD.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + colD1 := builderD.GenIntArray(1, 2, 3) + colD2 := builderD.GenStringArray("Alice", "Bob", "Charlie") + colD3 := builderD.GenFloatArray(95.5, 87.3, 92.1) + rbD, _ := builderD.NewRecordBatch(schemaD, []arrow.Array{colD1, colD2, colD3}) + if !rb1.DeepEqual(rbD) { + t.Error("DeepEqual should return true when all conditions match") + } + }) + + // Test case 1: Identical RecordBatches should be equal (tests the happy path - return true) + t.Run("Identical RecordBatches", func(t *testing.T) { + builder1 := NewRecordBatchBuilder() + schema1a := builder1.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + + idCol1a := builder1.GenIntArray(1, 2, 3) + nameCol1a := builder1.GenStringArray("Alice", "Bob", "Charlie") + scoreCol1a := builder1.GenFloatArray(95.5, 87.3, 92.1) + + rb1a, err := builder1.NewRecordBatch(schema1a, []arrow.Array{idCol1a, nameCol1a, scoreCol1a}) + if err != nil { + t.Fatalf("Failed to create identical RecordBatch: %v", err) + } + + if !rb1.DeepEqual(rb1a) { + t.Error("Expected identical RecordBatches to be equal") + } + }) + + // Test case 2: Different schemas should return false (tests first if branch) + t.Run("Different Schemas", func(t *testing.T) { + builder2 := NewRecordBatchBuilder() + // Different schema: different field name + differentSchema := builder2.SchemaBuilder. + WithField("user_id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + + idCol2 := builder2.GenIntArray(1, 2, 3) + nameCol2 := builder2.GenStringArray("Alice", "Bob", "Charlie") + scoreCol2 := builder2.GenFloatArray(95.5, 87.3, 92.1) + + rb2, err := builder2.NewRecordBatch(differentSchema, []arrow.Array{idCol2, nameCol2, scoreCol2}) + if err != nil { + t.Fatalf("Failed to create RecordBatch with different schema: %v", err) + } + + if rb1.DeepEqual(rb2) { + t.Error("Expected RecordBatches with different schemas to not be equal") + } + }) + + // Test case 3: Different number of columns should return false (tests second if branch) + t.Run("Different Column Count", func(t *testing.T) { + builder3 := NewRecordBatchBuilder() + // Schema with only 2 fields instead of 3 + fewerFieldsSchema := builder3.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + Build() + + idCol3 := builder3.GenIntArray(1, 2, 3) + nameCol3 := builder3.GenStringArray("Alice", "Bob", "Charlie") + + rb3, err := builder3.NewRecordBatch(fewerFieldsSchema, []arrow.Array{idCol3, nameCol3}) + if err != nil { + t.Fatalf("Failed to create RecordBatch with fewer columns: %v", err) + } + + if rb1.DeepEqual(rb3) { + t.Error("Expected RecordBatches with different column counts to not be equal") + } + }) + + // Test case 4: Same schema and column count, but different column data should return false + // (tests the for loop and array.Equal returning false) + t.Run("Different Column Data", func(t *testing.T) { + builder4 := NewRecordBatchBuilder() + schema4 := builder4.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + + // Different data in the first column + idCol4 := builder4.GenIntArray(10, 20, 30) // Different values + nameCol4 := builder4.GenStringArray("Alice", "Bob", "Charlie") + scoreCol4 := builder4.GenFloatArray(95.5, 87.3, 92.1) + + rb4, err := builder4.NewRecordBatch(schema4, []arrow.Array{idCol4, nameCol4, scoreCol4}) + if err != nil { + t.Fatalf("Failed to create RecordBatch with different data: %v", err) + } + + if rb1.DeepEqual(rb4) { + t.Error("Expected RecordBatches with different column data to not be equal") + } + }) + + // Test case 5: Different data in middle column (tests for loop continues to check all columns) + t.Run("Different Middle Column Data", func(t *testing.T) { + builder5 := NewRecordBatchBuilder() + schema5 := builder5.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + + idCol5 := builder5.GenIntArray(1, 2, 3) + nameCol5 := builder5.GenStringArray("Dave", "Eve", "Frank") // Different names + scoreCol5 := builder5.GenFloatArray(95.5, 87.3, 92.1) + + rb5, err := builder5.NewRecordBatch(schema5, []arrow.Array{idCol5, nameCol5, scoreCol5}) + if err != nil { + t.Fatalf("Failed to create RecordBatch with different middle column: %v", err) + } + + if rb1.DeepEqual(rb5) { + t.Error("Expected RecordBatches with different middle column to not be equal") + } + }) + + // Test case 6: Different data in last column (tests for loop completes and finds inequality) + t.Run("Different Last Column Data", func(t *testing.T) { + builder6 := NewRecordBatchBuilder() + schema6 := builder6.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("score", arrow.PrimitiveTypes.Float64, false). + Build() + + idCol6 := builder6.GenIntArray(1, 2, 3) + nameCol6 := builder6.GenStringArray("Alice", "Bob", "Charlie") + scoreCol6 := builder6.GenFloatArray(100.0, 100.0, 100.0) // Different scores + + rb6, err := builder6.NewRecordBatch(schema6, []arrow.Array{idCol6, nameCol6, scoreCol6}) + if err != nil { + t.Fatalf("Failed to create RecordBatch with different last column: %v", err) + } + + if rb1.DeepEqual(rb6) { + t.Error("Expected RecordBatches with different last column to not be equal") + } + }) + + // Test case 7: Same RecordBatch compared to itself (tests reflexivity) + t.Run("Same RecordBatch Instance", func(t *testing.T) { + if !rb1.DeepEqual(rb1) { + t.Error("Expected RecordBatch to be equal to itself") + } + }) + + // Test case 8: Empty RecordBatches should be equal + t.Run("Empty RecordBatches", func(t *testing.T) { + builder8a := NewRecordBatchBuilder() + emptySchema8a := builder8a.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + Build() + emptyCol8a := builder8a.GenIntArray() + rb_empty8a, err := builder8a.NewRecordBatch(emptySchema8a, []arrow.Array{emptyCol8a}) + if err != nil { + t.Fatalf("Failed to create empty RecordBatch 1: %v", err) + } + + builder8b := NewRecordBatchBuilder() + emptySchema8b := builder8b.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + Build() + emptyCol8b := builder8b.GenIntArray() + + rb_empty8b, err := builder8b.NewRecordBatch(emptySchema8b, []arrow.Array{emptyCol8b}) + if err != nil { + t.Fatalf("Failed to create empty RecordBatch 2: %v", err) + } + + if !rb_empty8a.DeepEqual(rb_empty8b) { + t.Error("Expected empty RecordBatches to be equal") + } + }) } diff --git a/src/Backend/opti-sql-go/operators/serialize.go b/src/Backend/opti-sql-go/operators/serialize.go index 1a68f5f..7336a63 100644 --- a/src/Backend/opti-sql-go/operators/serialize.go +++ b/src/Backend/opti-sql-go/operators/serialize.go @@ -1,3 +1,413 @@ package operators -// turn records into something we can read and write to disk +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" +) + +/* +Protocol | +Schema +┌──────────────────────────────────────────┐ +│ uint32 numberOfFields │ +├──────────────────────────────────────────┤ +│ uint32 field1NameLength │ +│ bytes[...] field1Name │ +│ uint32 field1TypeLength │ +│ bytes[...] field1TypeString │ +│ uint8 field1Nullable │ +├──────────────────────────────────────────┤ +│ uint32 field2NameLength │ +│ bytes[...] field2Name │ +│ uint32 field2TypeLength │ +│ bytes[...] field2TypeString │ +│ uint8 field2Nullable │ +├──────────────────────────────────────────┤ +│ ... repeated for N fields ... │ +└──────────────────────────────────────────┘ +Example: +[int32 nameLength][‘age’] +[int32 typeLength][‘int32’] +[byte nullable] + + +Column Data +┌──────────────────────────────────────────┐ +│ int64 lengthOfArray (num rows) │ +├──────────────────────────────────────────┤ +│ uint32 numBuffers │ +├──────────────────────────────────────────┤ +│ uint64 buffer0Length │ +│ bytes[] buffer0Bytes │ +├──────────────────────────────────────────┤ +│ uint64 buffer1Length │ +│ bytes[] buffer1Bytes │ +├──────────────────────────────────────────┤ +│ ... repeated for N buffers ... │ +└──────────────────────────────────────────┘ +Int32 exameple (5 elements, validity + values buffers): +[5] // array length +[2] // numBuffers = 2 + +// validity buffer +[1] // length (in bytes) +[0b11111...] // raw bitmap byte + +// values buffer +[20] // length = 5 × 4 bytes +[raw binary ints…] + + +Record Batch +┌──────────────────────────────────────────┐ +schemaBlock +column0Block +column1Block +column2Block +... +(column blocks for batch 1) +column0Block +column1Block +column2Block +... +(column blocks for batch 2) +... +└──────────────────────────────────────────┘ + + + + +What are we look for? +for serilization for intermediate record batches. +The main use case is for pipeline breaking operators where its unsafe to assume that all the records wil fit in ram +This means that all the inputs will the same schema | for example sort(col) -> each element is the exact same | join(col1 == col2) keep the left side in memory/in a seperate file and the right side will have the exact same schema + +This means that we can just work with the data directly since well know the schema. Just to be safe we can keep the schema in memory attached directly to the object/class handling the serilization + +V1 +Format on disk -> dataTypeSize|dataType|BatchSize|BatchElements|BatchSize|BatchElements..... + +but this assumes were are only dealing with one column. +what about sort order? +what do we do with the other columns as were writing this single column to disk? + + +We have to write the entire record batch to disk + +V2 +write the schema out to disk as well but this is only to validate with the in memory schema +!! inbetween each column being read into memory check with in memory schema for their data type for correct encoding +!! schema will also tell u how many columns u have to read in for that specefic record batch +format on disk -> **schema|ColumnNsize|columnNData|columnN+1size|columnN+1data|columnN+2size|columnN+2data...EndOfrecordBatch|columnNSize|columnNData|.... +format on disk for schema -> number of fields | field1NameLength|field1Name|field1TypeLength|field1Type|field1Nullable| field2NameLength|field2Name|field2TypeLength|field2Type|field2Nullable... +were going to be writing more data but theres not much we can do about that. for now this is fine + +optimizations +(1)wont need to add the schema each time +*/ +// saves allocations to have a struct that implements this interface than to have this attached directly to RecordBatch +// especially in the cases where we have to spill to disk multiple times (sort,hash join, aggregation) +type serializer struct { + schema *arrow.Schema // schema is always attached to the serializer +} + +func NewSerializer(schema *arrow.Schema) (*serializer, error) { + return &serializer{ + schema: schema, + }, nil +} +func (s *serializer) Schema() *arrow.Schema { + return s.schema +} + +// overwrite the input schema here. it should be the same but in the case where its not id rather the input schema be the source of truth +func (ss *serializer) SerializeBatchColumns(r RecordBatch) ([]byte, error) { + if !ss.schema.Equal(r.Schema) { + return nil, ErrInvalidSchema("serializer schema and record batch schema are not aligned") + } + columnContent, err := ss.columnsTodisk(r.Columns) + if err != nil { + return nil, err + } + return columnContent, nil +} +func (ss *serializer) SerializeSchema(s *arrow.Schema) ([]byte, error) { + buf := new(bytes.Buffer) + + // 1. number of fields + if err := binary.Write(buf, binary.LittleEndian, uint32(len(s.Fields()))); err != nil { + return nil, err + } + + for _, f := range s.Fields() { + // --- Field Name --- + nameBytes := []byte(f.Name) + if err := binary.Write(buf, binary.LittleEndian, uint32(len(nameBytes))); err != nil { + return nil, err + } + if _, err := buf.Write(nameBytes); err != nil { + return nil, err + } + + // --- Field Type (use Arrow's string representation) --- + typeBytes := []byte(f.Type.String()) + if err := binary.Write(buf, binary.LittleEndian, uint32(len(typeBytes))); err != nil { + return nil, err + } + if _, err := buf.Write(typeBytes); err != nil { + return nil, err + } + + // --- Nullable --- + var nullable uint8 + if f.Nullable { + nullable = 1 + fmt.Println("case 1") + } else { + nullable = 0 + fmt.Println("case 0") + } + if err := binary.Write(buf, binary.LittleEndian, nullable); err != nil { + return nil, err + } + } + + return buf.Bytes(), nil +} +func (ss *serializer) columnsTodisk(columns []arrow.Array) ([]byte, error) { + buf := new(bytes.Buffer) + + for _, col := range columns { + data := col.Data() + + // Write array length (number of rows) + if err := binary.Write(buf, binary.LittleEndian, int64(data.Len())); err != nil { + return nil, err + } + + // Number of buffers for this column + buffers := data.Buffers() + if err := binary.Write(buf, binary.LittleEndian, uint32(len(buffers))); err != nil { + return nil, err + } + + // Write each buffer + for _, b := range buffers { + if b == nil || b.Len() == 0 { + // Write 0 length + if err := binary.Write(buf, binary.LittleEndian, uint64(0)); err != nil { + return nil, err + } + continue + } + + // Write length of buffer + if err := binary.Write(buf, binary.LittleEndian, uint64(b.Len())); err != nil { + return nil, err + } + + // Write buffer contents + if _, err := buf.Write(b.Bytes()); err != nil { + return nil, err + } + } + } + + return buf.Bytes(), nil +} + +func (ss *serializer) DeserializeSchema(data io.Reader) (*arrow.Schema, error) { + // read in the schema first + return ss.schemaFromDisk(data) +} + +// after reading in the schema we read in one column at a time +func (ss *serializer) DeserializeNextColumn(r io.Reader, dt arrow.DataType) (arrow.Array, error) { + // 1. Read the number of elements in this column batch + var length int64 + if err := binary.Read(r, binary.LittleEndian, &length); err != nil { + return nil, err + } + + // 2. Read number of buffers for this column + var numBuffers uint32 + if err := binary.Read(r, binary.LittleEndian, &numBuffers); err != nil { + return nil, err + } + + buffers := make([]*memory.Buffer, numBuffers) + + // 3. Read each buffer in order + for i := uint32(0); i < numBuffers; i++ { + // buffer length + var size uint64 + if err := binary.Read(r, binary.LittleEndian, &size); err != nil { + return nil, err + } + + if size == 0 { + // Null / empty buffer + buffers[i] = nil + continue + } + + // Read raw bytes + raw := make([]byte, size) + if _, err := io.ReadFull(r, raw); err != nil { + return nil, err + } + + buffers[i] = memory.NewBufferBytes(raw) + } + + // 4. Construct Arrow ArrayData + arrData := array.NewData( + dt, + int(length), + buffers, // buffers + nil, // children (none for primitive) + -1, // null count (setting it to -1 lets Arrow compute it lazily) + 0, // offset + ) + + // 5. Wrap into Array type + return array.MakeFromData(arrData), nil +} + +// must call ss.DeserializeSchema first or else this will not work properly +func (ss *serializer) DecodeRecordBatch(r io.Reader, schema *arrow.Schema) ([]arrow.Array, error) { + if !ss.schema.Equal(schema) { + return nil, ErrInvalidSchema("serializer schema and provided schema do not match") + } + arrays := make([]arrow.Array, len(schema.Fields())) + + for i, field := range schema.Fields() { + arr, err := ss.DeserializeNextColumn(r, field.Type) + if err == io.EOF { + return nil, io.EOF + } + if err != nil { + return nil, err + } + arrays[i] = arr + } + + return arrays, nil +} + +func (ss *serializer) schemaFromDisk(data io.Reader) (*arrow.Schema, error) { + + // number of fields + var num uint32 + if err := binary.Read(data, binary.LittleEndian, &num); err != nil { + return nil, err + } + + fields := make([]arrow.Field, 0, num) + + for i := uint32(0); i < num; i++ { + // read name + var nameLen uint32 + binary.Read(data, binary.LittleEndian, &nameLen) + nameBytes := make([]byte, nameLen) + data.Read(nameBytes) + + // read type + var typeLen uint32 + binary.Read(data, binary.LittleEndian, &typeLen) + typeBytes := make([]byte, typeLen) + data.Read(typeBytes) + typ, err := BasicArrowTypeFromString(string(typeBytes)) + if err != nil { + return nil, err + } + + // read nullable + var nullable uint8 + binary.Read(data, binary.LittleEndian, &nullable) + + fields = append(fields, arrow.Field{ + Name: string(nameBytes), + Type: typ, + Nullable: nullable == 1, + }) + } + + return arrow.NewSchema(fields, nil), nil +} + +func BasicArrowTypeFromString(s string) (arrow.DataType, error) { + switch s { + case "null": + return arrow.Null, nil + case "bool": + return arrow.FixedWidthTypes.Boolean, nil + + case "int8": + return arrow.PrimitiveTypes.Int8, nil + case "int16": + return arrow.PrimitiveTypes.Int16, nil + case "int32": + return arrow.PrimitiveTypes.Int32, nil + case "int64": + return arrow.PrimitiveTypes.Int64, nil + + case "uint8": + return arrow.PrimitiveTypes.Uint8, nil + case "uint16": + return arrow.PrimitiveTypes.Uint16, nil + case "uint32": + return arrow.PrimitiveTypes.Uint32, nil + case "uint64": + return arrow.PrimitiveTypes.Uint64, nil + + case "float32": + return arrow.PrimitiveTypes.Float32, nil + case "float64": + return arrow.PrimitiveTypes.Float64, nil + + case "string", "utf8": + return arrow.BinaryTypes.String, nil + case "large_string", "large_utf8": + return arrow.BinaryTypes.LargeString, nil + + case "binary": + return arrow.BinaryTypes.Binary, nil + case "large_binary": + return arrow.BinaryTypes.LargeBinary, nil + } + + return nil, fmt.Errorf("unsupported arrow type: %s", s) +} + +/* + +FILE: +┌────────────────────────┐ +│ SCHEMA BLOCK │ +│ numberOfFields │ +│ (field entries...) │ +├────────────────────────┤ +│ RECORD BATCH #1 │ +│ COLUMN 0 │ +│ arrayLength │ +│ numBuffers │ +│ buffers[...] │ +│ COLUMN 1 │ +│ COLUMN 2 │ +│ ... │ +├────────────────────────┤ +│ RECORD BATCH #2 │ +│ COLUMN 0 │ +│ COLUMN 1 │ +│ COLUMN 2 │ +│ ... │ +└────────────────────────┘ +EOF +*/ diff --git a/src/Backend/opti-sql-go/operators/serialize_test.go b/src/Backend/opti-sql-go/operators/serialize_test.go index a384b1c..4115bfb 100644 --- a/src/Backend/opti-sql-go/operators/serialize_test.go +++ b/src/Backend/opti-sql-go/operators/serialize_test.go @@ -1,7 +1,758 @@ package operators -import "testing" +import ( + "bytes" + "fmt" + "io" + "testing" -func TestSerialize(t *testing.T) { - // Simple passing test + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" +) + +// generateDummyRecordBatch1 creates a test RecordBatch with employee data +func generateDummyRecordBatch1() RecordBatch { + dummyBuilder := NewRecordBatchBuilder() + dummyBuilder.SchemaBuilder.WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("age", arrow.PrimitiveTypes.Int32, false). + WithField("salary", arrow.PrimitiveTypes.Float64, false) + + colums := []arrow.Array{ + dummyBuilder.GenIntArray(1, 2, 3, 4, 5), + dummyBuilder.GenStringArray("Alice", "Bob", "Charlie", "David", "Eve"), + dummyBuilder.GenIntArray(25, 30, 35, 40, 45), + dummyBuilder.GenFloatArray(50000.0, 60000.0, 70000.0, 80000.0, 90000.0), + } + RecordBatch, _ := dummyBuilder.NewRecordBatch(dummyBuilder.Schema(), colums) + return *RecordBatch +} + +// generateDummyRecordBatch2 creates a test RecordBatch with product data +func generateDummyRecordBatch2() RecordBatch { + dummyBuilder := NewRecordBatchBuilder() + + // Define a different schema + dummyBuilder.SchemaBuilder. + WithField("product_id", arrow.PrimitiveTypes.Int32, false). + WithField("product_name", arrow.BinaryTypes.String, false). + WithField("quantity", arrow.PrimitiveTypes.Int32, false). + WithField("price", arrow.PrimitiveTypes.Float64, false). + WithField("in_stock", arrow.FixedWidthTypes.Boolean, false) + + // Generate dummy columns + columns := []arrow.Array{ + dummyBuilder.GenIntArray(101, 102, 103, 104, 105), + dummyBuilder.GenStringArray("Keyboard", "Mouse", "Monitor", "Laptop", "Headphones"), + dummyBuilder.GenIntArray(10, 50, 15, 5, 20), + dummyBuilder.GenFloatArray(49.99, 19.99, 199.99, 999.99, 79.99), + dummyBuilder.GenBoolArray(true, true, false, true, true), + } + + // Build the record batch + recordBatch, _ := dummyBuilder.NewRecordBatch(dummyBuilder.Schema(), columns) + return *recordBatch +} + +// generateEmptyRecordBatch creates a RecordBatch with schema but no rows +func generateEmptyRecordBatch() RecordBatch { + builder := NewRecordBatchBuilder() + builder.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, false). + WithField("name", arrow.BinaryTypes.String, false) + + columns := []arrow.Array{ + builder.GenIntArray(), + builder.GenStringArray(), + } + + recordBatch, _ := builder.NewRecordBatch(builder.Schema(), columns) + return *recordBatch +} + +// generateNullableRecordBatch creates a RecordBatch with nullable fields containing nulls +func generateNullableRecordBatch() RecordBatch { + builder := NewRecordBatchBuilder() + builder.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int32, true). + WithField("value", arrow.PrimitiveTypes.Float64, true) + + // Build arrays with null values manually + mem := memory.NewGoAllocator() + + // Int32 array with nulls + intBuilder := array.NewInt32Builder(mem) + intBuilder.AppendValues([]int32{1, 2, 3}, []bool{true, false, true}) + intArray := intBuilder.NewArray() + + // Float64 array with nulls + floatBuilder := array.NewFloat64Builder(mem) + floatBuilder.AppendValues([]float64{1.1, 2.2, 3.3}, []bool{true, true, false}) + floatArray := floatBuilder.NewArray() + + columns := []arrow.Array{intArray, floatArray} + recordBatch, _ := builder.NewRecordBatch(builder.Schema(), columns) + return *recordBatch +} + +// TestSerializerInit verifies serializer creation and schema validation +func TestSerializerInit(t *testing.T) { + recordBatch := generateDummyRecordBatch1() + serializer, err := NewSerializer(recordBatch.Schema) + if err != nil { + t.Fatalf("Failed to initialize serializer: %v", err) + } + + // Validate schema matches + if !serializer.schema.Equal(recordBatch.Schema) { + t.Fatal("Serializer schema does not match the provided schema") + } + + // Validate schema field count + if serializer.schema.NumFields() != 4 { + t.Fatalf("Expected 4 fields, got %d", serializer.schema.NumFields()) + } + + // Validate field names + expectedFields := []string{"id", "name", "age", "salary"} + for i, expected := range expectedFields { + if serializer.schema.Field(i).Name != expected { + t.Errorf("Field %d: expected name %q, got %q", i, expected, serializer.schema.Field(i).Name) + } + } + + // Validate field types + if serializer.schema.Field(0).Type.ID() != arrow.INT32 { + t.Errorf("Field 'id': expected INT32 type, got %v", serializer.schema.Field(0).Type) + } + if serializer.schema.Field(1).Type.ID() != arrow.STRING { + t.Errorf("Field 'name': expected STRING type, got %v", serializer.schema.Field(1).Type) + } +} + +// TestSchemaOnlySerialization tests standalone schema serialization/deserialization +func TestSchemaOnlySerialization(t *testing.T) { + recordBatch := generateDummyRecordBatch1() + fmt.Printf("original schema before serialization: %v\n", recordBatch.Schema) + + ss, err := NewSerializer(recordBatch.Schema) + if err != nil { + t.Fatalf("Failed to initialize serializer: %v", err) + } + + // Serialize schema + serializedSchema, err := ss.SerializeSchema(recordBatch.Schema) + if err != nil { + t.Fatalf("Schema serialization failed: %v", err) + } + fmt.Printf("serialized schema bytes length: %d\n", len(serializedSchema)) + + // Deserialize schema + deserializedSchema, err := ss.schemaFromDisk(bytes.NewBuffer(serializedSchema)) + if err != nil { + t.Fatalf("Schema deserialization failed: %v", err) + } + + // Validate schemas match + if !deserializedSchema.Equal(recordBatch.Schema) { + t.Fatal("Deserialized schema does not match the original schema") + } + fmt.Printf("schema after serialization & deserialization: %v\n", deserializedSchema) + + // Validate field properties + for i := 0; i < recordBatch.Schema.NumFields(); i++ { + origField := recordBatch.Schema.Field(i) + deserField := deserializedSchema.Field(i) + + if origField.Name != deserField.Name { + t.Errorf("Field %d name mismatch: expected %q, got %q", i, origField.Name, deserField.Name) + } + if origField.Type.ID() != deserField.Type.ID() { + t.Errorf("Field %d type mismatch: expected %v, got %v", i, origField.Type, deserField.Type) + } + if origField.Nullable != deserField.Nullable { + t.Errorf("Field %d nullable mismatch: expected %v, got %v", i, origField.Nullable, deserField.Nullable) + } + } +} + +// TestSerializerSchemaValidationFails verifies schema mismatch detection +func TestSerializerSchemaValidationFails(t *testing.T) { + // RecordBatch 1 uses schema A + rb1 := generateDummyRecordBatch1() + + // RecordBatch 2 uses schema B (intentionally different) + rb2 := generateDummyRecordBatch2() + + // Initialize serializer with schema from rb1 + serializer, err := NewSerializer(rb1.Schema) + if err != nil { + t.Fatalf("Failed to initialize serializer: %v", err) + } + + // Verify serializer schema is correct before test + if !serializer.schema.Equal(rb1.Schema) { + t.Fatal("Serializer schema does not match the initial schema") + } + + // Attempt to serialize a record batch with a DIFFERENT schema + _, err = serializer.SerializeBatchColumns(rb2) + if err == nil { + t.Fatal("Expected schema validation error, but got nil") + } + + // Make sure Schema() still returns the original serializer schema + decoded := serializer.Schema() + if !decoded.Equal(rb1.Schema) { + t.Fatal("Schema() returned an incorrect schema after validation failure") + } +} + +// TestSerializeDeserializeRoundTrip performs full round-trip test with sub-tests +func TestSerializeDeserializeRoundTrip(t *testing.T) { + rb1 := generateDummyRecordBatch1() + + t.Run("Schema Serialization", func(t *testing.T) { + serializer, err := NewSerializer(rb1.Schema) + if err != nil { + t.Fatalf("Failed to create serializer: %v", err) + } + + // Serialize schema + schemaBytes, err := serializer.SerializeSchema(rb1.Schema) + if err != nil { + t.Fatalf("Failed to serialize schema: %v", err) + } + + // Deserialize schema + buf := bytes.NewBuffer(schemaBytes) + deserializedSchema, err := serializer.schemaFromDisk(buf) + if err != nil { + t.Fatalf("Failed to deserialize schema: %v", err) + } + + // Validate + if !deserializedSchema.Equal(rb1.Schema) { + t.Errorf("Schemas do not match after round-trip") + } + }) + + t.Run("Columns Serialization", func(t *testing.T) { + serializer, err := NewSerializer(rb1.Schema) + if err != nil { + t.Fatalf("Failed to create serializer: %v", err) + } + + // Serialize columns + columnsBytes, err := serializer.SerializeBatchColumns(rb1) + if err != nil { + t.Fatalf("Failed to serialize columns: %v", err) + } + + // Deserialize columns + buf := bytes.NewBuffer(columnsBytes) + deserializedColumns, err := serializer.DecodeRecordBatch(buf, rb1.Schema) + if err != nil { + t.Fatalf("Failed to deserialize columns: %v", err) + } + + // Validate column count + if len(deserializedColumns) != len(rb1.Columns) { + t.Fatalf("Expected %d columns, got %d", len(rb1.Columns), len(deserializedColumns)) + } + + // Validate each column + for i, origCol := range rb1.Columns { + deserCol := deserializedColumns[i] + + // Check length + if deserCol.Len() != origCol.Len() { + t.Errorf("Column %d: length mismatch: expected %d, got %d", i, origCol.Len(), deserCol.Len()) + continue + } + + // Check data type + if deserCol.DataType().ID() != origCol.DataType().ID() { + t.Errorf("Column %d: type mismatch: expected %v, got %v", i, origCol.DataType(), deserCol.DataType()) + continue + } + + // Validate data values based on type + switch origCol.DataType().ID() { + case arrow.INT32: + origData := origCol.(*array.Int32).Int32Values() + deserData := deserCol.(*array.Int32).Int32Values() + for j := 0; j < len(origData); j++ { + if origData[j] != deserData[j] { + t.Errorf("Column %d, row %d: expected %d, got %d", i, j, origData[j], deserData[j]) + } + } + case arrow.FLOAT64: + origData := origCol.(*array.Float64).Float64Values() + deserData := deserCol.(*array.Float64).Float64Values() + for j := 0; j < len(origData); j++ { + if origData[j] != deserData[j] { + t.Errorf("Column %d, row %d: expected %f, got %f", i, j, origData[j], deserData[j]) + } + } + case arrow.STRING: + origArr := origCol.(*array.String) + deserArr := deserCol.(*array.String) + for j := 0; j < origArr.Len(); j++ { + if origArr.Value(j) != deserArr.Value(j) { + t.Errorf("Column %d, row %d: expected %q, got %q", i, j, origArr.Value(j), deserArr.Value(j)) + } + } + } + } + }) + + t.Run("Full RecordBatch Round-Trip", func(t *testing.T) { + serializer, err := NewSerializer(rb1.Schema) + if err != nil { + t.Fatalf("Failed to create serializer: %v", err) + } + + // Create buffer for full serialization + var buf bytes.Buffer + + // Write schema + schemaBytes, err := serializer.SerializeSchema(rb1.Schema) + if err != nil { + t.Fatalf("Failed to serialize schema: %v", err) + } + buf.Write(schemaBytes) + + // Write columns + columnsBytes, err := serializer.SerializeBatchColumns(rb1) + if err != nil { + t.Fatalf("Failed to serialize columns: %v", err) + } + buf.Write(columnsBytes) + + // Read everything back + reader := bytes.NewReader(buf.Bytes()) + + // Deserialize schema + deserSchema, err := serializer.schemaFromDisk(reader) + if err != nil { + t.Fatalf("Failed to deserialize schema: %v", err) + } + + // Deserialize columns + deserColumns, err := serializer.DecodeRecordBatch(reader, deserSchema) + if err != nil { + t.Fatalf("Failed to deserialize columns: %v", err) + } + + // Create new RecordBatch + builder := NewRecordBatchBuilder() + deserBatch, err := builder.NewRecordBatch(deserSchema, deserColumns) + if err != nil { + t.Fatalf("Failed to create deserialized RecordBatch: %v", err) + } + + // Validate + if !deserBatch.Schema.Equal(rb1.Schema) { + t.Errorf("Deserialized schema does not match original") + } + if len(deserBatch.Columns) != len(rb1.Columns) { + t.Errorf("Column count mismatch: expected %d, got %d", len(rb1.Columns), len(deserBatch.Columns)) + } + // Check row count by checking first column length + if len(deserBatch.Columns) > 0 && len(rb1.Columns) > 0 { + if deserBatch.Columns[0].Len() != rb1.Columns[0].Len() { + t.Errorf("Row count mismatch: expected %d, got %d", rb1.Columns[0].Len(), deserBatch.Columns[0].Len()) + } + } + }) +} + +// TestEmptyRecordBatchSerialization tests edge case with zero rows +func TestEmptyRecordBatchSerialization(t *testing.T) { + rb := generateEmptyRecordBatch() + + serializer, err := NewSerializer(rb.Schema) + if err != nil { + t.Fatalf("Failed to create serializer: %v", err) + } + + // Serialize + columnsBytes, err := serializer.SerializeBatchColumns(rb) + if err != nil { + t.Fatalf("Failed to serialize empty batch: %v", err) + } + + // Deserialize + buf := bytes.NewBuffer(columnsBytes) + deserColumns, err := serializer.DecodeRecordBatch(buf, rb.Schema) + if err != nil { + t.Fatalf("Failed to deserialize empty batch: %v", err) + } + + // Validate + if len(deserColumns) != len(rb.Columns) { + t.Fatalf("Column count mismatch: expected %d, got %d", len(rb.Columns), len(deserColumns)) + } + + for i, col := range deserColumns { + if col.Len() != 0 { + t.Errorf("Column %d: expected length 0, got %d", i, col.Len()) + } + } +} + +// TestNullValuesSerialization tests nullable fields with null values +func TestNullValuesSerialization(t *testing.T) { + rb := generateNullableRecordBatch() + + serializer, err := NewSerializer(rb.Schema) + if err != nil { + t.Fatalf("Failed to create serializer: %v", err) + } + + // Serialize + columnsBytes, err := serializer.SerializeBatchColumns(rb) + if err != nil { + t.Fatalf("Failed to serialize nullable batch: %v", err) + } + + // Deserialize + buf := bytes.NewBuffer(columnsBytes) + deserColumns, err := serializer.DecodeRecordBatch(buf, rb.Schema) + if err != nil { + t.Fatalf("Failed to deserialize nullable batch: %v", err) + } + + // Validate null bitmap preservation + for i, origCol := range rb.Columns { + deserCol := deserColumns[i] + + if origCol.NullN() != deserCol.NullN() { + t.Errorf("Column %d: null count mismatch: expected %d, got %d", + i, origCol.NullN(), deserCol.NullN()) + } + + // Check each row's nullness + for j := 0; j < origCol.Len(); j++ { + if origCol.IsNull(j) != deserCol.IsNull(j) { + t.Errorf("Column %d, row %d: null status mismatch: expected %v, got %v", + i, j, origCol.IsNull(j), deserCol.IsNull(j)) + } + } + } +} + +// TestMultipleBatchesSerialization tests writing/reading multiple batches to same buffer +func TestMultipleBatchesSerialization(t *testing.T) { + rb1 := generateDummyRecordBatch1() + rb2 := generateDummyRecordBatch1() // Same schema, different instance + + serializer, err := NewSerializer(rb1.Schema) + if err != nil { + t.Fatalf("Failed to create serializer: %v", err) + } + + var buf bytes.Buffer + + // Write schema once + schemaBytes, err := serializer.SerializeSchema(rb1.Schema) + if err != nil { + t.Fatalf("Failed to serialize schema: %v", err) + } + buf.Write(schemaBytes) + + // Write first batch + batch1Bytes, err := serializer.SerializeBatchColumns(rb1) + if err != nil { + t.Fatalf("Failed to serialize batch 1: %v", err) + } + buf.Write(batch1Bytes) + + // Write second batch + batch2Bytes, err := serializer.SerializeBatchColumns(rb2) + if err != nil { + t.Fatalf("Failed to serialize batch 2: %v", err) + } + buf.Write(batch2Bytes) + + // Read back + reader := bytes.NewReader(buf.Bytes()) + + // Read schema + schema, err := serializer.schemaFromDisk(reader) + if err != nil { + t.Fatalf("Failed to deserialize schema: %v", err) + } + + // Read first batch + cols1, err := serializer.DecodeRecordBatch(reader, schema) + if err != nil { + t.Fatalf("Failed to deserialize batch 1: %v", err) + } + + // Read second batch + cols2, err := serializer.DecodeRecordBatch(reader, schema) + if err != nil { + t.Fatalf("Failed to deserialize batch 2: %v", err) + } + + // Validate both batches + if len(cols1) != len(rb1.Columns) { + t.Errorf("Batch 1: column count mismatch") + } + if len(cols2) != len(rb2.Columns) { + t.Errorf("Batch 2: column count mismatch") + } + + // Verify EOF after reading all batches + _, err = serializer.DecodeRecordBatch(reader, schema) + if err != io.EOF { + t.Errorf("Expected EOF after reading all batches, got: %v", err) + } +} + +// TestBasicArrowTypeFromString tests type string parsing +func TestBasicArrowTypeFromString(t *testing.T) { + // cover all supported branches plus an unsupported case + testCases := []struct { + typeStr string + expectType arrow.Type + expectErr bool + }{ + {"null", arrow.NULL, false}, + {"bool", arrow.BOOL, false}, + + {"int8", arrow.INT8, false}, + {"int16", arrow.INT16, false}, + {"int32", arrow.INT32, false}, + {"int64", arrow.INT64, false}, + + {"uint8", arrow.UINT8, false}, + {"uint16", arrow.UINT16, false}, + {"uint32", arrow.UINT32, false}, + {"uint64", arrow.UINT64, false}, + + {"float32", arrow.FLOAT32, false}, + {"float64", arrow.FLOAT64, false}, + + {"string", arrow.STRING, false}, + {"utf8", arrow.STRING, false}, + {"large_string", arrow.LARGE_STRING, false}, + {"large_utf8", arrow.LARGE_STRING, false}, + + {"binary", arrow.BINARY, false}, + {"large_binary", arrow.LARGE_BINARY, false}, + + // unsupported type should return an error + {"not_a_type", arrow.Type(0), true}, + } + + for _, tc := range testCases { + t.Run(tc.typeStr, func(t *testing.T) { + dt, err := BasicArrowTypeFromString(tc.typeStr) + if tc.expectErr { + if err == nil { + t.Fatalf("expected error for %q but got nil and dt=%v", tc.typeStr, dt) + } + return + } + if err != nil { + t.Fatalf("unexpected error parsing %q: %v", tc.typeStr, err) + } + if dt == nil { + t.Fatalf("type parsing returned nil for %q", tc.typeStr) + } + if dt.ID() != tc.expectType { + t.Fatalf("for %q expected type %v but got %v", tc.typeStr, tc.expectType, dt.ID()) + } + }) + } +} + +// TestSerializeRecordBatchDeepEqual writes a record batch to an in-memory buffer +// (schema + columns), reads the schema back using DeserializeSchema, then +// reads the record batch and verifies DeepEqual between original and round-tripped +// RecordBatch. +func TestSerializeRecordBatchDeepEqual(t *testing.T) { + rb := generateDummyRecordBatch1() + + serializer, err := NewSerializer(rb.Schema) + if err != nil { + t.Fatalf("failed to create serializer: %v", err) + } + + var buf bytes.Buffer + + // write schema first + schemaBytes, err := serializer.SerializeSchema(rb.Schema) + if err != nil { + t.Fatalf("failed to serialize schema: %v", err) + } + buf.Write(schemaBytes) + + // write columns + colsBytes, err := serializer.SerializeBatchColumns(rb) + if err != nil { + t.Fatalf("failed to serialize columns: %v", err) + } + buf.Write(colsBytes) + + // now read back + reader := bytes.NewReader(buf.Bytes()) + + // DeserializeSchema first to validate schema round-trip + deserializedSchema, err := serializer.DeserializeSchema(reader) + if err != nil { + t.Fatalf("DeserializeSchema failed: %v", err) + } + if !deserializedSchema.Equal(rb.Schema) { + t.Fatalf("schema mismatch after DeserializeSchema: expected %v got %v", rb.Schema, deserializedSchema) + } + + // DecodeRecordBatch reads columns from the same reader + arrays, err := serializer.DecodeRecordBatch(reader, deserializedSchema) + if err != nil { + t.Fatalf("DecodeRecordBatch failed: %v", err) + } + + // Build RecordBatch from deserialized arrays + builder := NewRecordBatchBuilder() + gotRB, err := builder.NewRecordBatch(deserializedSchema, arrays) + if err != nil { + t.Fatalf("failed to construct RecordBatch from deserialized arrays: %v", err) + } + + if !rb.DeepEqual(gotRB) { + t.Fatalf("original and deserialized RecordBatch differ") + } +} + +// TestDecodeRecordBatchInvalidSchema ensures DecodeRecordBatch fails when the +// provided schema does not match the serializer's schema. +func TestDecodeRecordBatchInvalidSchema(t *testing.T) { + rb := generateDummyRecordBatch1() + + serializer, err := NewSerializer(rb.Schema) + if err != nil { + t.Fatalf("failed to create serializer: %v", err) + } + + // Prepare buffer that contains only columns for rb + colsBytes, err := serializer.SerializeBatchColumns(rb) + if err != nil { + t.Fatalf("failed to serialize columns: %v", err) + } + reader := bytes.NewReader(colsBytes) + + // Create a deliberately different schema (swap a field type) + wrongBuilder := NewRecordBatchBuilder() + wrongBuilder.SchemaBuilder. + WithField("id", arrow.PrimitiveTypes.Int64, false). + WithField("name", arrow.BinaryTypes.String, false). + WithField("age", arrow.PrimitiveTypes.Int32, false). + WithField("salary", arrow.PrimitiveTypes.Float64, false) + wrongSchema := wrongBuilder.Schema() + + // Expect an ErrInvalidSchema because serializer.schema != wrongSchema + _, err = serializer.DecodeRecordBatch(reader, wrongSchema) + if err == nil { + t.Fatalf("expected DecodeRecordBatch to fail due to invalid schema, but it succeeded") + } +} + +// TestSerializationWithDifferentTypes tests all supported Arrow types +func TestSerializationWithDifferentTypes(t *testing.T) { + builder := NewRecordBatchBuilder() + builder.SchemaBuilder. + WithField("int32_col", arrow.PrimitiveTypes.Int32, false). + WithField("int64_col", arrow.PrimitiveTypes.Int64, false). + WithField("float32_col", arrow.PrimitiveTypes.Float32, false). + WithField("float64_col", arrow.PrimitiveTypes.Float64, false). + WithField("string_col", arrow.BinaryTypes.String, false). + WithField("bool_col", arrow.FixedWidthTypes.Boolean, false) + + mem := memory.NewGoAllocator() + + int32Builder := array.NewInt32Builder(mem) + int32Builder.AppendValues([]int32{1, 2, 3}, nil) + int32Array := int32Builder.NewArray() + + int64Builder := array.NewInt64Builder(mem) + int64Builder.AppendValues([]int64{100, 200, 300}, nil) + int64Array := int64Builder.NewArray() + + float32Builder := array.NewFloat32Builder(mem) + float32Builder.AppendValues([]float32{1.1, 2.2, 3.3}, nil) + float32Array := float32Builder.NewArray() + + float64Builder := array.NewFloat64Builder(mem) + float64Builder.AppendValues([]float64{10.1, 20.2, 30.3}, nil) + float64Array := float64Builder.NewArray() + + stringBuilder := array.NewStringBuilder(mem) + stringBuilder.AppendValues([]string{"a", "b", "c"}, nil) + stringArray := stringBuilder.NewArray() + + boolBuilder := array.NewBooleanBuilder(mem) + boolBuilder.AppendValues([]bool{true, false, true}, nil) + boolArray := boolBuilder.NewArray() + + columns := []arrow.Array{ + int32Array, int64Array, float32Array, + float64Array, stringArray, boolArray, + } + + rb, err := builder.NewRecordBatch(builder.Schema(), columns) + if err != nil { + t.Fatalf("Failed to create RecordBatch: %v", err) + } + + // Serialize and deserialize + serializer, err := NewSerializer(rb.Schema) + if err != nil { + t.Fatalf("Failed to create serializer: %v", err) + } + + columnsBytes, err := serializer.SerializeBatchColumns(*rb) + if err != nil { + t.Fatalf("Failed to serialize: %v", err) + } + + buf := bytes.NewBuffer(columnsBytes) + deserColumns, err := serializer.DecodeRecordBatch(buf, rb.Schema) + if err != nil { + t.Fatalf("Failed to deserialize: %v", err) + } + + // Validate all columns + if len(deserColumns) != len(columns) { + t.Fatalf("Column count mismatch: expected %d, got %d", len(columns), len(deserColumns)) + } + + // Validate each type + for i, deserCol := range deserColumns { + if deserCol.DataType().ID() != columns[i].DataType().ID() { + t.Errorf("Column %d: type mismatch: expected %v, got %v", + i, columns[i].DataType(), deserCol.DataType()) + } + } +} + +func TestNullSchemaSerialize(t *testing.T) { + rb := generateNullableRecordBatch() + for i := range rb.Schema.Fields() { + fmt.Printf("is nullable? : %v\n", rb.Schema.Field(i).Nullable) + } + serializer, err := NewSerializer(rb.Schema) + if err != nil { + t.Fatalf("Failed to create serializer: %v", err) + } + + // Serialize schema + _, err = serializer.SerializeSchema(rb.Schema) + if err != nil { + t.Fatalf("Schema serialization failed: %v", err) + } } diff --git a/src/Backend/test_data/s3_source/source.json b/src/Backend/test_data/s3_source/source.json index 0cb53bb..c7cd269 100644 --- a/src/Backend/test_data/s3_source/source.json +++ b/src/Backend/test_data/s3_source/source.json @@ -13,5 +13,4 @@ "s3://my-bucket/data/file1.parquet", "s3://my-bucket/data/file2.parquet" ] - } \ No newline at end of file diff --git a/src/Contract/operation.proto b/src/Contract/operation.proto index 6c2fdd5..850057b 100644 --- a/src/Contract/operation.proto +++ b/src/Contract/operation.proto @@ -10,8 +10,8 @@ service SSOperation { // The request message containing the operation details. message QueryExecutionRequest { bytes substraight_logical = 1; //SS logical plan - string sql_statment = 2; // original sql statment - string id = 3; // unique id for this client + string sql_statment = 2; // original sql statement + string id = 3; // unique id for this client/request this is how we track requests and identify s3 links SourceType source = 4; // (s3 link| base64 data) } @@ -27,13 +27,13 @@ message SourceType{ } enum returnTypes{ - Success = 0; - ParseError = 1; - ExecutionError = 2; - SourceError = 3; - UploadError = 4; - OutOfMemory = 5; - UnknownError = 6; + SUCCESS = 0; + PARSE_ERROR = 1; + EXECUTION_ERROR = 2; + SOURCE_ERROR = 3; + UPLOAD_ERROR = 4; + OUT_OF_MEMORY = 5; + UNKNOWN_ERROR = 6; } message ErrorDetails{ returnTypes error_type = 1; From abb0348eeb32bca10dea8bb8a04de793e9c134ae Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Fri, 14 Nov 2025 16:02:36 -0500 Subject: [PATCH 09/10] Feat:GRPC-server. Generated grpc code for substrait server, with basic logic included --- src/Backend/opti-sql-go/go.mod | 9 +- src/Backend/opti-sql-go/go.sum | 10 + src/Backend/opti-sql-go/main.go | 3 + .../opti-sql-go/operators/serialize_test.go | 43 ++ .../opti-sql-go/substrait/operation.pb.go | 401 ++++++++++++++++++ .../substrait/operation_grpc.pb.go | 126 ++++++ src/Backend/opti-sql-go/substrait/server.go | 49 +++ .../opti-sql-go/substrait/substrait.go | 3 - src/Contract/operation.proto | 1 + 9 files changed, 641 insertions(+), 4 deletions(-) create mode 100644 src/Backend/opti-sql-go/substrait/operation.pb.go create mode 100644 src/Backend/opti-sql-go/substrait/operation_grpc.pb.go create mode 100644 src/Backend/opti-sql-go/substrait/server.go diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index f4c6898..c6ae03c 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -2,7 +2,11 @@ module opti-sql-go go 1.24.0 -require github.com/apache/arrow/go/v17 v17.0.0 +require ( + github.com/apache/arrow/go/v17 v17.0.0 + google.golang.org/grpc v1.63.2 + google.golang.org/protobuf v1.34.2 +) require ( github.com/goccy/go-json v0.10.3 // indirect @@ -11,8 +15,11 @@ require ( github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/mod v0.18.0 // indirect + golang.org/x/net v0.26.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect + golang.org/x/text v0.16.0 // indirect golang.org/x/tools v0.22.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect ) diff --git a/src/Backend/opti-sql-go/go.sum b/src/Backend/opti-sql-go/go.sum index 1e7a8ee..7046923 100644 --- a/src/Backend/opti-sql-go/go.sum +++ b/src/Backend/opti-sql-go/go.sum @@ -26,16 +26,26 @@ golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUF golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.15.0 h1:2lYxjRbTYyxkJxlhC+LvJIx3SsANPdRybu1tGj9/OrQ= gonum.org/v1/gonum v0.15.0/go.mod h1:xzZVBJBtS+Mz4q0Yl2LJTk+OxOg4jiXZ7qBoM0uISGo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY= +google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= +google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/Backend/opti-sql-go/main.go b/src/Backend/opti-sql-go/main.go index da29a2c..6a0958d 100644 --- a/src/Backend/opti-sql-go/main.go +++ b/src/Backend/opti-sql-go/main.go @@ -1,4 +1,7 @@ package main +import "opti-sql-go/substrait" + func main() { + substrait.Start() } diff --git a/src/Backend/opti-sql-go/operators/serialize_test.go b/src/Backend/opti-sql-go/operators/serialize_test.go index 4115bfb..76c960d 100644 --- a/src/Backend/opti-sql-go/operators/serialize_test.go +++ b/src/Backend/opti-sql-go/operators/serialize_test.go @@ -4,7 +4,9 @@ import ( "bytes" "fmt" "io" + "os" "testing" + "time" "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" @@ -756,3 +758,44 @@ func TestNullSchemaSerialize(t *testing.T) { t.Fatalf("Schema serialization failed: %v", err) } } + +func TestSeralizeToDisk(t *testing.T) { + r1 := generateDummyRecordBatch1() + serializer, err := NewSerializer(r1.Schema) + if err != nil { + t.Fatalf("Failed to create serializer: %v", err) + } + randStr := time.Now().Unix() + tmpFile, err := os.Create("serialized_data_" + fmt.Sprintf("%d", randStr) + ".bin") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + schemaContent, _ := serializer.SerializeSchema(r1.Schema) + columnContent, _ := serializer.SerializeBatchColumns(r1) + schemaContent = append(schemaContent, columnContent...) + _, err = tmpFile.Write(schemaContent) + if err != nil { + t.Fatalf("Failed to write serialized data to disk: %v", err) + } + // now decode from disk + _, err = tmpFile.Seek(0, io.SeekStart) + if err != nil { + t.Fatalf("Failed to seek to start of file: %v", err) + } + deserSchema, err := serializer.DeserializeSchema(tmpFile) + if err != nil { + t.Fatalf("Failed to deserialize schema from disk: %v", err) + } + if !deserSchema.Equal(r1.Schema) { + t.Fatalf("Deserialized schema does not match original schema") + } + deserColumns, err := serializer.DecodeRecordBatch(tmpFile, deserSchema) + if err != nil { + t.Fatalf("Failed to deserialize columns from disk: %v", err) + } + if len(deserColumns) != len(r1.Columns) { + t.Fatalf("Column count mismatch after deserialization from disk") + } +} diff --git a/src/Backend/opti-sql-go/substrait/operation.pb.go b/src/Backend/opti-sql-go/substrait/operation.pb.go new file mode 100644 index 0000000..00f49f2 --- /dev/null +++ b/src/Backend/opti-sql-go/substrait/operation.pb.go @@ -0,0 +1,401 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc v6.32.0 +// source: operation.proto + +package substrait + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ReturnTypes int32 + +const ( + ReturnTypes_SUCCESS ReturnTypes = 0 + ReturnTypes_PARSE_ERROR ReturnTypes = 1 + ReturnTypes_EXECUTION_ERROR ReturnTypes = 2 + ReturnTypes_SOURCE_ERROR ReturnTypes = 3 + ReturnTypes_UPLOAD_ERROR ReturnTypes = 4 + ReturnTypes_OUT_OF_MEMORY ReturnTypes = 5 + ReturnTypes_UNKNOWN_ERROR ReturnTypes = 6 +) + +// Enum value maps for ReturnTypes. +var ( + ReturnTypes_name = map[int32]string{ + 0: "SUCCESS", + 1: "PARSE_ERROR", + 2: "EXECUTION_ERROR", + 3: "SOURCE_ERROR", + 4: "UPLOAD_ERROR", + 5: "OUT_OF_MEMORY", + 6: "UNKNOWN_ERROR", + } + ReturnTypes_value = map[string]int32{ + "SUCCESS": 0, + "PARSE_ERROR": 1, + "EXECUTION_ERROR": 2, + "SOURCE_ERROR": 3, + "UPLOAD_ERROR": 4, + "OUT_OF_MEMORY": 5, + "UNKNOWN_ERROR": 6, + } +) + +func (x ReturnTypes) Enum() *ReturnTypes { + p := new(ReturnTypes) + *p = x + return p +} + +func (x ReturnTypes) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (ReturnTypes) Descriptor() protoreflect.EnumDescriptor { + return file_operation_proto_enumTypes[0].Descriptor() +} + +func (ReturnTypes) Type() protoreflect.EnumType { + return &file_operation_proto_enumTypes[0] +} + +func (x ReturnTypes) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use ReturnTypes.Descriptor instead. +func (ReturnTypes) EnumDescriptor() ([]byte, []int) { + return file_operation_proto_rawDescGZIP(), []int{0} +} + +// The request message containing the operation details. +type QueryExecutionRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + SubstraitLogical []byte `protobuf:"bytes,1,opt,name=substrait_logical,json=substraitLogical,proto3" json:"substrait_logical,omitempty"` //SS logical plan + SqlStatement string `protobuf:"bytes,2,opt,name=sql_statement,json=sqlStatement,proto3" json:"sql_statement,omitempty"` // original sql statement + Id string `protobuf:"bytes,3,opt,name=id,proto3" json:"id,omitempty"` // unique id for this client + Source *SourceType `protobuf:"bytes,4,opt,name=source,proto3" json:"source,omitempty"` // (s3 link| base64 data) + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *QueryExecutionRequest) Reset() { + *x = QueryExecutionRequest{} + mi := &file_operation_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *QueryExecutionRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*QueryExecutionRequest) ProtoMessage() {} + +func (x *QueryExecutionRequest) ProtoReflect() protoreflect.Message { + mi := &file_operation_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use QueryExecutionRequest.ProtoReflect.Descriptor instead. +func (*QueryExecutionRequest) Descriptor() ([]byte, []int) { + return file_operation_proto_rawDescGZIP(), []int{0} +} + +func (x *QueryExecutionRequest) GetSubstraitLogical() []byte { + if x != nil { + return x.SubstraitLogical + } + return nil +} + +func (x *QueryExecutionRequest) GetSqlStatement() string { + if x != nil { + return x.SqlStatement + } + return "" +} + +func (x *QueryExecutionRequest) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *QueryExecutionRequest) GetSource() *SourceType { + if x != nil { + return x.Source + } + return nil +} + +// The response message containing the result. +type QueryExecutionResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + S3ResultLink string `protobuf:"bytes,1,opt,name=s3_result_link,json=s3ResultLink,proto3" json:"s3_result_link,omitempty"` // s3 link to the result data + ErrorType *ErrorDetails `protobuf:"bytes,2,opt,name=error_type,json=errorType,proto3" json:"error_type,omitempty"` // error type if any + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *QueryExecutionResponse) Reset() { + *x = QueryExecutionResponse{} + mi := &file_operation_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *QueryExecutionResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*QueryExecutionResponse) ProtoMessage() {} + +func (x *QueryExecutionResponse) ProtoReflect() protoreflect.Message { + mi := &file_operation_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use QueryExecutionResponse.ProtoReflect.Descriptor instead. +func (*QueryExecutionResponse) Descriptor() ([]byte, []int) { + return file_operation_proto_rawDescGZIP(), []int{1} +} + +func (x *QueryExecutionResponse) GetS3ResultLink() string { + if x != nil { + return x.S3ResultLink + } + return "" +} + +func (x *QueryExecutionResponse) GetErrorType() *ErrorDetails { + if x != nil { + return x.ErrorType + } + return nil +} + +type SourceType struct { + state protoimpl.MessageState `protogen:"open.v1"` + S3Source string `protobuf:"bytes,1,opt,name=s3_source,json=s3Source,proto3" json:"s3_source,omitempty"` // s3 link to the source data + Mime string `protobuf:"bytes,2,opt,name=mime,proto3" json:"mime,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SourceType) Reset() { + *x = SourceType{} + mi := &file_operation_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SourceType) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SourceType) ProtoMessage() {} + +func (x *SourceType) ProtoReflect() protoreflect.Message { + mi := &file_operation_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SourceType.ProtoReflect.Descriptor instead. +func (*SourceType) Descriptor() ([]byte, []int) { + return file_operation_proto_rawDescGZIP(), []int{2} +} + +func (x *SourceType) GetS3Source() string { + if x != nil { + return x.S3Source + } + return "" +} + +func (x *SourceType) GetMime() string { + if x != nil { + return x.Mime + } + return "" +} + +type ErrorDetails struct { + state protoimpl.MessageState `protogen:"open.v1"` + ErrorType ReturnTypes `protobuf:"varint,1,opt,name=error_type,json=errorType,proto3,enum=contract.ReturnTypes" json:"error_type,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ErrorDetails) Reset() { + *x = ErrorDetails{} + mi := &file_operation_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ErrorDetails) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ErrorDetails) ProtoMessage() {} + +func (x *ErrorDetails) ProtoReflect() protoreflect.Message { + mi := &file_operation_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ErrorDetails.ProtoReflect.Descriptor instead. +func (*ErrorDetails) Descriptor() ([]byte, []int) { + return file_operation_proto_rawDescGZIP(), []int{3} +} + +func (x *ErrorDetails) GetErrorType() ReturnTypes { + if x != nil { + return x.ErrorType + } + return ReturnTypes_SUCCESS +} + +func (x *ErrorDetails) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +var File_operation_proto protoreflect.FileDescriptor + +const file_operation_proto_rawDesc = "" + + "\n" + + "\x0foperation.proto\x12\bcontract\"\xa7\x01\n" + + "\x15QueryExecutionRequest\x12+\n" + + "\x11substrait_logical\x18\x01 \x01(\fR\x10substraitLogical\x12#\n" + + "\rsql_statement\x18\x02 \x01(\tR\fsqlStatement\x12\x0e\n" + + "\x02id\x18\x03 \x01(\tR\x02id\x12,\n" + + "\x06source\x18\x04 \x01(\v2\x14.contract.SourceTypeR\x06source\"u\n" + + "\x16QueryExecutionResponse\x12$\n" + + "\x0es3_result_link\x18\x01 \x01(\tR\fs3ResultLink\x125\n" + + "\n" + + "error_type\x18\x02 \x01(\v2\x16.contract.ErrorDetailsR\terrorType\"=\n" + + "\n" + + "SourceType\x12\x1b\n" + + "\ts3_source\x18\x01 \x01(\tR\bs3Source\x12\x12\n" + + "\x04mime\x18\x02 \x01(\tR\x04mime\"^\n" + + "\fErrorDetails\x124\n" + + "\n" + + "error_type\x18\x01 \x01(\x0e2\x15.contract.returnTypesR\terrorType\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage*\x8a\x01\n" + + "\vreturnTypes\x12\v\n" + + "\aSUCCESS\x10\x00\x12\x0f\n" + + "\vPARSE_ERROR\x10\x01\x12\x13\n" + + "\x0fEXECUTION_ERROR\x10\x02\x12\x10\n" + + "\fSOURCE_ERROR\x10\x03\x12\x10\n" + + "\fUPLOAD_ERROR\x10\x04\x12\x11\n" + + "\rOUT_OF_MEMORY\x10\x05\x12\x11\n" + + "\rUNKNOWN_ERROR\x10\x062`\n" + + "\vSSOperation\x12Q\n" + + "\fExecuteQuery\x12\x1f.contract.QueryExecutionRequest\x1a .contract.QueryExecutionResponseB+Z)opti-sql-go/Backend/opti-sql-go/substraitb\x06proto3" + +var ( + file_operation_proto_rawDescOnce sync.Once + file_operation_proto_rawDescData []byte +) + +func file_operation_proto_rawDescGZIP() []byte { + file_operation_proto_rawDescOnce.Do(func() { + file_operation_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_operation_proto_rawDesc), len(file_operation_proto_rawDesc))) + }) + return file_operation_proto_rawDescData +} + +var file_operation_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_operation_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_operation_proto_goTypes = []any{ + (ReturnTypes)(0), // 0: contract.returnTypes + (*QueryExecutionRequest)(nil), // 1: contract.QueryExecutionRequest + (*QueryExecutionResponse)(nil), // 2: contract.QueryExecutionResponse + (*SourceType)(nil), // 3: contract.SourceType + (*ErrorDetails)(nil), // 4: contract.ErrorDetails +} +var file_operation_proto_depIdxs = []int32{ + 3, // 0: contract.QueryExecutionRequest.source:type_name -> contract.SourceType + 4, // 1: contract.QueryExecutionResponse.error_type:type_name -> contract.ErrorDetails + 0, // 2: contract.ErrorDetails.error_type:type_name -> contract.returnTypes + 1, // 3: contract.SSOperation.ExecuteQuery:input_type -> contract.QueryExecutionRequest + 2, // 4: contract.SSOperation.ExecuteQuery:output_type -> contract.QueryExecutionResponse + 4, // [4:5] is the sub-list for method output_type + 3, // [3:4] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_operation_proto_init() } +func file_operation_proto_init() { + if File_operation_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_operation_proto_rawDesc), len(file_operation_proto_rawDesc)), + NumEnums: 1, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_operation_proto_goTypes, + DependencyIndexes: file_operation_proto_depIdxs, + EnumInfos: file_operation_proto_enumTypes, + MessageInfos: file_operation_proto_msgTypes, + }.Build() + File_operation_proto = out.File + file_operation_proto_goTypes = nil + file_operation_proto_depIdxs = nil +} diff --git a/src/Backend/opti-sql-go/substrait/operation_grpc.pb.go b/src/Backend/opti-sql-go/substrait/operation_grpc.pb.go new file mode 100644 index 0000000..3b87fab --- /dev/null +++ b/src/Backend/opti-sql-go/substrait/operation_grpc.pb.go @@ -0,0 +1,126 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v6.32.0 +// source: operation.proto + +package substrait + +import ( + context "context" + + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +//const _ = grpc.SupportPackageIsVersion9 + +const ( + SSOperation_ExecuteQuery_FullMethodName = "/contract.SSOperation/ExecuteQuery" +) + +// SSOperationClient is the client API for SSOperation service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// The service definition. +type SSOperationClient interface { + ExecuteQuery(ctx context.Context, in *QueryExecutionRequest, opts ...grpc.CallOption) (*QueryExecutionResponse, error) +} + +type sSOperationClient struct { + cc grpc.ClientConnInterface +} + +func NewSSOperationClient(cc grpc.ClientConnInterface) SSOperationClient { + return &sSOperationClient{cc} +} + +func (c *sSOperationClient) ExecuteQuery(ctx context.Context, in *QueryExecutionRequest, opts ...grpc.CallOption) (*QueryExecutionResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(QueryExecutionResponse) + err := c.cc.Invoke(ctx, SSOperation_ExecuteQuery_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// SSOperationServer is the server API for SSOperation service. +// All implementations must embed UnimplementedSSOperationServer +// for forward compatibility. +// +// The service definition. +type SSOperationServer interface { + ExecuteQuery(context.Context, *QueryExecutionRequest) (*QueryExecutionResponse, error) + mustEmbedUnimplementedSSOperationServer() +} + +// UnimplementedSSOperationServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedSSOperationServer struct{} + +func (UnimplementedSSOperationServer) ExecuteQuery(context.Context, *QueryExecutionRequest) (*QueryExecutionResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ExecuteQuery not implemented") +} +func (UnimplementedSSOperationServer) mustEmbedUnimplementedSSOperationServer() {} +func (UnimplementedSSOperationServer) testEmbeddedByValue() {} + +// UnsafeSSOperationServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to SSOperationServer will +// result in compilation errors. +type UnsafeSSOperationServer interface { + mustEmbedUnimplementedSSOperationServer() +} + +func RegisterSSOperationServer(s grpc.ServiceRegistrar, srv SSOperationServer) { + // If the following call pancis, it indicates UnimplementedSSOperationServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&SSOperation_ServiceDesc, srv) +} + +func _SSOperation_ExecuteQuery_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(QueryExecutionRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SSOperationServer).ExecuteQuery(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: SSOperation_ExecuteQuery_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SSOperationServer).ExecuteQuery(ctx, req.(*QueryExecutionRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// SSOperation_ServiceDesc is the grpc.ServiceDesc for SSOperation service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var SSOperation_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "contract.SSOperation", + HandlerType: (*SSOperationServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "ExecuteQuery", + Handler: _SSOperation_ExecuteQuery_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "operation.proto", +} diff --git a/src/Backend/opti-sql-go/substrait/server.go b/src/Backend/opti-sql-go/substrait/server.go new file mode 100644 index 0000000..eefba9f --- /dev/null +++ b/src/Backend/opti-sql-go/substrait/server.go @@ -0,0 +1,49 @@ +package substrait + +import ( + "context" + "fmt" + "log" + "net" + + "google.golang.org/grpc" +) + +// SubstraitServer receives the substrait plan (gRPC) and sends out the optimized substrait plan (gRPC) +type SubstraitServer struct { + UnimplementedSSOperationServer +} + +func newSubstraitServer() *SubstraitServer { + return &SubstraitServer{} +} + +// ExecuteQuery implements the gRPC service method +func (s *SubstraitServer) ExecuteQuery(ctx context.Context, req *QueryExecutionRequest) (*QueryExecutionResponse, error) { + fmt.Printf("Received query request: logical_plan:%v\n sql:%s\n id:%v\n source: %v\n", req.SubstraitLogical, req.SqlStatement, req.Id, req.Source) + + // Placeholder response + return &QueryExecutionResponse{ + S3ResultLink: "", + ErrorType: &ErrorDetails{ + ErrorType: ReturnTypes_SUCCESS, + Message: "Query executed successfully", + }, + }, nil +} + +func Start() { + port := 8000 + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + log.Fatalf("Failed to listen on port %d: %v", port, err) + } + + grpcServer := grpc.NewServer() + RegisterSSOperationServer(grpcServer, newSubstraitServer()) + + log.Printf("Substrait server listening on port %d", port) + if err := grpcServer.Serve(listener); err != nil { + log.Fatalf("Failed to serve: %v", err) + } +} diff --git a/src/Backend/opti-sql-go/substrait/substrait.go b/src/Backend/opti-sql-go/substrait/substrait.go index 792434f..a809ba7 100644 --- a/src/Backend/opti-sql-go/substrait/substrait.go +++ b/src/Backend/opti-sql-go/substrait/substrait.go @@ -1,4 +1 @@ package substrait - -// recieve the substraight plan (GRPC) -//send out the optmized substrait plan (GRPC)// s3 diff --git a/src/Contract/operation.proto b/src/Contract/operation.proto index 040a0b1..6e5261f 100644 --- a/src/Contract/operation.proto +++ b/src/Contract/operation.proto @@ -1,6 +1,7 @@ syntax = "proto3"; package contract; +option go_package = "opti-sql-go/Backend/opti-sql-go/substrait"; // The service definition. service SSOperation { From 6b744fe9470b80251ac7874d84dc9b3db99fa6f1 Mon Sep 17 00:00:00 2001 From: Richard Baah Date: Fri, 14 Nov 2025 18:06:13 -0500 Subject: [PATCH 10/10] feat(config file): add config file for frequently changing values --- src/Backend/opti-sql-go/config/config.go | 184 ++++ src/Backend/opti-sql-go/config/config_test.go | 862 ++++++++++++++++++ .../opti-sql-go/config/example_config.yml | 25 + .../config/testdata/empty_config.yaml | 1 + .../config/testdata/full_override.yaml | 25 + .../opti-sql-go/config/testdata/invalid.yaml | 4 + .../config/testdata/mixed_partial.yaml | 14 + .../config/testdata/partial_batch.yaml | 3 + .../config/testdata/partial_metrics.yaml | 4 + .../config/testdata/partial_query.yaml | 3 + .../config/testdata/partial_server.yaml | 3 + .../config/testdata/zero_values.yaml | 27 + src/Backend/opti-sql-go/go.mod | 1 + src/Backend/opti-sql-go/go.sum | 1 + src/Backend/opti-sql-go/main.go | 14 +- .../opti-sql-go/operators/record_test.go | 164 ++++ .../opti-sql-go/operators/serialize.go | 111 ++- .../opti-sql-go/operators/serialize_test.go | 12 +- src/Backend/opti-sql-go/substrait/server.go | 52 +- .../opti-sql-go/substrait/substrait_test.go | 52 +- src/Contract/operation.proto | 2 +- 21 files changed, 1520 insertions(+), 44 deletions(-) create mode 100644 src/Backend/opti-sql-go/config/config.go create mode 100644 src/Backend/opti-sql-go/config/config_test.go create mode 100644 src/Backend/opti-sql-go/config/example_config.yml create mode 100644 src/Backend/opti-sql-go/config/testdata/empty_config.yaml create mode 100644 src/Backend/opti-sql-go/config/testdata/full_override.yaml create mode 100644 src/Backend/opti-sql-go/config/testdata/invalid.yaml create mode 100644 src/Backend/opti-sql-go/config/testdata/mixed_partial.yaml create mode 100644 src/Backend/opti-sql-go/config/testdata/partial_batch.yaml create mode 100644 src/Backend/opti-sql-go/config/testdata/partial_metrics.yaml create mode 100644 src/Backend/opti-sql-go/config/testdata/partial_query.yaml create mode 100644 src/Backend/opti-sql-go/config/testdata/partial_server.yaml create mode 100644 src/Backend/opti-sql-go/config/testdata/zero_values.yaml diff --git a/src/Backend/opti-sql-go/config/config.go b/src/Backend/opti-sql-go/config/config.go new file mode 100644 index 0000000..d5943e6 --- /dev/null +++ b/src/Backend/opti-sql-go/config/config.go @@ -0,0 +1,184 @@ +package config + +import ( + "errors" + "fmt" + "os" + "strings" + + "gopkg.in/yaml.v3" +) + +var ( + kiloByte = 1024 + megaByte = 1024 * kiloByte + gigaByte = 1024 * megaByte +) + +type Config struct { + Server serverConfig `yaml:"server"` + Batch batchConfig `yaml:"batch"` + Query queryConfig `yaml:"query"` + Metrics metricsConfig `yaml:"metrics"` +} +type serverConfig struct { + Port int `yaml:"port"` + Host string `yaml:"host"` + Timeout int `yaml:"timeout"` + MaxRequestSizeMB uint64 `yaml:"max_request_size_mb"` // max size of a file upload. passed in by grpc request +} +type batchConfig struct { + Size int `yaml:"size"` + EnableParallelRead bool `yaml:"enable_parallel_read"` + MaxMemoryBeforeSpill uint64 `yaml:"max_memory_before_spill"` + MaxFileSizeMB int `yaml:"max_file_size_mb"` // max size of a single file +} +type queryConfig struct { + // should results be cached, server side? if so how long + EnableCache bool `yaml:"enable_cache"` + CacheTTLSeconds int `yaml:"cache_ttl_seconds"` + // run queries concurrently? if so what the max before blocking + EnableConcurrentExecution bool `yaml:"enable_concurrent_execution"` + MaxConcurrentQueries int `yaml:"max_concurrent_queries"` // blocks after this many concurrent queries until one finishes +} +type metricsConfig struct { + EnableMetrics bool `yaml:"enable_metrics"` + MetricsPort int `yaml:"metrics_port"` + MetricsHost string `yaml:"metrics_host"` + ExportIntervalSecs int `yaml:"export_interval_secs"` + // what queries have beeen sent + EnableQueryStats bool `yaml:"enable_query_stats"` + // memory usage over time + EnableMemoryStats bool `yaml:"enable_memory_stats"` +} + +var configInstance *Config = &Config{ + Server: serverConfig{ + Port: 8080, + Host: "localhost", + Timeout: 30, + MaxRequestSizeMB: 15, + }, + Batch: batchConfig{ + Size: 1024 * 8, // rows per bathch + EnableParallelRead: true, + MaxMemoryBeforeSpill: uint64(gigaByte) * 2, // 2GB + MaxFileSizeMB: 500, // 500MB + }, + Query: queryConfig{ + EnableCache: true, + CacheTTLSeconds: 600, // 10 minutes + EnableConcurrentExecution: true, + MaxConcurrentQueries: 2, // 2 concurrent queries + }, + Metrics: metricsConfig{ + EnableMetrics: true, + MetricsPort: 9999, + MetricsHost: "localhost", + ExportIntervalSecs: 60, // 1 minute + EnableQueryStats: true, + EnableMemoryStats: true, + }, +} + +func GetConfig() *Config { + return configInstance +} + +// overwrite global instance with loaded config +func Decode(filePath string) error { + suffix := strings.Split(filePath, ".")[len(strings.Split(filePath, "."))-1] + if suffix != "yaml" && suffix != "yml" { + return errors.New("file must be a .yaml or .yml file") + } + r, err := os.Open(filePath) + if err != nil { + return err + } + config := make(map[string]interface{}) + decoder := yaml.NewDecoder(r) + if err := decoder.Decode(config); err != nil { + return fmt.Errorf("failed to decode config: %w", err) + } + mergeConfig(configInstance, config) + return nil +} +func mergeConfig(dst *Config, src map[string]interface{}) { + // ============================= + // SERVER + // ============================= + if server, ok := src["server"].(map[string]interface{}); ok { + if v, ok := server["port"].(int); ok { + dst.Server.Port = v + } + if v, ok := server["host"].(string); ok { + dst.Server.Host = v + } + if v, ok := server["timeout"].(int); ok { + dst.Server.Timeout = v + } + if v, ok := server["max_request_size_mb"].(int); ok { + dst.Server.MaxRequestSizeMB = uint64(v) + } + } + + // ============================= + // BATCH + // ============================= + if batch, ok := src["batch"].(map[string]interface{}); ok { + if v, ok := batch["size"].(int); ok { + dst.Batch.Size = v + } + if v, ok := batch["enable_parallel_read"].(bool); ok { + dst.Batch.EnableParallelRead = v + } + if v, ok := batch["max_memory_before_spill"].(int); ok { + dst.Batch.MaxMemoryBeforeSpill = uint64(v) + } + if v, ok := batch["max_file_size_mb"].(int); ok { + dst.Batch.MaxFileSizeMB = v + } + } + + // ============================= + // QUERY + // ============================= + if query, ok := src["query"].(map[string]interface{}); ok { + if v, ok := query["enable_cache"].(bool); ok { + dst.Query.EnableCache = v + } + if v, ok := query["cache_ttl_seconds"].(int); ok { + dst.Query.CacheTTLSeconds = v + } + if v, ok := query["enable_concurrent_execution"].(bool); ok { + dst.Query.EnableConcurrentExecution = v + } + if v, ok := query["max_concurrent_queries"].(int); ok { + dst.Query.MaxConcurrentQueries = v + } + } + + // ============================= + // METRICS + // ============================= + if metrics, ok := src["metrics"].(map[string]interface{}); ok { + if v, ok := metrics["enable_metrics"].(bool); ok { + dst.Metrics.EnableMetrics = v + } + if v, ok := metrics["metrics_port"].(int); ok { + dst.Metrics.MetricsPort = v + } + if v, ok := metrics["metrics_host"].(string); ok { + dst.Metrics.MetricsHost = v + } + if v, ok := metrics["export_interval_secs"].(int); ok { + dst.Metrics.ExportIntervalSecs = v + } + if v, ok := metrics["enable_query_stats"].(bool); ok { + dst.Metrics.EnableQueryStats = v + } + if v, ok := metrics["enable_memory_stats"].(bool); ok { + dst.Metrics.EnableMemoryStats = v + } + } +} diff --git a/src/Backend/opti-sql-go/config/config_test.go b/src/Backend/opti-sql-go/config/config_test.go new file mode 100644 index 0000000..052db1f --- /dev/null +++ b/src/Backend/opti-sql-go/config/config_test.go @@ -0,0 +1,862 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +// resetConfig resets the singleton to defaults between tests +func resetConfig() { + configInstance = &Config{ + Server: serverConfig{ + Port: 8080, + Host: "localhost", + Timeout: 30, + MaxRequestSizeMB: 15, + }, + Batch: batchConfig{ + Size: 8192, + EnableParallelRead: true, + MaxMemoryBeforeSpill: 2147483648, + MaxFileSizeMB: 500, + }, + Query: queryConfig{ + EnableCache: true, + CacheTTLSeconds: 600, + EnableConcurrentExecution: true, + MaxConcurrentQueries: 2, + }, + Metrics: metricsConfig{ + EnableMetrics: true, + MetricsPort: 9999, + MetricsHost: "localhost", + ExportIntervalSecs: 60, + EnableQueryStats: true, + EnableMemoryStats: true, + }, + } +} + +// TestGetConfig tests the singleton pattern +func TestGetConfig(t *testing.T) { + resetConfig() + + config1 := GetConfig() + config2 := GetConfig() + + // Should return the same instance + if config1 != config2 { + t.Error("GetConfig should return the same singleton instance") + } + + // Verify default values + if config1.Server.Port != 8080 { + t.Errorf("Expected default port 8080, got %d", config1.Server.Port) + } + if config1.Server.Host != "localhost" { + t.Errorf("Expected default host 'localhost', got %s", config1.Server.Host) + } +} + +// TestDecodeInvalidExtension tests file extension validation +func TestDecodeInvalidExtension(t *testing.T) { + resetConfig() + + tests := []struct { + name string + filename string + }{ + {"JSON extension", "config.json"}, + {"TXT extension", "config.txt"}, + {"No extension", "config"}, + {"Wrong extension", "config.xml"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := Decode(tt.filename) + if err == nil { + t.Errorf("Expected error for %s, got nil", tt.filename) + } + expectedMsg := "file must be a .yaml or .yml file" + if err.Error() != expectedMsg { + t.Errorf("Expected error '%s', got '%s'", expectedMsg, err.Error()) + } + }) + } +} + +// TestDecodeMissingFile tests handling of non-existent files +func TestDecodeMissingFile(t *testing.T) { + resetConfig() + + err := Decode("nonexistent.yaml") + if err == nil { + t.Error("Expected error for missing file, got nil") + } +} + +// TestDecodeInvalidYAML tests handling of malformed YAML +func TestDecodeInvalidYAML(t *testing.T) { + resetConfig() + + invalidPath := filepath.Join("testdata", "invalid.yaml") + err := Decode(invalidPath) + if err == nil { + t.Error("Expected error for invalid YAML, got nil") + } +} + +// TestDecodeEmptyConfig tests that empty config preserves all defaults +func TestDecodeEmptyConfig(t *testing.T) { + resetConfig() + + emptyPath := filepath.Join("testdata", "empty_config.yaml") + err := Decode(emptyPath) + if err == nil { + t.Fatalf("This operation should have failed due to EOF: %v", err) + } + + config := GetConfig() + + // Verify all defaults are preserved + if config.Server.Port != 8080 { + t.Errorf("Expected port 8080, got %d", config.Server.Port) + } + if config.Server.Host != "localhost" { + t.Errorf("Expected host 'localhost', got %s", config.Server.Host) + } + if config.Server.Timeout != 30 { + t.Errorf("Expected timeout 30, got %d", config.Server.Timeout) + } + if config.Server.MaxRequestSizeMB != 15 { + t.Errorf("Expected max request size 15, got %d", config.Server.MaxRequestSizeMB) + } + + if config.Batch.Size != 8192 { + t.Errorf("Expected batch size 8192, got %d", config.Batch.Size) + } + if !config.Batch.EnableParallelRead { + t.Error("Expected enable parallel read true") + } + if config.Batch.MaxMemoryBeforeSpill != 2147483648 { + t.Errorf("Expected max memory 2147483648, got %d", config.Batch.MaxMemoryBeforeSpill) + } + if config.Batch.MaxFileSizeMB != 500 { + t.Errorf("Expected max file size 500, got %d", config.Batch.MaxFileSizeMB) + } + + if !config.Query.EnableCache { + t.Error("Expected enable cache true") + } + if config.Query.CacheTTLSeconds != 600 { + t.Errorf("Expected cache TTL 600, got %d", config.Query.CacheTTLSeconds) + } + if !config.Query.EnableConcurrentExecution { + t.Error("Expected enable concurrent execution true") + } + if config.Query.MaxConcurrentQueries != 2 { + t.Errorf("Expected max concurrent queries 2, got %d", config.Query.MaxConcurrentQueries) + } + + if !config.Metrics.EnableMetrics { + t.Error("Expected enable metrics true") + } + if config.Metrics.MetricsPort != 9999 { + t.Errorf("Expected metrics port 9999, got %d", config.Metrics.MetricsPort) + } + if config.Metrics.MetricsHost != "localhost" { + t.Errorf("Expected metrics host 'localhost', got %s", config.Metrics.MetricsHost) + } + if config.Metrics.ExportIntervalSecs != 60 { + t.Errorf("Expected export interval 60, got %d", config.Metrics.ExportIntervalSecs) + } + if !config.Metrics.EnableQueryStats { + t.Error("Expected enable query stats true") + } + if !config.Metrics.EnableMemoryStats { + t.Error("Expected enable memory stats true") + } +} + +// TestDecodeFullOverride tests that all values can be overridden +func TestDecodeFullOverride(t *testing.T) { + resetConfig() + + fullPath := filepath.Join("testdata", "full_override.yaml") + err := Decode(fullPath) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + config := GetConfig() + + // Verify all values are overridden + if config.Server.Port != 9090 { + t.Errorf("Expected port 9090, got %d", config.Server.Port) + } + if config.Server.Host != "0.0.0.0" { + t.Errorf("Expected host '0.0.0.0', got %s", config.Server.Host) + } + if config.Server.Timeout != 60 { + t.Errorf("Expected timeout 60, got %d", config.Server.Timeout) + } + if config.Server.MaxRequestSizeMB != 25 { + t.Errorf("Expected max request size 25, got %d", config.Server.MaxRequestSizeMB) + } + + if config.Batch.Size != 16384 { + t.Errorf("Expected batch size 16384, got %d", config.Batch.Size) + } + if config.Batch.EnableParallelRead { + t.Error("Expected enable parallel read false") + } + if config.Batch.MaxMemoryBeforeSpill != 4294967296 { + t.Errorf("Expected max memory 4294967296, got %d", config.Batch.MaxMemoryBeforeSpill) + } + if config.Batch.MaxFileSizeMB != 1000 { + t.Errorf("Expected max file size 1000, got %d", config.Batch.MaxFileSizeMB) + } + + if config.Query.EnableCache { + t.Error("Expected enable cache false") + } + if config.Query.CacheTTLSeconds != 1200 { + t.Errorf("Expected cache TTL 1200, got %d", config.Query.CacheTTLSeconds) + } + if config.Query.EnableConcurrentExecution { + t.Error("Expected enable concurrent execution false") + } + if config.Query.MaxConcurrentQueries != 4 { + t.Errorf("Expected max concurrent queries 4, got %d", config.Query.MaxConcurrentQueries) + } + + if config.Metrics.EnableMetrics { + t.Error("Expected enable metrics false") + } + if config.Metrics.MetricsPort != 8888 { + t.Errorf("Expected metrics port 8888, got %d", config.Metrics.MetricsPort) + } + if config.Metrics.MetricsHost != "127.0.0.1" { + t.Errorf("Expected metrics host '127.0.0.1', got %s", config.Metrics.MetricsHost) + } + if config.Metrics.ExportIntervalSecs != 120 { + t.Errorf("Expected export interval 120, got %d", config.Metrics.ExportIntervalSecs) + } + if config.Metrics.EnableQueryStats { + t.Error("Expected enable query stats false") + } + if config.Metrics.EnableMemoryStats { + t.Error("Expected enable memory stats false") + } +} + +// TestMergeConfigServerPartial tests partial server config merge +func TestMergeConfigServerPartial(t *testing.T) { + resetConfig() + + partialPath := filepath.Join("testdata", "partial_server.yaml") + err := Decode(partialPath) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + config := GetConfig() + + // Verify overridden values + if config.Server.Port != 3000 { + t.Errorf("Expected port 3000, got %d", config.Server.Port) + } + if config.Server.Host != "192.168.1.1" { + t.Errorf("Expected host '192.168.1.1', got %s", config.Server.Host) + } + + // Verify preserved defaults + if config.Server.Timeout != 30 { + t.Errorf("Expected timeout 30 (default), got %d", config.Server.Timeout) + } + if config.Server.MaxRequestSizeMB != 15 { + t.Errorf("Expected max request size 15 (default), got %d", config.Server.MaxRequestSizeMB) + } + + // Verify other sections untouched + if config.Batch.Size != 8192 { + t.Errorf("Expected batch size 8192 (default), got %d", config.Batch.Size) + } +} + +// TestMergeConfigBatchPartial tests partial batch config merge +func TestMergeConfigBatchPartial(t *testing.T) { + resetConfig() + + partialPath := filepath.Join("testdata", "partial_batch.yaml") + err := Decode(partialPath) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + config := GetConfig() + + // Verify overridden values + if config.Batch.Size != 4096 { + t.Errorf("Expected batch size 4096, got %d", config.Batch.Size) + } + if config.Batch.EnableParallelRead { + t.Error("Expected enable parallel read false (overridden)") + } + + // Verify preserved defaults + if config.Batch.MaxMemoryBeforeSpill != 2147483648 { + t.Errorf("Expected max memory 2147483648 (default), got %d", config.Batch.MaxMemoryBeforeSpill) + } + if config.Batch.MaxFileSizeMB != 500 { + t.Errorf("Expected max file size 500 (default), got %d", config.Batch.MaxFileSizeMB) + } + + // Verify other sections untouched + if config.Server.Port != 8080 { + t.Errorf("Expected port 8080 (default), got %d", config.Server.Port) + } +} + +// TestMergeConfigQueryPartial tests partial query config merge +func TestMergeConfigQueryPartial(t *testing.T) { + resetConfig() + + partialPath := filepath.Join("testdata", "partial_query.yaml") + err := Decode(partialPath) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + config := GetConfig() + + // Verify overridden values + if config.Query.EnableCache { + t.Error("Expected enable cache false (overridden)") + } + if config.Query.MaxConcurrentQueries != 8 { + t.Errorf("Expected max concurrent queries 8, got %d", config.Query.MaxConcurrentQueries) + } + + // Verify preserved defaults + if config.Query.CacheTTLSeconds != 600 { + t.Errorf("Expected cache TTL 600 (default), got %d", config.Query.CacheTTLSeconds) + } + if !config.Query.EnableConcurrentExecution { + t.Error("Expected enable concurrent execution true (default)") + } + + // Verify other sections untouched + if config.Server.Port != 8080 { + t.Errorf("Expected port 8080 (default), got %d", config.Server.Port) + } +} + +// TestMergeConfigMetricsPartial tests partial metrics config merge +func TestMergeConfigMetricsPartial(t *testing.T) { + resetConfig() + + partialPath := filepath.Join("testdata", "partial_metrics.yaml") + err := Decode(partialPath) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + config := GetConfig() + + // Verify overridden values + if config.Metrics.EnableMetrics { + t.Error("Expected enable metrics false (overridden)") + } + if config.Metrics.MetricsPort != 7777 { + t.Errorf("Expected metrics port 7777, got %d", config.Metrics.MetricsPort) + } + if config.Metrics.ExportIntervalSecs != 30 { + t.Errorf("Expected export interval 30, got %d", config.Metrics.ExportIntervalSecs) + } + + // Verify preserved defaults + if config.Metrics.MetricsHost != "localhost" { + t.Errorf("Expected metrics host 'localhost' (default), got %s", config.Metrics.MetricsHost) + } + if !config.Metrics.EnableQueryStats { + t.Error("Expected enable query stats true (default)") + } + if !config.Metrics.EnableMemoryStats { + t.Error("Expected enable memory stats true (default)") + } + + // Verify other sections untouched + if config.Server.Port != 8080 { + t.Errorf("Expected port 8080 (default), got %d", config.Server.Port) + } +} + +// TestMergeConfigMixedPartial tests a realistic mixed partial config +func TestMergeConfigMixedPartial(t *testing.T) { + resetConfig() + + mixedPath := filepath.Join("testdata", "mixed_partial.yaml") + err := Decode(mixedPath) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + config := GetConfig() + + // Server: only timeout should be overridden + if config.Server.Port != 8080 { + t.Errorf("Expected port 8080 (default), got %d", config.Server.Port) + } + if config.Server.Host != "localhost" { + t.Errorf("Expected host 'localhost' (default), got %s", config.Server.Host) + } + if config.Server.Timeout != 45 { + t.Errorf("Expected timeout 45 (overridden), got %d", config.Server.Timeout) + } + if config.Server.MaxRequestSizeMB != 15 { + t.Errorf("Expected max request size 15 (default), got %d", config.Server.MaxRequestSizeMB) + } + + // Batch: only max memory should be overridden + if config.Batch.Size != 8192 { + t.Errorf("Expected batch size 8192 (default), got %d", config.Batch.Size) + } + if !config.Batch.EnableParallelRead { + t.Error("Expected enable parallel read true (default)") + } + if config.Batch.MaxMemoryBeforeSpill != 3221225472 { + t.Errorf("Expected max memory 3221225472 (overridden), got %d", config.Batch.MaxMemoryBeforeSpill) + } + if config.Batch.MaxFileSizeMB != 500 { + t.Errorf("Expected max file size 500 (default), got %d", config.Batch.MaxFileSizeMB) + } + + // Query: cache TTL and concurrent execution should be overridden + if !config.Query.EnableCache { + t.Error("Expected enable cache true (default)") + } + if config.Query.CacheTTLSeconds != 900 { + t.Errorf("Expected cache TTL 900 (overridden), got %d", config.Query.CacheTTLSeconds) + } + if config.Query.EnableConcurrentExecution { + t.Error("Expected enable concurrent execution false (overridden)") + } + if config.Query.MaxConcurrentQueries != 2 { + t.Errorf("Expected max concurrent queries 2 (default), got %d", config.Query.MaxConcurrentQueries) + } + + // Metrics: host and query stats should be overridden + if !config.Metrics.EnableMetrics { + t.Error("Expected enable metrics true (default)") + } + if config.Metrics.MetricsPort != 9999 { + t.Errorf("Expected metrics port 9999 (default), got %d", config.Metrics.MetricsPort) + } + if config.Metrics.MetricsHost != "metrics.example.com" { + t.Errorf("Expected metrics host 'metrics.example.com' (overridden), got %s", config.Metrics.MetricsHost) + } + if config.Metrics.ExportIntervalSecs != 60 { + t.Errorf("Expected export interval 60 (default), got %d", config.Metrics.ExportIntervalSecs) + } + if config.Metrics.EnableQueryStats { + t.Error("Expected enable query stats false (overridden)") + } + if !config.Metrics.EnableMemoryStats { + t.Error("Expected enable memory stats true (default)") + } +} + +// TestMergeConfigAllServerFields tests every server field individually +func TestMergeConfigAllServerFields(t *testing.T) { + tests := []struct { + name string + yaml string + checkFn func(*Config) error + }{ + { + name: "Port only", + yaml: "server:\n port: 5000\n", + checkFn: func(c *Config) error { + if c.Server.Port != 5000 { + t.Errorf("Expected port 5000, got %d", c.Server.Port) + } + if c.Server.Host != "localhost" { + t.Errorf("Expected host 'localhost', got %s", c.Server.Host) + } + return nil + }, + }, + { + name: "Host only", + yaml: "server:\n host: \"example.com\"\n", + checkFn: func(c *Config) error { + if c.Server.Port != 8080 { + t.Errorf("Expected port 8080, got %d", c.Server.Port) + } + if c.Server.Host != "example.com" { + t.Errorf("Expected host 'example.com', got %s", c.Server.Host) + } + return nil + }, + }, + { + name: "Timeout only", + yaml: "server:\n timeout: 120\n", + checkFn: func(c *Config) error { + if c.Server.Timeout != 120 { + t.Errorf("Expected timeout 120, got %d", c.Server.Timeout) + } + if c.Server.Port != 8080 { + t.Errorf("Expected port 8080, got %d", c.Server.Port) + } + return nil + }, + }, + { + name: "MaxRequestSizeMB only", + yaml: "server:\n max_request_size_mb: 50\n", + checkFn: func(c *Config) error { + if c.Server.MaxRequestSizeMB != 50 { + t.Errorf("Expected max request size 50, got %d", c.Server.MaxRequestSizeMB) + } + if c.Server.Port != 8080 { + t.Errorf("Expected port 8080, got %d", c.Server.Port) + } + return nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetConfig() + + // Create temporary file + tmpFile, err := os.CreateTemp("", "test_*.yaml") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + if _, err := tmpFile.WriteString(tt.yaml); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + if err := tmpFile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } + + if err := Decode(tmpFile.Name()); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if err := tt.checkFn(GetConfig()); err != nil { + t.Errorf("Check function failed: %v", err) + } + if err := os.Remove(tmpFile.Name()); err != nil { + t.Fatalf("Failed to remove temp file: %v", err) + } + }) + } +} + +// TestMergeConfigAllBatchFields tests every batch field individually +func TestMergeConfigAllBatchFields(t *testing.T) { + tests := []struct { + name string + yaml string + checkFn func(*Config) error + }{ + { + name: "Size only", + yaml: "batch:\n size: 1024\n", + checkFn: func(c *Config) error { + if c.Batch.Size != 1024 { + t.Errorf("Expected size 1024, got %d", c.Batch.Size) + } + if !c.Batch.EnableParallelRead { + t.Error("Expected enable parallel read true") + } + return nil + }, + }, + { + name: "EnableParallelRead only", + yaml: "batch:\n enable_parallel_read: false\n", + checkFn: func(c *Config) error { + if c.Batch.EnableParallelRead { + t.Error("Expected enable parallel read false") + } + if c.Batch.Size != 8192 { + t.Errorf("Expected size 8192, got %d", c.Batch.Size) + } + return nil + }, + }, + { + name: "MaxMemoryBeforeSplill only", + yaml: "batch:\n max_memory_before_spill: 1073741824\n", + checkFn: func(c *Config) error { + if c.Batch.MaxMemoryBeforeSpill != 1073741824 { + t.Errorf("Expected max memory 1073741824, got %d", c.Batch.MaxMemoryBeforeSpill) + } + if c.Batch.Size != 8192 { + t.Errorf("Expected size 8192, got %d", c.Batch.Size) + } + return nil + }, + }, + { + name: "MaxFileSizeMB only", + yaml: "batch:\n max_file_size_mb: 250\n", + checkFn: func(c *Config) error { + if c.Batch.MaxFileSizeMB != 250 { + t.Errorf("Expected max file size 250, got %d", c.Batch.MaxFileSizeMB) + } + if c.Batch.Size != 8192 { + t.Errorf("Expected size 8192, got %d", c.Batch.Size) + } + return nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetConfig() + + tmpFile, err := os.CreateTemp("", "test_*.yaml") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + if _, err := tmpFile.WriteString(tt.yaml); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + if err := tmpFile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } + + if err := Decode(tmpFile.Name()); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if err := tt.checkFn(GetConfig()); err != nil { + t.Errorf("Check function failed: %v", err) + } + + if err := os.Remove(tmpFile.Name()); err != nil { + t.Fatalf("Failed to remove temp file: %v", err) + } + }) + } +} + +// TestMergeConfigAllQueryFields tests every query field individually +func TestMergeConfigAllQueryFields(t *testing.T) { + tests := []struct { + name string + yaml string + checkFn func(*Config) error + }{ + { + name: "EnableCache only", + yaml: "query:\n enable_cache: false\n", + checkFn: func(c *Config) error { + if c.Query.EnableCache { + t.Error("Expected enable cache false") + } + if c.Query.CacheTTLSeconds != 600 { + t.Errorf("Expected cache TTL 600, got %d", c.Query.CacheTTLSeconds) + } + return nil + }, + }, + { + name: "CacheTTLSeconds only", + yaml: "query:\n cache_ttl_seconds: 300\n", + checkFn: func(c *Config) error { + if c.Query.CacheTTLSeconds != 300 { + t.Errorf("Expected cache TTL 300, got %d", c.Query.CacheTTLSeconds) + } + if !c.Query.EnableCache { + t.Error("Expected enable cache true") + } + return nil + }, + }, + { + name: "EnableConcurrentExecution only", + yaml: "query:\n enable_concurrent_execution: false\n", + checkFn: func(c *Config) error { + if c.Query.EnableConcurrentExecution { + t.Error("Expected enable concurrent execution false") + } + if c.Query.MaxConcurrentQueries != 2 { + t.Errorf("Expected max concurrent queries 2, got %d", c.Query.MaxConcurrentQueries) + } + return nil + }, + }, + { + name: "MaxConcurrentQueries only", + yaml: "query:\n max_concurrent_queries: 10\n", + checkFn: func(c *Config) error { + if c.Query.MaxConcurrentQueries != 10 { + t.Errorf("Expected max concurrent queries 10, got %d", c.Query.MaxConcurrentQueries) + } + if !c.Query.EnableConcurrentExecution { + t.Error("Expected enable concurrent execution true") + } + return nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetConfig() + + tmpFile, err := os.CreateTemp("", "test_*.yaml") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + if _, err := tmpFile.WriteString(tt.yaml); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + if err := tmpFile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } + + if err := Decode(tmpFile.Name()); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if err := tt.checkFn(GetConfig()); err != nil { + t.Errorf("Check function failed: %v", err) + } + if err := os.Remove(tmpFile.Name()); err != nil { + t.Fatalf("Failed to remove temp file: %v", err) + } + }) + } +} + +// TestMergeConfigAllMetricsFields tests every metrics field individually +func TestMergeConfigAllMetricsFields(t *testing.T) { + tests := []struct { + name string + yaml string + checkFn func(*Config) error + }{ + { + name: "EnableMetrics only", + yaml: "metrics:\n enable_metrics: false\n", + checkFn: func(c *Config) error { + if c.Metrics.EnableMetrics { + t.Error("Expected enable metrics false") + } + if c.Metrics.MetricsPort != 9999 { + t.Errorf("Expected metrics port 9999, got %d", c.Metrics.MetricsPort) + } + return nil + }, + }, + { + name: "MetricsPort only", + yaml: "metrics:\n metrics_port: 5555\n", + checkFn: func(c *Config) error { + if c.Metrics.MetricsPort != 5555 { + t.Errorf("Expected metrics port 5555, got %d", c.Metrics.MetricsPort) + } + if !c.Metrics.EnableMetrics { + t.Error("Expected enable metrics true") + } + return nil + }, + }, + { + name: "MetricsHost only", + yaml: "metrics:\n metrics_host: \"0.0.0.0\"\n", + checkFn: func(c *Config) error { + if c.Metrics.MetricsHost != "0.0.0.0" { + t.Errorf("Expected metrics host '0.0.0.0', got %s", c.Metrics.MetricsHost) + } + if c.Metrics.MetricsPort != 9999 { + t.Errorf("Expected metrics port 9999, got %d", c.Metrics.MetricsPort) + } + return nil + }, + }, + { + name: "ExportIntervalSecs only", + yaml: "metrics:\n export_interval_secs: 15\n", + checkFn: func(c *Config) error { + if c.Metrics.ExportIntervalSecs != 15 { + t.Errorf("Expected export interval 15, got %d", c.Metrics.ExportIntervalSecs) + } + if c.Metrics.MetricsPort != 9999 { + t.Errorf("Expected metrics port 9999, got %d", c.Metrics.MetricsPort) + } + return nil + }, + }, + { + name: "EnableQueryStats only", + yaml: "metrics:\n enable_query_stats: false\n", + checkFn: func(c *Config) error { + if c.Metrics.EnableQueryStats { + t.Error("Expected enable query stats false") + } + if !c.Metrics.EnableMemoryStats { + t.Error("Expected enable memory stats true") + } + return nil + }, + }, + { + name: "EnableMemoryStats only", + yaml: "metrics:\n enable_memory_stats: false\n", + checkFn: func(c *Config) error { + if c.Metrics.EnableMemoryStats { + t.Error("Expected enable memory stats false") + } + if !c.Metrics.EnableQueryStats { + t.Error("Expected enable query stats true") + } + return nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetConfig() + + tmpFile, err := os.CreateTemp("", "test_*.yaml") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + if _, err := tmpFile.WriteString(tt.yaml); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + if err := tmpFile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } + + if err := Decode(tmpFile.Name()); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if err := tt.checkFn(GetConfig()); err != nil { + t.Errorf("Check function failed: %v", err) + } + if err := os.Remove(tmpFile.Name()); err != nil { + t.Fatalf("Failed to remove temp file: %v", err) + } + }) + } +} diff --git a/src/Backend/opti-sql-go/config/example_config.yml b/src/Backend/opti-sql-go/config/example_config.yml new file mode 100644 index 0000000..d99353a --- /dev/null +++ b/src/Backend/opti-sql-go/config/example_config.yml @@ -0,0 +1,25 @@ +server: + port: 9090 + host: "0.0.0.0" + timeout: 45 + maxRequestSizeMB: 25 + +batch: + size: 16384 # 16k rows per batch + enableParallelRead: true + maxMemoryBeforeSplill: 4294967296 # 4GB + maxFileSizeMB: 256 + +query: + enableCache: true + cacheTTLSeconds: 300 # 5 mins + enableConcurrentExecution: true + maxConcurrentQueries: 4 + +metrics: + enableMetrics: true + metricsPort: 9100 + metricsHost: "localhost" + exportIntervalSecs: 30 + enableQueryStats: true + enableMemoryStats: true diff --git a/src/Backend/opti-sql-go/config/testdata/empty_config.yaml b/src/Backend/opti-sql-go/config/testdata/empty_config.yaml new file mode 100644 index 0000000..8963338 --- /dev/null +++ b/src/Backend/opti-sql-go/config/testdata/empty_config.yaml @@ -0,0 +1 @@ +# Empty config - should use all defaults diff --git a/src/Backend/opti-sql-go/config/testdata/full_override.yaml b/src/Backend/opti-sql-go/config/testdata/full_override.yaml new file mode 100644 index 0000000..0650e05 --- /dev/null +++ b/src/Backend/opti-sql-go/config/testdata/full_override.yaml @@ -0,0 +1,25 @@ +server: + port: 9090 + host: "0.0.0.0" + timeout: 60 + max_request_size_mb: 25 + +batch: + size: 16384 + enable_parallel_read: false + max_memory_before_spill: 4294967296 + max_file_size_mb: 1000 + +query: + enable_cache: false + cache_ttl_seconds: 1200 + enable_concurrent_execution: false + max_concurrent_queries: 4 + +metrics: + enable_metrics: false + metrics_port: 8888 + metrics_host: "127.0.0.1" + export_interval_secs: 120 + enable_query_stats: false + enable_memory_stats: false diff --git a/src/Backend/opti-sql-go/config/testdata/invalid.yaml b/src/Backend/opti-sql-go/config/testdata/invalid.yaml new file mode 100644 index 0000000..650d297 --- /dev/null +++ b/src/Backend/opti-sql-go/config/testdata/invalid.yaml @@ -0,0 +1,4 @@ +server: + port: "not a number" + invalid_field: true +[this is not valid yaml diff --git a/src/Backend/opti-sql-go/config/testdata/mixed_partial.yaml b/src/Backend/opti-sql-go/config/testdata/mixed_partial.yaml new file mode 100644 index 0000000..4552f4f --- /dev/null +++ b/src/Backend/opti-sql-go/config/testdata/mixed_partial.yaml @@ -0,0 +1,14 @@ +# Mix of different sections with some fields +server: + timeout: 45 + +batch: + max_memory_before_spill: 3221225472 + +query: + cache_ttl_seconds: 900 + enable_concurrent_execution: false + +metrics: + metrics_host: "metrics.example.com" + enable_query_stats: false diff --git a/src/Backend/opti-sql-go/config/testdata/partial_batch.yaml b/src/Backend/opti-sql-go/config/testdata/partial_batch.yaml new file mode 100644 index 0000000..a88d937 --- /dev/null +++ b/src/Backend/opti-sql-go/config/testdata/partial_batch.yaml @@ -0,0 +1,3 @@ +batch: + size: 4096 + enable_parallel_read: false diff --git a/src/Backend/opti-sql-go/config/testdata/partial_metrics.yaml b/src/Backend/opti-sql-go/config/testdata/partial_metrics.yaml new file mode 100644 index 0000000..13b716b --- /dev/null +++ b/src/Backend/opti-sql-go/config/testdata/partial_metrics.yaml @@ -0,0 +1,4 @@ +metrics: + enable_metrics: false + metrics_port: 7777 + export_interval_secs: 30 diff --git a/src/Backend/opti-sql-go/config/testdata/partial_query.yaml b/src/Backend/opti-sql-go/config/testdata/partial_query.yaml new file mode 100644 index 0000000..b596170 --- /dev/null +++ b/src/Backend/opti-sql-go/config/testdata/partial_query.yaml @@ -0,0 +1,3 @@ +query: + enable_cache: false + max_concurrent_queries: 8 diff --git a/src/Backend/opti-sql-go/config/testdata/partial_server.yaml b/src/Backend/opti-sql-go/config/testdata/partial_server.yaml new file mode 100644 index 0000000..346220b --- /dev/null +++ b/src/Backend/opti-sql-go/config/testdata/partial_server.yaml @@ -0,0 +1,3 @@ +server: + port: 3000 + host: "192.168.1.1" diff --git a/src/Backend/opti-sql-go/config/testdata/zero_values.yaml b/src/Backend/opti-sql-go/config/testdata/zero_values.yaml new file mode 100644 index 0000000..2599968 --- /dev/null +++ b/src/Backend/opti-sql-go/config/testdata/zero_values.yaml @@ -0,0 +1,27 @@ +# Test zero values - should NOT override defaults for numbers/strings +# but SHOULD override for booleans +server: + port: 0 + host: "" + timeout: 0 + max_request_size_mb: 0 + +batch: + size: 0 + enable_parallel_read: false # Should override (boolean) + max_memory_before_spill: 0 + max_file_size_mb: 0 + +query: + enable_cache: false # Should override (boolean) + cache_ttl_seconds: 0 + enable_concurrent_execution: false # Should override (boolean) + max_concurrent_queries: 0 + +metrics: + enable_metrics: false # Should override (boolean) + metrics_port: 0 + metrics_host: "" + export_interval_secs: 0 + enable_query_stats: false # Should override (boolean) + enable_memory_stats: false # Should override (boolean) diff --git a/src/Backend/opti-sql-go/go.mod b/src/Backend/opti-sql-go/go.mod index c6ae03c..49182e3 100644 --- a/src/Backend/opti-sql-go/go.mod +++ b/src/Backend/opti-sql-go/go.mod @@ -6,6 +6,7 @@ require ( github.com/apache/arrow/go/v17 v17.0.0 google.golang.org/grpc v1.63.2 google.golang.org/protobuf v1.34.2 + gopkg.in/yaml.v3 v3.0.1 ) require ( diff --git a/src/Backend/opti-sql-go/go.sum b/src/Backend/opti-sql-go/go.sum index 7046923..8839c2d 100644 --- a/src/Backend/opti-sql-go/go.sum +++ b/src/Backend/opti-sql-go/go.sum @@ -47,5 +47,6 @@ google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/Backend/opti-sql-go/main.go b/src/Backend/opti-sql-go/main.go index 6a0958d..82e1eb8 100644 --- a/src/Backend/opti-sql-go/main.go +++ b/src/Backend/opti-sql-go/main.go @@ -1,7 +1,17 @@ package main -import "opti-sql-go/substrait" +import ( + "opti-sql-go/config" + QueryExecuter "opti-sql-go/substrait" + "os" +) func main() { - substrait.Start() + if len(os.Args) > 1 { + if err := config.Decode(os.Args[1]); err != nil { + panic(err) + } + } + <-QueryExecuter.Start() + os.Exit(0) } diff --git a/src/Backend/opti-sql-go/operators/record_test.go b/src/Backend/opti-sql-go/operators/record_test.go index 5fe24f1..d78b1fe 100644 --- a/src/Backend/opti-sql-go/operators/record_test.go +++ b/src/Backend/opti-sql-go/operators/record_test.go @@ -1,6 +1,7 @@ package operators import ( + "fmt" "testing" "github.com/apache/arrow/go/v17/arrow" @@ -774,3 +775,166 @@ func TestRecordBatchDeepEqual(t *testing.T) { } }) } + +// TestGenTypedArrays exercises the remaining Gen*Array helper functions +// in a single test with subtests for each type (including GenInt8Array). +func TestGenTypedArrays(t *testing.T) { + rbb := NewRecordBatchBuilder() + + tests := []struct { + name string + gen func() arrow.Array + dtype arrow.DataType + validate func(arr arrow.Array) error + }{ + { + name: "GenInt8Array", + gen: func() arrow.Array { return rbb.GenInt8Array(1, -2, 3) }, + dtype: arrow.PrimitiveTypes.Int8, + validate: func(arr arrow.Array) error { + a := arr.(*array.Int8) + if a.Value(0) != 1 || a.Value(1) != -2 || a.Value(2) != 3 { + return fmt.Errorf("unexpected int8 values: %v,%v,%v", a.Value(0), a.Value(1), a.Value(2)) + } + return nil + }, + }, + { + name: "GenInt16Array", + gen: func() arrow.Array { return rbb.GenInt16Array(1000, -2000) }, + dtype: arrow.PrimitiveTypes.Int16, + validate: func(arr arrow.Array) error { + a := arr.(*array.Int16) + if a.Value(0) != 1000 || a.Value(1) != -2000 { + return fmt.Errorf("unexpected int16 values") + } + return nil + }, + }, + { + name: "GenInt64Array", + gen: func() arrow.Array { return rbb.GenInt64Array(1000000000, -5) }, + dtype: arrow.PrimitiveTypes.Int64, + validate: func(arr arrow.Array) error { + a := arr.(*array.Int64) + if a.Value(0) != 1000000000 || a.Value(1) != -5 { + return fmt.Errorf("unexpected int64 values") + } + return nil + }, + }, + { + name: "GenUint8Array", + gen: func() arrow.Array { return rbb.GenUint8Array(1, 2, 3) }, + dtype: arrow.PrimitiveTypes.Uint8, + validate: func(arr arrow.Array) error { + a := arr.(*array.Uint8) + if a.Value(0) != 1 || a.Value(2) != 3 { + return fmt.Errorf("unexpected uint8 values") + } + return nil + }, + }, + { + name: "GenUint16Array", + gen: func() arrow.Array { return rbb.GenUint16Array(100, 200) }, + dtype: arrow.PrimitiveTypes.Uint16, + validate: func(arr arrow.Array) error { + a := arr.(*array.Uint16) + if a.Value(0) != 100 || a.Value(1) != 200 { + return fmt.Errorf("unexpected uint16 values") + } + return nil + }, + }, + { + name: "GenUint32Array", + gen: func() arrow.Array { return rbb.GenUint32Array(10, 20) }, + dtype: arrow.PrimitiveTypes.Uint32, + validate: func(arr arrow.Array) error { + a := arr.(*array.Uint32) + if a.Value(0) != 10 || a.Value(1) != 20 { + return fmt.Errorf("unexpected uint32 values") + } + return nil + }, + }, + { + name: "GenUint64Array", + gen: func() arrow.Array { return rbb.GenUint64Array(1234567890, 42) }, + dtype: arrow.PrimitiveTypes.Uint64, + validate: func(arr arrow.Array) error { + a := arr.(*array.Uint64) + if a.Value(0) != 1234567890 || a.Value(1) != 42 { + return fmt.Errorf("unexpected uint64 values") + } + return nil + }, + }, + { + name: "GenFloat32Array", + gen: func() arrow.Array { return rbb.GenFloat32Array(1.5, -2.25) }, + dtype: arrow.PrimitiveTypes.Float32, + validate: func(arr arrow.Array) error { + a := arr.(*array.Float32) + if a.Value(0) != float32(1.5) || a.Value(1) != float32(-2.25) { + return fmt.Errorf("unexpected float32 values") + } + return nil + }, + }, + { + name: "GenBinaryArray", + gen: func() arrow.Array { return rbb.GenBinaryArray([]byte("a"), []byte("bb")) }, + dtype: arrow.BinaryTypes.Binary, + validate: func(arr arrow.Array) error { + a := arr.(*array.Binary) + if string(a.Value(0)) != "a" || string(a.Value(1)) != "bb" { + return fmt.Errorf("unexpected binary values") + } + return nil + }, + }, + { + name: "GenLargeStringArray", + gen: func() arrow.Array { return rbb.GenLargeStringArray("x", "y") }, + dtype: arrow.BinaryTypes.LargeString, + validate: func(arr arrow.Array) error { + a := arr.(*array.LargeString) + if a.Value(0) != "x" || a.Value(1) != "y" { + return fmt.Errorf("unexpected large string values") + } + return nil + }, + }, + { + name: "GenLargeBinaryArray", + gen: func() arrow.Array { return rbb.GenLargeBinaryArray([]byte("z"), []byte("zz")) }, + dtype: arrow.BinaryTypes.LargeBinary, + validate: func(arr arrow.Array) error { + a := arr.(*array.LargeBinary) + if string(a.Value(0)) != "z" || string(a.Value(1)) != "zz" { + return fmt.Errorf("unexpected large binary values") + } + return nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + arr := tt.gen() + defer arr.Release() + + if arr.Len() == 0 { + t.Fatalf("%s produced empty array", tt.name) + } + if !arrow.TypeEqual(arr.DataType(), tt.dtype) { + t.Fatalf("%s: expected dtype %s, got %s", tt.name, tt.dtype, arr.DataType()) + } + if err := tt.validate(arr); err != nil { + t.Fatalf("%s validation failed: %v", tt.name, err) + } + }) + } +} diff --git a/src/Backend/opti-sql-go/operators/serialize.go b/src/Backend/opti-sql-go/operators/serialize.go index d7b807d..4bacbb2 100644 --- a/src/Backend/opti-sql-go/operators/serialize.go +++ b/src/Backend/opti-sql-go/operators/serialize.go @@ -51,7 +51,7 @@ Column Data ├──────────────────────────────────────────┤ │ ... repeated for N buffers ... │ └──────────────────────────────────────────┘ -Int32 exameple (5 elements, validity + values buffers): +Int32 example (5 elements, validity + values buffers): [5] // array length [2] // numBuffers = 2 @@ -81,38 +81,67 @@ column2Block └──────────────────────────────────────────┘ +Record Batch Serialization Protocol +PURPOSE +This protocol defines how to serialize intermediate record batches to disk for pipeline-breaking +operators (sort, join, aggregation) where all records won't fit in RAM. +KEY ASSUMPTIONS +All batches share the SAME SCHEMA within a single operation +Examples: -What are we look for? -for serilization for intermediate record batches. -The main use case is for pipeline breaking operators where its unsafe to assume that all the records wil fit in ram -This means that all the inputs will the same schema | for example sort(col) -> each element is the exact same | join(col1 == col2) keep the left side in memory/in a seperate file and the right side will have the exact same schema +sort(col) - each element is the exact same type +join(col1 == col2) - keep left side in memory/separate file, right side has same schema -This means that we can just work with the data directly since well know the schema. Just to be safe we can keep the schema in memory attached directly to the object/class handling the serilization -V1 -Format on disk -> dataTypeSize|dataType|BatchSize|BatchElements|BatchSize|BatchElements..... +Schema is kept IN MEMORY attached to the serialization handler for validation -but this assumes were are only dealing with one column. -what about sort order? -what do we do with the other columns as were writing this single column to disk? -We have to write the entire record batch to disk +READING PROCEDURE -V2 -write the schema out to disk as well but this is only to validate with the in memory schema -!! inbetween each column being read into memory check with in memory schema for their data type for correct encoding -!! schema will also tell u how many columns u have to read in for that specefic record batch -format on disk -> **schema|ColumnNsize|columnNData|columnN+1size|columnN+1data|columnN+2size|columnN+2data...EndOfrecordBatch|columnNSize|columnNData|.... -format on disk for schema -> number of fields | field1NameLength|field1Name|field1TypeLength|field1Type|field1Nullable| field2NameLength|field2Name|field2TypeLength|field2Type|field2Nullable... -were going to be writing more data but theres not much we can do about that. for now this is fine +Read schema block first (done once at start) +For each record batch: -optimizations -(1)wont need to add the schema each time +Read numberOfFields from in-memory schema +For each column: + +Read columnSize +Read columnData +Validate data type against in-memory schema for correct encoding + + + + +Schema tells you exactly how many columns to read per batch + +IMPORTANT NOTES + +Schema is written to disk ONLY for validation against in-memory schema +Between reading each column, check in-memory schema for data type encoding +This trades more disk space for safety and clarity +The interface is implemented by a struct rather than attached to RecordBatch directly +to save allocations (especially for multiple spills: sort, hash join, aggregation) + +================================== +LOOSE NOTES / DEVELOPMENT HISTORY +V1 Issues: +Format was: dataTypeSize|dataType|BatchSize|BatchElements|BatchSize|BatchElements... +Problems: + +Only handled single column +What about sort order? +What happens to other columns while writing single column to disk? +Conclusion: Must write entire record batch to disk + +V2 Improvements: + +Write schema to disk for validation +Check schema between reading each column +Schema indicates how many columns per batch +Accept that we're writing more data - it's worth it for correctness */ -// saves allocations to have a struct that implements this interface than to have this attached directly to RecordBatch -// especially in the cases where we have to spill to disk multiple times (sort,hash join, aggregation) + type serializer struct { schema *arrow.Schema // schema is always attached to the serializer } @@ -168,10 +197,8 @@ func (ss *serializer) SerializeSchema(s *arrow.Schema) ([]byte, error) { var nullable uint8 if f.Nullable { nullable = 1 - fmt.Println("case 1") } else { nullable = 0 - fmt.Println("case 0") } if err := binary.Write(buf, binary.LittleEndian, nullable); err != nil { return nil, err @@ -426,3 +453,37 @@ FILE: └────────────────────────┘ EOF */ + +// lose notes / development history +/* + +What are we look for? +for serialization for intermediate record batches. +The main use case is for pipeline breaking operators where its unsafe to assume that all the records will fit in ram +This means that all the inputs will the same schema | for example sort(col) -> each element is the exact same | join(col1 == col2) keep the left side in memory/in a separate file and the right side will have the exact same schema + +This means that we can just work with the data directly since well know the schema. Just to be safe we can keep the schema in memory attached directly to the object/class handling the serialization + +V1 +Format on disk -> dataTypeSize|dataType|BatchSize|BatchElements|BatchSize|BatchElements..... + +but this assumes were are only dealing with one column. +what about sort order? +what do we do with the other columns as were writing this single column to disk? + + +We have to write the entire record batch to disk + +V2 +write the schema out to disk as well but this is only to validate with the in memory schema +!! inbetween each column being read into memory check with in memory schema for their data type for correct encoding +!! schema will also tell u how many columns u have to read in for that specific record batch +format on disk -> **schema|ColumnNsize|columnNData|columnN+1size|columnN+1data|columnN+2size|columnN+2data...EndOfrecordBatch|columnNSize|columnNData|.... +format on disk for schema -> number of fields | field1NameLength|field1Name|field1TypeLength|field1Type|field1Nullable| field2NameLength|field2Name|field2TypeLength|field2Type|field2Nullable... +were going to be writing more data but theres not much we can do about that. for now this is fine + +optimizations +(1)wont need to add the schema each time +*/ +// saves allocations to have a struct that implements this interface than to have this attached directly to RecordBatch +// especially in the cases where we have to spill to disk multiple times (sort,hash join, aggregation) diff --git a/src/Backend/opti-sql-go/operators/serialize_test.go b/src/Backend/opti-sql-go/operators/serialize_test.go index 76c960d..b8b3cac 100644 --- a/src/Backend/opti-sql-go/operators/serialize_test.go +++ b/src/Backend/opti-sql-go/operators/serialize_test.go @@ -21,13 +21,13 @@ func generateDummyRecordBatch1() RecordBatch { WithField("age", arrow.PrimitiveTypes.Int32, false). WithField("salary", arrow.PrimitiveTypes.Float64, false) - colums := []arrow.Array{ + columns := []arrow.Array{ dummyBuilder.GenIntArray(1, 2, 3, 4, 5), dummyBuilder.GenStringArray("Alice", "Bob", "Charlie", "David", "Eve"), dummyBuilder.GenIntArray(25, 30, 35, 40, 45), dummyBuilder.GenFloatArray(50000.0, 60000.0, 70000.0, 80000.0, 90000.0), } - RecordBatch, _ := dummyBuilder.NewRecordBatch(dummyBuilder.Schema(), colums) + RecordBatch, _ := dummyBuilder.NewRecordBatch(dummyBuilder.Schema(), columns) return *RecordBatch } @@ -770,8 +770,6 @@ func TestSeralizeToDisk(t *testing.T) { if err != nil { t.Fatalf("Failed to create temp file: %v", err) } - defer os.Remove(tmpFile.Name()) - defer tmpFile.Close() schemaContent, _ := serializer.SerializeSchema(r1.Schema) columnContent, _ := serializer.SerializeBatchColumns(r1) schemaContent = append(schemaContent, columnContent...) @@ -798,4 +796,10 @@ func TestSeralizeToDisk(t *testing.T) { if len(deserColumns) != len(r1.Columns) { t.Fatalf("Column count mismatch after deserialization from disk") } + if err := tmpFile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } + if err := os.Remove(tmpFile.Name()); err != nil { + t.Fatalf("Failed to remove temp file: %v", err) + } } diff --git a/src/Backend/opti-sql-go/substrait/server.go b/src/Backend/opti-sql-go/substrait/server.go index eefba9f..5fe5107 100644 --- a/src/Backend/opti-sql-go/substrait/server.go +++ b/src/Backend/opti-sql-go/substrait/server.go @@ -5,6 +5,10 @@ import ( "fmt" "log" "net" + "opti-sql-go/config" + "os" + "os/signal" + "syscall" "google.golang.org/grpc" ) @@ -12,10 +16,13 @@ import ( // SubstraitServer receives the substrait plan (gRPC) and sends out the optimized substrait plan (gRPC) type SubstraitServer struct { UnimplementedSSOperationServer + listener *net.Listener } -func newSubstraitServer() *SubstraitServer { - return &SubstraitServer{} +func newSubstraitServer(l *net.Listener) *SubstraitServer { + return &SubstraitServer{ + listener: l, + } } // ExecuteQuery implements the gRPC service method @@ -32,18 +39,43 @@ func (s *SubstraitServer) ExecuteQuery(ctx context.Context, req *QueryExecutionR }, nil } -func Start() { - port := 8000 - listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) +func Start() chan struct{} { + c := config.GetConfig() + listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.Server.Host, c.Server.Port)) if err != nil { - log.Fatalf("Failed to listen on port %d: %v", port, err) + log.Fatalf("Failed to listen on port %d: %v", c.Server.Port, err) } grpcServer := grpc.NewServer() - RegisterSSOperationServer(grpcServer, newSubstraitServer()) + ss := newSubstraitServer(&listener) + RegisterSSOperationServer(grpcServer, ss) + + stopChan := make(chan struct{}) + + log.Printf("Substrait server listening on port %d", c.Server.Port) + go unifiedShutdownHandler(ss, grpcServer, stopChan) + go func() { + if err := grpcServer.Serve(*ss.listener); err != nil { + log.Fatalf("Failed to serve: %v", err) + } + }() + return stopChan +} +func unifiedShutdownHandler(s *SubstraitServer, grpcServer *grpc.Server, stopChan chan struct{}) { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - log.Printf("Substrait server listening on port %d", port) - if err := grpcServer.Serve(listener); err != nil { - log.Fatalf("Failed to serve: %v", err) + select { + case <-stopChan: + fmt.Println("Shutdown requested by caller.") + case sig := <-sigChan: + fmt.Printf("Received signal: %v\n", sig) } + + l := *s.listener + _ = l.Close() + + grpcServer.GracefulStop() + + fmt.Println("Server shutdown complete") } diff --git a/src/Backend/opti-sql-go/substrait/substrait_test.go b/src/Backend/opti-sql-go/substrait/substrait_test.go index 49d87ee..122ad0b 100644 --- a/src/Backend/opti-sql-go/substrait/substrait_test.go +++ b/src/Backend/opti-sql-go/substrait/substrait_test.go @@ -1,7 +1,55 @@ package substrait -import "testing" +import ( + "context" + "net" + "testing" +) -func TestSubstrait(t *testing.T) { +func TestInitServer(t *testing.T) { // Simple passing test + l, err := net.Listen("tcp", "0.0.0.0:1212") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + + ss := newSubstraitServer(&l) + if ss == nil { + t.Errorf("Expected non-nil Substrait server") + } +} +func TestDummyInput(t *testing.T) { + l, err := net.Listen("tcp", "0.0.0.0:1213") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + + ss := newSubstraitServer(&l) + if ss == nil { + t.Errorf("Expected non-nil Substrait server") + } + dummyRequest := &QueryExecutionRequest{ + SqlStatement: "SELECT * FROM table", + SubstraitLogical: []byte("CgJTUxIMCgpTZWxlY3QgKiBGUk9NIHRhYmxl"), + Id: "GenerateDTODOHaasdavdasvasdvada", + Source: &SourceType{ + S3Source: "s3://my-bucket/data/table.parquet", + Mime: "application/vnd.apache.parquet", + }, + } + resp, err := ss.ExecuteQuery(context.TODO(), dummyRequest) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if resp.ErrorType.ErrorType != ReturnTypes_SUCCESS { + t.Errorf("Expected SUCCESS, got %v", resp.ErrorType.ErrorType) + } +} + +func TestStartServer(t *testing.T) { + stopChan := Start() + if stopChan == nil { + t.Errorf("Expected non-nil stop channel") + } + } diff --git a/src/Contract/operation.proto b/src/Contract/operation.proto index 6e5261f..598386b 100644 --- a/src/Contract/operation.proto +++ b/src/Contract/operation.proto @@ -10,7 +10,7 @@ service SSOperation { // The request message containing the operation details. message QueryExecutionRequest { - bytes substrait_logical = 1; //SS logical plan + bytes substrait_logical = 1; // Substrait logical plan: serialized representation of the query execution string sql_statement = 2; // original sql statement string id = 3; // unique id for this client SourceType source = 4; // (s3 link| base64 data)