diff --git a/decode_hooks.go b/decode_hooks.go index a3dcc133..3d933794 100644 --- a/decode_hooks.go +++ b/decode_hooks.go @@ -22,60 +22,81 @@ func safeInterface(v reflect.Value) any { return v.Interface() } -// typedDecodeHook takes a raw DecodeHookFunc (an any) and turns -// it into the proper DecodeHookFunc type, such as DecodeHookFuncType. -func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc { - // Create variables here so we can reference them with the reflect pkg - var f1 DecodeHookFuncType - var f2 DecodeHookFuncKind - var f3 DecodeHookFuncValue - - // Fill in the variables into this interface and the rest is done - // automatically using the reflect package. - potential := []any{f1, f2, f3} - - v := reflect.ValueOf(h) - vt := v.Type() - for _, raw := range potential { - pt := reflect.ValueOf(raw).Type() - if vt.ConvertibleTo(pt) { - return v.Convert(pt).Interface() +// DecodeHookFuncTyped restricting the hook types can be unified to a common DecodeHookFuncValue form. +type DecodeHookFuncTyped interface { + // Unify returns a DecodeHookFuncValue that can be used directly in the decoder, don't return a nil plz. + Unify() DecodeHookFuncValue +} + +// unifyDecodeHook takes a raw DecodeHookFunc (an any) and turns it into a DecodeHookFuncValue(most wide form). +// if the type fails to convert we return a closure always erroring to keep the previous behavior +func unifyDecodeHook(h DecodeHookFunc) DecodeHookFuncValue { + // Note: old versions panicked on nil + var typed DecodeHookFuncTyped + switch v := h.(type) { + case func(reflect.Type, reflect.Type, any) (any, error): // DecodeHookFuncType(implicitly) + typed = DecodeHookFuncType(v) + case func(reflect.Kind, reflect.Kind, any) (any, error): // DecodeHookFuncKind(implicitly) + typed = DecodeHookFuncKind(v) + case func(reflect.Value, reflect.Value) (any, error): // DecodeHookFuncValue(implicitly) + typed = DecodeHookFuncValue(v) + case DecodeHookFuncTyped: // Implemented decodeHookFuncTyped(explicitly type, internal/custom) + typed = v + default: + // Maybe some valid signature derived types that doesn't implement DecodeHookFuncTyped, but can be converted to a hook type + // try reflect-based conversion before giving up + rv := reflect.ValueOf(h) + candidates := []DecodeHookFuncTyped{ + DecodeHookFuncType(nil), + DecodeHookFuncKind(nil), + DecodeHookFuncValue(nil), + } + for _, candidate := range candidates { + ct := reflect.TypeOf(candidate) + if !rv.CanConvert(ct) { + continue + } + // Convert it, then can be recognized as internal type in inner pass + anyV := rv.Convert(ct).Interface() + return unifyDecodeHook(anyV) } - } - - return nil -} -// cachedDecodeHook takes a raw DecodeHookFunc (an any) and turns -// it into a closure to be used directly -// if the type fails to convert we return a closure always erroring to keep the previous behaviour -func cachedDecodeHook(raw DecodeHookFunc) func(from reflect.Value, to reflect.Value) (any, error) { - switch f := typedDecodeHook(raw).(type) { - case DecodeHookFuncType: + // Not a valid hook type, return a closure that always errors return func(from reflect.Value, to reflect.Value) (any, error) { - if !from.IsValid() { - return f(reflect.TypeOf((*any)(nil)).Elem(), to.Type(), nil) - } - return f(from.Type(), to.Type(), from.Interface()) + return nil, errors.New("invalid decode hook signature") } - case DecodeHookFuncKind: + } + + unified := typed.Unify() + if unified == nil { + // Unify should never return nil, guards for further safety (maybe custom decodeHookFuncTyped) return func(from reflect.Value, to reflect.Value) (any, error) { - if !from.IsValid() { - return f(reflect.Invalid, to.Kind(), nil) - } - return f(from.Kind(), to.Kind(), from.Interface()) + return nil, fmt.Errorf("failed to unify decode hook: (%T).Unify() returned nil", typed) } - case DecodeHookFuncValue: - return func(from reflect.Value, to reflect.Value) (any, error) { - return f(from, to) + } + return unified +} + +func (h DecodeHookFuncType) Unify() DecodeHookFuncValue { + return func(from reflect.Value, to reflect.Value) (any, error) { + if !from.IsValid() { + return h(reflect.TypeOf((*any)(nil)).Elem(), to.Type(), nil) } - default: - return func(from reflect.Value, to reflect.Value) (any, error) { - return nil, errors.New("invalid decode hook signature") + return h(from.Type(), to.Type(), from.Interface()) + } +} + +func (h DecodeHookFuncKind) Unify() DecodeHookFuncValue { + return func(from reflect.Value, to reflect.Value) (any, error) { + if !from.IsValid() { + return h(reflect.Invalid, to.Kind(), nil) } + return h(from.Kind(), to.Kind(), from.Interface()) } } +func (h DecodeHookFuncValue) Unify() DecodeHookFuncValue { return h } + // DecodeHookExec executes the given decode hook. This should be used // since it'll naturally degrade to the older backwards compatible DecodeHookFunc // that took reflect.Kind instead of reflect.Type. @@ -83,40 +104,31 @@ func DecodeHookExec( raw DecodeHookFunc, from reflect.Value, to reflect.Value, ) (any, error) { - switch f := typedDecodeHook(raw).(type) { - case DecodeHookFuncType: - if !from.IsValid() { - return f(reflect.TypeOf((*any)(nil)).Elem(), to.Type(), nil) - } - return f(from.Type(), to.Type(), from.Interface()) - case DecodeHookFuncKind: - if !from.IsValid() { - return f(reflect.Invalid, to.Kind(), nil) - } - return f(from.Kind(), to.Kind(), from.Interface()) - case DecodeHookFuncValue: - return f(from, to) - default: - return nil, errors.New("invalid decode hook signature") - } + unified := unifyDecodeHook(raw) + return unified(from, to) } // ComposeDecodeHookFunc creates a single DecodeHookFunc that // automatically composes multiple DecodeHookFuncs. // +// Given hooks should be one of the three function signatures: +// - [DecodeHookFuncType] func(reflect.Type, reflect.Type, any) (any, error) +// - [DecodeHookFuncKind] func(reflect.Kind, reflect.Kind, any) (any, error) +// - [DecodeHookFuncValue] func(reflect.Value, reflect.Value) (any, error) +// // The composed funcs are called in order, with the result of the // previous transformation. -func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc { - cached := make([]func(from reflect.Value, to reflect.Value) (any, error), 0, len(fs)) +func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFuncValue { + unified := make([]DecodeHookFuncValue, 0, len(fs)) for _, f := range fs { - cached = append(cached, cachedDecodeHook(f)) + unified = append(unified, unifyDecodeHook(f)) } return func(f reflect.Value, t reflect.Value) (any, error) { var err error data := safeInterface(f) newFrom := f - for _, c := range cached { + for _, c := range unified { data, err = c(newFrom, t) if err != nil { return nil, err @@ -137,17 +149,22 @@ func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc { // OrComposeDecodeHookFunc executes all input hook functions until one of them returns no error. In that case its value is returned. // If all hooks return an error, OrComposeDecodeHookFunc returns an error concatenating all error messages. -func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc { - cached := make([]func(from reflect.Value, to reflect.Value) (any, error), 0, len(ff)) +// +// Given hooks should be one of the three function signatures: +// - [DecodeHookFuncType] func(reflect.Type, reflect.Type, any) (any, error) +// - [DecodeHookFuncKind] func(reflect.Kind, reflect.Kind, any) (any, error) +// - [DecodeHookFuncValue] func(reflect.Value, reflect.Value) (any, error) +func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFuncValue { + unified := make([]DecodeHookFuncValue, 0, len(ff)) for _, f := range ff { - cached = append(cached, cachedDecodeHook(f)) + unified = append(unified, unifyDecodeHook(f)) } return func(a, b reflect.Value) (any, error) { var allErrs string var out any var err error - for _, c := range cached { + for _, c := range unified { out, err = c(a, b) if err != nil { allErrs += err.Error() + "\n" @@ -161,9 +178,9 @@ func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc { } } -// StringToSliceHookFunc returns a DecodeHookFunc that converts +// StringToSliceHookFunc returns a DecodeHookFuncType that converts // string to []string by splitting on the given sep. -func StringToSliceHookFunc(sep string) DecodeHookFunc { +func StringToSliceHookFunc(sep string) DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -189,7 +206,7 @@ func StringToSliceHookFunc(sep string) DecodeHookFunc { // // As of mapstructure v2.0.0 [StringToSliceHookFunc] checks if the return type is a string slice. // This function removes that check. -func StringToWeakSliceHookFunc(sep string) DecodeHookFunc { +func StringToWeakSliceHookFunc(sep string) DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -208,9 +225,9 @@ func StringToWeakSliceHookFunc(sep string) DecodeHookFunc { } } -// StringToTimeDurationHookFunc returns a DecodeHookFunc that converts +// StringToTimeDurationHookFunc returns a DecodeHookFuncType that converts // strings to time.Duration. -func StringToTimeDurationHookFunc() DecodeHookFunc { +func StringToTimeDurationHookFunc() DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -230,9 +247,9 @@ func StringToTimeDurationHookFunc() DecodeHookFunc { } } -// StringToTimeLocationHookFunc returns a DecodeHookFunc that converts +// StringToTimeLocationHookFunc returns a DecodeHookFuncType that converts // strings to *time.Location. -func StringToTimeLocationHookFunc() DecodeHookFunc { +func StringToTimeLocationHookFunc() DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -250,9 +267,9 @@ func StringToTimeLocationHookFunc() DecodeHookFunc { } } -// StringToURLHookFunc returns a DecodeHookFunc that converts +// StringToURLHookFunc returns a DecodeHookFuncType that converts // strings to *url.URL. -func StringToURLHookFunc() DecodeHookFunc { +func StringToURLHookFunc() DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -272,9 +289,9 @@ func StringToURLHookFunc() DecodeHookFunc { } } -// StringToIPHookFunc returns a DecodeHookFunc that converts +// StringToIPHookFunc returns a DecodeHookFuncType that converts // strings to net.IP -func StringToIPHookFunc() DecodeHookFunc { +func StringToIPHookFunc() DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -297,9 +314,9 @@ func StringToIPHookFunc() DecodeHookFunc { } } -// StringToIPNetHookFunc returns a DecodeHookFunc that converts +// StringToIPNetHookFunc returns a DecodeHookFuncType that converts // strings to net.IPNet -func StringToIPNetHookFunc() DecodeHookFunc { +func StringToIPNetHookFunc() DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -318,9 +335,9 @@ func StringToIPNetHookFunc() DecodeHookFunc { } } -// StringToTimeHookFunc returns a DecodeHookFunc that converts +// StringToTimeHookFunc returns a DecodeHookFuncType that converts // strings to time.Time. -func StringToTimeHookFunc(layout string) DecodeHookFunc { +func StringToTimeHookFunc(layout string) DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -377,7 +394,7 @@ func WeaklyTypedHook( return data, nil } -func RecursiveStructToMapHookFunc() DecodeHookFunc { +func RecursiveStructToMapHookFunc() DecodeHookFuncValue { return func(f reflect.Value, t reflect.Value) (any, error) { if f.Kind() != reflect.Struct { return f.Interface(), nil @@ -395,7 +412,7 @@ func RecursiveStructToMapHookFunc() DecodeHookFunc { } } -// TextUnmarshallerHookFunc returns a DecodeHookFunc that applies +// TextUnmarshallerHookFunc returns a DecodeHookFuncType that applies // strings to the UnmarshalText function, when the target type // implements the encoding.TextUnmarshaler interface func TextUnmarshallerHookFunc() DecodeHookFuncType { @@ -423,9 +440,9 @@ func TextUnmarshallerHookFunc() DecodeHookFuncType { } } -// StringToNetIPAddrHookFunc returns a DecodeHookFunc that converts +// StringToNetIPAddrHookFunc returns a DecodeHookFuncType that converts // strings to netip.Addr. -func StringToNetIPAddrHookFunc() DecodeHookFunc { +func StringToNetIPAddrHookFunc() DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -445,9 +462,9 @@ func StringToNetIPAddrHookFunc() DecodeHookFunc { } } -// StringToNetIPAddrPortHookFunc returns a DecodeHookFunc that converts +// StringToNetIPAddrPortHookFunc returns a DecodeHookFuncType that converts // strings to netip.AddrPort. -func StringToNetIPAddrPortHookFunc() DecodeHookFunc { +func StringToNetIPAddrPortHookFunc() DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -467,9 +484,9 @@ func StringToNetIPAddrPortHookFunc() DecodeHookFunc { } } -// StringToNetIPPrefixHookFunc returns a DecodeHookFunc that converts +// StringToNetIPPrefixHookFunc returns a DecodeHookFuncType that converts // strings to netip.Prefix. -func StringToNetIPPrefixHookFunc() DecodeHookFunc { +func StringToNetIPPrefixHookFunc() DecodeHookFuncType { return func( f reflect.Type, t reflect.Type, @@ -489,10 +506,10 @@ func StringToNetIPPrefixHookFunc() DecodeHookFunc { } } -// StringToBasicTypeHookFunc returns a DecodeHookFunc that converts +// StringToBasicTypeHookFunc returns a DecodeHookFuncValue that converts // strings to basic types. // int8, uint8, int16, uint16, int32, uint32, int64, uint64, int, uint, float32, float64, bool, byte, rune, complex64, complex128 -func StringToBasicTypeHookFunc() DecodeHookFunc { +func StringToBasicTypeHookFunc() DecodeHookFuncValue { return ComposeDecodeHookFunc( StringToInt8HookFunc(), StringToUint8HookFunc(), @@ -515,9 +532,9 @@ func StringToBasicTypeHookFunc() DecodeHookFunc { ) } -// StringToInt8HookFunc returns a DecodeHookFunc that converts +// StringToInt8HookFunc returns a DecodeHookFuncType that converts // strings to int8. -func StringToInt8HookFunc() DecodeHookFunc { +func StringToInt8HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Int8 { return data, nil @@ -529,9 +546,9 @@ func StringToInt8HookFunc() DecodeHookFunc { } } -// StringToUint8HookFunc returns a DecodeHookFunc that converts +// StringToUint8HookFunc returns a DecodeHookFuncType that converts // strings to uint8. -func StringToUint8HookFunc() DecodeHookFunc { +func StringToUint8HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Uint8 { return data, nil @@ -543,9 +560,9 @@ func StringToUint8HookFunc() DecodeHookFunc { } } -// StringToInt16HookFunc returns a DecodeHookFunc that converts +// StringToInt16HookFunc returns a DecodeHookFuncType that converts // strings to int16. -func StringToInt16HookFunc() DecodeHookFunc { +func StringToInt16HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Int16 { return data, nil @@ -557,9 +574,9 @@ func StringToInt16HookFunc() DecodeHookFunc { } } -// StringToUint16HookFunc returns a DecodeHookFunc that converts +// StringToUint16HookFunc returns a DecodeHookFuncType that converts // strings to uint16. -func StringToUint16HookFunc() DecodeHookFunc { +func StringToUint16HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Uint16 { return data, nil @@ -571,9 +588,9 @@ func StringToUint16HookFunc() DecodeHookFunc { } } -// StringToInt32HookFunc returns a DecodeHookFunc that converts +// StringToInt32HookFunc returns a DecodeHookFuncType that converts // strings to int32. -func StringToInt32HookFunc() DecodeHookFunc { +func StringToInt32HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Int32 { return data, nil @@ -585,9 +602,9 @@ func StringToInt32HookFunc() DecodeHookFunc { } } -// StringToUint32HookFunc returns a DecodeHookFunc that converts +// StringToUint32HookFunc returns a DecodeHookFuncType that converts // strings to uint32. -func StringToUint32HookFunc() DecodeHookFunc { +func StringToUint32HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Uint32 { return data, nil @@ -599,9 +616,9 @@ func StringToUint32HookFunc() DecodeHookFunc { } } -// StringToInt64HookFunc returns a DecodeHookFunc that converts +// StringToInt64HookFunc returns a DecodeHookFuncType that converts // strings to int64. -func StringToInt64HookFunc() DecodeHookFunc { +func StringToInt64HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Int64 { return data, nil @@ -613,9 +630,9 @@ func StringToInt64HookFunc() DecodeHookFunc { } } -// StringToUint64HookFunc returns a DecodeHookFunc that converts +// StringToUint64HookFunc returns a DecodeHookFuncType that converts // strings to uint64. -func StringToUint64HookFunc() DecodeHookFunc { +func StringToUint64HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Uint64 { return data, nil @@ -627,9 +644,9 @@ func StringToUint64HookFunc() DecodeHookFunc { } } -// StringToIntHookFunc returns a DecodeHookFunc that converts +// StringToIntHookFunc returns a DecodeHookFuncType that converts // strings to int. -func StringToIntHookFunc() DecodeHookFunc { +func StringToIntHookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Int { return data, nil @@ -641,9 +658,9 @@ func StringToIntHookFunc() DecodeHookFunc { } } -// StringToUintHookFunc returns a DecodeHookFunc that converts +// StringToUintHookFunc returns a DecodeHookFuncType that converts // strings to uint. -func StringToUintHookFunc() DecodeHookFunc { +func StringToUintHookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Uint { return data, nil @@ -655,9 +672,9 @@ func StringToUintHookFunc() DecodeHookFunc { } } -// StringToFloat32HookFunc returns a DecodeHookFunc that converts +// StringToFloat32HookFunc returns a DecodeHookFuncType that converts // strings to float32. -func StringToFloat32HookFunc() DecodeHookFunc { +func StringToFloat32HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Float32 { return data, nil @@ -669,9 +686,9 @@ func StringToFloat32HookFunc() DecodeHookFunc { } } -// StringToFloat64HookFunc returns a DecodeHookFunc that converts +// StringToFloat64HookFunc returns a DecodeHookFuncType that converts // strings to float64. -func StringToFloat64HookFunc() DecodeHookFunc { +func StringToFloat64HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Float64 { return data, nil @@ -683,9 +700,9 @@ func StringToFloat64HookFunc() DecodeHookFunc { } } -// StringToBoolHookFunc returns a DecodeHookFunc that converts +// StringToBoolHookFunc returns a DecodeHookFuncType that converts // strings to bool. -func StringToBoolHookFunc() DecodeHookFunc { +func StringToBoolHookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Bool { return data, nil @@ -697,21 +714,21 @@ func StringToBoolHookFunc() DecodeHookFunc { } } -// StringToByteHookFunc returns a DecodeHookFunc that converts +// StringToByteHookFunc returns a DecodeHookFuncType that converts // strings to byte. -func StringToByteHookFunc() DecodeHookFunc { +func StringToByteHookFunc() DecodeHookFuncType { return StringToUint8HookFunc() } -// StringToRuneHookFunc returns a DecodeHookFunc that converts +// StringToRuneHookFunc returns a DecodeHookFuncType that converts // strings to rune. -func StringToRuneHookFunc() DecodeHookFunc { +func StringToRuneHookFunc() DecodeHookFuncType { return StringToInt32HookFunc() } -// StringToComplex64HookFunc returns a DecodeHookFunc that converts +// StringToComplex64HookFunc returns a DecodeHookFuncType that converts // strings to complex64. -func StringToComplex64HookFunc() DecodeHookFunc { +func StringToComplex64HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Complex64 { return data, nil @@ -723,9 +740,9 @@ func StringToComplex64HookFunc() DecodeHookFunc { } } -// StringToComplex128HookFunc returns a DecodeHookFunc that converts +// StringToComplex128HookFunc returns a DecodeHookFuncType that converts // strings to complex128. -func StringToComplex128HookFunc() DecodeHookFunc { +func StringToComplex128HookFunc() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if f.Kind() != reflect.String || t.Kind() != reflect.Complex128 { return data, nil diff --git a/decode_hooks_test.go b/decode_hooks_test.go index 02d5e662..1438d996 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go @@ -2172,3 +2172,98 @@ func TestErrorLeakageDecodeHook(t *testing.T) { } } } + +type customHookType string + +func (c customHookType) Unify() DecodeHookFuncValue { + if c == "send nil cuz im bad" { + return nil + } + return func(from, to reflect.Value) (any, error) { + return string(c), nil + } +} + +func Test_unifyDecodeHook(t *testing.T) { + checkResultPassed := func(hook DecodeHookFuncValue) { + got, err := hook(reflect.ValueOf(""), reflect.ValueOf(0)) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if got != "passed" { + t.Fatalf("expected 'passed', got %v", got) + } + } + + t.Run("Passing DecodeHookFuncType", func(t *testing.T) { + hook := unifyDecodeHook(func(from, to reflect.Type, data any) (any, error) { + return "passed", nil + }) + checkResultPassed(hook) + }) + t.Run("Passing DecodeHookFuncType-Explicitly", func(t *testing.T) { + fn := DecodeHookFuncType(func(from, to reflect.Type, data any) (any, error) { + return "passed", nil + }) + hook := unifyDecodeHook(fn) + checkResultPassed(hook) + }) + + t.Run("Passing DecodeHookFuncKind", func(t *testing.T) { + hook := unifyDecodeHook(func(from, to reflect.Kind, data any) (any, error) { + return "passed", nil + }) + checkResultPassed(hook) + }) + t.Run("Passing DecodeHookFuncKind-Explicitly", func(t *testing.T) { + fn := DecodeHookFuncKind(func(from, to reflect.Kind, data any) (any, error) { + return "passed", nil + }) + hook := unifyDecodeHook(fn) + checkResultPassed(hook) + }) + + t.Run("Passing DecodeHookFuncValue", func(t *testing.T) { + hook := unifyDecodeHook(func(from, to reflect.Value) (any, error) { + return "passed", nil + }) + checkResultPassed(hook) + }) + t.Run("Passing DecodeHookFuncValue-Explicitly", func(t *testing.T) { + fn := DecodeHookFuncValue(func(from, to reflect.Value) (any, error) { + return "passed", nil + }) + hook := unifyDecodeHook(fn) + checkResultPassed(hook) + }) + + t.Run("Passing non-hook type", func(t *testing.T) { + hook := unifyDecodeHook(42) + got, err := hook(reflect.ValueOf(""), reflect.ValueOf(0)) + if err == nil { + t.Fatalf("expected error, got nil with output: %v", got) + } + }) + + t.Run("Passing DecodeHookFuncValue-Derived", func(t *testing.T) { + type MyDecodeHookFuncValue func(from, to reflect.Value) (any, error) + fn := MyDecodeHookFuncValue(func(from, to reflect.Value) (any, error) { + return "passed", nil + }) + hook := unifyDecodeHook(fn) + checkResultPassed(hook) + }) + + t.Run("Custom hook type implementing Unify", func(t *testing.T) { + hook := unifyDecodeHook(customHookType("passed")) + checkResultPassed(hook) + }) + + t.Run("Custom hook type implementing Unify that returns nil", func(t *testing.T) { + hook := unifyDecodeHook(customHookType("send nil cuz im bad")) + got, err := hook(reflect.ValueOf(""), reflect.ValueOf(0)) + if err == nil { + t.Fatalf("expected error, got nil with output: %v", got) + } + }) +} diff --git a/mapstructure.go b/mapstructure.go index 9087fd96..c53751ff 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -521,7 +521,7 @@ func NewDecoder(config *DecoderConfig) (*Decoder, error) { config: config, } if config.DecodeHook != nil { - result.cachedDecodeHook = cachedDecodeHook(config.DecodeHook) + result.cachedDecodeHook = unifyDecodeHook(config.DecodeHook) } return result, nil diff --git a/mapstructure_examples_test.go b/mapstructure_examples_test.go index 6eb1f26a..8da5d2e1 100644 --- a/mapstructure_examples_test.go +++ b/mapstructure_examples_test.go @@ -312,7 +312,7 @@ func ExampleDecode_decodeHookFunc() { "location": "-35.2809#149.1300", } - toPersonLocationHookFunc := func() DecodeHookFunc { + toPersonLocationHookFunc := func() DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data any) (any, error) { if t != reflect.TypeOf(PersonLocation{}) { return data, nil