diff --git a/README.md b/README.md index 497712f..f3aa58d 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,19 @@ go run *.go -token my-secret-token The token is marked as a secret, which is important to avoid leaking its value. +#### Empty Values for Optional Parameters + +It's important to understand how optional parameters with default values behave +when they receive an empty string (`""`) from a configuration source. + +For most types (including numeric types, `bool`, `time.Time`, `time.Duration`, +and most `xtypes`), providing an empty string is treated as an **absent** +value. This means the parameter will correctly use its specified default value, +just as it would if the parameter was omitted entirely. + +The `string` and `xtypes.String` types are an exception. For these, an empty +string is considered a valid, intentional value that will override any default. + ### XTypes _XTypes_ are types provided by _proteus_ to handle complex types and to provide diff --git a/basic_types.go b/basic_types.go index 451e802..2a4d7e3 100644 --- a/basic_types.go +++ b/basic_types.go @@ -5,8 +5,11 @@ import ( "reflect" "strconv" "time" + + "github.com/simplesurance/proteus/types" ) +//nolint:gocyclo func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error { // the redact function is to allow redacting part of a value, like // redacting the "password" part of an URL. For basic types use @@ -17,12 +20,18 @@ func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error switch valT := val.Interface().(type) { case time.Time: fieldData.validFn = func(str string) error { + if str == "" { + return types.ErrNoValue + } _, err := time.Parse(time.RFC3339Nano, str) return err } fieldData.setValueFn = func(str *string) error { panicOnNil(str) + if *str == "" { + return nil + } v, err := time.Parse(time.RFC3339Nano, *str) if err != nil { return err @@ -39,12 +48,18 @@ func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error return nil case time.Duration: fieldData.validFn = func(str string) error { + if str == "" { + return types.ErrNoValue + } _, err := time.ParseDuration(str) return err } fieldData.setValueFn = func(str *string) error { panicOnNil(str) + if *str == "" { + return nil + } v, err := time.ParseDuration(*str) if err != nil { return err @@ -84,12 +99,18 @@ func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error case reflect.Bool: fieldData.boolean = true fieldData.validFn = func(str string) error { + if str == "" { + return types.ErrNoValue + } _, err := strconv.ParseBool(str) return err } fieldData.setValueFn = func(str *string) error { panicOnNil(str) + if *str == "" { + return nil + } v, err := strconv.ParseBool(*str) if err != nil { return err @@ -142,6 +163,9 @@ func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error func configAsInt(fieldData *paramSetField, val reflect.Value, bitSize int) { fieldData.validFn = func(str string) error { + if str == "" { + return types.ErrNoValue + } _, err := strconv.ParseInt(str, 10, bitSize) if err != nil { return badNumberErr(true, bitSize) @@ -152,6 +176,9 @@ func configAsInt(fieldData *paramSetField, val reflect.Value, bitSize int) { fieldData.setValueFn = func(str *string) error { panicOnNil(str) + if *str == "" { + return nil + } v, err := strconv.ParseInt(*str, 10, bitSize) if err != nil { return badNumberErr(true, bitSize) @@ -168,6 +195,9 @@ func configAsInt(fieldData *paramSetField, val reflect.Value, bitSize int) { func configAsUint(fieldData *paramSetField, val reflect.Value, bitSize int) { fieldData.validFn = func(str string) error { + if str == "" { + return types.ErrNoValue + } _, err := strconv.ParseUint(str, 10, bitSize) if err != nil { return badNumberErr(false, bitSize) @@ -178,6 +208,9 @@ func configAsUint(fieldData *paramSetField, val reflect.Value, bitSize int) { fieldData.setValueFn = func(str *string) error { panicOnNil(str) + if *str == "" { + return nil + } v, err := strconv.ParseUint(*str, 10, bitSize) if err != nil { return badNumberErr(false, bitSize) diff --git a/parser_test.go b/parser_test.go index 29e7d8e..3db3aca 100644 --- a/parser_test.go +++ b/parser_test.go @@ -442,3 +442,176 @@ func generateTestKey(t *testing.T) (*rsa.PrivateKey, string) { } return privateKey, string(pem.EncodeToMemory(privateKeyPEM)) } + +func TestOptionalBasicTypes(t *testing.T) { + tests := []struct { + name string + params types.ParamValues + shouldErr bool + expectInt int + expectBool bool + expectDur time.Duration + }{ + { + name: "no value for optional params", + params: types.ParamValues{ + "": { + "req": "1", + }, + }, + shouldErr: false, + expectInt: 42, + expectBool: true, + expectDur: time.Hour, + }, + { + name: "empty string for optional params", + params: types.ParamValues{ + "": { + "i": "", + "b": "", + "d": "", + "req": "2", + }, + }, + shouldErr: false, + expectInt: 42, + expectBool: true, + expectDur: time.Hour, + }, + { + name: "valid values for optional params", + params: types.ParamValues{ + "": { + "i": "123", + "b": "false", + "d": "10s", + "req": "3", + }, + }, + shouldErr: false, + expectInt: 123, + expectBool: false, + expectDur: 10 * time.Second, + }, + { + name: "empty string for required param", + params: types.ParamValues{ + "": { + "req": "", + }, + }, + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := struct { + I int `param:",optional"` + B bool `param:",optional"` + D time.Duration `param:",optional"` + Req int + }{ + I: 42, + B: true, + D: time.Hour, + Req: 99, + } + + testProvider := cfgtest.New(tt.params) + defer testProvider.Stop() + + _, err := proteus.MustParse(&cfg, proteus.WithProviders(testProvider)) + + if tt.shouldErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectInt, cfg.I) + assert.Equal(t, tt.expectBool, cfg.B) + assert.Equal(t, tt.expectDur, cfg.D) + } + }) + } +} + +func TestOptionalXTypes(t *testing.T) { + tests := []struct { + name string + params types.ParamValues + shouldErr bool + expectInt int + expectBool bool + expectJSON string + }{ + { + name: "no value for optional xtypes", + params: types.ParamValues{"": {"req": "1"}}, + shouldErr: false, + expectInt: 88, + expectBool: true, + expectJSON: `{"a":"b"}`, + }, + { + name: "empty string for optional xtypes", + params: types.ParamValues{ + "": { + "xi": "", + "xb": "", + "xj": "", + "req": "2", + }, + }, + shouldErr: false, + expectInt: 88, + expectBool: true, + expectJSON: `{"a":"b"}`, + }, + { + name: "valid values for optional xtypes", + params: types.ParamValues{ + "": { + "xi": "-5", + "xb": "false", + "xj": `[1,2]`, // raw json + "req": "3", + }, + }, + shouldErr: false, + expectInt: -5, + expectBool: false, + expectJSON: `[1,2]`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := struct { + XI *xtypes.Integer[int] `param:",optional"` + XB *xtypes.Bool `param:",optional"` + XJ *xtypes.RawJSON `param:",optional"` + Req int + }{ + XI: &xtypes.Integer[int]{DefaultValue: 88}, + XB: &xtypes.Bool{DefaultValue: true}, + XJ: &xtypes.RawJSON{DefaultValue: []byte(`{"a":"b"}`)}, + Req: 99, + } + + testProvider := cfgtest.New(tt.params) + defer testProvider.Stop() + + _, err := proteus.MustParse(&cfg, proteus.WithProviders(testProvider)) + + if tt.shouldErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectInt, cfg.XI.Value()) + assert.Equal(t, tt.expectBool, cfg.XB.Value()) + assert.Equal(t, tt.expectJSON, string(cfg.XJ.Value())) + } + }) + } +} diff --git a/xtypes/bool.go b/xtypes/bool.go index 5a885db..9555e8a 100644 --- a/xtypes/bool.go +++ b/xtypes/bool.go @@ -24,7 +24,7 @@ var _ types.XType = &Bool{} // UnmarshalParam parses the input as a boolean. func (d *Bool) UnmarshalParam(in *string) error { var ptrBool *bool - if in != nil { + if in != nil && *in != "" { boolValue, err := strconv.ParseBool(*in) if err != nil { return errors.New("not a valid boolean") @@ -60,6 +60,9 @@ func (d *Bool) Value() bool { // ValueValid test if the provided parameter value is valid. Has no side // effects. func (d *Bool) ValueValid(s string) error { + if s == "" { + return types.ErrNoValue + } _, err := strconv.ParseBool(s) return err } diff --git a/xtypes/integer.go b/xtypes/integer.go index 706ce82..50aee77 100644 --- a/xtypes/integer.go +++ b/xtypes/integer.go @@ -27,7 +27,7 @@ var _ types.XType = &Integer[int]{} // UnmarshalParam parses the input as an integer of type T. func (d *Integer[T]) UnmarshalParam(in *string) error { var ptrT *T - if in != nil { + if in != nil && *in != "" { valT, err := parseInt[T](*in) if err != nil { return errors.New("invalid value for the numeric type") @@ -63,6 +63,9 @@ func (d *Integer[T]) Value() T { // ValueValid test if the provided parameter value is valid. Has no side // effects. func (d *Integer[T]) ValueValid(s string) error { + if s == "" { + return types.ErrNoValue + } _, err := parseInt[T](s) return err } diff --git a/xtypes/rawjson.go b/xtypes/rawjson.go index 87c3f23..8ee7d09 100644 --- a/xtypes/rawjson.go +++ b/xtypes/rawjson.go @@ -22,7 +22,7 @@ var _ types.XType = &RawJSON{} // UnmarshalParam parses the input as a string. func (d *RawJSON) UnmarshalParam(in *string) error { var j json.RawMessage - if in != nil { + if in != nil && *in != "" { err := json.Unmarshal([]byte(*in), &j) if err != nil { return err @@ -63,6 +63,9 @@ func (d *RawJSON) Value() json.RawMessage { // ValueValid test if the provided parameter value is valid. Has no side // effects. func (d *RawJSON) ValueValid(s string) error { + if s == "" { + return types.ErrNoValue + } var j json.RawMessage return json.Unmarshal([]byte(s), &j) } diff --git a/xtypes/url.go b/xtypes/url.go index 58d2532..eeb604c 100644 --- a/xtypes/url.go +++ b/xtypes/url.go @@ -25,7 +25,7 @@ var _ types.Redactor = &URL{} // UnmarshalParam parses the input as a string. func (d *URL) UnmarshalParam(in *string) error { var url *url.URL - if in != nil { + if in != nil && *in != "" { var err error url, err = parseURL(*in, d.ValidateFn) if err != nil { @@ -61,6 +61,9 @@ func (d *URL) Value() *url.URL { // ValueValid test if the provided parameter value is valid. Has no side // effects. func (d *URL) ValueValid(s string) error { + if s == "" { + return types.ErrNoValue + } _, err := parseURL(s, d.ValidateFn) return err } diff --git a/xtypes/url_test.go b/xtypes/url_test.go index 8dd6ead..2740856 100644 --- a/xtypes/url_test.go +++ b/xtypes/url_test.go @@ -82,6 +82,27 @@ func TestEmptyURL(t *testing.T) { "": map[string]string{"url": ""}, }) + _, err := proteus.MustParse(¶ms, proteus.WithProviders(provider)) + assert.ErrorNow(t, err) +} + +func TestEmptyOptionalURL(t *testing.T) { + defURL, err := url.Parse("https://localhost?xxx") + assert.NoErrorNow(t, err) + + params := struct { + URL *xtypes.URL `param:",optional"` + }{ + URL: &xtypes.URL{ + ValidateFn: func(_ *url.URL) error { return nil }, + DefaultValue: defURL, + }, + } + + provider := cfgtest.New(types.ParamValues{ + "": map[string]string{"url": ""}, + }) + parsed, err := proteus.MustParse(¶ms, proteus.WithProviders(provider)) assert.NoErrorNow(t, err) @@ -93,7 +114,7 @@ func TestEmptyURL(t *testing.T) { parsed.Usage(&buffer) t.Log("USAGE INFORMATION\n" + buffer.String()) - assert.Equal(t, "", params.URL.Value().String()) + assert.Equal(t, params.URL.Value(), defURL) } func TestCustomValidator(t *testing.T) {