Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func execWithParam(c context.Context, tdx Tdx, paramQuery string, paramMap inter
if params != nil && len(params) > 0 {
var args []interface{} = make([]interface{}, 0, len(params))
for _, param := range params {
param = param[2: len(param)-1]
param = param[2 : len(param)-1]
value, err := getFieldValue(paramMap, param)
if err != nil {
return nil, err
Expand Down Expand Up @@ -742,7 +742,7 @@ func selectRawWithParam(c context.Context, tdx Tdx, paramQuery string, paramMap
if params != nil && len(params) > 0 {
var args []interface{} = make([]interface{}, 0, len(params))
for _, param := range params {
param = param[2: len(param)-1]
param = param[2 : len(param)-1]
value, err := getFieldValue(paramMap, param)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -1201,6 +1201,38 @@ func insertOrUpdate(c context.Context, tdx Tdx, s interface{}, fields []string)
return nil
}

func insertOrUpdateByTable(c context.Context, tdx Tdx, tbName string, s interface{}, fields []string) error {
cols, vals, ifs, pk, isAi, pkName := columnsByStruct(s)
if len(fields) == 0 {
fields = strings.Split(cols, ",")
}
//重复时,需要更新的字段
for k, v := range fields {
v = fieldName2ColName(v)
str := fmt.Sprintf("%s=values(%s)", v, v)
fields[k] = str
}
//检查主键的情况,在insert中加入主键
if pk.Addr().Interface != nil {
cols += fmt.Sprintf(",%s", pkName)
vals += ",?"
ifs = append(ifs, pk.Addr().Interface())
}
q := fmt.Sprintf("insert into %s (%s) values (%s) on duplicate key update %s", tbName, cols, vals, strings.Join(fields, ","))
ret, err := exec(c, tdx, q, ifs...)
if err != nil {
return err
}
if isAi {
lid, err := ret.LastInsertId()
if err != nil {
return err
}
pk.SetInt(lid)
}
return nil
}

//通过传递需要更新的字段,去更新部分字段
func updateFieldsByPK(c context.Context, tdx Tdx, s interface{}, cols []string) error {
ifs, pk, _, pkName := columnsByStructFields(s, cols)
Expand Down Expand Up @@ -1425,7 +1457,11 @@ func (o *ORM) InsertBatch(s []interface{}) error {
}

func (o *ORM) InsertOrUpdate(s interface{}, keys []string) error {
return insertOrUpdate(o.ctx, o.db, s, keys)
return insertOrUpdateByTable(o.ctx, o.db, getTableName(s), s, keys)
}

func (o *ORM) InsertOrUpdateWithTable(s interface{}, tbName string, keys []string) error {
return insertOrUpdateByTable(o.ctx, o.db, tbName, s, keys)
}

func (o *ORM) ExecWithRowAffectCheck(n int64, query string, args ...interface{}) error {
Expand Down
21 changes: 14 additions & 7 deletions orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/magiconair/properties/assert"
"log"
"testing"
"time"

"github.com/magiconair/properties/assert"
)

type TestOrmA123 struct {
TestID int64 `json:"test_id" pk:"true" ai:"true" db:"test_id,ai,pk"`
TestID int64 `json:"test_id" pk:"true" ai:"true" db:"test_id,ai,pk"`
OtherId int64
Description string
Name sql.NullString
Expand All @@ -26,7 +27,7 @@ type TestOrmA123 struct {
}

type TestOrmB999 struct {
NoAiId int64 `pk:"true"`
NoAiId int64 `pk:"true"`
Description string
TestID int64 `db:"test_id"`
CreatedAt time.Time `ignore:"true"`
Expand All @@ -45,7 +46,7 @@ type TestOrmD222 struct {
}

type TestOrmE333 struct {
TestOrmEId int64 `pk:"true" ai:"true"`
TestOrmEId int64 `pk:"true" ai:"true"`
Name string
Description sql.NullString
VInt64 int64
Expand Down Expand Up @@ -342,6 +343,12 @@ func TestOrmInsertOrUpdate(t *testing.T) {
if testObj4.TestID != 3 {
t.Fatal("test id should be 3")
}
testObj1.Description = "update with table"
err = orm.InsertOrUpdateWithTable(testObj1, "test_orm_a123", []string{"description"})
if err != nil {
t.Error(err)
}
assert.Equal(t, testObj1.Description, "update with table")
})
}
func TestORMUpdateFieldsByPK(t *testing.T) {
Expand Down Expand Up @@ -441,7 +448,7 @@ func TestExecParam(t *testing.T) {
"id": testObj.TestID,
"description": "lala",
}
_, err := orm.ExecWithParam("update " + testTableName+
_, err := orm.ExecWithParam("update "+testTableName+
" set other_id = #{otherId}, description = #{description} where test_id = #{id}", paramMap)
if err != nil {
t.Error("failed to update", err)
Expand All @@ -460,7 +467,7 @@ func TestExecParam(t *testing.T) {
"description": "test",
}

_, err = orm.ExecWithParam("update " + testTableName+
_, err = orm.ExecWithParam("update "+testTableName+
" set other_id = #{otherId} + 1, description = #{description} where other_id = #{otherId}", params2)
orm.SelectByPK(&loadedObj, testObj.TestID)

Expand All @@ -477,7 +484,7 @@ func TestExecParam(t *testing.T) {
StartDate: time.Now(),
EndDate: time.Now(),
}
_, err = orm.ExecWithParam("update " + testTableName+
_, err = orm.ExecWithParam("update "+testTableName+
" set other_id = #{OtherId}, description = #{Description}, name = #{Name} where test_id = #{TestID}", testParam)
if err != nil {
t.Error(err)
Expand Down