Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2721,9 +2721,11 @@ type InsertStatement struct {
InsertOrIgnore Pos // position of IGNORE keyword after INSERT OR
Into Pos // position of INTO keyword

Table *Ident // table name
As Pos // position of AS keyword
Alias *Ident // optional alias
Schema *Ident // optional schema name
Dot Pos // position of DOT between schema and table name
Table *Ident // table name
As Pos // position of AS keyword
Alias *Ident // optional alias

ColumnsLparen Pos // position of column list left paren
Columns []*Ident // optional column list
Expand All @@ -2748,6 +2750,7 @@ func (s *InsertStatement) Clone() *InsertStatement {
}
other := *s
other.WithClause = s.WithClause.Clone()
other.Schema = s.Schema.Clone()
other.Table = s.Table.Clone()
other.Alias = s.Alias.Clone()
other.Columns = cloneIdents(s.Columns)
Expand Down Expand Up @@ -2783,7 +2786,12 @@ func (s *InsertStatement) String() string {
}
}

fmt.Fprintf(&buf, " INTO %s", s.Table.String())
buf.WriteString(" INTO ")
if s.Schema != nil {
buf.WriteString(s.Schema.String())
buf.WriteString(".")
}
buf.WriteString(s.Table.String())
if s.Alias != nil {
fmt.Fprintf(&buf, " AS %s", s.Alias.String())
}
Expand Down
12 changes: 10 additions & 2 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1475,10 +1475,18 @@ func (p *Parser) parseInsertStatement(inTrigger bool, withClause *WithClause) (_
}
stmt.Into, _, _ = p.scan()

// Parse table name & optional alias.
// Parse table name & optional schema & alias.
if stmt.Table, err = p.parseIdent("table name"); err != nil {
return &stmt, err
}
// Check for schema.table syntax
if p.peek() == DOT {
stmt.Schema = stmt.Table
stmt.Dot, _, _ = p.scan()
if stmt.Table, err = p.parseIdent("table name"); err != nil {
return &stmt, err
}
}
if p.peek() == AS {
stmt.As, _, _ = p.scan()
if stmt.Alias, err = p.parseIdent("alias"); err != nil {
Expand Down Expand Up @@ -1731,7 +1739,7 @@ func (p *Parser) parseUpdateStatement(inTrigger bool, withClause *WithClause) (_
return nil, p.errorExpected(p.pos, p.tok, "table name")
}
ident, _ := p.parseIdent("table name")
if stmt.Table, err = p.parseQualifiedTableName(ident, false, false, false); err != nil {
if stmt.Table, err = p.parseQualifiedTableName(ident, true, false, false); err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot -- why did you make this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the first parameter (schemaOK) from false to true to allow schema-qualified table names in UPDATE statements. The parseQualifiedTableName function was rejecting the dot separator with "expected unqualified table name" error when schemaOK=false. Since UPDATE already uses *QualifiedTableName in its AST (which supports schema.table), this was just enabling that existing capability in the parser.

return &stmt, err
}

Expand Down
44 changes: 44 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3685,6 +3685,34 @@ func TestParser_ParseStatement(t *testing.T) {
},
})

// Test schema-qualified table name
AssertParseStatement(t, `INSERT INTO main.tableau VALUES (1, 2), (3, 4)`, &sql.InsertStatement{
Insert: pos(0),
Into: pos(7),
Schema: &sql.Ident{NamePos: pos(12), Name: "main"},
Dot: pos(16),
Table: &sql.Ident{NamePos: pos(17), Name: "tableau"},
Values: pos(25),
ValueLists: []*sql.ExprList{
{
Lparen: pos(32),
Exprs: []sql.Expr{
&sql.NumberLit{ValuePos: pos(33), Value: "1"},
&sql.NumberLit{ValuePos: pos(36), Value: "2"},
},
Rparen: pos(37),
},
{
Lparen: pos(40),
Exprs: []sql.Expr{
&sql.NumberLit{ValuePos: pos(41), Value: "3"},
&sql.NumberLit{ValuePos: pos(44), Value: "4"},
},
Rparen: pos(45),
},
},
})

AssertParseStatementError(t, `INSERT`, `1:6: expected INTO, found 'EOF'`)
AssertParseStatementError(t, `INSERT OR`, `1:9: expected ROLLBACK, REPLACE, ABORT, FAIL, or IGNORE, found 'EOF'`)
AssertParseStatementError(t, `INSERT INTO`, `1:11: expected table name, found 'EOF'`)
Expand Down Expand Up @@ -3856,6 +3884,22 @@ func TestParser_ParseStatement(t *testing.T) {
}},
})

// Test schema-qualified table name
AssertParseStatement(t, `UPDATE main.tableau SET n2=n1`, &sql.UpdateStatement{
Update: pos(0),
Table: &sql.QualifiedTableName{
Schema: &sql.Ident{NamePos: pos(7), Name: "main"},
Dot: pos(11),
Name: &sql.Ident{NamePos: pos(12), Name: "tableau"},
},
Set: pos(20),
Assignments: []*sql.Assignment{{
Columns: []*sql.Ident{{NamePos: pos(24), Name: "n2"}},
Eq: pos(26),
Expr: &sql.Ident{NamePos: pos(27), Name: "n1"},
}},
})

AssertParseStatementError(t, `UPDATE`, `1:6: expected table name, found 'EOF'`)
AssertParseStatementError(t, `UPDATE OR`, `1:9: expected ROLLBACK, REPLACE, ABORT, FAIL, or IGNORE, found 'EOF'`)
AssertParseStatementError(t, `UPDATE tbl`, `1:10: expected SET, found 'EOF'`)
Expand Down