diff --git a/priority_queue.go b/priority_queue.go index 5d7eb83..31882bd 100644 --- a/priority_queue.go +++ b/priority_queue.go @@ -36,17 +36,17 @@ func newPriorityQueue(db *sql.DB, tableName string, opts ...Option) (*PriorityQu func (pq *PriorityQueue) initPriorityColumn() error { // Check if priority column exists var name string - err := pq.client.QueryRow(fmt.Sprintf("PRAGMA table_info(%s)", pq.tableName)).Scan(nil, &name, nil, nil, nil, nil) + err := pq.client.QueryRow(fmt.Sprintf("PRAGMA table_info(%s)", quoteIdent(pq.tableName))).Scan(nil, &name, nil, nil, nil, nil) if err != nil || name != "priority" { // Add priority column with default value 0 - _, err := pq.client.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN priority INTEGER NOT NULL DEFAULT 0", pq.tableName)) + _, err := pq.client.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN priority INTEGER NOT NULL DEFAULT 0", quoteIdent(pq.tableName))) if err != nil { return err } // Create index on priority (ASC for lower numbers = higher priority) - _, err = pq.client.Exec(fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_priority_idx ON %s (priority ASC, created_at ASC)", pq.tableName, pq.tableName)) + _, err = pq.client.Exec(fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (priority ASC, created_at ASC)", quoteIdent(pq.tableName+"_priority_idx"), quoteIdent(pq.tableName))) if err != nil { return err } @@ -75,7 +75,7 @@ func (pq *PriorityQueue) Enqueue(item any, priority int) bool { }() _, err = tx.Exec( - fmt.Sprintf("INSERT INTO %s (data, status, created_at, updated_at, priority) VALUES (?, ?, ?, ?, ?)", pq.tableName), + fmt.Sprintf("INSERT INTO %s (data, status, created_at, updated_at, priority) VALUES (?, ?, ?, ?, ?)", quoteIdent(pq.tableName)), item, "pending", now, now, priority, ) if err != nil { @@ -108,7 +108,7 @@ func (pq *PriorityQueue) dequeueInternal(withAckId bool) (any, bool, string) { var data []byte row := tx.QueryRow(fmt.Sprintf( "SELECT id, data FROM %s WHERE status = 'pending' ORDER BY priority ASC, created_at ASC LIMIT 1", - pq.tableName, + quoteIdent(pq.tableName), )) err = row.Scan(&id, &data) if err != nil { @@ -128,13 +128,13 @@ func (pq *PriorityQueue) dequeueInternal(withAckId bool) (any, bool, string) { ackID = cuid.New() _, err = tx.Exec( - fmt.Sprintf("UPDATE %s SET status = 'processing', ack_id = ?, updated_at = ? WHERE id = ?", pq.tableName), + fmt.Sprintf("UPDATE %s SET status = 'processing', ack_id = ?, updated_at = ? WHERE id = ?", quoteIdent(pq.tableName)), ackID, now, id, ) } else { // remove the row if there is no ack _, err = tx.Exec( - fmt.Sprintf("DELETE FROM %s WHERE id = ?", pq.tableName), + fmt.Sprintf("DELETE FROM %s WHERE id = ?", quoteIdent(pq.tableName)), id, ) } diff --git a/queue.go b/queue.go index 9171332..6a67467 100644 --- a/queue.go +++ b/queue.go @@ -3,7 +3,6 @@ package sqliteq import ( "database/sql" "fmt" - "strings" "sync/atomic" "time" @@ -317,11 +316,3 @@ func (q *Queue) Close() error { return nil } - -// Applies quotes to an identifier escaping any internal quotes. -// See: https://www.sqlite.org/lang_keywords.html -func quoteIdent(name string) string { - // Replace quotes with dobule quotes - escaped := strings.ReplaceAll(name, `"`, `""`) - return `"` + escaped + `"` -} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..6de4398 --- /dev/null +++ b/utils.go @@ -0,0 +1,11 @@ +package sqliteq + +import "strings" + +// Applies quotes to an identifier escaping any internal quotes. +// See: https://www.sqlite.org/lang_keywords.html +func quoteIdent(name string) string { + // Replace quotes with dobule quotes + escaped := strings.ReplaceAll(name, `"`, `""`) + return `"` + escaped + `"` +}