Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ go run *.go -token my-secret-token

The token is marked as a secret, which is important to avoid leaking its value.

#### Empty Values for Optional Parameters

It's important to understand how optional parameters with default values behave
when they receive an empty string (`""`) from a configuration source.

For most types (including numeric types, `bool`, `time.Time`, `time.Duration`,
and most `xtypes`), providing an empty string is treated as an **absent**
value. This means the parameter will correctly use its specified default value,
just as it would if the parameter was omitted entirely.

The `string` and `xtypes.String` types are an exception. For these, an empty
string is considered a valid, intentional value that will override any default.

### XTypes

_XTypes_ are types provided by _proteus_ to handle complex types and to provide
Expand Down
33 changes: 33 additions & 0 deletions basic_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"reflect"
"strconv"
"time"

"github.com/simplesurance/proteus/types"
)

//nolint:gocyclo
func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error {
// the redact function is to allow redacting part of a value, like
// redacting the "password" part of an URL. For basic types use
Expand All @@ -17,12 +20,18 @@ func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error
switch valT := val.Interface().(type) {
case time.Time:
fieldData.validFn = func(str string) error {
if str == "" {
return types.ErrNoValue
}
_, err := time.Parse(time.RFC3339Nano, str)
return err
}

fieldData.setValueFn = func(str *string) error {
panicOnNil(str)
if *str == "" {
return nil
}
v, err := time.Parse(time.RFC3339Nano, *str)
if err != nil {
return err
Expand All @@ -39,12 +48,18 @@ func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error
return nil
case time.Duration:
fieldData.validFn = func(str string) error {
if str == "" {
return types.ErrNoValue
}
_, err := time.ParseDuration(str)
return err
}

fieldData.setValueFn = func(str *string) error {
panicOnNil(str)
if *str == "" {
return nil
}
v, err := time.ParseDuration(*str)
if err != nil {
return err
Expand Down Expand Up @@ -84,12 +99,18 @@ func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error
case reflect.Bool:
fieldData.boolean = true
fieldData.validFn = func(str string) error {
if str == "" {
return types.ErrNoValue
}
_, err := strconv.ParseBool(str)
return err
}

fieldData.setValueFn = func(str *string) error {
panicOnNil(str)
if *str == "" {
return nil
}
v, err := strconv.ParseBool(*str)
if err != nil {
return err
Expand Down Expand Up @@ -142,6 +163,9 @@ func configStandardCallbacks(fieldData *paramSetField, val reflect.Value) error

func configAsInt(fieldData *paramSetField, val reflect.Value, bitSize int) {
fieldData.validFn = func(str string) error {
if str == "" {
return types.ErrNoValue
}
_, err := strconv.ParseInt(str, 10, bitSize)
if err != nil {
return badNumberErr(true, bitSize)
Expand All @@ -152,6 +176,9 @@ func configAsInt(fieldData *paramSetField, val reflect.Value, bitSize int) {

fieldData.setValueFn = func(str *string) error {
panicOnNil(str)
if *str == "" {
return nil
}
v, err := strconv.ParseInt(*str, 10, bitSize)
if err != nil {
return badNumberErr(true, bitSize)
Expand All @@ -168,6 +195,9 @@ func configAsInt(fieldData *paramSetField, val reflect.Value, bitSize int) {

func configAsUint(fieldData *paramSetField, val reflect.Value, bitSize int) {
fieldData.validFn = func(str string) error {
if str == "" {
return types.ErrNoValue
}
_, err := strconv.ParseUint(str, 10, bitSize)
if err != nil {
return badNumberErr(false, bitSize)
Expand All @@ -178,6 +208,9 @@ func configAsUint(fieldData *paramSetField, val reflect.Value, bitSize int) {

fieldData.setValueFn = func(str *string) error {
panicOnNil(str)
if *str == "" {
return nil
}
v, err := strconv.ParseUint(*str, 10, bitSize)
if err != nil {
return badNumberErr(false, bitSize)
Expand Down
173 changes: 173 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,176 @@ func generateTestKey(t *testing.T) (*rsa.PrivateKey, string) {
}
return privateKey, string(pem.EncodeToMemory(privateKeyPEM))
}

func TestOptionalBasicTypes(t *testing.T) {
tests := []struct {
name string
params types.ParamValues
shouldErr bool
expectInt int
expectBool bool
expectDur time.Duration
}{
{
name: "no value for optional params",
params: types.ParamValues{
"": {
"req": "1",
},
},
shouldErr: false,
expectInt: 42,
expectBool: true,
expectDur: time.Hour,
},
{
name: "empty string for optional params",
params: types.ParamValues{
"": {
"i": "",
"b": "",
"d": "",
"req": "2",
},
},
shouldErr: false,
expectInt: 42,
expectBool: true,
expectDur: time.Hour,
},
{
name: "valid values for optional params",
params: types.ParamValues{
"": {
"i": "123",
"b": "false",
"d": "10s",
"req": "3",
},
},
shouldErr: false,
expectInt: 123,
expectBool: false,
expectDur: 10 * time.Second,
},
{
name: "empty string for required param",
params: types.ParamValues{
"": {
"req": "",
},
},
shouldErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := struct {
I int `param:",optional"`
B bool `param:",optional"`
D time.Duration `param:",optional"`
Req int
}{
I: 42,
B: true,
D: time.Hour,
Req: 99,
}

testProvider := cfgtest.New(tt.params)
defer testProvider.Stop()

_, err := proteus.MustParse(&cfg, proteus.WithProviders(testProvider))

if tt.shouldErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectInt, cfg.I)
assert.Equal(t, tt.expectBool, cfg.B)
assert.Equal(t, tt.expectDur, cfg.D)
}
})
}
}

func TestOptionalXTypes(t *testing.T) {
tests := []struct {
name string
params types.ParamValues
shouldErr bool
expectInt int
expectBool bool
expectJSON string
}{
{
name: "no value for optional xtypes",
params: types.ParamValues{"": {"req": "1"}},
shouldErr: false,
expectInt: 88,
expectBool: true,
expectJSON: `{"a":"b"}`,
},
{
name: "empty string for optional xtypes",
params: types.ParamValues{
"": {
"xi": "",
"xb": "",
"xj": "",
"req": "2",
},
},
shouldErr: false,
expectInt: 88,
expectBool: true,
expectJSON: `{"a":"b"}`,
},
{
name: "valid values for optional xtypes",
params: types.ParamValues{
"": {
"xi": "-5",
"xb": "false",
"xj": `[1,2]`, // raw json
"req": "3",
},
},
shouldErr: false,
expectInt: -5,
expectBool: false,
expectJSON: `[1,2]`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := struct {
XI *xtypes.Integer[int] `param:",optional"`
XB *xtypes.Bool `param:",optional"`
XJ *xtypes.RawJSON `param:",optional"`
Req int
}{
XI: &xtypes.Integer[int]{DefaultValue: 88},
XB: &xtypes.Bool{DefaultValue: true},
XJ: &xtypes.RawJSON{DefaultValue: []byte(`{"a":"b"}`)},
Req: 99,
}

testProvider := cfgtest.New(tt.params)
defer testProvider.Stop()

_, err := proteus.MustParse(&cfg, proteus.WithProviders(testProvider))

if tt.shouldErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectInt, cfg.XI.Value())
assert.Equal(t, tt.expectBool, cfg.XB.Value())
assert.Equal(t, tt.expectJSON, string(cfg.XJ.Value()))
}
})
}
}
5 changes: 4 additions & 1 deletion xtypes/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ var _ types.XType = &Bool{}
// UnmarshalParam parses the input as a boolean.
func (d *Bool) UnmarshalParam(in *string) error {
var ptrBool *bool
if in != nil {
if in != nil && *in != "" {
boolValue, err := strconv.ParseBool(*in)
if err != nil {
return errors.New("not a valid boolean")
Expand Down Expand Up @@ -60,6 +60,9 @@ func (d *Bool) Value() bool {
// ValueValid test if the provided parameter value is valid. Has no side
// effects.
func (d *Bool) ValueValid(s string) error {
if s == "" {
return types.ErrNoValue
}
_, err := strconv.ParseBool(s)
return err
}
Expand Down
5 changes: 4 additions & 1 deletion xtypes/integer.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ var _ types.XType = &Integer[int]{}
// UnmarshalParam parses the input as an integer of type T.
func (d *Integer[T]) UnmarshalParam(in *string) error {
var ptrT *T
if in != nil {
if in != nil && *in != "" {
valT, err := parseInt[T](*in)
if err != nil {
return errors.New("invalid value for the numeric type")
Expand Down Expand Up @@ -63,6 +63,9 @@ func (d *Integer[T]) Value() T {
// ValueValid test if the provided parameter value is valid. Has no side
// effects.
func (d *Integer[T]) ValueValid(s string) error {
if s == "" {
return types.ErrNoValue
}
_, err := parseInt[T](s)
return err
}
Expand Down
5 changes: 4 additions & 1 deletion xtypes/rawjson.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var _ types.XType = &RawJSON{}
// UnmarshalParam parses the input as a string.
func (d *RawJSON) UnmarshalParam(in *string) error {
var j json.RawMessage
if in != nil {
if in != nil && *in != "" {
err := json.Unmarshal([]byte(*in), &j)
if err != nil {
return err
Expand Down Expand Up @@ -63,6 +63,9 @@ func (d *RawJSON) Value() json.RawMessage {
// ValueValid test if the provided parameter value is valid. Has no side
// effects.
func (d *RawJSON) ValueValid(s string) error {
if s == "" {
return types.ErrNoValue
}
var j json.RawMessage
return json.Unmarshal([]byte(s), &j)
}
Expand Down
5 changes: 4 additions & 1 deletion xtypes/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var _ types.Redactor = &URL{}
// UnmarshalParam parses the input as a string.
func (d *URL) UnmarshalParam(in *string) error {
var url *url.URL
if in != nil {
if in != nil && *in != "" {
var err error
url, err = parseURL(*in, d.ValidateFn)
if err != nil {
Expand Down Expand Up @@ -61,6 +61,9 @@ func (d *URL) Value() *url.URL {
// ValueValid test if the provided parameter value is valid. Has no side
// effects.
func (d *URL) ValueValid(s string) error {
if s == "" {
return types.ErrNoValue
}
_, err := parseURL(s, d.ValidateFn)
return err
}
Expand Down
Loading