diff --git a/compiler/ast.go b/compiler/ast.go index e1b9bf8..ae1220a 100644 --- a/compiler/ast.go +++ b/compiler/ast.go @@ -1,6 +1,10 @@ package compiler -import "github.com/chirst/cdb/catalog" +import ( + "fmt" + + "github.com/chirst/cdb/catalog" +) // ast (Abstract Syntax Tree) defines a data structure representing a SQL // program. This data structure is generated from the parser. This data @@ -92,6 +96,8 @@ type Expr interface { // BreadthWalk implements the visitor pattern for a in-order breadth first // walk. BreadthWalk(v ExprVisitor) + // Print returns a string representing the expression + Print() string } // BinaryExpr is for an expression with two operands. @@ -107,6 +113,10 @@ func (be *BinaryExpr) BreadthWalk(v ExprVisitor) { be.Right.BreadthWalk(v) } +func (be *BinaryExpr) Print() string { + return fmt.Sprintf("%s %s %s", be.Left.Print(), be.Operator, be.Right.Print()) +} + // UnaryExpr is an expression with one operand. type UnaryExpr struct { Operator string @@ -137,6 +147,13 @@ func (cr *ColumnRef) BreadthWalk(v ExprVisitor) { v.VisitColumnRefExpr(cr) } +func (cr *ColumnRef) Print() string { + if cr.IsPrimaryKey { + return fmt.Sprintf("%s PRIMARY KEY", cr.Column) + } + return cr.Column +} + // IntLit is an expression that is a literal integer such as "1". type IntLit struct { Value int @@ -146,6 +163,10 @@ func (il *IntLit) BreadthWalk(v ExprVisitor) { v.VisitIntLit(il) } +func (be *IntLit) Print() string { + return "?" +} + // StringLit is an expression that is a literal string such as "'asdf'". type StringLit struct { Value string @@ -155,6 +176,10 @@ func (sl *StringLit) BreadthWalk(v ExprVisitor) { v.VisitStringLit(sl) } +func (vi *StringLit) Print() string { + return "?" +} + type Variable struct { // Position is a unique integer defining what order the variable appeared in // the statement. @@ -165,6 +190,10 @@ func (vi *Variable) BreadthWalk(v ExprVisitor) { v.VisitVariable(vi) } +func (vi *Variable) Print() string { + return "?" +} + // FunctionExpr is an expression that represents a function. type FunctionExpr struct { // FnType corresponds to the type of function. For example fnCount is for @@ -180,3 +209,7 @@ const ( func (f *FunctionExpr) BreadthWalk(v ExprVisitor) { v.VisitFunctionExpr(f) } + +func (f *FunctionExpr) Print() string { + return f.FnType +} diff --git a/kv/kv.go b/kv/kv.go index 61fdd21..4672ccf 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -190,6 +190,27 @@ func (c *Cursor) GotoLastRecord() bool { return true } +func (c *Cursor) GotoKey(key []byte) bool { + candidatePage := c.pager.GetPage(c.rootPageNumber) + for !candidatePage.IsLeaf() { + v, exists := candidatePage.GetValue(key) + if !exists { + return false + } + nextPageNumber := int(binary.LittleEndian.Uint32(v)) + candidatePage = c.pager.GetPage(nextPageNumber) + } + c.moveToPage(candidatePage) + entries := c.currentPage.GetEntries() + for i, e := range entries { + if bytes.Equal(e.Key, key) { + c.currentTupleKey = entries[i].Key + return true + } + } + return false +} + // GetKey returns the key of the current tuple. func (c *Cursor) GetKey() []byte { return c.currentTupleKey diff --git a/planner/delete.go b/planner/delete.go index 3d4d8dc..999217e 100644 --- a/planner/delete.go +++ b/planner/delete.go @@ -69,6 +69,7 @@ func (d *deletePlanner) QueryPlan() (*QueryPlan, error) { deleteNode.child = sn sn.parent = deleteNode } + (&optimizer{}).optimizePlan(qp) return qp, nil } diff --git a/planner/delete_test.go b/planner/delete_test.go index abb5150..cfecb45 100644 --- a/planner/delete_test.go +++ b/planner/delete_test.go @@ -81,13 +81,11 @@ func TestDelete(t *testing.T) { }, }, expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 8}, + &vm.InitCmd{P2: 6}, &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 7}, - &vm.RowIdCmd{P1: 1, P2: 1}, - &vm.NotEqualCmd{P1: 1, P2: 6, P3: 2}, + &vm.CopyCmd{P1: 2, P2: 1}, + &vm.SeekRowId{P1: 1, P2: 5, P3: 1}, &vm.DeleteCmd{P1: 1}, - &vm.NextCmd{P1: 1, P2: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, &vm.IntegerCmd{P1: 1, P2: 2}, diff --git a/planner/generator.go b/planner/generator.go index 2f600c1..00947e8 100644 --- a/planner/generator.go +++ b/planner/generator.go @@ -227,3 +227,31 @@ func (d *deleteNode) produce() { func (n *joinNode) produce() {} func (n *joinNode) consume() {} + +func (s *seekNode) produce() { + s.consume() +} + +func (s *seekNode) consume() { + if s.isWriteCursor { + s.plan.commands = append( + s.plan.commands, + &vm.OpenWriteCmd{P1: s.cursorId, P2: s.rootPageNumber}, + ) + } else { + s.plan.commands = append( + s.plan.commands, + &vm.OpenReadCmd{P1: s.cursorId, P2: s.rootPageNumber}, + ) + } + rowIdRegister := s.plan.freeRegister + s.plan.freeRegister += 1 + generateExpressionTo(s.plan, s.predicate, rowIdRegister, s.cursorId) + seekCmd := &vm.SeekRowId{ + P1: s.cursorId, + P3: rowIdRegister, + } + s.plan.commands = append(s.plan.commands, seekCmd) + s.parent.consume() + seekCmd.P2 = len(s.plan.commands) +} diff --git a/planner/node.go b/planner/node.go index 5da2390..f5f0e99 100644 --- a/planner/node.go +++ b/planner/node.go @@ -21,6 +21,9 @@ type logicalNode interface { produce() // consume works with produce. consume() + // setChildren allows the caller to set a node's children. It may be + // advisable to call children to get an idea how many children the node has. + setChildren(n ...logicalNode) } // TODO joinNode is unused, but remains as a prototype binary operation node. @@ -42,6 +45,11 @@ func (j *joinNode) children() []logicalNode { return []logicalNode{j.left, j.right} } +func (j *joinNode) setChildren(n ...logicalNode) { + j.left = n[0] + j.right = n[1] +} + // createNode represents a operation to create an object in the system catalog. // For example a table, index, or trigger. type createNode struct { @@ -81,6 +89,8 @@ func (c *createNode) children() []logicalNode { return []logicalNode{} } +func (c *createNode) setChildren(n ...logicalNode) {} + // insertNode represents an insert operation. type insertNode struct { plan *QueryPlan @@ -114,6 +124,8 @@ func (i *insertNode) children() []logicalNode { return []logicalNode{} } +func (i *insertNode) setChildren(n ...logicalNode) {} + type countNode struct { plan *QueryPlan projection projection @@ -133,6 +145,8 @@ func (c *countNode) print() string { return fmt.Sprintf("count table %s", c.tableName) } +func (c *countNode) setChildren(n ...logicalNode) {} + type constantNode struct { parent logicalNode plan *QueryPlan @@ -146,6 +160,8 @@ func (c *constantNode) children() []logicalNode { return []logicalNode{} } +func (c *constantNode) setChildren(n ...logicalNode) {} + type projection struct { expr compiler.Expr // alias is the alias of the projection or no alias for the zero value. @@ -170,6 +186,10 @@ func (p *projectNode) children() []logicalNode { return []logicalNode{p.child} } +func (p *projectNode) setChildren(n ...logicalNode) { + p.child = n[0] +} + type scanNode struct { parent logicalNode plan *QueryPlan @@ -191,6 +211,35 @@ func (s *scanNode) children() []logicalNode { return []logicalNode{} } +func (s *scanNode) setChildren(n ...logicalNode) {} + +type seekNode struct { + parent logicalNode + plan *QueryPlan + // tableName is the name of the table being searched. + tableName string + // rootPageNumber is the root page number of the table being searched. + rootPageNumber int + // cursorId is the id of the cursor associated with the search. + cursorId int + // isWriteCursor determines whether or not the cursor is for read or write. + isWriteCursor bool + // fullPredicate is the entire expression this node matches. + fullPredicate compiler.Expr + // predicate is a subset of fullPredicate usually excluding the columnRef. + predicate compiler.Expr +} + +func (s *seekNode) print() string { + return fmt.Sprintf("seek table %s (%s)", s.tableName, s.fullPredicate.Print()) +} + +func (s *seekNode) children() []logicalNode { + return []logicalNode{} +} + +func (s *seekNode) setChildren(n ...logicalNode) {} + type filterNode struct { child logicalNode parent logicalNode @@ -203,13 +252,17 @@ type filterNode struct { } func (f *filterNode) print() string { - return "filter" + return "filter (" + f.predicate.Print() + ")" } func (f *filterNode) children() []logicalNode { return []logicalNode{f.child} } +func (f *filterNode) setChildren(n ...logicalNode) { + f.child = n[0] +} + type updateNode struct { child logicalNode plan *QueryPlan @@ -240,6 +293,10 @@ func (u *updateNode) children() []logicalNode { return []logicalNode{u.child} } +func (u *updateNode) setChildren(n ...logicalNode) { + u.child = n[0] +} + type deleteNode struct { child logicalNode plan *QueryPlan @@ -254,3 +311,7 @@ func (d *deleteNode) print() string { func (d *deleteNode) children() []logicalNode { return []logicalNode{d.child} } + +func (d *deleteNode) setChildren(n ...logicalNode) { + d.child = n[0] +} diff --git a/planner/optimizer.go b/planner/optimizer.go new file mode 100644 index 0000000..586356c --- /dev/null +++ b/planner/optimizer.go @@ -0,0 +1,66 @@ +package planner + +import "github.com/chirst/cdb/compiler" + +type optimizer struct{} + +func (o *optimizer) optimizePlan(plan *QueryPlan) { + if len(plan.root.children()) == 0 { + return + } + filterNode, ok := plan.root.children()[0].(*filterNode) + if !ok { + return + } + sn, ok := filterNode.child.(*scanNode) + if !ok { + return + } + rowExpr := o.canOpt(filterNode.predicate) + if rowExpr == nil { + return + } + // If the filter can be moved to a seek then remove the filter and push the + // predicate into a seek. + seekN := &seekNode{ + parent: filterNode.parent, + plan: sn.plan, + tableName: sn.tableName, + rootPageNumber: sn.rootPageNumber, + cursorId: sn.cursorId, + isWriteCursor: sn.isWriteCursor, + fullPredicate: filterNode.predicate, + predicate: rowExpr, + } + seekN.parent.setChildren(seekN) +} + +func (*optimizer) canOpt(predicate compiler.Expr) compiler.Expr { + // The most basic optimization. Is the filter a primary key column ref equal + // to a constant of some sort. + be, ok := predicate.(*compiler.BinaryExpr) + if !ok || be.Operator != compiler.OpEq { + return nil + } + if lcr, ok := be.Left.(*compiler.ColumnRef); ok && lcr.IsPrimaryKey { + switch t := be.Right.(type) { + case *compiler.IntLit: + return t + case *compiler.StringLit: + return t + case *compiler.Variable: + return t + } + } + if rcr, ok := be.Left.(*compiler.ColumnRef); ok && rcr.IsPrimaryKey { + switch t := be.Left.(type) { + case *compiler.IntLit: + return t + case *compiler.StringLit: + return t + case *compiler.Variable: + return t + } + } + return nil +} diff --git a/planner/select.go b/planner/select.go index 71f7e7c..93ab5e9 100644 --- a/planner/select.go +++ b/planner/select.go @@ -161,6 +161,7 @@ func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { } p.queryPlan = plan plan.root = projectNode + (&optimizer{}).optimizePlan(plan) return plan, nil } diff --git a/planner/select_test.go b/planner/select_test.go index 87adb7a..12c7ce1 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -389,14 +389,12 @@ func TestSelectPlan(t *testing.T) { { description: "WithWhereClause", expectedCommands: []vm.Command{ - &vm.InitCmd{P2: 9}, + &vm.InitCmd{P2: 7}, &vm.OpenReadCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 8}, - &vm.RowIdCmd{P1: 1, P2: 1}, - &vm.NotEqualCmd{P1: 1, P2: 7, P3: 2}, - &vm.RowIdCmd{P1: 1, P2: 4}, - &vm.ResultRowCmd{P1: 4, P2: 1}, - &vm.NextCmd{P1: 1, P2: 3}, + &vm.CopyCmd{P1: 2, P2: 1}, + &vm.SeekRowId{P1: 1, P2: 6, P3: 1}, + &vm.RowIdCmd{P1: 1, P2: 3}, + &vm.ResultRowCmd{P1: 3, P2: 1}, &vm.HaltCmd{}, &vm.TransactionCmd{P1: 0}, &vm.IntegerCmd{P1: 1, P2: 2}, @@ -487,3 +485,40 @@ func TestSelectTableDoesNotExist(t *testing.T) { t.Fatalf("expected err: %s but got: %s", expectErr, err) } } + +func TestUsePrimaryKeyIndex(t *testing.T) { + ast := &compiler.SelectStmt{ + StmtBase: &compiler.StmtBase{}, + From: &compiler.From{ + TableName: "foo", + }, + ResultColumns: []compiler.ResultColumn{ + { + All: true, + }, + }, + Where: &compiler.BinaryExpr{ + Left: &compiler.ColumnRef{Column: "id"}, + Right: &compiler.IntLit{Value: 1}, + Operator: compiler.OpEq, + }, + } + mockCatalog := &mockSelectCatalog{ + primaryKeyColumnName: "id", + } + qp, err := NewSelect(mockCatalog, ast).QueryPlan() + if err != nil { + t.Errorf("expected no err got err %s", err) + } + if pn, ok := qp.root.(*projectNode); ok { + if seekN, ok := pn.child.(*seekNode); ok { + if seekN.parent != pn { + t.Error("expected parent to be pn") + } + } else { + t.Errorf("expected seek node but got %#v", pn.child) + } + } else { + t.Errorf("expected project node but got %#v", qp.root) + } +} diff --git a/planner/update.go b/planner/update.go index 8430710..d4fcec4 100644 --- a/planner/update.go +++ b/planner/update.go @@ -96,7 +96,7 @@ func (p *updatePlanner) QueryPlan() (*QueryPlan, error) { scanNode.parent = updateNode updateNode.child = scanNode } - + (&optimizer{}).optimizePlan(logicalPlan) return logicalPlan, nil } diff --git a/planner/update_test.go b/planner/update_test.go index 7e45980..fcb556a 100644 --- a/planner/update_test.go +++ b/planner/update_test.go @@ -101,18 +101,16 @@ func TestUpdateWithWhere(t *testing.T) { }, } expectedCommands := []vm.Command{ - &vm.InitCmd{P2: 13}, + &vm.InitCmd{P2: 11}, &vm.OpenWriteCmd{P1: 1, P2: 2}, - &vm.RewindCmd{P1: 1, P2: 12}, - &vm.RowIdCmd{P1: 1, P2: 1}, - &vm.NotEqualCmd{P1: 1, P2: 11, P3: 2}, - &vm.RowIdCmd{P1: 1, P2: 4}, - &vm.ColumnCmd{P1: 1, P2: 0, P3: 5}, - &vm.CopyCmd{P1: 2, P2: 6}, - &vm.MakeRecordCmd{P1: 5, P2: 2, P3: 7}, + &vm.CopyCmd{P1: 2, P2: 1}, + &vm.SeekRowId{P1: 1, P2: 10, P3: 1}, + &vm.RowIdCmd{P1: 1, P2: 3}, + &vm.ColumnCmd{P1: 1, P2: 0, P3: 4}, + &vm.CopyCmd{P1: 2, P2: 5}, + &vm.MakeRecordCmd{P1: 4, P2: 2, P3: 6}, &vm.DeleteCmd{P1: 1}, - &vm.InsertCmd{P1: 1, P2: 7, P3: 4}, - &vm.NextCmd{P1: 1, P2: 3}, + &vm.InsertCmd{P1: 1, P2: 6, P3: 3}, &vm.HaltCmd{}, &vm.TransactionCmd{P2: 1}, &vm.IntegerCmd{P1: 1, P2: 2}, diff --git a/vm/vm.go b/vm/vm.go index 9dd2ea9..c7bb59c 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -107,6 +107,7 @@ func (e *ExecutionPlan) Append(command Command) { // the system catalog Execute will return ErrVersionChanged in the ExecuteResult // err field so the plan can be recompiled. func (v *vm) Execute(plan *ExecutionPlan, parameters []any) *ExecuteResult { + parameters = v.normalizeParameters(parameters) if plan.Explain { return v.explain(plan) } @@ -150,6 +151,23 @@ func (v *vm) Execute(plan *ExecutionPlan, parameters []any) *ExecuteResult { } } +// normalizeParameters converts parameters to a simpler type. This is because of +// things like a int vs int64 producing different byte array values. This can +// for example cause bugs with comparisons within the key value store. +func (v *vm) normalizeParameters(parameters []any) []any { + for i := range parameters { + switch t := parameters[i].(type) { + case int16: + parameters[i] = int(t) + case int32: + parameters[i] = int(t) + case int64: + parameters[i] = int(t) + } + } + return parameters +} + // resolveVarTypes takes unresolved var types in the result types and determines // their type from the passed in go type. func (v *vm) resolveVarTypes(plan *ExecutionPlan, parameters []any) error { @@ -158,12 +176,6 @@ func (v *vm) resolveVarTypes(plan *ExecutionPlan, parameters []any) error { switch parameters[plan.ResultTypes[i].VarPosition].(type) { case int: plan.ResultTypes[i].ID = catalog.CTInt - case int16: - plan.ResultTypes[i].ID = catalog.CTInt - case int32: - plan.ResultTypes[i].ID = catalog.CTInt - case int64: - plan.ResultTypes[i].ID = catalog.CTInt case string: plan.ResultTypes[i].ID = catalog.CTStr default: @@ -555,6 +567,31 @@ func (c *NewRowIdCmd) explain(addr int) []*string { return formatExplain(addr, "NewRowID", c.P1, c.P2, c.P3, c.P4, c.P5, comment) } +// SeekRowIdCmd moves cursor P1 to the row id in register P3. If there is no +// record it jumps to P2. +type SeekRowId cmd + +func (c *SeekRowId) execute(vm *vm, routine *routine) cmdRes { + key, err := kv.EncodeKey(routine.registers[c.P3]) + if err != nil { + return cmdRes{ + err: err, + } + } + found := routine.cursors[c.P1].GotoKey(key) + if !found { + return cmdRes{ + nextAddress: c.P2, + } + } + return cmdRes{} +} + +func (c *SeekRowId) explain(addr int) []*string { + comment := fmt.Sprintf("Move cursor %d to row in register[%d] or jump to addr[%d]", c.P1, c.P3, c.P2) + return formatExplain(addr, "SeekRowID", c.P1, c.P2, c.P3, c.P4, c.P5, comment) +} + // InsertCmd write to cursor P1 with data in P2 and key in P3 type InsertCmd cmd