diff --git a/mapstructure.go b/mapstructure.go index 7581806a..6130c4aa 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -414,18 +414,18 @@ func NewDecoder(config *DecoderConfig) (*Decoder, error) { // Decode decodes the given raw interface to the target pointer specified // by the configuration. func (d *Decoder) Decode(input interface{}) error { - return d.decode("", input, reflect.ValueOf(d.config.Result).Elem()) + return d.decode("", input, reflect.ValueOf(d.config.Result).Elem(), true) } // Decodes an unknown data type into a specific reflection value. -func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) error { +func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value, skipTypedNil bool) error { var inputVal reflect.Value if input != nil { inputVal = reflect.ValueOf(input) // We need to check here if input is a typed nil. Typed nils won't // match the "input == nil" below so we check that here. - if inputVal.Kind() == reflect.Ptr && inputVal.IsNil() { + if skipTypedNil && inputVal.Kind() == reflect.Ptr && inputVal.IsNil() { input = nil } } @@ -529,7 +529,7 @@ func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) // Decode. If we have an error then return. We also return right // away if we're not a copy because that means we decoded directly. - if err := d.decode(name, data, elem); err != nil || !copied { + if err := d.decode(name, data, elem, true); err != nil || !copied { return err } @@ -842,7 +842,7 @@ func (d *Decoder) decodeMapFromSlice(name string, dataVal reflect.Value, val ref for i := 0; i < dataVal.Len(); i++ { err := d.decode( name+"["+strconv.Itoa(i)+"]", - dataVal.Index(i).Interface(), val) + dataVal.Index(i).Interface(), val, true) if err != nil { return err } @@ -878,7 +878,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle // First decode the key into the proper type currentKey := reflect.Indirect(reflect.New(valKeyType)) - if err := d.decode(fieldName, k.Interface(), currentKey); err != nil { + if err := d.decode(fieldName, k.Interface(), currentKey, true); err != nil { errors = appendErrors(errors, err) continue } @@ -886,7 +886,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle // Next decode the data into the proper type v := dataVal.MapIndex(k).Interface() currentVal := reflect.Indirect(reflect.New(valElemType)) - if err := d.decode(fieldName, v, currentVal); err != nil { + if err := d.decode(fieldName, v, currentVal, true); err != nil { errors = appendErrors(errors, err) continue } @@ -918,9 +918,6 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re // Next get the actual value of this field and verify it is assignable // to the map value. v := dataVal.Field(i) - if !v.Type().AssignableTo(valMap.Type().Elem()) { - return fmt.Errorf("cannot assign type '%s' to map value field of type '%s'", v.Type(), valMap.Type().Elem()) - } tagValue := f.Tag.Get(d.config.TagName) keyName := f.Name @@ -986,7 +983,7 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re addrVal := reflect.New(vMap.Type()) reflect.Indirect(addrVal).Set(vMap) - err := d.decode(keyName, x.Interface(), reflect.Indirect(addrVal)) + err := d.decode(keyName, x.Interface(), reflect.Indirect(addrVal), true) if err != nil { return err } @@ -1004,7 +1001,13 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re } default: - valMap.SetMapIndex(reflect.ValueOf(keyName), v) + currentVal := reflect.Indirect(reflect.New(valMap.Type().Elem())) + err := d.decode(keyName, v.Interface(), currentVal, false) + if err != nil { + return err + } + + valMap.SetMapIndex(reflect.ValueOf(keyName), currentVal) } } @@ -1049,13 +1052,13 @@ func (d *Decoder) decodePtr(name string, data interface{}, val reflect.Value) (b realVal = reflect.New(valElemType) } - if err := d.decode(name, data, reflect.Indirect(realVal)); err != nil { + if err := d.decode(name, data, reflect.Indirect(realVal), true); err != nil { return false, err } val.Set(realVal) } else { - if err := d.decode(name, data, reflect.Indirect(val)); err != nil { + if err := d.decode(name, data, reflect.Indirect(val), true); err != nil { return false, err } } @@ -1138,7 +1141,7 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) currentField := valSlice.Index(i) fieldName := name + "[" + strconv.Itoa(i) + "]" - if err := d.decode(fieldName, currentData, currentField); err != nil { + if err := d.decode(fieldName, currentData, currentField, true); err != nil { errors = appendErrors(errors, err) } } @@ -1205,7 +1208,7 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) currentField := valArray.Index(i) fieldName := name + "[" + strconv.Itoa(i) + "]" - if err := d.decode(fieldName, currentData, currentField); err != nil { + if err := d.decode(fieldName, currentData, currentField, true); err != nil { errors = appendErrors(errors, err) } } @@ -1410,7 +1413,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e fieldName = name + "." + fieldName } - if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil { + if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue, true); err != nil { errors = appendErrors(errors, err) } } diff --git a/mapstructure_test.go b/mapstructure_test.go index d31129d7..3ae72836 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -5,6 +5,7 @@ import ( "io" "reflect" "sort" + "strconv" "strings" "testing" "time" @@ -2732,6 +2733,50 @@ func TestDecoder_IgnoreUntaggedFields(t *testing.T) { } } +func TestDecodeStructToMap_DecodeHook(t *testing.T) { + t.Parallel() + + input := struct { + Vstring string + Vint int + }{ + "foo", + 42, + } + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if f.Kind() != reflect.Int { + return v, nil + } + val := strconv.FormatInt(int64(v.(int)), 10) + return val, nil + } + + var result map[string]interface{} + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + expected := map[string]interface{}{ + "Vstring": "foo", + "Vint": "42", + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("Decode call result should be %#v, got %#v", expected, result) + } +} + func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) { var result Slice err := Decode(input, &result)