diff --git a/compiler/parser.go b/compiler/parser.go index 3332ecb..7914349 100644 --- a/compiler/parser.go +++ b/compiler/parser.go @@ -363,28 +363,11 @@ func (p *parser) parseValue(stmt *InsertStmt, valueIdx int) (*InsertStmt, error) } stmt.ColValues = append(stmt.ColValues, []Expr{}) for { - v := p.nextNonSpace() - if v.tokenType != tkNumeric && v.tokenType != tkLiteral && v.tokenType != tkParam { - return nil, fmt.Errorf(literalErr, v.value) - } - if v.tokenType == tkLiteral { - stmt.ColValues[valueIdx] = append( - stmt.ColValues[valueIdx], - &StringLit{ - Value: v.value, - }, - ) - } else if v.tokenType == tkParam { - variableExpr := &Variable{Position: p.paramCount} - p.paramCount += 1 - stmt.ColValues[valueIdx] = append(stmt.ColValues[valueIdx], variableExpr) - } else { - intValue, err := strconv.Atoi(v.value) - if err != nil { - return nil, fmt.Errorf("failed to convert %v to integer", v.value) - } - stmt.ColValues[valueIdx] = append(stmt.ColValues[valueIdx], &IntLit{Value: intValue}) + exp, err := p.parseExpression(0) + if err != nil { + return nil, err } + stmt.ColValues[valueIdx] = append(stmt.ColValues[valueIdx], exp) sep := p.nextNonSpace() if sep.value != "," { if sep.value == ")" { diff --git a/compiler/parser_test.go b/compiler/parser_test.go index 60f78c8..49387e4 100644 --- a/compiler/parser_test.go +++ b/compiler/parser_test.go @@ -345,6 +345,7 @@ func TestParseCreate(t *testing.T) { } type insertTestCase struct { + name string tokens []token expected Stmt } @@ -352,6 +353,7 @@ type insertTestCase struct { func TestParseInsert(t *testing.T) { cases := []insertTestCase{ { + name: "ManyValues", tokens: []token{ {tkKeyword, "INSERT"}, {tkWhitespace, " "}, @@ -448,15 +450,87 @@ func TestParseInsert(t *testing.T) { }, }, }, + { + name: "WithExpressions", + tokens: []token{ + {tkKeyword, "INSERT"}, + {tkWhitespace, " "}, + {tkKeyword, "INTO"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + {tkWhitespace, " "}, + {tkSeparator, "("}, + {tkIdentifier, "id"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkIdentifier, "age"}, + {tkSeparator, ")"}, + {tkWhitespace, " "}, + {tkKeyword, "VALUES"}, + {tkWhitespace, " "}, + {tkSeparator, "("}, + {tkNumeric, "1"}, + {tkWhitespace, " "}, + {tkOperator, "+"}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkNumeric, "2"}, + {tkWhitespace, " "}, + {tkOperator, "-"}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + {tkSeparator, ")"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkSeparator, "("}, + {tkNumeric, "3"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkNumeric, "4"}, + {tkSeparator, ")"}, + }, + expected: &InsertStmt{ + StmtBase: &StmtBase{ + Explain: false, + }, + TableName: "foo", + ColNames: []string{ + "id", + "age", + }, + ColValues: [][]Expr{ + { + &BinaryExpr{ + Operator: "+", + Left: &IntLit{Value: 1}, + Right: &IntLit{Value: 1}, + }, + &BinaryExpr{ + Operator: "-", + Left: &IntLit{Value: 2}, + Right: &IntLit{Value: 1}, + }, + }, + { + &IntLit{Value: 3}, + &IntLit{Value: 4}, + }, + }, + }, + }, } for _, c := range cases { - ret, err := NewParser(c.tokens).Parse() - if err != nil { - t.Errorf("expected no err got err %s", err) - } - if !reflect.DeepEqual(ret, c.expected) { - t.Errorf("expected %#v got %#v", c.expected, ret) - } + t.Run(c.name, func(t *testing.T) { + ret, err := NewParser(c.tokens).Parse() + if err != nil { + t.Errorf("expected no err got err %s", err) + } + if !reflect.DeepEqual(ret, c.expected) { + t.Errorf("expected %#v got %#v", c.expected, ret) + } + }) } } diff --git a/db/db_test.go b/db/db_test.go index bce71a5..4a4fcfb 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -157,6 +157,25 @@ func TestAddColumns(t *testing.T) { } } +func TestSelectHeaders(t *testing.T) { + db := mustCreateDB(t) + mustExecute(t, db, "CREATE TABLE test (id INTEGER PRIMARY KEY, val INTEGER)") + mustExecute(t, db, "INSERT INTO test (val) VALUES (1)") + res := mustExecute(t, db, "SELECT *, id AS foo FROM test") + if rowCount := len(res.ResultRows); rowCount != 1 { + t.Fatalf("want 1 row but got %d", rowCount) + } + if got := res.ResultHeader[0]; got != "id" { + t.Fatalf("want id but got %s", got) + } + if got := res.ResultHeader[1]; got != "val" { + t.Fatalf("want val but got %s", got) + } + if got := res.ResultHeader[2]; got != "foo" { + t.Fatalf("want foo but got %s", got) + } +} + func TestSelectWithWhere(t *testing.T) { db := mustCreateDB(t) mustExecute(t, db, "CREATE TABLE test (id INTEGER PRIMARY KEY, val INTEGER)") diff --git a/planner/assert_test.go b/planner/assert_test.go new file mode 100644 index 0000000..9652c65 --- /dev/null +++ b/planner/assert_test.go @@ -0,0 +1,44 @@ +package planner + +import ( + "errors" + "fmt" + "reflect" + + "github.com/chirst/cdb/vm" +) + +// assertCommandsMatch is a helper for tests in the planner package. +func assertCommandsMatch(gotCommands, expectedCommands []vm.Command) error { + didMatch := true + errOutput := "\n" + green := "\033[32m" + red := "\033[31m" + resetColor := "\033[0m" + for i, c := range expectedCommands { + color := green + if !reflect.DeepEqual(c, gotCommands[i]) { + didMatch = false + color = red + } + errOutput += fmt.Sprintf( + "%s%3d got %#v%s\n want %#v\n\n", + color, i, gotCommands[i], resetColor, c, + ) + } + gl := len(gotCommands) + wl := len(expectedCommands) + if gl != wl { + errOutput += red + errOutput += fmt.Sprintf("got %d want %d commands\n", gl, wl) + errOutput += resetColor + didMatch = false + } + // This helper returns an error instead of making the assertion so a fatal + // error will raise at the test site instead of the helper. This also allows + // the caller to differentiate between a fatal or non fatal assertion. + if !didMatch { + return errors.New(errOutput) + } + return nil +} diff --git a/planner/cevisitor.go b/planner/cevisitor.go index 0675450..13122a6 100644 --- a/planner/cevisitor.go +++ b/planner/cevisitor.go @@ -1,15 +1,24 @@ package planner -import "github.com/chirst/cdb/compiler" +import ( + "github.com/chirst/cdb/catalog" + "github.com/chirst/cdb/compiler" +) // catalogExprVisitor assigns catalog information to visited expressions. type catalogExprVisitor struct { - catalog selectCatalog + catalog cevCatalog tableName string err error } -func (c *catalogExprVisitor) Init(catalog selectCatalog, tableName string) { +type cevCatalog interface { + GetColumns(string) ([]string, error) + GetPrimaryKeyColumn(string) (string, error) + GetColumnType(tableName string, columnName string) (catalog.CdbType, error) +} + +func (c *catalogExprVisitor) Init(catalog cevCatalog, tableName string) { c.catalog = catalog c.tableName = tableName } diff --git a/planner/create.go b/planner/create.go index 5e8b831..7165348 100644 --- a/planner/create.go +++ b/planner/create.go @@ -17,23 +17,8 @@ type createCatalog interface { } // createPlanner is capable of generating a logical query plan and a physical -// executionPlan for a create statement. The planners within are separated by -// their responsibility. +// executionPlan for a create statement. type createPlanner struct { - // queryPlanner is responsible for transforming the AST to a logical query - // plan tree. This tree is made up of nodes similar to a relational algebra - // tree. The query planner also performs binding and validation. - queryPlanner *createQueryPlanner - // executionPlanner is responsible for converting the logical query plan - // tree to a bytecode execution plan capable of being run by the virtual - // machine. - executionPlanner *createExecutionPlanner -} - -// createQueryPlanner converts the AST to a logical query plan. Along the way it -// validates the statement makes sense with the catalog a process known as -// binding. -type createQueryPlanner struct { // catalog contains the schema catalog createCatalog // stmt contains the AST @@ -41,14 +26,6 @@ type createQueryPlanner struct { // queryPlan contains the query plan being constructed. The root node must // be createNode. queryPlan *createNode -} - -// createExecutionPlanner converts logical nodes to a bytecode execution plan -// that can be run by the vm. -type createExecutionPlanner struct { - // queryPlan contains the logical query plan. The is populated by calling - // QueryPlan. - queryPlan *createNode // executionPlan contains the bytecode execution plan being constructed. // This is populated by calling ExecutionPlan. executionPlan *vm.ExecutionPlan @@ -57,38 +34,34 @@ type createExecutionPlanner struct { // NewCreate creates a planner for the given create statement. func NewCreate(catalog createCatalog, stmt *compiler.CreateStmt) *createPlanner { return &createPlanner{ - queryPlanner: &createQueryPlanner{ - catalog: catalog, - stmt: stmt, - }, - executionPlanner: &createExecutionPlanner{ - executionPlan: vm.NewExecutionPlan( - catalog.GetVersion(), - stmt.Explain, - ), - }, + catalog: catalog, + stmt: stmt, + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), } } // QueryPlan generates the query plan for the planner. func (p *createPlanner) QueryPlan() (*QueryPlan, error) { - qp, err := p.queryPlanner.getQueryPlan() - if err != nil { - return nil, err - } - p.executionPlanner.queryPlan = p.queryPlanner.queryPlan - return qp, err -} - -func (p *createQueryPlanner) getQueryPlan() (*QueryPlan, error) { + schemaTableRoot := 1 tableExists := p.catalog.TableExists(p.stmt.TableName) if p.stmt.IfNotExists && tableExists { noopCreateNode := &createNode{ - noop: true, - tableName: p.stmt.TableName, + noop: true, + tableName: p.stmt.TableName, + catalogRootPageNumber: schemaTableRoot, + catalogCursorId: 1, } p.queryPlan = noopCreateNode - return newQueryPlan(noopCreateNode, p.stmt.ExplainQueryPlan), nil + qp := newQueryPlan( + noopCreateNode, + p.stmt.ExplainQueryPlan, + transactionTypeWrite, + ) + noopCreateNode.plan = qp + return qp, nil } if tableExists { return nil, errTableExists @@ -98,16 +71,24 @@ func (p *createQueryPlanner) getQueryPlan() (*QueryPlan, error) { return nil, err } createNode := &createNode{ - objectType: "table", - objectName: p.stmt.TableName, - tableName: p.stmt.TableName, - schema: jSchema, + objectType: "table", + objectName: p.stmt.TableName, + tableName: p.stmt.TableName, + schema: jSchema, + catalogRootPageNumber: schemaTableRoot, + catalogCursorId: 1, } p.queryPlan = createNode - return newQueryPlan(createNode, p.stmt.ExplainQueryPlan), nil + qp := newQueryPlan( + createNode, + p.stmt.ExplainQueryPlan, + transactionTypeWrite, + ) + createNode.plan = qp + return qp, nil } -func (p *createQueryPlanner) getSchemaString() (string, error) { +func (p *createPlanner) getSchemaString() (string, error) { if err := p.ensurePrimaryKeyCount(); err != nil { return "", err } @@ -124,7 +105,7 @@ func (p *createQueryPlanner) getSchemaString() (string, error) { // The id column must be an integer. The index key is capable of being something // other than an integer, but is not worth implementing at the moment. Integer // primary keys are superior for auto incrementing and being unique. -func (p *createQueryPlanner) ensurePrimaryKeyInteger() error { +func (p *createPlanner) ensurePrimaryKeyInteger() error { hasPK := slices.ContainsFunc(p.stmt.ColDefs, func(cd compiler.ColDef) bool { return cd.PrimaryKey }) @@ -141,7 +122,7 @@ func (p *createQueryPlanner) ensurePrimaryKeyInteger() error { } // Only one primary key is supported at this time. -func (p *createQueryPlanner) ensurePrimaryKeyCount() error { +func (p *createPlanner) ensurePrimaryKeyCount() error { count := 0 for _, cd := range p.stmt.ColDefs { if cd.PrimaryKey { @@ -154,7 +135,7 @@ func (p *createQueryPlanner) ensurePrimaryKeyCount() error { return nil } -func (p *createQueryPlanner) schemaFrom() *catalog.TableSchema { +func (p *createPlanner) schemaFrom() *catalog.TableSchema { schema := catalog.TableSchema{ Columns: []catalog.TableColumn{}, } @@ -172,43 +153,13 @@ func (p *createQueryPlanner) schemaFrom() *catalog.TableSchema { // QueryPlan is not a prerequisite to this method as it will be called by // ExecutionPlan if needed. func (p *createPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { - if p.queryPlanner.queryPlan == nil { + if p.queryPlan == nil { _, err := p.QueryPlan() if err != nil { return nil, err } } - if p.queryPlanner.queryPlan.noop { - return p.executionPlanner.getNoopExecutionPlan(), nil - } - return p.executionPlanner.getExecutionPlan(), nil -} - -// getNoopExecutionPlan asserts the query can be ran based on the information -// provided by the catalog. If the catalog were to go out of date this execution -// plan will be recompiled before it is ever ran. -func (p *createExecutionPlanner) getNoopExecutionPlan() *vm.ExecutionPlan { - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - p.executionPlan.Append(&vm.TransactionCmd{P2: 1}) - p.executionPlan.Append(&vm.HaltCmd{}) - return p.executionPlan -} - -func (p *createExecutionPlanner) getExecutionPlan() *vm.ExecutionPlan { - const cursorId = 1 - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - p.executionPlan.Append(&vm.TransactionCmd{P2: 1}) - p.executionPlan.Append(&vm.CreateBTreeCmd{P2: 1}) - p.executionPlan.Append(&vm.OpenWriteCmd{P1: cursorId, P2: 1}) - p.executionPlan.Append(&vm.NewRowIdCmd{P1: cursorId, P2: 2}) - p.executionPlan.Append(&vm.StringCmd{P1: 3, P4: p.queryPlan.objectType}) - p.executionPlan.Append(&vm.StringCmd{P1: 4, P4: p.queryPlan.objectName}) - p.executionPlan.Append(&vm.StringCmd{P1: 5, P4: p.queryPlan.tableName}) - p.executionPlan.Append(&vm.CopyCmd{P1: 1, P2: 6}) - p.executionPlan.Append(&vm.StringCmd{P1: 7, P4: string(p.queryPlan.schema)}) - p.executionPlan.Append(&vm.MakeRecordCmd{P1: 3, P2: 5, P3: 8}) - p.executionPlan.Append(&vm.InsertCmd{P1: cursorId, P2: 8, P3: 2}) - p.executionPlan.Append(&vm.ParseSchemaCmd{}) - p.executionPlan.Append(&vm.HaltCmd{}) - return p.executionPlan + p.queryPlan.plan.compile() + p.executionPlan.Commands = p.queryPlan.plan.commands + return p.executionPlan, nil } diff --git a/planner/create_test.go b/planner/create_test.go index 88cd493..2f7895e 100644 --- a/planner/create_test.go +++ b/planner/create_test.go @@ -2,7 +2,6 @@ package planner import ( "errors" - "reflect" "testing" "github.com/chirst/cdb/catalog" @@ -55,10 +54,9 @@ func TestCreateWithNoIDColumn(t *testing.T) { t.Fatalf("failed to convert expected schema to json %s", err) } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, - &vm.CreateBTreeCmd{P2: 1}, + &vm.InitCmd{P2: 13}, &vm.OpenWriteCmd{P1: 1, P2: 1}, + &vm.CreateBTreeCmd{P2: 1}, &vm.NewRowIdCmd{P1: 1, P2: 2}, &vm.StringCmd{P1: 3, P4: "table"}, &vm.StringCmd{P1: 4, P4: "foo"}, @@ -69,15 +67,15 @@ func TestCreateWithNoIDColumn(t *testing.T) { &vm.InsertCmd{P1: 1, P2: 8, P3: 2}, &vm.ParseSchemaCmd{}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.GotoCmd{P2: 1}, } plan, err := NewCreate(mc, stmt).ExecutionPlan() if err != nil { t.Fatal(err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } @@ -114,10 +112,9 @@ func TestCreateWithAlternateNamedIDColumn(t *testing.T) { t.Fatalf("failed to convert expected schema to json %s", err) } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, - &vm.CreateBTreeCmd{P2: 1}, + &vm.InitCmd{P2: 13}, &vm.OpenWriteCmd{P1: 1, P2: 1}, + &vm.CreateBTreeCmd{P2: 1}, &vm.NewRowIdCmd{P1: 1, P2: 2}, &vm.StringCmd{P1: 3, P4: "table"}, &vm.StringCmd{P1: 4, P4: "foo"}, @@ -128,15 +125,15 @@ func TestCreateWithAlternateNamedIDColumn(t *testing.T) { &vm.InsertCmd{P1: 1, P2: 8, P3: 2}, &vm.ParseSchemaCmd{}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.GotoCmd{P2: 1}, } plan, err := NewCreate(mc, stmt).ExecutionPlan() if err != nil { t.Fatal(err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } @@ -215,17 +212,16 @@ func TestCreateIfNotExistsNoop(t *testing.T) { } mc := &mockCreateCatalog{tableExistsRes: true} expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.GotoCmd{P2: 1}, } plan, err := NewCreate(mc, stmt).ExecutionPlan() if err != nil { t.Fatal(err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } diff --git a/planner/crvisitor.go b/planner/crvisitor.go deleted file mode 100644 index c9ba35b..0000000 --- a/planner/crvisitor.go +++ /dev/null @@ -1,124 +0,0 @@ -package planner - -import ( - "slices" - "strings" - - "github.com/chirst/cdb/compiler" - "github.com/chirst/cdb/vm" -) - -// constantRegisterVisitor fills constantRegisters with constants in the visited -// node. -type constantRegisterVisitor struct { - // nextOpenRegister counts upwards reserving registers as needed. - nextOpenRegister int - // constantRegisters is a mapping of register value to register index. - constantRegisters map[int]int - // variableRegisters is a mapping of variable indices to registers. - variableRegisters map[int]int - // stringRegisters is a mapping of string constants to registers - stringRegisters map[string]int -} - -func (c *constantRegisterVisitor) Init(openRegister int) { - c.constantRegisters = make(map[int]int) - c.variableRegisters = make(map[int]int) - c.stringRegisters = make(map[string]int) - c.nextOpenRegister = openRegister -} - -// GetRegisterCommands returns commands to fill the current constantRegister -// map. -func (c *constantRegisterVisitor) GetRegisterCommands() []vm.Command { - // Maps are unordered so there is some extra work to keep commands in order. - lc := c.getOrderedLitCommands() - vc := c.getOrderedVariableCommands() - sc := c.getOrderedStringCommands() - return append(lc, append(vc, sc...)...) -} - -func (c *constantRegisterVisitor) getOrderedLitCommands() []vm.Command { - unordered := []*vm.IntegerCmd{} - for k := range c.constantRegisters { - unordered = append(unordered, &vm.IntegerCmd{P1: k, P2: c.constantRegisters[k]}) - } - slices.SortFunc(unordered, func(a, b *vm.IntegerCmd) int { - return a.P2 - b.P2 - }) - ret := []vm.Command{} - for _, s := range unordered { - ret = append(ret, vm.Command(s)) - } - return ret -} - -func (c *constantRegisterVisitor) getOrderedVariableCommands() []vm.Command { - unordered := []*vm.VariableCmd{} - for k := range c.variableRegisters { - unordered = append(unordered, &vm.VariableCmd{P1: k, P2: c.variableRegisters[k]}) - } - slices.SortFunc(unordered, func(a, b *vm.VariableCmd) int { - return a.P2 - b.P2 - }) - ret := []vm.Command{} - for _, s := range unordered { - ret = append(ret, vm.Command(s)) - } - return ret -} - -func (c *constantRegisterVisitor) getOrderedStringCommands() []vm.Command { - unordered := []*vm.StringCmd{} - for k := range c.stringRegisters { - unordered = append(unordered, &vm.StringCmd{P1: c.stringRegisters[k], P4: k}) - } - slices.SortFunc(unordered, func(a, b *vm.StringCmd) int { - return strings.Compare(a.P4, b.P4) - }) - ret := []vm.Command{} - for _, s := range unordered { - ret = append(ret, vm.Command(s)) - } - return ret -} - -func (c *constantRegisterVisitor) VisitIntLit(e *compiler.IntLit) { - c.fillRegisterIfNeeded(e.Value) -} - -func (c *constantRegisterVisitor) fillRegisterIfNeeded(v int) { - found := false - for k := range c.constantRegisters { - if k == v { - found = true - } - } - if !found { - c.constantRegisters[v] = c.nextOpenRegister - c.nextOpenRegister += 1 - } -} - -func (c *constantRegisterVisitor) VisitVariable(v *compiler.Variable) { - c.variableRegisters[v.Position] = c.nextOpenRegister - c.nextOpenRegister += 1 -} - -func (c *constantRegisterVisitor) VisitStringLit(e *compiler.StringLit) { - found := false - for k := range c.stringRegisters { - if k == e.Value { - found = true - } - } - if !found { - c.stringRegisters[e.Value] = c.nextOpenRegister - c.nextOpenRegister += 1 - } -} - -func (c *constantRegisterVisitor) VisitBinaryExpr(e *compiler.BinaryExpr) {} -func (c *constantRegisterVisitor) VisitUnaryExpr(e *compiler.UnaryExpr) {} -func (c *constantRegisterVisitor) VisitColumnRefExpr(e *compiler.ColumnRef) {} -func (c *constantRegisterVisitor) VisitFunctionExpr(e *compiler.FunctionExpr) {} diff --git a/planner/generator.go b/planner/generator.go new file mode 100644 index 0000000..aec6da3 --- /dev/null +++ b/planner/generator.go @@ -0,0 +1,221 @@ +package planner + +import ( + "github.com/chirst/cdb/vm" +) + +func (u *updateNode) produce() { + u.child.produce() +} + +func (u *updateNode) consume() { + // RowID + u.plan.commands = append(u.plan.commands, &vm.RowIdCmd{ + P1: u.cursorId, + P2: u.plan.freeRegister, + }) + rowIdRegister := u.plan.freeRegister + u.plan.freeRegister += 1 + + // Reserve a contiguous block of free registers for the columns. This block + // will be used in makeRecord. + startRecordRegister := u.plan.freeRegister + u.plan.freeRegister += len(u.updateExprs) + recordRegisterCount := len(u.updateExprs) + for i, e := range u.updateExprs { + generateExpressionTo(u.plan, e, startRecordRegister+i, u.cursorId) + } + + // Make the record for inserting + u.plan.commands = append(u.plan.commands, &vm.MakeRecordCmd{ + P1: startRecordRegister, + P2: recordRegisterCount, + P3: u.plan.freeRegister, + }) + recordRegister := u.plan.freeRegister + u.plan.freeRegister += 1 + + // Update by deleting then inserting + u.plan.commands = append(u.plan.commands, &vm.DeleteCmd{ + P1: u.cursorId, + }) + u.plan.commands = append(u.plan.commands, &vm.InsertCmd{ + P1: u.cursorId, + P2: recordRegister, + P3: rowIdRegister, + }) +} + +func (f *filterNode) produce() { + f.child.produce() +} + +func (f *filterNode) consume() { + jumpCommand := generatePredicate(f.plan, f.predicate, f.cursorId) + f.parent.consume() + jumpCommand.SetJumpAddress(len(f.plan.commands)) +} + +func (s *scanNode) produce() { + s.consume() +} + +func (s *scanNode) consume() { + if s.isWriteCursor { + s.plan.commands = append( + s.plan.commands, + &vm.OpenWriteCmd{P1: s.cursorId, P2: s.rootPageNumber}, + ) + } else { + s.plan.commands = append( + s.plan.commands, + &vm.OpenReadCmd{P1: s.cursorId, P2: s.rootPageNumber}, + ) + } + rewindCmd := &vm.RewindCmd{P1: s.cursorId} + s.plan.commands = append(s.plan.commands, rewindCmd) + loopBeginAddress := len(s.plan.commands) + s.parent.consume() + s.plan.commands = append(s.plan.commands, &vm.NextCmd{ + P1: s.cursorId, + P2: loopBeginAddress, + }) + rewindCmd.P2 = len(s.plan.commands) +} + +func (p *projectNode) produce() { + p.child.produce() +} + +func (p *projectNode) consume() { + startRegister := p.plan.freeRegister + reservedRegisters := len(p.projections) + p.plan.freeRegister += reservedRegisters + for i, projection := range p.projections { + generateExpressionTo(p.plan, projection.expr, startRegister+i, p.cursorId) + } + p.plan.commands = append(p.plan.commands, &vm.ResultRowCmd{ + P1: startRegister, + P2: reservedRegisters, + }) +} + +func (c *constantNode) produce() { + c.consume() +} + +func (c *constantNode) consume() { + c.parent.consume() +} + +func (c *countNode) produce() { + c.consume() +} + +func (c *countNode) consume() { + c.plan.commands = append( + c.plan.commands, + &vm.OpenReadCmd{P1: c.cursorId, P2: c.rootPageNumber}, + ) + c.plan.commands = append(c.plan.commands, &vm.CountCmd{ + P1: c.cursorId, + P2: c.plan.freeRegister, + }) + countRegister := c.plan.freeRegister + countResults := 1 + c.plan.freeRegister += 1 + c.plan.commands = append(c.plan.commands, &vm.ResultRowCmd{ + P1: countRegister, + P2: countResults, + }) +} + +func (c *createNode) produce() { + c.consume() +} + +func (c *createNode) consume() { + if c.noop { + return + } + c.plan.commands = append( + c.plan.commands, + &vm.OpenWriteCmd{P1: c.catalogCursorId, P2: c.catalogRootPageNumber}, + ) + c.plan.commands = append(c.plan.commands, &vm.CreateBTreeCmd{P2: 1}) + c.plan.commands = append(c.plan.commands, &vm.NewRowIdCmd{P1: c.catalogCursorId, P2: 2}) + c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 3, P4: c.objectType}) + c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 4, P4: c.objectName}) + c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 5, P4: c.tableName}) + c.plan.commands = append(c.plan.commands, &vm.CopyCmd{P1: 1, P2: 6}) + c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 7, P4: string(c.schema)}) + c.plan.commands = append(c.plan.commands, &vm.MakeRecordCmd{P1: 3, P2: 5, P3: 8}) + c.plan.commands = append(c.plan.commands, &vm.InsertCmd{P1: c.catalogCursorId, P2: 8, P3: 2}) + c.plan.commands = append(c.plan.commands, &vm.ParseSchemaCmd{}) +} + +func (n *insertNode) produce() { + n.consume() +} + +func (n *insertNode) consume() { + n.plan.commands = append( + n.plan.commands, + &vm.OpenWriteCmd{P1: n.cursorId, P2: n.rootPageNumber}, + ) + for valuesIdx := range len(n.colValues) { + // Setup rowid and it's uniqueness/type checks + pkRegister := n.plan.freeRegister + n.plan.freeRegister += 1 + if n.autoPk { + n.plan.commands = append(n.plan.commands, &vm.NewRowIdCmd{ + P1: n.cursorId, + P2: pkRegister, + }) + } else { + generateExpressionTo(n.plan, n.pkValues[valuesIdx], pkRegister, n.cursorId) + n.plan.commands = append(n.plan.commands, &vm.MustBeIntCmd{P1: pkRegister}) + nec := &vm.NotExistsCmd{ + P1: n.cursorId, + P3: pkRegister, + } + n.plan.commands = append(n.plan.commands, nec) + n.plan.commands = append(n.plan.commands, &vm.HaltCmd{ + P1: 1, + P4: pkConstraint, + }) + nec.P2 = len(n.plan.commands) + } + + // Reserve registers and make values segment for MakeRecord + startRegister := n.plan.freeRegister + reservedRegisters := len(n.colValues[valuesIdx]) + n.plan.freeRegister += reservedRegisters + for vi := range n.colValues[valuesIdx] { + generateExpressionTo( + n.plan, + n.colValues[valuesIdx][vi], + startRegister+vi, + n.cursorId, + ) + } + + // Insert + n.plan.commands = append(n.plan.commands, &vm.MakeRecordCmd{ + P1: startRegister, + P2: reservedRegisters, + P3: n.plan.freeRegister, + }) + recordRegister := n.plan.freeRegister + n.plan.freeRegister += 1 + n.plan.commands = append(n.plan.commands, &vm.InsertCmd{ + P1: n.cursorId, + P2: recordRegister, + P3: pkRegister, + }) + } +} + +func (n *joinNode) produce() {} + +func (n *joinNode) consume() {} diff --git a/planner/insert.go b/planner/insert.go index 912205e..401a884 100644 --- a/planner/insert.go +++ b/planner/insert.go @@ -1,7 +1,6 @@ package planner import ( - "errors" "fmt" "slices" @@ -24,19 +23,6 @@ type insertCatalog interface { // insertPlanner consists of planners capable of generating a logical query plan // tree and bytecode execution plan for a insert statement. type insertPlanner struct { - // The query planner generates a logical query plan tree made up of nodes - // similar to relational algebra operators. The query planner performs - // validation while building the tree. Otherwise known as binding. - queryPlanner *insertQueryPlanner - // The executionPlanner transforms the logical query plan tree to a bytecode - // execution plan that can be ran by the virtual machine. - executionPlanner *insertExecutionPlanner -} - -// insertQueryPlanner converts the AST generated by the compiler to a logical -// query plan tree. It is also responsible for validating the AST against the -// system catalog. -type insertQueryPlanner struct { // catalog contains the schema. catalog insertCatalog // stmt contains the AST. @@ -44,14 +30,6 @@ type insertQueryPlanner struct { // queryPlan contains the query plan being constructed. For an insert, the // root node must be an insertNode. queryPlan *insertNode -} - -// insertExecutionPlanner converts the logical query plan to a bytecode routine -// to be ran by the vm. -type insertExecutionPlanner struct { - // queryPlan contains the query plan generated by the query planner's - // QueryPlan method. - queryPlan *insertNode // executionPlan contains the execution plan generated by calling // ExecutionPlan. executionPlan *vm.ExecutionPlan @@ -60,162 +38,115 @@ type insertExecutionPlanner struct { // NewInsert returns an instance of an insert planner for the given AST. func NewInsert(catalog insertCatalog, stmt *compiler.InsertStmt) *insertPlanner { return &insertPlanner{ - queryPlanner: &insertQueryPlanner{ - catalog: catalog, - stmt: stmt, - }, - executionPlanner: &insertExecutionPlanner{ - executionPlan: vm.NewExecutionPlan( - catalog.GetVersion(), - stmt.Explain, - ), - }, + catalog: catalog, + stmt: stmt, + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), } } // QueryPlan generates the query plan tree for the planner. func (p *insertPlanner) QueryPlan() (*QueryPlan, error) { - qp, err := p.queryPlanner.getQueryPlan() - if err != nil { - return nil, err - } - p.executionPlanner.queryPlan = p.queryPlanner.queryPlan - return qp, err -} - -func (p *insertQueryPlanner) getQueryPlan() (*QueryPlan, error) { rootPage, err := p.catalog.GetRootPageNumber(p.stmt.TableName) if err != nil { return nil, errTableNotExist } - catalogColumnNames, err := p.catalog.GetColumns(p.stmt.TableName) - if err != nil { - return nil, err - } - if err := checkValuesMatchColumns(p.stmt); err != nil { + if err := p.checkValuesMatchColumns(p.stmt); err != nil { return nil, err } - pkColumn, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) + colValues, err := p.getNonPkValues() if err != nil { return nil, err } insertNode := &insertNode{ - rootPage: rootPage, - catalogColumnNames: catalogColumnNames, - pkColumn: pkColumn, - colNames: p.stmt.ColNames, - colValues: p.stmt.ColValues, + colValues: colValues, + rootPageNumber: rootPage, + tableName: p.stmt.TableName, + cursorId: 1, + } + if err := p.setPkValues(insertNode); err != nil { + return nil, err } p.queryPlan = insertNode - return newQueryPlan(insertNode, p.stmt.ExplainQueryPlan), nil + qp := newQueryPlan( + insertNode, + p.stmt.ExplainQueryPlan, + transactionTypeWrite, + ) + insertNode.plan = qp + return qp, nil } -// ExecutionPlan returns the bytecode routine for the planner. Calling QueryPlan -// is not prerequisite to calling ExecutionPlan as ExecutionPlan will be called -// as needed. -func (p *insertPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { - if p.queryPlanner.queryPlan == nil { - _, err := p.QueryPlan() - if err != nil { - return nil, err +func (p *insertPlanner) setPkValues(n *insertNode) error { + pkColumnName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) + if err != nil { + return err + } + statementPkIdx := -1 + if pkColumnName != "" { + statementPkIdx = slices.IndexFunc(p.stmt.ColNames, func(s string) bool { + return s == pkColumnName + }) + } + if statementPkIdx == -1 { + n.autoPk = true + } else { + n.autoPk = false + n.pkValues = []compiler.Expr{} + for _, v := range p.stmt.ColValues { + n.pkValues = append(n.pkValues, v[statementPkIdx]) } } - return p.executionPlanner.getExecutionPlan() + return nil } -func (p *insertExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { - p.buildInit() - cursorId := p.openWrite() - for valueIdx := range len(p.queryPlan.colValues) { - // For simplicity, the primary key is in the first register. - const keyRegister = 1 - if err := p.buildPrimaryKey(cursorId, keyRegister, valueIdx); err != nil { - return nil, err - } - registerIdx := keyRegister - for _, catalogColumnName := range p.queryPlan.catalogColumnNames { - if catalogColumnName != "" && catalogColumnName == p.queryPlan.pkColumn { - // Skip the primary key column since it is handled before. +func (p *insertPlanner) getNonPkValues() ([][]compiler.Expr, error) { + pkColumnName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) + if err != nil { + return nil, err + } + catalogColumnNames, err := p.catalog.GetColumns(p.stmt.TableName) + if err != nil { + return nil, err + } + resultValues := [][]compiler.Expr{} + for _, colValue := range p.stmt.ColValues { + resultValue := []compiler.Expr{} + for _, cn := range catalogColumnNames { + if cn == pkColumnName { continue } - registerIdx += 1 - if err := p.buildNonPkValue(valueIdx, registerIdx, catalogColumnName); err != nil { - return nil, err + stmtColIdx := slices.IndexFunc(p.stmt.ColNames, func(stmtColName string) bool { + return stmtColName == cn + }) + if stmtColIdx == -1 { + return nil, fmt.Errorf("%w %s", errMissingColumnName, cn) } + resultValue = append(resultValue, colValue[stmtColIdx]) } - p.executionPlan.Append(&vm.MakeRecordCmd{P1: 2, P2: registerIdx - 1, P3: registerIdx + 1}) - p.executionPlan.Append(&vm.InsertCmd{P1: cursorId, P2: registerIdx + 1, P3: keyRegister}) - } - p.executionPlan.Append(&vm.HaltCmd{}) - return p.executionPlan, nil -} - -func (p *insertExecutionPlanner) buildInit() { - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - p.executionPlan.Append(&vm.TransactionCmd{P2: 1}) -} - -func (p *insertExecutionPlanner) openWrite() int { - const cursorId = 1 - p.executionPlan.Append(&vm.OpenWriteCmd{P1: cursorId, P2: p.queryPlan.rootPage}) - return cursorId -} - -func (p *insertExecutionPlanner) buildPrimaryKey(writeCursorId, keyRegister, valueIdx int) error { - // If the table has a user defined pk column it needs to be looked up in the - // user defined column list. If the user has defined the pk column the - // execution plan will involve checking the uniqueness of the pk during - // execution. Otherwise the system guarantees a unique key. - statementPkIdx := -1 - if p.queryPlan.pkColumn != "" { - statementPkIdx = slices.IndexFunc(p.queryPlan.colNames, func(s string) bool { - return s == p.queryPlan.pkColumn - }) - } - if statementPkIdx == -1 { - p.executionPlan.Append(&vm.NewRowIdCmd{P1: writeCursorId, P2: keyRegister}) - return nil - } - switch rv := p.queryPlan.colValues[valueIdx][statementPkIdx].(type) { - case *compiler.IntLit: - p.executionPlan.Append(&vm.IntegerCmd{P1: rv.Value, P2: keyRegister}) - case *compiler.Variable: - // TODO must be int could likely be used more to enforce schema types. - p.executionPlan.Append(&vm.VariableCmd{P1: rv.Position, P2: keyRegister}) - p.executionPlan.Append(&vm.MustBeIntCmd{P1: keyRegister}) - default: - return errors.New("unsupported row id value") + resultValues = append(resultValues, resultValue) } - continueIdx := len(p.executionPlan.Commands) + 2 - p.executionPlan.Append(&vm.NotExistsCmd{P1: writeCursorId, P2: continueIdx, P3: keyRegister}) - p.executionPlan.Append(&vm.HaltCmd{P1: 1, P4: pkConstraint}) - return nil + return resultValues, nil } -func (p *insertExecutionPlanner) buildNonPkValue(valueIdx, registerIdx int, catalogColumnName string) error { - // Get the statement index of the column name. Because the name positions - // can mismatch the table column positions. - stmtColIdx := slices.IndexFunc(p.queryPlan.colNames, func(stmtColName string) bool { - return stmtColName == catalogColumnName - }) - // Requires the statement to define a value for each column in the table. - if stmtColIdx == -1 { - return fmt.Errorf("%w %s", errMissingColumnName, catalogColumnName) - } - switch cv := p.queryPlan.colValues[valueIdx][stmtColIdx].(type) { - case *compiler.StringLit: - p.executionPlan.Append(&vm.StringCmd{P1: registerIdx, P4: cv.Value}) - case *compiler.IntLit: - p.executionPlan.Append(&vm.IntegerCmd{P1: cv.Value, P2: registerIdx}) - case *compiler.Variable: - p.executionPlan.Append(&vm.VariableCmd{P1: cv.Position, P2: registerIdx}) - default: - return errors.New("unsupported type of value") +// ExecutionPlan returns the bytecode routine for the planner. Calling QueryPlan +// is not prerequisite to calling ExecutionPlan as ExecutionPlan will be called +// as needed. +func (p *insertPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { + if p.queryPlan == nil { + _, err := p.QueryPlan() + if err != nil { + return nil, err + } } - return nil + p.queryPlan.plan.compile() + p.executionPlan.Commands = p.queryPlan.plan.commands + return p.executionPlan, nil } -func checkValuesMatchColumns(s *compiler.InsertStmt) error { +func (p *insertPlanner) checkValuesMatchColumns(s *compiler.InsertStmt) error { cl := len(s.ColNames) for _, cv := range s.ColValues { if cl != len(cv) { diff --git a/planner/insert_test.go b/planner/insert_test.go index 8e4cb10..6130788 100644 --- a/planner/insert_test.go +++ b/planner/insert_test.go @@ -2,7 +2,6 @@ package planner import ( "errors" - "reflect" "testing" "github.com/chirst/cdb/compiler" @@ -38,25 +37,32 @@ func (m *mockInsertCatalog) GetPrimaryKeyColumn(tableName string) (string, error func TestInsertWithoutPrimaryKey(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 18}, &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.NewRowIdCmd{P1: 1, P2: 1}, - &vm.StringCmd{P1: 2, P4: "gud"}, - &vm.StringCmd{P1: 3, P4: "dude"}, - &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 4}, - &vm.InsertCmd{P1: 1, P2: 4, P3: 1}, - &vm.NewRowIdCmd{P1: 1, P2: 1}, - &vm.StringCmd{P1: 2, P4: "joe"}, - &vm.StringCmd{P1: 3, P4: "doe"}, - &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 4}, - &vm.InsertCmd{P1: 1, P2: 4, P3: 1}, - &vm.NewRowIdCmd{P1: 1, P2: 1}, - &vm.StringCmd{P1: 2, P4: "jan"}, - &vm.StringCmd{P1: 3, P4: "ice"}, - &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 4}, - &vm.InsertCmd{P1: 1, P2: 4, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 2}, + &vm.CopyCmd{P1: 5, P2: 3}, + &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 6}, + &vm.InsertCmd{P1: 1, P2: 6, P3: 1}, + &vm.NewRowIdCmd{P1: 1, P2: 7}, + &vm.CopyCmd{P1: 10, P2: 8}, + &vm.CopyCmd{P1: 11, P2: 9}, + &vm.MakeRecordCmd{P1: 8, P2: 2, P3: 12}, + &vm.InsertCmd{P1: 1, P2: 12, P3: 7}, + &vm.NewRowIdCmd{P1: 1, P2: 13}, + &vm.CopyCmd{P1: 16, P2: 14}, + &vm.CopyCmd{P1: 17, P2: 15}, + &vm.MakeRecordCmd{P1: 14, P2: 2, P3: 18}, + &vm.InsertCmd{P1: 1, P2: 18, P3: 13}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.StringCmd{P1: 4, P4: "gud"}, + &vm.StringCmd{P1: 5, P4: "dude"}, + &vm.StringCmd{P1: 10, P4: "joe"}, + &vm.StringCmd{P1: 11, P4: "doe"}, + &vm.StringCmd{P1: 16, P4: "jan"}, + &vm.StringCmd{P1: 17, P4: "ice"}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ @@ -87,25 +93,27 @@ func TestInsertWithoutPrimaryKey(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } func TestInsertWithPrimaryKey(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 10}, &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.IntegerCmd{P1: 22, P2: 1}, + &vm.CopyCmd{P1: 2, P2: 1}, + &vm.MustBeIntCmd{P1: 1}, &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, - &vm.StringCmd{P1: 2, P4: "gud"}, - &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, - &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.IntegerCmd{P1: 22, P2: 2}, + &vm.StringCmd{P1: 4, P4: "gud"}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ StmtBase: &compiler.StmtBase{}, @@ -129,25 +137,27 @@ func TestInsertWithPrimaryKey(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 10}, &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.IntegerCmd{P1: 12, P2: 1}, + &vm.CopyCmd{P1: 2, P2: 1}, + &vm.MustBeIntCmd{P1: 1}, &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, - &vm.StringCmd{P1: 2, P4: "feller"}, - &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, - &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.IntegerCmd{P1: 12, P2: 2}, + &vm.StringCmd{P1: 4, P4: "feller"}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ StmtBase: &compiler.StmtBase{}, @@ -171,26 +181,27 @@ func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } func TestInsertWithPrimaryKeyParameter(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 10}, &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.VariableCmd{P1: 0, P2: 1}, + &vm.CopyCmd{P1: 2, P2: 1}, &vm.MustBeIntCmd{P1: 1}, - &vm.NotExistsCmd{P1: 1, P2: 7, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, - &vm.StringCmd{P1: 2, P4: "feller"}, - &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, - &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.StringCmd{P1: 4, P4: "feller"}, + &vm.VariableCmd{P1: 0, P2: 2}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ StmtBase: &compiler.StmtBase{}, @@ -214,26 +225,27 @@ func TestInsertWithPrimaryKeyParameter(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } func TestInsertWithParameter(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 10}, &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.VariableCmd{P1: 0, P2: 1}, + &vm.CopyCmd{P1: 2, P2: 1}, &vm.MustBeIntCmd{P1: 1}, - &vm.NotExistsCmd{P1: 1, P2: 7, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, - &vm.VariableCmd{P1: 1, P2: 2}, - &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, - &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.VariableCmd{P1: 0, P2: 2}, + &vm.VariableCmd{P1: 1, P2: 4}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ StmtBase: &compiler.StmtBase{}, @@ -257,10 +269,8 @@ func TestInsertWithParameter(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } diff --git a/planner/node.go b/planner/node.go index a6d6692..abb4856 100644 --- a/planner/node.go +++ b/planner/node.go @@ -1,59 +1,26 @@ package planner -import "github.com/chirst/cdb/compiler" +import ( + "fmt" + + "github.com/chirst/cdb/compiler" +) // This file defines the relational nodes in a logical query plan. // logicalNode defines the interface for a node in the query plan tree. type logicalNode interface { + // children returns the child nodes. children() []logicalNode + // print returns the string representation for explain. print() string -} - -// projectNode defines what columns should be projected. -type projectNode struct { - projections []projection - child logicalNode -} - -// projection is part of the sum of projections in a project node. -type projection struct { - // isCount signifies the projection is the count function. - isCount bool - // colName is the name of the column to be projected. - colName string -} - -// scanNode represents a full scan on a table -type scanNode struct { - // tableName is the name of the table to be scanned - tableName string - // rootPage is the valid page number corresponding to the table - rootPage int - // scanColumns contains information about how the scan will project columns - scanColumns []scanColumn - // scanPredicate is an expression evaluated as a boolean. This behaves as a - // filter in the scan. - scanPredicate compiler.Expr -} - -type scanColumn = compiler.Expr - -// constantNode is used in select statements where there is no table. -type constantNode struct { - // resultColumns are the result columns containing expressions. - resultColumns []compiler.Expr - // predicate filters the result depending on the result of the expression. - predicate compiler.Expr -} - -// countNode represents a special optimization when a table needs a full count -// with no filtering or other projections. -type countNode struct { - // tableName is the name of the table to be scanned - tableName string - // rootPage is the valid page number corresponding to the table - rootPage int + // produce works with consume to generate byte code in the nodes associated + // query plan. produce typically calls its children's produce methods until + // a leaf is reached. When the leaf is reached consume is called which emits + // byte code as the stack unwinds. + produce() + // consume works with produce. + consume() } // TODO joinNode is unused, but remains as a prototype binary operation node. @@ -67,9 +34,18 @@ type joinNode struct { operation string } +func (j *joinNode) print() string { + return fmt.Sprint(j.operation) +} + +func (j *joinNode) children() []logicalNode { + return []logicalNode{j.left, j.right} +} + // createNode represents a operation to create an object in the system catalog. // For example a table, index, or trigger. type createNode struct { + plan *QueryPlan // objectName is the name of the index, trigger, or table. objectName string // objectType could be an index, trigger, or in this case a table. @@ -87,35 +63,179 @@ type createNode struct { // because the query plan would be invalidated given the existence of the object // has changed between query planning and query execution. noop bool + // rootPageNumber is the page number of the system catalog. + catalogRootPageNumber int + // catalogCursorId is the id of the cursor associated with the system + // catalog table being updated. + catalogCursorId int +} + +func (c *createNode) print() string { + if c.noop { + return fmt.Sprintf("assert table %s does not exist", c.tableName) + } + return fmt.Sprintf("create table %s", c.tableName) +} + +func (c *createNode) children() []logicalNode { + return []logicalNode{} } // insertNode represents an insert operation. type insertNode struct { - // rootPage is the rootPage of the table the insert is performed on. - rootPage int - // catalogColumnNames are all of the names of columns associated with the - // table. - catalogColumnNames []string - // pkColumn is the name of the primary key column in the catalog. The value - // is empty if no user defined pk. - pkColumn string - // colNames are the names of columns specified in the insert statement. - colNames []string + plan *QueryPlan // colValues are the values specified in the insert statement. It is two // dimensional i.e. VALUES (v1, v2), (v3, v4) is [[v1, v2], [v3, v4]]. + // + // The logical planner must guarantee these values are in the correct + // ordinal position as the code generator will not check. colValues [][]compiler.Expr + // pkValues holds the pk expression separate from colValues for each values + // entry. In case a pkValue wasn't specified in the values list a reasonable + // value will be provided for the code generator or the autoPk will be true. + pkValues []compiler.Expr + // autoPk indicates the generator should use a NewRowIdCmd for pk + // generation. + autoPk bool + // tableName is the name of the table being inserted to. + tableName string + // rootPageNumber is the page number of the table being inserted to. + rootPageNumber int + // cursorId is the id of the cursor associated with the table being inserted + // to. + cursorId int } -// updateNode represents an update operation -type updateNode struct { - // rootPage is the rootPage of the table the update is performed on. - rootPage int - // recordExprs is a list of expressions that can be evaluated to make the - // desired record. If a column is in the update set list that column will be - // some sort of expression. If a column is not in the set list it will - // simply be a columnRef. Note the ordering is important of these - // expressions because they have to make the record. - recordExprs []compiler.Expr - // predicate is the where clause or nil +func (i *insertNode) print() string { + return "insert" +} + +func (i *insertNode) children() []logicalNode { + return []logicalNode{} +} + +type countNode struct { + plan *QueryPlan + projection projection + // tableName is the name of the table being scanned. + tableName string + // rootPageNumber is the page number of the table being scanned. + rootPageNumber int + // cursorId is the id of the cursor associated with the table being scanned. + cursorId int +} + +func (c *countNode) children() []logicalNode { + return []logicalNode{} +} + +func (c *countNode) print() string { + return fmt.Sprintf("count table %s", c.tableName) +} + +type constantNode struct { + parent logicalNode + plan *QueryPlan +} + +func (c *constantNode) print() string { + return "constant data source" +} + +func (c *constantNode) children() []logicalNode { + return []logicalNode{} +} + +type projection struct { + expr compiler.Expr + // alias is the alias of the projection or no alias for the zero value. + alias string +} + +type projectNode struct { + child logicalNode + plan *QueryPlan + projections []projection + // cursorId is the id of the cursor associated with the table being + // projected. In the future this will likely need to be enhanced since + // projections are not entirely meant for one table. + cursorId int +} + +func (p *projectNode) print() string { + return "project" +} + +func (p *projectNode) children() []logicalNode { + return []logicalNode{p.child} +} + +type scanNode struct { + parent logicalNode + plan *QueryPlan + // tableName is the name of the table being scanned. + tableName string + // rootPageNumber is the page number of the table being scanned. + rootPageNumber int + // cursorId is the id of the cursor associated with the table being scanned. + cursorId int + // isWriteCursor is true when the cursor should be a write cursor. + isWriteCursor bool +} + +func (s *scanNode) print() string { + return fmt.Sprintf("scan table %s", s.tableName) +} + +func (s *scanNode) children() []logicalNode { + return []logicalNode{} +} + +type filterNode struct { + child logicalNode + parent logicalNode + plan *QueryPlan predicate compiler.Expr + // cursorId is the id of the cursor associated with the table being filtered. + // In the future this will likely need to be enhanced since filters are not + // entirely meant for one table. + cursorId int +} + +func (f *filterNode) print() string { + return "filter" +} + +func (f *filterNode) children() []logicalNode { + return []logicalNode{f.child} +} + +type updateNode struct { + child logicalNode + plan *QueryPlan + // updateExprs is formed from the update statement AST. The idea is to + // provide an expression for each column where the expression is either a + // columnRef or the complex expression from the right hand side of the SET + // keyword. Note it is important to provide the expressions in their correct + // ordinal position as the generator will not try to order them correctly. + // + // The row id is not allowed to be updated at the moment because it could + // cause infinite loops due to it changing the physical location of the + // record. The query plan will have to use a temporary storage to update + // primary keys. + updateExprs []compiler.Expr + // tableName is the name of the table being updated. + tableName string + // rootPageNumber is the page number of the table being updated. + rootPageNumber int + // cursorId is the id of the cursor associated with the table being updated. + cursorId int +} + +func (u *updateNode) print() string { + return fmt.Sprintf("update table %s", u.tableName) +} + +func (u *updateNode) children() []logicalNode { + return []logicalNode{} } diff --git a/planner/plan.go b/planner/plan.go index 279183c..222e431 100644 --- a/planner/plan.go +++ b/planner/plan.go @@ -2,26 +2,177 @@ package planner import ( "fmt" + "slices" "strings" "unicode/utf8" + + "github.com/chirst/cdb/vm" ) -// QueryPlan contains the query plan tree. It is capable of converting the tree -// to a string representation for a query prefixed with `EXPLAIN QUERY PLAN`. +// QueryPlan contains the query plan tree stemming from the root node. It is +// capable of converting the tree to a string representation for a query +// prefixed with `EXPLAIN QUERY PLAN`. +// +// The structure also holds the necessary data and receivers for generating a +// plan as well as the final commands that define the execution plan. type QueryPlan struct { // plan holds the string representation also known as the tree. plan string - // root holds the root node of the query plan + // root is the root node of the plan tree. root logicalNode // ExplainQueryPlan is a flag indicating if the SQL asked for the query plan // to be printed as a string representation with `EXPLAIN QUERY PLAN`. ExplainQueryPlan bool -} - -func newQueryPlan(root logicalNode, explainQueryPlan bool) *QueryPlan { + // commands is a list of commands that define the plan. + commands []vm.Command + // constInts is a mapping of constant integer values to the registers that + // contain the value. + constInts map[int]int + // constStrings is a mapping of constant string values to the registers that + // contain the value. + constStrings map[string]int + // constVars is a mapping of a variable's position to the registers that + // holds the variable's value. + constVars map[int]int + // freeRegister is a counter containing the next free register in the plan. + freeRegister int + // transactionType defines what kind of transaction the plan will need. + transactionType transactionType +} + +func newQueryPlan( + root logicalNode, + explainQueryPlan bool, + transactionType transactionType, +) *QueryPlan { return &QueryPlan{ root: root, ExplainQueryPlan: explainQueryPlan, + commands: []vm.Command{}, + constInts: make(map[int]int), + constStrings: make(map[string]int), + constVars: make(map[int]int), + freeRegister: 1, + transactionType: transactionType, + } +} + +// transactionType defines possible transactions for a query plan. +type transactionType int + +const ( + transactionTypeNone transactionType = 0 + transactionTypeRead transactionType = 1 + transactionTypeWrite transactionType = 2 +) + +// declareConstInt gets or sets a register with the const value and returns the +// register. It is guaranteed the value will be in the register for the duration +// of the plan. +func (p *QueryPlan) declareConstInt(i int) int { + _, ok := p.constInts[i] + if !ok { + p.constInts[i] = p.freeRegister + p.freeRegister += 1 + } + return p.constInts[i] +} + +// declareConstString gets or sets a register with the const value and returns +// the register. It is guaranteed the value will be in the register for the +// duration of the plan. +func (p *QueryPlan) declareConstString(s string) int { + _, ok := p.constStrings[s] + if !ok { + p.constStrings[s] = p.freeRegister + p.freeRegister += 1 + } + return p.constStrings[s] +} + +// declareConstVar gets or sets a register with the const value and returns +// the register. It is guaranteed the value will be in the register for the +// duration of the plan. +func (p *QueryPlan) declareConstVar(position int) int { + _, ok := p.constVars[position] + if !ok { + p.constVars[position] = p.freeRegister + p.freeRegister += 1 + } + return p.constVars[position] +} + +// compile sets byte code for the root node and it's children on commands. +func (p *QueryPlan) compile() { + initCmd := &vm.InitCmd{} + p.commands = append(p.commands, initCmd) + p.root.produce() + p.commands = append(p.commands, &vm.HaltCmd{}) + initCmd.P2 = len(p.commands) + p.pushTransaction() + // these constants are pushed ordered since maps are unordered making it + // difficult to assert that a sequence of instructions appears. + p.pushConstantInts() + p.pushConstantStrings() + p.pushConstantVars() + p.commands = append(p.commands, &vm.GotoCmd{P2: 1}) +} + +func (p *QueryPlan) pushTransaction() { + switch p.transactionType { + case transactionTypeNone: + return + case transactionTypeRead: + p.commands = append( + p.commands, + &vm.TransactionCmd{P2: 0}, + ) + case transactionTypeWrite: + p.commands = append( + p.commands, + &vm.TransactionCmd{P2: 1}, + ) + default: + panic("unexpected transaction type") + } +} + +func (p *QueryPlan) pushConstantInts() { + temp := []*vm.IntegerCmd{} + for k := range p.constInts { + temp = append(temp, &vm.IntegerCmd{P1: k, P2: p.constInts[k]}) + } + slices.SortFunc(temp, func(a, b *vm.IntegerCmd) int { + return a.P2 - b.P2 + }) + for i := range temp { + p.commands = append(p.commands, temp[i]) + } +} + +func (p *QueryPlan) pushConstantStrings() { + temp := []*vm.StringCmd{} + for v := range p.constStrings { + temp = append(temp, &vm.StringCmd{P1: p.constStrings[v], P4: v}) + } + slices.SortFunc(temp, func(a, b *vm.StringCmd) int { + return a.P1 - b.P1 + }) + for i := range temp { + p.commands = append(p.commands, temp[i]) + } +} + +func (p *QueryPlan) pushConstantVars() { + temp := []*vm.VariableCmd{} + for v := range p.constVars { + temp = append(temp, &vm.VariableCmd{P1: v, P2: p.constVars[v]}) + } + slices.SortFunc(temp, func(a, b *vm.VariableCmd) int { + return a.P2 - b.P2 + }) + for i := range temp { + p.commands = append(p.commands, temp[i]) } } @@ -100,91 +251,3 @@ func (p *QueryPlan) connectSiblings() string { } return strings.Join(planMatrix, "\n") } - -func (p *projectNode) print() string { - list := "(" - for i, proj := range p.projections { - list += proj.print() - if i+1 < len(p.projections) { - list += ", " - } - } - list += ")" - return "project" + list -} - -func (p *projection) print() string { - if p.isCount { - return "count(*)" - } - if p.colName == "" { - return "" - } - return p.colName -} - -func (s *scanNode) print() string { - if s.scanPredicate != nil { - return fmt.Sprintf("scan table %s with predicate", s.tableName) - } - return fmt.Sprintf("scan table %s", s.tableName) -} - -func (c *constantNode) print() string { - return "constant data source" -} - -func (c *countNode) print() string { - return fmt.Sprintf("count table %s", c.tableName) -} - -func (j *joinNode) print() string { - return fmt.Sprint(j.operation) -} - -func (c *createNode) print() string { - if c.noop { - return fmt.Sprintf("assert table %s does not exist", c.tableName) - } - return fmt.Sprintf("create table %s", c.tableName) -} - -func (i *insertNode) print() string { - return "insert" -} - -func (u *updateNode) print() string { - return "update" -} - -func (p *projectNode) children() []logicalNode { - return []logicalNode{p.child} -} - -func (s *scanNode) children() []logicalNode { - return []logicalNode{} -} - -func (c *constantNode) children() []logicalNode { - return []logicalNode{} -} - -func (c *countNode) children() []logicalNode { - return []logicalNode{} -} - -func (j *joinNode) children() []logicalNode { - return []logicalNode{j.left, j.right} -} - -func (c *createNode) children() []logicalNode { - return []logicalNode{} -} - -func (i *insertNode) children() []logicalNode { - return []logicalNode{} -} - -func (u *updateNode) children() []logicalNode { - return []logicalNode{} -} diff --git a/planner/plan_test.go b/planner/plan_test.go index 2fcbaf4..e948777 100644 --- a/planner/plan_test.go +++ b/planner/plan_test.go @@ -4,11 +4,6 @@ import "testing" func TestExplainQueryPlan(t *testing.T) { root := &projectNode{ - projections: []projection{ - {colName: "id"}, - {colName: "first_name"}, - {colName: "last_name"}, - }, child: &joinNode{ operation: "join", left: &joinNode{ @@ -19,29 +14,29 @@ func TestExplainQueryPlan(t *testing.T) { right: &joinNode{ operation: "join", left: &scanNode{ - tableName: "baz", + tableName: "bar", }, right: &scanNode{ - tableName: "buzz", + tableName: "baz", }, }, }, right: &scanNode{ - tableName: "bar", + tableName: "buzz", }, }, } - qp := newQueryPlan(root, true) + qp := newQueryPlan(root, true, transactionTypeRead) formattedResult := qp.ToString() expectedResult := "" + - " ── project(id, first_name, last_name)\n" + + " ── project\n" + " └─ join\n" + " ├─ join\n" + " | ├─ scan table foo\n" + " | └─ join\n" + - " | ├─ scan table baz\n" + - " | └─ scan table buzz\n" + - " └─ scan table bar\n" + " | ├─ scan table bar\n" + + " | └─ scan table baz\n" + + " └─ scan table buzz\n" if formattedResult != expectedResult { t.Fatalf("got\n%s\nwant\n%s", formattedResult, expectedResult) } diff --git a/planner/predicate_generator.go b/planner/predicate_generator.go new file mode 100644 index 0000000..3a919f2 --- /dev/null +++ b/planner/predicate_generator.go @@ -0,0 +1,208 @@ +package planner + +import ( + "github.com/chirst/cdb/compiler" + "github.com/chirst/cdb/vm" +) + +// generatePredicate generates code to make a boolean jump for the given +// expression within the plan context. The function returns the jump command to +// lazily set the jump address. +func generatePredicate(plan *QueryPlan, expression compiler.Expr, cursorId int) vm.JumpCommand { + pg := &predicateGenerator{} + pg.plan = plan + pg.cursorId = cursorId + pg.build(expression, 0) + return pg.jumpCommand +} + +// predicateGenerator builds commands to calculate the boolean result of an +// expression. +type predicateGenerator struct { + plan *QueryPlan + // jumpCommand is the command used to make the jump. The command can be + // accessed to defer setting the jump address. + jumpCommand vm.JumpCommand + // cursorId is the id of the cursor for the table in the associated query. + // This will need to be enhanced at some point to support more than one + // alias in a predicate, but is fine for now. + cursorId int +} + +func (p *predicateGenerator) build(e compiler.Expr, level int) (int, error) { + switch ce := e.(type) { + case *compiler.BinaryExpr: + ol, err := p.build(ce.Left, level+1) + if err != nil { + return 0, err + } + or, err := p.build(ce.Right, level+1) + if err != nil { + return 0, err + } + r := p.getNextRegister() + switch ce.Operator { + case compiler.OpAdd: + p.plan.commands = append( + p.plan.commands, + &vm.AddCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpDiv: + p.plan.commands = append( + p.plan.commands, + &vm.DivideCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpMul: + p.plan.commands = append( + p.plan.commands, + &vm.MultiplyCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpExp: + p.plan.commands = append( + p.plan.commands, + &vm.ExponentCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpSub: + p.plan.commands = append( + p.plan.commands, + &vm.SubtractCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpEq: + if level == 0 { + jc := &vm.NotEqualCmd{P1: ol, P3: or} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + return 0, nil + } + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(p.plan.commands) + jumpOverCount + p.plan.commands = append( + p.plan.commands, + &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + return r, nil + case compiler.OpLt: + if level == 0 { + jc := &vm.LteCmd{P1: or, P3: ol} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + return 0, nil + } + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(p.plan.commands) + jumpOverCount + p.plan.commands = append( + p.plan.commands, + &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + return r, nil + case compiler.OpGt: + if level == 0 { + jc := &vm.GteCmd{P1: or, P3: ol} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + return 0, nil + } + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(p.plan.commands) + jumpOverCount + p.plan.commands = append( + p.plan.commands, + &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + return r, nil + default: + panic("no vm command for operator") + } + case *compiler.ColumnRef: + colRefReg := p.valueRegisterFor(ce) + if level == 0 { + jc := &vm.IfNotCmd{P1: colRefReg} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return colRefReg, nil + case *compiler.IntLit: + cir := p.plan.declareConstInt(ce.Value) + if level == 0 { + jc := &vm.IfNotCmd{P1: cir} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return cir, nil + case *compiler.StringLit: + csr := p.plan.declareConstString(ce.Value) + if level == 0 { + jc := &vm.IfNotCmd{P1: csr} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return csr, nil + case *compiler.Variable: + cvr := p.plan.declareConstVar(ce.Position) + if level == 0 { + jc := &vm.IfNotCmd{P1: cvr} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return cvr, nil + } + panic("unhandled expression in predicate builder") +} + +func (p *predicateGenerator) valueRegisterFor(ce *compiler.ColumnRef) int { + if ce.IsPrimaryKey { + r := p.getNextRegister() + p.plan.commands = append(p.plan.commands, &vm.RowIdCmd{ + P1: p.cursorId, + P2: r, + }) + return r + } + r := p.getNextRegister() + p.plan.commands = append(p.plan.commands, &vm.ColumnCmd{ + P1: p.cursorId, + P2: ce.ColIdx, P3: r, + }) + return r +} + +func (p *predicateGenerator) getNextRegister() int { + r := p.plan.freeRegister + p.plan.freeRegister += 1 + return r +} diff --git a/planner/result_generator.go b/planner/result_generator.go new file mode 100644 index 0000000..0e0a93f --- /dev/null +++ b/planner/result_generator.go @@ -0,0 +1,126 @@ +package planner + +import ( + "github.com/chirst/cdb/compiler" + "github.com/chirst/cdb/vm" +) + +// generateExpressionTo takes the context of the plan and generates commands +// that land the result of the given expr in the toRegister. +func generateExpressionTo(plan *QueryPlan, expr compiler.Expr, toRegister int, cursorId int) { + rg := &resultExprGenerator{} + rg.plan = plan + rg.outputRegister = toRegister + rg.cursorId = cursorId + rg.build(expr, 0) +} + +// resultExprGenerator builds commands for the given expression. +type resultExprGenerator struct { + plan *QueryPlan + // outputRegister is the target register for the result of the expression. + outputRegister int + // cursorId is the id of the cursor for the table in the associated query. + // This will need to be enhanced at some point to support more than one + // aliased column in the results, but is fine for now. + cursorId int +} + +func (e *resultExprGenerator) build(root compiler.Expr, level int) int { + switch n := root.(type) { + case *compiler.BinaryExpr: + ol := e.build(n.Left, level+1) + or := e.build(n.Right, level+1) + r := e.getNextRegister(level) + switch n.Operator { + case compiler.OpAdd: + e.plan.commands = append(e.plan.commands, &vm.AddCmd{P1: ol, P2: or, P3: r}) + case compiler.OpDiv: + e.plan.commands = append(e.plan.commands, &vm.DivideCmd{P1: ol, P2: or, P3: r}) + case compiler.OpMul: + e.plan.commands = append(e.plan.commands, &vm.MultiplyCmd{P1: ol, P2: or, P3: r}) + case compiler.OpExp: + e.plan.commands = append(e.plan.commands, &vm.ExponentCmd{P1: ol, P2: or, P3: r}) + case compiler.OpSub: + e.plan.commands = append(e.plan.commands, &vm.SubtractCmd{P1: ol, P2: or, P3: r}) + case compiler.OpEq: + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(e.plan.commands) + jumpOverCount + e.plan.commands = append( + e.plan.commands, + &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + case compiler.OpLt: + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(e.plan.commands) + jumpOverCount + e.plan.commands = append( + e.plan.commands, + &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + case compiler.OpGt: + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(e.plan.commands) + jumpOverCount + e.plan.commands = append( + e.plan.commands, + &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + default: + panic("no vm command for operator") + } + return r + case *compiler.ColumnRef: + r := e.getNextRegister(level) + if n.IsPrimaryKey { + e.plan.commands = append(e.plan.commands, &vm.RowIdCmd{P1: e.cursorId, P2: r}) + } else { + e.plan.commands = append( + e.plan.commands, + &vm.ColumnCmd{P1: e.cursorId, P2: n.ColIdx, P3: r}, + ) + } + return r + case *compiler.IntLit: + cir := e.plan.declareConstInt(n.Value) + if level == 0 { + e.plan.commands = append( + e.plan.commands, + &vm.CopyCmd{P1: cir, P2: e.outputRegister}, + ) + } + return cir + case *compiler.StringLit: + csr := e.plan.declareConstString(n.Value) + if level == 0 { + e.plan.commands = append( + e.plan.commands, + &vm.CopyCmd{P1: csr, P2: e.outputRegister}, + ) + } + return csr + case *compiler.Variable: + cvr := e.plan.declareConstVar(n.Position) + if level == 0 { + e.plan.commands = append( + e.plan.commands, + &vm.CopyCmd{P1: cvr, P2: e.outputRegister}, + ) + } + return cvr + } + panic("unhandled expression in expr command builder") +} + +func (e *resultExprGenerator) getNextRegister(level int) int { + if level == 0 { + return e.outputRegister + } + r := e.plan.freeRegister + e.plan.freeRegister += 1 + return r +} diff --git a/planner/select.go b/planner/select.go index 0fbf03c..71f7e7c 100644 --- a/planner/select.go +++ b/planner/select.go @@ -20,37 +20,14 @@ type selectCatalog interface { } // selectPlanner is capable of generating a logical query plan and a physical -// execution plan for a select statement. The planners within are separated by -// their responsibility. +// execution plan for a select statement. type selectPlanner struct { - // queryPlanner is responsible for transforming the AST to a logical query - // plan tree. This tree is made up of nodes that map closely to a relational - // algebra tree. The query planner also performs binding and validation. - queryPlanner *selectQueryPlanner - // executionPlanner transforms the logical query tree to a bytecode routine, - // built to be ran by the virtual machine. - executionPlanner *selectExecutionPlanner -} - -// selectQueryPlanner converts an AST to a logical query plan. Along the way it -// also validates the AST makes sense with the catalog (a process known as -// binding). -type selectQueryPlanner struct { - // catalog contains the schema + // catalog contains the schema. catalog selectCatalog - // stmt contains the AST + // stmt contains the AST. stmt *compiler.SelectStmt - // queryPlan contains the logical plan being built. The root node must be a - // projection. - queryPlan *projectNode -} - -// selectExecutionPlanner converts logical nodes in a query plan tree to -// bytecode that can be run by the vm. -type selectExecutionPlanner struct { - // queryPlan contains the logical plan. This node is populated by calling - // the QueryPlan method. - queryPlan *projectNode + // queryPlan contains the plan being built. + queryPlan *QueryPlan // executionPlan contains the execution plan for the vm. This is built by // calling ExecutionPlan. executionPlan *vm.ExecutionPlan @@ -59,149 +36,151 @@ type selectExecutionPlanner struct { // NewSelect returns an instance of a select planner for the given AST. func NewSelect(catalog selectCatalog, stmt *compiler.SelectStmt) *selectPlanner { return &selectPlanner{ - queryPlanner: &selectQueryPlanner{ - catalog: catalog, - stmt: stmt, - }, - executionPlanner: &selectExecutionPlanner{ - executionPlan: vm.NewExecutionPlan( - catalog.GetVersion(), - stmt.Explain, - ), - }, + catalog: catalog, + stmt: stmt, + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), } } // QueryPlan generates the query plan tree for the planner. func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { - qp, err := p.queryPlanner.getQueryPlan() + err := p.optimizeResultColumns() if err != nil { return nil, err } - p.executionPlanner.queryPlan = p.queryPlanner.queryPlan - return qp, err -} -// getQueryPlan performs several passes on the AST to compute a more manageable -// tree structure of logical operators who closely resemble relational algebra -// operators. -// -// Firstly, getQueryPlan performs simplification to translate the projection -// portion of the select statement to uniform expressions. This means a "*", -// "table.*", or "alias.*" would simply be translated to ColumnRef expressions. -// From here the query is easier to work on as it is one consistent structure. -// -// From here, more simplification is performed. Folding computes constant -// expressions to reduce the complexity of the expression tree. This saves -// instructions ran during a scan. An example of this folding could be the -// binary expression 1 + 1 becoming a constant expression 2. Or a function UPPER -// on a string literal "foo" being simplified to just the string literal "FOO". -// -// Analysis steps are also performed. Such as assigning catalog information to -// ColumnRef expressions. This means associating table names with root page -// numbers, column names with their indices within a tuple, and column names -// with their constraints and available indexes. -func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { - err := p.optimizeResultColumns() + var tableName string + var rootPageNumber int + if p.stmt.From != nil { + tableName = p.stmt.From.TableName + } + if tableName != "" { + rootPageNumber, err = p.catalog.GetRootPageNumber(tableName) + if err != nil { + return nil, errTableNotExist + } + } + + projections, err := p.getProjections() if err != nil { return nil, err } + for i := range projections { + cev := &catalogExprVisitor{} + cev.Init(p.catalog, tableName) + projections[i].expr.BreadthWalk(cev) + } - // Constant query has no "from". - if p.stmt.From == nil || p.stmt.From.TableName == "" { - constExprs := []compiler.Expr{} - for i := range p.stmt.ResultColumns { - constExprs = append(constExprs, p.stmt.ResultColumns[i].Expression) + hasFunc := false + for i := range projections { + _, ok := projections[i].expr.(*compiler.FunctionExpr) + if ok { + hasFunc = true } - child := &constantNode{ - resultColumns: constExprs, - predicate: p.stmt.Where, + } + if hasFunc { + if len(projections) != 1 { + return nil, errors.New("only one projection allowed for COUNT") } - projections, err := p.getProjections() - if err != nil { - return nil, err + if tableName == "" { + return nil, errors.New("must have from for COUNT") } - p.queryPlan = &projectNode{ - projections: projections, - child: child, + cn := &countNode{ + projection: projections[0], + rootPageNumber: rootPageNumber, + tableName: tableName, + cursorId: 1, } - return newQueryPlan(p.queryPlan, p.stmt.ExplainQueryPlan), nil + plan := newQueryPlan( + cn, + p.stmt.ExplainQueryPlan, + transactionTypeRead, + ) + cn.plan = plan + p.queryPlan = plan + return plan, nil } - tableName := p.stmt.From.TableName - rootPageNumber, err := p.catalog.GetRootPageNumber(tableName) - if err != nil { - return nil, errTableNotExist + tt := transactionTypeRead + if tableName == "" { + tt = transactionTypeNone } - - // Count node is specially supported for now. - qp, err := p.getCountNode(tableName, rootPageNumber) - if err != nil { - return nil, err - } - if qp != nil { - return qp, nil + projectNode := &projectNode{ + projections: projections, + cursorId: 1, } - + plan := newQueryPlan(projectNode, p.stmt.ExplainQueryPlan, tt) + projectNode.plan = plan if p.stmt.Where != nil { cev := &catalogExprVisitor{} - cev.Init(p.catalog, p.stmt.From.TableName) - if cev.err != nil { - return nil, err - } + cev.Init(p.catalog, tableName) p.stmt.Where.BreadthWalk(cev) - } - - // At this point a constant and count should be ruled out. The planner isn't - // looking at using indexes yet so we are safe to focus on scanNodes. - child := &scanNode{ - tableName: tableName, - rootPage: rootPageNumber, - scanColumns: []scanColumn{}, - scanPredicate: p.stmt.Where, - } - for _, resultColumn := range p.stmt.ResultColumns { - if resultColumn.All { - cols, err := p.getScanColumns() - if err != nil { - return nil, err + filterNode := &filterNode{ + parent: projectNode, + plan: plan, + predicate: p.stmt.Where, + cursorId: 1, + } + projectNode.child = filterNode + if tableName == "" { + constNode := &constantNode{ + plan: plan, } - child.scanColumns = append(child.scanColumns, cols...) - } else if resultColumn.AllTable != "" { - if tableName != resultColumn.AllTable { - return nil, fmt.Errorf("invalid expression %s.*", resultColumn.AllTable) + filterNode.child = constNode + constNode.parent = filterNode + } else { + scanNode := &scanNode{ + plan: plan, + tableName: tableName, + rootPageNumber: rootPageNumber, + cursorId: 1, } - cols, err := p.getScanColumns() - if err != nil { - return nil, err + filterNode.child = scanNode + scanNode.parent = filterNode + } + } else { + if tableName == "" { + constNode := &constantNode{ + plan: plan, } - child.scanColumns = append(child.scanColumns, cols...) - } else if resultColumn.Expression != nil { - child.scanColumns = append(child.scanColumns, resultColumn.Expression) + projectNode.child = constNode + constNode.parent = projectNode } else { - return nil, fmt.Errorf("unhandled result column %#v", resultColumn) - } - for i := range child.scanColumns { - cev := &catalogExprVisitor{} - if cev.err != nil { - return nil, err + scanNode := &scanNode{ + plan: plan, + tableName: tableName, + rootPageNumber: rootPageNumber, + cursorId: 1, } - cev.Init(p.catalog, child.tableName) - child.scanColumns[i].BreadthWalk(cev) + projectNode.child = scanNode + scanNode.parent = projectNode } } - projections, err := p.getProjections() - if err != nil { - return nil, err - } - p.queryPlan = &projectNode{ - projections: projections, - child: child, + p.queryPlan = plan + plan.root = projectNode + return plan, nil +} + +// ExecutionPlan returns the bytecode execution plan for the planner. Calling +// QueryPlan is not a prerequisite to this method as it will be called by +// ExecutionPlan if needed. +func (sp *selectPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { + if sp.queryPlan == nil { + _, err := sp.QueryPlan() + if err != nil { + return nil, err + } } - return newQueryPlan(p.queryPlan, p.stmt.ExplainQueryPlan), nil + sp.setResultHeader() + sp.queryPlan.compile() + sp.executionPlan.Commands = sp.queryPlan.commands + return sp.executionPlan, nil } -func (p *selectQueryPlanner) optimizeResultColumns() error { +func (p *selectPlanner) optimizeResultColumns() error { var err error for i := range p.stmt.ResultColumns { if p.stmt.ResultColumns[i].Expression != nil { @@ -279,67 +258,7 @@ func foldExpr(e compiler.Expr) (compiler.Expr, error) { } } -// getCountNode supports the count function under special circumstances. -func (p *selectQueryPlanner) getCountNode(tableName string, rootPageNumber int) (*QueryPlan, error) { - if len(p.stmt.ResultColumns) == 0 { - return nil, nil - } - switch e := p.stmt.ResultColumns[0].Expression.(type) { - case *compiler.FunctionExpr: - if len(p.stmt.ResultColumns) != 1 { - return nil, errors.New("count with other result columns not supported") - } - if e.FnType != compiler.FnCount { - return nil, fmt.Errorf("only %s function is supported", e.FnType) - } - child := &countNode{ - tableName: tableName, - rootPage: rootPageNumber, - } - projections, err := p.getProjections() - if err != nil { - return nil, err - } - p.queryPlan = &projectNode{ - projections: projections, - child: child, - } - return newQueryPlan(p.queryPlan, p.stmt.ExplainQueryPlan), nil - } - return nil, nil -} - -func (p *selectQueryPlanner) getScanColumns() ([]scanColumn, error) { - pkColName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.From.TableName) - if err != nil { - return nil, err - } - cols, err := p.catalog.GetColumns(p.stmt.From.TableName) - if err != nil { - return nil, err - } - scanColumns := []scanColumn{} - idx := 0 - for _, c := range cols { - if c == pkColName { - scanColumns = append(scanColumns, &compiler.ColumnRef{ - Table: p.stmt.From.TableName, - Column: c, - IsPrimaryKey: c == pkColName, - }) - } else { - scanColumns = append(scanColumns, &compiler.ColumnRef{ - Table: p.stmt.From.TableName, - Column: c, - ColIdx: idx, - }) - idx += 1 - } - } - return scanColumns, nil -} - -func (p *selectQueryPlanner) getProjections() ([]projection, error) { +func (p *selectPlanner) getProjections() ([]projection, error) { var projections []projection for _, resultColumn := range p.stmt.ResultColumns { if resultColumn.All { @@ -349,7 +268,10 @@ func (p *selectQueryPlanner) getProjections() ([]projection, error) { } for _, c := range cols { projections = append(projections, projection{ - colName: c, + expr: &compiler.ColumnRef{ + Table: p.stmt.From.TableName, + Column: c, + }, }) } } else if resultColumn.AllTable != "" { @@ -359,93 +281,51 @@ func (p *selectQueryPlanner) getProjections() ([]projection, error) { } for _, c := range cols { projections = append(projections, projection{ - colName: c, + expr: &compiler.ColumnRef{ + Table: p.stmt.From.TableName, + Column: c, + }, }) } } else if resultColumn.Expression != nil { - switch e := resultColumn.Expression.(type) { - case *compiler.ColumnRef: - colName := e.Column - if resultColumn.Alias != "" { - colName = resultColumn.Alias - } - projections = append(projections, projection{ - colName: colName, - }) - case *compiler.FunctionExpr: - projections = append(projections, projection{ - isCount: true, - colName: resultColumn.Alias, - }) - default: - projections = append(projections, projection{ - isCount: false, - colName: resultColumn.Alias, - }) - } + projections = append(projections, projection{ + expr: resultColumn.Expression, + alias: resultColumn.Alias, + }) } } return projections, nil } -// ExecutionPlan returns the bytecode execution plan for the planner. Calling -// QueryPlan is not a prerequisite to this method as it will be called by -// ExecutionPlan if needed. -func (sp *selectPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { - if sp.queryPlanner.queryPlan == nil { - _, err := sp.QueryPlan() - if err != nil { - return nil, err - } - } - return sp.executionPlanner.getExecutionPlan() -} - -func (p *selectExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { - p.setResultHeader() - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - switch c := p.queryPlan.child.(type) { - case *scanNode: - err := p.setResultTypes(c.scanColumns) - if err != nil { - return nil, err - } - p.executionPlan.Append(&vm.TransactionCmd{P2: 0}) - if err := p.buildScan(c); err != nil { - return nil, err +func (p *selectPlanner) setResultHeader() { + resultHeader := []string{} + switch t := p.queryPlan.root.(type) { + case *projectNode: + projectExprs := []compiler.Expr{} + for _, projection := range t.projections { + header := "" + if projection.alias == "" { + if cr, ok := projection.expr.(*compiler.ColumnRef); ok { + header = cr.Column + } + } else { + header = projection.alias + } + resultHeader = append(resultHeader, header) + projectExprs = append(projectExprs, projection.expr) } + p.setResultTypes(projectExprs) case *countNode: - err := p.setResultTypes([]compiler.Expr{&compiler.IntLit{}}) - if err != nil { - return nil, err - } - p.executionPlan.Append(&vm.TransactionCmd{P2: 0}) - p.buildOptimizedCountScan(c) - case *constantNode: - err := p.setResultTypes(c.resultColumns) - if err != nil { - return nil, err - } - if err := p.buildConstantNode(c); err != nil { - return nil, err - } + resultHeader = append(resultHeader, t.projection.alias) + p.setResultTypes([]compiler.Expr{t.projection.expr}) default: - return nil, fmt.Errorf("unhandled node %#v", c) - } - p.executionPlan.Append(&vm.HaltCmd{}) - return p.executionPlan, nil -} - -func (p *selectExecutionPlanner) setResultHeader() { - resultHeader := []string{} - for _, p := range p.queryPlan.projections { - resultHeader = append(resultHeader, p.colName) + panic("unhandled node for result header") } p.executionPlan.ResultHeader = resultHeader } // setResultTypes attempts to precompute the type for each result column expr. -func (p *selectExecutionPlanner) setResultTypes(exprs []compiler.Expr) error { +func (p *selectPlanner) setResultTypes(exprs []compiler.Expr) error { resolvedTypes := []catalog.CdbType{} for _, expr := range exprs { t, err := getExprType(expr) @@ -489,541 +369,3 @@ func getExprType(expr compiler.Expr) (catalog.CdbType, error) { return catalog.CdbType{ID: catalog.CTUnknown}, fmt.Errorf("no handler for expr type %v", expr) } } - -func (p *selectExecutionPlanner) buildScan(n *scanNode) error { - // Build a map of constant values to registers by walking result columns and - // the scan predicate. - const beginningRegister = 1 - crv := &constantRegisterVisitor{} - crv.Init(beginningRegister) - for _, c := range n.scanColumns { - c.BreadthWalk(crv) - } - if n.scanPredicate != nil { - n.scanPredicate.BreadthWalk(crv) - } - rcs := crv.GetRegisterCommands() - for _, rc := range rcs { - p.executionPlan.Append(rc) - } - - // Open an available cursor. Can just be 1 for now since no queries are - // supported at the moment that requires more than one cursor. - const cursorId = 1 - p.executionPlan.Append(&vm.OpenReadCmd{P1: cursorId, P2: n.rootPage}) - - // Rewind moves the aforementioned cursor to the "start" of the table. - rwc := &vm.RewindCmd{P1: cursorId} - p.executionPlan.Append(rwc) - - // Mark beginning of scan for rewind - scanBeginningCommand := len(p.executionPlan.Commands) - - // Reserve registers for the column result. Claim registers after as needed. - startScanRegister := crv.nextOpenRegister - endScanRegisterOffset := len(n.scanColumns) - - // This is the inside of the scan meaning how each result column is handled - // per iteration of the scan (loop). - var pkRegister int - var openRegister int - colRegisters := make(map[int]int) - for i, c := range n.scanColumns { - exprBuilder := &resultColumnCommandBuilder{} - exprBuilder.Build( - cursorId, - len(p.executionPlan.Commands), - startScanRegister+endScanRegisterOffset, - crv.constantRegisters, - crv.variableRegisters, - crv.stringRegisters, - startScanRegister+i, - c, - ) - if exprBuilder.pkRegister != 0 { - pkRegister = exprBuilder.pkRegister - } - openRegister = exprBuilder.openRegister - for crk, crv := range exprBuilder.colRegisters { - colRegisters[crk] = crv - } - for _, tc := range exprBuilder.commands { - p.executionPlan.Append(tc) - } - } - - // TODO predicate commands should come as early as possible to save - // instructions, but for now this is easier. - // - // Walk scan predicate and build commands to calculate a conditional jump. - if n.scanPredicate != nil { - bpb := &booleanPredicateBuilder{} - err := bpb.Build( - cursorId, - openRegister, - len(p.executionPlan.Commands), - len(p.executionPlan.Commands), - crv.constantRegisters, - colRegisters, - crv.variableRegisters, - crv.stringRegisters, - pkRegister, - n.scanPredicate, - ) - if err != nil { - return err - } - for _, bc := range bpb.commands { - p.executionPlan.Append(bc) - } - } - - // Result row gathers the aforementioned inside of the scan and makes them - // into a single row for the query results. - p.executionPlan.Append(&vm.ResultRowCmd{P1: startScanRegister, P2: endScanRegisterOffset}) - - // Falls through or goes back to the start of the scan loop. - p.executionPlan.Append(&vm.NextCmd{P1: cursorId, P2: scanBeginningCommand}) - - // Must tell the rewind command where to go in case the table is empty. - rwc.P2 = len(p.executionPlan.Commands) - return nil -} - -// buildOptimizedCountScan is a special optimization made when a table only has -// a count aggregate and no other projections. Since the optimized scan -// aggregates the count of tuples on each page, but does not look at individual -// tuples. -func (p *selectExecutionPlanner) buildOptimizedCountScan(n *countNode) { - const cursorId = 1 - p.executionPlan.Append(&vm.OpenReadCmd{P1: cursorId, P2: n.rootPage}) - p.executionPlan.Append(&vm.CountCmd{P1: cursorId, P2: 1}) - p.executionPlan.Append(&vm.ResultRowCmd{P1: 1, P2: 1}) -} - -// buildConstantNode is a single row operation produced by a "select" without a -// "from". -func (p *selectExecutionPlanner) buildConstantNode(n *constantNode) error { - // Build registers with constants. These are likely extra instructions, but - // okay since it allows this to follow the same pattern a scan does. - const beginningRegister = 1 - crv := &constantRegisterVisitor{} - crv.Init(beginningRegister) - for _, c := range n.resultColumns { - c.BreadthWalk(crv) - } - if n.predicate != nil { - n.predicate.BreadthWalk(crv) - } - rcs := crv.GetRegisterCommands() - for _, rc := range rcs { - p.executionPlan.Append(rc) - } - - // Like a scan, but for a single row. - reservedRegisterStart := crv.nextOpenRegister - reservedRegisterOffset := len(n.resultColumns) - var openRegister int - for i, rc := range n.resultColumns { - exprBuilder := &resultColumnCommandBuilder{} - exprBuilder.Build( - 1, - len(p.executionPlan.Commands), - reservedRegisterStart+reservedRegisterOffset, - crv.constantRegisters, - crv.variableRegisters, - crv.stringRegisters, - reservedRegisterStart+i, - rc, - ) - for _, tc := range exprBuilder.commands { - p.executionPlan.Append(tc) - } - openRegister = exprBuilder.openRegister - } - - if n.predicate != nil { - bpb := &booleanPredicateBuilder{} - err := bpb.Build( - 0, - openRegister, - len(p.executionPlan.Commands), - len(p.executionPlan.Commands), - crv.constantRegisters, - map[int]int{}, - crv.variableRegisters, - crv.stringRegisters, - 0, - n.predicate, - ) - if err != nil { - return err - } - for _, bc := range bpb.commands { - p.executionPlan.Append(bc) - } - } - - p.executionPlan.Append(&vm.ResultRowCmd{P1: reservedRegisterStart, P2: reservedRegisterOffset}) - return nil -} - -// resultColumnCommandBuilder builds commands for the given expression. -type resultColumnCommandBuilder struct { - // cursorId is the cursor for the related table. - cursorId int - // openRegister is the next available register. - openRegister int - // outputRegister is the target register for the result of the expression. - outputRegister int - // commands are the commands to evaluate the expression. - commands []vm.Command - // commandOffset is the amount of commands prior to calling this routine. - // Useful for calculating jump instructions. - commandOffset int - // litRegisters is a mapping of scalar values to registers containing them. - litRegisters map[int]int - // colRegisters is a mapping of column indexes to registers containing the - // column. This is for subsequent routines to reuse the result of these - // commands. - colRegisters map[int]int - // variableRegisters is a mapping of variable indices to registers. - variableRegisters map[int]int - // stringRegisters is a mapping of strings to registers - stringRegisters map[string]int - // pkRegister is 0 value unless a register has been filled as part of Build. - // This is for subsequent routines to reuse the result of the command. - pkRegister int -} - -func (e *resultColumnCommandBuilder) Build( - cursorId int, - commandOffset int, - openRegister int, - litRegisters map[int]int, - variableRegisters map[int]int, - stringRegisters map[string]int, - outputRegister int, - root compiler.Expr, -) int { - e.cursorId = cursorId - e.commandOffset = commandOffset - e.openRegister = openRegister - e.litRegisters = litRegisters - e.colRegisters = make(map[int]int) - e.variableRegisters = variableRegisters - e.stringRegisters = stringRegisters - e.outputRegister = outputRegister - return e.build(root, 0) -} - -func (e *resultColumnCommandBuilder) build(root compiler.Expr, level int) int { - switch n := root.(type) { - case *compiler.BinaryExpr: - ol := e.build(n.Left, level+1) - or := e.build(n.Right, level+1) - r := e.getNextRegister(level) - switch n.Operator { - case compiler.OpAdd: - e.commands = append(e.commands, &vm.AddCmd{P1: ol, P2: or, P3: r}) - case compiler.OpDiv: - e.commands = append(e.commands, &vm.DivideCmd{P1: ol, P2: or, P3: r}) - case compiler.OpMul: - e.commands = append(e.commands, &vm.MultiplyCmd{P1: ol, P2: or, P3: r}) - case compiler.OpExp: - e.commands = append(e.commands, &vm.ExponentCmd{P1: ol, P2: or, P3: r}) - case compiler.OpSub: - e.commands = append(e.commands, &vm.SubtractCmd{P1: ol, P2: or, P3: r}) - case compiler.OpEq: - e.commands = append(e.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(e.commands) + jumpOverCount + e.commandOffset - e.commands = append( - e.commands, - &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - e.commands = append(e.commands, &vm.IntegerCmd{P1: 1, P2: r}) - case compiler.OpLt: - e.commands = append(e.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(e.commands) + jumpOverCount + e.commandOffset - e.commands = append( - e.commands, - &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - e.commands = append(e.commands, &vm.IntegerCmd{P1: 1, P2: r}) - case compiler.OpGt: - e.commands = append(e.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(e.commands) + jumpOverCount + e.commandOffset - e.commands = append( - e.commands, - &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - e.commands = append(e.commands, &vm.IntegerCmd{P1: 1, P2: r}) - default: - panic("no vm command for operator") - } - return r - case *compiler.ColumnRef: - r := e.getNextRegister(level) - if n.IsPrimaryKey { - e.pkRegister = r - e.commands = append(e.commands, &vm.RowIdCmd{P1: e.cursorId, P2: r}) - } else { - e.colRegisters[n.ColIdx] = r - e.commands = append( - e.commands, - &vm.ColumnCmd{P1: e.cursorId, P2: n.ColIdx, P3: r}, - ) - } - return r - case *compiler.IntLit: - if level == 0 { - e.commands = append( - e.commands, - &vm.CopyCmd{P1: e.litRegisters[n.Value], P2: e.outputRegister}, - ) - } - return e.litRegisters[n.Value] - case *compiler.StringLit: - if level == 0 { - e.commands = append( - e.commands, - &vm.CopyCmd{P1: e.stringRegisters[n.Value], P2: e.outputRegister}, - ) - } - return e.stringRegisters[n.Value] - case *compiler.Variable: - if level == 0 { - e.commands = append( - e.commands, - &vm.CopyCmd{P1: e.variableRegisters[n.Position], P2: e.outputRegister}, - ) - } - return e.variableRegisters[n.Position] - } - panic("unhandled expression in expr command builder") -} - -func (e *resultColumnCommandBuilder) getNextRegister(level int) int { - if level == 0 { - return e.outputRegister - } - r := e.openRegister - e.openRegister += 1 - return r -} - -// booleanPredicateBuilder builds commands to calculate the boolean result of an -// expression. -type booleanPredicateBuilder struct { - // cursorId is the cursor for the associated table. - cursorId int - // openRegister is the next available register - openRegister int - // jumpAddress is the address the result of the boolean expression should - // conditionally jump to. - jumpAddress int - // commands is a list of commands representing the expression. - commands []vm.Command - // commandOffset is used to calculate the amount of commands already in the - // plan. - commandOffset int - // litRegisters is a mapping of scalar values to the register containing - // them. litRegisters should be guaranteed since they have a minimal cost - // due to being calculated outside of any scans/loops. - litRegisters map[int]int - // colRegisters is a mapping of table column index to register containing - // the column value. colRegisters may not be guaranteed since a projection - // may not require them, in which case colRegisters should be calculated as - // part of the predicate. - colRegisters map[int]int - // variableRegisters is a mapping of variable indices to registers. - variableRegisters map[int]int - // stringRegisters is a mapping of strings to registers - stringRegisters map[string]int - // pkRegister is unset when 0. Otherwise, pkRegister is the register - // containing the table row id. pkRegister may not be guaranteed depending - // on the projection in which case the register should be calculated as part - // of the expression evaluation. - pkRegister int -} - -func (p *booleanPredicateBuilder) Build( - cursorId int, - openRegister int, - jumpAddress int, - commandOffset int, - litRegisters map[int]int, - colRegisters map[int]int, - variableRegisters map[int]int, - stringRegisters map[string]int, - pkRegister int, - e compiler.Expr, -) error { - p.cursorId = cursorId - p.openRegister = openRegister - p.jumpAddress = jumpAddress - p.commandOffset = commandOffset - p.litRegisters = litRegisters - p.colRegisters = colRegisters - p.variableRegisters = variableRegisters - p.stringRegisters = stringRegisters - p.pkRegister = pkRegister - _, err := p.build(e, 0) - return err -} - -func (p *booleanPredicateBuilder) build(e compiler.Expr, level int) (int, error) { - switch ce := e.(type) { - case *compiler.BinaryExpr: - ol, err := p.build(ce.Left, level+1) - if err != nil { - return 0, err - } - or, err := p.build(ce.Right, level+1) - if err != nil { - return 0, err - } - r := p.getNextRegister() - switch ce.Operator { - case compiler.OpAdd: - p.commands = append(p.commands, &vm.AddCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpDiv: - p.commands = append(p.commands, &vm.DivideCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpMul: - p.commands = append(p.commands, &vm.MultiplyCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpExp: - p.commands = append(p.commands, &vm.ExponentCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpSub: - p.commands = append(p.commands, &vm.SubtractCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpEq: - if level == 0 { - p.commands = append( - p.commands, - &vm.NotEqualCmd{P1: ol, P2: p.getJumpAddress(), P3: or}, - ) - return 0, nil - } - p.commands = append(p.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(p.commands) + jumpOverCount + p.commandOffset - p.commands = append( - p.commands, - &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - p.commands = append(p.commands, &vm.IntegerCmd{P1: 1, P2: r}) - return r, nil - case compiler.OpLt: - if level == 0 { - p.commands = append( - p.commands, - &vm.LteCmd{P1: or, P2: p.getJumpAddress(), P3: ol}, - ) - return 0, nil - } - p.commands = append(p.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(p.commands) + jumpOverCount + p.commandOffset - p.commands = append( - p.commands, - &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - p.commands = append(p.commands, &vm.IntegerCmd{P1: 1, P2: r}) - return r, nil - case compiler.OpGt: - if level == 0 { - p.commands = append( - p.commands, - &vm.GteCmd{P1: or, P2: p.getJumpAddress(), P3: ol}, - ) - return 0, nil - } - p.commands = append(p.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(p.commands) + jumpOverCount + p.commandOffset - p.commands = append( - p.commands, - &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - p.commands = append(p.commands, &vm.IntegerCmd{P1: 1, P2: r}) - return r, nil - default: - panic("no vm command for operator") - } - case *compiler.ColumnRef: - colRefReg := p.valueRegisterFor(ce) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: colRefReg, P2: p.getJumpAddress()}) - } - return colRefReg, nil - case *compiler.IntLit: - litReg := p.litRegisters[ce.Value] - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: litReg, P2: p.getJumpAddress()}) - } - return litReg, nil - case *compiler.StringLit: - strReg := p.stringRegisters[ce.Value] - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: strReg, P2: p.getJumpAddress()}) - } - return strReg, nil - case *compiler.Variable: - varReg := p.variableRegisters[ce.Position] - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: varReg, P2: p.getJumpAddress()}) - } - return varReg, nil - } - panic("unhandled expression in predicate builder") -} - -func (p *booleanPredicateBuilder) getJumpAddress() int { - return p.jumpAddress + len(p.commands) + 2 -} - -func (p *booleanPredicateBuilder) valueRegisterFor(ce *compiler.ColumnRef) int { - if ce.IsPrimaryKey { - if p.pkRegister == 0 { - r := p.getNextRegister() - p.commands = append(p.commands, &vm.RowIdCmd{P1: p.cursorId, P2: r}) - return r - } - return p.pkRegister - } - cr := p.colRegisters[ce.ColIdx] - if cr == 0 { - r := p.getNextRegister() - p.commands = append(p.commands, &vm.ColumnCmd{P1: p.cursorId, P2: ce.ColIdx, P3: r}) - return r - } - return cr -} - -func (p *booleanPredicateBuilder) getNextRegister() int { - r := p.openRegister - p.openRegister += 1 - return r -} diff --git a/planner/select_test.go b/planner/select_test.go index 2b9a631..87adb7a 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -2,7 +2,6 @@ package planner import ( "errors" - "reflect" "testing" "github.com/chirst/cdb/catalog" @@ -59,15 +58,16 @@ func TestSelectPlan(t *testing.T) { { description: "StarWithPrimaryKey", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, + &vm.InitCmd{P2: 8}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 8}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -88,15 +88,16 @@ func TestSelectPlan(t *testing.T) { { description: "StarWithoutPrimaryKey", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, + &vm.InitCmd{P2: 8}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 8}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 1}, &vm.ColumnCmd{P1: 1, P2: 1, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -124,16 +125,17 @@ func TestSelectPlan(t *testing.T) { }, }, expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, + &vm.InitCmd{P2: 9}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 9}, + &vm.RewindCmd{P1: 1, P2: 8}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 1}, &vm.RowIdCmd{P1: 1, P2: 2}, &vm.ColumnCmd{P1: 1, P2: 1, P3: 3}, &vm.ResultRowCmd{P1: 1, P2: 3}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.GotoCmd{P2: 1}, }, mockCatalogSetup: func(m *mockSelectCatalog) *mockSelectCatalog { m.primaryKeyColumnName = "id" @@ -167,16 +169,17 @@ func TestSelectPlan(t *testing.T) { }, }, expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, - &vm.IntegerCmd{P1: 10, P2: 1}, + &vm.InitCmd{P2: 8}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 9}, - &vm.RowIdCmd{P1: 1, P2: 3}, - &vm.AddCmd{P1: 3, P2: 1, P3: 2}, - &vm.ResultRowCmd{P1: 2, P2: 1}, - &vm.NextCmd{P1: 1, P2: 5}, + &vm.RewindCmd{P1: 1, P2: 7}, + &vm.RowIdCmd{P1: 1, P2: 2}, + &vm.AddCmd{P1: 2, P2: 3, P3: 1}, + &vm.ResultRowCmd{P1: 1, P2: 1}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.IntegerCmd{P1: 10, P2: 3}, + &vm.GotoCmd{P2: 1}, }, mockCatalogSetup: func(m *mockSelectCatalog) *mockSelectCatalog { m.primaryKeyColumnName = "id" @@ -188,15 +191,16 @@ func TestSelectPlan(t *testing.T) { { description: "AllTable", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, + &vm.InitCmd{P2: 8}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 8}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -222,14 +226,15 @@ func TestSelectPlan(t *testing.T) { { description: "SpecificColumnPrimaryKeyMiddleOrdinal", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, + &vm.InitCmd{P2: 7}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 7}, + &vm.RewindCmd{P1: 1, P2: 6}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ResultRowCmd{P1: 1, P2: 1}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -258,15 +263,16 @@ func TestSelectPlan(t *testing.T) { { description: "SpecificColumns", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, + &vm.InitCmd{P2: 8}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 8}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 1, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -300,12 +306,13 @@ func TestSelectPlan(t *testing.T) { { description: "JustCountAggregate", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, + &vm.InitCmd{P2: 5}, &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.CountCmd{P1: 1, P2: 1}, &vm.ResultRowCmd{P1: 1, P2: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -322,19 +329,20 @@ func TestSelectPlan(t *testing.T) { { description: "Operators", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.IntegerCmd{P1: 1, P2: 1}, - &vm.IntegerCmd{P1: 18, P2: 2}, - &vm.IntegerCmd{P1: 387420489, P2: 3}, - &vm.IntegerCmd{P1: 81, P2: 4}, - &vm.IntegerCmd{P1: 0, P2: 5}, - &vm.CopyCmd{P1: 1, P2: 6}, - &vm.CopyCmd{P1: 2, P2: 7}, - &vm.CopyCmd{P1: 3, P2: 8}, - &vm.CopyCmd{P1: 4, P2: 9}, - &vm.CopyCmd{P1: 5, P2: 10}, - &vm.ResultRowCmd{P1: 6, P2: 5}, + &vm.InitCmd{P2: 8}, + &vm.CopyCmd{P1: 6, P2: 1}, + &vm.CopyCmd{P1: 7, P2: 2}, + &vm.CopyCmd{P1: 8, P2: 3}, + &vm.CopyCmd{P1: 9, P2: 4}, + &vm.CopyCmd{P1: 10, P2: 5}, + &vm.ResultRowCmd{P1: 1, P2: 5}, &vm.HaltCmd{}, + &vm.IntegerCmd{P1: 1, P2: 6}, + &vm.IntegerCmd{P1: 18, P2: 7}, + &vm.IntegerCmd{P1: 387420489, P2: 8}, + &vm.IntegerCmd{P1: 81, P2: 9}, + &vm.IntegerCmd{P1: 0, P2: 10}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -379,18 +387,20 @@ func TestSelectPlan(t *testing.T) { }, }, { - description: "with where clause", + description: "WithWhereClause", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, - &vm.IntegerCmd{P1: 1, P2: 1}, + &vm.InitCmd{P2: 9}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 9}, - &vm.RowIdCmd{P1: 1, P2: 2}, - &vm.NotEqualCmd{P1: 2, P2: 8, P3: 1}, - &vm.ResultRowCmd{P1: 2, P2: 1}, - &vm.NextCmd{P1: 1, P2: 5}, + &vm.RewindCmd{P1: 1, P2: 8}, + &vm.RowIdCmd{P1: 1, P2: 1}, + &vm.NotEqualCmd{P1: 1, P2: 7, P3: 2}, + &vm.RowIdCmd{P1: 1, P2: 4}, + &vm.ResultRowCmd{P1: 4, P2: 1}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.IntegerCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -418,11 +428,12 @@ func TestSelectPlan(t *testing.T) { { description: "ConstantString", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.StringCmd{P1: 1, P4: "foo"}, - &vm.CopyCmd{P1: 1, P2: 2}, - &vm.ResultRowCmd{P1: 2, P2: 1}, + &vm.InitCmd{P2: 4}, + &vm.CopyCmd{P1: 2, P2: 1}, + &vm.ResultRowCmd{P1: 1, P2: 1}, &vm.HaltCmd{}, + &vm.StringCmd{P1: 2, P4: "foo"}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -451,10 +462,8 @@ func TestSelectPlan(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range c.expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, c.expectedCommands); err != nil { + t.Error(err) } }) } diff --git a/planner/update.go b/planner/update.go index 7aa126e..8430710 100644 --- a/planner/update.go +++ b/planner/update.go @@ -4,6 +4,7 @@ import ( "errors" "slices" + "github.com/chirst/cdb/catalog" "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" ) @@ -14,24 +15,14 @@ type updateCatalog interface { GetRootPageNumber(string) (int, error) GetColumns(string) ([]string, error) GetPrimaryKeyColumn(string) (string, error) + GetColumnType(tableName string, columnName string) (catalog.CdbType, error) } // updatePanner houses the query planner and execution planner for a update // statement. type updatePlanner struct { - queryPlanner *updateQueryPlanner - executionPlanner *updateExecutionPlanner -} - -// updateQueryPlanner generates a queryPlan for the given update statement. -type updateQueryPlanner struct { - catalog updateCatalog - stmt *compiler.UpdateStmt - queryPlan *updateNode -} - -// updateExecutionPlanner generates a byte code routine for the given queryPlan. -type updateExecutionPlanner struct { + catalog updateCatalog + stmt *compiler.UpdateStmt queryPlan *updateNode executionPlan *vm.ExecutionPlan } @@ -39,40 +30,35 @@ type updateExecutionPlanner struct { // NewUpdate create a update planner. func NewUpdate(catalog updateCatalog, stmt *compiler.UpdateStmt) *updatePlanner { return &updatePlanner{ - queryPlanner: &updateQueryPlanner{ - catalog: catalog, - stmt: stmt, - }, - executionPlanner: &updateExecutionPlanner{ - executionPlan: vm.NewExecutionPlan( - catalog.GetVersion(), - stmt.Explain, - ), - }, + catalog: catalog, + stmt: stmt, + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), } } // QueryPlan sets up a high level plan to be passed to ExecutionPlan. func (p *updatePlanner) QueryPlan() (*QueryPlan, error) { - qp, err := p.queryPlanner.getQueryPlan() - if err != nil { - return nil, err - } - p.executionPlanner.queryPlan = p.queryPlanner.queryPlan - return qp, err -} - -// getQueryPlan returns a updateNode with a high level plan. -func (p *updateQueryPlanner) getQueryPlan() (*QueryPlan, error) { rootPage, err := p.catalog.GetRootPageNumber(p.stmt.TableName) if err != nil { return nil, errTableNotExist } updateNode := &updateNode{ - rootPage: rootPage, - recordExprs: []compiler.Expr{}, - } + updateExprs: []compiler.Expr{}, + tableName: p.stmt.TableName, + rootPageNumber: rootPage, + cursorId: 1, + } + logicalPlan := newQueryPlan( + updateNode, + p.stmt.ExplainQueryPlan, + transactionTypeWrite, + ) + updateNode.plan = logicalPlan p.queryPlan = updateNode + logicalPlan.root = updateNode if err := p.errIfPrimaryKeySet(); err != nil { return nil, err @@ -86,23 +72,37 @@ func (p *updateQueryPlanner) getQueryPlan() (*QueryPlan, error) { return nil, err } - if err := p.errIfSetExprNotSupported(); err != nil { - return nil, err - } - - if err := p.includeUpdate(); err != nil { - return nil, err + scanNode := &scanNode{ + plan: logicalPlan, + tableName: p.stmt.TableName, + rootPageNumber: rootPage, + cursorId: 1, + isWriteCursor: true, + } + if p.stmt.Predicate != nil { + cev := &catalogExprVisitor{} + cev.Init(p.catalog, p.stmt.TableName) + p.stmt.Predicate.BreadthWalk(cev) + filterNode := &filterNode{ + plan: logicalPlan, + predicate: p.stmt.Predicate, + parent: updateNode, + child: scanNode, + cursorId: 1, + } + updateNode.child = filterNode + scanNode.parent = filterNode + } else { + scanNode.parent = updateNode + updateNode.child = scanNode } - return &QueryPlan{ - ExplainQueryPlan: p.stmt.ExplainQueryPlan, - root: updateNode, - }, nil + return logicalPlan, nil } // errIfPrimaryKeySet checks the primary key isn't being updated because it // could cause an infinite loop if not handled properly. -func (p *updateQueryPlanner) errIfPrimaryKeySet() error { +func (p *updatePlanner) errIfPrimaryKeySet() error { pkColumnName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) if err != nil { return err @@ -115,7 +115,7 @@ func (p *updateQueryPlanner) errIfPrimaryKeySet() error { // errIfSetNotOnDestinationTable checks the set list has column names that are // part of the table being updated. -func (p *updateQueryPlanner) errIfSetNotOnDestinationTable() error { +func (p *updatePlanner) errIfSetNotOnDestinationTable() error { schemaColumns, err := p.catalog.GetColumns(p.stmt.TableName) if err != nil { return err @@ -130,7 +130,7 @@ func (p *updateQueryPlanner) errIfSetNotOnDestinationTable() error { // setQueryPlanRecordExpressions populates the query plan with appropriate // expressions for setting up to make a record. -func (p *updateQueryPlanner) setQueryPlanRecordExpressions() error { +func (p *updatePlanner) setQueryPlanRecordExpressions() error { schemaColumns, err := p.catalog.GetColumns(p.stmt.TableName) if err != nil { return err @@ -141,14 +141,17 @@ func (p *updateQueryPlanner) setQueryPlanRecordExpressions() error { } idx := 0 for _, schemaColumn := range schemaColumns { + if schemaColumn == pkColName { + continue + } if setListExpression, ok := p.stmt.SetList[schemaColumn]; ok { - p.queryPlan.recordExprs = append( - p.queryPlan.recordExprs, + p.queryPlan.updateExprs = append( + p.queryPlan.updateExprs, setListExpression, ) } else { - p.queryPlan.recordExprs = append( - p.queryPlan.recordExprs, + p.queryPlan.updateExprs = append( + p.queryPlan.updateExprs, &compiler.ColumnRef{ Table: p.stmt.TableName, Column: schemaColumn, @@ -161,155 +164,23 @@ func (p *updateQueryPlanner) setQueryPlanRecordExpressions() error { idx += 1 } } - return nil -} - -// errIfSetExprNotSupported is temporary until more expressions can be supported -// in the execution plan. -func (p *updateQueryPlanner) errIfSetExprNotSupported() error { - for _, e := range p.queryPlan.recordExprs { - switch e.(type) { - case *compiler.IntLit: - continue - case *compiler.StringLit: - continue - case *compiler.ColumnRef: - continue - default: - return errors.New("set list expression not supported") - } - } - return nil -} - -func (p *updateQueryPlanner) includeUpdate() error { - if p.stmt.Predicate == nil { - return nil - } - p.queryPlan.predicate = p.stmt.Predicate - t, ok := p.queryPlan.predicate.(*compiler.BinaryExpr) - supportErr := errors.New("only pk update supported in where clause") - if !ok { - return supportErr - } - l, ok := t.Left.(*compiler.ColumnRef) - if !ok { - return supportErr - } - pkColName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) - if err != nil { - return err - } - if l.Column != pkColName { - return supportErr - } - _, ok = t.Right.(*compiler.IntLit) - if !ok { - return supportErr - } - if t.Operator != compiler.OpEq { - return supportErr + for i := range p.queryPlan.updateExprs { + cev := &catalogExprVisitor{} + cev.Init(p.catalog, p.stmt.TableName) + p.queryPlan.updateExprs[i].BreadthWalk(cev) } return nil } // Execution plan is a byte code routine based off a high level query plan. func (p *updatePlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { - if p.queryPlanner.queryPlan == nil { + if p.queryPlan == nil { _, err := p.QueryPlan() if err != nil { return nil, err } } - return p.executionPlanner.getExecutionPlan() -} - -// getExecutionPlan transforms a query plan to a byte code routine. -func (p *updateExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { - freeRegisterCounter := 1 - // Init - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - p.executionPlan.Append(&vm.TransactionCmd{P2: 1}) - cursorId := 1 - p.executionPlan.Append(&vm.OpenWriteCmd{P1: cursorId, P2: p.queryPlan.rootPage}) - - // Go to start of table - rewindCmd := &vm.RewindCmd{P1: cursorId} // P2 deferred - p.executionPlan.Append(rewindCmd) - - // Loop - loopStartAddress := len(p.executionPlan.Commands) - - // If needed, include jump for if. - var notEqCmd *vm.NotEqualCmd - if p.queryPlan.predicate != nil { - p.executionPlan.Append(&vm.RowIdCmd{P1: cursorId, P2: freeRegisterCounter}) - freeRegisterCounter += 1 - // No ok checks because done in logical plan. - pe := p.queryPlan.predicate.(*compiler.BinaryExpr) - r := pe.Right.(*compiler.IntLit) - p.executionPlan.Append(&vm.IntegerCmd{P1: r.Value, P2: freeRegisterCounter}) - freeRegisterCounter += 1 - notEqCmd = &vm.NotEqualCmd{ - P1: freeRegisterCounter - 2, - P2: -1, // deferred - P3: freeRegisterCounter - 1, - } - p.executionPlan.Append(notEqCmd) - } - - // take each item in the set list and build to make a record - loopStartRegister := freeRegisterCounter - var pkRegister int - for _, expression := range p.queryPlan.recordExprs { - switch typedExpression := expression.(type) { - case *compiler.ColumnRef: - if typedExpression.IsPrimaryKey { - p.executionPlan.Append(&vm.RowIdCmd{ - P1: cursorId, - P2: freeRegisterCounter, - }) - pkRegister = freeRegisterCounter - } else { - p.executionPlan.Append(&vm.ColumnCmd{ - P1: cursorId, - P2: typedExpression.ColIdx, - P3: freeRegisterCounter, - }) - } - case *compiler.IntLit: - p.executionPlan.Append(&vm.IntegerCmd{ - P1: typedExpression.Value, - P2: freeRegisterCounter, - }) - case *compiler.StringLit: - p.executionPlan.Append(&vm.StringCmd{ - P1: freeRegisterCounter, - P4: typedExpression.Value, - }) - default: - return nil, errors.New("expression not supported") - } - freeRegisterCounter += 1 - } - p.executionPlan.Append(&vm.MakeRecordCmd{ - P1: loopStartRegister + 1, // plus 1 for the pk - P2: len(p.queryPlan.recordExprs) - 1, // minus 1 for the pk - P3: freeRegisterCounter, - }) - p.executionPlan.Append(&vm.DeleteCmd{P1: cursorId}) - p.executionPlan.Append(&vm.InsertCmd{ - P1: cursorId, - P2: freeRegisterCounter, - P3: pkRegister, - }) - p.executionPlan.Append(&vm.NextCmd{P1: cursorId, P2: loopStartAddress}) - if notEqCmd != nil { - notEqCmd.P2 = len(p.executionPlan.Commands) - 1 - } - - // End - p.executionPlan.Append(&vm.HaltCmd{}) - rewindCmd.P2 = len(p.executionPlan.Commands) - 1 + p.queryPlan.plan.compile() + p.executionPlan.Commands = p.queryPlan.plan.commands return p.executionPlan, nil } diff --git a/planner/update_test.go b/planner/update_test.go index 28e753e..7e45980 100644 --- a/planner/update_test.go +++ b/planner/update_test.go @@ -2,9 +2,9 @@ package planner import ( "errors" - "reflect" "testing" + "github.com/chirst/cdb/catalog" "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" ) @@ -40,6 +40,10 @@ func (*mockUpdateCatalog) GetPrimaryKeyColumn(tableName string) (string, error) return "", errors.New("err mock catalog pk") } +func (mockUpdateCatalog) GetColumnType(tableName string, columnName string) (catalog.CdbType, error) { + return catalog.CdbType{ID: catalog.CTInt}, nil +} + func TestUpdate(t *testing.T) { ast := &compiler.UpdateStmt{ StmtBase: &compiler.StmtBase{}, @@ -51,28 +55,28 @@ func TestUpdate(t *testing.T) { }, } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 11}, &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 11}, + &vm.RewindCmd{P1: 1, P2: 10}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 2}, - &vm.IntegerCmd{P1: 1, P2: 3}, - &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 4}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 5}, &vm.DeleteCmd{P1: 1}, - &vm.InsertCmd{P1: 1, P2: 4, P3: 1}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.IntegerCmd{P1: 1, P2: 4}, + &vm.GotoCmd{P2: 1}, } mockCatalog := &mockUpdateCatalog{} plan, err := NewUpdate(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } @@ -87,7 +91,8 @@ func TestUpdateWithWhere(t *testing.T) { }, Predicate: &compiler.BinaryExpr{ Left: &compiler.ColumnRef{ - Column: "id", + Column: "id", + IsPrimaryKey: true, }, Operator: compiler.OpEq, Right: &compiler.IntLit{ @@ -96,30 +101,29 @@ func TestUpdateWithWhere(t *testing.T) { }, } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 13}, &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 14}, + &vm.RewindCmd{P1: 1, P2: 12}, &vm.RowIdCmd{P1: 1, P2: 1}, - &vm.IntegerCmd{P1: 1, P2: 2}, - &vm.NotEqualCmd{P1: 1, P2: 13, P3: 2}, - &vm.RowIdCmd{P1: 1, P2: 3}, - &vm.ColumnCmd{P1: 1, P2: 0, P3: 4}, - &vm.IntegerCmd{P1: 1, P2: 5}, - &vm.MakeRecordCmd{P1: 4, P2: 2, P3: 6}, + &vm.NotEqualCmd{P1: 1, P2: 11, P3: 2}, + &vm.RowIdCmd{P1: 1, P2: 4}, + &vm.ColumnCmd{P1: 1, P2: 0, P3: 5}, + &vm.CopyCmd{P1: 2, P2: 6}, + &vm.MakeRecordCmd{P1: 5, P2: 2, P3: 7}, &vm.DeleteCmd{P1: 1}, - &vm.InsertCmd{P1: 1, P2: 6, P3: 3}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.InsertCmd{P1: 1, P2: 7, P3: 4}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.IntegerCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, } mockCatalog := &mockUpdateCatalog{} plan, err := NewUpdate(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) } } diff --git a/vm/vm.go b/vm/vm.go index 8f321f4..9dd2ea9 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -45,6 +45,11 @@ type Command interface { explain(addr int) []*string } +// JumpCommand is a command capable of jumping +type JumpCommand interface { + SetJumpAddress(address int) +} + type cmdRes struct { doHalt bool nextAddress int @@ -472,6 +477,20 @@ func (c *NextCmd) explain(addr int) []*string { return formatExplain(addr, "Next", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +// GotoCmd jumps to address P2 +type GotoCmd cmd + +func (c *GotoCmd) execute(vm *vm, routine *routine) cmdRes { + return cmdRes{ + nextAddress: c.P2, + } +} + +func (c *GotoCmd) explain(addr int) []*string { + comment := fmt.Sprintf("Goto to addr[%d]", c.P2) + return formatExplain(addr, "Goto", c.P1, c.P2, c.P3, c.P4, c.P5, comment) +} + // MakeRecordCmd makes a byte array record for registers P1 through P1+P2-1 and // stores the record in register P3. type MakeRecordCmd cmd @@ -837,6 +856,10 @@ func (c *NotEqualCmd) explain(addr int) []*string { return formatExplain(addr, "NotEqual", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +func (c *NotEqualCmd) SetJumpAddress(address int) { + c.P2 = address +} + // IfNotCmd jumps to P2 if P1 is false otherwise fall through. type IfNotCmd cmd @@ -856,6 +879,10 @@ func (c *IfNotCmd) explain(addr int) []*string { return formatExplain(addr, "IfNot", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +func (c *IfNotCmd) SetJumpAddress(address int) { + c.P2 = address +} + // GteCmd if P1 is greater than or equal to P3 jump to P2 type GteCmd cmd @@ -895,6 +922,10 @@ func (c *GteCmd) explain(addr int) []*string { return formatExplain(addr, "Gte", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +func (c *GteCmd) SetJumpAddress(address int) { + c.P2 = address +} + // LteCmd if P1 is less than or equal to P3 jump to P2 type LteCmd cmd @@ -934,6 +965,10 @@ func (c *LteCmd) explain(addr int) []*string { return formatExplain(addr, "Lte", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +func (c *LteCmd) SetJumpAddress(address int) { + c.P2 = address +} + // VariableCmd substitutes variable number P1 into register P2. Where P1 is a // zero based index. type VariableCmd cmd