Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions compiler/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,28 +363,11 @@ func (p *parser) parseValue(stmt *InsertStmt, valueIdx int) (*InsertStmt, error)
}
stmt.ColValues = append(stmt.ColValues, []Expr{})
for {
v := p.nextNonSpace()
if v.tokenType != tkNumeric && v.tokenType != tkLiteral && v.tokenType != tkParam {
return nil, fmt.Errorf(literalErr, v.value)
}
if v.tokenType == tkLiteral {
stmt.ColValues[valueIdx] = append(
stmt.ColValues[valueIdx],
&StringLit{
Value: v.value,
},
)
} else if v.tokenType == tkParam {
variableExpr := &Variable{Position: p.paramCount}
p.paramCount += 1
stmt.ColValues[valueIdx] = append(stmt.ColValues[valueIdx], variableExpr)
} else {
intValue, err := strconv.Atoi(v.value)
if err != nil {
return nil, fmt.Errorf("failed to convert %v to integer", v.value)
}
stmt.ColValues[valueIdx] = append(stmt.ColValues[valueIdx], &IntLit{Value: intValue})
exp, err := p.parseExpression(0)
if err != nil {
return nil, err
}
stmt.ColValues[valueIdx] = append(stmt.ColValues[valueIdx], exp)
sep := p.nextNonSpace()
if sep.value != "," {
if sep.value == ")" {
Expand Down
88 changes: 81 additions & 7 deletions compiler/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,15 @@ func TestParseCreate(t *testing.T) {
}

type insertTestCase struct {
name string
tokens []token
expected Stmt
}

func TestParseInsert(t *testing.T) {
cases := []insertTestCase{
{
name: "ManyValues",
tokens: []token{
{tkKeyword, "INSERT"},
{tkWhitespace, " "},
Expand Down Expand Up @@ -448,15 +450,87 @@ func TestParseInsert(t *testing.T) {
},
},
},
{
name: "WithExpressions",
tokens: []token{
{tkKeyword, "INSERT"},
{tkWhitespace, " "},
{tkKeyword, "INTO"},
{tkWhitespace, " "},
{tkIdentifier, "foo"},
{tkWhitespace, " "},
{tkSeparator, "("},
{tkIdentifier, "id"},
{tkSeparator, ","},
{tkWhitespace, " "},
{tkIdentifier, "age"},
{tkSeparator, ")"},
{tkWhitespace, " "},
{tkKeyword, "VALUES"},
{tkWhitespace, " "},
{tkSeparator, "("},
{tkNumeric, "1"},
{tkWhitespace, " "},
{tkOperator, "+"},
{tkWhitespace, " "},
{tkNumeric, "1"},
{tkSeparator, ","},
{tkWhitespace, " "},
{tkNumeric, "2"},
{tkWhitespace, " "},
{tkOperator, "-"},
{tkWhitespace, " "},
{tkNumeric, "1"},
{tkSeparator, ")"},
{tkSeparator, ","},
{tkWhitespace, " "},
{tkSeparator, "("},
{tkNumeric, "3"},
{tkSeparator, ","},
{tkWhitespace, " "},
{tkNumeric, "4"},
{tkSeparator, ")"},
},
expected: &InsertStmt{
StmtBase: &StmtBase{
Explain: false,
},
TableName: "foo",
ColNames: []string{
"id",
"age",
},
ColValues: [][]Expr{
{
&BinaryExpr{
Operator: "+",
Left: &IntLit{Value: 1},
Right: &IntLit{Value: 1},
},
&BinaryExpr{
Operator: "-",
Left: &IntLit{Value: 2},
Right: &IntLit{Value: 1},
},
},
{
&IntLit{Value: 3},
&IntLit{Value: 4},
},
},
},
},
}
for _, c := range cases {
ret, err := NewParser(c.tokens).Parse()
if err != nil {
t.Errorf("expected no err got err %s", err)
}
if !reflect.DeepEqual(ret, c.expected) {
t.Errorf("expected %#v got %#v", c.expected, ret)
}
t.Run(c.name, func(t *testing.T) {
ret, err := NewParser(c.tokens).Parse()
if err != nil {
t.Errorf("expected no err got err %s", err)
}
if !reflect.DeepEqual(ret, c.expected) {
t.Errorf("expected %#v got %#v", c.expected, ret)
}
})
}
}

Expand Down
19 changes: 19 additions & 0 deletions db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,25 @@ func TestAddColumns(t *testing.T) {
}
}

func TestSelectHeaders(t *testing.T) {
db := mustCreateDB(t)
mustExecute(t, db, "CREATE TABLE test (id INTEGER PRIMARY KEY, val INTEGER)")
mustExecute(t, db, "INSERT INTO test (val) VALUES (1)")
res := mustExecute(t, db, "SELECT *, id AS foo FROM test")
if rowCount := len(res.ResultRows); rowCount != 1 {
t.Fatalf("want 1 row but got %d", rowCount)
}
if got := res.ResultHeader[0]; got != "id" {
t.Fatalf("want id but got %s", got)
}
if got := res.ResultHeader[1]; got != "val" {
t.Fatalf("want val but got %s", got)
}
if got := res.ResultHeader[2]; got != "foo" {
t.Fatalf("want foo but got %s", got)
}
}

func TestSelectWithWhere(t *testing.T) {
db := mustCreateDB(t)
mustExecute(t, db, "CREATE TABLE test (id INTEGER PRIMARY KEY, val INTEGER)")
Expand Down
44 changes: 44 additions & 0 deletions planner/assert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package planner

import (
"errors"
"fmt"
"reflect"

"github.com/chirst/cdb/vm"
)

// assertCommandsMatch is a helper for tests in the planner package.
func assertCommandsMatch(gotCommands, expectedCommands []vm.Command) error {
didMatch := true
errOutput := "\n"
green := "\033[32m"
red := "\033[31m"
resetColor := "\033[0m"
for i, c := range expectedCommands {
color := green
if !reflect.DeepEqual(c, gotCommands[i]) {
didMatch = false
color = red
}
errOutput += fmt.Sprintf(
"%s%3d got %#v%s\n want %#v\n\n",
color, i, gotCommands[i], resetColor, c,
)
}
gl := len(gotCommands)
wl := len(expectedCommands)
if gl != wl {
errOutput += red
errOutput += fmt.Sprintf("got %d want %d commands\n", gl, wl)
errOutput += resetColor
didMatch = false
}
// This helper returns an error instead of making the assertion so a fatal
// error will raise at the test site instead of the helper. This also allows
// the caller to differentiate between a fatal or non fatal assertion.
if !didMatch {
return errors.New(errOutput)
}
return nil
}
15 changes: 12 additions & 3 deletions planner/cevisitor.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
package planner

import "github.com/chirst/cdb/compiler"
import (
"github.com/chirst/cdb/catalog"
"github.com/chirst/cdb/compiler"
)

// catalogExprVisitor assigns catalog information to visited expressions.
type catalogExprVisitor struct {
catalog selectCatalog
catalog cevCatalog
tableName string
err error
}

func (c *catalogExprVisitor) Init(catalog selectCatalog, tableName string) {
type cevCatalog interface {
GetColumns(string) ([]string, error)
GetPrimaryKeyColumn(string) (string, error)
GetColumnType(tableName string, columnName string) (catalog.CdbType, error)
}

func (c *catalogExprVisitor) Init(catalog cevCatalog, tableName string) {
c.catalog = catalog
c.tableName = tableName
}
Expand Down
Loading
Loading