diff --git a/Makefile b/Makefile index 36085559..4c9f6c0a 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ export PROTOBUF_ROOT=$(PWD)/_vendor/protobuf-21.12 install: bin/protoc-gen-go-vtproto bin/protoc-gen-go bin/protoc-gen-go-vtproto: - go install -tags protolegacy ./cmd/protoc-gen-go-vtproto + go install -buildvcs=false -tags protolegacy ./cmd/protoc-gen-go-vtproto bin/protoc-gen-go: go install -tags protolegacy google.golang.org/protobuf/cmd/protoc-gen-go diff --git a/features/pool/pool.go b/features/pool/pool.go index 004386b0..4f7efdca 100644 --- a/features/pool/pool.go +++ b/features/pool/pool.go @@ -63,6 +63,8 @@ func (p *pool) message(message *protogen.Message) { p.P(`mm.Reset()`) } p.P(`}`) + case protoreflect.BytesKind, protoreflect.StringKind: + p.P(`clear(m.`, fieldName, `)`) } p.P(fmt.Sprintf("f%d", len(saved)), ` := m.`, fieldName, `[:0]`) saved = append(saved, field) diff --git a/go.mod b/go.mod index 91c37af2..dec90527 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/planetscale/vtprotobuf -go 1.20 +go 1.21 require ( github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index 378518bc..1bf44503 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,7 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= diff --git a/testproto/pool/pool_with_oneof_vtproto.pb.go b/testproto/pool/pool_with_oneof_vtproto.pb.go index 1c75bbb4..7fd57350 100644 --- a/testproto/pool/pool_with_oneof_vtproto.pb.go +++ b/testproto/pool/pool_with_oneof_vtproto.pb.go @@ -973,6 +973,7 @@ var vtprotoPool_OneofTest_Test2 = sync.Pool{ func (m *OneofTest_Test2) ResetVT() { if m != nil { + clear(m.B) f0 := m.B[:0] m.Reset() m.B = f0 diff --git a/testproto/pool/pool_with_slice_reuse.pb.go b/testproto/pool/pool_with_slice_reuse.pb.go index aca74380..6489a1f2 100644 --- a/testproto/pool/pool_with_slice_reuse.pb.go +++ b/testproto/pool/pool_with_slice_reuse.pb.go @@ -249,6 +249,53 @@ func (x *Element2) GetA() int32 { return 0 } +type Test3 struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Sl [][]byte `protobuf:"bytes,1,rep,name=Sl,proto3" json:"Sl,omitempty"` +} + +func (x *Test3) Reset() { + *x = Test3{} + if protoimpl.UnsafeEnabled { + mi := &file_pool_pool_with_slice_reuse_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Test3) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Test3) ProtoMessage() {} + +func (x *Test3) ProtoReflect() protoreflect.Message { + mi := &file_pool_pool_with_slice_reuse_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Test3.ProtoReflect.Descriptor instead. +func (*Test3) Descriptor() ([]byte, []int) { + return file_pool_pool_with_slice_reuse_proto_rawDescGZIP(), []int{4} +} + +func (x *Test3) GetSl() [][]byte { + if x != nil { + return x.Sl + } + return nil +} + var File_pool_pool_with_slice_reuse_proto protoreflect.FileDescriptor var file_pool_pool_with_slice_reuse_proto_rawDesc = []byte{ @@ -275,8 +322,10 @@ var file_pool_pool_with_slice_reuse_proto_rawDesc = []byte{ 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x04, 0x0a, 0x02, 0x5f, 0x62, 0x22, 0x18, 0x0a, 0x08, 0x45, 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x32, 0x12, 0x0c, 0x0a, 0x01, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, - 0x52, 0x01, 0x61, 0x42, 0x10, 0x5a, 0x0e, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x2f, 0x70, 0x6f, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x52, 0x01, 0x61, 0x22, 0x1d, 0x0a, 0x05, 0x54, 0x65, 0x73, 0x74, 0x33, 0x12, 0x0e, 0x0a, 0x02, + 0x53, 0x6c, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x02, 0x53, 0x6c, 0x3a, 0x04, 0xa8, 0xa6, + 0x1f, 0x01, 0x42, 0x10, 0x5a, 0x0e, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, + 0x70, 0x6f, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -291,17 +340,18 @@ func file_pool_pool_with_slice_reuse_proto_rawDescGZIP() []byte { return file_pool_pool_with_slice_reuse_proto_rawDescData } -var file_pool_pool_with_slice_reuse_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_pool_pool_with_slice_reuse_proto_msgTypes = make([]protoimpl.MessageInfo, 6) var file_pool_pool_with_slice_reuse_proto_goTypes = []interface{}{ (*Test1)(nil), // 0: Test1 (*Test2)(nil), // 1: Test2 (*Slice2)(nil), // 2: Slice2 (*Element2)(nil), // 3: Element2 - nil, // 4: Slice2.AEntry + (*Test3)(nil), // 4: Test3 + nil, // 5: Slice2.AEntry } var file_pool_pool_with_slice_reuse_proto_depIdxs = []int32{ 2, // 0: Test2.Sl:type_name -> Slice2 - 4, // 1: Slice2.a:type_name -> Slice2.AEntry + 5, // 1: Slice2.a:type_name -> Slice2.AEntry 3, // 2: Slice2.d:type_name -> Element2 3, // [3:3] is the sub-list for method output_type 3, // [3:3] is the sub-list for method input_type @@ -364,6 +414,18 @@ func file_pool_pool_with_slice_reuse_proto_init() { return nil } } + file_pool_pool_with_slice_reuse_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Test3); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_pool_pool_with_slice_reuse_proto_msgTypes[2].OneofWrappers = []interface{}{} type x struct{} @@ -372,7 +434,7 @@ func file_pool_pool_with_slice_reuse_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_pool_pool_with_slice_reuse_proto_rawDesc, NumEnums: 0, - NumMessages: 5, + NumMessages: 6, NumExtensions: 0, NumServices: 0, }, diff --git a/testproto/pool/pool_with_slice_reuse.proto b/testproto/pool/pool_with_slice_reuse.proto index 6c0d544d..36e476d2 100644 --- a/testproto/pool/pool_with_slice_reuse.proto +++ b/testproto/pool/pool_with_slice_reuse.proto @@ -22,4 +22,9 @@ message Slice2 { } message Element2 { int32 a = 1; +} + +message Test3 { + option (vtproto.mempool) = true; + repeated bytes Sl = 1; } \ No newline at end of file diff --git a/testproto/pool/pool_with_slice_reuse_vtproto.pb.go b/testproto/pool/pool_with_slice_reuse_vtproto.pb.go index 530cb2e9..8d9585c9 100644 --- a/testproto/pool/pool_with_slice_reuse_vtproto.pb.go +++ b/testproto/pool/pool_with_slice_reuse_vtproto.pb.go @@ -117,6 +117,31 @@ func (m *Element2) CloneMessageVT() proto.Message { return m.CloneVT() } +func (m *Test3) CloneVT() *Test3 { + if m == nil { + return (*Test3)(nil) + } + r := Test3FromVTPool() + if rhs := m.Sl; rhs != nil { + tmpContainer := make([][]byte, len(rhs)) + for k, v := range rhs { + tmpBytes := make([]byte, len(v)) + copy(tmpBytes, v) + tmpContainer[k] = tmpBytes + } + r.Sl = tmpContainer + } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } + return r +} + +func (m *Test3) CloneMessageVT() proto.Message { + return m.CloneVT() +} + func (this *Test1) EqualVT(that *Test1) bool { if this == that { return true @@ -243,6 +268,31 @@ func (this *Element2) EqualMessageVT(thatMsg proto.Message) bool { } return this.EqualVT(that) } +func (this *Test3) EqualVT(that *Test3) bool { + if this == that { + return true + } else if this == nil || that == nil { + return false + } + if len(this.Sl) != len(that.Sl) { + return false + } + for i, vx := range this.Sl { + vy := that.Sl[i] + if string(vx) != string(vy) { + return false + } + } + return string(this.unknownFields) == string(that.unknownFields) +} + +func (this *Test3) EqualMessageVT(thatMsg proto.Message) bool { + that, ok := thatMsg.(*Test3) + if !ok { + return false + } + return this.EqualVT(that) +} func (m *Test1) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil @@ -452,6 +502,48 @@ func (m *Element2) MarshalToSizedBufferVT(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *Test3) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Test3) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *Test3) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.Sl) > 0 { + for iNdEx := len(m.Sl) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.Sl[iNdEx]) + copy(dAtA[i:], m.Sl[iNdEx]) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Sl[iNdEx]))) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + func (m *Test1) MarshalVTStrict() (dAtA []byte, err error) { if m == nil { return nil, nil @@ -661,6 +753,48 @@ func (m *Element2) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *Test3) MarshalVTStrict() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVTStrict(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Test3) MarshalToVTStrict(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVTStrict(dAtA[:size]) +} + +func (m *Test3) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.Sl) > 0 { + for iNdEx := len(m.Sl) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.Sl[iNdEx]) + copy(dAtA[i:], m.Sl[iNdEx]) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Sl[iNdEx]))) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + var vtprotoPool_Test1 = sync.Pool{ New: func() interface{} { return &Test1{} @@ -669,6 +803,7 @@ var vtprotoPool_Test1 = sync.Pool{ func (m *Test1) ResetVT() { if m != nil { + clear(m.Sl) f0 := m.Sl[:0] m.Reset() m.Sl = f0 @@ -709,6 +844,30 @@ func (m *Test2) ReturnToVTPool() { func Test2FromVTPool() *Test2 { return vtprotoPool_Test2.Get().(*Test2) } + +var vtprotoPool_Test3 = sync.Pool{ + New: func() interface{} { + return &Test3{} + }, +} + +func (m *Test3) ResetVT() { + if m != nil { + clear(m.Sl) + f0 := m.Sl[:0] + m.Reset() + m.Sl = f0 + } +} +func (m *Test3) ReturnToVTPool() { + if m != nil { + m.ResetVT() + vtprotoPool_Test3.Put(m) + } +} +func Test3FromVTPool() *Test3 { + return vtprotoPool_Test3.Get().(*Test3) +} func (m *Test1) SizeVT() (n int) { if m == nil { return 0 @@ -792,6 +951,22 @@ func (m *Element2) SizeVT() (n int) { return n } +func (m *Test3) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.Sl) > 0 { + for _, b := range m.Sl { + l = len(b) + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + } + n += len(m.unknownFields) + return n +} + func (m *Test1) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 @@ -1326,6 +1501,89 @@ func (m *Element2) UnmarshalVT(dAtA []byte) error { } return nil } +func (m *Test3) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Test3: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Test3: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Sl", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Sl = append(m.Sl, make([]byte, postIndex-iNdEx)) + copy(m.Sl[len(m.Sl)-1], dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func (m *Test1) UnmarshalVTUnsafe(dAtA []byte) error { l := len(dAtA) iNdEx := 0 @@ -1872,3 +2130,85 @@ func (m *Element2) UnmarshalVTUnsafe(dAtA []byte) error { } return nil } +func (m *Test3) UnmarshalVTUnsafe(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Test3: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Test3: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Sl", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Sl = append(m.Sl, dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +}