From 5e02e6e3f35f53a2e63c763223d9524c8cbbc7c3 Mon Sep 17 00:00:00 2001 From: Quinn Slack Date: Tue, 29 Apr 2014 21:35:00 -0700 Subject: [PATCH 1/2] Support embedded structs (closes #13) --- dbmap.go | 44 ++++++++++++++++++++++++++------------------ modl_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/dbmap.go b/dbmap.go index 9580a10..02740c0 100644 --- a/dbmap.go +++ b/dbmap.go @@ -111,25 +111,10 @@ func (m *DbMap) AddTable(i interface{}, name ...string) *TableMap { tmap := &TableMap{gotype: t, TableName: Name, dbmap: m, mapper: m.mapper} tmap.setupHooks(i) - n := t.NumField() - tmap.Columns = make([]*ColumnMap, 0, n) - for i := 0; i < n; i++ { - f := t.Field(i) - columnName := f.Tag.Get("db") - if columnName == "" { - columnName = sqlx.NameMapper(f.Name) - } - - cm := &ColumnMap{ - ColumnName: columnName, - Transient: columnName == "-", - fieldName: f.Name, - gotype: f.Type, - table: tmap, - } - tmap.Columns = append(tmap.Columns, cm) + tmap.Columns = columnMaps(t, tmap) + for _, cm := range tmap.Columns { if cm.fieldName == "Version" { - tmap.version = tmap.Columns[len(tmap.Columns)-1] + tmap.version = cm } } m.tables = append(m.tables, tmap) @@ -138,6 +123,29 @@ func (m *DbMap) AddTable(i interface{}, name ...string) *TableMap { } +func columnMaps(t reflect.Type, tmap *TableMap) []*ColumnMap { + var cols []*ColumnMap + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + name := f.Tag.Get("db") + if f.Anonymous { + cols = append(cols, columnMaps(f.Type, tmap)...) + } else { + if name == "" { + name = sqlx.NameMapper(f.Name) + } + cols = append(cols, &ColumnMap{ + ColumnName: name, + Transient: name == "-", + fieldName: f.Name, + gotype: f.Type, + table: tmap, + }) + } + } + return cols +} + // AddTableWithName adds a new mapping of the interface to a table name. func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap { return m.AddTable(i, name) diff --git a/modl_test.go b/modl_test.go index e533c42..bdaeb4c 100644 --- a/modl_test.go +++ b/modl_test.go @@ -65,6 +65,16 @@ type WithStringPk struct { type CustomStringType string +type WithEmbeddedStruct struct { + Id int64 + Names +} + +type Names struct { + FirstName string + LastName string +} + func (p *Person) PreInsert(s SqlExecutor) error { p.Created = time.Now().UnixNano() p.Updated = p.Created @@ -598,6 +608,29 @@ func TestWithStringPk(t *testing.T) { } } +func TestWithEmbeddedStruct(t *testing.T) { + dbmap := newDbMap() + //dbmap.TraceOn("", log.New(os.Stdout, "modltest: ", log.Lmicroseconds)) + dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "ID") + err := dbmap.CreateTables() + if err != nil { + t.Errorf("couldn't create embedded_struct_test: %v", err) + } + defer dbmap.DropTables() + + row := &WithEmbeddedStruct{Names: Names{"Alice", "Smith"}} + err = dbmap.Insert(row) + if err != nil { + t.Errorf("Error inserting into table w/embedded struct: %v", err) + } + + var es WithEmbeddedStruct + err = dbmap.Get(&es, row.Id) + if err != nil { + t.Errorf("Error selecting from table w/embedded struct: %v", err) + } +} + func BenchmarkNativeCrud(b *testing.B) { var err error From d6485b99aa47f8dad64193930ed687f1efabb156 Mon Sep 17 00:00:00 2001 From: Quinn Slack Date: Sun, 19 Oct 2014 11:28:57 -0700 Subject: [PATCH 2/2] Handle auto-increment columns in or after embedded structs Fixes issue reported by @cryptix at https://github.com/jmoiron/modl/pull/17#issuecomment-55384738. --- dbmap.go | 16 ++++++++++--- modl.go | 8 +++---- modl_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++ tablemap.go | 11 +++++---- 4 files changed, 87 insertions(+), 12 deletions(-) diff --git a/dbmap.go b/dbmap.go index 02740c0..b37a2e8 100644 --- a/dbmap.go +++ b/dbmap.go @@ -111,7 +111,7 @@ func (m *DbMap) AddTable(i interface{}, name ...string) *TableMap { tmap := &TableMap{gotype: t, TableName: Name, dbmap: m, mapper: m.mapper} tmap.setupHooks(i) - tmap.Columns = columnMaps(t, tmap) + tmap.Columns = columnMaps(t, tmap, nil) for _, cm := range tmap.Columns { if cm.fieldName == "Version" { tmap.version = cm @@ -123,13 +123,13 @@ func (m *DbMap) AddTable(i interface{}, name ...string) *TableMap { } -func columnMaps(t reflect.Type, tmap *TableMap) []*ColumnMap { +func columnMaps(t reflect.Type, tmap *TableMap, parentFieldIdx []int) []*ColumnMap { var cols []*ColumnMap for i := 0; i < t.NumField(); i++ { f := t.Field(i) name := f.Tag.Get("db") if f.Anonymous { - cols = append(cols, columnMaps(f.Type, tmap)...) + cols = append(cols, columnMaps(f.Type, tmap, makeFieldIdx(parentFieldIdx, i))...) } else { if name == "" { name = sqlx.NameMapper(f.Name) @@ -138,6 +138,7 @@ func columnMaps(t reflect.Type, tmap *TableMap) []*ColumnMap { ColumnName: name, Transient: name == "-", fieldName: f.Name, + fieldIdx: makeFieldIdx(parentFieldIdx, i), gotype: f.Type, table: tmap, }) @@ -146,6 +147,15 @@ func columnMaps(t reflect.Type, tmap *TableMap) []*ColumnMap { return cols } +// makeFieldIdx returns a new slice whose elements are equal to +// append(parent, i). +func makeFieldIdx(parent []int, i int) []int { + s := make([]int, len(parent)+1) + copy(s, parent) + s[len(s)-1] = i + return s +} + // AddTableWithName adds a new mapping of the interface to a table name. func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap { return m.AddTable(i, name) diff --git a/modl.go b/modl.go index a05c4a5..27ff76a 100644 --- a/modl.go +++ b/modl.go @@ -65,7 +65,7 @@ type bindPlan struct { argFields []string keyFields []string versField string - autoIncrIdx int + autoIncrIdx []int // Go struct field index } func (plan bindPlan) createBindInstance(elem reflect.Value) bindInstance { @@ -103,7 +103,7 @@ type bindInstance struct { keys []interface{} existingVersion int64 versField string - autoIncrIdx int + autoIncrIdx []int // the autoincr. column's Go struct field index } // SqlExecutor exposes modl operations that can be run from Pre/Post @@ -330,12 +330,12 @@ func insert(m *DbMap, e SqlExecutor, list ...interface{}) error { bi := table.bindInsert(elem) - if bi.autoIncrIdx > -1 { + if bi.autoIncrIdx != nil { id, err := m.Dialect.InsertAutoIncr(e, bi.query, bi.args...) if err != nil { return err } - f := elem.Field(bi.autoIncrIdx) + f := elem.FieldByIndex(bi.autoIncrIdx) k := f.Kind() if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) { f.SetInt(id) diff --git a/modl_test.go b/modl_test.go index bdaeb4c..ed38859 100644 --- a/modl_test.go +++ b/modl_test.go @@ -631,6 +631,70 @@ func TestWithEmbeddedStruct(t *testing.T) { } } +func TestWithEmbeddedStructAutoIncrColNotFirst(t *testing.T) { + // Tests that the tablemap retains separate indices for SQL + // columns (which are flattened with respect to struct embedding) + // and Go fields (which are not). In this test case, the + // auto-incremented column is the 3rd column in SQL but the 2nd Go + // field. + + type Embedded struct{ A, B string } + type withAutoIncrColNotFirst struct { + Embedded + ID int + } + + dbmap := newDbMap() + //dbmap.TraceOn("", log.New(os.Stdout, "modltest: ", log.Lmicroseconds)) + dbmap.AddTableWithName(withAutoIncrColNotFirst{}, "auto_incr_col_not_first_test").SetKeys(true, "ID") + if err := dbmap.CreateTables(); err != nil { + t.Errorf("couldn't create auto_incr_col_not_first_test: %v", err) + } + defer dbmap.Cleanup() + + row := withAutoIncrColNotFirst{Embedded: Embedded{A: "a"}, ID: 0} + if err := dbmap.Insert(&row); err != nil { + t.Fatal(err) + } + + var got withAutoIncrColNotFirst + if err := dbmap.Get(&got, row.ID); err != nil { + t.Fatal(err) + } + if got != row { + t.Errorf("Got %+v, want %+v", got, row) + } +} + +func TestWithEmbeddedAutoIncrCol(t *testing.T) { + type EmbeddedID struct { + A string + ID int + } + type embeddedAutoIncrCol struct{ EmbeddedID } + + dbmap := newDbMap() + //dbmap.TraceOn("", log.New(os.Stdout, "modltest: ", log.Lmicroseconds)) + dbmap.AddTableWithName(embeddedAutoIncrCol{}, "embedded_auto_incr_col_test").SetKeys(true, "ID") + if err := dbmap.CreateTables(); err != nil { + t.Errorf("couldn't create embedded_auto_incr_col_test: %v", err) + } + defer dbmap.Cleanup() + + row := embeddedAutoIncrCol{EmbeddedID{A: "a", ID: 0}} + if err := dbmap.Insert(&row); err != nil { + t.Fatal(err) + } + + var got embeddedAutoIncrCol + if err := dbmap.Get(&got, row.ID); err != nil { + t.Fatal(err) + } + if got != row { + t.Errorf("Got %+v, want %+v", got, row) + } +} + func BenchmarkNativeCrud(b *testing.B) { var err error diff --git a/tablemap.go b/tablemap.go index 98b32f3..49c9e64 100644 --- a/tablemap.go +++ b/tablemap.go @@ -237,12 +237,11 @@ func (t *TableMap) bindUpdate(elem reflect.Value) bindInstance { func (t *TableMap) bindInsert(elem reflect.Value) bindInstance { plan := t.insertPlan if plan.query == "" { - plan.autoIncrIdx = -1 - s := bytes.Buffer{} s2 := bytes.Buffer{} s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuoteField(t.TableName))) + var autoIncrCol *ColumnMap x := 0 first := true for y := range t.Columns { @@ -257,7 +256,8 @@ func (t *TableMap) bindInsert(elem reflect.Value) bindInstance { if col.isAutoIncr { s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue()) - plan.autoIncrIdx = y + plan.autoIncrIdx = col.fieldIdx + autoIncrCol = col } else { s2.WriteString(t.dbmap.Dialect.BindVar(x)) if col == t.version { @@ -276,8 +276,8 @@ func (t *TableMap) bindInsert(elem reflect.Value) bindInstance { s.WriteString(") values (") s.WriteString(s2.String()) s.WriteString(")") - if plan.autoIncrIdx > -1 { - s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx])) + if autoIncrCol != nil { + s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(autoIncrCol)) } s.WriteString(";") @@ -310,6 +310,7 @@ type ColumnMap struct { table *TableMap fieldName string + fieldIdx []int // Go struct field index from table's struct gotype reflect.Type sqltype string createSql string