From dd5de770b34773bd3dbf5436f96fa6bfcae586f6 Mon Sep 17 00:00:00 2001 From: LissaGreense Date: Thu, 16 May 2024 20:54:10 +0200 Subject: [PATCH 01/21] e2e version matrix update --- .github/workflows/end2end-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/end2end-tests.yml b/.github/workflows/end2end-tests.yml index 11dc859..43e111b 100644 --- a/.github/workflows/end2end-tests.yml +++ b/.github/workflows/end2end-tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.16.15', '1.17.11' ] + go: [ '1.16.15', '1.17.11', '1.18.10', '1.19.13', '1.20.14', '1.21.9', '1.22.3' ] steps: - uses: actions/checkout@v3 From 55752874498846fa2f973613af59e11b62713f87 Mon Sep 17 00:00:00 2001 From: LissaGreense Date: Thu, 16 May 2024 23:39:02 +0200 Subject: [PATCH 02/21] Refactor that makes code shiny --- ast/ast.go | 42 ++++++++----------- engine/column.go | 12 +++--- engine/engine.go | 86 +++++++++++-------------------------- engine/engine_utils.go | 4 +- engine/generic_value.go | 16 +++---- engine/row.go | 22 ++++++++++ engine/table.go | 12 +++--- main.go | 93 ++--------------------------------------- modes/handler.go | 89 ++++++++++++++++++++++++++++++++++----- parser/parser.go | 72 ++++++++++++------------------- parser/parser_test.go | 2 +- 11 files changed, 196 insertions(+), 254 deletions(-) create mode 100644 engine/row.go diff --git a/ast/ast.go b/ast/ast.go index 10244cf..7b18d17 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -29,12 +29,9 @@ type Command interface { // // Methods: // -// ExpressionNode: Abstraction needed for creating tree abstraction in order to optimise evaluating -// GetIdentifiers - Return array of pointers for all Identifiers within expression +// GetIdentifiers - Return array for all Identifiers within expression type Expression interface { - // ExpressionNode TODO: Check if ExpressionNode is needed - ExpressionNode() - GetIdentifiers() []*Identifier + GetIdentifiers() []Identifier } // Tifier - Interface that represent Token with string value @@ -73,7 +70,7 @@ type Anonymitifier struct { func (ls Anonymitifier) IsIdentifier() bool { return false } func (ls Anonymitifier) GetToken() token.Token { return ls.Token } -// BooleanExpression - Type of Expression that represent single boolean value +// BooleanExpression - TokenType of Expression that represent single boolean value // // Example: // TRUE @@ -81,13 +78,12 @@ type BooleanExpression struct { Boolean token.Token // example: token.TRUE } -func (ls BooleanExpression) ExpressionNode() {} -func (ls BooleanExpression) GetIdentifiers() []*Identifier { - var identifiers []*Identifier +func (ls BooleanExpression) GetIdentifiers() []Identifier { + var identifiers []Identifier return identifiers } -// ConditionExpression - Type of Expression that represent condition that is comparing value from column to static one +// ConditionExpression - TokenType of Expression that represent condition that is comparing value from column to static one // // Example: // column1 EQUAL 123 @@ -97,22 +93,21 @@ type ConditionExpression struct { Condition token.Token // example: token.EQUAL } -func (ls ConditionExpression) ExpressionNode() {} -func (ls ConditionExpression) GetIdentifiers() []*Identifier { - var identifiers []*Identifier +func (ls ConditionExpression) GetIdentifiers() []Identifier { + var identifiers []Identifier if ls.Left.IsIdentifier() { - identifiers = append(identifiers, &Identifier{ls.Left.GetToken()}) + identifiers = append(identifiers, Identifier{ls.Left.GetToken()}) } if ls.Right.IsIdentifier() { - identifiers = append(identifiers, &Identifier{ls.Right.GetToken()}) + identifiers = append(identifiers, Identifier{ls.Right.GetToken()}) } return identifiers } -// OperationExpression - Type of Expression that represent 2 other Expressions and conditional operation +// OperationExpression - TokenType of Expression that represent 2 other Expressions and conditional operation // // Example: // TRUE OR FALSE @@ -122,9 +117,8 @@ type OperationExpression struct { Operation token.Token // example: token.AND } -func (ls OperationExpression) ExpressionNode() {} -func (ls OperationExpression) GetIdentifiers() []*Identifier { - var identifiers []*Identifier +func (ls OperationExpression) GetIdentifiers() []Identifier { + var identifiers []Identifier identifiers = append(identifiers, ls.Left.GetIdentifiers()...) identifiers = append(identifiers, ls.Right.GetIdentifiers()...) @@ -138,7 +132,7 @@ func (ls OperationExpression) GetIdentifiers() []*Identifier { // CREATE TABLE table1( one TEXT , two INT); type CreateCommand struct { Token token.Token - Name *Identifier // name of the table + Name Identifier // name of the table ColumnNames []string ColumnTypes []token.Token } @@ -149,10 +143,10 @@ func (ls CreateCommand) TokenLiteral() string { return ls.Token.Literal } // InsertCommand - Part of Command that represent insertion of values into columns // // Example: -// INSERT INTO table1 VALUES( 'hello', 1); +// INSERT INTO table1 VALUES('hello', 1); type InsertCommand struct { Token token.Token - Name *Identifier // name of the table + Name Identifier // name of the table Values []token.Token } @@ -165,7 +159,7 @@ func (ls InsertCommand) TokenLiteral() string { return ls.Token.Literal } // SELECT one, two FROM table1; type SelectCommand struct { Token token.Token - Name *Identifier + Name Identifier Space []token.Token // ex. column names } @@ -190,7 +184,7 @@ func (ls WhereCommand) TokenLiteral() string { return ls.Token.Literal } // DELETE FROM tb1 WHERE two EQUAL 3; type DeleteCommand struct { Token token.Token - Name *Identifier // name of the table + Name Identifier // name of the table } func (ls DeleteCommand) CommandNode() {} diff --git a/engine/column.go b/engine/column.go index 2837ee1..c8633d4 100644 --- a/engine/column.go +++ b/engine/column.go @@ -13,22 +13,22 @@ type Column struct { Values []ValueInterface } -func extractColumnContent(columns []*Column, wantedColumnNames []string) *Table { +func extractColumnContent(columns []*Column, wantedColumnNames *[]string) *Table { selectedTable := &Table{Columns: make([]*Column, 0)} mappedIndexes := make([]int, 0) - for wantedColumnIndex := 0; wantedColumnIndex < len(wantedColumnNames); wantedColumnIndex++ { - for columnNameIndex := 0; columnNameIndex < len(columns); columnNameIndex++ { - if columns[columnNameIndex].Name == wantedColumnNames[wantedColumnIndex] { + for wantedColumnIndex := range *wantedColumnNames { + for columnNameIndex := range columns { + if columns[columnNameIndex].Name == (*wantedColumnNames)[wantedColumnIndex] { mappedIndexes = append(mappedIndexes, columnNameIndex) break } if columnNameIndex == len(columns)-1 { - log.Fatal("Provided column name: " + wantedColumnNames[wantedColumnIndex] + " doesn't exist") + log.Fatal("Provided column name: " + (*wantedColumnNames)[wantedColumnIndex] + " doesn't exist") } } } - for i := 0; i < len(mappedIndexes); i++ { + for i := range mappedIndexes { selectedTable.Columns = append(selectedTable.Columns, &Column{ Name: columns[mappedIndexes[i]].Name, Type: columns[mappedIndexes[i]].Type, diff --git a/engine/engine.go b/engine/engine.go index c644f68..ce21b57 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -12,13 +12,15 @@ import ( ) type DbEngine struct { - Tables map[string]*Table + Tables Tables } +type Tables map[string]*Table // New Return new DbEngine struct func New() *DbEngine { engine := &DbEngine{} - engine.Tables = make(map[string]*Table) + engine.Tables = make(Tables) + return engine } @@ -54,10 +56,10 @@ func (engine *DbEngine) InsertIntoTable(command *ast.InsertCommand) { log.Fatal("Invalid number of parameters in insert, should be: " + strconv.Itoa(len(columns)) + ", but got: " + strconv.Itoa(len(columns))) } - for i := 0; i < len(columns); i++ { + for i := range columns { expectedToken := tokenMapper(columns[i].Type.Type) if expectedToken != command.Values[i].Type { - log.Fatal("Invalid Token Type in Insert Command, expecting: " + expectedToken + ", got: " + command.Values[i].Type) + log.Fatal("Invalid Token TokenType in Insert Command, expecting: " + expectedToken + ", got: " + command.Values[i].Type) } columns[i].Values = append(columns[i].Values, getInterfaceValue(command.Values[i])) } @@ -82,7 +84,7 @@ func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, tabl for i := 0; i < len(columns); i++ { wantedColumnNames = append(wantedColumnNames, columns[i].Name) } - return extractColumnContent(columns, wantedColumnNames) + return extractColumnContent(columns, &wantedColumnNames) } else { for i := 0; i < len(command.Space); i++ { wantedColumnNames = append(wantedColumnNames, command.Space[i].Literal) @@ -153,7 +155,7 @@ func (engine *DbEngine) SelectFromTableWithOrderBy(selectCommand *ast.SelectComm func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, filteredTable *Table, copyOfTable *Table) *Table { sortPatterns := orderByCommand.SortPatterns - rows := mapTableToRows(filteredTable) + rows := MapTableToRows(filteredTable).rows sort.Slice(rows, func(i, j int) bool { howDeepWeSort := 0 @@ -189,10 +191,7 @@ func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, filte func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCommand, negation bool) *Table { filteredTable := getCopyOfTableWithoutRows(table) - //TODO: maybe rows should have separate structure, so it would would have it's on methods - rows := mapTableToRows(table) - - for _, row := range rows { + for _, row := range MapTableToRows(table).rows { fulfilledFilters, err := isFulfillingFilters(row, whereCommand.Expression) if err != nil { log.Fatal(err.Error()) @@ -226,38 +225,17 @@ func getCopyOfTableWithoutRows(table *Table) *Table { return filteredTable } -func mapTableToRows(table *Table) []map[string]ValueInterface { - rows := make([]map[string]ValueInterface, 0) - - numberOfRows := len(table.Columns[0].Values) - - for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { - row := make(map[string]ValueInterface) - for _, column := range table.Columns { - row[column.Name] = column.Values[rowIndex] - } - rows = append(rows, row) - } - return rows -} - func isFulfillingFilters(row map[string]ValueInterface, expressionTree ast.Expression) (bool, error) { - operationExpression, operationExpressionIsValid := expressionTree.(*ast.OperationExpression) - if operationExpressionIsValid { - return processOperationExpression(row, operationExpression) - } - - booleanExpression, booleanExpressionIsValid := expressionTree.(*ast.BooleanExpression) - if booleanExpressionIsValid { - return processBooleanExpression(booleanExpression) - } - - conditionExpression, conditionExpressionIsValid := expressionTree.(*ast.ConditionExpression) - if conditionExpressionIsValid { - return processConditionExpression(row, conditionExpression) + switch mappedExpression := expressionTree.(type) { + case *ast.OperationExpression: + return processOperationExpression(row, mappedExpression) + case *ast.BooleanExpression: + return processBooleanExpression(mappedExpression) + case *ast.ConditionExpression: + return processConditionExpression(row, mappedExpression) + default: + return false, fmt.Errorf("unsupported expression has been used in WHERE command: %v", expressionTree.GetIdentifiers()) } - - return false, fmt.Errorf("unsupported expression has been used in WHERE command: %v", expressionTree.GetIdentifiers()) } func processConditionExpression(row map[string]ValueInterface, conditionExpression *ast.ConditionExpression) (bool, error) { @@ -313,26 +291,12 @@ func processBooleanExpression(booleanExpression *ast.BooleanExpression) (bool, e } func getTifierValue(tifier ast.Tifier, row map[string]ValueInterface) (ValueInterface, error) { - identifier, identifierIsValid := tifier.(ast.Identifier) - - if identifierIsValid { - return row[identifier.GetToken().Literal], nil - } - - anonymitifier, anonymitifierIsValid := tifier.(ast.Anonymitifier) - if anonymitifierIsValid { - return getInterfaceValue(anonymitifier.GetToken()), nil - } - - // TODO: Maybe information in which table this column doesn't exist is needed - return nil, errors.New("Column name:'" + tifier.GetToken().Literal + "' doesn't exist!") -} - -func getColumnIndexByName(columns []*Column, columName string) (int, error) { - for i, column := range columns { - if column.Name == columName { - return i, nil - } + switch mappedTifier := tifier.(type) { + case ast.Identifier: + return row[mappedTifier.GetToken().Literal], nil + case ast.Anonymitifier: + return getInterfaceValue(mappedTifier.GetToken()), nil + default: + return nil, errors.New("Couldn't map interface to any implementation of it: " + tifier.GetToken().Literal) } - return -1, errors.New("Column name:'" + columName + "' doesn't exist!") } diff --git a/engine/engine_utils.go b/engine/engine_utils.go index 83c1036..5e27542 100644 --- a/engine/engine_utils.go +++ b/engine/engine_utils.go @@ -31,7 +31,7 @@ func tokenMapper(inputToken token.Type) token.Type { } } -func unique(arr []string) []string { +func unique(arr []string) *[]string { occurred := map[string]bool{} var result []string @@ -41,5 +41,5 @@ func unique(arr []string) []string { result = append(result, arr[e]) } } - return result + return &result } diff --git a/engine/generic_value.go b/engine/generic_value.go index 7adab43..284a912 100644 --- a/engine/generic_value.go +++ b/engine/generic_value.go @@ -48,43 +48,43 @@ func (value StringValue) IsEqual(valueInterface ValueInterface) bool { } // isSmallerThan implementations -func (firstValue IntegerValue) isSmallerThan(secondValue ValueInterface) bool { +func (value IntegerValue) isSmallerThan(secondValue ValueInterface) bool { secondValueAsInteger, isInteger := secondValue.(IntegerValue) if !isInteger { log.Fatal("Can't compare Integer with other type") } - return firstValue.Value < secondValueAsInteger.Value + return value.Value < secondValueAsInteger.Value } -func (firstValue StringValue) isSmallerThan(secondValue ValueInterface) bool { +func (value StringValue) isSmallerThan(secondValue ValueInterface) bool { secondValueAsString, isString := secondValue.(StringValue) if !isString { log.Fatal("Can't compare String with other type") } - return firstValue.Value < secondValueAsString.Value + return value.Value < secondValueAsString.Value } // isGreaterThan implementations -func (firstValue IntegerValue) isGreaterThan(secondValue ValueInterface) bool { +func (value IntegerValue) isGreaterThan(secondValue ValueInterface) bool { secondValueAsInteger, isInteger := secondValue.(IntegerValue) if !isInteger { log.Fatal("Can't compare Integer with other type") } - return firstValue.Value > secondValueAsInteger.Value + return value.Value > secondValueAsInteger.Value } -func (firstValue StringValue) isGreaterThan(secondValue ValueInterface) bool { +func (value StringValue) isGreaterThan(secondValue ValueInterface) bool { secondValueAsString, isString := secondValue.(StringValue) if !isString { log.Fatal("Can't compare String with other type") } - return firstValue.Value > secondValueAsString.Value + return value.Value > secondValueAsString.Value } func areEqual(first ValueInterface, second ValueInterface) bool { diff --git a/engine/row.go b/engine/row.go new file mode 100644 index 0000000..35fc707 --- /dev/null +++ b/engine/row.go @@ -0,0 +1,22 @@ +package engine + +// Rows - Contain rows that store values, alternative to Table, some operations are easier +type Rows struct { + rows []map[string]ValueInterface +} + +// MapTableToRows - transform Table struct into Rows +func MapTableToRows(table *Table) Rows { + rows := make([]map[string]ValueInterface, 0) + + numberOfRows := len(table.Columns[0].Values) + + for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { + row := make(map[string]ValueInterface) + for _, column := range table.Columns { + row[column.Name] = column.Values[rowIndex] + } + rows = append(rows, row) + } + return Rows{rows: rows} +} diff --git a/engine/table.go b/engine/table.go index 26f4b5f..2babc71 100644 --- a/engine/table.go +++ b/engine/table.go @@ -12,7 +12,7 @@ func (table *Table) isEqual(secondTable *Table) bool { return false } - for i := 0; i < len(table.Columns); i++ { + for i := range table.Columns { if table.Columns[i].Name != secondTable.Columns[i].Name { return false } @@ -25,7 +25,7 @@ func (table *Table) isEqual(secondTable *Table) bool { if len(table.Columns[i].Values) != len(secondTable.Columns[i].Values) { return false } - for j := 0; j < len(table.Columns[i].Values); j++ { + for j := range table.Columns[i].Values { if table.Columns[i].Values[j].ToString() != secondTable.Columns[i].Values[j].ToString() { return false } @@ -42,7 +42,7 @@ func (table *Table) ToString() string { result := bar + "\n" result += "|" - for i := 0; i < len(table.Columns); i++ { + for i := range table.Columns { result += " " for j := 0; j < columWidths[i]-len(table.Columns[i].Name); j++ { result += " " @@ -57,7 +57,7 @@ func (table *Table) ToString() string { for iRow := 0; iRow < rowsCount; iRow++ { result += "|" - for iColumn := 0; iColumn < len(table.Columns); iColumn++ { + for iColumn := range table.Columns { result += " " printedValue := table.Columns[iColumn].Values[iRow].ToString() @@ -94,9 +94,9 @@ func getBar(columWidths []int) string { func getColumWidths(columns []*Column) []int { widths := make([]int, 0) - for iColumn := 0; iColumn < len(columns); iColumn++ { + for iColumn := range columns { maxLength := len(columns[iColumn].Name) - for iRow := 0; iRow < len(columns[iColumn].Values); iRow++ { + for iRow := range columns[iColumn].Values { valueLength := len(columns[iColumn].Values[iRow].ToString()) if columns[iColumn].Type.Literal == token.TEXT { valueLength += 2 // double "'" diff --git a/main.go b/main.go index d6e6497..ff127e0 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "flag" - "github.com/LissaGreense/GO4SQL/ast" "github.com/LissaGreense/GO4SQL/engine" "github.com/LissaGreense/GO4SQL/modes" "log" @@ -18,98 +17,12 @@ func main() { engineSQL := engine.New() if len(*filePath) > 0 { - modes.HandleFileMode(*filePath, engineSQL, evaluateInEngine) + modes.HandleFileMode(*filePath, engineSQL) } else if *streamMode { - modes.HandleStreamMode(engineSQL, evaluateInEngine) + modes.HandleStreamMode(engineSQL) } else if *socketMode { - modes.HandleSocketMode(*port, engineSQL, evaluateInEngine) + modes.HandleSocketMode(*port, engineSQL) } else { log.Println("No mode has been providing. Exiting.") } } - -func evaluateInEngine(sequences *ast.Sequence, engineSQL *engine.DbEngine) string { - commands := sequences.Commands - - result := "" - for commandIndex, command := range commands { - - // TODO: Check if those statements are necessary - _, whereCommandIsValid := command.(*ast.WhereCommand) - if whereCommandIsValid { - continue - } - - _, orderByCommandIsValid := command.(*ast.OrderByCommand) - if orderByCommandIsValid { - continue - } - - createCommand, createCommandIsValid := command.(*ast.CreateCommand) - if createCommandIsValid { - engineSQL.CreateTable(createCommand) - result += "Table '" + createCommand.Name.GetToken().Literal + "' has been created\n" - continue - } - - insertCommand, insertCommandIsValid := command.(*ast.InsertCommand) - if insertCommandIsValid { - engineSQL.InsertIntoTable(insertCommand) - result += "Data Inserted\n" - continue - } - - selectCommand, selectCommandIsValid := command.(*ast.SelectCommand) - if selectCommandIsValid { - result += getSelectResponse(commandIndex, commands, engineSQL, selectCommand) + "\n" - continue - } - - deleteCommand, deleteCommandIsValid := command.(*ast.DeleteCommand) - if deleteCommandIsValid { - nextCommandIndex := commandIndex + 1 - - if nextCommandIndex != len(commands) { - whereCommand, whereCommandIsValid := commands[nextCommandIndex].(*ast.WhereCommand) - - if whereCommandIsValid { - engineSQL.DeleteFromTable(deleteCommand, whereCommand) - } - } - result += "Data from '" + deleteCommand.Name.GetToken().Literal + "' has been deleted\n" - continue - } - - } - - return result -} - -func getSelectResponse(commandIndex int, commands []ast.Command, engineSQL *engine.DbEngine, selectCommand *ast.SelectCommand) string { - nextCommandIndex := commandIndex + 1 - - if nextCommandIndex != len(commands) { - whereCommand, whereCommandIsValid := commands[nextCommandIndex].(*ast.WhereCommand) - - // TODO: It cannot be like that. Have to be refactored to tree structure. - if whereCommandIsValid { - if nextCommandIndex+1 < len(commands) { - orderByCommand, orderByCommandIsValid := commands[nextCommandIndex+1].(*ast.OrderByCommand) - - if orderByCommandIsValid { - return engineSQL.SelectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand).ToString() - } - } - - return engineSQL.SelectFromTableWithWhere(selectCommand, whereCommand).ToString() - } - - orderByCommand, orderByCommandIsValid := commands[nextCommandIndex].(*ast.OrderByCommand) - - if orderByCommandIsValid { - return engineSQL.SelectFromTableWithOrderBy(selectCommand, orderByCommand).ToString() - } - } - - return engineSQL.SelectFromTable(selectCommand).ToString() -} diff --git a/modes/handler.go b/modes/handler.go index 4e4116f..7ad50a0 100644 --- a/modes/handler.go +++ b/modes/handler.go @@ -7,7 +7,6 @@ import ( "github.com/LissaGreense/GO4SQL/engine" "github.com/LissaGreense/GO4SQL/lexer" "github.com/LissaGreense/GO4SQL/parser" - "io/ioutil" "log" "net" "os" @@ -15,22 +14,22 @@ import ( ) // HandleFileMode - Handle GO4SQL use case where client sends input via text file -func HandleFileMode(filePath string, engine *engine.DbEngine, evaluate func(sequences *ast.Sequence, engineSQL *engine.DbEngine) string) { - content, err := ioutil.ReadFile(filePath) +func HandleFileMode(filePath string, engine *engine.DbEngine) { + content, err := os.ReadFile(filePath) if err != nil { log.Fatal(err) } sequences := bytesToSequences(content) - fmt.Print(evaluate(sequences, engine)) + fmt.Print(evaluateInEngine(sequences, engine)) } // HandleStreamMode - Handle GO4SQL use case where client sends input via stdin -func HandleStreamMode(engine *engine.DbEngine, evaluate func(sequences *ast.Sequence, engineSQL *engine.DbEngine) string) { +func HandleStreamMode(engine *engine.DbEngine) { reader := bufio.NewScanner(os.Stdin) for reader.Scan() { sequences := bytesToSequences(reader.Bytes()) - fmt.Print(evaluate(sequences, engine)) + fmt.Print(evaluateInEngine(sequences, engine)) } err := reader.Err() if err != nil { @@ -39,7 +38,7 @@ func HandleStreamMode(engine *engine.DbEngine, evaluate func(sequences *ast.Sequ } // HandleSocketMode - Handle GO4SQL use case where client sends input via socket protocol -func HandleSocketMode(port int, engine *engine.DbEngine, evaluate func(sequences *ast.Sequence, engineSQL *engine.DbEngine) string) { +func HandleSocketMode(port int, engine *engine.DbEngine) { listener, err := net.Listen("tcp", "localhost:"+strconv.Itoa(port)) log.Printf("Starting Socket Server on %d port\n", port) @@ -61,9 +60,79 @@ func HandleSocketMode(port int, engine *engine.DbEngine, evaluate func(sequences continue } - go handleSocketClient(conn, engine, evaluate) + go handleSocketClient(conn, engine) } } +func evaluateInEngine(sequences *ast.Sequence, engineSQL *engine.DbEngine) string { + commands := sequences.Commands + + result := "" + for commandIndex, command := range commands { + + switch mappedCommand := command.(type) { + case *ast.WhereCommand: + continue + case *ast.OrderByCommand: + continue + case *ast.CreateCommand: + engineSQL.CreateTable(mappedCommand) + result += "Table '" + mappedCommand.Name.GetToken().Literal + "' has been created\n" + continue + case *ast.InsertCommand: + engineSQL.InsertIntoTable(mappedCommand) + result += "Data Inserted\n" + continue + case *ast.SelectCommand: + result += getSelectResponse(commandIndex, &commands, engineSQL, mappedCommand) + "\n" + continue + case *ast.DeleteCommand: + nextCommandIndex := commandIndex + 1 + if nextCommandIndex != len(commands) { + whereCommand, whereCommandIsValid := commands[nextCommandIndex].(*ast.WhereCommand) + + if whereCommandIsValid { + engineSQL.DeleteFromTable(mappedCommand, whereCommand) + } + } + result += "Data from '" + mappedCommand.Name.GetToken().Literal + "' has been deleted\n" + continue + default: + log.Fatalf("Unsupported Command detected: %v", command) + } + } + + return result +} + +func getSelectResponse(commandIndex int, commands *[]ast.Command, engineSQL *engine.DbEngine, selectCommand *ast.SelectCommand) string { + // TODO: this function should be a method of ast.SelectCommand + nextCommandIndex := commandIndex + 1 + + if nextCommandIndex != len(*commands) { + whereCommand, whereCommandIsValid := (*commands)[nextCommandIndex].(*ast.WhereCommand) + + // TODO: It cannot be like that. Have to be refactored to tree structure. + if whereCommandIsValid { + if nextCommandIndex+1 < len(*commands) { + orderByCommand, orderByCommandIsValid := (*commands)[nextCommandIndex+1].(*ast.OrderByCommand) + + if orderByCommandIsValid { + return engineSQL.SelectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand).ToString() + } + } + + return engineSQL.SelectFromTableWithWhere(selectCommand, whereCommand).ToString() + } + + orderByCommand, orderByCommandIsValid := (*commands)[nextCommandIndex].(*ast.OrderByCommand) + + if orderByCommandIsValid { + return engineSQL.SelectFromTableWithOrderBy(selectCommand, orderByCommand).ToString() + } + } + + return engineSQL.SelectFromTable(selectCommand).ToString() +} func bytesToSequences(content []byte) *ast.Sequence { lex := lexer.RunLexer(string(content)) @@ -73,7 +142,7 @@ func bytesToSequences(content []byte) *ast.Sequence { return sequences } -func handleSocketClient(conn net.Conn, engine *engine.DbEngine, evaluate func(sequences *ast.Sequence, engineSQL *engine.DbEngine) string) { +func handleSocketClient(conn net.Conn, engine *engine.DbEngine) { defer func(conn net.Conn) { err := conn.Close() if err != nil { @@ -89,7 +158,7 @@ func handleSocketClient(conn net.Conn, engine *engine.DbEngine, evaluate func(se log.Fatal("Error:", err) } sequences := bytesToSequences(buffer) - commandResult := evaluate(sequences, engine) + commandResult := evaluateInEngine(sequences, engine) if len(commandResult) > 0 { _, err = conn.Write([]byte(commandResult)) diff --git a/parser/parser.go b/parser/parser.go index dc8d42d..15e49ec 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -11,17 +11,19 @@ import ( // Parser - Contain token that is currently analyzed by parser and the next one. Lexer is used to tokenize the client // text input. type Parser struct { - lexer *lexer.Lexer + lexer lexer.Lexer currentToken token.Token peekToken token.Token } // New - Return new Parser struct func New(lexer *lexer.Lexer) *Parser { - p := &Parser{lexer: lexer} + p := &Parser{lexer: *lexer} + // Read two tokens, so curToken and peekToken are both set p.nextToken() p.nextToken() + return p } @@ -65,7 +67,7 @@ func validateToken(tokenType token.Type, expectedTokens []token.Type) { // // Example of input parsable to the ast.CreateCommand: // create table tbl( one TEXT , two INT ); -func (parser *Parser) parseCreateCommand() ast.Command { // TODO make it return the pointer +func (parser *Parser) parseCreateCommand() ast.Command { // token.CREATE already at current position in parser createCommand := &ast.CreateCommand{Token: parser.currentToken} @@ -75,7 +77,7 @@ func (parser *Parser) parseCreateCommand() ast.Command { // TODO make it return validateTokenAndSkip(parser, []token.Type{token.TABLE}) validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - createCommand.Name = &ast.Identifier{Token: parser.currentToken} + createCommand.Name = ast.Identifier{Token: parser.currentToken} // Skip token.IDENT parser.nextToken() @@ -134,7 +136,7 @@ func (parser *Parser) parseInsertCommand() ast.Command { validateTokenAndSkip(parser, []token.Type{token.INTO}) validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - insertCommand.Name = &ast.Identifier{Token: parser.currentToken} + insertCommand.Name = ast.Identifier{Token: parser.currentToken} // Ignore token.INDENT parser.nextToken() @@ -142,13 +144,13 @@ func (parser *Parser) parseInsertCommand() ast.Command { validateTokenAndSkip(parser, []token.Type{token.LPAREN}) for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.APOSTROPHE { - // TODO: Add apostrophe validation parser.skipIfCurrentTokenIsApostrophe() validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL}) insertCommand.Values = append(insertCommand.Values, parser.currentToken) // Ignore token.IDENT or token.LITERAL parser.nextToken() + parser.skipIfCurrentTokenIsApostrophe() if parser.currentToken.Type != token.COMMA { @@ -196,7 +198,7 @@ func (parser *Parser) parseSelectCommand() ast.Command { validateTokenAndSkip(parser, []token.Type{token.FROM}) - selectCommand.Name = &ast.Identifier{Token: parser.currentToken} + selectCommand.Name = ast.Identifier{Token: parser.currentToken} // Ignore token.INDENT parser.nextToken() @@ -249,7 +251,7 @@ func (parser *Parser) parseDeleteCommand() ast.Command { validateTokenAndSkip(parser, []token.Type{token.FROM}) validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - deleteCommand.Name = &ast.Identifier{Token: parser.currentToken} + deleteCommand.Name = ast.Identifier{Token: parser.currentToken} // token.IDENT no longer needed parser.nextToken() @@ -377,32 +379,21 @@ func (parser *Parser) getBooleanExpression() (bool, *ast.BooleanExpression) { // getConditionalExpression - Return ast.ConditionExpression created from tokens and validate the syntax func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression) { - // TODO REFACTOR THIS conditionalExpression := &ast.ConditionExpression{} - if parser.currentToken.Type == token.IDENT { - conditionalExpression.Left = ast.Identifier{ - Token: parser.currentToken, - } + switch parser.currentToken.Type { + case token.IDENT: + conditionalExpression.Left = ast.Identifier{Token: parser.currentToken} parser.nextToken() - - } else if parser.currentToken.Type == token.APOSTROPHE { + case token.APOSTROPHE: parser.skipIfCurrentTokenIsApostrophe() - - conditionalExpression.Left = ast.Anonymitifier{ - Token: parser.currentToken, - } - + conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) - } else if parser.currentToken.Type == token.LITERAL { - conditionalExpression.Left = ast.Anonymitifier{ - Token: parser.currentToken, - } + case token.LITERAL: + conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() - - } else { + default: return false, conditionalExpression } @@ -410,30 +401,19 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression conditionalExpression.Condition = parser.currentToken parser.nextToken() - if parser.currentToken.Type == token.IDENT { - conditionalExpression.Right = ast.Identifier{ - Token: parser.currentToken, - } + switch parser.currentToken.Type { + case token.IDENT: + conditionalExpression.Right = ast.Identifier{Token: parser.currentToken} parser.nextToken() - - } else if parser.currentToken.Type == token.APOSTROPHE { + case token.APOSTROPHE: parser.skipIfCurrentTokenIsApostrophe() - - conditionalExpression.Right = ast.Anonymitifier{ - Token: parser.currentToken, - } - + conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) - - } else if parser.currentToken.Type == token.LITERAL { - conditionalExpression.Right = ast.Anonymitifier{ - Token: parser.currentToken, - } + case token.LITERAL: + conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() - - } else { + default: log.Fatal("Syntax error, expecting: ", token.APOSTROPHE, ",", token.IDENT, ",", token.LITERAL, ", got: ", parser.currentToken.Literal) } diff --git a/parser/parser_test.go b/parser/parser_test.go index 0e4a1e1..472d9a0 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -390,7 +390,7 @@ func tokenArrayEquals(a []token.Token, b []token.Token) bool { func testOrderByCommands(t *testing.T, expectedOrderByCommand ast.OrderByCommand, actualOrderByCommand *ast.OrderByCommand) { if expectedOrderByCommand.Token.Type != actualOrderByCommand.Token.Type { - t.Errorf("Expecting Token Type: %q, got: %q", expectedOrderByCommand.Token.Type, actualOrderByCommand.Token.Type) + t.Errorf("Expecting Token TokenType: %q, got: %q", expectedOrderByCommand.Token.Type, actualOrderByCommand.Token.Type) } if expectedOrderByCommand.Token.Literal != actualOrderByCommand.Token.Literal { t.Errorf("Expecting Token Literal: %s, got: %s", expectedOrderByCommand.Token.Literal, actualOrderByCommand.Token.Literal) From ecb20457787e6a4543410e5f09ebfde3f15eceac Mon Sep 17 00:00:00 2001 From: LissaGreense Date: Thu, 16 May 2024 23:41:36 +0200 Subject: [PATCH 03/21] Fix parser test pointer --- parser/parser_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parser/parser_test.go b/parser/parser_test.go index 472d9a0..6ca3640 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -188,7 +188,7 @@ func TestParseDeleteCommand(t *testing.T) { input := "DELETE FROM colName1 WHERE colName2 EQUAL 6462389;" expectedDeleteCommand := ast.DeleteCommand{ Token: token.Token{Type: token.DELETE, Literal: "DELETE"}, - Name: &ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName1"}}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName1"}}, } expectedWhereCommand := ast.ConditionExpression{ Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName2"}}, From c224496d41adfc657e5672eed0d239c9b2f124fa Mon Sep 17 00:00:00 2001 From: LissaGreense Date: Wed, 22 May 2024 23:48:57 +0200 Subject: [PATCH 04/21] Final refactor - improve select structure --- README.md | 2 +- ast/ast.go | 58 ++++++++++++++++++++++++++--- engine/engine.go | 87 ++++++++++++++++++++++++++++++++++++------- engine/engine_test.go | 34 +++++------------ modes/handler.go | 76 ++----------------------------------- parser/parser.go | 35 ++++++++++------- parser/parser_test.go | 54 +++++++++++++++++---------- 7 files changed, 197 insertions(+), 149 deletions(-) diff --git a/README.md b/README.md index db942e2..b111801 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ docker build -t go4sql:test . ``` ### Run docker in interactive stream mode -To run this docker image in interactive stream mode mode use this command: +To run this docker image in interactive stream mode use this command: ```shell docker run -i go4sql:test -stream diff --git a/ast/ast.go b/ast/ast.go index 7b18d17..3fd6148 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -158,14 +158,46 @@ func (ls InsertCommand) TokenLiteral() string { return ls.Token.Literal } // Example: // SELECT one, two FROM table1; type SelectCommand struct { - Token token.Token - Name Identifier - Space []token.Token // ex. column names + Token token.Token + Name Identifier // ex. name of table + Space []token.Token // ex. column names + WhereCommand *WhereCommand // optional + OrderByCommand *OrderByCommand // optional } func (ls SelectCommand) CommandNode() {} func (ls SelectCommand) TokenLiteral() string { return ls.Token.Literal } +// HasWhereCommand - returns true if optional HasWhereCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table WHERE column1 NOT 'hi'; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls SelectCommand) HasWhereCommand() bool { + if ls.WhereCommand == nil { + return false + } + return true +} + +// HasOrderByCommand - returns true if optional OrderByCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table ORDER BY column1 ASC; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls SelectCommand) HasOrderByCommand() bool { + if ls.OrderByCommand == nil { + return false + } + return true +} + // WhereCommand - Part of Command that represent Where statement with expression that will qualify values from Select // // Example: @@ -183,13 +215,29 @@ func (ls WhereCommand) TokenLiteral() string { return ls.Token.Literal } // Example: // DELETE FROM tb1 WHERE two EQUAL 3; type DeleteCommand struct { - Token token.Token - Name Identifier // name of the table + Token token.Token + Name Identifier // name of the table + WhereCommand *WhereCommand // optional } func (ls DeleteCommand) CommandNode() {} func (ls DeleteCommand) TokenLiteral() string { return ls.Token.Literal } +// HasWhereCommand - returns true if optional HasWhereCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table WHERE column1 NOT 'hi'; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls DeleteCommand) HasWhereCommand() bool { + if ls.WhereCommand == nil { + return false + } + return true +} + // OrderByCommand - Part of Command that ordering columns from SelectCommand // // Example: diff --git a/engine/engine.go b/engine/engine.go index ce21b57..ee46ff9 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -24,8 +24,67 @@ func New() *DbEngine { return engine } -// CreateTable - initialize new table in engine with specified name -func (engine *DbEngine) CreateTable(command *ast.CreateCommand) { +// Evaluate - it takes sequences, map them to specific implementation and then process it in SQL engine +func (engine *DbEngine) Evaluate(sequences *ast.Sequence) string { + commands := sequences.Commands + + result := "" + for commandIndex, command := range commands { + + switch mappedCommand := command.(type) { + case *ast.WhereCommand: + continue + case *ast.OrderByCommand: + continue + case *ast.CreateCommand: + engine.createTable(mappedCommand) + result += "Table '" + mappedCommand.Name.GetToken().Literal + "' has been created\n" + continue + case *ast.InsertCommand: + engine.insertIntoTable(mappedCommand) + result += "Data Inserted\n" + continue + case *ast.SelectCommand: + result += engine.GetSelectResponse(mappedCommand) + "\n" + continue + case *ast.DeleteCommand: + nextCommandIndex := commandIndex + 1 + if nextCommandIndex != len(commands) { + whereCommand, whereCommandIsValid := commands[nextCommandIndex].(*ast.WhereCommand) + + if whereCommandIsValid { + engine.deleteFromTable(mappedCommand, whereCommand) + } + } + result += "Data from '" + mappedCommand.Name.GetToken().Literal + "' has been deleted\n" + continue + default: + log.Fatalf("Unsupported Command detected: %v", command) + } + } + + return result +} + +// GetSelectResponse - Returns Select response basing on ast.OrderByCommand and ast.WhereCommand included in this Select +func (engine *DbEngine) GetSelectResponse(selectCommand *ast.SelectCommand) string { + if selectCommand.HasWhereCommand() { + whereCommand := selectCommand.WhereCommand + if selectCommand.HasOrderByCommand() { + orderByCommand := selectCommand.OrderByCommand + return engine.selectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand).ToString() + } + return engine.selectFromTableWithWhere(selectCommand, whereCommand).ToString() + } + if selectCommand.HasOrderByCommand() { + orderByCommand := selectCommand.OrderByCommand + return engine.selectFromTableWithOrderBy(selectCommand, orderByCommand).ToString() + } + return engine.selectFromTable(selectCommand).ToString() +} + +// createTable - initialize new table in engine with specified name +func (engine *DbEngine) createTable(command *ast.CreateCommand) { _, exist := engine.Tables[command.Name.Token.Literal] if exist { @@ -43,8 +102,8 @@ func (engine *DbEngine) CreateTable(command *ast.CreateCommand) { } } -// InsertIntoTable - Insert row of values into the table -func (engine *DbEngine) InsertIntoTable(command *ast.InsertCommand) { +// insertIntoTable - Insert row of values into the table +func (engine *DbEngine) insertIntoTable(command *ast.InsertCommand) { table, exist := engine.Tables[command.Name.Token.Literal] if !exist { log.Fatal("Table with the name of " + command.Name.Token.Literal + " doesn't exist!") @@ -65,8 +124,8 @@ func (engine *DbEngine) InsertIntoTable(command *ast.InsertCommand) { } } -// SelectFromTable - Return Table containing all values requested by SelectCommand -func (engine *DbEngine) SelectFromTable(command *ast.SelectCommand) *Table { +// selectFromTable - Return Table containing all values requested by SelectCommand +func (engine *DbEngine) selectFromTable(command *ast.SelectCommand) *Table { table, exist := engine.Tables[command.Name.Token.Literal] if !exist { @@ -93,8 +152,8 @@ func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, tabl } } -// DeleteFromTable - Delete all rows of data from table that match given condition -func (engine *DbEngine) DeleteFromTable(deleteCommand *ast.DeleteCommand, whereCommand *ast.WhereCommand) { +// deleteFromTable - Delete all rows of data from table that match given condition +func (engine *DbEngine) deleteFromTable(deleteCommand *ast.DeleteCommand, whereCommand *ast.WhereCommand) { table, exist := engine.Tables[deleteCommand.Name.Token.Literal] if !exist { @@ -104,8 +163,8 @@ func (engine *DbEngine) DeleteFromTable(deleteCommand *ast.DeleteCommand, whereC engine.Tables[deleteCommand.Name.Token.Literal] = engine.getFilteredTable(table, whereCommand, true) } -// SelectFromTableWithWhere - Return Table containing all values requested by SelectCommand and filtered by WhereCommand -func (engine *DbEngine) SelectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand) *Table { +// selectFromTableWithWhere - Return Table containing all values requested by SelectCommand and filtered by WhereCommand +func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand) *Table { table, exist := engine.Tables[selectCommand.Name.Token.Literal] if !exist { @@ -121,9 +180,9 @@ func (engine *DbEngine) SelectFromTableWithWhere(selectCommand *ast.SelectComman return engine.selectFromProvidedTable(selectCommand, filteredTable) } -// SelectFromTableWithWhereAndOrderBy - Return Table containing all values requested by SelectCommand, +// selectFromTableWithWhereAndOrderBy - Return Table containing all values requested by SelectCommand, // filtered by WhereCommand and sorted by OrderByCommand -func (engine *DbEngine) SelectFromTableWithWhereAndOrderBy(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, orderByCommand *ast.OrderByCommand) *Table { +func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, orderByCommand *ast.OrderByCommand) *Table { table, exist := engine.Tables[selectCommand.Name.Token.Literal] if !exist { @@ -137,8 +196,8 @@ func (engine *DbEngine) SelectFromTableWithWhereAndOrderBy(selectCommand *ast.Se return engine.selectFromProvidedTable(selectCommand, engine.getSortedTable(orderByCommand, filteredTable, emptyTable)) } -// SelectFromTableWithOrderBy - Return Table containing all values requested by SelectCommand and sorted by OrderByCommand -func (engine *DbEngine) SelectFromTableWithOrderBy(selectCommand *ast.SelectCommand, orderByCommand *ast.OrderByCommand) *Table { +// selectFromTableWithOrderBy - Return Table containing all values requested by SelectCommand and sorted by OrderByCommand +func (engine *DbEngine) selectFromTableWithOrderBy(selectCommand *ast.SelectCommand, orderByCommand *ast.OrderByCommand) *Table { table, exist := engine.Tables[selectCommand.Name.Token.Literal] if !exist { diff --git a/engine/engine_test.go b/engine/engine_test.go index 8e7f54c..b6380b1 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -305,9 +305,6 @@ func (engineTestSuite *engineTestSuite) runTestSuite(t *testing.T) { input += engineTestSuite.createInputs[inputIndex] + "\n" } for inputIndex := 0; inputIndex < len(engineTestSuite.insertAndDeleteInputs); inputIndex++ { - if strings.HasPrefix(engineTestSuite.insertAndDeleteInputs[inputIndex], "DELETE") { - expectedSequencesNumber++ - } input += engineTestSuite.insertAndDeleteInputs[inputIndex] + "\n" } input += engineTestSuite.selectInput @@ -319,42 +316,31 @@ func (engineTestSuite *engineTestSuite) runTestSuite(t *testing.T) { expectedSequencesNumber += len(engineTestSuite.createInputs) + len(engineTestSuite.insertAndDeleteInputs) + 1 var actualTable *Table - - if strings.Contains(engineTestSuite.selectInput, "ORDER BY") { - expectedSequencesNumber++ - } + engine := engineTestSuite.getEngineWithInsertedValues(sequences) + selectCommand := sequences.Commands[len(sequences.Commands)-1].(*ast.SelectCommand) if strings.Contains(engineTestSuite.selectInput, " WHERE ") { // WHERE CONDITION - - expectedSequencesNumber++ if len(sequences.Commands) != expectedSequencesNumber { t.Fatalf("sequences does not contain %d statements. got=%d", expectedSequencesNumber, len(sequences.Commands)) } - - engine := engineTestSuite.getEngineWithInsertedValues(sequences) - if strings.Contains(engineTestSuite.selectInput, "ORDER BY") { - actualTable = engine.SelectFromTableWithWhereAndOrderBy(sequences.Commands[len(sequences.Commands)-3].(*ast.SelectCommand), sequences.Commands[len(sequences.Commands)-2].(*ast.WhereCommand), sequences.Commands[len(sequences.Commands)-1].(*ast.OrderByCommand)) + actualTable = engine.selectFromTableWithWhereAndOrderBy(selectCommand, selectCommand.WhereCommand, selectCommand.OrderByCommand) } else { - actualTable = engine.SelectFromTableWithWhere(sequences.Commands[len(sequences.Commands)-2].(*ast.SelectCommand), sequences.Commands[len(sequences.Commands)-1].(*ast.WhereCommand)) + actualTable = engine.selectFromTableWithWhere(selectCommand, selectCommand.WhereCommand) } } else { // NO WHERE CONDITION - if len(sequences.Commands) != expectedSequencesNumber { t.Fatalf("sequences does not contain %d statements. got=%d", expectedSequencesNumber, len(sequences.Commands)) } - - engine := engineTestSuite.getEngineWithInsertedValues(sequences) - if strings.Contains(engineTestSuite.selectInput, "ORDER BY") { - actualTable = engine.SelectFromTableWithOrderBy(sequences.Commands[len(sequences.Commands)-2].(*ast.SelectCommand), sequences.Commands[len(sequences.Commands)-1].(*ast.OrderByCommand)) + actualTable = engine.selectFromTableWithOrderBy(selectCommand, selectCommand.OrderByCommand) } else { - actualTable = engine.SelectFromTable(sequences.Commands[len(sequences.Commands)-1].(*ast.SelectCommand)) + actualTable = engine.selectFromTable(selectCommand) } } @@ -386,14 +372,14 @@ func (engineTestSuite *engineTestSuite) getEngineWithInsertedValues(sequences *a engine := New() for commandIndex := 0; commandIndex < len(sequences.Commands); commandIndex++ { if createCommand, ok := sequences.Commands[commandIndex].(*ast.CreateCommand); ok { - engine.CreateTable(createCommand) + engine.createTable(createCommand) } if insertCommand, ok := sequences.Commands[commandIndex].(*ast.InsertCommand); ok { - engine.InsertIntoTable(insertCommand) + engine.insertIntoTable(insertCommand) } if deleteCommand, ok := sequences.Commands[commandIndex].(*ast.DeleteCommand); ok { - whereCommand := sequences.Commands[commandIndex+1].(*ast.WhereCommand) - engine.DeleteFromTable(deleteCommand, whereCommand) + whereCommand := deleteCommand.WhereCommand + engine.deleteFromTable(deleteCommand, whereCommand) } } return engine diff --git a/modes/handler.go b/modes/handler.go index 7ad50a0..13a0710 100644 --- a/modes/handler.go +++ b/modes/handler.go @@ -21,7 +21,7 @@ func HandleFileMode(filePath string, engine *engine.DbEngine) { } sequences := bytesToSequences(content) - fmt.Print(evaluateInEngine(sequences, engine)) + fmt.Print(engine.Evaluate(sequences)) } // HandleStreamMode - Handle GO4SQL use case where client sends input via stdin @@ -29,7 +29,7 @@ func HandleStreamMode(engine *engine.DbEngine) { reader := bufio.NewScanner(os.Stdin) for reader.Scan() { sequences := bytesToSequences(reader.Bytes()) - fmt.Print(evaluateInEngine(sequences, engine)) + fmt.Print(engine.Evaluate(sequences)) } err := reader.Err() if err != nil { @@ -63,76 +63,6 @@ func HandleSocketMode(port int, engine *engine.DbEngine) { go handleSocketClient(conn, engine) } } -func evaluateInEngine(sequences *ast.Sequence, engineSQL *engine.DbEngine) string { - commands := sequences.Commands - - result := "" - for commandIndex, command := range commands { - - switch mappedCommand := command.(type) { - case *ast.WhereCommand: - continue - case *ast.OrderByCommand: - continue - case *ast.CreateCommand: - engineSQL.CreateTable(mappedCommand) - result += "Table '" + mappedCommand.Name.GetToken().Literal + "' has been created\n" - continue - case *ast.InsertCommand: - engineSQL.InsertIntoTable(mappedCommand) - result += "Data Inserted\n" - continue - case *ast.SelectCommand: - result += getSelectResponse(commandIndex, &commands, engineSQL, mappedCommand) + "\n" - continue - case *ast.DeleteCommand: - nextCommandIndex := commandIndex + 1 - if nextCommandIndex != len(commands) { - whereCommand, whereCommandIsValid := commands[nextCommandIndex].(*ast.WhereCommand) - - if whereCommandIsValid { - engineSQL.DeleteFromTable(mappedCommand, whereCommand) - } - } - result += "Data from '" + mappedCommand.Name.GetToken().Literal + "' has been deleted\n" - continue - default: - log.Fatalf("Unsupported Command detected: %v", command) - } - } - - return result -} - -func getSelectResponse(commandIndex int, commands *[]ast.Command, engineSQL *engine.DbEngine, selectCommand *ast.SelectCommand) string { - // TODO: this function should be a method of ast.SelectCommand - nextCommandIndex := commandIndex + 1 - - if nextCommandIndex != len(*commands) { - whereCommand, whereCommandIsValid := (*commands)[nextCommandIndex].(*ast.WhereCommand) - - // TODO: It cannot be like that. Have to be refactored to tree structure. - if whereCommandIsValid { - if nextCommandIndex+1 < len(*commands) { - orderByCommand, orderByCommandIsValid := (*commands)[nextCommandIndex+1].(*ast.OrderByCommand) - - if orderByCommandIsValid { - return engineSQL.SelectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand).ToString() - } - } - - return engineSQL.SelectFromTableWithWhere(selectCommand, whereCommand).ToString() - } - - orderByCommand, orderByCommandIsValid := (*commands)[nextCommandIndex].(*ast.OrderByCommand) - - if orderByCommandIsValid { - return engineSQL.SelectFromTableWithOrderBy(selectCommand, orderByCommand).ToString() - } - } - - return engineSQL.SelectFromTable(selectCommand).ToString() -} func bytesToSequences(content []byte) *ast.Sequence { lex := lexer.RunLexer(string(content)) @@ -158,7 +88,7 @@ func handleSocketClient(conn net.Conn, engine *engine.DbEngine) { log.Fatal("Error:", err) } sequences := bytesToSequences(buffer) - commandResult := evaluateInEngine(sequences, engine) + commandResult := engine.Evaluate(sequences) if len(commandResult) > 0 { _, err = conn.Write([]byte(commandResult)) diff --git a/parser/parser.go b/parser/parser.go index 15e49ec..d6d4433 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -439,23 +439,24 @@ func (parser *Parser) ParseSequence() *ast.Sequence { case token.DELETE: command = parser.parseDeleteCommand() case token.WHERE: - if len(sequence.Commands) == 0 { - log.Fatal("Syntax error, Where Command can't be used without predecessor") - } - lastStartingToken := sequence.Commands[len(sequence.Commands)-1].TokenLiteral() - if lastStartingToken != token.SELECT && lastStartingToken != token.DELETE { + lastCommand := parser.getLastCommand(sequence) + + if lastCommand.TokenLiteral() == token.SELECT { + lastCommand.(*ast.SelectCommand).WhereCommand = parser.parseWhereCommand().(*ast.WhereCommand) + } else if lastCommand.TokenLiteral() == token.DELETE { + lastCommand.(*ast.DeleteCommand).WhereCommand = parser.parseWhereCommand().(*ast.WhereCommand) + } else { log.Fatal("Syntax error, WHERE command needs SELECT or DELETE command before") } - command = parser.parseWhereCommand() case token.ORDER: - if len(sequence.Commands) == 0 { - log.Fatal("Syntax error, Order Command can't be used without predecessor") - } - lastStartingToken := sequence.Commands[len(sequence.Commands)-1].TokenLiteral() - if lastStartingToken != token.SELECT && lastStartingToken != token.WHERE { - log.Fatal("Syntax error, WHERE command needs SELECT or WHERE command before") + lastCommand := parser.getLastCommand(sequence) + + if lastCommand.TokenLiteral() != token.SELECT { + log.Fatal("Syntax error, ORDER BY command needs SELECT command before") } - command = parser.parseOrderByCommand() + + selectCommand := lastCommand.(*ast.SelectCommand) + selectCommand.OrderByCommand = parser.parseOrderByCommand().(*ast.OrderByCommand) default: log.Fatal("Syntax error, invalid command found: ", parser.currentToken.Type) } @@ -468,3 +469,11 @@ func (parser *Parser) ParseSequence() *ast.Sequence { return sequence } + +func (parser *Parser) getLastCommand(sequence *ast.Sequence) ast.Command { + if len(sequence.Commands) == 0 { + log.Fatal("Syntax error, Where Command can't be used without predecessor") + } + lastCommand := sequence.Commands[len(sequence.Commands)-1] + return lastCommand +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 6ca3640..2ce2ad0 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -174,11 +174,16 @@ func TestParseWhereCommand(t *testing.T) { parserInstance := New(lexer) sequences := parserInstance.ParseSequence() - if len(sequences.Commands) != 2 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements, got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + if !selectCommand.HasWhereCommand() { + t.Fatalf("sequences does not contain where command") } - if !whereStatementIsValid(t, sequences.Commands[1], tt.expectedExpression) { + if !whereStatementIsValid(t, selectCommand.WhereCommand, tt.expectedExpression) { return } } @@ -200,8 +205,8 @@ func TestParseDeleteCommand(t *testing.T) { parserInstance := New(lexer) sequences := parserInstance.ParseSequence() - if len(sequences.Commands) != 2 { - t.Fatalf("sequences does not contain 2 statements. got=%d", len(sequences.Commands)) + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) } actualDeleteCommand, ok := sequences.Commands[0].(*ast.DeleteCommand) @@ -217,7 +222,11 @@ func TestParseDeleteCommand(t *testing.T) { t.Errorf("Table name of DeleteCommand is not %s. got=%s", expectedDeleteCommand.Name.GetToken().Literal, actualDeleteCommand.Name.GetToken().Literal) } - if !whereStatementIsValid(t, sequences.Commands[1], expectedWhereCommand) { + if !actualDeleteCommand.HasWhereCommand() { + t.Fatalf("sequences does not contain where command") + } + + if !whereStatementIsValid(t, actualDeleteCommand.WhereCommand, expectedWhereCommand) { return } } @@ -239,20 +248,21 @@ func TestSelectWithOrderByCommand(t *testing.T) { parserInstance := New(lexer) sequences := parserInstance.ParseSequence() - if len(sequences.Commands) != 2 { - t.Fatalf("sequences does not contain 2 statements. got=%d", len(sequences.Commands)) + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) } - if !testSelectStatement(t, sequences.Commands[0], expectedTableName, expectedColumnName) { + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName) { return } - actualOrderByCommand, orderByCommandIsOk := sequences.Commands[1].(*ast.OrderByCommand) - if !orderByCommandIsOk { - t.Errorf("actualDeleteCommand is not %T. got=%T", &ast.OrderByCommand{}, sequences.Commands[0]) + if !selectCommand.HasOrderByCommand() { + t.Fatalf("sequences does not contain where command") } - testOrderByCommands(t, expectedOrderByCommand, actualOrderByCommand) + testOrderByCommands(t, expectedOrderByCommand, selectCommand.OrderByCommand) } func TestParseLogicOperatorsInCommand(t *testing.T) { @@ -308,11 +318,17 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { parserInstance := New(lexer) sequences := parserInstance.ParseSequence() - if len(sequences.Commands) != 2 { - t.Fatalf("sequences does not contain 2 statements. got=%d", len(sequences.Commands)) + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !selectCommand.HasWhereCommand() { + t.Fatalf("sequences does not contain where command") } - if !whereStatementIsValid(t, sequences.Commands[1], tt.expectedExpression) { + if !whereStatementIsValid(t, selectCommand.WhereCommand, tt.expectedExpression) { t.Fatalf("Actual expression and expected one are different") } } @@ -469,13 +485,13 @@ func validateConditionExpression(second ast.Expression, conditionExpression *ast } func validateBooleanExpressions(second ast.Expression, booleanExpression *ast.BooleanExpression) bool { - secondBooleanExpresion, secondBooleanExpresionIsValid := second.(ast.BooleanExpression) + secondBooleanExpression, secondBooleanExpressionIsValid := second.(ast.BooleanExpression) - if !secondBooleanExpresionIsValid { + if !secondBooleanExpressionIsValid { return false } - if booleanExpression.Boolean.Literal != secondBooleanExpresion.Boolean.Literal { + if booleanExpression.Boolean.Literal != secondBooleanExpression.Boolean.Literal { return false } From 414c56cec86c5ac4d577bd7e6780243979665226 Mon Sep 17 00:00:00 2001 From: LissaGreense Date: Wed, 22 May 2024 23:55:19 +0200 Subject: [PATCH 05/21] Fix delete command --- engine/engine.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/engine/engine.go b/engine/engine.go index ee46ff9..98ec3b3 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -29,7 +29,7 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) string { commands := sequences.Commands result := "" - for commandIndex, command := range commands { + for _, command := range commands { switch mappedCommand := command.(type) { case *ast.WhereCommand: @@ -48,13 +48,9 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) string { result += engine.GetSelectResponse(mappedCommand) + "\n" continue case *ast.DeleteCommand: - nextCommandIndex := commandIndex + 1 - if nextCommandIndex != len(commands) { - whereCommand, whereCommandIsValid := commands[nextCommandIndex].(*ast.WhereCommand) - - if whereCommandIsValid { - engine.deleteFromTable(mappedCommand, whereCommand) - } + deleteCommand := command.(*ast.DeleteCommand) + if deleteCommand.HasWhereCommand() { + engine.deleteFromTable(mappedCommand, deleteCommand.WhereCommand) } result += "Data from '" + mappedCommand.Name.GetToken().Literal + "' has been deleted\n" continue From 2a6a1a07374d18e15a2ac8eb4555f3065cd70e6c Mon Sep 17 00:00:00 2001 From: ixior462 Date: Wed, 29 May 2024 00:38:19 +0200 Subject: [PATCH 06/21] Add Drop Function --- .github/expected_results/end2end.txt | 1 + ast/ast.go | 12 ++ engine/engine.go | 23 ++-- engine/engine_test.go | 160 +++++++++++++++------------ lexer/lexer_test.go | 104 +++++------------ parser/parser.go | 28 ++++- parser/parser_test.go | 29 +++++ test_file | 1 + token/token.go | 2 + 9 files changed, 208 insertions(+), 152 deletions(-) diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt index 5a89e9a..2604fbd 100644 --- a/.github/expected_results/end2end.txt +++ b/.github/expected_results/end2end.txt @@ -35,3 +35,4 @@ Data from 'tbl' has been deleted | 'goodbye' | | 'hello' | +-----------+ +Table: 'tbl' has been dropped diff --git a/ast/ast.go b/ast/ast.go index 3fd6148..21f029f 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -223,6 +223,18 @@ type DeleteCommand struct { func (ls DeleteCommand) CommandNode() {} func (ls DeleteCommand) TokenLiteral() string { return ls.Token.Literal } +// DropCommand - Part of Command that represent dropping table +// +// Example: +// DROP TABLE table; +type DropCommand struct { + Token token.Token + Name Identifier // name of the table +} + +func (ls DropCommand) CommandNode() {} +func (ls DropCommand) TokenLiteral() string { return ls.Token.Literal } + // HasWhereCommand - returns true if optional HasWhereCommand is present in SelectCommand // // Example: diff --git a/engine/engine.go b/engine/engine.go index 98ec3b3..7271d7c 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -45,7 +45,7 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) string { result += "Data Inserted\n" continue case *ast.SelectCommand: - result += engine.GetSelectResponse(mappedCommand) + "\n" + result += engine.getSelectResponse(mappedCommand).ToString() + "\n" continue case *ast.DeleteCommand: deleteCommand := command.(*ast.DeleteCommand) @@ -54,6 +54,10 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) string { } result += "Data from '" + mappedCommand.Name.GetToken().Literal + "' has been deleted\n" continue + case *ast.DropCommand: + engine.dropTable(mappedCommand) + result += "Table: '" + mappedCommand.Name.GetToken().Literal + "' has been dropped\n" + continue default: log.Fatalf("Unsupported Command detected: %v", command) } @@ -62,21 +66,21 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) string { return result } -// GetSelectResponse - Returns Select response basing on ast.OrderByCommand and ast.WhereCommand included in this Select -func (engine *DbEngine) GetSelectResponse(selectCommand *ast.SelectCommand) string { +// getSelectResponse - Returns Select response basing on ast.OrderByCommand and ast.WhereCommand included in this Select +func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) *Table { if selectCommand.HasWhereCommand() { whereCommand := selectCommand.WhereCommand if selectCommand.HasOrderByCommand() { orderByCommand := selectCommand.OrderByCommand - return engine.selectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand).ToString() + return engine.selectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand) } - return engine.selectFromTableWithWhere(selectCommand, whereCommand).ToString() + return engine.selectFromTableWithWhere(selectCommand, whereCommand) } if selectCommand.HasOrderByCommand() { orderByCommand := selectCommand.OrderByCommand - return engine.selectFromTableWithOrderBy(selectCommand, orderByCommand).ToString() + return engine.selectFromTableWithOrderBy(selectCommand, orderByCommand) } - return engine.selectFromTable(selectCommand).ToString() + return engine.selectFromTable(selectCommand) } // createTable - initialize new table in engine with specified name @@ -159,6 +163,11 @@ func (engine *DbEngine) deleteFromTable(deleteCommand *ast.DeleteCommand, whereC engine.Tables[deleteCommand.Name.Token.Literal] = engine.getFilteredTable(table, whereCommand, true) } +// dropTable - Drop table with given name +func (engine *DbEngine) dropTable(dropCommand *ast.DropCommand) { + delete(engine.Tables, dropCommand.Name.GetToken().Literal) +} + // selectFromTableWithWhere - Return Table containing all values requested by SelectCommand and filtered by WhereCommand func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand) *Table { table, exist := engine.Tables[selectCommand.Name.Token.Literal] diff --git a/engine/engine_test.go b/engine/engine_test.go index b6380b1..fff7e6f 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -1,7 +1,6 @@ package engine import ( - "strings" "testing" "github.com/LissaGreense/GO4SQL/ast" @@ -9,8 +8,39 @@ import ( "github.com/LissaGreense/GO4SQL/parser" ) +func TestCreate(t *testing.T) { + simpleCreateCase := engineDBContentTestSuite{ + inputs: []string{"CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );"}, + expectedTableNames: []string{"tb1"}, + } + + simpleCreateCase.runTestSuite(t) + + multiplyCreationCase := engineDBContentTestSuite{ + inputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + "CREATE TABLE tb2( one TEXT, two INT, three INT, four TEXT );", + }, + expectedTableNames: []string{"tb1", "tb2"}, + } + + multiplyCreationCase.runTestSuite(t) + +} + +func TestDrop(t *testing.T) { + simpleDropCase := engineDBContentTestSuite{ + inputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + "DROP TABLE tb1;", + }, + expectedTableNames: []string{}, + } + simpleDropCase.runTestSuite(t) +} + func TestSelectCommand(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -32,7 +62,7 @@ func TestSelectCommand(t *testing.T) { } func TestSelectWithColumnNamesCommand(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -54,7 +84,7 @@ func TestSelectWithColumnNamesCommand(t *testing.T) { } func TestSelectWithWhereEqual(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -75,7 +105,7 @@ func TestSelectWithWhereEqual(t *testing.T) { func TestSelectWithWhereNotEqual(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -97,7 +127,7 @@ func TestSelectWithWhereNotEqual(t *testing.T) { func TestSelectWithWhereLogicalOperationAnd(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -118,7 +148,7 @@ func TestSelectWithWhereLogicalOperationAnd(t *testing.T) { func TestSelectWithWhereLogicalOperationOR(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -140,7 +170,7 @@ func TestSelectWithWhereLogicalOperationOR(t *testing.T) { func TestSelectWithWhereLogicalOperationOROperationAND(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -162,7 +192,7 @@ func TestSelectWithWhereLogicalOperationOROperationAND(t *testing.T) { func TestSelectWithWhereEqualToTrue(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -185,7 +215,7 @@ func TestSelectWithWhereEqualToTrue(t *testing.T) { func TestSelectWithWhereEqualToFalse(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -203,7 +233,7 @@ func TestSelectWithWhereEqualToFalse(t *testing.T) { func TestDelete(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -225,7 +255,7 @@ func TestDelete(t *testing.T) { } func TestOrderBy(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -247,7 +277,7 @@ func TestOrderBy(t *testing.T) { } func TestOrderByWithWhere(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -268,7 +298,7 @@ func TestOrderByWithWhere(t *testing.T) { } func TestOrderByWithMultipleSorts(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -291,59 +321,52 @@ func TestOrderByWithMultipleSorts(t *testing.T) { engineTestSuite.runTestSuite(t) } -type engineTestSuite struct { +type engineDBContentTestSuite struct { + inputs []string + expectedTableNames []string +} + +func (engineTestSuite *engineDBContentTestSuite) runTestSuite(t *testing.T) { + sequences := getSequences(inputsToString(engineTestSuite.inputs)) + engine := New() + engine.Evaluate(sequences) + + if len(engine.Tables) != len(engineTestSuite.expectedTableNames) { + t.Fatalf("Number of tables is incorrect, should be %d, got %d", len(engineTestSuite.expectedTableNames), len(engine.Tables)) + } + + for _, tableName := range engineTestSuite.expectedTableNames { + if engine.Tables[tableName] == nil { + t.Fatalf("Expected table '%s' does not exist", tableName) + } + } +} + +type engineTableContentTestSuite struct { createInputs []string insertAndDeleteInputs []string selectInput string expectedOutput [][]string } -func (engineTestSuite *engineTestSuite) runTestSuite(t *testing.T) { - input := "" +func (engineTestSuite *engineTableContentTestSuite) runTestSuite(t *testing.T) { expectedSequencesNumber := 0 - for inputIndex := 0; inputIndex < len(engineTestSuite.createInputs); inputIndex++ { - input += engineTestSuite.createInputs[inputIndex] + "\n" - } - for inputIndex := 0; inputIndex < len(engineTestSuite.insertAndDeleteInputs); inputIndex++ { - input += engineTestSuite.insertAndDeleteInputs[inputIndex] + "\n" - } - input += engineTestSuite.selectInput - lexerInstance := lexer.RunLexer(input) - parserInstance := parser.New(lexerInstance) - sequences := parserInstance.ParseSequence() - - expectedSequencesNumber += len(engineTestSuite.createInputs) + len(engineTestSuite.insertAndDeleteInputs) + 1 - - var actualTable *Table - engine := engineTestSuite.getEngineWithInsertedValues(sequences) - selectCommand := sequences.Commands[len(sequences.Commands)-1].(*ast.SelectCommand) - - if strings.Contains(engineTestSuite.selectInput, " WHERE ") { + input := inputsToString(engineTestSuite.createInputs) + inputsToString(engineTestSuite.insertAndDeleteInputs) - // WHERE CONDITION - if len(sequences.Commands) != expectedSequencesNumber { - t.Fatalf("sequences does not contain %d statements. got=%d", expectedSequencesNumber, len(sequences.Commands)) - } - if strings.Contains(engineTestSuite.selectInput, "ORDER BY") { - actualTable = engine.selectFromTableWithWhereAndOrderBy(selectCommand, selectCommand.WhereCommand, selectCommand.OrderByCommand) - } else { - actualTable = engine.selectFromTableWithWhere(selectCommand, selectCommand.WhereCommand) - } + sequencesWithoutSelect := getSequences(input) + selectCommand := getSequences(engineTestSuite.selectInput) - } else { + expectedSequencesNumber += len(engineTestSuite.createInputs) + len(engineTestSuite.insertAndDeleteInputs) + 1 - // NO WHERE CONDITION - if len(sequences.Commands) != expectedSequencesNumber { - t.Fatalf("sequences does not contain %d statements. got=%d", expectedSequencesNumber, len(sequences.Commands)) - } - if strings.Contains(engineTestSuite.selectInput, "ORDER BY") { - actualTable = engine.selectFromTableWithOrderBy(selectCommand, selectCommand.OrderByCommand) - } else { - actualTable = engine.selectFromTable(selectCommand) - } + if len(sequencesWithoutSelect.Commands)+len(selectCommand.Commands) != expectedSequencesNumber { + t.Fatalf("sequences does not contain %d statements. got=%d", expectedSequencesNumber, len(sequencesWithoutSelect.Commands)) } + engine := New() + engine.Evaluate(sequencesWithoutSelect) + actualTable := engine.getSelectResponse(selectCommand.Commands[0].(*ast.SelectCommand)) + if len(engineTestSuite.expectedOutput) == 0 { if len(actualTable.Columns[0].Values) != 0 { t.Fatalf("Number of rows is incorrect, should be 0, got %d", len(actualTable.Columns)) @@ -368,19 +391,20 @@ func (engineTestSuite *engineTestSuite) runTestSuite(t *testing.T) { } -func (engineTestSuite *engineTestSuite) getEngineWithInsertedValues(sequences *ast.Sequence) *DbEngine { - engine := New() - for commandIndex := 0; commandIndex < len(sequences.Commands); commandIndex++ { - if createCommand, ok := sequences.Commands[commandIndex].(*ast.CreateCommand); ok { - engine.createTable(createCommand) - } - if insertCommand, ok := sequences.Commands[commandIndex].(*ast.InsertCommand); ok { - engine.insertIntoTable(insertCommand) - } - if deleteCommand, ok := sequences.Commands[commandIndex].(*ast.DeleteCommand); ok { - whereCommand := deleteCommand.WhereCommand - engine.deleteFromTable(deleteCommand, whereCommand) - } +func inputsToString(inputs []string) string { + input := "" + + for inputIndex := 0; inputIndex < len(inputs); inputIndex++ { + input += inputs[inputIndex] + "\n" } - return engine + + return input +} + +func getSequences(input string) *ast.Sequence { + lexerInstance := lexer.RunLexer(input) + parserInstance := parser.New(lexerInstance) + sequences := parserInstance.ParseSequence() + + return sequences } diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 74a1010..7ef4016 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -55,21 +55,7 @@ func TestLexer(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } - - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } - } + runLexerTestSuite(t, input, tests) } func TestLexerWithNumbersMixedInLitterals(t *testing.T) { @@ -121,21 +107,7 @@ func TestLexerWithNumbersMixedInLitterals(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } - - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } - } + runLexerTestSuite(t, input, tests) } func TestLexerWithNumbersWithWhitespacesIdentifier(t *testing.T) { @@ -187,21 +159,7 @@ func TestLexerWithNumbersWithWhitespacesIdentifier(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } - - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } - } + runLexerTestSuite(t, input, tests) } func TestLogicalStatements(t *testing.T) { @@ -231,21 +189,7 @@ func TestLogicalStatements(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } - - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } - } + runLexerTestSuite(t, input, tests) } func TestDeleteStatement(t *testing.T) { @@ -267,21 +211,7 @@ func TestDeleteStatement(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } - - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } - } + runLexerTestSuite(t, input, tests) } func TestOrderByStatement(t *testing.T) { @@ -302,6 +232,29 @@ func TestOrderByStatement(t *testing.T) { {token.EOF, ""}, } + runLexerTestSuite(t, input, tests) +} + +func TestDropStatement(t *testing.T) { + input := `DROP TABLE table;` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.DROP, "DROP"}, + {token.TABLE, "TABLE"}, + {token.IDENT, "table"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func runLexerTestSuite(t *testing.T, input string, tests []struct { + expectedType token.Type + expectedLiteral string +}) { l := RunLexer(input) for i, tt := range tests { @@ -317,5 +270,4 @@ func TestOrderByStatement(t *testing.T) { i, tt.expectedLiteral, tok.Literal) } } - } diff --git a/parser/parser.go b/parser/parser.go index d6d4433..280731e 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -240,7 +240,7 @@ func (parser *Parser) parseWhereCommand() ast.Command { // parseDeleteCommand - Return ast.DeleteCommand created from tokens and validate the syntax // // Example of input parsable to the ast.DeleteCommand: -// DELETE FROM table +// DELETE FROM table; func (parser *Parser) parseDeleteCommand() ast.Command { // token.DELETE already at current position in parser deleteCommand := &ast.DeleteCommand{Token: parser.currentToken} @@ -262,6 +262,30 @@ func (parser *Parser) parseDeleteCommand() ast.Command { return deleteCommand } +// parseDropCommand - Return ast.DropCommand created from tokens and validate the syntax +// +// Example of input parsable to the ast.DropCommand: +// DROP TABLE table; +func (parser *Parser) parseDropCommand() ast.Command { + // token.DROP already at current position in parser + dropCommand := &ast.DropCommand{Token: parser.currentToken} + + // token.DROP no longer needed + parser.nextToken() + + validateTokenAndSkip(parser, []token.Type{token.TABLE}) + + validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + dropCommand.Name = ast.Identifier{Token: parser.currentToken} + + // token.IDENT no longer needed + parser.nextToken() + + parser.skipIfCurrentTokenIsSemicolon() + + return dropCommand +} + // parseOrderByCommand - Return ast.OrderByCommand created from tokens and validate the syntax // // Example of input parsable to the ast.OrderByCommand: @@ -438,6 +462,8 @@ func (parser *Parser) ParseSequence() *ast.Sequence { command = parser.parseSelectCommand() case token.DELETE: command = parser.parseDeleteCommand() + case token.DROP: + command = parser.parseDropCommand() case token.WHERE: lastCommand := parser.getLastCommand(sequence) diff --git a/parser/parser_test.go b/parser/parser_test.go index 2ce2ad0..0b61c2c 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -231,6 +231,35 @@ func TestParseDeleteCommand(t *testing.T) { } } +func TestParseDropCommand(t *testing.T) { + input := "DROP TABLE table;" + expectedDropCommand := ast.DropCommand{ + Token: token.Token{Type: token.DROP, Literal: "DROP"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "table"}}, + } + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences := parserInstance.ParseSequence() + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + actualDropCommand, ok := sequences.Commands[0].(*ast.DropCommand) + if !ok { + t.Errorf("actualDropCommand is not %T. got=%T", &ast.DropCommand{}, sequences.Commands[0]) + } + + if expectedDropCommand.TokenLiteral() != actualDropCommand.TokenLiteral() { + t.Errorf("TokenLiteral of DropCommand is not %s. got=%s", expectedDropCommand.TokenLiteral(), actualDropCommand.TokenLiteral()) + } + + if expectedDropCommand.Name.GetToken().Literal != actualDropCommand.Name.GetToken().Literal { + t.Errorf("Table name of DropCommand is not %s. got=%s", expectedDropCommand.Name.GetToken().Literal, actualDropCommand.Name.GetToken().Literal) + } +} + func TestSelectWithOrderByCommand(t *testing.T) { input := "SELECT * FROM tableName ORDER BY colName1 DESC;" expectedSortPattern := ast.SortPattern{ diff --git a/test_file b/test_file index e69a5c3..5f4c31d 100644 --- a/test_file +++ b/test_file @@ -9,3 +9,4 @@ DELETE FROM tbl WHERE one EQUAL 'byebye'; SELECT * FROM tbl; SELECT one FROM tbl WHERE TRUE ORDER BY two ASC, four DESC; + DROP TABLE tbl; diff --git a/token/token.go b/token/token.go index 69e5956..5643cab 100644 --- a/token/token.go +++ b/token/token.go @@ -30,6 +30,7 @@ const ( // CREATE - Keywords CREATE = "CREATE" + DROP = "DROP" TABLE = "TABLE" INSERT = "INSERT" INTO = "INTO" @@ -63,6 +64,7 @@ var keywords = map[string]Type{ "TEXT": TEXT, "INT": INT, "CREATE": CREATE, + "DROP": DROP, "TABLE": TABLE, "INSERT": INSERT, "INTO": INTO, From 78246a2723ae5f41e70be92fff0e14188433dade Mon Sep 17 00:00:00 2001 From: ixior462 Date: Tue, 11 Jun 2024 18:25:54 +0200 Subject: [PATCH 07/21] Add error handling for parser --- engine/engine_test.go | 7 +- modes/handler.go | 5 +- parser/parser.go | 298 +++++++++++++++++++++++++++++------------- parser/parser_test.go | 40 ++++-- 4 files changed, 251 insertions(+), 99 deletions(-) diff --git a/engine/engine_test.go b/engine/engine_test.go index fff7e6f..ade7f6b 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -1,6 +1,7 @@ package engine import ( + "log" "testing" "github.com/LissaGreense/GO4SQL/ast" @@ -404,7 +405,9 @@ func inputsToString(inputs []string) string { func getSequences(input string) *ast.Sequence { lexerInstance := lexer.RunLexer(input) parserInstance := parser.New(lexerInstance) - sequences := parserInstance.ParseSequence() - + sequences, err := parserInstance.ParseSequence() + if err != nil { + log.Fatal(err) + } return sequences } diff --git a/modes/handler.go b/modes/handler.go index 13a0710..40cb1a5 100644 --- a/modes/handler.go +++ b/modes/handler.go @@ -67,7 +67,10 @@ func HandleSocketMode(port int, engine *engine.DbEngine) { func bytesToSequences(content []byte) *ast.Sequence { lex := lexer.RunLexer(string(content)) parserInstance := parser.New(lex) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + log.Fatal(err) + } return sequences } diff --git a/parser/parser.go b/parser/parser.go index 280731e..801a46f 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -1,8 +1,7 @@ package parser import ( - "log" - + "errors" "github.com/LissaGreense/GO4SQL/ast" "github.com/LissaGreense/GO4SQL/lexer" "github.com/LissaGreense/GO4SQL/token" @@ -34,19 +33,23 @@ func (parser *Parser) nextToken() { } // validateTokenAndSkip - Check if current token type is appearing in provided expectedTokens array then move to the next token -func validateTokenAndSkip(parser *Parser, expectedTokens []token.Type) { - validateToken(parser.currentToken.Type, expectedTokens) +func validateTokenAndSkip(parser *Parser, expectedTokens []token.Type) error { + err := validateToken(parser.currentToken.Type, expectedTokens) + + if err != nil { + return err + } // Ignore validated token parser.nextToken() + return nil } // validateToken - Check if current token type is appearing in provided expectedTokens array -func validateToken(tokenType token.Type, expectedTokens []token.Type) { +func validateToken(tokenType token.Type, expectedTokens []token.Type) error { var contains = false var tokensPrintMessage = "" for i, x := range expectedTokens { - if i == 0 { tokensPrintMessage += string(x) } else { @@ -59,34 +62,49 @@ func validateToken(tokenType token.Type, expectedTokens []token.Type) { } } if !contains { - log.Fatal("Syntax error, expecting: ", tokensPrintMessage, ", got: ", tokenType) + return errors.New("Syntax error, expecting: " + tokensPrintMessage + ", got: " + string(tokenType)) } + return nil } // parseCreateCommand - Return ast.CreateCommand created from tokens and validate the syntax // // Example of input parsable to the ast.CreateCommand: // create table tbl( one TEXT , two INT ); -func (parser *Parser) parseCreateCommand() ast.Command { +func (parser *Parser) parseCreateCommand() (ast.Command, error) { // token.CREATE already at current position in parser createCommand := &ast.CreateCommand{Token: parser.currentToken} // Skip token.CREATE parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.TABLE}) + err := validateTokenAndSkip(parser, []token.Type{token.TABLE}) + if err != nil { + return nil, err + } + + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) createCommand.Name = ast.Identifier{Token: parser.currentToken} // Skip token.IDENT parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + err = validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + if err != nil { + return nil, err + } // Begin of inside Paren for parser.currentToken.Type == token.IDENT { - validateToken(parser.peekToken.Type, []token.Type{token.TEXT, token.INT}) + err = validateToken(parser.peekToken.Type, []token.Type{token.TEXT, token.INT}) + if err != nil { + return nil, err + } + createCommand.ColumnNames = append(createCommand.ColumnNames, parser.currentToken.Literal) createCommand.ColumnTypes = append(createCommand.ColumnTypes, parser.peekToken) @@ -104,10 +122,16 @@ func (parser *Parser) parseCreateCommand() ast.Command { } // End of inside Paren - validateTokenAndSkip(parser, []token.Type{token.RPAREN}) - validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) + if err != nil { + return nil, err + } + err = validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + if err != nil { + return nil, err + } - return createCommand + return createCommand, nil } func (parser *Parser) skipIfCurrentTokenIsApostrophe() { @@ -126,27 +150,42 @@ func (parser *Parser) skipIfCurrentTokenIsSemicolon() { // // Example of input parsable to the ast.InsertCommand: // insert into tbl values( 'hello', 10 ); -func (parser *Parser) parseInsertCommand() ast.Command { +func (parser *Parser) parseInsertCommand() (ast.Command, error) { // token.INSERT already at current position in parser insertCommand := &ast.InsertCommand{Token: parser.currentToken} // Ignore token.INSERT parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.INTO}) + err := validateTokenAndSkip(parser, []token.Type{token.INTO}) + if err != nil { + return nil, err + } - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } insertCommand.Name = ast.Identifier{Token: parser.currentToken} // Ignore token.INDENT parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.VALUES}) - validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + err = validateTokenAndSkip(parser, []token.Type{token.VALUES}) + if err != nil { + return nil, err + } + err = validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + if err != nil { + return nil, err + } for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.APOSTROPHE { parser.skipIfCurrentTokenIsApostrophe() - validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL}) + if err != nil { + return nil, err + } insertCommand.Values = append(insertCommand.Values, parser.currentToken) // Ignore token.IDENT or token.LITERAL parser.nextToken() @@ -161,16 +200,22 @@ func (parser *Parser) parseInsertCommand() ast.Command { parser.nextToken() } - validateTokenAndSkip(parser, []token.Type{token.RPAREN}) - validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) - return insertCommand + err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) + if err != nil { + return nil, err + } + err = validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + if err != nil { + return nil, err + } + return insertCommand, nil } // parseSelectCommand - Return ast.SelectCommand created from tokens and validate the syntax // // Example of input parsable to the ast.SelectCommand: // SELECT col1, col2, col3 FROM tbl; -func (parser *Parser) parseSelectCommand() ast.Command { +func (parser *Parser) parseSelectCommand() (ast.Command, error) { // token.SELECT already at current position in parser selectCommand := &ast.SelectCommand{Token: parser.currentToken} @@ -184,7 +229,10 @@ func (parser *Parser) parseSelectCommand() ast.Command { } else { for parser.currentToken.Type == token.IDENT { // Get column name - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + err := validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } selectCommand.Space = append(selectCommand.Space, parser.currentToken) parser.nextToken() @@ -196,86 +244,110 @@ func (parser *Parser) parseSelectCommand() ast.Command { } } - validateTokenAndSkip(parser, []token.Type{token.FROM}) + err := validateTokenAndSkip(parser, []token.Type{token.FROM}) + if err != nil { + return nil, err + } selectCommand.Name = ast.Identifier{Token: parser.currentToken} // Ignore token.INDENT parser.nextToken() // expect SEMICOLON or WHERE - validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE, token.ORDER}) + err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE, token.ORDER}) + if err != nil { + return nil, err + } if parser.currentToken.Type == token.SEMICOLON { parser.nextToken() } - return selectCommand + return selectCommand, nil } // parseWhereCommand - Return ast.WhereCommand created from tokens and validate the syntax // // Example of input parsable to the ast.WhereCommand: // WHERE colName EQUAL 'potato' -func (parser *Parser) parseWhereCommand() ast.Command { +func (parser *Parser) parseWhereCommand() (ast.Command, error) { // token.WHERE already at current position in parser whereCommand := &ast.WhereCommand{Token: parser.currentToken} expressionIsValid := false // Ignore token.WHERE parser.nextToken() - - expressionIsValid, whereCommand.Expression = parser.getExpression() + var err error + expressionIsValid, whereCommand.Expression, err = parser.getExpression() + if err != nil { + return nil, err + } if !expressionIsValid { - log.Fatal("Expression withing Where statement couldn't be parsed correctly") + return nil, errors.New("Expression withing Where statement couldn't be parsed correctly") } - validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.ORDER}) + err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.ORDER}) + if err != nil { + return nil, err + } parser.skipIfCurrentTokenIsSemicolon() - return whereCommand + return whereCommand, nil } // parseDeleteCommand - Return ast.DeleteCommand created from tokens and validate the syntax // // Example of input parsable to the ast.DeleteCommand: // DELETE FROM table; -func (parser *Parser) parseDeleteCommand() ast.Command { +func (parser *Parser) parseDeleteCommand() (ast.Command, error) { // token.DELETE already at current position in parser deleteCommand := &ast.DeleteCommand{Token: parser.currentToken} // token.DELETE no longer needed parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.FROM}) + err := validateTokenAndSkip(parser, []token.Type{token.FROM}) + if err != nil { + return nil, err + } - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } deleteCommand.Name = ast.Identifier{Token: parser.currentToken} // token.IDENT no longer needed parser.nextToken() // expect WHERE - validateToken(parser.currentToken.Type, []token.Type{token.WHERE}) + err = validateToken(parser.currentToken.Type, []token.Type{token.WHERE}) - return deleteCommand + return deleteCommand, err } // parseDropCommand - Return ast.DropCommand created from tokens and validate the syntax // // Example of input parsable to the ast.DropCommand: // DROP TABLE table; -func (parser *Parser) parseDropCommand() ast.Command { +func (parser *Parser) parseDropCommand() (ast.Command, error) { // token.DROP already at current position in parser dropCommand := &ast.DropCommand{Token: parser.currentToken} // token.DROP no longer needed parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.TABLE}) + err := validateTokenAndSkip(parser, []token.Type{token.TABLE}) + if err != nil { + return nil, err + } - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } dropCommand.Name = ast.Identifier{Token: parser.currentToken} // token.IDENT no longer needed @@ -283,34 +355,43 @@ func (parser *Parser) parseDropCommand() ast.Command { parser.skipIfCurrentTokenIsSemicolon() - return dropCommand + return dropCommand, nil } // parseOrderByCommand - Return ast.OrderByCommand created from tokens and validate the syntax // // Example of input parsable to the ast.OrderByCommand: // ORDER BY colName ASC -func (parser *Parser) parseOrderByCommand() ast.Command { +func (parser *Parser) parseOrderByCommand() (ast.Command, error) { // token.ORDER already at current position in parser orderCommand := &ast.OrderByCommand{Token: parser.currentToken} // token.ORDER no longer needed parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.BY}) + err := validateTokenAndSkip(parser, []token.Type{token.BY}) + if err != nil { + return nil, err + } // ensure that loop below will execute at least once - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) // array of SortPattern for parser.currentToken.Type == token.IDENT { // Get column name - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } columnName := parser.currentToken parser.nextToken() // Get ASC or DESC - validateToken(parser.currentToken.Type, []token.Type{token.ASC, token.DESC}) + err = validateToken(parser.currentToken.Type, []token.Type{token.ASC, token.DESC}) + if err != nil { + return nil, err + } order := parser.currentToken parser.nextToken() @@ -324,9 +405,9 @@ func (parser *Parser) parseOrderByCommand() ast.Command { parser.nextToken() } - validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + err = validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) - return orderCommand + return orderCommand, err } // getExpression - Return proper structure of ast.Expression and validate the syntax @@ -335,30 +416,36 @@ func (parser *Parser) parseOrderByCommand() ast.Command { // - ast.OperationExpression // - ast.BooleanExpression // - ast.ConditionExpression -func (parser *Parser) getExpression() (bool, ast.Expression) { +func (parser *Parser) getExpression() (bool, ast.Expression, error) { booleanExpressionExists, booleanExpression := parser.getBooleanExpression() - conditionalExpressionExists, conditionalExpression := parser.getConditionalExpression() + conditionalExpressionExists, conditionalExpression, err := parser.getConditionalExpression() + if err != nil { + return false, nil, err + } - operationExpressionExists, operationExpression := parser.getOperationExpression(booleanExpressionExists, conditionalExpressionExists, booleanExpression, conditionalExpression) + operationExpressionExists, operationExpression, err := parser.getOperationExpression(booleanExpressionExists, conditionalExpressionExists, booleanExpression, conditionalExpression) + if err != nil { + return false, nil, err + } if operationExpressionExists { - return true, operationExpression + return true, operationExpression, err } if conditionalExpressionExists { - return true, conditionalExpression + return true, conditionalExpression, err } if booleanExpressionExists { - return true, booleanExpression + return true, booleanExpression, err } - return false, nil + return false, nil, err } // getOperationExpression - Return ast.OperationExpression created from tokens and validate the syntax -func (parser *Parser) getOperationExpression(booleanExpressionExists bool, conditionalExpressionExists bool, booleanExpression *ast.BooleanExpression, conditionalExpression *ast.ConditionExpression) (bool, *ast.OperationExpression) { +func (parser *Parser) getOperationExpression(booleanExpressionExists bool, conditionalExpressionExists bool, booleanExpression *ast.BooleanExpression, conditionalExpression *ast.ConditionExpression) (bool, *ast.OperationExpression, error) { operationExpression := &ast.OperationExpression{} if (booleanExpressionExists || conditionalExpressionExists) && (parser.currentToken.Type == token.OR || parser.currentToken.Type == token.AND) { @@ -373,18 +460,21 @@ func (parser *Parser) getOperationExpression(booleanExpressionExists bool, condi operationExpression.Operation = parser.currentToken parser.nextToken() - expressionIsValid, expression := parser.getExpression() + expressionIsValid, expression, err := parser.getExpression() + if err != nil { + return false, nil, err + } if !expressionIsValid { - log.Fatal("Couldn't parse right side of the OperationExpression after ", operationExpression.Operation.Literal, " token.") + return false, nil, errors.New("Couldn't parse right side of the OperationExpression after " + operationExpression.Operation.Literal + " token.") } operationExpression.Right = expression - return true, operationExpression + return true, operationExpression, nil } - return false, operationExpression + return false, operationExpression, nil } // getBooleanExpression - Return ast.BooleanExpression created from tokens and validate the syntax @@ -402,7 +492,7 @@ func (parser *Parser) getBooleanExpression() (bool, *ast.BooleanExpression) { } // getConditionalExpression - Return ast.ConditionExpression created from tokens and validate the syntax -func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression) { +func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression, error) { conditionalExpression := &ast.ConditionExpression{} switch parser.currentToken.Type { @@ -413,15 +503,21 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression parser.skipIfCurrentTokenIsApostrophe() conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) + err := validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) + if err != nil { + return false, nil, err + } case token.LITERAL: conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() default: - return false, conditionalExpression + return false, conditionalExpression, nil } - validateToken(parser.currentToken.Type, []token.Type{token.EQUAL, token.NOT}) + err := validateToken(parser.currentToken.Type, []token.Type{token.EQUAL, token.NOT}) + if err != nil { + return false, nil, err + } conditionalExpression.Condition = parser.currentToken parser.nextToken() @@ -433,58 +529,84 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression parser.skipIfCurrentTokenIsApostrophe() conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) + err := validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) + if err != nil { + return false, nil, err + } case token.LITERAL: conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() default: - log.Fatal("Syntax error, expecting: ", token.APOSTROPHE, ",", token.IDENT, ",", token.LITERAL, ", got: ", parser.currentToken.Literal) + return false, nil, errors.New("Syntax error, expecting: " + token.APOSTROPHE + "," + token.IDENT + "," + token.LITERAL + ", got: " + parser.currentToken.Literal) } - return true, conditionalExpression + return true, conditionalExpression, nil } // ParseSequence - Return ast.Sequence (sequence of commands) created from client input after tokenization // // Parse tokens returned by lexer to structures defines in ast package, and it's responsible for syntax validation. -func (parser *Parser) ParseSequence() *ast.Sequence { +func (parser *Parser) ParseSequence() (*ast.Sequence, error) { // Create variable holding sequence/commands sequence := &ast.Sequence{} for parser.currentToken.Type != token.EOF { var command ast.Command + var err error switch parser.currentToken.Type { case token.CREATE: - command = parser.parseCreateCommand() + command, err = parser.parseCreateCommand() case token.INSERT: - command = parser.parseInsertCommand() + command, err = parser.parseInsertCommand() case token.SELECT: - command = parser.parseSelectCommand() + command, err = parser.parseSelectCommand() case token.DELETE: - command = parser.parseDeleteCommand() + command, err = parser.parseDeleteCommand() case token.DROP: - command = parser.parseDropCommand() + command, err = parser.parseDropCommand() case token.WHERE: - lastCommand := parser.getLastCommand(sequence) + lastCommand, parserError := parser.getLastCommand(sequence) + if parserError != nil { + return nil, parserError + } if lastCommand.TokenLiteral() == token.SELECT { - lastCommand.(*ast.SelectCommand).WhereCommand = parser.parseWhereCommand().(*ast.WhereCommand) + newCommand, err := parser.parseWhereCommand() + if err != nil { + return nil, err + } + lastCommand.(*ast.SelectCommand).WhereCommand = newCommand.(*ast.WhereCommand) } else if lastCommand.TokenLiteral() == token.DELETE { - lastCommand.(*ast.DeleteCommand).WhereCommand = parser.parseWhereCommand().(*ast.WhereCommand) + newCommand, err := parser.parseWhereCommand() + if err != nil { + return nil, err + } + lastCommand.(*ast.DeleteCommand).WhereCommand = newCommand.(*ast.WhereCommand) } else { - log.Fatal("Syntax error, WHERE command needs SELECT or DELETE command before") + return nil, errors.New("Syntax error, WHERE command needs SELECT or DELETE command before") } case token.ORDER: - lastCommand := parser.getLastCommand(sequence) + lastCommand, parserError := parser.getLastCommand(sequence) + if parserError != nil { + return nil, parserError + } if lastCommand.TokenLiteral() != token.SELECT { - log.Fatal("Syntax error, ORDER BY command needs SELECT command before") + return nil, errors.New("Syntax error, ORDER BY command needs SELECT command before") } selectCommand := lastCommand.(*ast.SelectCommand) - selectCommand.OrderByCommand = parser.parseOrderByCommand().(*ast.OrderByCommand) + newCommand, err := parser.parseOrderByCommand() + if err != nil { + return nil, err + } + selectCommand.OrderByCommand = newCommand.(*ast.OrderByCommand) default: - log.Fatal("Syntax error, invalid command found: ", parser.currentToken.Type) + return nil, errors.New("Syntax error, invalid command found: " + parser.currentToken.Literal) + } + + if err != nil { + return nil, err } // Add command to the list of parsed commands @@ -493,13 +615,13 @@ func (parser *Parser) ParseSequence() *ast.Sequence { } } - return sequence + return sequence, nil } -func (parser *Parser) getLastCommand(sequence *ast.Sequence) ast.Command { +func (parser *Parser) getLastCommand(sequence *ast.Sequence) (ast.Command, error) { if len(sequence.Commands) == 0 { - log.Fatal("Syntax error, Where Command can't be used without predecessor") + return nil, errors.New("Syntax error, Where Command can't be used without predecessor") } lastCommand := sequence.Commands[len(sequence.Commands)-1] - return lastCommand + return lastCommand, nil } diff --git a/parser/parser_test.go b/parser/parser_test.go index 0b61c2c..a59ce72 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -23,7 +23,10 @@ func TestParserCreateCommand(t *testing.T) { for _, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } if len(sequences.Commands) != 1 { t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) @@ -79,7 +82,10 @@ func TestParseInsertCommand(t *testing.T) { for _, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } if len(sequences.Commands) != 1 { t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) @@ -130,7 +136,10 @@ func TestParseSelectCommand(t *testing.T) { for _, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } if len(sequences.Commands) != 1 { t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) @@ -172,7 +181,10 @@ func TestParseWhereCommand(t *testing.T) { for _, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } if len(sequences.Commands) != 1 { t.Fatalf("sequences does not contain 1 statements, got=%d", len(sequences.Commands)) @@ -203,7 +215,10 @@ func TestParseDeleteCommand(t *testing.T) { lexer := lexer.RunLexer(input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } if len(sequences.Commands) != 1 { t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) @@ -240,7 +255,10 @@ func TestParseDropCommand(t *testing.T) { lexer := lexer.RunLexer(input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } if len(sequences.Commands) != 1 { t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) @@ -275,7 +293,10 @@ func TestSelectWithOrderByCommand(t *testing.T) { lexer := lexer.RunLexer(input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } if len(sequences.Commands) != 1 { t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) @@ -345,7 +366,10 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { for _, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } if len(sequences.Commands) != 1 { t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) From 0e8cd0630c7c1588037ee5f53eeecc1783227cfc Mon Sep 17 00:00:00 2001 From: LissaGreense Date: Tue, 11 Jun 2024 21:32:18 +0200 Subject: [PATCH 08/21] second part of error handling --- engine/column.go | 9 ++- engine/engine.go | 121 ++++++++++++++++++++++++++--------------- engine/engine_test.go | 10 +++- engine/engine_utils.go | 9 ++- main.go | 12 +++- modes/handler.go | 66 +++++++++++++--------- parser/parser.go | 18 +++--- 7 files changed, 151 insertions(+), 94 deletions(-) diff --git a/engine/column.go b/engine/column.go index c8633d4..0b26de7 100644 --- a/engine/column.go +++ b/engine/column.go @@ -1,8 +1,7 @@ package engine import ( - "log" - + "fmt" "github.com/LissaGreense/GO4SQL/token" ) @@ -13,7 +12,7 @@ type Column struct { Values []ValueInterface } -func extractColumnContent(columns []*Column, wantedColumnNames *[]string) *Table { +func extractColumnContent(columns []*Column, wantedColumnNames *[]string) (*Table, error) { selectedTable := &Table{Columns: make([]*Column, 0)} mappedIndexes := make([]int, 0) for wantedColumnIndex := range *wantedColumnNames { @@ -23,7 +22,7 @@ func extractColumnContent(columns []*Column, wantedColumnNames *[]string) *Table break } if columnNameIndex == len(columns)-1 { - log.Fatal("Provided column name: " + (*wantedColumnNames)[wantedColumnIndex] + " doesn't exist") + return nil, fmt.Errorf("provided column name: %s doesn't exist", (*wantedColumnNames)[wantedColumnIndex]) } } } @@ -43,5 +42,5 @@ func extractColumnContent(columns []*Column, wantedColumnNames *[]string) *Table append(selectedTable.Columns[iColumn].Values, columns[mappedIndexes[iColumn]].Values[iRow]) } } - return selectedTable + return selectedTable, nil } diff --git a/engine/engine.go b/engine/engine.go index 7271d7c..4732cc2 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -1,11 +1,8 @@ package engine import ( - "errors" "fmt" - "log" "sort" - "strconv" "github.com/LissaGreense/GO4SQL/ast" "github.com/LissaGreense/GO4SQL/token" @@ -25,7 +22,7 @@ func New() *DbEngine { } // Evaluate - it takes sequences, map them to specific implementation and then process it in SQL engine -func (engine *DbEngine) Evaluate(sequences *ast.Sequence) string { +func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { commands := sequences.Commands result := "" @@ -37,20 +34,33 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) string { case *ast.OrderByCommand: continue case *ast.CreateCommand: - engine.createTable(mappedCommand) + err := engine.createTable(mappedCommand) + if err != nil { + return "", err + } result += "Table '" + mappedCommand.Name.GetToken().Literal + "' has been created\n" continue case *ast.InsertCommand: - engine.insertIntoTable(mappedCommand) + err := engine.insertIntoTable(mappedCommand) + if err != nil { + return "", err + } result += "Data Inserted\n" continue case *ast.SelectCommand: - result += engine.getSelectResponse(mappedCommand).ToString() + "\n" + selectOutput, err := engine.getSelectResponse(mappedCommand) + if err != nil { + return "", err + } + result += selectOutput.ToString() + "\n" continue case *ast.DeleteCommand: deleteCommand := command.(*ast.DeleteCommand) if deleteCommand.HasWhereCommand() { - engine.deleteFromTable(mappedCommand, deleteCommand.WhereCommand) + err := engine.deleteFromTable(mappedCommand, deleteCommand.WhereCommand) + if err != nil { + return "", err + } } result += "Data from '" + mappedCommand.Name.GetToken().Literal + "' has been deleted\n" continue @@ -59,15 +69,15 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) string { result += "Table: '" + mappedCommand.Name.GetToken().Literal + "' has been dropped\n" continue default: - log.Fatalf("Unsupported Command detected: %v", command) + return "", fmt.Errorf("unsupported Command detected: %v", command) } } - return result + return result, nil } // getSelectResponse - Returns Select response basing on ast.OrderByCommand and ast.WhereCommand included in this Select -func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) *Table { +func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) (*Table, error) { if selectCommand.HasWhereCommand() { whereCommand := selectCommand.WhereCommand if selectCommand.HasOrderByCommand() { @@ -84,11 +94,11 @@ func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) *Tab } // createTable - initialize new table in engine with specified name -func (engine *DbEngine) createTable(command *ast.CreateCommand) { +func (engine *DbEngine) createTable(command *ast.CreateCommand) error { _, exist := engine.Tables[command.Name.Token.Literal] if exist { - log.Fatal("Table with the name of " + command.Name.Token.Literal + " already exist!") + return fmt.Errorf("table with the name of %s already exist", command.Name.Token.Literal) } engine.Tables[command.Name.Token.Literal] = &Table{Columns: []*Column{}} @@ -100,42 +110,48 @@ func (engine *DbEngine) createTable(command *ast.CreateCommand) { Name: columnName, }) } + return nil } // insertIntoTable - Insert row of values into the table -func (engine *DbEngine) insertIntoTable(command *ast.InsertCommand) { +func (engine *DbEngine) insertIntoTable(command *ast.InsertCommand) error { table, exist := engine.Tables[command.Name.Token.Literal] if !exist { - log.Fatal("Table with the name of " + command.Name.Token.Literal + " doesn't exist!") + return fmt.Errorf("table with the name of %s doesn't exist", command.Name.Token.Literal) } columns := table.Columns if len(command.Values) != len(columns) { - log.Fatal("Invalid number of parameters in insert, should be: " + strconv.Itoa(len(columns)) + ", but got: " + strconv.Itoa(len(columns))) + return fmt.Errorf("invalid number of parameters in insert, should be: %d, but got: %d", len(columns), len(columns)) } for i := range columns { expectedToken := tokenMapper(columns[i].Type.Type) if expectedToken != command.Values[i].Type { - log.Fatal("Invalid Token TokenType in Insert Command, expecting: " + expectedToken + ", got: " + command.Values[i].Type) + return fmt.Errorf("invalid Token TokenType in Insert Command, expecting: %s, got: %s", expectedToken, command.Values[i].Type) } - columns[i].Values = append(columns[i].Values, getInterfaceValue(command.Values[i])) + interfaceValue, err := getInterfaceValue(command.Values[i]) + if err != nil { + return err + } + columns[i].Values = append(columns[i].Values, interfaceValue) } + return nil } // selectFromTable - Return Table containing all values requested by SelectCommand -func (engine *DbEngine) selectFromTable(command *ast.SelectCommand) *Table { +func (engine *DbEngine) selectFromTable(command *ast.SelectCommand) (*Table, error) { table, exist := engine.Tables[command.Name.Token.Literal] if !exist { - log.Fatal("Table with the name of " + command.Name.Token.Literal + " doesn't exist!") + return nil, fmt.Errorf("table with the name of %s doesn't exist", command.Name.Token.Literal) } return engine.selectFromProvidedTable(command, table) } -func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, table *Table) *Table { +func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, table *Table) (*Table, error) { columns := table.Columns wantedColumnNames := make([]string, 0) @@ -153,14 +169,21 @@ func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, tabl } // deleteFromTable - Delete all rows of data from table that match given condition -func (engine *DbEngine) deleteFromTable(deleteCommand *ast.DeleteCommand, whereCommand *ast.WhereCommand) { +func (engine *DbEngine) deleteFromTable(deleteCommand *ast.DeleteCommand, whereCommand *ast.WhereCommand) error { table, exist := engine.Tables[deleteCommand.Name.Token.Literal] if !exist { - log.Fatal("Table with the name of " + deleteCommand.Name.Token.Literal + " doesn't exist!") + return fmt.Errorf("table with the name of %s doesn't exist", deleteCommand.Name.Token.Literal) } - engine.Tables[deleteCommand.Name.Token.Literal] = engine.getFilteredTable(table, whereCommand, true) + newTable, err := engine.getFilteredTable(table, whereCommand, true) + + if err != nil { + return err + } + engine.Tables[deleteCommand.Name.Token.Literal] = newTable + + return nil } // dropTable - Drop table with given name @@ -169,32 +192,40 @@ func (engine *DbEngine) dropTable(dropCommand *ast.DropCommand) { } // selectFromTableWithWhere - Return Table containing all values requested by SelectCommand and filtered by WhereCommand -func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand) *Table { +func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand) (*Table, error) { table, exist := engine.Tables[selectCommand.Name.Token.Literal] if !exist { - log.Fatal("Table with the name of " + selectCommand.Name.Token.Literal + " doesn't exist!") + return nil, fmt.Errorf("table with the name of %s doesn't exist", selectCommand.Name.Token.Literal) } if len(table.Columns) == 0 || len(table.Columns[0].Values) == 0 { return engine.selectFromProvidedTable(selectCommand, &Table{Columns: []*Column{}}) } - filteredTable := engine.getFilteredTable(table, whereCommand, false) + filteredTable, err := engine.getFilteredTable(table, whereCommand, false) + + if err != nil { + return nil, err + } return engine.selectFromProvidedTable(selectCommand, filteredTable) } // selectFromTableWithWhereAndOrderBy - Return Table containing all values requested by SelectCommand, // filtered by WhereCommand and sorted by OrderByCommand -func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, orderByCommand *ast.OrderByCommand) *Table { +func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, orderByCommand *ast.OrderByCommand) (*Table, error) { table, exist := engine.Tables[selectCommand.Name.Token.Literal] if !exist { - log.Fatal("Table with the name of " + selectCommand.Name.Token.Literal + " doesn't exist!") + return nil, fmt.Errorf("table with the name of %s doesn't exist", selectCommand.Name.Token.Literal) } - filteredTable := engine.getFilteredTable(table, whereCommand, false) + filteredTable, err := engine.getFilteredTable(table, whereCommand, false) + + if err != nil { + return nil, err + } emptyTable := getCopyOfTableWithoutRows(table) @@ -202,11 +233,11 @@ func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.Se } // selectFromTableWithOrderBy - Return Table containing all values requested by SelectCommand and sorted by OrderByCommand -func (engine *DbEngine) selectFromTableWithOrderBy(selectCommand *ast.SelectCommand, orderByCommand *ast.OrderByCommand) *Table { +func (engine *DbEngine) selectFromTableWithOrderBy(selectCommand *ast.SelectCommand, orderByCommand *ast.OrderByCommand) (*Table, error) { table, exist := engine.Tables[selectCommand.Name.Token.Literal] if !exist { - log.Fatal("Table with the name of " + selectCommand.Name.Token.Literal + " doesn't exist!") + return nil, fmt.Errorf("table with the name of %s doesn't exist", selectCommand.Name.Token.Literal) } emptyTable := getCopyOfTableWithoutRows(table) @@ -252,13 +283,13 @@ func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, filte return copyOfTable } -func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCommand, negation bool) *Table { +func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCommand, negation bool) (*Table, error) { filteredTable := getCopyOfTableWithoutRows(table) for _, row := range MapTableToRows(table).rows { fulfilledFilters, err := isFulfillingFilters(row, whereCommand.Expression) if err != nil { - log.Fatal(err.Error()) + return nil, err } if xor(fulfilledFilters, negation) { @@ -268,7 +299,7 @@ func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCo } } } - return filteredTable + return filteredTable, nil } func xor(fulfilledFilters bool, negation bool) bool { @@ -303,14 +334,14 @@ func isFulfillingFilters(row map[string]ValueInterface, expressionTree ast.Expre } func processConditionExpression(row map[string]ValueInterface, conditionExpression *ast.ConditionExpression) (bool, error) { - valueLeft, isValueLeftValid := getTifierValue(conditionExpression.Left, row) - if isValueLeftValid != nil { - log.Fatal(isValueLeftValid.Error()) + valueLeft, err := getTifierValue(conditionExpression.Left, row) + if err != nil { + return false, err } - valueRight, isValueRightValid := getTifierValue(conditionExpression.Right, row) - if isValueLeftValid != nil { - log.Fatal(isValueRightValid.Error()) + valueRight, err := getTifierValue(conditionExpression.Right, row) + if err != nil { + return false, err } switch conditionExpression.Condition.Type { @@ -319,7 +350,7 @@ func processConditionExpression(row map[string]ValueInterface, conditionExpressi case token.NOT: return !(valueLeft.IsEqual(valueRight)), nil default: - return false, errors.New("Operation '" + conditionExpression.Condition.Literal + "' provided in WHERE command isn't allowed!") + return false, fmt.Errorf("operation '%s' provided in WHERE command isn't allowed", conditionExpression.Condition.Literal) } } @@ -344,7 +375,7 @@ func processOperationExpression(row map[string]ValueInterface, operationExpressi return left || right, err } - return false, errors.New("unsupported operation token has been used: " + operationExpression.Operation.Literal) + return false, fmt.Errorf("unsupported operation token has been used: %s", operationExpression.Operation.Literal) } func processBooleanExpression(booleanExpression *ast.BooleanExpression) (bool, error) { @@ -359,8 +390,8 @@ func getTifierValue(tifier ast.Tifier, row map[string]ValueInterface) (ValueInte case ast.Identifier: return row[mappedTifier.GetToken().Literal], nil case ast.Anonymitifier: - return getInterfaceValue(mappedTifier.GetToken()), nil + return getInterfaceValue(mappedTifier.GetToken()) default: - return nil, errors.New("Couldn't map interface to any implementation of it: " + tifier.GetToken().Literal) + return nil, fmt.Errorf("couldn't map interface to any implementation of it: %s", tifier.GetToken().Literal) } } diff --git a/engine/engine_test.go b/engine/engine_test.go index ade7f6b..082a74b 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -365,8 +365,14 @@ func (engineTestSuite *engineTableContentTestSuite) runTestSuite(t *testing.T) { } engine := New() - engine.Evaluate(sequencesWithoutSelect) - actualTable := engine.getSelectResponse(selectCommand.Commands[0].(*ast.SelectCommand)) + _, err := engine.Evaluate(sequencesWithoutSelect) + if err != nil { + log.Fatal(err) + } + actualTable, err := engine.getSelectResponse(selectCommand.Commands[0].(*ast.SelectCommand)) + if err != nil { + log.Fatal(err) + } if len(engineTestSuite.expectedOutput) == 0 { if len(actualTable.Columns[0].Values) != 0 { diff --git a/engine/engine_utils.go b/engine/engine_utils.go index 5e27542..3486447 100644 --- a/engine/engine_utils.go +++ b/engine/engine_utils.go @@ -1,22 +1,21 @@ package engine import ( - "log" "strconv" "github.com/LissaGreense/GO4SQL/token" ) -func getInterfaceValue(t token.Token) ValueInterface { +func getInterfaceValue(t token.Token) (ValueInterface, error) { switch t.Type { case token.LITERAL: castedInteger, err := strconv.Atoi(t.Literal) if err != nil { - log.Fatal("Cannot cast \"" + t.Literal + "\" to Integer") + return nil, err } - return IntegerValue{Value: castedInteger} + return IntegerValue{Value: castedInteger}, nil default: - return StringValue{Value: t.Literal} + return StringValue{Value: t.Literal}, nil } } diff --git a/main.go b/main.go index ff127e0..93a59dc 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "flag" + "fmt" "github.com/LissaGreense/GO4SQL/engine" "github.com/LissaGreense/GO4SQL/modes" "log" @@ -15,14 +16,19 @@ func main() { flag.Parse() engineSQL := engine.New() + var err error if len(*filePath) > 0 { - modes.HandleFileMode(*filePath, engineSQL) + err = modes.HandleFileMode(*filePath, engineSQL) } else if *streamMode { - modes.HandleStreamMode(engineSQL) + err = modes.HandleStreamMode(engineSQL) } else if *socketMode { modes.HandleSocketMode(*port, engineSQL) } else { - log.Println("No mode has been providing. Exiting.") + err = fmt.Errorf("no mode has been providing, exiting") + } + + if err != nil { + log.Fatal(err) } } diff --git a/modes/handler.go b/modes/handler.go index 40cb1a5..d787202 100644 --- a/modes/handler.go +++ b/modes/handler.go @@ -14,27 +14,40 @@ import ( ) // HandleFileMode - Handle GO4SQL use case where client sends input via text file -func HandleFileMode(filePath string, engine *engine.DbEngine) { +func HandleFileMode(filePath string, engine *engine.DbEngine) error { content, err := os.ReadFile(filePath) if err != nil { - log.Fatal(err) + return err } - - sequences := bytesToSequences(content) - fmt.Print(engine.Evaluate(sequences)) + sequences, err := bytesToSequences(content) + if err != nil { + return err + } + evaluate, err := engine.Evaluate(sequences) + if err != nil { + return err + } + fmt.Print(evaluate) + return nil } // HandleStreamMode - Handle GO4SQL use case where client sends input via stdin -func HandleStreamMode(engine *engine.DbEngine) { +func HandleStreamMode(engine *engine.DbEngine) error { reader := bufio.NewScanner(os.Stdin) for reader.Scan() { - sequences := bytesToSequences(reader.Bytes()) - fmt.Print(engine.Evaluate(sequences)) - } - err := reader.Err() - if err != nil { - log.Fatal(err) + sequences, err := bytesToSequences(reader.Bytes()) + if err != nil { + fmt.Print(err) + } else { + evaluate, err := engine.Evaluate(sequences) + if err != nil { + fmt.Print(err) + } else { + fmt.Print(evaluate) + } + } } + return reader.Err() } // HandleSocketMode - Handle GO4SQL use case where client sends input via socket protocol @@ -43,7 +56,7 @@ func HandleSocketMode(port int, engine *engine.DbEngine) { log.Printf("Starting Socket Server on %d port\n", port) if err != nil { - log.Fatal("Error:", err) + log.Fatal(err.Error()) } defer func(listener net.Listener) { @@ -64,15 +77,11 @@ func HandleSocketMode(port int, engine *engine.DbEngine) { } } -func bytesToSequences(content []byte) *ast.Sequence { +func bytesToSequences(content []byte) (*ast.Sequence, error) { lex := lexer.RunLexer(string(content)) parserInstance := parser.New(lex) sequences, err := parserInstance.ParseSequence() - if err != nil { - log.Fatal(err) - } - - return sequences + return sequences, err } func handleSocketClient(conn net.Conn, engine *engine.DbEngine) { @@ -88,19 +97,26 @@ func handleSocketClient(conn net.Conn, engine *engine.DbEngine) { for { n, err := conn.Read(buffer) if err != nil && err.Error() != "EOF" { - log.Fatal("Error:", err) + log.Fatal(err.Error()) } - sequences := bytesToSequences(buffer) - commandResult := engine.Evaluate(sequences) + sequences, err := bytesToSequences(buffer) - if len(commandResult) > 0 { + if err != nil { + log.Fatal(err.Error()) + } + + commandResult, err := engine.Evaluate(sequences) + + if err != nil { + _, err = conn.Write([]byte(err.Error())) + } else if len(commandResult) > 0 { _, err = conn.Write([]byte(commandResult)) } if err != nil { - log.Fatal("Error:", err) + log.Fatal(err.Error()) } - fmt.Printf("Received: %s\n", buffer[:n]) + log.Printf("Received: %s\n", buffer[:n]) } } diff --git a/parser/parser.go b/parser/parser.go index 801a46f..119fc61 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -1,7 +1,7 @@ package parser import ( - "errors" + "fmt" "github.com/LissaGreense/GO4SQL/ast" "github.com/LissaGreense/GO4SQL/lexer" "github.com/LissaGreense/GO4SQL/token" @@ -62,7 +62,7 @@ func validateToken(tokenType token.Type, expectedTokens []token.Type) error { } } if !contains { - return errors.New("Syntax error, expecting: " + tokensPrintMessage + ", got: " + string(tokenType)) + return fmt.Errorf("syntax error, expecting: %s, got: %s", tokensPrintMessage, string(tokenType)) } return nil } @@ -284,7 +284,7 @@ func (parser *Parser) parseWhereCommand() (ast.Command, error) { } if !expressionIsValid { - return nil, errors.New("Expression withing Where statement couldn't be parsed correctly") + return nil, fmt.Errorf("expression withing Where statement couldn't be parsed correctly") } err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.ORDER}) @@ -466,7 +466,7 @@ func (parser *Parser) getOperationExpression(booleanExpressionExists bool, condi return false, nil, err } if !expressionIsValid { - return false, nil, errors.New("Couldn't parse right side of the OperationExpression after " + operationExpression.Operation.Literal + " token.") + return false, nil, fmt.Errorf("couldn't parse right side of the OperationExpression after %s token", operationExpression.Operation.Literal) } operationExpression.Right = expression @@ -537,7 +537,7 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() default: - return false, nil, errors.New("Syntax error, expecting: " + token.APOSTROPHE + "," + token.IDENT + "," + token.LITERAL + ", got: " + parser.currentToken.Literal) + return false, nil, fmt.Errorf("syntax error, expecting: { %s, %s, %s }, got: %s", token.APOSTROPHE, token.IDENT, token.LITERAL, parser.currentToken.Literal) } return true, conditionalExpression, nil @@ -583,7 +583,7 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { } lastCommand.(*ast.DeleteCommand).WhereCommand = newCommand.(*ast.WhereCommand) } else { - return nil, errors.New("Syntax error, WHERE command needs SELECT or DELETE command before") + return nil, fmt.Errorf("syntax error, WHERE command needs SELECT or DELETE command before") } case token.ORDER: lastCommand, parserError := parser.getLastCommand(sequence) @@ -592,7 +592,7 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { } if lastCommand.TokenLiteral() != token.SELECT { - return nil, errors.New("Syntax error, ORDER BY command needs SELECT command before") + return nil, fmt.Errorf("syntax error, ORDER BY command needs SELECT command before") } selectCommand := lastCommand.(*ast.SelectCommand) @@ -602,7 +602,7 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { } selectCommand.OrderByCommand = newCommand.(*ast.OrderByCommand) default: - return nil, errors.New("Syntax error, invalid command found: " + parser.currentToken.Literal) + return nil, fmt.Errorf("syntax error, invalid command found: %s", parser.currentToken.Literal) } if err != nil { @@ -620,7 +620,7 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { func (parser *Parser) getLastCommand(sequence *ast.Sequence) (ast.Command, error) { if len(sequence.Commands) == 0 { - return nil, errors.New("Syntax error, Where Command can't be used without predecessor") + return nil, fmt.Errorf("syntax error, Where Command can't be used without predecessor") } lastCommand := sequence.Commands[len(sequence.Commands)-1] return lastCommand, nil From 2b91d01de2292b987c31c35d014c2baed422fd38 Mon Sep 17 00:00:00 2001 From: LissaGreense Date: Wed, 12 Jun 2024 00:10:47 +0200 Subject: [PATCH 09/21] Adds LIMIT and OFFSET keywords :D --- .github/expected_results/end2end.txt | 16 +++ README.md | 30 ++++++ ast/ast.go | 50 +++++++++ engine/engine.go | 67 ++++++++++-- engine/engine_test.go | 155 +++++++++++++++++++++++++++ lexer/lexer_test.go | 17 +++ parser/parser.go | 108 ++++++++++++++++++- parser/parser_test.go | 131 ++++++++++++++++++++++ test_file | 3 + token/token.go | 4 + 10 files changed, 572 insertions(+), 9 deletions(-) diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt index 2604fbd..d756212 100644 --- a/.github/expected_results/end2end.txt +++ b/.github/expected_results/end2end.txt @@ -22,6 +22,22 @@ Data Inserted | one | two | three | four | +-----+-----+-------+------+ +-----+-----+-------+------+ ++---------+-----+-------+------+ +| one | two | three | four | ++---------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | ++---------+-----+-------+------+ ++-----------+-----+-------+------+ +| one | two | three | four | ++-----------+-----+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | 3 | 33 | 'e' | ++-----------+-----+-------+------+ ++-----------+-----+-------+------+ +| one | two | three | four | ++-----------+-----+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | ++-----------+-----+-------+------+ Data from 'tbl' has been deleted +-----------+-----+-------+------+ | one | two | three | four | diff --git a/README.md b/README.md index b111801..5e2258a 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,14 @@ Currently, there are 3 modes to chose from: First column is called ``one`` and it contains strings (keyword ``TEXT``), second one is called ``two`` and it contains integers (keyword ``INT``). +* ***DROP TABLE*** - you can destroy the table of name ``table1`` using + command: + ```sql + DROP TABLE table1; + ``` + After using this command table1 will no longer be available and all data connected to it (column + definitions and inserted values) will be lost. + * ***INSERT INTO*** - you can insert values into table called ``table1`` with command: @@ -94,6 +102,28 @@ Currently, there are 3 modes to chose from: In this case, this command will order by ``column1`` in ascending order, but if some rows have the same ``column1``, it orders them by column2 in descending order. +* ***LIMIT*** is used to reduce number of rows printed out by returning only specified number of + records with ``SELECT`` like this: + ```sql + SELECT column1, column2, + FROM table_name + ORDER BY column1 ASC + LIMIT 5; + ``` + In this case, this command will order by ``column1`` in ascending order and return 5 first records. + + +* ***OFFSET*** is used to reduce number of rows printed out by not skipping specified numbers of + rows in returned output with ``SELECT`` like this: + ```sql + SELECT column1, column2, + FROM table_name + ORDER BY column1 ASC + LIMIT 5 OFFSET 3; + ``` + In this case, this command will order by ``column1`` in ascending order and skip 3 first records, + then return records from 4th to 8th. + ## UNIT TESTS To run all the tests locally run this in root directory: diff --git a/ast/ast.go b/ast/ast.go index 21f029f..b0943a0 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -163,6 +163,8 @@ type SelectCommand struct { Space []token.Token // ex. column names WhereCommand *WhereCommand // optional OrderByCommand *OrderByCommand // optional + LimitCommand *LimitCommand // optional + OffsetCommand *OffsetCommand // optional } func (ls SelectCommand) CommandNode() {} @@ -198,6 +200,36 @@ func (ls SelectCommand) HasOrderByCommand() bool { return true } +// HasLimitCommand - returns true if optional HasLimitCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table LIMIT 5; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls SelectCommand) HasLimitCommand() bool { + if ls.LimitCommand == nil { + return false + } + return true +} + +// HasOffsetCommand - returns true if optional HasOffsetCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table OFFSET 100; +// Returns true +// +// SELECT * FROM table LIMIT 10; +// Returns false +func (ls SelectCommand) HasOffsetCommand() bool { + if ls.OffsetCommand == nil { + return false + } + return true +} + // WhereCommand - Part of Command that represent Where statement with expression that will qualify values from Select // // Example: @@ -267,3 +299,21 @@ type SortPattern struct { ColumnName token.Token // column name Order token.Token // ASC or DESC } + +// LimitCommand - Part of Command that limits results from SelectCommand +type LimitCommand struct { + Token token.Token + Count int +} + +func (ls LimitCommand) CommandNode() {} +func (ls LimitCommand) TokenLiteral() string { return ls.Token.Literal } + +// OffsetCommand - Part of Command that skip begging rows from SelectCommand +type OffsetCommand struct { + Token token.Token + Count int +} + +func (ls OffsetCommand) CommandNode() {} +func (ls OffsetCommand) TokenLiteral() string { return ls.Token.Literal } diff --git a/engine/engine.go b/engine/engine.go index 4732cc2..677439e 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -33,6 +33,10 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { continue case *ast.OrderByCommand: continue + case *ast.LimitCommand: + continue + case *ast.OffsetCommand: + continue case *ast.CreateCommand: err := engine.createTable(mappedCommand) if err != nil { @@ -78,19 +82,42 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { // getSelectResponse - Returns Select response basing on ast.OrderByCommand and ast.WhereCommand included in this Select func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) (*Table, error) { + var table *Table + var err error + if selectCommand.HasWhereCommand() { whereCommand := selectCommand.WhereCommand if selectCommand.HasOrderByCommand() { orderByCommand := selectCommand.OrderByCommand - return engine.selectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand) + table, err = engine.selectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand) + if err != nil { + return nil, err + } + } else { + table, err = engine.selectFromTableWithWhere(selectCommand, whereCommand) + if err != nil { + return nil, err + } + } + } else if selectCommand.HasOrderByCommand() { + table, err = engine.selectFromTableWithOrderBy(selectCommand, selectCommand.OrderByCommand) + if err != nil { + return nil, err } - return engine.selectFromTableWithWhere(selectCommand, whereCommand) } - if selectCommand.HasOrderByCommand() { - orderByCommand := selectCommand.OrderByCommand - return engine.selectFromTableWithOrderBy(selectCommand, orderByCommand) + + if table == nil { + table, err = engine.selectFromTable(selectCommand) + if err != nil { + return nil, err + } } - return engine.selectFromTable(selectCommand) + + if selectCommand.HasLimitCommand() || selectCommand.HasOffsetCommand() { + table.applyOffsetAndLimit(selectCommand) + } + + return table, nil } // createTable - initialize new table in engine with specified name @@ -302,6 +329,34 @@ func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCo return filteredTable, nil } +func (table *Table) applyOffsetAndLimit(command *ast.SelectCommand) { + var offset = 0 + var limitRaw = -1 + + if command.HasLimitCommand() { + limitRaw = command.LimitCommand.Count + } + if command.HasOffsetCommand() { + offset = command.OffsetCommand.Count + } + + for _, column := range table.Columns { + var limit int + + if limitRaw == -1 || limitRaw+offset > len(column.Values) { + limit = len(column.Values) + } else { + limit = limitRaw + offset + } + + if offset > len(column.Values) || limit == 0 { + column.Values = make([]ValueInterface, 0) + } else { + column.Values = column.Values[offset:limit] + } + } +} + func xor(fulfilledFilters bool, negation bool) bool { return (fulfilledFilters || negation) && !(fulfilledFilters && negation) } diff --git a/engine/engine_test.go b/engine/engine_test.go index 082a74b..d84d9d7 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -322,6 +322,161 @@ func TestOrderByWithMultipleSorts(t *testing.T) { engineTestSuite.runTestSuite(t) } +func TestLimit(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + "INSERT INTO tb1 VALUES( 'welcome', 2, 66, 'bb' );", + "INSERT INTO tb1 VALUES( 'seeYouLater', 2, 95, 'ab' );", + }, + selectInput: "SELECT one FROM tb1 LIMIT 2;", + expectedOutput: [][]string{ + {"one"}, + {"hello"}, + {"byebye"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestLimitEqualToZero(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + "INSERT INTO tb1 VALUES( 'welcome', 2, 66, 'bb' );", + "INSERT INTO tb1 VALUES( 'seeYouLater', 2, 95, 'ab' );", + }, + selectInput: "SELECT one FROM tb1 LIMIT 0;", + expectedOutput: [][]string{ + {"one"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestLimitThatIsMoreThanSize(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'aa' );", + }, + selectInput: "SELECT one FROM tb1 LIMIT 666;", + expectedOutput: [][]string{ + {"one"}, + {"hello"}, + {"byebye"}, + {"goodbye"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestOffset(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 4, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + }, + selectInput: "SELECT one FROM tb1 OFFSET 3;", + expectedOutput: [][]string{ + {"one"}, + {"sorry"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestOffsetThatOverExceedSize(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 4, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + }, + selectInput: "SELECT one FROM tb1 WHERE TRUE ORDER BY two ASC OFFSET 4;", + expectedOutput: [][]string{ + {"one"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestOffsetEqualToZero(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 4, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + }, + selectInput: "SELECT one FROM tb1 OFFSET 0;", + expectedOutput: [][]string{ + {"one"}, + {"hello"}, + {"byebye"}, + {"goodbye"}, + {"sorry"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestLimitAndOffset(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + }, + selectInput: "SELECT one FROM tb1 WHERE TRUE ORDER BY two ASC, four DESC LIMIT 2 OFFSET 2;", + expectedOutput: [][]string{ + {"one"}, + {"goodbye"}, + {"hello"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + type engineDBContentTestSuite struct { inputs []string expectedTableNames []string diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 7ef4016..f190b48 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -251,6 +251,23 @@ func TestDropStatement(t *testing.T) { runLexerTestSuite(t, input, tests) } +func TestLimitAndOffsetStatement(t *testing.T) { + input := `LIMIT 5 OFFSET 6;` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.LIMIT, "LIMIT"}, + {token.LITERAL, "5"}, + {token.OFFSET, "OFFSET"}, + {token.LITERAL, "6"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + func runLexerTestSuite(t *testing.T, input string, tests []struct { expectedType token.Type expectedLiteral string diff --git a/parser/parser.go b/parser/parser.go index 119fc61..695ec9a 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -5,6 +5,7 @@ import ( "github.com/LissaGreense/GO4SQL/ast" "github.com/LissaGreense/GO4SQL/lexer" "github.com/LissaGreense/GO4SQL/token" + "strconv" ) // Parser - Contain token that is currently analyzed by parser and the next one. Lexer is used to tokenize the client @@ -254,7 +255,7 @@ func (parser *Parser) parseSelectCommand() (ast.Command, error) { parser.nextToken() // expect SEMICOLON or WHERE - err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE, token.ORDER}) + err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE, token.ORDER, token.LIMIT, token.OFFSET}) if err != nil { return nil, err } @@ -376,6 +377,9 @@ func (parser *Parser) parseOrderByCommand() (ast.Command, error) { // ensure that loop below will execute at least once err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } // array of SortPattern for parser.currentToken.Type == token.IDENT { @@ -405,9 +409,79 @@ func (parser *Parser) parseOrderByCommand() (ast.Command, error) { parser.nextToken() } - err = validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + parser.skipIfCurrentTokenIsSemicolon() + + return orderCommand, nil +} + +// parseLimitCommand - Return ast.parseLimitCommand created from tokens and validate the syntax +// +// Example of input parsable to the ast.parseLimitCommand: +// LIMIT 10 +func (parser *Parser) parseLimitCommand() (ast.Command, error) { + // token.LIMIT already at current position in parser + limitCommand := &ast.LimitCommand{Token: parser.currentToken} + + // token.LIMIT no longer needed + parser.nextToken() + + err := validateToken(parser.currentToken.Type, []token.Type{token.LITERAL}) + if err != nil { + return nil, err + } + + // convert count number to int + count, err := strconv.Atoi(parser.currentToken.Literal) + if err != nil { + return nil, err + } + + if count < 0 { + return nil, fmt.Errorf("limit value should be more than 0") + } + + limitCommand.Count = count + + // Skip token.IDENT + parser.nextToken() + + parser.skipIfCurrentTokenIsSemicolon() + + return limitCommand, nil +} + +// parseOffsetCommand - Return ast.parseOffsetCommand created from tokens and validate the syntax +// +// Example of input parsable to the ast.parseLimitCommand: +// OFFSET 10 +func (parser *Parser) parseOffsetCommand() (ast.Command, error) { + // token.OFFSET already at current position in parser + offsetCommand := &ast.OffsetCommand{Token: parser.currentToken} + + // token.OFFSET no longer needed + parser.nextToken() + + err := validateToken(parser.currentToken.Type, []token.Type{token.LITERAL}) + if err != nil { + return nil, err + } + + count, err := strconv.Atoi(parser.currentToken.Literal) + if err != nil { + return nil, err + } + if count < 0 { + return nil, fmt.Errorf("limit value should be more than 0") + } + + offsetCommand.Count = count - return orderCommand, err + // Skip token.IDENT + parser.nextToken() + + parser.skipIfCurrentTokenIsSemicolon() + + return offsetCommand, nil } // getExpression - Return proper structure of ast.Expression and validate the syntax @@ -601,6 +675,34 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { return nil, err } selectCommand.OrderByCommand = newCommand.(*ast.OrderByCommand) + case token.LIMIT: + lastCommand, parserError := parser.getLastCommand(sequence) + if parserError != nil { + return nil, parserError + } + if lastCommand.TokenLiteral() != token.SELECT { + return nil, fmt.Errorf("syntax error, LIMIT command needs SELECT command before") + } + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseLimitCommand() + if err != nil { + return nil, err + } + selectCommand.LimitCommand = newCommand.(*ast.LimitCommand) + case token.OFFSET: + lastCommand, parserError := parser.getLastCommand(sequence) + if parserError != nil { + return nil, parserError + } + if lastCommand.TokenLiteral() != token.SELECT { + return nil, fmt.Errorf("syntax error, OFFSET command needs SELECT command before") + } + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseOffsetCommand() + if err != nil { + return nil, err + } + selectCommand.OffsetCommand = newCommand.(*ast.OffsetCommand) default: return nil, fmt.Errorf("syntax error, invalid command found: %s", parser.currentToken.Literal) } diff --git a/parser/parser_test.go b/parser/parser_test.go index a59ce72..d3179f9 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -315,6 +315,113 @@ func TestSelectWithOrderByCommand(t *testing.T) { testOrderByCommands(t, expectedOrderByCommand, selectCommand.OrderByCommand) } +func TestSelectWithLimitCommand(t *testing.T) { + input := "SELECT * FROM tableName LIMIT 5;" + expectedLimitCommand := ast.LimitCommand{ + Token: token.Token{Type: token.LIMIT, Literal: "LIMIT"}, + Count: 5, + } + expectedTableName := "tableName" + expectedColumnName := []token.Token{{Type: token.ASTERISK, Literal: "*"}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName) { + return + } + + if !selectCommand.HasLimitCommand() { + t.Fatalf("sequences does not contain where command") + } + + testLimitCommands(t, expectedLimitCommand, selectCommand.LimitCommand) +} + +func TestSelectWithOffsetCommand(t *testing.T) { + input := "SELECT * FROM tableName OFFSET 5;" + expectedOffsetCommand := ast.OffsetCommand{ + Token: token.Token{Type: token.OFFSET, Literal: "OFFSET"}, + Count: 5, + } + + expectedTableName := "tableName" + expectedColumnName := []token.Token{{Type: token.ASTERISK, Literal: "*"}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName) { + return + } + + if !selectCommand.HasOffsetCommand() { + t.Fatalf("select command should have offset command") + } + testOffsetCommands(t, expectedOffsetCommand, selectCommand.OffsetCommand) +} + +func TestSelectWithLimitAndOffsetCommand(t *testing.T) { + input := "SELECT * FROM tableName ORDER BY colName1 DESC LIMIT 2 OFFSET 13;" + expectedLimitCommand := ast.LimitCommand{ + Token: token.Token{Type: token.LIMIT, Literal: "LIMIT"}, + Count: 2, + } + expectedOffsetCommand := ast.OffsetCommand{ + Token: token.Token{Type: token.OFFSET, Literal: "OFFSET"}, + Count: 13, + } + expectedTableName := "tableName" + expectedColumnName := []token.Token{{Type: token.ASTERISK, Literal: "*"}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName) { + return + } + + if !selectCommand.HasLimitCommand() { + t.Fatalf("select command should have limit command") + } + if !selectCommand.HasOffsetCommand() { + t.Fatalf("select command should have offset command") + } + + testLimitCommands(t, expectedLimitCommand, selectCommand.LimitCommand) + testOffsetCommands(t, expectedOffsetCommand, selectCommand.OffsetCommand) +} + func TestParseLogicOperatorsInCommand(t *testing.T) { firstExpression := ast.OperationExpression{ @@ -476,9 +583,33 @@ func testOrderByCommands(t *testing.T, expectedOrderByCommand ast.OrderByCommand t.Errorf("Expecting Column Name: %s, got: %s", expectedSortPattern.ColumnName.Literal, actualOrderByCommand.SortPatterns[i].ColumnName.Literal) } } +} + +func testLimitCommands(t *testing.T, expectedLimitCommand ast.LimitCommand, actualLimitCommand *ast.LimitCommand) { + if expectedLimitCommand.Token.Type != actualLimitCommand.Token.Type { + t.Errorf("Expecting Token TokenType: %q, got: %q", expectedLimitCommand.Token.Type, actualLimitCommand.Token.Type) + } + if expectedLimitCommand.Token.Literal != actualLimitCommand.Token.Literal { + t.Errorf("Expecting Token Literal: %s, got: %s", expectedLimitCommand.Token.Literal, actualLimitCommand.Token.Literal) + } + if expectedLimitCommand.Count != actualLimitCommand.Count { + t.Errorf("Expecting Count to have value: %d, got: %d", expectedLimitCommand.Count, actualLimitCommand.Count) + } } +func testOffsetCommands(t *testing.T, expectedOffsetCommand ast.OffsetCommand, actualOffsetCommand *ast.OffsetCommand) { + + if expectedOffsetCommand.Token.Type != actualOffsetCommand.Token.Type { + t.Errorf("Expecting Token TokenType: %q, got: %q", expectedOffsetCommand.Token.Type, actualOffsetCommand.Token.Type) + } + if expectedOffsetCommand.Token.Literal != actualOffsetCommand.Token.Literal { + t.Errorf("Expecting Token Literal: %s, got: %s", expectedOffsetCommand.Token.Literal, actualOffsetCommand.Token.Literal) + } + if expectedOffsetCommand.Count != actualOffsetCommand.Count { + t.Errorf("Expecting Count to have value: %d, got: %d", expectedOffsetCommand.Count, actualOffsetCommand.Count) + } +} func expressionsAreEqual(first ast.Expression, second ast.Expression) bool { booleanExpression, booleanExpressionIsValid := first.(*ast.BooleanExpression) diff --git a/test_file b/test_file index 5f4c31d..142b87b 100644 --- a/test_file +++ b/test_file @@ -6,6 +6,9 @@ SELECT one, three FROM tbl WHERE two NOT 3; SELECT * FROM tbl WHERE one NOT 'goodbye' AND two EQUAL 3; SELECT * FROM tbl WHERE FALSE; + SELECT * FROM tbl LIMIT 1; + SELECT * FROM tbl OFFSET 1; + SELECT * FROM tbl LIMIT 1 OFFSET 1; DELETE FROM tbl WHERE one EQUAL 'byebye'; SELECT * FROM tbl; SELECT one FROM tbl WHERE TRUE ORDER BY two ASC, four DESC; diff --git a/token/token.go b/token/token.go index 5643cab..d471e5a 100644 --- a/token/token.go +++ b/token/token.go @@ -43,6 +43,8 @@ const ( BY = "BY" ASC = "ASC" DESC = "DESC" + LIMIT = "LIMIT" + OFFSET = "OFFSET" // EQUAL - Logical operations EQUAL = "EQUAL" @@ -75,6 +77,8 @@ var keywords = map[string]Type{ "BY": BY, "ASC": ASC, "DESC": DESC, + "LIMIT": LIMIT, + "OFFSET": OFFSET, "VALUES": VALUES, "WHERE": WHERE, "EQUAL": EQUAL, From 14c009d11b21ef90bb2e98ee48f6f87f798e376e Mon Sep 17 00:00:00 2001 From: LissaGreense Date: Thu, 27 Jun 2024 23:17:40 +0200 Subject: [PATCH 10/21] Add Update command --- .github/expected_results/end2end.txt | 7 ++ README.md | 8 ++ ast/ast.go | 29 +++++++ engine/engine.go | 53 ++++++++++++ engine/engine_test.go | 46 ++++++++++ engine/row.go | 13 ++- lexer/lexer_test.go | 38 ++++++++- parser/parser.go | 80 +++++++++++++++++ parser/parser_test.go | 123 +++++++++++++++++++++++++++ test_file | 2 + token/token.go | 7 ++ 11 files changed, 400 insertions(+), 6 deletions(-) diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt index d756212..5e8f803 100644 --- a/.github/expected_results/end2end.txt +++ b/.github/expected_results/end2end.txt @@ -51,4 +51,11 @@ Data from 'tbl' has been deleted | 'goodbye' | | 'hello' | +-----------+ +Table: 'tbl' has been updated ++-----------+-----+-------+------+ +| one | two | three | four | ++-----------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | 5 | 22 | 'P' | ++-----------+-----+-------+------+ Table: 'tbl' has been dropped diff --git a/README.md b/README.md index 5e2258a..85adbb4 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,14 @@ Currently, there are 3 modes to chose from: Please note that the number of arguments and types of the values must be the same as you declared with ``CREATE``. +* ***UPDATE*** - you can update values in table called ``table1`` with command: + ```sql + UPDATE table1 + SET column_name_1 TO new_value_1, column_name_2 TO new_value_2 + WHERE id EQUAL 1; + ``` + It will update all rows where column ``id`` is equal to ``1`` by replacing value in + ``column_name_1`` with ``new_value_1`` and ``column_name_2`` with ``new_value_2``. * ***SELECT FROM*** - you can either select everything from ``table1`` with: ```SELECT * FROM table1;``` diff --git a/ast/ast.go b/ast/ast.go index b0943a0..8d99aa7 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -230,6 +230,35 @@ func (ls SelectCommand) HasOffsetCommand() bool { return true } +// UpdateCommand - Part of Command that allow to change existing data +// +// Example: +// UPDATE table SET col1 TO 2 WHERE column1 NOT 'hi'; +type UpdateCommand struct { + Token token.Token + Name Identifier // ex. name of table + Changes map[token.Token]Anonymitifier // column names with new values + WhereCommand *WhereCommand // optional +} + +func (ls UpdateCommand) CommandNode() {} +func (ls UpdateCommand) TokenLiteral() string { return ls.Token.Literal } + +// HasWhereCommand - returns true if optional HasWhereCommand is present in UpdateCommand +// +// Example: +// UPDATE table SET col1 TO 2 WHERE column1 NOT 'hi'; +// Returns true +// +// UPDATE table SET col1 TO 2; +// Returns false +func (ls UpdateCommand) HasWhereCommand() bool { + if ls.WhereCommand == nil { + return false + } + return true +} + // WhereCommand - Part of Command that represent Where statement with expression that will qualify values from Select // // Example: diff --git a/engine/engine.go b/engine/engine.go index 677439e..d5b2bab 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -72,6 +72,13 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { engine.dropTable(mappedCommand) result += "Table: '" + mappedCommand.Name.GetToken().Literal + "' has been dropped\n" continue + case *ast.UpdateCommand: + err := engine.updateTable(mappedCommand) + if err != nil { + return "", err + } + result += "Table: '" + mappedCommand.Name.GetToken().Literal + "' has been updated\n" + continue default: return "", fmt.Errorf("unsupported Command detected: %v", command) } @@ -140,6 +147,52 @@ func (engine *DbEngine) createTable(command *ast.CreateCommand) error { return nil } +func (engine *DbEngine) updateTable(command *ast.UpdateCommand) error { + table, exist := engine.Tables[command.Name.Token.Literal] + + if !exist { + return fmt.Errorf("table with the name of %s doesn't exist", command.Name.Token.Literal) + } + + columns := table.Columns + + // TODO: This could be optimized + mappedChanges := make(map[int]ast.Anonymitifier) + for updatedCol, newValue := range command.Changes { + for colIndex := 0; colIndex < len(columns); colIndex++ { + if columns[colIndex].Name == updatedCol.Literal { + mappedChanges[colIndex] = newValue + break + } + if colIndex == len(columns)-1 { + return fmt.Errorf("column with the name of %s doesn't exist in table %s", updatedCol.Literal, command.Name.GetToken().Literal) + } + } + } + + numberOfRows := len(columns[0].Values) + for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { + if command.HasWhereCommand() { + fulfilledFilters, err := isFulfillingFilters(getRow(table, rowIndex), command.WhereCommand.Expression) + if err != nil { + return err + } + if !fulfilledFilters { + continue + } + } + for colIndex, value := range mappedChanges { + interfaceValue, err := getInterfaceValue(value.GetToken()) + if err != nil { + return err + } + table.Columns[colIndex].Values[rowIndex] = interfaceValue + } + } + + return nil +} + // insertIntoTable - Insert row of values into the table func (engine *DbEngine) insertIntoTable(command *ast.InsertCommand) error { table, exist := engine.Tables[command.Name.Token.Literal] diff --git a/engine/engine_test.go b/engine/engine_test.go index d84d9d7..907fecd 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -255,6 +255,52 @@ func TestDelete(t *testing.T) { engineTestSuite.runTestSuite(t) } +func TestUpdateWithWhere(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + "UPDATE tb1 SET one TO 'hi hello', three TO 5 WHERE two EQUAL 3;", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1;", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"hello", "1", "11", "q"}, + {"hi hello", "3", "5", "e"}, + {"goodbye", "2", "22", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestUpdateWithoutWhere(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + "UPDATE tb1 SET one TO 'hi hello', three TO 5;", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1;", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"hi hello", "1", "5", "q"}, + {"hi hello", "3", "5", "e"}, + {"goodbye", "2", "22", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + func TestOrderBy(t *testing.T) { engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ diff --git a/engine/row.go b/engine/row.go index 35fc707..5dd9743 100644 --- a/engine/row.go +++ b/engine/row.go @@ -12,11 +12,16 @@ func MapTableToRows(table *Table) Rows { numberOfRows := len(table.Columns[0].Values) for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { - row := make(map[string]ValueInterface) - for _, column := range table.Columns { - row[column.Name] = column.Values[rowIndex] - } + row := getRow(table, rowIndex) rows = append(rows, row) } return Rows{rows: rows} } + +func getRow(table *Table, rowIndex int) map[string]ValueInterface { + row := make(map[string]ValueInterface) + for _, column := range table.Columns { + row[column.Name] = column.Values[rowIndex] + } + return row +} diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index f190b48..b2b154f 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -6,7 +6,7 @@ import ( "github.com/LissaGreense/GO4SQL/token" ) -func TestLexer(t *testing.T) { +func TestLexerWithInsertCommand(t *testing.T) { input := ` CREATE TABLE 1tbl( one TEXT , two INT ); @@ -58,7 +58,41 @@ func TestLexer(t *testing.T) { runLexerTestSuite(t, input, tests) } -func TestLexerWithNumbersMixedInLitterals(t *testing.T) { +func TestLexerWithUpdateCommand(t *testing.T) { + input := + ` + UPDATE table1 + SET column_name_1 TO 'UPDATE', column_name_2 TO 42 + WHERE column_name_3 EQUAL 1; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.UPDATE, "UPDATE"}, + {token.IDENT, "table1"}, + {token.SET, "SET"}, + {token.IDENT, "column_name_1"}, + {token.TO, "TO"}, + {token.APOSTROPHE, "'"}, + {token.IDENT, "UPDATE"}, + {token.APOSTROPHE, "'"}, + {token.COMMA, ","}, + {token.IDENT, "column_name_2"}, + {token.TO, "TO"}, + {token.LITERAL, "42"}, + {token.WHERE, "WHERE"}, + {token.IDENT, "column_name_3"}, + {token.EQUAL, "EQUAL"}, + {token.LITERAL, "1"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestLexerWithNumbersMixedInLiterals(t *testing.T) { input := ` CREATE TABLE tbl2( one TEXT , two INT ); diff --git a/parser/parser.go b/parser/parser.go index 695ec9a..39262b5 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -484,6 +484,78 @@ func (parser *Parser) parseOffsetCommand() (ast.Command, error) { return offsetCommand, nil } +// parseUpdateCommand - Return ast.parseUpdateCommand created from tokens and validate the syntax +// +// Example of input parsable to the ast.parseUpdateCommand: +// UPDATE table SET col1 TO 'value' WHERE col2 EQUAL 10; +func (parser *Parser) parseUpdateCommand() (ast.Command, error) { + // token.UPDATE already at current position in parser + updateCommand := &ast.UpdateCommand{Token: parser.currentToken} + + // Ignore token.UPDATE + parser.nextToken() + + // Get table name + err := validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + updateCommand.Name = ast.Identifier{Token: parser.currentToken} + + // Ignore token.INDENT + parser.nextToken() + + err = validateToken(parser.currentToken.Type, []token.Type{token.SET}) + if err != nil { + return nil, err + } + + // Ignore token.SET + parser.nextToken() + + updateCommand.Changes = make(map[token.Token]ast.Anonymitifier) + for parser.currentToken.Type == token.IDENT { + // Get column name + err := validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + colKey := parser.currentToken + + // skip column name + parser.nextToken() + + err = validateToken(parser.currentToken.Type, []token.Type{token.TO}) + if err != nil { + return nil, err + } + // skip token.TO + parser.nextToken() + + parser.skipIfCurrentTokenIsApostrophe() + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL}) + if err != nil { + return nil, err + } + updateCommand.Changes[colKey] = ast.Anonymitifier{Token: parser.currentToken} + + // skip token.IDENT or token.LITERAL + parser.nextToken() + parser.skipIfCurrentTokenIsApostrophe() + + if parser.currentToken.Type != token.COMMA { + break + } + + // Skip token.COMMA + parser.nextToken() + } + + parser.skipIfCurrentTokenIsSemicolon() + + return updateCommand, nil +} + // getExpression - Return proper structure of ast.Expression and validate the syntax // // Available expressions: @@ -632,6 +704,8 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { command, err = parser.parseCreateCommand() case token.INSERT: command, err = parser.parseInsertCommand() + case token.UPDATE: + command, err = parser.parseUpdateCommand() case token.SELECT: command, err = parser.parseSelectCommand() case token.DELETE: @@ -656,6 +730,12 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { return nil, err } lastCommand.(*ast.DeleteCommand).WhereCommand = newCommand.(*ast.WhereCommand) + } else if lastCommand.TokenLiteral() == token.UPDATE { + newCommand, err := parser.parseWhereCommand() + if err != nil { + return nil, err + } + lastCommand.(*ast.UpdateCommand).WhereCommand = newCommand.(*ast.WhereCommand) } else { return nil, fmt.Errorf("syntax error, WHERE command needs SELECT or DELETE command before") } diff --git a/parser/parser_test.go b/parser/parser_test.go index d3179f9..b86709b 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -422,6 +422,94 @@ func TestSelectWithLimitAndOffsetCommand(t *testing.T) { testOffsetCommands(t, expectedOffsetCommand, selectCommand.OffsetCommand) } +func TestParseUpdateCommand(t *testing.T) { + tests := []struct { + input string + expectedTableName string + expectedChanges map[token.Token]ast.Anonymitifier + }{ + {input: "UPDATE tbl SET colName TO 5;", expectedTableName: "tbl", expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, + }, + {input: "UPDATE tbl1 SET colName1 TO 'hi hello', colName2 TO 5;", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.IDENT, Literal: "hi hello"}}, + {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, + }, + } + + for _, tt := range tests { + lexer := lexer.RunLexer(tt.input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + if !testUpdateStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedChanges) { + return + } + } +} + +func TestParseUpdateCommandWithWhere(t *testing.T) { + tests := []struct { + input string + expectedTableName string + expectedChanges map[token.Token]ast.Anonymitifier + expectedWhereCommand ast.Expression + }{ + { + input: "UPDATE tbl SET colName TO 5 WHERE id EQUAL 3;", + expectedTableName: "tbl", + expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, + expectedWhereCommand: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "id"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.LITERAL, Literal: "3"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + }, + } + + for _, tt := range tests { + lexer := lexer.RunLexer(tt.input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + actualUpdateCommand, ok := sequences.Commands[0].(*ast.UpdateCommand) + + if !ok { + t.Errorf("actualUpdateCommand is not %T. got=%T", &ast.UpdateCommand{}, sequences.Commands[0]) + } + + if !testUpdateStatement(t, actualUpdateCommand, tt.expectedTableName, tt.expectedChanges) { + return + } + + if !actualUpdateCommand.HasWhereCommand() { + t.Errorf("actualUpdateCommand should have where command") + } + + if !whereStatementIsValid(t, actualUpdateCommand.WhereCommand, tt.expectedWhereCommand) { + return + } + } +} + func TestParseLogicOperatorsInCommand(t *testing.T) { firstExpression := ast.OperationExpression{ @@ -519,6 +607,29 @@ func testSelectStatement(t *testing.T, command ast.Command, expectedTableName st return true } +func testUpdateStatement(t *testing.T, command ast.Command, expectedTableName string, expectedChanges map[token.Token]ast.Anonymitifier) bool { + if command.TokenLiteral() != "UPDATE" { + t.Errorf("command.TokenLiteral() not 'UPDATE'. got=%q", command.TokenLiteral()) + return false + } + actualUpdateCommand, ok := command.(*ast.UpdateCommand) + + if !ok { + t.Errorf("actualUpdateCommand is not %T. got=%T", &ast.UpdateCommand{}, command) + return false + } + if actualUpdateCommand.Name.Token.Literal != expectedTableName { + t.Errorf("%s != %s", actualUpdateCommand.TokenLiteral(), expectedTableName) + return false + } + if !tokenMapEquals(actualUpdateCommand.Changes, expectedChanges) { + t.Errorf("") + return false + } + + return true +} + func whereStatementIsValid(t *testing.T, command ast.Command, expectedExpression ast.Expression) bool { if command.TokenLiteral() != "WHERE" { t.Errorf("command.TokenLiteral() not 'WHERE'. got=%q", command.TokenLiteral()) @@ -563,6 +674,18 @@ func tokenArrayEquals(a []token.Token, b []token.Token) bool { return true } +func tokenMapEquals(a map[token.Token]ast.Anonymitifier, b map[token.Token]ast.Anonymitifier) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if v.GetToken().Literal != b[k].GetToken().Literal { + return false + } + } + return true +} + func testOrderByCommands(t *testing.T, expectedOrderByCommand ast.OrderByCommand, actualOrderByCommand *ast.OrderByCommand) { if expectedOrderByCommand.Token.Type != actualOrderByCommand.Token.Type { diff --git a/test_file b/test_file index 142b87b..493a292 100644 --- a/test_file +++ b/test_file @@ -12,4 +12,6 @@ DELETE FROM tbl WHERE one EQUAL 'byebye'; SELECT * FROM tbl; SELECT one FROM tbl WHERE TRUE ORDER BY two ASC, four DESC; + UPDATE tbl SET two TO 5, four TO 'P' WHERE one EQUAL 'goodbye'; + SELECT * FROM tbl; DROP TABLE tbl; diff --git a/token/token.go b/token/token.go index d471e5a..327a21c 100644 --- a/token/token.go +++ b/token/token.go @@ -45,6 +45,10 @@ const ( DESC = "DESC" LIMIT = "LIMIT" OFFSET = "OFFSET" + UPDATE = "UPDATE" + SET = "SET" + + TO = "TO" // EQUAL - Logical operations EQUAL = "EQUAL" @@ -79,6 +83,9 @@ var keywords = map[string]Type{ "DESC": DESC, "LIMIT": LIMIT, "OFFSET": OFFSET, + "UPDATE": UPDATE, + "SET": SET, + "TO": TO, "VALUES": VALUES, "WHERE": WHERE, "EQUAL": EQUAL, From dc6c57d88d101bfde1e36e2b6d065654a9fe0f3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Krupski?= <34219324+ixior462@users.noreply.github.com> Date: Thu, 11 Jul 2024 23:16:34 +0200 Subject: [PATCH 11/21] Feature/error handling tests (#20) * Parser error handling wip * Add error handling for engine and finish parser --------- Co-authored-by: LissaGreense --- engine/column.go | 9 +- engine/engine.go | 131 +++++++++++----- engine/engine_error_handling_test.go | 117 ++++++++++++++ engine/errors.go | 106 +++++++++++++ engine/table.go | 4 + parser/errors.go | 100 ++++++++++++ parser/parser.go | 79 +++++----- parser/parser_error_handling_test.go | 218 +++++++++++++++++++++++++++ parser/parser_test.go | 49 +++--- 9 files changed, 712 insertions(+), 101 deletions(-) create mode 100644 engine/engine_error_handling_test.go create mode 100644 engine/errors.go create mode 100644 parser/errors.go create mode 100644 parser/parser_error_handling_test.go diff --git a/engine/column.go b/engine/column.go index 0b26de7..419b157 100644 --- a/engine/column.go +++ b/engine/column.go @@ -1,7 +1,6 @@ package engine import ( - "fmt" "github.com/LissaGreense/GO4SQL/token" ) @@ -12,7 +11,7 @@ type Column struct { Values []ValueInterface } -func extractColumnContent(columns []*Column, wantedColumnNames *[]string) (*Table, error) { +func extractColumnContent(columns []*Column, wantedColumnNames *[]string, tableName string) (*Table, error) { selectedTable := &Table{Columns: make([]*Column, 0)} mappedIndexes := make([]int, 0) for wantedColumnIndex := range *wantedColumnNames { @@ -22,7 +21,7 @@ func extractColumnContent(columns []*Column, wantedColumnNames *[]string) (*Tabl break } if columnNameIndex == len(columns)-1 { - return nil, fmt.Errorf("provided column name: %s doesn't exist", (*wantedColumnNames)[wantedColumnIndex]) + return nil, &ColumnDoesNotExistError{columnName: (*wantedColumnNames)[wantedColumnIndex], tableName: tableName} } } } @@ -34,6 +33,10 @@ func extractColumnContent(columns []*Column, wantedColumnNames *[]string) (*Tabl Values: make([]ValueInterface, 0), }) } + if len(columns) == 0 { + return selectedTable, nil + } + rowsCount := len(columns[0].Values) for iRow := 0; iRow < rowsCount; iRow++ { diff --git a/engine/engine.go b/engine/engine.go index d5b2bab..1c606ab 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -80,7 +80,7 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { result += "Table: '" + mappedCommand.Name.GetToken().Literal + "' has been updated\n" continue default: - return "", fmt.Errorf("unsupported Command detected: %v", command) + return "", &UnsupportedCommandTypeFromParserError{variable: fmt.Sprintf("%s", command)} } } @@ -132,7 +132,7 @@ func (engine *DbEngine) createTable(command *ast.CreateCommand) error { _, exist := engine.Tables[command.Name.Token.Literal] if exist { - return fmt.Errorf("table with the name of %s already exist", command.Name.Token.Literal) + return &TableAlreadyExistsError{command.Name.Token.Literal} } engine.Tables[command.Name.Token.Literal] = &Table{Columns: []*Column{}} @@ -151,7 +151,7 @@ func (engine *DbEngine) updateTable(command *ast.UpdateCommand) error { table, exist := engine.Tables[command.Name.Token.Literal] if !exist { - return fmt.Errorf("table with the name of %s doesn't exist", command.Name.Token.Literal) + return &TableDoesNotExistError{command.Name.Token.Literal} } columns := table.Columns @@ -165,7 +165,7 @@ func (engine *DbEngine) updateTable(command *ast.UpdateCommand) error { break } if colIndex == len(columns)-1 { - return fmt.Errorf("column with the name of %s doesn't exist in table %s", updatedCol.Literal, command.Name.GetToken().Literal) + return &ColumnDoesNotExistError{tableName: command.Name.GetToken().Literal, columnName: updatedCol.Literal} } } } @@ -173,7 +173,7 @@ func (engine *DbEngine) updateTable(command *ast.UpdateCommand) error { numberOfRows := len(columns[0].Values) for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { if command.HasWhereCommand() { - fulfilledFilters, err := isFulfillingFilters(getRow(table, rowIndex), command.WhereCommand.Expression) + fulfilledFilters, err := isFulfillingFilters(getRow(table, rowIndex), command.WhereCommand.Expression, command.WhereCommand.Token.Literal) if err != nil { return err } @@ -197,19 +197,19 @@ func (engine *DbEngine) updateTable(command *ast.UpdateCommand) error { func (engine *DbEngine) insertIntoTable(command *ast.InsertCommand) error { table, exist := engine.Tables[command.Name.Token.Literal] if !exist { - return fmt.Errorf("table with the name of %s doesn't exist", command.Name.Token.Literal) + return &TableDoesNotExistError{command.Name.Token.Literal} } columns := table.Columns if len(command.Values) != len(columns) { - return fmt.Errorf("invalid number of parameters in insert, should be: %d, but got: %d", len(columns), len(columns)) + return &InvalidNumberOfParametersError{expectedNumber: len(columns), actualNumber: len(command.Values), commandName: command.Token.Literal} } for i := range columns { expectedToken := tokenMapper(columns[i].Type.Type) if expectedToken != command.Values[i].Type { - return fmt.Errorf("invalid Token TokenType in Insert Command, expecting: %s, got: %s", expectedToken, command.Values[i].Type) + return &InvalidValueTypeError{expectedType: string(expectedToken), actualType: string(command.Values[i].Type), commandName: command.Token.Literal} } interfaceValue, err := getInterfaceValue(command.Values[i]) if err != nil { @@ -225,7 +225,7 @@ func (engine *DbEngine) selectFromTable(command *ast.SelectCommand) (*Table, err table, exist := engine.Tables[command.Name.Token.Literal] if !exist { - return nil, fmt.Errorf("table with the name of %s doesn't exist", command.Name.Token.Literal) + return nil, &TableDoesNotExistError{command.Name.Token.Literal} } return engine.selectFromProvidedTable(command, table) @@ -239,12 +239,12 @@ func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, tabl for i := 0; i < len(columns); i++ { wantedColumnNames = append(wantedColumnNames, columns[i].Name) } - return extractColumnContent(columns, &wantedColumnNames) + return extractColumnContent(columns, &wantedColumnNames, command.Name.GetToken().Literal) } else { for i := 0; i < len(command.Space); i++ { wantedColumnNames = append(wantedColumnNames, command.Space[i].Literal) } - return extractColumnContent(columns, unique(wantedColumnNames)) + return extractColumnContent(columns, unique(wantedColumnNames), command.Name.GetToken().Literal) } } @@ -253,10 +253,10 @@ func (engine *DbEngine) deleteFromTable(deleteCommand *ast.DeleteCommand, whereC table, exist := engine.Tables[deleteCommand.Name.Token.Literal] if !exist { - return fmt.Errorf("table with the name of %s doesn't exist", deleteCommand.Name.Token.Literal) + return &TableDoesNotExistError{deleteCommand.Name.Token.Literal} } - newTable, err := engine.getFilteredTable(table, whereCommand, true) + newTable, err := engine.getFilteredTable(table, whereCommand, true, deleteCommand.Name.Token.Literal) if err != nil { return err @@ -276,14 +276,14 @@ func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectComman table, exist := engine.Tables[selectCommand.Name.Token.Literal] if !exist { - return nil, fmt.Errorf("table with the name of %s doesn't exist", selectCommand.Name.Token.Literal) + return nil, &TableDoesNotExistError{selectCommand.Name.Token.Literal} } if len(table.Columns) == 0 || len(table.Columns[0].Values) == 0 { return engine.selectFromProvidedTable(selectCommand, &Table{Columns: []*Column{}}) } - filteredTable, err := engine.getFilteredTable(table, whereCommand, false) + filteredTable, err := engine.getFilteredTable(table, whereCommand, false, selectCommand.Name.GetToken().Literal) if err != nil { return nil, err @@ -298,10 +298,10 @@ func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.Se table, exist := engine.Tables[selectCommand.Name.Token.Literal] if !exist { - return nil, fmt.Errorf("table with the name of %s doesn't exist", selectCommand.Name.Token.Literal) + return nil, &TableDoesNotExistError{selectCommand.Name.Token.Literal} } - filteredTable, err := engine.getFilteredTable(table, whereCommand, false) + filteredTable, err := engine.getFilteredTable(table, whereCommand, false, selectCommand.Name.GetToken().Literal) if err != nil { return nil, err @@ -309,7 +309,13 @@ func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.Se emptyTable := getCopyOfTableWithoutRows(table) - return engine.selectFromProvidedTable(selectCommand, engine.getSortedTable(orderByCommand, filteredTable, emptyTable)) + sortedTable, err := engine.getSortedTable(orderByCommand, filteredTable, emptyTable, selectCommand.Name.GetToken().Literal) + + if err != nil { + return nil, err + } + + return engine.selectFromProvidedTable(selectCommand, sortedTable) } // selectFromTableWithOrderBy - Return Table containing all values requested by SelectCommand and sorted by OrderByCommand @@ -317,20 +323,37 @@ func (engine *DbEngine) selectFromTableWithOrderBy(selectCommand *ast.SelectComm table, exist := engine.Tables[selectCommand.Name.Token.Literal] if !exist { - return nil, fmt.Errorf("table with the name of %s doesn't exist", selectCommand.Name.Token.Literal) + return nil, &TableDoesNotExistError{selectCommand.Name.Token.Literal} } emptyTable := getCopyOfTableWithoutRows(table) - sortedTable := engine.getSortedTable(orderByCommand, table, emptyTable) + sortedTable, err := engine.getSortedTable(orderByCommand, table, emptyTable, selectCommand.Name.GetToken().Literal) + + if err != nil { + return nil, err + } return engine.selectFromProvidedTable(selectCommand, sortedTable) } -func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, filteredTable *Table, copyOfTable *Table) *Table { +func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, table *Table, copyOfTable *Table, tableName string) (*Table, error) { sortPatterns := orderByCommand.SortPatterns - rows := MapTableToRows(filteredTable).rows + columnNames := make([]string, 0) + for _, sortPattern := range sortPatterns { + columnNames = append(columnNames, sortPattern.ColumnName.Literal) + } + + missingColName := engine.getMissingColumnName(columnNames, table) + if missingColName != "" { + return nil, &ColumnDoesNotExistError{ + tableName: tableName, + columnName: missingColName, + } + } + + rows := MapTableToRows(table).rows sort.Slice(rows, func(i, j int) bool { howDeepWeSort := 0 @@ -360,14 +383,40 @@ func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, filte newColumn.Values = append(newColumn.Values, value) } } - return copyOfTable + return copyOfTable, nil } -func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCommand, negation bool) (*Table, error) { +func (engine *DbEngine) getMissingColumnName(columnNames []string, table *Table) string { + for _, columnName := range columnNames { + exists := false + for _, column := range table.Columns { + if column.Name == columnName { + exists = true + break + } + } + if !exists { + return columnName + } + } + return "" +} + +func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCommand, negation bool, tableName string) (*Table, error) { filteredTable := getCopyOfTableWithoutRows(table) + identifiers := whereCommand.Expression.GetIdentifiers() + columnNames := make([]string, 0) + for _, identifier := range identifiers { + columnNames = append(columnNames, identifier.Token.Literal) + } + missingColumnName := engine.getMissingColumnName(columnNames, table) + if missingColumnName != "" { + return nil, &ColumnDoesNotExistError{tableName: tableName, columnName: missingColumnName} + } + for _, row := range MapTableToRows(table).rows { - fulfilledFilters, err := isFulfillingFilters(row, whereCommand.Expression) + fulfilledFilters, err := isFulfillingFilters(row, whereCommand.Expression, whereCommand.Token.Literal) if err != nil { return nil, err } @@ -428,20 +477,20 @@ func getCopyOfTableWithoutRows(table *Table) *Table { return filteredTable } -func isFulfillingFilters(row map[string]ValueInterface, expressionTree ast.Expression) (bool, error) { +func isFulfillingFilters(row map[string]ValueInterface, expressionTree ast.Expression, commandName string) (bool, error) { switch mappedExpression := expressionTree.(type) { case *ast.OperationExpression: - return processOperationExpression(row, mappedExpression) + return processOperationExpression(row, mappedExpression, commandName) case *ast.BooleanExpression: return processBooleanExpression(mappedExpression) case *ast.ConditionExpression: - return processConditionExpression(row, mappedExpression) + return processConditionExpression(row, mappedExpression, commandName) default: - return false, fmt.Errorf("unsupported expression has been used in WHERE command: %v", expressionTree.GetIdentifiers()) + return false, &UnsupportedExpressionTypeError{commandName: commandName, variable: fmt.Sprintf("%s", mappedExpression)} } } -func processConditionExpression(row map[string]ValueInterface, conditionExpression *ast.ConditionExpression) (bool, error) { +func processConditionExpression(row map[string]ValueInterface, conditionExpression *ast.ConditionExpression, commandName string) (bool, error) { valueLeft, err := getTifierValue(conditionExpression.Left, row) if err != nil { return false, err @@ -458,32 +507,32 @@ func processConditionExpression(row map[string]ValueInterface, conditionExpressi case token.NOT: return !(valueLeft.IsEqual(valueRight)), nil default: - return false, fmt.Errorf("operation '%s' provided in WHERE command isn't allowed", conditionExpression.Condition.Literal) + return false, &UnsupportedConditionalTokenError{variable: conditionExpression.Condition.Literal, commandName: commandName} } } -func processOperationExpression(row map[string]ValueInterface, operationExpression *ast.OperationExpression) (bool, error) { +func processOperationExpression(row map[string]ValueInterface, operationExpression *ast.OperationExpression, commandName string) (bool, error) { if operationExpression.Operation.Type == token.AND { - left, err := isFulfillingFilters(row, operationExpression.Left) + left, err := isFulfillingFilters(row, operationExpression.Left, commandName) if !left { return left, err } - right, err := isFulfillingFilters(row, operationExpression.Right) + right, err := isFulfillingFilters(row, operationExpression.Right, commandName) return left && right, err } if operationExpression.Operation.Type == token.OR { - left, err := isFulfillingFilters(row, operationExpression.Left) + left, err := isFulfillingFilters(row, operationExpression.Left, commandName) if left { return left, err } - right, err := isFulfillingFilters(row, operationExpression.Right) + right, err := isFulfillingFilters(row, operationExpression.Right, commandName) return left || right, err } - return false, fmt.Errorf("unsupported operation token has been used: %s", operationExpression.Operation.Literal) + return false, &UnsupportedOperationTokenError{operationExpression.Operation.Literal} } func processBooleanExpression(booleanExpression *ast.BooleanExpression) (bool, error) { @@ -496,10 +545,14 @@ func processBooleanExpression(booleanExpression *ast.BooleanExpression) (bool, e func getTifierValue(tifier ast.Tifier, row map[string]ValueInterface) (ValueInterface, error) { switch mappedTifier := tifier.(type) { case ast.Identifier: - return row[mappedTifier.GetToken().Literal], nil + value, ok := row[mappedTifier.GetToken().Literal] + if ok == false { + return nil, &ColumnDoesNotExistError{tableName: "", columnName: mappedTifier.GetToken().Literal} + } + return value, nil case ast.Anonymitifier: return getInterfaceValue(mappedTifier.GetToken()) default: - return nil, fmt.Errorf("couldn't map interface to any implementation of it: %s", tifier.GetToken().Literal) + return nil, &UnsupportedValueType{tifier.GetToken().Literal} } } diff --git a/engine/engine_error_handling_test.go b/engine/engine_error_handling_test.go new file mode 100644 index 0000000..eeba2f0 --- /dev/null +++ b/engine/engine_error_handling_test.go @@ -0,0 +1,117 @@ +package engine + +import ( + "github.com/LissaGreense/GO4SQL/lexer" + "github.com/LissaGreense/GO4SQL/parser" + "github.com/LissaGreense/GO4SQL/token" + "testing" +) + +type errorHandlingTestSuite struct { + input string + expectedError string +} + +func TestEngineCreateCommandErrorHandling(t *testing.T) { + duplicateTableNameError := TableAlreadyExistsError{"table1"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE table1( one TEXT , two INT);CREATE TABLE table1(two INT);", duplicateTableNameError.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineInsertCommandErrorHandling(t *testing.T) { + tableDoNotExistError := TableDoesNotExistError{"table1"} + invalidNumberOfParametersError := InvalidNumberOfParametersError{expectedNumber: 2, actualNumber: 1, commandName: token.INSERT} + invalidParametersTypeError := InvalidValueTypeError{expectedType: token.IDENT, actualType: token.LITERAL, commandName: token.INSERT} + tests := []errorHandlingTestSuite{ + {"INSERT INTO table1 VALUES( 'hello', 1);", tableDoNotExistError.Error()}, + {"CREATE TABLE table1( one TEXT , two INT); INSERT INTO table1 VALUES(1);", invalidNumberOfParametersError.Error()}, + {"CREATE TABLE table1( one TEXT , two INT); INSERT INTO table1 VALUES(1, 1 );", invalidParametersTypeError.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineSelectCommandErrorHandling(t *testing.T) { + noTableDoesNotExist := TableDoesNotExistError{"tb1"} + columnDoesNotExist := ColumnDoesNotExistError{tableName: "tbl", columnName: "two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); SELECT * FROM tb1;", noTableDoesNotExist.Error()}, + {"CREATE TABLE tbl(one TEXT); SELECT two FROM tbl;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineDeleteCommandErrorHandling(t *testing.T) { + noTableDoesNotExist := TableDoesNotExistError{"tb1"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); DELETE FROM tb1 WHERE one EQUAL 3;", noTableDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineWhereCommandErrorHandling(t *testing.T) { + columnDoesNotExist := ColumnDoesNotExistError{tableName: "tbl", columnName: "two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); INSERT INTO tbl VALUES('hello'); SELECT * FROM tbl WHERE two EQUAL 3;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineUpdateCommandErrorHandling(t *testing.T) { + noTableDoesNotExist := TableDoesNotExistError{"tb1"} + columnDoesNotExist := ColumnDoesNotExistError{tableName: "tbl", columnName: "two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); UPDATE tb1 SET one TO 2;", noTableDoesNotExist.Error()}, + {"CREATE TABLE tbl(one TEXT);UPDATE tbl SET two TO 2;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineOrderByCommandErrorHandling(t *testing.T) { + columnDoesNotExist := ColumnDoesNotExistError{tableName: "tbl", columnName: "two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); SELECT * FROM tbl ORDER BY two ASC;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func runEngineErrorHandlingSuite(t *testing.T, suite []errorHandlingTestSuite) { + for i, test := range suite { + errorMsg := getErrorMessage(t, test.input, i) + + if errorMsg != test.expectedError { + t.Fatalf("[%v]Was expecting error: \n\t{%s},\n\tbut it was:\n\t{%s}", i, test.expectedError, errorMsg) + } + } +} + +func getErrorMessage(t *testing.T, input string, testIndex int) string { + lexerInstance := lexer.RunLexer(input) + parserInstance := parser.New(lexerInstance) + sequences, parserError := parserInstance.ParseSequence() + if parserError != nil { + t.Fatalf("[%d] Error has occured in parser not in engine, error: %s", testIndex, parserError.Error()) + } + + engine := New() + _, engineError := engine.Evaluate(sequences) + if engineError == nil { + t.Fatalf("[%d] Was expecting error from engine but there was none", testIndex) + } + + return engineError.Error() +} diff --git a/engine/errors.go b/engine/errors.go new file mode 100644 index 0000000..81910de --- /dev/null +++ b/engine/errors.go @@ -0,0 +1,106 @@ +package engine + +import "strconv" + +// TableAlreadyExistsError - error thrown when user tries to create table using name that already +// exists in database +type TableAlreadyExistsError struct { + tableName string +} + +func (m *TableAlreadyExistsError) Error() string { + return "table with the name of " + m.tableName + " already exists" +} + +// TableDoesNotExistError - error thrown when user tries to make operation on un-existing table +type TableDoesNotExistError struct { + tableName string +} + +func (m *TableDoesNotExistError) Error() string { + return "table with the name of " + m.tableName + " doesn't exist" +} + +// ColumnDoesNotExistError - error thrown when user tries to make operation on un-existing column +type ColumnDoesNotExistError struct { + tableName string + columnName string +} + +func (m *ColumnDoesNotExistError) Error() string { + return "column with the name of " + m.columnName + " doesn't exist in table " + m.tableName +} + +// InvalidNumberOfParametersError - error thrown when user provides invalid number of expected parameters +// (ex. fewer values in insert than defined ) +type InvalidNumberOfParametersError struct { + expectedNumber int + actualNumber int + commandName string +} + +func (m *InvalidNumberOfParametersError) Error() string { + return "invalid number of parameters in " + m.commandName + " command, should be: " + strconv.Itoa(m.expectedNumber) + ", but got: " + strconv.Itoa(m.actualNumber) +} + +// InvalidValueTypeError - error thrown when user provides value of different type than expected +type InvalidValueTypeError struct { + expectedType string + actualType string + commandName string +} + +func (m *InvalidValueTypeError) Error() string { + return "invalid value type provided in " + m.commandName + " command, expecting: " + m.expectedType + ", got: " + m.actualType +} + +// UnsupportedValueType - error thrown when engine found unsupported data type to be stored inside +// the columns +type UnsupportedValueType struct { + variable string +} + +func (m *UnsupportedValueType) Error() string { + return "couldn't map interface to any implementation of it: " + m.variable +} + +// UnsupportedOperationTokenError - error thrown when engine found unsupported operation token +// (supported are: AND, OR) +type UnsupportedOperationTokenError struct { + variable string +} + +func (m *UnsupportedOperationTokenError) Error() string { + return "unsupported operation token has been used: " + m.variable +} + +// UnsupportedConditionalTokenError - error thrown when engine found unsupported conditional token +// inside expression (supported are: EQUAL, NOT) +type UnsupportedConditionalTokenError struct { + variable string + commandName string +} + +func (m *UnsupportedConditionalTokenError) Error() string { + return "operation '" + m.variable + "' provided in " + m.commandName + " command isn't allowed" +} + +// UnsupportedExpressionTypeError - error thrown when engine found unsupported expression type +type UnsupportedExpressionTypeError struct { + variable string + commandName string +} + +func (m *UnsupportedExpressionTypeError) Error() string { + return "unsupported expression has been used in " + m.commandName + "command: " + m.variable +} + +// UnsupportedCommandTypeFromParserError - error thrown when engine found unsupported command +// from parser +type UnsupportedCommandTypeFromParserError struct { + variable string +} + +func (m *UnsupportedCommandTypeFromParserError) Error() string { + return "unsupported Command detected: " + m.variable +} diff --git a/engine/table.go b/engine/table.go index 2babc71..dbc58c7 100644 --- a/engine/table.go +++ b/engine/table.go @@ -52,6 +52,10 @@ func (table *Table) ToString() string { } result += "\n" + bar + "\n" + if len(table.Columns) == 0 { + return result + } + rowsCount := len(table.Columns[0].Values) for iRow := 0; iRow < rowsCount; iRow++ { diff --git a/parser/errors.go b/parser/errors.go new file mode 100644 index 0000000..7cc11b3 --- /dev/null +++ b/parser/errors.go @@ -0,0 +1,100 @@ +package parser + +// SyntaxError - error thrown when parser was expecting different token from lexer +type SyntaxError struct { + expecting []string + got string +} + +func (m *SyntaxError) Error() string { + var expectingText string + + if len(m.expecting) == 1 { + expectingText = m.expecting[0] + } else { + for i, expected := range m.expecting { + expectingText += expected + if i != len(m.expecting)-1 { + expectingText += ", " + } + } + } + + return "syntax error, expecting: {" + expectingText + "}, got: {" + m.got + "}" +} + +// SyntaxCommandExpectedError - error thrown when there was command that logically should only +// appear after certain different command, but it wasn't found +type SyntaxCommandExpectedError struct { + command string + neededCommands []string +} + +func (m *SyntaxCommandExpectedError) Error() string { + var neededCommandsText string + + if len(neededCommandsText) == 1 { + neededCommandsText = m.neededCommands[0] + " command" + } else if len(neededCommandsText) == 2 { + neededCommandsText = m.neededCommands[0] + " or " + m.neededCommands[1] + " commands" + } else { + for i, command := range m.neededCommands { + if i == len(m.neededCommands)-1 { + neededCommandsText += " or " + } + + neededCommandsText += command + + if i != len(m.neededCommands)-1 || i != len(m.neededCommands)-2 { + neededCommandsText += ", " + } + } + neededCommandsText += " commands" + } + + return "syntax error, {" + m.command + "} command needs {" + neededCommandsText + "} before" +} + +// SyntaxInvalidCommandError - error thrown when invalid (non-existing) type of command has been +// found +type SyntaxInvalidCommandError struct { + invalidCommand string +} + +func (m *SyntaxInvalidCommandError) Error() string { + return "syntax error, invalid command found: {" + m.invalidCommand + "}" +} + +// LogicalExpressionParsingError - error thrown when logical expression inside WHERE statement +// couldn't be parsed correctly +type LogicalExpressionParsingError struct { + afterToken *string +} + +func (m *LogicalExpressionParsingError) Error() string { + errorMsg := "syntax error, logical expression within WHERE command couldn't be parsed correctly" + if m.afterToken != nil { + return errorMsg + ", after {" + *m.afterToken + "} character" + } + return errorMsg +} + +// ArithmeticLessThanZeroParserError - error thrown when parser found integer value that shouldn't +// be less than 0, but it is +type ArithmeticLessThanZeroParserError struct { + variable string +} + +func (m *ArithmeticLessThanZeroParserError) Error() string { + return "syntax error, {" + m.variable + "} value should be more than 0" +} + +// NoPredecessorParserError - error thrown when parser found integer value that shouldn't +// be less than 0, but it is +type NoPredecessorParserError struct { + command string +} + +func (m *NoPredecessorParserError) Error() string { + return "syntax error, {" + m.command + "} command can't be used without predecessor" +} diff --git a/parser/parser.go b/parser/parser.go index 39262b5..b6491bb 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -1,7 +1,6 @@ package parser import ( - "fmt" "github.com/LissaGreense/GO4SQL/ast" "github.com/LissaGreense/GO4SQL/lexer" "github.com/LissaGreense/GO4SQL/token" @@ -49,13 +48,9 @@ func validateTokenAndSkip(parser *Parser, expectedTokens []token.Type) error { // validateToken - Check if current token type is appearing in provided expectedTokens array func validateToken(tokenType token.Type, expectedTokens []token.Type) error { var contains = false - var tokensPrintMessage = "" - for i, x := range expectedTokens { - if i == 0 { - tokensPrintMessage += string(x) - } else { - tokensPrintMessage += ", or: " + string(x) - } + expectedTokensStrings := make([]string, 0) + for _, x := range expectedTokens { + expectedTokensStrings = append(expectedTokensStrings, string(x)) if x == tokenType { contains = true @@ -63,7 +58,7 @@ func validateToken(tokenType token.Type, expectedTokens []token.Type) error { } } if !contains { - return fmt.Errorf("syntax error, expecting: %s, got: %s", tokensPrintMessage, string(tokenType)) + return &SyntaxError{expectedTokensStrings, string(tokenType)} } return nil } @@ -205,10 +200,12 @@ func (parser *Parser) parseInsertCommand() (ast.Command, error) { if err != nil { return nil, err } + err = validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) if err != nil { return nil, err } + return insertCommand, nil } @@ -223,6 +220,10 @@ func (parser *Parser) parseSelectCommand() (ast.Command, error) { // Ignore token.SELECT parser.nextToken() + err := validateToken(parser.currentToken.Type, []token.Type{token.ASTERISK, token.IDENT}) + if err != nil { + return nil, err + } if parser.currentToken.Type == token.ASTERISK { selectCommand.Space = append(selectCommand.Space, parser.currentToken) parser.nextToken() @@ -245,13 +246,18 @@ func (parser *Parser) parseSelectCommand() (ast.Command, error) { } } - err := validateTokenAndSkip(parser, []token.Type{token.FROM}) + err = validateTokenAndSkip(parser, []token.Type{token.FROM}) + if err != nil { + return nil, err + } + + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) if err != nil { return nil, err } selectCommand.Name = ast.Identifier{Token: parser.currentToken} - // Ignore token.INDENT + // Ignore token.IDENT parser.nextToken() // expect SEMICOLON or WHERE @@ -285,7 +291,7 @@ func (parser *Parser) parseWhereCommand() (ast.Command, error) { } if !expressionIsValid { - return nil, fmt.Errorf("expression withing Where statement couldn't be parsed correctly") + return nil, &LogicalExpressionParsingError{} } err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.ORDER}) @@ -354,9 +360,9 @@ func (parser *Parser) parseDropCommand() (ast.Command, error) { // token.IDENT no longer needed parser.nextToken() - parser.skipIfCurrentTokenIsSemicolon() + err = validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) - return dropCommand, nil + return dropCommand, err } // parseOrderByCommand - Return ast.OrderByCommand created from tokens and validate the syntax @@ -437,7 +443,7 @@ func (parser *Parser) parseLimitCommand() (ast.Command, error) { } if count < 0 { - return nil, fmt.Errorf("limit value should be more than 0") + return nil, &ArithmeticLessThanZeroParserError{variable: "limit"} } limitCommand.Count = count @@ -471,7 +477,7 @@ func (parser *Parser) parseOffsetCommand() (ast.Command, error) { return nil, err } if count < 0 { - return nil, fmt.Errorf("limit value should be more than 0") + return nil, &ArithmeticLessThanZeroParserError{variable: "offset"} } offsetCommand.Count = count @@ -502,16 +508,18 @@ func (parser *Parser) parseUpdateCommand() (ast.Command, error) { } updateCommand.Name = ast.Identifier{Token: parser.currentToken} - // Ignore token.INDENT + // Ignore token.IDENT parser.nextToken() - err = validateToken(parser.currentToken.Type, []token.Type{token.SET}) + err = validateTokenAndSkip(parser, []token.Type{token.SET}) if err != nil { return nil, err } - // Ignore token.SET - parser.nextToken() + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } updateCommand.Changes = make(map[token.Token]ast.Anonymitifier) for parser.currentToken.Type == token.IDENT { @@ -551,8 +559,11 @@ func (parser *Parser) parseUpdateCommand() (ast.Command, error) { parser.nextToken() } + err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE}) + if err != nil { + return nil, err + } parser.skipIfCurrentTokenIsSemicolon() - return updateCommand, nil } @@ -612,7 +623,7 @@ func (parser *Parser) getOperationExpression(booleanExpressionExists bool, condi return false, nil, err } if !expressionIsValid { - return false, nil, fmt.Errorf("couldn't parse right side of the OperationExpression after %s token", operationExpression.Operation.Literal) + return false, nil, &LogicalExpressionParsingError{afterToken: &operationExpression.Operation.Literal} } operationExpression.Right = expression @@ -683,7 +694,7 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() default: - return false, nil, fmt.Errorf("syntax error, expecting: { %s, %s, %s }, got: %s", token.APOSTROPHE, token.IDENT, token.LITERAL, parser.currentToken.Literal) + return false, nil, &SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL}, got: parser.currentToken.Literal} } return true, conditionalExpression, nil @@ -713,7 +724,7 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { case token.DROP: command, err = parser.parseDropCommand() case token.WHERE: - lastCommand, parserError := parser.getLastCommand(sequence) + lastCommand, parserError := parser.getLastCommand(sequence, token.WHERE) if parserError != nil { return nil, parserError } @@ -737,16 +748,16 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { } lastCommand.(*ast.UpdateCommand).WhereCommand = newCommand.(*ast.WhereCommand) } else { - return nil, fmt.Errorf("syntax error, WHERE command needs SELECT or DELETE command before") + return nil, &SyntaxCommandExpectedError{command: "WHERE", neededCommands: []string{"SELECT", "DELETE", "UPDATE"}} } case token.ORDER: - lastCommand, parserError := parser.getLastCommand(sequence) + lastCommand, parserError := parser.getLastCommand(sequence, token.ORDER) if parserError != nil { return nil, parserError } if lastCommand.TokenLiteral() != token.SELECT { - return nil, fmt.Errorf("syntax error, ORDER BY command needs SELECT command before") + return nil, &SyntaxCommandExpectedError{command: "ORDER BY", neededCommands: []string{"SELECT"}} } selectCommand := lastCommand.(*ast.SelectCommand) @@ -756,12 +767,12 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { } selectCommand.OrderByCommand = newCommand.(*ast.OrderByCommand) case token.LIMIT: - lastCommand, parserError := parser.getLastCommand(sequence) + lastCommand, parserError := parser.getLastCommand(sequence, token.LIMIT) if parserError != nil { return nil, parserError } if lastCommand.TokenLiteral() != token.SELECT { - return nil, fmt.Errorf("syntax error, LIMIT command needs SELECT command before") + return nil, &SyntaxCommandExpectedError{command: "LIMIT", neededCommands: []string{"SELECT"}} } selectCommand := lastCommand.(*ast.SelectCommand) newCommand, err := parser.parseLimitCommand() @@ -770,12 +781,12 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { } selectCommand.LimitCommand = newCommand.(*ast.LimitCommand) case token.OFFSET: - lastCommand, parserError := parser.getLastCommand(sequence) + lastCommand, parserError := parser.getLastCommand(sequence, token.OFFSET) if parserError != nil { return nil, parserError } if lastCommand.TokenLiteral() != token.SELECT { - return nil, fmt.Errorf("syntax error, OFFSET command needs SELECT command before") + return nil, &SyntaxCommandExpectedError{command: "OFFSET", neededCommands: []string{"SELECT"}} } selectCommand := lastCommand.(*ast.SelectCommand) newCommand, err := parser.parseOffsetCommand() @@ -784,7 +795,7 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { } selectCommand.OffsetCommand = newCommand.(*ast.OffsetCommand) default: - return nil, fmt.Errorf("syntax error, invalid command found: %s", parser.currentToken.Literal) + return nil, &SyntaxInvalidCommandError{invalidCommand: parser.currentToken.Literal} } if err != nil { @@ -800,9 +811,9 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { return sequence, nil } -func (parser *Parser) getLastCommand(sequence *ast.Sequence) (ast.Command, error) { +func (parser *Parser) getLastCommand(sequence *ast.Sequence, currentToken string) (ast.Command, error) { if len(sequence.Commands) == 0 { - return nil, fmt.Errorf("syntax error, Where Command can't be used without predecessor") + return nil, &NoPredecessorParserError{command: currentToken} } lastCommand := sequence.Commands[len(sequence.Commands)-1] return lastCommand, nil diff --git a/parser/parser_error_handling_test.go b/parser/parser_error_handling_test.go new file mode 100644 index 0000000..2768b3d --- /dev/null +++ b/parser/parser_error_handling_test.go @@ -0,0 +1,218 @@ +package parser + +import ( + "github.com/LissaGreense/GO4SQL/lexer" + "github.com/LissaGreense/GO4SQL/token" + "testing" +) + +type errorHandlingTestSuite struct { + input string + expectedError string +} + +func TestParseCreateCommandErrorHandling(t *testing.T) { + noTableKeyword := SyntaxError{[]string{token.TABLE}, token.IDENT} + noTableName := SyntaxError{[]string{token.IDENT}, token.LPAREN} + noLeftParen := SyntaxError{[]string{token.LPAREN}, token.IDENT} + noRightParen := SyntaxError{[]string{token.RPAREN}, token.SEMICOLON} + noColumnName := SyntaxError{[]string{token.RPAREN}, token.TEXT} + noColumnType := SyntaxError{[]string{token.TEXT, token.INT}, token.COMMA} + noSemicolon := SyntaxError{[]string{token.SEMICOLON}, ""} + + tests := []errorHandlingTestSuite{ + {"CREATE tbl(one TEXT);", noTableKeyword.Error()}, + {"CREATE TABLE (one TEXT);", noTableName.Error()}, + {"CREATE TABLE tbl one TEXT);", noLeftParen.Error()}, + {"CREATE TABLE tbl (one TEXT;", noRightParen.Error()}, + {"CREATE TABLE tbl (TEXT, two INT);", noColumnName.Error()}, + {"CREATE TABLE tbl (one , two INT);", noColumnType.Error()}, + {"CREATE TABLE tbl (one TEXT, two INT)", noSemicolon.Error()}, + } + + runParserErrorHandlingSuite(t, tests) + +} + +func TestParseDropCommandErrorHandling(t *testing.T) { + missingTableKeywordError := SyntaxError{expecting: []string{token.TABLE}, got: token.IDENT} + missingDropKeywordError := SyntaxInvalidCommandError{token.TABLE} + missingSemicolonError := &SyntaxError{expecting: []string{token.SEMICOLON}, got: ""} + invalidIdentError := &SyntaxError{expecting: []string{token.IDENT}, got: token.LITERAL} + tests := []errorHandlingTestSuite{ + {input: "DROP table;", expectedError: missingTableKeywordError.Error()}, + {input: "TABLE table;", expectedError: missingDropKeywordError.Error()}, + {input: "DROP TABLE table", expectedError: missingSemicolonError.Error()}, + {input: "DROP TABLE 2;", expectedError: invalidIdentError.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseInsertCommandErrorHandling(t *testing.T) { + noIntoKeyword := SyntaxError{[]string{token.INTO}, token.IDENT} + noTableName := SyntaxError{[]string{token.IDENT}, token.VALUES} + noLeftParen := SyntaxError{[]string{token.LPAREN}, token.APOSTROPHE} + noValue := SyntaxError{[]string{token.IDENT, token.LITERAL}, token.APOSTROPHE} + noRightParen := SyntaxError{[]string{token.RPAREN}, token.SEMICOLON} + noSemicolon := SyntaxError{[]string{token.SEMICOLON}, ""} + + tests := []errorHandlingTestSuite{ + {"INSERT tbl VALUES( 'hello', 10);", noIntoKeyword.Error()}, + {"INSERT INTO VALUES( 'hello', 10);", noTableName.Error()}, + {"INSERT INTO tl VALUES 'hello', 10);", noLeftParen.Error()}, + {"INSERT INTO tl VALUES ('', 10);", noValue.Error()}, + {"INSERT INTO tl VALUES ('hello', 10;", noRightParen.Error()}, + {"INSERT INTO tl VALUES ('hello', 10)", noSemicolon.Error()}, + } + + runParserErrorHandlingSuite(t, tests) + +} + +func TestParseUpdateCommandErrorHandling(t *testing.T) { + notableName := SyntaxError{expecting: []string{token.IDENT}, got: token.SEMICOLON} + noSetKeyword := SyntaxError{expecting: []string{token.SET}, got: token.SEMICOLON} + noColumnName := SyntaxError{expecting: []string{token.IDENT}, got: token.LITERAL} + noToKeyword := SyntaxError{expecting: []string{token.TO}, got: token.SEMICOLON} + noSecondIdentOrLiteralForValue := SyntaxError{expecting: []string{token.IDENT, token.LITERAL}, got: token.SEMICOLON} + noCommaBetweenValues := SyntaxError{expecting: []string{token.SEMICOLON, token.WHERE}, got: token.IDENT} + noWhereOrSemicolon := SyntaxError{expecting: []string{token.SEMICOLON, token.WHERE}, got: token.SELECT} + + tests := []errorHandlingTestSuite{ + {"UPDATE;", notableName.Error()}, + {"UPDATE table;", noSetKeyword.Error()}, + {"UPDATE table SET 2;", noColumnName.Error()}, + {"UPDATE table SET column_name_1;", noToKeyword.Error()}, + {"UPDATE table SET column_name_1 TO;", noSecondIdentOrLiteralForValue.Error()}, + {"UPDATE table SET column_name_1 TO 2 column_name_1 TO 3;", noCommaBetweenValues.Error()}, + {"UPDATE table SET column_name_1 TO 'new_value_1' SELECT;", noWhereOrSemicolon.Error()}, + } + + runParserErrorHandlingSuite(t, tests) + +} + +func TestParseSelectCommandErrorHandling(t *testing.T) { + noFromKeyword := SyntaxError{[]string{token.FROM}, token.IDENT} + noColumns := SyntaxError{[]string{token.ASTERISK, token.IDENT}, token.FROM} + noTableName := SyntaxError{[]string{token.IDENT}, token.SEMICOLON} + noSemicolon := SyntaxError{[]string{token.SEMICOLON, token.WHERE, token.ORDER, token.LIMIT, token.OFFSET}, ""} + + tests := []errorHandlingTestSuite{ + {"SELECT column1, column2 tbl;", noFromKeyword.Error()}, + {"SELECT FROM table;", noColumns.Error()}, + {"SELECT column1, column2 FROM ;", noTableName.Error()}, + {"SELECT column1, column2 FROM table", noSemicolon.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseWhereCommandErrorHandling(t *testing.T) { + selectCommandPrefix := "SELECT * FROM tbl " + noPredecessorError := NoPredecessorParserError{command: token.WHERE} + noColName := LogicalExpressionParsingError{} + notOrEqualIsMissing := SyntaxError{expecting: []string{token.EQUAL, token.NOT}, got: token.APOSTROPHE} + valueIsMissing := SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL}, got: token.SEMICOLON} + tokenAnd := token.AND + conjunctionIsMissing := SyntaxError{expecting: []string{token.SEMICOLON, token.ORDER}, got: token.IDENT} + nextLogicalExpressionIsMissing := LogicalExpressionParsingError{afterToken: &tokenAnd} + noSemicolon := SyntaxError{expecting: []string{token.SEMICOLON, token.ORDER}, got: ""} + + tests := []errorHandlingTestSuite{ + {"WHERE col1 NOT 'goodbye' OR col2 EQUAL 3;", noPredecessorError.Error()}, + {selectCommandPrefix + "WHERE NOT 'goodbye' OR column2 EQUAL 3;", noColName.Error()}, + {selectCommandPrefix + "WHERE one 'goodbye';", notOrEqualIsMissing.Error()}, + {selectCommandPrefix + "WHERE one EQUAL;", valueIsMissing.Error()}, + {selectCommandPrefix + "WHERE one EQUAL 5 two NOT 1;", conjunctionIsMissing.Error()}, + {selectCommandPrefix + "WHERE one EQUAL 5 AND;", nextLogicalExpressionIsMissing.Error()}, + {selectCommandPrefix + "WHERE one EQUAL 5 AND two NOT 5", noSemicolon.Error()}, + } + + runParserErrorHandlingSuite(t, tests) + +} + +func TestParseOrderByCommandErrorHandling(t *testing.T) { + selectCommandPrefix := "SELECT * FROM tbl " + noPredecessorError := NoPredecessorParserError{command: token.ORDER} + noAscDescError := SyntaxError{expecting: []string{token.ASC, token.DESC}, got: token.SEMICOLON} + noByKeywordError := SyntaxError{expecting: []string{token.BY}, got: token.IDENT} + noIdentKeywordError := SyntaxError{expecting: []string{token.IDENT}, got: token.ASC} + + tests := []errorHandlingTestSuite{ + {"ORDER BY column1;", noPredecessorError.Error()}, + {selectCommandPrefix + "ORDER BY column1;", noAscDescError.Error()}, + {selectCommandPrefix + "ORDER column1 ASC;", noByKeywordError.Error()}, + {selectCommandPrefix + "ORDER BY ASC;", noIdentKeywordError.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseLimitCommandErrorHandling(t *testing.T) { + selectCommandPrefix := "SELECT * FROM tbl " + noPredecessorError := NoPredecessorParserError{command: token.LIMIT} + noLiteralError := SyntaxError{expecting: []string{token.LITERAL}, got: token.SEMICOLON} + lessThanZeroError := ArithmeticLessThanZeroParserError{variable: "limit"} + + tests := []errorHandlingTestSuite{ + {"LIMIT 5;", noPredecessorError.Error()}, + {selectCommandPrefix + "LIMIT;", noLiteralError.Error()}, + {selectCommandPrefix + "LIMIT -10;", lessThanZeroError.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseOffsetCommandErrorHandling(t *testing.T) { + selectCommandPrefix := "SELECT * FROM tbl " + noPredecessorError := NoPredecessorParserError{command: token.OFFSET} + noLiteralError := SyntaxError{expecting: []string{token.LITERAL}, got: token.IDENT} + lessThanZeroError := ArithmeticLessThanZeroParserError{variable: "offset"} + + tests := []errorHandlingTestSuite{ + {"OFFSET 5;", noPredecessorError.Error()}, + {selectCommandPrefix + "OFFSET hi;", noLiteralError.Error()}, + {selectCommandPrefix + "OFFSET -10;", lessThanZeroError.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseDeleteCommandErrorHandling(t *testing.T) { + noFromKeyword := SyntaxError{[]string{token.FROM}, token.IDENT} + noTableName := SyntaxError{[]string{token.IDENT}, token.WHERE} + noWhereCommand := SyntaxError{[]string{token.WHERE}, ";"} + + tests := []errorHandlingTestSuite{ + {"DELETE table WHERE TRUE", noFromKeyword.Error()}, + {"DELETE FROM WHERE TRUE;", noTableName.Error()}, + {"DELETE FROM table;", noWhereCommand.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func runParserErrorHandlingSuite(t *testing.T, suite []errorHandlingTestSuite) { + for i, test := range suite { + errorMsg := getErrorMessage(t, test.input, i) + + if errorMsg != test.expectedError { + t.Fatalf("[%v]Was expecting error: \n\t{%s},\n\tbut it was:\n\t{%s}", i, test.expectedError, errorMsg) + } + } +} + +func getErrorMessage(t *testing.T, input string, testIndex int) string { + lexerInstance := lexer.RunLexer(input) + parserInstance := New(lexerInstance) + _, err := parserInstance.ParseSequence() + + if err == nil { + t.Fatalf("[%v]Was expecting error from parser but there was none", testIndex) + } + + return err.Error() +} diff --git a/parser/parser_test.go b/parser/parser_test.go index b86709b..366f969 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -20,16 +20,16 @@ func TestParserCreateCommand(t *testing.T) { {"CREATE TABLE TBL( );", "TBL", []string{}, []token.Token{}}, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) sequences, err := parserInstance.ParseSequence() if err != nil { - t.Fatalf("Got error from parser: %s", err) + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } if !testCreateStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedColumnNames, tt.expectedColumTypes) { @@ -79,16 +79,16 @@ func TestParseInsertCommand(t *testing.T) { {"INSERT INTO TBL VALUES( 'HELLO', 10 , 'LOL');", "TBL", []token.Token{{Type: token.IDENT, Literal: "HELLO"}, {Type: token.LITERAL, Literal: "10"}, {Type: token.IDENT, Literal: "LOL"}}}, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) sequences, err := parserInstance.ParseSequence() if err != nil { - t.Fatalf("Got error from parser: %s", err) + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } if !testInsertStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedValuesTokens) { @@ -130,19 +130,18 @@ func TestParseSelectCommand(t *testing.T) { }{ {"SELECT * FROM TBL;", "TBL", []token.Token{{Type: token.ASTERISK, Literal: "*"}}}, {"SELECT ONE, TWO, THREE FROM TBL;", "TBL", []token.Token{{Type: token.IDENT, Literal: "ONE"}, {Type: token.IDENT, Literal: "TWO"}, {Type: token.IDENT, Literal: "THREE"}}}, - {"SELECT FROM TBL;", "TBL", []token.Token{}}, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) sequences, err := parserInstance.ParseSequence() if err != nil { - t.Fatalf("Got error from parser: %s", err) + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } if !testSelectStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedColumns) { @@ -178,21 +177,21 @@ func TestParseWhereCommand(t *testing.T) { }, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) sequences, err := parserInstance.ParseSequence() if err != nil { - t.Fatalf("Got error from parser: %s", err) + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements, got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements, got=%d", testIndex, len(sequences.Commands)) } selectCommand := sequences.Commands[0].(*ast.SelectCommand) if !selectCommand.HasWhereCommand() { - t.Fatalf("sequences does not contain where command") + t.Fatalf("[%d] sequences does not contain where command", testIndex) } if !whereStatementIsValid(t, selectCommand.WhereCommand, tt.expectedExpression) { @@ -439,16 +438,16 @@ func TestParseUpdateCommand(t *testing.T) { }, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) sequences, err := parserInstance.ParseSequence() if err != nil { - t.Fatalf("Got error from parser: %s", err) + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } if !testUpdateStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedChanges) { @@ -478,7 +477,7 @@ func TestParseUpdateCommandWithWhere(t *testing.T) { }, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) sequences, err := parserInstance.ParseSequence() @@ -487,13 +486,13 @@ func TestParseUpdateCommandWithWhere(t *testing.T) { } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } actualUpdateCommand, ok := sequences.Commands[0].(*ast.UpdateCommand) if !ok { - t.Errorf("actualUpdateCommand is not %T. got=%T", &ast.UpdateCommand{}, sequences.Commands[0]) + t.Errorf("[%d] actualUpdateCommand is not %T. got=%T", testIndex, &ast.UpdateCommand{}, sequences.Commands[0]) } if !testUpdateStatement(t, actualUpdateCommand, tt.expectedTableName, tt.expectedChanges) { @@ -501,7 +500,7 @@ func TestParseUpdateCommandWithWhere(t *testing.T) { } if !actualUpdateCommand.HasWhereCommand() { - t.Errorf("actualUpdateCommand should have where command") + t.Errorf("[%d] actualUpdateCommand should have where command", testIndex) } if !whereStatementIsValid(t, actualUpdateCommand.WhereCommand, tt.expectedWhereCommand) { @@ -558,7 +557,7 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { }, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) sequences, err := parserInstance.ParseSequence() @@ -567,17 +566,17 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } selectCommand := sequences.Commands[0].(*ast.SelectCommand) if !selectCommand.HasWhereCommand() { - t.Fatalf("sequences does not contain where command") + t.Fatalf("[%d] sequences does not contain where command", testIndex) } if !whereStatementIsValid(t, selectCommand.WhereCommand, tt.expectedExpression) { - t.Fatalf("Actual expression and expected one are different") + t.Fatalf("[%d] Actual expression and expected one are different", testIndex) } } } From fcb01188e6eaa7dcfec5825246e1d0c552d0e86e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sara=20Ryfczy=C5=84ska?= Date: Tue, 6 Aug 2024 22:21:19 +0200 Subject: [PATCH 12/21] Add Distinct select implementation (#21) --- .github/expected_results/end2end.txt | 7 +++ README.md | 8 +++ ast/ast.go | 1 + engine/engine.go | 4 ++ engine/engine_test.go | 24 ++++++++ engine/table.go | 39 +++++++++++- lexer/lexer_test.go | 18 ++++++ parser/parser.go | 8 +++ parser/parser_test.go | 25 +++++--- test_file | 2 + token/token.go | 92 ++++++++++++++-------------- 11 files changed, 172 insertions(+), 56 deletions(-) diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt index 5e8f803..009a9fe 100644 --- a/.github/expected_results/end2end.txt +++ b/.github/expected_results/end2end.txt @@ -58,4 +58,11 @@ Table: 'tbl' has been updated | 'hello' | 1 | 11 | 'q' | | 'goodbye' | 5 | 22 | 'P' | +-----------+-----+-------+------+ +Data Inserted ++-----------+-----+-------+------+ +| one | two | three | four | ++-----------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | 5 | 22 | 'P' | ++-----------+-----+-------+------+ Table: 'tbl' has been dropped diff --git a/README.md b/README.md index 85adbb4..3e59d60 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,14 @@ Currently, there are 3 modes to chose from: In this case, this command will order by ``column1`` in ascending order and skip 3 first records, then return records from 4th to 8th. +* ***DISTINCT*** is used to return only distinct (different) values in returned output with + ``SELECT`` like this: + ```sql + SELECT DISTINCT column1, column2, + FROM table_name; + ``` + In this case, this command will return only unique rows from ``table_name`` table. + ## UNIT TESTS To run all the tests locally run this in root directory: diff --git a/ast/ast.go b/ast/ast.go index 8d99aa7..ca9b332 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -161,6 +161,7 @@ type SelectCommand struct { Token token.Token Name Identifier // ex. name of table Space []token.Token // ex. column names + HasDistinct bool // DISTINCT keyword has been used WhereCommand *WhereCommand // optional OrderByCommand *OrderByCommand // optional LimitCommand *LimitCommand // optional diff --git a/engine/engine.go b/engine/engine.go index 1c606ab..4cba86e 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -124,6 +124,10 @@ func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) (*Ta table.applyOffsetAndLimit(selectCommand) } + if selectCommand.HasDistinct { + table = table.getDistinctTable() + } + return table, nil } diff --git a/engine/engine_test.go b/engine/engine_test.go index 907fecd..18d7be2 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -232,6 +232,30 @@ func TestSelectWithWhereEqualToFalse(t *testing.T) { engineTestSuite.runTestSuite(t) } +func TestDistinctSelect(t *testing.T) { + + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + }, + selectInput: "SELECT DISTINCT * FROM tb1;", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"hello", "1", "11", "q"}, + {"goodbye", "2", "22", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + func TestDelete(t *testing.T) { engineTestSuite := engineTableContentTestSuite{ diff --git a/engine/table.go b/engine/table.go index dbc58c7..262d2a8 100644 --- a/engine/table.go +++ b/engine/table.go @@ -1,6 +1,9 @@ package engine -import "github.com/LissaGreense/GO4SQL/token" +import ( + "github.com/LissaGreense/GO4SQL/token" + "hash/adler32" +) // Table - Contain Columns that store values in engine type Table struct { @@ -35,6 +38,38 @@ func (table *Table) isEqual(secondTable *Table) bool { return true } +// getDistinctTable - Takes input table, and returns new one without any duplicates +func (table *Table) getDistinctTable() *Table { + distinctTable := getCopyOfTableWithoutRows(table) + + rowsCount := len(table.Columns[0].Values) + + checksumSet := map[uint32]struct{}{} + + for iRow := 0; iRow < rowsCount; iRow++ { + + mergedColumnValues := "" + for iColumn := range table.Columns { + fieldValue := table.Columns[iColumn].Values[iRow].ToString() + if table.Columns[iColumn].Type.Literal == token.TEXT { + fieldValue = "'" + fieldValue + "'" + } + mergedColumnValues += fieldValue + } + checksum := adler32.Checksum([]byte(mergedColumnValues)) + + _, exist := checksumSet[checksum] + if !exist { + checksumSet[checksum] = struct{}{} + for i, column := range distinctTable.Columns { + column.Values = append(column.Values, table.Columns[i].Values[iRow]) + } + } + } + + return distinctTable +} + // ToString - Return string contain all values and Column names in Table func (table *Table) ToString() string { columWidths := getColumWidths(table.Columns) @@ -103,7 +138,7 @@ func getColumWidths(columns []*Column) []int { for iRow := range columns[iColumn].Values { valueLength := len(columns[iColumn].Values[iRow].ToString()) if columns[iColumn].Type.Literal == token.TEXT { - valueLength += 2 // double "'" + valueLength += 2 // double ' } if valueLength > maxLength { maxLength = valueLength diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index b2b154f..5c4c0b1 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -302,6 +302,24 @@ func TestLimitAndOffsetStatement(t *testing.T) { runLexerTestSuite(t, input, tests) } +func TestSelectWithDistinct(t *testing.T) { + input := `SELECT DISTINCT * FROM table;` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.DISTINCT, "DISTINCT"}, + {token.ASTERISK, "*"}, + {token.FROM, "FROM"}, + {token.IDENT, "table"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + func runLexerTestSuite(t *testing.T, input string, tests []struct { expectedType token.Type expectedLiteral string diff --git a/parser/parser.go b/parser/parser.go index b6491bb..15b46d3 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -220,6 +220,14 @@ func (parser *Parser) parseSelectCommand() (ast.Command, error) { // Ignore token.SELECT parser.nextToken() + // optional DISTINCT + if parser.currentToken.Type == token.DISTINCT { + selectCommand.HasDistinct = true + + // Ignore token.DISTINCT + parser.nextToken() + } + err := validateToken(parser.currentToken.Type, []token.Type{token.ASTERISK, token.IDENT}) if err != nil { return nil, err diff --git a/parser/parser_test.go b/parser/parser_test.go index 366f969..92fa217 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -127,9 +127,11 @@ func TestParseSelectCommand(t *testing.T) { input string expectedTableName string expectedColumns []token.Token + expectedDistinct bool }{ - {"SELECT * FROM TBL;", "TBL", []token.Token{{Type: token.ASTERISK, Literal: "*"}}}, - {"SELECT ONE, TWO, THREE FROM TBL;", "TBL", []token.Token{{Type: token.IDENT, Literal: "ONE"}, {Type: token.IDENT, Literal: "TWO"}, {Type: token.IDENT, Literal: "THREE"}}}, + {"SELECT * FROM TBL;", "TBL", []token.Token{{Type: token.ASTERISK, Literal: "*"}}, false}, + {"SELECT ONE, TWO, THREE FROM TBL;", "TBL", []token.Token{{Type: token.IDENT, Literal: "ONE"}, {Type: token.IDENT, Literal: "TWO"}, {Type: token.IDENT, Literal: "THREE"}}, false}, + {"SELECT DISTINCT * FROM TBL;", "TBL", []token.Token{{Type: token.ASTERISK, Literal: "*"}}, true}, } for testIndex, tt := range tests { @@ -144,7 +146,7 @@ func TestParseSelectCommand(t *testing.T) { t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } - if !testSelectStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedColumns) { + if !testSelectStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedColumns, tt.expectedDistinct) { return } } @@ -303,7 +305,7 @@ func TestSelectWithOrderByCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { return } @@ -336,7 +338,7 @@ func TestSelectWithLimitCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { return } @@ -370,7 +372,7 @@ func TestSelectWithOffsetCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { return } @@ -406,7 +408,7 @@ func TestSelectWithLimitAndOffsetCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { return } @@ -581,7 +583,7 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { } } -func testSelectStatement(t *testing.T, command ast.Command, expectedTableName string, expectedColumnsTokens []token.Token) bool { +func testSelectStatement(t *testing.T, command ast.Command, expectedTableName string, expectedColumnsTokens []token.Token, expectedDistinct bool) bool { if command.TokenLiteral() != "SELECT" { t.Errorf("command.TokenLiteral() not 'SELECT'. got=%q", command.TokenLiteral()) return false @@ -598,8 +600,13 @@ func testSelectStatement(t *testing.T, command ast.Command, expectedTableName st return false } + if actualSelectCommand.HasDistinct != expectedDistinct { + t.Errorf("HasDistinct should be set to %t, got=%t", expectedDistinct, actualSelectCommand.HasDistinct) + return false + } + if !tokenArrayEquals(actualSelectCommand.Space, expectedColumnsTokens) { - t.Errorf("") + t.Errorf("actualSelectCommand has diffrent space tan expected. %v != %v", actualSelectCommand.Space, expectedColumnsTokens) return false } diff --git a/test_file b/test_file index 493a292..ad768bd 100644 --- a/test_file +++ b/test_file @@ -14,4 +14,6 @@ SELECT one FROM tbl WHERE TRUE ORDER BY two ASC, four DESC; UPDATE tbl SET two TO 5, four TO 'P' WHERE one EQUAL 'goodbye'; SELECT * FROM tbl; + INSERT INTO tbl VALUES( 'goodbye', 5, 22, 'P' ); + SELECT DISTINCT * FROM tbl; DROP TABLE tbl; diff --git a/token/token.go b/token/token.go index 327a21c..913bc74 100644 --- a/token/token.go +++ b/token/token.go @@ -29,24 +29,25 @@ const ( RPAREN = ")" // CREATE - Keywords - CREATE = "CREATE" - DROP = "DROP" - TABLE = "TABLE" - INSERT = "INSERT" - INTO = "INTO" - VALUES = "VALUES" - SELECT = "SELECT" - FROM = "FROM" - WHERE = "WHERE" - DELETE = "DELETE" - ORDER = "ORDER" - BY = "BY" - ASC = "ASC" - DESC = "DESC" - LIMIT = "LIMIT" - OFFSET = "OFFSET" - UPDATE = "UPDATE" - SET = "SET" + CREATE = "CREATE" + DROP = "DROP" + TABLE = "TABLE" + INSERT = "INSERT" + INTO = "INTO" + VALUES = "VALUES" + SELECT = "SELECT" + FROM = "FROM" + WHERE = "WHERE" + DELETE = "DELETE" + ORDER = "ORDER" + BY = "BY" + ASC = "ASC" + DESC = "DESC" + LIMIT = "LIMIT" + OFFSET = "OFFSET" + UPDATE = "UPDATE" + SET = "SET" + DISTINCT = "DISTINCT" TO = "TO" @@ -67,33 +68,34 @@ const ( ) var keywords = map[string]Type{ - "TEXT": TEXT, - "INT": INT, - "CREATE": CREATE, - "DROP": DROP, - "TABLE": TABLE, - "INSERT": INSERT, - "INTO": INTO, - "SELECT": SELECT, - "FROM": FROM, - "DELETE": DELETE, - "ORDER": ORDER, - "BY": BY, - "ASC": ASC, - "DESC": DESC, - "LIMIT": LIMIT, - "OFFSET": OFFSET, - "UPDATE": UPDATE, - "SET": SET, - "TO": TO, - "VALUES": VALUES, - "WHERE": WHERE, - "EQUAL": EQUAL, - "NOT": NOT, - "AND": AND, - "OR": OR, - "TRUE": TRUE, - "FALSE": FALSE, + "TEXT": TEXT, + "INT": INT, + "CREATE": CREATE, + "DROP": DROP, + "TABLE": TABLE, + "INSERT": INSERT, + "INTO": INTO, + "SELECT": SELECT, + "FROM": FROM, + "DELETE": DELETE, + "ORDER": ORDER, + "BY": BY, + "ASC": ASC, + "DESC": DESC, + "LIMIT": LIMIT, + "OFFSET": OFFSET, + "UPDATE": UPDATE, + "SET": SET, + "DISTINCT": DISTINCT, + "TO": TO, + "VALUES": VALUES, + "WHERE": WHERE, + "EQUAL": EQUAL, + "NOT": NOT, + "AND": AND, + "OR": OR, + "TRUE": TRUE, + "FALSE": FALSE, } // LookupIdent - Return keyword type from defined list if exists, otherwise it returns IDENT type From efa0622e88ce145f42e97bca14907d54413a9cb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Krupski?= <34219324+ixior462@users.noreply.github.com> Date: Tue, 8 Oct 2024 23:23:55 +0200 Subject: [PATCH 13/21] Feature/joins (#22) JOIN FEATURE * Add full join implementation * Add full join error handling tests in engine * Add documentation, lexer, parser and only tests for engine, handling full,inner,left,right joins * Add enginge implemenatation and e2e tests for full,inner,left and right joins * Refactore join method --- .github/expected_results/end2end.txt | 30 ++++ .github/workflows/end2end-tests.yml | 2 +- .github/workflows/unit-tests.yml | 2 +- README.md | 91 ++++++++++-- ast/ast.go | 40 ++++- engine/engine.go | 142 +++++++++++++----- engine/engine_error_handling_test.go | 14 ++ engine/engine_test.go | 142 ++++++++++++++++++ engine/generic_value.go | 31 ++++ engine/generic_value_test.go | 16 +- engine/row.go | 8 + engine/table.go | 18 ++- go.mod | 2 +- lexer/lexer_test.go | 134 +++++++++++++++++ parser/errors.go | 9 ++ parser/parser.go | 87 ++++++++++- parser/parser_error_handling_test.go | 14 +- parser/parser_test.go | 211 +++++++++++++++++++++++++++ test_file | 10 ++ token/token.go | 12 ++ 20 files changed, 950 insertions(+), 65 deletions(-) diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt index 009a9fe..2a438cb 100644 --- a/.github/expected_results/end2end.txt +++ b/.github/expected_results/end2end.txt @@ -66,3 +66,33 @@ Data Inserted | 'goodbye' | 5 | 22 | 'P' | +-----------+-----+-------+------+ Table: 'tbl' has been dropped +Table 'table1' has been created +Table 'table2' has been created +Data Inserted +Data Inserted +Data Inserted +Data Inserted ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| 'Value1' | NULL | +| 'Value2' | 'Value2' | +| NULL | 'Value3' | ++--------------+--------------+ ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| 'Value2' | 'Value2' | ++--------------+--------------+ ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| 'Value1' | NULL | +| 'Value2' | 'Value2' | ++--------------+--------------+ ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| 'Value2' | 'Value2' | +| NULL | 'Value3' | ++--------------+--------------+ diff --git a/.github/workflows/end2end-tests.yml b/.github/workflows/end2end-tests.yml index 43e111b..66d7bf6 100644 --- a/.github/workflows/end2end-tests.yml +++ b/.github/workflows/end2end-tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.16.15', '1.17.11', '1.18.10', '1.19.13', '1.20.14', '1.21.9', '1.22.3' ] + go: [ '1.21.13', '1.22.7', '1.23.1' ] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 3648159..835a56b 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.16.15', '1.17.11' ] + go: [ '1.21.13', '1.22.7', '1.23.1' ] steps: - uses: actions/checkout@v3 diff --git a/README.md b/README.md index 3e59d60..f78dadb 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,19 @@ Currently, there are 3 modes to chose from: 3. `Socket Mode` - To start Socket Server use `./GO4SQL -socket`, it will be listening on port `1433` by default. To choose port different other than that, for example equal to `1444`, go with: `./GO4SQL -socket -port 1444` +## UNIT TESTS + +To run all the tests locally paste this in root directory: + +```shell +go clean -testcache; go test ./... +``` + ### Docker + 1. Pull docker image: `docker pull kajedot/go4sql:latest` -2. Run docker container in the interactive mode, remember to provide flag, for example: `docker run -i kajedot/go4sql -stream` +2. Run docker container in the interactive mode, remember to provide flag, for example: + `docker run -i kajedot/go4sql -stream` 3. You can test this image with `test_file` provided in this repo: `docker run -i kajedot/go4sql -stream < test_file` ## FUNCTIONALITY @@ -51,7 +61,7 @@ Currently, there are 3 modes to chose from: ```sql DROP TABLE table1; ``` - After using this command table1 will no longer be available and all data connected to it (column + After using this command table1 will no longer be available and all data connected to it (column definitions and inserted values) will be lost. @@ -129,10 +139,10 @@ Currently, there are 3 modes to chose from: ORDER BY column1 ASC LIMIT 5 OFFSET 3; ``` - In this case, this command will order by ``column1`` in ascending order and skip 3 first records, + In this case, this command will order by ``column1`` in ascending order and skip 3 first records, then return records from 4th to 8th. -* ***DISTINCT*** is used to return only distinct (different) values in returned output with +* ***DISTINCT*** is used to return only distinct (different) values in returned output with ``SELECT`` like this: ```sql SELECT DISTINCT column1, column2, @@ -140,13 +150,64 @@ Currently, there are 3 modes to chose from: ``` In this case, this command will return only unique rows from ``table_name`` table. -## UNIT TESTS - -To run all the tests locally run this in root directory: - -```shell -go clean -testcache; go test ./... -``` +* ***INNER JOIN*** is used to return a new table by combining rows from both tables where there is a match on the + specified condition. Only the rows that satisfy the condition from both tables are included in the result. + Rows from either table that do not meet the condition are excluded from the result. + ```sql + SELECT * + FROM tableOne + JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + or + ```sql + SELECT * + FROM tableOne + INNER JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + In this case, this command will return all columns from tableOne and tableTwo for rows where the condition + ``tableOne.columnY`` = ``tableTwo.columnX`` is met (i.e., the value of ``columnY`` in ``tableOne`` is equal to the + value of ``columnX`` in ``tableTwo``). +* ***LEFT JOIN*** is used to return a new table that includes all records from the left table and the matched records + from the right table. If there is no match, the result will contain empty values for columns from the right table. + ```sql + SELECT * + FROM tableOne + LEFT JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + In this case, this command will return all columns from ``tableOne`` and the matching columns from ``tableTwo``. For + rows in + ``tableOne`` that do not have a corresponding match in ``tableTwo``, the result will include empty values for columns + from + ``tableTwo``. +* ***RIGHT JOIN*** is used to return a new table that includes all records from the right table and the matched records + from the left table. If there is no match, the result will contain empty values for columns from the left table. + ```sql + SELECT * + FROM tableOne + RIGHT JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + In this case, this command will return all columns from ``tableTwo`` and the matching columns from ``tableOne``. For + rows in + ``tableTwo`` that do not have a corresponding match in ``tableOne``, the result will include empty values for columns + from + ``tableOne``. + +* ***FULL JOIN*** is used to return a new table created by joining two tables as a whole. The joined table contains all + records from both tables and fills empty values for missing matches on either side. This join combines the results of + both ``LEFT JOIN`` and ``RIGHT JOIN``. + ```sql + SELECT * + FROM tableOne + FULL JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + In this case, this command will return all columns from ``tableOne`` and ``tableTwo`` for rows fulfilling condition + ``tableOne.columnY EQUAL tableTwo.columnX`` (value of ``columnY`` in ``tableOne`` is equal the value of ``columnX`` in + ``tableTwo``). ## E2E TEST @@ -158,11 +219,13 @@ This is integrated into github workflows. To build your docker image run this command in root directory: -```shell +``` +shell docker build -t go4sql:test . ``` ### Run docker in interactive stream mode + To run this docker image in interactive stream mode use this command: ```shell @@ -170,6 +233,7 @@ docker run -i go4sql:test -stream ``` ### Run docker in socket mode + To run this docker image in socket mode use this command: ```shell @@ -177,6 +241,7 @@ docker run go4sql:test -socket ``` ### Run docker in file mode + **NOT RECOMMENDED** Alternatively you can run a docker image in file mode: @@ -190,12 +255,14 @@ docker run -i go4sql:test -file To create a pod deployment using helm chart, there is configuration under `./helm` directory. Commands: + ```shell cd ./helm helm install go4sql_pod_name GO4SQL/ ``` To check status of pod, use: + ```shell kubectl get pods ``` diff --git a/ast/ast.go b/ast/ast.go index ca9b332..17ede23 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -166,6 +166,7 @@ type SelectCommand struct { OrderByCommand *OrderByCommand // optional LimitCommand *LimitCommand // optional OffsetCommand *OffsetCommand // optional + JoinCommand *JoinCommand // optional } func (ls SelectCommand) CommandNode() {} @@ -201,7 +202,7 @@ func (ls SelectCommand) HasOrderByCommand() bool { return true } -// HasLimitCommand - returns true if optional HasLimitCommand is present in SelectCommand +// HasLimitCommand - returns true if optional LimitCommand is present in SelectCommand // // Example: // SELECT * FROM table LIMIT 5; @@ -216,7 +217,7 @@ func (ls SelectCommand) HasLimitCommand() bool { return true } -// HasOffsetCommand - returns true if optional HasOffsetCommand is present in SelectCommand +// HasOffsetCommand - returns true if optional OffsetCommand is present in SelectCommand // // Example: // SELECT * FROM table OFFSET 100; @@ -231,6 +232,21 @@ func (ls SelectCommand) HasOffsetCommand() bool { return true } +// HasJoinCommand - returns true if optional JoinCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table JOIN table2 ON table.one EQUAL table2.two; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls SelectCommand) HasJoinCommand() bool { + if ls.JoinCommand == nil { + return false + } + return true +} + // UpdateCommand - Part of Command that allow to change existing data // // Example: @@ -272,6 +288,26 @@ type WhereCommand struct { func (ls WhereCommand) CommandNode() {} func (ls WhereCommand) TokenLiteral() string { return ls.Token.Literal } +// JoinCommand - Part of Command that represent JOIN statement with expression that will merge tables +// +// Example: +// JOIN tbl2 ON tbl1.id EQUAL tbl2.f_idy; +type JoinCommand struct { + Token token.Token + Name Identifier // ex. name of table + JoinType token.Token + Expression Expression +} + +func (ls JoinCommand) CommandNode() {} +func (ls JoinCommand) TokenLiteral() string { return ls.Token.Literal } +func (ls JoinCommand) ShouldTakeLeftSide() bool { + return ls.JoinType.Type == token.LEFT || ls.JoinType.Type == token.FULL +} +func (ls JoinCommand) ShouldTakeRightSide() bool { + return ls.JoinType.Type == token.RIGHT || ls.JoinType.Type == token.FULL +} + // DeleteCommand - Part of Command that represent deleting row from table // // Example: diff --git a/engine/engine.go b/engine/engine.go index 4cba86e..4a149de 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -2,6 +2,7 @@ package engine import ( "fmt" + "maps" "sort" "github.com/LissaGreense/GO4SQL/ast" @@ -37,6 +38,8 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { continue case *ast.OffsetCommand: continue + case *ast.JoinCommand: + continue case *ast.CreateCommand: err := engine.createTable(mappedCommand) if err != nil { @@ -92,29 +95,43 @@ func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) (*Ta var table *Table var err error + if selectCommand.HasJoinCommand() { + joinCommand := selectCommand.JoinCommand + table, err = engine.joinTables(joinCommand, selectCommand.Name.Token.Literal) + if err != nil { + return nil, err + } + } else { + var exist bool + table, exist = engine.Tables[selectCommand.Name.Token.Literal] + + if !exist { + return nil, &TableDoesNotExistError{selectCommand.Name.Token.Literal} + } + } + if selectCommand.HasWhereCommand() { whereCommand := selectCommand.WhereCommand if selectCommand.HasOrderByCommand() { orderByCommand := selectCommand.OrderByCommand - table, err = engine.selectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand) + table, err = engine.selectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand, table) if err != nil { return nil, err } } else { - table, err = engine.selectFromTableWithWhere(selectCommand, whereCommand) + table, err = engine.selectFromTableWithWhere(selectCommand, whereCommand, table) if err != nil { return nil, err } } } else if selectCommand.HasOrderByCommand() { - table, err = engine.selectFromTableWithOrderBy(selectCommand, selectCommand.OrderByCommand) + table, err = engine.selectFromTableWithOrderBy(selectCommand, selectCommand.OrderByCommand, table) if err != nil { return nil, err } - } - - if table == nil { - table, err = engine.selectFromTable(selectCommand) + } else { + // panic: runtime error: invalid memory address or nil pointer dereference [recovered] + table, err = engine.selectFromProvidedTable(selectCommand, table) if err != nil { return nil, err } @@ -224,17 +241,6 @@ func (engine *DbEngine) insertIntoTable(command *ast.InsertCommand) error { return nil } -// selectFromTable - Return Table containing all values requested by SelectCommand -func (engine *DbEngine) selectFromTable(command *ast.SelectCommand) (*Table, error) { - table, exist := engine.Tables[command.Name.Token.Literal] - - if !exist { - return nil, &TableDoesNotExistError{command.Name.Token.Literal} - } - - return engine.selectFromProvidedTable(command, table) -} - func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, table *Table) (*Table, error) { columns := table.Columns @@ -276,13 +282,7 @@ func (engine *DbEngine) dropTable(dropCommand *ast.DropCommand) { } // selectFromTableWithWhere - Return Table containing all values requested by SelectCommand and filtered by WhereCommand -func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand) (*Table, error) { - table, exist := engine.Tables[selectCommand.Name.Token.Literal] - - if !exist { - return nil, &TableDoesNotExistError{selectCommand.Name.Token.Literal} - } - +func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, table *Table) (*Table, error) { if len(table.Columns) == 0 || len(table.Columns[0].Values) == 0 { return engine.selectFromProvidedTable(selectCommand, &Table{Columns: []*Column{}}) } @@ -298,13 +298,7 @@ func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectComman // selectFromTableWithWhereAndOrderBy - Return Table containing all values requested by SelectCommand, // filtered by WhereCommand and sorted by OrderByCommand -func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, orderByCommand *ast.OrderByCommand) (*Table, error) { - table, exist := engine.Tables[selectCommand.Name.Token.Literal] - - if !exist { - return nil, &TableDoesNotExistError{selectCommand.Name.Token.Literal} - } - +func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, orderByCommand *ast.OrderByCommand, table *Table) (*Table, error) { filteredTable, err := engine.getFilteredTable(table, whereCommand, false, selectCommand.Name.GetToken().Literal) if err != nil { @@ -323,13 +317,7 @@ func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.Se } // selectFromTableWithOrderBy - Return Table containing all values requested by SelectCommand and sorted by OrderByCommand -func (engine *DbEngine) selectFromTableWithOrderBy(selectCommand *ast.SelectCommand, orderByCommand *ast.OrderByCommand) (*Table, error) { - table, exist := engine.Tables[selectCommand.Name.Token.Literal] - - if !exist { - return nil, &TableDoesNotExistError{selectCommand.Name.Token.Literal} - } - +func (engine *DbEngine) selectFromTableWithOrderBy(selectCommand *ast.SelectCommand, orderByCommand *ast.OrderByCommand, table *Table) (*Table, error) { emptyTable := getCopyOfTableWithoutRows(table) sortedTable, err := engine.getSortedTable(orderByCommand, table, emptyTable, selectCommand.Name.GetToken().Literal) @@ -435,6 +423,82 @@ func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCo return filteredTable, nil } +func (engine *DbEngine) joinTables(joinCommand *ast.JoinCommand, leftTableName string) (*Table, error) { + leftTable, exist := engine.Tables[leftTableName] + leftTablePrefix := leftTableName + "." + if !exist { + return nil, &TableDoesNotExistError{leftTableName} + } + + rightTableName := joinCommand.Name.Token.Literal + rightTablePrefix := rightTableName + "." + rightTable, exist := engine.Tables[rightTableName] + if !exist { + return nil, &TableDoesNotExistError{rightTableName} + } + + joinedTable := &Table{Columns: []*Column{}} + + addColumnsWithPrefix(joinedTable, leftTable.Columns, leftTablePrefix) + addColumnsWithPrefix(joinedTable, rightTable.Columns, rightTablePrefix) + + leftTableWithAddedPrefix := leftTable.getTableCopyWithAddedPrefixToColumnNames(leftTablePrefix) + rightTableWithAddedPrefix := rightTable.getTableCopyWithAddedPrefixToColumnNames(rightTablePrefix) + var unmatchedRightRows = make(map[int]bool) + + for leftRowIndex := 0; leftRowIndex < len(leftTable.Columns[0].Values); leftRowIndex++ { + joinedRowLeft := getRow(leftTableWithAddedPrefix, leftRowIndex) + leftRowMatches := false + + for rightRowIndex := 0; rightRowIndex < len(rightTable.Columns[0].Values); rightRowIndex++ { + joinedRowRight := getRow(rightTableWithAddedPrefix, rightRowIndex) + maps.Copy(joinedRowRight, joinedRowLeft) + + fulfilledFilters, err := isFulfillingFilters(joinedRowRight, joinCommand.Expression, joinCommand.Token.Literal) + if err != nil { + return nil, err + } + + isLastLeftRow := leftRowIndex == len(leftTable.Columns[0].Values)-1 + + if fulfilledFilters { + for colIndex, column := range joinedTable.Columns { + joinedTable.Columns[colIndex].Values = append(joinedTable.Columns[colIndex].Values, joinedRowRight[column.Name]) + } + leftRowMatches, unmatchedRightRows[rightRowIndex] = true, true + } else if isLastLeftRow && joinCommand.ShouldTakeRightSide() && !unmatchedRightRows[rightRowIndex] { + joinedRowRight = getRow(rightTableWithAddedPrefix, rightRowIndex) + aggregateRowIntoJoinTable(leftTableWithAddedPrefix, joinedRowRight, joinedTable) + } + } + + if joinCommand.ShouldTakeLeftSide() && !leftRowMatches { + aggregateRowIntoJoinTable(rightTableWithAddedPrefix, joinedRowLeft, joinedTable) + } + } + + return joinedTable, nil +} + +func aggregateRowIntoJoinTable(tableWithAddedPrefix *Table, joinedRow map[string]ValueInterface, joinedTable *Table) { + joinedEmptyRow := getEmptyRow(tableWithAddedPrefix) + maps.Copy(joinedRow, joinedEmptyRow) + for colIndex, column := range joinedTable.Columns { + joinedTable.Columns[colIndex].Values = append(joinedTable.Columns[colIndex].Values, joinedRow[column.Name]) + } +} + +func addColumnsWithPrefix(finalTable *Table, columnsToAdd []*Column, prefix string) { + for _, column := range columnsToAdd { + finalTable.Columns = append(finalTable.Columns, + &Column{ + Type: column.Type, + Values: make([]ValueInterface, 0), + Name: prefix + column.Name, + }) + } +} + func (table *Table) applyOffsetAndLimit(command *ast.SelectCommand) { var offset = 0 var limitRaw = -1 diff --git a/engine/engine_error_handling_test.go b/engine/engine_error_handling_test.go index eeba2f0..f40f659 100644 --- a/engine/engine_error_handling_test.go +++ b/engine/engine_error_handling_test.go @@ -89,6 +89,20 @@ func TestEngineOrderByCommandErrorHandling(t *testing.T) { runEngineErrorHandlingSuite(t, tests) } +func TestEngineFullJoinErrorHandling(t *testing.T) { + leftTableNotExist := TableDoesNotExistError{tableName: "leftTable"} + rightTableNotExist := TableDoesNotExistError{tableName: "rightTable"} + columnDoesNotExist := ColumnDoesNotExistError{tableName: "", columnName: "leftTable.two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE rightTable(one TEXT); SELECT leftTable.one, rightTable.one FROM leftTable JOIN rightTable ON leftTable.one EQUAL rightTable.one;", leftTableNotExist.Error()}, + {"CREATE TABLE leftTable(one TEXT); SELECT leftTable.one, rightTable.one FROM leftTable JOIN rightTable ON leftTable.one EQUAL rightTable.one;", rightTableNotExist.Error()}, + {"CREATE TABLE leftTable(one TEXT); CREATE TABLE rightTable(one TEXT); INSERT INTO leftTable VALUES('hi'); INSERT INTO rightTable VALUES('hi'); SELECT * FROM leftTable JOIN rightTable ON leftTable.two EQUAL rightTable.one;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + func runEngineErrorHandlingSuite(t *testing.T, suite []errorHandlingTestSuite) { for i, test := range suite { errorMsg := getErrorMessage(t, test.input, i) diff --git a/engine/engine_test.go b/engine/engine_test.go index 18d7be2..afe2d98 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -547,6 +547,148 @@ func TestLimitAndOffset(t *testing.T) { engineTestSuite.runTestSuite(t) } +func TestDefaultJoinToBehaveLikeInnerJoin(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE books( author_id INT, title TEXT);", + "CREATE TABLE authors( author_id INT, name TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO books VALUES(2, 'Fire');", + "INSERT INTO books VALUES(1, 'Earth');", + "INSERT INTO books VALUES(1, 'Air');", + "INSERT INTO authors VALUES( 1, 'Reynold Boyka' );", + "INSERT INTO authors VALUES( 2, 'Alissa Ireneus' );", + }, + selectInput: "SELECT books.title, authors.name FROM books JOIN authors ON books.author_id EQUAL authors.author_id;", + expectedOutput: [][]string{ + {"books.title", "authors.name"}, + {"Fire", "Alissa Ireneus"}, + {"Earth", "Reynold Boyka"}, + {"Air", "Reynold Boyka"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestInnerJoinOnMultipleMatches(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE books( author_id INT, title TEXT);", + "CREATE TABLE authors( author_id INT, name TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO books VALUES(1, 'Book One');", + "INSERT INTO books VALUES(1, 'Book Two');", + "INSERT INTO authors VALUES(1, 'Author One');", + "INSERT INTO authors VALUES(1, 'Author Two');", + }, + selectInput: "SELECT books.title, authors.name FROM books JOIN authors ON books.author_id EQUAL authors.author_id;", + expectedOutput: [][]string{ + {"books.title", "authors.name"}, + {"Book One", "Author One"}, + {"Book One", "Author Two"}, + {"Book Two", "Author One"}, + {"Book Two", "Author Two"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestFullJoinOnIdenticalTables(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + "CREATE TABLE table2( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(3, 'Value3');", + }, + selectInput: "SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id;", + expectedOutput: [][]string{ + {"table1.value", "table2.value"}, + {"Value1", "NULL"}, + {"Value2", "Value2"}, + {"NULL", "Value3"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestInnerJoinWithSpecifiedKeywordOnIdenticalTables(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + "CREATE TABLE table2( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(3, 'Value3');", + }, + selectInput: "SELECT table1.value, table2.value FROM table1 INNER JOIN table2 ON table1.id EQUAL table2.id;", + expectedOutput: [][]string{ + {"table1.value", "table2.value"}, + {"Value2", "Value2"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestLeftJoinOnIdenticalTables(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + "CREATE TABLE table2( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(3, 'Value3');", + }, + selectInput: "SELECT table1.value, table2.value FROM table1 LEFT JOIN table2 ON table1.id EQUAL table2.id;", + expectedOutput: [][]string{ + {"table1.value", "table2.value"}, + {"Value1", "NULL"}, + {"Value2", "Value2"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestRightJoinOnIdenticalTables(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + "CREATE TABLE table2( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(3, 'Value3');", + }, + selectInput: "SELECT table1.value, table2.value FROM table1 RIGHT JOIN table2 ON table1.id EQUAL table2.id;", + expectedOutput: [][]string{ + {"table1.value", "table2.value"}, + {"Value2", "Value2"}, + {"NULL", "Value3"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + type engineDBContentTestSuite struct { inputs []string expectedTableNames []string diff --git a/engine/generic_value.go b/engine/generic_value.go index 284a912..2364dd4 100644 --- a/engine/generic_value.go +++ b/engine/generic_value.go @@ -19,6 +19,7 @@ type SupportedTypes int const ( IntType = iota StringType + NullType ) // IntegerValue - Implementation of ValueInterface that is containing integer values @@ -31,13 +32,19 @@ type StringValue struct { Value string } +// NullValue - Implementation of ValueInterface that is containing null +type NullValue struct { +} + // ToString implementations func (value IntegerValue) ToString() string { return strconv.Itoa(value.Value) } func (value StringValue) ToString() string { return value.Value } +func (value NullValue) ToString() string { return "NULL" } // GetType implementations func (value IntegerValue) GetType() SupportedTypes { return IntType } func (value StringValue) GetType() SupportedTypes { return StringType } +func (value NullValue) GetType() SupportedTypes { return NullType } // IsEqual implementations func (value IntegerValue) IsEqual(valueInterface ValueInterface) bool { @@ -46,6 +53,9 @@ func (value IntegerValue) IsEqual(valueInterface ValueInterface) bool { func (value StringValue) IsEqual(valueInterface ValueInterface) bool { return areEqual(value, valueInterface) } +func (value NullValue) IsEqual(valueInterface ValueInterface) bool { + return areEqual(value, valueInterface) +} // isSmallerThan implementations func (value IntegerValue) isSmallerThan(secondValue ValueInterface) bool { @@ -57,6 +67,7 @@ func (value IntegerValue) isSmallerThan(secondValue ValueInterface) bool { return value.Value < secondValueAsInteger.Value } + func (value StringValue) isSmallerThan(secondValue ValueInterface) bool { secondValueAsString, isString := secondValue.(StringValue) @@ -67,6 +78,16 @@ func (value StringValue) isSmallerThan(secondValue ValueInterface) bool { return value.Value < secondValueAsString.Value } +func (value NullValue) isSmallerThan(secondValue ValueInterface) bool { + _, isNull := secondValue.(NullValue) + + if !isNull { + log.Fatal("Can't compare Null with other type") + } + + return true +} + // isGreaterThan implementations func (value IntegerValue) isGreaterThan(secondValue ValueInterface) bool { secondValueAsInteger, isInteger := secondValue.(IntegerValue) @@ -87,6 +108,16 @@ func (value StringValue) isGreaterThan(secondValue ValueInterface) bool { return value.Value > secondValueAsString.Value } +func (value NullValue) isGreaterThan(secondValue ValueInterface) bool { + _, isNull := secondValue.(NullValue) + + if !isNull { + log.Fatal("Can't compare Null with other type") + } + + return true +} + func areEqual(first ValueInterface, second ValueInterface) bool { return first.GetType() == second.GetType() && first.ToString() == second.ToString() } diff --git a/engine/generic_value_test.go b/engine/generic_value_test.go index d4815e5..1198e0d 100644 --- a/engine/generic_value_test.go +++ b/engine/generic_value_test.go @@ -17,6 +17,8 @@ func TestIsGreaterThan(t *testing.T) { twoString := StringValue{ Value: "aab", } + oneNull := NullValue{} + twoNull := NullValue{} if oneInt.isGreaterThan(twoInt) { t.Errorf("1 shouldn't be greater than 2") @@ -33,6 +35,10 @@ func TestIsGreaterThan(t *testing.T) { if !twoString.isGreaterThan(oneString) { t.Errorf("1 shouldn't be greater than 2") } + + if !twoNull.isGreaterThan(oneNull) { + t.Errorf("null to null operations should always return true") + } } func TestIsSmallerThan(t *testing.T) { @@ -48,6 +54,8 @@ func TestIsSmallerThan(t *testing.T) { twoString := StringValue{ Value: "aab", } + oneNull := NullValue{} + twoNull := NullValue{} if !oneInt.isSmallerThan(twoInt) { t.Errorf("1 should be smaller than 2") @@ -64,10 +72,13 @@ func TestIsSmallerThan(t *testing.T) { if twoString.isSmallerThan(oneString) { t.Errorf("1 should be smaller than 2") } + + if !twoNull.isSmallerThan(oneNull) { + t.Errorf("null to null operations should always return true") + } } func TestEquals(t *testing.T) { - oneInt := IntegerValue{ Value: 1, } @@ -80,12 +91,15 @@ func TestEquals(t *testing.T) { twoString := StringValue{ Value: "two", } + oneNull := NullValue{} + twoNull := NullValue{} shouldBeEqual(t, oneInt, oneInt) shouldBeEqual(t, oneString, oneString) shouldNotBeEqual(t, oneInt, twoInt) shouldNotBeEqual(t, oneString, twoString) shouldNotBeEqual(t, oneString, oneInt) + shouldBeEqual(t, oneNull, twoNull) } func shouldBeEqual(t *testing.T, valueOne ValueInterface, valueTwo ValueInterface) { diff --git a/engine/row.go b/engine/row.go index 5dd9743..6160956 100644 --- a/engine/row.go +++ b/engine/row.go @@ -25,3 +25,11 @@ func getRow(table *Table, rowIndex int) map[string]ValueInterface { } return row } + +func getEmptyRow(table *Table) map[string]ValueInterface { + row := make(map[string]ValueInterface) + for _, column := range table.Columns { + row[column.Name] = NullValue{} + } + return row +} diff --git a/engine/table.go b/engine/table.go index 262d2a8..ec8561a 100644 --- a/engine/table.go +++ b/engine/table.go @@ -100,7 +100,8 @@ func (table *Table) ToString() string { result += " " printedValue := table.Columns[iColumn].Values[iRow].ToString() - if table.Columns[iColumn].Type.Literal == token.TEXT { + if table.Columns[iColumn].Type.Literal == token.TEXT && + table.Columns[iColumn].Values[iRow].GetType() != NullType { printedValue = "'" + printedValue + "'" } for i := 0; i < columWidths[iColumn]-len(printedValue); i++ { @@ -116,6 +117,21 @@ func (table *Table) ToString() string { return result + bar } +func (table *Table) getTableCopyWithAddedPrefixToColumnNames(columnNamePrefix string) *Table { + newTable := &Table{Columns: []*Column{}} + + for _, column := range table.Columns { + newTable.Columns = append(newTable.Columns, + &Column{ + Type: column.Type, + Values: column.Values, + Name: columnNamePrefix + column.Name, + }) + } + + return newTable +} + func getBar(columWidths []int) string { bar := "+" diff --git a/go.mod b/go.mod index f35fa57..aed510b 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/LissaGreense/GO4SQL -go 1.16 +go 1.21 diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 5c4c0b1..478660c 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -320,6 +320,140 @@ func TestSelectWithDistinct(t *testing.T) { runLexerTestSuite(t, input, tests) } +func TestDefaultJoin(t *testing.T) { + input := ` SELECT title FROM books + JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestInnerJoin(t *testing.T) { + input := ` SELECT title FROM books + INNER JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.INNER, "INNER"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestLeftJoin(t *testing.T) { + input := ` SELECT title FROM books + LEFT JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.LEFT, "LEFT"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestRightJoin(t *testing.T) { + input := ` SELECT title FROM books + RIGHT JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.RIGHT, "RIGHT"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestFullJoin(t *testing.T) { + input := ` SELECT title FROM books + FULL JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.FULL, "FULL"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + func runLexerTestSuite(t *testing.T, input string, tests []struct { expectedType token.Type expectedLiteral string diff --git a/parser/errors.go b/parser/errors.go index 7cc11b3..dbd71e4 100644 --- a/parser/errors.go +++ b/parser/errors.go @@ -98,3 +98,12 @@ type NoPredecessorParserError struct { func (m *NoPredecessorParserError) Error() string { return "syntax error, {" + m.command + "} command can't be used without predecessor" } + +// IllegalPeriodInIdentParserError - error thrown when parser found period in ident when parsing create command +type IllegalPeriodInIdentParserError struct { + name string +} + +func (m *IllegalPeriodInIdentParserError) Error() string { + return "syntax error, {" + m.name + "} shouldn't contain '.'" +} diff --git a/parser/parser.go b/parser/parser.go index 15b46d3..ec0e43d 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -5,6 +5,7 @@ import ( "github.com/LissaGreense/GO4SQL/lexer" "github.com/LissaGreense/GO4SQL/token" "strconv" + "strings" ) // Parser - Contain token that is currently analyzed by parser and the next one. Lexer is used to tokenize the client @@ -84,6 +85,9 @@ func (parser *Parser) parseCreateCommand() (ast.Command, error) { return nil, err } + if strings.Contains(parser.currentToken.Literal, ".") { + return nil, &IllegalPeriodInIdentParserError{name: parser.currentToken.Literal} + } createCommand.Name = ast.Identifier{Token: parser.currentToken} // Skip token.IDENT @@ -101,6 +105,10 @@ func (parser *Parser) parseCreateCommand() (ast.Command, error) { return nil, err } + if strings.Contains(parser.currentToken.Literal, ".") { + return nil, &IllegalPeriodInIdentParserError{name: parser.currentToken.Literal} + } + createCommand.ColumnNames = append(createCommand.ColumnNames, parser.currentToken.Literal) createCommand.ColumnTypes = append(createCommand.ColumnTypes, parser.peekToken) @@ -268,8 +276,8 @@ func (parser *Parser) parseSelectCommand() (ast.Command, error) { // Ignore token.IDENT parser.nextToken() - // expect SEMICOLON or WHERE - err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE, token.ORDER, token.LIMIT, token.OFFSET}) + // expect SEMICOLON or other keywords expected in SELECT statement + err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE, token.ORDER, token.LIMIT, token.OFFSET, token.JOIN, token.LEFT, token.RIGHT, token.INNER, token.FULL}) if err != nil { return nil, err } @@ -428,9 +436,9 @@ func (parser *Parser) parseOrderByCommand() (ast.Command, error) { return orderCommand, nil } -// parseLimitCommand - Return ast.parseLimitCommand created from tokens and validate the syntax +// parseLimitCommand - Return ast.LimitCommand created from tokens and validate the syntax // -// Example of input parsable to the ast.parseLimitCommand: +// Example of input parsable to the ast.LimitCommand: // LIMIT 10 func (parser *Parser) parseLimitCommand() (ast.Command, error) { // token.LIMIT already at current position in parser @@ -464,9 +472,9 @@ func (parser *Parser) parseLimitCommand() (ast.Command, error) { return limitCommand, nil } -// parseOffsetCommand - Return ast.parseOffsetCommand created from tokens and validate the syntax +// parseOffsetCommand - Return ast.OffsetCommand created from tokens and validate the syntax // -// Example of input parsable to the ast.parseLimitCommand: +// Example of input parsable to the ast.LimitCommand: // OFFSET 10 func (parser *Parser) parseOffsetCommand() (ast.Command, error) { // token.OFFSET already at current position in parser @@ -498,6 +506,59 @@ func (parser *Parser) parseOffsetCommand() (ast.Command, error) { return offsetCommand, nil } +// parseJoinCommand - Return ast.JoinCommand created from tokens and validate the syntax +// +// Example of input parsable to the ast.JoinCommand: +// JOIN table on table.one EQUAL table2.one; +func (parser *Parser) parseJoinCommand() (ast.Command, error) { + // parser has either token.JOIN, token.LEFT, token.RIGHT, token.INNER or token.FULL + var joinCommand *ast.JoinCommand + + if parser.currentToken.Type == token.JOIN { + joinCommand = &ast.JoinCommand{Token: parser.currentToken} + joinCommand.JoinType = token.Token{Type: token.INNER, Literal: token.INNER} + } else { + joinTypeTokenType := parser.currentToken + parser.nextToken() + err := validateToken(parser.currentToken.Type, []token.Type{token.JOIN}) + if err != nil { + return nil, err + } + joinCommand = &ast.JoinCommand{Token: parser.currentToken} + joinCommand.JoinType = joinTypeTokenType + } + + // token.JOIN no longer needed + parser.nextToken() + + err := validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + + joinCommand.Name = ast.Identifier{Token: parser.currentToken} + parser.nextToken() + + err = validateTokenAndSkip(parser, []token.Type{token.ON}) + if err != nil { + return nil, err + } + + var expressionIsValid bool + expressionIsValid, joinCommand.Expression, err = parser.getExpression() + if err != nil { + return nil, err + } + + if !expressionIsValid { + return nil, &LogicalExpressionParsingError{} + } + + parser.skipIfCurrentTokenIsSemicolon() + + return joinCommand, nil +} + // parseUpdateCommand - Return ast.parseUpdateCommand created from tokens and validate the syntax // // Example of input parsable to the ast.parseUpdateCommand: @@ -802,6 +863,20 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { return nil, err } selectCommand.OffsetCommand = newCommand.(*ast.OffsetCommand) + case token.JOIN, token.LEFT, token.RIGHT, token.INNER, token.FULL: + lastCommand, parserError := parser.getLastCommand(sequence, token.JOIN) + if parserError != nil { + return nil, parserError + } + if lastCommand.TokenLiteral() != token.SELECT { + return nil, &SyntaxCommandExpectedError{command: "JOIN", neededCommands: []string{"SELECT"}} + } + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseJoinCommand() + if err != nil { + return nil, err + } + selectCommand.JoinCommand = newCommand.(*ast.JoinCommand) default: return nil, &SyntaxInvalidCommandError{invalidCommand: parser.currentToken.Literal} } diff --git a/parser/parser_error_handling_test.go b/parser/parser_error_handling_test.go index 2768b3d..c01cfee 100644 --- a/parser/parser_error_handling_test.go +++ b/parser/parser_error_handling_test.go @@ -97,7 +97,7 @@ func TestParseSelectCommandErrorHandling(t *testing.T) { noFromKeyword := SyntaxError{[]string{token.FROM}, token.IDENT} noColumns := SyntaxError{[]string{token.ASTERISK, token.IDENT}, token.FROM} noTableName := SyntaxError{[]string{token.IDENT}, token.SEMICOLON} - noSemicolon := SyntaxError{[]string{token.SEMICOLON, token.WHERE, token.ORDER, token.LIMIT, token.OFFSET}, ""} + noSemicolon := SyntaxError{[]string{token.SEMICOLON, token.WHERE, token.ORDER, token.LIMIT, token.OFFSET, token.JOIN, token.LEFT, token.RIGHT, token.INNER, token.FULL}, ""} tests := []errorHandlingTestSuite{ {"SELECT column1, column2 tbl;", noFromKeyword.Error()}, @@ -195,6 +195,18 @@ func TestParseDeleteCommandErrorHandling(t *testing.T) { runParserErrorHandlingSuite(t, tests) } +func TestPeriodInIdentWhileCreatingTableErrorHandling(t *testing.T) { + illegalPeriodInTableName := IllegalPeriodInIdentParserError{"tab.le"} + illegalPeriodInColumnName := IllegalPeriodInIdentParserError{"col.umn"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tab.le( one TEXT , two INT);", illegalPeriodInTableName.Error()}, + {"CREATE TABLE table1( col.umn TEXT , two INT);", illegalPeriodInColumnName.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + func runParserErrorHandlingSuite(t *testing.T, suite []errorHandlingTestSuite) { for i, test := range suite { errorMsg := getErrorMessage(t, test.input, i) diff --git a/parser/parser_test.go b/parser/parser_test.go index 92fa217..06923ae 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -423,6 +423,201 @@ func TestSelectWithLimitAndOffsetCommand(t *testing.T) { testOffsetCommands(t, expectedOffsetCommand, selectCommand.OffsetCommand) } +func TestSelectWithDefaultInnerJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.INNER, Literal: "INNER"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + return + } + + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") + } + + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + +func TestSelectWithInnerJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl INNER JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.INNER, Literal: "INNER"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + return + } + + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") + } + + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + +func TestSelectWithLeftJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl LEFT JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.LEFT, Literal: "LEFT"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + return + } + + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") + } + + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + +func TestSelectWithRightJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl RIGHT JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.RIGHT, Literal: "RIGHT"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + return + } + + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") + } + + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + +func TestSelectWithFullJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl FULL JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.FULL, Literal: "FULL"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + return + } + + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") + } + + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + func TestParseUpdateCommand(t *testing.T) { tests := []struct { input string @@ -739,6 +934,22 @@ func testOffsetCommands(t *testing.T, expectedOffsetCommand ast.OffsetCommand, a t.Errorf("Expecting Count to have value: %d, got: %d", expectedOffsetCommand.Count, actualOffsetCommand.Count) } } + +func testJoinCommands(t *testing.T, expectedJoinCommand ast.JoinCommand, actualJoinCommand ast.JoinCommand) { + + if expectedJoinCommand.Token.Type != actualJoinCommand.Token.Type { + t.Errorf("Expecting Token TokenType: %q, got: %q", expectedJoinCommand.Token.Type, actualJoinCommand.Token.Type) + } + if expectedJoinCommand.Token.Literal != actualJoinCommand.Token.Literal { + t.Errorf("Expecting Token Literal: %s, got: %s", expectedJoinCommand.Token.Literal, actualJoinCommand.Token.Literal) + } + if expectedJoinCommand.Name != actualJoinCommand.Name { + t.Errorf("Expecting Name to has a value: %s, got: %s", expectedJoinCommand.Name, actualJoinCommand.Name) + } + if !expressionsAreEqual(actualJoinCommand.Expression, expectedJoinCommand.Expression) { + t.Errorf("Actual expression is not equal to expected one.\nActual: %#v\nExpected: %#v", actualJoinCommand.Expression, expectedJoinCommand.Expression) + } +} func expressionsAreEqual(first ast.Expression, second ast.Expression) bool { booleanExpression, booleanExpressionIsValid := first.(*ast.BooleanExpression) diff --git a/test_file b/test_file index ad768bd..aac029a 100644 --- a/test_file +++ b/test_file @@ -17,3 +17,13 @@ INSERT INTO tbl VALUES( 'goodbye', 5, 22, 'P' ); SELECT DISTINCT * FROM tbl; DROP TABLE tbl; + CREATE TABLE table1( id INT, value TEXT); + CREATE TABLE table2( id INT, value TEXT); + INSERT INTO table1 VALUES(1, 'Value1'); + INSERT INTO table1 VALUES(2, 'Value2'); + INSERT INTO table2 VALUES(2, 'Value2'); + INSERT INTO table2 VALUES(3, 'Value3'); + SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id; + SELECT table1.value, table2.value FROM table1 INNER JOIN table2 ON table1.id EQUAL table2.id; + SELECT table1.value, table2.value FROM table1 LEFT JOIN table2 ON table1.id EQUAL table2.id; + SELECT table1.value, table2.value FROM table1 RIGHT JOIN table2 ON table1.id EQUAL table2.id; diff --git a/token/token.go b/token/token.go index 913bc74..5da87ef 100644 --- a/token/token.go +++ b/token/token.go @@ -48,6 +48,12 @@ const ( UPDATE = "UPDATE" SET = "SET" DISTINCT = "DISTINCT" + JOIN = "JOIN" + INNER = "INNER" + FULL = "FULL" + LEFT = "LEFT" + RIGHT = "RIGHT" + ON = "ON" TO = "TO" @@ -87,6 +93,12 @@ var keywords = map[string]Type{ "UPDATE": UPDATE, "SET": SET, "DISTINCT": DISTINCT, + "INNER": INNER, + "FULL": FULL, + "LEFT": LEFT, + "RIGHT": RIGHT, + "JOIN": JOIN, + "ON": ON, "TO": TO, "VALUES": VALUES, "WHERE": WHERE, From e2e9191715d96e5f472f28ba14f3b2d3ecd85bec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sara=20Ryfczy=C5=84ska?= Date: Fri, 8 Nov 2024 00:25:47 +0100 Subject: [PATCH 14/21] Aggregate functions (#23) * Add Aggregate function handling in lexer, parser and even write documentation * Aggregate functions engine WIP * Finish tests, fix implementation of aggr functions --------- Co-authored-by: ixior462 --- .github/expected_results/end2end.txt | 31 ++++++ README.md | 36 +++++++ ast/ast.go | 30 +++++- engine/engine.go | 127 +++++++++++++++++++++++- engine/engine_test.go | 141 +++++++++++++++++++++++++++ engine/generic_value.go | 45 +++++++++ lexer/lexer_test.go | 40 ++++++++ parser/parser.go | 59 +++++++++-- parser/parser_error_handling_test.go | 10 +- parser/parser_test.go | 119 +++++++++++++++++----- test_file | 27 +++-- token/token.go | 10 ++ 12 files changed, 625 insertions(+), 50 deletions(-) diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt index 2a438cb..68f7159 100644 --- a/.github/expected_results/end2end.txt +++ b/.github/expected_results/end2end.txt @@ -96,3 +96,34 @@ Data Inserted | 'Value2' | 'Value2' | | NULL | 'Value3' | +--------------+--------------+ +Data Inserted ++---------+------------+ +| MAX(id) | MAX(value) | ++---------+------------+ +| 3 | Value3 | ++---------+------------+ ++------------+---------+ +| MIN(value) | MIN(id) | ++------------+---------+ +| Value1 | 1 | ++------------+---------+ ++----------+-----------+--------------+ +| COUNT(*) | COUNT(id) | COUNT(value) | ++----------+-----------+--------------+ +| 3 | 3 | 3 | ++----------+-----------+--------------+ ++---------+------------+ +| SUM(id) | SUM(value) | ++---------+------------+ +| 6 | 0 | ++---------+------------+ ++---------+------------+ +| AVG(id) | AVG(value) | ++---------+------------+ +| 2 | 0 | ++---------+------------+ ++---------+----+ +| AVG(id) | id | ++---------+----+ +| 2 | 1 | ++---------+----+ diff --git a/README.md b/README.md index f78dadb..f0d7904 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,42 @@ go clean -testcache; go test ./... ``tableOne.columnY EQUAL tableTwo.columnX`` (value of ``columnY`` in ``tableOne`` is equal the value of ``columnX`` in ``tableTwo``). +* ***MIN()*** is used to return the smallest value in a specified column. + ```sql + SELECT MIN(columnName) + FROM tableName; + ``` + In this case, this command will return the smallest value found in the column ``columnName`` of ``tableName``. + +* ***MAX()*** is used to return the largest value in a specified column. + ```sql + SELECT MAX(columnName) + FROM tableName; + ``` + This command will return the largest value found in the column ``columnName`` of ``tableName``. + +* ***COUNT()*** is used to return the number of rows that match a given condition or the total number of rows in a + specified column. + ```sql + SELECT COUNT(columnName) + FROM tableName; + ``` + This command will return the number of rows in the ``columnName`` of ``tableName``. + +* ***SUM()*** is used to return the total sum of the values in a specified numerical column. + ```sql + SELECT SUM(columnName) + FROM tableName; + ``` + This command will return the total sum of all values in the numerical column ``columnName`` of ``tableName``. + +* ***AVG()*** is used to return the average of values in a specified numerical column. + ```sql + SELECT AVG(columnName) + FROM tableName; + ``` + This command will return the average of all values in the numerical column ``columnName`` of ``tableName``. + ## E2E TEST In root directory there is **test_file** containing input commands for E2E tests. File diff --git a/ast/ast.go b/ast/ast.go index 17ede23..5e3a434 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -153,6 +153,26 @@ type InsertCommand struct { func (ls InsertCommand) CommandNode() {} func (ls InsertCommand) TokenLiteral() string { return ls.Token.Literal } +// Space - part of SelectCommand which is containing either * or a column name with an optional function aggregating it +type Space struct { + ColumnName token.Token + AggregateFunc *token.Token +} + +func (space Space) String() string { + columnName := "ColumnName={Type: " + string(space.ColumnName.Type) + ", Literal: " + space.ColumnName.Literal + "}" + if space.ContainsAggregateFunc() { + aggFunc := "AggregateFunc={Type: " + string(space.AggregateFunc.Type) + ", Literal: " + space.AggregateFunc.Literal + "}" + return columnName + ", " + aggFunc + } + return columnName +} + +// ContainsAggregateFunc - return true if space contains AggregateFunc that aggregate columnName or * +func (space Space) ContainsAggregateFunc() bool { + return space.AggregateFunc != nil +} + // SelectCommand - Part of Command that represent selecting values from tables // // Example: @@ -160,7 +180,7 @@ func (ls InsertCommand) TokenLiteral() string { return ls.Token.Literal } type SelectCommand struct { Token token.Token Name Identifier // ex. name of table - Space []token.Token // ex. column names + Space []Space // ex. column names HasDistinct bool // DISTINCT keyword has been used WhereCommand *WhereCommand // optional OrderByCommand *OrderByCommand // optional @@ -171,6 +191,14 @@ type SelectCommand struct { func (ls SelectCommand) CommandNode() {} func (ls SelectCommand) TokenLiteral() string { return ls.Token.Literal } +func (ls *SelectCommand) AggregateFunctionAppears() bool { + for _, space := range ls.Space { + if space.ContainsAggregateFunc() { + return true + } + } + return false +} // HasWhereCommand - returns true if optional HasWhereCommand is present in SelectCommand // diff --git a/engine/engine.go b/engine/engine.go index 4a149de..5874d78 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -4,6 +4,7 @@ import ( "fmt" "maps" "sort" + "strconv" "github.com/LissaGreense/GO4SQL/ast" "github.com/LissaGreense/GO4SQL/token" @@ -245,19 +246,141 @@ func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, tabl columns := table.Columns wantedColumnNames := make([]string, 0) - if command.Space[0].Type == token.ASTERISK { + if command.AggregateFunctionAppears() { + selectedTable := &Table{Columns: make([]*Column, 0)} + + for i := 0; i < len(command.Space); i++ { + var columnType token.Token + var columnName string + var columnValues []ValueInterface + var err error + value := make([]ValueInterface, 0) + currentSpace := command.Space[i] + + if currentSpace.ColumnName.Type == token.ASTERISK && currentSpace.AggregateFunc.Type == token.COUNT { + if len(columns) > 0 { + columnValues = columns[0].Values + } + } else { + columnValues, err = getValuesOfColumn(currentSpace.ColumnName.Literal, columns) + } + + if err != nil { + return nil, err + } + + if currentSpace.ContainsAggregateFunc() { + columnName = fmt.Sprintf("%s(%s)", currentSpace.AggregateFunc.Literal, + currentSpace.ColumnName.Literal) + columnType = evaluateColumnTypeOfAggregateFunc(currentSpace) + aggregatedValue, aggregateErr := aggregateColumnContent(currentSpace, columnValues) + if aggregateErr != nil { + return nil, aggregateErr + } + value = append(value, aggregatedValue) + } else { + columnName = currentSpace.ColumnName.Literal + columnType = currentSpace.ColumnName + value = append(value, columnValues[0]) + } + + selectedTable.Columns = append(selectedTable.Columns, &Column{ + Name: columnName, + Type: columnType, + Values: value, + }) + } + return selectedTable, nil + } else if command.Space[0].ColumnName.Type == token.ASTERISK { for i := 0; i < len(columns); i++ { wantedColumnNames = append(wantedColumnNames, columns[i].Name) } return extractColumnContent(columns, &wantedColumnNames, command.Name.GetToken().Literal) } else { for i := 0; i < len(command.Space); i++ { - wantedColumnNames = append(wantedColumnNames, command.Space[i].Literal) + wantedColumnNames = append(wantedColumnNames, command.Space[i].ColumnName.Literal) } return extractColumnContent(columns, unique(wantedColumnNames), command.Name.GetToken().Literal) } } +func getValuesOfColumn(columnName string, columns []*Column) ([]ValueInterface, error) { + wantedColumnName := []string{columnName} + columnContent, err := extractColumnContent(columns, &wantedColumnName, "") + if err != nil { + return nil, err + } + return columnContent.Columns[0].Values, nil +} + +func evaluateColumnTypeOfAggregateFunc(space ast.Space) token.Token { + if space.AggregateFunc.Type == token.MIN || + space.AggregateFunc.Type == token.MAX { + return space.ColumnName + } + return token.Token{Type: token.INT, Literal: "INT"} +} + +func aggregateColumnContent(space ast.Space, columnValues []ValueInterface) (ValueInterface, error) { + if space.AggregateFunc.Type == token.COUNT { + if space.ColumnName.Type == token.ASTERISK { + return IntegerValue{Value: len(columnValues)}, nil + } + count := 0 + for _, value := range columnValues { + if value.GetType() != NullType { + count++ + } + } + return IntegerValue{Value: count}, nil + } + if len(columnValues) == 0 { + return NullValue{}, nil + } + switch space.AggregateFunc.Type { + case token.MAX: + maxValue, err := getMax(columnValues) + if err != nil { + return nil, err + } + return maxValue, nil + case token.MIN: + minValue, err := getMin(columnValues) + if err != nil { + return nil, err + } + return minValue, nil + case token.SUM: + if columnValues[0].GetType() == StringType { + return IntegerValue{Value: 0}, nil + } else { + sum := 0 + for _, value := range columnValues { + num, err := strconv.Atoi(value.ToString()) + if err != nil { + return nil, err + } + sum += num + } + return IntegerValue{Value: sum}, nil + } + default: + if columnValues[0].GetType() == StringType { + return IntegerValue{Value: 0}, nil + } else { + sum := 0 + for _, value := range columnValues { + num, err := strconv.Atoi(value.ToString()) + if err != nil { + return nil, err + } + sum += num + } + return IntegerValue{Value: sum / len(columnValues)}, nil + } + } +} + // deleteFromTable - Delete all rows of data from table that match given condition func (engine *DbEngine) deleteFromTable(deleteCommand *ast.DeleteCommand, whereCommand *ast.WhereCommand) error { table, exist := engine.Tables[deleteCommand.Name.Token.Literal] diff --git a/engine/engine_test.go b/engine/engine_test.go index afe2d98..6320ca7 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -689,6 +689,147 @@ func TestRightJoinOnIdenticalTables(t *testing.T) { engineTestSuite.runTestSuite(t) } +func TestAggregateFunctionMax(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + }, + selectInput: "SELECT MAX(id), MAX(value) FROM table1;", + expectedOutput: [][]string{ + {"MAX(id)", "MAX(value)"}, + {"2", "Value2"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionMin(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + }, + selectInput: "SELECT MIN(value), MIN(id) FROM table1;", + expectedOutput: [][]string{ + {"MIN(value)", "MIN(id)"}, + {"Value1", "1"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionCount(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, 'Value3');", + // TODO: Add test case mentioned in comment below once inserting + // null values will be added + //"INSERT INTO table1 VALUES(NULL, NULL);", + }, + selectInput: "SELECT COUNT(*), COUNT(id), COUNT(value) FROM table1;", + expectedOutput: [][]string{ + {"COUNT(*)", "COUNT(id)", "COUNT(value)"}, + {"3", "3", "3"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionSum(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, 'Value3');", + }, + selectInput: "SELECT SUM(id), SUM(value) FROM table1;", + expectedOutput: [][]string{ + {"SUM(id)", "SUM(value)"}, + {"6", "0"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionAvg(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, 'Value3');", + }, + selectInput: "SELECT AVG(id), AVG(value) FROM table1;", + expectedOutput: [][]string{ + {"AVG(id)", "AVG(value)"}, + {"2", "0"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionWithColumnSelection(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, 'Value3');", + }, + selectInput: "SELECT AVG(id), id FROM table1;", + expectedOutput: [][]string{ + {"AVG(id)", "id"}, + {"2", "1"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionWithColumnSelectionAndOrderBy(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, 'Value3');", + }, + selectInput: "SELECT MAX(id), id FROM table1 ORDER BY id DESC;", + expectedOutput: [][]string{ + {"MAX(id)", "id"}, + {"3", "3"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + type engineDBContentTestSuite struct { inputs []string expectedTableNames []string diff --git a/engine/generic_value.go b/engine/generic_value.go index 2364dd4..c6fd07c 100644 --- a/engine/generic_value.go +++ b/engine/generic_value.go @@ -1,6 +1,8 @@ package engine import ( + "errors" + "fmt" "log" "strconv" ) @@ -36,6 +38,20 @@ type StringValue struct { type NullValue struct { } +// HandleValue - Function to take an instance of ValueInterface and cast to a specific implementation +func CastValueInterface(v ValueInterface) { + switch value := v.(type) { + case IntegerValue: + fmt.Printf("IntegerValue with Value: %d\n", value.Value) + case StringValue: + fmt.Printf("StringValue with Value: %s\n", value.Value) + case NullValue: + fmt.Println("NullValue (no value)") + default: + fmt.Println("Unknown type") + } +} + // ToString implementations func (value IntegerValue) ToString() string { return strconv.Itoa(value.Value) } func (value StringValue) ToString() string { return value.Value } @@ -121,3 +137,32 @@ func (value NullValue) isGreaterThan(secondValue ValueInterface) bool { func areEqual(first ValueInterface, second ValueInterface) bool { return first.GetType() == second.GetType() && first.ToString() == second.ToString() } + +func getMin(values []ValueInterface) (ValueInterface, error) { + if len(values) == 0 { + return nil, errors.New("can't extract min from empty array") + } + minValue := values[0] + + for _, value := range values[1:] { + if value.isSmallerThan(minValue) { + minValue = value + } + } + return minValue, nil +} + +func getMax(values []ValueInterface) (ValueInterface, error) { + if len(values) == 0 { + return nil, errors.New("can't extract max from empty array") + } + + maxValue := values[0] + for _, value := range values[1:] { + if value.isGreaterThan(maxValue) { + maxValue = value + } + } + + return maxValue, nil +} diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 478660c..66a6294 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -302,6 +302,46 @@ func TestLimitAndOffsetStatement(t *testing.T) { runLexerTestSuite(t, input, tests) } +func TestAggregateFunctions(t *testing.T) { + input := `SELECT MIN(colOne), MAX(colOne), COUNT(colOne), SUM(colOne), AVG(colOne) FROM tbl;` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.MIN, "MIN"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.COMMA, ","}, + {token.MAX, "MAX"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.COMMA, ","}, + {token.COUNT, "COUNT"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.COMMA, ","}, + {token.SUM, "SUM"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.COMMA, ","}, + {token.AVG, "AVG"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.FROM, "FROM"}, + {token.IDENT, "tbl"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + func TestSelectWithDistinct(t *testing.T) { input := `SELECT DISTINCT * FROM table;` tests := []struct { diff --git a/parser/parser.go b/parser/parser.go index ec0e43d..59d2f09 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -236,23 +236,47 @@ func (parser *Parser) parseSelectCommand() (ast.Command, error) { parser.nextToken() } - err := validateToken(parser.currentToken.Type, []token.Type{token.ASTERISK, token.IDENT}) + err := validateToken(parser.currentToken.Type, []token.Type{token.ASTERISK, token.IDENT, token.MAX, token.MIN, token.SUM, token.AVG, token.COUNT}) if err != nil { return nil, err } + if parser.currentToken.Type == token.ASTERISK { - selectCommand.Space = append(selectCommand.Space, parser.currentToken) + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken}) parser.nextToken() - } else { - for parser.currentToken.Type == token.IDENT { - // Get column name - err := validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - if err != nil { - return nil, err + for parser.currentToken.Type == token.IDENT || isAggregateFunction(parser.currentToken.Type) { + if parser.currentToken.Type != token.IDENT { + aggregateFunction := parser.currentToken + parser.nextToken() + err := validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + if err != nil { + return nil, err + } + if aggregateFunction.Type == token.COUNT { + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.ASTERISK}) + } else { + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + } + if err != nil { + return nil, err + } + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken, AggregateFunc: &aggregateFunction}) + parser.nextToken() + + err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) + if err != nil { + return nil, err + } + } else { + // Get column name + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken}) + parser.nextToken() } - selectCommand.Space = append(selectCommand.Space, parser.currentToken) - parser.nextToken() if parser.currentToken.Type != token.COMMA { break @@ -289,6 +313,21 @@ func (parser *Parser) parseSelectCommand() (ast.Command, error) { return selectCommand, nil } +func (parser *Parser) getColumnName(err error, selectCommand *ast.SelectCommand, aggregateFunction token.Token) error { + // Get column name + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.ASTERISK}) + if err != nil { + return err + } + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken, AggregateFunc: &aggregateFunction}) + parser.nextToken() + return nil +} + +func isAggregateFunction(t token.Type) bool { + return t == token.MIN || t == token.MAX || t == token.COUNT || t == token.SUM || t == token.AVG +} + // parseWhereCommand - Return ast.WhereCommand created from tokens and validate the syntax // // Example of input parsable to the ast.WhereCommand: diff --git a/parser/parser_error_handling_test.go b/parser/parser_error_handling_test.go index c01cfee..bd51d49 100644 --- a/parser/parser_error_handling_test.go +++ b/parser/parser_error_handling_test.go @@ -95,15 +95,23 @@ func TestParseUpdateCommandErrorHandling(t *testing.T) { func TestParseSelectCommandErrorHandling(t *testing.T) { noFromKeyword := SyntaxError{[]string{token.FROM}, token.IDENT} - noColumns := SyntaxError{[]string{token.ASTERISK, token.IDENT}, token.FROM} + noColumns := SyntaxError{[]string{token.ASTERISK, token.IDENT, token.MAX, token.MIN, token.SUM, token.AVG, token.COUNT}, token.FROM} noTableName := SyntaxError{[]string{token.IDENT}, token.SEMICOLON} noSemicolon := SyntaxError{[]string{token.SEMICOLON, token.WHERE, token.ORDER, token.LIMIT, token.OFFSET, token.JOIN, token.LEFT, token.RIGHT, token.INNER, token.FULL}, ""} + noAggregateFunctionParenClosure := SyntaxError{[]string{token.RPAREN}, ","} + noAggregateFunctionLeftParen := SyntaxError{[]string{token.LPAREN}, token.IDENT} + noFromAfterAsterisk := SyntaxError{[]string{token.FROM}, ","} + noAsteriskInsideMaxArgument := SyntaxError{[]string{token.IDENT}, "*"} tests := []errorHandlingTestSuite{ {"SELECT column1, column2 tbl;", noFromKeyword.Error()}, {"SELECT FROM table;", noColumns.Error()}, {"SELECT column1, column2 FROM ;", noTableName.Error()}, {"SELECT column1, column2 FROM table", noSemicolon.Error()}, + {"SELECT SUM(column1, column2 FROM table", noAggregateFunctionParenClosure.Error()}, + {"SELECT SUM column1 FROM table", noAggregateFunctionLeftParen.Error()}, + {"SELECT *, colName FROM table", noFromAfterAsterisk.Error()}, + {"SELECT MAX(*) FROM table", noAsteriskInsideMaxArgument.Error()}, } runParserErrorHandlingSuite(t, tests) diff --git a/parser/parser_test.go b/parser/parser_test.go index 06923ae..e6ccaf1 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -126,12 +126,12 @@ func TestParseSelectCommand(t *testing.T) { tests := []struct { input string expectedTableName string - expectedColumns []token.Token + expectedSpaces []ast.Space expectedDistinct bool }{ - {"SELECT * FROM TBL;", "TBL", []token.Token{{Type: token.ASTERISK, Literal: "*"}}, false}, - {"SELECT ONE, TWO, THREE FROM TBL;", "TBL", []token.Token{{Type: token.IDENT, Literal: "ONE"}, {Type: token.IDENT, Literal: "TWO"}, {Type: token.IDENT, Literal: "THREE"}}, false}, - {"SELECT DISTINCT * FROM TBL;", "TBL", []token.Token{{Type: token.ASTERISK, Literal: "*"}}, true}, + {"SELECT * FROM TBL;", "TBL", []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}}, false}, + {"SELECT ONE, TWO, THREE FROM TBL;", "TBL", []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "ONE"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "TWO"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "THREE"}}}, false}, + {"SELECT DISTINCT * FROM TBL;", "TBL", []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}}, true}, } for testIndex, tt := range tests { @@ -146,7 +146,7 @@ func TestParseSelectCommand(t *testing.T) { t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } - if !testSelectStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedColumns, tt.expectedDistinct) { + if !testSelectStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedSpaces, tt.expectedDistinct) { return } } @@ -290,7 +290,7 @@ func TestSelectWithOrderByCommand(t *testing.T) { SortPatterns: []ast.SortPattern{expectedSortPattern}, } expectedTableName := "tableName" - expectedColumnName := []token.Token{{Type: token.ASTERISK, Literal: "*"}} + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) @@ -305,7 +305,7 @@ func TestSelectWithOrderByCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { return } @@ -323,7 +323,7 @@ func TestSelectWithLimitCommand(t *testing.T) { Count: 5, } expectedTableName := "tableName" - expectedColumnName := []token.Token{{Type: token.ASTERISK, Literal: "*"}} + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) @@ -338,7 +338,7 @@ func TestSelectWithLimitCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { return } @@ -357,7 +357,7 @@ func TestSelectWithOffsetCommand(t *testing.T) { } expectedTableName := "tableName" - expectedColumnName := []token.Token{{Type: token.ASTERISK, Literal: "*"}} + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) @@ -372,7 +372,7 @@ func TestSelectWithOffsetCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { return } @@ -393,7 +393,7 @@ func TestSelectWithLimitAndOffsetCommand(t *testing.T) { Count: 13, } expectedTableName := "tableName" - expectedColumnName := []token.Token{{Type: token.ASTERISK, Literal: "*"}} + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) @@ -408,7 +408,7 @@ func TestSelectWithLimitAndOffsetCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { return } @@ -436,7 +436,7 @@ func TestSelectWithDefaultInnerJoinCommand(t *testing.T) { }, } expectedTableName := "tbl" - expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + expectedSpace := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) @@ -451,7 +451,7 @@ func TestSelectWithDefaultInnerJoinCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpace, false) { return } @@ -475,7 +475,7 @@ func TestSelectWithInnerJoinCommand(t *testing.T) { }, } expectedTableName := "tbl" - expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + expectedSpace := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) @@ -490,7 +490,7 @@ func TestSelectWithInnerJoinCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpace, false) { return } @@ -514,7 +514,7 @@ func TestSelectWithLeftJoinCommand(t *testing.T) { }, } expectedTableName := "tbl" - expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) @@ -529,7 +529,7 @@ func TestSelectWithLeftJoinCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { return } @@ -553,7 +553,7 @@ func TestSelectWithRightJoinCommand(t *testing.T) { }, } expectedTableName := "tbl" - expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) @@ -568,7 +568,7 @@ func TestSelectWithRightJoinCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { return } @@ -592,7 +592,7 @@ func TestSelectWithFullJoinCommand(t *testing.T) { }, } expectedTableName := "tbl" - expectedColumnName := []token.Token{{Type: token.IDENT, Literal: "tbl.one"}, {Type: token.IDENT, Literal: "tbl2.two"}} + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) @@ -607,7 +607,7 @@ func TestSelectWithFullJoinCommand(t *testing.T) { selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if !testSelectStatement(t, selectCommand, expectedTableName, expectedColumnName, false) { + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { return } @@ -618,6 +618,55 @@ func TestSelectWithFullJoinCommand(t *testing.T) { testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) } +func TestSelectWithAggregateFunctions(t *testing.T) { + input := "SELECT MIN(colOne), MAX(colOne), COUNT(*), COUNT(colOne), SUM(colOne), AVG(colOne) FROM tbl;" + + expectedTableName := "tbl" + expectedSpaces := []ast.Space{ + { + ColumnName: token.Token{Type: token.IDENT, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.MIN, Literal: "MIN"}, + }, + { + ColumnName: token.Token{Type: token.ASTERISK, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.MAX, Literal: "MAX"}, + }, + { + ColumnName: token.Token{Type: token.IDENT, Literal: "*"}, + AggregateFunc: &token.Token{Type: token.COUNT, Literal: "COUNT"}, + }, + { + ColumnName: token.Token{Type: token.IDENT, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.COUNT, Literal: "COUNT"}, + }, + { + ColumnName: token.Token{Type: token.IDENT, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.SUM, Literal: "SUM"}, + }, + { + ColumnName: token.Token{Type: token.IDENT, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.AVG, Literal: "AVG"}, + }, + } + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { + return + } +} + func TestParseUpdateCommand(t *testing.T) { tests := []struct { input string @@ -778,7 +827,7 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { } } -func testSelectStatement(t *testing.T, command ast.Command, expectedTableName string, expectedColumnsTokens []token.Token, expectedDistinct bool) bool { +func testSelectStatement(t *testing.T, command ast.Command, expectedTableName string, expectedSpaces []ast.Space, expectedDistinct bool) bool { if command.TokenLiteral() != "SELECT" { t.Errorf("command.TokenLiteral() not 'SELECT'. got=%q", command.TokenLiteral()) return false @@ -800,8 +849,8 @@ func testSelectStatement(t *testing.T, command ast.Command, expectedTableName st return false } - if !tokenArrayEquals(actualSelectCommand.Space, expectedColumnsTokens) { - t.Errorf("actualSelectCommand has diffrent space tan expected. %v != %v", actualSelectCommand.Space, expectedColumnsTokens) + if !spaceArrayEquals(actualSelectCommand.Space, expectedSpaces) { + t.Errorf("actualSelectCommand has diffrent space than expected. %+v != %+v", actualSelectCommand.Space, expectedSpaces) return false } @@ -875,6 +924,24 @@ func tokenArrayEquals(a []token.Token, b []token.Token) bool { return true } +func spaceArrayEquals(a []ast.Space, b []ast.Space) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v.ColumnName.Literal != b[i].ColumnName.Literal { + return false + } + if v.ContainsAggregateFunc() != b[i].ContainsAggregateFunc() { + return false + } + if v.ContainsAggregateFunc() && b[i].ContainsAggregateFunc() && v.AggregateFunc.Literal != b[i].AggregateFunc.Literal { + return false + } + } + return true +} + func tokenMapEquals(a map[token.Token]ast.Anonymitifier, b map[token.Token]ast.Anonymitifier) bool { if len(a) != len(b) { return false diff --git a/test_file b/test_file index aac029a..35ec3d2 100644 --- a/test_file +++ b/test_file @@ -17,13 +17,20 @@ INSERT INTO tbl VALUES( 'goodbye', 5, 22, 'P' ); SELECT DISTINCT * FROM tbl; DROP TABLE tbl; - CREATE TABLE table1( id INT, value TEXT); - CREATE TABLE table2( id INT, value TEXT); - INSERT INTO table1 VALUES(1, 'Value1'); - INSERT INTO table1 VALUES(2, 'Value2'); - INSERT INTO table2 VALUES(2, 'Value2'); - INSERT INTO table2 VALUES(3, 'Value3'); - SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id; - SELECT table1.value, table2.value FROM table1 INNER JOIN table2 ON table1.id EQUAL table2.id; - SELECT table1.value, table2.value FROM table1 LEFT JOIN table2 ON table1.id EQUAL table2.id; - SELECT table1.value, table2.value FROM table1 RIGHT JOIN table2 ON table1.id EQUAL table2.id; + CREATE TABLE table1( id INT, value TEXT); + CREATE TABLE table2( id INT, value TEXT); + INSERT INTO table1 VALUES(1, 'Value1'); + INSERT INTO table1 VALUES(2, 'Value2'); + INSERT INTO table2 VALUES(2, 'Value2'); + INSERT INTO table2 VALUES(3, 'Value3'); + SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id; + SELECT table1.value, table2.value FROM table1 INNER JOIN table2 ON table1.id EQUAL table2.id; + SELECT table1.value, table2.value FROM table1 LEFT JOIN table2 ON table1.id EQUAL table2.id; + SELECT table1.value, table2.value FROM table1 RIGHT JOIN table2 ON table1.id EQUAL table2.id; + INSERT INTO table1 VALUES(3, 'Value3'); + SELECT MAX(id), MAX(value) FROM table1; + SELECT MIN(value), MIN(id) FROM table1; + SELECT COUNT(*), COUNT(id), COUNT(value) FROM table1; + SELECT SUM(id), SUM(value) FROM table1; + SELECT AVG(id), AVG(value) FROM table1; + SELECT AVG(id), id FROM table1; \ No newline at end of file diff --git a/token/token.go b/token/token.go index 5da87ef..b02df85 100644 --- a/token/token.go +++ b/token/token.go @@ -54,6 +54,11 @@ const ( LEFT = "LEFT" RIGHT = "RIGHT" ON = "ON" + MIN = "MIN" + MAX = "MAX" + COUNT = "COUNT" + SUM = "SUM" + AVG = "AVG" TO = "TO" @@ -99,6 +104,11 @@ var keywords = map[string]Type{ "RIGHT": RIGHT, "JOIN": JOIN, "ON": ON, + "MIN": MIN, + "MAX": MAX, + "COUNT": COUNT, + "SUM": SUM, + "AVG": AVG, "TO": TO, "VALUES": VALUES, "WHERE": WHERE, From 364685834483c6a3f7c04cd886cd448937bfa9db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sara=20Ryfczy=C5=84ska?= Date: Tue, 19 Nov 2024 23:40:34 +0100 Subject: [PATCH 15/21] Feature - "IN" and "NOTIN" condition (#24) * Parser, lexer, token and ast impl * engine, parser err handling, readme, tests for IN --- .github/expected_results/end2end.txt | 11 ++++ README.md | 19 ++++++ ast/ast.go | 14 +++++ engine/engine.go | 32 +++++++++- engine/engine_test.go | 84 ++++++++++++++++++++++++- lexer/lexer_test.go | 38 +++++++++++ parser/parser.go | 94 +++++++++++++++++++++++++--- parser/parser_error_handling_test.go | 12 +++- parser/parser_test.go | 60 ++++++++++++++++++ test_file | 2 + token/token.go | 4 ++ 11 files changed, 358 insertions(+), 12 deletions(-) diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt index 68f7159..3f6d4e1 100644 --- a/.github/expected_results/end2end.txt +++ b/.github/expected_results/end2end.txt @@ -18,6 +18,17 @@ Data Inserted +----------+-----+-------+------+ | 'byebye' | 3 | 33 | 'e' | +----------+-----+-------+------+ ++-----------+-----+-------+------+ +| one | two | three | four | ++-----------+-----+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | 3 | 33 | 'e' | ++-----------+-----+-------+------+ ++---------+-----+-------+------+ +| one | two | three | four | ++---------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | ++---------+-----+-------+------+ +-----+-----+-------+------+ | one | two | three | four | +-----+-----+-------+------+ diff --git a/README.md b/README.md index f0d7904..d95c5ee 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,25 @@ go clean -testcache; go test ./... ``` Supported logical operations are: ``EQUAL``, ``NOT``, ``OR``, ``AND``, ```FALSE```, ```TRUE```. +* ***IN*** - is used to check if a value from a column exists in a specified list of values. + It can be used with ``WHERE`` like this: + ```sql + SELECT column1, column2 + FROM table_name + WHERE column1 IN ('value1', 'value2'); + ``` + ``table_name`` is the name of the table, and ``WHERE`` returns rows that value is either equal to + ``value1`` or ``value2`` + +* ***NOTIN*** - is used to check if a value from a column doesn't exist in a specified list of + values. It can be used with ``WHERE`` like this: + ```sql + SELECT column1, column2 + FROM table_name + WHERE column1 NOTIN ('value1', 'value2'); + ``` + ``table_name`` is the name of the table, and ``WHERE`` returns rows which values are not equal to + ``value1`` and not equal to ``value2`` * ***DELETE FROM*** is used to delete existing records in a table. It can be used like this: ```sql diff --git a/ast/ast.go b/ast/ast.go index 5e3a434..5aa7657 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -107,6 +107,20 @@ func (ls ConditionExpression) GetIdentifiers() []Identifier { return identifiers } +// ContainExpression - TokenType of Expression that represents structure for IN operator +// +// Example: +// colName IN ('value1', 'value2', 'value3') +type ContainExpression struct { + Left Identifier // name of column + Right []Anonymitifier // name of column + Contains bool // IN or NOTIN +} + +func (ls ContainExpression) GetIdentifiers() []Identifier { + return []Identifier{ls.Left} +} + // OperationExpression - TokenType of Expression that represent 2 other Expressions and conditional operation // // Example: diff --git a/engine/engine.go b/engine/engine.go index 5874d78..37cb9d8 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -131,7 +131,6 @@ func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) (*Ta return nil, err } } else { - // panic: runtime error: invalid memory address or nil pointer dereference [recovered] table, err = engine.selectFromProvidedTable(selectCommand, table) if err != nil { return nil, err @@ -676,6 +675,9 @@ func isFulfillingFilters(row map[string]ValueInterface, expressionTree ast.Expre return processBooleanExpression(mappedExpression) case *ast.ConditionExpression: return processConditionExpression(row, mappedExpression, commandName) + case *ast.ContainExpression: + return processContainExpression(row, mappedExpression) + default: return false, &UnsupportedExpressionTypeError{commandName: commandName, variable: fmt.Sprintf("%s", mappedExpression)} } @@ -702,6 +704,34 @@ func processConditionExpression(row map[string]ValueInterface, conditionExpressi } } +func processContainExpression(row map[string]ValueInterface, containExpression *ast.ContainExpression) (bool, error) { + valueLeft, err := getTifierValue(containExpression.Left, row) + if err != nil { + return false, err + } + + result, err := ifValueInterfaceInArray(containExpression.Right, valueLeft) + + if containExpression.Contains { + return result, err + } + + return !result, err +} + +func ifValueInterfaceInArray(array []ast.Anonymitifier, valueLeft ValueInterface) (bool, error) { + for _, expectedValue := range array { + value, err := getInterfaceValue(expectedValue.Token) + if err != nil { + return false, err + } + if value.IsEqual(valueLeft) { + return true, nil + } + } + return false, nil +} + func processOperationExpression(row map[string]ValueInterface, operationExpression *ast.OperationExpression, commandName string) (bool, error) { if operationExpression.Operation.Type == token.AND { left, err := isFulfillingFilters(row, operationExpression.Left, commandName) diff --git a/engine/engine_test.go b/engine/engine_test.go index 6320ca7..32e6a0f 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -126,6 +126,88 @@ func TestSelectWithWhereNotEqual(t *testing.T) { engineTestSuite.runTestSuite(t) } +func TestSelectWithWhereContains(t *testing.T) { + + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1 WHERE three IN (11, 22, 67);", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"hello", "1", "11", "q"}, + {"goodbye", "2", "22", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestSelectWithWhereNotContains(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1 WHERE one NOTIN ('hello', 'byebye', 'youAreTheBest');", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"goodbye", "2", "22", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestSelectWithWhereContainsButResponseIsEmpty(t *testing.T) { + + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1 WHERE one IN ('I', 'dont', 'exist', 'anymore');", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestSelectWithWhereNotContainsButResponseIsEmpty(t *testing.T) { + + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1 WHERE two NOTIN (1, 2, 3, 4);", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + func TestSelectWithWhereLogicalOperationAnd(t *testing.T) { engineTestSuite := engineTableContentTestSuite{ @@ -180,7 +262,7 @@ func TestSelectWithWhereLogicalOperationOROperationAND(t *testing.T) { "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", "INSERT INTO tb1 VALUES( 'goodbye', 3, 33, 'e' );", }, - selectInput: "SELECT * FROM tb1 WHERE one NOT 'goodbye' OR two EQUAL 3 AND four EQUAL 'e';", + selectInput: "SELECT * FROM tb1 WHERE one NOT 'goodbye' OR two IN (3) AND four EQUAL 'e';", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 66a6294..4cf8644 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -226,6 +226,44 @@ func TestLogicalStatements(t *testing.T) { runLexerTestSuite(t, input, tests) } +func TestInStatement(t *testing.T) { + input := + ` + WHERE two IN (1, 2) AND + WHERE three NOTIN ('one', 'two'); + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.WHERE, "WHERE"}, + {token.IDENT, "two"}, + {token.IN, "IN"}, + {token.LPAREN, "("}, + {token.LITERAL, "1"}, + {token.COMMA, ","}, + {token.LITERAL, "2"}, + {token.RPAREN, ")"}, + {token.AND, "AND"}, + {token.WHERE, "WHERE"}, + {token.IDENT, "three"}, + {token.NOTIN, "NOTIN"}, + {token.LPAREN, "("}, + {token.APOSTROPHE, "'"}, + {token.IDENT, "one"}, + {token.APOSTROPHE, "'"}, + {token.COMMA, ","}, + {token.APOSTROPHE, "'"}, + {token.IDENT, "two"}, + {token.APOSTROPHE, "'"}, + {token.RPAREN, ")"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + func TestDeleteStatement(t *testing.T) { input := `DELETE FROM table WHERE two NOT 11 OR TRUE;` tests := []struct { diff --git a/parser/parser.go b/parser/parser.go index 59d2f09..c4d6da3 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -681,6 +681,7 @@ func (parser *Parser) parseUpdateCommand() (ast.Command, error) { // - ast.OperationExpression // - ast.BooleanExpression // - ast.ConditionExpression +// - ast.ContainExpression func (parser *Parser) getExpression() (bool, ast.Expression, error) { booleanExpressionExists, booleanExpression := parser.getBooleanExpression() @@ -689,7 +690,12 @@ func (parser *Parser) getExpression() (bool, ast.Expression, error) { return false, nil, err } - operationExpressionExists, operationExpression, err := parser.getOperationExpression(booleanExpressionExists, conditionalExpressionExists, booleanExpression, conditionalExpression) + containExpressionExists, containExpression, err := parser.getContainExpression() + if err != nil { + return false, nil, err + } + + operationExpressionExists, operationExpression, err := parser.getOperationExpression(booleanExpressionExists, conditionalExpressionExists, containExpressionExists, booleanExpression, conditionalExpression, containExpression) if err != nil { return false, nil, err } @@ -702,6 +708,10 @@ func (parser *Parser) getExpression() (bool, ast.Expression, error) { return true, conditionalExpression, err } + if containExpressionExists { + return true, containExpression, err + } + if booleanExpressionExists { return true, booleanExpression, err } @@ -710,10 +720,10 @@ func (parser *Parser) getExpression() (bool, ast.Expression, error) { } // getOperationExpression - Return ast.OperationExpression created from tokens and validate the syntax -func (parser *Parser) getOperationExpression(booleanExpressionExists bool, conditionalExpressionExists bool, booleanExpression *ast.BooleanExpression, conditionalExpression *ast.ConditionExpression) (bool, *ast.OperationExpression, error) { +func (parser *Parser) getOperationExpression(booleanExpressionExists bool, conditionalExpressionExists bool, containExpressionExists bool, booleanExpression *ast.BooleanExpression, conditionalExpression *ast.ConditionExpression, containExpression *ast.ContainExpression) (bool, *ast.OperationExpression, error) { operationExpression := &ast.OperationExpression{} - if (booleanExpressionExists || conditionalExpressionExists) && (parser.currentToken.Type == token.OR || parser.currentToken.Type == token.AND) { + if (booleanExpressionExists || conditionalExpressionExists || containExpressionExists) && (parser.currentToken.Type == token.OR || parser.currentToken.Type == token.AND) { if booleanExpressionExists { operationExpression.Left = booleanExpression } @@ -722,6 +732,10 @@ func (parser *Parser) getOperationExpression(booleanExpressionExists bool, condi operationExpression.Left = conditionalExpression } + if containExpressionExists { + operationExpression.Left = containExpression + } + operationExpression.Operation = parser.currentToken parser.nextToken() @@ -760,6 +774,12 @@ func (parser *Parser) getBooleanExpression() (bool, *ast.BooleanExpression) { func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression, error) { conditionalExpression := &ast.ConditionExpression{} + err := validateToken(parser.peekToken.Type, []token.Type{token.EQUAL, token.NOT}) + if err != nil { + return false, nil, nil + } + conditionalExpression.Condition = parser.peekToken + switch parser.currentToken.Type { case token.IDENT: conditionalExpression.Left = ast.Identifier{Token: parser.currentToken} @@ -779,11 +799,7 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression return false, conditionalExpression, nil } - err := validateToken(parser.currentToken.Type, []token.Type{token.EQUAL, token.NOT}) - if err != nil { - return false, nil, err - } - conditionalExpression.Condition = parser.currentToken + // skip EQUAL or NOT parser.nextToken() switch parser.currentToken.Type { @@ -808,6 +824,68 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression return true, conditionalExpression, nil } +// getContainExpression - Return ast.ContainExpression created from tokens and validate the syntax +func (parser *Parser) getContainExpression() (bool, *ast.ContainExpression, error) { + containExpression := &ast.ContainExpression{} + + err := validateToken(parser.peekToken.Type, []token.Type{token.IN, token.NOTIN}) + if err != nil { + return false, nil, nil + } + if parser.peekToken.Type == token.IN { + containExpression.Contains = true + } else { + containExpression.Contains = false + } + + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return false, nil, nil + } + containExpression.Left = ast.Identifier{Token: parser.currentToken} + + parser.nextToken() + + // skip IN or NOTIN + parser.nextToken() + + err = validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + if err != nil { + return false, nil, err + } + + for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.APOSTROPHE { + parser.skipIfCurrentTokenIsApostrophe() + + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL}) + if err != nil { + return false, nil, err + } + containExpression.Right = append(containExpression.Right, ast.Anonymitifier{Token: parser.currentToken}) + // Ignore token.IDENT or token.LITERAL + parser.nextToken() + + parser.skipIfCurrentTokenIsApostrophe() + + if parser.currentToken.Type != token.COMMA { + if parser.currentToken.Type != token.RPAREN { + return false, nil, &SyntaxError{expecting: []string{token.COMMA, token.RPAREN}, got: string(parser.currentToken.Type)} + } + break + } + + // Ignore token.COMMA + parser.nextToken() + } + + err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) + if err != nil { + return false, nil, err + } + + return true, containExpression, err +} + // ParseSequence - Return ast.Sequence (sequence of commands) created from client input after tokenization // // Parse tokens returned by lexer to structures defines in ast package, and it's responsible for syntax validation. diff --git a/parser/parser_error_handling_test.go b/parser/parser_error_handling_test.go index bd51d49..1dd0112 100644 --- a/parser/parser_error_handling_test.go +++ b/parser/parser_error_handling_test.go @@ -121,21 +121,29 @@ func TestParseWhereCommandErrorHandling(t *testing.T) { selectCommandPrefix := "SELECT * FROM tbl " noPredecessorError := NoPredecessorParserError{command: token.WHERE} noColName := LogicalExpressionParsingError{} - notOrEqualIsMissing := SyntaxError{expecting: []string{token.EQUAL, token.NOT}, got: token.APOSTROPHE} + noOperatorInsideWhereStatementException := LogicalExpressionParsingError{} valueIsMissing := SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL}, got: token.SEMICOLON} tokenAnd := token.AND conjunctionIsMissing := SyntaxError{expecting: []string{token.SEMICOLON, token.ORDER}, got: token.IDENT} nextLogicalExpressionIsMissing := LogicalExpressionParsingError{afterToken: &tokenAnd} noSemicolon := SyntaxError{expecting: []string{token.SEMICOLON, token.ORDER}, got: ""} + noLeftParGotSemicolon := SyntaxError{expecting: []string{token.LPAREN}, got: ";"} + noLeftParGotNumber := SyntaxError{expecting: []string{token.LPAREN}, got: token.LITERAL} + noComma := SyntaxError{expecting: []string{token.COMMA, token.RPAREN}, got: token.LITERAL} + noInKeywordException := LogicalExpressionParsingError{} tests := []errorHandlingTestSuite{ {"WHERE col1 NOT 'goodbye' OR col2 EQUAL 3;", noPredecessorError.Error()}, {selectCommandPrefix + "WHERE NOT 'goodbye' OR column2 EQUAL 3;", noColName.Error()}, - {selectCommandPrefix + "WHERE one 'goodbye';", notOrEqualIsMissing.Error()}, + {selectCommandPrefix + "WHERE one 'goodbye';", noOperatorInsideWhereStatementException.Error()}, {selectCommandPrefix + "WHERE one EQUAL;", valueIsMissing.Error()}, {selectCommandPrefix + "WHERE one EQUAL 5 two NOT 1;", conjunctionIsMissing.Error()}, {selectCommandPrefix + "WHERE one EQUAL 5 AND;", nextLogicalExpressionIsMissing.Error()}, {selectCommandPrefix + "WHERE one EQUAL 5 AND two NOT 5", noSemicolon.Error()}, + {selectCommandPrefix + "WHERE one IN ;", noLeftParGotSemicolon.Error()}, + {selectCommandPrefix + "WHERE one IN 5;", noLeftParGotNumber.Error()}, + {selectCommandPrefix + "WHERE one IN (5 6);", noComma.Error()}, + {selectCommandPrefix + "WHERE one (5, 6);", noInKeywordException.Error()}, } runParserErrorHandlingSuite(t, tests) diff --git a/parser/parser_test.go b/parser/parser_test.go index e6ccaf1..42d8f3c 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -165,6 +165,24 @@ func TestParseWhereCommand(t *testing.T) { Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, } + thirdExpression := ast.ContainExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName3"}}, + Right: []ast.Anonymitifier{ + {Token: token.Token{Type: token.LITERAL, Literal: "1"}}, + {Token: token.Token{Type: token.LITERAL, Literal: "2"}}, + }, + Contains: true, + } + + fourthExpression := ast.ContainExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName4"}}, + Right: []ast.Anonymitifier{ + {Token: token.Token{Type: token.IDENT, Literal: "one"}}, + {Token: token.Token{Type: token.IDENT, Literal: "two"}}, + }, + Contains: false, + } + tests := []struct { input string expectedExpression ast.Expression @@ -177,6 +195,14 @@ func TestParseWhereCommand(t *testing.T) { input: "SELECT * FROM TBL WHERE colName2 EQUAL 6462389;", expectedExpression: secondExpression, }, + { + input: "SELECT * FROM TBL WHERE colName3 IN (1, 2);", + expectedExpression: thirdExpression, + }, + { + input: "SELECT * FROM TBL WHERE colName4 NOTIN ('one', 'two');", + expectedExpression: fourthExpression, + }, } for testIndex, tt := range tests { @@ -1034,6 +1060,11 @@ func expressionsAreEqual(first ast.Expression, second ast.Expression) bool { return validateOperationExpression(second, operationExpression) } + containExpression, containExpressionIsValid := first.(*ast.ContainExpression) + if containExpressionIsValid { + return validateContainExpression(second, containExpression) + } + return false } @@ -1051,6 +1082,35 @@ func validateOperationExpression(second ast.Expression, operationExpression *ast return expressionsAreEqual(operationExpression.Left, secondOperationExpression.Left) && expressionsAreEqual(operationExpression.Right, secondOperationExpression.Right) } +func validateContainExpression(expression ast.Expression, exptectedContainExpression *ast.ContainExpression) bool { + actualContainExpression, actualContainExpressionIsValid := expression.(ast.ContainExpression) + + if !actualContainExpressionIsValid { + return false + } + + if exptectedContainExpression.Contains != actualContainExpression.Contains { + return false + } + + if actualContainExpression.Left.GetToken().Literal != exptectedContainExpression.Left.GetToken().Literal && + actualContainExpression.Left.IsIdentifier() == exptectedContainExpression.Left.IsIdentifier() { + return false + } + + if len(exptectedContainExpression.Right) != len(actualContainExpression.Right) { + return false + } + + for i := 0; i < len(exptectedContainExpression.Right); i++ { + if exptectedContainExpression.Right[i] != actualContainExpression.Right[i] { + return false + } + } + + return true +} + func validateConditionExpression(second ast.Expression, conditionExpression *ast.ConditionExpression) bool { secondConditionExpression, secondConditionExpressionIsValid := second.(ast.ConditionExpression) diff --git a/test_file b/test_file index 35ec3d2..bf97956 100644 --- a/test_file +++ b/test_file @@ -5,6 +5,8 @@ SELECT * FROM tbl WHERE one EQUAL 'byebye'; SELECT one, three FROM tbl WHERE two NOT 3; SELECT * FROM tbl WHERE one NOT 'goodbye' AND two EQUAL 3; + SELECT * FROM tbl WHERE one IN ('goodbye', 'byebye'); + SELECT * FROM tbl WHERE one NOTIN ('goodbye', 'byebye'); SELECT * FROM tbl WHERE FALSE; SELECT * FROM tbl LIMIT 1; SELECT * FROM tbl OFFSET 1; diff --git a/token/token.go b/token/token.go index b02df85..049baa1 100644 --- a/token/token.go +++ b/token/token.go @@ -59,6 +59,8 @@ const ( COUNT = "COUNT" SUM = "SUM" AVG = "AVG" + IN = "IN" + NOTIN = "NOTIN" TO = "TO" @@ -109,6 +111,8 @@ var keywords = map[string]Type{ "COUNT": COUNT, "SUM": SUM, "AVG": AVG, + "IN": IN, + "NOTIN": NOTIN, "TO": TO, "VALUES": VALUES, "WHERE": WHERE, From af34c3d869a4cff92546150285c19ce91a1bc9a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Krupski?= <34219324+ixior462@users.noreply.github.com> Date: Wed, 4 Dec 2024 01:32:59 +0100 Subject: [PATCH 16/21] Feature/null insertion (#25) * Tokens have been added * Parser and lexer logic has been updated. * Introduced a new section in README 'supported types' * Add engine integration with NULL values --------- Co-authored-by: LissaGreense --- .github/expected_results/end2end.txt | 81 +++++++++++------------ .github/workflows/docker-publish.yml | 6 +- README.md | 90 +++++++++++++++++-------- engine/engine.go | 12 ++-- engine/engine_test.go | 98 +++++++++++++++++++--------- engine/engine_utils.go | 2 + engine/generic_value.go | 38 +++++++---- engine/generic_value_test.go | 57 ++++++++++++---- lexer/lexer_test.go | 23 +++++++ parser/parser.go | 24 ++++--- parser/parser_error_handling_test.go | 6 +- parser/parser_test.go | 56 +++++++++++----- test_file | 10 +-- token/token.go | 2 + 14 files changed, 341 insertions(+), 164 deletions(-) diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt index 3f6d4e1..e0e7f34 100644 --- a/.github/expected_results/end2end.txt +++ b/.github/expected_results/end2end.txt @@ -2,28 +2,28 @@ Table 'tbl' has been created Data Inserted Data Inserted Data Inserted -+----------+-----+-------+------+ -| one | two | three | four | -+----------+-----+-------+------+ -| 'byebye' | 3 | 33 | 'e' | -+----------+-----+-------+------+ ++----------+------+-------+------+ +| one | two | three | four | ++----------+------+-------+------+ +| 'byebye' | NULL | 33 | 'e' | ++----------+------+-------+------+ +-----------+-------+ | one | three | +-----------+-------+ | 'hello' | 11 | | 'goodbye' | 22 | +-----------+-------+ -+----------+-----+-------+------+ -| one | two | three | four | -+----------+-----+-------+------+ -| 'byebye' | 3 | 33 | 'e' | -+----------+-----+-------+------+ -+-----------+-----+-------+------+ -| one | two | three | four | -+-----------+-----+-------+------+ -| 'goodbye' | 1 | 22 | 'w' | -| 'byebye' | 3 | 33 | 'e' | -+-----------+-----+-------+------+ ++----------+------+-------+------+ +| one | two | three | four | ++----------+------+-------+------+ +| 'byebye' | NULL | 33 | 'e' | ++----------+------+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ +---------+-----+-------+------+ | one | two | three | four | +---------+-----+-------+------+ @@ -38,12 +38,12 @@ Data Inserted +---------+-----+-------+------+ | 'hello' | 1 | 11 | 'q' | +---------+-----+-------+------+ -+-----------+-----+-------+------+ -| one | two | three | four | -+-----------+-----+-------+------+ -| 'goodbye' | 1 | 22 | 'w' | -| 'byebye' | 3 | 33 | 'e' | -+-----------+-----+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ +-----------+-----+-------+------+ | one | two | three | four | +-----------+-----+-------+------+ @@ -63,19 +63,20 @@ Data from 'tbl' has been deleted | 'hello' | +-----------+ Table: 'tbl' has been updated -+-----------+-----+-------+------+ -| one | two | three | four | -+-----------+-----+-------+------+ -| 'hello' | 1 | 11 | 'q' | -| 'goodbye' | 5 | 22 | 'P' | -+-----------+-----+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | NULL | 22 | 'P' | ++-----------+------+-------+------+ Data Inserted -+-----------+-----+-------+------+ -| one | two | three | four | -+-----------+-----+-------+------+ -| 'hello' | 1 | 11 | 'q' | -| 'goodbye' | 5 | 22 | 'P' | -+-----------+-----+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | NULL | 22 | 'P' | +| 'goodbye' | 5 | 22 | 'P' | ++-----------+------+-------+------+ Table: 'tbl' has been dropped Table 'table1' has been created Table 'table2' has been created @@ -87,24 +88,24 @@ Data Inserted | table1.value | table2.value | +--------------+--------------+ | 'Value1' | NULL | -| 'Value2' | 'Value2' | +| NULL | 'Value2' | | NULL | 'Value3' | +--------------+--------------+ +--------------+--------------+ | table1.value | table2.value | +--------------+--------------+ -| 'Value2' | 'Value2' | +| NULL | 'Value2' | +--------------+--------------+ +--------------+--------------+ | table1.value | table2.value | +--------------+--------------+ | 'Value1' | NULL | -| 'Value2' | 'Value2' | +| NULL | 'Value2' | +--------------+--------------+ +--------------+--------------+ | table1.value | table2.value | +--------------+--------------+ -| 'Value2' | 'Value2' | +| NULL | 'Value2' | | NULL | 'Value3' | +--------------+--------------+ Data Inserted @@ -116,12 +117,12 @@ Data Inserted +------------+---------+ | MIN(value) | MIN(id) | +------------+---------+ -| Value1 | 1 | +| NULL | 1 | +------------+---------+ +----------+-----------+--------------+ | COUNT(*) | COUNT(id) | COUNT(value) | +----------+-----------+--------------+ -| 3 | 3 | 3 | +| 3 | 3 | 2 | +----------+-----------+--------------+ +---------+------------+ | SUM(id) | SUM(value) | diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 2b66e31..caff5a1 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -17,19 +17,19 @@ jobs: packages: write id-token: write steps: - + - name: Checkout repository uses: actions/checkout@v3 - name: Docker build run: docker build . --tag ${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }} - + - name: Docker login run: docker login -u ${{ secrets.DOCKERHUB_USERNAME }} -p ${{ secrets.DOCKERHUB_TOKEN }} - name: Docker tag run: docker image tag ${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }} ${{ env.REGISTRY }}/${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }} - + - name: Docker push run: docker push ${{ env.REGISTRY }}/${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }} diff --git a/README.md b/README.md index d95c5ee..1bb4c95 100644 --- a/README.md +++ b/README.md @@ -22,14 +22,18 @@ You can compile the project with ``go build``, this will create ``GO4SQL`` binar Currently, there are 3 modes to chose from: -1. `File Mode` - You can specify file path with ``./GO4SQL -file file_path``, that will read the input +1. `File Mode` - You can specify file path with ``./GO4SQL -file file_path``, that will read the + input data directly into the program and print the result. -2. `Stream Mode` - With ``./GO4SQL -stream`` you can run the program in stream mode, then you provide SQL commands +2. `Stream Mode` - With ``./GO4SQL -stream`` you can run the program in stream mode, then you + provide SQL commands in your console (from standard input). -3. `Socket Mode` - To start Socket Server use `./GO4SQL -socket`, it will be listening on port `1433` by default. To - choose port different other than that, for example equal to `1444`, go with: `./GO4SQL -socket -port 1444` +3. `Socket Mode` - To start Socket Server use `./GO4SQL -socket`, it will be listening on port + `1433` by default. To + choose port different other than that, for example equal to `1444`, go with: + `./GO4SQL -socket -port 1444` ## UNIT TESTS @@ -44,7 +48,20 @@ go clean -testcache; go test ./... 1. Pull docker image: `docker pull kajedot/go4sql:latest` 2. Run docker container in the interactive mode, remember to provide flag, for example: `docker run -i kajedot/go4sql -stream` -3. You can test this image with `test_file` provided in this repo: `docker run -i kajedot/go4sql -stream < test_file` +3. You can test this image with `test_file` provided in this repo: + `docker run -i kajedot/go4sql -stream < test_file` + +## SUPPORTED TYPES + ++ **TEXT Type** - represents string values. Number or NULL can be converted to this type by wrapping + with apostrophes. Columns can store this type with **TEXT** keyword while using **CREATE** + command. ++ **NUMERIC Type** - represents integer values, columns can store this type with **INT** keyword + while using **CREATE** command. In general every digit-only value is interpreted as this type. ++ **NULL Type** - columns can't be assigned that type, but it can be used with **INSERT INTO**, + **UPDATE**, and inside **WHERE** statements, also it can be a product of **JOIN** commands + (besides **FULL JOIN**). In GO4SQL NULL is the smallest possible value, what means it can be + compared with other types with **EQUAL** and **NOT** statements. ## FUNCTIONALITY @@ -147,7 +164,8 @@ go clean -testcache; go test ./... ORDER BY column1 ASC LIMIT 5; ``` - In this case, this command will order by ``column1`` in ascending order and return 5 first records. + In this case, this command will order by ``column1`` in ascending order and return 5 first + records. * ***OFFSET*** is used to reduce number of rows printed out by not skipping specified numbers of @@ -169,8 +187,10 @@ go clean -testcache; go test ./... ``` In this case, this command will return only unique rows from ``table_name`` table. -* ***INNER JOIN*** is used to return a new table by combining rows from both tables where there is a match on the - specified condition. Only the rows that satisfy the condition from both tables are included in the result. +* ***INNER JOIN*** is used to return a new table by combining rows from both tables where there is a + match on the + specified condition. Only the rows that satisfy the condition from both tables are included in the + result. Rows from either table that do not meet the condition are excluded from the result. ```sql SELECT * @@ -185,38 +205,50 @@ go clean -testcache; go test ./... INNER JOIN tableTwo ON tableOne.columnY EQUAL tableTwo.columnX; ``` - In this case, this command will return all columns from tableOne and tableTwo for rows where the condition - ``tableOne.columnY`` = ``tableTwo.columnX`` is met (i.e., the value of ``columnY`` in ``tableOne`` is equal to the + In this case, this command will return all columns from tableOne and tableTwo for rows where the + condition + ``tableOne.columnY`` = ``tableTwo.columnX`` is met (i.e., the value of ``columnY`` in ``tableOne`` + is equal to the value of ``columnX`` in ``tableTwo``). -* ***LEFT JOIN*** is used to return a new table that includes all records from the left table and the matched records - from the right table. If there is no match, the result will contain empty values for columns from the right table. +* ***LEFT JOIN*** is used to return a new table that includes all records from the left table and + the matched records + from the right table. If there is no match, the result will contain empty values for columns from + the right table. ```sql SELECT * FROM tableOne LEFT JOIN tableTwo ON tableOne.columnY EQUAL tableTwo.columnX; ``` - In this case, this command will return all columns from ``tableOne`` and the matching columns from ``tableTwo``. For + In this case, this command will return all columns from ``tableOne`` and the matching columns from + ``tableTwo``. For rows in - ``tableOne`` that do not have a corresponding match in ``tableTwo``, the result will include empty values for columns + ``tableOne`` that do not have a corresponding match in ``tableTwo``, the result will include empty + values for columns from ``tableTwo``. -* ***RIGHT JOIN*** is used to return a new table that includes all records from the right table and the matched records - from the left table. If there is no match, the result will contain empty values for columns from the left table. +* ***RIGHT JOIN*** is used to return a new table that includes all records from the right table and + the matched records + from the left table. If there is no match, the result will contain empty values for columns from + the left table. ```sql SELECT * FROM tableOne RIGHT JOIN tableTwo ON tableOne.columnY EQUAL tableTwo.columnX; ``` - In this case, this command will return all columns from ``tableTwo`` and the matching columns from ``tableOne``. For + In this case, this command will return all columns from ``tableTwo`` and the matching columns from + ``tableOne``. For rows in - ``tableTwo`` that do not have a corresponding match in ``tableOne``, the result will include empty values for columns + ``tableTwo`` that do not have a corresponding match in ``tableOne``, the result will include empty + values for columns from ``tableOne``. -* ***FULL JOIN*** is used to return a new table created by joining two tables as a whole. The joined table contains all - records from both tables and fills empty values for missing matches on either side. This join combines the results of +* ***FULL JOIN*** is used to return a new table created by joining two tables as a whole. The + joined table contains all + records from both tables and fills empty values for missing matches on either side. This join + combines the results of both ``LEFT JOIN`` and ``RIGHT JOIN``. ```sql SELECT * @@ -224,8 +256,10 @@ go clean -testcache; go test ./... FULL JOIN tableTwo ON tableOne.columnY EQUAL tableTwo.columnX; ``` - In this case, this command will return all columns from ``tableOne`` and ``tableTwo`` for rows fulfilling condition - ``tableOne.columnY EQUAL tableTwo.columnX`` (value of ``columnY`` in ``tableOne`` is equal the value of ``columnX`` in + In this case, this command will return all columns from ``tableOne`` and ``tableTwo`` for rows + fulfilling condition + ``tableOne.columnY EQUAL tableTwo.columnX`` (value of ``columnY`` in ``tableOne`` is equal the + value of ``columnX`` in ``tableTwo``). * ***MIN()*** is used to return the smallest value in a specified column. @@ -233,7 +267,8 @@ go clean -testcache; go test ./... SELECT MIN(columnName) FROM tableName; ``` - In this case, this command will return the smallest value found in the column ``columnName`` of ``tableName``. + In this case, this command will return the smallest value found in the column ``columnName`` of + ``tableName``. * ***MAX()*** is used to return the largest value in a specified column. ```sql @@ -242,7 +277,8 @@ go clean -testcache; go test ./... ``` This command will return the largest value found in the column ``columnName`` of ``tableName``. -* ***COUNT()*** is used to return the number of rows that match a given condition or the total number of rows in a +* ***COUNT()*** is used to return the number of rows that match a given condition or the total + number of rows in a specified column. ```sql SELECT COUNT(columnName) @@ -255,14 +291,16 @@ go clean -testcache; go test ./... SELECT SUM(columnName) FROM tableName; ``` - This command will return the total sum of all values in the numerical column ``columnName`` of ``tableName``. + This command will return the total sum of all values in the numerical column ``columnName`` of + ``tableName``. * ***AVG()*** is used to return the average of values in a specified numerical column. ```sql SELECT AVG(columnName) FROM tableName; ``` - This command will return the average of all values in the numerical column ``columnName`` of ``tableName``. + This command will return the average of all values in the numerical column ``columnName`` of + ``tableName``. ## E2E TEST diff --git a/engine/engine.go b/engine/engine.go index 37cb9d8..4ddf188 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -229,7 +229,7 @@ func (engine *DbEngine) insertIntoTable(command *ast.InsertCommand) error { for i := range columns { expectedToken := tokenMapper(columns[i].Type.Type) - if expectedToken != command.Values[i].Type { + if (expectedToken != command.Values[i].Type) && (command.Values[i].Type != token.NULL) { return &InvalidValueTypeError{expectedType: string(expectedToken), actualType: string(command.Values[i].Type), commandName: command.Token.Literal} } interfaceValue, err := getInterfaceValue(command.Values[i]) @@ -355,11 +355,13 @@ func aggregateColumnContent(space ast.Space, columnValues []ValueInterface) (Val } else { sum := 0 for _, value := range columnValues { - num, err := strconv.Atoi(value.ToString()) - if err != nil { - return nil, err + if value.GetType() != NullType { + num, err := strconv.Atoi(value.ToString()) + if err != nil { + return nil, err + } + sum += num } - sum += num } return IntegerValue{Value: sum}, nil } diff --git a/engine/engine_test.go b/engine/engine_test.go index 32e6a0f..4e1111d 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -47,15 +47,15 @@ func TestSelectCommand(t *testing.T) { }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", - "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", - "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, NULL );", + "INSERT INTO tb1 VALUES( 'byebye', NULL, 33, 'e' );", }, selectInput: "SELECT * FROM tb1;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, - {"goodbye", "2", "22", "w"}, - {"byebye", "3", "33", "e"}, + {"goodbye", "2", "22", "NULL"}, + {"byebye", "NULL", "33", "e"}, }, } @@ -112,10 +112,10 @@ func TestSelectWithWhereNotEqual(t *testing.T) { }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", - "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, NULL, 'w' );", "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", }, - selectInput: "SELECT one, two, three, four FROM tb1 WHERE three NOT 22;", + selectInput: "SELECT one, two, three, four FROM tb1 WHERE three NOT NULL;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, @@ -134,14 +134,14 @@ func TestSelectWithWhereContains(t *testing.T) { }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", - "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, NULL, 'w' );", "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", }, - selectInput: "SELECT one, two, three, four FROM tb1 WHERE three IN (11, 22, 67);", + selectInput: "SELECT one, two, three, four FROM tb1 WHERE three IN (11, NULL, 67);", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, - {"goodbye", "2", "22", "w"}, + {"goodbye", "2", "NULL", "w"}, }, } @@ -216,10 +216,10 @@ func TestSelectWithWhereLogicalOperationAnd(t *testing.T) { }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", - "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', NULL, 22, 'w' );", "INSERT INTO tb1 VALUES( 'goodbye', 3, 33, 'e' );", }, - selectInput: "SELECT * FROM tb1 WHERE one EQUAL 'goodbye' AND two NOT 2;", + selectInput: "SELECT * FROM tb1 WHERE one EQUAL 'goodbye' AND two NOT NULL;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"goodbye", "3", "33", "e"}, @@ -260,13 +260,13 @@ func TestSelectWithWhereLogicalOperationOROperationAND(t *testing.T) { insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", - "INSERT INTO tb1 VALUES( 'goodbye', 3, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 3, 33, NULL );", }, - selectInput: "SELECT * FROM tb1 WHERE one NOT 'goodbye' OR two IN (3) AND four EQUAL 'e';", + selectInput: "SELECT * FROM tb1 WHERE one NOT 'goodbye' OR two IN (3) AND four EQUAL NULL;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, - {"goodbye", "3", "33", "e"}, + {"goodbye", "3", "33", "NULL"}, }, } @@ -323,14 +323,15 @@ func TestDistinctSelect(t *testing.T) { insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", - "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", - "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', NULL, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', NULL, 22, 'w' );", "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", }, selectInput: "SELECT DISTINCT * FROM tb1;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, + {"goodbye", "NULL", "22", "w"}, {"goodbye", "2", "22", "w"}, }, } @@ -369,14 +370,14 @@ func TestUpdateWithWhere(t *testing.T) { insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", - "UPDATE tb1 SET one TO 'hi hello', three TO 5 WHERE two EQUAL 3;", + "UPDATE tb1 SET one TO 'hi hello', three TO NULL WHERE two EQUAL 3;", "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", }, selectInput: "SELECT one, two, three, four FROM tb1;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, - {"hi hello", "3", "5", "e"}, + {"hi hello", "3", "NULL", "e"}, {"goodbye", "2", "22", "w"}, }, } @@ -414,13 +415,13 @@ func TestOrderBy(t *testing.T) { }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", - "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'byebye', NULL, 33, 'e' );", "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", }, selectInput: "SELECT one, two, three, four FROM tb1 ORDER BY two ASC;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, - {"byebye", "1", "33", "e"}, + {"byebye", "NULL", "33", "e"}, {"goodbye", "2", "22", "w"}, {"hello", "3", "11", "q"}, }, @@ -639,8 +640,10 @@ func TestDefaultJoinToBehaveLikeInnerJoin(t *testing.T) { "INSERT INTO books VALUES(2, 'Fire');", "INSERT INTO books VALUES(1, 'Earth');", "INSERT INTO books VALUES(1, 'Air');", + "INSERT INTO books VALUES(3, 'Smoke');", "INSERT INTO authors VALUES( 1, 'Reynold Boyka' );", "INSERT INTO authors VALUES( 2, 'Alissa Ireneus' );", + "INSERT INTO authors VALUES( 3, NULL );", }, selectInput: "SELECT books.title, authors.name FROM books JOIN authors ON books.author_id EQUAL authors.author_id;", expectedOutput: [][]string{ @@ -648,6 +651,7 @@ func TestDefaultJoinToBehaveLikeInnerJoin(t *testing.T) { {"Fire", "Alissa Ireneus"}, {"Earth", "Reynold Boyka"}, {"Air", "Reynold Boyka"}, + {"Smoke", "NULL"}, }, } @@ -689,14 +693,14 @@ func TestFullJoinOnIdenticalTables(t *testing.T) { "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", "INSERT INTO table2 VALUES(2, 'Value2');", - "INSERT INTO table2 VALUES(3, 'Value3');", + "INSERT INTO table2 VALUES(3, NULL);", }, selectInput: "SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id;", expectedOutput: [][]string{ {"table1.value", "table2.value"}, {"Value1", "NULL"}, {"Value2", "Value2"}, - {"NULL", "Value3"}, + {"NULL", "NULL"}, }, } @@ -712,13 +716,16 @@ func TestInnerJoinWithSpecifiedKeywordOnIdenticalTables(t *testing.T) { insertAndDeleteInputs: []string{ "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(NULL, NULL);", "INSERT INTO table2 VALUES(2, 'Value2');", "INSERT INTO table2 VALUES(3, 'Value3');", + "INSERT INTO table2 VALUES(NULL, 'Value4');", }, selectInput: "SELECT table1.value, table2.value FROM table1 INNER JOIN table2 ON table1.id EQUAL table2.id;", expectedOutput: [][]string{ {"table1.value", "table2.value"}, {"Value2", "Value2"}, + {"NULL", "Value4"}, }, } @@ -734,14 +741,17 @@ func TestLeftJoinOnIdenticalTables(t *testing.T) { insertAndDeleteInputs: []string{ "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(NULL, 'Value4');", "INSERT INTO table2 VALUES(2, 'Value2');", "INSERT INTO table2 VALUES(3, 'Value3');", + "INSERT INTO table2 VALUES(NULL, NULL);", }, selectInput: "SELECT table1.value, table2.value FROM table1 LEFT JOIN table2 ON table1.id EQUAL table2.id;", expectedOutput: [][]string{ {"table1.value", "table2.value"}, {"Value1", "NULL"}, {"Value2", "Value2"}, + {"Value4", "NULL"}, }, } @@ -757,14 +767,17 @@ func TestRightJoinOnIdenticalTables(t *testing.T) { insertAndDeleteInputs: []string{ "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(NULL, NULL);", "INSERT INTO table2 VALUES(2, 'Value2');", "INSERT INTO table2 VALUES(3, 'Value3');", + "INSERT INTO table2 VALUES(NULL, 'Value4');", }, selectInput: "SELECT table1.value, table2.value FROM table1 RIGHT JOIN table2 ON table1.id EQUAL table2.id;", expectedOutput: [][]string{ {"table1.value", "table2.value"}, {"Value2", "Value2"}, {"NULL", "Value3"}, + {"NULL", "Value4"}, }, } @@ -798,8 +811,9 @@ func TestAggregateFunctionMin(t *testing.T) { insertAndDeleteInputs: []string{ "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, NULL);", }, - selectInput: "SELECT MIN(value), MIN(id) FROM table1;", + selectInput: "SELECT MIN(value), MIN(id) FROM table1 WHERE value NOT NULL;", expectedOutput: [][]string{ {"MIN(value)", "MIN(id)"}, {"Value1", "1"}, @@ -809,6 +823,26 @@ func TestAggregateFunctionMin(t *testing.T) { engineTestSuite.runTestSuite(t) } +func TestAggregateFunctionMinWithNull(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, NULL);", + }, + selectInput: "SELECT MIN(value), MIN(id) FROM table1;", + expectedOutput: [][]string{ + {"MIN(value)", "MIN(id)"}, + {"NULL", "1"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + func TestAggregateFunctionCount(t *testing.T) { engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ @@ -818,14 +852,12 @@ func TestAggregateFunctionCount(t *testing.T) { "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", "INSERT INTO table1 VALUES(3, 'Value3');", - // TODO: Add test case mentioned in comment below once inserting - // null values will be added - //"INSERT INTO table1 VALUES(NULL, NULL);", + "INSERT INTO table1 VALUES(NULL, NULL);", }, selectInput: "SELECT COUNT(*), COUNT(id), COUNT(value) FROM table1;", expectedOutput: [][]string{ {"COUNT(*)", "COUNT(id)", "COUNT(value)"}, - {"3", "3", "3"}, + {"4", "3", "3"}, }, } @@ -841,6 +873,7 @@ func TestAggregateFunctionSum(t *testing.T) { "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", "INSERT INTO table1 VALUES(3, 'Value3');", + "INSERT INTO table1 VALUES(NULL, 'Value4');", }, selectInput: "SELECT SUM(id), SUM(value) FROM table1;", expectedOutput: [][]string{ @@ -861,11 +894,12 @@ func TestAggregateFunctionAvg(t *testing.T) { "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", "INSERT INTO table1 VALUES(3, 'Value3');", + "INSERT INTO table1 VALUES(10, NULL);", }, selectInput: "SELECT AVG(id), AVG(value) FROM table1;", expectedOutput: [][]string{ {"AVG(id)", "AVG(value)"}, - {"2", "0"}, + {"4", "0"}, }, } @@ -880,12 +914,13 @@ func TestAggregateFunctionWithColumnSelection(t *testing.T) { insertAndDeleteInputs: []string{ "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", - "INSERT INTO table1 VALUES(3, 'Value3');", + "INSERT INTO table1 VALUES(3, NULL);", + "INSERT INTO table1 VALUES(6, 'Value3');", }, selectInput: "SELECT AVG(id), id FROM table1;", expectedOutput: [][]string{ {"AVG(id)", "id"}, - {"2", "1"}, + {"3", "1"}, }, } @@ -901,11 +936,12 @@ func TestAggregateFunctionWithColumnSelectionAndOrderBy(t *testing.T) { "INSERT INTO table1 VALUES(1, 'Value1');", "INSERT INTO table1 VALUES(2, 'Value2');", "INSERT INTO table1 VALUES(3, 'Value3');", + "INSERT INTO table1 VALUES(4, NULL);", }, selectInput: "SELECT MAX(id), id FROM table1 ORDER BY id DESC;", expectedOutput: [][]string{ {"MAX(id)", "id"}, - {"3", "3"}, + {"4", "4"}, }, } diff --git a/engine/engine_utils.go b/engine/engine_utils.go index 3486447..69329d2 100644 --- a/engine/engine_utils.go +++ b/engine/engine_utils.go @@ -8,6 +8,8 @@ import ( func getInterfaceValue(t token.Token) (ValueInterface, error) { switch t.Type { + case token.NULL: + return NullValue{}, nil case token.LITERAL: castedInteger, err := strconv.Atoi(t.Literal) if err != nil { diff --git a/engine/generic_value.go b/engine/generic_value.go index c6fd07c..d5c26c9 100644 --- a/engine/generic_value.go +++ b/engine/generic_value.go @@ -75,8 +75,12 @@ func (value NullValue) IsEqual(valueInterface ValueInterface) bool { // isSmallerThan implementations func (value IntegerValue) isSmallerThan(secondValue ValueInterface) bool { - secondValueAsInteger, isInteger := secondValue.(IntegerValue) + nullValue, isNull := secondValue.(NullValue) + if isNull { + return nullValue.isGreaterThan(value) + } + secondValueAsInteger, isInteger := secondValue.(IntegerValue) if !isInteger { log.Fatal("Can't compare Integer with other type") } @@ -85,8 +89,12 @@ func (value IntegerValue) isSmallerThan(secondValue ValueInterface) bool { } func (value StringValue) isSmallerThan(secondValue ValueInterface) bool { - secondValueAsString, isString := secondValue.(StringValue) + nullValue, isNull := secondValue.(NullValue) + if isNull { + return nullValue.isGreaterThan(value) + } + secondValueAsString, isString := secondValue.(StringValue) if !isString { log.Fatal("Can't compare String with other type") } @@ -97,8 +105,8 @@ func (value StringValue) isSmallerThan(secondValue ValueInterface) bool { func (value NullValue) isSmallerThan(secondValue ValueInterface) bool { _, isNull := secondValue.(NullValue) - if !isNull { - log.Fatal("Can't compare Null with other type") + if isNull { + return false } return true @@ -106,8 +114,12 @@ func (value NullValue) isSmallerThan(secondValue ValueInterface) bool { // isGreaterThan implementations func (value IntegerValue) isGreaterThan(secondValue ValueInterface) bool { - secondValueAsInteger, isInteger := secondValue.(IntegerValue) + nullValue, isNull := secondValue.(NullValue) + if isNull { + return nullValue.isSmallerThan(value) + } + secondValueAsInteger, isInteger := secondValue.(IntegerValue) if !isInteger { log.Fatal("Can't compare Integer with other type") } @@ -115,8 +127,12 @@ func (value IntegerValue) isGreaterThan(secondValue ValueInterface) bool { return value.Value > secondValueAsInteger.Value } func (value StringValue) isGreaterThan(secondValue ValueInterface) bool { - secondValueAsString, isString := secondValue.(StringValue) + nullValue, isNull := secondValue.(NullValue) + if isNull { + return nullValue.isSmallerThan(value) + } + secondValueAsString, isString := secondValue.(StringValue) if !isString { log.Fatal("Can't compare String with other type") } @@ -124,14 +140,8 @@ func (value StringValue) isGreaterThan(secondValue ValueInterface) bool { return value.Value > secondValueAsString.Value } -func (value NullValue) isGreaterThan(secondValue ValueInterface) bool { - _, isNull := secondValue.(NullValue) - - if !isNull { - log.Fatal("Can't compare Null with other type") - } - - return true +func (value NullValue) isGreaterThan(_ ValueInterface) bool { + return false } func areEqual(first ValueInterface, second ValueInterface) bool { diff --git a/engine/generic_value_test.go b/engine/generic_value_test.go index 1198e0d..5b53198 100644 --- a/engine/generic_value_test.go +++ b/engine/generic_value_test.go @@ -6,7 +6,7 @@ import ( func TestIsGreaterThan(t *testing.T) { oneInt := IntegerValue{ - Value: 1, + Value: 0, } twoInt := IntegerValue{ Value: 2, @@ -21,11 +21,11 @@ func TestIsGreaterThan(t *testing.T) { twoNull := NullValue{} if oneInt.isGreaterThan(twoInt) { - t.Errorf("1 shouldn't be greater than 2") + t.Errorf("0 shouldn't be greater than 2") } if !twoInt.isGreaterThan(oneInt) { - t.Errorf("1 shouldn't be greater than 2") + t.Errorf("0 shouldn't be greater than 2") } if oneString.isGreaterThan(twoString) { @@ -36,14 +36,30 @@ func TestIsGreaterThan(t *testing.T) { t.Errorf("1 shouldn't be greater than 2") } - if !twoNull.isGreaterThan(oneNull) { - t.Errorf("null to null operations should always return true") + if twoNull.isGreaterThan(oneNull) { + t.Errorf("null is not greater than null") + } + + if !oneInt.isGreaterThan(oneNull) { + t.Errorf("Any Int value cannot be smaller than null") + } + + if !oneString.isGreaterThan(oneNull) { + t.Errorf("Any String value cannot be smaller than null") + } + + if oneNull.isGreaterThan(oneInt) { + t.Errorf("Null cannot be greater than any int value") + } + + if oneNull.isGreaterThan(oneString) { + t.Errorf("Null cannot be greater than any string value") } } func TestIsSmallerThan(t *testing.T) { oneInt := IntegerValue{ - Value: 1, + Value: 0, } twoInt := IntegerValue{ Value: 2, @@ -58,11 +74,11 @@ func TestIsSmallerThan(t *testing.T) { twoNull := NullValue{} if !oneInt.isSmallerThan(twoInt) { - t.Errorf("1 should be smaller than 2") + t.Errorf("0 should be smaller than 2") } if twoInt.isSmallerThan(oneInt) { - t.Errorf("1 should be smaller than 2") + t.Errorf("0 should be smaller than 2") } if !oneString.isSmallerThan(twoString) { @@ -73,8 +89,24 @@ func TestIsSmallerThan(t *testing.T) { t.Errorf("1 should be smaller than 2") } - if !twoNull.isSmallerThan(oneNull) { - t.Errorf("null to null operations should always return true") + if twoNull.isSmallerThan(oneNull) { + t.Errorf("null is not smaller than null") + } + + if oneInt.isSmallerThan(oneNull) { + t.Errorf("Any int value cannot be smaller than null") + } + + if oneString.isSmallerThan(oneNull) { + t.Errorf("Any string value cannot be smaller than null") + } + + if !oneNull.isSmallerThan(oneInt) { + t.Errorf("Null cannot be greater than any int value") + } + + if !oneNull.isSmallerThan(oneString) { + t.Errorf("Null cannot be greater than any string value") } } @@ -96,10 +128,13 @@ func TestEquals(t *testing.T) { shouldBeEqual(t, oneInt, oneInt) shouldBeEqual(t, oneString, oneString) + shouldBeEqual(t, oneNull, twoNull) shouldNotBeEqual(t, oneInt, twoInt) shouldNotBeEqual(t, oneString, twoString) shouldNotBeEqual(t, oneString, oneInt) - shouldBeEqual(t, oneNull, twoNull) + shouldNotBeEqual(t, oneNull, oneInt) + shouldNotBeEqual(t, oneNull, oneString) + shouldNotBeEqual(t, twoInt, twoNull) } func shouldBeEqual(t *testing.T, valueOne ValueInterface, valueTwo ValueInterface) { diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 4cf8644..0100f29 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -532,6 +532,29 @@ func TestFullJoin(t *testing.T) { runLexerTestSuite(t, input, tests) } +func TestHandlingNullValues(t *testing.T) { + input := `INSERT INTO tbl VALUES( 'NULL', NULL );` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.INSERT, "INSERT"}, + {token.INTO, "INTO"}, + {token.IDENT, "tbl"}, + {token.VALUES, "VALUES"}, + {token.LPAREN, "("}, + {token.APOSTROPHE, "'"}, + {token.IDENT, "NULL"}, + {token.APOSTROPHE, "'"}, + {token.COMMA, ","}, + {token.NULL, "NULL"}, + {token.RPAREN, ")"}, + {token.SEMICOLON, ";"}, + } + + runLexerTestSuite(t, input, tests) +} + func runLexerTestSuite(t *testing.T, input string, tests []struct { expectedType token.Type expectedLiteral string diff --git a/parser/parser.go b/parser/parser.go index c4d6da3..18a90f8 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -183,15 +183,15 @@ func (parser *Parser) parseInsertCommand() (ast.Command, error) { return nil, err } - for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.APOSTROPHE { + for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { parser.skipIfCurrentTokenIsApostrophe() - err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) if err != nil { return nil, err } insertCommand.Values = append(insertCommand.Values, parser.currentToken) - // Ignore token.IDENT or token.LITERAL + // Ignore token.IDENT, token.LITERAL or token.NULL parser.nextToken() parser.skipIfCurrentTokenIsApostrophe() @@ -649,13 +649,13 @@ func (parser *Parser) parseUpdateCommand() (ast.Command, error) { parser.nextToken() parser.skipIfCurrentTokenIsApostrophe() - err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) if err != nil { return nil, err } updateCommand.Changes[colKey] = ast.Anonymitifier{Token: parser.currentToken} - // skip token.IDENT or token.LITERAL + // skip token.IDENT, token.LITERAL or token.NULL parser.nextToken() parser.skipIfCurrentTokenIsApostrophe() @@ -792,6 +792,9 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression if err != nil { return false, nil, err } + case token.NULL: + conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} + parser.nextToken() case token.LITERAL: conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() @@ -814,11 +817,14 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression if err != nil { return false, nil, err } + case token.NULL: + conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} + parser.nextToken() case token.LITERAL: conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} parser.nextToken() default: - return false, nil, &SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL}, got: parser.currentToken.Literal} + return false, nil, &SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL, token.NULL}, got: parser.currentToken.Literal} } return true, conditionalExpression, nil @@ -854,15 +860,15 @@ func (parser *Parser) getContainExpression() (bool, *ast.ContainExpression, erro return false, nil, err } - for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.APOSTROPHE { + for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { parser.skipIfCurrentTokenIsApostrophe() - err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) if err != nil { return false, nil, err } containExpression.Right = append(containExpression.Right, ast.Anonymitifier{Token: parser.currentToken}) - // Ignore token.IDENT or token.LITERAL + // Ignore token.IDENT, token.LITERAL or token.NULL parser.nextToken() parser.skipIfCurrentTokenIsApostrophe() diff --git a/parser/parser_error_handling_test.go b/parser/parser_error_handling_test.go index 1dd0112..abb10d7 100644 --- a/parser/parser_error_handling_test.go +++ b/parser/parser_error_handling_test.go @@ -53,7 +53,7 @@ func TestParseInsertCommandErrorHandling(t *testing.T) { noIntoKeyword := SyntaxError{[]string{token.INTO}, token.IDENT} noTableName := SyntaxError{[]string{token.IDENT}, token.VALUES} noLeftParen := SyntaxError{[]string{token.LPAREN}, token.APOSTROPHE} - noValue := SyntaxError{[]string{token.IDENT, token.LITERAL}, token.APOSTROPHE} + noValue := SyntaxError{[]string{token.IDENT, token.LITERAL, token.NULL}, token.APOSTROPHE} noRightParen := SyntaxError{[]string{token.RPAREN}, token.SEMICOLON} noSemicolon := SyntaxError{[]string{token.SEMICOLON}, ""} @@ -75,7 +75,7 @@ func TestParseUpdateCommandErrorHandling(t *testing.T) { noSetKeyword := SyntaxError{expecting: []string{token.SET}, got: token.SEMICOLON} noColumnName := SyntaxError{expecting: []string{token.IDENT}, got: token.LITERAL} noToKeyword := SyntaxError{expecting: []string{token.TO}, got: token.SEMICOLON} - noSecondIdentOrLiteralForValue := SyntaxError{expecting: []string{token.IDENT, token.LITERAL}, got: token.SEMICOLON} + noSecondIdentOrLiteralForValue := SyntaxError{expecting: []string{token.IDENT, token.LITERAL, token.NULL}, got: token.SEMICOLON} noCommaBetweenValues := SyntaxError{expecting: []string{token.SEMICOLON, token.WHERE}, got: token.IDENT} noWhereOrSemicolon := SyntaxError{expecting: []string{token.SEMICOLON, token.WHERE}, got: token.SELECT} @@ -122,7 +122,7 @@ func TestParseWhereCommandErrorHandling(t *testing.T) { noPredecessorError := NoPredecessorParserError{command: token.WHERE} noColName := LogicalExpressionParsingError{} noOperatorInsideWhereStatementException := LogicalExpressionParsingError{} - valueIsMissing := SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL}, got: token.SEMICOLON} + valueIsMissing := SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL, token.NULL}, got: token.SEMICOLON} tokenAnd := token.AND conjunctionIsMissing := SyntaxError{expecting: []string{token.SEMICOLON, token.ORDER}, got: token.IDENT} nextLogicalExpressionIsMissing := LogicalExpressionParsingError{afterToken: &tokenAnd} diff --git a/parser/parser_test.go b/parser/parser_test.go index 42d8f3c..ae46a99 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -77,6 +77,7 @@ func TestParseInsertCommand(t *testing.T) { {"INSERT INTO TBL VALUES();", "TBL", []token.Token{}}, {"INSERT INTO TBL VALUES( 'HELLO' );", "TBL", []token.Token{{Type: token.IDENT, Literal: "HELLO"}}}, {"INSERT INTO TBL VALUES( 'HELLO', 10 , 'LOL');", "TBL", []token.Token{{Type: token.IDENT, Literal: "HELLO"}, {Type: token.LITERAL, Literal: "10"}, {Type: token.IDENT, Literal: "LOL"}}}, + {"INSERT INTO TBL VALUES(NULL, 'NULL', null);", "TBL", []token.Token{{Type: token.NULL, Literal: "NULL"}, {Type: token.IDENT, Literal: "NULL"}, {Type: token.IDENT, Literal: "null"}}}, } for testIndex, tt := range tests { @@ -183,6 +184,12 @@ func TestParseWhereCommand(t *testing.T) { Contains: false, } + fifthExpression := ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName5"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.NULL, Literal: "NULL"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + } + tests := []struct { input string expectedExpression ast.Expression @@ -203,6 +210,10 @@ func TestParseWhereCommand(t *testing.T) { input: "SELECT * FROM TBL WHERE colName4 NOTIN ('one', 'two');", expectedExpression: fourthExpression, }, + { + input: "SELECT * FROM TBL WHERE colName5 EQUAL NULL;", + expectedExpression: fifthExpression, + }, } for testIndex, tt := range tests { @@ -699,14 +710,22 @@ func TestParseUpdateCommand(t *testing.T) { expectedTableName string expectedChanges map[token.Token]ast.Anonymitifier }{ - {input: "UPDATE tbl SET colName TO 5;", expectedTableName: "tbl", expectedChanges: map[token.Token]ast.Anonymitifier{ - {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, - }, + { + input: "UPDATE tbl SET colName TO 5;", expectedTableName: "tbl", expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, }, - {input: "UPDATE tbl1 SET colName1 TO 'hi hello', colName2 TO 5;", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ - {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.IDENT, Literal: "hi hello"}}, - {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + { + input: "UPDATE tbl1 SET colName1 TO 'hi hello', colName2 TO 5;", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.IDENT, Literal: "hi hello"}}, + {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, }, + { + input: "UPDATE tbl1 SET colName1 TO NULL, colName2 TO 'NULL';", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.NULL, Literal: "NULL"}}, + {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "NULL"}}, + }, }, } @@ -790,7 +809,7 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}}, Right: ast.ConditionExpression{ Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName2"}}, - Right: ast.Anonymitifier{Token: token.Token{Type: token.LITERAL, Literal: "123"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.NULL, Literal: "NULL"}}, Condition: token.Token{Type: token.EQUAL, Literal: "NOT"}}, Operation: token.Token{Type: token.AND, Literal: "AND"}, } @@ -802,7 +821,7 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { Condition: token.Token{Type: token.NOT, Literal: "NOT"}}, Right: ast.ConditionExpression{ Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName1"}}, - Right: ast.Anonymitifier{Token: token.Token{Type: token.IDENT, Literal: "qwe"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.IDENT, Literal: "NULL"}}, Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}}, Operation: token.Token{Type: token.OR, Literal: "OR"}, } @@ -816,11 +835,11 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { expectedExpression ast.Expression }{ { - input: "SELECT * FROM TBL WHERE colName1 EQUAL 'fda' AND colName2 NOT 123;", + input: "SELECT * FROM TBL WHERE colName1 EQUAL 'fda' AND colName2 NOT NULL;", expectedExpression: firstExpression, }, { - input: "SELECT * FROM TBL WHERE colName2 NOT 6462389 OR colName1 EQUAL 'qwe';", + input: "SELECT * FROM TBL WHERE colName2 NOT 6462389 OR colName1 EQUAL 'NULL';", expectedExpression: secondExpression, }, { @@ -946,6 +965,9 @@ func tokenArrayEquals(a []token.Token, b []token.Token) bool { if v.Literal != b[i].Literal { return false } + if v.Type != b[i].Type { + return false + } } return true } @@ -1082,28 +1104,28 @@ func validateOperationExpression(second ast.Expression, operationExpression *ast return expressionsAreEqual(operationExpression.Left, secondOperationExpression.Left) && expressionsAreEqual(operationExpression.Right, secondOperationExpression.Right) } -func validateContainExpression(expression ast.Expression, exptectedContainExpression *ast.ContainExpression) bool { +func validateContainExpression(expression ast.Expression, expectedContainExpression *ast.ContainExpression) bool { actualContainExpression, actualContainExpressionIsValid := expression.(ast.ContainExpression) if !actualContainExpressionIsValid { return false } - if exptectedContainExpression.Contains != actualContainExpression.Contains { + if expectedContainExpression.Contains != actualContainExpression.Contains { return false } - if actualContainExpression.Left.GetToken().Literal != exptectedContainExpression.Left.GetToken().Literal && - actualContainExpression.Left.IsIdentifier() == exptectedContainExpression.Left.IsIdentifier() { + if actualContainExpression.Left.GetToken().Literal != expectedContainExpression.Left.GetToken().Literal && + actualContainExpression.Left.IsIdentifier() == expectedContainExpression.Left.IsIdentifier() { return false } - if len(exptectedContainExpression.Right) != len(actualContainExpression.Right) { + if len(expectedContainExpression.Right) != len(actualContainExpression.Right) { return false } - for i := 0; i < len(exptectedContainExpression.Right); i++ { - if exptectedContainExpression.Right[i] != actualContainExpression.Right[i] { + for i := 0; i < len(expectedContainExpression.Right); i++ { + if expectedContainExpression.Right[i] != actualContainExpression.Right[i] { return false } } diff --git a/test_file b/test_file index bf97956..8149bf0 100644 --- a/test_file +++ b/test_file @@ -1,10 +1,10 @@ CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); INSERT INTO tbl VALUES( 'hello', 1, 11, 'q' ); INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); - INSERT INTO tbl VALUES( 'byebye', 3, 33, 'e' ); + INSERT INTO tbl VALUES( 'byebye', NULL, 33, 'e' ); SELECT * FROM tbl WHERE one EQUAL 'byebye'; - SELECT one, three FROM tbl WHERE two NOT 3; - SELECT * FROM tbl WHERE one NOT 'goodbye' AND two EQUAL 3; + SELECT one, three FROM tbl WHERE two NOT NULL; + SELECT * FROM tbl WHERE one NOT 'goodbye' AND two EQUAL NULL; SELECT * FROM tbl WHERE one IN ('goodbye', 'byebye'); SELECT * FROM tbl WHERE one NOTIN ('goodbye', 'byebye'); SELECT * FROM tbl WHERE FALSE; @@ -14,7 +14,7 @@ DELETE FROM tbl WHERE one EQUAL 'byebye'; SELECT * FROM tbl; SELECT one FROM tbl WHERE TRUE ORDER BY two ASC, four DESC; - UPDATE tbl SET two TO 5, four TO 'P' WHERE one EQUAL 'goodbye'; + UPDATE tbl SET two TO NULL, four TO 'P' WHERE one EQUAL 'goodbye'; SELECT * FROM tbl; INSERT INTO tbl VALUES( 'goodbye', 5, 22, 'P' ); SELECT DISTINCT * FROM tbl; @@ -22,7 +22,7 @@ CREATE TABLE table1( id INT, value TEXT); CREATE TABLE table2( id INT, value TEXT); INSERT INTO table1 VALUES(1, 'Value1'); - INSERT INTO table1 VALUES(2, 'Value2'); + INSERT INTO table1 VALUES(2, NULL); INSERT INTO table2 VALUES(2, 'Value2'); INSERT INTO table2 VALUES(3, 'Value3'); SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id; diff --git a/token/token.go b/token/token.go index 049baa1..381a168 100644 --- a/token/token.go +++ b/token/token.go @@ -61,6 +61,7 @@ const ( AVG = "AVG" IN = "IN" NOTIN = "NOTIN" + NULL = "NULL" TO = "TO" @@ -122,6 +123,7 @@ var keywords = map[string]Type{ "OR": OR, "TRUE": TRUE, "FALSE": FALSE, + "NULL": NULL, } // LookupIdent - Return keyword type from defined list if exists, otherwise it returns IDENT type From 940f5fd288cb11605c0e738c7de13e72cf786b6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sara=20Ryfczy=C5=84ska?= Date: Tue, 10 Dec 2024 23:33:01 +0100 Subject: [PATCH 17/21] Add apostrophe validation errors (#26) --- parser/errors.go | 18 +++++ parser/parser.go | 104 ++++++++++++++++----------- parser/parser_error_handling_test.go | 20 +++++- parser/parser_test.go | 20 +++--- 4 files changed, 113 insertions(+), 49 deletions(-) diff --git a/parser/errors.go b/parser/errors.go index dbd71e4..8bc04fd 100644 --- a/parser/errors.go +++ b/parser/errors.go @@ -107,3 +107,21 @@ type IllegalPeriodInIdentParserError struct { func (m *IllegalPeriodInIdentParserError) Error() string { return "syntax error, {" + m.name + "} shouldn't contain '.'" } + +// NoApostropheOnRightParserError - error thrown when parser found no apostrophe on right of ident +type NoApostropheOnRightParserError struct { + ident string +} + +func (m *NoApostropheOnRightParserError) Error() string { + return "syntax error, Identifier: {" + m.ident + "} has no apostrophe on right" +} + +// NoApostropheOnLeftParserError - error thrown when parser found no apostrophe on left of ident +type NoApostropheOnLeftParserError struct { + ident string +} + +func (m *NoApostropheOnLeftParserError) Error() string { + return "syntax error, Identifier: {" + m.ident + "} has no apostrophe on left" +} diff --git a/parser/parser.go b/parser/parser.go index 18a90f8..ce5e5d4 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -138,10 +138,12 @@ func (parser *Parser) parseCreateCommand() (ast.Command, error) { return createCommand, nil } -func (parser *Parser) skipIfCurrentTokenIsApostrophe() { +func (parser *Parser) skipIfCurrentTokenIsApostrophe() bool { if parser.currentToken.Type == token.APOSTROPHE { parser.nextToken() + return true } + return false } func (parser *Parser) skipIfCurrentTokenIsSemicolon() { @@ -184,17 +186,24 @@ func (parser *Parser) parseInsertCommand() (ast.Command, error) { } for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { - parser.skipIfCurrentTokenIsApostrophe() + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) if err != nil { return nil, err } - insertCommand.Values = append(insertCommand.Values, parser.currentToken) + value := parser.currentToken + insertCommand.Values = append(insertCommand.Values, value) // Ignore token.IDENT, token.LITERAL or token.NULL parser.nextToken() - parser.skipIfCurrentTokenIsApostrophe() + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + err = validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, value) + + if err != nil { + return nil, err + } if parser.currentToken.Type != token.COMMA { break @@ -217,6 +226,15 @@ func (parser *Parser) parseInsertCommand() (ast.Command, error) { return insertCommand, nil } +func validateApostropheWrapping(startedWithApostrophe bool, finishedWithApostrophe bool, value token.Token) error { + if startedWithApostrophe && !finishedWithApostrophe { + return &NoApostropheOnRightParserError{ident: value.Literal} + } else if !startedWithApostrophe && finishedWithApostrophe { + return &NoApostropheOnLeftParserError{ident: value.Literal} + } + return nil +} + // parseSelectCommand - Return ast.SelectCommand created from tokens and validate the syntax // // Example of input parsable to the ast.SelectCommand: @@ -648,7 +666,7 @@ func (parser *Parser) parseUpdateCommand() (ast.Command, error) { // skip token.TO parser.nextToken() - parser.skipIfCurrentTokenIsApostrophe() + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) if err != nil { return nil, err @@ -657,7 +675,13 @@ func (parser *Parser) parseUpdateCommand() (ast.Command, error) { // skip token.IDENT, token.LITERAL or token.NULL parser.nextToken() - parser.skipIfCurrentTokenIsApostrophe() + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + err = validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, updateCommand.Changes[colKey].GetToken()) + + if err != nil { + return nil, err + } if parser.currentToken.Type != token.COMMA { break @@ -780,51 +804,45 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression } conditionalExpression.Condition = parser.peekToken - switch parser.currentToken.Type { - case token.IDENT: - conditionalExpression.Left = ast.Identifier{Token: parser.currentToken} - parser.nextToken() - case token.APOSTROPHE: - parser.skipIfCurrentTokenIsApostrophe() - conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} + if parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + if !startedWithApostrophe && parser.currentToken.Type == token.IDENT { + conditionalExpression.Left = ast.Identifier{Token: parser.currentToken} + } else { + conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} + } parser.nextToken() - err := validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) + + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + err := validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, conditionalExpression.Left.GetToken()) if err != nil { return false, nil, err } - case token.NULL: - conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} - parser.nextToken() - case token.LITERAL: - conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} - parser.nextToken() - default: + } else { return false, conditionalExpression, nil } // skip EQUAL or NOT parser.nextToken() - switch parser.currentToken.Type { - case token.IDENT: - conditionalExpression.Right = ast.Identifier{Token: parser.currentToken} - parser.nextToken() - case token.APOSTROPHE: - parser.skipIfCurrentTokenIsApostrophe() - conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} + if parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + if !startedWithApostrophe && parser.currentToken.Type == token.IDENT { + conditionalExpression.Right = ast.Identifier{Token: parser.currentToken} + } else { + conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} + } parser.nextToken() - err := validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) + + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + err = validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, conditionalExpression.Right.GetToken()) if err != nil { return false, nil, err } - case token.NULL: - conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} - parser.nextToken() - case token.LITERAL: - conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} - parser.nextToken() - default: - return false, nil, &SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL, token.NULL}, got: parser.currentToken.Literal} + } else { + return false, conditionalExpression, &SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL, token.NULL}, got: parser.currentToken.Literal} } return true, conditionalExpression, nil @@ -861,17 +879,23 @@ func (parser *Parser) getContainExpression() (bool, *ast.ContainExpression, erro } for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { - parser.skipIfCurrentTokenIsApostrophe() + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) if err != nil { return false, nil, err } - containExpression.Right = append(containExpression.Right, ast.Anonymitifier{Token: parser.currentToken}) + currentAnonymitifier := ast.Anonymitifier{Token: parser.currentToken} + containExpression.Right = append(containExpression.Right, currentAnonymitifier) // Ignore token.IDENT, token.LITERAL or token.NULL parser.nextToken() - parser.skipIfCurrentTokenIsApostrophe() + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + err = validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, currentAnonymitifier.GetToken()) + if err != nil { + return false, nil, err + } if parser.currentToken.Type != token.COMMA { if parser.currentToken.Type != token.RPAREN { diff --git a/parser/parser_error_handling_test.go b/parser/parser_error_handling_test.go index abb10d7..cfd17c0 100644 --- a/parser/parser_error_handling_test.go +++ b/parser/parser_error_handling_test.go @@ -56,6 +56,8 @@ func TestParseInsertCommandErrorHandling(t *testing.T) { noValue := SyntaxError{[]string{token.IDENT, token.LITERAL, token.NULL}, token.APOSTROPHE} noRightParen := SyntaxError{[]string{token.RPAREN}, token.SEMICOLON} noSemicolon := SyntaxError{[]string{token.SEMICOLON}, ""} + noLeftApostrophe := NoApostropheOnLeftParserError{ident: "hello"} + noRightApostrophe := NoApostropheOnRightParserError{ident: "hello, 10)"} tests := []errorHandlingTestSuite{ {"INSERT tbl VALUES( 'hello', 10);", noIntoKeyword.Error()}, @@ -64,6 +66,8 @@ func TestParseInsertCommandErrorHandling(t *testing.T) { {"INSERT INTO tl VALUES ('', 10);", noValue.Error()}, {"INSERT INTO tl VALUES ('hello', 10;", noRightParen.Error()}, {"INSERT INTO tl VALUES ('hello', 10)", noSemicolon.Error()}, + {"INSERT INTO tl VALUES (hello', 10)", noLeftApostrophe.Error()}, + {"INSERT INTO tl VALUES ('hello, 10)", noRightApostrophe.Error()}, } runParserErrorHandlingSuite(t, tests) @@ -78,6 +82,8 @@ func TestParseUpdateCommandErrorHandling(t *testing.T) { noSecondIdentOrLiteralForValue := SyntaxError{expecting: []string{token.IDENT, token.LITERAL, token.NULL}, got: token.SEMICOLON} noCommaBetweenValues := SyntaxError{expecting: []string{token.SEMICOLON, token.WHERE}, got: token.IDENT} noWhereOrSemicolon := SyntaxError{expecting: []string{token.SEMICOLON, token.WHERE}, got: token.SELECT} + noLeftApostrophe := NoApostropheOnLeftParserError{ident: "new_value_1"} + noRightApostrophe := NoApostropheOnRightParserError{ident: "new_value_1"} tests := []errorHandlingTestSuite{ {"UPDATE;", notableName.Error()}, @@ -87,6 +93,8 @@ func TestParseUpdateCommandErrorHandling(t *testing.T) { {"UPDATE table SET column_name_1 TO;", noSecondIdentOrLiteralForValue.Error()}, {"UPDATE table SET column_name_1 TO 2 column_name_1 TO 3;", noCommaBetweenValues.Error()}, {"UPDATE table SET column_name_1 TO 'new_value_1' SELECT;", noWhereOrSemicolon.Error()}, + {"UPDATE table SET column_name_1 TO new_value_1'", noLeftApostrophe.Error()}, + {"UPDATE table SET column_name_1 TO 'new_value_1", noRightApostrophe.Error()}, } runParserErrorHandlingSuite(t, tests) @@ -131,6 +139,10 @@ func TestParseWhereCommandErrorHandling(t *testing.T) { noLeftParGotNumber := SyntaxError{expecting: []string{token.LPAREN}, got: token.LITERAL} noComma := SyntaxError{expecting: []string{token.COMMA, token.RPAREN}, got: token.LITERAL} noInKeywordException := LogicalExpressionParsingError{} + noLeftApostropheGoodbye := NoApostropheOnLeftParserError{ident: "goodbye"} + noLeftApostropheFive := NoApostropheOnLeftParserError{ident: "5"} + noRightApostropheGoodbye := NoApostropheOnRightParserError{ident: "goodbye"} + noRightApostropheFive := NoApostropheOnRightParserError{ident: "5"} tests := []errorHandlingTestSuite{ {"WHERE col1 NOT 'goodbye' OR col2 EQUAL 3;", noPredecessorError.Error()}, @@ -143,11 +155,17 @@ func TestParseWhereCommandErrorHandling(t *testing.T) { {selectCommandPrefix + "WHERE one IN ;", noLeftParGotSemicolon.Error()}, {selectCommandPrefix + "WHERE one IN 5;", noLeftParGotNumber.Error()}, {selectCommandPrefix + "WHERE one IN (5 6);", noComma.Error()}, + {selectCommandPrefix + "WHERE one IN ('5", noRightApostropheFive.Error()}, + {selectCommandPrefix + "WHERE one IN (5');", noLeftApostropheFive.Error()}, {selectCommandPrefix + "WHERE one (5, 6);", noInKeywordException.Error()}, + {selectCommandPrefix + "WHERE one EQUAL goodbye';", noLeftApostropheGoodbye.Error()}, + {selectCommandPrefix + "WHERE one EQUAL 'goodbye", noRightApostropheGoodbye.Error()}, + // TODO: Add after fix apostrophe on left side of condition + //{selectCommandPrefix + "WHERE 'goodbye EQUAL two", noRightApostropheGoodbye.Error()}, + //{selectCommandPrefix + "WHERE goodbye' EQUAL two", noLeftApostropheGoodbye.Error()}, } runParserErrorHandlingSuite(t, tests) - } func TestParseOrderByCommandErrorHandling(t *testing.T) { diff --git a/parser/parser_test.go b/parser/parser_test.go index ae46a99..c887986 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -214,6 +214,10 @@ func TestParseWhereCommand(t *testing.T) { input: "SELECT * FROM TBL WHERE colName5 EQUAL NULL;", expectedExpression: fifthExpression, }, + { + input: "SELECT * FROM TBL WHERE colName5 EQUAL NULL;", + expectedExpression: fifthExpression, + }, } for testIndex, tt := range tests { @@ -712,20 +716,20 @@ func TestParseUpdateCommand(t *testing.T) { }{ { input: "UPDATE tbl SET colName TO 5;", expectedTableName: "tbl", expectedChanges: map[token.Token]ast.Anonymitifier{ - {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, - }, + {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, }, { input: "UPDATE tbl1 SET colName1 TO 'hi hello', colName2 TO 5;", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ - {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.IDENT, Literal: "hi hello"}}, - {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, - }, + {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.IDENT, Literal: "hi hello"}}, + {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, }, { input: "UPDATE tbl1 SET colName1 TO NULL, colName2 TO 'NULL';", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ - {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.NULL, Literal: "NULL"}}, - {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "NULL"}}, - }, + {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.NULL, Literal: "NULL"}}, + {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "NULL"}}, + }, }, } From 8d6f5a187c75a22ecf1e8a48a3f2731cd11a7b1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Krupski?= <34219324+ixior462@users.noreply.github.com> Date: Thu, 16 Jan 2025 21:50:12 +0100 Subject: [PATCH 18/21] Add gopher to README.md (#28) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 1bb4c95..c641ee2 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +

+ gopher_GO4SQL +

+ # GO4SQL

From 2f0a67edba65d15576443d19738426c3ecac7349 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Krupski?= <34219324+ixior462@users.noreply.github.com> Date: Thu, 16 Jan 2025 22:08:23 +0100 Subject: [PATCH 19/21] Refactore e2e tests (#27) * Refactore e2e tests structure * Update README in E2E section * Move e2e test to seperate file --- .github/expected_results/end2end.txt | 141 ------------------ .github/workflows/end2end-tests.yml | 8 +- README.md | 26 ++-- e2e/e2e_test.sh | 22 +++ .../1_select_with_where_expected_output | 35 +++++ ...lect_with_limit_and_offset_expected_output | 20 +++ e2e/expected_outputs/3_delete_expected_output | 11 ++ .../4_orderby_expected_output | 11 ++ e2e/expected_outputs/5_update_expected_output | 12 ++ .../6_select_distinct_expected_output | 13 ++ .../7_drop_table_expected_output | 2 + .../8_select_with_join_expected_output | 30 ++++ .../9_aggregate_functions_expected_output | 36 +++++ e2e/test_files/1_select_with_where_test | 12 ++ .../2_select_with_limit_and_offset_test | 9 ++ e2e/test_files/3_delete_test | 9 ++ e2e/test_files/4_orderby_test | 8 + e2e/test_files/5_update_test | 9 ++ e2e/test_files/6_select_distinct_test | 9 ++ e2e/test_files/7_drop_table_test | 3 + e2e/test_files/8_select_with_join_test | 12 ++ e2e/test_files/9_aggregate_functions_test | 14 ++ test_file | 38 ----- 23 files changed, 295 insertions(+), 195 deletions(-) delete mode 100644 .github/expected_results/end2end.txt create mode 100644 e2e/e2e_test.sh create mode 100644 e2e/expected_outputs/1_select_with_where_expected_output create mode 100644 e2e/expected_outputs/2_select_with_limit_and_offset_expected_output create mode 100644 e2e/expected_outputs/3_delete_expected_output create mode 100644 e2e/expected_outputs/4_orderby_expected_output create mode 100644 e2e/expected_outputs/5_update_expected_output create mode 100644 e2e/expected_outputs/6_select_distinct_expected_output create mode 100644 e2e/expected_outputs/7_drop_table_expected_output create mode 100644 e2e/expected_outputs/8_select_with_join_expected_output create mode 100644 e2e/expected_outputs/9_aggregate_functions_expected_output create mode 100644 e2e/test_files/1_select_with_where_test create mode 100644 e2e/test_files/2_select_with_limit_and_offset_test create mode 100644 e2e/test_files/3_delete_test create mode 100644 e2e/test_files/4_orderby_test create mode 100644 e2e/test_files/5_update_test create mode 100644 e2e/test_files/6_select_distinct_test create mode 100644 e2e/test_files/7_drop_table_test create mode 100644 e2e/test_files/8_select_with_join_test create mode 100644 e2e/test_files/9_aggregate_functions_test delete mode 100644 test_file diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt deleted file mode 100644 index e0e7f34..0000000 --- a/.github/expected_results/end2end.txt +++ /dev/null @@ -1,141 +0,0 @@ -Table 'tbl' has been created -Data Inserted -Data Inserted -Data Inserted -+----------+------+-------+------+ -| one | two | three | four | -+----------+------+-------+------+ -| 'byebye' | NULL | 33 | 'e' | -+----------+------+-------+------+ -+-----------+-------+ -| one | three | -+-----------+-------+ -| 'hello' | 11 | -| 'goodbye' | 22 | -+-----------+-------+ -+----------+------+-------+------+ -| one | two | three | four | -+----------+------+-------+------+ -| 'byebye' | NULL | 33 | 'e' | -+----------+------+-------+------+ -+-----------+------+-------+------+ -| one | two | three | four | -+-----------+------+-------+------+ -| 'goodbye' | 1 | 22 | 'w' | -| 'byebye' | NULL | 33 | 'e' | -+-----------+------+-------+------+ -+---------+-----+-------+------+ -| one | two | three | four | -+---------+-----+-------+------+ -| 'hello' | 1 | 11 | 'q' | -+---------+-----+-------+------+ -+-----+-----+-------+------+ -| one | two | three | four | -+-----+-----+-------+------+ -+-----+-----+-------+------+ -+---------+-----+-------+------+ -| one | two | three | four | -+---------+-----+-------+------+ -| 'hello' | 1 | 11 | 'q' | -+---------+-----+-------+------+ -+-----------+------+-------+------+ -| one | two | three | four | -+-----------+------+-------+------+ -| 'goodbye' | 1 | 22 | 'w' | -| 'byebye' | NULL | 33 | 'e' | -+-----------+------+-------+------+ -+-----------+-----+-------+------+ -| one | two | three | four | -+-----------+-----+-------+------+ -| 'goodbye' | 1 | 22 | 'w' | -+-----------+-----+-------+------+ -Data from 'tbl' has been deleted -+-----------+-----+-------+------+ -| one | two | three | four | -+-----------+-----+-------+------+ -| 'hello' | 1 | 11 | 'q' | -| 'goodbye' | 1 | 22 | 'w' | -+-----------+-----+-------+------+ -+-----------+ -| one | -+-----------+ -| 'goodbye' | -| 'hello' | -+-----------+ -Table: 'tbl' has been updated -+-----------+------+-------+------+ -| one | two | three | four | -+-----------+------+-------+------+ -| 'hello' | 1 | 11 | 'q' | -| 'goodbye' | NULL | 22 | 'P' | -+-----------+------+-------+------+ -Data Inserted -+-----------+------+-------+------+ -| one | two | three | four | -+-----------+------+-------+------+ -| 'hello' | 1 | 11 | 'q' | -| 'goodbye' | NULL | 22 | 'P' | -| 'goodbye' | 5 | 22 | 'P' | -+-----------+------+-------+------+ -Table: 'tbl' has been dropped -Table 'table1' has been created -Table 'table2' has been created -Data Inserted -Data Inserted -Data Inserted -Data Inserted -+--------------+--------------+ -| table1.value | table2.value | -+--------------+--------------+ -| 'Value1' | NULL | -| NULL | 'Value2' | -| NULL | 'Value3' | -+--------------+--------------+ -+--------------+--------------+ -| table1.value | table2.value | -+--------------+--------------+ -| NULL | 'Value2' | -+--------------+--------------+ -+--------------+--------------+ -| table1.value | table2.value | -+--------------+--------------+ -| 'Value1' | NULL | -| NULL | 'Value2' | -+--------------+--------------+ -+--------------+--------------+ -| table1.value | table2.value | -+--------------+--------------+ -| NULL | 'Value2' | -| NULL | 'Value3' | -+--------------+--------------+ -Data Inserted -+---------+------------+ -| MAX(id) | MAX(value) | -+---------+------------+ -| 3 | Value3 | -+---------+------------+ -+------------+---------+ -| MIN(value) | MIN(id) | -+------------+---------+ -| NULL | 1 | -+------------+---------+ -+----------+-----------+--------------+ -| COUNT(*) | COUNT(id) | COUNT(value) | -+----------+-----------+--------------+ -| 3 | 3 | 2 | -+----------+-----------+--------------+ -+---------+------------+ -| SUM(id) | SUM(value) | -+---------+------------+ -| 6 | 0 | -+---------+------------+ -+---------+------------+ -| AVG(id) | AVG(value) | -+---------+------------+ -| 2 | 0 | -+---------+------------+ -+---------+----+ -| AVG(id) | id | -+---------+----+ -| 2 | 1 | -+---------+----+ diff --git a/.github/workflows/end2end-tests.yml b/.github/workflows/end2end-tests.yml index 66d7bf6..8f6f0c9 100644 --- a/.github/workflows/end2end-tests.yml +++ b/.github/workflows/end2end-tests.yml @@ -25,8 +25,8 @@ jobs: - name: Build run: go build -v - - name: Run - run: ./GO4SQL -file test_file > output.txt + - name: Make Test Script Executable + run: chmod +x e2e/e2e_test.sh - - name: Check Result - run: diff output.txt ./.github/expected_results/end2end.txt + - name: Run Tests + run: e2e/e2e_test.sh diff --git a/README.md b/README.md index c641ee2..817250b 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,10 @@ You can compile the project with ``go build``, this will create ``GO4SQL`` binar Currently, there are 3 modes to chose from: 1. `File Mode` - You can specify file path with ``./GO4SQL -file file_path``, that will read the - input - data directly into the program and print the result. - + input data directly into the program and print the result. In order to run one of e2e test files you can use: + ```shell + go build; ./GO4SQL -file e2e/test_files/1_select_with_where_test + ``` 2. `Stream Mode` - With ``./GO4SQL -stream`` you can run the program in stream mode, then you provide SQL commands in your console (from standard input). @@ -47,7 +48,15 @@ To run all the tests locally paste this in root directory: go clean -testcache; go test ./... ``` -### Docker +## E2E TESTS + +There are integrated with Github actions e2e tests that can be found in: `.github/workflows/end2end-tests.yml` file. +Tests run files inside `e2e/test_files` directory through `GO4SQL`, save stdout into files, and finally compare +then with expected outputs inside `e2e/expected_outputs` directory. + +To run e2e test locally, you can run script `./e2e/e2e_test.sh` if you're in the root directory. + +## Docker 1. Pull docker image: `docker pull kajedot/go4sql:latest` 2. Run docker container in the interactive mode, remember to provide flag, for example: @@ -306,18 +315,11 @@ go clean -testcache; go test ./... This command will return the average of all values in the numerical column ``columnName`` of ``tableName``. -## E2E TEST - -In root directory there is **test_file** containing input commands for E2E tests. File -**.github/expected_results/end2end.txt** has expected results for it. -This is integrated into github workflows. - ## DOCKER To build your docker image run this command in root directory: -``` -shell +```shell docker build -t go4sql:test . ``` diff --git a/e2e/e2e_test.sh b/e2e/e2e_test.sh new file mode 100644 index 0000000..9dff5ef --- /dev/null +++ b/e2e/e2e_test.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +e2e_failed=false + +for test_file in e2e/test_files/*_test; do + output_file="./e2e/$(basename "${test_file/_test/_output}")" + ./GO4SQL -file "$test_file" > "$output_file" + expected_output="e2e/expected_outputs/$(basename "${test_file/_test/_expected_output}")" + diff "$output_file" "$expected_output" + if [ $? -ne 0 ]; then + echo "E2E test for: {$test_file} failed" + e2e_failed=true + fi + rm "./$output_file" +done + +if [ "$e2e_failed" = true ]; then + echo "E2E tests failed." + exit 1 +else + echo "All E2E tests passed." +fi diff --git a/e2e/expected_outputs/1_select_with_where_expected_output b/e2e/expected_outputs/1_select_with_where_expected_output new file mode 100644 index 0000000..ba58bca --- /dev/null +++ b/e2e/expected_outputs/1_select_with_where_expected_output @@ -0,0 +1,35 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted ++----------+------+-------+------+ +| one | two | three | four | ++----------+------+-------+------+ +| 'byebye' | NULL | 33 | 'e' | ++----------+------+-------+------+ ++-----------+-------+ +| one | three | ++-----------+-------+ +| 'hello' | 11 | +| 'goodbye' | 22 | ++-----------+-------+ ++----------+------+-------+------+ +| one | two | three | four | ++----------+------+-------+------+ +| 'byebye' | NULL | 33 | 'e' | ++----------+------+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ ++---------+-----+-------+------+ +| one | two | three | four | ++---------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | ++---------+-----+-------+------+ ++-----+-----+-------+------+ +| one | two | three | four | ++-----+-----+-------+------+ ++-----+-----+-------+------+ diff --git a/e2e/expected_outputs/2_select_with_limit_and_offset_expected_output b/e2e/expected_outputs/2_select_with_limit_and_offset_expected_output new file mode 100644 index 0000000..cea7b84 --- /dev/null +++ b/e2e/expected_outputs/2_select_with_limit_and_offset_expected_output @@ -0,0 +1,20 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted ++---------+-----+-------+------+ +| one | two | three | four | ++---------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | ++---------+-----+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ ++-----------+-----+-------+------+ +| one | two | three | four | ++-----------+-----+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | ++-----------+-----+-------+------+ diff --git a/e2e/expected_outputs/3_delete_expected_output b/e2e/expected_outputs/3_delete_expected_output new file mode 100644 index 0000000..f0c6911 --- /dev/null +++ b/e2e/expected_outputs/3_delete_expected_output @@ -0,0 +1,11 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted +Data from 'tbl' has been deleted ++-----------+-----+-------+------+ +| one | two | three | four | ++-----------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | 1 | 22 | 'w' | ++-----------+-----+-------+------+ diff --git a/e2e/expected_outputs/4_orderby_expected_output b/e2e/expected_outputs/4_orderby_expected_output new file mode 100644 index 0000000..b920afb --- /dev/null +++ b/e2e/expected_outputs/4_orderby_expected_output @@ -0,0 +1,11 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted ++-----------+ +| one | ++-----------+ +| 'byebye' | +| 'goodbye' | +| 'hello' | ++-----------+ diff --git a/e2e/expected_outputs/5_update_expected_output b/e2e/expected_outputs/5_update_expected_output new file mode 100644 index 0000000..6094e0a --- /dev/null +++ b/e2e/expected_outputs/5_update_expected_output @@ -0,0 +1,12 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted +Table: 'tbl' has been updated ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | NULL | 22 | 'P' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ diff --git a/e2e/expected_outputs/6_select_distinct_expected_output b/e2e/expected_outputs/6_select_distinct_expected_output new file mode 100644 index 0000000..7ed7fb7 --- /dev/null +++ b/e2e/expected_outputs/6_select_distinct_expected_output @@ -0,0 +1,13 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted +Data Inserted +Data Inserted ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ diff --git a/e2e/expected_outputs/7_drop_table_expected_output b/e2e/expected_outputs/7_drop_table_expected_output new file mode 100644 index 0000000..b0852d5 --- /dev/null +++ b/e2e/expected_outputs/7_drop_table_expected_output @@ -0,0 +1,2 @@ +Table 'tbl' has been created +Table: 'tbl' has been dropped diff --git a/e2e/expected_outputs/8_select_with_join_expected_output b/e2e/expected_outputs/8_select_with_join_expected_output new file mode 100644 index 0000000..6e9c1df --- /dev/null +++ b/e2e/expected_outputs/8_select_with_join_expected_output @@ -0,0 +1,30 @@ +Table 'table1' has been created +Table 'table2' has been created +Data Inserted +Data Inserted +Data Inserted +Data Inserted ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| 'Value1' | NULL | +| NULL | 'Value2' | +| NULL | 'Value3' | ++--------------+--------------+ ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| NULL | 'Value2' | ++--------------+--------------+ ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| 'Value1' | NULL | +| NULL | 'Value2' | ++--------------+--------------+ ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| NULL | 'Value2' | +| NULL | 'Value3' | ++--------------+--------------+ diff --git a/e2e/expected_outputs/9_aggregate_functions_expected_output b/e2e/expected_outputs/9_aggregate_functions_expected_output new file mode 100644 index 0000000..4be149e --- /dev/null +++ b/e2e/expected_outputs/9_aggregate_functions_expected_output @@ -0,0 +1,36 @@ +Table 'table1' has been created +Table 'table2' has been created +Data Inserted +Data Inserted +Data Inserted +Data Inserted ++---------+------------+ +| MAX(id) | MAX(value) | ++---------+------------+ +| 2 | Value1 | ++---------+------------+ ++------------+---------+ +| MIN(value) | MIN(id) | ++------------+---------+ +| NULL | 1 | ++------------+---------+ ++----------+-----------+--------------+ +| COUNT(*) | COUNT(id) | COUNT(value) | ++----------+-----------+--------------+ +| 2 | 2 | 1 | ++----------+-----------+--------------+ ++---------+------------+ +| SUM(id) | SUM(value) | ++---------+------------+ +| 3 | 0 | ++---------+------------+ ++---------+------------+ +| AVG(id) | AVG(value) | ++---------+------------+ +| 1 | 0 | ++---------+------------+ ++---------+----+ +| AVG(id) | id | ++---------+----+ +| 1 | 1 | ++---------+----+ diff --git a/e2e/test_files/1_select_with_where_test b/e2e/test_files/1_select_with_where_test new file mode 100644 index 0000000..cd8b6ce --- /dev/null +++ b/e2e/test_files/1_select_with_where_test @@ -0,0 +1,12 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +SELECT * FROM tbl WHERE one EQUAL 'byebye'; +SELECT one, three FROM tbl WHERE two NOT NULL; +SELECT * FROM tbl WHERE one NOT 'goodbye' AND two EQUAL NULL; +SELECT * FROM tbl WHERE one IN ('goodbye', 'byebye'); +SELECT * FROM tbl WHERE one NOTIN ('goodbye', 'byebye'); +SELECT * FROM tbl WHERE FALSE; diff --git a/e2e/test_files/2_select_with_limit_and_offset_test b/e2e/test_files/2_select_with_limit_and_offset_test new file mode 100644 index 0000000..51902d1 --- /dev/null +++ b/e2e/test_files/2_select_with_limit_and_offset_test @@ -0,0 +1,9 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +SELECT * FROM tbl LIMIT 1; +SELECT * FROM tbl OFFSET 1; +SELECT * FROM tbl LIMIT 1 OFFSET 1; diff --git a/e2e/test_files/3_delete_test b/e2e/test_files/3_delete_test new file mode 100644 index 0000000..008c84a --- /dev/null +++ b/e2e/test_files/3_delete_test @@ -0,0 +1,9 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +DELETE FROM tbl WHERE one EQUAL 'byebye'; + +SELECT * FROM tbl; diff --git a/e2e/test_files/4_orderby_test b/e2e/test_files/4_orderby_test new file mode 100644 index 0000000..59c27ba --- /dev/null +++ b/e2e/test_files/4_orderby_test @@ -0,0 +1,8 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +SELECT one FROM tbl WHERE TRUE ORDER BY two ASC, four DESC; + diff --git a/e2e/test_files/5_update_test b/e2e/test_files/5_update_test new file mode 100644 index 0000000..4325950 --- /dev/null +++ b/e2e/test_files/5_update_test @@ -0,0 +1,9 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +UPDATE tbl SET two TO NULL, four TO 'P' WHERE one EQUAL 'goodbye'; + +SELECT * FROM tbl; diff --git a/e2e/test_files/6_select_distinct_test b/e2e/test_files/6_select_distinct_test new file mode 100644 index 0000000..0e8a880 --- /dev/null +++ b/e2e/test_files/6_select_distinct_test @@ -0,0 +1,9 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +SELECT DISTINCT * FROM tbl; diff --git a/e2e/test_files/7_drop_table_test b/e2e/test_files/7_drop_table_test new file mode 100644 index 0000000..e9b24c3 --- /dev/null +++ b/e2e/test_files/7_drop_table_test @@ -0,0 +1,3 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +DROP TABLE tbl; diff --git a/e2e/test_files/8_select_with_join_test b/e2e/test_files/8_select_with_join_test new file mode 100644 index 0000000..5362818 --- /dev/null +++ b/e2e/test_files/8_select_with_join_test @@ -0,0 +1,12 @@ +CREATE TABLE table1( id INT, value TEXT); +CREATE TABLE table2( id INT, value TEXT); + +INSERT INTO table1 VALUES(1, 'Value1'); +INSERT INTO table1 VALUES(2, NULL); +INSERT INTO table2 VALUES(2, 'Value2'); +INSERT INTO table2 VALUES(3, 'Value3'); + +SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id; +SELECT table1.value, table2.value FROM table1 INNER JOIN table2 ON table1.id EQUAL table2.id; +SELECT table1.value, table2.value FROM table1 LEFT JOIN table2 ON table1.id EQUAL table2.id; +SELECT table1.value, table2.value FROM table1 RIGHT JOIN table2 ON table1.id EQUAL table2.id; diff --git a/e2e/test_files/9_aggregate_functions_test b/e2e/test_files/9_aggregate_functions_test new file mode 100644 index 0000000..07b1262 --- /dev/null +++ b/e2e/test_files/9_aggregate_functions_test @@ -0,0 +1,14 @@ +CREATE TABLE table1( id INT, value TEXT); +CREATE TABLE table2( id INT, value TEXT); + +INSERT INTO table1 VALUES(1, 'Value1'); +INSERT INTO table1 VALUES(2, NULL); +INSERT INTO table2 VALUES(2, 'Value2'); +INSERT INTO table2 VALUES(3, 'Value3'); + +SELECT MAX(id), MAX(value) FROM table1; +SELECT MIN(value), MIN(id) FROM table1; +SELECT COUNT(*), COUNT(id), COUNT(value) FROM table1; +SELECT SUM(id), SUM(value) FROM table1; +SELECT AVG(id), AVG(value) FROM table1; +SELECT AVG(id), id FROM table1; diff --git a/test_file b/test_file deleted file mode 100644 index 8149bf0..0000000 --- a/test_file +++ /dev/null @@ -1,38 +0,0 @@ - CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); - INSERT INTO tbl VALUES( 'hello', 1, 11, 'q' ); - INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); - INSERT INTO tbl VALUES( 'byebye', NULL, 33, 'e' ); - SELECT * FROM tbl WHERE one EQUAL 'byebye'; - SELECT one, three FROM tbl WHERE two NOT NULL; - SELECT * FROM tbl WHERE one NOT 'goodbye' AND two EQUAL NULL; - SELECT * FROM tbl WHERE one IN ('goodbye', 'byebye'); - SELECT * FROM tbl WHERE one NOTIN ('goodbye', 'byebye'); - SELECT * FROM tbl WHERE FALSE; - SELECT * FROM tbl LIMIT 1; - SELECT * FROM tbl OFFSET 1; - SELECT * FROM tbl LIMIT 1 OFFSET 1; - DELETE FROM tbl WHERE one EQUAL 'byebye'; - SELECT * FROM tbl; - SELECT one FROM tbl WHERE TRUE ORDER BY two ASC, four DESC; - UPDATE tbl SET two TO NULL, four TO 'P' WHERE one EQUAL 'goodbye'; - SELECT * FROM tbl; - INSERT INTO tbl VALUES( 'goodbye', 5, 22, 'P' ); - SELECT DISTINCT * FROM tbl; - DROP TABLE tbl; - CREATE TABLE table1( id INT, value TEXT); - CREATE TABLE table2( id INT, value TEXT); - INSERT INTO table1 VALUES(1, 'Value1'); - INSERT INTO table1 VALUES(2, NULL); - INSERT INTO table2 VALUES(2, 'Value2'); - INSERT INTO table2 VALUES(3, 'Value3'); - SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id; - SELECT table1.value, table2.value FROM table1 INNER JOIN table2 ON table1.id EQUAL table2.id; - SELECT table1.value, table2.value FROM table1 LEFT JOIN table2 ON table1.id EQUAL table2.id; - SELECT table1.value, table2.value FROM table1 RIGHT JOIN table2 ON table1.id EQUAL table2.id; - INSERT INTO table1 VALUES(3, 'Value3'); - SELECT MAX(id), MAX(value) FROM table1; - SELECT MIN(value), MIN(id) FROM table1; - SELECT COUNT(*), COUNT(id), COUNT(value) FROM table1; - SELECT SUM(id), SUM(value) FROM table1; - SELECT AVG(id), AVG(value) FROM table1; - SELECT AVG(id), id FROM table1; \ No newline at end of file From 87bed280ad9c50673676ab3a40a39b83d66364ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sara=20Ryfczy=C5=84ska?= Date: Fri, 17 Jan 2025 00:59:40 +0100 Subject: [PATCH 20/21] Feature - apostrophe error validate (#29) * Add apostrophe validation errors * refactor * Rewrite getExpression logic --- .../1_select_with_where_expected_output | 7 + e2e/test_files/1_select_with_where_test | 1 + parser/parser.go | 186 ++++++++---------- parser/parser_error_handling_test.go | 9 +- parser/parser_test.go | 25 ++- 5 files changed, 117 insertions(+), 111 deletions(-) diff --git a/e2e/expected_outputs/1_select_with_where_expected_output b/e2e/expected_outputs/1_select_with_where_expected_output index ba58bca..3a89f6a 100644 --- a/e2e/expected_outputs/1_select_with_where_expected_output +++ b/e2e/expected_outputs/1_select_with_where_expected_output @@ -33,3 +33,10 @@ Data Inserted | one | two | three | four | +-----+-----+-------+------+ +-----+-----+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ diff --git a/e2e/test_files/1_select_with_where_test b/e2e/test_files/1_select_with_where_test index cd8b6ce..00156cb 100644 --- a/e2e/test_files/1_select_with_where_test +++ b/e2e/test_files/1_select_with_where_test @@ -10,3 +10,4 @@ SELECT * FROM tbl WHERE one NOT 'goodbye' AND two EQUAL NULL; SELECT * FROM tbl WHERE one IN ('goodbye', 'byebye'); SELECT * FROM tbl WHERE one NOTIN ('goodbye', 'byebye'); SELECT * FROM tbl WHERE FALSE; +SELECT * FROM tbl WHERE 'colName1 EQUAL;' EQUAL 'colName1 EQUAL;'; diff --git a/parser/parser.go b/parser/parser.go index ce5e5d4..3cc7e0e 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -707,126 +707,118 @@ func (parser *Parser) parseUpdateCommand() (ast.Command, error) { // - ast.ConditionExpression // - ast.ContainExpression func (parser *Parser) getExpression() (bool, ast.Expression, error) { - booleanExpressionExists, booleanExpression := parser.getBooleanExpression() - conditionalExpressionExists, conditionalExpression, err := parser.getConditionalExpression() - if err != nil { - return false, nil, err - } + if parser.currentToken.Type == token.IDENT || + parser.currentToken.Type == token.LITERAL || + parser.currentToken.Type == token.NULL || + parser.currentToken.Type == token.APOSTROPHE || + parser.currentToken.Type == token.TRUE || + parser.currentToken.Type == token.FALSE { - containExpressionExists, containExpression, err := parser.getContainExpression() - if err != nil { - return false, nil, err - } + leftSide, isAnonymitifier, err := parser.getExpressionLeftSideValue() + if err != nil { + return false, nil, err + } - operationExpressionExists, operationExpression, err := parser.getOperationExpression(booleanExpressionExists, conditionalExpressionExists, containExpressionExists, booleanExpression, conditionalExpression, containExpression) - if err != nil { - return false, nil, err - } + isValidExpression := false + var expression ast.Expression + + if parser.currentToken.Type == token.EQUAL || parser.currentToken.Type == token.NOT { + isValidExpression, expression, err = parser.getConditionalExpression(leftSide, isAnonymitifier) + } else if parser.currentToken.Type == token.IN || parser.currentToken.Type == token.NOTIN { + isValidExpression, expression, err = parser.getContainExpression(leftSide, isAnonymitifier) + } else if leftSide.Type == token.TRUE || leftSide.Type == token.FALSE { + expression = &ast.BooleanExpression{Boolean: leftSide} + isValidExpression = true + err = nil + } - if operationExpressionExists { - return true, operationExpression, err - } + if err != nil { + return false, nil, err + } - if conditionalExpressionExists { - return true, conditionalExpression, err - } + if (parser.currentToken.Type == token.AND || parser.currentToken.Type == token.OR) && isValidExpression { + isValidExpression, expression, err = parser.getOperationExpression(expression) + } - if containExpressionExists { - return true, containExpression, err - } + if err != nil { + return false, nil, err + } - if booleanExpressionExists { - return true, booleanExpression, err + if isValidExpression { + return true, expression, nil + } } - - return false, nil, err + return false, nil, nil } -// getOperationExpression - Return ast.OperationExpression created from tokens and validate the syntax -func (parser *Parser) getOperationExpression(booleanExpressionExists bool, conditionalExpressionExists bool, containExpressionExists bool, booleanExpression *ast.BooleanExpression, conditionalExpression *ast.ConditionExpression, containExpression *ast.ContainExpression) (bool, *ast.OperationExpression, error) { - operationExpression := &ast.OperationExpression{} - - if (booleanExpressionExists || conditionalExpressionExists || containExpressionExists) && (parser.currentToken.Type == token.OR || parser.currentToken.Type == token.AND) { - if booleanExpressionExists { - operationExpression.Left = booleanExpression - } +func (parser *Parser) getExpressionLeftSideValue() (token.Token, bool, error) { + var leftSide token.Token + isAnonymitifier := false + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() - if conditionalExpressionExists { - operationExpression.Left = conditionalExpression + if startedWithApostrophe { + isAnonymitifier = true + value := "" + for parser.currentToken.Type != token.EOF && parser.currentToken.Type != token.APOSTROPHE { + value += parser.currentToken.Literal + parser.nextToken() } - if containExpressionExists { - operationExpression.Left = containExpression - } + leftSide = token.Token{Type: token.IDENT, Literal: value} - operationExpression.Operation = parser.currentToken - parser.nextToken() + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() - expressionIsValid, expression, err := parser.getExpression() + err := validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, leftSide) if err != nil { - return false, nil, err - } - if !expressionIsValid { - return false, nil, &LogicalExpressionParsingError{afterToken: &operationExpression.Operation.Literal} + return token.Token{}, isAnonymitifier, err } - - operationExpression.Right = expression - - return true, operationExpression, nil + } else { + leftSide = parser.currentToken + parser.nextToken() } - - return false, operationExpression, nil + return leftSide, isAnonymitifier, nil } -// getBooleanExpression - Return ast.BooleanExpression created from tokens and validate the syntax -func (parser *Parser) getBooleanExpression() (bool, *ast.BooleanExpression) { - booleanExpression := &ast.BooleanExpression{} - isValid := false +// getOperationExpression - Return ast.OperationExpression created from tokens and validate the syntax +func (parser *Parser) getOperationExpression(expression ast.Expression) (bool, *ast.OperationExpression, error) { + operationExpression := &ast.OperationExpression{} + operationExpression.Left = expression - if parser.currentToken.Type == token.TRUE || parser.currentToken.Type == token.FALSE { - booleanExpression.Boolean = parser.currentToken - parser.nextToken() - isValid = true - } + operationExpression.Operation = parser.currentToken + parser.nextToken() - return isValid, booleanExpression -} + expressionIsValid, expression, err := parser.getExpression() -// getConditionalExpression - Return ast.ConditionExpression created from tokens and validate the syntax -func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression, error) { - conditionalExpression := &ast.ConditionExpression{} - - err := validateToken(parser.peekToken.Type, []token.Type{token.EQUAL, token.NOT}) if err != nil { - return false, nil, nil + return false, nil, err } - conditionalExpression.Condition = parser.peekToken - if parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { - startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + if !expressionIsValid { + return false, nil, &LogicalExpressionParsingError{afterToken: &operationExpression.Operation.Literal} + } - if !startedWithApostrophe && parser.currentToken.Type == token.IDENT { - conditionalExpression.Left = ast.Identifier{Token: parser.currentToken} - } else { - conditionalExpression.Left = ast.Anonymitifier{Token: parser.currentToken} - } - parser.nextToken() + operationExpression.Right = expression - finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() - err := validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, conditionalExpression.Left.GetToken()) - if err != nil { - return false, nil, err - } + return true, operationExpression, nil +} + +// getConditionalExpression - Return ast.ConditionExpression created from tokens and validate the syntax +func (parser *Parser) getConditionalExpression(leftSide token.Token, isAnonymitifier bool) (bool, *ast.ConditionExpression, error) { + conditionalExpression := &ast.ConditionExpression{Condition: parser.currentToken} + + if isAnonymitifier { + conditionalExpression.Left = ast.Anonymitifier{Token: leftSide} } else { - return false, conditionalExpression, nil + conditionalExpression.Left = ast.Identifier{Token: leftSide} } // skip EQUAL or NOT parser.nextToken() - if parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { + if parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || + parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() if !startedWithApostrophe && parser.currentToken.Type == token.IDENT { @@ -837,7 +829,7 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression parser.nextToken() finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() - err = validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, conditionalExpression.Right.GetToken()) + err := validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, conditionalExpression.Right.GetToken()) if err != nil { return false, nil, err } @@ -849,31 +841,25 @@ func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression } // getContainExpression - Return ast.ContainExpression created from tokens and validate the syntax -func (parser *Parser) getContainExpression() (bool, *ast.ContainExpression, error) { +func (parser *Parser) getContainExpression(leftSide token.Token, isAnonymitifier bool) (bool, *ast.ContainExpression, error) { containExpression := &ast.ContainExpression{} - err := validateToken(parser.peekToken.Type, []token.Type{token.IN, token.NOTIN}) - if err != nil { - return false, nil, nil + if isAnonymitifier { + return false, nil, &SyntaxError{expecting: []string{token.IDENT}, got: "'" + leftSide.Literal + "'"} } - if parser.peekToken.Type == token.IN { + + containExpression.Left = ast.Identifier{Token: leftSide} + + if parser.currentToken.Type == token.IN { containExpression.Contains = true } else { containExpression.Contains = false } - err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - if err != nil { - return false, nil, nil - } - containExpression.Left = ast.Identifier{Token: parser.currentToken} - - parser.nextToken() - // skip IN or NOTIN parser.nextToken() - err = validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + err := validateTokenAndSkip(parser, []token.Type{token.LPAREN}) if err != nil { return false, nil, err } diff --git a/parser/parser_error_handling_test.go b/parser/parser_error_handling_test.go index cfd17c0..a823169 100644 --- a/parser/parser_error_handling_test.go +++ b/parser/parser_error_handling_test.go @@ -129,6 +129,7 @@ func TestParseWhereCommandErrorHandling(t *testing.T) { selectCommandPrefix := "SELECT * FROM tbl " noPredecessorError := NoPredecessorParserError{command: token.WHERE} noColName := LogicalExpressionParsingError{} + noLeftAphostrophe := LogicalExpressionParsingError{} noOperatorInsideWhereStatementException := LogicalExpressionParsingError{} valueIsMissing := SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL, token.NULL}, got: token.SEMICOLON} tokenAnd := token.AND @@ -138,10 +139,12 @@ func TestParseWhereCommandErrorHandling(t *testing.T) { noLeftParGotSemicolon := SyntaxError{expecting: []string{token.LPAREN}, got: ";"} noLeftParGotNumber := SyntaxError{expecting: []string{token.LPAREN}, got: token.LITERAL} noComma := SyntaxError{expecting: []string{token.COMMA, token.RPAREN}, got: token.LITERAL} + anonymitifierInContains := SyntaxError{expecting: []string{token.IDENT}, got: "'one'"} noInKeywordException := LogicalExpressionParsingError{} noLeftApostropheGoodbye := NoApostropheOnLeftParserError{ident: "goodbye"} noLeftApostropheFive := NoApostropheOnLeftParserError{ident: "5"} noRightApostropheGoodbye := NoApostropheOnRightParserError{ident: "goodbye"} + noRightApostropheGoodbyeBigger := NoApostropheOnRightParserError{ident: "goodbye EQUAL two"} noRightApostropheFive := NoApostropheOnRightParserError{ident: "5"} tests := []errorHandlingTestSuite{ @@ -157,12 +160,12 @@ func TestParseWhereCommandErrorHandling(t *testing.T) { {selectCommandPrefix + "WHERE one IN (5 6);", noComma.Error()}, {selectCommandPrefix + "WHERE one IN ('5", noRightApostropheFive.Error()}, {selectCommandPrefix + "WHERE one IN (5');", noLeftApostropheFive.Error()}, + {selectCommandPrefix + "WHERE 'one' IN (5);", anonymitifierInContains.Error()}, {selectCommandPrefix + "WHERE one (5, 6);", noInKeywordException.Error()}, {selectCommandPrefix + "WHERE one EQUAL goodbye';", noLeftApostropheGoodbye.Error()}, {selectCommandPrefix + "WHERE one EQUAL 'goodbye", noRightApostropheGoodbye.Error()}, - // TODO: Add after fix apostrophe on left side of condition - //{selectCommandPrefix + "WHERE 'goodbye EQUAL two", noRightApostropheGoodbye.Error()}, - //{selectCommandPrefix + "WHERE goodbye' EQUAL two", noLeftApostropheGoodbye.Error()}, + {selectCommandPrefix + "WHERE 'goodbye EQUAL two", noRightApostropheGoodbyeBigger.Error()}, + {selectCommandPrefix + "WHERE goodbye' EQUAL two", noLeftAphostrophe.Error()}, } runParserErrorHandlingSuite(t, tests) diff --git a/parser/parser_test.go b/parser/parser_test.go index c887986..0823da4 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -716,20 +716,20 @@ func TestParseUpdateCommand(t *testing.T) { }{ { input: "UPDATE tbl SET colName TO 5;", expectedTableName: "tbl", expectedChanges: map[token.Token]ast.Anonymitifier{ - {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, - }, + {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, }, { input: "UPDATE tbl1 SET colName1 TO 'hi hello', colName2 TO 5;", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ - {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.IDENT, Literal: "hi hello"}}, - {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, - }, + {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.IDENT, Literal: "hi hello"}}, + {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, }, { input: "UPDATE tbl1 SET colName1 TO NULL, colName2 TO 'NULL';", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ - {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.NULL, Literal: "NULL"}}, - {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "NULL"}}, - }, + {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.NULL, Literal: "NULL"}}, + {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "NULL"}}, + }, }, } @@ -834,6 +834,11 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { Boolean: token.Token{Type: token.TRUE, Literal: "TRUE"}, } + fourthExpression := ast.ConditionExpression{ + Left: ast.Anonymitifier{Token: token.Token{Type: token.IDENT, Literal: "colName1 EQUAL;"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.IDENT, Literal: "colName1 EQUAL;"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}} + tests := []struct { input string expectedExpression ast.Expression @@ -850,6 +855,10 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { input: "SELECT * FROM TBL WHERE TRUE;", expectedExpression: thirdExpression, }, + { + input: "SELECT * FROM TBL WHERE 'colName1 EQUAL;' EQUAL 'colName1 EQUAL;';", + expectedExpression: fourthExpression, + }, } for testIndex, tt := range tests { From c476fb460c86aee70058045220d6c1238ec66646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sara=20Ryfczy=C5=84ska?= Date: Fri, 23 May 2025 23:01:14 +0200 Subject: [PATCH 21/21] Refactor and improvements (#31) * Upgrade go version and fix issues from Go Report Card * Fix errorhandling tests * Resolve go report card feed * Refactore engine functions * Change getSelectResponse function * Refactor select response logic * Improve comments --------- Co-authored-by: ixior462 --- .github/workflows/end2end-tests.yml | 2 +- .github/workflows/unit-tests.yml | 2 +- ast/ast.go | 34 +-- e2e/e2e_test.sh | 0 engine/column.go | 2 +- engine/engine.go | 368 ++++++++++----------------- engine/engine_error_handling_test.go | 2 +- engine/errors.go | 22 +- engine/generic_value.go | 15 -- engine/query_processor.go | 178 +++++++++++++ engine/row.go | 3 +- engine/table.go | 42 +-- go.mod | 2 +- parser/parser.go | 363 ++++++++++++++------------ parser/parser_test.go | 2 +- token/token.go | 85 ++++--- 16 files changed, 593 insertions(+), 529 deletions(-) mode change 100644 => 100755 e2e/e2e_test.sh create mode 100644 engine/query_processor.go diff --git a/.github/workflows/end2end-tests.yml b/.github/workflows/end2end-tests.yml index 8f6f0c9..0079f03 100644 --- a/.github/workflows/end2end-tests.yml +++ b/.github/workflows/end2end-tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.21.13', '1.22.7', '1.23.1' ] + go: [ '1.23.6', '1.24.0' ] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 835a56b..82e6a3c 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.21.13', '1.22.7', '1.23.1' ] + go: [ '1.23.6', '1.24.0' ] steps: - uses: actions/checkout@v3 diff --git a/ast/ast.go b/ast/ast.go index 5aa7657..3c10f13 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -49,9 +49,8 @@ type Tifier interface { func (p *Sequence) TokenLiteral() string { if len(p.Commands) > 0 { return p.Commands[0].TokenLiteral() - } else { - return "" } + return "" } // Identifier - Represent Token with string value that is equal to either column or table name @@ -62,7 +61,7 @@ type Identifier struct { func (ls Identifier) IsIdentifier() bool { return true } func (ls Identifier) GetToken() token.Token { return ls.Token } -// Anonymitifier - Represent Token with string value that is equal to simple value that is put into columns +// Anonymitifier - Represent Token with a string value that is equal to a simple value that is put into columns type Anonymitifier struct { Token token.Token // the token.IDENT token } @@ -107,7 +106,7 @@ func (ls ConditionExpression) GetIdentifiers() []Identifier { return identifiers } -// ContainExpression - TokenType of Expression that represents structure for IN operator +// ContainExpression - TokenType of Expression that represents structure for-IN-operator // // Example: // colName IN ('value1', 'value2', 'value3') @@ -205,7 +204,7 @@ type SelectCommand struct { func (ls SelectCommand) CommandNode() {} func (ls SelectCommand) TokenLiteral() string { return ls.Token.Literal } -func (ls *SelectCommand) AggregateFunctionAppears() bool { +func (ls SelectCommand) AggregateFunctionAppears() bool { for _, space := range ls.Space { if space.ContainsAggregateFunc() { return true @@ -223,10 +222,7 @@ func (ls *SelectCommand) AggregateFunctionAppears() bool { // SELECT * FROM table; // Returns false func (ls SelectCommand) HasWhereCommand() bool { - if ls.WhereCommand == nil { - return false - } - return true + return ls.WhereCommand != nil } // HasOrderByCommand - returns true if optional OrderByCommand is present in SelectCommand @@ -238,10 +234,7 @@ func (ls SelectCommand) HasWhereCommand() bool { // SELECT * FROM table; // Returns false func (ls SelectCommand) HasOrderByCommand() bool { - if ls.OrderByCommand == nil { - return false - } - return true + return ls.OrderByCommand != nil } // HasLimitCommand - returns true if optional LimitCommand is present in SelectCommand @@ -253,10 +246,7 @@ func (ls SelectCommand) HasOrderByCommand() bool { // SELECT * FROM table; // Returns false func (ls SelectCommand) HasLimitCommand() bool { - if ls.LimitCommand == nil { - return false - } - return true + return ls.LimitCommand != nil } // HasOffsetCommand - returns true if optional OffsetCommand is present in SelectCommand @@ -268,10 +258,7 @@ func (ls SelectCommand) HasLimitCommand() bool { // SELECT * FROM table LIMIT 10; // Returns false func (ls SelectCommand) HasOffsetCommand() bool { - if ls.OffsetCommand == nil { - return false - } - return true + return ls.OffsetCommand != nil } // HasJoinCommand - returns true if optional JoinCommand is present in SelectCommand @@ -283,10 +270,7 @@ func (ls SelectCommand) HasOffsetCommand() bool { // SELECT * FROM table; // Returns false func (ls SelectCommand) HasJoinCommand() bool { - if ls.JoinCommand == nil { - return false - } - return true + return ls.JoinCommand != nil } // UpdateCommand - Part of Command that allow to change existing data diff --git a/e2e/e2e_test.sh b/e2e/e2e_test.sh old mode 100644 new mode 100755 diff --git a/engine/column.go b/engine/column.go index 419b157..d583b90 100644 --- a/engine/column.go +++ b/engine/column.go @@ -4,7 +4,7 @@ import ( "github.com/LissaGreense/GO4SQL/token" ) -// Column - part of the Table containing name of Column and values in it +// Column - part of the Table containing the name of Column and values in it type Column struct { Name string Type token.Token diff --git a/engine/engine.go b/engine/engine.go index 4ddf188..658d569 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -28,8 +28,11 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { commands := sequences.Commands result := "" + var err error for _, command := range commands { - + if err != nil { + return "", err + } switch mappedCommand := command.(type) { case *ast.WhereCommand: continue @@ -42,33 +45,22 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { case *ast.JoinCommand: continue case *ast.CreateCommand: - err := engine.createTable(mappedCommand) - if err != nil { - return "", err - } + err = engine.createTable(mappedCommand) result += "Table '" + mappedCommand.Name.GetToken().Literal + "' has been created\n" continue case *ast.InsertCommand: - err := engine.insertIntoTable(mappedCommand) - if err != nil { - return "", err - } + err = engine.insertIntoTable(mappedCommand) result += "Data Inserted\n" continue case *ast.SelectCommand: - selectOutput, err := engine.getSelectResponse(mappedCommand) - if err != nil { - return "", err - } + var selectOutput *Table + selectOutput, err = engine.getSelectResponse(mappedCommand) result += selectOutput.ToString() + "\n" continue case *ast.DeleteCommand: deleteCommand := command.(*ast.DeleteCommand) if deleteCommand.HasWhereCommand() { - err := engine.deleteFromTable(mappedCommand, deleteCommand.WhereCommand) - if err != nil { - return "", err - } + err = engine.deleteFromTable(mappedCommand, deleteCommand.WhereCommand) } result += "Data from '" + mappedCommand.Name.GetToken().Literal + "' has been deleted\n" continue @@ -77,75 +69,62 @@ func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { result += "Table: '" + mappedCommand.Name.GetToken().Literal + "' has been dropped\n" continue case *ast.UpdateCommand: - err := engine.updateTable(mappedCommand) - if err != nil { - return "", err - } + err = engine.updateTable(mappedCommand) result += "Table: '" + mappedCommand.Name.GetToken().Literal + "' has been updated\n" continue default: return "", &UnsupportedCommandTypeFromParserError{variable: fmt.Sprintf("%s", command)} } } - - return result, nil + return result, err } -// getSelectResponse - Returns Select response basing on ast.OrderByCommand and ast.WhereCommand included in this Select +// getSelectResponse - processes a SELECT query represented by the ast.SelectCommand and applies a pipeline of +// transformations based on options applied to ast.SelectCommand func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) (*Table, error) { var table *Table var err error if selectCommand.HasJoinCommand() { - joinCommand := selectCommand.JoinCommand - table, err = engine.joinTables(joinCommand, selectCommand.Name.Token.Literal) - if err != nil { - return nil, err - } + table, err = engine.joinTables(selectCommand.JoinCommand, selectCommand.Name.Token.Literal) } else { - var exist bool - table, exist = engine.Tables[selectCommand.Name.Token.Literal] - - if !exist { + var exists bool + table, exists = engine.Tables[selectCommand.Name.Token.Literal] + if !exists { return nil, &TableDoesNotExistError{selectCommand.Name.Token.Literal} } } + if err != nil { + return nil, err + } + + processor := NewSelectProcessor(engine, selectCommand) + + // Build the transformation pipeline using the builder pattern + if selectCommand.HasOrderByCommand() { + processor.WithOrderByClause() + } + if selectCommand.HasWhereCommand() { - whereCommand := selectCommand.WhereCommand - if selectCommand.HasOrderByCommand() { - orderByCommand := selectCommand.OrderByCommand - table, err = engine.selectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand, table) - if err != nil { - return nil, err - } - } else { - table, err = engine.selectFromTableWithWhere(selectCommand, whereCommand, table) - if err != nil { - return nil, err - } - } - } else if selectCommand.HasOrderByCommand() { - table, err = engine.selectFromTableWithOrderBy(selectCommand, selectCommand.OrderByCommand, table) - if err != nil { - return nil, err - } - } else { - table, err = engine.selectFromProvidedTable(selectCommand, table) - if err != nil { - return nil, err - } + processor.WithWhereClause() } - if selectCommand.HasLimitCommand() || selectCommand.HasOffsetCommand() { - table.applyOffsetAndLimit(selectCommand) + // If no WHERE or ORDER BY, the vanilla select (projection) is applied first. + // Otherwise, WHERE/ORDER BY are applied, and then the projection happens within them (handled by their respective transformers). + if !selectCommand.HasOrderByCommand() && !selectCommand.HasWhereCommand() { + processor.WithVanillaSelectClause() + } + + if selectCommand.HasOffsetCommand() || selectCommand.HasLimitCommand() { + processor.WithOffsetLimitClause() } if selectCommand.HasDistinct { - table = table.getDistinctTable() + processor.WithDistinctClause() } - return table, nil + return processor.Process(table) } // createTable - initialize new table in engine with specified name @@ -169,45 +148,52 @@ func (engine *DbEngine) createTable(command *ast.CreateCommand) error { } func (engine *DbEngine) updateTable(command *ast.UpdateCommand) error { - table, exist := engine.Tables[command.Name.Token.Literal] - - if !exist { + table, exists := engine.Tables[command.Name.Token.Literal] + if !exists { return &TableDoesNotExistError{command.Name.Token.Literal} } - columns := table.Columns + columnIndices := make(map[string]int, len(table.Columns)) + for i, col := range table.Columns { + columnIndices[col.Name] = i + } - // TODO: This could be optimized - mappedChanges := make(map[int]ast.Anonymitifier) - for updatedCol, newValue := range command.Changes { - for colIndex := 0; colIndex < len(columns); colIndex++ { - if columns[colIndex].Name == updatedCol.Literal { - mappedChanges[colIndex] = newValue - break - } - if colIndex == len(columns)-1 { - return &ColumnDoesNotExistError{tableName: command.Name.GetToken().Literal, columnName: updatedCol.Literal} + // Map changes to column indices + type change struct { + index int + value ast.Anonymitifier + } + + changes := make([]change, 0, len(command.Changes)) + for colToken, newValue := range command.Changes { + colName := colToken.Literal + colIndex, ok := columnIndices[colName] + if !ok { + return &ColumnDoesNotExistError{ + tableName: command.Name.Token.Literal, + columnName: colName, } } + changes = append(changes, change{index: colIndex, value: newValue}) } - numberOfRows := len(columns[0].Values) - for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { + for rowIndex := 0; rowIndex < len(table.Columns[0].Values); rowIndex++ { if command.HasWhereCommand() { - fulfilledFilters, err := isFulfillingFilters(getRow(table, rowIndex), command.WhereCommand.Expression, command.WhereCommand.Token.Literal) + matches, err := isFulfillingFilters(getRow(table, rowIndex), command.WhereCommand.Expression, command.WhereCommand.Token.Literal) if err != nil { return err } - if !fulfilledFilters { + if !matches { continue } } - for colIndex, value := range mappedChanges { - interfaceValue, err := getInterfaceValue(value.GetToken()) + + for _, change := range changes { + val, err := getInterfaceValue(change.value.GetToken()) if err != nil { return err } - table.Columns[colIndex].Values[rowIndex] = interfaceValue + table.Columns[change.index].Values[rowIndex] = val } } @@ -224,13 +210,23 @@ func (engine *DbEngine) insertIntoTable(command *ast.InsertCommand) error { columns := table.Columns if len(command.Values) != len(columns) { - return &InvalidNumberOfParametersError{expectedNumber: len(columns), actualNumber: len(command.Values), commandName: command.Token.Literal} + return &InvalidNumberOfParametersError{ + expectedNumber: len(columns), + actualNumber: len(command.Values), + commandName: command.Token.Literal, + } } for i := range columns { expectedToken := tokenMapper(columns[i].Type.Type) - if (expectedToken != command.Values[i].Type) && (command.Values[i].Type != token.NULL) { - return &InvalidValueTypeError{expectedType: string(expectedToken), actualType: string(command.Values[i].Type), commandName: command.Token.Literal} + colValueType := command.Values[i].Type + + if (expectedToken != colValueType) && (colValueType != token.NULL) { + return &InvalidValueTypeError{ + expectedType: string(expectedToken), + actualType: string(colValueType), + commandName: command.Token.Literal, + } } interfaceValue, err := getInterfaceValue(command.Values[i]) if err != nil { @@ -249,11 +245,8 @@ func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, tabl selectedTable := &Table{Columns: make([]*Column, 0)} for i := 0; i < len(command.Space); i++ { - var columnType token.Token - var columnName string + col := &Column{} var columnValues []ValueInterface - var err error - value := make([]ValueInterface, 0) currentSpace := command.Space[i] if currentSpace.ColumnName.Type == token.ASTERISK && currentSpace.AggregateFunc.Type == token.COUNT { @@ -261,33 +254,28 @@ func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, tabl columnValues = columns[0].Values } } else { + var err error columnValues, err = getValuesOfColumn(currentSpace.ColumnName.Literal, columns) - } - - if err != nil { - return nil, err + if err != nil { + return nil, err + } } if currentSpace.ContainsAggregateFunc() { - columnName = fmt.Sprintf("%s(%s)", currentSpace.AggregateFunc.Literal, + col.Name = fmt.Sprintf("%s(%s)", currentSpace.AggregateFunc.Literal, currentSpace.ColumnName.Literal) - columnType = evaluateColumnTypeOfAggregateFunc(currentSpace) - aggregatedValue, aggregateErr := aggregateColumnContent(currentSpace, columnValues) - if aggregateErr != nil { - return nil, aggregateErr + col.Type = evaluateColumnTypeOfAggregateFunc(currentSpace) + aggregatedValue, err := aggregateColumnContent(currentSpace, columnValues) + if err != nil { + return nil, err } - value = append(value, aggregatedValue) + col.Values = []ValueInterface{aggregatedValue} } else { - columnName = currentSpace.ColumnName.Literal - columnType = currentSpace.ColumnName - value = append(value, columnValues[0]) + col.Name = currentSpace.ColumnName.Literal + col.Type = currentSpace.ColumnName + col.Values = []ValueInterface{columnValues[0]} } - - selectedTable.Columns = append(selectedTable.Columns, &Column{ - Name: columnName, - Type: columnType, - Values: value, - }) + selectedTable.Columns = append(selectedTable.Columns, col) } return selectedTable, nil } else if command.Space[0].ColumnName.Type == token.ASTERISK { @@ -322,63 +310,59 @@ func evaluateColumnTypeOfAggregateFunc(space ast.Space) token.Token { func aggregateColumnContent(space ast.Space, columnValues []ValueInterface) (ValueInterface, error) { if space.AggregateFunc.Type == token.COUNT { - if space.ColumnName.Type == token.ASTERISK { - return IntegerValue{Value: len(columnValues)}, nil - } - count := 0 - for _, value := range columnValues { - if value.GetType() != NullType { - count++ - } - } - return IntegerValue{Value: count}, nil + return getCount(space, columnValues) } if len(columnValues) == 0 { return NullValue{}, nil } switch space.AggregateFunc.Type { case token.MAX: - maxValue, err := getMax(columnValues) - if err != nil { - return nil, err - } - return maxValue, nil + return getMax(columnValues) case token.MIN: - minValue, err := getMin(columnValues) - if err != nil { - return nil, err - } - return minValue, nil + return getMin(columnValues) case token.SUM: - if columnValues[0].GetType() == StringType { - return IntegerValue{Value: 0}, nil - } else { - sum := 0 - for _, value := range columnValues { - if value.GetType() != NullType { - num, err := strconv.Atoi(value.ToString()) - if err != nil { - return nil, err - } - sum += num - } - } - return IntegerValue{Value: sum}, nil - } + return getSum(columnValues) default: - if columnValues[0].GetType() == StringType { - return IntegerValue{Value: 0}, nil - } else { - sum := 0 - for _, value := range columnValues { + return getAvg(columnValues) + } +} + +func getCount(space ast.Space, columnValues []ValueInterface) (*IntegerValue, error) { + if space.ColumnName.Type == token.ASTERISK { + return &IntegerValue{Value: len(columnValues)}, nil + } + count := 0 + for _, value := range columnValues { + if value.GetType() != NullType { + count++ + } + } + return &IntegerValue{Value: count}, nil +} + +func getAvg(columnValues []ValueInterface) (*IntegerValue, error) { + sum, err := getSum(columnValues) + if err != nil { + return nil, err + } + return &IntegerValue{Value: sum.Value / len(columnValues)}, nil +} + +func getSum(columnValues []ValueInterface) (*IntegerValue, error) { + if columnValues[0].GetType() == StringType { + return &IntegerValue{Value: 0}, nil + } else { + sum := 0 + for _, value := range columnValues { + if value.GetType() != NullType { num, err := strconv.Atoi(value.ToString()) if err != nil { return nil, err } sum += num } - return IntegerValue{Value: sum / len(columnValues)}, nil } + return &IntegerValue{Value: sum}, nil } } @@ -405,54 +389,6 @@ func (engine *DbEngine) dropTable(dropCommand *ast.DropCommand) { delete(engine.Tables, dropCommand.Name.GetToken().Literal) } -// selectFromTableWithWhere - Return Table containing all values requested by SelectCommand and filtered by WhereCommand -func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, table *Table) (*Table, error) { - if len(table.Columns) == 0 || len(table.Columns[0].Values) == 0 { - return engine.selectFromProvidedTable(selectCommand, &Table{Columns: []*Column{}}) - } - - filteredTable, err := engine.getFilteredTable(table, whereCommand, false, selectCommand.Name.GetToken().Literal) - - if err != nil { - return nil, err - } - - return engine.selectFromProvidedTable(selectCommand, filteredTable) -} - -// selectFromTableWithWhereAndOrderBy - Return Table containing all values requested by SelectCommand, -// filtered by WhereCommand and sorted by OrderByCommand -func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, orderByCommand *ast.OrderByCommand, table *Table) (*Table, error) { - filteredTable, err := engine.getFilteredTable(table, whereCommand, false, selectCommand.Name.GetToken().Literal) - - if err != nil { - return nil, err - } - - emptyTable := getCopyOfTableWithoutRows(table) - - sortedTable, err := engine.getSortedTable(orderByCommand, filteredTable, emptyTable, selectCommand.Name.GetToken().Literal) - - if err != nil { - return nil, err - } - - return engine.selectFromProvidedTable(selectCommand, sortedTable) -} - -// selectFromTableWithOrderBy - Return Table containing all values requested by SelectCommand and sorted by OrderByCommand -func (engine *DbEngine) selectFromTableWithOrderBy(selectCommand *ast.SelectCommand, orderByCommand *ast.OrderByCommand, table *Table) (*Table, error) { - emptyTable := getCopyOfTableWithoutRows(table) - - sortedTable, err := engine.getSortedTable(orderByCommand, table, emptyTable, selectCommand.Name.GetToken().Literal) - - if err != nil { - return nil, err - } - - return engine.selectFromProvidedTable(selectCommand, sortedTable) -} - func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, table *Table, copyOfTable *Table, tableName string) (*Table, error) { sortPatterns := orderByCommand.SortPatterns @@ -623,34 +559,6 @@ func addColumnsWithPrefix(finalTable *Table, columnsToAdd []*Column, prefix stri } } -func (table *Table) applyOffsetAndLimit(command *ast.SelectCommand) { - var offset = 0 - var limitRaw = -1 - - if command.HasLimitCommand() { - limitRaw = command.LimitCommand.Count - } - if command.HasOffsetCommand() { - offset = command.OffsetCommand.Count - } - - for _, column := range table.Columns { - var limit int - - if limitRaw == -1 || limitRaw+offset > len(column.Values) { - limit = len(column.Values) - } else { - limit = limitRaw + offset - } - - if offset > len(column.Values) || limit == 0 { - column.Values = make([]ValueInterface, 0) - } else { - column.Values = column.Values[offset:limit] - } - } -} - func xor(fulfilledFilters bool, negation bool) bool { return (fulfilledFilters || negation) && !(fulfilledFilters && negation) } @@ -674,7 +582,7 @@ func isFulfillingFilters(row map[string]ValueInterface, expressionTree ast.Expre case *ast.OperationExpression: return processOperationExpression(row, mappedExpression, commandName) case *ast.BooleanExpression: - return processBooleanExpression(mappedExpression) + return processBooleanExpression(mappedExpression), nil case *ast.ConditionExpression: return processConditionExpression(row, mappedExpression, commandName) case *ast.ContainExpression: @@ -742,7 +650,7 @@ func processOperationExpression(row map[string]ValueInterface, operationExpressi } right, err := isFulfillingFilters(row, operationExpression.Right, commandName) - return left && right, err + return right, err } if operationExpression.Operation.Type == token.OR { @@ -752,17 +660,17 @@ func processOperationExpression(row map[string]ValueInterface, operationExpressi } right, err := isFulfillingFilters(row, operationExpression.Right, commandName) - return left || right, err + return right, err } return false, &UnsupportedOperationTokenError{operationExpression.Operation.Literal} } -func processBooleanExpression(booleanExpression *ast.BooleanExpression) (bool, error) { +func processBooleanExpression(booleanExpression *ast.BooleanExpression) bool { if booleanExpression.Boolean.Literal == token.TRUE { - return true, nil + return true } - return false, nil + return false } func getTifierValue(tifier ast.Tifier, row map[string]ValueInterface) (ValueInterface, error) { diff --git a/engine/engine_error_handling_test.go b/engine/engine_error_handling_test.go index f40f659..9c7f88b 100644 --- a/engine/engine_error_handling_test.go +++ b/engine/engine_error_handling_test.go @@ -118,7 +118,7 @@ func getErrorMessage(t *testing.T, input string, testIndex int) string { parserInstance := parser.New(lexerInstance) sequences, parserError := parserInstance.ParseSequence() if parserError != nil { - t.Fatalf("[%d] Error has occured in parser not in engine, error: %s", testIndex, parserError.Error()) + t.Fatalf("[%d] Error has occurred in parser not in engine, error: %s", testIndex, parserError.Error()) } engine := New() diff --git a/engine/errors.go b/engine/errors.go index 81910de..e0497ad 100644 --- a/engine/errors.go +++ b/engine/errors.go @@ -2,8 +2,8 @@ package engine import "strconv" -// TableAlreadyExistsError - error thrown when user tries to create table using name that already -// exists in database +// TableAlreadyExistsError - error thrown when a user tries to create table using name that already +// exists in a database type TableAlreadyExistsError struct { tableName string } @@ -12,7 +12,7 @@ func (m *TableAlreadyExistsError) Error() string { return "table with the name of " + m.tableName + " already exists" } -// TableDoesNotExistError - error thrown when user tries to make operation on un-existing table +// TableDoesNotExistError - error thrown when the user tries to make operation on an unexisting table type TableDoesNotExistError struct { tableName string } @@ -21,7 +21,7 @@ func (m *TableDoesNotExistError) Error() string { return "table with the name of " + m.tableName + " doesn't exist" } -// ColumnDoesNotExistError - error thrown when user tries to make operation on un-existing column +// ColumnDoesNotExistError - error thrown when the user tries to make operation on an unexisting column type ColumnDoesNotExistError struct { tableName string columnName string @@ -32,7 +32,7 @@ func (m *ColumnDoesNotExistError) Error() string { } // InvalidNumberOfParametersError - error thrown when user provides invalid number of expected parameters -// (ex. fewer values in insert than defined ) +// (ex. fewer values in insert than defined) type InvalidNumberOfParametersError struct { expectedNumber int actualNumber int @@ -43,7 +43,7 @@ func (m *InvalidNumberOfParametersError) Error() string { return "invalid number of parameters in " + m.commandName + " command, should be: " + strconv.Itoa(m.expectedNumber) + ", but got: " + strconv.Itoa(m.actualNumber) } -// InvalidValueTypeError - error thrown when user provides value of different type than expected +// InvalidValueTypeError - error thrown when a user provides value of a different type than expected type InvalidValueTypeError struct { expectedType string actualType string @@ -54,7 +54,7 @@ func (m *InvalidValueTypeError) Error() string { return "invalid value type provided in " + m.commandName + " command, expecting: " + m.expectedType + ", got: " + m.actualType } -// UnsupportedValueType - error thrown when engine found unsupported data type to be stored inside +// UnsupportedValueType - error thrown when the engine found unsupported data type to be stored inside // the columns type UnsupportedValueType struct { variable string @@ -74,8 +74,8 @@ func (m *UnsupportedOperationTokenError) Error() string { return "unsupported operation token has been used: " + m.variable } -// UnsupportedConditionalTokenError - error thrown when engine found unsupported conditional token -// inside expression (supported are: EQUAL, NOT) +// UnsupportedConditionalTokenError - error thrown when the engine found unsupported conditional token +// inside the expression (supported are: EQUAL, NOT) type UnsupportedConditionalTokenError struct { variable string commandName string @@ -85,7 +85,7 @@ func (m *UnsupportedConditionalTokenError) Error() string { return "operation '" + m.variable + "' provided in " + m.commandName + " command isn't allowed" } -// UnsupportedExpressionTypeError - error thrown when engine found unsupported expression type +// UnsupportedExpressionTypeError - error thrown when the engine found an unsupported expression type type UnsupportedExpressionTypeError struct { variable string commandName string @@ -95,7 +95,7 @@ func (m *UnsupportedExpressionTypeError) Error() string { return "unsupported expression has been used in " + m.commandName + "command: " + m.variable } -// UnsupportedCommandTypeFromParserError - error thrown when engine found unsupported command +// UnsupportedCommandTypeFromParserError - error thrown when the engine found unsupported command // from parser type UnsupportedCommandTypeFromParserError struct { variable string diff --git a/engine/generic_value.go b/engine/generic_value.go index d5c26c9..048f49f 100644 --- a/engine/generic_value.go +++ b/engine/generic_value.go @@ -2,7 +2,6 @@ package engine import ( "errors" - "fmt" "log" "strconv" ) @@ -38,20 +37,6 @@ type StringValue struct { type NullValue struct { } -// HandleValue - Function to take an instance of ValueInterface and cast to a specific implementation -func CastValueInterface(v ValueInterface) { - switch value := v.(type) { - case IntegerValue: - fmt.Printf("IntegerValue with Value: %d\n", value.Value) - case StringValue: - fmt.Printf("StringValue with Value: %s\n", value.Value) - case NullValue: - fmt.Println("NullValue (no value)") - default: - fmt.Println("Unknown type") - } -} - // ToString implementations func (value IntegerValue) ToString() string { return strconv.Itoa(value.Value) } func (value StringValue) ToString() string { return value.Value } diff --git a/engine/query_processor.go b/engine/query_processor.go new file mode 100644 index 0000000..46b4343 --- /dev/null +++ b/engine/query_processor.go @@ -0,0 +1,178 @@ +package engine + +import ( + "hash/adler32" + + "github.com/LissaGreense/GO4SQL/ast" + "github.com/LissaGreense/GO4SQL/token" +) + +// TableTransformer defines a function that takes a Table as input and then applies a transformation +type TableTransformer func(*Table) (*Table, error) + +// SelectProcessor handles the step-by-step processing of a SELECT query using a builder pattern. +type SelectProcessor struct { + engine *DbEngine + cmd *ast.SelectCommand + transformers []TableTransformer +} + +// NewSelectProcessor creates a new SelectProcessor. +func NewSelectProcessor(engine *DbEngine, cmd *ast.SelectCommand) *SelectProcessor { + return &SelectProcessor{ + engine: engine, + cmd: cmd, + transformers: []TableTransformer{}, // Initialize as an empty slice + } +} + +// WithVanillaSelectClause adds the vanilla select (projection) transformation. +func (sp *SelectProcessor) WithVanillaSelectClause() *SelectProcessor { + sp.transformers = append(sp.transformers, sp.getVanillaSelectTransformer()) + return sp +} + +// WithWhereClause adds the WHERE clause transformation. +func (sp *SelectProcessor) WithWhereClause() *SelectProcessor { + sp.transformers = append(sp.transformers, sp.getWhereTransformer()) + return sp +} + +// WithOrderByClause adds the ORDER BY clause transformation. +func (sp *SelectProcessor) WithOrderByClause() *SelectProcessor { + sp.transformers = append(sp.transformers, sp.getOrderByTransformer()) + return sp +} + +// WithOffsetLimitClause adds the OFFSET and LIMIT clause transformation. +func (sp *SelectProcessor) WithOffsetLimitClause() *SelectProcessor { + sp.transformers = append(sp.transformers, sp.getOffsetLimitTransformer()) + return sp +} + +// WithDistinctClause adds the DISTINCT clause transformation. +func (sp *SelectProcessor) WithDistinctClause() *SelectProcessor { + sp.transformers = append(sp.transformers, sp.getDistinctTransformer()) + return sp +} + +// Process applies the configured pipeline of transformations to the initialTable. +func (sp *SelectProcessor) Process(initialTable *Table) (*Table, error) { + table := initialTable + var err error + + for _, transform := range sp.transformers { + table, err = transform(table) + if err != nil { + return nil, err + } + } + return table, nil +} + +// --- Private Transformer Getters --- + +func (sp *SelectProcessor) getVanillaSelectTransformer() TableTransformer { + return func(tbl *Table) (*Table, error) { + return sp.engine.selectFromProvidedTable(sp.cmd, tbl) + } +} + +func (sp *SelectProcessor) getWhereTransformer() TableTransformer { + return func(tbl *Table) (*Table, error) { + if len(tbl.Columns) == 0 || (len(tbl.Columns) > 0 && len(tbl.Columns[0].Values) == 0) { + return sp.engine.selectFromProvidedTable(sp.cmd, &Table{Columns: []*Column{}}) + } + filtered, err := sp.engine.getFilteredTable(tbl, sp.cmd.WhereCommand, false, sp.cmd.Name.GetToken().Literal) + if err != nil { + return nil, err + } + return sp.engine.selectFromProvidedTable(sp.cmd, filtered) + } +} + +func (sp *SelectProcessor) getOrderByTransformer() TableTransformer { + return func(tbl *Table) (*Table, error) { + emptyTable := getCopyOfTableWithoutRows(tbl) + sorted, err := sp.engine.getSortedTable(sp.cmd.OrderByCommand, tbl, emptyTable, sp.cmd.Name.GetToken().Literal) + if err != nil { + return nil, err + } + return sp.engine.selectFromProvidedTable(sp.cmd, sorted) + } +} + +func (sp *SelectProcessor) getOffsetLimitTransformer() TableTransformer { + return func(tbl *Table) (*Table, error) { + var offset = 0 + var limitRaw = -1 + + if sp.cmd.HasLimitCommand() { + limitRaw = sp.cmd.LimitCommand.Count + } + if sp.cmd.HasOffsetCommand() { + offset = sp.cmd.OffsetCommand.Count + } + + if len(tbl.Columns) == 0 { + return tbl, nil + } + + for _, column := range tbl.Columns { + var limit int + + if limitRaw == -1 || limitRaw+offset > len(column.Values) { + limit = len(column.Values) + } else { + limit = limitRaw + offset + } + + if offset >= len(column.Values) { + column.Values = make([]ValueInterface, 0) + } else if offset < len(column.Values) && limit > offset { + column.Values = column.Values[offset:limit] + } else { + column.Values = make([]ValueInterface, 0) + } + } + return tbl, nil + } +} + +func (sp *SelectProcessor) getDistinctTransformer() TableTransformer { + return func(tbl *Table) (*Table, error) { + if len(tbl.Columns) == 0 || len(tbl.Columns[0].Values) == 0 { + return tbl, nil + } + + distinctTable := getCopyOfTableWithoutRows(tbl) + rowsCount := len(tbl.Columns[0].Values) + checksumSet := make(map[uint32]struct{}) + + for iRow := range rowsCount { + mergedColumnValues := "" + for iColumn := range tbl.Columns { + if iRow < len(tbl.Columns[iColumn].Values) { + fieldValue := tbl.Columns[iColumn].Values[iRow].ToString() + if tbl.Columns[iColumn].Type.Literal == token.TEXT { + fieldValue = "'" + fieldValue + "'" + } + mergedColumnValues += fieldValue + } else { + mergedColumnValues += "" + } + } + checksum := adler32.Checksum([]byte(mergedColumnValues)) + + if _, exist := checksumSet[checksum]; !exist { + checksumSet[checksum] = struct{}{} + for i, column := range distinctTable.Columns { + if iRow < len(tbl.Columns[i].Values) { + column.Values = append(column.Values, tbl.Columns[i].Values[iRow]) + } + } + } + } + return distinctTable, nil + } +} diff --git a/engine/row.go b/engine/row.go index 6160956..b31f881 100644 --- a/engine/row.go +++ b/engine/row.go @@ -12,8 +12,7 @@ func MapTableToRows(table *Table) Rows { numberOfRows := len(table.Columns[0].Values) for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { - row := getRow(table, rowIndex) - rows = append(rows, row) + rows = append(rows, getRow(table, rowIndex)) } return Rows{rows: rows} } diff --git a/engine/table.go b/engine/table.go index ec8561a..a3fedd3 100644 --- a/engine/table.go +++ b/engine/table.go @@ -2,14 +2,15 @@ package engine import ( "github.com/LissaGreense/GO4SQL/token" - "hash/adler32" ) // Table - Contain Columns that store values in engine type Table struct { - Columns []*Column + Columns Columns } +type Columns []*Column + func (table *Table) isEqual(secondTable *Table) bool { if len(table.Columns) != len(secondTable.Columns) { return false @@ -38,40 +39,11 @@ func (table *Table) isEqual(secondTable *Table) bool { return true } -// getDistinctTable - Takes input table, and returns new one without any duplicates -func (table *Table) getDistinctTable() *Table { - distinctTable := getCopyOfTableWithoutRows(table) - - rowsCount := len(table.Columns[0].Values) - - checksumSet := map[uint32]struct{}{} - - for iRow := 0; iRow < rowsCount; iRow++ { - - mergedColumnValues := "" - for iColumn := range table.Columns { - fieldValue := table.Columns[iColumn].Values[iRow].ToString() - if table.Columns[iColumn].Type.Literal == token.TEXT { - fieldValue = "'" + fieldValue + "'" - } - mergedColumnValues += fieldValue - } - checksum := adler32.Checksum([]byte(mergedColumnValues)) - - _, exist := checksumSet[checksum] - if !exist { - checksumSet[checksum] = struct{}{} - for i, column := range distinctTable.Columns { - column.Values = append(column.Values, table.Columns[i].Values[iRow]) - } - } - } - - return distinctTable -} - -// ToString - Return string contain all values and Column names in Table +// ToString - Return string contains all values and Column names in Table func (table *Table) ToString() string { + if table == nil { + return "" + } columWidths := getColumWidths(table.Columns) bar := getBar(columWidths) result := bar + "\n" diff --git a/go.mod b/go.mod index aed510b..d6e4689 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/LissaGreense/GO4SQL -go 1.21 +go 1.23 diff --git a/parser/parser.go b/parser/parser.go index 3cc7e0e..e600b4a 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -263,44 +263,9 @@ func (parser *Parser) parseSelectCommand() (ast.Command, error) { selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken}) parser.nextToken() } else { - for parser.currentToken.Type == token.IDENT || isAggregateFunction(parser.currentToken.Type) { - if parser.currentToken.Type != token.IDENT { - aggregateFunction := parser.currentToken - parser.nextToken() - err := validateTokenAndSkip(parser, []token.Type{token.LPAREN}) - if err != nil { - return nil, err - } - if aggregateFunction.Type == token.COUNT { - err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.ASTERISK}) - } else { - err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - } - if err != nil { - return nil, err - } - selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken, AggregateFunc: &aggregateFunction}) - parser.nextToken() - - err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) - if err != nil { - return nil, err - } - } else { - // Get column name - err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - if err != nil { - return nil, err - } - selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken}) - parser.nextToken() - } - - if parser.currentToken.Type != token.COMMA { - break - } - // Ignore token.COMMA - parser.nextToken() + command, err := parser.parseSelectSpace(selectCommand) + if err != nil { + return command, err } } @@ -331,6 +296,54 @@ func (parser *Parser) parseSelectCommand() (ast.Command, error) { return selectCommand, nil } +func (parser *Parser) parseSelectSpace(selectCommand *ast.SelectCommand) (ast.Command, error) { + for parser.currentToken.Type == token.IDENT || isAggregateFunction(parser.currentToken.Type) { + if parser.currentToken.Type != token.IDENT { + err := parser.parseAggregateFunction(selectCommand) + if err != nil { + return nil, err + } + } else { + // Get column name + err := validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken}) + parser.nextToken() + } + + if parser.currentToken.Type != token.COMMA { + break + } + // Ignore token.COMMA + parser.nextToken() + } + return nil, nil +} + +func (parser *Parser) parseAggregateFunction(selectCommand *ast.SelectCommand) error { + aggregateFunction := parser.currentToken + parser.nextToken() + err := validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + if err != nil { + return err + } + if aggregateFunction.Type == token.COUNT { + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.ASTERISK}) + } else { + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + } + if err != nil { + return err + } + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken, AggregateFunc: &aggregateFunction}) + parser.nextToken() + + err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) + return err +} + func (parser *Parser) getColumnName(err error, selectCommand *ast.SelectCommand, aggregateFunction token.Token) error { // Get column name err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.ASTERISK}) @@ -353,17 +366,16 @@ func isAggregateFunction(t token.Type) bool { func (parser *Parser) parseWhereCommand() (ast.Command, error) { // token.WHERE already at current position in parser whereCommand := &ast.WhereCommand{Token: parser.currentToken} - expressionIsValid := false // Ignore token.WHERE parser.nextToken() var err error - expressionIsValid, whereCommand.Expression, err = parser.getExpression() + whereCommand.Expression, err = parser.getExpression() if err != nil { return nil, err } - if !expressionIsValid { + if whereCommand.Expression == nil { return nil, &LogicalExpressionParsingError{} } @@ -601,13 +613,12 @@ func (parser *Parser) parseJoinCommand() (ast.Command, error) { return nil, err } - var expressionIsValid bool - expressionIsValid, joinCommand.Expression, err = parser.getExpression() + joinCommand.Expression, err = parser.getExpression() if err != nil { return nil, err } - if !expressionIsValid { + if joinCommand.Expression == nil { return nil, &LogicalExpressionParsingError{} } @@ -706,7 +717,7 @@ func (parser *Parser) parseUpdateCommand() (ast.Command, error) { // - ast.BooleanExpression // - ast.ConditionExpression // - ast.ContainExpression -func (parser *Parser) getExpression() (bool, ast.Expression, error) { +func (parser *Parser) getExpression() (ast.Expression, error) { if parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || @@ -717,39 +728,39 @@ func (parser *Parser) getExpression() (bool, ast.Expression, error) { leftSide, isAnonymitifier, err := parser.getExpressionLeftSideValue() if err != nil { - return false, nil, err + return nil, err } - isValidExpression := false - var expression ast.Expression - - if parser.currentToken.Type == token.EQUAL || parser.currentToken.Type == token.NOT { - isValidExpression, expression, err = parser.getConditionalExpression(leftSide, isAnonymitifier) - } else if parser.currentToken.Type == token.IN || parser.currentToken.Type == token.NOTIN { - isValidExpression, expression, err = parser.getContainExpression(leftSide, isAnonymitifier) - } else if leftSide.Type == token.TRUE || leftSide.Type == token.FALSE { - expression = &ast.BooleanExpression{Boolean: leftSide} - isValidExpression = true - err = nil - } + expression, err := parser.getExpressionLeaf(leftSide, isAnonymitifier) if err != nil { - return false, nil, err + return nil, err } - if (parser.currentToken.Type == token.AND || parser.currentToken.Type == token.OR) && isValidExpression { - isValidExpression, expression, err = parser.getOperationExpression(expression) + if (parser.currentToken.Type == token.AND || parser.currentToken.Type == token.OR) && expression != nil { + expression, err = parser.getOperationExpression(expression) } if err != nil { - return false, nil, err + return nil, err } - if isValidExpression { - return true, expression, nil + if expression != nil { + return expression, nil } } - return false, nil, nil + return nil, nil +} + +func (parser *Parser) getExpressionLeaf(leftSide token.Token, isAnonymitifier bool) (ast.Expression, error) { + if parser.currentToken.Type == token.EQUAL || parser.currentToken.Type == token.NOT { + return parser.getConditionalExpression(leftSide, isAnonymitifier) + } else if parser.currentToken.Type == token.IN || parser.currentToken.Type == token.NOTIN { + return parser.getContainExpression(leftSide, isAnonymitifier) + } else if leftSide.Type == token.TRUE || leftSide.Type == token.FALSE { + return &ast.BooleanExpression{Boolean: leftSide}, nil + } + return nil, nil } func (parser *Parser) getExpressionLeftSideValue() (token.Token, bool, error) { @@ -782,30 +793,30 @@ func (parser *Parser) getExpressionLeftSideValue() (token.Token, bool, error) { } // getOperationExpression - Return ast.OperationExpression created from tokens and validate the syntax -func (parser *Parser) getOperationExpression(expression ast.Expression) (bool, *ast.OperationExpression, error) { +func (parser *Parser) getOperationExpression(expression ast.Expression) (*ast.OperationExpression, error) { operationExpression := &ast.OperationExpression{} operationExpression.Left = expression operationExpression.Operation = parser.currentToken parser.nextToken() - expressionIsValid, expression, err := parser.getExpression() + expression, err := parser.getExpression() if err != nil { - return false, nil, err + return nil, err } - if !expressionIsValid { - return false, nil, &LogicalExpressionParsingError{afterToken: &operationExpression.Operation.Literal} + if expression == nil { + return nil, &LogicalExpressionParsingError{afterToken: &operationExpression.Operation.Literal} } operationExpression.Right = expression - return true, operationExpression, nil + return operationExpression, nil } // getConditionalExpression - Return ast.ConditionExpression created from tokens and validate the syntax -func (parser *Parser) getConditionalExpression(leftSide token.Token, isAnonymitifier bool) (bool, *ast.ConditionExpression, error) { +func (parser *Parser) getConditionalExpression(leftSide token.Token, isAnonymitifier bool) (*ast.ConditionExpression, error) { conditionalExpression := &ast.ConditionExpression{Condition: parser.currentToken} if isAnonymitifier { @@ -831,21 +842,21 @@ func (parser *Parser) getConditionalExpression(leftSide token.Token, isAnonymiti finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() err := validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, conditionalExpression.Right.GetToken()) if err != nil { - return false, nil, err + return nil, err } } else { - return false, conditionalExpression, &SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL, token.NULL}, got: parser.currentToken.Literal} + return nil, &SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL, token.NULL}, got: parser.currentToken.Literal} } - return true, conditionalExpression, nil + return conditionalExpression, nil } // getContainExpression - Return ast.ContainExpression created from tokens and validate the syntax -func (parser *Parser) getContainExpression(leftSide token.Token, isAnonymitifier bool) (bool, *ast.ContainExpression, error) { +func (parser *Parser) getContainExpression(leftSide token.Token, isAnonymitifier bool) (*ast.ContainExpression, error) { containExpression := &ast.ContainExpression{} if isAnonymitifier { - return false, nil, &SyntaxError{expecting: []string{token.IDENT}, got: "'" + leftSide.Literal + "'"} + return nil, &SyntaxError{expecting: []string{token.IDENT}, got: "'" + leftSide.Literal + "'"} } containExpression.Left = ast.Identifier{Token: leftSide} @@ -861,7 +872,7 @@ func (parser *Parser) getContainExpression(leftSide token.Token, isAnonymitifier err := validateTokenAndSkip(parser, []token.Type{token.LPAREN}) if err != nil { - return false, nil, err + return nil, err } for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { @@ -869,7 +880,7 @@ func (parser *Parser) getContainExpression(leftSide token.Token, isAnonymitifier err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) if err != nil { - return false, nil, err + return nil, err } currentAnonymitifier := ast.Anonymitifier{Token: parser.currentToken} containExpression.Right = append(containExpression.Right, currentAnonymitifier) @@ -880,12 +891,12 @@ func (parser *Parser) getContainExpression(leftSide token.Token, isAnonymitifier err = validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, currentAnonymitifier.GetToken()) if err != nil { - return false, nil, err + return nil, err } if parser.currentToken.Type != token.COMMA { if parser.currentToken.Type != token.RPAREN { - return false, nil, &SyntaxError{expecting: []string{token.COMMA, token.RPAREN}, got: string(parser.currentToken.Type)} + return nil, &SyntaxError{expecting: []string{token.COMMA, token.RPAREN}, got: string(parser.currentToken.Type)} } break } @@ -896,10 +907,10 @@ func (parser *Parser) getContainExpression(leftSide token.Token, isAnonymitifier err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) if err != nil { - return false, nil, err + return nil, err } - return true, containExpression, err + return containExpression, err } // ParseSequence - Return ast.Sequence (sequence of commands) created from client input after tokenization @@ -926,90 +937,15 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { case token.DROP: command, err = parser.parseDropCommand() case token.WHERE: - lastCommand, parserError := parser.getLastCommand(sequence, token.WHERE) - if parserError != nil { - return nil, parserError - } - - if lastCommand.TokenLiteral() == token.SELECT { - newCommand, err := parser.parseWhereCommand() - if err != nil { - return nil, err - } - lastCommand.(*ast.SelectCommand).WhereCommand = newCommand.(*ast.WhereCommand) - } else if lastCommand.TokenLiteral() == token.DELETE { - newCommand, err := parser.parseWhereCommand() - if err != nil { - return nil, err - } - lastCommand.(*ast.DeleteCommand).WhereCommand = newCommand.(*ast.WhereCommand) - } else if lastCommand.TokenLiteral() == token.UPDATE { - newCommand, err := parser.parseWhereCommand() - if err != nil { - return nil, err - } - lastCommand.(*ast.UpdateCommand).WhereCommand = newCommand.(*ast.WhereCommand) - } else { - return nil, &SyntaxCommandExpectedError{command: "WHERE", neededCommands: []string{"SELECT", "DELETE", "UPDATE"}} - } + err = parser.updateLastCommandWithWhereConstraints(sequence) case token.ORDER: - lastCommand, parserError := parser.getLastCommand(sequence, token.ORDER) - if parserError != nil { - return nil, parserError - } - - if lastCommand.TokenLiteral() != token.SELECT { - return nil, &SyntaxCommandExpectedError{command: "ORDER BY", neededCommands: []string{"SELECT"}} - } - - selectCommand := lastCommand.(*ast.SelectCommand) - newCommand, err := parser.parseOrderByCommand() - if err != nil { - return nil, err - } - selectCommand.OrderByCommand = newCommand.(*ast.OrderByCommand) + err = parser.updateSelectCommandWithOrderByConstraints(sequence) case token.LIMIT: - lastCommand, parserError := parser.getLastCommand(sequence, token.LIMIT) - if parserError != nil { - return nil, parserError - } - if lastCommand.TokenLiteral() != token.SELECT { - return nil, &SyntaxCommandExpectedError{command: "LIMIT", neededCommands: []string{"SELECT"}} - } - selectCommand := lastCommand.(*ast.SelectCommand) - newCommand, err := parser.parseLimitCommand() - if err != nil { - return nil, err - } - selectCommand.LimitCommand = newCommand.(*ast.LimitCommand) + err = parser.updateSelectCommandWithLimitConstraints(sequence) case token.OFFSET: - lastCommand, parserError := parser.getLastCommand(sequence, token.OFFSET) - if parserError != nil { - return nil, parserError - } - if lastCommand.TokenLiteral() != token.SELECT { - return nil, &SyntaxCommandExpectedError{command: "OFFSET", neededCommands: []string{"SELECT"}} - } - selectCommand := lastCommand.(*ast.SelectCommand) - newCommand, err := parser.parseOffsetCommand() - if err != nil { - return nil, err - } - selectCommand.OffsetCommand = newCommand.(*ast.OffsetCommand) + err = parser.updateSelectCommandWithOffsetConstraints(sequence) case token.JOIN, token.LEFT, token.RIGHT, token.INNER, token.FULL: - lastCommand, parserError := parser.getLastCommand(sequence, token.JOIN) - if parserError != nil { - return nil, parserError - } - if lastCommand.TokenLiteral() != token.SELECT { - return nil, &SyntaxCommandExpectedError{command: "JOIN", neededCommands: []string{"SELECT"}} - } - selectCommand := lastCommand.(*ast.SelectCommand) - newCommand, err := parser.parseJoinCommand() - if err != nil { - return nil, err - } - selectCommand.JoinCommand = newCommand.(*ast.JoinCommand) + err = parser.updateSelectCommandWithJoinConstraints(sequence) default: return nil, &SyntaxInvalidCommandError{invalidCommand: parser.currentToken.Literal} } @@ -1023,10 +959,109 @@ func (parser *Parser) ParseSequence() (*ast.Sequence, error) { sequence.Commands = append(sequence.Commands, command) } } - return sequence, nil } +func (parser *Parser) updateSelectCommandWithJoinConstraints(sequence *ast.Sequence) error { + lastCommand, parserError := parser.getLastCommand(sequence, token.JOIN) + if parserError != nil { + return parserError + } + if lastCommand.TokenLiteral() != token.SELECT { + return &SyntaxCommandExpectedError{command: "JOIN", neededCommands: []string{"SELECT"}} + } + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseJoinCommand() + if err != nil { + return err + } + selectCommand.JoinCommand = newCommand.(*ast.JoinCommand) + return nil +} + +func (parser *Parser) updateSelectCommandWithOffsetConstraints(sequence *ast.Sequence) error { + lastCommand, parserError := parser.getLastCommand(sequence, token.OFFSET) + if parserError != nil { + return parserError + } + if lastCommand.TokenLiteral() != token.SELECT { + return &SyntaxCommandExpectedError{command: "OFFSET", neededCommands: []string{"SELECT"}} + } + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseOffsetCommand() + if err != nil { + return err + } + selectCommand.OffsetCommand = newCommand.(*ast.OffsetCommand) + return nil +} + +func (parser *Parser) updateSelectCommandWithLimitConstraints(sequence *ast.Sequence) error { + lastCommand, parserError := parser.getLastCommand(sequence, token.LIMIT) + if parserError != nil { + return parserError + } + if lastCommand.TokenLiteral() != token.SELECT { + return &SyntaxCommandExpectedError{command: "LIMIT", neededCommands: []string{"SELECT"}} + } + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseLimitCommand() + if err != nil { + return err + } + selectCommand.LimitCommand = newCommand.(*ast.LimitCommand) + return nil +} + +func (parser *Parser) updateSelectCommandWithOrderByConstraints(sequence *ast.Sequence) error { + lastCommand, parserError := parser.getLastCommand(sequence, token.ORDER) + if parserError != nil { + return parserError + } + + if lastCommand.TokenLiteral() != token.SELECT { + return &SyntaxCommandExpectedError{command: "ORDER BY", neededCommands: []string{"SELECT"}} + } + + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseOrderByCommand() + if err != nil { + return err + } + selectCommand.OrderByCommand = newCommand.(*ast.OrderByCommand) + return nil +} + +func (parser *Parser) updateLastCommandWithWhereConstraints(sequence *ast.Sequence) error { + lastCommand, parserError := parser.getLastCommand(sequence, token.WHERE) + if parserError != nil { + return parserError + } + + if lastCommand.TokenLiteral() == token.SELECT { + newCommand, err := parser.parseWhereCommand() + if err != nil { + return err + } + lastCommand.(*ast.SelectCommand).WhereCommand = newCommand.(*ast.WhereCommand) + } else if lastCommand.TokenLiteral() == token.DELETE { + newCommand, err := parser.parseWhereCommand() + if err != nil { + return err + } + lastCommand.(*ast.DeleteCommand).WhereCommand = newCommand.(*ast.WhereCommand) + } else if lastCommand.TokenLiteral() == token.UPDATE { + newCommand, err := parser.parseWhereCommand() + if err != nil { + return err + } + lastCommand.(*ast.UpdateCommand).WhereCommand = newCommand.(*ast.WhereCommand) + } else { + return &SyntaxCommandExpectedError{command: "WHERE", neededCommands: []string{"SELECT", "DELETE", "UPDATE"}} + } + return nil +} + func (parser *Parser) getLastCommand(sequence *ast.Sequence, currentToken string) (ast.Command, error) { if len(sequence.Commands) == 0 { return nil, &NoPredecessorParserError{command: currentToken} diff --git a/parser/parser_test.go b/parser/parser_test.go index 0823da4..3a00ad8 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -908,7 +908,7 @@ func testSelectStatement(t *testing.T, command ast.Command, expectedTableName st } if !spaceArrayEquals(actualSelectCommand.Space, expectedSpaces) { - t.Errorf("actualSelectCommand has diffrent space than expected. %+v != %+v", actualSelectCommand.Space, expectedSpaces) + t.Errorf("actualSelectCommand has different space than expected. %+v != %+v", actualSelectCommand.Space, expectedSpaces) return false } diff --git a/token/token.go b/token/token.go index 381a168..6842c54 100644 --- a/token/token.go +++ b/token/token.go @@ -9,76 +9,79 @@ type Token struct { } const ( - // ASTERISK - Operators + // Operators ASTERISK = "*" - // IDENT - Identifiers + literals - IDENT = "IDENT" // tab, car, apple... - LITERAL = "LITERAL" // 1343456 + // Identifiers & Literals + IDENT = "IDENT" // e.g., table, column names + LITERAL = "LITERAL" // e.g., numeric or string literals - // COMMA - Delimiters - COMMA = "," - SEMICOLON = ";" - - // EOF - Special tokens - EOF = "" + // Delimiters + COMMA = "," + SEMICOLON = ";" APOSTROPHE = "'" - // LPAREN - Paren + // Parentheses LPAREN = "(" RPAREN = ")" - // CREATE - Keywords - CREATE = "CREATE" - DROP = "DROP" - TABLE = "TABLE" - INSERT = "INSERT" - INTO = "INTO" - VALUES = "VALUES" - SELECT = "SELECT" + // Special Tokens + EOF = "" + ILLEGAL = "ILLEGAL" + + // Commands + CREATE = "CREATE" + DROP = "DROP" + TABLE = "TABLE" + INSERT = "INSERT" + INTO = "INTO" + VALUES = "VALUES" + SELECT = "SELECT" + DELETE = "DELETE" + UPDATE = "UPDATE" + + // Clauses FROM = "FROM" WHERE = "WHERE" - DELETE = "DELETE" ORDER = "ORDER" BY = "BY" ASC = "ASC" DESC = "DESC" LIMIT = "LIMIT" OFFSET = "OFFSET" - UPDATE = "UPDATE" SET = "SET" DISTINCT = "DISTINCT" - JOIN = "JOIN" - INNER = "INNER" - FULL = "FULL" - LEFT = "LEFT" - RIGHT = "RIGHT" - ON = "ON" - MIN = "MIN" - MAX = "MAX" - COUNT = "COUNT" - SUM = "SUM" - AVG = "AVG" - IN = "IN" - NOTIN = "NOTIN" - NULL = "NULL" + TO = "TO" - TO = "TO" + // Joins + JOIN = "JOIN" + INNER = "INNER" + FULL = "FULL" + LEFT = "LEFT" + RIGHT = "RIGHT" + ON = "ON" - // EQUAL - Logical operations + // Aggregates + MIN = "MIN" + MAX = "MAX" + COUNT = "COUNT" + SUM = "SUM" + AVG = "AVG" + + // Logical EQUAL = "EQUAL" NOT = "NOT" AND = "AND" OR = "OR" TRUE = "TRUE" FALSE = "FALSE" + IN = "IN" + NOTIN = "NOTIN" + NULL = "NULL" - // TEXT - Data types + // Data Types TEXT = "TEXT" INT = "INT" - - // ILLEGAL - System - ILLEGAL = "ILLEGAL" ) var keywords = map[string]Type{