From a508c39fe2cce71c9521923dca46403199fde5ae Mon Sep 17 00:00:00 2001 From: Colton Date: Mon, 1 Dec 2025 21:29:22 -0700 Subject: [PATCH 01/17] begin v2 planner for update --- planner/update.go | 229 +++++++++++++++++++++++++++++++++++++++++ planner/update_test.go | 4 + vm/vm.go | 14 +++ 3 files changed, 247 insertions(+) diff --git a/planner/update.go b/planner/update.go index 7aa126e..65ab87e 100644 --- a/planner/update.go +++ b/planner/update.go @@ -2,6 +2,7 @@ package planner import ( "errors" + "fmt" "slices" "github.com/chirst/cdb/compiler" @@ -313,3 +314,231 @@ func (p *updateExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { rewindCmd.P2 = len(p.executionPlan.Commands) - 1 return p.executionPlan, nil } + +// CREATE TABLE table (? INTEGER PRIMARY KEY, ? INTEGER, ? TEXT) +// +// - create + +// SELECT * FROM table +// WHERE ? = ? +// +// - project +// - filter +// - scan + +// SELECT constant +// WHERE ? = ? +// +// - project +// - filter +// - constant + +// SELECT COUNT(*) FROM table +// +// - count (similar to scan and breaks rule and no project) + +// INSERT INTO table (?, ?) VALUES (?, ?) +// +// - insert +// - constant? + +// UPDATE TABLE table +// SET ? = ? +// WHERE ? = ? +// +// - update +// - filter +// - scan + +func generateUpdate() { + logicalPlan := &planV2{ + commands: []vm.Command{}, + constInts: make(map[int]int), + constStrings: make(map[string]int), + freeRegister: 1, + transactionType: 2, + cursorId: 1, + } + un := &updateNodeV2{ + plan: logicalPlan, + } + fn := &filterNodeV2{ + plan: logicalPlan, + } + fn.parent = un + un.child = fn + sn := &scanNodeV2{ + plan: logicalPlan, + } + sn.parent = fn + fn.child = sn + logicalPlan.root = un + logicalPlan.compile() + for i := range logicalPlan.commands { + fmt.Printf("%d %#v\n", i+1, logicalPlan.commands[i]) + } +} + +type planV2 struct { + root nodeV2 + commands []vm.Command + constInts map[int]int // int to register + constStrings map[string]int // string to register + freeRegister int + transactionType int // 0 none, 1 read, 2 write + cursorId int +} + +// declareConstInt gets or sets a register with the const value and returns the +// register. It is guaranteed the value will be in the register for the duration +// of the plan. +func (p *planV2) declareConstInt(i int) int { + _, ok := p.constInts[i] + if !ok { + p.constInts[i] = p.freeRegister + p.freeRegister += 1 + } + return p.constInts[i] +} + +// declareConstString gets or sets a register with the const value and returns +// the register. It is guaranteed the value will be in the register for the +// duration of the plan. +func (p *planV2) declareConstString(s string) int { + _, ok := p.constStrings[s] + if !ok { + p.constStrings[s] = p.freeRegister + p.freeRegister += 1 + } + return p.constStrings[s] +} + +func (p *planV2) compile() { + initCmd := &vm.InitCmd{} + p.commands = append(p.commands, initCmd) + p.root.produce() + p.commands = append(p.commands, &vm.HaltCmd{}) + initCmd.P2 = len(p.commands) + 1 + p.pushTransaction() + p.pushConstants() + p.commands = append(p.commands, &vm.GotoCmd{P1: 2}) +} + +func (p *planV2) pushTransaction() { + switch p.transactionType { + case 0: + return + case 1: + p.commands = append(p.commands, &vm.TransactionCmd{P2: 1}) + case 2: + p.commands = append(p.commands, &vm.TransactionCmd{P2: 2}) + default: + panic("unexpected transaction type") + } +} + +func (p *planV2) pushConstants() { + for v := range p.constInts { + p.commands = append(p.commands, &vm.IntegerCmd{P1: v, P2: p.constInts[v]}) + } + for v := range p.constStrings { + p.commands = append(p.commands, &vm.StringCmd{P1: p.constStrings[v], P4: v}) + } +} + +type nodeV2 interface { + produce() + consume() +} + +type updateNodeV2 struct { + child nodeV2 + plan *planV2 +} + +func (u *updateNodeV2) produce() { + u.child.produce() +} + +func (u *updateNodeV2) consume() { + startRecordRegister := u.plan.freeRegister + u.plan.commands = append(u.plan.commands, &vm.RowIdCmd{ + P1: u.plan.cursorId, + P2: u.plan.freeRegister, + }) + rowIdRegister := u.plan.freeRegister + u.plan.freeRegister += 1 + u.plan.commands = append(u.plan.commands, &vm.CopyCmd{ + P1: u.plan.declareConstInt(1), + P2: u.plan.freeRegister, + }) + u.plan.freeRegister += 1 + u.plan.commands = append(u.plan.commands, &vm.ColumnCmd{ + P1: u.plan.cursorId, + P2: 0, + P3: u.plan.freeRegister, + }) + endRecordRegister := u.plan.freeRegister + u.plan.freeRegister += 1 + u.plan.commands = append(u.plan.commands, &vm.MakeRecordCmd{ + P1: startRecordRegister, + P2: endRecordRegister, + P3: u.plan.freeRegister, + }) + recordRegister := u.plan.freeRegister + u.plan.freeRegister += 1 + u.plan.commands = append(u.plan.commands, &vm.DeleteCmd{ + P1: u.plan.cursorId, + }) + u.plan.commands = append(u.plan.commands, &vm.InsertCmd{ + P1: u.plan.cursorId, + P2: recordRegister, + P3: rowIdRegister, + }) +} + +type filterNodeV2 struct { + child nodeV2 + parent nodeV2 + plan *planV2 +} + +func (f *filterNodeV2) produce() { + f.child.produce() +} + +func (f *filterNodeV2) consume() { + f.plan.commands = append(f.plan.commands, &vm.ColumnCmd{ + P1: f.plan.cursorId, + P2: 1, + P3: f.plan.freeRegister, + }) + notEqualCmd := &vm.NotEqualCmd{ + P1: f.plan.freeRegister, + P3: f.plan.declareConstInt(1), + } + f.plan.commands = append(f.plan.commands, notEqualCmd) + f.parent.consume() + notEqualCmd.P2 = len(f.plan.commands) + 1 +} + +type scanNodeV2 struct { + parent nodeV2 + plan *planV2 +} + +func (s *scanNodeV2) produce() { + s.consume() +} + +func (s *scanNodeV2) consume() { + rewindCmd := &vm.RewindCmd{P1: s.plan.cursorId} + s.plan.commands = append(s.plan.commands, rewindCmd) + loopBeginAddress := len(s.plan.commands) + 1 + s.parent.consume() + s.plan.commands = append(s.plan.commands, &vm.NextCmd{ + P1: s.plan.cursorId, + P2: loopBeginAddress, + }) + rewindCmd.P2 = len(s.plan.commands) + 1 +} diff --git a/planner/update_test.go b/planner/update_test.go index 28e753e..320e156 100644 --- a/planner/update_test.go +++ b/planner/update_test.go @@ -123,3 +123,7 @@ func TestUpdateWithWhere(t *testing.T) { } } } + +func TestFoo(t *testing.T) { + generateUpdate() +} diff --git a/vm/vm.go b/vm/vm.go index 8f321f4..ee7dc5d 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -472,6 +472,20 @@ func (c *NextCmd) explain(addr int) []*string { return formatExplain(addr, "Next", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +// GotoCmd jumps to address P2 +type GotoCmd cmd + +func (c *GotoCmd) execute(vm *vm, routine *routine) cmdRes { + return cmdRes{ + nextAddress: c.P2, + } +} + +func (c *GotoCmd) explain(addr int) []*string { + comment := fmt.Sprintf("Goto to addr[%d]", c.P2) + return formatExplain(addr, "Goto", c.P1, c.P2, c.P3, c.P4, c.P5, comment) +} + // MakeRecordCmd makes a byte array record for registers P1 through P1+P2-1 and // stores the record in register P3. type MakeRecordCmd cmd From fc44b5943bd63434e34a0b51160358b2d03289c3 Mon Sep 17 00:00:00 2001 From: Colton Date: Sat, 20 Dec 2025 22:35:44 -0700 Subject: [PATCH 02/17] progress on generator --- planner/generator.go | 345 +++++++++++++++++++++++++++++++++ planner/predicate_generator.go | 209 ++++++++++++++++++++ planner/result_generator.go | 125 ++++++++++++ planner/update.go | 229 ---------------------- planner/update_test.go | 1 + vm/vm.go | 21 ++ 6 files changed, 701 insertions(+), 229 deletions(-) create mode 100644 planner/generator.go create mode 100644 planner/predicate_generator.go create mode 100644 planner/result_generator.go diff --git a/planner/generator.go b/planner/generator.go new file mode 100644 index 0000000..556b5cf --- /dev/null +++ b/planner/generator.go @@ -0,0 +1,345 @@ +package planner + +import ( + "fmt" + + "github.com/chirst/cdb/compiler" + "github.com/chirst/cdb/vm" +) + +// transactionType defines possible transactions for a query plan. +type transactionType int + +const ( + transactionTypeNone transactionType = 0 + transactionTypeRead transactionType = 1 + transactionTypeWrite transactionType = 2 +) + +// planV2 holds the necessary data and receivers for generating a plan as well +// as the final commands that define the execution plan. +type planV2 struct { + // root is the root node of the plan tree. + root nodeV2 + // commands is a list of commands that define the plan. + commands []vm.Command + // constInts is a mapping of constant integer values to the registers that + // contain the value. + constInts map[int]int + // constStrings is a mapping of constant string values to the registers that + // contain the value. + constStrings map[string]int + // constVars is a mapping of a variable's position to the registers that + // holds the variable's value. + constVars map[int]int + // freeRegister is a counter containing the next free register in the plan. + freeRegister int + // transactionType defines what kind of transaction the plan will need. + transactionType transactionType + // cursorId is the id of the cursor the plan is using. Note plans will + // eventually need to use more than one cursor, but for now it is convenient + // to pull the id from here. + cursorId int + // rootPageNumber is the root page number of the table cursorId is + // associated with. This should be a map at some point when multiple tables + // can be queried in one plan. + rootPageNumber int +} + +func generateUpdate() { + logicalPlan := &planV2{ + commands: []vm.Command{}, + constInts: make(map[int]int), + constStrings: make(map[string]int), + freeRegister: 1, + transactionType: transactionTypeWrite, + cursorId: 1, + } + un := &updateNodeV2{ + plan: logicalPlan, + updateExprs: []compiler.Expr{ + &compiler.ColumnRef{ + IsPrimaryKey: false, + ColIdx: 0, + }, + }, + } + fn := &filterNodeV2{ + plan: logicalPlan, + predicate: &compiler.IntLit{Value: 277}, + } + fn.parent = un + un.child = fn + sn := &scanNodeV2{ + plan: logicalPlan, + } + sn.parent = fn + fn.child = sn + logicalPlan.root = un + logicalPlan.compile() + for i := range logicalPlan.commands { + fmt.Printf("%d %#v\n", i+1, logicalPlan.commands[i]) + } +} + +func generateSelect() { + logicalPlan := &planV2{ + commands: []vm.Command{}, + constInts: make(map[int]int), + constStrings: make(map[string]int), + freeRegister: 1, + transactionType: transactionTypeRead, + cursorId: 1, + } + pn := &projectNodeV2{ + plan: logicalPlan, + } + fn := &filterNodeV2{ + plan: logicalPlan, + } + fn.parent = pn + pn.child = fn + sn := &scanNodeV2{ + plan: logicalPlan, + } + sn.parent = fn + fn.child = sn + logicalPlan.root = pn + logicalPlan.compile() + for i := range logicalPlan.commands { + fmt.Printf("%d %#v\n", i+1, logicalPlan.commands[i]) + } +} + +// declareConstInt gets or sets a register with the const value and returns the +// register. It is guaranteed the value will be in the register for the duration +// of the plan. +func (p *planV2) declareConstInt(i int) int { + _, ok := p.constInts[i] + if !ok { + p.constInts[i] = p.freeRegister + p.freeRegister += 1 + } + return p.constInts[i] +} + +// declareConstString gets or sets a register with the const value and returns +// the register. It is guaranteed the value will be in the register for the +// duration of the plan. +func (p *planV2) declareConstString(s string) int { + _, ok := p.constStrings[s] + if !ok { + p.constStrings[s] = p.freeRegister + p.freeRegister += 1 + } + return p.constStrings[s] +} + +// declareConstVar gets or sets a register with the const value and returns +// the register. It is guaranteed the value will be in the register for the +// duration of the plan. +func (p *planV2) declareConstVar(position int) int { + _, ok := p.constVars[position] + if !ok { + p.constVars[position] = p.freeRegister + p.freeRegister += 1 + } + return p.constVars[position] +} + +func (p *planV2) compile() { + initCmd := &vm.InitCmd{} + p.commands = append(p.commands, initCmd) + p.root.produce() + p.commands = append(p.commands, &vm.HaltCmd{}) + initCmd.P2 = len(p.commands) + 1 + p.pushTransaction() + p.pushConstants() + p.commands = append(p.commands, &vm.GotoCmd{P1: 2}) +} + +func (p *planV2) pushTransaction() { + switch p.transactionType { + case transactionTypeNone: + return + case transactionTypeRead: + p.commands = append( + p.commands, + &vm.TransactionCmd{P2: int(p.transactionType)}, + ) + p.commands = append( + p.commands, + &vm.OpenReadCmd{P1: p.cursorId, P2: p.rootPageNumber}, + ) + case transactionTypeWrite: + p.commands = append( + p.commands, + &vm.TransactionCmd{P2: int(p.transactionType)}, + ) + p.commands = append( + p.commands, + &vm.OpenWriteCmd{P1: p.cursorId, P2: p.rootPageNumber}, + ) + default: + panic("unexpected transaction type") + } +} + +func (p *planV2) pushConstants() { + for v := range p.constInts { + p.commands = append(p.commands, &vm.IntegerCmd{P1: v, P2: p.constInts[v]}) + } + for v := range p.constStrings { + p.commands = append(p.commands, &vm.StringCmd{P1: p.constStrings[v], P4: v}) + } + for v := range p.constVars { + p.commands = append(p.commands, &vm.VariableCmd{P1: p.constVars[v], P2: v}) + } +} + +type nodeV2 interface { + produce() + consume() +} + +type updateNodeV2 struct { + child nodeV2 + plan *planV2 + // updateExprs is formed from the update statement AST. The idea is to + // provide an expression for each column where the expression is either a + // columnRef or the complex expression from the right hand side of the SET + // keyword. Note it is important to provide the expressions in their correct + // ordinal position as the generator will not try to order them correctly. + // + // The row id is not allowed to be updated at the moment because it could + // cause infinite loops due to it changing the physical location of the + // record. The query plan will have to use a temporary storage to update + // primary keys. + updateExprs []compiler.Expr +} + +func (u *updateNodeV2) produce() { + u.child.produce() +} + +func (u *updateNodeV2) consume() { + // RowID + u.plan.commands = append(u.plan.commands, &vm.RowIdCmd{ + P1: u.plan.cursorId, + P2: u.plan.freeRegister, + }) + rowIdRegister := u.plan.freeRegister + u.plan.freeRegister += 1 + + // Reserve a contiguous block of free registers for the columns. This block + // will be used in makeRecord. + startRecordRegister := u.plan.freeRegister + u.plan.freeRegister += len(u.updateExprs) + endRecordRegister := u.plan.freeRegister + for _, e := range u.updateExprs { + generateExpressionTo(u.plan, e, startRecordRegister) + startRecordRegister += 1 + } + + // Make the record for inserting + u.plan.commands = append(u.plan.commands, &vm.MakeRecordCmd{ + P1: startRecordRegister, + P2: endRecordRegister, + P3: u.plan.freeRegister, + }) + recordRegister := u.plan.freeRegister + u.plan.freeRegister += 1 + + // Update by deleting then inserting + u.plan.commands = append(u.plan.commands, &vm.DeleteCmd{ + P1: u.plan.cursorId, + }) + u.plan.commands = append(u.plan.commands, &vm.InsertCmd{ + P1: u.plan.cursorId, + P2: recordRegister, + P3: rowIdRegister, + }) +} + +type filterNodeV2 struct { + child nodeV2 + parent nodeV2 + plan *planV2 + predicate compiler.Expr +} + +func (f *filterNodeV2) produce() { + f.child.produce() +} + +func (f *filterNodeV2) consume() { + if f.predicate == nil { + f.parent.consume() + return + } + jumpCommand := generatePredicate(f.plan, f.predicate) + f.parent.consume() + jumpCommand.SetJumpAddress(len(f.plan.commands) + 1) +} + +type scanNodeV2 struct { + parent nodeV2 + plan *planV2 +} + +func (s *scanNodeV2) produce() { + s.consume() +} + +func (s *scanNodeV2) consume() { + rewindCmd := &vm.RewindCmd{P1: s.plan.cursorId} + s.plan.commands = append(s.plan.commands, rewindCmd) + loopBeginAddress := len(s.plan.commands) + 1 + s.parent.consume() + s.plan.commands = append(s.plan.commands, &vm.NextCmd{ + P1: s.plan.cursorId, + P2: loopBeginAddress, + }) + rewindCmd.P2 = len(s.plan.commands) + 1 +} + +type projectNodeV2 struct { + child nodeV2 + plan *planV2 +} + +func (p *projectNodeV2) produce() { + p.child.produce() +} + +func (p *projectNodeV2) consume() { + startRegister := p.plan.freeRegister + p.plan.commands = append(p.plan.commands, &vm.RowIdCmd{ + P1: p.plan.cursorId, + P2: p.plan.freeRegister, + }) + p.plan.freeRegister += 1 + p.plan.commands = append(p.plan.commands, &vm.ColumnCmd{ + P1: p.plan.cursorId, + P2: 0, + P3: p.plan.freeRegister, + }) + p.plan.freeRegister += 1 + p.plan.commands = append(p.plan.commands, &vm.ResultRowCmd{ + P1: startRegister, + P2: p.plan.freeRegister - startRegister, + }) +} + +type constantNodeV2 struct { + parent nodeV2 + plan *planV2 +} + +func (c *constantNodeV2) produce() { + c.consume() +} + +func (c *constantNodeV2) consume() { + c.parent.consume() +} diff --git a/planner/predicate_generator.go b/planner/predicate_generator.go new file mode 100644 index 0000000..eb30e70 --- /dev/null +++ b/planner/predicate_generator.go @@ -0,0 +1,209 @@ +package planner + +import ( + "github.com/chirst/cdb/compiler" + "github.com/chirst/cdb/vm" +) + +// generatePredicate generates code to make a boolean jump for the given +// expression within the plan context. The function returns the jump command to +// lazily set the jump address. +func generatePredicate(plan *planV2, expression compiler.Expr) vm.JumpCommand { + pg := &predicateGenerator{} + pg.plan = plan + pg.commandOffset = len(pg.plan.commands) + pg.build(expression, 0) + return pg.jumpCommand +} + +// predicateGenerator builds commands to calculate the boolean result of an +// expression. +type predicateGenerator struct { + plan *planV2 + // openRegister is the next available register + openRegister int + // commandOffset is used to calculate the amount of commands already in the + // plan. + commandOffset int + // jumpCommand is the command used to make the jump. The command can be + // accessed to defer setting the jump address. + jumpCommand vm.JumpCommand +} + +func (p *predicateGenerator) build(e compiler.Expr, level int) (int, error) { + switch ce := e.(type) { + case *compiler.BinaryExpr: + ol, err := p.build(ce.Left, level+1) + if err != nil { + return 0, err + } + or, err := p.build(ce.Right, level+1) + if err != nil { + return 0, err + } + r := p.getNextRegister() + switch ce.Operator { + case compiler.OpAdd: + p.plan.commands = append( + p.plan.commands, + &vm.AddCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpDiv: + p.plan.commands = append( + p.plan.commands, + &vm.DivideCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpMul: + p.plan.commands = append( + p.plan.commands, + &vm.MultiplyCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpExp: + p.plan.commands = append( + p.plan.commands, + &vm.ExponentCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpSub: + p.plan.commands = append( + p.plan.commands, + &vm.SubtractCmd{P1: ol, P2: or, P3: r}, + ) + if level == 0 { + jc := &vm.IfNotCmd{P1: r} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return r, nil + case compiler.OpEq: + if level == 0 { + jc := &vm.NotEqualCmd{P1: ol, P3: or} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + return 0, nil + } + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(p.plan.commands) + jumpOverCount + p.commandOffset + p.plan.commands = append( + p.plan.commands, + &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + return r, nil + case compiler.OpLt: + if level == 0 { + jc := &vm.LteCmd{P1: or, P3: ol} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + return 0, nil + } + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(p.plan.commands) + jumpOverCount + p.commandOffset + p.plan.commands = append( + p.plan.commands, + &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + return r, nil + case compiler.OpGt: + if level == 0 { + jc := &vm.GteCmd{P1: or, P3: ol} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + return 0, nil + } + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(p.plan.commands) + jumpOverCount + p.commandOffset + p.plan.commands = append( + p.plan.commands, + &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + return r, nil + default: + panic("no vm command for operator") + } + case *compiler.ColumnRef: + colRefReg := p.valueRegisterFor(ce) + if level == 0 { + jc := &vm.IfNotCmd{P1: colRefReg} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return colRefReg, nil + case *compiler.IntLit: + cir := p.plan.declareConstInt(ce.Value) + if level == 0 { + jc := &vm.IfNotCmd{P1: cir} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return cir, nil + case *compiler.StringLit: + csr := p.plan.declareConstString(ce.Value) + if level == 0 { + jc := &vm.IfNotCmd{P1: csr} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return csr, nil + case *compiler.Variable: + cvr := p.plan.declareConstVar(ce.Position) + if level == 0 { + jc := &vm.IfNotCmd{P1: cvr} + p.jumpCommand = jc + p.plan.commands = append(p.plan.commands, jc) + } + return cvr, nil + } + panic("unhandled expression in predicate builder") +} + +func (p *predicateGenerator) valueRegisterFor(ce *compiler.ColumnRef) int { + if ce.IsPrimaryKey { + r := p.getNextRegister() + p.plan.commands = append(p.plan.commands, &vm.RowIdCmd{ + P1: p.plan.cursorId, + P2: r, + }) + return r + } + r := p.getNextRegister() + p.plan.commands = append(p.plan.commands, &vm.ColumnCmd{ + P1: p.plan.cursorId, + P2: ce.ColIdx, P3: r, + }) + return r +} + +func (p *predicateGenerator) getNextRegister() int { + r := p.openRegister + p.openRegister += 1 + return r +} diff --git a/planner/result_generator.go b/planner/result_generator.go new file mode 100644 index 0000000..739a09e --- /dev/null +++ b/planner/result_generator.go @@ -0,0 +1,125 @@ +package planner + +import ( + "github.com/chirst/cdb/compiler" + "github.com/chirst/cdb/vm" +) + +// generateExpressionTo takes the context of the plan and generates commands +// that land the result of the given expr in the toRegister. +func generateExpressionTo(plan *planV2, expr compiler.Expr, toRegister int) { + rg := &resultExprGenerator{} + rg.plan = plan + rg.outputRegister = toRegister + rg.commandOffset = len(rg.plan.commands) + rg.build(expr, 0) +} + +// resultExprGenerator builds commands for the given expression. +type resultExprGenerator struct { + plan *planV2 + // outputRegister is the target register for the result of the expression. + outputRegister int + // commandOffset is the amount of commands prior to calling this routine. + // Useful for calculating jump instructions. + commandOffset int +} + +func (e *resultExprGenerator) build(root compiler.Expr, level int) int { + switch n := root.(type) { + case *compiler.BinaryExpr: + ol := e.build(n.Left, level+1) + or := e.build(n.Right, level+1) + r := e.getNextRegister(level) + switch n.Operator { + case compiler.OpAdd: + e.plan.commands = append(e.plan.commands, &vm.AddCmd{P1: ol, P2: or, P3: r}) + case compiler.OpDiv: + e.plan.commands = append(e.plan.commands, &vm.DivideCmd{P1: ol, P2: or, P3: r}) + case compiler.OpMul: + e.plan.commands = append(e.plan.commands, &vm.MultiplyCmd{P1: ol, P2: or, P3: r}) + case compiler.OpExp: + e.plan.commands = append(e.plan.commands, &vm.ExponentCmd{P1: ol, P2: or, P3: r}) + case compiler.OpSub: + e.plan.commands = append(e.plan.commands, &vm.SubtractCmd{P1: ol, P2: or, P3: r}) + case compiler.OpEq: + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(e.plan.commands) + jumpOverCount + e.commandOffset + e.plan.commands = append( + e.plan.commands, + &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + case compiler.OpLt: + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(e.plan.commands) + jumpOverCount + e.commandOffset + e.plan.commands = append( + e.plan.commands, + &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + case compiler.OpGt: + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) + jumpOverCount := 2 + jumpAddress := len(e.plan.commands) + jumpOverCount + e.commandOffset + e.plan.commands = append( + e.plan.commands, + &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, + ) + e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 1, P2: r}) + default: + panic("no vm command for operator") + } + return r + case *compiler.ColumnRef: + r := e.getNextRegister(level) + if n.IsPrimaryKey { + e.plan.commands = append(e.plan.commands, &vm.RowIdCmd{P1: e.plan.cursorId, P2: r}) + } else { + e.plan.commands = append( + e.plan.commands, + &vm.ColumnCmd{P1: e.plan.cursorId, P2: n.ColIdx, P3: r}, + ) + } + return r + case *compiler.IntLit: + cir := e.plan.declareConstInt(n.Value) + if level == 0 { + e.plan.commands = append( + e.plan.commands, + &vm.CopyCmd{P1: cir, P2: e.outputRegister}, + ) + } + return cir + case *compiler.StringLit: + csr := e.plan.declareConstString(n.Value) + if level == 0 { + e.plan.commands = append( + e.plan.commands, + &vm.CopyCmd{P1: csr, P2: e.outputRegister}, + ) + } + return csr + case *compiler.Variable: + cvr := e.plan.declareConstVar(n.Position) + if level == 0 { + e.plan.commands = append( + e.plan.commands, + &vm.CopyCmd{P1: cvr, P2: e.outputRegister}, + ) + } + return cvr + } + panic("unhandled expression in expr command builder") +} + +func (e *resultExprGenerator) getNextRegister(level int) int { + if level == 0 { + return e.outputRegister + } + r := e.plan.freeRegister + e.plan.freeRegister += 1 + return r +} diff --git a/planner/update.go b/planner/update.go index 65ab87e..7aa126e 100644 --- a/planner/update.go +++ b/planner/update.go @@ -2,7 +2,6 @@ package planner import ( "errors" - "fmt" "slices" "github.com/chirst/cdb/compiler" @@ -314,231 +313,3 @@ func (p *updateExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { rewindCmd.P2 = len(p.executionPlan.Commands) - 1 return p.executionPlan, nil } - -// CREATE TABLE table (? INTEGER PRIMARY KEY, ? INTEGER, ? TEXT) -// -// - create - -// SELECT * FROM table -// WHERE ? = ? -// -// - project -// - filter -// - scan - -// SELECT constant -// WHERE ? = ? -// -// - project -// - filter -// - constant - -// SELECT COUNT(*) FROM table -// -// - count (similar to scan and breaks rule and no project) - -// INSERT INTO table (?, ?) VALUES (?, ?) -// -// - insert -// - constant? - -// UPDATE TABLE table -// SET ? = ? -// WHERE ? = ? -// -// - update -// - filter -// - scan - -func generateUpdate() { - logicalPlan := &planV2{ - commands: []vm.Command{}, - constInts: make(map[int]int), - constStrings: make(map[string]int), - freeRegister: 1, - transactionType: 2, - cursorId: 1, - } - un := &updateNodeV2{ - plan: logicalPlan, - } - fn := &filterNodeV2{ - plan: logicalPlan, - } - fn.parent = un - un.child = fn - sn := &scanNodeV2{ - plan: logicalPlan, - } - sn.parent = fn - fn.child = sn - logicalPlan.root = un - logicalPlan.compile() - for i := range logicalPlan.commands { - fmt.Printf("%d %#v\n", i+1, logicalPlan.commands[i]) - } -} - -type planV2 struct { - root nodeV2 - commands []vm.Command - constInts map[int]int // int to register - constStrings map[string]int // string to register - freeRegister int - transactionType int // 0 none, 1 read, 2 write - cursorId int -} - -// declareConstInt gets or sets a register with the const value and returns the -// register. It is guaranteed the value will be in the register for the duration -// of the plan. -func (p *planV2) declareConstInt(i int) int { - _, ok := p.constInts[i] - if !ok { - p.constInts[i] = p.freeRegister - p.freeRegister += 1 - } - return p.constInts[i] -} - -// declareConstString gets or sets a register with the const value and returns -// the register. It is guaranteed the value will be in the register for the -// duration of the plan. -func (p *planV2) declareConstString(s string) int { - _, ok := p.constStrings[s] - if !ok { - p.constStrings[s] = p.freeRegister - p.freeRegister += 1 - } - return p.constStrings[s] -} - -func (p *planV2) compile() { - initCmd := &vm.InitCmd{} - p.commands = append(p.commands, initCmd) - p.root.produce() - p.commands = append(p.commands, &vm.HaltCmd{}) - initCmd.P2 = len(p.commands) + 1 - p.pushTransaction() - p.pushConstants() - p.commands = append(p.commands, &vm.GotoCmd{P1: 2}) -} - -func (p *planV2) pushTransaction() { - switch p.transactionType { - case 0: - return - case 1: - p.commands = append(p.commands, &vm.TransactionCmd{P2: 1}) - case 2: - p.commands = append(p.commands, &vm.TransactionCmd{P2: 2}) - default: - panic("unexpected transaction type") - } -} - -func (p *planV2) pushConstants() { - for v := range p.constInts { - p.commands = append(p.commands, &vm.IntegerCmd{P1: v, P2: p.constInts[v]}) - } - for v := range p.constStrings { - p.commands = append(p.commands, &vm.StringCmd{P1: p.constStrings[v], P4: v}) - } -} - -type nodeV2 interface { - produce() - consume() -} - -type updateNodeV2 struct { - child nodeV2 - plan *planV2 -} - -func (u *updateNodeV2) produce() { - u.child.produce() -} - -func (u *updateNodeV2) consume() { - startRecordRegister := u.plan.freeRegister - u.plan.commands = append(u.plan.commands, &vm.RowIdCmd{ - P1: u.plan.cursorId, - P2: u.plan.freeRegister, - }) - rowIdRegister := u.plan.freeRegister - u.plan.freeRegister += 1 - u.plan.commands = append(u.plan.commands, &vm.CopyCmd{ - P1: u.plan.declareConstInt(1), - P2: u.plan.freeRegister, - }) - u.plan.freeRegister += 1 - u.plan.commands = append(u.plan.commands, &vm.ColumnCmd{ - P1: u.plan.cursorId, - P2: 0, - P3: u.plan.freeRegister, - }) - endRecordRegister := u.plan.freeRegister - u.plan.freeRegister += 1 - u.plan.commands = append(u.plan.commands, &vm.MakeRecordCmd{ - P1: startRecordRegister, - P2: endRecordRegister, - P3: u.plan.freeRegister, - }) - recordRegister := u.plan.freeRegister - u.plan.freeRegister += 1 - u.plan.commands = append(u.plan.commands, &vm.DeleteCmd{ - P1: u.plan.cursorId, - }) - u.plan.commands = append(u.plan.commands, &vm.InsertCmd{ - P1: u.plan.cursorId, - P2: recordRegister, - P3: rowIdRegister, - }) -} - -type filterNodeV2 struct { - child nodeV2 - parent nodeV2 - plan *planV2 -} - -func (f *filterNodeV2) produce() { - f.child.produce() -} - -func (f *filterNodeV2) consume() { - f.plan.commands = append(f.plan.commands, &vm.ColumnCmd{ - P1: f.plan.cursorId, - P2: 1, - P3: f.plan.freeRegister, - }) - notEqualCmd := &vm.NotEqualCmd{ - P1: f.plan.freeRegister, - P3: f.plan.declareConstInt(1), - } - f.plan.commands = append(f.plan.commands, notEqualCmd) - f.parent.consume() - notEqualCmd.P2 = len(f.plan.commands) + 1 -} - -type scanNodeV2 struct { - parent nodeV2 - plan *planV2 -} - -func (s *scanNodeV2) produce() { - s.consume() -} - -func (s *scanNodeV2) consume() { - rewindCmd := &vm.RewindCmd{P1: s.plan.cursorId} - s.plan.commands = append(s.plan.commands, rewindCmd) - loopBeginAddress := len(s.plan.commands) + 1 - s.parent.consume() - s.plan.commands = append(s.plan.commands, &vm.NextCmd{ - P1: s.plan.cursorId, - P2: loopBeginAddress, - }) - rewindCmd.P2 = len(s.plan.commands) + 1 -} diff --git a/planner/update_test.go b/planner/update_test.go index 320e156..ba66fe2 100644 --- a/planner/update_test.go +++ b/planner/update_test.go @@ -126,4 +126,5 @@ func TestUpdateWithWhere(t *testing.T) { func TestFoo(t *testing.T) { generateUpdate() + // generateSelect() } diff --git a/vm/vm.go b/vm/vm.go index ee7dc5d..9dd2ea9 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -45,6 +45,11 @@ type Command interface { explain(addr int) []*string } +// JumpCommand is a command capable of jumping +type JumpCommand interface { + SetJumpAddress(address int) +} + type cmdRes struct { doHalt bool nextAddress int @@ -851,6 +856,10 @@ func (c *NotEqualCmd) explain(addr int) []*string { return formatExplain(addr, "NotEqual", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +func (c *NotEqualCmd) SetJumpAddress(address int) { + c.P2 = address +} + // IfNotCmd jumps to P2 if P1 is false otherwise fall through. type IfNotCmd cmd @@ -870,6 +879,10 @@ func (c *IfNotCmd) explain(addr int) []*string { return formatExplain(addr, "IfNot", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +func (c *IfNotCmd) SetJumpAddress(address int) { + c.P2 = address +} + // GteCmd if P1 is greater than or equal to P3 jump to P2 type GteCmd cmd @@ -909,6 +922,10 @@ func (c *GteCmd) explain(addr int) []*string { return formatExplain(addr, "Gte", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +func (c *GteCmd) SetJumpAddress(address int) { + c.P2 = address +} + // LteCmd if P1 is less than or equal to P3 jump to P2 type LteCmd cmd @@ -948,6 +965,10 @@ func (c *LteCmd) explain(addr int) []*string { return formatExplain(addr, "Lte", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +func (c *LteCmd) SetJumpAddress(address int) { + c.P2 = address +} + // VariableCmd substitutes variable number P1 into register P2. Where P1 is a // zero based index. type VariableCmd cmd From 8f711c4c344905897370112242d97fbff10617b5 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 21 Dec 2025 11:00:39 -0700 Subject: [PATCH 03/17] update uses new code generator --- planner/cevisitor.go | 15 ++- planner/generator.go | 59 +++------- planner/node.go | 14 --- planner/plan.go | 4 +- planner/predicate_generator.go | 6 +- planner/update.go | 193 +++++++-------------------------- planner/update_test.go | 91 ++++++++++------ 7 files changed, 126 insertions(+), 256 deletions(-) diff --git a/planner/cevisitor.go b/planner/cevisitor.go index 0675450..13122a6 100644 --- a/planner/cevisitor.go +++ b/planner/cevisitor.go @@ -1,15 +1,24 @@ package planner -import "github.com/chirst/cdb/compiler" +import ( + "github.com/chirst/cdb/catalog" + "github.com/chirst/cdb/compiler" +) // catalogExprVisitor assigns catalog information to visited expressions. type catalogExprVisitor struct { - catalog selectCatalog + catalog cevCatalog tableName string err error } -func (c *catalogExprVisitor) Init(catalog selectCatalog, tableName string) { +type cevCatalog interface { + GetColumns(string) ([]string, error) + GetPrimaryKeyColumn(string) (string, error) + GetColumnType(tableName string, columnName string) (catalog.CdbType, error) +} + +func (c *catalogExprVisitor) Init(catalog cevCatalog, tableName string) { c.catalog = catalog c.tableName = tableName } diff --git a/planner/generator.go b/planner/generator.go index 556b5cf..46a65af 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -46,39 +46,15 @@ type planV2 struct { rootPageNumber int } -func generateUpdate() { - logicalPlan := &planV2{ +func newPlan(transactionType transactionType, rootPageNumber int) *planV2 { + return &planV2{ commands: []vm.Command{}, constInts: make(map[int]int), constStrings: make(map[string]int), freeRegister: 1, - transactionType: transactionTypeWrite, + transactionType: transactionType, cursorId: 1, - } - un := &updateNodeV2{ - plan: logicalPlan, - updateExprs: []compiler.Expr{ - &compiler.ColumnRef{ - IsPrimaryKey: false, - ColIdx: 0, - }, - }, - } - fn := &filterNodeV2{ - plan: logicalPlan, - predicate: &compiler.IntLit{Value: 277}, - } - fn.parent = un - un.child = fn - sn := &scanNodeV2{ - plan: logicalPlan, - } - sn.parent = fn - fn.child = sn - logicalPlan.root = un - logicalPlan.compile() - for i := range logicalPlan.commands { - fmt.Printf("%d %#v\n", i+1, logicalPlan.commands[i]) + rootPageNumber: rootPageNumber, } } @@ -152,10 +128,10 @@ func (p *planV2) compile() { p.commands = append(p.commands, initCmd) p.root.produce() p.commands = append(p.commands, &vm.HaltCmd{}) - initCmd.P2 = len(p.commands) + 1 + initCmd.P2 = len(p.commands) p.pushTransaction() p.pushConstants() - p.commands = append(p.commands, &vm.GotoCmd{P1: 2}) + p.commands = append(p.commands, &vm.GotoCmd{P2: 1}) } func (p *planV2) pushTransaction() { @@ -165,7 +141,7 @@ func (p *planV2) pushTransaction() { case transactionTypeRead: p.commands = append( p.commands, - &vm.TransactionCmd{P2: int(p.transactionType)}, + &vm.TransactionCmd{P2: 0}, ) p.commands = append( p.commands, @@ -174,7 +150,7 @@ func (p *planV2) pushTransaction() { case transactionTypeWrite: p.commands = append( p.commands, - &vm.TransactionCmd{P2: int(p.transactionType)}, + &vm.TransactionCmd{P2: 1}, ) p.commands = append( p.commands, @@ -235,16 +211,15 @@ func (u *updateNodeV2) consume() { // will be used in makeRecord. startRecordRegister := u.plan.freeRegister u.plan.freeRegister += len(u.updateExprs) - endRecordRegister := u.plan.freeRegister - for _, e := range u.updateExprs { - generateExpressionTo(u.plan, e, startRecordRegister) - startRecordRegister += 1 + recordRegisterCount := len(u.updateExprs) + for i, e := range u.updateExprs { + generateExpressionTo(u.plan, e, startRecordRegister+i) } // Make the record for inserting u.plan.commands = append(u.plan.commands, &vm.MakeRecordCmd{ P1: startRecordRegister, - P2: endRecordRegister, + P2: recordRegisterCount, P3: u.plan.freeRegister, }) recordRegister := u.plan.freeRegister @@ -273,13 +248,9 @@ func (f *filterNodeV2) produce() { } func (f *filterNodeV2) consume() { - if f.predicate == nil { - f.parent.consume() - return - } jumpCommand := generatePredicate(f.plan, f.predicate) f.parent.consume() - jumpCommand.SetJumpAddress(len(f.plan.commands) + 1) + jumpCommand.SetJumpAddress(len(f.plan.commands)) } type scanNodeV2 struct { @@ -294,13 +265,13 @@ func (s *scanNodeV2) produce() { func (s *scanNodeV2) consume() { rewindCmd := &vm.RewindCmd{P1: s.plan.cursorId} s.plan.commands = append(s.plan.commands, rewindCmd) - loopBeginAddress := len(s.plan.commands) + 1 + loopBeginAddress := len(s.plan.commands) s.parent.consume() s.plan.commands = append(s.plan.commands, &vm.NextCmd{ P1: s.plan.cursorId, P2: loopBeginAddress, }) - rewindCmd.P2 = len(s.plan.commands) + 1 + rewindCmd.P2 = len(s.plan.commands) } type projectNodeV2 struct { diff --git a/planner/node.go b/planner/node.go index a6d6692..5ee7468 100644 --- a/planner/node.go +++ b/planner/node.go @@ -105,17 +105,3 @@ type insertNode struct { // dimensional i.e. VALUES (v1, v2), (v3, v4) is [[v1, v2], [v3, v4]]. colValues [][]compiler.Expr } - -// updateNode represents an update operation -type updateNode struct { - // rootPage is the rootPage of the table the update is performed on. - rootPage int - // recordExprs is a list of expressions that can be evaluated to make the - // desired record. If a column is in the update set list that column will be - // some sort of expression. If a column is not in the set list it will - // simply be a columnRef. Note the ordering is important of these - // expressions because they have to make the record. - recordExprs []compiler.Expr - // predicate is the where clause or nil - predicate compiler.Expr -} diff --git a/planner/plan.go b/planner/plan.go index 279183c..2452c2d 100644 --- a/planner/plan.go +++ b/planner/plan.go @@ -153,7 +153,7 @@ func (i *insertNode) print() string { return "insert" } -func (u *updateNode) print() string { +func (u *updateNodeV2) print() string { return "update" } @@ -185,6 +185,6 @@ func (i *insertNode) children() []logicalNode { return []logicalNode{} } -func (u *updateNode) children() []logicalNode { +func (u *updateNodeV2) children() []logicalNode { return []logicalNode{} } diff --git a/planner/predicate_generator.go b/planner/predicate_generator.go index eb30e70..0ef6d47 100644 --- a/planner/predicate_generator.go +++ b/planner/predicate_generator.go @@ -20,8 +20,6 @@ func generatePredicate(plan *planV2, expression compiler.Expr) vm.JumpCommand { // expression. type predicateGenerator struct { plan *planV2 - // openRegister is the next available register - openRegister int // commandOffset is used to calculate the amount of commands already in the // plan. commandOffset int @@ -203,7 +201,7 @@ func (p *predicateGenerator) valueRegisterFor(ce *compiler.ColumnRef) int { } func (p *predicateGenerator) getNextRegister() int { - r := p.openRegister - p.openRegister += 1 + r := p.plan.freeRegister + p.plan.freeRegister += 1 return r } diff --git a/planner/update.go b/planner/update.go index 7aa126e..d2a38a8 100644 --- a/planner/update.go +++ b/planner/update.go @@ -4,6 +4,7 @@ import ( "errors" "slices" + "github.com/chirst/cdb/catalog" "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" ) @@ -14,6 +15,7 @@ type updateCatalog interface { GetRootPageNumber(string) (int, error) GetColumns(string) ([]string, error) GetPrimaryKeyColumn(string) (string, error) + GetColumnType(tableName string, columnName string) (catalog.CdbType, error) } // updatePanner houses the query planner and execution planner for a update @@ -27,12 +29,12 @@ type updatePlanner struct { type updateQueryPlanner struct { catalog updateCatalog stmt *compiler.UpdateStmt - queryPlan *updateNode + queryPlan *updateNodeV2 } // updateExecutionPlanner generates a byte code routine for the given queryPlan. type updateExecutionPlanner struct { - queryPlan *updateNode + queryPlan *updateNodeV2 executionPlan *vm.ExecutionPlan } @@ -68,11 +70,13 @@ func (p *updateQueryPlanner) getQueryPlan() (*QueryPlan, error) { if err != nil { return nil, errTableNotExist } - updateNode := &updateNode{ - rootPage: rootPage, - recordExprs: []compiler.Expr{}, + logicalPlan := newPlan(transactionTypeWrite, rootPage) + updateNode := &updateNodeV2{ + plan: logicalPlan, + updateExprs: []compiler.Expr{}, } p.queryPlan = updateNode + logicalPlan.root = updateNode if err := p.errIfPrimaryKeySet(); err != nil { return nil, err @@ -86,12 +90,24 @@ func (p *updateQueryPlanner) getQueryPlan() (*QueryPlan, error) { return nil, err } - if err := p.errIfSetExprNotSupported(); err != nil { - return nil, err + scanNode := &scanNodeV2{ + plan: logicalPlan, } - - if err := p.includeUpdate(); err != nil { - return nil, err + if p.stmt.Predicate != nil { + cev := &catalogExprVisitor{} + cev.Init(p.catalog, p.stmt.TableName) + p.stmt.Predicate.BreadthWalk(cev) + filterNode := &filterNodeV2{ + plan: logicalPlan, + predicate: p.stmt.Predicate, + parent: updateNode, + child: scanNode, + } + updateNode.child = filterNode + scanNode.parent = filterNode + } else { + scanNode.parent = updateNode + updateNode.child = scanNode } return &QueryPlan{ @@ -141,14 +157,17 @@ func (p *updateQueryPlanner) setQueryPlanRecordExpressions() error { } idx := 0 for _, schemaColumn := range schemaColumns { + if schemaColumn == pkColName { + continue + } if setListExpression, ok := p.stmt.SetList[schemaColumn]; ok { - p.queryPlan.recordExprs = append( - p.queryPlan.recordExprs, + p.queryPlan.updateExprs = append( + p.queryPlan.updateExprs, setListExpression, ) } else { - p.queryPlan.recordExprs = append( - p.queryPlan.recordExprs, + p.queryPlan.updateExprs = append( + p.queryPlan.updateExprs, &compiler.ColumnRef{ Table: p.stmt.TableName, Column: schemaColumn, @@ -161,54 +180,10 @@ func (p *updateQueryPlanner) setQueryPlanRecordExpressions() error { idx += 1 } } - return nil -} - -// errIfSetExprNotSupported is temporary until more expressions can be supported -// in the execution plan. -func (p *updateQueryPlanner) errIfSetExprNotSupported() error { - for _, e := range p.queryPlan.recordExprs { - switch e.(type) { - case *compiler.IntLit: - continue - case *compiler.StringLit: - continue - case *compiler.ColumnRef: - continue - default: - return errors.New("set list expression not supported") - } - } - return nil -} - -func (p *updateQueryPlanner) includeUpdate() error { - if p.stmt.Predicate == nil { - return nil - } - p.queryPlan.predicate = p.stmt.Predicate - t, ok := p.queryPlan.predicate.(*compiler.BinaryExpr) - supportErr := errors.New("only pk update supported in where clause") - if !ok { - return supportErr - } - l, ok := t.Left.(*compiler.ColumnRef) - if !ok { - return supportErr - } - pkColName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) - if err != nil { - return err - } - if l.Column != pkColName { - return supportErr - } - _, ok = t.Right.(*compiler.IntLit) - if !ok { - return supportErr - } - if t.Operator != compiler.OpEq { - return supportErr + for i := range p.queryPlan.updateExprs { + cev := &catalogExprVisitor{} + cev.Init(p.catalog, p.stmt.TableName) + p.queryPlan.updateExprs[i].BreadthWalk(cev) } return nil } @@ -221,95 +196,7 @@ func (p *updatePlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { return nil, err } } - return p.executionPlanner.getExecutionPlan() -} - -// getExecutionPlan transforms a query plan to a byte code routine. -func (p *updateExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { - freeRegisterCounter := 1 - // Init - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - p.executionPlan.Append(&vm.TransactionCmd{P2: 1}) - cursorId := 1 - p.executionPlan.Append(&vm.OpenWriteCmd{P1: cursorId, P2: p.queryPlan.rootPage}) - - // Go to start of table - rewindCmd := &vm.RewindCmd{P1: cursorId} // P2 deferred - p.executionPlan.Append(rewindCmd) - - // Loop - loopStartAddress := len(p.executionPlan.Commands) - - // If needed, include jump for if. - var notEqCmd *vm.NotEqualCmd - if p.queryPlan.predicate != nil { - p.executionPlan.Append(&vm.RowIdCmd{P1: cursorId, P2: freeRegisterCounter}) - freeRegisterCounter += 1 - // No ok checks because done in logical plan. - pe := p.queryPlan.predicate.(*compiler.BinaryExpr) - r := pe.Right.(*compiler.IntLit) - p.executionPlan.Append(&vm.IntegerCmd{P1: r.Value, P2: freeRegisterCounter}) - freeRegisterCounter += 1 - notEqCmd = &vm.NotEqualCmd{ - P1: freeRegisterCounter - 2, - P2: -1, // deferred - P3: freeRegisterCounter - 1, - } - p.executionPlan.Append(notEqCmd) - } - - // take each item in the set list and build to make a record - loopStartRegister := freeRegisterCounter - var pkRegister int - for _, expression := range p.queryPlan.recordExprs { - switch typedExpression := expression.(type) { - case *compiler.ColumnRef: - if typedExpression.IsPrimaryKey { - p.executionPlan.Append(&vm.RowIdCmd{ - P1: cursorId, - P2: freeRegisterCounter, - }) - pkRegister = freeRegisterCounter - } else { - p.executionPlan.Append(&vm.ColumnCmd{ - P1: cursorId, - P2: typedExpression.ColIdx, - P3: freeRegisterCounter, - }) - } - case *compiler.IntLit: - p.executionPlan.Append(&vm.IntegerCmd{ - P1: typedExpression.Value, - P2: freeRegisterCounter, - }) - case *compiler.StringLit: - p.executionPlan.Append(&vm.StringCmd{ - P1: freeRegisterCounter, - P4: typedExpression.Value, - }) - default: - return nil, errors.New("expression not supported") - } - freeRegisterCounter += 1 - } - p.executionPlan.Append(&vm.MakeRecordCmd{ - P1: loopStartRegister + 1, // plus 1 for the pk - P2: len(p.queryPlan.recordExprs) - 1, // minus 1 for the pk - P3: freeRegisterCounter, - }) - p.executionPlan.Append(&vm.DeleteCmd{P1: cursorId}) - p.executionPlan.Append(&vm.InsertCmd{ - P1: cursorId, - P2: freeRegisterCounter, - P3: pkRegister, - }) - p.executionPlan.Append(&vm.NextCmd{P1: cursorId, P2: loopStartAddress}) - if notEqCmd != nil { - notEqCmd.P2 = len(p.executionPlan.Commands) - 1 - } - - // End - p.executionPlan.Append(&vm.HaltCmd{}) - rewindCmd.P2 = len(p.executionPlan.Commands) - 1 - return p.executionPlan, nil + p.queryPlanner.queryPlan.plan.compile() + p.executionPlanner.executionPlan.Commands = p.queryPlanner.queryPlan.plan.commands + return p.executionPlanner.executionPlan, nil } diff --git a/planner/update_test.go b/planner/update_test.go index ba66fe2..73408e0 100644 --- a/planner/update_test.go +++ b/planner/update_test.go @@ -2,9 +2,11 @@ package planner import ( "errors" + "fmt" "reflect" "testing" + "github.com/chirst/cdb/catalog" "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" ) @@ -40,6 +42,32 @@ func (*mockUpdateCatalog) GetPrimaryKeyColumn(tableName string) (string, error) return "", errors.New("err mock catalog pk") } +func (mockUpdateCatalog) GetColumnType(tableName string, columnName string) (catalog.CdbType, error) { + return catalog.CdbType{ID: catalog.CTInt}, nil +} + +func assertCommandsMatch(t *testing.T, gotCommands, expectedCommands []vm.Command) { + didMatch := true + errOutput := "\n" + for i, c := range expectedCommands { + green := "\033[32m" + red := "\033[31m" + resetColor := "\033[0m" + color := green + if !reflect.DeepEqual(c, gotCommands[i]) { + didMatch = false + color = red + } + errOutput += fmt.Sprintf( + "%s%3d got %#v%s\n want %#v\n\n", + color, i, gotCommands[i], resetColor, c, + ) + } + if !didMatch { + t.Error(errOutput) + } +} + func TestUpdate(t *testing.T) { ast := &compiler.UpdateStmt{ StmtBase: &compiler.StmtBase{}, @@ -51,29 +79,27 @@ func TestUpdate(t *testing.T) { }, } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 11}, + &vm.InitCmd{P2: 10}, + &vm.RewindCmd{P1: 1, P2: 9}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 2}, - &vm.IntegerCmd{P1: 1, P2: 3}, - &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 4}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 5}, &vm.DeleteCmd{P1: 1}, - &vm.InsertCmd{P1: 1, P2: 4, P3: 1}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, + &vm.NextCmd{P1: 1, P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.IntegerCmd{P1: 1, P2: 4}, + &vm.GotoCmd{P2: 1}, } mockCatalog := &mockUpdateCatalog{} plan, err := NewUpdate(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, expectedCommands) } func TestUpdateWithWhere(t *testing.T) { @@ -87,7 +113,8 @@ func TestUpdateWithWhere(t *testing.T) { }, Predicate: &compiler.BinaryExpr{ Left: &compiler.ColumnRef{ - Column: "id", + Column: "id", + IsPrimaryKey: true, }, Operator: compiler.OpEq, Right: &compiler.IntLit{ @@ -96,35 +123,27 @@ func TestUpdateWithWhere(t *testing.T) { }, } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 14}, + &vm.InitCmd{P2: 12}, + &vm.RewindCmd{P1: 1, P2: 11}, &vm.RowIdCmd{P1: 1, P2: 1}, - &vm.IntegerCmd{P1: 1, P2: 2}, - &vm.NotEqualCmd{P1: 1, P2: 13, P3: 2}, - &vm.RowIdCmd{P1: 1, P2: 3}, - &vm.ColumnCmd{P1: 1, P2: 0, P3: 4}, - &vm.IntegerCmd{P1: 1, P2: 5}, - &vm.MakeRecordCmd{P1: 4, P2: 2, P3: 6}, + &vm.NotEqualCmd{P1: 1, P2: 10, P3: 2}, + &vm.RowIdCmd{P1: 1, P2: 4}, + &vm.ColumnCmd{P1: 1, P2: 0, P3: 5}, + &vm.CopyCmd{P1: 2, P2: 6}, + &vm.MakeRecordCmd{P1: 5, P2: 2, P3: 7}, &vm.DeleteCmd{P1: 1}, - &vm.InsertCmd{P1: 1, P2: 6, P3: 3}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.InsertCmd{P1: 1, P2: 7, P3: 4}, + &vm.NextCmd{P1: 1, P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.IntegerCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, } mockCatalog := &mockUpdateCatalog{} plan, err := NewUpdate(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } -} - -func TestFoo(t *testing.T) { - generateUpdate() - // generateSelect() + assertCommandsMatch(t, plan.Commands, expectedCommands) } From 8ba40ac90eade12e567a1f66c2d8808b0a8559bc Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 21 Dec 2025 22:18:07 -0700 Subject: [PATCH 04/17] convert select to new generator except for count --- planner/generator.go | 83 ++++- planner/plan.go | 8 + planner/select.go | 763 +++++------------------------------------ planner/select_test.go | 142 ++++---- 4 files changed, 226 insertions(+), 770 deletions(-) diff --git a/planner/generator.go b/planner/generator.go index 46a65af..b977703 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -2,6 +2,7 @@ package planner import ( "fmt" + "slices" "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" @@ -162,15 +163,50 @@ func (p *planV2) pushTransaction() { } func (p *planV2) pushConstants() { - for v := range p.constInts { - p.commands = append(p.commands, &vm.IntegerCmd{P1: v, P2: p.constInts[v]}) + // these constants are pushed ordered since maps are unordered making it + // difficult to assert that a sequence of instructions appears. + p.pushConstantInts() + p.pushConstantStrings() + p.pushConstantVars() +} + +func (p *planV2) pushConstantInts() { + temp := []*vm.IntegerCmd{} + for k := range p.constInts { + temp = append(temp, &vm.IntegerCmd{P1: k, P2: p.constInts[k]}) + } + slices.SortFunc(temp, func(a, b *vm.IntegerCmd) int { + return a.P2 - b.P2 + }) + for i := range temp { + p.commands = append(p.commands, temp[i]) } +} + +func (p *planV2) pushConstantStrings() { + temp := []*vm.StringCmd{} for v := range p.constStrings { p.commands = append(p.commands, &vm.StringCmd{P1: p.constStrings[v], P4: v}) } + slices.SortFunc(temp, func(a, b *vm.StringCmd) int { + return a.P2 - b.P2 + }) + for i := range temp { + p.commands = append(p.commands, temp[i]) + } +} + +func (p *planV2) pushConstantVars() { + temp := []*vm.VariableCmd{} for v := range p.constVars { p.commands = append(p.commands, &vm.VariableCmd{P1: p.constVars[v], P2: v}) } + slices.SortFunc(temp, func(a, b *vm.VariableCmd) int { + return a.P2 - b.P2 + }) + for i := range temp { + p.commands = append(p.commands, temp[i]) + } } type nodeV2 interface { @@ -274,9 +310,16 @@ func (s *scanNodeV2) consume() { rewindCmd.P2 = len(s.plan.commands) } +type projectionV2 struct { + expr compiler.Expr + // alias is the alias of the projection or no alias for the zero value. + alias string +} + type projectNodeV2 struct { - child nodeV2 - plan *planV2 + child nodeV2 + plan *planV2 + projections []projectionV2 } func (p *projectNodeV2) produce() { @@ -285,20 +328,14 @@ func (p *projectNodeV2) produce() { func (p *projectNodeV2) consume() { startRegister := p.plan.freeRegister - p.plan.commands = append(p.plan.commands, &vm.RowIdCmd{ - P1: p.plan.cursorId, - P2: p.plan.freeRegister, - }) - p.plan.freeRegister += 1 - p.plan.commands = append(p.plan.commands, &vm.ColumnCmd{ - P1: p.plan.cursorId, - P2: 0, - P3: p.plan.freeRegister, - }) - p.plan.freeRegister += 1 + reservedRegisters := len(p.projections) + p.plan.freeRegister += reservedRegisters + for i, projection := range p.projections { + generateExpressionTo(p.plan, projection.expr, startRegister+i) + } p.plan.commands = append(p.plan.commands, &vm.ResultRowCmd{ P1: startRegister, - P2: p.plan.freeRegister - startRegister, + P2: reservedRegisters, }) } @@ -314,3 +351,17 @@ func (c *constantNodeV2) produce() { func (c *constantNodeV2) consume() { c.parent.consume() } + +type countNodeV2 struct { + plan *planV2 +} + +func (c *countNodeV2) produce() { + c.consume() +} + +func (c *countNodeV2) consume() { + c.plan.commands = append(c.plan.commands, &vm.OpenReadCmd{P1: 1, P2: 2}) + c.plan.commands = append(c.plan.commands, &vm.CountCmd{P1: 1, P2: 1}) + c.plan.commands = append(c.plan.commands, &vm.ResultRowCmd{P1: 1, P2: 1}) +} diff --git a/planner/plan.go b/planner/plan.go index 2452c2d..34e2c07 100644 --- a/planner/plan.go +++ b/planner/plan.go @@ -113,6 +113,14 @@ func (p *projectNode) print() string { return "project" + list } +func (p *projectNodeV2) print() string { + return "project" +} + +func (p *projectNodeV2) children() []logicalNode { + return []logicalNode{} +} + func (p *projection) print() string { if p.isCount { return "count(*)" diff --git a/planner/select.go b/planner/select.go index 0fbf03c..f31b0dc 100644 --- a/planner/select.go +++ b/planner/select.go @@ -42,7 +42,7 @@ type selectQueryPlanner struct { stmt *compiler.SelectStmt // queryPlan contains the logical plan being built. The root node must be a // projection. - queryPlan *projectNode + queryPlan *projectNodeV2 } // selectExecutionPlanner converts logical nodes in a query plan tree to @@ -50,7 +50,7 @@ type selectQueryPlanner struct { type selectExecutionPlanner struct { // queryPlan contains the logical plan. This node is populated by calling // the QueryPlan method. - queryPlan *projectNode + queryPlan *projectNodeV2 // executionPlan contains the execution plan for the vm. This is built by // calling ExecutionPlan. executionPlan *vm.ExecutionPlan @@ -107,97 +107,76 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { return nil, err } - // Constant query has no "from". - if p.stmt.From == nil || p.stmt.From.TableName == "" { - constExprs := []compiler.Expr{} - for i := range p.stmt.ResultColumns { - constExprs = append(constExprs, p.stmt.ResultColumns[i].Expression) - } - child := &constantNode{ - resultColumns: constExprs, - predicate: p.stmt.Where, - } - projections, err := p.getProjections() + var tableName string + var rootPageNumber int + if p.stmt.From != nil { + tableName = p.stmt.From.TableName + } + if tableName != "" { + rootPageNumber, err = p.catalog.GetRootPageNumber(tableName) if err != nil { - return nil, err - } - p.queryPlan = &projectNode{ - projections: projections, - child: child, + return nil, errTableNotExist } - return newQueryPlan(p.queryPlan, p.stmt.ExplainQueryPlan), nil } - tableName := p.stmt.From.TableName - rootPageNumber, err := p.catalog.GetRootPageNumber(tableName) - if err != nil { - return nil, errTableNotExist + projections, err := p.getProjections() + for i := range projections { + cev := &catalogExprVisitor{} + cev.Init(p.catalog, tableName) + projections[i].expr.BreadthWalk(cev) } - - // Count node is specially supported for now. - qp, err := p.getCountNode(tableName, rootPageNumber) if err != nil { return nil, err } - if qp != nil { - return qp, nil + tt := transactionTypeRead + if tableName == "" { + tt = transactionTypeNone + } + plan := newPlan(tt, rootPageNumber) + projectNode := &projectNodeV2{ + plan: plan, + projections: projections, } - if p.stmt.Where != nil { cev := &catalogExprVisitor{} - cev.Init(p.catalog, p.stmt.From.TableName) - if cev.err != nil { - return nil, err - } + cev.Init(p.catalog, tableName) p.stmt.Where.BreadthWalk(cev) - } - - // At this point a constant and count should be ruled out. The planner isn't - // looking at using indexes yet so we are safe to focus on scanNodes. - child := &scanNode{ - tableName: tableName, - rootPage: rootPageNumber, - scanColumns: []scanColumn{}, - scanPredicate: p.stmt.Where, - } - for _, resultColumn := range p.stmt.ResultColumns { - if resultColumn.All { - cols, err := p.getScanColumns() - if err != nil { - return nil, err + filterNode := &filterNodeV2{ + parent: projectNode, + plan: plan, + predicate: p.stmt.Where, + } + projectNode.child = filterNode + if tableName == "" { + constNode := &constantNodeV2{ + plan: plan, } - child.scanColumns = append(child.scanColumns, cols...) - } else if resultColumn.AllTable != "" { - if tableName != resultColumn.AllTable { - return nil, fmt.Errorf("invalid expression %s.*", resultColumn.AllTable) + filterNode.child = constNode + constNode.parent = filterNode + } else { + scanNode := &scanNodeV2{ + plan: plan, } - cols, err := p.getScanColumns() - if err != nil { - return nil, err + filterNode.child = scanNode + scanNode.parent = filterNode + } + } else { + if tableName == "" { + constNode := &constantNodeV2{ + plan: plan, } - child.scanColumns = append(child.scanColumns, cols...) - } else if resultColumn.Expression != nil { - child.scanColumns = append(child.scanColumns, resultColumn.Expression) + projectNode.child = constNode + constNode.parent = projectNode } else { - return nil, fmt.Errorf("unhandled result column %#v", resultColumn) - } - for i := range child.scanColumns { - cev := &catalogExprVisitor{} - if cev.err != nil { - return nil, err + scanNode := &scanNodeV2{ + plan: plan, } - cev.Init(p.catalog, child.tableName) - child.scanColumns[i].BreadthWalk(cev) + projectNode.child = scanNode + scanNode.parent = projectNode } } - projections, err := p.getProjections() - if err != nil { - return nil, err - } - p.queryPlan = &projectNode{ - projections: projections, - child: child, - } + p.queryPlan = projectNode + plan.root = p.queryPlan return newQueryPlan(p.queryPlan, p.stmt.ExplainQueryPlan), nil } @@ -280,10 +259,7 @@ func foldExpr(e compiler.Expr) (compiler.Expr, error) { } // getCountNode supports the count function under special circumstances. -func (p *selectQueryPlanner) getCountNode(tableName string, rootPageNumber int) (*QueryPlan, error) { - if len(p.stmt.ResultColumns) == 0 { - return nil, nil - } +func (p *selectQueryPlanner) getCountNode(tableName string, rootPageNumber int) (*countNodeV2, error) { switch e := p.stmt.ResultColumns[0].Expression.(type) { case *compiler.FunctionExpr: if len(p.stmt.ResultColumns) != 1 { @@ -292,19 +268,10 @@ func (p *selectQueryPlanner) getCountNode(tableName string, rootPageNumber int) if e.FnType != compiler.FnCount { return nil, fmt.Errorf("only %s function is supported", e.FnType) } - child := &countNode{ - tableName: tableName, - rootPage: rootPageNumber, - } - projections, err := p.getProjections() - if err != nil { - return nil, err - } - p.queryPlan = &projectNode{ - projections: projections, - child: child, + cn := &countNodeV2{ + plan: p.queryPlan.plan, } - return newQueryPlan(p.queryPlan, p.stmt.ExplainQueryPlan), nil + return cn, nil } return nil, nil } @@ -339,8 +306,8 @@ func (p *selectQueryPlanner) getScanColumns() ([]scanColumn, error) { return scanColumns, nil } -func (p *selectQueryPlanner) getProjections() ([]projection, error) { - var projections []projection +func (p *selectQueryPlanner) getProjections() ([]projectionV2, error) { + var projections []projectionV2 for _, resultColumn := range p.stmt.ResultColumns { if resultColumn.All { cols, err := p.catalog.GetColumns(p.stmt.From.TableName) @@ -348,8 +315,11 @@ func (p *selectQueryPlanner) getProjections() ([]projection, error) { return nil, err } for _, c := range cols { - projections = append(projections, projection{ - colName: c, + projections = append(projections, projectionV2{ + expr: &compiler.ColumnRef{ + Table: p.stmt.From.TableName, + Column: c, + }, }) } } else if resultColumn.AllTable != "" { @@ -358,31 +328,18 @@ func (p *selectQueryPlanner) getProjections() ([]projection, error) { return nil, err } for _, c := range cols { - projections = append(projections, projection{ - colName: c, + projections = append(projections, projectionV2{ + expr: &compiler.ColumnRef{ + Table: p.stmt.From.TableName, + Column: c, + }, }) } } else if resultColumn.Expression != nil { - switch e := resultColumn.Expression.(type) { - case *compiler.ColumnRef: - colName := e.Column - if resultColumn.Alias != "" { - colName = resultColumn.Alias - } - projections = append(projections, projection{ - colName: colName, - }) - case *compiler.FunctionExpr: - projections = append(projections, projection{ - isCount: true, - colName: resultColumn.Alias, - }) - default: - projections = append(projections, projection{ - isCount: false, - colName: resultColumn.Alias, - }) - } + projections = append(projections, projectionV2{ + expr: resultColumn.Expression, + alias: resultColumn.Alias, + }) } } return projections, nil @@ -403,43 +360,15 @@ func (sp *selectPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { func (p *selectExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { p.setResultHeader() - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - switch c := p.queryPlan.child.(type) { - case *scanNode: - err := p.setResultTypes(c.scanColumns) - if err != nil { - return nil, err - } - p.executionPlan.Append(&vm.TransactionCmd{P2: 0}) - if err := p.buildScan(c); err != nil { - return nil, err - } - case *countNode: - err := p.setResultTypes([]compiler.Expr{&compiler.IntLit{}}) - if err != nil { - return nil, err - } - p.executionPlan.Append(&vm.TransactionCmd{P2: 0}) - p.buildOptimizedCountScan(c) - case *constantNode: - err := p.setResultTypes(c.resultColumns) - if err != nil { - return nil, err - } - if err := p.buildConstantNode(c); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("unhandled node %#v", c) - } - p.executionPlan.Append(&vm.HaltCmd{}) + p.queryPlan.plan.compile() + p.executionPlan.Commands = p.queryPlan.plan.commands return p.executionPlan, nil } func (p *selectExecutionPlanner) setResultHeader() { resultHeader := []string{} - for _, p := range p.queryPlan.projections { - resultHeader = append(resultHeader, p.colName) + for range p.queryPlan.projections { + resultHeader = append(resultHeader, "unknown") } p.executionPlan.ResultHeader = resultHeader } @@ -489,541 +418,3 @@ func getExprType(expr compiler.Expr) (catalog.CdbType, error) { return catalog.CdbType{ID: catalog.CTUnknown}, fmt.Errorf("no handler for expr type %v", expr) } } - -func (p *selectExecutionPlanner) buildScan(n *scanNode) error { - // Build a map of constant values to registers by walking result columns and - // the scan predicate. - const beginningRegister = 1 - crv := &constantRegisterVisitor{} - crv.Init(beginningRegister) - for _, c := range n.scanColumns { - c.BreadthWalk(crv) - } - if n.scanPredicate != nil { - n.scanPredicate.BreadthWalk(crv) - } - rcs := crv.GetRegisterCommands() - for _, rc := range rcs { - p.executionPlan.Append(rc) - } - - // Open an available cursor. Can just be 1 for now since no queries are - // supported at the moment that requires more than one cursor. - const cursorId = 1 - p.executionPlan.Append(&vm.OpenReadCmd{P1: cursorId, P2: n.rootPage}) - - // Rewind moves the aforementioned cursor to the "start" of the table. - rwc := &vm.RewindCmd{P1: cursorId} - p.executionPlan.Append(rwc) - - // Mark beginning of scan for rewind - scanBeginningCommand := len(p.executionPlan.Commands) - - // Reserve registers for the column result. Claim registers after as needed. - startScanRegister := crv.nextOpenRegister - endScanRegisterOffset := len(n.scanColumns) - - // This is the inside of the scan meaning how each result column is handled - // per iteration of the scan (loop). - var pkRegister int - var openRegister int - colRegisters := make(map[int]int) - for i, c := range n.scanColumns { - exprBuilder := &resultColumnCommandBuilder{} - exprBuilder.Build( - cursorId, - len(p.executionPlan.Commands), - startScanRegister+endScanRegisterOffset, - crv.constantRegisters, - crv.variableRegisters, - crv.stringRegisters, - startScanRegister+i, - c, - ) - if exprBuilder.pkRegister != 0 { - pkRegister = exprBuilder.pkRegister - } - openRegister = exprBuilder.openRegister - for crk, crv := range exprBuilder.colRegisters { - colRegisters[crk] = crv - } - for _, tc := range exprBuilder.commands { - p.executionPlan.Append(tc) - } - } - - // TODO predicate commands should come as early as possible to save - // instructions, but for now this is easier. - // - // Walk scan predicate and build commands to calculate a conditional jump. - if n.scanPredicate != nil { - bpb := &booleanPredicateBuilder{} - err := bpb.Build( - cursorId, - openRegister, - len(p.executionPlan.Commands), - len(p.executionPlan.Commands), - crv.constantRegisters, - colRegisters, - crv.variableRegisters, - crv.stringRegisters, - pkRegister, - n.scanPredicate, - ) - if err != nil { - return err - } - for _, bc := range bpb.commands { - p.executionPlan.Append(bc) - } - } - - // Result row gathers the aforementioned inside of the scan and makes them - // into a single row for the query results. - p.executionPlan.Append(&vm.ResultRowCmd{P1: startScanRegister, P2: endScanRegisterOffset}) - - // Falls through or goes back to the start of the scan loop. - p.executionPlan.Append(&vm.NextCmd{P1: cursorId, P2: scanBeginningCommand}) - - // Must tell the rewind command where to go in case the table is empty. - rwc.P2 = len(p.executionPlan.Commands) - return nil -} - -// buildOptimizedCountScan is a special optimization made when a table only has -// a count aggregate and no other projections. Since the optimized scan -// aggregates the count of tuples on each page, but does not look at individual -// tuples. -func (p *selectExecutionPlanner) buildOptimizedCountScan(n *countNode) { - const cursorId = 1 - p.executionPlan.Append(&vm.OpenReadCmd{P1: cursorId, P2: n.rootPage}) - p.executionPlan.Append(&vm.CountCmd{P1: cursorId, P2: 1}) - p.executionPlan.Append(&vm.ResultRowCmd{P1: 1, P2: 1}) -} - -// buildConstantNode is a single row operation produced by a "select" without a -// "from". -func (p *selectExecutionPlanner) buildConstantNode(n *constantNode) error { - // Build registers with constants. These are likely extra instructions, but - // okay since it allows this to follow the same pattern a scan does. - const beginningRegister = 1 - crv := &constantRegisterVisitor{} - crv.Init(beginningRegister) - for _, c := range n.resultColumns { - c.BreadthWalk(crv) - } - if n.predicate != nil { - n.predicate.BreadthWalk(crv) - } - rcs := crv.GetRegisterCommands() - for _, rc := range rcs { - p.executionPlan.Append(rc) - } - - // Like a scan, but for a single row. - reservedRegisterStart := crv.nextOpenRegister - reservedRegisterOffset := len(n.resultColumns) - var openRegister int - for i, rc := range n.resultColumns { - exprBuilder := &resultColumnCommandBuilder{} - exprBuilder.Build( - 1, - len(p.executionPlan.Commands), - reservedRegisterStart+reservedRegisterOffset, - crv.constantRegisters, - crv.variableRegisters, - crv.stringRegisters, - reservedRegisterStart+i, - rc, - ) - for _, tc := range exprBuilder.commands { - p.executionPlan.Append(tc) - } - openRegister = exprBuilder.openRegister - } - - if n.predicate != nil { - bpb := &booleanPredicateBuilder{} - err := bpb.Build( - 0, - openRegister, - len(p.executionPlan.Commands), - len(p.executionPlan.Commands), - crv.constantRegisters, - map[int]int{}, - crv.variableRegisters, - crv.stringRegisters, - 0, - n.predicate, - ) - if err != nil { - return err - } - for _, bc := range bpb.commands { - p.executionPlan.Append(bc) - } - } - - p.executionPlan.Append(&vm.ResultRowCmd{P1: reservedRegisterStart, P2: reservedRegisterOffset}) - return nil -} - -// resultColumnCommandBuilder builds commands for the given expression. -type resultColumnCommandBuilder struct { - // cursorId is the cursor for the related table. - cursorId int - // openRegister is the next available register. - openRegister int - // outputRegister is the target register for the result of the expression. - outputRegister int - // commands are the commands to evaluate the expression. - commands []vm.Command - // commandOffset is the amount of commands prior to calling this routine. - // Useful for calculating jump instructions. - commandOffset int - // litRegisters is a mapping of scalar values to registers containing them. - litRegisters map[int]int - // colRegisters is a mapping of column indexes to registers containing the - // column. This is for subsequent routines to reuse the result of these - // commands. - colRegisters map[int]int - // variableRegisters is a mapping of variable indices to registers. - variableRegisters map[int]int - // stringRegisters is a mapping of strings to registers - stringRegisters map[string]int - // pkRegister is 0 value unless a register has been filled as part of Build. - // This is for subsequent routines to reuse the result of the command. - pkRegister int -} - -func (e *resultColumnCommandBuilder) Build( - cursorId int, - commandOffset int, - openRegister int, - litRegisters map[int]int, - variableRegisters map[int]int, - stringRegisters map[string]int, - outputRegister int, - root compiler.Expr, -) int { - e.cursorId = cursorId - e.commandOffset = commandOffset - e.openRegister = openRegister - e.litRegisters = litRegisters - e.colRegisters = make(map[int]int) - e.variableRegisters = variableRegisters - e.stringRegisters = stringRegisters - e.outputRegister = outputRegister - return e.build(root, 0) -} - -func (e *resultColumnCommandBuilder) build(root compiler.Expr, level int) int { - switch n := root.(type) { - case *compiler.BinaryExpr: - ol := e.build(n.Left, level+1) - or := e.build(n.Right, level+1) - r := e.getNextRegister(level) - switch n.Operator { - case compiler.OpAdd: - e.commands = append(e.commands, &vm.AddCmd{P1: ol, P2: or, P3: r}) - case compiler.OpDiv: - e.commands = append(e.commands, &vm.DivideCmd{P1: ol, P2: or, P3: r}) - case compiler.OpMul: - e.commands = append(e.commands, &vm.MultiplyCmd{P1: ol, P2: or, P3: r}) - case compiler.OpExp: - e.commands = append(e.commands, &vm.ExponentCmd{P1: ol, P2: or, P3: r}) - case compiler.OpSub: - e.commands = append(e.commands, &vm.SubtractCmd{P1: ol, P2: or, P3: r}) - case compiler.OpEq: - e.commands = append(e.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(e.commands) + jumpOverCount + e.commandOffset - e.commands = append( - e.commands, - &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - e.commands = append(e.commands, &vm.IntegerCmd{P1: 1, P2: r}) - case compiler.OpLt: - e.commands = append(e.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(e.commands) + jumpOverCount + e.commandOffset - e.commands = append( - e.commands, - &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - e.commands = append(e.commands, &vm.IntegerCmd{P1: 1, P2: r}) - case compiler.OpGt: - e.commands = append(e.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(e.commands) + jumpOverCount + e.commandOffset - e.commands = append( - e.commands, - &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - e.commands = append(e.commands, &vm.IntegerCmd{P1: 1, P2: r}) - default: - panic("no vm command for operator") - } - return r - case *compiler.ColumnRef: - r := e.getNextRegister(level) - if n.IsPrimaryKey { - e.pkRegister = r - e.commands = append(e.commands, &vm.RowIdCmd{P1: e.cursorId, P2: r}) - } else { - e.colRegisters[n.ColIdx] = r - e.commands = append( - e.commands, - &vm.ColumnCmd{P1: e.cursorId, P2: n.ColIdx, P3: r}, - ) - } - return r - case *compiler.IntLit: - if level == 0 { - e.commands = append( - e.commands, - &vm.CopyCmd{P1: e.litRegisters[n.Value], P2: e.outputRegister}, - ) - } - return e.litRegisters[n.Value] - case *compiler.StringLit: - if level == 0 { - e.commands = append( - e.commands, - &vm.CopyCmd{P1: e.stringRegisters[n.Value], P2: e.outputRegister}, - ) - } - return e.stringRegisters[n.Value] - case *compiler.Variable: - if level == 0 { - e.commands = append( - e.commands, - &vm.CopyCmd{P1: e.variableRegisters[n.Position], P2: e.outputRegister}, - ) - } - return e.variableRegisters[n.Position] - } - panic("unhandled expression in expr command builder") -} - -func (e *resultColumnCommandBuilder) getNextRegister(level int) int { - if level == 0 { - return e.outputRegister - } - r := e.openRegister - e.openRegister += 1 - return r -} - -// booleanPredicateBuilder builds commands to calculate the boolean result of an -// expression. -type booleanPredicateBuilder struct { - // cursorId is the cursor for the associated table. - cursorId int - // openRegister is the next available register - openRegister int - // jumpAddress is the address the result of the boolean expression should - // conditionally jump to. - jumpAddress int - // commands is a list of commands representing the expression. - commands []vm.Command - // commandOffset is used to calculate the amount of commands already in the - // plan. - commandOffset int - // litRegisters is a mapping of scalar values to the register containing - // them. litRegisters should be guaranteed since they have a minimal cost - // due to being calculated outside of any scans/loops. - litRegisters map[int]int - // colRegisters is a mapping of table column index to register containing - // the column value. colRegisters may not be guaranteed since a projection - // may not require them, in which case colRegisters should be calculated as - // part of the predicate. - colRegisters map[int]int - // variableRegisters is a mapping of variable indices to registers. - variableRegisters map[int]int - // stringRegisters is a mapping of strings to registers - stringRegisters map[string]int - // pkRegister is unset when 0. Otherwise, pkRegister is the register - // containing the table row id. pkRegister may not be guaranteed depending - // on the projection in which case the register should be calculated as part - // of the expression evaluation. - pkRegister int -} - -func (p *booleanPredicateBuilder) Build( - cursorId int, - openRegister int, - jumpAddress int, - commandOffset int, - litRegisters map[int]int, - colRegisters map[int]int, - variableRegisters map[int]int, - stringRegisters map[string]int, - pkRegister int, - e compiler.Expr, -) error { - p.cursorId = cursorId - p.openRegister = openRegister - p.jumpAddress = jumpAddress - p.commandOffset = commandOffset - p.litRegisters = litRegisters - p.colRegisters = colRegisters - p.variableRegisters = variableRegisters - p.stringRegisters = stringRegisters - p.pkRegister = pkRegister - _, err := p.build(e, 0) - return err -} - -func (p *booleanPredicateBuilder) build(e compiler.Expr, level int) (int, error) { - switch ce := e.(type) { - case *compiler.BinaryExpr: - ol, err := p.build(ce.Left, level+1) - if err != nil { - return 0, err - } - or, err := p.build(ce.Right, level+1) - if err != nil { - return 0, err - } - r := p.getNextRegister() - switch ce.Operator { - case compiler.OpAdd: - p.commands = append(p.commands, &vm.AddCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpDiv: - p.commands = append(p.commands, &vm.DivideCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpMul: - p.commands = append(p.commands, &vm.MultiplyCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpExp: - p.commands = append(p.commands, &vm.ExponentCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpSub: - p.commands = append(p.commands, &vm.SubtractCmd{P1: ol, P2: or, P3: r}) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: r, P2: p.getJumpAddress()}) - } - return r, nil - case compiler.OpEq: - if level == 0 { - p.commands = append( - p.commands, - &vm.NotEqualCmd{P1: ol, P2: p.getJumpAddress(), P3: or}, - ) - return 0, nil - } - p.commands = append(p.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(p.commands) + jumpOverCount + p.commandOffset - p.commands = append( - p.commands, - &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - p.commands = append(p.commands, &vm.IntegerCmd{P1: 1, P2: r}) - return r, nil - case compiler.OpLt: - if level == 0 { - p.commands = append( - p.commands, - &vm.LteCmd{P1: or, P2: p.getJumpAddress(), P3: ol}, - ) - return 0, nil - } - p.commands = append(p.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(p.commands) + jumpOverCount + p.commandOffset - p.commands = append( - p.commands, - &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - p.commands = append(p.commands, &vm.IntegerCmd{P1: 1, P2: r}) - return r, nil - case compiler.OpGt: - if level == 0 { - p.commands = append( - p.commands, - &vm.GteCmd{P1: or, P2: p.getJumpAddress(), P3: ol}, - ) - return 0, nil - } - p.commands = append(p.commands, &vm.IntegerCmd{P1: 0, P2: r}) - jumpOverCount := 2 - jumpAddress := len(p.commands) + jumpOverCount + p.commandOffset - p.commands = append( - p.commands, - &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, - ) - p.commands = append(p.commands, &vm.IntegerCmd{P1: 1, P2: r}) - return r, nil - default: - panic("no vm command for operator") - } - case *compiler.ColumnRef: - colRefReg := p.valueRegisterFor(ce) - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: colRefReg, P2: p.getJumpAddress()}) - } - return colRefReg, nil - case *compiler.IntLit: - litReg := p.litRegisters[ce.Value] - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: litReg, P2: p.getJumpAddress()}) - } - return litReg, nil - case *compiler.StringLit: - strReg := p.stringRegisters[ce.Value] - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: strReg, P2: p.getJumpAddress()}) - } - return strReg, nil - case *compiler.Variable: - varReg := p.variableRegisters[ce.Position] - if level == 0 { - p.commands = append(p.commands, &vm.IfNotCmd{P1: varReg, P2: p.getJumpAddress()}) - } - return varReg, nil - } - panic("unhandled expression in predicate builder") -} - -func (p *booleanPredicateBuilder) getJumpAddress() int { - return p.jumpAddress + len(p.commands) + 2 -} - -func (p *booleanPredicateBuilder) valueRegisterFor(ce *compiler.ColumnRef) int { - if ce.IsPrimaryKey { - if p.pkRegister == 0 { - r := p.getNextRegister() - p.commands = append(p.commands, &vm.RowIdCmd{P1: p.cursorId, P2: r}) - return r - } - return p.pkRegister - } - cr := p.colRegisters[ce.ColIdx] - if cr == 0 { - r := p.getNextRegister() - p.commands = append(p.commands, &vm.ColumnCmd{P1: p.cursorId, P2: ce.ColIdx, P3: r}) - return r - } - return cr -} - -func (p *booleanPredicateBuilder) getNextRegister() int { - r := p.openRegister - p.openRegister += 1 - return r -} diff --git a/planner/select_test.go b/planner/select_test.go index 2b9a631..dc2c08e 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -2,7 +2,6 @@ package planner import ( "errors" - "reflect" "testing" "github.com/chirst/cdb/catalog" @@ -59,15 +58,16 @@ func TestSelectPlan(t *testing.T) { { description: "StarWithPrimaryKey", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 8}, + &vm.InitCmd{P2: 7}, + &vm.RewindCmd{P1: 1, P2: 6}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -88,15 +88,16 @@ func TestSelectPlan(t *testing.T) { { description: "StarWithoutPrimaryKey", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 8}, + &vm.InitCmd{P2: 7}, + &vm.RewindCmd{P1: 1, P2: 6}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 1}, &vm.ColumnCmd{P1: 1, P2: 1, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -124,16 +125,17 @@ func TestSelectPlan(t *testing.T) { }, }, expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 9}, + &vm.InitCmd{P2: 8}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 1}, &vm.RowIdCmd{P1: 1, P2: 2}, &vm.ColumnCmd{P1: 1, P2: 1, P3: 3}, &vm.ResultRowCmd{P1: 1, P2: 3}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, }, mockCatalogSetup: func(m *mockSelectCatalog) *mockSelectCatalog { m.primaryKeyColumnName = "id" @@ -167,16 +169,17 @@ func TestSelectPlan(t *testing.T) { }, }, expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, + &vm.InitCmd{P2: 7}, + &vm.RewindCmd{P1: 1, P2: 6}, + &vm.RowIdCmd{P1: 1, P2: 2}, + &vm.AddCmd{P1: 2, P2: 3, P3: 1}, + &vm.ResultRowCmd{P1: 1, P2: 1}, + &vm.NextCmd{P1: 1, P2: 2}, + &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.IntegerCmd{P1: 10, P2: 1}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 9}, - &vm.RowIdCmd{P1: 1, P2: 3}, - &vm.AddCmd{P1: 3, P2: 1, P3: 2}, - &vm.ResultRowCmd{P1: 2, P2: 1}, - &vm.NextCmd{P1: 1, P2: 5}, - &vm.HaltCmd{}, + &vm.IntegerCmd{P1: 10, P2: 3}, + &vm.GotoCmd{P2: 1}, }, mockCatalogSetup: func(m *mockSelectCatalog) *mockSelectCatalog { m.primaryKeyColumnName = "id" @@ -188,15 +191,16 @@ func TestSelectPlan(t *testing.T) { { description: "AllTable", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 8}, + &vm.InitCmd{P2: 7}, + &vm.RewindCmd{P1: 1, P2: 6}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -222,14 +226,15 @@ func TestSelectPlan(t *testing.T) { { description: "SpecificColumnPrimaryKeyMiddleOrdinal", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 7}, + &vm.InitCmd{P2: 6}, + &vm.RewindCmd{P1: 1, P2: 5}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ResultRowCmd{P1: 1, P2: 1}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -258,15 +263,16 @@ func TestSelectPlan(t *testing.T) { { description: "SpecificColumns", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 8}, + &vm.InitCmd{P2: 7}, + &vm.RewindCmd{P1: 1, P2: 6}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 1, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 4}, + &vm.NextCmd{P1: 1, P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -322,19 +328,20 @@ func TestSelectPlan(t *testing.T) { { description: "Operators", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.IntegerCmd{P1: 1, P2: 1}, - &vm.IntegerCmd{P1: 18, P2: 2}, - &vm.IntegerCmd{P1: 387420489, P2: 3}, - &vm.IntegerCmd{P1: 81, P2: 4}, - &vm.IntegerCmd{P1: 0, P2: 5}, - &vm.CopyCmd{P1: 1, P2: 6}, - &vm.CopyCmd{P1: 2, P2: 7}, - &vm.CopyCmd{P1: 3, P2: 8}, - &vm.CopyCmd{P1: 4, P2: 9}, - &vm.CopyCmd{P1: 5, P2: 10}, - &vm.ResultRowCmd{P1: 6, P2: 5}, + &vm.InitCmd{P2: 8}, + &vm.CopyCmd{P1: 6, P2: 1}, + &vm.CopyCmd{P1: 7, P2: 2}, + &vm.CopyCmd{P1: 8, P2: 3}, + &vm.CopyCmd{P1: 9, P2: 4}, + &vm.CopyCmd{P1: 10, P2: 5}, + &vm.ResultRowCmd{P1: 1, P2: 5}, &vm.HaltCmd{}, + &vm.IntegerCmd{P1: 1, P2: 6}, + &vm.IntegerCmd{P1: 18, P2: 7}, + &vm.IntegerCmd{P1: 387420489, P2: 8}, + &vm.IntegerCmd{P1: 81, P2: 9}, + &vm.IntegerCmd{P1: 0, P2: 10}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -381,16 +388,18 @@ func TestSelectPlan(t *testing.T) { { description: "with where clause", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, + &vm.InitCmd{P2: 8}, + &vm.RewindCmd{P1: 1, P2: 7}, + &vm.RowIdCmd{P1: 1, P2: 1}, + &vm.NotEqualCmd{P1: 1, P2: 6, P3: 2}, + &vm.RowIdCmd{P1: 1, P2: 4}, + &vm.ResultRowCmd{P1: 4, P2: 1}, + &vm.NextCmd{P1: 1, P2: 2}, + &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.IntegerCmd{P1: 1, P2: 1}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 9}, - &vm.RowIdCmd{P1: 1, P2: 2}, - &vm.NotEqualCmd{P1: 2, P2: 8, P3: 1}, - &vm.ResultRowCmd{P1: 2, P2: 1}, - &vm.NextCmd{P1: 1, P2: 5}, - &vm.HaltCmd{}, + &vm.IntegerCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -418,11 +427,12 @@ func TestSelectPlan(t *testing.T) { { description: "ConstantString", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.StringCmd{P1: 1, P4: "foo"}, - &vm.CopyCmd{P1: 1, P2: 2}, - &vm.ResultRowCmd{P1: 2, P2: 1}, + &vm.InitCmd{P2: 4}, + &vm.CopyCmd{P1: 2, P2: 1}, + &vm.ResultRowCmd{P1: 1, P2: 1}, &vm.HaltCmd{}, + &vm.StringCmd{P1: 2, P4: "foo"}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, @@ -451,11 +461,7 @@ func TestSelectPlan(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range c.expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, c.expectedCommands) }) } } From 7a5edc84ab221af23ea51f730ade6637277a31f4 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 21 Dec 2025 22:56:41 -0700 Subject: [PATCH 05/17] support count --- planner/generator.go | 50 ++++++-------------- planner/plan.go | 8 ++++ planner/select.go | 103 ++++++++++++++++++----------------------- planner/select_test.go | 7 +-- 4 files changed, 71 insertions(+), 97 deletions(-) diff --git a/planner/generator.go b/planner/generator.go index b977703..eb63a85 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -1,7 +1,6 @@ package planner import ( - "fmt" "slices" "github.com/chirst/cdb/compiler" @@ -52,6 +51,7 @@ func newPlan(transactionType transactionType, rootPageNumber int) *planV2 { commands: []vm.Command{}, constInts: make(map[int]int), constStrings: make(map[string]int), + constVars: make(map[int]int), freeRegister: 1, transactionType: transactionType, cursorId: 1, @@ -59,35 +59,6 @@ func newPlan(transactionType transactionType, rootPageNumber int) *planV2 { } } -func generateSelect() { - logicalPlan := &planV2{ - commands: []vm.Command{}, - constInts: make(map[int]int), - constStrings: make(map[string]int), - freeRegister: 1, - transactionType: transactionTypeRead, - cursorId: 1, - } - pn := &projectNodeV2{ - plan: logicalPlan, - } - fn := &filterNodeV2{ - plan: logicalPlan, - } - fn.parent = pn - pn.child = fn - sn := &scanNodeV2{ - plan: logicalPlan, - } - sn.parent = fn - fn.child = sn - logicalPlan.root = pn - logicalPlan.compile() - for i := range logicalPlan.commands { - fmt.Printf("%d %#v\n", i+1, logicalPlan.commands[i]) - } -} - // declareConstInt gets or sets a register with the const value and returns the // register. It is guaranteed the value will be in the register for the duration // of the plan. @@ -199,7 +170,7 @@ func (p *planV2) pushConstantStrings() { func (p *planV2) pushConstantVars() { temp := []*vm.VariableCmd{} for v := range p.constVars { - p.commands = append(p.commands, &vm.VariableCmd{P1: p.constVars[v], P2: v}) + p.commands = append(p.commands, &vm.VariableCmd{P1: v, P2: p.constVars[v]}) } slices.SortFunc(temp, func(a, b *vm.VariableCmd) int { return a.P2 - b.P2 @@ -353,7 +324,8 @@ func (c *constantNodeV2) consume() { } type countNodeV2 struct { - plan *planV2 + plan *planV2 + projection projectionV2 } func (c *countNodeV2) produce() { @@ -361,7 +333,15 @@ func (c *countNodeV2) produce() { } func (c *countNodeV2) consume() { - c.plan.commands = append(c.plan.commands, &vm.OpenReadCmd{P1: 1, P2: 2}) - c.plan.commands = append(c.plan.commands, &vm.CountCmd{P1: 1, P2: 1}) - c.plan.commands = append(c.plan.commands, &vm.ResultRowCmd{P1: 1, P2: 1}) + c.plan.commands = append(c.plan.commands, &vm.CountCmd{ + P1: c.plan.cursorId, + P2: c.plan.freeRegister, + }) + countRegister := c.plan.freeRegister + countResults := 1 + c.plan.freeRegister += 1 + c.plan.commands = append(c.plan.commands, &vm.ResultRowCmd{ + P1: countRegister, + P2: countResults, + }) } diff --git a/planner/plan.go b/planner/plan.go index 34e2c07..35877e0 100644 --- a/planner/plan.go +++ b/planner/plan.go @@ -196,3 +196,11 @@ func (i *insertNode) children() []logicalNode { func (u *updateNodeV2) children() []logicalNode { return []logicalNode{} } + +func (p *countNodeV2) print() string { + return "count" +} + +func (p *countNodeV2) children() []logicalNode { + return []logicalNode{} +} diff --git a/planner/select.go b/planner/select.go index f31b0dc..e5295fa 100644 --- a/planner/select.go +++ b/planner/select.go @@ -42,7 +42,7 @@ type selectQueryPlanner struct { stmt *compiler.SelectStmt // queryPlan contains the logical plan being built. The root node must be a // projection. - queryPlan *projectNodeV2 + queryPlan *planV2 } // selectExecutionPlanner converts logical nodes in a query plan tree to @@ -50,7 +50,7 @@ type selectQueryPlanner struct { type selectExecutionPlanner struct { // queryPlan contains the logical plan. This node is populated by calling // the QueryPlan method. - queryPlan *projectNodeV2 + queryPlan *planV2 // executionPlan contains the execution plan for the vm. This is built by // calling ExecutionPlan. executionPlan *vm.ExecutionPlan @@ -120,14 +120,36 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { } projections, err := p.getProjections() + if err != nil { + return nil, err + } for i := range projections { cev := &catalogExprVisitor{} cev.Init(p.catalog, tableName) projections[i].expr.BreadthWalk(cev) } - if err != nil { - return nil, err + + hasFunc := false + for i := range projections { + _, ok := projections[i].expr.(*compiler.FunctionExpr) + if ok { + hasFunc = true + } + } + if hasFunc { + if len(projections) != 1 { + return nil, errors.New("only one projection allowed for COUNT") + } + if tableName == "" { + return nil, errors.New("must have from for COUNT") + } + plan := newPlan(transactionTypeRead, rootPageNumber) + cn := &countNodeV2{plan: plan, projection: projections[0]} + p.queryPlan = plan + plan.root = cn + return newQueryPlan(cn, p.stmt.ExplainQueryPlan), nil } + tt := transactionTypeRead if tableName == "" { tt = transactionTypeNone @@ -175,9 +197,9 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { scanNode.parent = projectNode } } - p.queryPlan = projectNode - plan.root = p.queryPlan - return newQueryPlan(p.queryPlan, p.stmt.ExplainQueryPlan), nil + p.queryPlan = plan + plan.root = projectNode + return newQueryPlan(projectNode, p.stmt.ExplainQueryPlan), nil } func (p *selectQueryPlanner) optimizeResultColumns() error { @@ -258,54 +280,6 @@ func foldExpr(e compiler.Expr) (compiler.Expr, error) { } } -// getCountNode supports the count function under special circumstances. -func (p *selectQueryPlanner) getCountNode(tableName string, rootPageNumber int) (*countNodeV2, error) { - switch e := p.stmt.ResultColumns[0].Expression.(type) { - case *compiler.FunctionExpr: - if len(p.stmt.ResultColumns) != 1 { - return nil, errors.New("count with other result columns not supported") - } - if e.FnType != compiler.FnCount { - return nil, fmt.Errorf("only %s function is supported", e.FnType) - } - cn := &countNodeV2{ - plan: p.queryPlan.plan, - } - return cn, nil - } - return nil, nil -} - -func (p *selectQueryPlanner) getScanColumns() ([]scanColumn, error) { - pkColName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.From.TableName) - if err != nil { - return nil, err - } - cols, err := p.catalog.GetColumns(p.stmt.From.TableName) - if err != nil { - return nil, err - } - scanColumns := []scanColumn{} - idx := 0 - for _, c := range cols { - if c == pkColName { - scanColumns = append(scanColumns, &compiler.ColumnRef{ - Table: p.stmt.From.TableName, - Column: c, - IsPrimaryKey: c == pkColName, - }) - } else { - scanColumns = append(scanColumns, &compiler.ColumnRef{ - Table: p.stmt.From.TableName, - Column: c, - ColIdx: idx, - }) - idx += 1 - } - } - return scanColumns, nil -} - func (p *selectQueryPlanner) getProjections() ([]projectionV2, error) { var projections []projectionV2 for _, resultColumn := range p.stmt.ResultColumns { @@ -360,15 +334,26 @@ func (sp *selectPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { func (p *selectExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { p.setResultHeader() - p.queryPlan.plan.compile() - p.executionPlan.Commands = p.queryPlan.plan.commands + p.queryPlan.compile() + p.executionPlan.Commands = p.queryPlan.commands return p.executionPlan, nil } func (p *selectExecutionPlanner) setResultHeader() { resultHeader := []string{} - for range p.queryPlan.projections { - resultHeader = append(resultHeader, "unknown") + switch t := p.queryPlan.root.(type) { + case *projectNodeV2: + projectExprs := []compiler.Expr{} + for _, projection := range t.projections { + resultHeader = append(resultHeader, projection.alias) + projectExprs = append(projectExprs, projection.expr) + } + p.setResultTypes(projectExprs) + case *countNodeV2: + resultHeader = append(resultHeader, t.projection.alias) + p.setResultTypes([]compiler.Expr{t.projection.expr}) + default: + panic("unhandled node for result header") } p.executionPlan.ResultHeader = resultHeader } diff --git a/planner/select_test.go b/planner/select_test.go index dc2c08e..d9bf29e 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -306,12 +306,13 @@ func TestSelectPlan(t *testing.T) { { description: "JustCountAggregate", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.InitCmd{P2: 4}, &vm.CountCmd{P1: 1, P2: 1}, &vm.ResultRowCmd{P1: 1, P2: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ StmtBase: &compiler.StmtBase{}, From b5e2dfb7fdf0373bc7f2dbabda3371fed00ac98c Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 21 Dec 2025 23:32:58 -0700 Subject: [PATCH 06/17] fix bug and remove dead code --- planner/crvisitor.go | 124 --------------------------------- planner/node.go | 48 +------------ planner/plan.go | 57 +++------------ planner/plan_test.go | 85 +++++++++++----------- planner/predicate_generator.go | 10 +-- planner/result_generator.go | 10 +-- 6 files changed, 56 insertions(+), 278 deletions(-) delete mode 100644 planner/crvisitor.go diff --git a/planner/crvisitor.go b/planner/crvisitor.go deleted file mode 100644 index c9ba35b..0000000 --- a/planner/crvisitor.go +++ /dev/null @@ -1,124 +0,0 @@ -package planner - -import ( - "slices" - "strings" - - "github.com/chirst/cdb/compiler" - "github.com/chirst/cdb/vm" -) - -// constantRegisterVisitor fills constantRegisters with constants in the visited -// node. -type constantRegisterVisitor struct { - // nextOpenRegister counts upwards reserving registers as needed. - nextOpenRegister int - // constantRegisters is a mapping of register value to register index. - constantRegisters map[int]int - // variableRegisters is a mapping of variable indices to registers. - variableRegisters map[int]int - // stringRegisters is a mapping of string constants to registers - stringRegisters map[string]int -} - -func (c *constantRegisterVisitor) Init(openRegister int) { - c.constantRegisters = make(map[int]int) - c.variableRegisters = make(map[int]int) - c.stringRegisters = make(map[string]int) - c.nextOpenRegister = openRegister -} - -// GetRegisterCommands returns commands to fill the current constantRegister -// map. -func (c *constantRegisterVisitor) GetRegisterCommands() []vm.Command { - // Maps are unordered so there is some extra work to keep commands in order. - lc := c.getOrderedLitCommands() - vc := c.getOrderedVariableCommands() - sc := c.getOrderedStringCommands() - return append(lc, append(vc, sc...)...) -} - -func (c *constantRegisterVisitor) getOrderedLitCommands() []vm.Command { - unordered := []*vm.IntegerCmd{} - for k := range c.constantRegisters { - unordered = append(unordered, &vm.IntegerCmd{P1: k, P2: c.constantRegisters[k]}) - } - slices.SortFunc(unordered, func(a, b *vm.IntegerCmd) int { - return a.P2 - b.P2 - }) - ret := []vm.Command{} - for _, s := range unordered { - ret = append(ret, vm.Command(s)) - } - return ret -} - -func (c *constantRegisterVisitor) getOrderedVariableCommands() []vm.Command { - unordered := []*vm.VariableCmd{} - for k := range c.variableRegisters { - unordered = append(unordered, &vm.VariableCmd{P1: k, P2: c.variableRegisters[k]}) - } - slices.SortFunc(unordered, func(a, b *vm.VariableCmd) int { - return a.P2 - b.P2 - }) - ret := []vm.Command{} - for _, s := range unordered { - ret = append(ret, vm.Command(s)) - } - return ret -} - -func (c *constantRegisterVisitor) getOrderedStringCommands() []vm.Command { - unordered := []*vm.StringCmd{} - for k := range c.stringRegisters { - unordered = append(unordered, &vm.StringCmd{P1: c.stringRegisters[k], P4: k}) - } - slices.SortFunc(unordered, func(a, b *vm.StringCmd) int { - return strings.Compare(a.P4, b.P4) - }) - ret := []vm.Command{} - for _, s := range unordered { - ret = append(ret, vm.Command(s)) - } - return ret -} - -func (c *constantRegisterVisitor) VisitIntLit(e *compiler.IntLit) { - c.fillRegisterIfNeeded(e.Value) -} - -func (c *constantRegisterVisitor) fillRegisterIfNeeded(v int) { - found := false - for k := range c.constantRegisters { - if k == v { - found = true - } - } - if !found { - c.constantRegisters[v] = c.nextOpenRegister - c.nextOpenRegister += 1 - } -} - -func (c *constantRegisterVisitor) VisitVariable(v *compiler.Variable) { - c.variableRegisters[v.Position] = c.nextOpenRegister - c.nextOpenRegister += 1 -} - -func (c *constantRegisterVisitor) VisitStringLit(e *compiler.StringLit) { - found := false - for k := range c.stringRegisters { - if k == e.Value { - found = true - } - } - if !found { - c.stringRegisters[e.Value] = c.nextOpenRegister - c.nextOpenRegister += 1 - } -} - -func (c *constantRegisterVisitor) VisitBinaryExpr(e *compiler.BinaryExpr) {} -func (c *constantRegisterVisitor) VisitUnaryExpr(e *compiler.UnaryExpr) {} -func (c *constantRegisterVisitor) VisitColumnRefExpr(e *compiler.ColumnRef) {} -func (c *constantRegisterVisitor) VisitFunctionExpr(e *compiler.FunctionExpr) {} diff --git a/planner/node.go b/planner/node.go index 5ee7468..1a8c946 100644 --- a/planner/node.go +++ b/planner/node.go @@ -10,54 +10,8 @@ type logicalNode interface { print() string } -// projectNode defines what columns should be projected. -type projectNode struct { - projections []projection - child logicalNode -} - -// projection is part of the sum of projections in a project node. -type projection struct { - // isCount signifies the projection is the count function. - isCount bool - // colName is the name of the column to be projected. - colName string -} - -// scanNode represents a full scan on a table -type scanNode struct { - // tableName is the name of the table to be scanned - tableName string - // rootPage is the valid page number corresponding to the table - rootPage int - // scanColumns contains information about how the scan will project columns - scanColumns []scanColumn - // scanPredicate is an expression evaluated as a boolean. This behaves as a - // filter in the scan. - scanPredicate compiler.Expr -} - -type scanColumn = compiler.Expr - -// constantNode is used in select statements where there is no table. -type constantNode struct { - // resultColumns are the result columns containing expressions. - resultColumns []compiler.Expr - // predicate filters the result depending on the result of the expression. - predicate compiler.Expr -} - -// countNode represents a special optimization when a table needs a full count -// with no filtering or other projections. -type countNode struct { - // tableName is the name of the table to be scanned - tableName string - // rootPage is the valid page number corresponding to the table - rootPage int -} - // TODO joinNode is unused, but remains as a prototype binary operation node. -type joinNode struct { +type joinNodeV2 struct { // left is the left subtree of the join. left logicalNode // right is the right subtree of the join. diff --git a/planner/plan.go b/planner/plan.go index 35877e0..312f262 100644 --- a/planner/plan.go +++ b/planner/plan.go @@ -101,18 +101,6 @@ func (p *QueryPlan) connectSiblings() string { return strings.Join(planMatrix, "\n") } -func (p *projectNode) print() string { - list := "(" - for i, proj := range p.projections { - list += proj.print() - if i+1 < len(p.projections) { - list += ", " - } - } - list += ")" - return "project" + list -} - func (p *projectNodeV2) print() string { return "project" } @@ -121,32 +109,19 @@ func (p *projectNodeV2) children() []logicalNode { return []logicalNode{} } -func (p *projection) print() string { - if p.isCount { - return "count(*)" - } - if p.colName == "" { - return "" - } - return p.colName -} - -func (s *scanNode) print() string { - if s.scanPredicate != nil { - return fmt.Sprintf("scan table %s with predicate", s.tableName) - } - return fmt.Sprintf("scan table %s", s.tableName) +func (s *scanNodeV2) print() string { + return fmt.Sprintf("scan table") } -func (c *constantNode) print() string { +func (c *constantNodeV2) print() string { return "constant data source" } -func (c *countNode) print() string { - return fmt.Sprintf("count table %s", c.tableName) +func (c *countNodeV2) print() string { + return fmt.Sprintf("count table") } -func (j *joinNode) print() string { +func (j *joinNodeV2) print() string { return fmt.Sprint(j.operation) } @@ -165,23 +140,19 @@ func (u *updateNodeV2) print() string { return "update" } -func (p *projectNode) children() []logicalNode { - return []logicalNode{p.child} -} - -func (s *scanNode) children() []logicalNode { +func (s *scanNodeV2) children() []logicalNode { return []logicalNode{} } -func (c *constantNode) children() []logicalNode { +func (c *constantNodeV2) children() []logicalNode { return []logicalNode{} } -func (c *countNode) children() []logicalNode { +func (c *countNodeV2) children() []logicalNode { return []logicalNode{} } -func (j *joinNode) children() []logicalNode { +func (j *joinNodeV2) children() []logicalNode { return []logicalNode{j.left, j.right} } @@ -196,11 +167,3 @@ func (i *insertNode) children() []logicalNode { func (u *updateNodeV2) children() []logicalNode { return []logicalNode{} } - -func (p *countNodeV2) print() string { - return "count" -} - -func (p *countNodeV2) children() []logicalNode { - return []logicalNode{} -} diff --git a/planner/plan_test.go b/planner/plan_test.go index 2fcbaf4..2204144 100644 --- a/planner/plan_test.go +++ b/planner/plan_test.go @@ -1,48 +1,41 @@ package planner -import "testing" - -func TestExplainQueryPlan(t *testing.T) { - root := &projectNode{ - projections: []projection{ - {colName: "id"}, - {colName: "first_name"}, - {colName: "last_name"}, - }, - child: &joinNode{ - operation: "join", - left: &joinNode{ - operation: "join", - left: &scanNode{ - tableName: "foo", - }, - right: &joinNode{ - operation: "join", - left: &scanNode{ - tableName: "baz", - }, - right: &scanNode{ - tableName: "buzz", - }, - }, - }, - right: &scanNode{ - tableName: "bar", - }, - }, - } - qp := newQueryPlan(root, true) - formattedResult := qp.ToString() - expectedResult := "" + - " ── project(id, first_name, last_name)\n" + - " └─ join\n" + - " ├─ join\n" + - " | ├─ scan table foo\n" + - " | └─ join\n" + - " | ├─ scan table baz\n" + - " | └─ scan table buzz\n" + - " └─ scan table bar\n" - if formattedResult != expectedResult { - t.Fatalf("got\n%s\nwant\n%s", formattedResult, expectedResult) - } -} +// func TestExplainQueryPlan(t *testing.T) { +// root := &projectNodeV2{ +// child: &joinNodeV2{ +// operation: "join", +// left: &joinNodeV2{ +// operation: "join", +// left: &scanNodeV2{ +// tableName: "foo", +// }, +// right: &joinNodeV2{ +// operation: "join", +// left: &scanNodeV2{ +// tableName: "baz", +// }, +// right: &scanNodeV2{ +// tableName: "buzz", +// }, +// }, +// }, +// right: &scanNodeV2{ +// tableName: "bar", +// }, +// }, +// } +// qp := newQueryPlan(root, true) +// formattedResult := qp.ToString() +// expectedResult := "" + +// " ── project\n" + +// " └─ join\n" + +// " ├─ join\n" + +// " | ├─ scan table foo\n" + +// " | └─ join\n" + +// " | ├─ scan table baz\n" + +// " | └─ scan table buzz\n" + +// " └─ scan table bar\n" +// if formattedResult != expectedResult { +// t.Fatalf("got\n%s\nwant\n%s", formattedResult, expectedResult) +// } +// } diff --git a/planner/predicate_generator.go b/planner/predicate_generator.go index 0ef6d47..1dd464a 100644 --- a/planner/predicate_generator.go +++ b/planner/predicate_generator.go @@ -11,7 +11,6 @@ import ( func generatePredicate(plan *planV2, expression compiler.Expr) vm.JumpCommand { pg := &predicateGenerator{} pg.plan = plan - pg.commandOffset = len(pg.plan.commands) pg.build(expression, 0) return pg.jumpCommand } @@ -20,9 +19,6 @@ func generatePredicate(plan *planV2, expression compiler.Expr) vm.JumpCommand { // expression. type predicateGenerator struct { plan *planV2 - // commandOffset is used to calculate the amount of commands already in the - // plan. - commandOffset int // jumpCommand is the command used to make the jump. The command can be // accessed to defer setting the jump address. jumpCommand vm.JumpCommand @@ -105,7 +101,7 @@ func (p *predicateGenerator) build(e compiler.Expr, level int) (int, error) { } p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) jumpOverCount := 2 - jumpAddress := len(p.plan.commands) + jumpOverCount + p.commandOffset + jumpAddress := len(p.plan.commands) + jumpOverCount p.plan.commands = append( p.plan.commands, &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, @@ -121,7 +117,7 @@ func (p *predicateGenerator) build(e compiler.Expr, level int) (int, error) { } p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) jumpOverCount := 2 - jumpAddress := len(p.plan.commands) + jumpOverCount + p.commandOffset + jumpAddress := len(p.plan.commands) + jumpOverCount p.plan.commands = append( p.plan.commands, &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, @@ -137,7 +133,7 @@ func (p *predicateGenerator) build(e compiler.Expr, level int) (int, error) { } p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) jumpOverCount := 2 - jumpAddress := len(p.plan.commands) + jumpOverCount + p.commandOffset + jumpAddress := len(p.plan.commands) + jumpOverCount p.plan.commands = append( p.plan.commands, &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, diff --git a/planner/result_generator.go b/planner/result_generator.go index 739a09e..2ebe1d1 100644 --- a/planner/result_generator.go +++ b/planner/result_generator.go @@ -11,7 +11,6 @@ func generateExpressionTo(plan *planV2, expr compiler.Expr, toRegister int) { rg := &resultExprGenerator{} rg.plan = plan rg.outputRegister = toRegister - rg.commandOffset = len(rg.plan.commands) rg.build(expr, 0) } @@ -20,9 +19,6 @@ type resultExprGenerator struct { plan *planV2 // outputRegister is the target register for the result of the expression. outputRegister int - // commandOffset is the amount of commands prior to calling this routine. - // Useful for calculating jump instructions. - commandOffset int } func (e *resultExprGenerator) build(root compiler.Expr, level int) int { @@ -45,7 +41,7 @@ func (e *resultExprGenerator) build(root compiler.Expr, level int) int { case compiler.OpEq: e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) jumpOverCount := 2 - jumpAddress := len(e.plan.commands) + jumpOverCount + e.commandOffset + jumpAddress := len(e.plan.commands) + jumpOverCount e.plan.commands = append( e.plan.commands, &vm.NotEqualCmd{P1: ol, P2: jumpAddress, P3: or}, @@ -54,7 +50,7 @@ func (e *resultExprGenerator) build(root compiler.Expr, level int) int { case compiler.OpLt: e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) jumpOverCount := 2 - jumpAddress := len(e.plan.commands) + jumpOverCount + e.commandOffset + jumpAddress := len(e.plan.commands) + jumpOverCount e.plan.commands = append( e.plan.commands, &vm.GteCmd{P1: ol, P2: jumpAddress, P3: or}, @@ -63,7 +59,7 @@ func (e *resultExprGenerator) build(root compiler.Expr, level int) int { case compiler.OpGt: e.plan.commands = append(e.plan.commands, &vm.IntegerCmd{P1: 0, P2: r}) jumpOverCount := 2 - jumpAddress := len(e.plan.commands) + jumpOverCount + e.commandOffset + jumpAddress := len(e.plan.commands) + jumpOverCount e.plan.commands = append( e.plan.commands, &vm.LteCmd{P1: ol, P2: jumpAddress, P3: or}, From 04a8586d641e8665b9ec32d6a04ff9445b4de0b1 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 28 Dec 2025 00:47:59 -0700 Subject: [PATCH 07/17] bridge v1 and v2 nodes --- planner/create.go | 14 ++- planner/generator.go | 150 +++++++++++++-------------------- planner/insert.go | 7 +- planner/node.go | 4 +- planner/plan.go | 92 +++++++++++++++----- planner/plan_test.go | 72 ++++++++-------- planner/predicate_generator.go | 4 +- planner/result_generator.go | 4 +- planner/select.go | 51 +++++------ planner/update.go | 26 +++--- 10 files changed, 229 insertions(+), 195 deletions(-) diff --git a/planner/create.go b/planner/create.go index 5e8b831..4ed40bf 100644 --- a/planner/create.go +++ b/planner/create.go @@ -88,7 +88,12 @@ func (p *createQueryPlanner) getQueryPlan() (*QueryPlan, error) { tableName: p.stmt.TableName, } p.queryPlan = noopCreateNode - return newQueryPlan(noopCreateNode, p.stmt.ExplainQueryPlan), nil + return newQueryPlan( + noopCreateNode, + p.stmt.ExplainQueryPlan, + transactionTypeWrite, + 0, + ), nil } if tableExists { return nil, errTableExists @@ -104,7 +109,12 @@ func (p *createQueryPlanner) getQueryPlan() (*QueryPlan, error) { schema: jSchema, } p.queryPlan = createNode - return newQueryPlan(createNode, p.stmt.ExplainQueryPlan), nil + return newQueryPlan( + createNode, + p.stmt.ExplainQueryPlan, + transactionTypeWrite, + 0, + ), nil } func (p *createQueryPlanner) getSchemaString() (string, error) { diff --git a/planner/generator.go b/planner/generator.go index eb63a85..8a7f7d8 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -16,53 +16,10 @@ const ( transactionTypeWrite transactionType = 2 ) -// planV2 holds the necessary data and receivers for generating a plan as well -// as the final commands that define the execution plan. -type planV2 struct { - // root is the root node of the plan tree. - root nodeV2 - // commands is a list of commands that define the plan. - commands []vm.Command - // constInts is a mapping of constant integer values to the registers that - // contain the value. - constInts map[int]int - // constStrings is a mapping of constant string values to the registers that - // contain the value. - constStrings map[string]int - // constVars is a mapping of a variable's position to the registers that - // holds the variable's value. - constVars map[int]int - // freeRegister is a counter containing the next free register in the plan. - freeRegister int - // transactionType defines what kind of transaction the plan will need. - transactionType transactionType - // cursorId is the id of the cursor the plan is using. Note plans will - // eventually need to use more than one cursor, but for now it is convenient - // to pull the id from here. - cursorId int - // rootPageNumber is the root page number of the table cursorId is - // associated with. This should be a map at some point when multiple tables - // can be queried in one plan. - rootPageNumber int -} - -func newPlan(transactionType transactionType, rootPageNumber int) *planV2 { - return &planV2{ - commands: []vm.Command{}, - constInts: make(map[int]int), - constStrings: make(map[string]int), - constVars: make(map[int]int), - freeRegister: 1, - transactionType: transactionType, - cursorId: 1, - rootPageNumber: rootPageNumber, - } -} - // declareConstInt gets or sets a register with the const value and returns the // register. It is guaranteed the value will be in the register for the duration // of the plan. -func (p *planV2) declareConstInt(i int) int { +func (p *QueryPlan) declareConstInt(i int) int { _, ok := p.constInts[i] if !ok { p.constInts[i] = p.freeRegister @@ -74,7 +31,7 @@ func (p *planV2) declareConstInt(i int) int { // declareConstString gets or sets a register with the const value and returns // the register. It is guaranteed the value will be in the register for the // duration of the plan. -func (p *planV2) declareConstString(s string) int { +func (p *QueryPlan) declareConstString(s string) int { _, ok := p.constStrings[s] if !ok { p.constStrings[s] = p.freeRegister @@ -86,7 +43,7 @@ func (p *planV2) declareConstString(s string) int { // declareConstVar gets or sets a register with the const value and returns // the register. It is guaranteed the value will be in the register for the // duration of the plan. -func (p *planV2) declareConstVar(position int) int { +func (p *QueryPlan) declareConstVar(position int) int { _, ok := p.constVars[position] if !ok { p.constVars[position] = p.freeRegister @@ -95,7 +52,7 @@ func (p *planV2) declareConstVar(position int) int { return p.constVars[position] } -func (p *planV2) compile() { +func (p *QueryPlan) compile() { initCmd := &vm.InitCmd{} p.commands = append(p.commands, initCmd) p.root.produce() @@ -106,7 +63,7 @@ func (p *planV2) compile() { p.commands = append(p.commands, &vm.GotoCmd{P2: 1}) } -func (p *planV2) pushTransaction() { +func (p *QueryPlan) pushTransaction() { switch p.transactionType { case transactionTypeNone: return @@ -133,7 +90,7 @@ func (p *planV2) pushTransaction() { } } -func (p *planV2) pushConstants() { +func (p *QueryPlan) pushConstants() { // these constants are pushed ordered since maps are unordered making it // difficult to assert that a sequence of instructions appears. p.pushConstantInts() @@ -141,7 +98,7 @@ func (p *planV2) pushConstants() { p.pushConstantVars() } -func (p *planV2) pushConstantInts() { +func (p *QueryPlan) pushConstantInts() { temp := []*vm.IntegerCmd{} for k := range p.constInts { temp = append(temp, &vm.IntegerCmd{P1: k, P2: p.constInts[k]}) @@ -154,7 +111,7 @@ func (p *planV2) pushConstantInts() { } } -func (p *planV2) pushConstantStrings() { +func (p *QueryPlan) pushConstantStrings() { temp := []*vm.StringCmd{} for v := range p.constStrings { p.commands = append(p.commands, &vm.StringCmd{P1: p.constStrings[v], P4: v}) @@ -167,7 +124,7 @@ func (p *planV2) pushConstantStrings() { } } -func (p *planV2) pushConstantVars() { +func (p *QueryPlan) pushConstantVars() { temp := []*vm.VariableCmd{} for v := range p.constVars { p.commands = append(p.commands, &vm.VariableCmd{P1: v, P2: p.constVars[v]}) @@ -180,14 +137,9 @@ func (p *planV2) pushConstantVars() { } } -type nodeV2 interface { - produce() - consume() -} - -type updateNodeV2 struct { - child nodeV2 - plan *planV2 +type updateNode struct { + child logicalNode + plan *QueryPlan // updateExprs is formed from the update statement AST. The idea is to // provide an expression for each column where the expression is either a // columnRef or the complex expression from the right hand side of the SET @@ -201,11 +153,11 @@ type updateNodeV2 struct { updateExprs []compiler.Expr } -func (u *updateNodeV2) produce() { +func (u *updateNode) produce() { u.child.produce() } -func (u *updateNodeV2) consume() { +func (u *updateNode) consume() { // RowID u.plan.commands = append(u.plan.commands, &vm.RowIdCmd{ P1: u.plan.cursorId, @@ -243,33 +195,33 @@ func (u *updateNodeV2) consume() { }) } -type filterNodeV2 struct { - child nodeV2 - parent nodeV2 - plan *planV2 +type filterNode struct { + child logicalNode + parent logicalNode + plan *QueryPlan predicate compiler.Expr } -func (f *filterNodeV2) produce() { +func (f *filterNode) produce() { f.child.produce() } -func (f *filterNodeV2) consume() { +func (f *filterNode) consume() { jumpCommand := generatePredicate(f.plan, f.predicate) f.parent.consume() jumpCommand.SetJumpAddress(len(f.plan.commands)) } -type scanNodeV2 struct { - parent nodeV2 - plan *planV2 +type scanNode struct { + parent logicalNode + plan *QueryPlan } -func (s *scanNodeV2) produce() { +func (s *scanNode) produce() { s.consume() } -func (s *scanNodeV2) consume() { +func (s *scanNode) consume() { rewindCmd := &vm.RewindCmd{P1: s.plan.cursorId} s.plan.commands = append(s.plan.commands, rewindCmd) loopBeginAddress := len(s.plan.commands) @@ -281,23 +233,23 @@ func (s *scanNodeV2) consume() { rewindCmd.P2 = len(s.plan.commands) } -type projectionV2 struct { +type projection struct { expr compiler.Expr // alias is the alias of the projection or no alias for the zero value. alias string } -type projectNodeV2 struct { - child nodeV2 - plan *planV2 - projections []projectionV2 +type projectNode struct { + child logicalNode + plan *QueryPlan + projections []projection } -func (p *projectNodeV2) produce() { +func (p *projectNode) produce() { p.child.produce() } -func (p *projectNodeV2) consume() { +func (p *projectNode) consume() { startRegister := p.plan.freeRegister reservedRegisters := len(p.projections) p.plan.freeRegister += reservedRegisters @@ -310,29 +262,29 @@ func (p *projectNodeV2) consume() { }) } -type constantNodeV2 struct { - parent nodeV2 - plan *planV2 +type constantNode struct { + parent logicalNode + plan *QueryPlan } -func (c *constantNodeV2) produce() { +func (c *constantNode) produce() { c.consume() } -func (c *constantNodeV2) consume() { +func (c *constantNode) consume() { c.parent.consume() } -type countNodeV2 struct { - plan *planV2 - projection projectionV2 +type countNode struct { + plan *QueryPlan + projection projection } -func (c *countNodeV2) produce() { +func (c *countNode) produce() { c.consume() } -func (c *countNodeV2) consume() { +func (c *countNode) consume() { c.plan.commands = append(c.plan.commands, &vm.CountCmd{ P1: c.plan.cursorId, P2: c.plan.freeRegister, @@ -345,3 +297,21 @@ func (c *countNodeV2) consume() { P2: countResults, }) } + +func (c *createNode) produce() { +} + +func (c *createNode) consume() { +} + +func (n *insertNode) produce() { +} + +func (n *insertNode) consume() { +} + +func (n *joinNode) produce() { +} + +func (n *joinNode) consume() { +} diff --git a/planner/insert.go b/planner/insert.go index 912205e..2278b59 100644 --- a/planner/insert.go +++ b/planner/insert.go @@ -107,7 +107,12 @@ func (p *insertQueryPlanner) getQueryPlan() (*QueryPlan, error) { colValues: p.stmt.ColValues, } p.queryPlan = insertNode - return newQueryPlan(insertNode, p.stmt.ExplainQueryPlan), nil + return newQueryPlan( + insertNode, + p.stmt.ExplainQueryPlan, + transactionTypeWrite, + rootPage, + ), nil } // ExecutionPlan returns the bytecode routine for the planner. Calling QueryPlan diff --git a/planner/node.go b/planner/node.go index 1a8c946..1627279 100644 --- a/planner/node.go +++ b/planner/node.go @@ -8,10 +8,12 @@ import "github.com/chirst/cdb/compiler" type logicalNode interface { children() []logicalNode print() string + produce() + consume() } // TODO joinNode is unused, but remains as a prototype binary operation node. -type joinNodeV2 struct { +type joinNode struct { // left is the left subtree of the join. left logicalNode // right is the right subtree of the join. diff --git a/planner/plan.go b/planner/plan.go index 312f262..77128b7 100644 --- a/planner/plan.go +++ b/planner/plan.go @@ -4,24 +4,66 @@ import ( "fmt" "strings" "unicode/utf8" + + "github.com/chirst/cdb/vm" ) -// QueryPlan contains the query plan tree. It is capable of converting the tree -// to a string representation for a query prefixed with `EXPLAIN QUERY PLAN`. +// QueryPlan contains the query plan tree stemming from the root node. It is +// capable of converting the tree to a string representation for a query +// prefixed with `EXPLAIN QUERY PLAN`. +// +// The structure also holds the necessary data and receivers for generating a +// plan as well as the final commands that define the execution plan. type QueryPlan struct { // plan holds the string representation also known as the tree. plan string - // root holds the root node of the query plan + // root is the root node of the plan tree. root logicalNode // ExplainQueryPlan is a flag indicating if the SQL asked for the query plan // to be printed as a string representation with `EXPLAIN QUERY PLAN`. ExplainQueryPlan bool -} - -func newQueryPlan(root logicalNode, explainQueryPlan bool) *QueryPlan { + // commands is a list of commands that define the plan. + commands []vm.Command + // constInts is a mapping of constant integer values to the registers that + // contain the value. + constInts map[int]int + // constStrings is a mapping of constant string values to the registers that + // contain the value. + constStrings map[string]int + // constVars is a mapping of a variable's position to the registers that + // holds the variable's value. + constVars map[int]int + // freeRegister is a counter containing the next free register in the plan. + freeRegister int + // transactionType defines what kind of transaction the plan will need. + transactionType transactionType + // cursorId is the id of the cursor the plan is using. Note plans will + // eventually need to use more than one cursor, but for now it is convenient + // to pull the id from here. + cursorId int + // rootPageNumber is the root page number of the table cursorId is + // associated with. This should be a map at some point when multiple tables + // can be queried in one plan. + rootPageNumber int +} + +func newQueryPlan( + root logicalNode, + explainQueryPlan bool, + transactionType transactionType, + rootPageNumber int, +) *QueryPlan { return &QueryPlan{ root: root, ExplainQueryPlan: explainQueryPlan, + commands: []vm.Command{}, + constInts: make(map[int]int), + constStrings: make(map[string]int), + constVars: make(map[int]int), + freeRegister: 1, + transactionType: transactionType, + cursorId: 1, + rootPageNumber: rootPageNumber, } } @@ -101,27 +143,27 @@ func (p *QueryPlan) connectSiblings() string { return strings.Join(planMatrix, "\n") } -func (p *projectNodeV2) print() string { +func (p *projectNode) print() string { return "project" } -func (p *projectNodeV2) children() []logicalNode { - return []logicalNode{} +func (p *projectNode) children() []logicalNode { + return []logicalNode{p.child} } -func (s *scanNodeV2) print() string { - return fmt.Sprintf("scan table") +func (s *scanNode) print() string { + return "scan table" } -func (c *constantNodeV2) print() string { +func (c *constantNode) print() string { return "constant data source" } -func (c *countNodeV2) print() string { - return fmt.Sprintf("count table") +func (c *countNode) print() string { + return "count table" } -func (j *joinNodeV2) print() string { +func (j *joinNode) print() string { return fmt.Sprint(j.operation) } @@ -136,23 +178,27 @@ func (i *insertNode) print() string { return "insert" } -func (u *updateNodeV2) print() string { +func (u *updateNode) print() string { return "update" } -func (s *scanNodeV2) children() []logicalNode { +func (f *filterNode) print() string { + return "filter" +} + +func (s *scanNode) children() []logicalNode { return []logicalNode{} } -func (c *constantNodeV2) children() []logicalNode { +func (c *constantNode) children() []logicalNode { return []logicalNode{} } -func (c *countNodeV2) children() []logicalNode { +func (c *countNode) children() []logicalNode { return []logicalNode{} } -func (j *joinNodeV2) children() []logicalNode { +func (j *joinNode) children() []logicalNode { return []logicalNode{j.left, j.right} } @@ -164,6 +210,10 @@ func (i *insertNode) children() []logicalNode { return []logicalNode{} } -func (u *updateNodeV2) children() []logicalNode { +func (u *updateNode) children() []logicalNode { return []logicalNode{} } + +func (f *filterNode) children() []logicalNode { + return []logicalNode{f.child} +} diff --git a/planner/plan_test.go b/planner/plan_test.go index 2204144..eb5b9e6 100644 --- a/planner/plan_test.go +++ b/planner/plan_test.go @@ -1,41 +1,35 @@ package planner -// func TestExplainQueryPlan(t *testing.T) { -// root := &projectNodeV2{ -// child: &joinNodeV2{ -// operation: "join", -// left: &joinNodeV2{ -// operation: "join", -// left: &scanNodeV2{ -// tableName: "foo", -// }, -// right: &joinNodeV2{ -// operation: "join", -// left: &scanNodeV2{ -// tableName: "baz", -// }, -// right: &scanNodeV2{ -// tableName: "buzz", -// }, -// }, -// }, -// right: &scanNodeV2{ -// tableName: "bar", -// }, -// }, -// } -// qp := newQueryPlan(root, true) -// formattedResult := qp.ToString() -// expectedResult := "" + -// " ── project\n" + -// " └─ join\n" + -// " ├─ join\n" + -// " | ├─ scan table foo\n" + -// " | └─ join\n" + -// " | ├─ scan table baz\n" + -// " | └─ scan table buzz\n" + -// " └─ scan table bar\n" -// if formattedResult != expectedResult { -// t.Fatalf("got\n%s\nwant\n%s", formattedResult, expectedResult) -// } -// } +import "testing" + +func TestExplainQueryPlan(t *testing.T) { + root := &projectNode{ + child: &joinNode{ + operation: "join", + left: &joinNode{ + operation: "join", + left: &scanNode{}, + right: &joinNode{ + operation: "join", + left: &scanNode{}, + right: &scanNode{}, + }, + }, + right: &scanNode{}, + }, + } + qp := newQueryPlan(root, true, transactionTypeRead, 0) + formattedResult := qp.ToString() + expectedResult := "" + + " ── project\n" + + " └─ join\n" + + " ├─ join\n" + + " | ├─ scan table\n" + + " | └─ join\n" + + " | ├─ scan table\n" + + " | └─ scan table\n" + + " └─ scan table\n" + if formattedResult != expectedResult { + t.Fatalf("got\n%s\nwant\n%s", formattedResult, expectedResult) + } +} diff --git a/planner/predicate_generator.go b/planner/predicate_generator.go index 1dd464a..9380568 100644 --- a/planner/predicate_generator.go +++ b/planner/predicate_generator.go @@ -8,7 +8,7 @@ import ( // generatePredicate generates code to make a boolean jump for the given // expression within the plan context. The function returns the jump command to // lazily set the jump address. -func generatePredicate(plan *planV2, expression compiler.Expr) vm.JumpCommand { +func generatePredicate(plan *QueryPlan, expression compiler.Expr) vm.JumpCommand { pg := &predicateGenerator{} pg.plan = plan pg.build(expression, 0) @@ -18,7 +18,7 @@ func generatePredicate(plan *planV2, expression compiler.Expr) vm.JumpCommand { // predicateGenerator builds commands to calculate the boolean result of an // expression. type predicateGenerator struct { - plan *planV2 + plan *QueryPlan // jumpCommand is the command used to make the jump. The command can be // accessed to defer setting the jump address. jumpCommand vm.JumpCommand diff --git a/planner/result_generator.go b/planner/result_generator.go index 2ebe1d1..89ff355 100644 --- a/planner/result_generator.go +++ b/planner/result_generator.go @@ -7,7 +7,7 @@ import ( // generateExpressionTo takes the context of the plan and generates commands // that land the result of the given expr in the toRegister. -func generateExpressionTo(plan *planV2, expr compiler.Expr, toRegister int) { +func generateExpressionTo(plan *QueryPlan, expr compiler.Expr, toRegister int) { rg := &resultExprGenerator{} rg.plan = plan rg.outputRegister = toRegister @@ -16,7 +16,7 @@ func generateExpressionTo(plan *planV2, expr compiler.Expr, toRegister int) { // resultExprGenerator builds commands for the given expression. type resultExprGenerator struct { - plan *planV2 + plan *QueryPlan // outputRegister is the target register for the result of the expression. outputRegister int } diff --git a/planner/select.go b/planner/select.go index e5295fa..bc6defd 100644 --- a/planner/select.go +++ b/planner/select.go @@ -42,7 +42,7 @@ type selectQueryPlanner struct { stmt *compiler.SelectStmt // queryPlan contains the logical plan being built. The root node must be a // projection. - queryPlan *planV2 + queryPlan *QueryPlan } // selectExecutionPlanner converts logical nodes in a query plan tree to @@ -50,7 +50,7 @@ type selectQueryPlanner struct { type selectExecutionPlanner struct { // queryPlan contains the logical plan. This node is populated by calling // the QueryPlan method. - queryPlan *planV2 + queryPlan *QueryPlan // executionPlan contains the execution plan for the vm. This is built by // calling ExecutionPlan. executionPlan *vm.ExecutionPlan @@ -143,40 +143,43 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { if tableName == "" { return nil, errors.New("must have from for COUNT") } - plan := newPlan(transactionTypeRead, rootPageNumber) - cn := &countNodeV2{plan: plan, projection: projections[0]} + cn := &countNode{projection: projections[0]} + plan := newQueryPlan( + cn, + p.stmt.ExplainQueryPlan, + transactionTypeRead, + rootPageNumber, + ) + cn.plan = plan p.queryPlan = plan - plan.root = cn - return newQueryPlan(cn, p.stmt.ExplainQueryPlan), nil + return plan, nil } tt := transactionTypeRead if tableName == "" { tt = transactionTypeNone } - plan := newPlan(tt, rootPageNumber) - projectNode := &projectNodeV2{ - plan: plan, - projections: projections, - } + projectNode := &projectNode{projections: projections} + plan := newQueryPlan(projectNode, p.stmt.ExplainQueryPlan, tt, rootPageNumber) + projectNode.plan = plan if p.stmt.Where != nil { cev := &catalogExprVisitor{} cev.Init(p.catalog, tableName) p.stmt.Where.BreadthWalk(cev) - filterNode := &filterNodeV2{ + filterNode := &filterNode{ parent: projectNode, plan: plan, predicate: p.stmt.Where, } projectNode.child = filterNode if tableName == "" { - constNode := &constantNodeV2{ + constNode := &constantNode{ plan: plan, } filterNode.child = constNode constNode.parent = filterNode } else { - scanNode := &scanNodeV2{ + scanNode := &scanNode{ plan: plan, } filterNode.child = scanNode @@ -184,13 +187,13 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { } } else { if tableName == "" { - constNode := &constantNodeV2{ + constNode := &constantNode{ plan: plan, } projectNode.child = constNode constNode.parent = projectNode } else { - scanNode := &scanNodeV2{ + scanNode := &scanNode{ plan: plan, } projectNode.child = scanNode @@ -199,7 +202,7 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { } p.queryPlan = plan plan.root = projectNode - return newQueryPlan(projectNode, p.stmt.ExplainQueryPlan), nil + return plan, nil } func (p *selectQueryPlanner) optimizeResultColumns() error { @@ -280,8 +283,8 @@ func foldExpr(e compiler.Expr) (compiler.Expr, error) { } } -func (p *selectQueryPlanner) getProjections() ([]projectionV2, error) { - var projections []projectionV2 +func (p *selectQueryPlanner) getProjections() ([]projection, error) { + var projections []projection for _, resultColumn := range p.stmt.ResultColumns { if resultColumn.All { cols, err := p.catalog.GetColumns(p.stmt.From.TableName) @@ -289,7 +292,7 @@ func (p *selectQueryPlanner) getProjections() ([]projectionV2, error) { return nil, err } for _, c := range cols { - projections = append(projections, projectionV2{ + projections = append(projections, projection{ expr: &compiler.ColumnRef{ Table: p.stmt.From.TableName, Column: c, @@ -302,7 +305,7 @@ func (p *selectQueryPlanner) getProjections() ([]projectionV2, error) { return nil, err } for _, c := range cols { - projections = append(projections, projectionV2{ + projections = append(projections, projection{ expr: &compiler.ColumnRef{ Table: p.stmt.From.TableName, Column: c, @@ -310,7 +313,7 @@ func (p *selectQueryPlanner) getProjections() ([]projectionV2, error) { }) } } else if resultColumn.Expression != nil { - projections = append(projections, projectionV2{ + projections = append(projections, projection{ expr: resultColumn.Expression, alias: resultColumn.Alias, }) @@ -342,14 +345,14 @@ func (p *selectExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { func (p *selectExecutionPlanner) setResultHeader() { resultHeader := []string{} switch t := p.queryPlan.root.(type) { - case *projectNodeV2: + case *projectNode: projectExprs := []compiler.Expr{} for _, projection := range t.projections { resultHeader = append(resultHeader, projection.alias) projectExprs = append(projectExprs, projection.expr) } p.setResultTypes(projectExprs) - case *countNodeV2: + case *countNode: resultHeader = append(resultHeader, t.projection.alias) p.setResultTypes([]compiler.Expr{t.projection.expr}) default: diff --git a/planner/update.go b/planner/update.go index d2a38a8..3203aa0 100644 --- a/planner/update.go +++ b/planner/update.go @@ -29,12 +29,12 @@ type updatePlanner struct { type updateQueryPlanner struct { catalog updateCatalog stmt *compiler.UpdateStmt - queryPlan *updateNodeV2 + queryPlan *updateNode } // updateExecutionPlanner generates a byte code routine for the given queryPlan. type updateExecutionPlanner struct { - queryPlan *updateNodeV2 + queryPlan *updateNode executionPlan *vm.ExecutionPlan } @@ -70,11 +70,14 @@ func (p *updateQueryPlanner) getQueryPlan() (*QueryPlan, error) { if err != nil { return nil, errTableNotExist } - logicalPlan := newPlan(transactionTypeWrite, rootPage) - updateNode := &updateNodeV2{ - plan: logicalPlan, - updateExprs: []compiler.Expr{}, - } + updateNode := &updateNode{updateExprs: []compiler.Expr{}} + logicalPlan := newQueryPlan( + updateNode, + p.stmt.ExplainQueryPlan, + transactionTypeWrite, + rootPage, + ) + updateNode.plan = logicalPlan p.queryPlan = updateNode logicalPlan.root = updateNode @@ -90,14 +93,14 @@ func (p *updateQueryPlanner) getQueryPlan() (*QueryPlan, error) { return nil, err } - scanNode := &scanNodeV2{ + scanNode := &scanNode{ plan: logicalPlan, } if p.stmt.Predicate != nil { cev := &catalogExprVisitor{} cev.Init(p.catalog, p.stmt.TableName) p.stmt.Predicate.BreadthWalk(cev) - filterNode := &filterNodeV2{ + filterNode := &filterNode{ plan: logicalPlan, predicate: p.stmt.Predicate, parent: updateNode, @@ -110,10 +113,7 @@ func (p *updateQueryPlanner) getQueryPlan() (*QueryPlan, error) { updateNode.child = scanNode } - return &QueryPlan{ - ExplainQueryPlan: p.stmt.ExplainQueryPlan, - root: updateNode, - }, nil + return logicalPlan, nil } // errIfPrimaryKeySet checks the primary key isn't being updated because it From 52fc8ecbb8c0510bc593af9fe476c0dcca13d062 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 28 Dec 2025 01:01:58 -0700 Subject: [PATCH 08/17] simplify select planner --- planner/select.go | 115 +++++++++++++--------------------------------- 1 file changed, 32 insertions(+), 83 deletions(-) diff --git a/planner/select.go b/planner/select.go index bc6defd..7fd39c7 100644 --- a/planner/select.go +++ b/planner/select.go @@ -20,36 +20,13 @@ type selectCatalog interface { } // selectPlanner is capable of generating a logical query plan and a physical -// execution plan for a select statement. The planners within are separated by -// their responsibility. +// execution plan for a select statement. type selectPlanner struct { - // queryPlanner is responsible for transforming the AST to a logical query - // plan tree. This tree is made up of nodes that map closely to a relational - // algebra tree. The query planner also performs binding and validation. - queryPlanner *selectQueryPlanner - // executionPlanner transforms the logical query tree to a bytecode routine, - // built to be ran by the virtual machine. - executionPlanner *selectExecutionPlanner -} - -// selectQueryPlanner converts an AST to a logical query plan. Along the way it -// also validates the AST makes sense with the catalog (a process known as -// binding). -type selectQueryPlanner struct { - // catalog contains the schema + // catalog contains the schema. catalog selectCatalog - // stmt contains the AST + // stmt contains the AST. stmt *compiler.SelectStmt - // queryPlan contains the logical plan being built. The root node must be a - // projection. - queryPlan *QueryPlan -} - -// selectExecutionPlanner converts logical nodes in a query plan tree to -// bytecode that can be run by the vm. -type selectExecutionPlanner struct { - // queryPlan contains the logical plan. This node is populated by calling - // the QueryPlan method. + // queryPlan contains the plan being built. queryPlan *QueryPlan // executionPlan contains the execution plan for the vm. This is built by // calling ExecutionPlan. @@ -59,49 +36,41 @@ type selectExecutionPlanner struct { // NewSelect returns an instance of a select planner for the given AST. func NewSelect(catalog selectCatalog, stmt *compiler.SelectStmt) *selectPlanner { return &selectPlanner{ - queryPlanner: &selectQueryPlanner{ - catalog: catalog, - stmt: stmt, - }, - executionPlanner: &selectExecutionPlanner{ - executionPlan: vm.NewExecutionPlan( - catalog.GetVersion(), - stmt.Explain, - ), - }, + catalog: catalog, + stmt: stmt, + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), } } // QueryPlan generates the query plan tree for the planner. func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { - qp, err := p.queryPlanner.getQueryPlan() + qp, err := p.getQueryPlan() if err != nil { return nil, err } - p.executionPlanner.queryPlan = p.queryPlanner.queryPlan return qp, err } -// getQueryPlan performs several passes on the AST to compute a more manageable -// tree structure of logical operators who closely resemble relational algebra -// operators. -// -// Firstly, getQueryPlan performs simplification to translate the projection -// portion of the select statement to uniform expressions. This means a "*", -// "table.*", or "alias.*" would simply be translated to ColumnRef expressions. -// From here the query is easier to work on as it is one consistent structure. -// -// From here, more simplification is performed. Folding computes constant -// expressions to reduce the complexity of the expression tree. This saves -// instructions ran during a scan. An example of this folding could be the -// binary expression 1 + 1 becoming a constant expression 2. Or a function UPPER -// on a string literal "foo" being simplified to just the string literal "FOO". -// -// Analysis steps are also performed. Such as assigning catalog information to -// ColumnRef expressions. This means associating table names with root page -// numbers, column names with their indices within a tuple, and column names -// with their constraints and available indexes. -func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { +// ExecutionPlan returns the bytecode execution plan for the planner. Calling +// QueryPlan is not a prerequisite to this method as it will be called by +// ExecutionPlan if needed. +func (sp *selectPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { + if sp.queryPlan == nil { + _, err := sp.QueryPlan() + if err != nil { + return nil, err + } + } + sp.setResultHeader() + sp.queryPlan.compile() + sp.executionPlan.Commands = sp.queryPlan.commands + return sp.executionPlan, nil +} + +func (p *selectPlanner) getQueryPlan() (*QueryPlan, error) { err := p.optimizeResultColumns() if err != nil { return nil, err @@ -205,7 +174,7 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { return plan, nil } -func (p *selectQueryPlanner) optimizeResultColumns() error { +func (p *selectPlanner) optimizeResultColumns() error { var err error for i := range p.stmt.ResultColumns { if p.stmt.ResultColumns[i].Expression != nil { @@ -283,7 +252,7 @@ func foldExpr(e compiler.Expr) (compiler.Expr, error) { } } -func (p *selectQueryPlanner) getProjections() ([]projection, error) { +func (p *selectPlanner) getProjections() ([]projection, error) { var projections []projection for _, resultColumn := range p.stmt.ResultColumns { if resultColumn.All { @@ -322,27 +291,7 @@ func (p *selectQueryPlanner) getProjections() ([]projection, error) { return projections, nil } -// ExecutionPlan returns the bytecode execution plan for the planner. Calling -// QueryPlan is not a prerequisite to this method as it will be called by -// ExecutionPlan if needed. -func (sp *selectPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { - if sp.queryPlanner.queryPlan == nil { - _, err := sp.QueryPlan() - if err != nil { - return nil, err - } - } - return sp.executionPlanner.getExecutionPlan() -} - -func (p *selectExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { - p.setResultHeader() - p.queryPlan.compile() - p.executionPlan.Commands = p.queryPlan.commands - return p.executionPlan, nil -} - -func (p *selectExecutionPlanner) setResultHeader() { +func (p *selectPlanner) setResultHeader() { resultHeader := []string{} switch t := p.queryPlan.root.(type) { case *projectNode: @@ -362,7 +311,7 @@ func (p *selectExecutionPlanner) setResultHeader() { } // setResultTypes attempts to precompute the type for each result column expr. -func (p *selectExecutionPlanner) setResultTypes(exprs []compiler.Expr) error { +func (p *selectPlanner) setResultTypes(exprs []compiler.Expr) error { resolvedTypes := []catalog.CdbType{} for _, expr := range exprs { t, err := getExprType(expr) From b3a0d929b0c53cca775557b2a37acff169df57e8 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 28 Dec 2025 01:33:46 -0700 Subject: [PATCH 09/17] simplify create and use generator --- planner/create.go | 113 ++++++++++------------------------------- planner/create_test.go | 39 ++++++-------- planner/generator.go | 14 +++++ planner/node.go | 1 + planner/select.go | 40 ++++++--------- 5 files changed, 72 insertions(+), 135 deletions(-) diff --git a/planner/create.go b/planner/create.go index 4ed40bf..8711875 100644 --- a/planner/create.go +++ b/planner/create.go @@ -17,23 +17,8 @@ type createCatalog interface { } // createPlanner is capable of generating a logical query plan and a physical -// executionPlan for a create statement. The planners within are separated by -// their responsibility. +// executionPlan for a create statement. type createPlanner struct { - // queryPlanner is responsible for transforming the AST to a logical query - // plan tree. This tree is made up of nodes similar to a relational algebra - // tree. The query planner also performs binding and validation. - queryPlanner *createQueryPlanner - // executionPlanner is responsible for converting the logical query plan - // tree to a bytecode execution plan capable of being run by the virtual - // machine. - executionPlanner *createExecutionPlanner -} - -// createQueryPlanner converts the AST to a logical query plan. Along the way it -// validates the statement makes sense with the catalog a process known as -// binding. -type createQueryPlanner struct { // catalog contains the schema catalog createCatalog // stmt contains the AST @@ -41,14 +26,6 @@ type createQueryPlanner struct { // queryPlan contains the query plan being constructed. The root node must // be createNode. queryPlan *createNode -} - -// createExecutionPlanner converts logical nodes to a bytecode execution plan -// that can be run by the vm. -type createExecutionPlanner struct { - // queryPlan contains the logical query plan. The is populated by calling - // QueryPlan. - queryPlan *createNode // executionPlan contains the bytecode execution plan being constructed. // This is populated by calling ExecutionPlan. executionPlan *vm.ExecutionPlan @@ -57,30 +34,18 @@ type createExecutionPlanner struct { // NewCreate creates a planner for the given create statement. func NewCreate(catalog createCatalog, stmt *compiler.CreateStmt) *createPlanner { return &createPlanner{ - queryPlanner: &createQueryPlanner{ - catalog: catalog, - stmt: stmt, - }, - executionPlanner: &createExecutionPlanner{ - executionPlan: vm.NewExecutionPlan( - catalog.GetVersion(), - stmt.Explain, - ), - }, + catalog: catalog, + stmt: stmt, + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), } } // QueryPlan generates the query plan for the planner. func (p *createPlanner) QueryPlan() (*QueryPlan, error) { - qp, err := p.queryPlanner.getQueryPlan() - if err != nil { - return nil, err - } - p.executionPlanner.queryPlan = p.queryPlanner.queryPlan - return qp, err -} - -func (p *createQueryPlanner) getQueryPlan() (*QueryPlan, error) { + schemaTableRoot := 1 tableExists := p.catalog.TableExists(p.stmt.TableName) if p.stmt.IfNotExists && tableExists { noopCreateNode := &createNode{ @@ -88,12 +53,14 @@ func (p *createQueryPlanner) getQueryPlan() (*QueryPlan, error) { tableName: p.stmt.TableName, } p.queryPlan = noopCreateNode - return newQueryPlan( + qp := newQueryPlan( noopCreateNode, p.stmt.ExplainQueryPlan, transactionTypeWrite, - 0, - ), nil + schemaTableRoot, + ) + noopCreateNode.plan = qp + return qp, nil } if tableExists { return nil, errTableExists @@ -109,15 +76,17 @@ func (p *createQueryPlanner) getQueryPlan() (*QueryPlan, error) { schema: jSchema, } p.queryPlan = createNode - return newQueryPlan( + qp := newQueryPlan( createNode, p.stmt.ExplainQueryPlan, transactionTypeWrite, - 0, - ), nil + schemaTableRoot, + ) + createNode.plan = qp + return qp, nil } -func (p *createQueryPlanner) getSchemaString() (string, error) { +func (p *createPlanner) getSchemaString() (string, error) { if err := p.ensurePrimaryKeyCount(); err != nil { return "", err } @@ -134,7 +103,7 @@ func (p *createQueryPlanner) getSchemaString() (string, error) { // The id column must be an integer. The index key is capable of being something // other than an integer, but is not worth implementing at the moment. Integer // primary keys are superior for auto incrementing and being unique. -func (p *createQueryPlanner) ensurePrimaryKeyInteger() error { +func (p *createPlanner) ensurePrimaryKeyInteger() error { hasPK := slices.ContainsFunc(p.stmt.ColDefs, func(cd compiler.ColDef) bool { return cd.PrimaryKey }) @@ -151,7 +120,7 @@ func (p *createQueryPlanner) ensurePrimaryKeyInteger() error { } // Only one primary key is supported at this time. -func (p *createQueryPlanner) ensurePrimaryKeyCount() error { +func (p *createPlanner) ensurePrimaryKeyCount() error { count := 0 for _, cd := range p.stmt.ColDefs { if cd.PrimaryKey { @@ -164,7 +133,7 @@ func (p *createQueryPlanner) ensurePrimaryKeyCount() error { return nil } -func (p *createQueryPlanner) schemaFrom() *catalog.TableSchema { +func (p *createPlanner) schemaFrom() *catalog.TableSchema { schema := catalog.TableSchema{ Columns: []catalog.TableColumn{}, } @@ -182,43 +151,13 @@ func (p *createQueryPlanner) schemaFrom() *catalog.TableSchema { // QueryPlan is not a prerequisite to this method as it will be called by // ExecutionPlan if needed. func (p *createPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { - if p.queryPlanner.queryPlan == nil { + if p.queryPlan == nil { _, err := p.QueryPlan() if err != nil { return nil, err } } - if p.queryPlanner.queryPlan.noop { - return p.executionPlanner.getNoopExecutionPlan(), nil - } - return p.executionPlanner.getExecutionPlan(), nil -} - -// getNoopExecutionPlan asserts the query can be ran based on the information -// provided by the catalog. If the catalog were to go out of date this execution -// plan will be recompiled before it is ever ran. -func (p *createExecutionPlanner) getNoopExecutionPlan() *vm.ExecutionPlan { - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - p.executionPlan.Append(&vm.TransactionCmd{P2: 1}) - p.executionPlan.Append(&vm.HaltCmd{}) - return p.executionPlan -} - -func (p *createExecutionPlanner) getExecutionPlan() *vm.ExecutionPlan { - const cursorId = 1 - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - p.executionPlan.Append(&vm.TransactionCmd{P2: 1}) - p.executionPlan.Append(&vm.CreateBTreeCmd{P2: 1}) - p.executionPlan.Append(&vm.OpenWriteCmd{P1: cursorId, P2: 1}) - p.executionPlan.Append(&vm.NewRowIdCmd{P1: cursorId, P2: 2}) - p.executionPlan.Append(&vm.StringCmd{P1: 3, P4: p.queryPlan.objectType}) - p.executionPlan.Append(&vm.StringCmd{P1: 4, P4: p.queryPlan.objectName}) - p.executionPlan.Append(&vm.StringCmd{P1: 5, P4: p.queryPlan.tableName}) - p.executionPlan.Append(&vm.CopyCmd{P1: 1, P2: 6}) - p.executionPlan.Append(&vm.StringCmd{P1: 7, P4: string(p.queryPlan.schema)}) - p.executionPlan.Append(&vm.MakeRecordCmd{P1: 3, P2: 5, P3: 8}) - p.executionPlan.Append(&vm.InsertCmd{P1: cursorId, P2: 8, P3: 2}) - p.executionPlan.Append(&vm.ParseSchemaCmd{}) - p.executionPlan.Append(&vm.HaltCmd{}) - return p.executionPlan + p.queryPlan.plan.compile() + p.executionPlan.Commands = p.queryPlan.plan.commands + return p.executionPlan, nil } diff --git a/planner/create_test.go b/planner/create_test.go index 88cd493..41937ca 100644 --- a/planner/create_test.go +++ b/planner/create_test.go @@ -2,7 +2,6 @@ package planner import ( "errors" - "reflect" "testing" "github.com/chirst/cdb/catalog" @@ -55,10 +54,8 @@ func TestCreateWithNoIDColumn(t *testing.T) { t.Fatalf("failed to convert expected schema to json %s", err) } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 12}, &vm.CreateBTreeCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 1}, &vm.NewRowIdCmd{P1: 1, P2: 2}, &vm.StringCmd{P1: 3, P4: "table"}, &vm.StringCmd{P1: 4, P4: "foo"}, @@ -69,16 +66,15 @@ func TestCreateWithNoIDColumn(t *testing.T) { &vm.InsertCmd{P1: 1, P2: 8, P3: 2}, &vm.ParseSchemaCmd{}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 1}, + &vm.GotoCmd{P2: 1}, } plan, err := NewCreate(mc, stmt).ExecutionPlan() if err != nil { t.Fatal(err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, expectedCommands) } func TestCreateWithAlternateNamedIDColumn(t *testing.T) { @@ -114,10 +110,8 @@ func TestCreateWithAlternateNamedIDColumn(t *testing.T) { t.Fatalf("failed to convert expected schema to json %s", err) } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 12}, &vm.CreateBTreeCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 1}, &vm.NewRowIdCmd{P1: 1, P2: 2}, &vm.StringCmd{P1: 3, P4: "table"}, &vm.StringCmd{P1: 4, P4: "foo"}, @@ -128,16 +122,15 @@ func TestCreateWithAlternateNamedIDColumn(t *testing.T) { &vm.InsertCmd{P1: 1, P2: 8, P3: 2}, &vm.ParseSchemaCmd{}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 1}, + &vm.GotoCmd{P2: 1}, } plan, err := NewCreate(mc, stmt).ExecutionPlan() if err != nil { t.Fatal(err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, expectedCommands) } func TestCreatePrimaryKeyWithTextType(t *testing.T) { @@ -215,17 +208,15 @@ func TestCreateIfNotExistsNoop(t *testing.T) { } mc := &mockCreateCatalog{tableExistsRes: true} expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, + &vm.InitCmd{P2: 2}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 1}, + &vm.GotoCmd{P2: 1}, } plan, err := NewCreate(mc, stmt).ExecutionPlan() if err != nil { t.Fatal(err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, expectedCommands) } diff --git a/planner/generator.go b/planner/generator.go index 8a7f7d8..2bb5f04 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -299,9 +299,23 @@ func (c *countNode) consume() { } func (c *createNode) produce() { + c.consume() } func (c *createNode) consume() { + if c.noop { + return + } + c.plan.commands = append(c.plan.commands, &vm.CreateBTreeCmd{P2: 1}) + c.plan.commands = append(c.plan.commands, &vm.NewRowIdCmd{P1: c.plan.cursorId, P2: 2}) + c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 3, P4: c.objectType}) + c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 4, P4: c.objectName}) + c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 5, P4: c.tableName}) + c.plan.commands = append(c.plan.commands, &vm.CopyCmd{P1: 1, P2: 6}) + c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 7, P4: string(c.schema)}) + c.plan.commands = append(c.plan.commands, &vm.MakeRecordCmd{P1: 3, P2: 5, P3: 8}) + c.plan.commands = append(c.plan.commands, &vm.InsertCmd{P1: c.plan.cursorId, P2: 8, P3: 2}) + c.plan.commands = append(c.plan.commands, &vm.ParseSchemaCmd{}) } func (n *insertNode) produce() { diff --git a/planner/node.go b/planner/node.go index 1627279..35ff8cb 100644 --- a/planner/node.go +++ b/planner/node.go @@ -26,6 +26,7 @@ type joinNode struct { // createNode represents a operation to create an object in the system catalog. // For example a table, index, or trigger. type createNode struct { + plan *QueryPlan // objectName is the name of the index, trigger, or table. objectName string // objectType could be an index, trigger, or in this case a table. diff --git a/planner/select.go b/planner/select.go index 7fd39c7..ccf4474 100644 --- a/planner/select.go +++ b/planner/select.go @@ -47,30 +47,6 @@ func NewSelect(catalog selectCatalog, stmt *compiler.SelectStmt) *selectPlanner // QueryPlan generates the query plan tree for the planner. func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { - qp, err := p.getQueryPlan() - if err != nil { - return nil, err - } - return qp, err -} - -// ExecutionPlan returns the bytecode execution plan for the planner. Calling -// QueryPlan is not a prerequisite to this method as it will be called by -// ExecutionPlan if needed. -func (sp *selectPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { - if sp.queryPlan == nil { - _, err := sp.QueryPlan() - if err != nil { - return nil, err - } - } - sp.setResultHeader() - sp.queryPlan.compile() - sp.executionPlan.Commands = sp.queryPlan.commands - return sp.executionPlan, nil -} - -func (p *selectPlanner) getQueryPlan() (*QueryPlan, error) { err := p.optimizeResultColumns() if err != nil { return nil, err @@ -174,6 +150,22 @@ func (p *selectPlanner) getQueryPlan() (*QueryPlan, error) { return plan, nil } +// ExecutionPlan returns the bytecode execution plan for the planner. Calling +// QueryPlan is not a prerequisite to this method as it will be called by +// ExecutionPlan if needed. +func (sp *selectPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { + if sp.queryPlan == nil { + _, err := sp.QueryPlan() + if err != nil { + return nil, err + } + } + sp.setResultHeader() + sp.queryPlan.compile() + sp.executionPlan.Commands = sp.queryPlan.commands + return sp.executionPlan, nil +} + func (p *selectPlanner) optimizeResultColumns() error { var err error for i := range p.stmt.ResultColumns { From 67b6daa912fcbf4ff15fc9892f63e349193357d9 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 28 Dec 2025 02:09:51 -0700 Subject: [PATCH 10/17] begin moving insert to generator --- planner/generator.go | 77 ++++++++++++++++++++ planner/insert.go | 156 ++++------------------------------------- planner/insert_test.go | 74 ++++++++----------- planner/node.go | 1 + 4 files changed, 122 insertions(+), 186 deletions(-) diff --git a/planner/generator.go b/planner/generator.go index 2bb5f04..783e1d8 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -319,9 +319,86 @@ func (c *createNode) consume() { } func (n *insertNode) produce() { + n.consume() } func (n *insertNode) consume() { + // TODO should lift anything involving catalog into the logical planning. + for valueIdx := range len(n.colValues) { + // For simplicity, the primary key is in the first register. + const keyRegister = 1 + n.buildPrimaryKey(n.plan.cursorId, keyRegister, valueIdx) + registerIdx := keyRegister + for _, catalogColumnName := range n.catalogColumnNames { + if catalogColumnName != "" && catalogColumnName == n.pkColumn { + // Skip the primary key column since it is handled before. + continue + } + registerIdx += 1 + n.buildNonPkValue(valueIdx, registerIdx, catalogColumnName) + } + n.plan.commands = append( + n.plan.commands, + &vm.MakeRecordCmd{P1: 2, P2: registerIdx - 1, P3: registerIdx + 1}, + ) + n.plan.commands = append( + n.plan.commands, + &vm.InsertCmd{P1: n.plan.cursorId, P2: registerIdx + 1, P3: keyRegister}, + ) + } +} + +func (p *insertNode) buildPrimaryKey(writeCursorId, keyRegister, valueIdx int) { + // If the table has a user defined pk column it needs to be looked up in the + // user defined column list. If the user has defined the pk column the + // execution plan will involve checking the uniqueness of the pk during + // execution. Otherwise the system guarantees a unique key. + statementPkIdx := -1 + if p.pkColumn != "" { + statementPkIdx = slices.IndexFunc(p.colNames, func(s string) bool { + return s == p.pkColumn + }) + } + if statementPkIdx == -1 { + p.plan.commands = append(p.plan.commands, &vm.NewRowIdCmd{P1: writeCursorId, P2: keyRegister}) + return + } + switch rv := p.colValues[valueIdx][statementPkIdx].(type) { + case *compiler.IntLit: + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: rv.Value, P2: keyRegister}) + case *compiler.Variable: + // TODO must be int could likely be used more to enforce schema types. + p.plan.commands = append(p.plan.commands, &vm.VariableCmd{P1: rv.Position, P2: keyRegister}) + p.plan.commands = append(p.plan.commands, &vm.MustBeIntCmd{P1: keyRegister}) + default: + panic("unsupported row id value") + } + continueIdx := len(p.plan.commands) + 2 + p.plan.commands = append(p.plan.commands, &vm.NotExistsCmd{P1: writeCursorId, P2: continueIdx, P3: keyRegister}) + p.plan.commands = append(p.plan.commands, &vm.HaltCmd{P1: 1, P4: pkConstraint}) +} + +func (p *insertNode) buildNonPkValue(valueIdx, registerIdx int, catalogColumnName string) { + // Get the statement index of the column name. Because the name positions + // can mismatch the table column positions. + stmtColIdx := slices.IndexFunc(p.colNames, func(stmtColName string) bool { + return stmtColName == catalogColumnName + }) + // Requires the statement to define a value for each column in the table. + // TODO lift this check up into logical planner + // if stmtColIdx == -1 { + // return fmt.Errorf("%w %s", errMissingColumnName, catalogColumnName) + // } + switch cv := p.colValues[valueIdx][stmtColIdx].(type) { + case *compiler.StringLit: + p.plan.commands = append(p.plan.commands, &vm.StringCmd{P1: registerIdx, P4: cv.Value}) + case *compiler.IntLit: + p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: cv.Value, P2: registerIdx}) + case *compiler.Variable: + p.plan.commands = append(p.plan.commands, &vm.VariableCmd{P1: cv.Position, P2: registerIdx}) + default: + panic("unsupported type of value") + } } func (n *joinNode) produce() { diff --git a/planner/insert.go b/planner/insert.go index 2278b59..6119526 100644 --- a/planner/insert.go +++ b/planner/insert.go @@ -1,10 +1,6 @@ package planner import ( - "errors" - "fmt" - "slices" - "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" ) @@ -24,19 +20,6 @@ type insertCatalog interface { // insertPlanner consists of planners capable of generating a logical query plan // tree and bytecode execution plan for a insert statement. type insertPlanner struct { - // The query planner generates a logical query plan tree made up of nodes - // similar to relational algebra operators. The query planner performs - // validation while building the tree. Otherwise known as binding. - queryPlanner *insertQueryPlanner - // The executionPlanner transforms the logical query plan tree to a bytecode - // execution plan that can be ran by the virtual machine. - executionPlanner *insertExecutionPlanner -} - -// insertQueryPlanner converts the AST generated by the compiler to a logical -// query plan tree. It is also responsible for validating the AST against the -// system catalog. -type insertQueryPlanner struct { // catalog contains the schema. catalog insertCatalog // stmt contains the AST. @@ -44,14 +27,6 @@ type insertQueryPlanner struct { // queryPlan contains the query plan being constructed. For an insert, the // root node must be an insertNode. queryPlan *insertNode -} - -// insertExecutionPlanner converts the logical query plan to a bytecode routine -// to be ran by the vm. -type insertExecutionPlanner struct { - // queryPlan contains the query plan generated by the query planner's - // QueryPlan method. - queryPlan *insertNode // executionPlan contains the execution plan generated by calling // ExecutionPlan. executionPlan *vm.ExecutionPlan @@ -60,30 +35,17 @@ type insertExecutionPlanner struct { // NewInsert returns an instance of an insert planner for the given AST. func NewInsert(catalog insertCatalog, stmt *compiler.InsertStmt) *insertPlanner { return &insertPlanner{ - queryPlanner: &insertQueryPlanner{ - catalog: catalog, - stmt: stmt, - }, - executionPlanner: &insertExecutionPlanner{ - executionPlan: vm.NewExecutionPlan( - catalog.GetVersion(), - stmt.Explain, - ), - }, + catalog: catalog, + stmt: stmt, + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), } } // QueryPlan generates the query plan tree for the planner. func (p *insertPlanner) QueryPlan() (*QueryPlan, error) { - qp, err := p.queryPlanner.getQueryPlan() - if err != nil { - return nil, err - } - p.executionPlanner.queryPlan = p.queryPlanner.queryPlan - return qp, err -} - -func (p *insertQueryPlanner) getQueryPlan() (*QueryPlan, error) { rootPage, err := p.catalog.GetRootPageNumber(p.stmt.TableName) if err != nil { return nil, errTableNotExist @@ -92,7 +54,7 @@ func (p *insertQueryPlanner) getQueryPlan() (*QueryPlan, error) { if err != nil { return nil, err } - if err := checkValuesMatchColumns(p.stmt); err != nil { + if err := p.checkValuesMatchColumns(p.stmt); err != nil { return nil, err } pkColumn, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) @@ -107,120 +69,32 @@ func (p *insertQueryPlanner) getQueryPlan() (*QueryPlan, error) { colValues: p.stmt.ColValues, } p.queryPlan = insertNode - return newQueryPlan( + qp := newQueryPlan( insertNode, p.stmt.ExplainQueryPlan, transactionTypeWrite, rootPage, - ), nil + ) + insertNode.plan = qp + return qp, nil } // ExecutionPlan returns the bytecode routine for the planner. Calling QueryPlan // is not prerequisite to calling ExecutionPlan as ExecutionPlan will be called // as needed. func (p *insertPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { - if p.queryPlanner.queryPlan == nil { + if p.queryPlan == nil { _, err := p.QueryPlan() if err != nil { return nil, err } } - return p.executionPlanner.getExecutionPlan() -} - -func (p *insertExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { - p.buildInit() - cursorId := p.openWrite() - for valueIdx := range len(p.queryPlan.colValues) { - // For simplicity, the primary key is in the first register. - const keyRegister = 1 - if err := p.buildPrimaryKey(cursorId, keyRegister, valueIdx); err != nil { - return nil, err - } - registerIdx := keyRegister - for _, catalogColumnName := range p.queryPlan.catalogColumnNames { - if catalogColumnName != "" && catalogColumnName == p.queryPlan.pkColumn { - // Skip the primary key column since it is handled before. - continue - } - registerIdx += 1 - if err := p.buildNonPkValue(valueIdx, registerIdx, catalogColumnName); err != nil { - return nil, err - } - } - p.executionPlan.Append(&vm.MakeRecordCmd{P1: 2, P2: registerIdx - 1, P3: registerIdx + 1}) - p.executionPlan.Append(&vm.InsertCmd{P1: cursorId, P2: registerIdx + 1, P3: keyRegister}) - } - p.executionPlan.Append(&vm.HaltCmd{}) + p.queryPlan.plan.compile() + p.executionPlan.Commands = p.queryPlan.plan.commands return p.executionPlan, nil } -func (p *insertExecutionPlanner) buildInit() { - p.executionPlan.Append(&vm.InitCmd{P2: 1}) - p.executionPlan.Append(&vm.TransactionCmd{P2: 1}) -} - -func (p *insertExecutionPlanner) openWrite() int { - const cursorId = 1 - p.executionPlan.Append(&vm.OpenWriteCmd{P1: cursorId, P2: p.queryPlan.rootPage}) - return cursorId -} - -func (p *insertExecutionPlanner) buildPrimaryKey(writeCursorId, keyRegister, valueIdx int) error { - // If the table has a user defined pk column it needs to be looked up in the - // user defined column list. If the user has defined the pk column the - // execution plan will involve checking the uniqueness of the pk during - // execution. Otherwise the system guarantees a unique key. - statementPkIdx := -1 - if p.queryPlan.pkColumn != "" { - statementPkIdx = slices.IndexFunc(p.queryPlan.colNames, func(s string) bool { - return s == p.queryPlan.pkColumn - }) - } - if statementPkIdx == -1 { - p.executionPlan.Append(&vm.NewRowIdCmd{P1: writeCursorId, P2: keyRegister}) - return nil - } - switch rv := p.queryPlan.colValues[valueIdx][statementPkIdx].(type) { - case *compiler.IntLit: - p.executionPlan.Append(&vm.IntegerCmd{P1: rv.Value, P2: keyRegister}) - case *compiler.Variable: - // TODO must be int could likely be used more to enforce schema types. - p.executionPlan.Append(&vm.VariableCmd{P1: rv.Position, P2: keyRegister}) - p.executionPlan.Append(&vm.MustBeIntCmd{P1: keyRegister}) - default: - return errors.New("unsupported row id value") - } - continueIdx := len(p.executionPlan.Commands) + 2 - p.executionPlan.Append(&vm.NotExistsCmd{P1: writeCursorId, P2: continueIdx, P3: keyRegister}) - p.executionPlan.Append(&vm.HaltCmd{P1: 1, P4: pkConstraint}) - return nil -} - -func (p *insertExecutionPlanner) buildNonPkValue(valueIdx, registerIdx int, catalogColumnName string) error { - // Get the statement index of the column name. Because the name positions - // can mismatch the table column positions. - stmtColIdx := slices.IndexFunc(p.queryPlan.colNames, func(stmtColName string) bool { - return stmtColName == catalogColumnName - }) - // Requires the statement to define a value for each column in the table. - if stmtColIdx == -1 { - return fmt.Errorf("%w %s", errMissingColumnName, catalogColumnName) - } - switch cv := p.queryPlan.colValues[valueIdx][stmtColIdx].(type) { - case *compiler.StringLit: - p.executionPlan.Append(&vm.StringCmd{P1: registerIdx, P4: cv.Value}) - case *compiler.IntLit: - p.executionPlan.Append(&vm.IntegerCmd{P1: cv.Value, P2: registerIdx}) - case *compiler.Variable: - p.executionPlan.Append(&vm.VariableCmd{P1: cv.Position, P2: registerIdx}) - default: - return errors.New("unsupported type of value") - } - return nil -} - -func checkValuesMatchColumns(s *compiler.InsertStmt) error { +func (p *insertPlanner) checkValuesMatchColumns(s *compiler.InsertStmt) error { cl := len(s.ColNames) for _, cv := range s.ColValues { if cl != len(cv) { diff --git a/planner/insert_test.go b/planner/insert_test.go index 8e4cb10..e9acebf 100644 --- a/planner/insert_test.go +++ b/planner/insert_test.go @@ -2,7 +2,6 @@ package planner import ( "errors" - "reflect" "testing" "github.com/chirst/cdb/compiler" @@ -38,9 +37,7 @@ func (m *mockInsertCatalog) GetPrimaryKeyColumn(tableName string) (string, error func TestInsertWithoutPrimaryKey(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.InitCmd{P2: 17}, &vm.NewRowIdCmd{P1: 1, P2: 1}, &vm.StringCmd{P1: 2, P4: "gud"}, &vm.StringCmd{P1: 3, P4: "dude"}, @@ -57,6 +54,9 @@ func TestInsertWithoutPrimaryKey(t *testing.T) { &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 4}, &vm.InsertCmd{P1: 1, P2: 4, P3: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ @@ -87,25 +87,22 @@ func TestInsertWithoutPrimaryKey(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, expectedCommands) } func TestInsertWithPrimaryKey(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.InitCmd{P2: 8}, &vm.IntegerCmd{P1: 22, P2: 1}, - &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 4, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, &vm.StringCmd{P1: 2, P4: "gud"}, &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ StmtBase: &compiler.StmtBase{}, @@ -129,25 +126,22 @@ func TestInsertWithPrimaryKey(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, expectedCommands) } func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.InitCmd{P2: 8}, &vm.IntegerCmd{P1: 12, P2: 1}, - &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 4, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, &vm.StringCmd{P1: 2, P4: "feller"}, &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ StmtBase: &compiler.StmtBase{}, @@ -171,26 +165,23 @@ func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, expectedCommands) } func TestInsertWithPrimaryKeyParameter(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.InitCmd{P2: 9}, &vm.VariableCmd{P1: 0, P2: 1}, &vm.MustBeIntCmd{P1: 1}, - &vm.NotExistsCmd{P1: 1, P2: 7, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, &vm.StringCmd{P1: 2, P4: "feller"}, &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ StmtBase: &compiler.StmtBase{}, @@ -214,26 +205,23 @@ func TestInsertWithPrimaryKeyParameter(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, expectedCommands) } func TestInsertWithParameter(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 1}, - &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.InitCmd{P2: 9}, &vm.VariableCmd{P1: 0, P2: 1}, &vm.MustBeIntCmd{P1: 1}, - &vm.NotExistsCmd{P1: 1, P2: 7, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, &vm.VariableCmd{P1: 1, P2: 2}, &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, &vm.HaltCmd{}, + &vm.TransactionCmd{P2: 1}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ StmtBase: &compiler.StmtBase{}, @@ -257,11 +245,7 @@ func TestInsertWithParameter(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - for i, c := range expectedCommands { - if !reflect.DeepEqual(c, plan.Commands[i]) { - t.Errorf("got %#v want %#v", plan.Commands[i], c) - } - } + assertCommandsMatch(t, plan.Commands, expectedCommands) } func TestInsertIntoNonExistingTable(t *testing.T) { diff --git a/planner/node.go b/planner/node.go index 35ff8cb..fe2e8d0 100644 --- a/planner/node.go +++ b/planner/node.go @@ -48,6 +48,7 @@ type createNode struct { // insertNode represents an insert operation. type insertNode struct { + plan *QueryPlan // rootPage is the rootPage of the table the insert is performed on. rootPage int // catalogColumnNames are all of the names of columns associated with the From d33de904237e5ed82e8b7d551fec5812397ee148 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 28 Dec 2025 11:43:50 -0700 Subject: [PATCH 11/17] split insert planning and generation --- planner/generator.go | 122 ++++++++++++++++------------------------- planner/insert.go | 70 +++++++++++++++++++---- planner/insert_test.go | 84 ++++++++++++++++------------ planner/node.go | 20 +++---- 4 files changed, 168 insertions(+), 128 deletions(-) diff --git a/planner/generator.go b/planner/generator.go index 783e1d8..73f5d8e 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -114,10 +114,10 @@ func (p *QueryPlan) pushConstantInts() { func (p *QueryPlan) pushConstantStrings() { temp := []*vm.StringCmd{} for v := range p.constStrings { - p.commands = append(p.commands, &vm.StringCmd{P1: p.constStrings[v], P4: v}) + temp = append(temp, &vm.StringCmd{P1: p.constStrings[v], P4: v}) } slices.SortFunc(temp, func(a, b *vm.StringCmd) int { - return a.P2 - b.P2 + return a.P1 - b.P1 }) for i := range temp { p.commands = append(p.commands, temp[i]) @@ -127,7 +127,7 @@ func (p *QueryPlan) pushConstantStrings() { func (p *QueryPlan) pushConstantVars() { temp := []*vm.VariableCmd{} for v := range p.constVars { - p.commands = append(p.commands, &vm.VariableCmd{P1: v, P2: p.constVars[v]}) + temp = append(temp, &vm.VariableCmd{P1: v, P2: p.constVars[v]}) } slices.SortFunc(temp, func(a, b *vm.VariableCmd) int { return a.P2 - b.P2 @@ -323,81 +323,55 @@ func (n *insertNode) produce() { } func (n *insertNode) consume() { - // TODO should lift anything involving catalog into the logical planning. - for valueIdx := range len(n.colValues) { - // For simplicity, the primary key is in the first register. - const keyRegister = 1 - n.buildPrimaryKey(n.plan.cursorId, keyRegister, valueIdx) - registerIdx := keyRegister - for _, catalogColumnName := range n.catalogColumnNames { - if catalogColumnName != "" && catalogColumnName == n.pkColumn { - // Skip the primary key column since it is handled before. - continue + for valuesIdx := range len(n.colValues) { + // Setup rowid and it's uniqueness/type checks + pkRegister := n.plan.freeRegister + n.plan.freeRegister += 1 + if n.autoPk { + n.plan.commands = append(n.plan.commands, &vm.NewRowIdCmd{ + P1: n.plan.cursorId, + P2: pkRegister, + }) + } else { + generateExpressionTo(n.plan, n.pkValues[valuesIdx], pkRegister) + n.plan.commands = append(n.plan.commands, &vm.MustBeIntCmd{P1: pkRegister}) + nec := &vm.NotExistsCmd{ + P1: n.plan.cursorId, + P3: pkRegister, } - registerIdx += 1 - n.buildNonPkValue(valueIdx, registerIdx, catalogColumnName) + n.plan.commands = append(n.plan.commands, nec) + n.plan.commands = append(n.plan.commands, &vm.HaltCmd{ + P1: 1, + P4: pkConstraint, + }) + nec.P2 = len(n.plan.commands) } - n.plan.commands = append( - n.plan.commands, - &vm.MakeRecordCmd{P1: 2, P2: registerIdx - 1, P3: registerIdx + 1}, - ) - n.plan.commands = append( - n.plan.commands, - &vm.InsertCmd{P1: n.plan.cursorId, P2: registerIdx + 1, P3: keyRegister}, - ) - } -} -func (p *insertNode) buildPrimaryKey(writeCursorId, keyRegister, valueIdx int) { - // If the table has a user defined pk column it needs to be looked up in the - // user defined column list. If the user has defined the pk column the - // execution plan will involve checking the uniqueness of the pk during - // execution. Otherwise the system guarantees a unique key. - statementPkIdx := -1 - if p.pkColumn != "" { - statementPkIdx = slices.IndexFunc(p.colNames, func(s string) bool { - return s == p.pkColumn - }) - } - if statementPkIdx == -1 { - p.plan.commands = append(p.plan.commands, &vm.NewRowIdCmd{P1: writeCursorId, P2: keyRegister}) - return - } - switch rv := p.colValues[valueIdx][statementPkIdx].(type) { - case *compiler.IntLit: - p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: rv.Value, P2: keyRegister}) - case *compiler.Variable: - // TODO must be int could likely be used more to enforce schema types. - p.plan.commands = append(p.plan.commands, &vm.VariableCmd{P1: rv.Position, P2: keyRegister}) - p.plan.commands = append(p.plan.commands, &vm.MustBeIntCmd{P1: keyRegister}) - default: - panic("unsupported row id value") - } - continueIdx := len(p.plan.commands) + 2 - p.plan.commands = append(p.plan.commands, &vm.NotExistsCmd{P1: writeCursorId, P2: continueIdx, P3: keyRegister}) - p.plan.commands = append(p.plan.commands, &vm.HaltCmd{P1: 1, P4: pkConstraint}) -} + // Reserve registers and make values segment for MakeRecord + startRegister := n.plan.freeRegister + reservedRegisters := len(n.colValues[valuesIdx]) + n.plan.freeRegister += reservedRegisters + for vi := range n.colValues[valuesIdx] { + generateExpressionTo( + n.plan, + n.colValues[valuesIdx][vi], + startRegister+vi, + ) + } -func (p *insertNode) buildNonPkValue(valueIdx, registerIdx int, catalogColumnName string) { - // Get the statement index of the column name. Because the name positions - // can mismatch the table column positions. - stmtColIdx := slices.IndexFunc(p.colNames, func(stmtColName string) bool { - return stmtColName == catalogColumnName - }) - // Requires the statement to define a value for each column in the table. - // TODO lift this check up into logical planner - // if stmtColIdx == -1 { - // return fmt.Errorf("%w %s", errMissingColumnName, catalogColumnName) - // } - switch cv := p.colValues[valueIdx][stmtColIdx].(type) { - case *compiler.StringLit: - p.plan.commands = append(p.plan.commands, &vm.StringCmd{P1: registerIdx, P4: cv.Value}) - case *compiler.IntLit: - p.plan.commands = append(p.plan.commands, &vm.IntegerCmd{P1: cv.Value, P2: registerIdx}) - case *compiler.Variable: - p.plan.commands = append(p.plan.commands, &vm.VariableCmd{P1: cv.Position, P2: registerIdx}) - default: - panic("unsupported type of value") + // Insert + n.plan.commands = append(n.plan.commands, &vm.MakeRecordCmd{ + P1: startRegister, + P2: reservedRegisters, + P3: n.plan.freeRegister, + }) + recordRegister := n.plan.freeRegister + n.plan.freeRegister += 1 + n.plan.commands = append(n.plan.commands, &vm.InsertCmd{ + P1: n.plan.cursorId, + P2: recordRegister, + P3: pkRegister, + }) } } diff --git a/planner/insert.go b/planner/insert.go index 6119526..4ac7ce7 100644 --- a/planner/insert.go +++ b/planner/insert.go @@ -1,6 +1,9 @@ package planner import ( + "fmt" + "slices" + "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" ) @@ -50,23 +53,18 @@ func (p *insertPlanner) QueryPlan() (*QueryPlan, error) { if err != nil { return nil, errTableNotExist } - catalogColumnNames, err := p.catalog.GetColumns(p.stmt.TableName) - if err != nil { - return nil, err - } if err := p.checkValuesMatchColumns(p.stmt); err != nil { return nil, err } - pkColumn, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) + colValues, err := p.getNonPkValues() if err != nil { return nil, err } insertNode := &insertNode{ - rootPage: rootPage, - catalogColumnNames: catalogColumnNames, - pkColumn: pkColumn, - colNames: p.stmt.ColNames, - colValues: p.stmt.ColValues, + colValues: colValues, + } + if err := p.setPkValues(insertNode); err != nil { + return nil, err } p.queryPlan = insertNode qp := newQueryPlan( @@ -79,6 +77,58 @@ func (p *insertPlanner) QueryPlan() (*QueryPlan, error) { return qp, nil } +func (p *insertPlanner) setPkValues(n *insertNode) error { + pkColumnName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) + if err != nil { + return err + } + statementPkIdx := -1 + if pkColumnName != "" { + statementPkIdx = slices.IndexFunc(p.stmt.ColNames, func(s string) bool { + return s == pkColumnName + }) + } + if statementPkIdx == -1 { + n.autoPk = true + } else { + n.autoPk = false + n.pkValues = []compiler.Expr{} + for _, v := range p.stmt.ColValues { + n.pkValues = append(n.pkValues, v[statementPkIdx]) + } + } + return nil +} + +func (p *insertPlanner) getNonPkValues() ([][]compiler.Expr, error) { + pkColumnName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) + if err != nil { + return nil, err + } + catalogColumnNames, err := p.catalog.GetColumns(p.stmt.TableName) + if err != nil { + return nil, err + } + resultValues := [][]compiler.Expr{} + for _, colValue := range p.stmt.ColValues { + resultValue := []compiler.Expr{} + for _, cn := range catalogColumnNames { + if cn == pkColumnName { + continue + } + stmtColIdx := slices.IndexFunc(p.stmt.ColNames, func(stmtColName string) bool { + return stmtColName == cn + }) + if stmtColIdx == -1 { + return nil, fmt.Errorf("%w %s", errMissingColumnName, cn) + } + resultValue = append(resultValue, colValue[stmtColIdx]) + } + resultValues = append(resultValues, resultValue) + } + return resultValues, nil +} + // ExecutionPlan returns the bytecode routine for the planner. Calling QueryPlan // is not prerequisite to calling ExecutionPlan as ExecutionPlan will be called // as needed. diff --git a/planner/insert_test.go b/planner/insert_test.go index e9acebf..273c210 100644 --- a/planner/insert_test.go +++ b/planner/insert_test.go @@ -39,23 +39,29 @@ func TestInsertWithoutPrimaryKey(t *testing.T) { expectedCommands := []vm.Command{ &vm.InitCmd{P2: 17}, &vm.NewRowIdCmd{P1: 1, P2: 1}, - &vm.StringCmd{P1: 2, P4: "gud"}, - &vm.StringCmd{P1: 3, P4: "dude"}, - &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 4}, - &vm.InsertCmd{P1: 1, P2: 4, P3: 1}, - &vm.NewRowIdCmd{P1: 1, P2: 1}, - &vm.StringCmd{P1: 2, P4: "joe"}, - &vm.StringCmd{P1: 3, P4: "doe"}, - &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 4}, - &vm.InsertCmd{P1: 1, P2: 4, P3: 1}, - &vm.NewRowIdCmd{P1: 1, P2: 1}, - &vm.StringCmd{P1: 2, P4: "jan"}, - &vm.StringCmd{P1: 3, P4: "ice"}, - &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 4}, - &vm.InsertCmd{P1: 1, P2: 4, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 2}, + &vm.CopyCmd{P1: 5, P2: 3}, + &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 6}, + &vm.InsertCmd{P1: 1, P2: 6, P3: 1}, + &vm.NewRowIdCmd{P1: 1, P2: 7}, + &vm.CopyCmd{P1: 10, P2: 8}, + &vm.CopyCmd{P1: 11, P2: 9}, + &vm.MakeRecordCmd{P1: 8, P2: 2, P3: 12}, + &vm.InsertCmd{P1: 1, P2: 12, P3: 7}, + &vm.NewRowIdCmd{P1: 1, P2: 13}, + &vm.CopyCmd{P1: 16, P2: 14}, + &vm.CopyCmd{P1: 17, P2: 15}, + &vm.MakeRecordCmd{P1: 14, P2: 2, P3: 18}, + &vm.InsertCmd{P1: 1, P2: 18, P3: 13}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.StringCmd{P1: 4, P4: "gud"}, + &vm.StringCmd{P1: 5, P4: "dude"}, + &vm.StringCmd{P1: 10, P4: "joe"}, + &vm.StringCmd{P1: 11, P4: "doe"}, + &vm.StringCmd{P1: 16, P4: "jan"}, + &vm.StringCmd{P1: 17, P4: "ice"}, &vm.GotoCmd{P2: 1}, } @@ -92,16 +98,19 @@ func TestInsertWithoutPrimaryKey(t *testing.T) { func TestInsertWithPrimaryKey(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 8}, - &vm.IntegerCmd{P1: 22, P2: 1}, - &vm.NotExistsCmd{P1: 1, P2: 4, P3: 1}, + &vm.InitCmd{P2: 9}, + &vm.CopyCmd{P1: 2, P2: 1}, + &vm.MustBeIntCmd{P1: 1}, + &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, - &vm.StringCmd{P1: 2, P4: "gud"}, - &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, - &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.IntegerCmd{P1: 22, P2: 2}, + &vm.StringCmd{P1: 4, P4: "gud"}, &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ @@ -131,16 +140,19 @@ func TestInsertWithPrimaryKey(t *testing.T) { func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 8}, - &vm.IntegerCmd{P1: 12, P2: 1}, - &vm.NotExistsCmd{P1: 1, P2: 4, P3: 1}, + &vm.InitCmd{P2: 9}, + &vm.CopyCmd{P1: 2, P2: 1}, + &vm.MustBeIntCmd{P1: 1}, + &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, - &vm.StringCmd{P1: 2, P4: "feller"}, - &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, - &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.IntegerCmd{P1: 12, P2: 2}, + &vm.StringCmd{P1: 4, P4: "feller"}, &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ @@ -171,16 +183,18 @@ func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { func TestInsertWithPrimaryKeyParameter(t *testing.T) { expectedCommands := []vm.Command{ &vm.InitCmd{P2: 9}, - &vm.VariableCmd{P1: 0, P2: 1}, + &vm.CopyCmd{P1: 2, P2: 1}, &vm.MustBeIntCmd{P1: 1}, &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, - &vm.StringCmd{P1: 2, P4: "feller"}, - &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, - &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.StringCmd{P1: 4, P4: "feller"}, + &vm.VariableCmd{P1: 0, P2: 2}, &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ @@ -211,16 +225,18 @@ func TestInsertWithPrimaryKeyParameter(t *testing.T) { func TestInsertWithParameter(t *testing.T) { expectedCommands := []vm.Command{ &vm.InitCmd{P2: 9}, - &vm.VariableCmd{P1: 0, P2: 1}, + &vm.CopyCmd{P1: 2, P2: 1}, &vm.MustBeIntCmd{P1: 1}, &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, - &vm.VariableCmd{P1: 1, P2: 2}, - &vm.MakeRecordCmd{P1: 2, P2: 1, P3: 3}, - &vm.InsertCmd{P1: 1, P2: 3, P3: 1}, + &vm.CopyCmd{P1: 4, P2: 3}, + &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, + &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.VariableCmd{P1: 0, P2: 2}, + &vm.VariableCmd{P1: 1, P2: 4}, &vm.GotoCmd{P2: 1}, } ast := &compiler.InsertStmt{ diff --git a/planner/node.go b/planner/node.go index fe2e8d0..0492fda 100644 --- a/planner/node.go +++ b/planner/node.go @@ -49,17 +49,17 @@ type createNode struct { // insertNode represents an insert operation. type insertNode struct { plan *QueryPlan - // rootPage is the rootPage of the table the insert is performed on. - rootPage int - // catalogColumnNames are all of the names of columns associated with the - // table. - catalogColumnNames []string - // pkColumn is the name of the primary key column in the catalog. The value - // is empty if no user defined pk. - pkColumn string - // colNames are the names of columns specified in the insert statement. - colNames []string // colValues are the values specified in the insert statement. It is two // dimensional i.e. VALUES (v1, v2), (v3, v4) is [[v1, v2], [v3, v4]]. + // + // The logical planner must guarantee these values are in the correct + // ordinal position as the code generator will not check. colValues [][]compiler.Expr + // pkValues holds the pk expression separate from colValues for each values + // entry. In case a pkValue wasn't specified in the values list a reasonable + // value will be provided for the code generator or the autoPk will be true. + pkValues []compiler.Expr + // autoPk indicates the generator should use a NewRowIdCmd for pk + // generation. + autoPk bool } From b4673512b9d210ac7a04ba08b6b6a83ceb0695c2 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 28 Dec 2025 11:56:57 -0700 Subject: [PATCH 12/17] move helper --- planner/assert_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++ planner/create_test.go | 12 +++++++++--- planner/insert_test.go | 20 ++++++++++++++----- planner/select_test.go | 4 +++- planner/update_test.go | 32 ++++++------------------------ 5 files changed, 77 insertions(+), 35 deletions(-) create mode 100644 planner/assert_test.go diff --git a/planner/assert_test.go b/planner/assert_test.go new file mode 100644 index 0000000..9652c65 --- /dev/null +++ b/planner/assert_test.go @@ -0,0 +1,44 @@ +package planner + +import ( + "errors" + "fmt" + "reflect" + + "github.com/chirst/cdb/vm" +) + +// assertCommandsMatch is a helper for tests in the planner package. +func assertCommandsMatch(gotCommands, expectedCommands []vm.Command) error { + didMatch := true + errOutput := "\n" + green := "\033[32m" + red := "\033[31m" + resetColor := "\033[0m" + for i, c := range expectedCommands { + color := green + if !reflect.DeepEqual(c, gotCommands[i]) { + didMatch = false + color = red + } + errOutput += fmt.Sprintf( + "%s%3d got %#v%s\n want %#v\n\n", + color, i, gotCommands[i], resetColor, c, + ) + } + gl := len(gotCommands) + wl := len(expectedCommands) + if gl != wl { + errOutput += red + errOutput += fmt.Sprintf("got %d want %d commands\n", gl, wl) + errOutput += resetColor + didMatch = false + } + // This helper returns an error instead of making the assertion so a fatal + // error will raise at the test site instead of the helper. This also allows + // the caller to differentiate between a fatal or non fatal assertion. + if !didMatch { + return errors.New(errOutput) + } + return nil +} diff --git a/planner/create_test.go b/planner/create_test.go index 41937ca..c70d6a7 100644 --- a/planner/create_test.go +++ b/planner/create_test.go @@ -74,7 +74,9 @@ func TestCreateWithNoIDColumn(t *testing.T) { if err != nil { t.Fatal(err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } func TestCreateWithAlternateNamedIDColumn(t *testing.T) { @@ -130,7 +132,9 @@ func TestCreateWithAlternateNamedIDColumn(t *testing.T) { if err != nil { t.Fatal(err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } func TestCreatePrimaryKeyWithTextType(t *testing.T) { @@ -218,5 +222,7 @@ func TestCreateIfNotExistsNoop(t *testing.T) { if err != nil { t.Fatal(err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } diff --git a/planner/insert_test.go b/planner/insert_test.go index 273c210..589bb07 100644 --- a/planner/insert_test.go +++ b/planner/insert_test.go @@ -93,7 +93,9 @@ func TestInsertWithoutPrimaryKey(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } func TestInsertWithPrimaryKey(t *testing.T) { @@ -135,7 +137,9 @@ func TestInsertWithPrimaryKey(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { @@ -177,7 +181,9 @@ func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } func TestInsertWithPrimaryKeyParameter(t *testing.T) { @@ -219,7 +225,9 @@ func TestInsertWithPrimaryKeyParameter(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } func TestInsertWithParameter(t *testing.T) { @@ -261,7 +269,9 @@ func TestInsertWithParameter(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } func TestInsertIntoNonExistingTable(t *testing.T) { diff --git a/planner/select_test.go b/planner/select_test.go index d9bf29e..8d32de0 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -462,7 +462,9 @@ func TestSelectPlan(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - assertCommandsMatch(t, plan.Commands, c.expectedCommands) + if err := assertCommandsMatch(plan.Commands, c.expectedCommands); err != nil { + t.Error(err) + } }) } } diff --git a/planner/update_test.go b/planner/update_test.go index 73408e0..10d13e1 100644 --- a/planner/update_test.go +++ b/planner/update_test.go @@ -2,8 +2,6 @@ package planner import ( "errors" - "fmt" - "reflect" "testing" "github.com/chirst/cdb/catalog" @@ -46,28 +44,6 @@ func (mockUpdateCatalog) GetColumnType(tableName string, columnName string) (cat return catalog.CdbType{ID: catalog.CTInt}, nil } -func assertCommandsMatch(t *testing.T, gotCommands, expectedCommands []vm.Command) { - didMatch := true - errOutput := "\n" - for i, c := range expectedCommands { - green := "\033[32m" - red := "\033[31m" - resetColor := "\033[0m" - color := green - if !reflect.DeepEqual(c, gotCommands[i]) { - didMatch = false - color = red - } - errOutput += fmt.Sprintf( - "%s%3d got %#v%s\n want %#v\n\n", - color, i, gotCommands[i], resetColor, c, - ) - } - if !didMatch { - t.Error(errOutput) - } -} - func TestUpdate(t *testing.T) { ast := &compiler.UpdateStmt{ StmtBase: &compiler.StmtBase{}, @@ -99,7 +75,9 @@ func TestUpdate(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } func TestUpdateWithWhere(t *testing.T) { @@ -145,5 +123,7 @@ func TestUpdateWithWhere(t *testing.T) { if err != nil { t.Errorf("expected no err got err %s", err) } - assertCommandsMatch(t, plan.Commands, expectedCommands) + if err := assertCommandsMatch(plan.Commands, expectedCommands); err != nil { + t.Error(err) + } } From d604ae483cb556463bc3ce2be8463943e65e4f71 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 28 Dec 2025 12:19:17 -0700 Subject: [PATCH 13/17] result col bug --- db/db_test.go | 19 +++++++++++++++++++ planner/select.go | 10 +++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/db/db_test.go b/db/db_test.go index bce71a5..4a4fcfb 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -157,6 +157,25 @@ func TestAddColumns(t *testing.T) { } } +func TestSelectHeaders(t *testing.T) { + db := mustCreateDB(t) + mustExecute(t, db, "CREATE TABLE test (id INTEGER PRIMARY KEY, val INTEGER)") + mustExecute(t, db, "INSERT INTO test (val) VALUES (1)") + res := mustExecute(t, db, "SELECT *, id AS foo FROM test") + if rowCount := len(res.ResultRows); rowCount != 1 { + t.Fatalf("want 1 row but got %d", rowCount) + } + if got := res.ResultHeader[0]; got != "id" { + t.Fatalf("want id but got %s", got) + } + if got := res.ResultHeader[1]; got != "val" { + t.Fatalf("want val but got %s", got) + } + if got := res.ResultHeader[2]; got != "foo" { + t.Fatalf("want foo but got %s", got) + } +} + func TestSelectWithWhere(t *testing.T) { db := mustCreateDB(t) mustExecute(t, db, "CREATE TABLE test (id INTEGER PRIMARY KEY, val INTEGER)") diff --git a/planner/select.go b/planner/select.go index ccf4474..45967f4 100644 --- a/planner/select.go +++ b/planner/select.go @@ -289,7 +289,15 @@ func (p *selectPlanner) setResultHeader() { case *projectNode: projectExprs := []compiler.Expr{} for _, projection := range t.projections { - resultHeader = append(resultHeader, projection.alias) + header := "" + if projection.alias == "" { + if cr, ok := projection.expr.(*compiler.ColumnRef); ok { + header = cr.Column + } + } else { + header = projection.alias + } + resultHeader = append(resultHeader, header) projectExprs = append(projectExprs, projection.expr) } p.setResultTypes(projectExprs) From 99163e7d25a19d94f465c4a79fed401c84c1a0b2 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 28 Dec 2025 12:21:48 -0700 Subject: [PATCH 14/17] merge update planner --- planner/update.go | 55 +++++++++++++---------------------------------- 1 file changed, 15 insertions(+), 40 deletions(-) diff --git a/planner/update.go b/planner/update.go index 3203aa0..9bae097 100644 --- a/planner/update.go +++ b/planner/update.go @@ -21,19 +21,8 @@ type updateCatalog interface { // updatePanner houses the query planner and execution planner for a update // statement. type updatePlanner struct { - queryPlanner *updateQueryPlanner - executionPlanner *updateExecutionPlanner -} - -// updateQueryPlanner generates a queryPlan for the given update statement. -type updateQueryPlanner struct { - catalog updateCatalog - stmt *compiler.UpdateStmt - queryPlan *updateNode -} - -// updateExecutionPlanner generates a byte code routine for the given queryPlan. -type updateExecutionPlanner struct { + catalog updateCatalog + stmt *compiler.UpdateStmt queryPlan *updateNode executionPlan *vm.ExecutionPlan } @@ -41,31 +30,17 @@ type updateExecutionPlanner struct { // NewUpdate create a update planner. func NewUpdate(catalog updateCatalog, stmt *compiler.UpdateStmt) *updatePlanner { return &updatePlanner{ - queryPlanner: &updateQueryPlanner{ - catalog: catalog, - stmt: stmt, - }, - executionPlanner: &updateExecutionPlanner{ - executionPlan: vm.NewExecutionPlan( - catalog.GetVersion(), - stmt.Explain, - ), - }, + catalog: catalog, + stmt: stmt, + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), } } // QueryPlan sets up a high level plan to be passed to ExecutionPlan. func (p *updatePlanner) QueryPlan() (*QueryPlan, error) { - qp, err := p.queryPlanner.getQueryPlan() - if err != nil { - return nil, err - } - p.executionPlanner.queryPlan = p.queryPlanner.queryPlan - return qp, err -} - -// getQueryPlan returns a updateNode with a high level plan. -func (p *updateQueryPlanner) getQueryPlan() (*QueryPlan, error) { rootPage, err := p.catalog.GetRootPageNumber(p.stmt.TableName) if err != nil { return nil, errTableNotExist @@ -118,7 +93,7 @@ func (p *updateQueryPlanner) getQueryPlan() (*QueryPlan, error) { // errIfPrimaryKeySet checks the primary key isn't being updated because it // could cause an infinite loop if not handled properly. -func (p *updateQueryPlanner) errIfPrimaryKeySet() error { +func (p *updatePlanner) errIfPrimaryKeySet() error { pkColumnName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) if err != nil { return err @@ -131,7 +106,7 @@ func (p *updateQueryPlanner) errIfPrimaryKeySet() error { // errIfSetNotOnDestinationTable checks the set list has column names that are // part of the table being updated. -func (p *updateQueryPlanner) errIfSetNotOnDestinationTable() error { +func (p *updatePlanner) errIfSetNotOnDestinationTable() error { schemaColumns, err := p.catalog.GetColumns(p.stmt.TableName) if err != nil { return err @@ -146,7 +121,7 @@ func (p *updateQueryPlanner) errIfSetNotOnDestinationTable() error { // setQueryPlanRecordExpressions populates the query plan with appropriate // expressions for setting up to make a record. -func (p *updateQueryPlanner) setQueryPlanRecordExpressions() error { +func (p *updatePlanner) setQueryPlanRecordExpressions() error { schemaColumns, err := p.catalog.GetColumns(p.stmt.TableName) if err != nil { return err @@ -190,13 +165,13 @@ func (p *updateQueryPlanner) setQueryPlanRecordExpressions() error { // Execution plan is a byte code routine based off a high level query plan. func (p *updatePlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { - if p.queryPlanner.queryPlan == nil { + if p.queryPlan == nil { _, err := p.QueryPlan() if err != nil { return nil, err } } - p.queryPlanner.queryPlan.plan.compile() - p.executionPlanner.executionPlan.Commands = p.queryPlanner.queryPlan.plan.commands - return p.executionPlanner.executionPlan, nil + p.queryPlan.plan.compile() + p.executionPlan.Commands = p.queryPlan.plan.commands + return p.executionPlan, nil } From fd16e53facf1f42bfe58d72f5b2bcc9478447c84 Mon Sep 17 00:00:00 2001 From: Colton Date: Sun, 28 Dec 2025 22:51:39 -0700 Subject: [PATCH 15/17] move --- planner/generator.go | 189 +--------------------------------------- planner/node.go | 138 ++++++++++++++++++++++++++++- planner/plan.go | 203 +++++++++++++++++++++++++++---------------- 3 files changed, 267 insertions(+), 263 deletions(-) diff --git a/planner/generator.go b/planner/generator.go index 73f5d8e..12661aa 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -1,158 +1,9 @@ package planner import ( - "slices" - - "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" ) -// transactionType defines possible transactions for a query plan. -type transactionType int - -const ( - transactionTypeNone transactionType = 0 - transactionTypeRead transactionType = 1 - transactionTypeWrite transactionType = 2 -) - -// declareConstInt gets or sets a register with the const value and returns the -// register. It is guaranteed the value will be in the register for the duration -// of the plan. -func (p *QueryPlan) declareConstInt(i int) int { - _, ok := p.constInts[i] - if !ok { - p.constInts[i] = p.freeRegister - p.freeRegister += 1 - } - return p.constInts[i] -} - -// declareConstString gets or sets a register with the const value and returns -// the register. It is guaranteed the value will be in the register for the -// duration of the plan. -func (p *QueryPlan) declareConstString(s string) int { - _, ok := p.constStrings[s] - if !ok { - p.constStrings[s] = p.freeRegister - p.freeRegister += 1 - } - return p.constStrings[s] -} - -// declareConstVar gets or sets a register with the const value and returns -// the register. It is guaranteed the value will be in the register for the -// duration of the plan. -func (p *QueryPlan) declareConstVar(position int) int { - _, ok := p.constVars[position] - if !ok { - p.constVars[position] = p.freeRegister - p.freeRegister += 1 - } - return p.constVars[position] -} - -func (p *QueryPlan) compile() { - initCmd := &vm.InitCmd{} - p.commands = append(p.commands, initCmd) - p.root.produce() - p.commands = append(p.commands, &vm.HaltCmd{}) - initCmd.P2 = len(p.commands) - p.pushTransaction() - p.pushConstants() - p.commands = append(p.commands, &vm.GotoCmd{P2: 1}) -} - -func (p *QueryPlan) pushTransaction() { - switch p.transactionType { - case transactionTypeNone: - return - case transactionTypeRead: - p.commands = append( - p.commands, - &vm.TransactionCmd{P2: 0}, - ) - p.commands = append( - p.commands, - &vm.OpenReadCmd{P1: p.cursorId, P2: p.rootPageNumber}, - ) - case transactionTypeWrite: - p.commands = append( - p.commands, - &vm.TransactionCmd{P2: 1}, - ) - p.commands = append( - p.commands, - &vm.OpenWriteCmd{P1: p.cursorId, P2: p.rootPageNumber}, - ) - default: - panic("unexpected transaction type") - } -} - -func (p *QueryPlan) pushConstants() { - // these constants are pushed ordered since maps are unordered making it - // difficult to assert that a sequence of instructions appears. - p.pushConstantInts() - p.pushConstantStrings() - p.pushConstantVars() -} - -func (p *QueryPlan) pushConstantInts() { - temp := []*vm.IntegerCmd{} - for k := range p.constInts { - temp = append(temp, &vm.IntegerCmd{P1: k, P2: p.constInts[k]}) - } - slices.SortFunc(temp, func(a, b *vm.IntegerCmd) int { - return a.P2 - b.P2 - }) - for i := range temp { - p.commands = append(p.commands, temp[i]) - } -} - -func (p *QueryPlan) pushConstantStrings() { - temp := []*vm.StringCmd{} - for v := range p.constStrings { - temp = append(temp, &vm.StringCmd{P1: p.constStrings[v], P4: v}) - } - slices.SortFunc(temp, func(a, b *vm.StringCmd) int { - return a.P1 - b.P1 - }) - for i := range temp { - p.commands = append(p.commands, temp[i]) - } -} - -func (p *QueryPlan) pushConstantVars() { - temp := []*vm.VariableCmd{} - for v := range p.constVars { - temp = append(temp, &vm.VariableCmd{P1: v, P2: p.constVars[v]}) - } - slices.SortFunc(temp, func(a, b *vm.VariableCmd) int { - return a.P2 - b.P2 - }) - for i := range temp { - p.commands = append(p.commands, temp[i]) - } -} - -type updateNode struct { - child logicalNode - plan *QueryPlan - // updateExprs is formed from the update statement AST. The idea is to - // provide an expression for each column where the expression is either a - // columnRef or the complex expression from the right hand side of the SET - // keyword. Note it is important to provide the expressions in their correct - // ordinal position as the generator will not try to order them correctly. - // - // The row id is not allowed to be updated at the moment because it could - // cause infinite loops due to it changing the physical location of the - // record. The query plan will have to use a temporary storage to update - // primary keys. - updateExprs []compiler.Expr -} - func (u *updateNode) produce() { u.child.produce() } @@ -195,13 +46,6 @@ func (u *updateNode) consume() { }) } -type filterNode struct { - child logicalNode - parent logicalNode - plan *QueryPlan - predicate compiler.Expr -} - func (f *filterNode) produce() { f.child.produce() } @@ -212,11 +56,6 @@ func (f *filterNode) consume() { jumpCommand.SetJumpAddress(len(f.plan.commands)) } -type scanNode struct { - parent logicalNode - plan *QueryPlan -} - func (s *scanNode) produce() { s.consume() } @@ -233,18 +72,6 @@ func (s *scanNode) consume() { rewindCmd.P2 = len(s.plan.commands) } -type projection struct { - expr compiler.Expr - // alias is the alias of the projection or no alias for the zero value. - alias string -} - -type projectNode struct { - child logicalNode - plan *QueryPlan - projections []projection -} - func (p *projectNode) produce() { p.child.produce() } @@ -262,11 +89,6 @@ func (p *projectNode) consume() { }) } -type constantNode struct { - parent logicalNode - plan *QueryPlan -} - func (c *constantNode) produce() { c.consume() } @@ -275,11 +97,6 @@ func (c *constantNode) consume() { c.parent.consume() } -type countNode struct { - plan *QueryPlan - projection projection -} - func (c *countNode) produce() { c.consume() } @@ -375,8 +192,6 @@ func (n *insertNode) consume() { } } -func (n *joinNode) produce() { -} +func (n *joinNode) produce() {} -func (n *joinNode) consume() { -} +func (n *joinNode) consume() {} diff --git a/planner/node.go b/planner/node.go index 0492fda..c88d89f 100644 --- a/planner/node.go +++ b/planner/node.go @@ -1,14 +1,25 @@ package planner -import "github.com/chirst/cdb/compiler" +import ( + "fmt" + + "github.com/chirst/cdb/compiler" +) // This file defines the relational nodes in a logical query plan. // logicalNode defines the interface for a node in the query plan tree. type logicalNode interface { + // children returns the child nodes. children() []logicalNode + // print returns the string representation for explain. print() string + // produce works with consume to generate byte code in the nodes associated + // query plan. produce typically calls its children's produce methods until + // a leaf is reached. When the leaf is reached consume is called which emits + // byte code as the stack unwinds. produce() + // consume works with produce. consume() } @@ -23,6 +34,14 @@ type joinNode struct { operation string } +func (j *joinNode) print() string { + return fmt.Sprint(j.operation) +} + +func (j *joinNode) children() []logicalNode { + return []logicalNode{j.left, j.right} +} + // createNode represents a operation to create an object in the system catalog. // For example a table, index, or trigger. type createNode struct { @@ -46,6 +65,17 @@ type createNode struct { noop bool } +func (c *createNode) print() string { + if c.noop { + return fmt.Sprintf("assert table %s does not exist", c.tableName) + } + return fmt.Sprintf("create table %s", c.tableName) +} + +func (c *createNode) children() []logicalNode { + return []logicalNode{} +} + // insertNode represents an insert operation. type insertNode struct { plan *QueryPlan @@ -63,3 +93,109 @@ type insertNode struct { // generation. autoPk bool } + +func (i *insertNode) print() string { + return "insert" +} + +func (i *insertNode) children() []logicalNode { + return []logicalNode{} +} + +type countNode struct { + plan *QueryPlan + projection projection +} + +func (c *countNode) children() []logicalNode { + return []logicalNode{} +} + +func (c *countNode) print() string { + return "count table" +} + +type constantNode struct { + parent logicalNode + plan *QueryPlan +} + +func (c *constantNode) print() string { + return "constant data source" +} + +func (c *constantNode) children() []logicalNode { + return []logicalNode{} +} + +type projection struct { + expr compiler.Expr + // alias is the alias of the projection or no alias for the zero value. + alias string +} + +type projectNode struct { + child logicalNode + plan *QueryPlan + projections []projection +} + +func (p *projectNode) print() string { + return "project" +} + +func (p *projectNode) children() []logicalNode { + return []logicalNode{p.child} +} + +type scanNode struct { + parent logicalNode + plan *QueryPlan +} + +func (s *scanNode) print() string { + return "scan table" +} + +func (s *scanNode) children() []logicalNode { + return []logicalNode{} +} + +type filterNode struct { + child logicalNode + parent logicalNode + plan *QueryPlan + predicate compiler.Expr +} + +func (f *filterNode) print() string { + return "filter" +} + +func (f *filterNode) children() []logicalNode { + return []logicalNode{f.child} +} + +type updateNode struct { + child logicalNode + plan *QueryPlan + // updateExprs is formed from the update statement AST. The idea is to + // provide an expression for each column where the expression is either a + // columnRef or the complex expression from the right hand side of the SET + // keyword. Note it is important to provide the expressions in their correct + // ordinal position as the generator will not try to order them correctly. + // + // The row id is not allowed to be updated at the moment because it could + // cause infinite loops due to it changing the physical location of the + // record. The query plan will have to use a temporary storage to update + // primary keys. + updateExprs []compiler.Expr +} + +func (u *updateNode) print() string { + return "update" +} + +func (u *updateNode) children() []logicalNode { + return []logicalNode{} +} diff --git a/planner/plan.go b/planner/plan.go index 77128b7..79321ce 100644 --- a/planner/plan.go +++ b/planner/plan.go @@ -2,6 +2,7 @@ package planner import ( "fmt" + "slices" "strings" "unicode/utf8" @@ -67,6 +68,133 @@ func newQueryPlan( } } +// transactionType defines possible transactions for a query plan. +type transactionType int + +const ( + transactionTypeNone transactionType = 0 + transactionTypeRead transactionType = 1 + transactionTypeWrite transactionType = 2 +) + +// declareConstInt gets or sets a register with the const value and returns the +// register. It is guaranteed the value will be in the register for the duration +// of the plan. +func (p *QueryPlan) declareConstInt(i int) int { + _, ok := p.constInts[i] + if !ok { + p.constInts[i] = p.freeRegister + p.freeRegister += 1 + } + return p.constInts[i] +} + +// declareConstString gets or sets a register with the const value and returns +// the register. It is guaranteed the value will be in the register for the +// duration of the plan. +func (p *QueryPlan) declareConstString(s string) int { + _, ok := p.constStrings[s] + if !ok { + p.constStrings[s] = p.freeRegister + p.freeRegister += 1 + } + return p.constStrings[s] +} + +// declareConstVar gets or sets a register with the const value and returns +// the register. It is guaranteed the value will be in the register for the +// duration of the plan. +func (p *QueryPlan) declareConstVar(position int) int { + _, ok := p.constVars[position] + if !ok { + p.constVars[position] = p.freeRegister + p.freeRegister += 1 + } + return p.constVars[position] +} + +// compile sets byte code for the root node and it's children on commands. +func (p *QueryPlan) compile() { + initCmd := &vm.InitCmd{} + p.commands = append(p.commands, initCmd) + p.root.produce() + p.commands = append(p.commands, &vm.HaltCmd{}) + initCmd.P2 = len(p.commands) + p.pushTransaction() + // these constants are pushed ordered since maps are unordered making it + // difficult to assert that a sequence of instructions appears. + p.pushConstantInts() + p.pushConstantStrings() + p.pushConstantVars() + p.commands = append(p.commands, &vm.GotoCmd{P2: 1}) +} + +func (p *QueryPlan) pushTransaction() { + switch p.transactionType { + case transactionTypeNone: + return + case transactionTypeRead: + p.commands = append( + p.commands, + &vm.TransactionCmd{P2: 0}, + ) + p.commands = append( + p.commands, + &vm.OpenReadCmd{P1: p.cursorId, P2: p.rootPageNumber}, + ) + case transactionTypeWrite: + p.commands = append( + p.commands, + &vm.TransactionCmd{P2: 1}, + ) + p.commands = append( + p.commands, + &vm.OpenWriteCmd{P1: p.cursorId, P2: p.rootPageNumber}, + ) + default: + panic("unexpected transaction type") + } +} + +func (p *QueryPlan) pushConstantInts() { + temp := []*vm.IntegerCmd{} + for k := range p.constInts { + temp = append(temp, &vm.IntegerCmd{P1: k, P2: p.constInts[k]}) + } + slices.SortFunc(temp, func(a, b *vm.IntegerCmd) int { + return a.P2 - b.P2 + }) + for i := range temp { + p.commands = append(p.commands, temp[i]) + } +} + +func (p *QueryPlan) pushConstantStrings() { + temp := []*vm.StringCmd{} + for v := range p.constStrings { + temp = append(temp, &vm.StringCmd{P1: p.constStrings[v], P4: v}) + } + slices.SortFunc(temp, func(a, b *vm.StringCmd) int { + return a.P1 - b.P1 + }) + for i := range temp { + p.commands = append(p.commands, temp[i]) + } +} + +func (p *QueryPlan) pushConstantVars() { + temp := []*vm.VariableCmd{} + for v := range p.constVars { + temp = append(temp, &vm.VariableCmd{P1: v, P2: p.constVars[v]}) + } + slices.SortFunc(temp, func(a, b *vm.VariableCmd) int { + return a.P2 - b.P2 + }) + for i := range temp { + p.commands = append(p.commands, temp[i]) + } +} + // ToString evaluates and returns the query plan as a string representation. func (p *QueryPlan) ToString() string { qp := &QueryPlan{} @@ -142,78 +270,3 @@ func (p *QueryPlan) connectSiblings() string { } return strings.Join(planMatrix, "\n") } - -func (p *projectNode) print() string { - return "project" -} - -func (p *projectNode) children() []logicalNode { - return []logicalNode{p.child} -} - -func (s *scanNode) print() string { - return "scan table" -} - -func (c *constantNode) print() string { - return "constant data source" -} - -func (c *countNode) print() string { - return "count table" -} - -func (j *joinNode) print() string { - return fmt.Sprint(j.operation) -} - -func (c *createNode) print() string { - if c.noop { - return fmt.Sprintf("assert table %s does not exist", c.tableName) - } - return fmt.Sprintf("create table %s", c.tableName) -} - -func (i *insertNode) print() string { - return "insert" -} - -func (u *updateNode) print() string { - return "update" -} - -func (f *filterNode) print() string { - return "filter" -} - -func (s *scanNode) children() []logicalNode { - return []logicalNode{} -} - -func (c *constantNode) children() []logicalNode { - return []logicalNode{} -} - -func (c *countNode) children() []logicalNode { - return []logicalNode{} -} - -func (j *joinNode) children() []logicalNode { - return []logicalNode{j.left, j.right} -} - -func (c *createNode) children() []logicalNode { - return []logicalNode{} -} - -func (i *insertNode) children() []logicalNode { - return []logicalNode{} -} - -func (u *updateNode) children() []logicalNode { - return []logicalNode{} -} - -func (f *filterNode) children() []logicalNode { - return []logicalNode{f.child} -} From c5b18158e92e7263970ff992315f6754410241b0 Mon Sep 17 00:00:00 2001 From: Colton Date: Sat, 3 Jan 2026 01:46:28 -0700 Subject: [PATCH 16/17] move cursor and root page off plan and bring back table names in explain query plan --- planner/create.go | 18 +++++---- planner/create_test.go | 9 ++--- planner/generator.go | 54 ++++++++++++++++++------- planner/insert.go | 6 ++- planner/insert_test.go | 28 ++++++------- planner/node.go | 46 ++++++++++++++++++++-- planner/plan.go | 19 --------- planner/plan_test.go | 26 +++++++----- planner/predicate_generator.go | 11 ++++-- planner/result_generator.go | 11 ++++-- planner/select.go | 26 +++++++++--- planner/select_test.go | 72 +++++++++++++++++----------------- planner/update.go | 15 +++++-- planner/update_test.go | 18 ++++----- 14 files changed, 224 insertions(+), 135 deletions(-) diff --git a/planner/create.go b/planner/create.go index 8711875..7165348 100644 --- a/planner/create.go +++ b/planner/create.go @@ -49,15 +49,16 @@ func (p *createPlanner) QueryPlan() (*QueryPlan, error) { tableExists := p.catalog.TableExists(p.stmt.TableName) if p.stmt.IfNotExists && tableExists { noopCreateNode := &createNode{ - noop: true, - tableName: p.stmt.TableName, + noop: true, + tableName: p.stmt.TableName, + catalogRootPageNumber: schemaTableRoot, + catalogCursorId: 1, } p.queryPlan = noopCreateNode qp := newQueryPlan( noopCreateNode, p.stmt.ExplainQueryPlan, transactionTypeWrite, - schemaTableRoot, ) noopCreateNode.plan = qp return qp, nil @@ -70,17 +71,18 @@ func (p *createPlanner) QueryPlan() (*QueryPlan, error) { return nil, err } createNode := &createNode{ - objectType: "table", - objectName: p.stmt.TableName, - tableName: p.stmt.TableName, - schema: jSchema, + objectType: "table", + objectName: p.stmt.TableName, + tableName: p.stmt.TableName, + schema: jSchema, + catalogRootPageNumber: schemaTableRoot, + catalogCursorId: 1, } p.queryPlan = createNode qp := newQueryPlan( createNode, p.stmt.ExplainQueryPlan, transactionTypeWrite, - schemaTableRoot, ) createNode.plan = qp return qp, nil diff --git a/planner/create_test.go b/planner/create_test.go index c70d6a7..2f7895e 100644 --- a/planner/create_test.go +++ b/planner/create_test.go @@ -54,7 +54,8 @@ func TestCreateWithNoIDColumn(t *testing.T) { t.Fatalf("failed to convert expected schema to json %s", err) } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 12}, + &vm.InitCmd{P2: 13}, + &vm.OpenWriteCmd{P1: 1, P2: 1}, &vm.CreateBTreeCmd{P2: 1}, &vm.NewRowIdCmd{P1: 1, P2: 2}, &vm.StringCmd{P1: 3, P4: "table"}, @@ -67,7 +68,6 @@ func TestCreateWithNoIDColumn(t *testing.T) { &vm.ParseSchemaCmd{}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 1}, &vm.GotoCmd{P2: 1}, } plan, err := NewCreate(mc, stmt).ExecutionPlan() @@ -112,7 +112,8 @@ func TestCreateWithAlternateNamedIDColumn(t *testing.T) { t.Fatalf("failed to convert expected schema to json %s", err) } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 12}, + &vm.InitCmd{P2: 13}, + &vm.OpenWriteCmd{P1: 1, P2: 1}, &vm.CreateBTreeCmd{P2: 1}, &vm.NewRowIdCmd{P1: 1, P2: 2}, &vm.StringCmd{P1: 3, P4: "table"}, @@ -125,7 +126,6 @@ func TestCreateWithAlternateNamedIDColumn(t *testing.T) { &vm.ParseSchemaCmd{}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 1}, &vm.GotoCmd{P2: 1}, } plan, err := NewCreate(mc, stmt).ExecutionPlan() @@ -215,7 +215,6 @@ func TestCreateIfNotExistsNoop(t *testing.T) { &vm.InitCmd{P2: 2}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 1}, &vm.GotoCmd{P2: 1}, } plan, err := NewCreate(mc, stmt).ExecutionPlan() diff --git a/planner/generator.go b/planner/generator.go index 12661aa..aec6da3 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -11,7 +11,7 @@ func (u *updateNode) produce() { func (u *updateNode) consume() { // RowID u.plan.commands = append(u.plan.commands, &vm.RowIdCmd{ - P1: u.plan.cursorId, + P1: u.cursorId, P2: u.plan.freeRegister, }) rowIdRegister := u.plan.freeRegister @@ -23,7 +23,7 @@ func (u *updateNode) consume() { u.plan.freeRegister += len(u.updateExprs) recordRegisterCount := len(u.updateExprs) for i, e := range u.updateExprs { - generateExpressionTo(u.plan, e, startRecordRegister+i) + generateExpressionTo(u.plan, e, startRecordRegister+i, u.cursorId) } // Make the record for inserting @@ -37,10 +37,10 @@ func (u *updateNode) consume() { // Update by deleting then inserting u.plan.commands = append(u.plan.commands, &vm.DeleteCmd{ - P1: u.plan.cursorId, + P1: u.cursorId, }) u.plan.commands = append(u.plan.commands, &vm.InsertCmd{ - P1: u.plan.cursorId, + P1: u.cursorId, P2: recordRegister, P3: rowIdRegister, }) @@ -51,7 +51,7 @@ func (f *filterNode) produce() { } func (f *filterNode) consume() { - jumpCommand := generatePredicate(f.plan, f.predicate) + jumpCommand := generatePredicate(f.plan, f.predicate, f.cursorId) f.parent.consume() jumpCommand.SetJumpAddress(len(f.plan.commands)) } @@ -61,12 +61,23 @@ func (s *scanNode) produce() { } func (s *scanNode) consume() { - rewindCmd := &vm.RewindCmd{P1: s.plan.cursorId} + if s.isWriteCursor { + s.plan.commands = append( + s.plan.commands, + &vm.OpenWriteCmd{P1: s.cursorId, P2: s.rootPageNumber}, + ) + } else { + s.plan.commands = append( + s.plan.commands, + &vm.OpenReadCmd{P1: s.cursorId, P2: s.rootPageNumber}, + ) + } + rewindCmd := &vm.RewindCmd{P1: s.cursorId} s.plan.commands = append(s.plan.commands, rewindCmd) loopBeginAddress := len(s.plan.commands) s.parent.consume() s.plan.commands = append(s.plan.commands, &vm.NextCmd{ - P1: s.plan.cursorId, + P1: s.cursorId, P2: loopBeginAddress, }) rewindCmd.P2 = len(s.plan.commands) @@ -81,7 +92,7 @@ func (p *projectNode) consume() { reservedRegisters := len(p.projections) p.plan.freeRegister += reservedRegisters for i, projection := range p.projections { - generateExpressionTo(p.plan, projection.expr, startRegister+i) + generateExpressionTo(p.plan, projection.expr, startRegister+i, p.cursorId) } p.plan.commands = append(p.plan.commands, &vm.ResultRowCmd{ P1: startRegister, @@ -102,8 +113,12 @@ func (c *countNode) produce() { } func (c *countNode) consume() { + c.plan.commands = append( + c.plan.commands, + &vm.OpenReadCmd{P1: c.cursorId, P2: c.rootPageNumber}, + ) c.plan.commands = append(c.plan.commands, &vm.CountCmd{ - P1: c.plan.cursorId, + P1: c.cursorId, P2: c.plan.freeRegister, }) countRegister := c.plan.freeRegister @@ -123,15 +138,19 @@ func (c *createNode) consume() { if c.noop { return } + c.plan.commands = append( + c.plan.commands, + &vm.OpenWriteCmd{P1: c.catalogCursorId, P2: c.catalogRootPageNumber}, + ) c.plan.commands = append(c.plan.commands, &vm.CreateBTreeCmd{P2: 1}) - c.plan.commands = append(c.plan.commands, &vm.NewRowIdCmd{P1: c.plan.cursorId, P2: 2}) + c.plan.commands = append(c.plan.commands, &vm.NewRowIdCmd{P1: c.catalogCursorId, P2: 2}) c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 3, P4: c.objectType}) c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 4, P4: c.objectName}) c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 5, P4: c.tableName}) c.plan.commands = append(c.plan.commands, &vm.CopyCmd{P1: 1, P2: 6}) c.plan.commands = append(c.plan.commands, &vm.StringCmd{P1: 7, P4: string(c.schema)}) c.plan.commands = append(c.plan.commands, &vm.MakeRecordCmd{P1: 3, P2: 5, P3: 8}) - c.plan.commands = append(c.plan.commands, &vm.InsertCmd{P1: c.plan.cursorId, P2: 8, P3: 2}) + c.plan.commands = append(c.plan.commands, &vm.InsertCmd{P1: c.catalogCursorId, P2: 8, P3: 2}) c.plan.commands = append(c.plan.commands, &vm.ParseSchemaCmd{}) } @@ -140,20 +159,24 @@ func (n *insertNode) produce() { } func (n *insertNode) consume() { + n.plan.commands = append( + n.plan.commands, + &vm.OpenWriteCmd{P1: n.cursorId, P2: n.rootPageNumber}, + ) for valuesIdx := range len(n.colValues) { // Setup rowid and it's uniqueness/type checks pkRegister := n.plan.freeRegister n.plan.freeRegister += 1 if n.autoPk { n.plan.commands = append(n.plan.commands, &vm.NewRowIdCmd{ - P1: n.plan.cursorId, + P1: n.cursorId, P2: pkRegister, }) } else { - generateExpressionTo(n.plan, n.pkValues[valuesIdx], pkRegister) + generateExpressionTo(n.plan, n.pkValues[valuesIdx], pkRegister, n.cursorId) n.plan.commands = append(n.plan.commands, &vm.MustBeIntCmd{P1: pkRegister}) nec := &vm.NotExistsCmd{ - P1: n.plan.cursorId, + P1: n.cursorId, P3: pkRegister, } n.plan.commands = append(n.plan.commands, nec) @@ -173,6 +196,7 @@ func (n *insertNode) consume() { n.plan, n.colValues[valuesIdx][vi], startRegister+vi, + n.cursorId, ) } @@ -185,7 +209,7 @@ func (n *insertNode) consume() { recordRegister := n.plan.freeRegister n.plan.freeRegister += 1 n.plan.commands = append(n.plan.commands, &vm.InsertCmd{ - P1: n.plan.cursorId, + P1: n.cursorId, P2: recordRegister, P3: pkRegister, }) diff --git a/planner/insert.go b/planner/insert.go index 4ac7ce7..401a884 100644 --- a/planner/insert.go +++ b/planner/insert.go @@ -61,7 +61,10 @@ func (p *insertPlanner) QueryPlan() (*QueryPlan, error) { return nil, err } insertNode := &insertNode{ - colValues: colValues, + colValues: colValues, + rootPageNumber: rootPage, + tableName: p.stmt.TableName, + cursorId: 1, } if err := p.setPkValues(insertNode); err != nil { return nil, err @@ -71,7 +74,6 @@ func (p *insertPlanner) QueryPlan() (*QueryPlan, error) { insertNode, p.stmt.ExplainQueryPlan, transactionTypeWrite, - rootPage, ) insertNode.plan = qp return qp, nil diff --git a/planner/insert_test.go b/planner/insert_test.go index 589bb07..6130788 100644 --- a/planner/insert_test.go +++ b/planner/insert_test.go @@ -37,7 +37,8 @@ func (m *mockInsertCatalog) GetPrimaryKeyColumn(tableName string) (string, error func TestInsertWithoutPrimaryKey(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 17}, + &vm.InitCmd{P2: 18}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.NewRowIdCmd{P1: 1, P2: 1}, &vm.CopyCmd{P1: 4, P2: 2}, &vm.CopyCmd{P1: 5, P2: 3}, @@ -55,7 +56,6 @@ func TestInsertWithoutPrimaryKey(t *testing.T) { &vm.InsertCmd{P1: 1, P2: 18, P3: 13}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.StringCmd{P1: 4, P4: "gud"}, &vm.StringCmd{P1: 5, P4: "dude"}, &vm.StringCmd{P1: 10, P4: "joe"}, @@ -100,17 +100,17 @@ func TestInsertWithoutPrimaryKey(t *testing.T) { func TestInsertWithPrimaryKey(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 9}, + &vm.InitCmd{P2: 10}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.CopyCmd{P1: 2, P2: 1}, &vm.MustBeIntCmd{P1: 1}, - &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, &vm.CopyCmd{P1: 4, P2: 3}, &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.IntegerCmd{P1: 22, P2: 2}, &vm.StringCmd{P1: 4, P4: "gud"}, &vm.GotoCmd{P2: 1}, @@ -144,17 +144,17 @@ func TestInsertWithPrimaryKey(t *testing.T) { func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 9}, + &vm.InitCmd{P2: 10}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.CopyCmd{P1: 2, P2: 1}, &vm.MustBeIntCmd{P1: 1}, - &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, &vm.CopyCmd{P1: 4, P2: 3}, &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.IntegerCmd{P1: 12, P2: 2}, &vm.StringCmd{P1: 4, P4: "feller"}, &vm.GotoCmd{P2: 1}, @@ -188,17 +188,17 @@ func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { func TestInsertWithPrimaryKeyParameter(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 9}, + &vm.InitCmd{P2: 10}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.CopyCmd{P1: 2, P2: 1}, &vm.MustBeIntCmd{P1: 1}, - &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, &vm.CopyCmd{P1: 4, P2: 3}, &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.StringCmd{P1: 4, P4: "feller"}, &vm.VariableCmd{P1: 0, P2: 2}, &vm.GotoCmd{P2: 1}, @@ -232,17 +232,17 @@ func TestInsertWithPrimaryKeyParameter(t *testing.T) { func TestInsertWithParameter(t *testing.T) { expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 9}, + &vm.InitCmd{P2: 10}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.CopyCmd{P1: 2, P2: 1}, &vm.MustBeIntCmd{P1: 1}, - &vm.NotExistsCmd{P1: 1, P2: 5, P3: 1}, + &vm.NotExistsCmd{P1: 1, P2: 6, P3: 1}, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}, &vm.CopyCmd{P1: 4, P2: 3}, &vm.MakeRecordCmd{P1: 3, P2: 1, P3: 5}, &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.VariableCmd{P1: 0, P2: 2}, &vm.VariableCmd{P1: 1, P2: 4}, &vm.GotoCmd{P2: 1}, diff --git a/planner/node.go b/planner/node.go index c88d89f..abb4856 100644 --- a/planner/node.go +++ b/planner/node.go @@ -63,6 +63,11 @@ type createNode struct { // because the query plan would be invalidated given the existence of the object // has changed between query planning and query execution. noop bool + // rootPageNumber is the page number of the system catalog. + catalogRootPageNumber int + // catalogCursorId is the id of the cursor associated with the system + // catalog table being updated. + catalogCursorId int } func (c *createNode) print() string { @@ -92,6 +97,13 @@ type insertNode struct { // autoPk indicates the generator should use a NewRowIdCmd for pk // generation. autoPk bool + // tableName is the name of the table being inserted to. + tableName string + // rootPageNumber is the page number of the table being inserted to. + rootPageNumber int + // cursorId is the id of the cursor associated with the table being inserted + // to. + cursorId int } func (i *insertNode) print() string { @@ -105,6 +117,12 @@ func (i *insertNode) children() []logicalNode { type countNode struct { plan *QueryPlan projection projection + // tableName is the name of the table being scanned. + tableName string + // rootPageNumber is the page number of the table being scanned. + rootPageNumber int + // cursorId is the id of the cursor associated with the table being scanned. + cursorId int } func (c *countNode) children() []logicalNode { @@ -112,7 +130,7 @@ func (c *countNode) children() []logicalNode { } func (c *countNode) print() string { - return "count table" + return fmt.Sprintf("count table %s", c.tableName) } type constantNode struct { @@ -138,6 +156,10 @@ type projectNode struct { child logicalNode plan *QueryPlan projections []projection + // cursorId is the id of the cursor associated with the table being + // projected. In the future this will likely need to be enhanced since + // projections are not entirely meant for one table. + cursorId int } func (p *projectNode) print() string { @@ -151,10 +173,18 @@ func (p *projectNode) children() []logicalNode { type scanNode struct { parent logicalNode plan *QueryPlan + // tableName is the name of the table being scanned. + tableName string + // rootPageNumber is the page number of the table being scanned. + rootPageNumber int + // cursorId is the id of the cursor associated with the table being scanned. + cursorId int + // isWriteCursor is true when the cursor should be a write cursor. + isWriteCursor bool } func (s *scanNode) print() string { - return "scan table" + return fmt.Sprintf("scan table %s", s.tableName) } func (s *scanNode) children() []logicalNode { @@ -166,6 +196,10 @@ type filterNode struct { parent logicalNode plan *QueryPlan predicate compiler.Expr + // cursorId is the id of the cursor associated with the table being filtered. + // In the future this will likely need to be enhanced since filters are not + // entirely meant for one table. + cursorId int } func (f *filterNode) print() string { @@ -190,10 +224,16 @@ type updateNode struct { // record. The query plan will have to use a temporary storage to update // primary keys. updateExprs []compiler.Expr + // tableName is the name of the table being updated. + tableName string + // rootPageNumber is the page number of the table being updated. + rootPageNumber int + // cursorId is the id of the cursor associated with the table being updated. + cursorId int } func (u *updateNode) print() string { - return "update" + return fmt.Sprintf("update table %s", u.tableName) } func (u *updateNode) children() []logicalNode { diff --git a/planner/plan.go b/planner/plan.go index 79321ce..222e431 100644 --- a/planner/plan.go +++ b/planner/plan.go @@ -38,21 +38,12 @@ type QueryPlan struct { freeRegister int // transactionType defines what kind of transaction the plan will need. transactionType transactionType - // cursorId is the id of the cursor the plan is using. Note plans will - // eventually need to use more than one cursor, but for now it is convenient - // to pull the id from here. - cursorId int - // rootPageNumber is the root page number of the table cursorId is - // associated with. This should be a map at some point when multiple tables - // can be queried in one plan. - rootPageNumber int } func newQueryPlan( root logicalNode, explainQueryPlan bool, transactionType transactionType, - rootPageNumber int, ) *QueryPlan { return &QueryPlan{ root: root, @@ -63,8 +54,6 @@ func newQueryPlan( constVars: make(map[int]int), freeRegister: 1, transactionType: transactionType, - cursorId: 1, - rootPageNumber: rootPageNumber, } } @@ -138,19 +127,11 @@ func (p *QueryPlan) pushTransaction() { p.commands, &vm.TransactionCmd{P2: 0}, ) - p.commands = append( - p.commands, - &vm.OpenReadCmd{P1: p.cursorId, P2: p.rootPageNumber}, - ) case transactionTypeWrite: p.commands = append( p.commands, &vm.TransactionCmd{P2: 1}, ) - p.commands = append( - p.commands, - &vm.OpenWriteCmd{P1: p.cursorId, P2: p.rootPageNumber}, - ) default: panic("unexpected transaction type") } diff --git a/planner/plan_test.go b/planner/plan_test.go index eb5b9e6..e948777 100644 --- a/planner/plan_test.go +++ b/planner/plan_test.go @@ -8,27 +8,35 @@ func TestExplainQueryPlan(t *testing.T) { operation: "join", left: &joinNode{ operation: "join", - left: &scanNode{}, + left: &scanNode{ + tableName: "foo", + }, right: &joinNode{ operation: "join", - left: &scanNode{}, - right: &scanNode{}, + left: &scanNode{ + tableName: "bar", + }, + right: &scanNode{ + tableName: "baz", + }, }, }, - right: &scanNode{}, + right: &scanNode{ + tableName: "buzz", + }, }, } - qp := newQueryPlan(root, true, transactionTypeRead, 0) + qp := newQueryPlan(root, true, transactionTypeRead) formattedResult := qp.ToString() expectedResult := "" + " ── project\n" + " └─ join\n" + " ├─ join\n" + - " | ├─ scan table\n" + + " | ├─ scan table foo\n" + " | └─ join\n" + - " | ├─ scan table\n" + - " | └─ scan table\n" + - " └─ scan table\n" + " | ├─ scan table bar\n" + + " | └─ scan table baz\n" + + " └─ scan table buzz\n" if formattedResult != expectedResult { t.Fatalf("got\n%s\nwant\n%s", formattedResult, expectedResult) } diff --git a/planner/predicate_generator.go b/planner/predicate_generator.go index 9380568..3a919f2 100644 --- a/planner/predicate_generator.go +++ b/planner/predicate_generator.go @@ -8,9 +8,10 @@ import ( // generatePredicate generates code to make a boolean jump for the given // expression within the plan context. The function returns the jump command to // lazily set the jump address. -func generatePredicate(plan *QueryPlan, expression compiler.Expr) vm.JumpCommand { +func generatePredicate(plan *QueryPlan, expression compiler.Expr, cursorId int) vm.JumpCommand { pg := &predicateGenerator{} pg.plan = plan + pg.cursorId = cursorId pg.build(expression, 0) return pg.jumpCommand } @@ -22,6 +23,10 @@ type predicateGenerator struct { // jumpCommand is the command used to make the jump. The command can be // accessed to defer setting the jump address. jumpCommand vm.JumpCommand + // cursorId is the id of the cursor for the table in the associated query. + // This will need to be enhanced at some point to support more than one + // alias in a predicate, but is fine for now. + cursorId int } func (p *predicateGenerator) build(e compiler.Expr, level int) (int, error) { @@ -183,14 +188,14 @@ func (p *predicateGenerator) valueRegisterFor(ce *compiler.ColumnRef) int { if ce.IsPrimaryKey { r := p.getNextRegister() p.plan.commands = append(p.plan.commands, &vm.RowIdCmd{ - P1: p.plan.cursorId, + P1: p.cursorId, P2: r, }) return r } r := p.getNextRegister() p.plan.commands = append(p.plan.commands, &vm.ColumnCmd{ - P1: p.plan.cursorId, + P1: p.cursorId, P2: ce.ColIdx, P3: r, }) return r diff --git a/planner/result_generator.go b/planner/result_generator.go index 89ff355..0e0a93f 100644 --- a/planner/result_generator.go +++ b/planner/result_generator.go @@ -7,10 +7,11 @@ import ( // generateExpressionTo takes the context of the plan and generates commands // that land the result of the given expr in the toRegister. -func generateExpressionTo(plan *QueryPlan, expr compiler.Expr, toRegister int) { +func generateExpressionTo(plan *QueryPlan, expr compiler.Expr, toRegister int, cursorId int) { rg := &resultExprGenerator{} rg.plan = plan rg.outputRegister = toRegister + rg.cursorId = cursorId rg.build(expr, 0) } @@ -19,6 +20,10 @@ type resultExprGenerator struct { plan *QueryPlan // outputRegister is the target register for the result of the expression. outputRegister int + // cursorId is the id of the cursor for the table in the associated query. + // This will need to be enhanced at some point to support more than one + // aliased column in the results, but is fine for now. + cursorId int } func (e *resultExprGenerator) build(root compiler.Expr, level int) int { @@ -72,11 +77,11 @@ func (e *resultExprGenerator) build(root compiler.Expr, level int) int { case *compiler.ColumnRef: r := e.getNextRegister(level) if n.IsPrimaryKey { - e.plan.commands = append(e.plan.commands, &vm.RowIdCmd{P1: e.plan.cursorId, P2: r}) + e.plan.commands = append(e.plan.commands, &vm.RowIdCmd{P1: e.cursorId, P2: r}) } else { e.plan.commands = append( e.plan.commands, - &vm.ColumnCmd{P1: e.plan.cursorId, P2: n.ColIdx, P3: r}, + &vm.ColumnCmd{P1: e.cursorId, P2: n.ColIdx, P3: r}, ) } return r diff --git a/planner/select.go b/planner/select.go index 45967f4..71f7e7c 100644 --- a/planner/select.go +++ b/planner/select.go @@ -88,12 +88,16 @@ func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { if tableName == "" { return nil, errors.New("must have from for COUNT") } - cn := &countNode{projection: projections[0]} + cn := &countNode{ + projection: projections[0], + rootPageNumber: rootPageNumber, + tableName: tableName, + cursorId: 1, + } plan := newQueryPlan( cn, p.stmt.ExplainQueryPlan, transactionTypeRead, - rootPageNumber, ) cn.plan = plan p.queryPlan = plan @@ -104,8 +108,11 @@ func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { if tableName == "" { tt = transactionTypeNone } - projectNode := &projectNode{projections: projections} - plan := newQueryPlan(projectNode, p.stmt.ExplainQueryPlan, tt, rootPageNumber) + projectNode := &projectNode{ + projections: projections, + cursorId: 1, + } + plan := newQueryPlan(projectNode, p.stmt.ExplainQueryPlan, tt) projectNode.plan = plan if p.stmt.Where != nil { cev := &catalogExprVisitor{} @@ -115,6 +122,7 @@ func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { parent: projectNode, plan: plan, predicate: p.stmt.Where, + cursorId: 1, } projectNode.child = filterNode if tableName == "" { @@ -125,7 +133,10 @@ func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { constNode.parent = filterNode } else { scanNode := &scanNode{ - plan: plan, + plan: plan, + tableName: tableName, + rootPageNumber: rootPageNumber, + cursorId: 1, } filterNode.child = scanNode scanNode.parent = filterNode @@ -139,7 +150,10 @@ func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { constNode.parent = projectNode } else { scanNode := &scanNode{ - plan: plan, + plan: plan, + tableName: tableName, + rootPageNumber: rootPageNumber, + cursorId: 1, } projectNode.child = scanNode scanNode.parent = projectNode diff --git a/planner/select_test.go b/planner/select_test.go index 8d32de0..87adb7a 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -58,15 +58,15 @@ func TestSelectPlan(t *testing.T) { { description: "StarWithPrimaryKey", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 7}, - &vm.RewindCmd{P1: 1, P2: 6}, + &vm.InitCmd{P2: 8}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ @@ -88,15 +88,15 @@ func TestSelectPlan(t *testing.T) { { description: "StarWithoutPrimaryKey", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 7}, - &vm.RewindCmd{P1: 1, P2: 6}, + &vm.InitCmd{P2: 8}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 1}, &vm.ColumnCmd{P1: 1, P2: 1, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ @@ -125,16 +125,16 @@ func TestSelectPlan(t *testing.T) { }, }, expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 8}, - &vm.RewindCmd{P1: 1, P2: 7}, + &vm.InitCmd{P2: 9}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 8}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 1}, &vm.RowIdCmd{P1: 1, P2: 2}, &vm.ColumnCmd{P1: 1, P2: 1, P3: 3}, &vm.ResultRowCmd{P1: 1, P2: 3}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.GotoCmd{P2: 1}, }, mockCatalogSetup: func(m *mockSelectCatalog) *mockSelectCatalog { @@ -169,15 +169,15 @@ func TestSelectPlan(t *testing.T) { }, }, expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 7}, - &vm.RewindCmd{P1: 1, P2: 6}, + &vm.InitCmd{P2: 8}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.RowIdCmd{P1: 1, P2: 2}, &vm.AddCmd{P1: 2, P2: 3, P3: 1}, &vm.ResultRowCmd{P1: 1, P2: 1}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.IntegerCmd{P1: 10, P2: 3}, &vm.GotoCmd{P2: 1}, }, @@ -191,15 +191,15 @@ func TestSelectPlan(t *testing.T) { { description: "AllTable", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 7}, - &vm.RewindCmd{P1: 1, P2: 6}, + &vm.InitCmd{P2: 8}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ @@ -226,14 +226,14 @@ func TestSelectPlan(t *testing.T) { { description: "SpecificColumnPrimaryKeyMiddleOrdinal", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 6}, - &vm.RewindCmd{P1: 1, P2: 5}, + &vm.InitCmd{P2: 7}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 6}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ResultRowCmd{P1: 1, P2: 1}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ @@ -263,15 +263,15 @@ func TestSelectPlan(t *testing.T) { { description: "SpecificColumns", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 7}, - &vm.RewindCmd{P1: 1, P2: 6}, + &vm.InitCmd{P2: 8}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 7}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 1, P3: 2}, &vm.ResultRowCmd{P1: 1, P2: 2}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ @@ -306,12 +306,12 @@ func TestSelectPlan(t *testing.T) { { description: "JustCountAggregate", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 4}, + &vm.InitCmd{P2: 5}, + &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.CountCmd{P1: 1, P2: 1}, &vm.ResultRowCmd{P1: 1, P2: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.GotoCmd{P2: 1}, }, ast: &compiler.SelectStmt{ @@ -387,18 +387,18 @@ func TestSelectPlan(t *testing.T) { }, }, { - description: "with where clause", + description: "WithWhereClause", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 8}, - &vm.RewindCmd{P1: 1, P2: 7}, + &vm.InitCmd{P2: 9}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 8}, &vm.RowIdCmd{P1: 1, P2: 1}, - &vm.NotEqualCmd{P1: 1, P2: 6, P3: 2}, + &vm.NotEqualCmd{P1: 1, P2: 7, P3: 2}, &vm.RowIdCmd{P1: 1, P2: 4}, &vm.ResultRowCmd{P1: 4, P2: 1}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, - &vm.OpenReadCmd{P1: 1, P2: 2}, &vm.IntegerCmd{P1: 1, P2: 2}, &vm.GotoCmd{P2: 1}, }, diff --git a/planner/update.go b/planner/update.go index 9bae097..8430710 100644 --- a/planner/update.go +++ b/planner/update.go @@ -45,12 +45,16 @@ func (p *updatePlanner) QueryPlan() (*QueryPlan, error) { if err != nil { return nil, errTableNotExist } - updateNode := &updateNode{updateExprs: []compiler.Expr{}} + updateNode := &updateNode{ + updateExprs: []compiler.Expr{}, + tableName: p.stmt.TableName, + rootPageNumber: rootPage, + cursorId: 1, + } logicalPlan := newQueryPlan( updateNode, p.stmt.ExplainQueryPlan, transactionTypeWrite, - rootPage, ) updateNode.plan = logicalPlan p.queryPlan = updateNode @@ -69,7 +73,11 @@ func (p *updatePlanner) QueryPlan() (*QueryPlan, error) { } scanNode := &scanNode{ - plan: logicalPlan, + plan: logicalPlan, + tableName: p.stmt.TableName, + rootPageNumber: rootPage, + cursorId: 1, + isWriteCursor: true, } if p.stmt.Predicate != nil { cev := &catalogExprVisitor{} @@ -80,6 +88,7 @@ func (p *updatePlanner) QueryPlan() (*QueryPlan, error) { predicate: p.stmt.Predicate, parent: updateNode, child: scanNode, + cursorId: 1, } updateNode.child = filterNode scanNode.parent = filterNode diff --git a/planner/update_test.go b/planner/update_test.go index 10d13e1..7e45980 100644 --- a/planner/update_test.go +++ b/planner/update_test.go @@ -55,18 +55,18 @@ func TestUpdate(t *testing.T) { }, } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 10}, - &vm.RewindCmd{P1: 1, P2: 9}, + &vm.InitCmd{P2: 11}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 10}, &vm.RowIdCmd{P1: 1, P2: 1}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 2}, &vm.CopyCmd{P1: 4, P2: 3}, &vm.MakeRecordCmd{P1: 2, P2: 2, P3: 5}, &vm.DeleteCmd{P1: 1}, &vm.InsertCmd{P1: 1, P2: 5, P3: 1}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.IntegerCmd{P1: 1, P2: 4}, &vm.GotoCmd{P2: 1}, } @@ -101,20 +101,20 @@ func TestUpdateWithWhere(t *testing.T) { }, } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 12}, - &vm.RewindCmd{P1: 1, P2: 11}, + &vm.InitCmd{P2: 13}, + &vm.OpenWriteCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 12}, &vm.RowIdCmd{P1: 1, P2: 1}, - &vm.NotEqualCmd{P1: 1, P2: 10, P3: 2}, + &vm.NotEqualCmd{P1: 1, P2: 11, P3: 2}, &vm.RowIdCmd{P1: 1, P2: 4}, &vm.ColumnCmd{P1: 1, P2: 0, P3: 5}, &vm.CopyCmd{P1: 2, P2: 6}, &vm.MakeRecordCmd{P1: 5, P2: 2, P3: 7}, &vm.DeleteCmd{P1: 1}, &vm.InsertCmd{P1: 1, P2: 7, P3: 4}, - &vm.NextCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, - &vm.OpenWriteCmd{P1: 1, P2: 2}, &vm.IntegerCmd{P1: 1, P2: 2}, &vm.GotoCmd{P2: 1}, } From be274496bc50cfe8e06d115e80e7bbf604a5f336 Mon Sep 17 00:00:00 2001 From: Colton Date: Sat, 3 Jan 2026 02:10:08 -0700 Subject: [PATCH 17/17] insert supports expressions --- compiler/parser.go | 25 ++---------- compiler/parser_test.go | 88 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 85 insertions(+), 28 deletions(-) diff --git a/compiler/parser.go b/compiler/parser.go index 3332ecb..7914349 100644 --- a/compiler/parser.go +++ b/compiler/parser.go @@ -363,28 +363,11 @@ func (p *parser) parseValue(stmt *InsertStmt, valueIdx int) (*InsertStmt, error) } stmt.ColValues = append(stmt.ColValues, []Expr{}) for { - v := p.nextNonSpace() - if v.tokenType != tkNumeric && v.tokenType != tkLiteral && v.tokenType != tkParam { - return nil, fmt.Errorf(literalErr, v.value) - } - if v.tokenType == tkLiteral { - stmt.ColValues[valueIdx] = append( - stmt.ColValues[valueIdx], - &StringLit{ - Value: v.value, - }, - ) - } else if v.tokenType == tkParam { - variableExpr := &Variable{Position: p.paramCount} - p.paramCount += 1 - stmt.ColValues[valueIdx] = append(stmt.ColValues[valueIdx], variableExpr) - } else { - intValue, err := strconv.Atoi(v.value) - if err != nil { - return nil, fmt.Errorf("failed to convert %v to integer", v.value) - } - stmt.ColValues[valueIdx] = append(stmt.ColValues[valueIdx], &IntLit{Value: intValue}) + exp, err := p.parseExpression(0) + if err != nil { + return nil, err } + stmt.ColValues[valueIdx] = append(stmt.ColValues[valueIdx], exp) sep := p.nextNonSpace() if sep.value != "," { if sep.value == ")" { diff --git a/compiler/parser_test.go b/compiler/parser_test.go index 60f78c8..49387e4 100644 --- a/compiler/parser_test.go +++ b/compiler/parser_test.go @@ -345,6 +345,7 @@ func TestParseCreate(t *testing.T) { } type insertTestCase struct { + name string tokens []token expected Stmt } @@ -352,6 +353,7 @@ type insertTestCase struct { func TestParseInsert(t *testing.T) { cases := []insertTestCase{ { + name: "ManyValues", tokens: []token{ {tkKeyword, "INSERT"}, {tkWhitespace, " "}, @@ -448,15 +450,87 @@ func TestParseInsert(t *testing.T) { }, }, }, + { + name: "WithExpressions", + tokens: []token{ + {tkKeyword, "INSERT"}, + {tkWhitespace, " "}, + {tkKeyword, "INTO"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + {tkWhitespace, " "}, + {tkSeparator, "("}, + {tkIdentifier, "id"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkIdentifier, "age"}, + {tkSeparator, ")"}, + {tkWhitespace, " "}, + {tkKeyword, "VALUES"}, + {tkWhitespace, " "}, + {tkSeparator, "("}, + {tkNumeric, "1"}, + {tkWhitespace, " "}, + {tkOperator, "+"}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkNumeric, "2"}, + {tkWhitespace, " "}, + {tkOperator, "-"}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + {tkSeparator, ")"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkSeparator, "("}, + {tkNumeric, "3"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkNumeric, "4"}, + {tkSeparator, ")"}, + }, + expected: &InsertStmt{ + StmtBase: &StmtBase{ + Explain: false, + }, + TableName: "foo", + ColNames: []string{ + "id", + "age", + }, + ColValues: [][]Expr{ + { + &BinaryExpr{ + Operator: "+", + Left: &IntLit{Value: 1}, + Right: &IntLit{Value: 1}, + }, + &BinaryExpr{ + Operator: "-", + Left: &IntLit{Value: 2}, + Right: &IntLit{Value: 1}, + }, + }, + { + &IntLit{Value: 3}, + &IntLit{Value: 4}, + }, + }, + }, + }, } for _, c := range cases { - ret, err := NewParser(c.tokens).Parse() - if err != nil { - t.Errorf("expected no err got err %s", err) - } - if !reflect.DeepEqual(ret, c.expected) { - t.Errorf("expected %#v got %#v", c.expected, ret) - } + t.Run(c.name, func(t *testing.T) { + ret, err := NewParser(c.tokens).Parse() + if err != nil { + t.Errorf("expected no err got err %s", err) + } + if !reflect.DeepEqual(ret, c.expected) { + t.Errorf("expected %#v got %#v", c.expected, ret) + } + }) } }