diff --git a/codec.go b/codec.go index e877713..060b174 100644 --- a/codec.go +++ b/codec.go @@ -3,6 +3,7 @@ package plenc import ( "fmt" "reflect" + "strings" "sync" "github.com/philpearl/plenc/plenccodec" @@ -113,9 +114,20 @@ func (p *Plenc) CodecForTypeRegistry(registry plenccodec.CodecRegistry, typ refl c = plenccodec.PointerWrapper{Underlying: subc} case reflect.Struct: - c, err = plenccodec.BuildStructCodec(p, registry, typ, tag) - if err != nil { - return nil, err + // Is this an Optional? The reflect package doesn't have a great way for + // us to tell yet. So we just hack around with the name and package path. + // Improve this when reflect fully supports generics. + if strings.HasSuffix(typ.PkgPath(), "plenc/plenccodec") && + strings.HasPrefix(typ.Name(), "Optional[") { + c, err = plenccodec.BuildOptionalCodec(p, registry, typ, tag) + if err != nil { + return nil, err + } + } else { + c, err = plenccodec.BuildStructCodec(p, registry, typ, tag) + if err != nil { + return nil, err + } } case reflect.Slice: diff --git a/plenccodec/optional.go b/plenccodec/optional.go new file mode 100644 index 0000000..0029003 --- /dev/null +++ b/plenccodec/optional.go @@ -0,0 +1,111 @@ +package plenccodec + +import ( + "fmt" + "reflect" + "unsafe" + + "github.com/philpearl/plenc/plenccore" +) + +// Optional is a type that can be used to represent an optional value, without +// resorting to pointers to indicate presence. +// +// Optional should be used within structs, not as a top-level type. +type Optional[T any] struct { + Set bool + Value T +} + +// OptionalOf creates an Optional[T] with the given value. +// It is a convenience function to avoid having to create an Optional[T] struct manually. +// This is useful for creating optional values in a more readable way. +// +// Example usage: +// +// opt := OptionalOf(42) // Creates an Optional[int] with Set=true and Value=42 +// opt := OptionalOf("hello") // Creates an Optional[string] with Set=true and Value="hello" +// +// Note: This function is generic and works with any type T. +func OptionalOf[T any](value T) Optional[T] { + return Optional[T]{ + Set: true, + Value: value, + } +} + +// optionalHeader lets us access the Set field of any Optional[T] without +// needing a concrete implementation of the actual type. +type optionalHeader struct { + Set bool +} + +func BuildOptionalCodec(p CodecBuilder, registry CodecRegistry, typ reflect.Type, tag string) (Codec, error) { + valueField := typ.Field(1) + offset := valueField.Offset + underlying, err := p.CodecForTypeRegistry(registry, valueField.Type, tag) + if err != nil { + return nil, fmt.Errorf("building codec for underlying type %s: %w", typ.Name(), err) + } + + return OptionalCodec{ + underlying: underlying, + offset: offset, + typ: typ, + }, nil +} + +// OptionalCodec is a codec for Optional[T] +type OptionalCodec struct { + underlying Codec + offset uintptr + typ reflect.Type +} + +func (p OptionalCodec) Omit(ptr unsafe.Pointer) bool { + t := (*optionalHeader)(ptr) + return !t.Set +} + +func (p OptionalCodec) Read(data []byte, ptr unsafe.Pointer, wt plenccore.WireType) (n int, err error) { + t := (*optionalHeader)(ptr) + // Need offset of the value, which depends in its alignment + n, err = p.underlying.Read(data, unsafe.Add(ptr, p.offset), wt) + if err != nil { + return n, err + } + t.Set = true + return n, nil +} + +func (p OptionalCodec) New() unsafe.Pointer { + return unsafe.Pointer(reflect.New(p.typ).Pointer()) +} + +func (p OptionalCodec) WireType() plenccore.WireType { + return p.underlying.WireType() +} + +func (p OptionalCodec) Descriptor() Descriptor { + d := p.underlying.Descriptor() + d.ExplicitPresence = true + return d +} + +func (p OptionalCodec) Size(ptr unsafe.Pointer, tag []byte) int { + // This should never be called if Omit returns true + t := (*optionalHeader)(ptr) + if !t.Set { + return 0 + } + return p.underlying.Size(unsafe.Add(ptr, p.offset), tag) +} + +func (p OptionalCodec) Append(data []byte, ptr unsafe.Pointer, tag []byte) []byte { + // This should never be called if Omit returns true + t := (*optionalHeader)(ptr) + if !t.Set { + return data + } + return p.underlying.Append(data, unsafe.Add(ptr, p.offset), tag) +} diff --git a/plenccodec/optional_test.go b/plenccodec/optional_test.go new file mode 100644 index 0000000..1a6976d --- /dev/null +++ b/plenccodec/optional_test.go @@ -0,0 +1,108 @@ +package plenccodec_test + +import ( + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/philpearl/plenc" + "github.com/philpearl/plenc/plenccodec" +) + +func TestOptionalRoundTrip(t *testing.T) { + tests := []struct { + name string + in any + expbytes []byte + }{ + { + name: "zero", + in: plenccodec.Optional[int]{Set: true, Value: 0}, + expbytes: []byte{0x00}, + }, + { + name: "1", + in: plenccodec.Optional[int]{Set: true, Value: 1}, + expbytes: []byte{0x02}, + }, + { + name: "empty string", + in: plenccodec.Optional[string]{Set: true, Value: ""}, + }, + { + name: "struct", + in: struct { + A plenccodec.Optional[int] `plenc:"1"` + B plenccodec.Optional[string] `plenc:"2"` + C plenccodec.Optional[float64] `plenc:"3"` + }{ + A: plenccodec.Optional[int]{Set: true, Value: 42}, + B: plenccodec.Optional[string]{Set: true, Value: "hello"}, + }, + expbytes: []byte{0x08, 0x54, 0x12, 0x05, 'h', 'e', 'l', 'l', 'o'}, + }, + { + name: "struct all set", + in: struct { + A plenccodec.Optional[int] `plenc:"1"` + B plenccodec.Optional[string] `plenc:"2"` + C plenccodec.Optional[float64] `plenc:"3"` + }{ + A: plenccodec.Optional[int]{Set: true, Value: 42}, + B: plenccodec.Optional[string]{Set: true, Value: "hello"}, + C: plenccodec.Optional[float64]{Set: true, Value: 3.14}, + }, + expbytes: []byte{0x08, 0x54, + 0x12, 0x05, 'h', 'e', 'l', 'l', 'o', + 0x19, 0x1f, 0x85, 0xeb, 0x51, 0xb8, 0x1e, 0x09, 0x40}, + }, + { + name: "struct zero values", + in: struct { + A plenccodec.Optional[int] `plenc:"1"` + B plenccodec.Optional[string] `plenc:"2"` + C plenccodec.Optional[float64] `plenc:"3"` + }{ + A: plenccodec.Optional[int]{Set: true, Value: 0}, + B: plenccodec.Optional[string]{Set: true, Value: ""}, + C: plenccodec.Optional[float64]{Set: true, Value: 0.0}, + }, + + expbytes: []byte{0x08, 0x00, + 0x12, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }, + { + name: "struct empty", + in: struct { + A plenccodec.Optional[int] `plenc:"1"` + B plenccodec.Optional[string] `plenc:"2"` + C plenccodec.Optional[float64] `plenc:"3"` + }{}, + + expbytes: []byte{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + data, err := plenc.Marshal(nil, test.in) + if err != nil { + t.Fatal(err) + } + + if string(test.expbytes) != string(data) { + t.Errorf("Expected bytes %x, got %x", test.expbytes, data) + } + + out := reflect.New(reflect.TypeOf(test.in)) + if err := plenc.Unmarshal(data, out.Interface()); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(test.in, out.Elem().Interface()); diff != "" { + t.Errorf("Round trip failed: %s", diff) + } + }) + } +}