From a03ef26e7b54c53afb50275a7ecba83ac9648fac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Krupski?= <34219324+ixior462@users.noreply.github.com> Date: Fri, 17 Jan 2025 01:16:29 +0100 Subject: [PATCH] =?UTF-8?q?=20=F0=9F=8E=89=20Release=201.0.0=20?= =?UTF-8?q?=F0=9F=9A=80=20(#30)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * e2e version matrix update * Refactor that makes code shiny * Fix parser test pointer * Final refactor - improve select structure * Fix delete command * Add Drop Function * Add error handling for parser * second part of error handling * Adds LIMIT and OFFSET keywords :D * Add Update command * Feature/error handling tests (#20) * Parser error handling wip * Add error handling for engine and finish parser --------- Co-authored-by: LissaGreense * Add Distinct select implementation (#21) * Feature/joins (#22) JOIN FEATURE * Add full join implementation * Add full join error handling tests in engine * Add documentation, lexer, parser and only tests for engine, handling full,inner,left,right joins * Add enginge implemenatation and e2e tests for full,inner,left and right joins * Refactore join method * Aggregate functions (#23) * Add Aggregate function handling in lexer, parser and even write documentation * Aggregate functions engine WIP * Finish tests, fix implementation of aggr functions --------- Co-authored-by: ixior462 * Feature - "IN" and "NOTIN" condition (#24) * Parser, lexer, token and ast impl * engine, parser err handling, readme, tests for IN * Feature/null insertion (#25) * Tokens have been added * Parser and lexer logic has been updated. * Introduced a new section in README 'supported types' * Add engine integration with NULL values --------- Co-authored-by: LissaGreense * Add apostrophe validation errors (#26) * Add gopher to README.md (#28) * Refactore e2e tests (#27) * Refactore e2e tests structure * Update README in E2E section * Move e2e test to seperate file * Feature - apostrophe error validate (#29) * Add apostrophe validation errors * refactor * Rewrite getExpression logic --------- Co-authored-by: LissaGreense Co-authored-by: Sara RyfczyƄska --- .github/expected_results/end2end.txt | 37 - .github/workflows/docker-publish.yml | 6 +- .github/workflows/end2end-tests.yml | 10 +- .github/workflows/unit-tests.yml | 2 +- README.md | 250 ++++- ast/ast.go | 266 +++++- e2e/e2e_test.sh | 22 + .../1_select_with_where_expected_output | 42 + ...lect_with_limit_and_offset_expected_output | 20 + e2e/expected_outputs/3_delete_expected_output | 11 + .../4_orderby_expected_output | 11 + e2e/expected_outputs/5_update_expected_output | 12 + .../6_select_distinct_expected_output | 13 + .../7_drop_table_expected_output | 2 + .../8_select_with_join_expected_output | 30 + .../9_aggregate_functions_expected_output | 36 + e2e/test_files/1_select_with_where_test | 13 + .../2_select_with_limit_and_offset_test | 9 + e2e/test_files/3_delete_test | 9 + e2e/test_files/4_orderby_test | 8 + e2e/test_files/5_update_test | 9 + e2e/test_files/6_select_distinct_test | 9 + e2e/test_files/7_drop_table_test | 3 + e2e/test_files/8_select_with_join_test | 12 + e2e/test_files/9_aggregate_functions_test | 14 + engine/column.go | 20 +- engine/engine.go | 697 +++++++++++--- engine/engine_error_handling_test.go | 131 +++ engine/engine_test.go | 817 ++++++++++++++-- engine/engine_utils.go | 15 +- engine/errors.go | 106 +++ engine/generic_value.go | 110 ++- engine/generic_value_test.go | 63 +- engine/row.go | 35 + engine/table.go | 73 +- go.mod | 2 +- lexer/lexer_test.go | 404 ++++++-- main.go | 97 +- modes/handler.go | 74 +- parser/errors.go | 127 +++ parser/parser.go | 890 ++++++++++++++---- parser/parser_error_handling_test.go | 267 ++++++ parser/parser_test.go | 800 +++++++++++++++- test_file | 11 - token/token.go | 111 ++- 45 files changed, 4899 insertions(+), 807 deletions(-) delete mode 100644 .github/expected_results/end2end.txt create mode 100644 e2e/e2e_test.sh create mode 100644 e2e/expected_outputs/1_select_with_where_expected_output create mode 100644 e2e/expected_outputs/2_select_with_limit_and_offset_expected_output create mode 100644 e2e/expected_outputs/3_delete_expected_output create mode 100644 e2e/expected_outputs/4_orderby_expected_output create mode 100644 e2e/expected_outputs/5_update_expected_output create mode 100644 e2e/expected_outputs/6_select_distinct_expected_output create mode 100644 e2e/expected_outputs/7_drop_table_expected_output create mode 100644 e2e/expected_outputs/8_select_with_join_expected_output create mode 100644 e2e/expected_outputs/9_aggregate_functions_expected_output create mode 100644 e2e/test_files/1_select_with_where_test create mode 100644 e2e/test_files/2_select_with_limit_and_offset_test create mode 100644 e2e/test_files/3_delete_test create mode 100644 e2e/test_files/4_orderby_test create mode 100644 e2e/test_files/5_update_test create mode 100644 e2e/test_files/6_select_distinct_test create mode 100644 e2e/test_files/7_drop_table_test create mode 100644 e2e/test_files/8_select_with_join_test create mode 100644 e2e/test_files/9_aggregate_functions_test create mode 100644 engine/engine_error_handling_test.go create mode 100644 engine/errors.go create mode 100644 engine/row.go create mode 100644 parser/errors.go create mode 100644 parser/parser_error_handling_test.go delete mode 100644 test_file diff --git a/.github/expected_results/end2end.txt b/.github/expected_results/end2end.txt deleted file mode 100644 index 5a89e9a..0000000 --- a/.github/expected_results/end2end.txt +++ /dev/null @@ -1,37 +0,0 @@ -Table 'tbl' has been created -Data Inserted -Data Inserted -Data Inserted -+----------+-----+-------+------+ -| one | two | three | four | -+----------+-----+-------+------+ -| 'byebye' | 3 | 33 | 'e' | -+----------+-----+-------+------+ -+-----------+-------+ -| one | three | -+-----------+-------+ -| 'hello' | 11 | -| 'goodbye' | 22 | -+-----------+-------+ -+----------+-----+-------+------+ -| one | two | three | four | -+----------+-----+-------+------+ -| 'byebye' | 3 | 33 | 'e' | -+----------+-----+-------+------+ -+-----+-----+-------+------+ -| one | two | three | four | -+-----+-----+-------+------+ -+-----+-----+-------+------+ -Data from 'tbl' has been deleted -+-----------+-----+-------+------+ -| one | two | three | four | -+-----------+-----+-------+------+ -| 'hello' | 1 | 11 | 'q' | -| 'goodbye' | 1 | 22 | 'w' | -+-----------+-----+-------+------+ -+-----------+ -| one | -+-----------+ -| 'goodbye' | -| 'hello' | -+-----------+ diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 2b66e31..caff5a1 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -17,19 +17,19 @@ jobs: packages: write id-token: write steps: - + - name: Checkout repository uses: actions/checkout@v3 - name: Docker build run: docker build . --tag ${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }} - + - name: Docker login run: docker login -u ${{ secrets.DOCKERHUB_USERNAME }} -p ${{ secrets.DOCKERHUB_TOKEN }} - name: Docker tag run: docker image tag ${{ env.IMAGE_NAME }}:${{ env.IMAGE_TAG }} ${{ env.REGISTRY }}/${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }} - + - name: Docker push run: docker push ${{ env.REGISTRY }}/${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }} diff --git a/.github/workflows/end2end-tests.yml b/.github/workflows/end2end-tests.yml index 11dc859..8f6f0c9 100644 --- a/.github/workflows/end2end-tests.yml +++ b/.github/workflows/end2end-tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.16.15', '1.17.11' ] + go: [ '1.21.13', '1.22.7', '1.23.1' ] steps: - uses: actions/checkout@v3 @@ -25,8 +25,8 @@ jobs: - name: Build run: go build -v - - name: Run - run: ./GO4SQL -file test_file > output.txt + - name: Make Test Script Executable + run: chmod +x e2e/e2e_test.sh - - name: Check Result - run: diff output.txt ./.github/expected_results/end2end.txt + - name: Run Tests + run: e2e/e2e_test.sh diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 3648159..835a56b 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.16.15', '1.17.11' ] + go: [ '1.21.13', '1.22.7', '1.23.1' ] steps: - uses: actions/checkout@v3 diff --git a/README.md b/README.md index db942e2..817250b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +

+ gopher_GO4SQL +

+ # GO4SQL

@@ -22,19 +26,55 @@ You can compile the project with ``go build``, this will create ``GO4SQL`` binar Currently, there are 3 modes to chose from: -1. `File Mode` - You can specify file path with ``./GO4SQL -file file_path``, that will read the input - data directly into the program and print the result. - -2. `Stream Mode` - With ``./GO4SQL -stream`` you can run the program in stream mode, then you provide SQL commands +1. `File Mode` - You can specify file path with ``./GO4SQL -file file_path``, that will read the + input data directly into the program and print the result. In order to run one of e2e test files you can use: + ```shell + go build; ./GO4SQL -file e2e/test_files/1_select_with_where_test + ``` +2. `Stream Mode` - With ``./GO4SQL -stream`` you can run the program in stream mode, then you + provide SQL commands in your console (from standard input). -3. `Socket Mode` - To start Socket Server use `./GO4SQL -socket`, it will be listening on port `1433` by default. To - choose port different other than that, for example equal to `1444`, go with: `./GO4SQL -socket -port 1444` +3. `Socket Mode` - To start Socket Server use `./GO4SQL -socket`, it will be listening on port + `1433` by default. To + choose port different other than that, for example equal to `1444`, go with: + `./GO4SQL -socket -port 1444` + +## UNIT TESTS + +To run all the tests locally paste this in root directory: + +```shell +go clean -testcache; go test ./... +``` + +## E2E TESTS + +There are integrated with Github actions e2e tests that can be found in: `.github/workflows/end2end-tests.yml` file. +Tests run files inside `e2e/test_files` directory through `GO4SQL`, save stdout into files, and finally compare +then with expected outputs inside `e2e/expected_outputs` directory. + +To run e2e test locally, you can run script `./e2e/e2e_test.sh` if you're in the root directory. + +## Docker -### Docker 1. Pull docker image: `docker pull kajedot/go4sql:latest` -2. Run docker container in the interactive mode, remember to provide flag, for example: `docker run -i kajedot/go4sql -stream` -3. You can test this image with `test_file` provided in this repo: `docker run -i kajedot/go4sql -stream < test_file` +2. Run docker container in the interactive mode, remember to provide flag, for example: + `docker run -i kajedot/go4sql -stream` +3. You can test this image with `test_file` provided in this repo: + `docker run -i kajedot/go4sql -stream < test_file` + +## SUPPORTED TYPES + ++ **TEXT Type** - represents string values. Number or NULL can be converted to this type by wrapping + with apostrophes. Columns can store this type with **TEXT** keyword while using **CREATE** + command. ++ **NUMERIC Type** - represents integer values, columns can store this type with **INT** keyword + while using **CREATE** command. In general every digit-only value is interpreted as this type. ++ **NULL Type** - columns can't be assigned that type, but it can be used with **INSERT INTO**, + **UPDATE**, and inside **WHERE** statements, also it can be a product of **JOIN** commands + (besides **FULL JOIN**). In GO4SQL NULL is the smallest possible value, what means it can be + compared with other types with **EQUAL** and **NOT** statements. ## FUNCTIONALITY @@ -46,6 +86,14 @@ Currently, there are 3 modes to chose from: First column is called ``one`` and it contains strings (keyword ``TEXT``), second one is called ``two`` and it contains integers (keyword ``INT``). +* ***DROP TABLE*** - you can destroy the table of name ``table1`` using + command: + ```sql + DROP TABLE table1; + ``` + After using this command table1 will no longer be available and all data connected to it (column + definitions and inserted values) will be lost. + * ***INSERT INTO*** - you can insert values into table called ``table1`` with command: @@ -55,6 +103,14 @@ Currently, there are 3 modes to chose from: Please note that the number of arguments and types of the values must be the same as you declared with ``CREATE``. +* ***UPDATE*** - you can update values in table called ``table1`` with command: + ```sql + UPDATE table1 + SET column_name_1 TO new_value_1, column_name_2 TO new_value_2 + WHERE id EQUAL 1; + ``` + It will update all rows where column ``id`` is equal to ``1`` by replacing value in + ``column_name_1`` with ``new_value_1`` and ``column_name_2`` with ``new_value_2``. * ***SELECT FROM*** - you can either select everything from ``table1`` with: ```SELECT * FROM table1;``` @@ -75,6 +131,25 @@ Currently, there are 3 modes to chose from: ``` Supported logical operations are: ``EQUAL``, ``NOT``, ``OR``, ``AND``, ```FALSE```, ```TRUE```. +* ***IN*** - is used to check if a value from a column exists in a specified list of values. + It can be used with ``WHERE`` like this: + ```sql + SELECT column1, column2 + FROM table_name + WHERE column1 IN ('value1', 'value2'); + ``` + ``table_name`` is the name of the table, and ``WHERE`` returns rows that value is either equal to + ``value1`` or ``value2`` + +* ***NOTIN*** - is used to check if a value from a column doesn't exist in a specified list of + values. It can be used with ``WHERE`` like this: + ```sql + SELECT column1, column2 + FROM table_name + WHERE column1 NOTIN ('value1', 'value2'); + ``` + ``table_name`` is the name of the table, and ``WHERE`` returns rows which values are not equal to + ``value1`` and not equal to ``value2`` * ***DELETE FROM*** is used to delete existing records in a table. It can be used like this: ```sql @@ -94,19 +169,151 @@ Currently, there are 3 modes to chose from: In this case, this command will order by ``column1`` in ascending order, but if some rows have the same ``column1``, it orders them by column2 in descending order. -## UNIT TESTS +* ***LIMIT*** is used to reduce number of rows printed out by returning only specified number of + records with ``SELECT`` like this: + ```sql + SELECT column1, column2, + FROM table_name + ORDER BY column1 ASC + LIMIT 5; + ``` + In this case, this command will order by ``column1`` in ascending order and return 5 first + records. -To run all the tests locally run this in root directory: -```shell -go clean -testcache; go test ./... -``` +* ***OFFSET*** is used to reduce number of rows printed out by not skipping specified numbers of + rows in returned output with ``SELECT`` like this: + ```sql + SELECT column1, column2, + FROM table_name + ORDER BY column1 ASC + LIMIT 5 OFFSET 3; + ``` + In this case, this command will order by ``column1`` in ascending order and skip 3 first records, + then return records from 4th to 8th. -## E2E TEST +* ***DISTINCT*** is used to return only distinct (different) values in returned output with + ``SELECT`` like this: + ```sql + SELECT DISTINCT column1, column2, + FROM table_name; + ``` + In this case, this command will return only unique rows from ``table_name`` table. -In root directory there is **test_file** containing input commands for E2E tests. File -**.github/expected_results/end2end.txt** has expected results for it. -This is integrated into github workflows. +* ***INNER JOIN*** is used to return a new table by combining rows from both tables where there is a + match on the + specified condition. Only the rows that satisfy the condition from both tables are included in the + result. + Rows from either table that do not meet the condition are excluded from the result. + ```sql + SELECT * + FROM tableOne + JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + or + ```sql + SELECT * + FROM tableOne + INNER JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + In this case, this command will return all columns from tableOne and tableTwo for rows where the + condition + ``tableOne.columnY`` = ``tableTwo.columnX`` is met (i.e., the value of ``columnY`` in ``tableOne`` + is equal to the + value of ``columnX`` in ``tableTwo``). +* ***LEFT JOIN*** is used to return a new table that includes all records from the left table and + the matched records + from the right table. If there is no match, the result will contain empty values for columns from + the right table. + ```sql + SELECT * + FROM tableOne + LEFT JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + In this case, this command will return all columns from ``tableOne`` and the matching columns from + ``tableTwo``. For + rows in + ``tableOne`` that do not have a corresponding match in ``tableTwo``, the result will include empty + values for columns + from + ``tableTwo``. +* ***RIGHT JOIN*** is used to return a new table that includes all records from the right table and + the matched records + from the left table. If there is no match, the result will contain empty values for columns from + the left table. + ```sql + SELECT * + FROM tableOne + RIGHT JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + In this case, this command will return all columns from ``tableTwo`` and the matching columns from + ``tableOne``. For + rows in + ``tableTwo`` that do not have a corresponding match in ``tableOne``, the result will include empty + values for columns + from + ``tableOne``. + +* ***FULL JOIN*** is used to return a new table created by joining two tables as a whole. The + joined table contains all + records from both tables and fills empty values for missing matches on either side. This join + combines the results of + both ``LEFT JOIN`` and ``RIGHT JOIN``. + ```sql + SELECT * + FROM tableOne + FULL JOIN tableTwo + ON tableOne.columnY EQUAL tableTwo.columnX; + ``` + In this case, this command will return all columns from ``tableOne`` and ``tableTwo`` for rows + fulfilling condition + ``tableOne.columnY EQUAL tableTwo.columnX`` (value of ``columnY`` in ``tableOne`` is equal the + value of ``columnX`` in + ``tableTwo``). + +* ***MIN()*** is used to return the smallest value in a specified column. + ```sql + SELECT MIN(columnName) + FROM tableName; + ``` + In this case, this command will return the smallest value found in the column ``columnName`` of + ``tableName``. + +* ***MAX()*** is used to return the largest value in a specified column. + ```sql + SELECT MAX(columnName) + FROM tableName; + ``` + This command will return the largest value found in the column ``columnName`` of ``tableName``. + +* ***COUNT()*** is used to return the number of rows that match a given condition or the total + number of rows in a + specified column. + ```sql + SELECT COUNT(columnName) + FROM tableName; + ``` + This command will return the number of rows in the ``columnName`` of ``tableName``. + +* ***SUM()*** is used to return the total sum of the values in a specified numerical column. + ```sql + SELECT SUM(columnName) + FROM tableName; + ``` + This command will return the total sum of all values in the numerical column ``columnName`` of + ``tableName``. + +* ***AVG()*** is used to return the average of values in a specified numerical column. + ```sql + SELECT AVG(columnName) + FROM tableName; + ``` + This command will return the average of all values in the numerical column ``columnName`` of + ``tableName``. ## DOCKER @@ -117,13 +324,15 @@ docker build -t go4sql:test . ``` ### Run docker in interactive stream mode -To run this docker image in interactive stream mode mode use this command: + +To run this docker image in interactive stream mode use this command: ```shell docker run -i go4sql:test -stream ``` ### Run docker in socket mode + To run this docker image in socket mode use this command: ```shell @@ -131,6 +340,7 @@ docker run go4sql:test -socket ``` ### Run docker in file mode + **NOT RECOMMENDED** Alternatively you can run a docker image in file mode: @@ -144,12 +354,14 @@ docker run -i go4sql:test -file To create a pod deployment using helm chart, there is configuration under `./helm` directory. Commands: + ```shell cd ./helm helm install go4sql_pod_name GO4SQL/ ``` To check status of pod, use: + ```shell kubectl get pods ``` diff --git a/ast/ast.go b/ast/ast.go index 10244cf..5aa7657 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -29,12 +29,9 @@ type Command interface { // // Methods: // -// ExpressionNode: Abstraction needed for creating tree abstraction in order to optimise evaluating -// GetIdentifiers - Return array of pointers for all Identifiers within expression +// GetIdentifiers - Return array for all Identifiers within expression type Expression interface { - // ExpressionNode TODO: Check if ExpressionNode is needed - ExpressionNode() - GetIdentifiers() []*Identifier + GetIdentifiers() []Identifier } // Tifier - Interface that represent Token with string value @@ -73,7 +70,7 @@ type Anonymitifier struct { func (ls Anonymitifier) IsIdentifier() bool { return false } func (ls Anonymitifier) GetToken() token.Token { return ls.Token } -// BooleanExpression - Type of Expression that represent single boolean value +// BooleanExpression - TokenType of Expression that represent single boolean value // // Example: // TRUE @@ -81,13 +78,12 @@ type BooleanExpression struct { Boolean token.Token // example: token.TRUE } -func (ls BooleanExpression) ExpressionNode() {} -func (ls BooleanExpression) GetIdentifiers() []*Identifier { - var identifiers []*Identifier +func (ls BooleanExpression) GetIdentifiers() []Identifier { + var identifiers []Identifier return identifiers } -// ConditionExpression - Type of Expression that represent condition that is comparing value from column to static one +// ConditionExpression - TokenType of Expression that represent condition that is comparing value from column to static one // // Example: // column1 EQUAL 123 @@ -97,22 +93,35 @@ type ConditionExpression struct { Condition token.Token // example: token.EQUAL } -func (ls ConditionExpression) ExpressionNode() {} -func (ls ConditionExpression) GetIdentifiers() []*Identifier { - var identifiers []*Identifier +func (ls ConditionExpression) GetIdentifiers() []Identifier { + var identifiers []Identifier if ls.Left.IsIdentifier() { - identifiers = append(identifiers, &Identifier{ls.Left.GetToken()}) + identifiers = append(identifiers, Identifier{ls.Left.GetToken()}) } if ls.Right.IsIdentifier() { - identifiers = append(identifiers, &Identifier{ls.Right.GetToken()}) + identifiers = append(identifiers, Identifier{ls.Right.GetToken()}) } return identifiers } -// OperationExpression - Type of Expression that represent 2 other Expressions and conditional operation +// ContainExpression - TokenType of Expression that represents structure for IN operator +// +// Example: +// colName IN ('value1', 'value2', 'value3') +type ContainExpression struct { + Left Identifier // name of column + Right []Anonymitifier // name of column + Contains bool // IN or NOTIN +} + +func (ls ContainExpression) GetIdentifiers() []Identifier { + return []Identifier{ls.Left} +} + +// OperationExpression - TokenType of Expression that represent 2 other Expressions and conditional operation // // Example: // TRUE OR FALSE @@ -122,9 +131,8 @@ type OperationExpression struct { Operation token.Token // example: token.AND } -func (ls OperationExpression) ExpressionNode() {} -func (ls OperationExpression) GetIdentifiers() []*Identifier { - var identifiers []*Identifier +func (ls OperationExpression) GetIdentifiers() []Identifier { + var identifiers []Identifier identifiers = append(identifiers, ls.Left.GetIdentifiers()...) identifiers = append(identifiers, ls.Right.GetIdentifiers()...) @@ -138,7 +146,7 @@ func (ls OperationExpression) GetIdentifiers() []*Identifier { // CREATE TABLE table1( one TEXT , two INT); type CreateCommand struct { Token token.Token - Name *Identifier // name of the table + Name Identifier // name of the table ColumnNames []string ColumnTypes []token.Token } @@ -149,28 +157,166 @@ func (ls CreateCommand) TokenLiteral() string { return ls.Token.Literal } // InsertCommand - Part of Command that represent insertion of values into columns // // Example: -// INSERT INTO table1 VALUES( 'hello', 1); +// INSERT INTO table1 VALUES('hello', 1); type InsertCommand struct { Token token.Token - Name *Identifier // name of the table + Name Identifier // name of the table Values []token.Token } func (ls InsertCommand) CommandNode() {} func (ls InsertCommand) TokenLiteral() string { return ls.Token.Literal } +// Space - part of SelectCommand which is containing either * or a column name with an optional function aggregating it +type Space struct { + ColumnName token.Token + AggregateFunc *token.Token +} + +func (space Space) String() string { + columnName := "ColumnName={Type: " + string(space.ColumnName.Type) + ", Literal: " + space.ColumnName.Literal + "}" + if space.ContainsAggregateFunc() { + aggFunc := "AggregateFunc={Type: " + string(space.AggregateFunc.Type) + ", Literal: " + space.AggregateFunc.Literal + "}" + return columnName + ", " + aggFunc + } + return columnName +} + +// ContainsAggregateFunc - return true if space contains AggregateFunc that aggregate columnName or * +func (space Space) ContainsAggregateFunc() bool { + return space.AggregateFunc != nil +} + // SelectCommand - Part of Command that represent selecting values from tables // // Example: // SELECT one, two FROM table1; type SelectCommand struct { - Token token.Token - Name *Identifier - Space []token.Token // ex. column names + Token token.Token + Name Identifier // ex. name of table + Space []Space // ex. column names + HasDistinct bool // DISTINCT keyword has been used + WhereCommand *WhereCommand // optional + OrderByCommand *OrderByCommand // optional + LimitCommand *LimitCommand // optional + OffsetCommand *OffsetCommand // optional + JoinCommand *JoinCommand // optional } func (ls SelectCommand) CommandNode() {} func (ls SelectCommand) TokenLiteral() string { return ls.Token.Literal } +func (ls *SelectCommand) AggregateFunctionAppears() bool { + for _, space := range ls.Space { + if space.ContainsAggregateFunc() { + return true + } + } + return false +} + +// HasWhereCommand - returns true if optional HasWhereCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table WHERE column1 NOT 'hi'; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls SelectCommand) HasWhereCommand() bool { + if ls.WhereCommand == nil { + return false + } + return true +} + +// HasOrderByCommand - returns true if optional OrderByCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table ORDER BY column1 ASC; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls SelectCommand) HasOrderByCommand() bool { + if ls.OrderByCommand == nil { + return false + } + return true +} + +// HasLimitCommand - returns true if optional LimitCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table LIMIT 5; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls SelectCommand) HasLimitCommand() bool { + if ls.LimitCommand == nil { + return false + } + return true +} + +// HasOffsetCommand - returns true if optional OffsetCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table OFFSET 100; +// Returns true +// +// SELECT * FROM table LIMIT 10; +// Returns false +func (ls SelectCommand) HasOffsetCommand() bool { + if ls.OffsetCommand == nil { + return false + } + return true +} + +// HasJoinCommand - returns true if optional JoinCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table JOIN table2 ON table.one EQUAL table2.two; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls SelectCommand) HasJoinCommand() bool { + if ls.JoinCommand == nil { + return false + } + return true +} + +// UpdateCommand - Part of Command that allow to change existing data +// +// Example: +// UPDATE table SET col1 TO 2 WHERE column1 NOT 'hi'; +type UpdateCommand struct { + Token token.Token + Name Identifier // ex. name of table + Changes map[token.Token]Anonymitifier // column names with new values + WhereCommand *WhereCommand // optional +} + +func (ls UpdateCommand) CommandNode() {} +func (ls UpdateCommand) TokenLiteral() string { return ls.Token.Literal } + +// HasWhereCommand - returns true if optional HasWhereCommand is present in UpdateCommand +// +// Example: +// UPDATE table SET col1 TO 2 WHERE column1 NOT 'hi'; +// Returns true +// +// UPDATE table SET col1 TO 2; +// Returns false +func (ls UpdateCommand) HasWhereCommand() bool { + if ls.WhereCommand == nil { + return false + } + return true +} // WhereCommand - Part of Command that represent Where statement with expression that will qualify values from Select // @@ -184,18 +330,66 @@ type WhereCommand struct { func (ls WhereCommand) CommandNode() {} func (ls WhereCommand) TokenLiteral() string { return ls.Token.Literal } +// JoinCommand - Part of Command that represent JOIN statement with expression that will merge tables +// +// Example: +// JOIN tbl2 ON tbl1.id EQUAL tbl2.f_idy; +type JoinCommand struct { + Token token.Token + Name Identifier // ex. name of table + JoinType token.Token + Expression Expression +} + +func (ls JoinCommand) CommandNode() {} +func (ls JoinCommand) TokenLiteral() string { return ls.Token.Literal } +func (ls JoinCommand) ShouldTakeLeftSide() bool { + return ls.JoinType.Type == token.LEFT || ls.JoinType.Type == token.FULL +} +func (ls JoinCommand) ShouldTakeRightSide() bool { + return ls.JoinType.Type == token.RIGHT || ls.JoinType.Type == token.FULL +} + // DeleteCommand - Part of Command that represent deleting row from table // // Example: // DELETE FROM tb1 WHERE two EQUAL 3; type DeleteCommand struct { - Token token.Token - Name *Identifier // name of the table + Token token.Token + Name Identifier // name of the table + WhereCommand *WhereCommand // optional } func (ls DeleteCommand) CommandNode() {} func (ls DeleteCommand) TokenLiteral() string { return ls.Token.Literal } +// DropCommand - Part of Command that represent dropping table +// +// Example: +// DROP TABLE table; +type DropCommand struct { + Token token.Token + Name Identifier // name of the table +} + +func (ls DropCommand) CommandNode() {} +func (ls DropCommand) TokenLiteral() string { return ls.Token.Literal } + +// HasWhereCommand - returns true if optional HasWhereCommand is present in SelectCommand +// +// Example: +// SELECT * FROM table WHERE column1 NOT 'hi'; +// Returns true +// +// SELECT * FROM table; +// Returns false +func (ls DeleteCommand) HasWhereCommand() bool { + if ls.WhereCommand == nil { + return false + } + return true +} + // OrderByCommand - Part of Command that ordering columns from SelectCommand // // Example: @@ -213,3 +407,21 @@ type SortPattern struct { ColumnName token.Token // column name Order token.Token // ASC or DESC } + +// LimitCommand - Part of Command that limits results from SelectCommand +type LimitCommand struct { + Token token.Token + Count int +} + +func (ls LimitCommand) CommandNode() {} +func (ls LimitCommand) TokenLiteral() string { return ls.Token.Literal } + +// OffsetCommand - Part of Command that skip begging rows from SelectCommand +type OffsetCommand struct { + Token token.Token + Count int +} + +func (ls OffsetCommand) CommandNode() {} +func (ls OffsetCommand) TokenLiteral() string { return ls.Token.Literal } diff --git a/e2e/e2e_test.sh b/e2e/e2e_test.sh new file mode 100644 index 0000000..9dff5ef --- /dev/null +++ b/e2e/e2e_test.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +e2e_failed=false + +for test_file in e2e/test_files/*_test; do + output_file="./e2e/$(basename "${test_file/_test/_output}")" + ./GO4SQL -file "$test_file" > "$output_file" + expected_output="e2e/expected_outputs/$(basename "${test_file/_test/_expected_output}")" + diff "$output_file" "$expected_output" + if [ $? -ne 0 ]; then + echo "E2E test for: {$test_file} failed" + e2e_failed=true + fi + rm "./$output_file" +done + +if [ "$e2e_failed" = true ]; then + echo "E2E tests failed." + exit 1 +else + echo "All E2E tests passed." +fi diff --git a/e2e/expected_outputs/1_select_with_where_expected_output b/e2e/expected_outputs/1_select_with_where_expected_output new file mode 100644 index 0000000..3a89f6a --- /dev/null +++ b/e2e/expected_outputs/1_select_with_where_expected_output @@ -0,0 +1,42 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted ++----------+------+-------+------+ +| one | two | three | four | ++----------+------+-------+------+ +| 'byebye' | NULL | 33 | 'e' | ++----------+------+-------+------+ ++-----------+-------+ +| one | three | ++-----------+-------+ +| 'hello' | 11 | +| 'goodbye' | 22 | ++-----------+-------+ ++----------+------+-------+------+ +| one | two | three | four | ++----------+------+-------+------+ +| 'byebye' | NULL | 33 | 'e' | ++----------+------+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ ++---------+-----+-------+------+ +| one | two | three | four | ++---------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | ++---------+-----+-------+------+ ++-----+-----+-------+------+ +| one | two | three | four | ++-----+-----+-------+------+ ++-----+-----+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ diff --git a/e2e/expected_outputs/2_select_with_limit_and_offset_expected_output b/e2e/expected_outputs/2_select_with_limit_and_offset_expected_output new file mode 100644 index 0000000..cea7b84 --- /dev/null +++ b/e2e/expected_outputs/2_select_with_limit_and_offset_expected_output @@ -0,0 +1,20 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted ++---------+-----+-------+------+ +| one | two | three | four | ++---------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | ++---------+-----+-------+------+ ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ ++-----------+-----+-------+------+ +| one | two | three | four | ++-----------+-----+-------+------+ +| 'goodbye' | 1 | 22 | 'w' | ++-----------+-----+-------+------+ diff --git a/e2e/expected_outputs/3_delete_expected_output b/e2e/expected_outputs/3_delete_expected_output new file mode 100644 index 0000000..f0c6911 --- /dev/null +++ b/e2e/expected_outputs/3_delete_expected_output @@ -0,0 +1,11 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted +Data from 'tbl' has been deleted ++-----------+-----+-------+------+ +| one | two | three | four | ++-----------+-----+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | 1 | 22 | 'w' | ++-----------+-----+-------+------+ diff --git a/e2e/expected_outputs/4_orderby_expected_output b/e2e/expected_outputs/4_orderby_expected_output new file mode 100644 index 0000000..b920afb --- /dev/null +++ b/e2e/expected_outputs/4_orderby_expected_output @@ -0,0 +1,11 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted ++-----------+ +| one | ++-----------+ +| 'byebye' | +| 'goodbye' | +| 'hello' | ++-----------+ diff --git a/e2e/expected_outputs/5_update_expected_output b/e2e/expected_outputs/5_update_expected_output new file mode 100644 index 0000000..6094e0a --- /dev/null +++ b/e2e/expected_outputs/5_update_expected_output @@ -0,0 +1,12 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted +Table: 'tbl' has been updated ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | NULL | 22 | 'P' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ diff --git a/e2e/expected_outputs/6_select_distinct_expected_output b/e2e/expected_outputs/6_select_distinct_expected_output new file mode 100644 index 0000000..7ed7fb7 --- /dev/null +++ b/e2e/expected_outputs/6_select_distinct_expected_output @@ -0,0 +1,13 @@ +Table 'tbl' has been created +Data Inserted +Data Inserted +Data Inserted +Data Inserted +Data Inserted ++-----------+------+-------+------+ +| one | two | three | four | ++-----------+------+-------+------+ +| 'hello' | 1 | 11 | 'q' | +| 'goodbye' | 1 | 22 | 'w' | +| 'byebye' | NULL | 33 | 'e' | ++-----------+------+-------+------+ diff --git a/e2e/expected_outputs/7_drop_table_expected_output b/e2e/expected_outputs/7_drop_table_expected_output new file mode 100644 index 0000000..b0852d5 --- /dev/null +++ b/e2e/expected_outputs/7_drop_table_expected_output @@ -0,0 +1,2 @@ +Table 'tbl' has been created +Table: 'tbl' has been dropped diff --git a/e2e/expected_outputs/8_select_with_join_expected_output b/e2e/expected_outputs/8_select_with_join_expected_output new file mode 100644 index 0000000..6e9c1df --- /dev/null +++ b/e2e/expected_outputs/8_select_with_join_expected_output @@ -0,0 +1,30 @@ +Table 'table1' has been created +Table 'table2' has been created +Data Inserted +Data Inserted +Data Inserted +Data Inserted ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| 'Value1' | NULL | +| NULL | 'Value2' | +| NULL | 'Value3' | ++--------------+--------------+ ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| NULL | 'Value2' | ++--------------+--------------+ ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| 'Value1' | NULL | +| NULL | 'Value2' | ++--------------+--------------+ ++--------------+--------------+ +| table1.value | table2.value | ++--------------+--------------+ +| NULL | 'Value2' | +| NULL | 'Value3' | ++--------------+--------------+ diff --git a/e2e/expected_outputs/9_aggregate_functions_expected_output b/e2e/expected_outputs/9_aggregate_functions_expected_output new file mode 100644 index 0000000..4be149e --- /dev/null +++ b/e2e/expected_outputs/9_aggregate_functions_expected_output @@ -0,0 +1,36 @@ +Table 'table1' has been created +Table 'table2' has been created +Data Inserted +Data Inserted +Data Inserted +Data Inserted ++---------+------------+ +| MAX(id) | MAX(value) | ++---------+------------+ +| 2 | Value1 | ++---------+------------+ ++------------+---------+ +| MIN(value) | MIN(id) | ++------------+---------+ +| NULL | 1 | ++------------+---------+ ++----------+-----------+--------------+ +| COUNT(*) | COUNT(id) | COUNT(value) | ++----------+-----------+--------------+ +| 2 | 2 | 1 | ++----------+-----------+--------------+ ++---------+------------+ +| SUM(id) | SUM(value) | ++---------+------------+ +| 3 | 0 | ++---------+------------+ ++---------+------------+ +| AVG(id) | AVG(value) | ++---------+------------+ +| 1 | 0 | ++---------+------------+ ++---------+----+ +| AVG(id) | id | ++---------+----+ +| 1 | 1 | ++---------+----+ diff --git a/e2e/test_files/1_select_with_where_test b/e2e/test_files/1_select_with_where_test new file mode 100644 index 0000000..00156cb --- /dev/null +++ b/e2e/test_files/1_select_with_where_test @@ -0,0 +1,13 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +SELECT * FROM tbl WHERE one EQUAL 'byebye'; +SELECT one, three FROM tbl WHERE two NOT NULL; +SELECT * FROM tbl WHERE one NOT 'goodbye' AND two EQUAL NULL; +SELECT * FROM tbl WHERE one IN ('goodbye', 'byebye'); +SELECT * FROM tbl WHERE one NOTIN ('goodbye', 'byebye'); +SELECT * FROM tbl WHERE FALSE; +SELECT * FROM tbl WHERE 'colName1 EQUAL;' EQUAL 'colName1 EQUAL;'; diff --git a/e2e/test_files/2_select_with_limit_and_offset_test b/e2e/test_files/2_select_with_limit_and_offset_test new file mode 100644 index 0000000..51902d1 --- /dev/null +++ b/e2e/test_files/2_select_with_limit_and_offset_test @@ -0,0 +1,9 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +SELECT * FROM tbl LIMIT 1; +SELECT * FROM tbl OFFSET 1; +SELECT * FROM tbl LIMIT 1 OFFSET 1; diff --git a/e2e/test_files/3_delete_test b/e2e/test_files/3_delete_test new file mode 100644 index 0000000..008c84a --- /dev/null +++ b/e2e/test_files/3_delete_test @@ -0,0 +1,9 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +DELETE FROM tbl WHERE one EQUAL 'byebye'; + +SELECT * FROM tbl; diff --git a/e2e/test_files/4_orderby_test b/e2e/test_files/4_orderby_test new file mode 100644 index 0000000..59c27ba --- /dev/null +++ b/e2e/test_files/4_orderby_test @@ -0,0 +1,8 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +SELECT one FROM tbl WHERE TRUE ORDER BY two ASC, four DESC; + diff --git a/e2e/test_files/5_update_test b/e2e/test_files/5_update_test new file mode 100644 index 0000000..4325950 --- /dev/null +++ b/e2e/test_files/5_update_test @@ -0,0 +1,9 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +UPDATE tbl SET two TO NULL, four TO 'P' WHERE one EQUAL 'goodbye'; + +SELECT * FROM tbl; diff --git a/e2e/test_files/6_select_distinct_test b/e2e/test_files/6_select_distinct_test new file mode 100644 index 0000000..0e8a880 --- /dev/null +++ b/e2e/test_files/6_select_distinct_test @@ -0,0 +1,9 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +INSERT INTO tbl VALUES( 'hello',1, 11, 'q' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); +INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); +INSERT INTO tbl VALUES( 'byebye', NULL, 33,'e' ); + +SELECT DISTINCT * FROM tbl; diff --git a/e2e/test_files/7_drop_table_test b/e2e/test_files/7_drop_table_test new file mode 100644 index 0000000..e9b24c3 --- /dev/null +++ b/e2e/test_files/7_drop_table_test @@ -0,0 +1,3 @@ +CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); + +DROP TABLE tbl; diff --git a/e2e/test_files/8_select_with_join_test b/e2e/test_files/8_select_with_join_test new file mode 100644 index 0000000..5362818 --- /dev/null +++ b/e2e/test_files/8_select_with_join_test @@ -0,0 +1,12 @@ +CREATE TABLE table1( id INT, value TEXT); +CREATE TABLE table2( id INT, value TEXT); + +INSERT INTO table1 VALUES(1, 'Value1'); +INSERT INTO table1 VALUES(2, NULL); +INSERT INTO table2 VALUES(2, 'Value2'); +INSERT INTO table2 VALUES(3, 'Value3'); + +SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id; +SELECT table1.value, table2.value FROM table1 INNER JOIN table2 ON table1.id EQUAL table2.id; +SELECT table1.value, table2.value FROM table1 LEFT JOIN table2 ON table1.id EQUAL table2.id; +SELECT table1.value, table2.value FROM table1 RIGHT JOIN table2 ON table1.id EQUAL table2.id; diff --git a/e2e/test_files/9_aggregate_functions_test b/e2e/test_files/9_aggregate_functions_test new file mode 100644 index 0000000..07b1262 --- /dev/null +++ b/e2e/test_files/9_aggregate_functions_test @@ -0,0 +1,14 @@ +CREATE TABLE table1( id INT, value TEXT); +CREATE TABLE table2( id INT, value TEXT); + +INSERT INTO table1 VALUES(1, 'Value1'); +INSERT INTO table1 VALUES(2, NULL); +INSERT INTO table2 VALUES(2, 'Value2'); +INSERT INTO table2 VALUES(3, 'Value3'); + +SELECT MAX(id), MAX(value) FROM table1; +SELECT MIN(value), MIN(id) FROM table1; +SELECT COUNT(*), COUNT(id), COUNT(value) FROM table1; +SELECT SUM(id), SUM(value) FROM table1; +SELECT AVG(id), AVG(value) FROM table1; +SELECT AVG(id), id FROM table1; diff --git a/engine/column.go b/engine/column.go index 2837ee1..419b157 100644 --- a/engine/column.go +++ b/engine/column.go @@ -1,8 +1,6 @@ package engine import ( - "log" - "github.com/LissaGreense/GO4SQL/token" ) @@ -13,28 +11,32 @@ type Column struct { Values []ValueInterface } -func extractColumnContent(columns []*Column, wantedColumnNames []string) *Table { +func extractColumnContent(columns []*Column, wantedColumnNames *[]string, tableName string) (*Table, error) { selectedTable := &Table{Columns: make([]*Column, 0)} mappedIndexes := make([]int, 0) - for wantedColumnIndex := 0; wantedColumnIndex < len(wantedColumnNames); wantedColumnIndex++ { - for columnNameIndex := 0; columnNameIndex < len(columns); columnNameIndex++ { - if columns[columnNameIndex].Name == wantedColumnNames[wantedColumnIndex] { + for wantedColumnIndex := range *wantedColumnNames { + for columnNameIndex := range columns { + if columns[columnNameIndex].Name == (*wantedColumnNames)[wantedColumnIndex] { mappedIndexes = append(mappedIndexes, columnNameIndex) break } if columnNameIndex == len(columns)-1 { - log.Fatal("Provided column name: " + wantedColumnNames[wantedColumnIndex] + " doesn't exist") + return nil, &ColumnDoesNotExistError{columnName: (*wantedColumnNames)[wantedColumnIndex], tableName: tableName} } } } - for i := 0; i < len(mappedIndexes); i++ { + for i := range mappedIndexes { selectedTable.Columns = append(selectedTable.Columns, &Column{ Name: columns[mappedIndexes[i]].Name, Type: columns[mappedIndexes[i]].Type, Values: make([]ValueInterface, 0), }) } + if len(columns) == 0 { + return selectedTable, nil + } + rowsCount := len(columns[0].Values) for iRow := 0; iRow < rowsCount; iRow++ { @@ -43,5 +45,5 @@ func extractColumnContent(columns []*Column, wantedColumnNames []string) *Table append(selectedTable.Columns[iColumn].Values, columns[mappedIndexes[iColumn]].Values[iRow]) } } - return selectedTable + return selectedTable, nil } diff --git a/engine/engine.go b/engine/engine.go index c644f68..4ddf188 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -1,9 +1,8 @@ package engine import ( - "errors" "fmt" - "log" + "maps" "sort" "strconv" @@ -12,22 +11,149 @@ import ( ) type DbEngine struct { - Tables map[string]*Table + Tables Tables } +type Tables map[string]*Table // New Return new DbEngine struct func New() *DbEngine { engine := &DbEngine{} - engine.Tables = make(map[string]*Table) + engine.Tables = make(Tables) + return engine } -// CreateTable - initialize new table in engine with specified name -func (engine *DbEngine) CreateTable(command *ast.CreateCommand) { +// Evaluate - it takes sequences, map them to specific implementation and then process it in SQL engine +func (engine *DbEngine) Evaluate(sequences *ast.Sequence) (string, error) { + commands := sequences.Commands + + result := "" + for _, command := range commands { + + switch mappedCommand := command.(type) { + case *ast.WhereCommand: + continue + case *ast.OrderByCommand: + continue + case *ast.LimitCommand: + continue + case *ast.OffsetCommand: + continue + case *ast.JoinCommand: + continue + case *ast.CreateCommand: + err := engine.createTable(mappedCommand) + if err != nil { + return "", err + } + result += "Table '" + mappedCommand.Name.GetToken().Literal + "' has been created\n" + continue + case *ast.InsertCommand: + err := engine.insertIntoTable(mappedCommand) + if err != nil { + return "", err + } + result += "Data Inserted\n" + continue + case *ast.SelectCommand: + selectOutput, err := engine.getSelectResponse(mappedCommand) + if err != nil { + return "", err + } + result += selectOutput.ToString() + "\n" + continue + case *ast.DeleteCommand: + deleteCommand := command.(*ast.DeleteCommand) + if deleteCommand.HasWhereCommand() { + err := engine.deleteFromTable(mappedCommand, deleteCommand.WhereCommand) + if err != nil { + return "", err + } + } + result += "Data from '" + mappedCommand.Name.GetToken().Literal + "' has been deleted\n" + continue + case *ast.DropCommand: + engine.dropTable(mappedCommand) + result += "Table: '" + mappedCommand.Name.GetToken().Literal + "' has been dropped\n" + continue + case *ast.UpdateCommand: + err := engine.updateTable(mappedCommand) + if err != nil { + return "", err + } + result += "Table: '" + mappedCommand.Name.GetToken().Literal + "' has been updated\n" + continue + default: + return "", &UnsupportedCommandTypeFromParserError{variable: fmt.Sprintf("%s", command)} + } + } + + return result, nil +} + +// getSelectResponse - Returns Select response basing on ast.OrderByCommand and ast.WhereCommand included in this Select +func (engine *DbEngine) getSelectResponse(selectCommand *ast.SelectCommand) (*Table, error) { + var table *Table + var err error + + if selectCommand.HasJoinCommand() { + joinCommand := selectCommand.JoinCommand + table, err = engine.joinTables(joinCommand, selectCommand.Name.Token.Literal) + if err != nil { + return nil, err + } + } else { + var exist bool + table, exist = engine.Tables[selectCommand.Name.Token.Literal] + + if !exist { + return nil, &TableDoesNotExistError{selectCommand.Name.Token.Literal} + } + } + + if selectCommand.HasWhereCommand() { + whereCommand := selectCommand.WhereCommand + if selectCommand.HasOrderByCommand() { + orderByCommand := selectCommand.OrderByCommand + table, err = engine.selectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand, table) + if err != nil { + return nil, err + } + } else { + table, err = engine.selectFromTableWithWhere(selectCommand, whereCommand, table) + if err != nil { + return nil, err + } + } + } else if selectCommand.HasOrderByCommand() { + table, err = engine.selectFromTableWithOrderBy(selectCommand, selectCommand.OrderByCommand, table) + if err != nil { + return nil, err + } + } else { + table, err = engine.selectFromProvidedTable(selectCommand, table) + if err != nil { + return nil, err + } + } + + if selectCommand.HasLimitCommand() || selectCommand.HasOffsetCommand() { + table.applyOffsetAndLimit(selectCommand) + } + + if selectCommand.HasDistinct { + table = table.getDistinctTable() + } + + return table, nil +} + +// createTable - initialize new table in engine with specified name +func (engine *DbEngine) createTable(command *ast.CreateCommand) error { _, exist := engine.Tables[command.Name.Token.Literal] if exist { - log.Fatal("Table with the name of " + command.Name.Token.Literal + " already exist!") + return &TableAlreadyExistsError{command.Name.Token.Literal} } engine.Tables[command.Name.Token.Literal] = &Table{Columns: []*Column{}} @@ -39,121 +165,311 @@ func (engine *DbEngine) CreateTable(command *ast.CreateCommand) { Name: columnName, }) } + return nil } -// InsertIntoTable - Insert row of values into the table -func (engine *DbEngine) InsertIntoTable(command *ast.InsertCommand) { +func (engine *DbEngine) updateTable(command *ast.UpdateCommand) error { table, exist := engine.Tables[command.Name.Token.Literal] + if !exist { - log.Fatal("Table with the name of " + command.Name.Token.Literal + " doesn't exist!") + return &TableDoesNotExistError{command.Name.Token.Literal} } columns := table.Columns - if len(command.Values) != len(columns) { - log.Fatal("Invalid number of parameters in insert, should be: " + strconv.Itoa(len(columns)) + ", but got: " + strconv.Itoa(len(columns))) + // TODO: This could be optimized + mappedChanges := make(map[int]ast.Anonymitifier) + for updatedCol, newValue := range command.Changes { + for colIndex := 0; colIndex < len(columns); colIndex++ { + if columns[colIndex].Name == updatedCol.Literal { + mappedChanges[colIndex] = newValue + break + } + if colIndex == len(columns)-1 { + return &ColumnDoesNotExistError{tableName: command.Name.GetToken().Literal, columnName: updatedCol.Literal} + } + } } - for i := 0; i < len(columns); i++ { - expectedToken := tokenMapper(columns[i].Type.Type) - if expectedToken != command.Values[i].Type { - log.Fatal("Invalid Token Type in Insert Command, expecting: " + expectedToken + ", got: " + command.Values[i].Type) + numberOfRows := len(columns[0].Values) + for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { + if command.HasWhereCommand() { + fulfilledFilters, err := isFulfillingFilters(getRow(table, rowIndex), command.WhereCommand.Expression, command.WhereCommand.Token.Literal) + if err != nil { + return err + } + if !fulfilledFilters { + continue + } + } + for colIndex, value := range mappedChanges { + interfaceValue, err := getInterfaceValue(value.GetToken()) + if err != nil { + return err + } + table.Columns[colIndex].Values[rowIndex] = interfaceValue } - columns[i].Values = append(columns[i].Values, getInterfaceValue(command.Values[i])) } + + return nil } -// SelectFromTable - Return Table containing all values requested by SelectCommand -func (engine *DbEngine) SelectFromTable(command *ast.SelectCommand) *Table { +// insertIntoTable - Insert row of values into the table +func (engine *DbEngine) insertIntoTable(command *ast.InsertCommand) error { table, exist := engine.Tables[command.Name.Token.Literal] - if !exist { - log.Fatal("Table with the name of " + command.Name.Token.Literal + " doesn't exist!") + return &TableDoesNotExistError{command.Name.Token.Literal} } - return engine.selectFromProvidedTable(command, table) + columns := table.Columns + + if len(command.Values) != len(columns) { + return &InvalidNumberOfParametersError{expectedNumber: len(columns), actualNumber: len(command.Values), commandName: command.Token.Literal} + } + + for i := range columns { + expectedToken := tokenMapper(columns[i].Type.Type) + if (expectedToken != command.Values[i].Type) && (command.Values[i].Type != token.NULL) { + return &InvalidValueTypeError{expectedType: string(expectedToken), actualType: string(command.Values[i].Type), commandName: command.Token.Literal} + } + interfaceValue, err := getInterfaceValue(command.Values[i]) + if err != nil { + return err + } + columns[i].Values = append(columns[i].Values, interfaceValue) + } + return nil } -func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, table *Table) *Table { +func (engine *DbEngine) selectFromProvidedTable(command *ast.SelectCommand, table *Table) (*Table, error) { columns := table.Columns wantedColumnNames := make([]string, 0) - if command.Space[0].Type == token.ASTERISK { + if command.AggregateFunctionAppears() { + selectedTable := &Table{Columns: make([]*Column, 0)} + + for i := 0; i < len(command.Space); i++ { + var columnType token.Token + var columnName string + var columnValues []ValueInterface + var err error + value := make([]ValueInterface, 0) + currentSpace := command.Space[i] + + if currentSpace.ColumnName.Type == token.ASTERISK && currentSpace.AggregateFunc.Type == token.COUNT { + if len(columns) > 0 { + columnValues = columns[0].Values + } + } else { + columnValues, err = getValuesOfColumn(currentSpace.ColumnName.Literal, columns) + } + + if err != nil { + return nil, err + } + + if currentSpace.ContainsAggregateFunc() { + columnName = fmt.Sprintf("%s(%s)", currentSpace.AggregateFunc.Literal, + currentSpace.ColumnName.Literal) + columnType = evaluateColumnTypeOfAggregateFunc(currentSpace) + aggregatedValue, aggregateErr := aggregateColumnContent(currentSpace, columnValues) + if aggregateErr != nil { + return nil, aggregateErr + } + value = append(value, aggregatedValue) + } else { + columnName = currentSpace.ColumnName.Literal + columnType = currentSpace.ColumnName + value = append(value, columnValues[0]) + } + + selectedTable.Columns = append(selectedTable.Columns, &Column{ + Name: columnName, + Type: columnType, + Values: value, + }) + } + return selectedTable, nil + } else if command.Space[0].ColumnName.Type == token.ASTERISK { for i := 0; i < len(columns); i++ { wantedColumnNames = append(wantedColumnNames, columns[i].Name) } - return extractColumnContent(columns, wantedColumnNames) + return extractColumnContent(columns, &wantedColumnNames, command.Name.GetToken().Literal) } else { for i := 0; i < len(command.Space); i++ { - wantedColumnNames = append(wantedColumnNames, command.Space[i].Literal) + wantedColumnNames = append(wantedColumnNames, command.Space[i].ColumnName.Literal) } - return extractColumnContent(columns, unique(wantedColumnNames)) + return extractColumnContent(columns, unique(wantedColumnNames), command.Name.GetToken().Literal) } } -// DeleteFromTable - Delete all rows of data from table that match given condition -func (engine *DbEngine) DeleteFromTable(deleteCommand *ast.DeleteCommand, whereCommand *ast.WhereCommand) { - table, exist := engine.Tables[deleteCommand.Name.Token.Literal] +func getValuesOfColumn(columnName string, columns []*Column) ([]ValueInterface, error) { + wantedColumnName := []string{columnName} + columnContent, err := extractColumnContent(columns, &wantedColumnName, "") + if err != nil { + return nil, err + } + return columnContent.Columns[0].Values, nil +} - if !exist { - log.Fatal("Table with the name of " + deleteCommand.Name.Token.Literal + " doesn't exist!") +func evaluateColumnTypeOfAggregateFunc(space ast.Space) token.Token { + if space.AggregateFunc.Type == token.MIN || + space.AggregateFunc.Type == token.MAX { + return space.ColumnName } + return token.Token{Type: token.INT, Literal: "INT"} +} - engine.Tables[deleteCommand.Name.Token.Literal] = engine.getFilteredTable(table, whereCommand, true) +func aggregateColumnContent(space ast.Space, columnValues []ValueInterface) (ValueInterface, error) { + if space.AggregateFunc.Type == token.COUNT { + if space.ColumnName.Type == token.ASTERISK { + return IntegerValue{Value: len(columnValues)}, nil + } + count := 0 + for _, value := range columnValues { + if value.GetType() != NullType { + count++ + } + } + return IntegerValue{Value: count}, nil + } + if len(columnValues) == 0 { + return NullValue{}, nil + } + switch space.AggregateFunc.Type { + case token.MAX: + maxValue, err := getMax(columnValues) + if err != nil { + return nil, err + } + return maxValue, nil + case token.MIN: + minValue, err := getMin(columnValues) + if err != nil { + return nil, err + } + return minValue, nil + case token.SUM: + if columnValues[0].GetType() == StringType { + return IntegerValue{Value: 0}, nil + } else { + sum := 0 + for _, value := range columnValues { + if value.GetType() != NullType { + num, err := strconv.Atoi(value.ToString()) + if err != nil { + return nil, err + } + sum += num + } + } + return IntegerValue{Value: sum}, nil + } + default: + if columnValues[0].GetType() == StringType { + return IntegerValue{Value: 0}, nil + } else { + sum := 0 + for _, value := range columnValues { + num, err := strconv.Atoi(value.ToString()) + if err != nil { + return nil, err + } + sum += num + } + return IntegerValue{Value: sum / len(columnValues)}, nil + } + } } -// SelectFromTableWithWhere - Return Table containing all values requested by SelectCommand and filtered by WhereCommand -func (engine *DbEngine) SelectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand) *Table { - table, exist := engine.Tables[selectCommand.Name.Token.Literal] +// deleteFromTable - Delete all rows of data from table that match given condition +func (engine *DbEngine) deleteFromTable(deleteCommand *ast.DeleteCommand, whereCommand *ast.WhereCommand) error { + table, exist := engine.Tables[deleteCommand.Name.Token.Literal] if !exist { - log.Fatal("Table with the name of " + selectCommand.Name.Token.Literal + " doesn't exist!") + return &TableDoesNotExistError{deleteCommand.Name.Token.Literal} + } + + newTable, err := engine.getFilteredTable(table, whereCommand, true, deleteCommand.Name.Token.Literal) + + if err != nil { + return err } + engine.Tables[deleteCommand.Name.Token.Literal] = newTable + + return nil +} +// dropTable - Drop table with given name +func (engine *DbEngine) dropTable(dropCommand *ast.DropCommand) { + delete(engine.Tables, dropCommand.Name.GetToken().Literal) +} + +// selectFromTableWithWhere - Return Table containing all values requested by SelectCommand and filtered by WhereCommand +func (engine *DbEngine) selectFromTableWithWhere(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, table *Table) (*Table, error) { if len(table.Columns) == 0 || len(table.Columns[0].Values) == 0 { return engine.selectFromProvidedTable(selectCommand, &Table{Columns: []*Column{}}) } - filteredTable := engine.getFilteredTable(table, whereCommand, false) + filteredTable, err := engine.getFilteredTable(table, whereCommand, false, selectCommand.Name.GetToken().Literal) + + if err != nil { + return nil, err + } return engine.selectFromProvidedTable(selectCommand, filteredTable) } -// SelectFromTableWithWhereAndOrderBy - Return Table containing all values requested by SelectCommand, +// selectFromTableWithWhereAndOrderBy - Return Table containing all values requested by SelectCommand, // filtered by WhereCommand and sorted by OrderByCommand -func (engine *DbEngine) SelectFromTableWithWhereAndOrderBy(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, orderByCommand *ast.OrderByCommand) *Table { - table, exist := engine.Tables[selectCommand.Name.Token.Literal] +func (engine *DbEngine) selectFromTableWithWhereAndOrderBy(selectCommand *ast.SelectCommand, whereCommand *ast.WhereCommand, orderByCommand *ast.OrderByCommand, table *Table) (*Table, error) { + filteredTable, err := engine.getFilteredTable(table, whereCommand, false, selectCommand.Name.GetToken().Literal) - if !exist { - log.Fatal("Table with the name of " + selectCommand.Name.Token.Literal + " doesn't exist!") + if err != nil { + return nil, err } - filteredTable := engine.getFilteredTable(table, whereCommand, false) - emptyTable := getCopyOfTableWithoutRows(table) - return engine.selectFromProvidedTable(selectCommand, engine.getSortedTable(orderByCommand, filteredTable, emptyTable)) -} - -// SelectFromTableWithOrderBy - Return Table containing all values requested by SelectCommand and sorted by OrderByCommand -func (engine *DbEngine) SelectFromTableWithOrderBy(selectCommand *ast.SelectCommand, orderByCommand *ast.OrderByCommand) *Table { - table, exist := engine.Tables[selectCommand.Name.Token.Literal] + sortedTable, err := engine.getSortedTable(orderByCommand, filteredTable, emptyTable, selectCommand.Name.GetToken().Literal) - if !exist { - log.Fatal("Table with the name of " + selectCommand.Name.Token.Literal + " doesn't exist!") + if err != nil { + return nil, err } + return engine.selectFromProvidedTable(selectCommand, sortedTable) +} + +// selectFromTableWithOrderBy - Return Table containing all values requested by SelectCommand and sorted by OrderByCommand +func (engine *DbEngine) selectFromTableWithOrderBy(selectCommand *ast.SelectCommand, orderByCommand *ast.OrderByCommand, table *Table) (*Table, error) { emptyTable := getCopyOfTableWithoutRows(table) - sortedTable := engine.getSortedTable(orderByCommand, table, emptyTable) + sortedTable, err := engine.getSortedTable(orderByCommand, table, emptyTable, selectCommand.Name.GetToken().Literal) + + if err != nil { + return nil, err + } return engine.selectFromProvidedTable(selectCommand, sortedTable) } -func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, filteredTable *Table, copyOfTable *Table) *Table { +func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, table *Table, copyOfTable *Table, tableName string) (*Table, error) { sortPatterns := orderByCommand.SortPatterns - rows := mapTableToRows(filteredTable) + columnNames := make([]string, 0) + for _, sortPattern := range sortPatterns { + columnNames = append(columnNames, sortPattern.ColumnName.Literal) + } + + missingColName := engine.getMissingColumnName(columnNames, table) + if missingColName != "" { + return nil, &ColumnDoesNotExistError{ + tableName: tableName, + columnName: missingColName, + } + } + + rows := MapTableToRows(table).rows sort.Slice(rows, func(i, j int) bool { howDeepWeSort := 0 @@ -183,19 +499,42 @@ func (engine *DbEngine) getSortedTable(orderByCommand *ast.OrderByCommand, filte newColumn.Values = append(newColumn.Values, value) } } - return copyOfTable + return copyOfTable, nil } -func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCommand, negation bool) *Table { +func (engine *DbEngine) getMissingColumnName(columnNames []string, table *Table) string { + for _, columnName := range columnNames { + exists := false + for _, column := range table.Columns { + if column.Name == columnName { + exists = true + break + } + } + if !exists { + return columnName + } + } + return "" +} + +func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCommand, negation bool, tableName string) (*Table, error) { filteredTable := getCopyOfTableWithoutRows(table) - //TODO: maybe rows should have separate structure, so it would would have it's on methods - rows := mapTableToRows(table) + identifiers := whereCommand.Expression.GetIdentifiers() + columnNames := make([]string, 0) + for _, identifier := range identifiers { + columnNames = append(columnNames, identifier.Token.Literal) + } + missingColumnName := engine.getMissingColumnName(columnNames, table) + if missingColumnName != "" { + return nil, &ColumnDoesNotExistError{tableName: tableName, columnName: missingColumnName} + } - for _, row := range rows { - fulfilledFilters, err := isFulfillingFilters(row, whereCommand.Expression) + for _, row := range MapTableToRows(table).rows { + fulfilledFilters, err := isFulfillingFilters(row, whereCommand.Expression, whereCommand.Token.Literal) if err != nil { - log.Fatal(err.Error()) + return nil, err } if xor(fulfilledFilters, negation) { @@ -205,7 +544,111 @@ func (engine *DbEngine) getFilteredTable(table *Table, whereCommand *ast.WhereCo } } } - return filteredTable + return filteredTable, nil +} + +func (engine *DbEngine) joinTables(joinCommand *ast.JoinCommand, leftTableName string) (*Table, error) { + leftTable, exist := engine.Tables[leftTableName] + leftTablePrefix := leftTableName + "." + if !exist { + return nil, &TableDoesNotExistError{leftTableName} + } + + rightTableName := joinCommand.Name.Token.Literal + rightTablePrefix := rightTableName + "." + rightTable, exist := engine.Tables[rightTableName] + if !exist { + return nil, &TableDoesNotExistError{rightTableName} + } + + joinedTable := &Table{Columns: []*Column{}} + + addColumnsWithPrefix(joinedTable, leftTable.Columns, leftTablePrefix) + addColumnsWithPrefix(joinedTable, rightTable.Columns, rightTablePrefix) + + leftTableWithAddedPrefix := leftTable.getTableCopyWithAddedPrefixToColumnNames(leftTablePrefix) + rightTableWithAddedPrefix := rightTable.getTableCopyWithAddedPrefixToColumnNames(rightTablePrefix) + var unmatchedRightRows = make(map[int]bool) + + for leftRowIndex := 0; leftRowIndex < len(leftTable.Columns[0].Values); leftRowIndex++ { + joinedRowLeft := getRow(leftTableWithAddedPrefix, leftRowIndex) + leftRowMatches := false + + for rightRowIndex := 0; rightRowIndex < len(rightTable.Columns[0].Values); rightRowIndex++ { + joinedRowRight := getRow(rightTableWithAddedPrefix, rightRowIndex) + maps.Copy(joinedRowRight, joinedRowLeft) + + fulfilledFilters, err := isFulfillingFilters(joinedRowRight, joinCommand.Expression, joinCommand.Token.Literal) + if err != nil { + return nil, err + } + + isLastLeftRow := leftRowIndex == len(leftTable.Columns[0].Values)-1 + + if fulfilledFilters { + for colIndex, column := range joinedTable.Columns { + joinedTable.Columns[colIndex].Values = append(joinedTable.Columns[colIndex].Values, joinedRowRight[column.Name]) + } + leftRowMatches, unmatchedRightRows[rightRowIndex] = true, true + } else if isLastLeftRow && joinCommand.ShouldTakeRightSide() && !unmatchedRightRows[rightRowIndex] { + joinedRowRight = getRow(rightTableWithAddedPrefix, rightRowIndex) + aggregateRowIntoJoinTable(leftTableWithAddedPrefix, joinedRowRight, joinedTable) + } + } + + if joinCommand.ShouldTakeLeftSide() && !leftRowMatches { + aggregateRowIntoJoinTable(rightTableWithAddedPrefix, joinedRowLeft, joinedTable) + } + } + + return joinedTable, nil +} + +func aggregateRowIntoJoinTable(tableWithAddedPrefix *Table, joinedRow map[string]ValueInterface, joinedTable *Table) { + joinedEmptyRow := getEmptyRow(tableWithAddedPrefix) + maps.Copy(joinedRow, joinedEmptyRow) + for colIndex, column := range joinedTable.Columns { + joinedTable.Columns[colIndex].Values = append(joinedTable.Columns[colIndex].Values, joinedRow[column.Name]) + } +} + +func addColumnsWithPrefix(finalTable *Table, columnsToAdd []*Column, prefix string) { + for _, column := range columnsToAdd { + finalTable.Columns = append(finalTable.Columns, + &Column{ + Type: column.Type, + Values: make([]ValueInterface, 0), + Name: prefix + column.Name, + }) + } +} + +func (table *Table) applyOffsetAndLimit(command *ast.SelectCommand) { + var offset = 0 + var limitRaw = -1 + + if command.HasLimitCommand() { + limitRaw = command.LimitCommand.Count + } + if command.HasOffsetCommand() { + offset = command.OffsetCommand.Count + } + + for _, column := range table.Columns { + var limit int + + if limitRaw == -1 || limitRaw+offset > len(column.Values) { + limit = len(column.Values) + } else { + limit = limitRaw + offset + } + + if offset > len(column.Values) || limit == 0 { + column.Values = make([]ValueInterface, 0) + } else { + column.Values = column.Values[offset:limit] + } + } } func xor(fulfilledFilters bool, negation bool) bool { @@ -226,83 +669,93 @@ func getCopyOfTableWithoutRows(table *Table) *Table { return filteredTable } -func mapTableToRows(table *Table) []map[string]ValueInterface { - rows := make([]map[string]ValueInterface, 0) - - numberOfRows := len(table.Columns[0].Values) +func isFulfillingFilters(row map[string]ValueInterface, expressionTree ast.Expression, commandName string) (bool, error) { + switch mappedExpression := expressionTree.(type) { + case *ast.OperationExpression: + return processOperationExpression(row, mappedExpression, commandName) + case *ast.BooleanExpression: + return processBooleanExpression(mappedExpression) + case *ast.ConditionExpression: + return processConditionExpression(row, mappedExpression, commandName) + case *ast.ContainExpression: + return processContainExpression(row, mappedExpression) - for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { - row := make(map[string]ValueInterface) - for _, column := range table.Columns { - row[column.Name] = column.Values[rowIndex] - } - rows = append(rows, row) + default: + return false, &UnsupportedExpressionTypeError{commandName: commandName, variable: fmt.Sprintf("%s", mappedExpression)} } - return rows } -func isFulfillingFilters(row map[string]ValueInterface, expressionTree ast.Expression) (bool, error) { - operationExpression, operationExpressionIsValid := expressionTree.(*ast.OperationExpression) - if operationExpressionIsValid { - return processOperationExpression(row, operationExpression) +func processConditionExpression(row map[string]ValueInterface, conditionExpression *ast.ConditionExpression, commandName string) (bool, error) { + valueLeft, err := getTifierValue(conditionExpression.Left, row) + if err != nil { + return false, err } - booleanExpression, booleanExpressionIsValid := expressionTree.(*ast.BooleanExpression) - if booleanExpressionIsValid { - return processBooleanExpression(booleanExpression) + valueRight, err := getTifierValue(conditionExpression.Right, row) + if err != nil { + return false, err } - conditionExpression, conditionExpressionIsValid := expressionTree.(*ast.ConditionExpression) - if conditionExpressionIsValid { - return processConditionExpression(row, conditionExpression) + switch conditionExpression.Condition.Type { + case token.EQUAL: + return valueLeft.IsEqual(valueRight), nil + case token.NOT: + return !(valueLeft.IsEqual(valueRight)), nil + default: + return false, &UnsupportedConditionalTokenError{variable: conditionExpression.Condition.Literal, commandName: commandName} } - - return false, fmt.Errorf("unsupported expression has been used in WHERE command: %v", expressionTree.GetIdentifiers()) } -func processConditionExpression(row map[string]ValueInterface, conditionExpression *ast.ConditionExpression) (bool, error) { - valueLeft, isValueLeftValid := getTifierValue(conditionExpression.Left, row) - if isValueLeftValid != nil { - log.Fatal(isValueLeftValid.Error()) +func processContainExpression(row map[string]ValueInterface, containExpression *ast.ContainExpression) (bool, error) { + valueLeft, err := getTifierValue(containExpression.Left, row) + if err != nil { + return false, err } - valueRight, isValueRightValid := getTifierValue(conditionExpression.Right, row) - if isValueLeftValid != nil { - log.Fatal(isValueRightValid.Error()) + result, err := ifValueInterfaceInArray(containExpression.Right, valueLeft) + + if containExpression.Contains { + return result, err } - switch conditionExpression.Condition.Type { - case token.EQUAL: - return valueLeft.IsEqual(valueRight), nil - case token.NOT: - return !(valueLeft.IsEqual(valueRight)), nil - default: - return false, errors.New("Operation '" + conditionExpression.Condition.Literal + "' provided in WHERE command isn't allowed!") + return !result, err +} + +func ifValueInterfaceInArray(array []ast.Anonymitifier, valueLeft ValueInterface) (bool, error) { + for _, expectedValue := range array { + value, err := getInterfaceValue(expectedValue.Token) + if err != nil { + return false, err + } + if value.IsEqual(valueLeft) { + return true, nil + } } + return false, nil } -func processOperationExpression(row map[string]ValueInterface, operationExpression *ast.OperationExpression) (bool, error) { +func processOperationExpression(row map[string]ValueInterface, operationExpression *ast.OperationExpression, commandName string) (bool, error) { if operationExpression.Operation.Type == token.AND { - left, err := isFulfillingFilters(row, operationExpression.Left) + left, err := isFulfillingFilters(row, operationExpression.Left, commandName) if !left { return left, err } - right, err := isFulfillingFilters(row, operationExpression.Right) + right, err := isFulfillingFilters(row, operationExpression.Right, commandName) return left && right, err } if operationExpression.Operation.Type == token.OR { - left, err := isFulfillingFilters(row, operationExpression.Left) + left, err := isFulfillingFilters(row, operationExpression.Left, commandName) if left { return left, err } - right, err := isFulfillingFilters(row, operationExpression.Right) + right, err := isFulfillingFilters(row, operationExpression.Right, commandName) return left || right, err } - return false, errors.New("unsupported operation token has been used: " + operationExpression.Operation.Literal) + return false, &UnsupportedOperationTokenError{operationExpression.Operation.Literal} } func processBooleanExpression(booleanExpression *ast.BooleanExpression) (bool, error) { @@ -313,26 +766,16 @@ func processBooleanExpression(booleanExpression *ast.BooleanExpression) (bool, e } func getTifierValue(tifier ast.Tifier, row map[string]ValueInterface) (ValueInterface, error) { - identifier, identifierIsValid := tifier.(ast.Identifier) - - if identifierIsValid { - return row[identifier.GetToken().Literal], nil - } - - anonymitifier, anonymitifierIsValid := tifier.(ast.Anonymitifier) - if anonymitifierIsValid { - return getInterfaceValue(anonymitifier.GetToken()), nil - } - - // TODO: Maybe information in which table this column doesn't exist is needed - return nil, errors.New("Column name:'" + tifier.GetToken().Literal + "' doesn't exist!") -} - -func getColumnIndexByName(columns []*Column, columName string) (int, error) { - for i, column := range columns { - if column.Name == columName { - return i, nil + switch mappedTifier := tifier.(type) { + case ast.Identifier: + value, ok := row[mappedTifier.GetToken().Literal] + if ok == false { + return nil, &ColumnDoesNotExistError{tableName: "", columnName: mappedTifier.GetToken().Literal} } + return value, nil + case ast.Anonymitifier: + return getInterfaceValue(mappedTifier.GetToken()) + default: + return nil, &UnsupportedValueType{tifier.GetToken().Literal} } - return -1, errors.New("Column name:'" + columName + "' doesn't exist!") } diff --git a/engine/engine_error_handling_test.go b/engine/engine_error_handling_test.go new file mode 100644 index 0000000..f40f659 --- /dev/null +++ b/engine/engine_error_handling_test.go @@ -0,0 +1,131 @@ +package engine + +import ( + "github.com/LissaGreense/GO4SQL/lexer" + "github.com/LissaGreense/GO4SQL/parser" + "github.com/LissaGreense/GO4SQL/token" + "testing" +) + +type errorHandlingTestSuite struct { + input string + expectedError string +} + +func TestEngineCreateCommandErrorHandling(t *testing.T) { + duplicateTableNameError := TableAlreadyExistsError{"table1"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE table1( one TEXT , two INT);CREATE TABLE table1(two INT);", duplicateTableNameError.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineInsertCommandErrorHandling(t *testing.T) { + tableDoNotExistError := TableDoesNotExistError{"table1"} + invalidNumberOfParametersError := InvalidNumberOfParametersError{expectedNumber: 2, actualNumber: 1, commandName: token.INSERT} + invalidParametersTypeError := InvalidValueTypeError{expectedType: token.IDENT, actualType: token.LITERAL, commandName: token.INSERT} + tests := []errorHandlingTestSuite{ + {"INSERT INTO table1 VALUES( 'hello', 1);", tableDoNotExistError.Error()}, + {"CREATE TABLE table1( one TEXT , two INT); INSERT INTO table1 VALUES(1);", invalidNumberOfParametersError.Error()}, + {"CREATE TABLE table1( one TEXT , two INT); INSERT INTO table1 VALUES(1, 1 );", invalidParametersTypeError.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineSelectCommandErrorHandling(t *testing.T) { + noTableDoesNotExist := TableDoesNotExistError{"tb1"} + columnDoesNotExist := ColumnDoesNotExistError{tableName: "tbl", columnName: "two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); SELECT * FROM tb1;", noTableDoesNotExist.Error()}, + {"CREATE TABLE tbl(one TEXT); SELECT two FROM tbl;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineDeleteCommandErrorHandling(t *testing.T) { + noTableDoesNotExist := TableDoesNotExistError{"tb1"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); DELETE FROM tb1 WHERE one EQUAL 3;", noTableDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineWhereCommandErrorHandling(t *testing.T) { + columnDoesNotExist := ColumnDoesNotExistError{tableName: "tbl", columnName: "two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); INSERT INTO tbl VALUES('hello'); SELECT * FROM tbl WHERE two EQUAL 3;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineUpdateCommandErrorHandling(t *testing.T) { + noTableDoesNotExist := TableDoesNotExistError{"tb1"} + columnDoesNotExist := ColumnDoesNotExistError{tableName: "tbl", columnName: "two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); UPDATE tb1 SET one TO 2;", noTableDoesNotExist.Error()}, + {"CREATE TABLE tbl(one TEXT);UPDATE tbl SET two TO 2;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineOrderByCommandErrorHandling(t *testing.T) { + columnDoesNotExist := ColumnDoesNotExistError{tableName: "tbl", columnName: "two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tbl(one TEXT); SELECT * FROM tbl ORDER BY two ASC;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func TestEngineFullJoinErrorHandling(t *testing.T) { + leftTableNotExist := TableDoesNotExistError{tableName: "leftTable"} + rightTableNotExist := TableDoesNotExistError{tableName: "rightTable"} + columnDoesNotExist := ColumnDoesNotExistError{tableName: "", columnName: "leftTable.two"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE rightTable(one TEXT); SELECT leftTable.one, rightTable.one FROM leftTable JOIN rightTable ON leftTable.one EQUAL rightTable.one;", leftTableNotExist.Error()}, + {"CREATE TABLE leftTable(one TEXT); SELECT leftTable.one, rightTable.one FROM leftTable JOIN rightTable ON leftTable.one EQUAL rightTable.one;", rightTableNotExist.Error()}, + {"CREATE TABLE leftTable(one TEXT); CREATE TABLE rightTable(one TEXT); INSERT INTO leftTable VALUES('hi'); INSERT INTO rightTable VALUES('hi'); SELECT * FROM leftTable JOIN rightTable ON leftTable.two EQUAL rightTable.one;", columnDoesNotExist.Error()}, + } + + runEngineErrorHandlingSuite(t, tests) +} + +func runEngineErrorHandlingSuite(t *testing.T, suite []errorHandlingTestSuite) { + for i, test := range suite { + errorMsg := getErrorMessage(t, test.input, i) + + if errorMsg != test.expectedError { + t.Fatalf("[%v]Was expecting error: \n\t{%s},\n\tbut it was:\n\t{%s}", i, test.expectedError, errorMsg) + } + } +} + +func getErrorMessage(t *testing.T, input string, testIndex int) string { + lexerInstance := lexer.RunLexer(input) + parserInstance := parser.New(lexerInstance) + sequences, parserError := parserInstance.ParseSequence() + if parserError != nil { + t.Fatalf("[%d] Error has occured in parser not in engine, error: %s", testIndex, parserError.Error()) + } + + engine := New() + _, engineError := engine.Evaluate(sequences) + if engineError == nil { + t.Fatalf("[%d] Was expecting error from engine but there was none", testIndex) + } + + return engineError.Error() +} diff --git a/engine/engine_test.go b/engine/engine_test.go index 8e7f54c..4e1111d 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -1,7 +1,7 @@ package engine import ( - "strings" + "log" "testing" "github.com/LissaGreense/GO4SQL/ast" @@ -9,22 +9,53 @@ import ( "github.com/LissaGreense/GO4SQL/parser" ) +func TestCreate(t *testing.T) { + simpleCreateCase := engineDBContentTestSuite{ + inputs: []string{"CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );"}, + expectedTableNames: []string{"tb1"}, + } + + simpleCreateCase.runTestSuite(t) + + multiplyCreationCase := engineDBContentTestSuite{ + inputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + "CREATE TABLE tb2( one TEXT, two INT, three INT, four TEXT );", + }, + expectedTableNames: []string{"tb1", "tb2"}, + } + + multiplyCreationCase.runTestSuite(t) + +} + +func TestDrop(t *testing.T) { + simpleDropCase := engineDBContentTestSuite{ + inputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + "DROP TABLE tb1;", + }, + expectedTableNames: []string{}, + } + simpleDropCase.runTestSuite(t) +} + func TestSelectCommand(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", - "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", - "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, NULL );", + "INSERT INTO tb1 VALUES( 'byebye', NULL, 33, 'e' );", }, selectInput: "SELECT * FROM tb1;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, - {"goodbye", "2", "22", "w"}, - {"byebye", "3", "33", "e"}, + {"goodbye", "2", "22", "NULL"}, + {"byebye", "NULL", "33", "e"}, }, } @@ -32,7 +63,7 @@ func TestSelectCommand(t *testing.T) { } func TestSelectWithColumnNamesCommand(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -54,7 +85,7 @@ func TestSelectWithColumnNamesCommand(t *testing.T) { } func TestSelectWithWhereEqual(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -75,16 +106,16 @@ func TestSelectWithWhereEqual(t *testing.T) { func TestSelectWithWhereNotEqual(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", - "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, NULL, 'w' );", "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", }, - selectInput: "SELECT one, two, three, four FROM tb1 WHERE three NOT 22;", + selectInput: "SELECT one, two, three, four FROM tb1 WHERE three NOT NULL;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, @@ -95,18 +126,100 @@ func TestSelectWithWhereNotEqual(t *testing.T) { engineTestSuite.runTestSuite(t) } -func TestSelectWithWhereLogicalOperationAnd(t *testing.T) { +func TestSelectWithWhereContains(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, NULL, 'w' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1 WHERE three IN (11, NULL, 67);", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"hello", "1", "11", "q"}, + {"goodbye", "2", "NULL", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestSelectWithWhereNotContains(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1 WHERE one NOTIN ('hello', 'byebye', 'youAreTheBest');", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"goodbye", "2", "22", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestSelectWithWhereContainsButResponseIsEmpty(t *testing.T) { + + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1 WHERE one IN ('I', 'dont', 'exist', 'anymore');", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestSelectWithWhereNotContainsButResponseIsEmpty(t *testing.T) { + + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1 WHERE two NOTIN (1, 2, 3, 4);", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestSelectWithWhereLogicalOperationAnd(t *testing.T) { + + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', NULL, 22, 'w' );", "INSERT INTO tb1 VALUES( 'goodbye', 3, 33, 'e' );", }, - selectInput: "SELECT * FROM tb1 WHERE one EQUAL 'goodbye' AND two NOT 2;", + selectInput: "SELECT * FROM tb1 WHERE one EQUAL 'goodbye' AND two NOT NULL;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"goodbye", "3", "33", "e"}, @@ -118,7 +231,7 @@ func TestSelectWithWhereLogicalOperationAnd(t *testing.T) { func TestSelectWithWhereLogicalOperationOR(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -140,20 +253,20 @@ func TestSelectWithWhereLogicalOperationOR(t *testing.T) { func TestSelectWithWhereLogicalOperationOROperationAND(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", - "INSERT INTO tb1 VALUES( 'goodbye', 3, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 3, 33, NULL );", }, - selectInput: "SELECT * FROM tb1 WHERE one NOT 'goodbye' OR two EQUAL 3 AND four EQUAL 'e';", + selectInput: "SELECT * FROM tb1 WHERE one NOT 'goodbye' OR two IN (3) AND four EQUAL NULL;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, {"hello", "1", "11", "q"}, - {"goodbye", "3", "33", "e"}, + {"goodbye", "3", "33", "NULL"}, }, } @@ -162,7 +275,7 @@ func TestSelectWithWhereLogicalOperationOROperationAND(t *testing.T) { func TestSelectWithWhereEqualToTrue(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -185,7 +298,7 @@ func TestSelectWithWhereEqualToTrue(t *testing.T) { func TestSelectWithWhereEqualToFalse(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -201,9 +314,34 @@ func TestSelectWithWhereEqualToFalse(t *testing.T) { engineTestSuite.runTestSuite(t) } +func TestDistinctSelect(t *testing.T) { + + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'goodbye', NULL, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', NULL, 22, 'w' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + }, + selectInput: "SELECT DISTINCT * FROM tb1;", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"hello", "1", "11", "q"}, + {"goodbye", "NULL", "22", "w"}, + {"goodbye", "2", "22", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + func TestDelete(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -224,20 +362,66 @@ func TestDelete(t *testing.T) { engineTestSuite.runTestSuite(t) } +func TestUpdateWithWhere(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + "UPDATE tb1 SET one TO 'hi hello', three TO NULL WHERE two EQUAL 3;", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1;", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"hello", "1", "11", "q"}, + {"hi hello", "3", "NULL", "e"}, + {"goodbye", "2", "22", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestUpdateWithoutWhere(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 1, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 3, 33, 'e' );", + "UPDATE tb1 SET one TO 'hi hello', three TO 5;", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", + }, + selectInput: "SELECT one, two, three, four FROM tb1;", + expectedOutput: [][]string{ + {"one", "two", "three", "four"}, + {"hi hello", "1", "5", "q"}, + {"hi hello", "3", "5", "e"}, + {"goodbye", "2", "22", "w"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + func TestOrderBy(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, insertAndDeleteInputs: []string{ "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", - "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'byebye', NULL, 33, 'e' );", "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'w' );", }, selectInput: "SELECT one, two, three, four FROM tb1 ORDER BY two ASC;", expectedOutput: [][]string{ {"one", "two", "three", "four"}, - {"byebye", "1", "33", "e"}, + {"byebye", "NULL", "33", "e"}, {"goodbye", "2", "22", "w"}, {"hello", "3", "11", "q"}, }, @@ -247,7 +431,7 @@ func TestOrderBy(t *testing.T) { } func TestOrderByWithWhere(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -268,7 +452,7 @@ func TestOrderByWithWhere(t *testing.T) { } func TestOrderByWithMultipleSorts(t *testing.T) { - engineTestSuite := engineTestSuite{ + engineTestSuite := engineTableContentTestSuite{ createInputs: []string{ "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", }, @@ -291,72 +475,530 @@ func TestOrderByWithMultipleSorts(t *testing.T) { engineTestSuite.runTestSuite(t) } -type engineTestSuite struct { - createInputs []string - insertAndDeleteInputs []string - selectInput string - expectedOutput [][]string +func TestLimit(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + "INSERT INTO tb1 VALUES( 'welcome', 2, 66, 'bb' );", + "INSERT INTO tb1 VALUES( 'seeYouLater', 2, 95, 'ab' );", + }, + selectInput: "SELECT one FROM tb1 LIMIT 2;", + expectedOutput: [][]string{ + {"one"}, + {"hello"}, + {"byebye"}, + }, + } + + engineTestSuite.runTestSuite(t) } -func (engineTestSuite *engineTestSuite) runTestSuite(t *testing.T) { - input := "" - expectedSequencesNumber := 0 - for inputIndex := 0; inputIndex < len(engineTestSuite.createInputs); inputIndex++ { - input += engineTestSuite.createInputs[inputIndex] + "\n" +func TestLimitEqualToZero(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + "INSERT INTO tb1 VALUES( 'welcome', 2, 66, 'bb' );", + "INSERT INTO tb1 VALUES( 'seeYouLater', 2, 95, 'ab' );", + }, + selectInput: "SELECT one FROM tb1 LIMIT 0;", + expectedOutput: [][]string{ + {"one"}, + }, } - for inputIndex := 0; inputIndex < len(engineTestSuite.insertAndDeleteInputs); inputIndex++ { - if strings.HasPrefix(engineTestSuite.insertAndDeleteInputs[inputIndex], "DELETE") { - expectedSequencesNumber++ - } - input += engineTestSuite.insertAndDeleteInputs[inputIndex] + "\n" + + engineTestSuite.runTestSuite(t) +} + +func TestLimitThatIsMoreThanSize(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'aa' );", + }, + selectInput: "SELECT one FROM tb1 LIMIT 666;", + expectedOutput: [][]string{ + {"one"}, + {"hello"}, + {"byebye"}, + {"goodbye"}, + }, } - input += engineTestSuite.selectInput - lexerInstance := lexer.RunLexer(input) - parserInstance := parser.New(lexerInstance) - sequences := parserInstance.ParseSequence() + engineTestSuite.runTestSuite(t) +} - expectedSequencesNumber += len(engineTestSuite.createInputs) + len(engineTestSuite.insertAndDeleteInputs) + 1 +func TestOffset(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 4, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + }, + selectInput: "SELECT one FROM tb1 OFFSET 3;", + expectedOutput: [][]string{ + {"one"}, + {"sorry"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestOffsetThatOverExceedSize(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 4, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + }, + selectInput: "SELECT one FROM tb1 WHERE TRUE ORDER BY two ASC OFFSET 4;", + expectedOutput: [][]string{ + {"one"}, + }, + } - var actualTable *Table + engineTestSuite.runTestSuite(t) +} - if strings.Contains(engineTestSuite.selectInput, "ORDER BY") { - expectedSequencesNumber++ +func TestOffsetEqualToZero(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 4, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + }, + selectInput: "SELECT one FROM tb1 OFFSET 0;", + expectedOutput: [][]string{ + {"one"}, + {"hello"}, + {"byebye"}, + {"goodbye"}, + {"sorry"}, + }, } - if strings.Contains(engineTestSuite.selectInput, " WHERE ") { + engineTestSuite.runTestSuite(t) +} - // WHERE CONDITION +func TestLimitAndOffset(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE tb1( one TEXT, two INT, three INT, four TEXT );", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO tb1 VALUES( 'hello', 3, 11, 'q' );", + "INSERT INTO tb1 VALUES( 'byebye', 1, 33, 'e' );", + "INSERT INTO tb1 VALUES( 'goodbye', 2, 22, 'aa' );", + "INSERT INTO tb1 VALUES( 'sorry', 2, 55, 'ba' );", + }, + selectInput: "SELECT one FROM tb1 WHERE TRUE ORDER BY two ASC, four DESC LIMIT 2 OFFSET 2;", + expectedOutput: [][]string{ + {"one"}, + {"goodbye"}, + {"hello"}, + }, + } - expectedSequencesNumber++ - if len(sequences.Commands) != expectedSequencesNumber { - t.Fatalf("sequences does not contain %d statements. got=%d", expectedSequencesNumber, len(sequences.Commands)) - } + engineTestSuite.runTestSuite(t) +} - engine := engineTestSuite.getEngineWithInsertedValues(sequences) +func TestDefaultJoinToBehaveLikeInnerJoin(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE books( author_id INT, title TEXT);", + "CREATE TABLE authors( author_id INT, name TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO books VALUES(2, 'Fire');", + "INSERT INTO books VALUES(1, 'Earth');", + "INSERT INTO books VALUES(1, 'Air');", + "INSERT INTO books VALUES(3, 'Smoke');", + "INSERT INTO authors VALUES( 1, 'Reynold Boyka' );", + "INSERT INTO authors VALUES( 2, 'Alissa Ireneus' );", + "INSERT INTO authors VALUES( 3, NULL );", + }, + selectInput: "SELECT books.title, authors.name FROM books JOIN authors ON books.author_id EQUAL authors.author_id;", + expectedOutput: [][]string{ + {"books.title", "authors.name"}, + {"Fire", "Alissa Ireneus"}, + {"Earth", "Reynold Boyka"}, + {"Air", "Reynold Boyka"}, + {"Smoke", "NULL"}, + }, + } - if strings.Contains(engineTestSuite.selectInput, "ORDER BY") { - actualTable = engine.SelectFromTableWithWhereAndOrderBy(sequences.Commands[len(sequences.Commands)-3].(*ast.SelectCommand), sequences.Commands[len(sequences.Commands)-2].(*ast.WhereCommand), sequences.Commands[len(sequences.Commands)-1].(*ast.OrderByCommand)) - } else { - actualTable = engine.SelectFromTableWithWhere(sequences.Commands[len(sequences.Commands)-2].(*ast.SelectCommand), sequences.Commands[len(sequences.Commands)-1].(*ast.WhereCommand)) - } + engineTestSuite.runTestSuite(t) +} - } else { +func TestInnerJoinOnMultipleMatches(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE books( author_id INT, title TEXT);", + "CREATE TABLE authors( author_id INT, name TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO books VALUES(1, 'Book One');", + "INSERT INTO books VALUES(1, 'Book Two');", + "INSERT INTO authors VALUES(1, 'Author One');", + "INSERT INTO authors VALUES(1, 'Author Two');", + }, + selectInput: "SELECT books.title, authors.name FROM books JOIN authors ON books.author_id EQUAL authors.author_id;", + expectedOutput: [][]string{ + {"books.title", "authors.name"}, + {"Book One", "Author One"}, + {"Book One", "Author Two"}, + {"Book Two", "Author One"}, + {"Book Two", "Author Two"}, + }, + } - // NO WHERE CONDITION + engineTestSuite.runTestSuite(t) +} - if len(sequences.Commands) != expectedSequencesNumber { - t.Fatalf("sequences does not contain %d statements. got=%d", expectedSequencesNumber, len(sequences.Commands)) - } +func TestFullJoinOnIdenticalTables(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + "CREATE TABLE table2( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(3, NULL);", + }, + selectInput: "SELECT table1.value, table2.value FROM table1 FULL JOIN table2 ON table1.id EQUAL table2.id;", + expectedOutput: [][]string{ + {"table1.value", "table2.value"}, + {"Value1", "NULL"}, + {"Value2", "Value2"}, + {"NULL", "NULL"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestInnerJoinWithSpecifiedKeywordOnIdenticalTables(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + "CREATE TABLE table2( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(NULL, NULL);", + "INSERT INTO table2 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(3, 'Value3');", + "INSERT INTO table2 VALUES(NULL, 'Value4');", + }, + selectInput: "SELECT table1.value, table2.value FROM table1 INNER JOIN table2 ON table1.id EQUAL table2.id;", + expectedOutput: [][]string{ + {"table1.value", "table2.value"}, + {"Value2", "Value2"}, + {"NULL", "Value4"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestLeftJoinOnIdenticalTables(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + "CREATE TABLE table2( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(NULL, 'Value4');", + "INSERT INTO table2 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(3, 'Value3');", + "INSERT INTO table2 VALUES(NULL, NULL);", + }, + selectInput: "SELECT table1.value, table2.value FROM table1 LEFT JOIN table2 ON table1.id EQUAL table2.id;", + expectedOutput: [][]string{ + {"table1.value", "table2.value"}, + {"Value1", "NULL"}, + {"Value2", "Value2"}, + {"Value4", "NULL"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestRightJoinOnIdenticalTables(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + "CREATE TABLE table2( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(NULL, NULL);", + "INSERT INTO table2 VALUES(2, 'Value2');", + "INSERT INTO table2 VALUES(3, 'Value3');", + "INSERT INTO table2 VALUES(NULL, 'Value4');", + }, + selectInput: "SELECT table1.value, table2.value FROM table1 RIGHT JOIN table2 ON table1.id EQUAL table2.id;", + expectedOutput: [][]string{ + {"table1.value", "table2.value"}, + {"Value2", "Value2"}, + {"NULL", "Value3"}, + {"NULL", "Value4"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionMax(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + }, + selectInput: "SELECT MAX(id), MAX(value) FROM table1;", + expectedOutput: [][]string{ + {"MAX(id)", "MAX(value)"}, + {"2", "Value2"}, + }, + } - engine := engineTestSuite.getEngineWithInsertedValues(sequences) + engineTestSuite.runTestSuite(t) +} - if strings.Contains(engineTestSuite.selectInput, "ORDER BY") { - actualTable = engine.SelectFromTableWithOrderBy(sequences.Commands[len(sequences.Commands)-2].(*ast.SelectCommand), sequences.Commands[len(sequences.Commands)-1].(*ast.OrderByCommand)) - } else { - actualTable = engine.SelectFromTable(sequences.Commands[len(sequences.Commands)-1].(*ast.SelectCommand)) +func TestAggregateFunctionMin(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, NULL);", + }, + selectInput: "SELECT MIN(value), MIN(id) FROM table1 WHERE value NOT NULL;", + expectedOutput: [][]string{ + {"MIN(value)", "MIN(id)"}, + {"Value1", "1"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionMinWithNull(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, NULL);", + }, + selectInput: "SELECT MIN(value), MIN(id) FROM table1;", + expectedOutput: [][]string{ + {"MIN(value)", "MIN(id)"}, + {"NULL", "1"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionCount(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, 'Value3');", + "INSERT INTO table1 VALUES(NULL, NULL);", + }, + selectInput: "SELECT COUNT(*), COUNT(id), COUNT(value) FROM table1;", + expectedOutput: [][]string{ + {"COUNT(*)", "COUNT(id)", "COUNT(value)"}, + {"4", "3", "3"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionSum(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, 'Value3');", + "INSERT INTO table1 VALUES(NULL, 'Value4');", + }, + selectInput: "SELECT SUM(id), SUM(value) FROM table1;", + expectedOutput: [][]string{ + {"SUM(id)", "SUM(value)"}, + {"6", "0"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionAvg(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, 'Value3');", + "INSERT INTO table1 VALUES(10, NULL);", + }, + selectInput: "SELECT AVG(id), AVG(value) FROM table1;", + expectedOutput: [][]string{ + {"AVG(id)", "AVG(value)"}, + {"4", "0"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionWithColumnSelection(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, NULL);", + "INSERT INTO table1 VALUES(6, 'Value3');", + }, + selectInput: "SELECT AVG(id), id FROM table1;", + expectedOutput: [][]string{ + {"AVG(id)", "id"}, + {"3", "1"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +func TestAggregateFunctionWithColumnSelectionAndOrderBy(t *testing.T) { + engineTestSuite := engineTableContentTestSuite{ + createInputs: []string{ + "CREATE TABLE table1( id INT, value TEXT);", + }, + insertAndDeleteInputs: []string{ + "INSERT INTO table1 VALUES(1, 'Value1');", + "INSERT INTO table1 VALUES(2, 'Value2');", + "INSERT INTO table1 VALUES(3, 'Value3');", + "INSERT INTO table1 VALUES(4, NULL);", + }, + selectInput: "SELECT MAX(id), id FROM table1 ORDER BY id DESC;", + expectedOutput: [][]string{ + {"MAX(id)", "id"}, + {"4", "4"}, + }, + } + + engineTestSuite.runTestSuite(t) +} + +type engineDBContentTestSuite struct { + inputs []string + expectedTableNames []string +} + +func (engineTestSuite *engineDBContentTestSuite) runTestSuite(t *testing.T) { + sequences := getSequences(inputsToString(engineTestSuite.inputs)) + engine := New() + engine.Evaluate(sequences) + + if len(engine.Tables) != len(engineTestSuite.expectedTableNames) { + t.Fatalf("Number of tables is incorrect, should be %d, got %d", len(engineTestSuite.expectedTableNames), len(engine.Tables)) + } + + for _, tableName := range engineTestSuite.expectedTableNames { + if engine.Tables[tableName] == nil { + t.Fatalf("Expected table '%s' does not exist", tableName) } } +} + +type engineTableContentTestSuite struct { + createInputs []string + insertAndDeleteInputs []string + selectInput string + expectedOutput [][]string +} + +func (engineTestSuite *engineTableContentTestSuite) runTestSuite(t *testing.T) { + expectedSequencesNumber := 0 + + input := inputsToString(engineTestSuite.createInputs) + inputsToString(engineTestSuite.insertAndDeleteInputs) + + sequencesWithoutSelect := getSequences(input) + selectCommand := getSequences(engineTestSuite.selectInput) + + expectedSequencesNumber += len(engineTestSuite.createInputs) + len(engineTestSuite.insertAndDeleteInputs) + 1 + + if len(sequencesWithoutSelect.Commands)+len(selectCommand.Commands) != expectedSequencesNumber { + t.Fatalf("sequences does not contain %d statements. got=%d", expectedSequencesNumber, len(sequencesWithoutSelect.Commands)) + } + + engine := New() + _, err := engine.Evaluate(sequencesWithoutSelect) + if err != nil { + log.Fatal(err) + } + actualTable, err := engine.getSelectResponse(selectCommand.Commands[0].(*ast.SelectCommand)) + if err != nil { + log.Fatal(err) + } if len(engineTestSuite.expectedOutput) == 0 { if len(actualTable.Columns[0].Values) != 0 { @@ -382,19 +1024,22 @@ func (engineTestSuite *engineTestSuite) runTestSuite(t *testing.T) { } -func (engineTestSuite *engineTestSuite) getEngineWithInsertedValues(sequences *ast.Sequence) *DbEngine { - engine := New() - for commandIndex := 0; commandIndex < len(sequences.Commands); commandIndex++ { - if createCommand, ok := sequences.Commands[commandIndex].(*ast.CreateCommand); ok { - engine.CreateTable(createCommand) - } - if insertCommand, ok := sequences.Commands[commandIndex].(*ast.InsertCommand); ok { - engine.InsertIntoTable(insertCommand) - } - if deleteCommand, ok := sequences.Commands[commandIndex].(*ast.DeleteCommand); ok { - whereCommand := sequences.Commands[commandIndex+1].(*ast.WhereCommand) - engine.DeleteFromTable(deleteCommand, whereCommand) - } +func inputsToString(inputs []string) string { + input := "" + + for inputIndex := 0; inputIndex < len(inputs); inputIndex++ { + input += inputs[inputIndex] + "\n" + } + + return input +} + +func getSequences(input string) *ast.Sequence { + lexerInstance := lexer.RunLexer(input) + parserInstance := parser.New(lexerInstance) + sequences, err := parserInstance.ParseSequence() + if err != nil { + log.Fatal(err) } - return engine + return sequences } diff --git a/engine/engine_utils.go b/engine/engine_utils.go index 83c1036..69329d2 100644 --- a/engine/engine_utils.go +++ b/engine/engine_utils.go @@ -1,22 +1,23 @@ package engine import ( - "log" "strconv" "github.com/LissaGreense/GO4SQL/token" ) -func getInterfaceValue(t token.Token) ValueInterface { +func getInterfaceValue(t token.Token) (ValueInterface, error) { switch t.Type { + case token.NULL: + return NullValue{}, nil case token.LITERAL: castedInteger, err := strconv.Atoi(t.Literal) if err != nil { - log.Fatal("Cannot cast \"" + t.Literal + "\" to Integer") + return nil, err } - return IntegerValue{Value: castedInteger} + return IntegerValue{Value: castedInteger}, nil default: - return StringValue{Value: t.Literal} + return StringValue{Value: t.Literal}, nil } } @@ -31,7 +32,7 @@ func tokenMapper(inputToken token.Type) token.Type { } } -func unique(arr []string) []string { +func unique(arr []string) *[]string { occurred := map[string]bool{} var result []string @@ -41,5 +42,5 @@ func unique(arr []string) []string { result = append(result, arr[e]) } } - return result + return &result } diff --git a/engine/errors.go b/engine/errors.go new file mode 100644 index 0000000..81910de --- /dev/null +++ b/engine/errors.go @@ -0,0 +1,106 @@ +package engine + +import "strconv" + +// TableAlreadyExistsError - error thrown when user tries to create table using name that already +// exists in database +type TableAlreadyExistsError struct { + tableName string +} + +func (m *TableAlreadyExistsError) Error() string { + return "table with the name of " + m.tableName + " already exists" +} + +// TableDoesNotExistError - error thrown when user tries to make operation on un-existing table +type TableDoesNotExistError struct { + tableName string +} + +func (m *TableDoesNotExistError) Error() string { + return "table with the name of " + m.tableName + " doesn't exist" +} + +// ColumnDoesNotExistError - error thrown when user tries to make operation on un-existing column +type ColumnDoesNotExistError struct { + tableName string + columnName string +} + +func (m *ColumnDoesNotExistError) Error() string { + return "column with the name of " + m.columnName + " doesn't exist in table " + m.tableName +} + +// InvalidNumberOfParametersError - error thrown when user provides invalid number of expected parameters +// (ex. fewer values in insert than defined ) +type InvalidNumberOfParametersError struct { + expectedNumber int + actualNumber int + commandName string +} + +func (m *InvalidNumberOfParametersError) Error() string { + return "invalid number of parameters in " + m.commandName + " command, should be: " + strconv.Itoa(m.expectedNumber) + ", but got: " + strconv.Itoa(m.actualNumber) +} + +// InvalidValueTypeError - error thrown when user provides value of different type than expected +type InvalidValueTypeError struct { + expectedType string + actualType string + commandName string +} + +func (m *InvalidValueTypeError) Error() string { + return "invalid value type provided in " + m.commandName + " command, expecting: " + m.expectedType + ", got: " + m.actualType +} + +// UnsupportedValueType - error thrown when engine found unsupported data type to be stored inside +// the columns +type UnsupportedValueType struct { + variable string +} + +func (m *UnsupportedValueType) Error() string { + return "couldn't map interface to any implementation of it: " + m.variable +} + +// UnsupportedOperationTokenError - error thrown when engine found unsupported operation token +// (supported are: AND, OR) +type UnsupportedOperationTokenError struct { + variable string +} + +func (m *UnsupportedOperationTokenError) Error() string { + return "unsupported operation token has been used: " + m.variable +} + +// UnsupportedConditionalTokenError - error thrown when engine found unsupported conditional token +// inside expression (supported are: EQUAL, NOT) +type UnsupportedConditionalTokenError struct { + variable string + commandName string +} + +func (m *UnsupportedConditionalTokenError) Error() string { + return "operation '" + m.variable + "' provided in " + m.commandName + " command isn't allowed" +} + +// UnsupportedExpressionTypeError - error thrown when engine found unsupported expression type +type UnsupportedExpressionTypeError struct { + variable string + commandName string +} + +func (m *UnsupportedExpressionTypeError) Error() string { + return "unsupported expression has been used in " + m.commandName + "command: " + m.variable +} + +// UnsupportedCommandTypeFromParserError - error thrown when engine found unsupported command +// from parser +type UnsupportedCommandTypeFromParserError struct { + variable string +} + +func (m *UnsupportedCommandTypeFromParserError) Error() string { + return "unsupported Command detected: " + m.variable +} diff --git a/engine/generic_value.go b/engine/generic_value.go index 7adab43..d5c26c9 100644 --- a/engine/generic_value.go +++ b/engine/generic_value.go @@ -1,6 +1,8 @@ package engine import ( + "errors" + "fmt" "log" "strconv" ) @@ -19,6 +21,7 @@ type SupportedTypes int const ( IntType = iota StringType + NullType ) // IntegerValue - Implementation of ValueInterface that is containing integer values @@ -31,13 +34,33 @@ type StringValue struct { Value string } +// NullValue - Implementation of ValueInterface that is containing null +type NullValue struct { +} + +// HandleValue - Function to take an instance of ValueInterface and cast to a specific implementation +func CastValueInterface(v ValueInterface) { + switch value := v.(type) { + case IntegerValue: + fmt.Printf("IntegerValue with Value: %d\n", value.Value) + case StringValue: + fmt.Printf("StringValue with Value: %s\n", value.Value) + case NullValue: + fmt.Println("NullValue (no value)") + default: + fmt.Println("Unknown type") + } +} + // ToString implementations func (value IntegerValue) ToString() string { return strconv.Itoa(value.Value) } func (value StringValue) ToString() string { return value.Value } +func (value NullValue) ToString() string { return "NULL" } // GetType implementations func (value IntegerValue) GetType() SupportedTypes { return IntType } func (value StringValue) GetType() SupportedTypes { return StringType } +func (value NullValue) GetType() SupportedTypes { return NullType } // IsEqual implementations func (value IntegerValue) IsEqual(valueInterface ValueInterface) bool { @@ -46,47 +69,110 @@ func (value IntegerValue) IsEqual(valueInterface ValueInterface) bool { func (value StringValue) IsEqual(valueInterface ValueInterface) bool { return areEqual(value, valueInterface) } +func (value NullValue) IsEqual(valueInterface ValueInterface) bool { + return areEqual(value, valueInterface) +} // isSmallerThan implementations -func (firstValue IntegerValue) isSmallerThan(secondValue ValueInterface) bool { - secondValueAsInteger, isInteger := secondValue.(IntegerValue) +func (value IntegerValue) isSmallerThan(secondValue ValueInterface) bool { + nullValue, isNull := secondValue.(NullValue) + if isNull { + return nullValue.isGreaterThan(value) + } + secondValueAsInteger, isInteger := secondValue.(IntegerValue) if !isInteger { log.Fatal("Can't compare Integer with other type") } - return firstValue.Value < secondValueAsInteger.Value + return value.Value < secondValueAsInteger.Value } -func (firstValue StringValue) isSmallerThan(secondValue ValueInterface) bool { - secondValueAsString, isString := secondValue.(StringValue) +func (value StringValue) isSmallerThan(secondValue ValueInterface) bool { + nullValue, isNull := secondValue.(NullValue) + if isNull { + return nullValue.isGreaterThan(value) + } + + secondValueAsString, isString := secondValue.(StringValue) if !isString { log.Fatal("Can't compare String with other type") } - return firstValue.Value < secondValueAsString.Value + return value.Value < secondValueAsString.Value +} + +func (value NullValue) isSmallerThan(secondValue ValueInterface) bool { + _, isNull := secondValue.(NullValue) + + if isNull { + return false + } + + return true } // isGreaterThan implementations -func (firstValue IntegerValue) isGreaterThan(secondValue ValueInterface) bool { - secondValueAsInteger, isInteger := secondValue.(IntegerValue) +func (value IntegerValue) isGreaterThan(secondValue ValueInterface) bool { + nullValue, isNull := secondValue.(NullValue) + if isNull { + return nullValue.isSmallerThan(value) + } + secondValueAsInteger, isInteger := secondValue.(IntegerValue) if !isInteger { log.Fatal("Can't compare Integer with other type") } - return firstValue.Value > secondValueAsInteger.Value + return value.Value > secondValueAsInteger.Value } -func (firstValue StringValue) isGreaterThan(secondValue ValueInterface) bool { - secondValueAsString, isString := secondValue.(StringValue) +func (value StringValue) isGreaterThan(secondValue ValueInterface) bool { + nullValue, isNull := secondValue.(NullValue) + if isNull { + return nullValue.isSmallerThan(value) + } + secondValueAsString, isString := secondValue.(StringValue) if !isString { log.Fatal("Can't compare String with other type") } - return firstValue.Value > secondValueAsString.Value + return value.Value > secondValueAsString.Value +} + +func (value NullValue) isGreaterThan(_ ValueInterface) bool { + return false } func areEqual(first ValueInterface, second ValueInterface) bool { return first.GetType() == second.GetType() && first.ToString() == second.ToString() } + +func getMin(values []ValueInterface) (ValueInterface, error) { + if len(values) == 0 { + return nil, errors.New("can't extract min from empty array") + } + minValue := values[0] + + for _, value := range values[1:] { + if value.isSmallerThan(minValue) { + minValue = value + } + } + return minValue, nil +} + +func getMax(values []ValueInterface) (ValueInterface, error) { + if len(values) == 0 { + return nil, errors.New("can't extract max from empty array") + } + + maxValue := values[0] + for _, value := range values[1:] { + if value.isGreaterThan(maxValue) { + maxValue = value + } + } + + return maxValue, nil +} diff --git a/engine/generic_value_test.go b/engine/generic_value_test.go index d4815e5..5b53198 100644 --- a/engine/generic_value_test.go +++ b/engine/generic_value_test.go @@ -6,7 +6,7 @@ import ( func TestIsGreaterThan(t *testing.T) { oneInt := IntegerValue{ - Value: 1, + Value: 0, } twoInt := IntegerValue{ Value: 2, @@ -17,13 +17,15 @@ func TestIsGreaterThan(t *testing.T) { twoString := StringValue{ Value: "aab", } + oneNull := NullValue{} + twoNull := NullValue{} if oneInt.isGreaterThan(twoInt) { - t.Errorf("1 shouldn't be greater than 2") + t.Errorf("0 shouldn't be greater than 2") } if !twoInt.isGreaterThan(oneInt) { - t.Errorf("1 shouldn't be greater than 2") + t.Errorf("0 shouldn't be greater than 2") } if oneString.isGreaterThan(twoString) { @@ -33,11 +35,31 @@ func TestIsGreaterThan(t *testing.T) { if !twoString.isGreaterThan(oneString) { t.Errorf("1 shouldn't be greater than 2") } + + if twoNull.isGreaterThan(oneNull) { + t.Errorf("null is not greater than null") + } + + if !oneInt.isGreaterThan(oneNull) { + t.Errorf("Any Int value cannot be smaller than null") + } + + if !oneString.isGreaterThan(oneNull) { + t.Errorf("Any String value cannot be smaller than null") + } + + if oneNull.isGreaterThan(oneInt) { + t.Errorf("Null cannot be greater than any int value") + } + + if oneNull.isGreaterThan(oneString) { + t.Errorf("Null cannot be greater than any string value") + } } func TestIsSmallerThan(t *testing.T) { oneInt := IntegerValue{ - Value: 1, + Value: 0, } twoInt := IntegerValue{ Value: 2, @@ -48,13 +70,15 @@ func TestIsSmallerThan(t *testing.T) { twoString := StringValue{ Value: "aab", } + oneNull := NullValue{} + twoNull := NullValue{} if !oneInt.isSmallerThan(twoInt) { - t.Errorf("1 should be smaller than 2") + t.Errorf("0 should be smaller than 2") } if twoInt.isSmallerThan(oneInt) { - t.Errorf("1 should be smaller than 2") + t.Errorf("0 should be smaller than 2") } if !oneString.isSmallerThan(twoString) { @@ -64,10 +88,29 @@ func TestIsSmallerThan(t *testing.T) { if twoString.isSmallerThan(oneString) { t.Errorf("1 should be smaller than 2") } + + if twoNull.isSmallerThan(oneNull) { + t.Errorf("null is not smaller than null") + } + + if oneInt.isSmallerThan(oneNull) { + t.Errorf("Any int value cannot be smaller than null") + } + + if oneString.isSmallerThan(oneNull) { + t.Errorf("Any string value cannot be smaller than null") + } + + if !oneNull.isSmallerThan(oneInt) { + t.Errorf("Null cannot be greater than any int value") + } + + if !oneNull.isSmallerThan(oneString) { + t.Errorf("Null cannot be greater than any string value") + } } func TestEquals(t *testing.T) { - oneInt := IntegerValue{ Value: 1, } @@ -80,12 +123,18 @@ func TestEquals(t *testing.T) { twoString := StringValue{ Value: "two", } + oneNull := NullValue{} + twoNull := NullValue{} shouldBeEqual(t, oneInt, oneInt) shouldBeEqual(t, oneString, oneString) + shouldBeEqual(t, oneNull, twoNull) shouldNotBeEqual(t, oneInt, twoInt) shouldNotBeEqual(t, oneString, twoString) shouldNotBeEqual(t, oneString, oneInt) + shouldNotBeEqual(t, oneNull, oneInt) + shouldNotBeEqual(t, oneNull, oneString) + shouldNotBeEqual(t, twoInt, twoNull) } func shouldBeEqual(t *testing.T, valueOne ValueInterface, valueTwo ValueInterface) { diff --git a/engine/row.go b/engine/row.go new file mode 100644 index 0000000..6160956 --- /dev/null +++ b/engine/row.go @@ -0,0 +1,35 @@ +package engine + +// Rows - Contain rows that store values, alternative to Table, some operations are easier +type Rows struct { + rows []map[string]ValueInterface +} + +// MapTableToRows - transform Table struct into Rows +func MapTableToRows(table *Table) Rows { + rows := make([]map[string]ValueInterface, 0) + + numberOfRows := len(table.Columns[0].Values) + + for rowIndex := 0; rowIndex < numberOfRows; rowIndex++ { + row := getRow(table, rowIndex) + rows = append(rows, row) + } + return Rows{rows: rows} +} + +func getRow(table *Table, rowIndex int) map[string]ValueInterface { + row := make(map[string]ValueInterface) + for _, column := range table.Columns { + row[column.Name] = column.Values[rowIndex] + } + return row +} + +func getEmptyRow(table *Table) map[string]ValueInterface { + row := make(map[string]ValueInterface) + for _, column := range table.Columns { + row[column.Name] = NullValue{} + } + return row +} diff --git a/engine/table.go b/engine/table.go index 26f4b5f..ec8561a 100644 --- a/engine/table.go +++ b/engine/table.go @@ -1,6 +1,9 @@ package engine -import "github.com/LissaGreense/GO4SQL/token" +import ( + "github.com/LissaGreense/GO4SQL/token" + "hash/adler32" +) // Table - Contain Columns that store values in engine type Table struct { @@ -12,7 +15,7 @@ func (table *Table) isEqual(secondTable *Table) bool { return false } - for i := 0; i < len(table.Columns); i++ { + for i := range table.Columns { if table.Columns[i].Name != secondTable.Columns[i].Name { return false } @@ -25,7 +28,7 @@ func (table *Table) isEqual(secondTable *Table) bool { if len(table.Columns[i].Values) != len(secondTable.Columns[i].Values) { return false } - for j := 0; j < len(table.Columns[i].Values); j++ { + for j := range table.Columns[i].Values { if table.Columns[i].Values[j].ToString() != secondTable.Columns[i].Values[j].ToString() { return false } @@ -35,6 +38,38 @@ func (table *Table) isEqual(secondTable *Table) bool { return true } +// getDistinctTable - Takes input table, and returns new one without any duplicates +func (table *Table) getDistinctTable() *Table { + distinctTable := getCopyOfTableWithoutRows(table) + + rowsCount := len(table.Columns[0].Values) + + checksumSet := map[uint32]struct{}{} + + for iRow := 0; iRow < rowsCount; iRow++ { + + mergedColumnValues := "" + for iColumn := range table.Columns { + fieldValue := table.Columns[iColumn].Values[iRow].ToString() + if table.Columns[iColumn].Type.Literal == token.TEXT { + fieldValue = "'" + fieldValue + "'" + } + mergedColumnValues += fieldValue + } + checksum := adler32.Checksum([]byte(mergedColumnValues)) + + _, exist := checksumSet[checksum] + if !exist { + checksumSet[checksum] = struct{}{} + for i, column := range distinctTable.Columns { + column.Values = append(column.Values, table.Columns[i].Values[iRow]) + } + } + } + + return distinctTable +} + // ToString - Return string contain all values and Column names in Table func (table *Table) ToString() string { columWidths := getColumWidths(table.Columns) @@ -42,7 +77,7 @@ func (table *Table) ToString() string { result := bar + "\n" result += "|" - for i := 0; i < len(table.Columns); i++ { + for i := range table.Columns { result += " " for j := 0; j < columWidths[i]-len(table.Columns[i].Name); j++ { result += " " @@ -52,16 +87,21 @@ func (table *Table) ToString() string { } result += "\n" + bar + "\n" + if len(table.Columns) == 0 { + return result + } + rowsCount := len(table.Columns[0].Values) for iRow := 0; iRow < rowsCount; iRow++ { result += "|" - for iColumn := 0; iColumn < len(table.Columns); iColumn++ { + for iColumn := range table.Columns { result += " " printedValue := table.Columns[iColumn].Values[iRow].ToString() - if table.Columns[iColumn].Type.Literal == token.TEXT { + if table.Columns[iColumn].Type.Literal == token.TEXT && + table.Columns[iColumn].Values[iRow].GetType() != NullType { printedValue = "'" + printedValue + "'" } for i := 0; i < columWidths[iColumn]-len(printedValue); i++ { @@ -77,6 +117,21 @@ func (table *Table) ToString() string { return result + bar } +func (table *Table) getTableCopyWithAddedPrefixToColumnNames(columnNamePrefix string) *Table { + newTable := &Table{Columns: []*Column{}} + + for _, column := range table.Columns { + newTable.Columns = append(newTable.Columns, + &Column{ + Type: column.Type, + Values: column.Values, + Name: columnNamePrefix + column.Name, + }) + } + + return newTable +} + func getBar(columWidths []int) string { bar := "+" @@ -94,12 +149,12 @@ func getBar(columWidths []int) string { func getColumWidths(columns []*Column) []int { widths := make([]int, 0) - for iColumn := 0; iColumn < len(columns); iColumn++ { + for iColumn := range columns { maxLength := len(columns[iColumn].Name) - for iRow := 0; iRow < len(columns[iColumn].Values); iRow++ { + for iRow := range columns[iColumn].Values { valueLength := len(columns[iColumn].Values[iRow].ToString()) if columns[iColumn].Type.Literal == token.TEXT { - valueLength += 2 // double "'" + valueLength += 2 // double ' } if valueLength > maxLength { maxLength = valueLength diff --git a/go.mod b/go.mod index f35fa57..aed510b 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/LissaGreense/GO4SQL -go 1.16 +go 1.21 diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 74a1010..0100f29 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -6,7 +6,7 @@ import ( "github.com/LissaGreense/GO4SQL/token" ) -func TestLexer(t *testing.T) { +func TestLexerWithInsertCommand(t *testing.T) { input := ` CREATE TABLE 1tbl( one TEXT , two INT ); @@ -55,24 +55,44 @@ func TestLexer(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } + runLexerTestSuite(t, input, tests) +} - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } +func TestLexerWithUpdateCommand(t *testing.T) { + input := + ` + UPDATE table1 + SET column_name_1 TO 'UPDATE', column_name_2 TO 42 + WHERE column_name_3 EQUAL 1; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.UPDATE, "UPDATE"}, + {token.IDENT, "table1"}, + {token.SET, "SET"}, + {token.IDENT, "column_name_1"}, + {token.TO, "TO"}, + {token.APOSTROPHE, "'"}, + {token.IDENT, "UPDATE"}, + {token.APOSTROPHE, "'"}, + {token.COMMA, ","}, + {token.IDENT, "column_name_2"}, + {token.TO, "TO"}, + {token.LITERAL, "42"}, + {token.WHERE, "WHERE"}, + {token.IDENT, "column_name_3"}, + {token.EQUAL, "EQUAL"}, + {token.LITERAL, "1"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, } + + runLexerTestSuite(t, input, tests) } -func TestLexerWithNumbersMixedInLitterals(t *testing.T) { +func TestLexerWithNumbersMixedInLiterals(t *testing.T) { input := ` CREATE TABLE tbl2( one TEXT , two INT ); @@ -121,21 +141,7 @@ func TestLexerWithNumbersMixedInLitterals(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } - - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } - } + runLexerTestSuite(t, input, tests) } func TestLexerWithNumbersWithWhitespacesIdentifier(t *testing.T) { @@ -187,21 +193,7 @@ func TestLexerWithNumbersWithWhitespacesIdentifier(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } - - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } - } + runLexerTestSuite(t, input, tests) } func TestLogicalStatements(t *testing.T) { @@ -231,21 +223,45 @@ func TestLogicalStatements(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } + runLexerTestSuite(t, input, tests) +} - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } +func TestInStatement(t *testing.T) { + input := + ` + WHERE two IN (1, 2) AND + WHERE three NOTIN ('one', 'two'); + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.WHERE, "WHERE"}, + {token.IDENT, "two"}, + {token.IN, "IN"}, + {token.LPAREN, "("}, + {token.LITERAL, "1"}, + {token.COMMA, ","}, + {token.LITERAL, "2"}, + {token.RPAREN, ")"}, + {token.AND, "AND"}, + {token.WHERE, "WHERE"}, + {token.IDENT, "three"}, + {token.NOTIN, "NOTIN"}, + {token.LPAREN, "("}, + {token.APOSTROPHE, "'"}, + {token.IDENT, "one"}, + {token.APOSTROPHE, "'"}, + {token.COMMA, ","}, + {token.APOSTROPHE, "'"}, + {token.IDENT, "two"}, + {token.APOSTROPHE, "'"}, + {token.RPAREN, ")"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, } + + runLexerTestSuite(t, input, tests) } func TestDeleteStatement(t *testing.T) { @@ -267,21 +283,7 @@ func TestDeleteStatement(t *testing.T) { {token.EOF, ""}, } - l := RunLexer(input) - - for i, tt := range tests { - tok := l.NextToken() - - if tok.Type != tt.expectedType { - t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q", - i, tt.expectedType, tok.Type) - } - - if tok.Literal != tt.expectedLiteral { - t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q", - i, tt.expectedLiteral, tok.Literal) - } - } + runLexerTestSuite(t, input, tests) } func TestOrderByStatement(t *testing.T) { @@ -302,6 +304,261 @@ func TestOrderByStatement(t *testing.T) { {token.EOF, ""}, } + runLexerTestSuite(t, input, tests) +} + +func TestDropStatement(t *testing.T) { + input := `DROP TABLE table;` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.DROP, "DROP"}, + {token.TABLE, "TABLE"}, + {token.IDENT, "table"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestLimitAndOffsetStatement(t *testing.T) { + input := `LIMIT 5 OFFSET 6;` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.LIMIT, "LIMIT"}, + {token.LITERAL, "5"}, + {token.OFFSET, "OFFSET"}, + {token.LITERAL, "6"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestAggregateFunctions(t *testing.T) { + input := `SELECT MIN(colOne), MAX(colOne), COUNT(colOne), SUM(colOne), AVG(colOne) FROM tbl;` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.MIN, "MIN"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.COMMA, ","}, + {token.MAX, "MAX"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.COMMA, ","}, + {token.COUNT, "COUNT"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.COMMA, ","}, + {token.SUM, "SUM"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.COMMA, ","}, + {token.AVG, "AVG"}, + {token.LPAREN, "("}, + {token.IDENT, "colOne"}, + {token.RPAREN, ")"}, + {token.FROM, "FROM"}, + {token.IDENT, "tbl"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestSelectWithDistinct(t *testing.T) { + input := `SELECT DISTINCT * FROM table;` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.DISTINCT, "DISTINCT"}, + {token.ASTERISK, "*"}, + {token.FROM, "FROM"}, + {token.IDENT, "table"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestDefaultJoin(t *testing.T) { + input := ` SELECT title FROM books + JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestInnerJoin(t *testing.T) { + input := ` SELECT title FROM books + INNER JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.INNER, "INNER"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestLeftJoin(t *testing.T) { + input := ` SELECT title FROM books + LEFT JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.LEFT, "LEFT"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestRightJoin(t *testing.T) { + input := ` SELECT title FROM books + RIGHT JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.RIGHT, "RIGHT"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestFullJoin(t *testing.T) { + input := ` SELECT title FROM books + FULL JOIN authors ON + books.author_id EQUAL authors.author_id; + ` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.SELECT, "SELECT"}, + {token.IDENT, "title"}, + {token.FROM, "FROM"}, + {token.IDENT, "books"}, + {token.FULL, "FULL"}, + {token.JOIN, "JOIN"}, + {token.IDENT, "authors"}, + {token.ON, "ON"}, + {token.IDENT, "books.author_id"}, + {token.EQUAL, "EQUAL"}, + {token.IDENT, "authors.author_id"}, + {token.SEMICOLON, ";"}, + {token.EOF, ""}, + } + + runLexerTestSuite(t, input, tests) +} + +func TestHandlingNullValues(t *testing.T) { + input := `INSERT INTO tbl VALUES( 'NULL', NULL );` + tests := []struct { + expectedType token.Type + expectedLiteral string + }{ + {token.INSERT, "INSERT"}, + {token.INTO, "INTO"}, + {token.IDENT, "tbl"}, + {token.VALUES, "VALUES"}, + {token.LPAREN, "("}, + {token.APOSTROPHE, "'"}, + {token.IDENT, "NULL"}, + {token.APOSTROPHE, "'"}, + {token.COMMA, ","}, + {token.NULL, "NULL"}, + {token.RPAREN, ")"}, + {token.SEMICOLON, ";"}, + } + + runLexerTestSuite(t, input, tests) +} + +func runLexerTestSuite(t *testing.T, input string, tests []struct { + expectedType token.Type + expectedLiteral string +}) { l := RunLexer(input) for i, tt := range tests { @@ -317,5 +574,4 @@ func TestOrderByStatement(t *testing.T) { i, tt.expectedLiteral, tok.Literal) } } - } diff --git a/main.go b/main.go index d6e6497..93a59dc 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,7 @@ package main import ( "flag" - "github.com/LissaGreense/GO4SQL/ast" + "fmt" "github.com/LissaGreense/GO4SQL/engine" "github.com/LissaGreense/GO4SQL/modes" "log" @@ -16,100 +16,19 @@ func main() { flag.Parse() engineSQL := engine.New() + var err error if len(*filePath) > 0 { - modes.HandleFileMode(*filePath, engineSQL, evaluateInEngine) + err = modes.HandleFileMode(*filePath, engineSQL) } else if *streamMode { - modes.HandleStreamMode(engineSQL, evaluateInEngine) + err = modes.HandleStreamMode(engineSQL) } else if *socketMode { - modes.HandleSocketMode(*port, engineSQL, evaluateInEngine) + modes.HandleSocketMode(*port, engineSQL) } else { - log.Println("No mode has been providing. Exiting.") + err = fmt.Errorf("no mode has been providing, exiting") } -} - -func evaluateInEngine(sequences *ast.Sequence, engineSQL *engine.DbEngine) string { - commands := sequences.Commands - - result := "" - for commandIndex, command := range commands { - - // TODO: Check if those statements are necessary - _, whereCommandIsValid := command.(*ast.WhereCommand) - if whereCommandIsValid { - continue - } - - _, orderByCommandIsValid := command.(*ast.OrderByCommand) - if orderByCommandIsValid { - continue - } - - createCommand, createCommandIsValid := command.(*ast.CreateCommand) - if createCommandIsValid { - engineSQL.CreateTable(createCommand) - result += "Table '" + createCommand.Name.GetToken().Literal + "' has been created\n" - continue - } - - insertCommand, insertCommandIsValid := command.(*ast.InsertCommand) - if insertCommandIsValid { - engineSQL.InsertIntoTable(insertCommand) - result += "Data Inserted\n" - continue - } - - selectCommand, selectCommandIsValid := command.(*ast.SelectCommand) - if selectCommandIsValid { - result += getSelectResponse(commandIndex, commands, engineSQL, selectCommand) + "\n" - continue - } - - deleteCommand, deleteCommandIsValid := command.(*ast.DeleteCommand) - if deleteCommandIsValid { - nextCommandIndex := commandIndex + 1 - - if nextCommandIndex != len(commands) { - whereCommand, whereCommandIsValid := commands[nextCommandIndex].(*ast.WhereCommand) - - if whereCommandIsValid { - engineSQL.DeleteFromTable(deleteCommand, whereCommand) - } - } - result += "Data from '" + deleteCommand.Name.GetToken().Literal + "' has been deleted\n" - continue - } + if err != nil { + log.Fatal(err) } - - return result -} - -func getSelectResponse(commandIndex int, commands []ast.Command, engineSQL *engine.DbEngine, selectCommand *ast.SelectCommand) string { - nextCommandIndex := commandIndex + 1 - - if nextCommandIndex != len(commands) { - whereCommand, whereCommandIsValid := commands[nextCommandIndex].(*ast.WhereCommand) - - // TODO: It cannot be like that. Have to be refactored to tree structure. - if whereCommandIsValid { - if nextCommandIndex+1 < len(commands) { - orderByCommand, orderByCommandIsValid := commands[nextCommandIndex+1].(*ast.OrderByCommand) - - if orderByCommandIsValid { - return engineSQL.SelectFromTableWithWhereAndOrderBy(selectCommand, whereCommand, orderByCommand).ToString() - } - } - - return engineSQL.SelectFromTableWithWhere(selectCommand, whereCommand).ToString() - } - - orderByCommand, orderByCommandIsValid := commands[nextCommandIndex].(*ast.OrderByCommand) - - if orderByCommandIsValid { - return engineSQL.SelectFromTableWithOrderBy(selectCommand, orderByCommand).ToString() - } - } - - return engineSQL.SelectFromTable(selectCommand).ToString() } diff --git a/modes/handler.go b/modes/handler.go index 4e4116f..d787202 100644 --- a/modes/handler.go +++ b/modes/handler.go @@ -7,7 +7,6 @@ import ( "github.com/LissaGreense/GO4SQL/engine" "github.com/LissaGreense/GO4SQL/lexer" "github.com/LissaGreense/GO4SQL/parser" - "io/ioutil" "log" "net" "os" @@ -15,36 +14,49 @@ import ( ) // HandleFileMode - Handle GO4SQL use case where client sends input via text file -func HandleFileMode(filePath string, engine *engine.DbEngine, evaluate func(sequences *ast.Sequence, engineSQL *engine.DbEngine) string) { - content, err := ioutil.ReadFile(filePath) +func HandleFileMode(filePath string, engine *engine.DbEngine) error { + content, err := os.ReadFile(filePath) if err != nil { - log.Fatal(err) + return err } - - sequences := bytesToSequences(content) - fmt.Print(evaluate(sequences, engine)) + sequences, err := bytesToSequences(content) + if err != nil { + return err + } + evaluate, err := engine.Evaluate(sequences) + if err != nil { + return err + } + fmt.Print(evaluate) + return nil } // HandleStreamMode - Handle GO4SQL use case where client sends input via stdin -func HandleStreamMode(engine *engine.DbEngine, evaluate func(sequences *ast.Sequence, engineSQL *engine.DbEngine) string) { +func HandleStreamMode(engine *engine.DbEngine) error { reader := bufio.NewScanner(os.Stdin) for reader.Scan() { - sequences := bytesToSequences(reader.Bytes()) - fmt.Print(evaluate(sequences, engine)) - } - err := reader.Err() - if err != nil { - log.Fatal(err) + sequences, err := bytesToSequences(reader.Bytes()) + if err != nil { + fmt.Print(err) + } else { + evaluate, err := engine.Evaluate(sequences) + if err != nil { + fmt.Print(err) + } else { + fmt.Print(evaluate) + } + } } + return reader.Err() } // HandleSocketMode - Handle GO4SQL use case where client sends input via socket protocol -func HandleSocketMode(port int, engine *engine.DbEngine, evaluate func(sequences *ast.Sequence, engineSQL *engine.DbEngine) string) { +func HandleSocketMode(port int, engine *engine.DbEngine) { listener, err := net.Listen("tcp", "localhost:"+strconv.Itoa(port)) log.Printf("Starting Socket Server on %d port\n", port) if err != nil { - log.Fatal("Error:", err) + log.Fatal(err.Error()) } defer func(listener net.Listener) { @@ -61,19 +73,18 @@ func HandleSocketMode(port int, engine *engine.DbEngine, evaluate func(sequences continue } - go handleSocketClient(conn, engine, evaluate) + go handleSocketClient(conn, engine) } } -func bytesToSequences(content []byte) *ast.Sequence { +func bytesToSequences(content []byte) (*ast.Sequence, error) { lex := lexer.RunLexer(string(content)) parserInstance := parser.New(lex) - sequences := parserInstance.ParseSequence() - - return sequences + sequences, err := parserInstance.ParseSequence() + return sequences, err } -func handleSocketClient(conn net.Conn, engine *engine.DbEngine, evaluate func(sequences *ast.Sequence, engineSQL *engine.DbEngine) string) { +func handleSocketClient(conn net.Conn, engine *engine.DbEngine) { defer func(conn net.Conn) { err := conn.Close() if err != nil { @@ -86,19 +97,26 @@ func handleSocketClient(conn net.Conn, engine *engine.DbEngine, evaluate func(se for { n, err := conn.Read(buffer) if err != nil && err.Error() != "EOF" { - log.Fatal("Error:", err) + log.Fatal(err.Error()) + } + sequences, err := bytesToSequences(buffer) + + if err != nil { + log.Fatal(err.Error()) } - sequences := bytesToSequences(buffer) - commandResult := evaluate(sequences, engine) - if len(commandResult) > 0 { + commandResult, err := engine.Evaluate(sequences) + + if err != nil { + _, err = conn.Write([]byte(err.Error())) + } else if len(commandResult) > 0 { _, err = conn.Write([]byte(commandResult)) } if err != nil { - log.Fatal("Error:", err) + log.Fatal(err.Error()) } - fmt.Printf("Received: %s\n", buffer[:n]) + log.Printf("Received: %s\n", buffer[:n]) } } diff --git a/parser/errors.go b/parser/errors.go new file mode 100644 index 0000000..8bc04fd --- /dev/null +++ b/parser/errors.go @@ -0,0 +1,127 @@ +package parser + +// SyntaxError - error thrown when parser was expecting different token from lexer +type SyntaxError struct { + expecting []string + got string +} + +func (m *SyntaxError) Error() string { + var expectingText string + + if len(m.expecting) == 1 { + expectingText = m.expecting[0] + } else { + for i, expected := range m.expecting { + expectingText += expected + if i != len(m.expecting)-1 { + expectingText += ", " + } + } + } + + return "syntax error, expecting: {" + expectingText + "}, got: {" + m.got + "}" +} + +// SyntaxCommandExpectedError - error thrown when there was command that logically should only +// appear after certain different command, but it wasn't found +type SyntaxCommandExpectedError struct { + command string + neededCommands []string +} + +func (m *SyntaxCommandExpectedError) Error() string { + var neededCommandsText string + + if len(neededCommandsText) == 1 { + neededCommandsText = m.neededCommands[0] + " command" + } else if len(neededCommandsText) == 2 { + neededCommandsText = m.neededCommands[0] + " or " + m.neededCommands[1] + " commands" + } else { + for i, command := range m.neededCommands { + if i == len(m.neededCommands)-1 { + neededCommandsText += " or " + } + + neededCommandsText += command + + if i != len(m.neededCommands)-1 || i != len(m.neededCommands)-2 { + neededCommandsText += ", " + } + } + neededCommandsText += " commands" + } + + return "syntax error, {" + m.command + "} command needs {" + neededCommandsText + "} before" +} + +// SyntaxInvalidCommandError - error thrown when invalid (non-existing) type of command has been +// found +type SyntaxInvalidCommandError struct { + invalidCommand string +} + +func (m *SyntaxInvalidCommandError) Error() string { + return "syntax error, invalid command found: {" + m.invalidCommand + "}" +} + +// LogicalExpressionParsingError - error thrown when logical expression inside WHERE statement +// couldn't be parsed correctly +type LogicalExpressionParsingError struct { + afterToken *string +} + +func (m *LogicalExpressionParsingError) Error() string { + errorMsg := "syntax error, logical expression within WHERE command couldn't be parsed correctly" + if m.afterToken != nil { + return errorMsg + ", after {" + *m.afterToken + "} character" + } + return errorMsg +} + +// ArithmeticLessThanZeroParserError - error thrown when parser found integer value that shouldn't +// be less than 0, but it is +type ArithmeticLessThanZeroParserError struct { + variable string +} + +func (m *ArithmeticLessThanZeroParserError) Error() string { + return "syntax error, {" + m.variable + "} value should be more than 0" +} + +// NoPredecessorParserError - error thrown when parser found integer value that shouldn't +// be less than 0, but it is +type NoPredecessorParserError struct { + command string +} + +func (m *NoPredecessorParserError) Error() string { + return "syntax error, {" + m.command + "} command can't be used without predecessor" +} + +// IllegalPeriodInIdentParserError - error thrown when parser found period in ident when parsing create command +type IllegalPeriodInIdentParserError struct { + name string +} + +func (m *IllegalPeriodInIdentParserError) Error() string { + return "syntax error, {" + m.name + "} shouldn't contain '.'" +} + +// NoApostropheOnRightParserError - error thrown when parser found no apostrophe on right of ident +type NoApostropheOnRightParserError struct { + ident string +} + +func (m *NoApostropheOnRightParserError) Error() string { + return "syntax error, Identifier: {" + m.ident + "} has no apostrophe on right" +} + +// NoApostropheOnLeftParserError - error thrown when parser found no apostrophe on left of ident +type NoApostropheOnLeftParserError struct { + ident string +} + +func (m *NoApostropheOnLeftParserError) Error() string { + return "syntax error, Identifier: {" + m.ident + "} has no apostrophe on left" +} diff --git a/parser/parser.go b/parser/parser.go index dc8d42d..3cc7e0e 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -1,27 +1,29 @@ package parser import ( - "log" - "github.com/LissaGreense/GO4SQL/ast" "github.com/LissaGreense/GO4SQL/lexer" "github.com/LissaGreense/GO4SQL/token" + "strconv" + "strings" ) // Parser - Contain token that is currently analyzed by parser and the next one. Lexer is used to tokenize the client // text input. type Parser struct { - lexer *lexer.Lexer + lexer lexer.Lexer currentToken token.Token peekToken token.Token } // New - Return new Parser struct func New(lexer *lexer.Lexer) *Parser { - p := &Parser{lexer: lexer} + p := &Parser{lexer: *lexer} + // Read two tokens, so curToken and peekToken are both set p.nextToken() p.nextToken() + return p } @@ -32,24 +34,24 @@ func (parser *Parser) nextToken() { } // validateTokenAndSkip - Check if current token type is appearing in provided expectedTokens array then move to the next token -func validateTokenAndSkip(parser *Parser, expectedTokens []token.Type) { - validateToken(parser.currentToken.Type, expectedTokens) +func validateTokenAndSkip(parser *Parser, expectedTokens []token.Type) error { + err := validateToken(parser.currentToken.Type, expectedTokens) + + if err != nil { + return err + } // Ignore validated token parser.nextToken() + return nil } // validateToken - Check if current token type is appearing in provided expectedTokens array -func validateToken(tokenType token.Type, expectedTokens []token.Type) { +func validateToken(tokenType token.Type, expectedTokens []token.Type) error { var contains = false - var tokensPrintMessage = "" - for i, x := range expectedTokens { - - if i == 0 { - tokensPrintMessage += string(x) - } else { - tokensPrintMessage += ", or: " + string(x) - } + expectedTokensStrings := make([]string, 0) + for _, x := range expectedTokens { + expectedTokensStrings = append(expectedTokensStrings, string(x)) if x == tokenType { contains = true @@ -57,34 +59,56 @@ func validateToken(tokenType token.Type, expectedTokens []token.Type) { } } if !contains { - log.Fatal("Syntax error, expecting: ", tokensPrintMessage, ", got: ", tokenType) + return &SyntaxError{expectedTokensStrings, string(tokenType)} } + return nil } // parseCreateCommand - Return ast.CreateCommand created from tokens and validate the syntax // // Example of input parsable to the ast.CreateCommand: // create table tbl( one TEXT , two INT ); -func (parser *Parser) parseCreateCommand() ast.Command { // TODO make it return the pointer +func (parser *Parser) parseCreateCommand() (ast.Command, error) { // token.CREATE already at current position in parser createCommand := &ast.CreateCommand{Token: parser.currentToken} // Skip token.CREATE parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.TABLE}) + err := validateTokenAndSkip(parser, []token.Type{token.TABLE}) + if err != nil { + return nil, err + } + + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - createCommand.Name = &ast.Identifier{Token: parser.currentToken} + if strings.Contains(parser.currentToken.Literal, ".") { + return nil, &IllegalPeriodInIdentParserError{name: parser.currentToken.Literal} + } + createCommand.Name = ast.Identifier{Token: parser.currentToken} // Skip token.IDENT parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + err = validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + if err != nil { + return nil, err + } // Begin of inside Paren for parser.currentToken.Type == token.IDENT { - validateToken(parser.peekToken.Type, []token.Type{token.TEXT, token.INT}) + err = validateToken(parser.peekToken.Type, []token.Type{token.TEXT, token.INT}) + if err != nil { + return nil, err + } + + if strings.Contains(parser.currentToken.Literal, ".") { + return nil, &IllegalPeriodInIdentParserError{name: parser.currentToken.Literal} + } + createCommand.ColumnNames = append(createCommand.ColumnNames, parser.currentToken.Literal) createCommand.ColumnTypes = append(createCommand.ColumnTypes, parser.peekToken) @@ -102,16 +126,24 @@ func (parser *Parser) parseCreateCommand() ast.Command { // TODO make it return } // End of inside Paren - validateTokenAndSkip(parser, []token.Type{token.RPAREN}) - validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) + if err != nil { + return nil, err + } + err = validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + if err != nil { + return nil, err + } - return createCommand + return createCommand, nil } -func (parser *Parser) skipIfCurrentTokenIsApostrophe() { +func (parser *Parser) skipIfCurrentTokenIsApostrophe() bool { if parser.currentToken.Type == token.APOSTROPHE { parser.nextToken() + return true } + return false } func (parser *Parser) skipIfCurrentTokenIsSemicolon() { @@ -124,32 +156,54 @@ func (parser *Parser) skipIfCurrentTokenIsSemicolon() { // // Example of input parsable to the ast.InsertCommand: // insert into tbl values( 'hello', 10 ); -func (parser *Parser) parseInsertCommand() ast.Command { +func (parser *Parser) parseInsertCommand() (ast.Command, error) { // token.INSERT already at current position in parser insertCommand := &ast.InsertCommand{Token: parser.currentToken} // Ignore token.INSERT parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.INTO}) + err := validateTokenAndSkip(parser, []token.Type{token.INTO}) + if err != nil { + return nil, err + } - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - insertCommand.Name = &ast.Identifier{Token: parser.currentToken} + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + insertCommand.Name = ast.Identifier{Token: parser.currentToken} // Ignore token.INDENT parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.VALUES}) - validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + err = validateTokenAndSkip(parser, []token.Type{token.VALUES}) + if err != nil { + return nil, err + } + err = validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + if err != nil { + return nil, err + } - for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.APOSTROPHE { - // TODO: Add apostrophe validation - parser.skipIfCurrentTokenIsApostrophe() + for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() - validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL}) - insertCommand.Values = append(insertCommand.Values, parser.currentToken) - // Ignore token.IDENT or token.LITERAL + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) + if err != nil { + return nil, err + } + value := parser.currentToken + insertCommand.Values = append(insertCommand.Values, value) + // Ignore token.IDENT, token.LITERAL or token.NULL parser.nextToken() - parser.skipIfCurrentTokenIsApostrophe() + + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + err = validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, value) + + if err != nil { + return nil, err + } if parser.currentToken.Type != token.COMMA { break @@ -159,32 +213,88 @@ func (parser *Parser) parseInsertCommand() ast.Command { parser.nextToken() } - validateTokenAndSkip(parser, []token.Type{token.RPAREN}) - validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) - return insertCommand + err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) + if err != nil { + return nil, err + } + + err = validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + if err != nil { + return nil, err + } + + return insertCommand, nil +} + +func validateApostropheWrapping(startedWithApostrophe bool, finishedWithApostrophe bool, value token.Token) error { + if startedWithApostrophe && !finishedWithApostrophe { + return &NoApostropheOnRightParserError{ident: value.Literal} + } else if !startedWithApostrophe && finishedWithApostrophe { + return &NoApostropheOnLeftParserError{ident: value.Literal} + } + return nil } // parseSelectCommand - Return ast.SelectCommand created from tokens and validate the syntax // // Example of input parsable to the ast.SelectCommand: // SELECT col1, col2, col3 FROM tbl; -func (parser *Parser) parseSelectCommand() ast.Command { +func (parser *Parser) parseSelectCommand() (ast.Command, error) { // token.SELECT already at current position in parser selectCommand := &ast.SelectCommand{Token: parser.currentToken} // Ignore token.SELECT parser.nextToken() - if parser.currentToken.Type == token.ASTERISK { - selectCommand.Space = append(selectCommand.Space, parser.currentToken) + // optional DISTINCT + if parser.currentToken.Type == token.DISTINCT { + selectCommand.HasDistinct = true + + // Ignore token.DISTINCT parser.nextToken() + } + + err := validateToken(parser.currentToken.Type, []token.Type{token.ASTERISK, token.IDENT, token.MAX, token.MIN, token.SUM, token.AVG, token.COUNT}) + if err != nil { + return nil, err + } + if parser.currentToken.Type == token.ASTERISK { + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken}) + parser.nextToken() } else { - for parser.currentToken.Type == token.IDENT { - // Get column name - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - selectCommand.Space = append(selectCommand.Space, parser.currentToken) - parser.nextToken() + for parser.currentToken.Type == token.IDENT || isAggregateFunction(parser.currentToken.Type) { + if parser.currentToken.Type != token.IDENT { + aggregateFunction := parser.currentToken + parser.nextToken() + err := validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + if err != nil { + return nil, err + } + if aggregateFunction.Type == token.COUNT { + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.ASTERISK}) + } else { + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + } + if err != nil { + return nil, err + } + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken, AggregateFunc: &aggregateFunction}) + parser.nextToken() + + err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) + if err != nil { + return nil, err + } + } else { + // Get column name + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken}) + parser.nextToken() + } if parser.currentToken.Type != token.COMMA { break @@ -194,97 +304,177 @@ func (parser *Parser) parseSelectCommand() ast.Command { } } - validateTokenAndSkip(parser, []token.Type{token.FROM}) + err = validateTokenAndSkip(parser, []token.Type{token.FROM}) + if err != nil { + return nil, err + } - selectCommand.Name = &ast.Identifier{Token: parser.currentToken} - // Ignore token.INDENT + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + + selectCommand.Name = ast.Identifier{Token: parser.currentToken} + // Ignore token.IDENT parser.nextToken() - // expect SEMICOLON or WHERE - validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE, token.ORDER}) + // expect SEMICOLON or other keywords expected in SELECT statement + err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE, token.ORDER, token.LIMIT, token.OFFSET, token.JOIN, token.LEFT, token.RIGHT, token.INNER, token.FULL}) + if err != nil { + return nil, err + } if parser.currentToken.Type == token.SEMICOLON { parser.nextToken() } - return selectCommand + return selectCommand, nil +} + +func (parser *Parser) getColumnName(err error, selectCommand *ast.SelectCommand, aggregateFunction token.Token) error { + // Get column name + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.ASTERISK}) + if err != nil { + return err + } + selectCommand.Space = append(selectCommand.Space, ast.Space{ColumnName: parser.currentToken, AggregateFunc: &aggregateFunction}) + parser.nextToken() + return nil +} + +func isAggregateFunction(t token.Type) bool { + return t == token.MIN || t == token.MAX || t == token.COUNT || t == token.SUM || t == token.AVG } // parseWhereCommand - Return ast.WhereCommand created from tokens and validate the syntax // // Example of input parsable to the ast.WhereCommand: // WHERE colName EQUAL 'potato' -func (parser *Parser) parseWhereCommand() ast.Command { +func (parser *Parser) parseWhereCommand() (ast.Command, error) { // token.WHERE already at current position in parser whereCommand := &ast.WhereCommand{Token: parser.currentToken} expressionIsValid := false // Ignore token.WHERE parser.nextToken() - - expressionIsValid, whereCommand.Expression = parser.getExpression() + var err error + expressionIsValid, whereCommand.Expression, err = parser.getExpression() + if err != nil { + return nil, err + } if !expressionIsValid { - log.Fatal("Expression withing Where statement couldn't be parsed correctly") + return nil, &LogicalExpressionParsingError{} } - validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.ORDER}) + err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.ORDER}) + if err != nil { + return nil, err + } parser.skipIfCurrentTokenIsSemicolon() - return whereCommand + return whereCommand, nil } // parseDeleteCommand - Return ast.DeleteCommand created from tokens and validate the syntax // // Example of input parsable to the ast.DeleteCommand: -// DELETE FROM table -func (parser *Parser) parseDeleteCommand() ast.Command { +// DELETE FROM table; +func (parser *Parser) parseDeleteCommand() (ast.Command, error) { // token.DELETE already at current position in parser deleteCommand := &ast.DeleteCommand{Token: parser.currentToken} // token.DELETE no longer needed parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.FROM}) + err := validateTokenAndSkip(parser, []token.Type{token.FROM}) + if err != nil { + return nil, err + } - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) - deleteCommand.Name = &ast.Identifier{Token: parser.currentToken} + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + deleteCommand.Name = ast.Identifier{Token: parser.currentToken} // token.IDENT no longer needed parser.nextToken() // expect WHERE - validateToken(parser.currentToken.Type, []token.Type{token.WHERE}) + err = validateToken(parser.currentToken.Type, []token.Type{token.WHERE}) - return deleteCommand + return deleteCommand, err +} + +// parseDropCommand - Return ast.DropCommand created from tokens and validate the syntax +// +// Example of input parsable to the ast.DropCommand: +// DROP TABLE table; +func (parser *Parser) parseDropCommand() (ast.Command, error) { + // token.DROP already at current position in parser + dropCommand := &ast.DropCommand{Token: parser.currentToken} + + // token.DROP no longer needed + parser.nextToken() + + err := validateTokenAndSkip(parser, []token.Type{token.TABLE}) + if err != nil { + return nil, err + } + + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + dropCommand.Name = ast.Identifier{Token: parser.currentToken} + + // token.IDENT no longer needed + parser.nextToken() + + err = validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + + return dropCommand, err } // parseOrderByCommand - Return ast.OrderByCommand created from tokens and validate the syntax // // Example of input parsable to the ast.OrderByCommand: // ORDER BY colName ASC -func (parser *Parser) parseOrderByCommand() ast.Command { +func (parser *Parser) parseOrderByCommand() (ast.Command, error) { // token.ORDER already at current position in parser orderCommand := &ast.OrderByCommand{Token: parser.currentToken} // token.ORDER no longer needed parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.BY}) + err := validateTokenAndSkip(parser, []token.Type{token.BY}) + if err != nil { + return nil, err + } // ensure that loop below will execute at least once - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } // array of SortPattern for parser.currentToken.Type == token.IDENT { // Get column name - validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } columnName := parser.currentToken parser.nextToken() // Get ASC or DESC - validateToken(parser.currentToken.Type, []token.Type{token.ASC, token.DESC}) + err = validateToken(parser.currentToken.Type, []token.Type{token.ASC, token.DESC}) + if err != nil { + return nil, err + } order := parser.currentToken parser.nextToken() @@ -298,186 +488,534 @@ func (parser *Parser) parseOrderByCommand() ast.Command { parser.nextToken() } - validateTokenAndSkip(parser, []token.Type{token.SEMICOLON}) + parser.skipIfCurrentTokenIsSemicolon() - return orderCommand + return orderCommand, nil } -// getExpression - Return proper structure of ast.Expression and validate the syntax +// parseLimitCommand - Return ast.LimitCommand created from tokens and validate the syntax // -// Available expressions: -// - ast.OperationExpression -// - ast.BooleanExpression -// - ast.ConditionExpression -func (parser *Parser) getExpression() (bool, ast.Expression) { - booleanExpressionExists, booleanExpression := parser.getBooleanExpression() +// Example of input parsable to the ast.LimitCommand: +// LIMIT 10 +func (parser *Parser) parseLimitCommand() (ast.Command, error) { + // token.LIMIT already at current position in parser + limitCommand := &ast.LimitCommand{Token: parser.currentToken} - conditionalExpressionExists, conditionalExpression := parser.getConditionalExpression() - - operationExpressionExists, operationExpression := parser.getOperationExpression(booleanExpressionExists, conditionalExpressionExists, booleanExpression, conditionalExpression) + // token.LIMIT no longer needed + parser.nextToken() - if operationExpressionExists { - return true, operationExpression + err := validateToken(parser.currentToken.Type, []token.Type{token.LITERAL}) + if err != nil { + return nil, err } - if conditionalExpressionExists { - return true, conditionalExpression + // convert count number to int + count, err := strconv.Atoi(parser.currentToken.Literal) + if err != nil { + return nil, err } - if booleanExpressionExists { - return true, booleanExpression + if count < 0 { + return nil, &ArithmeticLessThanZeroParserError{variable: "limit"} } - return false, nil + limitCommand.Count = count + + // Skip token.IDENT + parser.nextToken() + + parser.skipIfCurrentTokenIsSemicolon() + + return limitCommand, nil } -// getOperationExpression - Return ast.OperationExpression created from tokens and validate the syntax -func (parser *Parser) getOperationExpression(booleanExpressionExists bool, conditionalExpressionExists bool, booleanExpression *ast.BooleanExpression, conditionalExpression *ast.ConditionExpression) (bool, *ast.OperationExpression) { - operationExpression := &ast.OperationExpression{} +// parseOffsetCommand - Return ast.OffsetCommand created from tokens and validate the syntax +// +// Example of input parsable to the ast.LimitCommand: +// OFFSET 10 +func (parser *Parser) parseOffsetCommand() (ast.Command, error) { + // token.OFFSET already at current position in parser + offsetCommand := &ast.OffsetCommand{Token: parser.currentToken} + + // token.OFFSET no longer needed + parser.nextToken() + + err := validateToken(parser.currentToken.Type, []token.Type{token.LITERAL}) + if err != nil { + return nil, err + } + + count, err := strconv.Atoi(parser.currentToken.Literal) + if err != nil { + return nil, err + } + if count < 0 { + return nil, &ArithmeticLessThanZeroParserError{variable: "offset"} + } + + offsetCommand.Count = count + + // Skip token.IDENT + parser.nextToken() + + parser.skipIfCurrentTokenIsSemicolon() - if (booleanExpressionExists || conditionalExpressionExists) && (parser.currentToken.Type == token.OR || parser.currentToken.Type == token.AND) { - if booleanExpressionExists { - operationExpression.Left = booleanExpression + return offsetCommand, nil +} + +// parseJoinCommand - Return ast.JoinCommand created from tokens and validate the syntax +// +// Example of input parsable to the ast.JoinCommand: +// JOIN table on table.one EQUAL table2.one; +func (parser *Parser) parseJoinCommand() (ast.Command, error) { + // parser has either token.JOIN, token.LEFT, token.RIGHT, token.INNER or token.FULL + var joinCommand *ast.JoinCommand + + if parser.currentToken.Type == token.JOIN { + joinCommand = &ast.JoinCommand{Token: parser.currentToken} + joinCommand.JoinType = token.Token{Type: token.INNER, Literal: token.INNER} + } else { + joinTypeTokenType := parser.currentToken + parser.nextToken() + err := validateToken(parser.currentToken.Type, []token.Type{token.JOIN}) + if err != nil { + return nil, err } + joinCommand = &ast.JoinCommand{Token: parser.currentToken} + joinCommand.JoinType = joinTypeTokenType + } + + // token.JOIN no longer needed + parser.nextToken() + + err := validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + + joinCommand.Name = ast.Identifier{Token: parser.currentToken} + parser.nextToken() + + err = validateTokenAndSkip(parser, []token.Type{token.ON}) + if err != nil { + return nil, err + } + + var expressionIsValid bool + expressionIsValid, joinCommand.Expression, err = parser.getExpression() + if err != nil { + return nil, err + } + + if !expressionIsValid { + return nil, &LogicalExpressionParsingError{} + } + + parser.skipIfCurrentTokenIsSemicolon() + + return joinCommand, nil +} + +// parseUpdateCommand - Return ast.parseUpdateCommand created from tokens and validate the syntax +// +// Example of input parsable to the ast.parseUpdateCommand: +// UPDATE table SET col1 TO 'value' WHERE col2 EQUAL 10; +func (parser *Parser) parseUpdateCommand() (ast.Command, error) { + // token.UPDATE already at current position in parser + updateCommand := &ast.UpdateCommand{Token: parser.currentToken} + + // Ignore token.UPDATE + parser.nextToken() + + // Get table name + err := validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + updateCommand.Name = ast.Identifier{Token: parser.currentToken} + + // Ignore token.IDENT + parser.nextToken() - if conditionalExpressionExists { - operationExpression.Left = conditionalExpression + err = validateTokenAndSkip(parser, []token.Type{token.SET}) + if err != nil { + return nil, err + } + + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err + } + + updateCommand.Changes = make(map[token.Token]ast.Anonymitifier) + for parser.currentToken.Type == token.IDENT { + // Get column name + err := validateToken(parser.currentToken.Type, []token.Type{token.IDENT}) + if err != nil { + return nil, err } + colKey := parser.currentToken - operationExpression.Operation = parser.currentToken + // skip column name parser.nextToken() - expressionIsValid, expression := parser.getExpression() + err = validateToken(parser.currentToken.Type, []token.Type{token.TO}) + if err != nil { + return nil, err + } + // skip token.TO + parser.nextToken() - if !expressionIsValid { - log.Fatal("Couldn't parse right side of the OperationExpression after ", operationExpression.Operation.Literal, " token.") + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) + if err != nil { + return nil, err } + updateCommand.Changes[colKey] = ast.Anonymitifier{Token: parser.currentToken} - operationExpression.Right = expression + // skip token.IDENT, token.LITERAL or token.NULL + parser.nextToken() + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() - return true, operationExpression - } + err = validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, updateCommand.Changes[colKey].GetToken()) - return false, operationExpression -} + if err != nil { + return nil, err + } -// getBooleanExpression - Return ast.BooleanExpression created from tokens and validate the syntax -func (parser *Parser) getBooleanExpression() (bool, *ast.BooleanExpression) { - booleanExpression := &ast.BooleanExpression{} - isValid := false + if parser.currentToken.Type != token.COMMA { + break + } - if parser.currentToken.Type == token.TRUE || parser.currentToken.Type == token.FALSE { - booleanExpression.Boolean = parser.currentToken + // Skip token.COMMA parser.nextToken() - isValid = true } - return isValid, booleanExpression + err = validateToken(parser.currentToken.Type, []token.Type{token.SEMICOLON, token.WHERE}) + if err != nil { + return nil, err + } + parser.skipIfCurrentTokenIsSemicolon() + return updateCommand, nil } -// getConditionalExpression - Return ast.ConditionExpression created from tokens and validate the syntax -func (parser *Parser) getConditionalExpression() (bool, *ast.ConditionExpression) { - // TODO REFACTOR THIS - conditionalExpression := &ast.ConditionExpression{} +// getExpression - Return proper structure of ast.Expression and validate the syntax +// +// Available expressions: +// - ast.OperationExpression +// - ast.BooleanExpression +// - ast.ConditionExpression +// - ast.ContainExpression +func (parser *Parser) getExpression() (bool, ast.Expression, error) { + + if parser.currentToken.Type == token.IDENT || + parser.currentToken.Type == token.LITERAL || + parser.currentToken.Type == token.NULL || + parser.currentToken.Type == token.APOSTROPHE || + parser.currentToken.Type == token.TRUE || + parser.currentToken.Type == token.FALSE { + + leftSide, isAnonymitifier, err := parser.getExpressionLeftSideValue() + if err != nil { + return false, nil, err + } - if parser.currentToken.Type == token.IDENT { - conditionalExpression.Left = ast.Identifier{ - Token: parser.currentToken, + isValidExpression := false + var expression ast.Expression + + if parser.currentToken.Type == token.EQUAL || parser.currentToken.Type == token.NOT { + isValidExpression, expression, err = parser.getConditionalExpression(leftSide, isAnonymitifier) + } else if parser.currentToken.Type == token.IN || parser.currentToken.Type == token.NOTIN { + isValidExpression, expression, err = parser.getContainExpression(leftSide, isAnonymitifier) + } else if leftSide.Type == token.TRUE || leftSide.Type == token.FALSE { + expression = &ast.BooleanExpression{Boolean: leftSide} + isValidExpression = true + err = nil } - parser.nextToken() - } else if parser.currentToken.Type == token.APOSTROPHE { - parser.skipIfCurrentTokenIsApostrophe() + if err != nil { + return false, nil, err + } - conditionalExpression.Left = ast.Anonymitifier{ - Token: parser.currentToken, + if (parser.currentToken.Type == token.AND || parser.currentToken.Type == token.OR) && isValidExpression { + isValidExpression, expression, err = parser.getOperationExpression(expression) } - parser.nextToken() + if err != nil { + return false, nil, err + } - validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) - } else if parser.currentToken.Type == token.LITERAL { - conditionalExpression.Left = ast.Anonymitifier{ - Token: parser.currentToken, + if isValidExpression { + return true, expression, nil } + } + return false, nil, nil +} + +func (parser *Parser) getExpressionLeftSideValue() (token.Token, bool, error) { + var leftSide token.Token + isAnonymitifier := false + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + if startedWithApostrophe { + isAnonymitifier = true + value := "" + for parser.currentToken.Type != token.EOF && parser.currentToken.Type != token.APOSTROPHE { + value += parser.currentToken.Literal + parser.nextToken() + } + + leftSide = token.Token{Type: token.IDENT, Literal: value} + + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + err := validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, leftSide) + + if err != nil { + return token.Token{}, isAnonymitifier, err + } + } else { + leftSide = parser.currentToken parser.nextToken() + } + return leftSide, isAnonymitifier, nil +} +// getOperationExpression - Return ast.OperationExpression created from tokens and validate the syntax +func (parser *Parser) getOperationExpression(expression ast.Expression) (bool, *ast.OperationExpression, error) { + operationExpression := &ast.OperationExpression{} + operationExpression.Left = expression + + operationExpression.Operation = parser.currentToken + parser.nextToken() + + expressionIsValid, expression, err := parser.getExpression() + + if err != nil { + return false, nil, err + } + + if !expressionIsValid { + return false, nil, &LogicalExpressionParsingError{afterToken: &operationExpression.Operation.Literal} + } + + operationExpression.Right = expression + + return true, operationExpression, nil +} + +// getConditionalExpression - Return ast.ConditionExpression created from tokens and validate the syntax +func (parser *Parser) getConditionalExpression(leftSide token.Token, isAnonymitifier bool) (bool, *ast.ConditionExpression, error) { + conditionalExpression := &ast.ConditionExpression{Condition: parser.currentToken} + + if isAnonymitifier { + conditionalExpression.Left = ast.Anonymitifier{Token: leftSide} } else { - return false, conditionalExpression + conditionalExpression.Left = ast.Identifier{Token: leftSide} } - validateToken(parser.currentToken.Type, []token.Type{token.EQUAL, token.NOT}) - conditionalExpression.Condition = parser.currentToken + // skip EQUAL or NOT parser.nextToken() - if parser.currentToken.Type == token.IDENT { - conditionalExpression.Right = ast.Identifier{ - Token: parser.currentToken, + if parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || + parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + if !startedWithApostrophe && parser.currentToken.Type == token.IDENT { + conditionalExpression.Right = ast.Identifier{Token: parser.currentToken} + } else { + conditionalExpression.Right = ast.Anonymitifier{Token: parser.currentToken} } parser.nextToken() - } else if parser.currentToken.Type == token.APOSTROPHE { - parser.skipIfCurrentTokenIsApostrophe() - - conditionalExpression.Right = ast.Anonymitifier{ - Token: parser.currentToken, + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + err := validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, conditionalExpression.Right.GetToken()) + if err != nil { + return false, nil, err } + } else { + return false, conditionalExpression, &SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL, token.NULL}, got: parser.currentToken.Literal} + } + + return true, conditionalExpression, nil +} + +// getContainExpression - Return ast.ContainExpression created from tokens and validate the syntax +func (parser *Parser) getContainExpression(leftSide token.Token, isAnonymitifier bool) (bool, *ast.ContainExpression, error) { + containExpression := &ast.ContainExpression{} + + if isAnonymitifier { + return false, nil, &SyntaxError{expecting: []string{token.IDENT}, got: "'" + leftSide.Literal + "'"} + } + containExpression.Left = ast.Identifier{Token: leftSide} + + if parser.currentToken.Type == token.IN { + containExpression.Contains = true + } else { + containExpression.Contains = false + } + + // skip IN or NOTIN + parser.nextToken() + + err := validateTokenAndSkip(parser, []token.Type{token.LPAREN}) + if err != nil { + return false, nil, err + } + + for parser.currentToken.Type == token.IDENT || parser.currentToken.Type == token.LITERAL || parser.currentToken.Type == token.NULL || parser.currentToken.Type == token.APOSTROPHE { + startedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + err = validateToken(parser.currentToken.Type, []token.Type{token.IDENT, token.LITERAL, token.NULL}) + if err != nil { + return false, nil, err + } + currentAnonymitifier := ast.Anonymitifier{Token: parser.currentToken} + containExpression.Right = append(containExpression.Right, currentAnonymitifier) + // Ignore token.IDENT, token.LITERAL or token.NULL parser.nextToken() - validateTokenAndSkip(parser, []token.Type{token.APOSTROPHE}) + finishedWithApostrophe := parser.skipIfCurrentTokenIsApostrophe() + + err = validateApostropheWrapping(startedWithApostrophe, finishedWithApostrophe, currentAnonymitifier.GetToken()) + if err != nil { + return false, nil, err + } - } else if parser.currentToken.Type == token.LITERAL { - conditionalExpression.Right = ast.Anonymitifier{ - Token: parser.currentToken, + if parser.currentToken.Type != token.COMMA { + if parser.currentToken.Type != token.RPAREN { + return false, nil, &SyntaxError{expecting: []string{token.COMMA, token.RPAREN}, got: string(parser.currentToken.Type)} + } + break } + + // Ignore token.COMMA parser.nextToken() + } - } else { - log.Fatal("Syntax error, expecting: ", token.APOSTROPHE, ",", token.IDENT, ",", token.LITERAL, ", got: ", parser.currentToken.Literal) + err = validateTokenAndSkip(parser, []token.Type{token.RPAREN}) + if err != nil { + return false, nil, err } - return true, conditionalExpression + return true, containExpression, err } // ParseSequence - Return ast.Sequence (sequence of commands) created from client input after tokenization // // Parse tokens returned by lexer to structures defines in ast package, and it's responsible for syntax validation. -func (parser *Parser) ParseSequence() *ast.Sequence { +func (parser *Parser) ParseSequence() (*ast.Sequence, error) { // Create variable holding sequence/commands sequence := &ast.Sequence{} for parser.currentToken.Type != token.EOF { var command ast.Command + var err error switch parser.currentToken.Type { case token.CREATE: - command = parser.parseCreateCommand() + command, err = parser.parseCreateCommand() case token.INSERT: - command = parser.parseInsertCommand() + command, err = parser.parseInsertCommand() + case token.UPDATE: + command, err = parser.parseUpdateCommand() case token.SELECT: - command = parser.parseSelectCommand() + command, err = parser.parseSelectCommand() case token.DELETE: - command = parser.parseDeleteCommand() + command, err = parser.parseDeleteCommand() + case token.DROP: + command, err = parser.parseDropCommand() case token.WHERE: - if len(sequence.Commands) == 0 { - log.Fatal("Syntax error, Where Command can't be used without predecessor") + lastCommand, parserError := parser.getLastCommand(sequence, token.WHERE) + if parserError != nil { + return nil, parserError } - lastStartingToken := sequence.Commands[len(sequence.Commands)-1].TokenLiteral() - if lastStartingToken != token.SELECT && lastStartingToken != token.DELETE { - log.Fatal("Syntax error, WHERE command needs SELECT or DELETE command before") + + if lastCommand.TokenLiteral() == token.SELECT { + newCommand, err := parser.parseWhereCommand() + if err != nil { + return nil, err + } + lastCommand.(*ast.SelectCommand).WhereCommand = newCommand.(*ast.WhereCommand) + } else if lastCommand.TokenLiteral() == token.DELETE { + newCommand, err := parser.parseWhereCommand() + if err != nil { + return nil, err + } + lastCommand.(*ast.DeleteCommand).WhereCommand = newCommand.(*ast.WhereCommand) + } else if lastCommand.TokenLiteral() == token.UPDATE { + newCommand, err := parser.parseWhereCommand() + if err != nil { + return nil, err + } + lastCommand.(*ast.UpdateCommand).WhereCommand = newCommand.(*ast.WhereCommand) + } else { + return nil, &SyntaxCommandExpectedError{command: "WHERE", neededCommands: []string{"SELECT", "DELETE", "UPDATE"}} } - command = parser.parseWhereCommand() case token.ORDER: - if len(sequence.Commands) == 0 { - log.Fatal("Syntax error, Order Command can't be used without predecessor") + lastCommand, parserError := parser.getLastCommand(sequence, token.ORDER) + if parserError != nil { + return nil, parserError + } + + if lastCommand.TokenLiteral() != token.SELECT { + return nil, &SyntaxCommandExpectedError{command: "ORDER BY", neededCommands: []string{"SELECT"}} } - lastStartingToken := sequence.Commands[len(sequence.Commands)-1].TokenLiteral() - if lastStartingToken != token.SELECT && lastStartingToken != token.WHERE { - log.Fatal("Syntax error, WHERE command needs SELECT or WHERE command before") + + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseOrderByCommand() + if err != nil { + return nil, err + } + selectCommand.OrderByCommand = newCommand.(*ast.OrderByCommand) + case token.LIMIT: + lastCommand, parserError := parser.getLastCommand(sequence, token.LIMIT) + if parserError != nil { + return nil, parserError + } + if lastCommand.TokenLiteral() != token.SELECT { + return nil, &SyntaxCommandExpectedError{command: "LIMIT", neededCommands: []string{"SELECT"}} + } + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseLimitCommand() + if err != nil { + return nil, err + } + selectCommand.LimitCommand = newCommand.(*ast.LimitCommand) + case token.OFFSET: + lastCommand, parserError := parser.getLastCommand(sequence, token.OFFSET) + if parserError != nil { + return nil, parserError + } + if lastCommand.TokenLiteral() != token.SELECT { + return nil, &SyntaxCommandExpectedError{command: "OFFSET", neededCommands: []string{"SELECT"}} + } + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseOffsetCommand() + if err != nil { + return nil, err + } + selectCommand.OffsetCommand = newCommand.(*ast.OffsetCommand) + case token.JOIN, token.LEFT, token.RIGHT, token.INNER, token.FULL: + lastCommand, parserError := parser.getLastCommand(sequence, token.JOIN) + if parserError != nil { + return nil, parserError } - command = parser.parseOrderByCommand() + if lastCommand.TokenLiteral() != token.SELECT { + return nil, &SyntaxCommandExpectedError{command: "JOIN", neededCommands: []string{"SELECT"}} + } + selectCommand := lastCommand.(*ast.SelectCommand) + newCommand, err := parser.parseJoinCommand() + if err != nil { + return nil, err + } + selectCommand.JoinCommand = newCommand.(*ast.JoinCommand) default: - log.Fatal("Syntax error, invalid command found: ", parser.currentToken.Type) + return nil, &SyntaxInvalidCommandError{invalidCommand: parser.currentToken.Literal} + } + + if err != nil { + return nil, err } // Add command to the list of parsed commands @@ -486,5 +1024,13 @@ func (parser *Parser) ParseSequence() *ast.Sequence { } } - return sequence + return sequence, nil +} + +func (parser *Parser) getLastCommand(sequence *ast.Sequence, currentToken string) (ast.Command, error) { + if len(sequence.Commands) == 0 { + return nil, &NoPredecessorParserError{command: currentToken} + } + lastCommand := sequence.Commands[len(sequence.Commands)-1] + return lastCommand, nil } diff --git a/parser/parser_error_handling_test.go b/parser/parser_error_handling_test.go new file mode 100644 index 0000000..a823169 --- /dev/null +++ b/parser/parser_error_handling_test.go @@ -0,0 +1,267 @@ +package parser + +import ( + "github.com/LissaGreense/GO4SQL/lexer" + "github.com/LissaGreense/GO4SQL/token" + "testing" +) + +type errorHandlingTestSuite struct { + input string + expectedError string +} + +func TestParseCreateCommandErrorHandling(t *testing.T) { + noTableKeyword := SyntaxError{[]string{token.TABLE}, token.IDENT} + noTableName := SyntaxError{[]string{token.IDENT}, token.LPAREN} + noLeftParen := SyntaxError{[]string{token.LPAREN}, token.IDENT} + noRightParen := SyntaxError{[]string{token.RPAREN}, token.SEMICOLON} + noColumnName := SyntaxError{[]string{token.RPAREN}, token.TEXT} + noColumnType := SyntaxError{[]string{token.TEXT, token.INT}, token.COMMA} + noSemicolon := SyntaxError{[]string{token.SEMICOLON}, ""} + + tests := []errorHandlingTestSuite{ + {"CREATE tbl(one TEXT);", noTableKeyword.Error()}, + {"CREATE TABLE (one TEXT);", noTableName.Error()}, + {"CREATE TABLE tbl one TEXT);", noLeftParen.Error()}, + {"CREATE TABLE tbl (one TEXT;", noRightParen.Error()}, + {"CREATE TABLE tbl (TEXT, two INT);", noColumnName.Error()}, + {"CREATE TABLE tbl (one , two INT);", noColumnType.Error()}, + {"CREATE TABLE tbl (one TEXT, two INT)", noSemicolon.Error()}, + } + + runParserErrorHandlingSuite(t, tests) + +} + +func TestParseDropCommandErrorHandling(t *testing.T) { + missingTableKeywordError := SyntaxError{expecting: []string{token.TABLE}, got: token.IDENT} + missingDropKeywordError := SyntaxInvalidCommandError{token.TABLE} + missingSemicolonError := &SyntaxError{expecting: []string{token.SEMICOLON}, got: ""} + invalidIdentError := &SyntaxError{expecting: []string{token.IDENT}, got: token.LITERAL} + tests := []errorHandlingTestSuite{ + {input: "DROP table;", expectedError: missingTableKeywordError.Error()}, + {input: "TABLE table;", expectedError: missingDropKeywordError.Error()}, + {input: "DROP TABLE table", expectedError: missingSemicolonError.Error()}, + {input: "DROP TABLE 2;", expectedError: invalidIdentError.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseInsertCommandErrorHandling(t *testing.T) { + noIntoKeyword := SyntaxError{[]string{token.INTO}, token.IDENT} + noTableName := SyntaxError{[]string{token.IDENT}, token.VALUES} + noLeftParen := SyntaxError{[]string{token.LPAREN}, token.APOSTROPHE} + noValue := SyntaxError{[]string{token.IDENT, token.LITERAL, token.NULL}, token.APOSTROPHE} + noRightParen := SyntaxError{[]string{token.RPAREN}, token.SEMICOLON} + noSemicolon := SyntaxError{[]string{token.SEMICOLON}, ""} + noLeftApostrophe := NoApostropheOnLeftParserError{ident: "hello"} + noRightApostrophe := NoApostropheOnRightParserError{ident: "hello, 10)"} + + tests := []errorHandlingTestSuite{ + {"INSERT tbl VALUES( 'hello', 10);", noIntoKeyword.Error()}, + {"INSERT INTO VALUES( 'hello', 10);", noTableName.Error()}, + {"INSERT INTO tl VALUES 'hello', 10);", noLeftParen.Error()}, + {"INSERT INTO tl VALUES ('', 10);", noValue.Error()}, + {"INSERT INTO tl VALUES ('hello', 10;", noRightParen.Error()}, + {"INSERT INTO tl VALUES ('hello', 10)", noSemicolon.Error()}, + {"INSERT INTO tl VALUES (hello', 10)", noLeftApostrophe.Error()}, + {"INSERT INTO tl VALUES ('hello, 10)", noRightApostrophe.Error()}, + } + + runParserErrorHandlingSuite(t, tests) + +} + +func TestParseUpdateCommandErrorHandling(t *testing.T) { + notableName := SyntaxError{expecting: []string{token.IDENT}, got: token.SEMICOLON} + noSetKeyword := SyntaxError{expecting: []string{token.SET}, got: token.SEMICOLON} + noColumnName := SyntaxError{expecting: []string{token.IDENT}, got: token.LITERAL} + noToKeyword := SyntaxError{expecting: []string{token.TO}, got: token.SEMICOLON} + noSecondIdentOrLiteralForValue := SyntaxError{expecting: []string{token.IDENT, token.LITERAL, token.NULL}, got: token.SEMICOLON} + noCommaBetweenValues := SyntaxError{expecting: []string{token.SEMICOLON, token.WHERE}, got: token.IDENT} + noWhereOrSemicolon := SyntaxError{expecting: []string{token.SEMICOLON, token.WHERE}, got: token.SELECT} + noLeftApostrophe := NoApostropheOnLeftParserError{ident: "new_value_1"} + noRightApostrophe := NoApostropheOnRightParserError{ident: "new_value_1"} + + tests := []errorHandlingTestSuite{ + {"UPDATE;", notableName.Error()}, + {"UPDATE table;", noSetKeyword.Error()}, + {"UPDATE table SET 2;", noColumnName.Error()}, + {"UPDATE table SET column_name_1;", noToKeyword.Error()}, + {"UPDATE table SET column_name_1 TO;", noSecondIdentOrLiteralForValue.Error()}, + {"UPDATE table SET column_name_1 TO 2 column_name_1 TO 3;", noCommaBetweenValues.Error()}, + {"UPDATE table SET column_name_1 TO 'new_value_1' SELECT;", noWhereOrSemicolon.Error()}, + {"UPDATE table SET column_name_1 TO new_value_1'", noLeftApostrophe.Error()}, + {"UPDATE table SET column_name_1 TO 'new_value_1", noRightApostrophe.Error()}, + } + + runParserErrorHandlingSuite(t, tests) + +} + +func TestParseSelectCommandErrorHandling(t *testing.T) { + noFromKeyword := SyntaxError{[]string{token.FROM}, token.IDENT} + noColumns := SyntaxError{[]string{token.ASTERISK, token.IDENT, token.MAX, token.MIN, token.SUM, token.AVG, token.COUNT}, token.FROM} + noTableName := SyntaxError{[]string{token.IDENT}, token.SEMICOLON} + noSemicolon := SyntaxError{[]string{token.SEMICOLON, token.WHERE, token.ORDER, token.LIMIT, token.OFFSET, token.JOIN, token.LEFT, token.RIGHT, token.INNER, token.FULL}, ""} + noAggregateFunctionParenClosure := SyntaxError{[]string{token.RPAREN}, ","} + noAggregateFunctionLeftParen := SyntaxError{[]string{token.LPAREN}, token.IDENT} + noFromAfterAsterisk := SyntaxError{[]string{token.FROM}, ","} + noAsteriskInsideMaxArgument := SyntaxError{[]string{token.IDENT}, "*"} + + tests := []errorHandlingTestSuite{ + {"SELECT column1, column2 tbl;", noFromKeyword.Error()}, + {"SELECT FROM table;", noColumns.Error()}, + {"SELECT column1, column2 FROM ;", noTableName.Error()}, + {"SELECT column1, column2 FROM table", noSemicolon.Error()}, + {"SELECT SUM(column1, column2 FROM table", noAggregateFunctionParenClosure.Error()}, + {"SELECT SUM column1 FROM table", noAggregateFunctionLeftParen.Error()}, + {"SELECT *, colName FROM table", noFromAfterAsterisk.Error()}, + {"SELECT MAX(*) FROM table", noAsteriskInsideMaxArgument.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseWhereCommandErrorHandling(t *testing.T) { + selectCommandPrefix := "SELECT * FROM tbl " + noPredecessorError := NoPredecessorParserError{command: token.WHERE} + noColName := LogicalExpressionParsingError{} + noLeftAphostrophe := LogicalExpressionParsingError{} + noOperatorInsideWhereStatementException := LogicalExpressionParsingError{} + valueIsMissing := SyntaxError{expecting: []string{token.APOSTROPHE, token.IDENT, token.LITERAL, token.NULL}, got: token.SEMICOLON} + tokenAnd := token.AND + conjunctionIsMissing := SyntaxError{expecting: []string{token.SEMICOLON, token.ORDER}, got: token.IDENT} + nextLogicalExpressionIsMissing := LogicalExpressionParsingError{afterToken: &tokenAnd} + noSemicolon := SyntaxError{expecting: []string{token.SEMICOLON, token.ORDER}, got: ""} + noLeftParGotSemicolon := SyntaxError{expecting: []string{token.LPAREN}, got: ";"} + noLeftParGotNumber := SyntaxError{expecting: []string{token.LPAREN}, got: token.LITERAL} + noComma := SyntaxError{expecting: []string{token.COMMA, token.RPAREN}, got: token.LITERAL} + anonymitifierInContains := SyntaxError{expecting: []string{token.IDENT}, got: "'one'"} + noInKeywordException := LogicalExpressionParsingError{} + noLeftApostropheGoodbye := NoApostropheOnLeftParserError{ident: "goodbye"} + noLeftApostropheFive := NoApostropheOnLeftParserError{ident: "5"} + noRightApostropheGoodbye := NoApostropheOnRightParserError{ident: "goodbye"} + noRightApostropheGoodbyeBigger := NoApostropheOnRightParserError{ident: "goodbye EQUAL two"} + noRightApostropheFive := NoApostropheOnRightParserError{ident: "5"} + + tests := []errorHandlingTestSuite{ + {"WHERE col1 NOT 'goodbye' OR col2 EQUAL 3;", noPredecessorError.Error()}, + {selectCommandPrefix + "WHERE NOT 'goodbye' OR column2 EQUAL 3;", noColName.Error()}, + {selectCommandPrefix + "WHERE one 'goodbye';", noOperatorInsideWhereStatementException.Error()}, + {selectCommandPrefix + "WHERE one EQUAL;", valueIsMissing.Error()}, + {selectCommandPrefix + "WHERE one EQUAL 5 two NOT 1;", conjunctionIsMissing.Error()}, + {selectCommandPrefix + "WHERE one EQUAL 5 AND;", nextLogicalExpressionIsMissing.Error()}, + {selectCommandPrefix + "WHERE one EQUAL 5 AND two NOT 5", noSemicolon.Error()}, + {selectCommandPrefix + "WHERE one IN ;", noLeftParGotSemicolon.Error()}, + {selectCommandPrefix + "WHERE one IN 5;", noLeftParGotNumber.Error()}, + {selectCommandPrefix + "WHERE one IN (5 6);", noComma.Error()}, + {selectCommandPrefix + "WHERE one IN ('5", noRightApostropheFive.Error()}, + {selectCommandPrefix + "WHERE one IN (5');", noLeftApostropheFive.Error()}, + {selectCommandPrefix + "WHERE 'one' IN (5);", anonymitifierInContains.Error()}, + {selectCommandPrefix + "WHERE one (5, 6);", noInKeywordException.Error()}, + {selectCommandPrefix + "WHERE one EQUAL goodbye';", noLeftApostropheGoodbye.Error()}, + {selectCommandPrefix + "WHERE one EQUAL 'goodbye", noRightApostropheGoodbye.Error()}, + {selectCommandPrefix + "WHERE 'goodbye EQUAL two", noRightApostropheGoodbyeBigger.Error()}, + {selectCommandPrefix + "WHERE goodbye' EQUAL two", noLeftAphostrophe.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseOrderByCommandErrorHandling(t *testing.T) { + selectCommandPrefix := "SELECT * FROM tbl " + noPredecessorError := NoPredecessorParserError{command: token.ORDER} + noAscDescError := SyntaxError{expecting: []string{token.ASC, token.DESC}, got: token.SEMICOLON} + noByKeywordError := SyntaxError{expecting: []string{token.BY}, got: token.IDENT} + noIdentKeywordError := SyntaxError{expecting: []string{token.IDENT}, got: token.ASC} + + tests := []errorHandlingTestSuite{ + {"ORDER BY column1;", noPredecessorError.Error()}, + {selectCommandPrefix + "ORDER BY column1;", noAscDescError.Error()}, + {selectCommandPrefix + "ORDER column1 ASC;", noByKeywordError.Error()}, + {selectCommandPrefix + "ORDER BY ASC;", noIdentKeywordError.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseLimitCommandErrorHandling(t *testing.T) { + selectCommandPrefix := "SELECT * FROM tbl " + noPredecessorError := NoPredecessorParserError{command: token.LIMIT} + noLiteralError := SyntaxError{expecting: []string{token.LITERAL}, got: token.SEMICOLON} + lessThanZeroError := ArithmeticLessThanZeroParserError{variable: "limit"} + + tests := []errorHandlingTestSuite{ + {"LIMIT 5;", noPredecessorError.Error()}, + {selectCommandPrefix + "LIMIT;", noLiteralError.Error()}, + {selectCommandPrefix + "LIMIT -10;", lessThanZeroError.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseOffsetCommandErrorHandling(t *testing.T) { + selectCommandPrefix := "SELECT * FROM tbl " + noPredecessorError := NoPredecessorParserError{command: token.OFFSET} + noLiteralError := SyntaxError{expecting: []string{token.LITERAL}, got: token.IDENT} + lessThanZeroError := ArithmeticLessThanZeroParserError{variable: "offset"} + + tests := []errorHandlingTestSuite{ + {"OFFSET 5;", noPredecessorError.Error()}, + {selectCommandPrefix + "OFFSET hi;", noLiteralError.Error()}, + {selectCommandPrefix + "OFFSET -10;", lessThanZeroError.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestParseDeleteCommandErrorHandling(t *testing.T) { + noFromKeyword := SyntaxError{[]string{token.FROM}, token.IDENT} + noTableName := SyntaxError{[]string{token.IDENT}, token.WHERE} + noWhereCommand := SyntaxError{[]string{token.WHERE}, ";"} + + tests := []errorHandlingTestSuite{ + {"DELETE table WHERE TRUE", noFromKeyword.Error()}, + {"DELETE FROM WHERE TRUE;", noTableName.Error()}, + {"DELETE FROM table;", noWhereCommand.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func TestPeriodInIdentWhileCreatingTableErrorHandling(t *testing.T) { + illegalPeriodInTableName := IllegalPeriodInIdentParserError{"tab.le"} + illegalPeriodInColumnName := IllegalPeriodInIdentParserError{"col.umn"} + + tests := []errorHandlingTestSuite{ + {"CREATE TABLE tab.le( one TEXT , two INT);", illegalPeriodInTableName.Error()}, + {"CREATE TABLE table1( col.umn TEXT , two INT);", illegalPeriodInColumnName.Error()}, + } + + runParserErrorHandlingSuite(t, tests) +} + +func runParserErrorHandlingSuite(t *testing.T, suite []errorHandlingTestSuite) { + for i, test := range suite { + errorMsg := getErrorMessage(t, test.input, i) + + if errorMsg != test.expectedError { + t.Fatalf("[%v]Was expecting error: \n\t{%s},\n\tbut it was:\n\t{%s}", i, test.expectedError, errorMsg) + } + } +} + +func getErrorMessage(t *testing.T, input string, testIndex int) string { + lexerInstance := lexer.RunLexer(input) + parserInstance := New(lexerInstance) + _, err := parserInstance.ParseSequence() + + if err == nil { + t.Fatalf("[%v]Was expecting error from parser but there was none", testIndex) + } + + return err.Error() +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 0e4a1e1..0823da4 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -20,13 +20,16 @@ func TestParserCreateCommand(t *testing.T) { {"CREATE TABLE TBL( );", "TBL", []string{}, []token.Token{}}, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) + } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } if !testCreateStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedColumnNames, tt.expectedColumTypes) { @@ -74,15 +77,19 @@ func TestParseInsertCommand(t *testing.T) { {"INSERT INTO TBL VALUES();", "TBL", []token.Token{}}, {"INSERT INTO TBL VALUES( 'HELLO' );", "TBL", []token.Token{{Type: token.IDENT, Literal: "HELLO"}}}, {"INSERT INTO TBL VALUES( 'HELLO', 10 , 'LOL');", "TBL", []token.Token{{Type: token.IDENT, Literal: "HELLO"}, {Type: token.LITERAL, Literal: "10"}, {Type: token.IDENT, Literal: "LOL"}}}, + {"INSERT INTO TBL VALUES(NULL, 'NULL', null);", "TBL", []token.Token{{Type: token.NULL, Literal: "NULL"}, {Type: token.IDENT, Literal: "NULL"}, {Type: token.IDENT, Literal: "null"}}}, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) + } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } if !testInsertStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedValuesTokens) { @@ -120,23 +127,27 @@ func TestParseSelectCommand(t *testing.T) { tests := []struct { input string expectedTableName string - expectedColumns []token.Token + expectedSpaces []ast.Space + expectedDistinct bool }{ - {"SELECT * FROM TBL;", "TBL", []token.Token{{Type: token.ASTERISK, Literal: "*"}}}, - {"SELECT ONE, TWO, THREE FROM TBL;", "TBL", []token.Token{{Type: token.IDENT, Literal: "ONE"}, {Type: token.IDENT, Literal: "TWO"}, {Type: token.IDENT, Literal: "THREE"}}}, - {"SELECT FROM TBL;", "TBL", []token.Token{}}, + {"SELECT * FROM TBL;", "TBL", []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}}, false}, + {"SELECT ONE, TWO, THREE FROM TBL;", "TBL", []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "ONE"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "TWO"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "THREE"}}}, false}, + {"SELECT DISTINCT * FROM TBL;", "TBL", []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}}, true}, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) + } if len(sequences.Commands) != 1 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) } - if !testSelectStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedColumns) { + if !testSelectStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedSpaces, tt.expectedDistinct) { return } } @@ -155,6 +166,30 @@ func TestParseWhereCommand(t *testing.T) { Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, } + thirdExpression := ast.ContainExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName3"}}, + Right: []ast.Anonymitifier{ + {Token: token.Token{Type: token.LITERAL, Literal: "1"}}, + {Token: token.Token{Type: token.LITERAL, Literal: "2"}}, + }, + Contains: true, + } + + fourthExpression := ast.ContainExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName4"}}, + Right: []ast.Anonymitifier{ + {Token: token.Token{Type: token.IDENT, Literal: "one"}}, + {Token: token.Token{Type: token.IDENT, Literal: "two"}}, + }, + Contains: false, + } + + fifthExpression := ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName5"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.NULL, Literal: "NULL"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + } + tests := []struct { input string expectedExpression ast.Expression @@ -167,18 +202,42 @@ func TestParseWhereCommand(t *testing.T) { input: "SELECT * FROM TBL WHERE colName2 EQUAL 6462389;", expectedExpression: secondExpression, }, + { + input: "SELECT * FROM TBL WHERE colName3 IN (1, 2);", + expectedExpression: thirdExpression, + }, + { + input: "SELECT * FROM TBL WHERE colName4 NOTIN ('one', 'two');", + expectedExpression: fourthExpression, + }, + { + input: "SELECT * FROM TBL WHERE colName5 EQUAL NULL;", + expectedExpression: fifthExpression, + }, + { + input: "SELECT * FROM TBL WHERE colName5 EQUAL NULL;", + expectedExpression: fifthExpression, + }, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) + } - if len(sequences.Commands) != 2 { - t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + if len(sequences.Commands) != 1 { + t.Fatalf("[%d] sequences does not contain 1 statements, got=%d", testIndex, len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + if !selectCommand.HasWhereCommand() { + t.Fatalf("[%d] sequences does not contain where command", testIndex) } - if !whereStatementIsValid(t, sequences.Commands[1], tt.expectedExpression) { + if !whereStatementIsValid(t, selectCommand.WhereCommand, tt.expectedExpression) { return } } @@ -188,7 +247,7 @@ func TestParseDeleteCommand(t *testing.T) { input := "DELETE FROM colName1 WHERE colName2 EQUAL 6462389;" expectedDeleteCommand := ast.DeleteCommand{ Token: token.Token{Type: token.DELETE, Literal: "DELETE"}, - Name: &ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName1"}}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName1"}}, } expectedWhereCommand := ast.ConditionExpression{ Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName2"}}, @@ -198,10 +257,13 @@ func TestParseDeleteCommand(t *testing.T) { lexer := lexer.RunLexer(input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } - if len(sequences.Commands) != 2 { - t.Fatalf("sequences does not contain 2 statements. got=%d", len(sequences.Commands)) + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) } actualDeleteCommand, ok := sequences.Commands[0].(*ast.DeleteCommand) @@ -217,11 +279,47 @@ func TestParseDeleteCommand(t *testing.T) { t.Errorf("Table name of DeleteCommand is not %s. got=%s", expectedDeleteCommand.Name.GetToken().Literal, actualDeleteCommand.Name.GetToken().Literal) } - if !whereStatementIsValid(t, sequences.Commands[1], expectedWhereCommand) { + if !actualDeleteCommand.HasWhereCommand() { + t.Fatalf("sequences does not contain where command") + } + + if !whereStatementIsValid(t, actualDeleteCommand.WhereCommand, expectedWhereCommand) { return } } +func TestParseDropCommand(t *testing.T) { + input := "DROP TABLE table;" + expectedDropCommand := ast.DropCommand{ + Token: token.Token{Type: token.DROP, Literal: "DROP"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "table"}}, + } + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + actualDropCommand, ok := sequences.Commands[0].(*ast.DropCommand) + if !ok { + t.Errorf("actualDropCommand is not %T. got=%T", &ast.DropCommand{}, sequences.Commands[0]) + } + + if expectedDropCommand.TokenLiteral() != actualDropCommand.TokenLiteral() { + t.Errorf("TokenLiteral of DropCommand is not %s. got=%s", expectedDropCommand.TokenLiteral(), actualDropCommand.TokenLiteral()) + } + + if expectedDropCommand.Name.GetToken().Literal != actualDropCommand.Name.GetToken().Literal { + t.Errorf("Table name of DropCommand is not %s. got=%s", expectedDropCommand.Name.GetToken().Literal, actualDropCommand.Name.GetToken().Literal) + } +} + func TestSelectWithOrderByCommand(t *testing.T) { input := "SELECT * FROM tableName ORDER BY colName1 DESC;" expectedSortPattern := ast.SortPattern{ @@ -233,26 +331,477 @@ func TestSelectWithOrderByCommand(t *testing.T) { SortPatterns: []ast.SortPattern{expectedSortPattern}, } expectedTableName := "tableName" - expectedColumnName := []token.Token{{Type: token.ASTERISK, Literal: "*"}} + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { + return + } + + if !selectCommand.HasOrderByCommand() { + t.Fatalf("sequences does not contain where command") + } + + testOrderByCommands(t, expectedOrderByCommand, selectCommand.OrderByCommand) +} + +func TestSelectWithLimitCommand(t *testing.T) { + input := "SELECT * FROM tableName LIMIT 5;" + expectedLimitCommand := ast.LimitCommand{ + Token: token.Token{Type: token.LIMIT, Literal: "LIMIT"}, + Count: 5, + } + expectedTableName := "tableName" + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { + return + } + + if !selectCommand.HasLimitCommand() { + t.Fatalf("sequences does not contain where command") + } + + testLimitCommands(t, expectedLimitCommand, selectCommand.LimitCommand) +} + +func TestSelectWithOffsetCommand(t *testing.T) { + input := "SELECT * FROM tableName OFFSET 5;" + expectedOffsetCommand := ast.OffsetCommand{ + Token: token.Token{Type: token.OFFSET, Literal: "OFFSET"}, + Count: 5, + } + + expectedTableName := "tableName" + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { + return + } + + if !selectCommand.HasOffsetCommand() { + t.Fatalf("select command should have offset command") + } + testOffsetCommands(t, expectedOffsetCommand, selectCommand.OffsetCommand) +} + +func TestSelectWithLimitAndOffsetCommand(t *testing.T) { + input := "SELECT * FROM tableName ORDER BY colName1 DESC LIMIT 2 OFFSET 13;" + expectedLimitCommand := ast.LimitCommand{ + Token: token.Token{Type: token.LIMIT, Literal: "LIMIT"}, + Count: 2, + } + expectedOffsetCommand := ast.OffsetCommand{ + Token: token.Token{Type: token.OFFSET, Literal: "OFFSET"}, + Count: 13, + } + expectedTableName := "tableName" + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.ASTERISK, Literal: "*"}}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { + return + } + + if !selectCommand.HasLimitCommand() { + t.Fatalf("select command should have limit command") + } + if !selectCommand.HasOffsetCommand() { + t.Fatalf("select command should have offset command") + } + + testLimitCommands(t, expectedLimitCommand, selectCommand.LimitCommand) + testOffsetCommands(t, expectedOffsetCommand, selectCommand.OffsetCommand) +} + +func TestSelectWithDefaultInnerJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.INNER, Literal: "INNER"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedSpace := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpace, false) { + return + } + + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") + } + + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + +func TestSelectWithInnerJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl INNER JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.INNER, Literal: "INNER"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedSpace := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpace, false) { + return + } + + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") + } + + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + +func TestSelectWithLeftJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl LEFT JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.LEFT, Literal: "LEFT"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { + return + } + + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") + } + + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + +func TestSelectWithRightJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl RIGHT JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.RIGHT, Literal: "RIGHT"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} lexer := lexer.RunLexer(input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } - if len(sequences.Commands) != 2 { - t.Fatalf("sequences does not contain 2 statements. got=%d", len(sequences.Commands)) + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) } - if !testSelectStatement(t, sequences.Commands[0], expectedTableName, expectedColumnName) { + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { return } - actualOrderByCommand, orderByCommandIsOk := sequences.Commands[1].(*ast.OrderByCommand) - if !orderByCommandIsOk { - t.Errorf("actualDeleteCommand is not %T. got=%T", &ast.OrderByCommand{}, sequences.Commands[0]) + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") } - testOrderByCommands(t, expectedOrderByCommand, actualOrderByCommand) + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + +func TestSelectWithFullJoinCommand(t *testing.T) { + input := "SELECT tbl.one, tbl2.two FROM tbl FULL JOIN tbl2 ON tbl.one EQUAL tbl2.one;" + expectedJoinCommand := ast.JoinCommand{ + Token: token.Token{Type: token.JOIN, Literal: "JOIN"}, + Name: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2"}}, + JoinType: token.Token{Type: token.FULL, Literal: "FULL"}, + Expression: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, + Right: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "tbl2.one"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + } + expectedTableName := "tbl" + expectedSpaces := []ast.Space{{ColumnName: token.Token{Type: token.IDENT, Literal: "tbl.one"}}, {ColumnName: token.Token{Type: token.IDENT, Literal: "tbl2.two"}}} + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { + return + } + + if !selectCommand.HasJoinCommand() { + t.Fatalf("select command should have join command") + } + + testJoinCommands(t, expectedJoinCommand, *selectCommand.JoinCommand) +} + +func TestSelectWithAggregateFunctions(t *testing.T) { + input := "SELECT MIN(colOne), MAX(colOne), COUNT(*), COUNT(colOne), SUM(colOne), AVG(colOne) FROM tbl;" + + expectedTableName := "tbl" + expectedSpaces := []ast.Space{ + { + ColumnName: token.Token{Type: token.IDENT, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.MIN, Literal: "MIN"}, + }, + { + ColumnName: token.Token{Type: token.ASTERISK, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.MAX, Literal: "MAX"}, + }, + { + ColumnName: token.Token{Type: token.IDENT, Literal: "*"}, + AggregateFunc: &token.Token{Type: token.COUNT, Literal: "COUNT"}, + }, + { + ColumnName: token.Token{Type: token.IDENT, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.COUNT, Literal: "COUNT"}, + }, + { + ColumnName: token.Token{Type: token.IDENT, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.SUM, Literal: "SUM"}, + }, + { + ColumnName: token.Token{Type: token.IDENT, Literal: "colOne"}, + AggregateFunc: &token.Token{Type: token.AVG, Literal: "AVG"}, + }, + } + + lexer := lexer.RunLexer(input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("sequences does not contain 1 statements. got=%d", len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) + + if !testSelectStatement(t, selectCommand, expectedTableName, expectedSpaces, false) { + return + } +} + +func TestParseUpdateCommand(t *testing.T) { + tests := []struct { + input string + expectedTableName string + expectedChanges map[token.Token]ast.Anonymitifier + }{ + { + input: "UPDATE tbl SET colName TO 5;", expectedTableName: "tbl", expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, + }, + { + input: "UPDATE tbl1 SET colName1 TO 'hi hello', colName2 TO 5;", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.IDENT, Literal: "hi hello"}}, + {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, + }, + { + input: "UPDATE tbl1 SET colName1 TO NULL, colName2 TO 'NULL';", expectedTableName: "tbl1", expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName1"}: {Token: token.Token{Type: token.NULL, Literal: "NULL"}}, + {Type: token.IDENT, Literal: "colName2"}: {Token: token.Token{Type: token.LITERAL, Literal: "NULL"}}, + }, + }, + } + + for testIndex, tt := range tests { + lexer := lexer.RunLexer(tt.input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("[%d] Got error from parser: %s", testIndex, err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) + } + + if !testUpdateStatement(t, sequences.Commands[0], tt.expectedTableName, tt.expectedChanges) { + return + } + } +} + +func TestParseUpdateCommandWithWhere(t *testing.T) { + tests := []struct { + input string + expectedTableName string + expectedChanges map[token.Token]ast.Anonymitifier + expectedWhereCommand ast.Expression + }{ + { + input: "UPDATE tbl SET colName TO 5 WHERE id EQUAL 3;", + expectedTableName: "tbl", + expectedChanges: map[token.Token]ast.Anonymitifier{ + {Type: token.IDENT, Literal: "colName"}: {Token: token.Token{Type: token.LITERAL, Literal: "5"}}, + }, + expectedWhereCommand: ast.ConditionExpression{ + Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "id"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.LITERAL, Literal: "3"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}, + }, + }, + } + + for testIndex, tt := range tests { + lexer := lexer.RunLexer(tt.input) + parserInstance := New(lexer) + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) + } + + actualUpdateCommand, ok := sequences.Commands[0].(*ast.UpdateCommand) + + if !ok { + t.Errorf("[%d] actualUpdateCommand is not %T. got=%T", testIndex, &ast.UpdateCommand{}, sequences.Commands[0]) + } + + if !testUpdateStatement(t, actualUpdateCommand, tt.expectedTableName, tt.expectedChanges) { + return + } + + if !actualUpdateCommand.HasWhereCommand() { + t.Errorf("[%d] actualUpdateCommand should have where command", testIndex) + } + + if !whereStatementIsValid(t, actualUpdateCommand.WhereCommand, tt.expectedWhereCommand) { + return + } + } } func TestParseLogicOperatorsInCommand(t *testing.T) { @@ -264,7 +813,7 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}}, Right: ast.ConditionExpression{ Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName2"}}, - Right: ast.Anonymitifier{Token: token.Token{Type: token.LITERAL, Literal: "123"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.NULL, Literal: "NULL"}}, Condition: token.Token{Type: token.EQUAL, Literal: "NOT"}}, Operation: token.Token{Type: token.AND, Literal: "AND"}, } @@ -276,7 +825,7 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { Condition: token.Token{Type: token.NOT, Literal: "NOT"}}, Right: ast.ConditionExpression{ Left: ast.Identifier{Token: token.Token{Type: token.IDENT, Literal: "colName1"}}, - Right: ast.Anonymitifier{Token: token.Token{Type: token.IDENT, Literal: "qwe"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.IDENT, Literal: "NULL"}}, Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}}, Operation: token.Token{Type: token.OR, Literal: "OR"}, } @@ -285,40 +834,58 @@ func TestParseLogicOperatorsInCommand(t *testing.T) { Boolean: token.Token{Type: token.TRUE, Literal: "TRUE"}, } + fourthExpression := ast.ConditionExpression{ + Left: ast.Anonymitifier{Token: token.Token{Type: token.IDENT, Literal: "colName1 EQUAL;"}}, + Right: ast.Anonymitifier{Token: token.Token{Type: token.IDENT, Literal: "colName1 EQUAL;"}}, + Condition: token.Token{Type: token.EQUAL, Literal: "EQUAL"}} + tests := []struct { input string expectedExpression ast.Expression }{ { - input: "SELECT * FROM TBL WHERE colName1 EQUAL 'fda' AND colName2 NOT 123;", + input: "SELECT * FROM TBL WHERE colName1 EQUAL 'fda' AND colName2 NOT NULL;", expectedExpression: firstExpression, }, { - input: "SELECT * FROM TBL WHERE colName2 NOT 6462389 OR colName1 EQUAL 'qwe';", + input: "SELECT * FROM TBL WHERE colName2 NOT 6462389 OR colName1 EQUAL 'NULL';", expectedExpression: secondExpression, }, { input: "SELECT * FROM TBL WHERE TRUE;", expectedExpression: thirdExpression, }, + { + input: "SELECT * FROM TBL WHERE 'colName1 EQUAL;' EQUAL 'colName1 EQUAL;';", + expectedExpression: fourthExpression, + }, } - for _, tt := range tests { + for testIndex, tt := range tests { lexer := lexer.RunLexer(tt.input) parserInstance := New(lexer) - sequences := parserInstance.ParseSequence() + sequences, err := parserInstance.ParseSequence() + if err != nil { + t.Fatalf("Got error from parser: %s", err) + } + + if len(sequences.Commands) != 1 { + t.Fatalf("[%d] sequences does not contain 1 statements. got=%d", testIndex, len(sequences.Commands)) + } + + selectCommand := sequences.Commands[0].(*ast.SelectCommand) - if len(sequences.Commands) != 2 { - t.Fatalf("sequences does not contain 2 statements. got=%d", len(sequences.Commands)) + if !selectCommand.HasWhereCommand() { + t.Fatalf("[%d] sequences does not contain where command", testIndex) } - if !whereStatementIsValid(t, sequences.Commands[1], tt.expectedExpression) { - t.Fatalf("Actual expression and expected one are different") + if !whereStatementIsValid(t, selectCommand.WhereCommand, tt.expectedExpression) { + t.Fatalf("[%d] Actual expression and expected one are different", testIndex) } } } -func testSelectStatement(t *testing.T, command ast.Command, expectedTableName string, expectedColumnsTokens []token.Token) bool { +func testSelectStatement(t *testing.T, command ast.Command, expectedTableName string, expectedSpaces []ast.Space, expectedDistinct bool) bool { if command.TokenLiteral() != "SELECT" { t.Errorf("command.TokenLiteral() not 'SELECT'. got=%q", command.TokenLiteral()) return false @@ -335,7 +902,35 @@ func testSelectStatement(t *testing.T, command ast.Command, expectedTableName st return false } - if !tokenArrayEquals(actualSelectCommand.Space, expectedColumnsTokens) { + if actualSelectCommand.HasDistinct != expectedDistinct { + t.Errorf("HasDistinct should be set to %t, got=%t", expectedDistinct, actualSelectCommand.HasDistinct) + return false + } + + if !spaceArrayEquals(actualSelectCommand.Space, expectedSpaces) { + t.Errorf("actualSelectCommand has diffrent space than expected. %+v != %+v", actualSelectCommand.Space, expectedSpaces) + return false + } + + return true +} + +func testUpdateStatement(t *testing.T, command ast.Command, expectedTableName string, expectedChanges map[token.Token]ast.Anonymitifier) bool { + if command.TokenLiteral() != "UPDATE" { + t.Errorf("command.TokenLiteral() not 'UPDATE'. got=%q", command.TokenLiteral()) + return false + } + actualUpdateCommand, ok := command.(*ast.UpdateCommand) + + if !ok { + t.Errorf("actualUpdateCommand is not %T. got=%T", &ast.UpdateCommand{}, command) + return false + } + if actualUpdateCommand.Name.Token.Literal != expectedTableName { + t.Errorf("%s != %s", actualUpdateCommand.TokenLiteral(), expectedTableName) + return false + } + if !tokenMapEquals(actualUpdateCommand.Changes, expectedChanges) { t.Errorf("") return false } @@ -383,6 +978,39 @@ func tokenArrayEquals(a []token.Token, b []token.Token) bool { if v.Literal != b[i].Literal { return false } + if v.Type != b[i].Type { + return false + } + } + return true +} + +func spaceArrayEquals(a []ast.Space, b []ast.Space) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v.ColumnName.Literal != b[i].ColumnName.Literal { + return false + } + if v.ContainsAggregateFunc() != b[i].ContainsAggregateFunc() { + return false + } + if v.ContainsAggregateFunc() && b[i].ContainsAggregateFunc() && v.AggregateFunc.Literal != b[i].AggregateFunc.Literal { + return false + } + } + return true +} + +func tokenMapEquals(a map[token.Token]ast.Anonymitifier, b map[token.Token]ast.Anonymitifier) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if v.GetToken().Literal != b[k].GetToken().Literal { + return false + } } return true } @@ -390,7 +1018,7 @@ func tokenArrayEquals(a []token.Token, b []token.Token) bool { func testOrderByCommands(t *testing.T, expectedOrderByCommand ast.OrderByCommand, actualOrderByCommand *ast.OrderByCommand) { if expectedOrderByCommand.Token.Type != actualOrderByCommand.Token.Type { - t.Errorf("Expecting Token Type: %q, got: %q", expectedOrderByCommand.Token.Type, actualOrderByCommand.Token.Type) + t.Errorf("Expecting Token TokenType: %q, got: %q", expectedOrderByCommand.Token.Type, actualOrderByCommand.Token.Type) } if expectedOrderByCommand.Token.Literal != actualOrderByCommand.Token.Literal { t.Errorf("Expecting Token Literal: %s, got: %s", expectedOrderByCommand.Token.Literal, actualOrderByCommand.Token.Literal) @@ -407,9 +1035,49 @@ func testOrderByCommands(t *testing.T, expectedOrderByCommand ast.OrderByCommand t.Errorf("Expecting Column Name: %s, got: %s", expectedSortPattern.ColumnName.Literal, actualOrderByCommand.SortPatterns[i].ColumnName.Literal) } } +} + +func testLimitCommands(t *testing.T, expectedLimitCommand ast.LimitCommand, actualLimitCommand *ast.LimitCommand) { + + if expectedLimitCommand.Token.Type != actualLimitCommand.Token.Type { + t.Errorf("Expecting Token TokenType: %q, got: %q", expectedLimitCommand.Token.Type, actualLimitCommand.Token.Type) + } + if expectedLimitCommand.Token.Literal != actualLimitCommand.Token.Literal { + t.Errorf("Expecting Token Literal: %s, got: %s", expectedLimitCommand.Token.Literal, actualLimitCommand.Token.Literal) + } + if expectedLimitCommand.Count != actualLimitCommand.Count { + t.Errorf("Expecting Count to have value: %d, got: %d", expectedLimitCommand.Count, actualLimitCommand.Count) + } +} +func testOffsetCommands(t *testing.T, expectedOffsetCommand ast.OffsetCommand, actualOffsetCommand *ast.OffsetCommand) { + + if expectedOffsetCommand.Token.Type != actualOffsetCommand.Token.Type { + t.Errorf("Expecting Token TokenType: %q, got: %q", expectedOffsetCommand.Token.Type, actualOffsetCommand.Token.Type) + } + if expectedOffsetCommand.Token.Literal != actualOffsetCommand.Token.Literal { + t.Errorf("Expecting Token Literal: %s, got: %s", expectedOffsetCommand.Token.Literal, actualOffsetCommand.Token.Literal) + } + if expectedOffsetCommand.Count != actualOffsetCommand.Count { + t.Errorf("Expecting Count to have value: %d, got: %d", expectedOffsetCommand.Count, actualOffsetCommand.Count) + } } +func testJoinCommands(t *testing.T, expectedJoinCommand ast.JoinCommand, actualJoinCommand ast.JoinCommand) { + + if expectedJoinCommand.Token.Type != actualJoinCommand.Token.Type { + t.Errorf("Expecting Token TokenType: %q, got: %q", expectedJoinCommand.Token.Type, actualJoinCommand.Token.Type) + } + if expectedJoinCommand.Token.Literal != actualJoinCommand.Token.Literal { + t.Errorf("Expecting Token Literal: %s, got: %s", expectedJoinCommand.Token.Literal, actualJoinCommand.Token.Literal) + } + if expectedJoinCommand.Name != actualJoinCommand.Name { + t.Errorf("Expecting Name to has a value: %s, got: %s", expectedJoinCommand.Name, actualJoinCommand.Name) + } + if !expressionsAreEqual(actualJoinCommand.Expression, expectedJoinCommand.Expression) { + t.Errorf("Actual expression is not equal to expected one.\nActual: %#v\nExpected: %#v", actualJoinCommand.Expression, expectedJoinCommand.Expression) + } +} func expressionsAreEqual(first ast.Expression, second ast.Expression) bool { booleanExpression, booleanExpressionIsValid := first.(*ast.BooleanExpression) @@ -427,6 +1095,11 @@ func expressionsAreEqual(first ast.Expression, second ast.Expression) bool { return validateOperationExpression(second, operationExpression) } + containExpression, containExpressionIsValid := first.(*ast.ContainExpression) + if containExpressionIsValid { + return validateContainExpression(second, containExpression) + } + return false } @@ -444,6 +1117,35 @@ func validateOperationExpression(second ast.Expression, operationExpression *ast return expressionsAreEqual(operationExpression.Left, secondOperationExpression.Left) && expressionsAreEqual(operationExpression.Right, secondOperationExpression.Right) } +func validateContainExpression(expression ast.Expression, expectedContainExpression *ast.ContainExpression) bool { + actualContainExpression, actualContainExpressionIsValid := expression.(ast.ContainExpression) + + if !actualContainExpressionIsValid { + return false + } + + if expectedContainExpression.Contains != actualContainExpression.Contains { + return false + } + + if actualContainExpression.Left.GetToken().Literal != expectedContainExpression.Left.GetToken().Literal && + actualContainExpression.Left.IsIdentifier() == expectedContainExpression.Left.IsIdentifier() { + return false + } + + if len(expectedContainExpression.Right) != len(actualContainExpression.Right) { + return false + } + + for i := 0; i < len(expectedContainExpression.Right); i++ { + if expectedContainExpression.Right[i] != actualContainExpression.Right[i] { + return false + } + } + + return true +} + func validateConditionExpression(second ast.Expression, conditionExpression *ast.ConditionExpression) bool { secondConditionExpression, secondConditionExpressionIsValid := second.(ast.ConditionExpression) @@ -469,13 +1171,13 @@ func validateConditionExpression(second ast.Expression, conditionExpression *ast } func validateBooleanExpressions(second ast.Expression, booleanExpression *ast.BooleanExpression) bool { - secondBooleanExpresion, secondBooleanExpresionIsValid := second.(ast.BooleanExpression) + secondBooleanExpression, secondBooleanExpressionIsValid := second.(ast.BooleanExpression) - if !secondBooleanExpresionIsValid { + if !secondBooleanExpressionIsValid { return false } - if booleanExpression.Boolean.Literal != secondBooleanExpresion.Boolean.Literal { + if booleanExpression.Boolean.Literal != secondBooleanExpression.Boolean.Literal { return false } diff --git a/test_file b/test_file deleted file mode 100644 index e69a5c3..0000000 --- a/test_file +++ /dev/null @@ -1,11 +0,0 @@ - CREATE TABLE tbl( one TEXT , two INT, three INT, four TEXT ); - INSERT INTO tbl VALUES( 'hello', 1, 11, 'q' ); - INSERT INTO tbl VALUES( 'goodbye', 1, 22, 'w' ); - INSERT INTO tbl VALUES( 'byebye', 3, 33, 'e' ); - SELECT * FROM tbl WHERE one EQUAL 'byebye'; - SELECT one, three FROM tbl WHERE two NOT 3; - SELECT * FROM tbl WHERE one NOT 'goodbye' AND two EQUAL 3; - SELECT * FROM tbl WHERE FALSE; - DELETE FROM tbl WHERE one EQUAL 'byebye'; - SELECT * FROM tbl; - SELECT one FROM tbl WHERE TRUE ORDER BY two ASC, four DESC; diff --git a/token/token.go b/token/token.go index 69e5956..381a168 100644 --- a/token/token.go +++ b/token/token.go @@ -29,19 +29,41 @@ const ( RPAREN = ")" // CREATE - Keywords - CREATE = "CREATE" - TABLE = "TABLE" - INSERT = "INSERT" - INTO = "INTO" - VALUES = "VALUES" - SELECT = "SELECT" - FROM = "FROM" - WHERE = "WHERE" - DELETE = "DELETE" - ORDER = "ORDER" - BY = "BY" - ASC = "ASC" - DESC = "DESC" + CREATE = "CREATE" + DROP = "DROP" + TABLE = "TABLE" + INSERT = "INSERT" + INTO = "INTO" + VALUES = "VALUES" + SELECT = "SELECT" + FROM = "FROM" + WHERE = "WHERE" + DELETE = "DELETE" + ORDER = "ORDER" + BY = "BY" + ASC = "ASC" + DESC = "DESC" + LIMIT = "LIMIT" + OFFSET = "OFFSET" + UPDATE = "UPDATE" + SET = "SET" + DISTINCT = "DISTINCT" + JOIN = "JOIN" + INNER = "INNER" + FULL = "FULL" + LEFT = "LEFT" + RIGHT = "RIGHT" + ON = "ON" + MIN = "MIN" + MAX = "MAX" + COUNT = "COUNT" + SUM = "SUM" + AVG = "AVG" + IN = "IN" + NOTIN = "NOTIN" + NULL = "NULL" + + TO = "TO" // EQUAL - Logical operations EQUAL = "EQUAL" @@ -60,27 +82,48 @@ const ( ) var keywords = map[string]Type{ - "TEXT": TEXT, - "INT": INT, - "CREATE": CREATE, - "TABLE": TABLE, - "INSERT": INSERT, - "INTO": INTO, - "SELECT": SELECT, - "FROM": FROM, - "DELETE": DELETE, - "ORDER": ORDER, - "BY": BY, - "ASC": ASC, - "DESC": DESC, - "VALUES": VALUES, - "WHERE": WHERE, - "EQUAL": EQUAL, - "NOT": NOT, - "AND": AND, - "OR": OR, - "TRUE": TRUE, - "FALSE": FALSE, + "TEXT": TEXT, + "INT": INT, + "CREATE": CREATE, + "DROP": DROP, + "TABLE": TABLE, + "INSERT": INSERT, + "INTO": INTO, + "SELECT": SELECT, + "FROM": FROM, + "DELETE": DELETE, + "ORDER": ORDER, + "BY": BY, + "ASC": ASC, + "DESC": DESC, + "LIMIT": LIMIT, + "OFFSET": OFFSET, + "UPDATE": UPDATE, + "SET": SET, + "DISTINCT": DISTINCT, + "INNER": INNER, + "FULL": FULL, + "LEFT": LEFT, + "RIGHT": RIGHT, + "JOIN": JOIN, + "ON": ON, + "MIN": MIN, + "MAX": MAX, + "COUNT": COUNT, + "SUM": SUM, + "AVG": AVG, + "IN": IN, + "NOTIN": NOTIN, + "TO": TO, + "VALUES": VALUES, + "WHERE": WHERE, + "EQUAL": EQUAL, + "NOT": NOT, + "AND": AND, + "OR": OR, + "TRUE": TRUE, + "FALSE": FALSE, + "NULL": NULL, } // LookupIdent - Return keyword type from defined list if exists, otherwise it returns IDENT type