diff --git a/compiler/ast.go b/compiler/ast.go index 23be814..e1b9bf8 100644 --- a/compiler/ast.go +++ b/compiler/ast.go @@ -71,6 +71,12 @@ type UpdateStmt struct { Predicate Expr } +type DeleteStmt struct { + *StmtBase + TableName string + Predicate Expr +} + type ExprVisitor interface { VisitBinaryExpr(*BinaryExpr) VisitUnaryExpr(*UnaryExpr) diff --git a/compiler/lexer.go b/compiler/lexer.go index 3573b63..249ed6b 100644 --- a/compiler/lexer.go +++ b/compiler/lexer.go @@ -72,6 +72,7 @@ const ( kwExists = "EXISTS" kwUpdate = "UPDATE" kwSet = "SET" + kwDelete = "DELETE" ) // keywords is a list of all keywords. @@ -98,6 +99,7 @@ var keywords = []string{ kwExists, kwUpdate, kwSet, + kwDelete, } // Operators where op is operator. diff --git a/compiler/lexer_test.go b/compiler/lexer_test.go index c1f978e..01b9a1f 100644 --- a/compiler/lexer_test.go +++ b/compiler/lexer_test.go @@ -493,6 +493,37 @@ func TestLexUpdate(t *testing.T) { } } +func TestLexDelete(t *testing.T) { + cases := []tc{ + { + sql: "DELETE FROM foo WHERE id = 1", + expected: []token{ + {tkKeyword, "DELETE"}, + {tkWhitespace, " "}, + {tkKeyword, "FROM"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + {tkWhitespace, " "}, + {tkKeyword, "WHERE"}, + {tkWhitespace, " "}, + {tkIdentifier, "id"}, + {tkWhitespace, " "}, + {tkOperator, "="}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + }, + }, + } + for _, c := range cases { + t.Run(c.sql, func(t *testing.T) { + ret := NewLexer(c.sql).Lex() + if !reflect.DeepEqual(ret, c.expected) { + t.Errorf("expected %#v got %#v", c.expected, ret) + } + }) + } +} + func TestToStatements(t *testing.T) { type testCase struct { src string diff --git a/compiler/parser.go b/compiler/parser.go index 7914349..f344623 100644 --- a/compiler/parser.go +++ b/compiler/parser.go @@ -11,10 +11,9 @@ import ( ) const ( - tokenErr = "unexpected token %s" - identErr = "expected identifier but got %s" - columnErr = "expected column type but got %s" - literalErr = "expected literal but got %s" + tokenErr = "unexpected token %s" + identErr = "expected identifier but got %s" + columnErr = "expected column type but got %s" ) type parser struct { @@ -68,6 +67,8 @@ func (p *parser) parseStmt() (Stmt, error) { return p.parseInsert(sb) case kwUpdate: return p.parseUpdate(sb) + case kwDelete: + return p.parseDelete(sb) } return nil, fmt.Errorf(tokenErr, t.value) } @@ -429,6 +430,29 @@ func (p *parser) parseUpdate(sb *StmtBase) (*UpdateStmt, error) { return stmt, nil } +func (p *parser) parseDelete(sb *StmtBase) (*DeleteStmt, error) { + stmt := &DeleteStmt{StmtBase: sb} + from := p.nextNonSpace() + if from.value != kwFrom { + return nil, fmt.Errorf(tokenErr, p.tokens[p.end].value) + } + tableName := p.nextNonSpace() + if tableName.tokenType != tkIdentifier { + return nil, fmt.Errorf(tokenErr, p.tokens[p.end].value) + } + stmt.TableName = tableName.value + possibleWhere := p.peekNextNonSpace() + if possibleWhere.value == kwWhere { + p.nextNonSpace() + expr, err := p.parseExpression(0) + if err != nil { + return nil, err + } + stmt.Predicate = expr + } + return stmt, nil +} + func (p *parser) nextNonSpace() token { p.end = p.end + 1 if p.end > len(p.tokens)-1 { diff --git a/compiler/parser_test.go b/compiler/parser_test.go index 49387e4..a69015b 100644 --- a/compiler/parser_test.go +++ b/compiler/parser_test.go @@ -648,6 +648,61 @@ func TestParseUpdate(t *testing.T) { } } +type deleteTestCase struct { + caseName string + tokens []token + expected Stmt +} + +func TestParseDelete(t *testing.T) { + cases := []deleteTestCase{ + { + caseName: "", + tokens: []token{ + {tkKeyword, "DELETE"}, + {tkWhitespace, " "}, + {tkKeyword, "FROM"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + {tkWhitespace, " "}, + {tkKeyword, "WHERE"}, + {tkWhitespace, " "}, + {tkIdentifier, "id"}, + {tkWhitespace, " "}, + {tkOperator, "="}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + }, + expected: &DeleteStmt{ + StmtBase: &StmtBase{ + Explain: false, + }, + TableName: "foo", + Predicate: &BinaryExpr{ + Left: &ColumnRef{ + Column: "id", + }, + Operator: OpEq, + Right: &IntLit{ + Value: 1, + }, + }, + }, + }, + } + for _, c := range cases { + t.Run(c.caseName, func(t *testing.T) { + ret, err := NewParser(c.tokens).Parse() + if err != nil { + t.Errorf("expected no err got err %s", err) + } + if !reflect.DeepEqual(ret, c.expected) { + t.Errorf("expected %#v got %#v", c.expected, ret) + } + }) + } +} + type resultColumnTestCase struct { name string tokens []token diff --git a/db/db.go b/db/db.go index f7d2666..0bb2cdc 100644 --- a/db/db.go +++ b/db/db.go @@ -127,6 +127,8 @@ func (db *DB) getPlannerFor(statement compiler.Stmt) statementPlanner { return planner.NewInsert(db.catalog, s) case *compiler.UpdateStmt: return planner.NewUpdate(db.catalog, s) + case *compiler.DeleteStmt: + return planner.NewDelete(db.catalog, s) } panic("statement not supported") } diff --git a/db/db_test.go b/db/db_test.go index 4a4fcfb..e3a06bd 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -389,3 +389,34 @@ func TestUpdateStatement(t *testing.T) { t.Fatalf("expected all 3 rows to be 1") } } + +func TestDeleteAll(t *testing.T) { + db := mustCreateDB(t) + mustExecute(t, db, "CREATE TABLE foo (id INTEGER PRIMARY KEY, a INTEGER);") + mustExecute(t, db, "INSERT INTO foo (a) VALUES (1), (2), (3);") + mustExecute(t, db, "DELETE FROM foo;") + res := mustExecute(t, db, "SELECT * FROM foo;") + if lrr := len(res.ResultRows); lrr != 0 { + t.Fatalf("expected no rows but got %d", lrr) + } +} + +func TestDeleteStatementWithWhere(t *testing.T) { + db := mustCreateDB(t) + mustExecute(t, db, "CREATE TABLE foo (id INTEGER PRIMARY KEY, a INTEGER);") + mustExecute(t, db, "INSERT INTO foo (a) VALUES (11), (12), (13);") + mustExecute(t, db, "DELETE FROM foo WHERE a = 12;") + res := mustExecute(t, db, "SELECT * FROM foo;") + expectedRows := 2 + if lrr := len(res.ResultRows); lrr != expectedRows { + t.Fatalf("expected %d rows but got %d", expectedRows, lrr) + } + want1 := "11" + if got1 := *res.ResultRows[0][1]; got1 != want1 { + t.Fatalf("expected %s but got %s", want1, got1) + } + want2 := "13" + if got2 := *res.ResultRows[1][1]; got2 != want2 { + t.Fatalf("expected %s but got %s", want2, got2) + } +} diff --git a/planner/assert_test.go b/planner/assert_test.go index 9652c65..d88515b 100644 --- a/planner/assert_test.go +++ b/planner/assert_test.go @@ -16,6 +16,9 @@ func assertCommandsMatch(gotCommands, expectedCommands []vm.Command) error { red := "\033[31m" resetColor := "\033[0m" for i, c := range expectedCommands { + if i >= len(gotCommands) { + continue + } color := green if !reflect.DeepEqual(c, gotCommands[i]) { didMatch = false diff --git a/planner/delete.go b/planner/delete.go new file mode 100644 index 0000000..3d4d8dc --- /dev/null +++ b/planner/delete.go @@ -0,0 +1,86 @@ +package planner + +import ( + "github.com/chirst/cdb/catalog" + "github.com/chirst/cdb/compiler" + "github.com/chirst/cdb/vm" +) + +type deleteCatalog interface { + GetVersion() string + GetRootPageNumber(string) (int, error) + GetColumns(string) ([]string, error) + GetPrimaryKeyColumn(string) (string, error) + GetColumnType(tableName string, columnName string) (catalog.CdbType, error) +} + +type deletePlanner struct { + catalog deleteCatalog + stmt *compiler.DeleteStmt + queryPlan *deleteNode + executionPlan *vm.ExecutionPlan +} + +func NewDelete(catalog deleteCatalog, stmt *compiler.DeleteStmt) *deletePlanner { + return &deletePlanner{ + catalog: catalog, + stmt: stmt, + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), + } +} + +// QueryPlan implements db.statementPlanner. +func (d *deletePlanner) QueryPlan() (*QueryPlan, error) { + rootPageNumber, err := d.catalog.GetRootPageNumber(d.stmt.TableName) + if err != nil { + return nil, errTableNotExist + } + deleteNode := &deleteNode{ + rootPageNumber: rootPageNumber, + cursorId: 1, + } + qp := newQueryPlan(deleteNode, d.stmt.ExplainQueryPlan, transactionTypeWrite) + deleteNode.plan = qp + d.queryPlan = deleteNode + sn := &scanNode{ + plan: qp, + tableName: d.stmt.TableName, + rootPageNumber: rootPageNumber, + cursorId: 1, + isWriteCursor: true, + } + if d.stmt.Predicate != nil { + cev := &catalogExprVisitor{} + cev.Init(d.catalog, d.stmt.TableName) + d.stmt.Predicate.BreadthWalk(cev) + fn := &filterNode{ + plan: qp, + predicate: d.stmt.Predicate, + cursorId: 1, + } + deleteNode.child = fn + fn.parent = deleteNode + sn.parent = fn + fn.child = sn + } else { + deleteNode.child = sn + sn.parent = deleteNode + } + return qp, nil +} + +// ExecutionPlan implements db.statementPlanner. +func (d *deletePlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { + if d.queryPlan == nil { + _, err := d.QueryPlan() + if err != nil { + return nil, err + } + } + d.queryPlan.plan.compile() + d.executionPlan.Commands = d.queryPlan.plan.commands + return d.executionPlan, nil +} diff --git a/planner/delete_test.go b/planner/delete_test.go new file mode 100644 index 0000000..abb5150 --- /dev/null +++ b/planner/delete_test.go @@ -0,0 +1,110 @@ +package planner + +import ( + "errors" + "testing" + + "github.com/chirst/cdb/catalog" + "github.com/chirst/cdb/compiler" + "github.com/chirst/cdb/vm" +) + +type mockDeleteCatalog struct{} + +func (*mockDeleteCatalog) GetVersion() string { + return "mock" +} + +func (*mockDeleteCatalog) GetRootPageNumber(tableName string) (int, error) { + if tableName == "foo" { + return 2, nil + } + return -1, errors.New("err mock catalog root page") +} + +func (*mockDeleteCatalog) GetColumns(tableName string) ([]string, error) { + if tableName == "foo" { + return []string{ + "id", + "age", + }, nil + } + return nil, errors.New("err mock catalog columns") +} + +func (*mockDeleteCatalog) GetPrimaryKeyColumn(tableName string) (string, error) { + if tableName == "foo" { + return "id", nil + } + return "", errors.New("err mock catalog pk") +} + +func (mockDeleteCatalog) GetColumnType(tableName string, columnName string) (catalog.CdbType, error) { + return catalog.CdbType{ID: catalog.CTInt}, nil +} + +func TestDelete(t *testing.T) { + type deleteTestCase struct { + expectation string + ast *compiler.DeleteStmt + expectedCommands []vm.Command + } + tcs := []deleteTestCase{ + { + expectation: "DeleteWithNoPredicate", + ast: &compiler.DeleteStmt{ + StmtBase: &compiler.StmtBase{}, + TableName: "foo", + }, + expectedCommands: []vm.Command{ + &vm.InitCmd{P2: 6}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 5}, + &vm.DeleteCmd{P1: 1}, + &vm.NextCmd{P1: 1, P2: 3}, + &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.GotoCmd{P2: 1}, + }, + }, + { + expectation: "DeleteWithPredicate", + ast: &compiler.DeleteStmt{ + StmtBase: &compiler.StmtBase{}, + TableName: "foo", + Predicate: &compiler.BinaryExpr{ + Operator: compiler.OpEq, + Left: &compiler.ColumnRef{ + Column: "id", + }, + Right: &compiler.IntLit{Value: 1}, + }, + }, + expectedCommands: []vm.Command{ + &vm.InitCmd{P2: 8}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 7}, + &vm.RowIdCmd{P1: 1, P2: 1}, + &vm.NotEqualCmd{P1: 1, P2: 6, P3: 2}, + &vm.DeleteCmd{P1: 1}, + &vm.NextCmd{P1: 1, P2: 3}, + &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.IntegerCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.expectation, func(t *testing.T) { + mockCatalog := &mockDeleteCatalog{} + plan, err := NewDelete(mockCatalog, tc.ast).ExecutionPlan() + if err != nil { + t.Errorf("expected no err got err %s", err) + } + if err := assertCommandsMatch(plan.Commands, tc.expectedCommands); err != nil { + t.Error(err) + } + }) + } +} diff --git a/planner/generator.go b/planner/generator.go index aec6da3..2f600c1 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -216,6 +216,14 @@ func (n *insertNode) consume() { } } +func (d *deleteNode) consume() { + d.plan.commands = append(d.plan.commands, &vm.DeleteCmd{P1: d.cursorId}) +} + +func (d *deleteNode) produce() { + d.child.produce() +} + func (n *joinNode) produce() {} func (n *joinNode) consume() {} diff --git a/planner/node.go b/planner/node.go index abb4856..ba30cb2 100644 --- a/planner/node.go +++ b/planner/node.go @@ -239,3 +239,18 @@ func (u *updateNode) print() string { func (u *updateNode) children() []logicalNode { return []logicalNode{} } + +type deleteNode struct { + child logicalNode + plan *QueryPlan + rootPageNumber int + cursorId int +} + +func (d *deleteNode) print() string { + return "delete" +} + +func (d *deleteNode) children() []logicalNode { + return []logicalNode{d.child} +}