@@ -3,7 +3,9 @@ package mutation
33import (
44 "fmt"
55
6+ "github.com/efritz/gostgres/internal/execution/projection"
67 "github.com/efritz/gostgres/internal/execution/queries/nodes"
8+ projectionHelpers "github.com/efritz/gostgres/internal/execution/queries/nodes/projection"
79 "github.com/efritz/gostgres/internal/execution/serialization"
810 "github.com/efritz/gostgres/internal/shared/fields"
911 "github.com/efritz/gostgres/internal/shared/impls"
@@ -15,24 +17,50 @@ type insertNode struct {
1517 nodes.Node
1618 table impls.Table
1719 columnNames []string
20+ projection * projection.Projection
1821}
1922
20- func NewInsert (node nodes.Node , table impls.Table , columnNames []string ) nodes.Node {
23+ func NewInsert (node nodes.Node , table impls.Table , columnNames []string , projection * projection. Projection ) nodes.Node {
2124 return & insertNode {
2225 Node : node ,
2326 table : table ,
2427 columnNames : columnNames ,
28+ projection : projection ,
2529 }
2630}
2731
2832func (n * insertNode ) Serialize (w serialization.IndentWriter ) {
2933 w .WritefLine ("insert into %s" , n .table .Name ())
3034 n .Node .Serialize (w .Indent ())
35+
36+ if n .projection != nil {
37+ w .WritefLine ("returning %s" , n .projection )
38+ n .Node .Serialize (w .Indent ())
39+ }
3140}
3241
3342func (n * insertNode ) Scanner (ctx impls.ExecutionContext ) (scan.RowScanner , error ) {
3443 ctx .Log ("Building Insert scanner" )
3544
45+ insertedRows , err := n .insertRows (ctx )
46+ if err != nil {
47+ return nil , err
48+ }
49+
50+ return scan .RowScannerFunc (func () (rows.Row , error ) {
51+ ctx .Log ("Scanning Insert" )
52+
53+ if len (insertedRows ) != 0 {
54+ return rows.Row {}, scan .ErrNoRows
55+ }
56+
57+ row := insertedRows [0 ]
58+ insertedRows = insertedRows [1 :]
59+ return projectionHelpers .Project (ctx , row , n .projection )
60+ }), nil
61+ }
62+
63+ func (n * insertNode ) insertRows (ctx impls.ExecutionContext ) ([]rows.Row , error ) {
3664 scanner , err := n .Node .Scanner (ctx )
3765 if err != nil {
3866 return nil , err
@@ -50,31 +78,38 @@ func (n *insertNode) Scanner(ctx impls.ExecutionContext) (scan.RowScanner, error
5078 fields = append (fields , field .Field )
5179 }
5280
53- return scan .RowScannerFunc (func () (rows.Row , error ) {
54- ctx .Log ("Scanning Insert" )
55-
81+ var insertedRows []rows.Row
82+ for {
5683 row , err := scanner .Scan ()
5784 if err != nil {
58- return rows.Row {}, err
85+ if err == scan .ErrNoRows {
86+ break
87+ }
88+
89+ return nil , err
5990 }
6091
6192 values , err := n .prepareValuesForRow (ctx , row , nonInternalFields )
6293 if err != nil {
63- return rows. Row {} , err
94+ return nil , err
6495 }
6596
6697 insertedRow , err := rows .NewRow (fields , values )
6798 if err != nil {
68- return rows. Row {} , err
99+ return nil , err
69100 }
70101
71102 insertedRow , err = n .table .Insert (ctx , insertedRow )
72103 if err != nil {
73- return rows. Row {} , err
104+ return nil , err
74105 }
75106
76- return insertedRow , nil
77- }), nil
107+ if n .projection != nil {
108+ insertedRows = append (insertedRows , insertedRow )
109+ }
110+ }
111+
112+ return insertedRows , nil
78113}
79114
80115func (n * insertNode ) prepareValuesForRow (ctx impls.ExecutionContext , row rows.Row , fields []impls.TableField ) ([]any , error ) {
0 commit comments