From 07e4d07d17fcf91ba5b0a2637f9b4f179d960c7c Mon Sep 17 00:00:00 2001 From: Tomas Aschan <1550920+tomasaschan@users.noreply.github.com> Date: Thu, 17 Jul 2025 11:54:51 +0200 Subject: [PATCH] Make Value a drop-in replacement for flag.Value --- flag.go | 28 +++++++++++++++++++++++----- text.go | 4 ++-- time.go | 4 ++-- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/flag.go b/flag.go index 107fa190..b4b57cfa 100644 --- a/flag.go +++ b/flag.go @@ -200,6 +200,15 @@ type Flag struct { type Value interface { String() string Set(string) error +} + +// TypedValue wraps Value but adds an additional Type() string, which +// can be used to special-case things like usage instructions, error +// messages, parsing, etc. +// Its value should be a constant string that identifies the type of +// the underlying flag value. +type TypedValue interface { + Value Type() string } @@ -398,8 +407,8 @@ func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval stri return nil, err } - if flag.Value.Type() != ftype { - err := fmt.Errorf("trying to get %s value of flag of type %s", ftype, flag.Value.Type()) + if tvalue, ok := flag.Value.(TypedValue); ok && tvalue.Type() != ftype { + err := fmt.Errorf("trying to get %s value of flag of type %s", ftype, tvalue.Type()) return nil, err } @@ -597,7 +606,10 @@ func UnquoteUsage(flag *Flag) (name string, usage string) { } } - name = flag.Value.Type() + name = "value" + if tvalue, ok := flag.Value.(TypedValue); ok { + name = tvalue.Type() + } switch name { case "bool", "boolfunc": name = "" @@ -716,8 +728,14 @@ func (f *FlagSet) FlagUsagesWrapped(cols int) string { if varname != "" { line += " " + varname } + tvalue, tvalueOk := flag.Value.(TypedValue) if flag.NoOptDefVal != "" { - switch flag.Value.Type() { + if !tvalueOk { + line += fmt.Sprintf(" [=%s]", flag.NoOptDefVal) + return + } + + switch tvalue.Type() { case "string": line += fmt.Sprintf("[=\"%s\"]", flag.NoOptDefVal) case "bool", "boolfunc": @@ -742,7 +760,7 @@ func (f *FlagSet) FlagUsagesWrapped(cols int) string { line += usage if !flag.defaultIsZeroValue() { - if flag.Value.Type() == "string" { + if tvalueOk && tvalue.Type() == "string" { line += fmt.Sprintf(" (default %q)", flag.DefValue) } else { line += fmt.Sprintf(" (default %s)", flag.DefValue) diff --git a/text.go b/text.go index 886d5a3d..8319d0b2 100644 --- a/text.go +++ b/text.go @@ -54,8 +54,8 @@ func (f *FlagSet) GetText(name string, out encoding.TextUnmarshaler) error { if flag == nil { return fmt.Errorf("flag accessed but not defined: %s", name) } - if flag.Value.Type() != reflect.TypeOf(out).Name() { - return fmt.Errorf("trying to get %s value of flag of type %s", reflect.TypeOf(out).Name(), flag.Value.Type()) + if tvalue, ok := flag.Value.(TypedValue); ok && tvalue.Type() != reflect.TypeOf(out).Name() { + return fmt.Errorf("trying to get %s value of flag of type %s", reflect.TypeOf(out).Name(), tvalue.Type()) } return out.UnmarshalText([]byte(flag.Value.String())) } diff --git a/time.go b/time.go index dc024807..d0d44d80 100644 --- a/time.go +++ b/time.go @@ -58,8 +58,8 @@ func (f *FlagSet) GetTime(name string) (time.Time, error) { return time.Time{}, err } - if flag.Value.Type() != "time" { - err := fmt.Errorf("trying to get %s value of flag of type %s", "time", flag.Value.Type()) + if tvalue, ok := flag.Value.(TypedValue); ok && tvalue.Type() != "time" { + err := fmt.Errorf("trying to get %s value of flag of type %s", "time", tvalue.Type()) return time.Time{}, err }