From adfc7de9f7ef9a5d704b3517da002d79dd5b3806 Mon Sep 17 00:00:00 2001 From: Vincent Behar Date: Fri, 23 Jun 2017 11:00:57 +0200 Subject: [PATCH] add support for context allow to pass context.Context to the sql package, using the *Context funcs on sqlx needs a more recent version of sqlx - see https://github.com/jmoiron/sqlx/pull/270 --- dat/execer.go | 10 +++++++++- glide.lock | 4 ++-- sqlx-runner/exec.go | 24 +++++++++++++++--------- sqlx-runner/execer.go | 10 ++++++++++ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/dat/execer.go b/dat/execer.go index 29d4967..c42b3a7 100644 --- a/dat/execer.go +++ b/dat/execer.go @@ -1,6 +1,9 @@ package dat -import "time" +import ( + "context" + "time" +) // Result serves the same purpose as sql.Result. Defining // it for the package avoids tight coupling with database/sql. @@ -13,6 +16,7 @@ type Result struct { type Execer interface { Cache(id string, ttl time.Duration, invalidate bool) Execer Timeout(time.Duration) Execer + Context(context.Context) Execer Interpolate() (string, []interface{}, error) Exec() (*Result, error) @@ -38,6 +42,10 @@ func (nop *disconnectedExecer) Timeout(time.Duration) Execer { return nil } +func (nop *disconnectedExecer) Context(context.Context) Execer { + return nil +} + // Exec panics when Exec is called. func (nop *disconnectedExecer) Exec() (*Result, error) { return nil, ErrDisconnectedExecer diff --git a/glide.lock b/glide.lock index 93b145b..5f80d0b 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ hash: 96b33372e619c3e4f2c39384e9577881d2a6ba9c22ac6eb04cfd680471f8dc84 -updated: 2016-06-04T15:45:36.594344149-07:00 +updated: 2017-06-23T10:43:54.561604473+02:00 imports: - name: github.com/cenkalti/backoff version: c29158af31815ccc31ca29c86c121bc39e00d3d8 @@ -11,7 +11,7 @@ imports: - name: github.com/howeyc/gopass version: 66487b23f2880ba32e185121d2cd51a338ea069a - name: github.com/jmoiron/sqlx - version: a7f971fe8ea891a1a74ddb40c623f5722e28e8a8 + version: 8ed836a8adb659e8492bcaa49e0880bb84075fe2 subpackages: - reflectx - name: github.com/lib/pq diff --git a/sqlx-runner/exec.go b/sqlx-runner/exec.go index c9f6bc3..d422fcf 100644 --- a/sqlx-runner/exec.go +++ b/sqlx-runner/exec.go @@ -2,6 +2,7 @@ package runner import ( "bytes" + "context" "database/sql" "encoding/json" "errors" @@ -22,10 +23,15 @@ import ( // queries can be executed type database interface { Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) Queryx(query string, args ...interface{}) (*sqlx.Rows, error) + QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) QueryRowx(query string, args ...interface{}) *sqlx.Row + QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row Select(dest interface{}, query string, args ...interface{}) error + SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error Get(dest interface{}, query string, args ...interface{}) error + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error } func toOutputStr(args []interface{}) string { @@ -126,7 +132,7 @@ func (ex *Execer) execFn() (sql.Result, error) { defer logExecutionTime(time.Now(), fullSQL, args) var result sql.Result - result, err = ex.database.Exec(fullSQL, args...) + result, err = ex.database.ExecContext(ex.ctx, fullSQL, args...) if err != nil { return nil, logSQLError(err, "execFn.30:"+fmt.Sprintf("%T", err), fullSQL, args) } @@ -141,7 +147,7 @@ func (ex *Execer) execSQL(fullSQL string, args []interface{}) (sql.Result, error var result sql.Result var err error - result, err = ex.database.Exec(fullSQL, args...) + result, err = ex.database.ExecContext(ex.ctx, fullSQL, args...) if err != nil { return nil, logSQLError(err, "execSQL.30", fullSQL, args) } @@ -180,7 +186,7 @@ func (ex *Execer) queryFn() (*sqlx.Rows, error) { } defer logExecutionTime(time.Now(), fullSQL, args) - rows, err := ex.database.Queryx(fullSQL, args...) + rows, err := ex.database.QueryxContext(ex.ctx, fullSQL, args...) if err != nil { return nil, logSQLError(err, "queryFn.30", fullSQL, args) } @@ -230,7 +236,7 @@ func (ex *Execer) queryScalarFn(destinations []interface{}) error { defer logExecutionTime(time.Now(), fullSQL, args) // Run the query: var rows *sqlx.Rows - rows, err = ex.database.Queryx(fullSQL, args...) + rows, err = ex.database.QueryxContext(ex.ctx, fullSQL, args...) if err != nil { return logSQLError(err, "queryScalarFn.12: querying database", fullSQL, args) } @@ -316,7 +322,7 @@ func (ex *Execer) querySliceFn(dest interface{}) error { } defer logExecutionTime(time.Now(), fullSQL, args) - rows, err := ex.database.Queryx(fullSQL, args...) + rows, err := ex.database.QueryxContext(ex.ctx, fullSQL, args...) if err != nil { return logSQLError(err, "querySlice.load_all_values.query", fullSQL, args) } @@ -387,7 +393,7 @@ func (ex *Execer) queryStructFn(dest interface{}) error { } defer logExecutionTime(time.Now(), fullSQL, args) - err = ex.database.Get(dest, fullSQL, args...) + err = ex.database.GetContext(ex.ctx, dest, fullSQL, args...) if err != nil { return logSQLError(err, "queryStruct.3", fullSQL, args) } @@ -438,7 +444,7 @@ func (ex *Execer) queryStructsFn(dest interface{}) error { } defer logExecutionTime(time.Now(), fullSQL, args) - err = ex.database.Select(dest, fullSQL, args...) + err = ex.database.SelectContext(ex.ctx, dest, fullSQL, args...) if err != nil { logSQLError(err, "queryStructs", fullSQL, args) } @@ -498,7 +504,7 @@ func (ex *Execer) queryJSONBlobFn(single bool) ([]byte, error) { } defer logExecutionTime(time.Now(), fullSQL, args) - rows, err := ex.database.Queryx(fullSQL, args...) + rows, err := ex.database.QueryxContext(ex.ctx, fullSQL, args...) if err != nil { return nil, logSQLError(err, "queryJSONStructs", fullSQL, args) } @@ -689,7 +695,7 @@ func (ex *Execer) queryJSONFn() ([]byte, error) { defer logExecutionTime(time.Now(), fullSQL, args) jsonSQL := fmt.Sprintf("SELECT TO_JSON(ARRAY_AGG(__datq.*)) FROM (%s) AS __datq", fullSQL) - err = ex.database.Get(&blob, jsonSQL, args...) + err = ex.database.GetContext(ex.ctx, &blob, jsonSQL, args...) if err != nil { logSQLError(err, "queryJSON", jsonSQL, args) } diff --git a/sqlx-runner/execer.go b/sqlx-runner/execer.go index c680632..bb7886a 100644 --- a/sqlx-runner/execer.go +++ b/sqlx-runner/execer.go @@ -1,6 +1,7 @@ package runner import ( + "context" "encoding/json" "fmt" "time" @@ -24,6 +25,8 @@ type Execer struct { // uuid is prepended into the SQL for the query to be searched // in pg_stat_activity, used by timeout logic queryID string + + ctx context.Context } const queryIDPrefix = "--dat:qid=" @@ -33,6 +36,7 @@ func NewExecer(database database, builder dat.Builder) *Execer { return &Execer{ database: database, builder: builder, + ctx: context.Background(), } } @@ -55,6 +59,12 @@ func (ex *Execer) Timeout(timeout time.Duration) dat.Execer { return ex } +// Context sets the context for current query. +func (ex *Execer) Context(ctx context.Context) dat.Execer { + ex.ctx = ctx + return ex +} + func datQueryID(id string) string { return fmt.Sprintf("--dat:qid=%s", id) }