Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions compiler/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ type UpdateStmt struct {
Predicate Expr
}

type DeleteStmt struct {
*StmtBase
TableName string
Predicate Expr
}

type ExprVisitor interface {
VisitBinaryExpr(*BinaryExpr)
VisitUnaryExpr(*UnaryExpr)
Expand Down
2 changes: 2 additions & 0 deletions compiler/lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const (
kwExists = "EXISTS"
kwUpdate = "UPDATE"
kwSet = "SET"
kwDelete = "DELETE"
)

// keywords is a list of all keywords.
Expand All @@ -98,6 +99,7 @@ var keywords = []string{
kwExists,
kwUpdate,
kwSet,
kwDelete,
}

// Operators where op is operator.
Expand Down
31 changes: 31 additions & 0 deletions compiler/lexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,37 @@ func TestLexUpdate(t *testing.T) {
}
}

func TestLexDelete(t *testing.T) {
cases := []tc{
{
sql: "DELETE FROM foo WHERE id = 1",
expected: []token{
{tkKeyword, "DELETE"},
{tkWhitespace, " "},
{tkKeyword, "FROM"},
{tkWhitespace, " "},
{tkIdentifier, "foo"},
{tkWhitespace, " "},
{tkKeyword, "WHERE"},
{tkWhitespace, " "},
{tkIdentifier, "id"},
{tkWhitespace, " "},
{tkOperator, "="},
{tkWhitespace, " "},
{tkNumeric, "1"},
},
},
}
for _, c := range cases {
t.Run(c.sql, func(t *testing.T) {
ret := NewLexer(c.sql).Lex()
if !reflect.DeepEqual(ret, c.expected) {
t.Errorf("expected %#v got %#v", c.expected, ret)
}
})
}
}

