11package Expr
22
33import (
4+ "bytes"
45 "context"
6+ "errors"
57 "fmt"
68 "opti-sql-go/operators"
9+ "regexp"
710 "strings"
811
912 "github.com/apache/arrow/go/v17/arrow"
1619 ErrUnsupportedExpression = func (info string ) error {
1720 return fmt .Errorf ("unsupported expression passed to EvalExpression: %s" , info )
1821 }
22+ ErrCantCompareDifferentTypes = func (leftType , rightType arrow.DataType ) error {
23+ return fmt .Errorf ("cannot compare different data types: %s and %s" , leftType , rightType )
24+ }
1925)
2026
2127type binaryOperator int
@@ -36,7 +42,8 @@ const (
3642 // logical
3743 And binaryOperator = 12
3844 Or binaryOperator = 13
39- Not binaryOperator = 14
45+ // RegEx expressions
46+ Like binaryOperator = 14 // where column_name like "patte%n_with_wi%dcard_"
4047)
4148
4249type supportedFunctions int
@@ -103,36 +110,45 @@ func EvalExpression(expr Expression, batch *operators.RecordBatch) (arrow.Array,
103110 }
104111}
105112
106- func ExprDataType (e Expression , inputSchema * arrow.Schema ) arrow.DataType {
113+ func ExprDataType (e Expression , inputSchema * arrow.Schema ) ( arrow.DataType , error ) {
107114 switch ex := e .(type ) {
108115
109116 case * LiteralResolve :
110- return ex .Type
117+ return ex .Type , nil
111118
112119 case * ColumnResolve :
113120 idx := inputSchema .FieldIndices (ex .Name )
114121 if len (idx ) == 0 {
115- panic ( fmt .Sprintf ("exprDataType: unknown column %q" , ex .Name ) )
122+ return nil , fmt .Errorf ("exprDataType: unknown column %q" , ex .Name )
116123 }
117- return inputSchema .Field (idx [0 ]).Type
124+ return inputSchema .Field (idx [0 ]).Type , nil
118125 case * Alias :
119126 // alias does NOT change type
120127 return ExprDataType (ex .Expr , inputSchema )
121128
122129 case * CastExpr :
123- return ex .TargetType
130+ return ex .TargetType , nil
124131
125132 case * BinaryExpr :
126- leftType := ExprDataType (ex .Left , inputSchema )
127- rightType := ExprDataType (ex .Right , inputSchema )
128- return inferBinaryType (leftType , ex .Op , rightType )
133+ leftType , err := ExprDataType (ex .Left , inputSchema )
134+ if err != nil {
135+ return nil , err
136+ }
137+ rightType , err := ExprDataType (ex .Right , inputSchema )
138+ if err != nil {
139+ return nil , err
140+ }
141+ return inferBinaryType (leftType , ex .Op , rightType ), nil
129142
130143 case * ScalarFunction :
131- argType := ExprDataType (ex .Arguments , inputSchema )
132- return inferScalarFunctionType (ex .Function , argType )
144+ argType , err := ExprDataType (ex .Arguments , inputSchema )
145+ if err != nil {
146+ return nil , err
147+ }
148+ return inferScalarFunctionType (ex .Function , argType ), nil
133149
134150 default :
135- panic ( fmt . Sprintf ( "unsupported expr type %T" , ex ))
151+ return nil , ErrUnsupportedExpression ( ex . String ( ))
136152 }
137153}
138154func NewExpressions (exprs ... Expression ) []Expression {
@@ -403,25 +419,95 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error
403419 return unpackDatum (datum )
404420
405421 // comparisions TODO:
422+ // These return a boolean array
406423 case Equal :
407- return nil , fmt .Errorf ("operator Equal (%d) not yet implemented" , b .Op )
424+ if leftArr .DataType () != rightArr .DataType () {
425+ return nil , ErrCantCompareDifferentTypes (leftArr .DataType (), rightArr .DataType ())
426+ }
427+ datum , err := compute .CallFunction (context .Background (), "equal" , compute .DefaultFilterOptions (), compute .NewDatum (leftArr ), compute .NewDatum (rightArr ))
428+ if err != nil {
429+ return nil , err
430+ }
431+ return unpackDatum (datum )
408432 case NotEqual :
409- return nil , fmt .Errorf ("operator NotEqual (%d) not yet implemented" , b .Op )
433+ if leftArr .DataType () != rightArr .DataType () {
434+ return nil , ErrCantCompareDifferentTypes (leftArr .DataType (), rightArr .DataType ())
435+ }
436+ datum , err := compute .CallFunction (context .Background (), "not_equal" , compute .DefaultFilterOptions (), compute .NewDatum (leftArr ), compute .NewDatum (rightArr ))
437+ if err != nil {
438+ return nil , err
439+ }
440+ return unpackDatum (datum )
410441 case LessThan :
411- return nil , fmt .Errorf ("operator LessThan (%d) not yet implemented" , b .Op )
442+ if leftArr .DataType () != rightArr .DataType () {
443+ return nil , ErrCantCompareDifferentTypes (leftArr .DataType (), rightArr .DataType ())
444+ }
445+ datum , err := compute .CallFunction (context .Background (), "less" , compute .DefaultFilterOptions (), compute .NewDatum (leftArr ), compute .NewDatum (rightArr ))
446+ if err != nil {
447+ return nil , err
448+ }
449+ return unpackDatum (datum )
412450 case LessThanOrEqual :
413- return nil , fmt .Errorf ("operator LessThanOrEqual (%d) not yet implemented" , b .Op )
451+ if leftArr .DataType () != rightArr .DataType () {
452+ return nil , ErrCantCompareDifferentTypes (leftArr .DataType (), rightArr .DataType ())
453+ }
454+ datum , err := compute .CallFunction (context .Background (), "less_equal" , compute .DefaultFilterOptions (), compute .NewDatum (leftArr ), compute .NewDatum (rightArr ))
455+ if err != nil {
456+ return nil , err
457+ }
458+ return unpackDatum (datum )
414459 case GreaterThan :
415- return nil , fmt .Errorf ("operator GreaterThan (%d) not yet implemented" , b .Op )
460+ if leftArr .DataType () != rightArr .DataType () {
461+ return nil , ErrCantCompareDifferentTypes (leftArr .DataType (), rightArr .DataType ())
462+ }
463+ datum , err := compute .CallFunction (context .Background (), "greater" , compute .DefaultFilterOptions (), compute .NewDatum (leftArr ), compute .NewDatum (rightArr ))
464+ if err != nil {
465+ return nil , err
466+ }
467+ return unpackDatum (datum )
416468 case GreaterThanOrEqual :
417- return nil , fmt .Errorf ("operator GreaterThanOrEqual (%d) not yet implemented" , b .Op )
469+ if leftArr .DataType () != rightArr .DataType () {
470+ return nil , ErrCantCompareDifferentTypes (leftArr .DataType (), rightArr .DataType ())
471+ }
472+ datum , err := compute .CallFunction (context .Background (), "greater_equal" , compute .DefaultFilterOptions (), compute .NewDatum (leftArr ), compute .NewDatum (rightArr ))
473+ if err != nil {
474+ return nil , err
475+ }
476+ return unpackDatum (datum )
418477 // logical
419478 case And :
420- return nil , fmt .Errorf ("operator And (%d) not yet implemented" , b .Op )
479+ if leftArr .DataType () != rightArr .DataType () {
480+ return nil , ErrCantCompareDifferentTypes (leftArr .DataType (), rightArr .DataType ())
481+ }
482+ datum , err := compute .CallFunction (context .Background (), "and" , compute .DefaultFilterOptions (), compute .NewDatum (leftArr ), compute .NewDatum (rightArr ))
483+ if err != nil {
484+ return nil , err
485+ }
486+ return unpackDatum (datum )
421487 case Or :
422- return nil , fmt .Errorf ("operator Or (%d) not yet implemented" , b .Op )
423- case Not :
424- return nil , fmt .Errorf ("operator Not (%d) not yet implemented" , b .Op )
488+ if leftArr .DataType () != rightArr .DataType () {
489+ return nil , ErrCantCompareDifferentTypes (leftArr .DataType (), rightArr .DataType ())
490+ }
491+ datum , err := compute .CallFunction (context .Background (), "or" , compute .DefaultFilterOptions (), compute .NewDatum (leftArr ), compute .NewDatum (rightArr ))
492+ if err != nil {
493+ return nil , err
494+ }
495+ return unpackDatum (datum )
496+ case Like :
497+ if leftArr .DataType () != arrow .BinaryTypes .String || rightArr .DataType () != arrow .BinaryTypes .String {
498+ // regEx runs only on strings
499+ return nil , errors .New ("binary operator Like only works on arrays of strings" )
500+ }
501+ var compiledRegEx = compileSqlRegEx (rightArr .ValueStr (0 ))
502+ filterBuilder := array .NewBooleanBuilder (memory .NewGoAllocator ())
503+ leftStrArray := leftArr .(* array.String )
504+ for i := 0 ; i < leftStrArray .Len (); i ++ {
505+ valid := validRegEx (leftStrArray .Value (i ), compiledRegEx )
506+ fmt .Printf ("does %s match %s: %v\n " , leftStrArray .Value (i ), compiledRegEx , valid )
507+ filterBuilder .Append (valid )
508+ }
509+ return filterBuilder .NewArray (), nil
510+
425511 }
426512 return nil , fmt .Errorf ("binary operator %d not supported" , b .Op )
427513}
@@ -436,28 +522,6 @@ func unpackDatum(d compute.Datum) (arrow.Array, error) {
436522 }
437523 return array .MakeArray (), nil
438524}
439- func inferBinaryType (left arrow.DataType , op binaryOperator , right arrow.DataType ) arrow.DataType {
440- switch op {
441-
442- case Addition , Subtraction , Multiplication , Division :
443- // numeric → numeric promotion rules
444- return numericPromotion (left , right )
445-
446- case Equal , NotEqual , LessThan , LessThanOrEqual , GreaterThan , GreaterThanOrEqual :
447- return arrow .FixedWidthTypes .Boolean
448-
449- case And , Or :
450- return arrow .FixedWidthTypes .Boolean
451-
452- default :
453- panic (fmt .Sprintf ("inferBinaryType: unsupported operator %v" , op ))
454- }
455- }
456- func numericPromotion (a , b arrow.DataType ) arrow.DataType {
457- // simplest version: return float64 for any mixed numeric types.
458- // expand later when needed.
459- return arrow .PrimitiveTypes .Float64
460- }
461525
462526type ScalarFunction struct {
463527 Function supportedFunctions
@@ -513,6 +577,44 @@ func (s *ScalarFunction) ExprNode() {}
513577func (s * ScalarFunction ) String () string {
514578 return fmt .Sprintf ("ScalarFunction(%d, %v)" , s .Function , s .Arguments )
515579}
580+
581+ // If cast succeeds → return the casted value
582+ // If cast fails → throw a runtime error
583+ type CastExpr struct {
584+ Expr Expression // can be a Literal or Column (check for datatype when you resolve)
585+ TargetType arrow.DataType
586+ }
587+
588+ func NewCastExpr (expr Expression , targetType arrow.DataType ) * CastExpr {
589+ return & CastExpr {
590+ Expr : expr ,
591+ TargetType : targetType ,
592+ }
593+ }
594+
595+ func EvalCast (c * CastExpr , batch * operators.RecordBatch ) (arrow.Array , error ) {
596+ arr , err := EvalExpression (c .Expr , batch )
597+ if err != nil {
598+ return nil , err
599+ }
600+
601+ // Use Arrow compute kernel to cast
602+ castOpts := compute .SafeCastOptions (c .TargetType )
603+ out , err := compute .CastArray (context .TODO (), arr , castOpts )
604+ if err != nil {
605+ // This is a runtime cast error
606+ return nil , fmt .Errorf ("cast error: cannot cast %s to %s: %w" ,
607+ arr .DataType (), c .TargetType , err )
608+ }
609+
610+ return out , nil
611+ }
612+
613+ func (c * CastExpr ) ExprNode () {}
614+ func (c * CastExpr ) String () string {
615+ return fmt .Sprintf ("Cast(%s AS %s)" , c .Expr , c .TargetType )
616+ }
617+
516618func upperImpl (arr arrow.Array ) (arrow.Array , error ) {
517619 strArr , ok := arr .(* array.String )
518620 if ! ok {
@@ -564,39 +666,66 @@ func inferScalarFunctionType(fn supportedFunctions, argType arrow.DataType) arro
564666 }
565667}
566668
567- // If cast succeeds → return the casted value
568- // If cast fails → throw a runtime error
569- type CastExpr struct {
570- Expr Expression // can be a Literal or Column (check for datatype when you resolve)
571- TargetType arrow.DataType
572- }
669+ func inferBinaryType (left arrow.DataType , op binaryOperator , right arrow.DataType ) arrow.DataType {
670+ switch op {
573671
574- func NewCastExpr (expr Expression , targetType arrow.DataType ) * CastExpr {
575- return & CastExpr {
576- Expr : expr ,
577- TargetType : targetType ,
672+ case Addition , Subtraction , Multiplication , Division :
673+ // numeric → numeric promotion rules
674+ return numericPromotion (left , right )
675+
676+ case Equal , NotEqual , LessThan , LessThanOrEqual , GreaterThan , GreaterThanOrEqual :
677+ return arrow .FixedWidthTypes .Boolean
678+
679+ case And , Or :
680+ return arrow .FixedWidthTypes .Boolean
681+
682+ default :
683+ panic (fmt .Sprintf ("inferBinaryType: unsupported operator %v" , op ))
578684 }
579685}
686+ func numericPromotion (a , b arrow.DataType ) arrow.DataType {
687+ // simplest version: return float64 for any mixed numeric types.
688+ return arrow .PrimitiveTypes .Float64
689+ }
580690
581- func EvalCast (c * CastExpr , batch * operators.RecordBatch ) (arrow.Array , error ) {
582- arr , err := EvalExpression (c .Expr , batch )
583- if err != nil {
584- return nil , err
691+ func compileSqlRegEx (s string ) string {
692+ var buf bytes.Buffer
693+
694+ // Track anchoring rules
695+ startsWithWildcard := len (s ) > 0 && s [0 ] == '%'
696+ endsWithWildcard := len (s ) > 0 && s [len (s )- 1 ] == '%'
697+
698+ // Build body
699+ for i := 0 ; i < len (s ); i ++ {
700+ switch s [i ] {
701+ case '_' :
702+ buf .WriteString ("." )
703+ case '%' :
704+ buf .WriteString (".*" )
705+ default :
706+ // Escape regex meta chars
707+ if strings .ContainsRune (`.^$|()[]*+?{}` , rune (s [i ])) {
708+ buf .WriteByte ('\\' )
709+ }
710+ buf .WriteByte (s [i ])
711+ }
585712 }
586713
587- // Use Arrow compute kernel to cast
588- castOpts := compute .SafeCastOptions (c .TargetType )
589- out , err := compute .CastArray (context .TODO (), arr , castOpts )
590- if err != nil {
591- // This is a runtime cast error
592- return nil , fmt .Errorf ("cast error: cannot cast %s to %s: %w" ,
593- arr .DataType (), c .TargetType , err )
714+ regex := buf .String ()
715+
716+ // Apply anchoring
717+ if ! startsWithWildcard {
718+ regex = "^" + regex
719+ }
720+ if ! endsWithWildcard {
721+ regex = regex + "$"
594722 }
595723
596- return out , nil
724+ return regex
597725}
598726
599- func (c * CastExpr ) ExprNode () {}
600- func (c * CastExpr ) String () string {
601- return fmt .Sprintf ("Cast(%s AS %s)" , c .Expr , c .TargetType )
727+ func validRegEx (columnValue , regExExpr string ) bool {
728+ ok , _ := regexp .MatchString (regExExpr , columnValue )
729+ return ok
730+
602731}
0 commit comments