diff --git a/decode_hooks_test.go b/decode_hooks_test.go index 02d5e662..1625ce2d 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go @@ -1002,6 +1002,65 @@ func TestStructToMapHookFuncTabled(t *testing.T) { } } +func TestNestedStructFieldTypeOverride(t *testing.T) { + type Struct struct { + Time time.Time `mapstructure:"time"` + } + + input := Struct{ + Time: time.Now(), + } + + var actualNoHook map[string]any + + d, err := NewDecoder(&DecoderConfig{Result: &actualNoHook}) + if err != nil { + t.Fatalf("unexpected err %#v", err) + } + + expectedNoHook := map[string]any{} + + err = d.Decode(input) + if err != nil { + t.Fatalf("unexpected err %#v", err) + } + + if !reflect.DeepEqual(expectedNoHook, actualNoHook["time"]) { + t.Fatalf("expected %#v, got %#v", expectedNoHook, actualNoHook["time"]) + } + + var actualHook map[string]any + + d, err = NewDecoder(&DecoderConfig{ + Result: &actualHook, + DecodeHook: func(from, to reflect.Type, data any) (any, error) { + if tm, ok := data.(time.Time); ok { + return tm.Format(time.RFC3339), nil + } + + if tm, ok := data.(*time.Time); ok { + return tm.Format(time.RFC3339), nil + } + + return data, nil + }, + }) + if err != nil { + t.Fatalf("unexpected err %#v", err) + } + + expectedHook := input.Time.Format(time.RFC3339) + + err = d.Decode(input) + if err != nil { + t.Fatalf("unexpected err %#v", err) + } + + if !reflect.DeepEqual(expectedHook, actualHook["time"]) { + t.Fatalf("expected %#v, got %#v", expectedHook, actualHook["time"]) + } +} + func TestTextUnmarshallerHookFunc(t *testing.T) { type MyString string diff --git a/mapstructure.go b/mapstructure.go index 9087fd96..d22128b1 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -1220,6 +1220,21 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re keyName = tagValue } + if d.cachedDecodeHook != nil { + mapelem := reflect.New(valMap.Type().Elem()).Elem() + + input, err := d.cachedDecodeHook(v, mapelem) + if err != nil { + return fmt.Errorf("error decoding '%s': %w", name, err) + } + + if !reflect.DeepEqual(input, v.Interface()) { + valMap.SetMapIndex(reflect.ValueOf(keyName), reflect.ValueOf(input)) + + continue + } + } + switch v.Kind() { // this is an embedded struct, so handle it differently case reflect.Struct: