diff --git a/orm.go b/orm.go index 0330993..2fb44d5 100644 --- a/orm.go +++ b/orm.go @@ -285,7 +285,7 @@ func execWithParam(c context.Context, tdx Tdx, paramQuery string, paramMap inter if params != nil && len(params) > 0 { var args []interface{} = make([]interface{}, 0, len(params)) for _, param := range params { - param = param[2: len(param)-1] + param = param[2 : len(param)-1] value, err := getFieldValue(paramMap, param) if err != nil { return nil, err @@ -742,7 +742,7 @@ func selectRawWithParam(c context.Context, tdx Tdx, paramQuery string, paramMap if params != nil && len(params) > 0 { var args []interface{} = make([]interface{}, 0, len(params)) for _, param := range params { - param = param[2: len(param)-1] + param = param[2 : len(param)-1] value, err := getFieldValue(paramMap, param) if err != nil { return nil, nil, err @@ -1201,6 +1201,38 @@ func insertOrUpdate(c context.Context, tdx Tdx, s interface{}, fields []string) return nil } +func insertOrUpdateByTable(c context.Context, tdx Tdx, tbName string, s interface{}, fields []string) error { + cols, vals, ifs, pk, isAi, pkName := columnsByStruct(s) + if len(fields) == 0 { + fields = strings.Split(cols, ",") + } + //重复时,需要更新的字段 + for k, v := range fields { + v = fieldName2ColName(v) + str := fmt.Sprintf("%s=values(%s)", v, v) + fields[k] = str + } + //检查主键的情况,在insert中加入主键 + if pk.Addr().Interface != nil { + cols += fmt.Sprintf(",%s", pkName) + vals += ",?" + ifs = append(ifs, pk.Addr().Interface()) + } + q := fmt.Sprintf("insert into %s (%s) values (%s) on duplicate key update %s", tbName, cols, vals, strings.Join(fields, ",")) + ret, err := exec(c, tdx, q, ifs...) + if err != nil { + return err + } + if isAi { + lid, err := ret.LastInsertId() + if err != nil { + return err + } + pk.SetInt(lid) + } + return nil +} + //通过传递需要更新的字段,去更新部分字段 func updateFieldsByPK(c context.Context, tdx Tdx, s interface{}, cols []string) error { ifs, pk, _, pkName := columnsByStructFields(s, cols) @@ -1425,7 +1457,11 @@ func (o *ORM) InsertBatch(s []interface{}) error { } func (o *ORM) InsertOrUpdate(s interface{}, keys []string) error { - return insertOrUpdate(o.ctx, o.db, s, keys) + return insertOrUpdateByTable(o.ctx, o.db, getTableName(s), s, keys) +} + +func (o *ORM) InsertOrUpdateWithTable(s interface{}, tbName string, keys []string) error { + return insertOrUpdateByTable(o.ctx, o.db, tbName, s, keys) } func (o *ORM) ExecWithRowAffectCheck(n int64, query string, args ...interface{}) error { diff --git a/orm_test.go b/orm_test.go index a87205f..f07d4c8 100644 --- a/orm_test.go +++ b/orm_test.go @@ -4,14 +4,15 @@ import ( "database/sql" "errors" "fmt" - "github.com/magiconair/properties/assert" "log" "testing" "time" + + "github.com/magiconair/properties/assert" ) type TestOrmA123 struct { - TestID int64 `json:"test_id" pk:"true" ai:"true" db:"test_id,ai,pk"` + TestID int64 `json:"test_id" pk:"true" ai:"true" db:"test_id,ai,pk"` OtherId int64 Description string Name sql.NullString @@ -26,7 +27,7 @@ type TestOrmA123 struct { } type TestOrmB999 struct { - NoAiId int64 `pk:"true"` + NoAiId int64 `pk:"true"` Description string TestID int64 `db:"test_id"` CreatedAt time.Time `ignore:"true"` @@ -45,7 +46,7 @@ type TestOrmD222 struct { } type TestOrmE333 struct { - TestOrmEId int64 `pk:"true" ai:"true"` + TestOrmEId int64 `pk:"true" ai:"true"` Name string Description sql.NullString VInt64 int64 @@ -342,6 +343,12 @@ func TestOrmInsertOrUpdate(t *testing.T) { if testObj4.TestID != 3 { t.Fatal("test id should be 3") } + testObj1.Description = "update with table" + err = orm.InsertOrUpdateWithTable(testObj1, "test_orm_a123", []string{"description"}) + if err != nil { + t.Error(err) + } + assert.Equal(t, testObj1.Description, "update with table") }) } func TestORMUpdateFieldsByPK(t *testing.T) { @@ -441,7 +448,7 @@ func TestExecParam(t *testing.T) { "id": testObj.TestID, "description": "lala", } - _, err := orm.ExecWithParam("update " + testTableName+ + _, err := orm.ExecWithParam("update "+testTableName+ " set other_id = #{otherId}, description = #{description} where test_id = #{id}", paramMap) if err != nil { t.Error("failed to update", err) @@ -460,7 +467,7 @@ func TestExecParam(t *testing.T) { "description": "test", } - _, err = orm.ExecWithParam("update " + testTableName+ + _, err = orm.ExecWithParam("update "+testTableName+ " set other_id = #{otherId} + 1, description = #{description} where other_id = #{otherId}", params2) orm.SelectByPK(&loadedObj, testObj.TestID) @@ -477,7 +484,7 @@ func TestExecParam(t *testing.T) { StartDate: time.Now(), EndDate: time.Now(), } - _, err = orm.ExecWithParam("update " + testTableName+ + _, err = orm.ExecWithParam("update "+testTableName+ " set other_id = #{OtherId}, description = #{Description}, name = #{Name} where test_id = #{TestID}", testParam) if err != nil { t.Error(err)