Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions dbmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,10 @@ func (m *DbMap) AddTable(i interface{}, name ...string) *TableMap {
tmap := &TableMap{gotype: t, TableName: Name, dbmap: m, mapper: m.mapper}
tmap.setupHooks(i)

n := t.NumField()
tmap.Columns = make([]*ColumnMap, 0, n)
for i := 0; i < n; i++ {
f := t.Field(i)
columnName := f.Tag.Get("db")
if columnName == "" {
columnName = sqlx.NameMapper(f.Name)
}

cm := &ColumnMap{
ColumnName: columnName,
Transient: columnName == "-",
fieldName: f.Name,
gotype: f.Type,
table: tmap,
}
tmap.Columns = append(tmap.Columns, cm)
tmap.Columns = columnMaps(t, tmap, nil)
for _, cm := range tmap.Columns {
if cm.fieldName == "Version" {
tmap.version = tmap.Columns[len(tmap.Columns)-1]
tmap.version = cm
}
}
m.tables = append(m.tables, tmap)
Expand All @@ -138,6 +123,39 @@ func (m *DbMap) AddTable(i interface{}, name ...string) *TableMap {

}

func columnMaps(t reflect.Type, tmap *TableMap, parentFieldIdx []int) []*ColumnMap {
var cols []*ColumnMap
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
name := f.Tag.Get("db")
if f.Anonymous {
cols = append(cols, columnMaps(f.Type, tmap, makeFieldIdx(parentFieldIdx, i))...)
} else {
if name == "" {
name = sqlx.NameMapper(f.Name)
}
cols = append(cols, &ColumnMap{
ColumnName: name,
Transient: name == "-",
fieldName: f.Name,
fieldIdx: makeFieldIdx(parentFieldIdx, i),
gotype: f.Type,
table: tmap,
})
}
}
return cols
}

// makeFieldIdx returns a new slice whose elements are equal to
// append(parent, i).
func makeFieldIdx(parent []int, i int) []int {
s := make([]int, len(parent)+1)
copy(s, parent)
s[len(s)-1] = i
return s
}

// AddTableWithName adds a new mapping of the interface to a table name.
func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap {
return m.AddTable(i, name)
Expand Down
8 changes: 4 additions & 4 deletions modl.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type bindPlan struct {
argFields []string
keyFields []string
versField string
autoIncrIdx int
autoIncrIdx []int // Go struct field index
}

func (plan bindPlan) createBindInstance(elem reflect.Value) bindInstance {
Expand Down Expand Up @@ -103,7 +103,7 @@ type bindInstance struct {
keys []interface{}
existingVersion int64
versField string
autoIncrIdx int
autoIncrIdx []int // the autoincr. column's Go struct field index
}

// SqlExecutor exposes modl operations that can be run from Pre/Post
Expand Down Expand Up @@ -330,12 +330,12 @@ func insert(m *DbMap, e SqlExecutor, list ...interface{}) error {

bi := table.bindInsert(elem)

if bi.autoIncrIdx > -1 {
if bi.autoIncrIdx != nil {
id, err := m.Dialect.InsertAutoIncr(e, bi.query, bi.args...)
if err != nil {
return err
}
f := elem.Field(bi.autoIncrIdx)
f := elem.FieldByIndex(bi.autoIncrIdx)
k := f.Kind()
if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
f.SetInt(id)
Expand Down
97 changes: 97 additions & 0 deletions modl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ type WithStringPk struct {

type CustomStringType string

type WithEmbeddedStruct struct {
Id int64
Names
}

type Names struct {
FirstName string
LastName string
}

func (p *Person) PreInsert(s SqlExecutor) error {
p.Created = time.Now().UnixNano()
p.Updated = p.Created
Expand Down Expand Up @@ -598,6 +608,93 @@ func TestWithStringPk(t *testing.T) {
}
}

func TestWithEmbeddedStruct(t *testing.T) {
dbmap := newDbMap()
//dbmap.TraceOn("", log.New(os.Stdout, "modltest: ", log.Lmicroseconds))
dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "ID")
err := dbmap.CreateTables()
if err != nil {
t.Errorf("couldn't create embedded_struct_test: %v", err)
}
defer dbmap.DropTables()

row := &WithEmbeddedStruct{Names: Names{"Alice", "Smith"}}
err = dbmap.Insert(row)
if err != nil {
t.Errorf("Error inserting into table w/embedded struct: %v", err)
}

var es WithEmbeddedStruct
err = dbmap.Get(&es, row.Id)
if err != nil {
t.Errorf("Error selecting from table w/embedded struct: %v", err)
}
}

func TestWithEmbeddedStructAutoIncrColNotFirst(t *testing.T) {
// Tests that the tablemap retains separate indices for SQL
// columns (which are flattened with respect to struct embedding)
// and Go fields (which are not). In this test case, the
// auto-incremented column is the 3rd column in SQL but the 2nd Go
// field.

type Embedded struct{ A, B string }
type withAutoIncrColNotFirst struct {
Embedded
ID int
}

dbmap := newDbMap()
//dbmap.TraceOn("", log.New(os.Stdout, "modltest: ", log.Lmicroseconds))
dbmap.AddTableWithName(withAutoIncrColNotFirst{}, "auto_incr_col_not_first_test").SetKeys(true, "ID")
if err := dbmap.CreateTables(); err != nil {
t.Errorf("couldn't create auto_incr_col_not_first_test: %v", err)
}
defer dbmap.Cleanup()

row := withAutoIncrColNotFirst{Embedded: Embedded{A: "a"}, ID: 0}
if err := dbmap.Insert(&row); err != nil {
t.Fatal(err)
}

var got withAutoIncrColNotFirst
if err := dbmap.Get(&got, row.ID); err != nil {
t.Fatal(err)
}
if got != row {
t.Errorf("Got %+v, want %+v", got, row)
}
}

func TestWithEmbeddedAutoIncrCol(t *testing.T) {
type EmbeddedID struct {
A string
ID int
}
type embeddedAutoIncrCol struct{ EmbeddedID }

dbmap := newDbMap()
//dbmap.TraceOn("", log.New(os.Stdout, "modltest: ", log.Lmicroseconds))
dbmap.AddTableWithName(embeddedAutoIncrCol{}, "embedded_auto_incr_col_test").SetKeys(true, "ID")
if err := dbmap.CreateTables(); err != nil {
t.Errorf("couldn't create embedded_auto_incr_col_test: %v", err)
}
defer dbmap.Cleanup()

row := embeddedAutoIncrCol{EmbeddedID{A: "a", ID: 0}}
if err := dbmap.Insert(&row); err != nil {
t.Fatal(err)
}

var got embeddedAutoIncrCol
if err := dbmap.Get(&got, row.ID); err != nil {
t.Fatal(err)
}
if got != row {
t.Errorf("Got %+v, want %+v", got, row)
}
}

func BenchmarkNativeCrud(b *testing.B) {
var err error

Expand Down
11 changes: 6 additions & 5 deletions tablemap.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,11 @@ func (t *TableMap) bindUpdate(elem reflect.Value) bindInstance {
func (t *TableMap) bindInsert(elem reflect.Value) bindInstance {
plan := t.insertPlan
if plan.query == "" {
plan.autoIncrIdx = -1

s := bytes.Buffer{}
s2 := bytes.Buffer{}
s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuoteField(t.TableName)))

var autoIncrCol *ColumnMap
x := 0
first := true
for y := range t.Columns {
Expand All @@ -257,7 +256,8 @@ func (t *TableMap) bindInsert(elem reflect.Value) bindInstance {

if col.isAutoIncr {
s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue())
plan.autoIncrIdx = y
plan.autoIncrIdx = col.fieldIdx
autoIncrCol = col
} else {
s2.WriteString(t.dbmap.Dialect.BindVar(x))
if col == t.version {
Expand All @@ -276,8 +276,8 @@ func (t *TableMap) bindInsert(elem reflect.Value) bindInstance {
s.WriteString(") values (")
s.WriteString(s2.String())
s.WriteString(")")
if plan.autoIncrIdx > -1 {
s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx]))
if autoIncrCol != nil {
s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(autoIncrCol))
}
s.WriteString(";")

Expand Down Expand Up @@ -310,6 +310,7 @@ type ColumnMap struct {
table *TableMap

fieldName string
fieldIdx []int // Go struct field index from table's struct
gotype reflect.Type
sqltype string
createSql string
Expand Down