Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion dat/execer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package dat

import "time"
import (
"context"
"time"
)

// Result serves the same purpose as sql.Result. Defining
// it for the package avoids tight coupling with database/sql.
Expand All @@ -13,6 +16,7 @@ type Result struct {
type Execer interface {
Cache(id string, ttl time.Duration, invalidate bool) Execer
Timeout(time.Duration) Execer
Context(context.Context) Execer
Interpolate() (string, []interface{}, error)
Exec() (*Result, error)

Expand All @@ -38,6 +42,10 @@ func (nop *disconnectedExecer) Timeout(time.Duration) Execer {
return nil
}

func (nop *disconnectedExecer) Context(context.Context) Execer {
return nil
}

// Exec panics when Exec is called.
func (nop *disconnectedExecer) Exec() (*Result, error) {
return nil, ErrDisconnectedExecer
Expand Down
4 changes: 2 additions & 2 deletions glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 15 additions & 9 deletions sqlx-runner/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package runner

import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
Expand All @@ -22,10 +23,15 @@ import (
// queries can be executed
type database interface {
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Queryx(query string, args ...interface{}) (*sqlx.Rows, error)
QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error)
QueryRowx(query string, args ...interface{}) *sqlx.Row
QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row
Select(dest interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
Get(dest interface{}, query string, args ...interface{}) error
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
}

func toOutputStr(args []interface{}) string {
Expand Down Expand Up @@ -126,7 +132,7 @@ func (ex *Execer) execFn() (sql.Result, error) {
defer logExecutionTime(time.Now(), fullSQL, args)

var result sql.Result
result, err = ex.database.Exec(fullSQL, args...)
result, err = ex.database.ExecContext(ex.ctx, fullSQL, args...)
if err != nil {
return nil, logSQLError(err, "execFn.30:"+fmt.Sprintf("%T", err), fullSQL, args)
}
Expand All @@ -141,7 +147,7 @@ func (ex *Execer) execSQL(fullSQL string, args []interface{}) (sql.Result, error

var result sql.Result
var err error
result, err = ex.database.Exec(fullSQL, args...)
result, err = ex.database.ExecContext(ex.ctx, fullSQL, args...)
if err != nil {
return nil, logSQLError(err, "execSQL.30", fullSQL, args)
}
Expand Down Expand Up @@ -180,7 +186,7 @@ func (ex *Execer) queryFn() (*sqlx.Rows, error) {
}

defer logExecutionTime(time.Now(), fullSQL, args)
rows, err := ex.database.Queryx(fullSQL, args...)
rows, err := ex.database.QueryxContext(ex.ctx, fullSQL, args...)
if err != nil {
return nil, logSQLError(err, "queryFn.30", fullSQL, args)
}
Expand Down Expand Up @@ -230,7 +236,7 @@ func (ex *Execer) queryScalarFn(destinations []interface{}) error {
defer logExecutionTime(time.Now(), fullSQL, args)
// Run the query:
var rows *sqlx.Rows
rows, err = ex.database.Queryx(fullSQL, args...)
rows, err = ex.database.QueryxContext(ex.ctx, fullSQL, args...)
if err != nil {
return logSQLError(err, "queryScalarFn.12: querying database", fullSQL, args)
}
Expand Down Expand Up @@ -316,7 +322,7 @@ func (ex *Execer) querySliceFn(dest interface{}) error {
}

defer logExecutionTime(time.Now(), fullSQL, args)
rows, err := ex.database.Queryx(fullSQL, args...)
rows, err := ex.database.QueryxContext(ex.ctx, fullSQL, args...)
if err != nil {
return logSQLError(err, "querySlice.load_all_values.query", fullSQL, args)
}
Expand Down Expand Up @@ -387,7 +393,7 @@ func (ex *Execer) queryStructFn(dest interface{}) error {
}

defer logExecutionTime(time.Now(), fullSQL, args)
err = ex.database.Get(dest, fullSQL, args...)
err = ex.database.GetContext(ex.ctx, dest, fullSQL, args...)
if err != nil {
return logSQLError(err, "queryStruct.3", fullSQL, args)
}
Expand Down Expand Up @@ -438,7 +444,7 @@ func (ex *Execer) queryStructsFn(dest interface{}) error {
}

defer logExecutionTime(time.Now(), fullSQL, args)
err = ex.database.Select(dest, fullSQL, args...)
err = ex.database.SelectContext(ex.ctx, dest, fullSQL, args...)
if err != nil {
logSQLError(err, "queryStructs", fullSQL, args)
}
Expand Down Expand Up @@ -498,7 +504,7 @@ func (ex *Execer) queryJSONBlobFn(single bool) ([]byte, error) {
}

defer logExecutionTime(time.Now(), fullSQL, args)
rows, err := ex.database.Queryx(fullSQL, args...)
rows, err := ex.database.QueryxContext(ex.ctx, fullSQL, args...)
if err != nil {
return nil, logSQLError(err, "queryJSONStructs", fullSQL, args)
}
Expand Down Expand Up @@ -689,7 +695,7 @@ func (ex *Execer) queryJSONFn() ([]byte, error) {
defer logExecutionTime(time.Now(), fullSQL, args)
jsonSQL := fmt.Sprintf("SELECT TO_JSON(ARRAY_AGG(__datq.*)) FROM (%s) AS __datq", fullSQL)

err = ex.database.Get(&blob, jsonSQL, args...)
err = ex.database.GetContext(ex.ctx, &blob, jsonSQL, args...)
if err != nil {
logSQLError(err, "queryJSON", jsonSQL, args)
}
Expand Down
10 changes: 10 additions & 0 deletions sqlx-runner/execer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package runner

import (
"context"
"encoding/json"
"fmt"
"time"
Expand All @@ -24,6 +25,8 @@ type Execer struct {
// uuid is prepended into the SQL for the query to be searched
// in pg_stat_activity, used by timeout logic
queryID string

ctx context.Context
}

const queryIDPrefix = "--dat:qid="
Expand All @@ -33,6 +36,7 @@ func NewExecer(database database, builder dat.Builder) *Execer {
return &Execer{
database: database,
builder: builder,
ctx: context.Background(),
}
}

Expand All @@ -55,6 +59,12 @@ func (ex *Execer) Timeout(timeout time.Duration) dat.Execer {
return ex
}

// Context sets the context for current query.
func (ex *Execer) Context(ctx context.Context) dat.Execer {
ex.ctx = ctx
return ex
}

func datQueryID(id string) string {
return fmt.Sprintf("--dat:qid=%s", id)
}
Expand Down