diff --git a/go.mod b/go.mod index 8a31f79..8fb8a33 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/upfluence/sql -go 1.21 +go 1.23 require ( github.com/lib/pq v1.10.9 diff --git a/x/sqlbuilder/insert_statement.go b/x/sqlbuilder/insert_statement.go index a473b5f..d96e5f9 100644 --- a/x/sqlbuilder/insert_statement.go +++ b/x/sqlbuilder/insert_statement.go @@ -3,6 +3,8 @@ package sqlbuilder import ( "fmt" "io" + "slices" + "strings" "github.com/upfluence/sql" ) @@ -104,8 +106,14 @@ type InsertStatement struct { Fields []Marker + // Deprecated: Please use Returnings Returning *sql.Returning + + Returnings []*sql.Returning + OnConfict *OnConflictClause + + isQuery bool } func (is InsertStatement) Clone() InsertStatement { @@ -118,11 +126,25 @@ func (is InsertStatement) Clone() InsertStatement { } return InsertStatement{ - Table: is.Table, - Fields: cloneMarkers(is.Fields), - Returning: r, - OnConfict: is.OnConfict.Clone(), + Table: is.Table, + Fields: cloneMarkers(is.Fields), + Returnings: slices.Clone(is.Returnings), + Returning: r, + isQuery: is.isQuery, + OnConfict: is.OnConfict.Clone(), + } +} + +func (is InsertStatement) returnings() []*sql.Returning { + var res []*sql.Returning + + res = append(res, is.Returnings...) + + if is.Returning != nil { + res = append(res, is.Returning) } + + return res } func (is InsertStatement) buildQuery(qvs map[string]interface{}) (string, []interface{}, error) { @@ -190,8 +212,23 @@ func (is InsertStatement) buildQueries(vvs []map[string]interface{}, qvs map[str } } - if is.Returning != nil { - qw.vs = append(qw.vs, is.Returning) + switch rs := is.returnings(); len(rs) { + case 0: + case 1: + if !is.isQuery { + qw.vs = append(qw.vs, is.Returning) + break + } + + fallthrough + default: + var fields = make([]string, len(rs)) + + for i, r := range rs { + fields[i] = r.Field + } + + fmt.Fprintf(&qw, " RETURNING %s", strings.Join(fields, ", ")) } return qw.String(), qw.vs, nil diff --git a/x/sqlbuilder/insert_statement_test.go b/x/sqlbuilder/insert_statement_test.go index 09e0688..8772497 100644 --- a/x/sqlbuilder/insert_statement_test.go +++ b/x/sqlbuilder/insert_statement_test.go @@ -85,6 +85,32 @@ func TestInsertQuery(t *testing.T) { stmt: "INSERT INTO foo(buz) VALUES ($1) ON CONFLICT (buz) DO UPDATE SET bar = $2", args: []interface{}{1, 2}, }, + { + name: "with returning + isQuery", + is: InsertStatement{ + Table: "foo", + Fields: []Marker{Column("buz")}, + Returning: &sql.Returning{Field: "buz"}, + isQuery: true, + }, + vs: map[string]interface{}{"buz": 1}, + stmt: "INSERT INTO foo(buz) VALUES ($1) RETURNING buz", + args: []interface{}{1}, + }, + { + name: "with returnings ", + is: InsertStatement{ + Table: "foo", + Fields: []Marker{Column("buz")}, + Returnings: []*sql.Returning{ + {Field: "buz"}, + {Field: "bar"}, + }, + }, + vs: map[string]interface{}{"buz": 1}, + stmt: "INSERT INTO foo(buz) VALUES ($1) RETURNING buz, bar", + args: []interface{}{1}, + }, } { t.Run(tt.name, func(t *testing.T) { stmt, args, err := tt.is.Clone().buildQuery(tt.vs) diff --git a/x/sqlbuilder/query_builder.go b/x/sqlbuilder/query_builder.go index c9635fe..ad51975 100644 --- a/x/sqlbuilder/query_builder.go +++ b/x/sqlbuilder/query_builder.go @@ -52,6 +52,27 @@ type InsertExecer struct { Statement InsertStatement } +type errScanner struct { + error +} + +func (es errScanner) Scan(...interface{}) error { + return es.error +} + +func (ie *InsertExecer) QueryRow(ctx context.Context, qvs map[string]interface{}) sql.Scanner { + stmt := ie.Statement + stmt.isQuery = true + + sstmt, vs, err := stmt.buildQuery(qvs) + + if err != nil { + return errScanner{error: err} + } + + return ie.qb.QueryRow(ctx, sstmt, vs...) +} + func (ie *InsertExecer) MultiExec(ctx context.Context, vvs []map[string]interface{}, qvs map[string]interface{}) (sql.Result, error) { stmt, vs, err := ie.Statement.buildQueries(vvs, qvs)