func TestToStatements(t *testing.T) {
type testCase struct {
src string
Expand Down
32 changes: 28 additions & 4 deletions compiler/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ import (
)

const (
tokenErr = "unexpected token %s"
identErr = "expected identifier but got %s"
columnErr = "expected column type but got %s"
literalErr = "expected literal but got %s"
tokenErr = "unexpected token %s"
identErr = "expected identifier but got %s"
columnErr = "expected column type but got %s"
)

type parser struct {
Expand Down Expand Up @@ -68,6 +67,8 @@ func (p *parser) parseStmt() (Stmt, error) {
return p.parseInsert(sb)
case kwUpdate:
return p.parseUpdate(sb)
case kwDelete:
return p.parseDelete(sb)
}
return nil, fmt.Errorf(tokenErr, t.value)
}
Expand Down Expand Up @@ -429,6 +430,29 @@ func (p *parser) parseUpdate(sb *StmtBase) (*UpdateStmt, error) {
return stmt, nil
}

func (p *parser) parseDelete(sb *StmtBase) (*DeleteStmt, error) {
stmt := &DeleteStmt{StmtBase: sb}
from := p.nextNonSpace()
if from.value != kwFrom {
return nil, fmt.Errorf(tokenErr, p.tokens[p.end].value)
}
tableName := p.nextNonSpace()
if tableName.tokenType != tkIdentifier {
return nil, fmt.Errorf(tokenErr, p.tokens[p.end].value)
}
stmt.TableName = tableName.value
possibleWhere := p.peekNextNonSpace()
if possibleWhere.value == kwWhere {
p.nextNonSpace()
expr, err := p.parseExpression(0)
if err != nil {
return nil, err
}
stmt.Predicate = expr
}
return stmt, nil
}

func (p *parser) nextNonSpace() token {
p.end = p.end + 1
if p.end > len(p.tokens)-1 {
Expand Down
55 changes: 55 additions & 0 deletions compiler/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,61 @@ func TestParseUpdate(t *testing.T) {
}
}

type deleteTestCase struct {
caseName string
tokens []token
expected Stmt
}

func TestParseDelete(t *testing.T) {
cases := []deleteTestCase{
{
caseName: "",
tokens: []token{
{tkKeyword, "DELETE"},
{tkWhitespace, " "},
{tkKeyword, "FROM"},
{tkWhitespace, " "},
{tkIdentifier, "foo"},
{tkWhitespace, " "},
{tkKeyword, "WHERE"},
{tkWhitespace, " "},
{tkIdentifier, "id"},
{tkWhitespace, " "},
{tkOperator, "="},
{tkWhitespace, " "},
{tkNumeric, "1"},
},
expected: &DeleteStmt{
StmtBase: &StmtBase{
Explain: false,
},
TableName: "foo",
Predicate: &BinaryExpr{
Left: &ColumnRef{
Column: "id",
},
Operator: OpEq,
Right: &IntLit{
Value: 1,
},
},
},
},
}
for _, c := range cases {
t.Run(c.caseName, func(t *testing.T) {
ret, err := NewParser(c.tokens).Parse()
if err != nil {
t.Errorf("expected no err got err %s", err)
}
if !reflect.DeepEqual(ret, c.expected) {
t.Errorf("expected %#v got %#v", c.expected, ret)
}
})
}
}

type resultColumnTestCase struct {
name string
tokens []token
Expand Down
2 changes: 2 additions & 0 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ func (db *DB) getPlannerFor(statement compiler.Stmt) statementPlanner {
return planner.NewInsert(db.catalog, s)
case *compiler.UpdateStmt:
return planner.NewUpdate(db.catalog, s)
case *compiler.DeleteStmt:
return planner.NewDelete(db.catalog, s)
}
panic("statement not supported")
}
31 changes: 31 additions & 0 deletions db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,34 @@ func TestUpdateStatement(t *testing.T) {
t.Fatalf("expected all 3 rows to be 1")
}
}

func TestDeleteAll(t *testing.T) {
db := mustCreateDB(t)
mustExecute(t, db, "CREATE TABLE foo (id INTEGER PRIMARY KEY, a INTEGER);")
mustExecute(t, db, "INSERT INTO foo (a) VALUES (1), (2), (3);")
mustExecute(t, db, "DELETE FROM foo;")
res := mustExecute(t, db, "SELECT * FROM foo;")
if lrr := len(res.ResultRows); lrr != 0 {
t.Fatalf("expected no rows but got %d", lrr)
}
}

func TestDeleteStatementWithWhere(t *testing.T) {
db := mustCreateDB(t)
mustExecute(t, db, "CREATE TABLE foo (id INTEGER PRIMARY KEY, a INTEGER);")
mustExecute(t, db, "INSERT INTO foo (a) VALUES (11), (12), (13);")
mustExecute(t, db, "DELETE FROM foo WHERE a = 12;")
res := mustExecute(t, db, "SELECT * FROM foo;")
expectedRows := 2
if lrr := len(res.ResultRows); lrr != expectedRows {
t.Fatalf("expected %d rows but got %d", expectedRows, lrr)
}
want1 := "11"
if got1 := *res.ResultRows[0][1]; got1 != want1 {
t.Fatalf("expected %s but got %s", want1, got1)
}
want2 := "13"
if got2 := *res.ResultRows[1][1]; got2 != want2 {
t.Fatalf("expected %s but got %s", want2, got2)
}
}
3 changes: 3 additions & 0 deletions planner/assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ func assertCommandsMatch(gotCommands, expectedCommands []vm.Command) error {
red := "\033[31m"
resetColor := "\033[0m"
for i, c := range expectedCommands {
if i >= len(gotCommands) {
continue
}
color := green
if !reflect.DeepEqual(c, gotCommands[i]) {
didMatch = false
Expand Down
86 changes: 86 additions & 0 deletions planner/delete.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package planner

import (
"github.com/chirst/cdb/catalog"
"github.com/chirst/cdb/compiler"
"github.com/chirst/cdb/vm"
)

type deleteCatalog interface {
GetVersion() string
GetRootPageNumber(string) (int, error)
GetColumns(string) ([]string, error)
GetPrimaryKeyColumn(string) (string, error)
GetColumnType(tableName string, columnName string) (catalog.CdbType, error)
}

type deletePlanner struct {
catalog deleteCatalog
stmt *compiler.DeleteStmt
queryPlan *deleteNode
executionPlan *vm.ExecutionPlan
}

func NewDelete(catalog deleteCatalog, stmt *compiler.DeleteStmt) *deletePlanner {
return &deletePlanner{
catalog: catalog,
stmt: stmt,
executionPlan: vm.NewExecutionPlan(
catalog.GetVersion(),
stmt.Explain,
),
}
}

// QueryPlan implements db.statementPlanner.
func (d *deletePlanner) QueryPlan() (*QueryPlan, error) {
rootPageNumber, err := d.catalog.GetRootPageNumber(d.stmt.TableName)
if err != nil {
return nil, errTableNotExist
}
deleteNode := &deleteNode{
rootPageNumber: rootPageNumber,
cursorId: 1,
}
qp := newQueryPlan(deleteNode, d.stmt.ExplainQueryPlan, transactionTypeWrite)
deleteNode.plan = qp
d.queryPlan = deleteNode
sn := &scanNode{
plan: qp,
tableName: d.stmt.TableName,
rootPageNumber: rootPageNumber,
cursorId: 1,
isWriteCursor: true,
}
if d.stmt.Predicate != nil {
cev := &catalogExprVisitor{}
cev.Init(d.catalog, d.stmt.TableName)
d.stmt.Predicate.BreadthWalk(cev)
fn := &filterNode{
plan: qp,
predicate: d.stmt.Predicate,
cursorId: 1,
}
deleteNode.child = fn
fn.parent = deleteNode
sn.parent = fn
fn.child = sn
} else {
deleteNode.child = sn
sn.parent = deleteNode
}
return qp, nil
}

// ExecutionPlan implements db.statementPlanner.
func (d *deletePlanner) ExecutionPlan() (*vm.ExecutionPlan, error) {
if d.queryPlan == nil {
_, err := d.QueryPlan()
if err != nil {
return nil, err
}
}
d.queryPlan.plan.compile()
d.executionPlan.Commands = d.queryPlan.plan.commands
return d.executionPlan, nil
}
Loading
Loading