diff --git a/pkg/sql/parser/cte.go b/pkg/sql/parser/cte.go index daef735..a769e1b 100644 --- a/pkg/sql/parser/cte.go +++ b/pkg/sql/parser/cte.go @@ -108,8 +108,8 @@ func (p *Parser) parseCommonTableExpr() (*ast.CommonTableExpr, error) { ) } - // Parse CTE name - if !p.isType(models.TokenTypeIdentifier) { + // Parse CTE name (supports double-quoted identifiers) + if !p.isIdentifier() { return nil, p.expectedError("CTE name") } name := p.currentToken.Literal @@ -121,7 +121,7 @@ func (p *Parser) parseCommonTableExpr() (*ast.CommonTableExpr, error) { p.advance() // Consume ( for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("column name") } columns = append(columns, p.currentToken.Literal) diff --git a/pkg/sql/parser/ddl.go b/pkg/sql/parser/ddl.go index 9f7ee15..83f2c30 100644 --- a/pkg/sql/parser/ddl.go +++ b/pkg/sql/parser/ddl.go @@ -98,8 +98,8 @@ func (p *Parser) parseCreateView(orReplace, temporary bool) (*ast.CreateViewStat stmt.IfNotExists = true } - // Parse view name - if !p.isType(models.TokenTypeIdentifier) { + // Parse view name (supports double-quoted identifiers for PostgreSQL compatibility) + if !p.isIdentifier() { return nil, p.expectedError("view name") } stmt.Name = p.currentToken.Literal @@ -109,7 +109,7 @@ func (p *Parser) parseCreateView(orReplace, temporary bool) (*ast.CreateViewStat if p.isType(models.TokenTypeLParen) { p.advance() // Consume ( for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("column name") } stmt.Columns = append(stmt.Columns, p.currentToken.Literal) @@ -202,8 +202,8 @@ func (p *Parser) parseCreateMaterializedView() (*ast.CreateMaterializedViewState stmt.IfNotExists = true } - // Parse view name - if !p.isType(models.TokenTypeIdentifier) { + // Parse view name (supports double-quoted identifiers for PostgreSQL compatibility) + if !p.isIdentifier() { return nil, p.expectedError("materialized view name") } stmt.Name = p.currentToken.Literal @@ -213,7 +213,7 @@ func (p *Parser) parseCreateMaterializedView() (*ast.CreateMaterializedViewState if p.isType(models.TokenTypeLParen) { p.advance() // Consume ( for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("column name") } stmt.Columns = append(stmt.Columns, p.currentToken.Literal) @@ -234,7 +234,7 @@ func (p *Parser) parseCreateMaterializedView() (*ast.CreateMaterializedViewState // Parse optional TABLESPACE if p.isTokenMatch("TABLESPACE") { p.advance() // Consume TABLESPACE - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("tablespace name") } stmt.Tablespace = p.currentToken.Literal @@ -307,8 +307,8 @@ func (p *Parser) parseCreateTable(temporary bool) (*ast.CreateTableStatement, er stmt.IfNotExists = true } - // Parse table name - if !p.isType(models.TokenTypeIdentifier) { + // Parse table name (supports double-quoted identifiers for PostgreSQL compatibility) + if !p.isIdentifier() { return nil, p.expectedError("table name") } stmt.Name = p.currentToken.Literal @@ -398,7 +398,7 @@ func (p *Parser) parseCreateTable(temporary bool) (*ast.CreateTableStatement, er if p.isType(models.TokenTypeEq) { p.advance() // Consume = } - if p.isType(models.TokenTypeIdentifier) || p.isType(models.TokenTypeString) { + if p.isIdentifier() || p.isType(models.TokenTypeString) { opt.Value = p.currentToken.Literal p.advance() } @@ -434,7 +434,7 @@ func (p *Parser) parsePartitionByClause() (*ast.PartitionBy, error) { // Parse column list for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("column name") } partitionBy.Columns = append(partitionBy.Columns, p.currentToken.Literal) @@ -466,8 +466,8 @@ func (p *Parser) parsePartitionDefinition() (*ast.PartitionDefinition, error) { } p.advance() // Consume PARTITION - // Parse partition name - if !p.isType(models.TokenTypeIdentifier) { + // Parse partition name (supports double-quoted identifiers) + if !p.isIdentifier() { return nil, p.expectedError("partition name") } partDef.Name = p.currentToken.Literal @@ -581,7 +581,7 @@ func (p *Parser) parsePartitionDefinition() (*ast.PartitionDefinition, error) { // Parse optional TABLESPACE if p.isTokenMatch("TABLESPACE") { p.advance() // Consume TABLESPACE - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("tablespace name") } partDef.Tablespace = p.currentToken.Literal @@ -611,8 +611,8 @@ func (p *Parser) parseCreateIndex(unique bool) (*ast.CreateIndexStatement, error stmt.IfNotExists = true } - // Parse index name - if !p.isType(models.TokenTypeIdentifier) { + // Parse index name (supports double-quoted identifiers) + if !p.isIdentifier() { return nil, p.expectedError("index name") } stmt.Name = p.currentToken.Literal @@ -624,8 +624,8 @@ func (p *Parser) parseCreateIndex(unique bool) (*ast.CreateIndexStatement, error } p.advance() // Consume ON - // Parse table name - if !p.isType(models.TokenTypeIdentifier) { + // Parse table name (supports double-quoted identifiers for PostgreSQL compatibility) + if !p.isIdentifier() { return nil, p.expectedError("table name") } stmt.Table = p.currentToken.Literal @@ -634,7 +634,7 @@ func (p *Parser) parseCreateIndex(unique bool) (*ast.CreateIndexStatement, error // Parse optional USING if p.isType(models.TokenTypeUsing) { p.advance() // Consume USING - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("index method") } stmt.Using = p.currentToken.Literal @@ -650,7 +650,7 @@ func (p *Parser) parseCreateIndex(unique bool) (*ast.CreateIndexStatement, error // Parse column list for { col := ast.IndexColumn{} - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("column name") } col.Column = p.currentToken.Literal @@ -739,9 +739,9 @@ func (p *Parser) parseDropStatement() (*ast.DropStatement, error) { stmt.IfExists = true } - // Parse object names (can be comma-separated) + // Parse object names (can be comma-separated, supports double-quoted identifiers) for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("object name") } stmt.Names = append(stmt.Names, p.currentToken.Literal) @@ -788,8 +788,8 @@ func (p *Parser) parseRefreshStatement() (*ast.RefreshMaterializedViewStatement, p.advance() } - // Parse view name - if !p.isType(models.TokenTypeIdentifier) { + // Parse view name (supports double-quoted identifiers for PostgreSQL compatibility) + if !p.isIdentifier() { return nil, p.expectedError("materialized view name") } stmt.Name = p.currentToken.Literal @@ -827,9 +827,9 @@ func (p *Parser) parseTruncateStatement() (*ast.TruncateStatement, error) { p.advance() // Consume TABLE } - // Parse table names (can be comma-separated) + // Parse table names (can be comma-separated, supports double-quoted identifiers) for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("table name") } stmt.Tables = append(stmt.Tables, p.currentToken.Literal) diff --git a/pkg/sql/parser/dml.go b/pkg/sql/parser/dml.go index a345b7a..a0b8d40 100644 --- a/pkg/sql/parser/dml.go +++ b/pkg/sql/parser/dml.go @@ -21,8 +21,8 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) { } p.advance() // Consume INTO - // Parse table name - if !p.isType(models.TokenTypeIdentifier) { + // Parse table name (supports double-quoted identifiers for PostgreSQL compatibility) + if !p.isIdentifier() { return nil, p.expectedError("table name") } tableName := p.currentToken.Literal @@ -34,8 +34,8 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) { p.advance() // Consume ( for { - // Parse column name - if !p.isType(models.TokenTypeIdentifier) { + // Parse column name (supports double-quoted identifiers) + if !p.isIdentifier() { return nil, p.expectedError("column name") } columns = append(columns, &ast.Identifier{Name: p.currentToken.Literal}) @@ -143,8 +143,8 @@ func (p *Parser) parseInsertStatement() (ast.Statement, error) { func (p *Parser) parseUpdateStatement() (ast.Statement, error) { // We've already consumed the UPDATE token in matchToken - // Parse table name - if !p.isType(models.TokenTypeIdentifier) { + // Parse table name (supports double-quoted identifiers for PostgreSQL compatibility) + if !p.isIdentifier() { return nil, p.expectedError("table name") } tableName := p.currentToken.Literal @@ -159,8 +159,8 @@ func (p *Parser) parseUpdateStatement() (ast.Statement, error) { // Parse assignments updates := make([]ast.UpdateExpression, 0) for { - // Parse column name - if !p.isType(models.TokenTypeIdentifier) { + // Parse column name (supports double-quoted identifiers) + if !p.isIdentifier() { return nil, p.expectedError("column name") } columnName := p.currentToken.Literal @@ -250,8 +250,8 @@ func (p *Parser) parseDeleteStatement() (ast.Statement, error) { } p.advance() // Consume FROM - // Parse table name - if !p.isType(models.TokenTypeIdentifier) { + // Parse table name (supports double-quoted identifiers for PostgreSQL compatibility) + if !p.isIdentifier() { return nil, p.expectedError("table name") } tableName := p.currentToken.Literal @@ -311,7 +311,7 @@ func (p *Parser) parseMergeStatement() (ast.Statement, error) { // Parse optional target alias (AS alias or just alias) if p.isType(models.TokenTypeAs) { p.advance() // Consume AS - if !p.isType(models.TokenTypeIdentifier) && !p.isNonReservedKeyword() { + if !p.isIdentifier() && !p.isNonReservedKeyword() { return nil, p.expectedError("target alias after AS") } stmt.TargetAlias = p.currentToken.Literal @@ -337,7 +337,7 @@ func (p *Parser) parseMergeStatement() (ast.Statement, error) { // Parse optional source alias if p.isType(models.TokenTypeAs) { p.advance() // Consume AS - if !p.isType(models.TokenTypeIdentifier) && !p.isNonReservedKeyword() { + if !p.isIdentifier() && !p.isNonReservedKeyword() { return nil, p.expectedError("source alias after AS") } stmt.SourceAlias = p.currentToken.Literal @@ -449,7 +449,7 @@ func (p *Parser) parseMergeAction(clauseType string) (*ast.MergeAction, error) { // Parse SET clauses for { - if !p.isType(models.TokenTypeIdentifier) && !p.canBeAlias() { + if !p.isIdentifier() && !p.canBeAlias() { return nil, p.expectedError("column name") } // Handle qualified column names (e.g., t.name) @@ -459,7 +459,7 @@ func (p *Parser) parseMergeAction(clauseType string) (*ast.MergeAction, error) { // Check for qualified name (table.column) if p.isType(models.TokenTypePeriod) { p.advance() // Consume . - if !p.isType(models.TokenTypeIdentifier) && !p.canBeAlias() { + if !p.isIdentifier() && !p.canBeAlias() { return nil, p.expectedError("column name after .") } columnName = columnName + "." + p.currentToken.Literal @@ -496,7 +496,7 @@ func (p *Parser) parseMergeAction(clauseType string) (*ast.MergeAction, error) { if p.isType(models.TokenTypeLParen) { p.advance() // Consume ( for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("column name") } action.Columns = append(action.Columns, p.currentToken.Literal) @@ -601,7 +601,7 @@ func (p *Parser) parseOnConflictClause() (*ast.OnConflict, error) { var targets []ast.Expression for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("column name in ON CONFLICT target") } targets = append(targets, &ast.Identifier{Name: p.currentToken.Literal}) @@ -622,7 +622,7 @@ func (p *Parser) parseOnConflictClause() (*ast.OnConflict, error) { // ON CONSTRAINT constraint_name p.advance() // Consume ON p.advance() // Consume CONSTRAINT - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("constraint name") } onConflict.Constraint = p.currentToken.Literal @@ -651,7 +651,7 @@ func (p *Parser) parseOnConflictClause() (*ast.OnConflict, error) { // Parse update assignments var updates []ast.UpdateExpression for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("column name") } columnName := p.currentToken.Literal diff --git a/pkg/sql/parser/double_quoted_identifier_test.go b/pkg/sql/parser/double_quoted_identifier_test.go new file mode 100644 index 0000000..c732596 --- /dev/null +++ b/pkg/sql/parser/double_quoted_identifier_test.go @@ -0,0 +1,527 @@ +// Package parser - double_quoted_identifier_test.go +// Tests for double-quoted identifier support in DML and DDL statements. +// Double-quoted identifiers are part of the ANSI SQL standard and are used by +// PostgreSQL, Oracle, SQLite, and other databases. + +package parser + +import ( + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/models" + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/token" + "github.com/ajitpratap0/GoSQLX/pkg/sql/tokenizer" +) + +// parseSQLWithQuotedIdentifiers is a helper to tokenize and parse SQL for testing quoted identifiers +// (double-quoted for ANSI SQL/PostgreSQL, backticks for MySQL, etc.) +func parseSQLWithQuotedIdentifiers(t *testing.T, sql string) (*ast.AST, error) { + t.Helper() + + tkz := tokenizer.GetTokenizer() + defer tokenizer.PutTokenizer(tkz) + + tokens, err := tkz.Tokenize([]byte(sql)) + if err != nil { + return nil, err + } + + convertedTokens := convertTokensWithQuotedIdentifiers(tokens) + + parser := NewParser() + defer parser.Release() + + tree, err := parser.Parse(convertedTokens) + return tree, err +} + +// convertTokensWithQuotedIdentifiers converts tokenizer tokens to parser tokens, +// including proper handling of quoted strings (double-quoted, backticks) as identifiers +func convertTokensWithQuotedIdentifiers(tokens []models.TokenWithSpan) []token.Token { + result := make([]token.Token, 0, len(tokens)) + for _, t := range tokens { + var tokenType token.Type + var modelType models.TokenType = t.Token.Type // Preserve the original ModelType + literal := t.Token.Value + + switch t.Token.Type { + case models.TokenTypeIdentifier: + tokenType = "IDENT" + case models.TokenTypeDoubleQuotedString: + // Double-quoted strings should be treated as identifiers in SQL + tokenType = "DOUBLE_QUOTED_STRING" + case models.TokenTypeKeyword: + tokenType = token.Type(t.Token.Value) + case models.TokenTypeString: + tokenType = "STRING" + case models.TokenTypeNumber: + tokenType = "INT" + case models.TokenTypeOperator: + tokenType = token.Type(t.Token.Value) + case models.TokenTypeLParen: + tokenType = "(" + case models.TokenTypeRParen: + tokenType = ")" + case models.TokenTypeComma: + tokenType = "," + case models.TokenTypePeriod: + tokenType = "." + case models.TokenTypeEq: + tokenType = "=" + case models.TokenTypeSemicolon: + tokenType = ";" + case models.TokenTypeAsterisk: + tokenType = "*" + literal = "*" + case models.TokenTypeMul: + // Normalize multiplication to asterisk for parser compatibility + tokenType = "*" + modelType = models.TokenTypeAsterisk + literal = "*" + default: + if t.Token.Value != "" { + tokenType = token.Type(t.Token.Value) + } + } + + if tokenType != "" { + result = append(result, token.Token{ + Type: tokenType, + ModelType: modelType, + Literal: literal, + }) + } + } + return result +} + +func TestDoubleQuotedIdentifiers_SELECT(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted column in SELECT", + sql: `SELECT "id" FROM users`, + }, + { + name: "double-quoted table in SELECT", + sql: `SELECT id FROM "users"`, + }, + { + name: "double-quoted column and table in SELECT", + sql: `SELECT "id", "name" FROM "users"`, + }, + { + name: "double-quoted in WHERE clause", + sql: `SELECT id FROM users WHERE "id" = 1`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +func TestDoubleQuotedIdentifiers_INSERT(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted table in INSERT", + sql: `INSERT INTO "users" (name) VALUES (1)`, + }, + { + name: "double-quoted columns in INSERT", + sql: `INSERT INTO users ("id", "name") VALUES (1, 2)`, + }, + { + name: "double-quoted table and columns in INSERT", + sql: `INSERT INTO "users" ("id", "name") VALUES (1, 2)`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +func TestDoubleQuotedIdentifiers_UPDATE(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted table in UPDATE", + sql: `UPDATE "users" SET name = 1`, + }, + { + name: "double-quoted column in UPDATE SET", + sql: `UPDATE users SET "name" = 1`, + }, + { + name: "double-quoted table and column in UPDATE", + sql: `UPDATE "users" SET "name" = 1 WHERE "id" = 1`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +func TestDoubleQuotedIdentifiers_DELETE(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted table in DELETE", + sql: `DELETE FROM "users"`, + }, + { + name: "double-quoted table with WHERE in DELETE", + sql: `DELETE FROM "users" WHERE "id" = 1`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +func TestDoubleQuotedIdentifiers_DROP(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted table in DROP TABLE", + sql: `DROP TABLE "users"`, + }, + { + name: "double-quoted table with IF EXISTS in DROP", + sql: `DROP TABLE IF EXISTS "users"`, + }, + { + name: "double-quoted view in DROP VIEW", + sql: `DROP VIEW "user_summary"`, + }, + { + name: "double-quoted index in DROP INDEX", + sql: `DROP INDEX "idx_users_name"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +func TestDoubleQuotedIdentifiers_CREATE(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted table in CREATE TABLE", + sql: `CREATE TABLE "users" (id INT)`, + }, + { + name: "double-quoted view in CREATE VIEW", + sql: `CREATE VIEW "user_summary" AS SELECT id FROM users`, + }, + { + name: "double-quoted index in CREATE INDEX", + sql: `CREATE INDEX "idx_users_name" ON users (name)`, + }, + { + name: "double-quoted table in CREATE INDEX ON", + sql: `CREATE INDEX idx_users_name ON "users" (name)`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +func TestDoubleQuotedIdentifiers_TRUNCATE(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted table in TRUNCATE", + sql: `TRUNCATE TABLE "users"`, + }, + { + name: "double-quoted table without TABLE keyword", + sql: `TRUNCATE "users"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +// TestDoubleQuotedIdentifiers_Mixed tests mixing quoted and unquoted identifiers +func TestDoubleQuotedIdentifiers_Mixed(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "mixed identifiers in SELECT", + sql: `SELECT "id", name FROM "users" WHERE status = 1`, + }, + { + name: "mixed identifiers in INSERT", + sql: `INSERT INTO "users" (id, "name") VALUES (1, 2)`, + }, + { + name: "mixed identifiers in UPDATE", + sql: `UPDATE "users" SET name = 1, "status" = 2`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +// TestDoubleQuotedIdentifiers_CTE tests double-quoted identifiers in Common Table Expressions +func TestDoubleQuotedIdentifiers_CTE(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted CTE name", + sql: `WITH "reserved-word" AS (SELECT 1) SELECT * FROM "reserved-word"`, + }, + { + name: "double-quoted CTE column", + sql: `WITH cte ("column") AS (SELECT 1) SELECT * FROM cte`, + }, + { + name: "double-quoted CTE name and columns", + sql: `WITH "my-cte" ("col1", "col2") AS (SELECT 1, 2) SELECT * FROM "my-cte"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +// TestDoubleQuotedIdentifiers_MERGE tests double-quoted identifiers in MERGE statements +func TestDoubleQuotedIdentifiers_MERGE(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted target table in MERGE", + sql: `MERGE INTO "target" t USING source s ON t.id = s.id WHEN MATCHED THEN UPDATE SET name = s.name`, + }, + { + name: "double-quoted source table in MERGE", + sql: `MERGE INTO target t USING "source" s ON t.id = s.id WHEN MATCHED THEN UPDATE SET name = s.name`, + }, + { + name: "double-quoted column in MERGE UPDATE", + sql: `MERGE INTO target t USING source s ON t.id = s.id WHEN MATCHED THEN UPDATE SET "col" = s.val`, + }, + { + name: "double-quoted tables in MERGE", + sql: `MERGE INTO "target" t USING "source" s ON t.id = s.id WHEN MATCHED THEN UPDATE SET "col" = s.val`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +// TestDoubleQuotedIdentifiers_MaterializedView tests double-quoted identifiers in materialized view statements +func TestDoubleQuotedIdentifiers_MaterializedView(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted view in CREATE MATERIALIZED VIEW", + sql: `CREATE MATERIALIZED VIEW "my-view" AS SELECT id FROM users`, + }, + { + name: "double-quoted view in REFRESH MATERIALIZED VIEW", + sql: `REFRESH MATERIALIZED VIEW "my-view"`, + }, + { + name: "double-quoted view in DROP MATERIALIZED VIEW", + sql: `DROP MATERIALIZED VIEW "my-view"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +// TestDoubleQuotedIdentifiers_OnConflict tests double-quoted identifiers in ON CONFLICT clauses +func TestDoubleQuotedIdentifiers_OnConflict(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "double-quoted column in ON CONFLICT", + sql: `INSERT INTO users ("id") VALUES (1) ON CONFLICT ("id") DO NOTHING`, + }, + { + name: "double-quoted table and column in INSERT with ON CONFLICT", + sql: `INSERT INTO "users" ("id") VALUES (1) ON CONFLICT ("id") DO NOTHING`, + }, + { + name: "double-quoted in ON CONFLICT DO UPDATE", + sql: `INSERT INTO "users" ("id", "name") VALUES (1, 2) ON CONFLICT ("id") DO UPDATE SET "name" = 3`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} + +// TestDoubleQuotedIdentifiers_EdgeCases tests edge cases for double-quoted identifiers +func TestDoubleQuotedIdentifiers_EdgeCases(t *testing.T) { + tests := []struct { + name string + sql string + }{ + { + name: "reserved word as identifier", + sql: `SELECT "select", "from" FROM users`, + }, + { + name: "identifier with hyphen", + sql: `SELECT * FROM "my-table"`, + }, + { + name: "identifier with space", + sql: `SELECT * FROM "table name"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree, err := parseSQLWithQuotedIdentifiers(t, tt.sql) + if err != nil { + t.Fatalf("Failed to parse %q: %v", tt.sql, err) + } + if tree != nil { + defer ast.ReleaseAST(tree) + } + }) + } +} diff --git a/pkg/sql/parser/select.go b/pkg/sql/parser/select.go index 5a85e33..3355a63 100644 --- a/pkg/sql/parser/select.go +++ b/pkg/sql/parser/select.go @@ -167,8 +167,8 @@ func (p *Parser) parseColumnConstraint() (*ast.ColumnConstraint, bool, error) { p.advance() // Consume REFERENCES constraint.Type = "REFERENCES" - // Parse referenced table name - if !p.isType(models.TokenTypeIdentifier) { + // Parse referenced table name (supports double-quoted identifiers) + if !p.isIdentifier() { return nil, false, p.expectedError("table name after REFERENCES") } refDef := &ast.ReferenceDefinition{ @@ -180,7 +180,7 @@ func (p *Parser) parseColumnConstraint() (*ast.ColumnConstraint, bool, error) { if p.isType(models.TokenTypeLParen) { p.advance() // Consume ( for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, false, p.expectedError("column name in REFERENCES") } refDef.Columns = append(refDef.Columns, p.currentToken.Literal) @@ -325,8 +325,8 @@ func (p *Parser) parseTableConstraint() (*ast.TableConstraint, error) { } p.advance() // Consume REFERENCES - // Parse referenced table - if !p.isType(models.TokenTypeIdentifier) { + // Parse referenced table (supports double-quoted identifiers) + if !p.isIdentifier() { return nil, p.expectedError("table name after REFERENCES") } refDef := &ast.ReferenceDefinition{ @@ -398,7 +398,7 @@ func (p *Parser) parseConstraintColumnList() ([]string, error) { var columns []string for { - if !p.isType(models.TokenTypeIdentifier) { + if !p.isIdentifier() { return nil, p.expectedError("column name") } columns = append(columns, p.currentToken.Literal)