diff --git a/binary/borsh_test.go b/binary/borsh_test.go new file mode 100644 index 000000000..85223b0c5 --- /dev/null +++ b/binary/borsh_test.go @@ -0,0 +1,1764 @@ +// Copyright 2021 github.com/gagliardetto +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "bytes" + "fmt" + "io" + "math" + "reflect" + strings2 "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type OptionalPointerFields struct { + Good uint8 + Arr *Arr `bin:"optional"` +} + +type Arr []string + +func TestOptionWithPointer(t *testing.T) { + // nil (optional not present) + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := OptionalPointerFields{ + Good: 9, + // Will be decoded as nil pointer. + Arr: nil, + } + require.NoError(t, enc.Encode(val)) + require.Equal(t, + concatByteSlices( + []byte{9}, + []byte{0}, + ), + buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got OptionalPointerFields + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + // optional is present but has zero elements + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := OptionalPointerFields{ + Good: 9, + // Will be decoded as pointer to nil Arr. + Arr: &Arr{}, + } + require.NoError(t, enc.Encode(val)) + require.Equal(t, + concatByteSlices( + []byte{9}, + []byte{1}, + []byte{0, 0, 0, 0}, + ), + buf.Bytes(), + ) + { + dec := NewBorshDecoder(buf.Bytes()) + var got OptionalPointerFields + require.NoError(t, dec.Decode(&got)) + // an empty slice is decoded as nil. + po := (Arr)(nil) + val.Arr = &po + require.Equal(t, + val, got) + } + } + // optional is present and has elements + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := OptionalPointerFields{ + Good: 9, + Arr: &Arr{"foo"}, + } + require.NoError(t, enc.Encode(val)) + require.Equal(t, + concatByteSlices( + []byte{9}, + []byte{1}, + []byte{1, 0, 0, 0}, + + []byte{3, 0, 0, 0}, + []byte("foo"), + ), + buf.Bytes(), + ) + { + dec := NewBorshDecoder(buf.Bytes()) + var got OptionalPointerFields + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } +} + +type StructWithComplexPeculiarEnums struct { + Complex2NotSet ComplexEnumPointers + Complex2PtrNotSet *ComplexEnumPointers + + // Complex2PtrOptionalSet *ComplexEnumPointers `bin:"optional"` +} + +func TestBorsh_peculiarEnums(t *testing.T) { + t.Skip() + { + // struct with peculiar complex enums: + { + // buf := new(bytes.Buffer) + buf := NewWriteByWrite("") + enc := NewBorshEncoder(buf) + val := StructWithComplexPeculiarEnums{ + // If the enums are left empty, they won't serialize correctly. + } + require.NoError(t, enc.Encode(val)) + fmt.Println(buf.String()) + require.Equal(t, + concatByteSlices( + []byte{0}, + []byte{0, 0, 0, 0}, + []byte{0, 0, 0, 0}, + + []byte{0}, + []byte{0, 0, 0, 0}, + []byte{0, 0, 0, 0}, + ), + buf.Bytes(), + ) + + { + dec := NewBorshDecoder(buf.Bytes()) + var got StructWithComplexPeculiarEnums + require.NoError(t, dec.Decode(&got)) + { + } + require.Equal(t, val, got) + } + } + } +} + +func TestBorsh_Encode(t *testing.T) { + // ints: + { + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := int8(33) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{33}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got int8 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := int16(44) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{44, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got int16 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := int32(55) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{55, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got int32 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := int64(556) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0x2c, 0x2, 0, 0, 0, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got int64 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + // pointers to a basic type shall be encoded as values. + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := int64(556) + require.NoError(t, enc.Encode(&val)) + require.Equal(t, []byte{0x2c, 0x2, 0, 0, 0, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got int64 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := int8(120) + require.NoError(t, enc.Encode(&val)) + require.Equal(t, []byte{120}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got int8 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + } + { + // pointer to a nil value of a basic type shall be encoded as the zero value of that type: + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := new(int64) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got int64 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := new(int8) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got int8 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + } + } + // uints: + { + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := uint8(33) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{33}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got uint8 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := uint16(44) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{44, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got uint16 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := uint32(55) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{55, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got uint32 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := uint64(556) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0x2c, 0x2, 0, 0, 0, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got uint64 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + // pouinters to a basic type shall be encoded as values. + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := uint64(556) + require.NoError(t, enc.Encode(&val)) + require.Equal(t, []byte{0x2c, 0x2, 0, 0, 0, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got uint64 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := uint8(120) + require.NoError(t, enc.Encode(&val)) + require.Equal(t, []byte{120}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got uint8 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + } + { + // pointer to a nil value of a basic type shall be encoded as the zero value of that type: + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := new(uint64) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got uint64 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := new(uint8) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got uint8 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + } + } + { + // bool + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + require.NoError(t, enc.Encode(true)) + require.Equal(t, []byte{1}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got bool + require.NoError(t, dec.Decode(&got)) + require.Equal(t, true, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + require.NoError(t, enc.Encode(false)) + require.Equal(t, []byte{0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got bool + require.NoError(t, dec.Decode(&got)) + require.Equal(t, false, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := false + require.NoError(t, enc.Encode(&val)) + require.Equal(t, []byte{0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got bool + require.NoError(t, dec.Decode(&got)) + require.Equal(t, false, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := true + require.NoError(t, enc.Encode(&val)) + require.Equal(t, []byte{1}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got bool + require.NoError(t, dec.Decode(&got)) + require.Equal(t, true, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := new(bool) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got bool + require.NoError(t, dec.Decode(&got)) + require.Equal(t, false, got) + } + } + } + { + // floats + { + // float32 + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := float32(1.123) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0x77, 0xbe, 0x8f, 0x3f}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got float32 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := float32(1.123) + require.NoError(t, enc.Encode(&val)) + require.Equal(t, []byte{0x77, 0xbe, 0x8f, 0x3f}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got float32 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := new(float32) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got float32 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + } + { + // float64 + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := float64(1.123) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0x2b, 0x87, 0x16, 0xd9, 0xce, 0xf7, 0xf1, 0x3f}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got float64 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := float64(1.123) + require.NoError(t, enc.Encode(&val)) + require.Equal(t, []byte{0x2b, 0x87, 0x16, 0xd9, 0xce, 0xf7, 0xf1, 0x3f}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got float64 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := new(float64) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 0}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got float64 + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + } + } + { + // string + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := string("hello world") + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0xb, 0x0, 0x0, 0x0, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64}, buf.Bytes()) + require.Equal(t, append([]byte{byte(len(val)), 0, 0, 0}, []byte(val)...), buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got string + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := string("hello world") + require.NoError(t, enc.Encode(&val)) + require.Equal(t, []byte{0xb, 0x0, 0x0, 0x0, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64}, buf.Bytes()) + require.Equal(t, append([]byte{byte(len(val)), 0, 0, 0}, []byte(val)...), buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got string + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := new(string) + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{0x0, 0x0, 0x0, 0x0}, buf.Bytes()) + require.Equal(t, append([]byte{0, 0, 0, 0}, []byte{}...), buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got string + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + } + { + // interface + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + var val io.Reader + require.NoError(t, enc.Encode(val)) + require.Equal(t, ([]byte)(nil), buf.Bytes()) + } + { + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + var val io.Reader + require.NoError(t, enc.Encode(&val)) + require.Equal(t, ([]byte)(nil), buf.Bytes()) + } + } + { + // type that has `func (e CustomEncoding) MarshalWithEncoder(encoder *Encoder) error` method. + // NOTE: the `MarshalWithEncoder` method MUST be on value (NOT on pointer). + { + // by value: + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := CustomEncoding{ + Prefix: byte('a'), + Value: 33, + } + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{33, 0, 0, 0, byte('a')}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got CustomEncoding + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + // by pointer: + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := &CustomEncoding{ + Prefix: byte('a'), + Value: 33, + } + require.NoError(t, enc.Encode(val)) + require.Equal(t, []byte{33, 0, 0, 0, byte('a')}, buf.Bytes()) + { + dec := NewBorshDecoder(buf.Bytes()) + var got CustomEncoding + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + } + { + // struct + { + // simple + { + // by value: + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := Struct{ + Foo: "hello", + Bar: 33, + } + require.NoError(t, enc.Encode(val)) + require.Equal(t, + concatByteSlices( + []byte{byte(len(val.Foo)), 0, 0, 0}, + []byte(val.Foo), + []byte{33, 0, 0, 0}, + ), + buf.Bytes(), + ) + { + dec := NewBorshDecoder(buf.Bytes()) + var got Struct + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + // by pointer: + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := &Struct{ + Foo: "hello", + Bar: 33, + } + require.NoError(t, enc.Encode(val)) + require.Equal(t, + concatByteSlices( + []byte{byte(len(val.Foo)), 0, 0, 0}, + []byte(val.Foo), + []byte{33, 0, 0, 0}, + ), + buf.Bytes(), + ) + { + dec := NewBorshDecoder(buf.Bytes()) + var got Struct + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + } + { + // with fields that are pointers + hello := "hello" + bar := uint32(33) + { + // by value: + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := StructWithPointerFields{ + Foo: &hello, + Bar: &bar, + } + require.NoError(t, enc.Encode(val)) + require.Equal(t, + concatByteSlices( + []byte{byte(len(*val.Foo)), 0, 0, 0}, + []byte(*val.Foo), + []byte{33, 0, 0, 0}, + ), + buf.Bytes(), + ) + { + dec := NewBorshDecoder(buf.Bytes()) + var got StructWithPointerFields + require.NoError(t, dec.Decode(&got)) + require.Equal(t, val, got) + } + } + { + // by pointer: + buf := new(bytes.Buffer) + enc := NewBorshEncoder(buf) + val := &StructWithPointerFields{ + Foo: &hello, + Bar: &bar, + } + require.NoError(t, enc.Encode(val)) + require.Equal(t, + concatByteSlices( + []byte{byte(len(*val.Foo)), 0, 0, 0}, + []byte(*val.Foo), + []byte{33, 0, 0, 0}, + ), + buf.Bytes(), + ) + { + dec := NewBorshDecoder(buf.Bytes()) + var got StructWithPointerFields + require.NoError(t, dec.Decode(&got)) + require.Equal(t, *val, got) + } + } + } + { + // with optional fields + { + // buf := new(bytes.Buffer) + fooRequired := "hello" + fooOptional := "-world" + barPointer := uint32(33) + buf := NewWriteByWrite("") + enc := NewBorshEncoder(buf) + val := StructWithOptionalFields{ + FooRequired: &fooRequired, + FooPointer: &fooOptional, + BarPointer: &barPointer, + FooValue: "hi", + } + require.NoError(t, enc.Encode(val)) + // fmt.Println(buf.String()) + require.Equal(t, + concatByteSlices( + // .FooRequired + []byte{5, 0, 0, 0}, + []byte(*val.FooRequired), + + // .BarRequiredNotSet + []byte{0, 0, 0, 0}, + + // .FooPointer (optional) + []byte{1}, + []byte{6, 0, 0, 0}, + []byte(*val.FooPointer), + + // .FooPointerNotSet (optional) + []byte{0}, + + // .BarPointer (optional) + []byte{1}, + []byte{33, 0, 0, 0}, + + // .FooValue (optional) + []byte{1}, + []byte{2, 0, 0, 0}, + []byte(val.FooValue), + + // .BarValueNotSet (optional) + []byte{0}, + + // .Hello + []byte{0, 0, 0, 0}, + ), + buf.Bytes(), + ) + // - 0: [5, 0, 0, 0](len=4) + // - 1: [104, 101, 108, 108, 111](len=5) + // - 2: [0, 0, 0, 0](len=4) + // - 3: [1](len=1) + // - 4: [6, 0, 0, 0](len=4) + // - 5: [45, 119, 111, 114, 108, 100](len=6) + // - 6: [0](len=1) + // - 7: [1](len=1) + // - 8: [33, 0, 0, 0](len=4) + // - 9: [1](len=1) + // - 10: [2, 0, 0, 0](len=4) + // - 11: [104, 105](len=2) + // - 12: [0](len=1) + // - 13: [0, 0, 0, 0](len=4) + // - 14: [](len=0) + { + dec := NewBorshDecoder(buf.Bytes()) + var got StructWithOptionalFields + require.NoError(t, dec.Decode(&got)) + { + // .BarRequiredNotSet is NOT an optiona field, + // which means that it was encoded as zero, + // and will be decoded as zero. + zero := uint32(0) + val.BarRequiredNotSet = &zero + } + require.Equal(t, val, got) + } + } + } + { + // struct with enums + { + buf := NewWriteByWrite("") + enc := NewBorshEncoder(buf) + simple := z + val := StructWithEnum{ + Simple: y, + SimplePointer: &simple, + + Complex: ComplexEnum{ + Enum: 1, + Bar: Bar{ + BarA: 99, + BarB: "this is bar", + }, + }, + ComplexPtr: &ComplexEnum{ + Enum: 1, + Bar: Bar{ + BarA: 22, + BarB: "this is bar from pointer", + }, + }, + ComplexEmpty: ComplexEnumEmpty{ + Enum: 0, + Foo: EmptyVariant{}, + }, + + ComplexPrimitives1: ComplexEnumPrimitives{ + Enum: 0, + Foo: 20, + }, + + ComplexPrimitives2: ComplexEnumPrimitives{ + Enum: 1, + Bar: 11, + }, + + Complex2: ComplexEnumPointers{ + Enum: 1, + Bar: &Bar{ + BarA: 62, + BarB: "very tested!!!", + }, + }, + + Complex2Ptr: &ComplexEnumPointers{ + Enum: 1, + Bar: &Bar{ + BarA: 123, + BarB: "lorem ipsum", + }, + }, + + Complex2PtrOptionalSet: &ComplexEnumPointers{ + Enum: 1, + Bar: &Bar{ + BarA: 32, + BarB: "very complex", + }, + }, + + Map: map[string]uint64{ + "foo": 1, + "bar": 46, + }, + + Slice: []Struct{ + { + Foo: "this is first foo", + Bar: 97, + }, + { + Foo: "this is second foo", + Bar: 98, + }, + }, + + Array: [4]Struct{ + { + Foo: "arr 0", + Bar: 22, + }, + { + Foo: "arr 1", + Bar: 23, + }, + { + Foo: "arr 2", + Bar: 24, + }, + { + Foo: "arr 3", + Bar: 25, + }, + }, + } + require.NoError(t, enc.Encode(val)) + // fmt.Println(buf.String()) + require.Equal(t, + concatByteSlices( + // .Simple + []byte{1}, + + // .SimplePointer + []byte{2}, + + // .Complex + []byte{1}, + []byte{99, 0, 0, 0, 0, 0, 0, 0}, + []byte{11, 0, 0, 0}, + []byte(val.Complex.Bar.BarB), + + // .ComplexNotSet + []byte{0}, + []byte{0, 0, 0, 0}, + []byte{0, 0, 0, 0}, + + // .ComplexPtr + []byte{1}, + []byte{22, 0, 0, 0, 0, 0, 0, 0}, + []byte{24, 0, 0, 0}, + []byte(val.ComplexPtr.Bar.BarB), + + // .ComplexPtrNotSet is not set, leaving the index to zero + // which corresponds to ComplexPtrNotSet.Foo + []byte{0}, + []byte{0, 0, 0, 0}, + []byte{0, 0, 0, 0}, + + // .ComplexEmpty + []byte{0}, + + // .ComplexPrimitives1 + []byte{0}, + []byte{20, 0, 0, 0}, + + // .ComplexPrimitives2 + []byte{1}, + []byte{11, 0}, + + // .Complex2 + []byte{1}, + []byte{62, 0, 0, 0, 0, 0, 0, 0}, + []byte{14, 0, 0, 0}, + []byte(val.Complex2.Bar.BarB), + + // .Complex2Ptr + []byte{1}, + []byte{123, 0, 0, 0, 0, 0, 0, 0}, + []byte{11, 0, 0, 0}, // = len(.Complex2Ptr.Bar.BarB) + []byte(val.Complex2Ptr.Bar.BarB), + + // .Complex2PtrOptionalSet + []byte{1}, // TODO: why is this set? this shouldn't be here. + []byte{1}, + []byte{32, 0, 0, 0, 0, 0, 0, 0}, + []byte{12, 0, 0, 0}, // = len(.Complex2PtrOptionalSet.Bar.BarB) + []byte(val.Complex2PtrOptionalSet.Bar.BarB), + + // .Complex2PtrOptionalNotSet is optional, and is not set. + []byte{0}, + + // .Map + []byte{2, 0, 0, 0}, // len of map + []byte{3, 0, 0, 0}, // len of key "bar" (comes in alphabetical order) + []byte("bar"), + []byte{46, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0}, // len of key "foo" (comes in alphabetical order) + []byte("foo"), + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + + // .Slice + []byte{2, 0, 0, 0}, // len of slice + // .Slice[0] + []byte{17, 0, 0, 0}, // len of [0].Foo + []byte(val.Slice[0].Foo), + []byte{97, 0, 0, 0}, + // .Slice[1] + []byte{18, 0, 0, 0}, // len of [1].Foo + []byte(val.Slice[1].Foo), + []byte{98, 0, 0, 0}, + + // .Array + // .Array[0] + []byte{5, 0, 0, 0}, // len of [0].Foo + []byte(val.Array[0].Foo), + []byte{22, 0, 0, 0}, + // .Array[1] + []byte{5, 0, 0, 0}, // len of [1].Foo + []byte(val.Array[1].Foo), + []byte{23, 0, 0, 0}, + // .Array[2] + []byte{5, 0, 0, 0}, // len of [2].Foo + []byte(val.Array[2].Foo), + []byte{24, 0, 0, 0}, + // .Array[3] + []byte{5, 0, 0, 0}, // len of [3].Foo + []byte(val.Array[3].Foo), + []byte{25, 0, 0, 0}, + ), + buf.Bytes(), + ) + + { + dec := NewBorshDecoder(buf.Bytes()) + var got StructWithEnum + require.NoError(t, dec.Decode(&got)) + { + val.ComplexPtrNotSet = &ComplexEnum{} + } + require.Equal(t, val, got) + } + } + } + } +} + +type StructWithEnum struct { + Simple Dummy + SimplePointer *Dummy + + Complex ComplexEnum + ComplexNotSet ComplexEnum + ComplexPtr *ComplexEnum + ComplexPtrNotSet *ComplexEnum + + ComplexEmpty ComplexEnumEmpty + ComplexPrimitives1 ComplexEnumPrimitives + ComplexPrimitives2 ComplexEnumPrimitives + + Complex2 ComplexEnumPointers + Complex2Ptr *ComplexEnumPointers + + Complex2PtrOptionalSet *ComplexEnumPointers `bin:"optional"` + Complex2PtrOptionalNotSet *ComplexEnumPointers `bin:"optional"` + + Map map[string]uint64 + Slice []Struct + Array [4]Struct +} + +type StructWithOptionalFields struct { + FooRequired *string + BarRequiredNotSet *uint32 + FooPointer *string `bin:"optional"` + FooPointerNotSet *string `bin:"optional"` + BarPointer *uint32 `bin:"optional"` + FooValue string `bin:"optional"` + BarValueNotSet uint32 `bin:"optional"` + Hello string +} + +type Struct struct { + Foo string + Bar uint32 +} +type StructWithPointerFields struct { + Foo *string + Bar *uint32 +} + +type AA struct { + A int64 + B int32 + C bool + D *bool `bin:"optional"` + E *uint64 `bin:"optional"` + // NOTE: multilevel pointers are not supported. + // DoublePointer **uint64 + + Map map[string]string + EmptyMap map[int64]string + // // NOTE: pointers to map are not supported. + // // PointerToMap *map[string]string + // // PointerToMapEmpty *map[string]string + Array [2]int64 + + Optional *Struct `bin:"optional"` + Value Struct + + InterfaceEncoderDecoderByValue CustomEncoding + InterfaceEncoderDecoderByPointer *CustomEncoding + + // InterfaceEncoderDecoderByValueEmpty CustomEncoding + // InterfaceEncoderDecoderByPointerEmpty *CustomEncoding `bin:"optional"` + + HighValuesInt64 []int64 + HighValuesUint64 []uint64 + HighValuesFloat64 []float64 +} + +type CustomEncoding struct { + Prefix byte + Value uint32 +} + +func (e CustomEncoding) MarshalWithEncoder(encoder *Encoder) error { + if err := encoder.WriteUint32(e.Value, LE); err != nil { + return err + } + return encoder.WriteByte(e.Prefix) +} + +func (e *CustomEncoding) UnmarshalWithDecoder(decoder *Decoder) (err error) { + if e.Value, err = decoder.ReadUint32(LE); err != nil { + return err + } + if e.Prefix, err = decoder.ReadByte(); err != nil { + return err + } + return nil +} + +var _ EncoderDecoder = &CustomEncoding{} + +func TestBorsh_kitchenSink(t *testing.T) { + boolTrue := true + uint64Num := uint64(25464132585) + x := AA{ + A: 1, + B: 32, + C: true, + D: &boolTrue, + E: &uint64Num, + Map: map[string]string{"foo": "bar"}, + Array: [2]int64{57, 88}, + Optional: &Struct{ + Foo: "optional foo", + Bar: 8888886, + }, + Value: Struct{ + Foo: "value foo", + Bar: 7777, + }, + InterfaceEncoderDecoderByValue: CustomEncoding{Prefix: byte('b'), Value: 72}, + InterfaceEncoderDecoderByPointer: &CustomEncoding{Prefix: byte('c'), Value: 9999}, + + HighValuesInt64: []int64{ + math.MaxInt8, + math.MaxInt16, + math.MaxInt32, + math.MaxInt64, + + -math.MaxInt8, + -math.MaxInt16, + -math.MaxInt32, + -math.MaxInt64, + + math.MaxUint8, + math.MaxUint16, + math.MaxUint32, + // math.MaxUint64, + + -math.MaxUint8, + -math.MaxUint16, + -math.MaxUint32, + // -math.MaxUint64, + }, + + HighValuesUint64: []uint64{ + math.MaxInt8, + math.MaxInt16, + math.MaxInt32, + math.MaxInt64, + + math.MaxUint8, + math.MaxUint16, + math.MaxUint32, + math.MaxUint64, + }, + + HighValuesFloat64: []float64{ + math.MaxFloat32, + math.MaxFloat64, + + -math.MaxFloat32, + -math.MaxFloat64, + }, + } + buf := NewWriteByWrite("") + borshEnc := NewBorshEncoder(buf) + err := borshEnc.Encode(x) + // fmt.Println(buf.String()) + require.NoError(t, err) + + y := new(AA) + err = UnmarshalBorsh(y, buf.Bytes()) + require.NoError(t, err) + require.Equal(t, x, *y) +} + +type A struct { + A int64 + B int32 + C bool + D *bool + E *uint64 +} + +func TestSimple(t *testing.T) { + boolTrue := true + uint64Num := uint64(25464132585) + x := A{ + A: 1, + B: 32, + C: true, + D: &boolTrue, + E: &uint64Num, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + y := new(A) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + require.Equal(t, x, *y) +} + +type B struct { + I8 int8 + I16 int16 + I32 int32 + I64 int64 + U8 uint8 + U16 uint16 + U32 uint32 + U64 uint64 + F32 float32 + F64 float64 + unexported int64 // unexported fields are skipped. + Err error // nil interfaces must be specified to be skipped. +} + +func TestBasic(t *testing.T) { + x := B{ + I8: 12, + I16: -1, + I32: 124, + I64: 1243, + U8: 1, + U16: 979, + U32: 123124, + U64: 1135351135, + F32: -231.23, + F64: 3121221.232, + unexported: 333, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + y := new(B) + + // expect the unexported field to be zero because + // it shouldn't have been encoded or be tried to be decoded: + x.unexported = 0 + + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + require.Equal(t, x, *y) +} + +type C struct { + A3 [3]int64 + S []int64 + P *int64 + M map[string]string +} + +func TestBasicContainer(t *testing.T) { + ip := new(int64) + *ip = 213 + x := C{ + A3: [3]int64{234, -123, 123}, + S: []int64{21442, 421241241, 2424}, + P: ip, + M: map[string]string{"foo": "bar"}, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(C) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + require.Equal(t, x, *y) +} + +type N struct { + B B + C C +} + +func TestNested(t *testing.T) { + ip := new(int64) + *ip = 213 + x := N{ + B: B{ + I8: 12, + I16: -1, + I32: 124, + I64: 1243, + U8: 1, + U16: 979, + U32: 123124, + U64: 1135351135, + F32: -231.23, + F64: 3121221.232, + }, + C: C{ + A3: [3]int64{234, -123, 123}, + S: []int64{21442, 421241241, 2424}, + P: ip, + M: map[string]string{"foo": "bar"}, + }, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(N) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + require.Equal(t, x, *y) +} + +type Dummy BorshEnum + +const ( + x Dummy = iota + y + z +) + +type D struct { + D Dummy +} + +func TestSimpleEnum(t *testing.T) { + x := D{ + D: y, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(D) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) +} + +type ComplexEnum struct { + Enum BorshEnum `borsh_enum:"true"` + Foo Foo + Bar Bar +} + +type ComplexEnumPointers struct { + Enum BorshEnum `borsh_enum:"true"` + Foo *Foo + Bar *Bar +} + +type ComplexEnumEmpty struct { + Enum BorshEnum `borsh_enum:"true"` + Foo EmptyVariant + Bar Bar +} + +type ComplexEnumPrimitives struct { + Enum BorshEnum `borsh_enum:"true"` + Foo uint32 + Bar int16 +} + +type Foo struct { + FooA int32 + FooB string +} + +type Bar struct { + BarA int64 + BarB string +} + +func TestComplexEnum(t *testing.T) { + { + x := ComplexEnum{ + Enum: 1, + Bar: Bar{ + BarA: 23, + BarB: "baz", + }, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(ComplexEnum) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } + { + x := ComplexEnumPointers{ + Enum: 1, + Bar: &Bar{ + BarA: 99999, + BarB: "hello world", + }, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(ComplexEnumPointers) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } + { + x := ComplexEnumEmpty{ + Enum: 1, + Bar: Bar{ + BarA: 23, + BarB: "baz", + }, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(ComplexEnumEmpty) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } + { + x := ComplexEnumPrimitives{ + Enum: 1, + Bar: 22, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(ComplexEnumPrimitives) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } +} + +type S struct { + S map[int64]struct{} +} + +func TestSet(t *testing.T) { + emptyStruct := struct{}{} + x := S{ + S: map[int64]struct{}{124: emptyStruct, 214: emptyStruct, 24: emptyStruct, 53: emptyStruct}, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(S) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + require.Equal(t, x, *y) +} + +type Skipped struct { + A int64 + B int64 `borsh_skip:"true"` + C int64 +} + +func TestSkipped(t *testing.T) { + x := Skipped{ + A: 32, + B: 535, + C: 123, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(Skipped) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x.A, y.A) + require.Equal(t, x.C, y.C) + require.NotEqual(t, y.B, x.B, "didn't skip field B") +} + +type E struct{} + +func TestEmpty(t *testing.T) { + x := E{} + data, err := MarshalBorsh(x) + require.NoError(t, err) + if len(data) != 0 { + t.Error("not empty") + } + y := new(E) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + require.Equal(t, x, *y) +} + +func testValue(t *testing.T, v interface{}) { + data, err := MarshalBorsh(v) + require.NoError(t, err) + + parsed := reflect.New(reflect.TypeOf(v)) + err = UnmarshalBorsh(parsed.Interface(), data) + require.NoError(t, err) + require.Equal(t, v, parsed.Elem().Interface()) +} + +func TestStrings(t *testing.T) { + tests := []struct { + in string + }{ + {""}, + {"a"}, + {"hellow world"}, + {strings2.Repeat("x", 1024)}, + {strings2.Repeat("x", 4096)}, + {strings2.Repeat("x", 65535)}, + {strings2.Repeat("hello world!", 1000)}, + {"🎯"}, + } + + for _, tt := range tests { + testValue(t, tt.in) + } +} + +func makeInt32Slice(val int32, len int) []int32 { + s := make([]int32, len) + for i := 0; i < len; i++ { + s[i] = val + } + return s +} + +func TestSlices(t *testing.T) { + tests := []struct { + in []int32 + }{ + {nil}, // zero length slice + {makeInt32Slice(1000000000, 1)}, + {makeInt32Slice(1000000001, 2)}, + {makeInt32Slice(1000000002, 3)}, + {makeInt32Slice(1000000003, 4)}, + {makeInt32Slice(1000000004, 8)}, + {makeInt32Slice(1000000005, 16)}, + {makeInt32Slice(1000000006, 32)}, + {makeInt32Slice(1000000007, 64)}, + {makeInt32Slice(1000000008, 65)}, + } + + for _, tt := range tests { + testValue(t, tt.in) + } +} + +func TestUint128_old(t *testing.T) { + tests := []struct { + in Int128 + }{ + {func() Int128 { + v := Int128{ + Hi: math.MaxInt16, + Lo: math.MaxInt16, + } + return v + }()}, + } + + for _, tt := range tests { + testValue(t, tt.in) + } +} + +type ( + Myu8 uint8 + Myu16 uint16 + Myu32 uint32 + Myu64 uint64 + Myi8 int8 + Myi16 int16 + Myi32 int32 + Myi64 int64 +) + +type CustomType struct { + U8 Myu8 + U16 Myu16 + U32 Myu32 + U64 Myu64 + I8 Myi8 + I16 Myi16 + I32 Myi32 + I64 Myi64 +} + +func TestCustomType(t *testing.T) { + x := CustomType{ + U8: 1, + U16: 2, + U32: 3, + U64: 4, + I8: 5, + I16: 6, + I32: 7, + I64: 8, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + y := new(CustomType) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) +} + +func TestStringSlice(t *testing.T) { + { + // slice: + x := []string{"a", "b", "c"} + data, err := MarshalBorsh(x) + require.NoError(t, err) + + require.Equal(t, concatByteSlices( + []byte{0x3, 0x0, 0x0, 0x0}, // length + + []byte{0x1, 0x0, 0x0, 0x0}, // length of first string + []byte("a"), + + []byte{0x1, 0x0, 0x0, 0x0}, // length of second string + []byte("b"), + + []byte{0x1, 0x0, 0x0, 0x0}, // length of third string + []byte("c"), + ), data) + + y := new([]string) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } + { + // string slice as field: + type S struct { + A []string + } + x := S{ + A: []string{"a", "b", "c"}, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + require.Equal(t, concatByteSlices( + []byte{0x3, 0x0, 0x0, 0x0}, // length of A + + []byte{0x1, 0x0, 0x0, 0x0}, // length of A[0] + []byte("a"), + + []byte{0x1, 0x0, 0x0, 0x0}, // length of A[1] + []byte("b"), + + []byte{0x1, 0x0, 0x0, 0x0}, // length of A[2] + []byte("c"), + ), data) + + y := new(S) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } + { + // string slice as optional field (present): + type S struct { + A *[]string `bin:"optional"` + } + slice := []string{"a", "b", "c"} + x := S{ + A: &slice, + } + data, err := MarshalBorsh(x) + require.NoError(t, err) + + require.Equal(t, concatByteSlices( + []byte{0x01}, // optionality + []byte{0x3, 0x0, 0x0, 0x0}, // slice length + + []byte{0x1, 0x0, 0x0, 0x0}, // slice item length (string) + []byte("a"), + + []byte{0x1, 0x0, 0x0, 0x0}, // slice item length (string) + []byte("b"), + + []byte{0x1, 0x0, 0x0, 0x0}, // slice item length (string) + []byte("c"), + ), data) + + y := new(S) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } + { + // string slice as optional field (absent): + type S struct { + A *[]string `bin:"optional"` + } + x := S{} + data, err := MarshalBorsh(x) + require.NoError(t, err) + + require.Equal(t, concatByteSlices( + []byte{0x0}, // optionality + ), data) + + y := new(S) + err = UnmarshalBorsh(y, data) + require.NoError(t, err) + + require.Equal(t, x, *y) + } +} diff --git a/binary/compact-u16.go b/binary/compact-u16.go new file mode 100644 index 000000000..5caee0862 --- /dev/null +++ b/binary/compact-u16.go @@ -0,0 +1,138 @@ +// Copyright 2021 github.com/gagliardetto +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "fmt" + "io" + "math" +) + +// EncodeCompactU16Length encodes a "Compact-u16" length into the provided slice pointer. +// See https://docs.solana.com/developing/programming-model/transactions#compact-u16-format +// See https://github.com/solana-labs/solana/blob/2ef2b6daa05a7cff057e9d3ef95134cee3e4045d/web3.js/src/util/shortvec-encoding.ts +func EncodeCompactU16Length(buf *[]byte, ln int) error { + if ln < 0 || ln > math.MaxUint16 { + return fmt.Errorf("length %d out of range", ln) + } + u := uint(ln) + switch { + case u < 0x80: + *buf = append(*buf, byte(u)) + case u < 0x4000: + *buf = append(*buf, byte(u)|0x80, byte(u>>7)) + default: + *buf = append(*buf, byte(u)|0x80, byte(u>>7)|0x80, byte(u>>14)) + } + return nil +} + +// PutCompactU16Length writes a "Compact-u16" length into dst and returns the +// number of bytes written (1, 2, or 3). dst must be at least 3 bytes long. +// This is the allocation-free variant of EncodeCompactU16Length, used by the +// Encoder's scratch-buffer hot path. +func PutCompactU16Length(dst []byte, ln int) (int, error) { + if ln < 0 || ln > math.MaxUint16 { + return 0, fmt.Errorf("length %d out of range", ln) + } + u := uint(ln) + switch { + case u < 0x80: + dst[0] = byte(u) + return 1, nil + case u < 0x4000: + dst[0] = byte(u) | 0x80 + dst[1] = byte(u >> 7) + return 2, nil + default: + dst[0] = byte(u) | 0x80 + dst[1] = byte(u>>7) | 0x80 + dst[2] = byte(u >> 14) + return 3, nil + } +} + +const _MAX_COMPACTU16_ENCODING_LENGTH = 3 + +// DecodeCompactU16 decodes a Solana "Compact-u16" length from bytes and returns +// (value, bytes_consumed, error). Hand-unrolled for the max 3-byte encoding to +// avoid a per-iteration loop overhead. +func DecodeCompactU16(bytes []byte) (int, int, error) { + if len(bytes) == 0 { + return 0, 0, io.ErrUnexpectedEOF + } + b0 := int(bytes[0]) + if b0&0x80 == 0 { + return b0, 1, nil + } + if len(bytes) < 2 { + return 0, 0, io.ErrUnexpectedEOF + } + b1 := int(bytes[1]) + if b1&0x80 == 0 { + if b1 == 0 { + return 0, 0, fmt.Errorf("compact-u16: non-canonical 2-byte encoding (trailing zero byte)") + } + return (b0 & 0x7f) | (b1 << 7), 2, nil + } + if len(bytes) < 3 { + return 0, 0, io.ErrUnexpectedEOF + } + b2 := int(bytes[2]) + if b2 == 0 { + return 0, 0, fmt.Errorf("compact-u16: non-canonical 3-byte encoding (trailing zero byte)") + } + if b2&0x80 != 0 { + return 0, 0, fmt.Errorf("byte three continues") + } + ln := (b0 & 0x7f) | ((b1 & 0x7f) << 7) | (b2 << 14) + if ln > math.MaxUint16 { + return 0, 0, fmt.Errorf("invalid length: %d", ln) + } + return ln, 3, nil +} + +// DecodeCompactU16LengthFromByteReader decodes a "Compact-u16" length from the provided io.ByteReader. +func DecodeCompactU16LengthFromByteReader(reader io.ByteReader) (int, error) { + ln := 0 + size := 0 + for nthByte := range _MAX_COMPACTU16_ENCODING_LENGTH { + elemByte, err := reader.ReadByte() + if err != nil { + return 0, err + } + elem := int(elemByte) + if elem == 0 && nthByte != 0 { + return 0, fmt.Errorf("compact-u16: non-canonical encoding (trailing zero byte at position %d)", nthByte) + } + if nthByte == _MAX_COMPACTU16_ENCODING_LENGTH-1 && (elem&0x80) != 0 { + return 0, fmt.Errorf("compact-u16: byte three has continuation bit set") + } + ln |= (elem & 0x7f) << (size * 7) + size += 1 + if (elem & 0x80) == 0 { + break + } + } + // check for non-valid sizes + if size == 0 || size > _MAX_COMPACTU16_ENCODING_LENGTH { + return 0, fmt.Errorf("compact-u16: invalid size: %d", size) + } + // check for non-valid lengths + if ln < 0 || ln > math.MaxUint16 { + return 0, fmt.Errorf("compact-u16: invalid length: %d", ln) + } + return ln, nil +} diff --git a/binary/compact-u16_test.go b/binary/compact-u16_test.go new file mode 100644 index 000000000..077b88d88 --- /dev/null +++ b/binary/compact-u16_test.go @@ -0,0 +1,221 @@ +// Copyright 2021 github.com/gagliardetto +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "bytes" + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCompactU16(t *testing.T) { + candidates := []int{0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 100, 1000, 10000, math.MaxUint16 - 1, math.MaxUint16} + for _, val := range candidates { + if val < 0 || val > math.MaxUint16 { + panic("value too large") + } + buf := make([]byte, 0) + require.NoError(t, EncodeCompactU16Length(&buf, val)) + + buf = append(buf, []byte("hello world")...) + decoded, _, err := DecodeCompactU16(buf) + require.NoError(t, err) + + require.Equal(t, val, decoded) + } + for _, val := range candidates { + buf := make([]byte, 0) + EncodeCompactU16Length(&buf, val) + + buf = append(buf, []byte("hello world")...) + { + decoded, err := DecodeCompactU16LengthFromByteReader(bytes.NewReader(buf)) + require.NoError(t, err) + require.Equal(t, val, decoded) + } + { + decoded, _, err := DecodeCompactU16(buf) + require.NoError(t, err) + require.Equal(t, val, decoded) + } + } + { + // now test all from 0 to 0xffff + for i := 0; i < math.MaxUint16; i++ { + buf := make([]byte, 0) + EncodeCompactU16Length(&buf, i) + + buf = append(buf, []byte("hello world")...) + { + decoded, err := DecodeCompactU16LengthFromByteReader(bytes.NewReader(buf)) + require.NoError(t, err) + require.Equal(t, i, decoded) + } + { + decoded, _, err := DecodeCompactU16(buf) + require.NoError(t, err) + require.Equal(t, i, decoded) + } + } + } +} + +func BenchmarkCompactU16(b *testing.B) { + // generate 1000 random values + candidates := make([]int, 1000) + for i := 0; i < 1000; i++ { + candidates[i] = i + } + + buf := make([]byte, 0) + EncodeCompactU16Length(&buf, math.MaxUint16) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, _ = DecodeCompactU16(buf) + } +} + +func BenchmarkCompactU16Encode(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf := make([]byte, 0) + EncodeCompactU16Length(&buf, math.MaxUint16) + } +} + +func BenchmarkCompactU16Reader(b *testing.B) { + // generate 1000 random values + candidates := make([]int, 1000) + for i := 0; i < 1000; i++ { + candidates[i] = i + } + + buf := make([]byte, 0) + EncodeCompactU16Length(&buf, math.MaxUint16) + + reader := NewBorshDecoder(buf) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out, _ := reader.ReadCompactU16() + if out != math.MaxUint16 { + panic("not equal") + } + reader.SetPosition(0) + } +} + +func encode_len(len uint16) []byte { + buf := make([]byte, 0) + err := EncodeCompactU16Length(&buf, int(len)) + if err != nil { + panic(err) + } + return buf +} + +func assert_len_encoding(t *testing.T, len uint16, buf []byte) { + require.Equal(t, encode_len(len), buf, "unexpected usize encoding") + decoded, _, err := DecodeCompactU16(buf) + require.NoError(t, err) + require.Equal(t, int(len), decoded) + { + // now try with a reader + reader := bytes.NewReader(buf) + out, _ := DecodeCompactU16LengthFromByteReader(reader) + require.Equal(t, int(len), out) + } +} + +func TestShortVecEncodeLen(t *testing.T) { + assert_len_encoding(t, 0x0, []byte{0x0}) + assert_len_encoding(t, 0x7f, []byte{0x7f}) + assert_len_encoding(t, 0x80, []byte{0x80, 0x01}) + assert_len_encoding(t, 0xff, []byte{0xff, 0x01}) + assert_len_encoding(t, 0x100, []byte{0x80, 0x02}) + assert_len_encoding(t, 0x7fff, []byte{0xff, 0xff, 0x01}) + assert_len_encoding(t, 0xffff, []byte{0xff, 0xff, 0x03}) +} + +func assert_good_deserialized_value(t *testing.T, value uint16, buf []byte) { + decoded, _, err := DecodeCompactU16(buf) + require.NoError(t, err) + require.Equal(t, int(value), decoded) + { + // now try with a reader + reader := bytes.NewReader(buf) + out, _ := DecodeCompactU16LengthFromByteReader(reader) + require.Equal(t, int(value), out) + } +} + +func assert_bad_deserialized_value(t *testing.T, buf []byte) { + _, _, err := DecodeCompactU16(buf) + require.Error(t, err, "expected an error for bytes: %v", buf) + { + // now try with a reader + reader := bytes.NewReader(buf) + _, err := DecodeCompactU16LengthFromByteReader(reader) + require.Error(t, err, "expected an error for bytes: %v", buf) + } +} + +func TestDeserialize(t *testing.T) { + assert_good_deserialized_value(t, 0x0000, []byte{0x00}) + assert_good_deserialized_value(t, 0x007f, []byte{0x7f}) + assert_good_deserialized_value(t, 0x0080, []byte{0x80, 0x01}) + assert_good_deserialized_value(t, 0x00ff, []byte{0xff, 0x01}) + assert_good_deserialized_value(t, 0x0100, []byte{0x80, 0x02}) + assert_good_deserialized_value(t, 0x07ff, []byte{0xff, 0x0f}) + assert_good_deserialized_value(t, 0x3fff, []byte{0xff, 0x7f}) + assert_good_deserialized_value(t, 0x4000, []byte{0x80, 0x80, 0x01}) + assert_good_deserialized_value(t, 0xffff, []byte{0xff, 0xff, 0x03}) + + // aliases + // 0x0000 + assert_bad_deserialized_value(t, []byte{0x80, 0x00}) + assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x00}) + // 0x007f + assert_bad_deserialized_value(t, []byte{0xff, 0x00}) + assert_bad_deserialized_value(t, []byte{0xff, 0x80, 0x00}) + // 0x0080 + assert_bad_deserialized_value(t, []byte{0x80, 0x81, 0x00}) + // 0x00ff + assert_bad_deserialized_value(t, []byte{0xff, 0x81, 0x00}) + // 0x0100 + assert_bad_deserialized_value(t, []byte{0x80, 0x82, 0x00}) + // 0x07ff + assert_bad_deserialized_value(t, []byte{0xff, 0x8f, 0x00}) + // 0x3fff + assert_bad_deserialized_value(t, []byte{0xff, 0xff, 0x00}) + + // too short + assert_bad_deserialized_value(t, []byte{}) + assert_bad_deserialized_value(t, []byte{0x80}) + + // too long + assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x80, 0x00}) + + // too large + // 0x0001_0000 + assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x04}) + // 0x0001_8000 + assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x06}) +} diff --git a/binary/decoder.go b/binary/decoder.go new file mode 100644 index 000000000..f8801e2e3 --- /dev/null +++ b/binary/decoder.go @@ -0,0 +1,1027 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "reflect" + "strings" + "unsafe" + + "go.uber.org/zap" +) + +// isHostLittleEndian is true on little-endian platforms (amd64, arm64, ...). +// When the requested byte order matches the host, slices of fixed-width +// integers can be read via a single memcpy-style copy from the input buffer +// instead of decoding element-by-element. +var isHostLittleEndian = func() bool { + var x uint16 = 1 + return *(*byte)(unsafe.Pointer(&x)) == 1 +}() + +var TypeSize = struct { + Bool int + Byte int + + Int8 int + Int16 int + + Uint8 int + Uint16 int + Uint32 int + Uint64 int + Uint128 int + + Float32 int + Float64 int + + PublicKey int + Signature int +}{ + Byte: 1, + Bool: 1, + + Int8: 1, + Int16: 2, + + Uint8: 1, + Uint16: 2, + Uint32: 4, + Uint64: 8, + Uint128: 16, + + Float32: 4, + Float64: 8, + + PublicKey: 32, + Signature: 64, +} + +type Decoder struct { + data []byte + pos int + + // currentFieldOpt is the per-field option of the most recent decode call. + // Held by value (not pointer) so it doesn't escape to the heap. The Order + // field is consulted by deeply-nested types (e.g. Uint128) to find the + // active byte order; defaultByteOrder is used when Order is nil. + currentFieldOpt option + + encoding Encoding +} + +// Reset resets the decoder to decode a new message. +func (dec *Decoder) Reset(data []byte) { + dec.data = data + dec.pos = 0 + dec.currentFieldOpt = option{} +} + +func (dec *Decoder) IsBorsh() bool { + return dec.encoding.IsBorsh() +} + +func (dec *Decoder) IsBin() bool { + return dec.encoding.IsBin() +} + +func (dec *Decoder) IsCompactU16() bool { + return dec.encoding.IsCompactU16() +} + +func NewDecoderWithEncoding(data []byte, enc Encoding) *Decoder { + if !isValidEncoding(enc) { + panic(fmt.Sprintf("provided encoding is not valid: %s", enc)) + } + return &Decoder{ + data: data, + encoding: enc, + } +} + +// SetEncoding sets the encoding scheme to use for decoding. +func (dec *Decoder) SetEncoding(enc Encoding) { + dec.encoding = enc +} + +func NewBinDecoder(data []byte) *Decoder { + return NewDecoderWithEncoding(data, EncodingBin) +} + +func NewBorshDecoder(data []byte) *Decoder { + return NewDecoderWithEncoding(data, EncodingBorsh) +} + +func NewCompactU16Decoder(data []byte) *Decoder { + return NewDecoderWithEncoding(data, EncodingCompactU16) +} + +func (dec *Decoder) Decode(v interface{}) (err error) { + switch dec.encoding { + case EncodingBin: + return dec.decodeWithOptionBin(v, defaultOption) + case EncodingBorsh: + return dec.decodeWithOptionBorsh(v, defaultOption) + case EncodingCompactU16: + return dec.decodeWithOptionCompactU16(v, defaultOption) + default: + panic(fmt.Errorf("encoding not implemented: %s", dec.encoding)) + } +} + +func sizeof(t reflect.Type, v reflect.Value) int { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return int(v.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n := int(v.Uint()) + // all the builtin array length types are native int + // so this guards against weird truncation + if n < 0 { + return 0 + } + return n + default: + panic(fmt.Sprintf("sizeof field not implemented for kind %s", t.Kind())) + } +} + +var ErrVarIntBufferSize = errors.New("varint: invalid buffer size") + +func (dec *Decoder) ReadUvarint64() (uint64, error) { + l, read := binary.Uvarint(dec.data[dec.pos:]) + if read <= 0 { + return l, ErrVarIntBufferSize + } + if traceEnabled { + zlog.Debug("decode: read uvarint64", zap.Uint64("val", l)) + } + dec.pos += read + return l, nil +} + +func (d *Decoder) ReadVarint64() (out int64, err error) { + l, read := binary.Varint(d.data[d.pos:]) + if read <= 0 { + return l, ErrVarIntBufferSize + } + if traceEnabled { + zlog.Debug("decode: read varint", zap.Int64("val", l)) + } + d.pos += read + return l, nil +} + +func (dec *Decoder) ReadVarint32() (out int32, err error) { + n, err := dec.ReadVarint64() + if err != nil { + return out, err + } + out = int32(n) + if traceEnabled { + zlog.Debug("decode: read varint32", zap.Int32("val", out)) + } + return +} + +func (dec *Decoder) ReadUvarint32() (out uint32, err error) { + n, err := dec.ReadUvarint64() + if err != nil { + return out, err + } + out = uint32(n) + if traceEnabled { + zlog.Debug("decode: read uvarint32", zap.Uint32("val", out)) + } + return +} + +func (dec *Decoder) ReadVarint16() (out int16, err error) { + n, err := dec.ReadVarint64() + if err != nil { + return out, err + } + out = int16(n) + if traceEnabled { + zlog.Debug("decode: read varint16", zap.Int16("val", out)) + } + return +} + +func (dec *Decoder) ReadUvarint16() (out uint16, err error) { + n, err := dec.ReadUvarint64() + if err != nil { + return out, err + } + out = uint16(n) + if traceEnabled { + zlog.Debug("decode: read uvarint16", zap.Uint16("val", out)) + } + return +} + +// ReadByteSlice reads a length-prefixed byte slice from the decoder. The +// returned slice is an independent copy; the caller may retain it and mutate +// it without affecting the decoder's input buffer. For the zero-copy variant, +// see ReadByteSliceBorrow. +func (dec *Decoder) ReadByteSlice() (out []byte, err error) { + borrowed, err := dec.ReadByteSliceBorrow() + if err != nil { + return nil, err + } + if len(borrowed) == 0 { + return nil, nil + } + out = make([]byte, len(borrowed)) + copy(out, borrowed) + return out, nil +} + +// ReadByteSliceBorrow is the zero-copy variant of ReadByteSlice. The returned +// slice aliases the decoder's input buffer and is only safe to use while that +// buffer is alive and unmodified. Use ReadByteSlice when you need an owned +// copy you can retain. +func (dec *Decoder) ReadByteSliceBorrow() (out []byte, err error) { + length, err := dec.ReadLength() + if err != nil { + return nil, err + } + + if len(dec.data) < dec.pos+length { + return nil, fmt.Errorf("byte array: varlen=%d, missing %d bytes", length, dec.pos+length-len(dec.data)) + } + + out = dec.data[dec.pos : dec.pos+length] + dec.pos += length + if traceEnabled { + zlog.Debug("decode: read byte array", zap.Stringer("hex", HexBytes(out))) + } + return +} + +func (dec *Decoder) ReadLength() (length int, err error) { + switch dec.encoding { + case EncodingBin: + val, err := dec.ReadUvarint64() + if err != nil { + return 0, err + } + if val > 0x7FFF_FFFF { + return 0, io.ErrUnexpectedEOF + } + length = int(val) + case EncodingBorsh: + val, err := dec.ReadUint32(LE) + if err != nil { + return 0, err + } + if val > 0x7FFF_FFFF { + return 0, io.ErrUnexpectedEOF + } + length = int(val) + case EncodingCompactU16: + val, err := dec.ReadCompactU16() + if err != nil { + return 0, err + } + length = val + default: + panic(fmt.Errorf("encoding not implemented: %s", dec.encoding)) + } + return +} + +func readNBytes(n int, reader *Decoder) ([]byte, error) { + if n == 0 { + return make([]byte, 0), nil + } + if n < 0 || n > 0x7FFF_FFFF { + return nil, fmt.Errorf("invalid length n: %v", n) + } + if reader.pos+n > len(reader.data) { + return nil, fmt.Errorf("not enough data: %d bytes missing", reader.pos+n-len(reader.data)) + } + out := reader.data[reader.pos : reader.pos+n] + reader.pos += n + return out, nil +} + +func discardNBytes(n int, reader *Decoder) error { + if n == 0 { + return nil + } + if n < 0 || n > 0x7FFF_FFFF { + return fmt.Errorf("invalid length n: %v", n) + } + return reader.SkipBytes(uint(n)) +} + +func (d *Decoder) Read(buf []byte) (int, error) { + if d.pos+len(buf) > len(d.data) { + return 0, io.ErrShortBuffer + } + numCopied := copy(buf, d.data[d.pos:]) + d.pos += numCopied + // must read exactly len(buf) bytes + if numCopied != len(buf) { + return 0, io.ErrUnexpectedEOF + } + return len(buf), nil +} + +func (dec *Decoder) ReadNBytes(n int) (out []byte, err error) { + return readNBytes(n, dec) +} + +// ReadBytes reads a byte slice of length n. +func (dec *Decoder) ReadBytes(n int) (out []byte, err error) { + return readNBytes(n, dec) +} + +func (dec *Decoder) Discard(n int) (err error) { + return discardNBytes(n, dec) +} + +func (dec *Decoder) ReadTypeID() (out TypeID, err error) { + discriminator, err := dec.ReadNBytes(8) + if err != nil { + return TypeID{}, err + } + return TypeIDFromBytes(discriminator), nil +} + +func (dec *Decoder) ReadDiscriminator() (out TypeID, err error) { + return dec.ReadTypeID() +} + +func (dec *Decoder) PeekDiscriminator() (out TypeID, err error) { + discriminator, err := dec.Peek(8) + if err != nil { + return TypeID{}, err + } + return TypeIDFromBytes(discriminator), nil +} + +func (dec *Decoder) Peek(n int) (out []byte, err error) { + if n < 0 { + err = fmt.Errorf("n not valid: %d", n) + return + } + + requiredSize := TypeSize.Byte * n + if dec.Remaining() < requiredSize { + err = fmt.Errorf("required [%d] bytes, remaining [%d]", requiredSize, dec.Remaining()) + return + } + + out = dec.data[dec.pos : dec.pos+n] + if traceEnabled { + zlog.Debug("decode: peek", zap.Int("n", n), zap.Binary("out", out)) + } + return +} + +// ReadCompactU16 reads a compact u16 from the decoder. +func (dec *Decoder) ReadCompactU16() (out int, err error) { + out, size, err := DecodeCompactU16(dec.data[dec.pos:]) + if traceEnabled { + zlog.Debug("decode: read compact u16", zap.Int("val", out)) + } + dec.pos += size + return out, err +} + +func (dec *Decoder) ReadOption() (out bool, err error) { + b, err := dec.ReadByte() + if err != nil { + return false, fmt.Errorf("decode: read option, %w", err) + } + out = b != 0 + if traceEnabled { + zlog.Debug("decode: read option", zap.Bool("val", out)) + } + return +} + +func (dec *Decoder) ReadCOption() (out bool, err error) { + b, err := dec.ReadUint32(LE) + if err != nil { + return false, fmt.Errorf("decode: read c-option, %w", err) + } + if b > 1 { + return false, fmt.Errorf("decode: read c-option, invalid value: %d", b) + } + out = b != 0 + if traceEnabled { + zlog.Debug("decode: read c-option", zap.Bool("val", out)) + } + return +} + +func (dec *Decoder) ReadByte() (out byte, err error) { + if dec.Remaining() < TypeSize.Byte { + err = fmt.Errorf("required [1] byte, remaining [%d]", dec.Remaining()) + return + } + + out = dec.data[dec.pos] + dec.pos++ + if traceEnabled { + zlog.Debug("decode: read byte", zap.Uint8("byte", out), zap.String("hex", hex.EncodeToString([]byte{out}))) + } + return +} + +func (dec *Decoder) ReadBool() (out bool, err error) { + if dec.Remaining() < TypeSize.Bool { + err = fmt.Errorf("bool required [%d] byte, remaining [%d]", TypeSize.Bool, dec.Remaining()) + return + } + + b, err := dec.ReadByte() + if err != nil { + err = fmt.Errorf("readBool: %w", err) + } + out = b != 0 + if traceEnabled { + zlog.Debug("decode: read bool", zap.Bool("val", out)) + } + return +} + +func (dec *Decoder) ReadUint8() (out uint8, err error) { + out, err = dec.ReadByte() + return +} + +func (dec *Decoder) ReadInt8() (out int8, err error) { + b, err := dec.ReadByte() + out = int8(b) + if traceEnabled { + zlog.Debug("decode: read int8", zap.Int8("val", out)) + } + return +} + +func (dec *Decoder) ReadUint16(order binary.ByteOrder) (out uint16, err error) { + if dec.Remaining() < TypeSize.Uint16 { + err = fmt.Errorf("uint16 required [%d] bytes, remaining [%d]", TypeSize.Uint16, dec.Remaining()) + return + } + + out = order.Uint16(dec.data[dec.pos:]) + dec.pos += TypeSize.Uint16 + if traceEnabled { + zlog.Debug("decode: read uint16", zap.Uint16("val", out)) + } + return +} + +func (dec *Decoder) ReadInt16(order binary.ByteOrder) (out int16, err error) { + n, err := dec.ReadUint16(order) + out = int16(n) + if traceEnabled { + zlog.Debug("decode: read int16", zap.Int16("val", out)) + } + return +} + +func (dec *Decoder) ReadUint32(order binary.ByteOrder) (out uint32, err error) { + if dec.Remaining() < TypeSize.Uint32 { + err = fmt.Errorf("uint32 required [%d] bytes, remaining [%d]", TypeSize.Uint32, dec.Remaining()) + return + } + + out = order.Uint32(dec.data[dec.pos:]) + dec.pos += TypeSize.Uint32 + if traceEnabled { + zlog.Debug("decode: read uint32", zap.Uint32("val", out)) + } + return +} + +func (dec *Decoder) ReadInt32(order binary.ByteOrder) (out int32, err error) { + n, err := dec.ReadUint32(order) + out = int32(n) + if traceEnabled { + zlog.Debug("decode: read int32", zap.Int32("val", out)) + } + return +} + +func (dec *Decoder) ReadUint64(order binary.ByteOrder) (out uint64, err error) { + if dec.Remaining() < TypeSize.Uint64 { + err = fmt.Errorf("decode: uint64 required [%d] bytes, remaining [%d]", TypeSize.Uint64, dec.Remaining()) + return + } + + out = order.Uint64(dec.data[dec.pos:]) + dec.pos += TypeSize.Uint64 + if traceEnabled { + zlog.Debug("decode: read uint64", zap.Uint64("val", out)) + } + return +} + +func (dec *Decoder) ReadInt64(order binary.ByteOrder) (out int64, err error) { + n, err := dec.ReadUint64(order) + out = int64(n) + if traceEnabled { + zlog.Debug("decode: read int64", zap.Int64("val", out)) + } + return +} + +func (dec *Decoder) ReadUint128(order binary.ByteOrder) (out Uint128, err error) { + if dec.Remaining() < TypeSize.Uint128 { + err = fmt.Errorf("uint128 required [%d] bytes, remaining [%d]", TypeSize.Uint128, dec.Remaining()) + return + } + + data := dec.data[dec.pos : dec.pos+TypeSize.Uint128] + + if order == binary.LittleEndian { + out.Hi = order.Uint64(data[8:]) + out.Lo = order.Uint64(data[:8]) + } else { + // TODO: is this correct? + out.Hi = order.Uint64(data[:8]) + out.Lo = order.Uint64(data[8:]) + } + + dec.pos += TypeSize.Uint128 + if traceEnabled { + zlog.Debug("decode: read uint128", zap.Stringer("hex", out), zap.Uint64("hi", out.Hi), zap.Uint64("lo", out.Lo)) + } + return +} + +func (dec *Decoder) ReadInt128(order binary.ByteOrder) (out Int128, err error) { + v, err := dec.ReadUint128(order) + if err != nil { + return + } + return Int128(v), nil +} + +func (dec *Decoder) ReadFloat32(order binary.ByteOrder) (out float32, err error) { + if dec.Remaining() < TypeSize.Float32 { + err = fmt.Errorf("float32 required [%d] bytes, remaining [%d]", TypeSize.Float32, dec.Remaining()) + return + } + + n := order.Uint32(dec.data[dec.pos:]) + out = math.Float32frombits(n) + dec.pos += TypeSize.Float32 + if traceEnabled { + zlog.Debug("decode: read float32", zap.Float32("val", out)) + } + + if dec.IsBorsh() { + if math.IsNaN(float64(out)) { + return 0, errors.New("NaN for float not allowed") + } + } + return +} + +func (dec *Decoder) ReadFloat64(order binary.ByteOrder) (out float64, err error) { + if dec.Remaining() < TypeSize.Float64 { + err = fmt.Errorf("float64 required [%d] bytes, remaining [%d]", TypeSize.Float64, dec.Remaining()) + return + } + + n := order.Uint64(dec.data[dec.pos:]) + out = math.Float64frombits(n) + dec.pos += TypeSize.Float64 + if traceEnabled { + zlog.Debug("decode: read Float64", zap.Float64("val", out)) + } + if dec.IsBorsh() { + if math.IsNaN(out) { + return 0, errors.New("NaN for float not allowed") + } + } + return +} + +func (dec *Decoder) ReadFloat128(order binary.ByteOrder) (out Float128, err error) { + value, err := dec.ReadUint128(order) + if err != nil { + return out, fmt.Errorf("float128: %w", err) + } + return Float128(value), nil +} + +// SafeReadUTF8String reads a length-prefixed byte slice and returns it as a +// string with any invalid UTF-8 sequences replaced by the Unicode replacement +// character (U+FFFD). Use when the input is untrusted and may contain +// non-UTF-8 bytes you'd rather sanitize than reject. +func (dec *Decoder) SafeReadUTF8String() (out string, err error) { + data, err := dec.ReadByteSliceBorrow() + if err != nil { + return "", err + } + out = strings.ToValidUTF8(string(data), "\uFFFD") + if traceEnabled { + zlog.Debug("read safe UTF8 string", zap.String("val", out)) + } + return +} + +func (dec *Decoder) ReadString() (out string, err error) { + // Borrow and let `string(...)` do the copy — avoids the double-copy of + // ReadByteSlice followed by string(). + data, err := dec.ReadByteSliceBorrow() + out = string(data) + if traceEnabled { + zlog.Debug("read string", zap.String("val", out)) + } + return +} + +// ReadStringBorrow returns a string that aliases the decoder's input buffer +// without copying. The returned string is only safe to use while the +// decoder's underlying []byte stays alive and unmodified — typically that +// means as long as the source buffer outlives the call site. Use ReadString +// when you need a copy you can retain across the buffer's lifetime. +// +// This is the wincode-style zero-copy fast path. It is allocation-free. +func (dec *Decoder) ReadStringBorrow() (string, error) { + data, err := dec.ReadByteSliceBorrow() + if err != nil { + return "", err + } + if len(data) == 0 { + return "", nil + } + return unsafe.String(&data[0], len(data)), nil +} + +func (dec *Decoder) ReadRustString() (out string, err error) { + length, err := dec.ReadUint64(binary.LittleEndian) + if err != nil { + return "", err + } + if length > 0x7FFF_FFFF { + return "", io.ErrUnexpectedEOF + } + bytes, err := dec.ReadNBytes(int(length)) + if err != nil { + return "", err + } + out = string(bytes) + if traceEnabled { + zlog.Debug("read Rust string", zap.String("val", out)) + } + return +} + +// ReadRustStringBorrow is the zero-copy variant of ReadRustString. Same +// lifetime caveats as ReadStringBorrow apply: the returned string aliases the +// decoder's input buffer. +func (dec *Decoder) ReadRustStringBorrow() (string, error) { + length, err := dec.ReadUint64(binary.LittleEndian) + if err != nil { + return "", err + } + if length > 0x7FFF_FFFF { + return "", io.ErrUnexpectedEOF + } + bytes, err := dec.ReadNBytes(int(length)) + if err != nil { + return "", err + } + if len(bytes) == 0 { + return "", nil + } + return unsafe.String(&bytes[0], len(bytes)), nil +} + +func (dec *Decoder) ReadCompactU16Length() (int, error) { + return dec.ReadCompactU16() +} + +func (dec *Decoder) SkipBytes(count uint) error { + if uint(dec.Remaining()) < count { + return fmt.Errorf("request to skip %d but only %d bytes remain", count, dec.Remaining()) + } + dec.pos += int(count) + return nil +} + +func (dec *Decoder) SetPosition(idx uint) error { + if idx > uint(len(dec.data)) { + return fmt.Errorf("request to set position to %d outside of buffer (buffer size %d)", idx, len(dec.data)) + } + dec.pos = int(idx) + return nil +} + +func (dec *Decoder) Position() uint { + return uint(dec.pos) +} + +func (dec *Decoder) Remaining() int { + return len(dec.data) - dec.pos +} + +func (dec *Decoder) Len() int { + return len(dec.data) +} + +func (dec *Decoder) HasRemaining() bool { + return dec.Remaining() > 0 +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// if it encounters an Unmarshaler, indirect stops and returns that. +// if decodingNull is true, indirect stops at the last pointer so it can be set to nil. +// +// *Note* This is a copy of `encoding/json/decoder.go#indirect` of Golang 1.14. +// +// See here: https://github.com/golang/go/blob/go1.14.2/src/encoding/json/decode.go#L439 +func indirect(v reflect.Value, decodingNull bool) (BinaryUnmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + haveAddr = false + v = e + continue + } + } + + if v.Kind() != reflect.Ptr { + break + } + + if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 && v.CanInterface() { + if u, ok := v.Interface().(BinaryUnmarshaler); ok { + return u, reflect.Value{} + } + } + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } + } + return nil, v +} + +func reflect_readArrayOfBytes(d *Decoder, l int, rv reflect.Value) error { + buf, err := d.ReadNBytes(l) + if err != nil { + return err + } + switch rv.Kind() { + case reflect.Array: + // if the type of the array is not [n]uint8, but a custom type like [n]CustomUint8: + if rv.Type().Elem() != typeOfUint8 { + // if the type of the array is not [n]uint8, but a custom type like [n]CustomUint8: + // then we need to convert each uint8 to the custom type + for i := range l { + rv.Index(i).Set(reflect.ValueOf(buf[i]).Convert(rv.Index(i).Type())) + } + } else { + reflect.Copy(rv, reflect.ValueOf(buf)) + } + case reflect.Slice: + // if the type of the slice is not []uint8, but a custom type like []CustomUint8: + if rv.Type().Elem() != typeOfUint8 { + // convert the []uint8 to the custom type + customSlice := reflect.MakeSlice(rv.Type(), len(buf), len(buf)) + for i := range len(buf) { + customSlice.Index(i).SetUint(uint64(buf[i])) + } + rv.Set(customSlice) + } else { + rv.Set(reflect.ValueOf(buf)) + } + default: + return fmt.Errorf("unsupported kind: %s", rv.Kind()) + } + return nil +} + +// podSliceReadTarget returns a destination reflect.Value for a slice-or-array +// fixed-width integer read. For arrays it returns rv directly (in-place). For +// slices, when the caller has pre-allocated a slice with sufficient capacity, +// the existing backing array is reused (length is reset via SetLen) — the +// only allocation-free path for hot decode loops. Otherwise a fresh slice of +// length l is allocated. +// +// The caller does not need to call rv.Set when capacity is reused. +func podSliceReadTarget(rv reflect.Value, l int) (reflect.Value, error) { + switch rv.Kind() { + case reflect.Array: + return rv, nil + case reflect.Slice: + if rv.CanSet() && rv.Cap() >= l { + rv.SetLen(l) + return rv, nil + } + return reflect.MakeSlice(rv.Type(), l, l), nil + default: + return reflect.Value{}, fmt.Errorf("unsupported kind: %s", rv.Kind()) + } +} + +// readPoDSliceBytes is the zero-copy-ish fast path for decoding a slice or +// array of fixed-width integers whose element Kind is uint16/uint32/uint64 +// (and by extension type aliases like `type MyU64 uint64`, since the memory +// layout is identical). On little-endian hosts with a little-endian wire +// format it performs a single memcpy from the decoder buffer into the +// destination's backing storage. Otherwise it falls back to an element loop +// using direct pointer writes so we still avoid an intermediate typed slice. +// +// elemSize must be 2, 4, or 8. dst must be an addressable array or slice. +func readPoDSliceBytes(d *Decoder, dst reflect.Value, l, elemSize int, order binary.ByteOrder) { + if l == 0 { + return + } + need := l * elemSize + src := d.data[d.pos : d.pos+need] + base := unsafe.Pointer(dst.Index(0).UnsafeAddr()) + + if isHostLittleEndian && order == binary.LittleEndian { + // Single memcpy into the destination's backing array. + dstBytes := unsafe.Slice((*byte)(base), need) + copy(dstBytes, src) + d.pos += need + return + } + + // Host is BE, or caller requested BE: decode per element but still write + // directly into the destination's memory to avoid the intermediate slice. + switch elemSize { + case 2: + for i := range l { + *(*uint16)(unsafe.Add(base, i*2)) = order.Uint16(src[i*2:]) + } + case 4: + for i := range l { + *(*uint32)(unsafe.Add(base, i*4)) = order.Uint32(src[i*4:]) + } + case 8: + for i := range l { + *(*uint64)(unsafe.Add(base, i*8)) = order.Uint64(src[i*8:]) + } + } + d.pos += need +} + +func reflect_readArrayOfUint16(d *Decoder, l int, rv reflect.Value, order binary.ByteOrder) error { + need := l * 2 + if need > d.Remaining() { + return io.ErrUnexpectedEOF + } + dst, err := podSliceReadTarget(rv, l) + if err != nil { + return err + } + readPoDSliceBytes(d, dst, l, 2, order) + if rv.Kind() == reflect.Slice { + rv.Set(dst) + } + return nil +} + +func reflect_readArrayOfUint32(d *Decoder, l int, rv reflect.Value, order binary.ByteOrder) error { + need := l * 4 + if need > d.Remaining() { + return io.ErrUnexpectedEOF + } + dst, err := podSliceReadTarget(rv, l) + if err != nil { + return err + } + readPoDSliceBytes(d, dst, l, 4, order) + if rv.Kind() == reflect.Slice { + rv.Set(dst) + } + return nil +} + +func init() { + if typeOfByte != typeOfUint8 { + panic("typeOfByte != typeOfUint8") + } +} + +var ( + typeOfByte = reflect.TypeOf(byte(0)) + typeOfUint8 = reflect.TypeOf(uint8(0)) + typeOfUint16 = reflect.TypeOf(uint16(0)) + typeOfUint32 = reflect.TypeOf(uint32(0)) + typeOfUint64 = reflect.TypeOf(uint64(0)) +) + +func reflect_readArrayOfUint64(d *Decoder, l int, rv reflect.Value, order binary.ByteOrder) error { + need := l * 8 + if need > d.Remaining() { + return io.ErrUnexpectedEOF + } + dst, err := podSliceReadTarget(rv, l) + if err != nil { + return err + } + readPoDSliceBytes(d, dst, l, 8, order) + if rv.Kind() == reflect.Slice { + rv.Set(dst) + } + return nil +} + +// reflect_readArrayOfUint_ is used for reading arrays/slices of uints of any size. +func reflect_readArrayOfUint_(d *Decoder, l int, k reflect.Kind, rv reflect.Value, order binary.ByteOrder) error { + // uint64 arithmetic so `l * elemSize` can't wrap int on 32-bit hosts. + var elemSize uint64 + switch k { + case reflect.Uint8: + elemSize = 1 + case reflect.Uint16: + elemSize = 2 + case reflect.Uint32: + elemSize = 4 + case reflect.Uint64: + elemSize = 8 + default: + return fmt.Errorf("unsupported kind: %v", k) + } + if uint64(l) > uint64(d.Remaining())/elemSize { + return io.ErrUnexpectedEOF + } + switch k { + case reflect.Uint8: + return reflect_readArrayOfBytes(d, l, rv) + case reflect.Uint16: + return reflect_readArrayOfUint16(d, l, rv, order) + case reflect.Uint32: + return reflect_readArrayOfUint32(d, l, rv, order) + default: + return reflect_readArrayOfUint64(d, l, rv, order) + } +} diff --git a/binary/decoder_bench_test.go b/binary/decoder_bench_test.go new file mode 100644 index 000000000..5e816370c --- /dev/null +++ b/binary/decoder_bench_test.go @@ -0,0 +1,419 @@ +package bin + +import ( + "reflect" + "testing" +) + +func newUint64SliceEncoded(l int) []byte { + buf := make([]byte, 0) + for i := 0; i < l; i++ { + buf = append(buf, uint64ToBytes(uint64(i), LE)...) + } + return buf +} + +func Benchmark_uintSlice64_Decode_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got []uint64 + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} +func Benchmark_uintSlice64_Decode_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + got := make([]uint64, 0) + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func Benchmark_uintSlice64_Decode_field_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + type S struct { + Field []uint64 + } + for i := 0; i < b.N; i++ { + var got S + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got.Field) != l { + b.Errorf("got %d, want %d", len(got.Field), l) + } + } +} + +func Benchmark_uintSlice64_Decode_field_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + type S struct { + Field []uint64 + } + for i := 0; i < b.N; i++ { + var got S + got.Field = make([]uint64, 0) + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got.Field) != l { + b.Errorf("got %d, want %d", len(got.Field), l) + } + } +} + +func Benchmark_uintSlice64_readArray_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got []uint64 + + decoder := NewBorshDecoder(buf) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + + err := reflect_readArrayOfUint_(decoder, len(buf)/8, k, rv, LE) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func Benchmark_uintSlice64_readArray_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + got := make([]uint64, 0) + + decoder := NewBorshDecoder(buf) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + + err := reflect_readArrayOfUint_(decoder, len(buf)/8, k, rv, LE) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +type sliceUint64WithCustomDecoder []uint64 + +// UnmarshalWithDecoder +func (s *sliceUint64WithCustomDecoder) UnmarshalWithDecoder(decoder *Decoder) error { + // read length + l, err := decoder.ReadUint32(LE) + if err != nil { + return err + } + // read data + *s = make([]uint64, l) + for i := 0; i < int(l); i++ { + (*s)[i], err = decoder.ReadUint64(LE) + if err != nil { + return err + } + } + return nil +} + +func Benchmark_uintSlice64_Decode_field_withCustomDecoder(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint64SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got sliceUint64WithCustomDecoder + + decoder := NewBorshDecoder(buf) + err := got.UnmarshalWithDecoder(decoder) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func newUint32SliceEncoded(l int) []byte { + buf := make([]byte, 0) + for i := 0; i < l; i++ { + buf = append(buf, uint32ToBytes(uint32(i), LE)...) + } + return buf +} + +func Benchmark_uintSlice32_Decode_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got []uint32 + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} +func Benchmark_uintSlice32_Decode_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + got := make([]uint32, 0) + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func Benchmark_uintSlice32_Decode_field_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + type S struct { + Field []uint32 + } + for i := 0; i < b.N; i++ { + var got S + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got.Field) != l { + b.Errorf("got %d, want %d", len(got.Field), l) + } + } +} + +func Benchmark_uintSlice32_Decode_field_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + type S struct { + Field []uint32 + } + for i := 0; i < b.N; i++ { + var got S + got.Field = make([]uint32, 0) + + decoder := NewBorshDecoder(buf) + err := decoder.Decode(&got) + if err != nil { + b.Error(err) + } + if len(got.Field) != l { + b.Errorf("got %d, want %d", len(got.Field), l) + } + } +} + +func Benchmark_uintSlice32_readArray_noMake(b *testing.B) { + l := 1024 + buf := concatByteSlices( + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got []uint32 + + decoder := NewBorshDecoder(buf) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + + err := reflect_readArrayOfUint_(decoder, len(buf)/4, k, rv, LE) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +func Benchmark_uintSlice32_readArray_make(b *testing.B) { + l := 1024 + buf := concatByteSlices( + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + got := make([]uint32, 0) + + decoder := NewBorshDecoder(buf) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + + err := reflect_readArrayOfUint_(decoder, len(buf)/4, k, rv, LE) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} + +type sliceUint32WithCustomDecoder []uint32 + +// UnmarshalWithDecoder +func (s *sliceUint32WithCustomDecoder) UnmarshalWithDecoder(decoder *Decoder) error { + // read length + l, err := decoder.ReadUint32(LE) + if err != nil { + return err + } + // read data + *s = make([]uint32, l) + for i := 0; i < int(l); i++ { + (*s)[i], err = decoder.ReadUint32(LE) + if err != nil { + return err + } + } + return nil +} +func Benchmark_uintSlice32_Decode_field_withCustomDecoder(b *testing.B) { + l := 1024 + buf := concatByteSlices( + // length: + uint32ToBytes(uint32(l), LE), + // data: + newUint32SliceEncoded(l), + ) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var got sliceUint32WithCustomDecoder + + decoder := NewBorshDecoder(buf) + err := got.UnmarshalWithDecoder(decoder) + if err != nil { + b.Error(err) + } + if len(got) != l { + b.Errorf("got %d, want %d", len(got), l) + } + } +} diff --git a/binary/decoder_bin.go b/binary/decoder_bin.go new file mode 100644 index 000000000..9d237d03c --- /dev/null +++ b/binary/decoder_bin.go @@ -0,0 +1,364 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/binary" + "fmt" + "io" + "reflect" + + "go.uber.org/zap" +) + +func (dec *Decoder) decodeWithOptionBin(v interface{}, opt option) (err error) { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return &InvalidDecoderError{reflect.TypeOf(v)} + } + + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + err = dec.decodeBin(rv, opt) + if err != nil { + return err + } + return nil +} + +func (dec *Decoder) decodeBin(rv reflect.Value, opt option) (err error) { + if opt.Order == nil { + opt.Order = defaultByteOrder + } + dec.currentFieldOpt = opt + + unmarshaler, rv := indirect(rv, opt.is_Optional()) + + if traceEnabled { + zlog.Debug("decode: type", + zap.Stringer("value_kind", rv.Kind()), + zap.Bool("has_unmarshaler", (unmarshaler != nil)), + zap.Reflect("options", opt), + ) + } + + if opt.is_Optional() { + isPresent, e := dec.ReadUint32(binary.LittleEndian) + if e != nil { + err = fmt.Errorf("decode: %s isPresent, %s", rv.Type().String(), e) + return + } + + if isPresent == 0 { + if traceEnabled { + zlog.Debug("decode: skipping optional value", zap.Stringer("type", rv.Kind())) + } + + rv.Set(reflect.Zero(rv.Type())) + return + } + + // we have ptr here we should not go get the element + unmarshaler, rv = indirect(rv, false) + } + + if unmarshaler != nil { + if traceEnabled { + zlog.Debug("decode: using UnmarshalWithDecoder method to decode type") + } + return unmarshaler.UnmarshalWithDecoder(dec) + } + rt := rv.Type() + + switch rv.Kind() { + case reflect.String: + s, e := dec.ReadRustString() + if e != nil { + err = e + return + } + rv.SetString(s) + return + case reflect.Uint8: + var n byte + n, err = dec.ReadByte() + rv.SetUint(uint64(n)) + return + case reflect.Int8: + var n int8 + n, err = dec.ReadInt8() + rv.SetInt(int64(n)) + return + case reflect.Int16: + var n int16 + n, err = dec.ReadInt16(opt.Order) + rv.SetInt(int64(n)) + return + case reflect.Int32: + var n int32 + n, err = dec.ReadInt32(opt.Order) + rv.SetInt(int64(n)) + return + case reflect.Int64: + var n int64 + n, err = dec.ReadInt64(opt.Order) + rv.SetInt(int64(n)) + return + case reflect.Uint16: + var n uint16 + n, err = dec.ReadUint16(opt.Order) + rv.SetUint(uint64(n)) + return + case reflect.Uint32: + var n uint32 + n, err = dec.ReadUint32(opt.Order) + rv.SetUint(uint64(n)) + return + case reflect.Uint64: + var n uint64 + n, err = dec.ReadUint64(opt.Order) + rv.SetUint(n) + return + case reflect.Float32: + var n float32 + n, err = dec.ReadFloat32(opt.Order) + rv.SetFloat(float64(n)) + return + case reflect.Float64: + var n float64 + n, err = dec.ReadFloat64(opt.Order) + rv.SetFloat(n) + return + case reflect.Bool: + var r bool + r, err = dec.ReadBool() + rv.SetBool(r) + return + case reflect.Interface: + // skip + return nil + } + switch rt.Kind() { + case reflect.Array: + l := rt.Len() + if traceEnabled { + zlog.Debug("decoding: reading array", zap.Int("length", l)) + } + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + for i := range l { + if err = dec.decodeBin(rv.Index(i), opt); err != nil { + return + } + } + } + return + case reflect.Slice: + var l int + if opt.hasSizeOfSlice() { + l = opt.getSizeOfSlice() + } else { + length, err := dec.ReadLength() + if err != nil { + return err + } + l = length + } + + if traceEnabled { + zlog.Debug("reading slice", zap.Int("len", l), typeField("type", rv)) + } + + if l > dec.Remaining() { + return io.ErrUnexpectedEOF + } + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + // Pre-size the slice once instead of growing it via reflect.Append + // in a loop. Decoding into slc.Index(i).Addr() also avoids the + // per-element reflect.New heap allocation. + slc := reflect.MakeSlice(rt, l, l) + elOpt := option{Order: opt.Order} + for i := range l { + if err = dec.decodeBin(slc.Index(i).Addr(), elOpt); err != nil { + return + } + } + rv.Set(slc) + } + + case reflect.Struct: + if err = dec.decodeStructBin(rt, rv); err != nil { + return + } + + case reflect.Map: + l, err := dec.ReadLength() + if err != nil { + return err + } + if l == 0 { + // If the map has no content, keep it nil. + return nil + } + rv.Set(reflect.MakeMap(rt)) + mapOpt := option{Order: opt.Order} + for i := 0; i < int(l); i++ { + key := reflect.New(rt.Key()) + err := dec.decodeBin(key.Elem(), mapOpt) + if err != nil { + return err + } + val := reflect.New(rt.Elem()) + err = dec.decodeBin(val.Elem(), mapOpt) + if err != nil { + return err + } + rv.SetMapIndex(key.Elem(), val.Elem()) + } + return nil + + default: + return fmt.Errorf("decode: unsupported type %q", rt) + } + + return +} + +func (dec *Decoder) decodeStructBin(rt reflect.Type, rv reflect.Value) (err error) { + plan := planForStruct(rt) + + if traceEnabled { + zlog.Debug("decode: struct", zap.Int("fields", len(plan.fields)), zap.Stringer("type", rv.Kind())) + } + + var sizes []int + if plan.hasSizeOf { + var stack sizesScratch + if len(plan.fields) <= sizesScratchLen { + sizes = stack[:len(plan.fields)] + } else { + sizes = make([]int, len(plan.fields)) + } + for i := range sizes { + sizes[i] = -1 + } + } + + seenBinaryExtensionField := false + for i := range plan.fields { + fp := &plan.fields[i] + + if fp.skip { + if traceEnabled { + zlog.Debug("decode: skipping struct field with skip flag", + zap.String("struct_field_name", fp.name), + ) + } + continue + } + + if !fp.binaryExtension && seenBinaryExtensionField { + panic(fmt.Sprintf("the `bin:\"binary_extension\"` tags must be packed together at the end of struct fields, problematic field %q", fp.name)) + } + + if fp.binaryExtension { + seenBinaryExtensionField = true + if len(dec.data[dec.pos:]) <= 0 { + continue + } + } + + // Fast primitive path: no option construction, no kind switch. + if fp.binFastDecode != nil { + if err = fp.binFastDecode(dec, rv.Field(i)); err != nil { + return fmt.Errorf("error while decoding %q field: %w", fp.name, err) + } + continue + } + + v := rv.Field(i) + if !v.CanSet() { + if !v.CanAddr() { + if traceEnabled { + zlog.Debug("skipping struct field that cannot be addressed", + zap.String("struct_field_name", fp.name), + zap.Stringer("struct_value_type", v.Kind()), + ) + } + return fmt.Errorf("unable to decode a none setup struc field %q with type %q", fp.name, v.Kind()) + } + v = v.Addr() + } + + if !v.CanSet() { + if traceEnabled { + zlog.Debug("skipping struct field that cannot be addressed", + zap.String("struct_field_name", fp.name), + zap.Stringer("struct_value_type", v.Kind()), + ) + } + continue + } + + opt := option{ + is_OptionalField: fp.tag.Option, + Order: fp.tag.Order, + } + + if sizes != nil && fp.sizeFromIdx >= 0 && sizes[i] >= 0 { + opt.sliceSizeIsSet = true + opt.sliceSize = sizes[i] + } + + if traceEnabled { + zlog.Debug("decode: struct field", + zap.Stringer("struct_field_value_type", v.Kind()), + zap.String("struct_field_name", fp.name), + zap.Reflect("struct_field_tags", fp.tag), + zap.Reflect("struct_field_option", opt), + ) + } + + if err = dec.decodeBin(v, opt); err != nil { + return fmt.Errorf("error while decoding %q field: %w", fp.name, err) + } + + if fp.sizeOfTargetIdx >= 0 && sizes != nil { + size := sizeof(fp.fieldType, v) + if traceEnabled { + zlog.Debug("setting size of field", + zap.String("field_name", plan.fields[fp.sizeOfTargetIdx].name), + zap.Int("size", size), + ) + } + sizes[fp.sizeOfTargetIdx] = size + } + } + return +} diff --git a/binary/decoder_borsh.go b/binary/decoder_borsh.go new file mode 100644 index 000000000..bf8afe4bd --- /dev/null +++ b/binary/decoder_borsh.go @@ -0,0 +1,464 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "errors" + "fmt" + "io" + "reflect" + + "go.uber.org/zap" +) + +func (dec *Decoder) decodeWithOptionBorsh(v interface{}, opt option) (err error) { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return &InvalidDecoderError{reflect.TypeOf(v)} + } + + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + err = dec.decodeBorsh(rv, opt) + if err != nil { + return err + } + return nil +} + +func (dec *Decoder) decodeBorsh(rv reflect.Value, opt option) (err error) { + if opt.Order == nil { + opt.Order = defaultByteOrder + } + dec.currentFieldOpt = opt + + unmarshaler, rv := indirect(rv, opt.is_Optional() || opt.is_COptional()) + + if traceEnabled { + zlog.Debug("decode: type", + zap.Stringer("value_kind", rv.Kind()), + zap.Bool("has_unmarshaler", (unmarshaler != nil)), + zap.Reflect("options", opt), + ) + } + + if opt.is_Optional() { + isPresent, e := dec.ReadOption() + if e != nil { + err = fmt.Errorf("decode: %s isPresent: %w", rv.Type(), e) + return + } + + if !isPresent { + if traceEnabled { + zlog.Debug("decode: skipping optional value", zap.Stringer("type", rv.Kind())) + } + + rv.Set(reflect.Zero(rv.Type())) + return + } + + // we have ptr here we should not go get the element + unmarshaler, rv = indirect(rv, false) + } + if opt.is_COptional() { + isPresent, e := dec.ReadCOption() + if e != nil { + err = fmt.Errorf("decode: %s isPresent: %w", rv.Type(), e) + return + } + + if !isPresent { + if traceEnabled { + zlog.Debug("decode: skipping optional value", zap.Stringer("type", rv.Kind())) + } + + rv.Set(reflect.Zero(rv.Type())) + return + } + + // we have ptr here we should not go get the element + unmarshaler, rv = indirect(rv, false) + } + // Reset optionality so it won't propagate to child types. opt is a value + // copy so we mutate it locally without affecting the caller. + opt.is_OptionalField = false + opt.is_COptionalField = false + + if unmarshaler != nil { + if traceEnabled { + zlog.Debug("decode: using UnmarshalWithDecoder method to decode type") + } + return unmarshaler.UnmarshalWithDecoder(dec) + } + + rt := rv.Type() + switch rv.Kind() { + // case reflect.Int: + // // TODO: check if is x32 or x64 + // var n int64 + // n, err = dec.ReadInt64(LE) + // rv.SetInt(n) + // return + // case reflect.Uint: + // // TODO: check if is x32 or x64 + // var n uint64 + // n, err = dec.ReadUint64(LE) + // rv.SetUint(n) + // return + case reflect.String: + s, e := dec.ReadString() + if e != nil { + err = e + return + } + rv.SetString(s) + return + case reflect.Uint8: + var n byte + n, err = dec.ReadByte() + rv.SetUint(uint64(n)) + return + case reflect.Int8: + var n int8 + n, err = dec.ReadInt8() + rv.SetInt(int64(n)) + return + case reflect.Int16: + var n int16 + n, err = dec.ReadInt16(LE) + rv.SetInt(int64(n)) + return + case reflect.Int32: + var n int32 + n, err = dec.ReadInt32(LE) + rv.SetInt(int64(n)) + return + case reflect.Int64: + var n int64 + n, err = dec.ReadInt64(LE) + rv.SetInt(int64(n)) + return + case reflect.Uint16: + var n uint16 + n, err = dec.ReadUint16(LE) + rv.SetUint(uint64(n)) + return + case reflect.Uint32: + var n uint32 + n, err = dec.ReadUint32(LE) + rv.SetUint(uint64(n)) + return + case reflect.Uint64: + var n uint64 + n, err = dec.ReadUint64(LE) + rv.SetUint(n) + return + case reflect.Float32: + var n float32 + n, err = dec.ReadFloat32(LE) + rv.SetFloat(float64(n)) + return + case reflect.Float64: + var n float64 + n, err = dec.ReadFloat64(LE) + rv.SetFloat(n) + return + case reflect.Bool: + var r bool + r, err = dec.ReadBool() + rv.SetBool(r) + return + case reflect.Interface: + // Skip: cannot know the concrete type of the interface. + // The parent container should implement a custom decoder. + return nil + // TODO: handle reflect.Ptr ??? + } + switch rt.Kind() { + case reflect.Array: + l := rt.Len() + if traceEnabled { + zlog.Debug("decoding: reading array", zap.Int("length", l)) + } + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + for i := range l { + if err = dec.decodeBorsh(rv.Index(i), opt); err != nil { + return + } + } + } + return + case reflect.Slice: + var l int + if opt.hasSizeOfSlice() { + l = opt.getSizeOfSlice() + } else { + length, err := dec.ReadUint32(LE) + if err != nil { + return err + } + l = int(length) + } + + if traceEnabled { + zlog.Debug("reading slice", zap.Int("len", l), typeField("type", rv)) + } + + if l == 0 { + // Empty slices are left nil + return + } + if l > dec.Remaining() { + return io.ErrUnexpectedEOF + } + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + // Pre-size and decode in place; avoids per-element reflect.New + // and the O(log n) reallocs that reflect.Append would trigger. + slc := reflect.MakeSlice(rt, l, l) + elOpt := option{Order: opt.Order} + for i := range l { + if err = dec.decodeBorsh(slc.Index(i).Addr(), elOpt); err != nil { + return + } + } + rv.Set(slc) + } + + case reflect.Struct: + if err = dec.decodeStructBorsh(rt, rv); err != nil { + return + } + + case reflect.Map: + l, err := dec.ReadUint32(LE) + if err != nil { + return err + } + if l == 0 { + // If the map has no content, keep it nil. + return nil + } + rv.Set(reflect.MakeMap(rt)) + mapOpt := option{Order: opt.Order} + for i := 0; i < int(l); i++ { + key := reflect.New(rt.Key()) + err := dec.decodeBorsh(key.Elem(), mapOpt) + if err != nil { + return err + } + val := reflect.New(rt.Elem()) + err = dec.decodeBorsh(val.Elem(), mapOpt) + if err != nil { + return err + } + rv.SetMapIndex(key.Elem(), val.Elem()) + } + return nil + + default: + return fmt.Errorf("decode: unsupported type %q", rt) + } + + return +} + +func (dec *Decoder) deserializeComplexEnum(rv reflect.Value) error { + rt := rv.Type() + // read enum identifier + tmp, err := dec.ReadUint8() + if err != nil { + return err + } + enum := BorshEnum(tmp) + rv.Field(0).Set(reflect.ValueOf(enum).Convert(rv.Field(0).Type())) + + // read enum field, if necessary + if int(enum)+1 >= rt.NumField() { + return errors.New("complex enum too large") + } + field := rv.Field(int(enum) + 1) + return dec.decodeBorsh(field, defaultOption) +} + +var borshEnumType = reflect.TypeOf(BorshEnum(0)) + +func isTypeBorshEnum(typ reflect.Type) bool { + return typ.Kind() == reflect.Uint8 && typ == borshEnumType +} + +func (dec *Decoder) decodeStructBorsh(rt reflect.Type, rv reflect.Value) (err error) { + plan := planForStruct(rt) + + if traceEnabled { + zlog.Debug("decode: struct", zap.Int("fields", len(plan.fields)), zap.Stringer("type", rv.Kind())) + } + + if plan.isComplexEnum { + return dec.deserializeComplexEnum(rv) + } + + // sizes is non-nil only when the struct actually has sizeof wiring; + // allocated on the stack via sizesScratch for small structs. + var sizes []int + if plan.hasSizeOf { + var stack sizesScratch + if len(plan.fields) <= sizesScratchLen { + sizes = stack[:len(plan.fields)] + } else { + sizes = make([]int, len(plan.fields)) + } + for i := range sizes { + sizes[i] = -1 + } + } + + seenBinaryExtensionField := false + for i := range plan.fields { + fp := &plan.fields[i] + + if fp.skip { + if traceEnabled { + zlog.Debug("decode: skipping struct field with skip flag", + zap.String("struct_field_name", fp.name), + ) + } + continue + } + + if !fp.binaryExtension && seenBinaryExtensionField { + panic(fmt.Sprintf("the `bin:\"binary_extension\"` tags must be packed together at the end of struct fields, problematic field %q", fp.name)) + } + + if fp.binaryExtension { + seenBinaryExtensionField = true + // FIXME: This works only if what is in `d.data` is the actual full data buffer that + // needs to be decoded. If there is for example two structs in the buffer, this + // will not work as we would continue into the next struct. + if len(dec.data[dec.pos:]) <= 0 { + continue + } + } + + // Fast primitive path: no option construction, no kind switch, just an + // inline read straight into the field's memory. + if fp.borshFastDecode != nil { + if err = fp.borshFastDecode(dec, rv.Field(i)); err != nil { + return fmt.Errorf("error while decoding %q field: %w", fp.name, err) + } + continue + } + + v := rv.Field(i) + if !v.CanSet() { + if !v.CanAddr() { + if traceEnabled { + zlog.Debug("skipping struct field that cannot be addressed", + zap.String("struct_field_name", fp.name), + zap.Stringer("struct_value_type", v.Kind()), + ) + } + return fmt.Errorf("unable to decode a none setup struc field %q with type %q", fp.name, v.Kind()) + } + v = v.Addr() + } + + if !v.CanSet() { + if traceEnabled { + zlog.Debug("skipping struct field that cannot be addressed", + zap.String("struct_field_name", fp.name), + zap.Stringer("struct_value_type", v.Kind()), + ) + } + continue + } + + opt := option{ + is_OptionalField: fp.tag.Option, + is_COptionalField: fp.tag.COption, + Order: fp.tag.Order, + } + + if sizes != nil && fp.sizeFromIdx >= 0 && sizes[i] >= 0 { + opt.sliceSizeIsSet = true + opt.sliceSize = sizes[i] + } + + if traceEnabled { + zlog.Debug("decode: struct field", + zap.Stringer("struct_field_value_type", v.Kind()), + zap.String("struct_field_name", fp.name), + zap.Reflect("struct_field_tags", fp.tag), + zap.Reflect("struct_field_option", opt), + ) + } + + if fp.ptrImplementsUnmarshaler || fp.valImplementsUnmarshaler { + ft := fp.fieldType + switch { + case fp.ptrImplementsUnmarshaler: + m := reflect.New(ft) + val := m.Interface() + if err := val.(BinaryUnmarshaler).UnmarshalWithDecoder(dec); err != nil { + return err + } + v.Set(reflect.ValueOf(val).Elem()) + case fp.valImplementsUnmarshaler: + m := reflect.New(ft.Elem()) + val := m.Interface() + if err := val.(BinaryUnmarshaler).UnmarshalWithDecoder(dec); err != nil { + return err + } + v.Set(reflect.ValueOf(val)) + } + } else { + if err = dec.decodeBorsh(v, opt); err != nil { + return fmt.Errorf("error while decoding %q field: %w", fp.name, err) + } + } + + if fp.sizeOfTargetIdx >= 0 && sizes != nil { + size := sizeof(fp.fieldType, v) + if traceEnabled { + zlog.Debug("setting size of field", + zap.String("field_name", plan.fields[fp.sizeOfTargetIdx].name), + zap.Int("size", size), + ) + } + sizes[fp.sizeOfTargetIdx] = size + } + } + return +} + +var ( + marshalableType = reflect.TypeOf((*BinaryMarshaler)(nil)).Elem() + unmarshalableType = reflect.TypeOf((*BinaryUnmarshaler)(nil)).Elem() +) diff --git a/binary/decoder_compact-u16.go b/binary/decoder_compact-u16.go new file mode 100644 index 000000000..f3f94896c --- /dev/null +++ b/binary/decoder_compact-u16.go @@ -0,0 +1,360 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "fmt" + "io" + "reflect" + + "go.uber.org/zap" +) + +func (dec *Decoder) decodeWithOptionCompactU16(v interface{}, opt option) (err error) { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return &InvalidDecoderError{reflect.TypeOf(v)} + } + + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + err = dec.decodeCompactU16(rv, opt) + if err != nil { + return err + } + return nil +} + +func (dec *Decoder) decodeCompactU16(rv reflect.Value, opt option) (err error) { + if opt.Order == nil { + opt.Order = defaultByteOrder + } + dec.currentFieldOpt = opt + + unmarshaler, rv := indirect(rv, opt.is_Optional()) + + if traceEnabled { + zlog.Debug("decode: type", + zap.Stringer("value_kind", rv.Kind()), + zap.Bool("has_unmarshaler", (unmarshaler != nil)), + zap.Reflect("options", opt), + ) + } + + if opt.is_Optional() { + isPresent, e := dec.ReadByte() + if e != nil { + err = fmt.Errorf("decode: %t isPresent, %s", rv.Type(), e) + return + } + + if isPresent == 0 { + if traceEnabled { + zlog.Debug("decode: skipping optional value", zap.Stringer("type", rv.Kind())) + } + + rv.Set(reflect.Zero(rv.Type())) + return + } + + // we have ptr here we should not go get the element + unmarshaler, rv = indirect(rv, false) + } + + if unmarshaler != nil { + if traceEnabled { + zlog.Debug("decode: using UnmarshalWithDecoder method to decode type") + } + return unmarshaler.UnmarshalWithDecoder(dec) + } + rt := rv.Type() + + switch rv.Kind() { + case reflect.String: + s, e := dec.ReadString() + if e != nil { + err = e + return + } + rv.SetString(s) + return + case reflect.Uint8: + var n byte + n, err = dec.ReadByte() + rv.SetUint(uint64(n)) + return + case reflect.Int8: + var n int8 + n, err = dec.ReadInt8() + rv.SetInt(int64(n)) + return + case reflect.Int16: + var n int16 + n, err = dec.ReadInt16(opt.Order) + rv.SetInt(int64(n)) + return + case reflect.Int32: + var n int32 + n, err = dec.ReadInt32(opt.Order) + rv.SetInt(int64(n)) + return + case reflect.Int64: + var n int64 + n, err = dec.ReadInt64(opt.Order) + rv.SetInt(int64(n)) + return + case reflect.Uint16: + var n uint16 + n, err = dec.ReadUint16(opt.Order) + rv.SetUint(uint64(n)) + return + case reflect.Uint32: + var n uint32 + n, err = dec.ReadUint32(opt.Order) + rv.SetUint(uint64(n)) + return + case reflect.Uint64: + var n uint64 + n, err = dec.ReadUint64(opt.Order) + rv.SetUint(n) + return + case reflect.Float32: + var n float32 + n, err = dec.ReadFloat32(opt.Order) + rv.SetFloat(float64(n)) + return + case reflect.Float64: + var n float64 + n, err = dec.ReadFloat64(opt.Order) + rv.SetFloat(n) + return + case reflect.Bool: + var r bool + r, err = dec.ReadBool() + rv.SetBool(r) + return + case reflect.Interface: + // skip + return nil + } + switch rt.Kind() { + case reflect.Array: + l := rt.Len() + if traceEnabled { + zlog.Debug("decoding: reading array", zap.Int("length", l)) + } + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + for i := range l { + if err = dec.decodeCompactU16(rv.Index(i), opt); err != nil { + return + } + } + } + return + case reflect.Slice: + var l int + if opt.hasSizeOfSlice() { + l = opt.getSizeOfSlice() + } else { + length, err := dec.ReadCompactU16Length() + if err != nil { + return err + } + l = int(length) + } + + if traceEnabled { + zlog.Debug("reading slice", zap.Int("len", l), typeField("type", rv)) + } + + if l > dec.Remaining() { + return io.ErrUnexpectedEOF + } + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := reflect_readArrayOfUint_(dec, l, k, rv, LE); err != nil { + return err + } + default: + slc := reflect.MakeSlice(rt, l, l) + elOpt := option{Order: opt.Order} + for i := range l { + if err = dec.decodeCompactU16(slc.Index(i).Addr(), elOpt); err != nil { + return + } + } + rv.Set(slc) + } + + case reflect.Struct: + if err = dec.decodeStructCompactU16(rt, rv); err != nil { + return + } + + case reflect.Map: + l, err := dec.ReadCompactU16Length() + if err != nil { + return err + } + if l == 0 { + // If the map has no content, keep it nil. + return nil + } + rv.Set(reflect.MakeMap(rt)) + mapOpt := option{Order: opt.Order} + for i := 0; i < int(l); i++ { + key := reflect.New(rt.Key()) + err := dec.decodeCompactU16(key.Elem(), mapOpt) + if err != nil { + return err + } + val := reflect.New(rt.Elem()) + err = dec.decodeCompactU16(val.Elem(), mapOpt) + if err != nil { + return err + } + rv.SetMapIndex(key.Elem(), val.Elem()) + } + return nil + + default: + return fmt.Errorf("decode: unsupported type %q", rt) + } + + return +} + +func (dec *Decoder) decodeStructCompactU16(rt reflect.Type, rv reflect.Value) (err error) { + plan := planForStruct(rt) + + if traceEnabled { + zlog.Debug("decode: struct", zap.Int("fields", len(plan.fields)), zap.Stringer("type", rv.Kind())) + } + + var sizes []int + if plan.hasSizeOf { + var stack sizesScratch + if len(plan.fields) <= sizesScratchLen { + sizes = stack[:len(plan.fields)] + } else { + sizes = make([]int, len(plan.fields)) + } + for i := range sizes { + sizes[i] = -1 + } + } + + seenBinaryExtensionField := false + for i := range plan.fields { + fp := &plan.fields[i] + + if fp.skip { + if traceEnabled { + zlog.Debug("decode: skipping struct field with skip flag", + zap.String("struct_field_name", fp.name), + ) + } + continue + } + + if !fp.binaryExtension && seenBinaryExtensionField { + panic(fmt.Sprintf("the `bin:\"binary_extension\"` tags must be packed together at the end of struct fields, problematic field %q", fp.name)) + } + + if fp.binaryExtension { + seenBinaryExtensionField = true + if len(dec.data[dec.pos:]) <= 0 { + continue + } + } + + // Fast primitive path: no option construction, no kind switch. + if fp.binFastDecode != nil { + if err = fp.binFastDecode(dec, rv.Field(i)); err != nil { + return fmt.Errorf("error while decoding %q field: %w", fp.name, err) + } + continue + } + + v := rv.Field(i) + if !v.CanSet() { + if !v.CanAddr() { + if traceEnabled { + zlog.Debug("skipping struct field that cannot be addressed", + zap.String("struct_field_name", fp.name), + zap.Stringer("struct_value_type", v.Kind()), + ) + } + return fmt.Errorf("unable to decode a none setup struc field %q with type %q", fp.name, v.Kind()) + } + v = v.Addr() + } + + if !v.CanSet() { + if traceEnabled { + zlog.Debug("skipping struct field that cannot be addressed", + zap.String("struct_field_name", fp.name), + zap.Stringer("struct_value_type", v.Kind()), + ) + } + continue + } + + opt := option{ + is_OptionalField: fp.tag.Option, + Order: fp.tag.Order, + } + + if sizes != nil && fp.sizeFromIdx >= 0 && sizes[i] >= 0 { + opt.sliceSizeIsSet = true + opt.sliceSize = sizes[i] + } + + if traceEnabled { + zlog.Debug("decode: struct field", + zap.Stringer("struct_field_value_type", v.Kind()), + zap.String("struct_field_name", fp.name), + zap.Reflect("struct_field_tags", fp.tag), + zap.Reflect("struct_field_option", opt), + ) + } + + if err = dec.decodeCompactU16(v, opt); err != nil { + return fmt.Errorf("error while decoding %q field: %w", fp.name, err) + } + + if fp.sizeOfTargetIdx >= 0 && sizes != nil { + size := sizeof(fp.fieldType, v) + if traceEnabled { + zlog.Debug("setting size of field", + zap.String("field_name", plan.fields[fp.sizeOfTargetIdx].name), + zap.Int("size", size), + ) + } + sizes[fp.sizeOfTargetIdx] = size + } + } + return +} diff --git a/binary/decoder_test.go b/binary/decoder_test.go new file mode 100644 index 000000000..7f01fa0d5 --- /dev/null +++ b/binary/decoder_test.go @@ -0,0 +1,1431 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/binary" + "encoding/hex" + "math" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecoder_Peek(t *testing.T) { + buf := []byte{ + 0x17, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0, 0x0, + } + + dec := NewBinDecoder(buf) + { + peeked, err := dec.Peek(8) + assert.NoError(t, err) + assert.Len(t, peeked, 8) + assert.Equal(t, buf, peeked) + } + { + peeked, err := dec.Peek(8) + assert.NoError(t, err) + assert.Len(t, peeked, 8) + assert.Equal(t, buf, peeked) + } + { + peeked, err := dec.Peek(1) + assert.NoError(t, err) + assert.Len(t, peeked, 1) + assert.Equal(t, buf[0], peeked[0]) + } + { + peeked, err := dec.Peek(2) + assert.NoError(t, err) + assert.Len(t, peeked, 2) + assert.Equal(t, buf[:2], peeked) + } + { + read, err := dec.ReadByte() + assert.Equal(t, buf[0], read) + assert.NoError(t, err) + + peeked, err := dec.Peek(1) + assert.NoError(t, err) + assert.Len(t, peeked, 1) + assert.Equal(t, buf[1], peeked[0]) + } +} + +func TestDecoder_AliastTestType(t *testing.T) { + buf := []byte{ + 0x17, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0, 0x0, + } + + var s aliasTestType + err := NewBinDecoder(buf).Decode(&s) + assert.NoError(t, err) + assert.Equal(t, uint64(23), uint64(s)) +} + +func TestDecoder_Remaining(t *testing.T) { + b := make([]byte, 4) + binary.LittleEndian.PutUint16(b, 1) + binary.LittleEndian.PutUint16(b[2:], 2) + + d := NewBinDecoder(b) + + n, err := d.ReadUint16(LE) + assert.NoError(t, err) + assert.Equal(t, uint16(1), n) + assert.Equal(t, 2, d.Remaining()) + + n, err = d.ReadUint16(LE) + assert.NoError(t, err) + assert.Equal(t, uint16(2), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_int8(t *testing.T) { + buf := []byte{ + 0x9d, // -99 + 0x64, // 100 + } + + d := NewBinDecoder(buf) + + n, err := d.ReadInt8() + assert.NoError(t, err) + assert.Equal(t, int8(-99), n) + assert.Equal(t, 1, d.Remaining()) + + n, err = d.ReadInt8() + assert.NoError(t, err) + assert.Equal(t, int8(100), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_int16(t *testing.T) { + // little endian + buf := []byte{ + 0xae, 0xff, // -82 + 0x49, 0x00, // 73 + } + + d := NewBinDecoder(buf) + + n, err := d.ReadInt16(LE) + assert.NoError(t, err) + assert.Equal(t, int16(-82), n) + assert.Equal(t, 2, d.Remaining()) + + n, err = d.ReadInt16(LE) + assert.NoError(t, err) + assert.Equal(t, int16(73), n) + assert.Equal(t, 0, d.Remaining()) + + // big endian + buf = []byte{ + 0xff, 0xae, // -82 + 0x00, 0x49, // 73 + } + + d = NewBinDecoder(buf) + + n, err = d.ReadInt16(BE) + assert.NoError(t, err) + assert.Equal(t, int16(-82), n) + assert.Equal(t, 2, d.Remaining()) + + n, err = d.ReadInt16(BE) + assert.NoError(t, err) + assert.Equal(t, int16(73), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_int32(t *testing.T) { + // little endian + buf := []byte{ + 0xd8, 0x8d, 0x8a, 0xef, // -276132392 + 0x4f, 0x9f, 0x3, 0x00, // 237391 + } + + d := NewBinDecoder(buf) + + n, err := d.ReadInt32(LE) + assert.NoError(t, err) + assert.Equal(t, int32(-276132392), n) + assert.Equal(t, 4, d.Remaining()) + + n, err = d.ReadInt32(LE) + assert.NoError(t, err) + assert.Equal(t, int32(237391), n) + assert.Equal(t, 0, d.Remaining()) + + // big endian + buf = []byte{ + 0xef, 0x8a, 0x8d, 0xd8, // -276132392 + 0x00, 0x3, 0x9f, 0x4f, // 237391 + } + + d = NewBinDecoder(buf) + + n, err = d.ReadInt32(BE) + assert.NoError(t, err) + assert.Equal(t, int32(-276132392), n) + assert.Equal(t, 4, d.Remaining()) + + n, err = d.ReadInt32(BE) + assert.NoError(t, err) + assert.Equal(t, int32(237391), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_int64(t *testing.T) { + // little endian + buf := []byte{ + 0x91, 0x7d, 0xf3, 0xff, 0xff, 0xff, 0xff, 0xff, //-819823 + 0xe3, 0x1c, 0x1, 0x00, 0x00, 0x00, 0x00, 0x00, //72931 + } + + d := NewBinDecoder(buf) + + n, err := d.ReadInt64(LE) + assert.NoError(t, err) + assert.Equal(t, int64(-819823), n) + assert.Equal(t, 8, d.Remaining()) + + n, err = d.ReadInt64(LE) + assert.NoError(t, err) + assert.Equal(t, int64(72931), n) + assert.Equal(t, 0, d.Remaining()) + + // big endian + buf = []byte{ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xf3, 0x7d, 0x91, //-819823 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x1, 0x1c, 0xe3, //72931 + } + + d = NewBinDecoder(buf) + + n, err = d.ReadInt64(BE) + assert.NoError(t, err) + assert.Equal(t, int64(-819823), n) + assert.Equal(t, 8, d.Remaining()) + + n, err = d.ReadInt64(BE) + assert.NoError(t, err) + assert.Equal(t, int64(72931), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_uint8(t *testing.T) { + buf := []byte{ + 0x63, // 99 + 0x64, // 100 + } + + d := NewBinDecoder(buf) + + n, err := d.ReadUint8() + assert.NoError(t, err) + assert.Equal(t, uint8(99), n) + assert.Equal(t, 1, d.Remaining()) + + n, err = d.ReadUint8() + assert.NoError(t, err) + assert.Equal(t, uint8(100), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_uint16(t *testing.T) { + // little endian + buf := []byte{ + 0x52, 0x00, // 82 + 0x49, 0x00, // 73 + } + + d := NewBinDecoder(buf) + + n, err := d.ReadUint16(LE) + assert.NoError(t, err) + assert.Equal(t, uint16(82), n) + assert.Equal(t, 2, d.Remaining()) + + n, err = d.ReadUint16(LE) + assert.NoError(t, err) + assert.Equal(t, uint16(73), n) + assert.Equal(t, 0, d.Remaining()) + + // big endian + buf = []byte{ + 0x00, 0x52, // 82 + 0x00, 0x49, // 73 + } + + d = NewBinDecoder(buf) + + n, err = d.ReadUint16(BE) + assert.NoError(t, err) + assert.Equal(t, uint16(82), n) + assert.Equal(t, 2, d.Remaining()) + + n, err = d.ReadUint16(BE) + assert.NoError(t, err) + assert.Equal(t, uint16(73), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_uint32(t *testing.T) { + // little endian + buf := []byte{ + 0x28, 0x72, 0x75, 0x10, // 276132392 as LE + 0x4f, 0x9f, 0x03, 0x00, // 237391 as LE + } + + d := NewBinDecoder(buf) + + n, err := d.ReadUint32(LE) + assert.NoError(t, err) + assert.Equal(t, uint32(276132392), n) + assert.Equal(t, 4, d.Remaining()) + + n, err = d.ReadUint32(LE) + assert.NoError(t, err) + assert.Equal(t, uint32(237391), n) + assert.Equal(t, 0, d.Remaining()) + + // big endian + buf = []byte{ + 0x10, 0x75, 0x72, 0x28, // 276132392 as LE + 0x00, 0x03, 0x9f, 0x4f, // 237391 as LE + } + + d = NewBinDecoder(buf) + + n, err = d.ReadUint32(BE) + assert.NoError(t, err) + assert.Equal(t, uint32(276132392), n) + assert.Equal(t, 4, d.Remaining()) + + n, err = d.ReadUint32(BE) + assert.NoError(t, err) + assert.Equal(t, uint32(237391), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_uint64(t *testing.T) { + // little endian + buf := []byte{ + 0x6f, 0x82, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, //819823 + 0xe3, 0x1c, 0x1, 0x00, 0x00, 0x00, 0x00, 0x00, //72931 + } + + d := NewBinDecoder(buf) + + n, err := d.ReadUint64(LE) + assert.NoError(t, err) + assert.Equal(t, uint64(819823), n) + assert.Equal(t, 8, d.Remaining()) + + n, err = d.ReadUint64(LE) + assert.NoError(t, err) + assert.Equal(t, uint64(72931), n) + assert.Equal(t, 0, d.Remaining()) + + // big endian + buf = []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x82, 0x6f, //819823 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x1, 0x1c, 0xe3, //72931 + } + + d = NewBinDecoder(buf) + + n, err = d.ReadUint64(BE) + assert.NoError(t, err) + assert.Equal(t, uint64(819823), n) + assert.Equal(t, 8, d.Remaining()) + + n, err = d.ReadUint64(BE) + assert.NoError(t, err) + assert.Equal(t, uint64(72931), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_float32(t *testing.T) { + // little endian + buf := []byte{ + 0xc3, 0xf5, 0xa8, 0x3f, + 0xa4, 0x70, 0x4d, 0xc0, + } + + d := NewBinDecoder(buf) + + n, err := d.ReadFloat32(LE) + assert.NoError(t, err) + assert.Equal(t, float32(1.32), n) + assert.Equal(t, 4, d.Remaining()) + + n, err = d.ReadFloat32(LE) + assert.NoError(t, err) + assert.Equal(t, float32(-3.21), n) + assert.Equal(t, 0, d.Remaining()) + + // big endian + buf = []byte{ + 0x3f, 0xa8, 0xf5, 0xc3, + 0xc0, 0x4d, 0x70, 0xa4, + } + + d = NewBinDecoder(buf) + + n, err = d.ReadFloat32(BE) + assert.NoError(t, err) + assert.Equal(t, float32(1.32), n) + assert.Equal(t, 4, d.Remaining()) + + n, err = d.ReadFloat32(BE) + assert.NoError(t, err) + assert.Equal(t, float32(-3.21), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_float64(t *testing.T) { + // little endian + buf := []byte{ + 0x3d, 0x0a, 0xd7, 0xa3, 0x70, 0x1d, 0x4f, 0xc0, + 0x77, 0xbe, 0x9f, 0x1a, 0x2f, 0x3d, 0x37, 0x40, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x7f, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf8, 0x7f, + } + + d := NewBinDecoder(buf) + + n, err := d.ReadFloat64(LE) + assert.NoError(t, err) + assert.Equal(t, float64(-62.23), n) + assert.Equal(t, 32, d.Remaining()) + + n, err = d.ReadFloat64(LE) + assert.NoError(t, err) + assert.Equal(t, float64(23.239), n) + assert.Equal(t, 24, d.Remaining()) + + n, err = d.ReadFloat64(LE) + assert.NoError(t, err) + assert.Equal(t, math.Inf(1), n) + assert.Equal(t, 16, d.Remaining()) + + n, err = d.ReadFloat64(LE) + assert.NoError(t, err) + assert.Equal(t, math.Inf(-1), n) + assert.Equal(t, 8, d.Remaining()) + + n, err = d.ReadFloat64(LE) + assert.NoError(t, err) + assert.True(t, math.IsNaN(n)) + + // big endian + buf = []byte{ + 0xc0, 0x4f, 0x1d, 0x70, 0xa3, 0xd7, 0x0a, 0x3d, + 0x40, 0x37, 0x3d, 0x2f, 0x1a, 0x9f, 0xbe, 0x77, + 0x7f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x7f, 0xf8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + } + + d = NewBinDecoder(buf) + + n, err = d.ReadFloat64(BE) + assert.NoError(t, err) + assert.Equal(t, float64(-62.23), n) + assert.Equal(t, 32, d.Remaining()) + + n, err = d.ReadFloat64(BE) + assert.NoError(t, err) + assert.Equal(t, float64(23.239), n) + assert.Equal(t, 24, d.Remaining()) + + n, err = d.ReadFloat64(BE) + assert.NoError(t, err) + assert.Equal(t, math.Inf(1), n) + assert.Equal(t, 16, d.Remaining()) + + n, err = d.ReadFloat64(BE) + assert.NoError(t, err) + assert.Equal(t, math.Inf(-1), n) + assert.Equal(t, 8, d.Remaining()) + + n, err = d.ReadFloat64(BE) + assert.NoError(t, err) + assert.True(t, math.IsNaN(n)) +} + +func TestDecoder_string(t *testing.T) { + buf := []byte{ + 0x03, 0x31, 0x32, 0x33, // "123" + 0x00, // "" + 0x03, 0x61, 0x62, 0x63, // "abc + } + + d := NewBinDecoder(buf) + + s, err := d.ReadString() + assert.NoError(t, err) + assert.Equal(t, "123", s) + assert.Equal(t, 5, d.Remaining()) + + s, err = d.ReadString() + assert.NoError(t, err) + assert.Equal(t, "", s) + assert.Equal(t, 4, d.Remaining()) + + s, err = d.ReadString() + assert.NoError(t, err) + assert.Equal(t, "abc", s) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_Decode_String_Err(t *testing.T) { + buf := []byte{ + 0x01, 0x00, 0x00, 0x00, + byte('a'), + } + + decoder := NewBinDecoder(buf) + + var s string + err := decoder.Decode(&s) + assert.EqualError(t, err, "decode: uint64 required [8] bytes, remaining [5]") +} + +func TestDecoder_Byte(t *testing.T) { + buf := []byte{ + 0x00, 0x01, + } + + d := NewBinDecoder(buf) + + n, err := d.ReadByte() + assert.NoError(t, err) + assert.Equal(t, byte(0), n) + assert.Equal(t, 1, d.Remaining()) + + n, err = d.ReadByte() + assert.NoError(t, err) + assert.Equal(t, byte(1), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_Bool(t *testing.T) { + buf := []byte{ + 0x01, 0x00, + } + + d := NewBinDecoder(buf) + + n, err := d.ReadBool() + assert.NoError(t, err) + assert.Equal(t, true, n) + assert.Equal(t, 1, d.Remaining()) + + n, err = d.ReadBool() + assert.NoError(t, err) + assert.Equal(t, false, n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_ByteArray(t *testing.T) { + buf := []byte{ + 0x03, 0x01, 0x02, 0x03, + 0x03, 0x04, 0x05, 0x06, + } + + d := NewBinDecoder(buf) + + data, err := d.ReadByteSlice() + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3}, data) + assert.Equal(t, 4, d.Remaining()) + + data, err = d.ReadByteSlice() + assert.Equal(t, []byte{4, 5, 6}, data) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_ByteArray_MissingData(t *testing.T) { + buf := []byte{ + 0x0a, + } + + d := NewBinDecoder(buf) + + _, err := d.ReadByteSlice() + assert.EqualError(t, err, "byte array: varlen=10, missing 10 bytes") +} + +func TestDecoder_Array(t *testing.T) { + buf := []byte{1, 2, 4} + + decoder := NewBinDecoder(buf) + + var decoded [3]byte + decoder.Decode(&decoded) + assert.Equal(t, [3]byte{1, 2, 4}, decoded) +} + +func TestDecoder_Slice_Err(t *testing.T) { + buf := []byte{} + + decoder := NewBinDecoder(buf) + var s []string + err := decoder.Decode(&s) + assert.Equal(t, ErrVarIntBufferSize, err) + + buf = []byte{0x01} + + decoder = NewBinDecoder(buf) + err = decoder.Decode(&s) + assert.EqualError(t, err, "unexpected EOF") +} + +func TestDecoder_Slice_InvalidLen(t *testing.T) { + buf := []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01} + + decoder := NewBinDecoder(buf) + var s []string + err := decoder.Decode(&s) + assert.EqualError(t, err, "unexpected EOF") +} + +func TestDecoder_Int64(t *testing.T) { + // little endian + buf := []byte{ + 0x91, 0x7d, 0xf3, 0xff, 0xff, 0xff, 0xff, 0xff, //-819823 + 0xe3, 0x1c, 0x1, 0x00, 0x00, 0x00, 0x00, 0x00, //72931 + } + + d := NewBinDecoder(buf) + + n, err := d.ReadInt64(LE) + assert.NoError(t, err) + assert.Equal(t, int64(-819823), n) + assert.Equal(t, 8, d.Remaining()) + + n, err = d.ReadInt64(LE) + assert.NoError(t, err) + assert.Equal(t, int64(72931), n) + assert.Equal(t, 0, d.Remaining()) + + // big endian + buf = []byte{ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xf3, 0x7d, 0x91, //-819823 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x1, 0x1c, 0xe3, //72931 + } + + d = NewBinDecoder(buf) + + n, err = d.ReadInt64(BE) + assert.NoError(t, err) + assert.Equal(t, int64(-819823), n) + assert.Equal(t, 8, d.Remaining()) + + n, err = d.ReadInt64(BE) + assert.NoError(t, err) + assert.Equal(t, int64(72931), n) + assert.Equal(t, 0, d.Remaining()) +} + +func TestDecoder_Uint128_2(t *testing.T) { + // little endian + buf := []byte{ + 0x0d, 0x88, 0xd3, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x6d, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + + d := NewBinDecoder(buf) + + n, err := d.ReadUint128(LE) + assert.NoError(t, err) + assert.Equal(t, Uint128{Hi: 0xb6d, Lo: 0xffffffffffd3880d}, n) + + buf = []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0xbb, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xac, 0xdc, 0xad, + } + + d = NewBinDecoder(buf) + + n, err = d.ReadUint128(BE) + assert.NoError(t, err) + assert.Equal(t, Uint128{Hi: 0x00000000000008bb, Lo: 0xffffffffffacdcad}, n) + +} + +func TestDecoder_BinaryStruct(t *testing.T) { + cnt, err := hex.DecodeString("0300000000000000616263b5ff630019ffffffe703000051ccffffffffffff9f860100000000003d0ab9c15c8fc2f5285c0f4002030000000000000064656603000000000000003738390300000000000000666f6f0300000000000000626172ff05010203040501e9ffffffffffffff17000000000000001f85eb51b81e09400a000000000000005200000000000000070000000000000003000000000000000a000000000000005200000000000000e707cd0f01050102030405") + require.NoError(t, err) + + s := binaryTestStruct{} + decoder := NewBinDecoder(cnt) + assert.NoError(t, decoder.Decode(&s)) + + assert.Equal(t, "abc", s.F1) + assert.Equal(t, int16(-75), s.F2) + assert.Equal(t, uint16(99), s.F3) + assert.Equal(t, int32(-231), s.F4) + assert.Equal(t, uint32(999), s.F5) + assert.Equal(t, int64(-13231), s.F6) + assert.Equal(t, uint64(99999), s.F7) + assert.Equal(t, float32(-23.13), s.F8) + assert.Equal(t, float64(3.92), s.F9) + assert.Equal(t, []string{"def", "789"}, s.F10) + assert.Equal(t, [2]string{"foo", "bar"}, s.F11) + assert.Equal(t, uint8(0xff), s.F12) + assert.Equal(t, []byte{1, 2, 3, 4, 5}, s.F13) + assert.Equal(t, true, s.F14) + assert.Equal(t, Int64(-23), s.F15) + assert.Equal(t, Uint64(23), s.F16) + assert.Equal(t, JSONFloat64(3.14), s.F17) + assert.Equal(t, Uint128{ + Lo: 10, + Hi: 82, + }, s.F18) + assert.Equal(t, Int128{ + Lo: 7, + Hi: 3, + }, s.F19) + assert.Equal(t, Float128{ + Lo: 10, + Hi: 82, + }, s.F20) + assert.Equal(t, Varuint32(999), s.F21) + assert.Equal(t, Varint32(-999), s.F22) + assert.Equal(t, Bool(true), s.F23) + assert.Equal(t, HexBytes([]byte{1, 2, 3, 4, 5}), s.F24) +} + +func TestDecoder_Decode_No_Ptr(t *testing.T) { + decoder := NewBinDecoder([]byte{}) + err := decoder.Decode(1) + assert.EqualError(t, err, "decoder: Decode(non-pointer int)") +} + +func TestDecoder_BinaryTestStructWithTags(t *testing.T) { + cnt, err := hex.DecodeString("ffb50063ffffff19000003e7ffffffffffffcc51000000000001869fc1b90a3d400f5c28f5c28f5c010000000000000000") + require.NoError(t, err) + + s := &binaryTestStructWithTags{} + decoder := NewBinDecoder(cnt) + assert.NoError(t, decoder.Decode(s)) + + assert.Equal(t, "", s.F1) + assert.Equal(t, int16(-75), s.F2) + assert.Equal(t, uint16(99), s.F3) + assert.Equal(t, int32(-231), s.F4) + assert.Equal(t, uint32(999), s.F5) + assert.Equal(t, int64(-13231), s.F6) + assert.Equal(t, uint64(99999), s.F7) + assert.Equal(t, float32(-23.13), s.F8) + assert.Equal(t, float64(3.92), s.F9) + assert.Equal(t, true, s.F10) + var i *Int64 + assert.Equal(t, i, s.F11) +} + +func TestDecoder_SkipBytes(t *testing.T) { + buf := []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + decoder := NewBinDecoder(buf) + err := decoder.SkipBytes(1) + require.NoError(t, err) + require.Equal(t, 7, decoder.Remaining()) + + err = decoder.SkipBytes(2) + require.NoError(t, err) + require.Equal(t, 5, decoder.Remaining()) + + err = decoder.SkipBytes(6) + require.Error(t, err) + + err = decoder.SkipBytes(5) + require.NoError(t, err) + require.Equal(t, 0, decoder.Remaining()) +} + +func Test_Discard(t *testing.T) { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + decoder := NewBinDecoder(buf) + err := decoder.Discard(5) + require.NoError(t, err) + require.Equal(t, 5, decoder.Remaining()) + remaining, err := decoder.Peek(decoder.Remaining()) + require.NoError(t, err) + require.Equal(t, []byte{5, 6, 7, 8, 9}, remaining) +} + +func Test_reflect_readArrayOfBytes(t *testing.T) { + { + { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7} + decoder := NewBinDecoder(buf) + + got := make([]byte, 0) + err := reflect_readArrayOfBytes(decoder, len(buf), reflect.ValueOf(&got).Elem()) + require.NoError(t, err) + require.Equal(t, buf, got) + } + { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7} + decoder := NewBinDecoder(buf) + + got := [8]byte{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfBytes(decoder, len(buf), reflect.ValueOf(&got).Elem()) + require.NoError(t, err) + require.Equal(t, buf, got[:]) + } + } + { + { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7} + decoder := NewBorshDecoder(buf) + + got := make([]byte, 0) + err := reflect_readArrayOfBytes(decoder, len(buf), reflect.ValueOf(&got).Elem()) + require.NoError(t, err) + require.Equal(t, buf, got) + } + { + buf := []byte{0, 1, 2, 3, 4, 5, 6, 7} + decoder := NewBorshDecoder(buf) + + got := [8]byte{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfBytes(decoder, len(buf), reflect.ValueOf(&got).Elem()) + require.NoError(t, err) + require.Equal(t, buf, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint16(t *testing.T) { + { + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint16, 0) + err := reflect_readArrayOfUint16(decoder, len(buf)/2, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint16{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := [8]uint16{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint16(decoder, len(buf)/2, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint16{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } + { + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := make([]uint16, 0) + err := reflect_readArrayOfUint16(decoder, len(buf)/2, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint16{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := [8]uint16{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint16(decoder, len(buf)/2, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint16{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint32(t *testing.T) { + { + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint32, 0) + err := reflect_readArrayOfUint32(decoder, len(buf)/4, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint32(decoder, len(buf)/4, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } + { + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := make([]uint32, 0) + err := reflect_readArrayOfUint32(decoder, len(buf)/4, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint32(decoder, len(buf)/4, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint64(t *testing.T) { + { + { + buf := concatByteSlices( + uint64ToBytes(0, LE), + uint64ToBytes(1, LE), + uint64ToBytes(2, LE), + uint64ToBytes(3, LE), + uint64ToBytes(4, LE), + uint64ToBytes(5, LE), + uint64ToBytes(6, LE), + uint64ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint64, 0) + err := reflect_readArrayOfUint64(decoder, len(buf)/8, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint64{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint64ToBytes(0, LE), + uint64ToBytes(1, LE), + uint64ToBytes(2, LE), + uint64ToBytes(3, LE), + uint64ToBytes(4, LE), + uint64ToBytes(5, LE), + uint64ToBytes(6, LE), + uint64ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + got := [8]uint64{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint64(decoder, len(buf)/8, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint64{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } + { + { + buf := concatByteSlices( + uint64ToBytes(0, LE), + uint64ToBytes(1, LE), + uint64ToBytes(2, LE), + uint64ToBytes(3, LE), + uint64ToBytes(4, LE), + uint64ToBytes(5, LE), + uint64ToBytes(6, LE), + uint64ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := make([]uint64, 0) + err := reflect_readArrayOfUint64(decoder, len(buf)/8, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint64{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint64ToBytes(0, LE), + uint64ToBytes(1, LE), + uint64ToBytes(2, LE), + uint64ToBytes(3, LE), + uint64ToBytes(4, LE), + uint64ToBytes(5, LE), + uint64ToBytes(6, LE), + uint64ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + got := [8]uint64{0, 0, 0, 0, 0, 0, 0, 0} + err := reflect_readArrayOfUint64(decoder, len(buf)/8, reflect.ValueOf(&got).Elem(), LE) + require.NoError(t, err) + require.Equal(t, []uint64{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint(t *testing.T) { + { + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint32, 0) + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + err := reflect_readArrayOfUint_(decoder, len(buf)/4, k, rv, LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + rv := reflect.ValueOf(&got).Elem() + k := rv.Type().Elem().Kind() + err := reflect_readArrayOfUint_(decoder, len(buf)/4, k, rv, LE) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_Decode_custom(t *testing.T) { + t.Run("custom-type-uint32 slice", func(t *testing.T) { + { + buf := concatByteSlices( + // length: + []byte{3}, + // data: + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + ) + decoder := NewBinDecoder(buf) + + type CustomUint32 uint32 + got := make([]CustomUint32, 0) + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, []CustomUint32{0, 1, 2}, got) + } + }) + t.Run("custom-type-uint32 array", func(t *testing.T) { + { + buf := concatByteSlices( + // data: + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + ) + decoder := NewBinDecoder(buf) + + type CustomUint32 uint32 + got := [3]CustomUint32{0, 0, 0} + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, [3]CustomUint32{0, 1, 2}, got) + } + }) + t.Run("uint32 custom-type-slice", func(t *testing.T) { + { + buf := concatByteSlices( + // length: + []byte{3}, + // data: + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + ) + decoder := NewBinDecoder(buf) + + type CustomSliceUint32 []uint32 + got := make(CustomSliceUint32, 0) + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, CustomSliceUint32{0, 1, 2}, got) + } + }) + t.Run("uint32 custom-type-array", func(t *testing.T) { + { + buf := concatByteSlices( + // data: + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + ) + decoder := NewBinDecoder(buf) + + type CustomArrayUint32 [3]uint32 + got := CustomArrayUint32{0, 0, 0} + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, CustomArrayUint32{0, 1, 2}, got) + } + }) +} + +func Test_ReadNBytes(t *testing.T) { + { + b1 := []byte{123, 99, 88, 77, 66, 55, 44, 33, 22, 11} + b2 := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + buf := concatByteSlices( + b1, + b2, + ) + decoder := NewBinDecoder(buf) + + got, err := decoder.ReadNBytes(10) + require.NoError(t, err) + require.Equal(t, b1, got) + + got, err = decoder.ReadNBytes(10) + require.NoError(t, err) + require.Equal(t, b2, got) + } +} + +func Test_ReadBytes(t *testing.T) { + { + b1 := []byte{123, 99, 88, 77, 66, 55, 44, 33, 22, 11} + b2 := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + buf := concatByteSlices( + b1, + b2, + ) + decoder := NewBinDecoder(buf) + + got, err := decoder.ReadBytes(10) + require.NoError(t, err) + require.Equal(t, b1, got) + + got, err = decoder.ReadBytes(10) + require.NoError(t, err) + require.Equal(t, b2, got) + } +} + +func Test_Read(t *testing.T) { + { + b1 := []byte{123, 99, 88, 77, 66, 55, 44, 33, 22, 11} + b2 := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + buf := concatByteSlices( + b1, + b2, + ) + decoder := NewBinDecoder(buf) + + { + got := make([]byte, 10) + num, err := decoder.Read(got) + require.NoError(t, err) + require.Equal(t, b1, got) + require.Equal(t, 10, num) + } + + { + got := make([]byte, 10) + num, err := decoder.Read(got) + require.NoError(t, err) + require.Equal(t, b2, got) + require.Equal(t, 10, num) + } + { + got := make([]byte, 11) + _, err := decoder.Read(got) + require.EqualError(t, err, "short buffer") + } + { + got := make([]byte, 0) + num, err := decoder.Read(got) + require.NoError(t, err) + require.Equal(t, 0, num) + require.Equal(t, []byte{}, got) + } + } +} + +func Test_Decode_readArrayOfUint(t *testing.T) { + { + { + buf := concatByteSlices( + // length: + []byte{3}, + // data: + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + ) + decoder := NewBinDecoder(buf) + + got := make([]uint32, 0) + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } + { + { + buf := concatByteSlices( + // length: + uint32ToBytes(8, LE), + // data: + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + got := make([]uint32, 0) + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got) + } + { + buf := concatByteSlices( + uint32ToBytes(0, LE), + uint32ToBytes(1, LE), + uint32ToBytes(2, LE), + uint32ToBytes(3, LE), + uint32ToBytes(4, LE), + uint32ToBytes(5, LE), + uint32ToBytes(6, LE), + uint32ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + got := [8]uint32{0, 0, 0, 0, 0, 0, 0, 0} + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, got[:]) + } + } +} + +func Test_reflect_readArrayOfUint16_asField(t *testing.T) { + { + { + buf := concatByteSlices( + // length: + []byte{8}, + // data: + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + type S struct { + Val []uint16 + } + var got S + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, S{[]uint16{0, 1, 2, 3, 4, 5, 6, 7}}, got) + } + { + buf := concatByteSlices( + // data: + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBinDecoder(buf) + + type S struct { + Val [8]uint16 + } + var got S + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, S{[8]uint16{0, 1, 2, 3, 4, 5, 6, 7}}, got) + } + } + { + { + buf := concatByteSlices( + // length: + uint32ToBytes(8, LE), + // data: + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + type S struct { + Val []uint16 + } + var got S + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, S{[]uint16{0, 1, 2, 3, 4, 5, 6, 7}}, got) + } + { + buf := concatByteSlices( + uint16ToBytes(0, LE), + uint16ToBytes(1, LE), + uint16ToBytes(2, LE), + uint16ToBytes(3, LE), + uint16ToBytes(4, LE), + uint16ToBytes(5, LE), + uint16ToBytes(6, LE), + uint16ToBytes(7, LE), + ) + decoder := NewBorshDecoder(buf) + + type S struct { + Val [8]uint16 + } + var got S + err := decoder.Decode(&got) + require.NoError(t, err) + require.Equal(t, S{[8]uint16{0, 1, 2, 3, 4, 5, 6, 7}}, got) + } + } +} diff --git a/binary/encoder.go b/binary/encoder.go new file mode 100644 index 000000000..754f95080 --- /dev/null +++ b/binary/encoder.go @@ -0,0 +1,539 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "reflect" + "unsafe" + + "go.uber.org/zap" +) + +type Encoder struct { + count int + + // currentFieldOpt is held by value (not pointer) so it doesn't escape. + // Same role as Decoder.currentFieldOpt — gives nested types like Uint128 + // access to the active byte order. + currentFieldOpt option + encoding Encoding + + // output is the destination io.Writer. May be nil when the Encoder is + // running in buffered mode (see NewBinEncoderBuf etc.) — in that case all + // writes accumulate in `buf` and the caller retrieves them via Bytes() + // or WriteTo. + output io.Writer + buf []byte + + // scratch is a per-Encoder staging buffer reused across primitive writes + // so WriteUint*/WriteFloat*/WriteVarInt/WriteCompactU16/... don't allocate. + // 16 bytes is enough for any fixed-width primitive (Uint128) and for a + // Uvarint/Varint header (max 10 bytes) and for a CompactU16 (max 3 bytes). + // Safe to reuse: io.Writer.Write must not retain p after return. + scratch [16]byte + + // skipMarshalerCheck tells encodeBin/encodeBorsh/encodeCompactU16 to + // skip the per-call asBinaryMarshaler() type assertion. encodeStructBin + // (and friends) sets this to true before encoding a field whose typePlan + // has already proven that neither the value nor the pointer type + // implements BinaryMarshaler. The flag is propagated through Ptr.Elem + // recursion (because *T not implementing means T doesn't either), and + // reset around array/slice element loops where elements are independent + // types. Without this flag, the rv.Interface() boxing dominates encode + // allocations for non-marshaler types like solana.PublicKey. + skipMarshalerCheck bool +} + +func (enc *Encoder) IsBorsh() bool { + return enc.encoding.IsBorsh() +} + +func (enc *Encoder) IsBin() bool { + return enc.encoding.IsBin() +} + +func (enc *Encoder) IsCompactU16() bool { + return enc.encoding.IsCompactU16() +} + +func NewEncoderWithEncoding(writer io.Writer, enc Encoding) *Encoder { + if !isValidEncoding(enc) { + panic(fmt.Sprintf("provided encoding is not valid: %s", enc)) + } + return &Encoder{ + output: writer, + count: 0, + encoding: enc, + } +} + +func NewBinEncoder(writer io.Writer) *Encoder { + return NewEncoderWithEncoding(writer, EncodingBin) +} + +func NewBorshEncoder(writer io.Writer) *Encoder { + return NewEncoderWithEncoding(writer, EncodingBorsh) +} + +func NewCompactU16Encoder(writer io.Writer) *Encoder { + return NewEncoderWithEncoding(writer, EncodingCompactU16) +} + +// NewBufferedEncoder returns an Encoder that writes into an internal []byte +// buffer instead of an io.Writer. Use Bytes() to retrieve the encoded payload +// and Reset()/Bytes() to reuse the encoder across multiple messages. +// +// This is the lowest-overhead encode mode: every primitive write becomes an +// `append(e.buf, ...)` with no interface dispatch and no per-call allocation. +func NewBufferedEncoder(enc Encoding) *Encoder { + if !isValidEncoding(enc) { + panic(fmt.Sprintf("provided encoding is not valid: %s", enc)) + } + return &Encoder{encoding: enc} +} + +func NewBinEncoderBuf() *Encoder { return NewBufferedEncoder(EncodingBin) } +func NewBorshEncoderBuf() *Encoder { return NewBufferedEncoder(EncodingBorsh) } +func NewCompactU16EncoderBuf() *Encoder { return NewBufferedEncoder(EncodingCompactU16) } + +// Bytes returns the encoded payload accumulated in buffered mode. The slice +// aliases the encoder's internal buffer; copy it if you need to retain it +// across a Reset() / further writes. +func (e *Encoder) Bytes() []byte { + return e.buf +} + +// Reset clears the encoder's internal state (count, buffer, current option) +// so it can be reused for another message. The output writer is preserved. +func (e *Encoder) Reset() { + e.count = 0 + e.buf = e.buf[:0] + e.currentFieldOpt = option{} + e.skipMarshalerCheck = false +} + +// Grow ensures the internal buffer has at least n free bytes available. +// Useful in buffered mode to amortize append-driven growth when the encoded +// size is known in advance. +func (e *Encoder) Grow(n int) { + if cap(e.buf)-len(e.buf) >= n { + return + } + nb := make([]byte, len(e.buf), len(e.buf)+n) + copy(nb, e.buf) + e.buf = nb +} + +func (e *Encoder) Encode(v interface{}) (err error) { + switch e.encoding { + case EncodingBin: + return e.encodeBin(reflect.ValueOf(v), defaultOption) + case EncodingBorsh: + return e.encodeBorsh(reflect.ValueOf(v), defaultOption) + case EncodingCompactU16: + return e.encodeCompactU16(reflect.ValueOf(v), defaultOption) + default: + panic(fmt.Errorf("encoding not implemented: %s", e.encoding)) + } +} + +func (e *Encoder) toWriter(bytes []byte) (err error) { + e.count += len(bytes) + if traceEnabled { + zlog.Debug(" > encode: appending", zap.Stringer("hex", HexBytes(bytes)), zap.Int("pos", e.count)) + } + if e.output == nil { + e.buf = append(e.buf, bytes...) + return nil + } + _, err = e.output.Write(bytes) + return +} + +// Written returns the count of bytes written. +func (e *Encoder) Written() int { + return e.count +} + +func (e *Encoder) WriteBytes(b []byte, writeLength bool) error { + if traceEnabled { + zlog.Debug("encode: write byte array", zap.Int("len", len(b))) + } + if writeLength { + if err := e.WriteLength(len(b)); err != nil { + return err + } + } + if len(b) == 0 { + return nil + } + return e.toWriter(b) +} + +func (e *Encoder) Write(b []byte) (n int, err error) { + // Route through toWriter so buffered Encoders (output == nil) append to + // e.buf instead of nil-derefing. Matches WriteBytes semantics. + if err := e.toWriter(b); err != nil { + return 0, err + } + return len(b), nil +} + +func (e *Encoder) WriteLength(length int) error { + if traceEnabled { + zlog.Debug("encode: write length", zap.Int("len", length)) + } + switch e.encoding { + case EncodingBin: + if err := e.WriteUVarInt(length); err != nil { + return err + } + case EncodingBorsh: + if err := e.WriteUint32(uint32(length), LE); err != nil { + return err + } + case EncodingCompactU16: + n, err := PutCompactU16Length(e.scratch[:3], length) + if err != nil { + return err + } + if err := e.toWriter(e.scratch[:n]); err != nil { + return err + } + default: + panic(fmt.Errorf("encoding not implemented: %s", e.encoding)) + } + return nil +} + +func (e *Encoder) WriteUVarInt(v int) (err error) { + if traceEnabled { + zlog.Debug("encode: write uvarint", zap.Int("val", v)) + } + l := binary.PutUvarint(e.scratch[:], uint64(v)) + return e.toWriter(e.scratch[:l]) +} + +func (e *Encoder) WriteVarInt(v int) (err error) { + if traceEnabled { + zlog.Debug("encode: write varint", zap.Int("val", v)) + } + l := binary.PutVarint(e.scratch[:], int64(v)) + return e.toWriter(e.scratch[:l]) +} + +func (e *Encoder) WriteByte(b byte) (err error) { + if traceEnabled { + zlog.Debug("encode: write byte", zap.Uint8("val", b)) + } + e.scratch[0] = b + return e.toWriter(e.scratch[:1]) +} + +func (e *Encoder) WriteOption(b bool) (err error) { + if traceEnabled { + zlog.Debug("encode: write option", zap.Bool("val", b)) + } + return e.WriteBool(b) +} + +func (e *Encoder) WriteCOption(b bool) (err error) { + if traceEnabled { + zlog.Debug("encode: write c-option", zap.Bool("val", b)) + } + var num uint32 + if b { + num = 1 + } + return e.WriteUint32(num, LE) +} + +func (e *Encoder) WriteBool(b bool) (err error) { + if traceEnabled { + zlog.Debug("encode: write bool", zap.Bool("val", b)) + } + var out byte + if b { + out = 1 + } + return e.WriteByte(out) +} + +func (e *Encoder) WriteUint8(i uint8) (err error) { + return e.WriteByte(i) +} + +func (e *Encoder) WriteInt8(i int8) (err error) { + return e.WriteByte(uint8(i)) +} + +func (e *Encoder) WriteUint16(i uint16, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write uint16", zap.Uint16("val", i)) + } + order.PutUint16(e.scratch[:2], i) + return e.toWriter(e.scratch[:2]) +} + +func (e *Encoder) WriteInt16(i int16, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write int16", zap.Int16("val", i)) + } + return e.WriteUint16(uint16(i), order) +} + +func (e *Encoder) WriteUint32(i uint32, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write uint32", zap.Uint32("val", i)) + } + order.PutUint32(e.scratch[:4], i) + return e.toWriter(e.scratch[:4]) +} + +func (e *Encoder) WriteInt32(i int32, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write int32", zap.Int32("val", i)) + } + return e.WriteUint32(uint32(i), order) +} + +func (e *Encoder) WriteUint64(i uint64, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write uint64", zap.Uint64("val", i)) + } + order.PutUint64(e.scratch[:8], i) + return e.toWriter(e.scratch[:8]) +} + +func (e *Encoder) WriteInt64(i int64, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write int64", zap.Int64("val", i)) + } + return e.WriteUint64(uint64(i), order) +} + +func (e *Encoder) WriteUint128(i Uint128, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write uint128", zap.Stringer("hex", i), zap.Uint64("lo", i.Lo), zap.Uint64("hi", i.Hi)) + } + buf := e.scratch[:16] + switch order { + case binary.LittleEndian: + order.PutUint64(buf[:8], i.Lo) + order.PutUint64(buf[8:], i.Hi) + case binary.BigEndian: + order.PutUint64(buf[:8], i.Hi) + order.PutUint64(buf[8:], i.Lo) + default: + return fmt.Errorf("invalid byte order: %v", order) + } + return e.toWriter(buf) +} + +func (e *Encoder) WriteInt128(i Int128, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write int128", zap.Stringer("hex", i), zap.Uint64("lo", i.Lo), zap.Uint64("hi", i.Hi)) + } + buf := e.scratch[:16] + switch order { + case binary.LittleEndian: + order.PutUint64(buf[:8], i.Lo) + order.PutUint64(buf[8:], i.Hi) + case binary.BigEndian: + order.PutUint64(buf[:8], i.Hi) + order.PutUint64(buf[8:], i.Lo) + default: + return fmt.Errorf("invalid byte order: %v", order) + } + return e.toWriter(buf) +} + +func (e *Encoder) WriteFloat32(f float32, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write float32", zap.Float32("val", f)) + } + + if e.IsBorsh() { + if math.IsNaN(float64(f)) { + return errors.New("NaN float value") + } + } + + order.PutUint32(e.scratch[:4], math.Float32bits(f)) + return e.toWriter(e.scratch[:4]) +} + +func (e *Encoder) WriteFloat64(f float64, order binary.ByteOrder) (err error) { + if traceEnabled { + zlog.Debug("encode: write float64", zap.Float64("val", f)) + } + + if e.IsBorsh() { + if math.IsNaN(float64(f)) { + return errors.New("NaN float value") + } + } + order.PutUint64(e.scratch[:8], math.Float64bits(f)) + return e.toWriter(e.scratch[:8]) +} + +func (e *Encoder) WriteString(s string) (err error) { + if traceEnabled { + zlog.Debug("encode: write string", zap.String("val", s)) + } + return e.WriteBytes([]byte(s), true) +} + +func (e *Encoder) WriteRustString(s string) (err error) { + err = e.WriteUint64(uint64(len(s)), binary.LittleEndian) + if err != nil { + return err + } + if traceEnabled { + zlog.Debug("encode: write Rust string", zap.String("val", s)) + } + return e.WriteBytes([]byte(s), false) +} + +func (e *Encoder) WriteCompactU16(ln int) (err error) { + if traceEnabled { + zlog.Debug("encode: write compact-u16", zap.Int("val", ln)) + } + n, err := PutCompactU16Length(e.scratch[:3], ln) + if err != nil { + return err + } + return e.toWriter(e.scratch[:n]) +} + +func (e *Encoder) WriteCompactU16Length(ln int) (err error) { + return e.WriteCompactU16(ln) +} + +// writePoDSliceBytes is the encoder analog of readPoDSliceBytes. When the +// destination is addressable AND the host and wire byte orders match, it +// sends the destination's backing memory straight through toWriter via a +// single byte-view — no make, no per-element reflect.Index/Uint loop. +// +// When rv is not addressable (caller passed a struct by value), we fall back +// to a per-element reflect.Index path that uses an intermediate scratch +// slice. This is the same behaviour as the pre-PoD-fast-path implementation. +// +// elemSize must be 1, 2, 4, or 8. +func writePoDSliceBytes(e *Encoder, rv reflect.Value, l, elemSize int, order binary.ByteOrder) error { + if l == 0 { + return nil + } + need := l * elemSize + + if rv.CanAddr() && rv.Len() > 0 { + base := unsafe.Pointer(rv.Index(0).UnsafeAddr()) + + if elemSize == 1 || (isHostLittleEndian && order == binary.LittleEndian) { + // Single memcpy from the slice/array's backing storage. WriteBytes + // (or the buffered append in toWriter) will copy the bytes onward. + return e.toWriter(unsafe.Slice((*byte)(base), need)) + } + + // Byte-order mismatch: serialize element-by-element into a temporary + // scratch slice. We still avoid going through reflect.Index for each + // element by reading directly from the backing memory. + tmp := make([]byte, need) + switch elemSize { + case 2: + for i := range l { + order.PutUint16(tmp[i*2:], *(*uint16)(unsafe.Add(base, i*2))) + } + case 4: + for i := range l { + order.PutUint32(tmp[i*4:], *(*uint32)(unsafe.Add(base, i*4))) + } + case 8: + for i := range l { + order.PutUint64(tmp[i*8:], *(*uint64)(unsafe.Add(base, i*8))) + } + } + return e.toWriter(tmp) + } + + // Fallback: rv is not addressable. Walk via reflect.Index — slower but + // correct for callers that pass arrays by value. + tmp := make([]byte, need) + switch elemSize { + case 1: + for i := range l { + tmp[i] = byte(rv.Index(i).Uint()) + } + case 2: + for i := range l { + order.PutUint16(tmp[i*2:], uint16(rv.Index(i).Uint())) + } + case 4: + for i := range l { + order.PutUint32(tmp[i*4:], uint32(rv.Index(i).Uint())) + } + case 8: + for i := range l { + order.PutUint64(tmp[i*8:], rv.Index(i).Uint()) + } + } + return e.toWriter(tmp) +} + +func reflect_writeArrayOfBytes(e *Encoder, l int, rv reflect.Value) error { + return writePoDSliceBytes(e, rv, l, 1, binary.LittleEndian) +} + +func reflect_writeArrayOfUint16(e *Encoder, l int, rv reflect.Value, order binary.ByteOrder) error { + return writePoDSliceBytes(e, rv, l, 2, order) +} + +func reflect_writeArrayOfUint32(e *Encoder, l int, rv reflect.Value, order binary.ByteOrder) error { + return writePoDSliceBytes(e, rv, l, 4, order) +} + +func reflect_writeArrayOfUint64(e *Encoder, l int, rv reflect.Value, order binary.ByteOrder) error { + return writePoDSliceBytes(e, rv, l, 8, order) +} + +// reflect_writeArrayOfUint_ is used for writing arrays/slices of uints of any size. +func reflect_writeArrayOfUint_(e *Encoder, l int, k reflect.Kind, rv reflect.Value, order binary.ByteOrder) error { + switch k { + // case reflect.Uint: + // // switch on system architecture (32 or 64 bit) + // if unsafe.Sizeof(uintptr(0)) == 4 { + // return reflect_writeArrayOfUint32(e, l, rv, order) + // } + // return reflect_writeArrayOfUint64(e, l, rv, order) + case reflect.Uint8: + return reflect_writeArrayOfBytes(e, l, rv) + case reflect.Uint16: + return reflect_writeArrayOfUint16(e, l, rv, order) + case reflect.Uint32: + return reflect_writeArrayOfUint32(e, l, rv, order) + case reflect.Uint64: + return reflect_writeArrayOfUint64(e, l, rv, order) + default: + return fmt.Errorf("unsupported kind: %v", k) + } +} diff --git a/binary/encoder_bin.go b/binary/encoder_bin.go new file mode 100644 index 000000000..7a7edf1c7 --- /dev/null +++ b/binary/encoder_bin.go @@ -0,0 +1,307 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/binary" + "fmt" + "reflect" + + "go.uber.org/zap" +) + +func (e *Encoder) encodeBin(rv reflect.Value, opt option) (err error) { + if opt.Order == nil { + opt.Order = defaultByteOrder + } + e.currentFieldOpt = opt + + if traceEnabled { + zlog.Debug("encode: type", + zap.Stringer("value_kind", rv.Kind()), + zap.Reflect("options", opt), + ) + } + + if opt.is_Optional() { + if rv.IsZero() { + if traceEnabled { + zlog.Debug("encode: skipping optional value with", zap.Stringer("type", rv.Kind())) + } + return e.WriteUint32(0, binary.LittleEndian) + } + err := e.WriteUint32(1, binary.LittleEndian) + if err != nil { + return err + } + // The optionality has been used; stop its propagation: + opt.is_OptionalField = false + } + + if isInvalidValue(rv) { + return nil + } + + // Skip the asBinaryMarshaler boxing call when encodeStructBin has + // proven via the cached fieldPlan that neither the value nor the + // pointer type implements BinaryMarshaler. This is the dominant + // allocation site for hot encode loops on non-marshaler types + // (every field used to box rv via rv.Interface() to test the + // assertion). + if !e.skipMarshalerCheck { + if marshaler, ok := asBinaryMarshaler(rv); ok { + if traceEnabled { + zlog.Debug("encode: using MarshalerBinary method to encode type") + } + return marshaler.MarshalWithEncoder(e) + } + } + + switch rv.Kind() { + case reflect.String: + return e.WriteRustString(rv.String()) + case reflect.Uint8: + return e.WriteByte(byte(rv.Uint())) + case reflect.Int8: + return e.WriteByte(byte(rv.Int())) + case reflect.Int16: + return e.WriteInt16(int16(rv.Int()), opt.Order) + case reflect.Uint16: + return e.WriteUint16(uint16(rv.Uint()), opt.Order) + case reflect.Int32: + return e.WriteInt32(int32(rv.Int()), opt.Order) + case reflect.Uint32: + return e.WriteUint32(uint32(rv.Uint()), opt.Order) + case reflect.Uint64: + return e.WriteUint64(rv.Uint(), opt.Order) + case reflect.Int64: + return e.WriteInt64(rv.Int(), opt.Order) + case reflect.Float32: + return e.WriteFloat32(float32(rv.Float()), opt.Order) + case reflect.Float64: + return e.WriteFloat64(rv.Float(), opt.Order) + case reflect.Bool: + return e.WriteBool(rv.Bool()) + case reflect.Ptr: + return e.encodeBin(rv.Elem(), opt) + case reflect.Interface: + // skip + return nil + } + + rv = reflect.Indirect(rv) + rt := rv.Type() + switch rt.Kind() { + case reflect.Array: + l := rt.Len() + if traceEnabled { + defer func(prev *zap.Logger) { zlog = prev }(zlog) + zlog = zlog.Named("array") + zlog.Debug("encode: array", zap.Int("length", l), zap.Stringer("type", rv.Kind())) + } + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // if it's a [n]byte, accumulate and write in one command: + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { + return err + } + default: + // Element-wise recursion: each element is an independent type + // that may have its own marshaler, so reset the skip flag + // inherited from the field-level entry. The flag is restored + // when this Array case returns. + prevSkip := e.skipMarshalerCheck + e.skipMarshalerCheck = false + for i := range l { + if err = e.encodeBin(rv.Index(i), opt); err != nil { + e.skipMarshalerCheck = prevSkip + return + } + } + e.skipMarshalerCheck = prevSkip + } + + case reflect.Slice: + var l int + if opt.hasSizeOfSlice() { + l = opt.getSizeOfSlice() + if traceEnabled { + zlog.Debug("encode: slice with sizeof set", zap.Int("size_of", l)) + } + } else { + l = rv.Len() + if err = e.WriteUVarInt(l); err != nil { + return + } + } + if traceEnabled { + defer func(prev *zap.Logger) { zlog = prev }(zlog) + zlog = zlog.Named("slice") + zlog.Debug("encode: slice", zap.Int("length", l), zap.Stringer("type", rv.Kind())) + } + + // we would want to skip to the correct head_offset + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // if it's a [n]byte, accumulate and write in one command: + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { + return err + } + default: + prevSkip := e.skipMarshalerCheck + e.skipMarshalerCheck = false + for i := range l { + if err = e.encodeBin(rv.Index(i), opt); err != nil { + e.skipMarshalerCheck = prevSkip + return + } + } + e.skipMarshalerCheck = prevSkip + } + case reflect.Struct: + if err = e.encodeStructBin(rt, rv); err != nil { + return + } + + case reflect.Map: + keyCount := len(rv.MapKeys()) + + if traceEnabled { + zlog.Debug("encode: map", + zap.Int("key_count", keyCount), + zap.String("key_type", rt.String()), + typeField("value_type", rv.Elem()), + ) + defer func(prev *zap.Logger) { zlog = prev }(zlog) + zlog = zlog.Named("struct") + } + + if err = e.WriteUVarInt(keyCount); err != nil { + return + } + + for _, mapKey := range rv.MapKeys() { + if err = e.Encode(mapKey.Interface()); err != nil { + return + } + + if err = e.Encode(rv.MapIndex(mapKey).Interface()); err != nil { + return + } + } + + default: + return fmt.Errorf("encode: unsupported type %q", rt) + } + return +} + +func (e *Encoder) encodeStructBin(rt reflect.Type, rv reflect.Value) (err error) { + plan := planForStruct(rt) + + if traceEnabled { + zlog.Debug("encode: struct", zap.Int("fields", len(plan.fields)), zap.Stringer("type", rv.Kind())) + } + + var sizes []int + if plan.hasSizeOf { + var stack sizesScratch + if len(plan.fields) <= sizesScratchLen { + sizes = stack[:len(plan.fields)] + } else { + sizes = make([]int, len(plan.fields)) + } + for i := range sizes { + sizes[i] = -1 + } + } + + fastOK := rv.CanAddr() + for i := range plan.fields { + fp := &plan.fields[i] + + if fp.skip { + if traceEnabled { + zlog.Debug("encode: skipping struct field with skip flag", + zap.String("struct_field_name", fp.name), + ) + } + continue + } + + // Fast primitive path: no option construction, no kind switch. + if fastOK && fp.binFastEncode != nil { + if err := fp.binFastEncode(e, rv.Field(i)); err != nil { + return fmt.Errorf("error while encoding %q field: %w", fp.name, err) + } + continue + } + + fv := rv.Field(i) + + if fp.sizeOfTargetIdx >= 0 && sizes != nil { + sizes[fp.sizeOfTargetIdx] = sizeof(fp.fieldType, fv) + } + + if !fp.canInterface { + if traceEnabled { + zlog.Debug("encode: skipping field: unable to interface field, probably since field is not exported", + zap.String("struct_field_name", fp.name), + ) + } + continue + } + + opt := option{ + is_OptionalField: fp.tag.Option, + Order: fp.tag.Order, + } + + if sizes != nil && fp.sizeFromIdx >= 0 && sizes[i] >= 0 { + opt.sliceSizeIsSet = true + opt.sliceSize = sizes[i] + } + + if traceEnabled { + zlog.Debug("encode: struct field", + zap.Stringer("struct_field_value_type", fv.Kind()), + zap.String("struct_field_name", fp.name), + zap.Reflect("struct_field_tags", fp.tag), + zap.Reflect("struct_field_option", opt), + ) + } + + // Tell encodeBin to skip its asBinaryMarshaler boxing when the + // cached plan has already proven this field's type doesn't + // implement BinaryMarshaler. The flag is propagated through + // Ptr.Elem recursion in encodeBin and reset around array/slice + // element loops. + prevSkip := e.skipMarshalerCheck + if !fp.valImplementsMarshaler && !fp.ptrImplementsMarshaler { + e.skipMarshalerCheck = true + } + err := e.encodeBin(fv, opt) + e.skipMarshalerCheck = prevSkip + if err != nil { + return fmt.Errorf("error while encoding %q field: %w", fp.name, err) + } + } + return nil +} diff --git a/binary/encoder_borsh.go b/binary/encoder_borsh.go new file mode 100644 index 000000000..2b46d8f7f --- /dev/null +++ b/binary/encoder_borsh.go @@ -0,0 +1,406 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "errors" + "fmt" + "reflect" + "sort" + + "go.uber.org/zap" +) + +func (e *Encoder) encodePrimitive(rv reflect.Value) (isPrimitive bool, err error) { + isPrimitive = true + switch rv.Kind() { + // case reflect.Int: + // err = e.WriteInt64(rv.Int(), LE) + // case reflect.Uint: + // err = e.WriteUint64(rv.Uint(), LE) + case reflect.String: + err = e.WriteString(rv.String()) + case reflect.Uint8: + err = e.WriteByte(byte(rv.Uint())) + case reflect.Int8: + err = e.WriteByte(byte(rv.Int())) + case reflect.Int16: + err = e.WriteInt16(int16(rv.Int()), LE) + case reflect.Uint16: + err = e.WriteUint16(uint16(rv.Uint()), LE) + case reflect.Int32: + err = e.WriteInt32(int32(rv.Int()), LE) + case reflect.Uint32: + err = e.WriteUint32(uint32(rv.Uint()), LE) + case reflect.Uint64: + err = e.WriteUint64(rv.Uint(), LE) + case reflect.Int64: + err = e.WriteInt64(rv.Int(), LE) + case reflect.Float32: + err = e.WriteFloat32(float32(rv.Float()), LE) + case reflect.Float64: + err = e.WriteFloat64(rv.Float(), LE) + case reflect.Bool: + err = e.WriteBool(rv.Bool()) + default: + isPrimitive = false + } + return +} + +func (e *Encoder) encodeBorsh(rv reflect.Value, opt option) (err error) { + if opt.Order == nil { + opt.Order = defaultByteOrder + } + e.currentFieldOpt = opt + + if traceEnabled { + zlog.Debug("encode: type", + zap.Stringer("value_kind", rv.Kind()), + zap.Reflect("options", opt), + ) + } + + if opt.is_Optional() { + if rv.IsZero() { + if traceEnabled { + zlog.Debug("encode: skipping optional value with", zap.Stringer("type", rv.Kind())) + } + return e.WriteOption(false) + } + err := e.WriteOption(true) + if err != nil { + return err + } + // The optionality has been used; stop its propagation: + opt.is_OptionalField = false + } + if opt.is_COptional() { + if rv.IsZero() { + if traceEnabled { + zlog.Debug("encode: skipping optional value with", zap.Stringer("type", rv.Kind())) + } + return e.WriteCOption(false) + } + err := e.WriteCOption(true) + if err != nil { + return err + } + // The optionality has been used; stop its propagation: + opt.is_COptionalField = false + } + // Reset optionality so it won't propagate to child types. opt is a value + // copy so we mutate locally without affecting the caller. + opt.is_OptionalField = false + opt.is_COptionalField = false + + if isInvalidValue(rv) { + return nil + } + + if marshaler, ok := asBinaryMarshaler(rv); ok { + if rv.Kind() == reflect.Ptr && rv.IsZero() { + return nil + } + if traceEnabled { + zlog.Debug("encode: using MarshalerBinary method to encode type") + } + return marshaler.MarshalWithEncoder(e) + } + + // Encode the value if it's a primitive type + isPrimitive, err := e.encodePrimitive(rv) + if isPrimitive { + return err + } + + switch rv.Kind() { + case reflect.Ptr: + if rv.IsNil() { + el := reflect.New(rv.Type().Elem()).Elem() + return e.encodeBorsh(el, opt) + } else { + return e.encodeBorsh(rv.Elem(), opt) + } + case reflect.Interface: + // skip + return nil + } + + if !rv.IsZero() && !reflect.Indirect(rv).IsZero() { + rv = reflect.Indirect(rv) + } + rt := rv.Type() + switch rt.Kind() { + case reflect.Array: + l := rt.Len() + if traceEnabled { + defer func(prev *zap.Logger) { zlog = prev }(zlog) + zlog = zlog.Named("array") + zlog.Debug("encode: array", zap.Int("length", l), zap.Stringer("type", rv.Kind())) + } + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // if it's a [n]byte, accumulate and write in one command: + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { + return err + } + default: + for i := range l { + if err = e.encodeBorsh(rv.Index(i), opt); err != nil { + return + } + } + } + case reflect.Slice: + var l int + if opt.hasSizeOfSlice() { + l = opt.getSizeOfSlice() + if traceEnabled { + zlog.Debug("encode: slice with sizeof set", zap.Int("size_of", l)) + } + } else { + l = rv.Len() + if err = e.WriteUint32(uint32(l), LE); err != nil { + return + } + } + if traceEnabled { + defer func(prev *zap.Logger) { zlog = prev }(zlog) + zlog = zlog.Named("slice") + zlog.Debug("encode: slice", zap.Int("length", l), zap.Stringer("type", rv.Kind())) + } + + // we would want to skip to the correct head_offset + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // if it's a [n]byte, accumulate and write in one command: + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { + return err + } + default: + for i := range l { + if err = e.encodeBorsh(rv.Index(i), opt); err != nil { + return + } + } + } + + case reflect.Struct: + if err = e.encodeStructBorsh(rt, rv); err != nil { + return + } + + case reflect.Map: + keys := rv.MapKeys() + sort.Slice(keys, vComp(keys)) + + keyCount := rv.Len() + if traceEnabled { + zlog.Debug("encode: map", + zap.Int("key_count", keyCount), + zap.String("key_type", rt.String()), + typeField("value_type", rv), + ) + defer func(prev *zap.Logger) { zlog = prev }(zlog) + zlog = zlog.Named("struct") + } + + if err = e.WriteUint32(uint32(keyCount), LE); err != nil { + return + } + + for _, mapKey := range keys { + if err = e.Encode(mapKey.Interface()); err != nil { + return + } + + if err = e.Encode(rv.MapIndex(mapKey).Interface()); err != nil { + return + } + } + // TODO: + // case reflect.Ptr: + // if rv.IsNil() { + // } else { + // return e.encodeBorsh(rv.Elem(), opt) + // } + default: + return fmt.Errorf("encode: unsupported type %q", rt) + } + return +} + +func (enc *Encoder) encodeComplexEnumBorsh(rv reflect.Value) error { + t := rv.Type() + enum := BorshEnum(rv.Field(0).Uint()) + // write enum identifier + if err := enc.WriteByte(byte(enum)); err != nil { + return err + } + // write enum field, if necessary + if int(enum)+1 >= t.NumField() { + return errors.New("complex enum too large") + } + // Enum is empty + field := rv.Field(int(enum) + 1) + if field.Kind() == reflect.Ptr { + field = field.Elem() + } + if field.Kind() == reflect.Struct { + return enc.encodeStructBorsh(field.Type(), field) + } + // Encode the value if it's a primitive type + isPrimitive, err := enc.encodePrimitive(field) + if isPrimitive { + return err + } + return nil +} + +type BorshEnum uint8 + +// EmptyVariant is an empty borsh enum variant. +type EmptyVariant struct{} + +func (*EmptyVariant) MarshalWithEncoder(_ *Encoder) error { + return nil +} + +func (*EmptyVariant) UnmarshalWithDecoder(_ *Decoder) error { + return nil +} + +func (e *Encoder) encodeStructBorsh(rt reflect.Type, rv reflect.Value) (err error) { + plan := planForStruct(rt) + + if traceEnabled { + zlog.Debug("encode: struct", zap.Int("fields", len(plan.fields)), zap.Stringer("type", rv.Kind())) + } + + if plan.isComplexEnum { + return e.encodeComplexEnumBorsh(rv) + } + + // The fast primitive encode closures use unsafe.Pointer(fv.UnsafeAddr()) + // which requires the struct to be addressable. When the caller passed the + // struct by value (e.g. enc.Encode(myStruct)) we fall back to the generic + // reflect-based dispatch. + fastOK := rv.CanAddr() + + var sizes []int + if plan.hasSizeOf { + var stack sizesScratch + if len(plan.fields) <= sizesScratchLen { + sizes = stack[:len(plan.fields)] + } else { + sizes = make([]int, len(plan.fields)) + } + for i := range sizes { + sizes[i] = -1 + } + } + + for i := range plan.fields { + fp := &plan.fields[i] + + if fp.skip { + if traceEnabled { + zlog.Debug("encode: skipping struct field with skip flag", + zap.String("struct_field_name", fp.name), + ) + } + continue + } + + // Fast primitive path: no option construction, no kind switch. + if fastOK && fp.borshFastEncode != nil { + if err := fp.borshFastEncode(e, rv.Field(i)); err != nil { + return fmt.Errorf("error while encoding %q field: %w", fp.name, err) + } + continue + } + + fv := rv.Field(i) + + if fp.sizeOfTargetIdx >= 0 && sizes != nil { + sizes[fp.sizeOfTargetIdx] = sizeof(fp.fieldType, fv) + } + + if !fp.canInterface { + if traceEnabled { + zlog.Debug("encode: skipping field: unable to interface field, probably since field is not exported", + zap.String("struct_field_name", fp.name), + ) + } + continue + } + + opt := option{ + is_OptionalField: fp.tag.Option, + is_COptionalField: fp.tag.COption, + Order: fp.tag.Order, + } + + if sizes != nil && fp.sizeFromIdx >= 0 && sizes[i] >= 0 { + opt.sliceSizeIsSet = true + opt.sliceSize = sizes[i] + } + + if traceEnabled { + zlog.Debug("encode: struct field", + zap.Stringer("struct_field_value_type", fv.Kind()), + zap.String("struct_field_name", fp.name), + zap.Reflect("struct_field_tags", fp.tag), + zap.Reflect("struct_field_option", opt), + ) + } + + if err := e.encodeBorsh(fv, opt); err != nil { + return fmt.Errorf("error while encoding %q field: %w", fp.name, err) + } + } + return nil +} + +func vComp(keys []reflect.Value) func(int, int) bool { + return func(i int, j int) bool { + a, b := keys[i], keys[j] + if a.Kind() == reflect.Interface { + a = a.Elem() + b = b.Elem() + } + switch a.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + return a.Int() < b.Int() + case reflect.Int64: + return a.Interface().(int64) < b.Interface().(int64) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return a.Uint() < b.Uint() + case reflect.Uint64: + return a.Interface().(uint64) < b.Interface().(uint64) + case reflect.Float32, reflect.Float64: + return a.Float() < b.Float() + case reflect.String: + return a.String() < b.String() + } + panic("unsupported key compare") + } +} diff --git a/binary/encoder_compact-u16.go b/binary/encoder_compact-u16.go new file mode 100644 index 000000000..dcf396198 --- /dev/null +++ b/binary/encoder_compact-u16.go @@ -0,0 +1,274 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "fmt" + "reflect" + + "go.uber.org/zap" +) + +func (e *Encoder) encodeCompactU16(rv reflect.Value, opt option) (err error) { + if opt.Order == nil { + opt.Order = defaultByteOrder + } + e.currentFieldOpt = opt + + if traceEnabled { + zlog.Debug("encode: type", + zap.Stringer("value_kind", rv.Kind()), + zap.Reflect("options", opt), + ) + } + + if opt.is_Optional() { + if rv.IsZero() { + if traceEnabled { + zlog.Debug("encode: skipping optional value with", zap.Stringer("type", rv.Kind())) + } + return e.WriteBool(false) + } + err := e.WriteBool(true) + if err != nil { + return err + } + // The optionality has been used; stop its propagation: + opt.is_OptionalField = false + } + + if isInvalidValue(rv) { + return nil + } + + if marshaler, ok := asBinaryMarshaler(rv); ok { + if traceEnabled { + zlog.Debug("encode: using MarshalerBinary method to encode type") + } + return marshaler.MarshalWithEncoder(e) + } + + switch rv.Kind() { + case reflect.String: + return e.WriteString(rv.String()) + case reflect.Uint8: + return e.WriteByte(byte(rv.Uint())) + case reflect.Int8: + return e.WriteByte(byte(rv.Int())) + case reflect.Int16: + return e.WriteInt16(int16(rv.Int()), opt.Order) + case reflect.Uint16: + return e.WriteUint16(uint16(rv.Uint()), opt.Order) + case reflect.Int32: + return e.WriteInt32(int32(rv.Int()), opt.Order) + case reflect.Uint32: + return e.WriteUint32(uint32(rv.Uint()), opt.Order) + case reflect.Uint64: + return e.WriteUint64(rv.Uint(), opt.Order) + case reflect.Int64: + return e.WriteInt64(rv.Int(), opt.Order) + case reflect.Float32: + return e.WriteFloat32(float32(rv.Float()), opt.Order) + case reflect.Float64: + return e.WriteFloat64(rv.Float(), opt.Order) + case reflect.Bool: + return e.WriteBool(rv.Bool()) + case reflect.Ptr: + return e.encodeCompactU16(rv.Elem(), opt) + case reflect.Interface: + // skip + return nil + } + + rv = reflect.Indirect(rv) + rt := rv.Type() + switch rt.Kind() { + case reflect.Array: + l := rt.Len() + if traceEnabled { + defer func(prev *zap.Logger) { zlog = prev }(zlog) + zlog = zlog.Named("array") + zlog.Debug("encode: array", zap.Int("length", l), zap.Stringer("type", rv.Kind())) + } + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // if it's a [n]byte, accumulate and write in one command: + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { + return err + } + default: + for i := range l { + if err = e.encodeCompactU16(rv.Index(i), opt); err != nil { + return + } + } + } + case reflect.Slice: + var l int + if opt.hasSizeOfSlice() { + l = opt.getSizeOfSlice() + if traceEnabled { + zlog.Debug("encode: slice with sizeof set", zap.Int("size_of", l)) + } + } else { + l = rv.Len() + if err = e.WriteCompactU16Length(l); err != nil { + return + } + } + if traceEnabled { + defer func(prev *zap.Logger) { zlog = prev }(zlog) + zlog = zlog.Named("slice") + zlog.Debug("encode: slice", zap.Int("length", l), zap.Stringer("type", rv.Kind())) + } + + // we would want to skip to the correct head_offset + + switch k := rv.Type().Elem().Kind(); k { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // if it's a [n]byte, accumulate and write in one command: + if err := reflect_writeArrayOfUint_(e, l, k, rv, LE); err != nil { + return err + } + default: + for i := range l { + if err = e.encodeCompactU16(rv.Index(i), opt); err != nil { + return + } + } + } + case reflect.Struct: + if err = e.encodeStructCompactU16(rt, rv); err != nil { + return + } + + case reflect.Map: + keyCount := len(rv.MapKeys()) + + if traceEnabled { + zlog.Debug("encode: map", + zap.Int("key_count", keyCount), + zap.String("key_type", rt.String()), + typeField("value_type", rv.Elem()), + ) + defer func(prev *zap.Logger) { zlog = prev }(zlog) + zlog = zlog.Named("struct") + } + + if err = e.WriteCompactU16Length(keyCount); err != nil { + return + } + + for _, mapKey := range rv.MapKeys() { + if err = e.Encode(mapKey.Interface()); err != nil { + return + } + + if err = e.Encode(rv.MapIndex(mapKey).Interface()); err != nil { + return + } + } + + default: + return fmt.Errorf("encode: unsupported type %q", rt) + } + return +} + +func (e *Encoder) encodeStructCompactU16(rt reflect.Type, rv reflect.Value) (err error) { + plan := planForStruct(rt) + + if traceEnabled { + zlog.Debug("encode: struct", zap.Int("fields", len(plan.fields)), zap.Stringer("type", rv.Kind())) + } + + var sizes []int + if plan.hasSizeOf { + var stack sizesScratch + if len(plan.fields) <= sizesScratchLen { + sizes = stack[:len(plan.fields)] + } else { + sizes = make([]int, len(plan.fields)) + } + for i := range sizes { + sizes[i] = -1 + } + } + + fastOK := rv.CanAddr() + for i := range plan.fields { + fp := &plan.fields[i] + + if fp.skip { + if traceEnabled { + zlog.Debug("encode: skipping struct field with skip flag", + zap.String("struct_field_name", fp.name), + ) + } + continue + } + + // Fast primitive path: no option construction, no kind switch. + if fastOK && fp.binFastEncode != nil { + if err := fp.binFastEncode(e, rv.Field(i)); err != nil { + return fmt.Errorf("error while encoding %q field: %w", fp.name, err) + } + continue + } + + fv := rv.Field(i) + + if fp.sizeOfTargetIdx >= 0 && sizes != nil { + sizes[fp.sizeOfTargetIdx] = sizeof(fp.fieldType, fv) + } + + if !fp.canInterface { + if traceEnabled { + zlog.Debug("encode: skipping field: unable to interface field, probably since field is not exported", + zap.String("struct_field_name", fp.name), + ) + } + continue + } + + opt := option{ + is_OptionalField: fp.tag.Option, + Order: fp.tag.Order, + } + + if sizes != nil && fp.sizeFromIdx >= 0 && sizes[i] >= 0 { + opt.sliceSizeIsSet = true + opt.sliceSize = sizes[i] + } + + if traceEnabled { + zlog.Debug("encode: struct field", + zap.Stringer("struct_field_value_type", fv.Kind()), + zap.String("struct_field_name", fp.name), + zap.Reflect("struct_field_tags", fp.tag), + zap.Reflect("struct_field_option", opt), + ) + } + + if err := e.encodeCompactU16(fv, opt); err != nil { + return fmt.Errorf("error while encoding %q field: %w", fp.name, err) + } + } + return nil +} diff --git a/binary/encoder_test.go b/binary/encoder_test.go new file mode 100644 index 000000000..72bf96606 --- /dev/null +++ b/binary/encoder_test.go @@ -0,0 +1,1274 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "math" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncoder_Size(t *testing.T) { + { + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + assert.Equal(t, enc.Written(), 0) + enc.Encode(SafeString("hello")) + + assert.Equal(t, enc.Written(), 6) + enc.WriteBool(true) + assert.Equal(t, enc.Written(), 7) + } + { + buf := new(bytes.Buffer) + + enc := NewBorshEncoder(buf) + assert.Equal(t, enc.Written(), 0) + enc.WriteByte(123) + + assert.Equal(t, enc.Written(), 1) + enc.WriteBool(true) + assert.Equal(t, enc.Written(), 2) + } +} + +func TestEncoder_AliastTestType(t *testing.T) { + buf := new(bytes.Buffer) + enc := NewBinEncoder(buf) + enc.Encode(aliasTestType(23)) + + assert.Equal(t, []byte{ + 0x17, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0, 0x0, + }, buf.Bytes()) +} + +func TestEncoder_safeString(t *testing.T) { + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.Encode(SafeString("hello")) + + assert.Equal(t, []byte{ + 0x5, 0x68, 0x65, 0x6c, 0x6c, 0x6f, + }, buf.Bytes()) +} + +func TestEncoder_int8(t *testing.T) { + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + v := int8(-99) + enc.WriteByte(byte(v)) + enc.WriteByte(byte(int8(100))) + + assert.Equal(t, []byte{ + 0x9d, // -99 + 0x64, // 100 + }, buf.Bytes()) +} + +func TestEncoder_int16(t *testing.T) { + // little endian + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteInt16(int16(-82), LE) + enc.WriteInt16(int16(73), LE) + + assert.Equal(t, []byte{ + 0xae, 0xff, // -82 + 0x49, 0x00, // 73 + }, buf.Bytes()) + + // big endian + buf = new(bytes.Buffer) + + enc = NewBinEncoder(buf) + enc.WriteInt16(int16(-82), BE) + enc.WriteInt16(int16(73), BE) + + assert.Equal(t, []byte{ + 0xff, 0xae, // -82 + 0x00, 0x49, // 73 + }, buf.Bytes()) +} + +func TestEncoder_int32(t *testing.T) { + // little endian + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteInt32(int32(-276132392), LE) + enc.WriteInt32(int32(237391), LE) + + assert.Equal(t, []byte{ + 0xd8, 0x8d, 0x8a, 0xef, + 0x4f, 0x9f, 0x3, 0x00, + }, buf.Bytes()) + + // big endian + buf = new(bytes.Buffer) + + enc = NewBinEncoder(buf) + enc.WriteInt32(int32(-276132392), BE) + enc.WriteInt32(int32(237391), BE) + + assert.Equal(t, []byte{ + 0xef, 0x8a, 0x8d, 0xd8, + 0x00, 0x3, 0x9f, 0x4f, + }, buf.Bytes()) +} + +func TestEncoder_int64(t *testing.T) { + // little endian + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteInt64(int64(-819823), LE) + enc.WriteInt64(int64(72931), LE) + + assert.Equal(t, []byte{ + 0x91, 0x7d, 0xf3, 0xff, 0xff, 0xff, 0xff, 0xff, //-819823 + 0xe3, 0x1c, 0x1, 0x00, 0x00, 0x00, 0x00, 0x00, // 72931 + }, buf.Bytes()) + + // big endian + buf = new(bytes.Buffer) + + enc = NewBinEncoder(buf) + enc.WriteInt64(int64(-819823), BE) + enc.WriteInt64(int64(72931), BE) + + assert.Equal(t, []byte{ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xf3, 0x7d, 0x91, //-819823 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x1, 0x1c, 0xe3, // 72931 + }, buf.Bytes()) +} + +func TestEncoder_uint8(t *testing.T) { + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteByte(uint8(99)) + enc.WriteByte(uint8(100)) + + assert.Equal(t, []byte{ + 0x63, // 99 + 0x64, // 100 + }, buf.Bytes()) +} + +func TestEncoder_uint16(t *testing.T) { + // little endian + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteUint16(uint16(82), LE) + enc.WriteUint16(uint16(73), LE) + + assert.Equal(t, []byte{ + 0x52, 0x00, // 82 + 0x49, 0x00, // 73 + }, buf.Bytes()) + + // big endian + buf = new(bytes.Buffer) + + enc = NewBinEncoder(buf) + enc.WriteUint16(uint16(82), BE) + enc.WriteUint16(uint16(73), BE) + + assert.Equal(t, []byte{ + 0x00, 0x52, // 82 + 0x00, 0x49, // 73 + }, buf.Bytes()) +} + +func TestEncoder_uint32(t *testing.T) { + // little endian + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteUint32(uint32(276132392), LE) + enc.WriteUint32(uint32(237391), LE) + + assert.Equal(t, []byte{ + 0x28, 0x72, 0x75, 0x10, // 276132392 as LE + 0x4f, 0x9f, 0x03, 0x00, // 237391 as LE + }, buf.Bytes()) + + // big endian + buf = new(bytes.Buffer) + + enc = NewBinEncoder(buf) + enc.WriteUint32(uint32(276132392), BE) + enc.WriteUint32(uint32(237391), BE) + + assert.Equal(t, []byte{ + 0x10, 0x75, 0x72, 0x28, // 276132392 as LE + 0x00, 0x03, 0x9f, 0x4f, // 237391 as LE + }, buf.Bytes()) +} + +func TestEncoder_uint64(t *testing.T) { + // little endian + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteUint64(uint64(819823), LE) + enc.WriteUint64(uint64(72931), LE) + + assert.Equal(t, []byte{ + 0x6f, 0x82, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, // 819823 + 0xe3, 0x1c, 0x1, 0x00, 0x00, 0x00, 0x00, 0x00, // 72931 + }, buf.Bytes()) + + // big endian + buf = new(bytes.Buffer) + + enc = NewBinEncoder(buf) + enc.WriteUint64(uint64(819823), BE) + enc.WriteUint64(uint64(72931), BE) + + assert.Equal(t, []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x82, 0x6f, // 819823 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x1, 0x1c, 0xe3, // 72931 + }, buf.Bytes()) +} + +func TestEncoder_float32(t *testing.T) { + // little endian + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteFloat32(float32(1.32), LE) + enc.WriteFloat32(float32(-3.21), LE) + + assert.Equal(t, []byte{ + 0xc3, 0xf5, 0xa8, 0x3f, + 0xa4, 0x70, 0x4d, 0xc0, + }, buf.Bytes()) + + // big endian + buf = new(bytes.Buffer) + + enc = NewBinEncoder(buf) + enc.WriteFloat32(float32(1.32), BE) + enc.WriteFloat32(float32(-3.21), BE) + assert.Equal(t, []byte{ + 0x3f, 0xa8, 0xf5, 0xc3, + 0xc0, 0x4d, 0x70, 0xa4, + }, buf.Bytes()) +} + +func TestEncoder_float64(t *testing.T) { + // little endian + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteFloat64(float64(-62.23), LE) + enc.WriteFloat64(float64(23.239), LE) + enc.WriteFloat64(float64(math.Inf(1)), LE) + enc.WriteFloat64(float64(math.Inf(-1)), LE) + + assert.Equal(t, []byte{ + 0x3d, 0x0a, 0xd7, 0xa3, 0x70, 0x1d, 0x4f, 0xc0, + 0x77, 0xbe, 0x9f, 0x1a, 0x2f, 0x3d, 0x37, 0x40, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x7f, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xff, + }, buf.Bytes()) + + // big endian + buf = new(bytes.Buffer) + + enc = NewBinEncoder(buf) + enc.WriteFloat64(float64(-62.23), BE) + enc.WriteFloat64(float64(23.239), BE) + enc.WriteFloat64(float64(math.Inf(1)), BE) + enc.WriteFloat64(float64(math.Inf(-1)), BE) + + assert.Equal(t, []byte{ + 0xc0, 0x4f, 0x1d, 0x70, 0xa3, 0xd7, 0x0a, 0x3d, + 0x40, 0x37, 0x3d, 0x2f, 0x1a, 0x9f, 0xbe, 0x77, + 0x7f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, buf.Bytes()) +} + +func TestEncoder_string(t *testing.T) { + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteString("123") + enc.WriteString("") + enc.WriteString("abc") + + assert.Equal(t, []byte{ + 0x03, 0x31, 0x32, 0x33, // "123" + 0x00, // "" + 0x03, 0x61, 0x62, 0x63, // "abc + }, buf.Bytes()) +} + +func TestEncoder_byte(t *testing.T) { + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteByte(0) + enc.WriteByte(1) + + assert.Equal(t, []byte{ + 0x00, 0x01, + }, buf.Bytes()) +} + +func TestEncoder_bool(t *testing.T) { + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteBool(true) + enc.WriteBool(false) + + assert.Equal(t, []byte{ + 0x01, 0x00, + }, buf.Bytes()) +} + +func TestEncoder_ByteArray(t *testing.T) { + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteBytes([]byte{1, 2, 3}, true) + enc.WriteBytes([]byte{4, 5, 6}, true) + enc.WriteBytes([]byte{7, 8}, false) + + assert.Equal(t, []byte{ + 0x03, 0x01, 0x02, 0x03, + 0x03, 0x04, 0x05, 0x06, + 0x07, 0x08, + }, buf.Bytes()) + + bufB := new(bytes.Buffer) + + enc = NewBinEncoder(bufB) + enc.Encode([]byte{1, 2, 3}) + + assert.Equal(t, []byte{ + 0x03, 0x01, 0x02, 0x03, + }, bufB.Bytes()) +} + +func TestEncode_Array(t *testing.T) { + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.Encode([3]byte{1, 2, 4}) + + assert.Equal(t, + []byte{1, 2, 4}, + buf.Bytes(), + ) +} + +func Test_OptionalPointerToPrimitiveType(t *testing.T) { + type test struct { + ID *Uint64 `bin:"optional"` + } + + expect := []byte{0x00, 0x00, 0x00, 0x00} + + out, err := MarshalBin(test{ID: nil}) + require.NoError(t, err) + assert.Equal(t, expect, out) + + id := Uint64(0) + out, err = MarshalBin(test{ID: &id}) + require.NoError(t, err) + assert.Equal(t, []byte{0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, out) + + id = Uint64(10) + out, err = MarshalBin(test{ID: &id}) + require.NoError(t, err) + + assert.Equal(t, []byte{0x1, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, out) +} + +func TestEncoder_Uint128(t *testing.T) { + // little endian + u := Uint128{ + Lo: 7, + Hi: 9, + } + + buf := new(bytes.Buffer) + + enc := NewBinEncoder(buf) + enc.WriteUint128(u, LE) + + assert.Equal(t, []byte{ + 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, buf.Bytes()) + + // big endian + buf = new(bytes.Buffer) + + enc = NewBinEncoder(buf) + enc.WriteUint128(u, BE) + + assert.Equal(t, []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, + }, buf.Bytes()) +} + +func TestEncoder_BinaryStruct(t *testing.T) { + s := &binaryTestStruct{ + F1: "abc", + F2: -75, + F3: 99, + F4: -231, + F5: 999, + F6: -13231, + F7: 99999, + F8: -23.13, + F9: 3.92, + F10: []string{"def", "789"}, + F11: [2]string{"foo", "bar"}, + F12: 0xff, + F13: []byte{1, 2, 3, 4, 5}, + F14: true, + F15: Int64(-23), + F16: Uint64(23), + F17: JSONFloat64(3.14), + F18: Uint128{ + Lo: 10, + Hi: 82, + }, + F19: Int128{ + Lo: 7, + Hi: 3, + }, + F20: Float128{ + Lo: 10, + Hi: 82, + }, + F21: Varuint32(999), + F22: Varint32(-999), + F23: Bool(true), + F24: HexBytes([]byte{1, 2, 3, 4, 5}), + } + + buf := new(bytes.Buffer) + enc := NewBinEncoder(buf) + err := enc.Encode(s) + assert.NoError(t, err) + + assert.Equal(t, + "0300000000000000616263b5ff630019ffffffe703000051ccffffffffffff9f860100000000003d0ab9c15c8fc2f5285c0f4002030000000000000064656603000000000000003738390300000000000000666f6f0300000000000000626172ff05010203040501e9ffffffffffffff17000000000000001f85eb51b81e09400a000000000000005200000000000000070000000000000003000000000000000a000000000000005200000000000000e707cd0f01050102030405", + hex.EncodeToString(buf.Bytes()), + ) +} + +func TestEncoder_BinaryTestStructWithTags(t *testing.T) { + s := &binaryTestStructWithTags{ + F1: "abc", + F2: -75, + F3: 99, + F4: -231, + F5: 999, + F6: -13231, + F7: 99999, + F8: -23.13, + F9: 3.92, + F10: true, + F12: (*[]int64)(&[]int64{99, 33}), + } + + expected := []byte{ + 255, 181, // F2 + 0, 99, // F3 + 255, 255, 255, 25, // F4 + 0, 0, 3, 231, // F5 + 255, 255, 255, 255, 255, 255, 204, 81, // F6 + 0, 0, 0, 0, 0, 1, 134, 159, // F7 + 193, 185, 10, 61, // F8 + 64, 15, 92, 40, 245, 194, 143, 92, // F9 + 1, // F10 + + 0, 0, 0, 0, // F11 is optional, and NOT SET (meaning uint32(0)) + + 1, 0, 0, 0, // F12 is optional, and IS SET (meaning uint32(1)) + 2, // F12 is a slice, and the len is encoded as WriteUVarInt) + 99, 0, 0, 0, 0, 0, 0, 0, + 33, 0, 0, 0, 0, 0, 0, 0, + } + { + buf := new(bytes.Buffer) + enc := NewBinEncoder(buf) + { + err := enc.WriteInt16(s.F2, binary.BigEndian) // [255, 181](len=2) + if err != nil { + panic(err) + } + } + { + err := enc.WriteUint16(s.F3, binary.BigEndian) // [0, 99](len=2) + if err != nil { + panic(err) + } + } + { + err := enc.WriteInt32(s.F4, binary.BigEndian) // [255, 255, 255, 25](len=4) + if err != nil { + panic(err) + } + } + { + err := enc.WriteUint32(s.F5, binary.BigEndian) // [0, 0, 3, 231](len=4) + if err != nil { + panic(err) + } + } + { + err := enc.WriteInt64(s.F6, binary.BigEndian) // [255, 255, 255, 255, 255, 255, 204, 81](len=8) + if err != nil { + panic(err) + } + } + { + err := enc.WriteUint64(s.F7, binary.BigEndian) // [0, 0, 0, 0, 0, 1, 134, 159](len=8) + if err != nil { + panic(err) + } + } + { + err := enc.WriteFloat32(s.F8, binary.BigEndian) // [193, 185, 10, 61](len=4) + if err != nil { + panic(err) + } + } + { + err := enc.WriteFloat64(s.F9, binary.BigEndian) // [64, 15, 92, 40, 245, 194, 143, 92](len=8) + if err != nil { + panic(err) + } + } + { + err := enc.WriteBool(s.F10) // [1](len=1) + if err != nil { + panic(err) + } + } + { + err := enc.WriteUint32(0, binary.LittleEndian) // [0, 0, 0, 0](len=4) + if err != nil { + panic(err) + } + } + { + err := enc.WriteUint32(1, binary.LittleEndian) // [1, 0, 0, 0](len=4) + if err != nil { + panic(err) + } + } + { + err := enc.WriteUVarInt(2) // [2](len=1) + if err != nil { + panic(err) + } + } + { + err := enc.WriteInt64((*s.F12)[0], binary.LittleEndian) // [99, 0, 0, 0, 0, 0, 0, 0](len=8) + if err != nil { + panic(err) + } + } + { + err := enc.WriteInt64((*s.F12)[1], binary.LittleEndian) // [33, 0, 0, 0, 0, 0, 0, 0](len=8) + if err != nil { + panic(err) + } + } + + assert.Equal(t, + expected, + buf.Bytes(), + FormatByteSlice(buf.Bytes()), + ) + } + + buf := new(bytes.Buffer) + enc := NewBinEncoder(buf) + err := enc.Encode(s) + assert.NoError(t, err) + + assert.Equal(t, + expected, + buf.Bytes(), + FormatByteSlice(buf.Bytes()), + ) +} + +func TestEncoder_InterfaceNil(t *testing.T) { + var foo interface{} + foo = nil + buf := new(bytes.Buffer) + enc := NewBinEncoder(buf) + err := enc.Encode(foo) + assert.NoError(t, err) +} + +func TestByteArrays(t *testing.T) { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]byte{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3}, buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]byte{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3}, buf.Bytes()) + } +} + +func TestUintArrays(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]uint8{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3}, buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]uint8{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, []byte{1, 2, 3}, buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]uint16{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]uint16{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]uint32{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]uint32{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([3]uint64{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([3]uint64{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), buf.Bytes()) + } + } +} + +func TestUintSlices(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([]uint8{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + // length: + []byte{3}, + // data: + []byte{1, 2, 3}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([]uint8{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3, 0, 0, 0}, + // data: + []byte{1, 2, 3}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([]uint16{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + // length: + []byte{3}, + // data: + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([]uint16{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + // length: + []byte{3, 0, 0, 0}, + // data: + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([]uint32{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3}, + // data: + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([]uint32{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3, 0, 0, 0}, + // data: + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + err := enc.Encode([]uint64{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3}, + // data: + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + err := enc.Encode([]uint64{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + // length: + []byte{3, 0, 0, 0}, + // data: + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), buf.Bytes()) + } + } +} + +func Test_writeArrayOfBytes(t *testing.T) { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [3]byte{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfBytes(enc, l, reflect.ValueOf(arr)) + assert.NoError(t, err) + assert.Equal(t, arr[:], buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [10]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + l := len(arr) + + err := reflect_writeArrayOfBytes(enc, l, reflect.ValueOf(arr)) + assert.NoError(t, err) + assert.Equal(t, arr[:], buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + l := len(arr) + + err := reflect_writeArrayOfBytes(enc, l, reflect.ValueOf(arr)) + assert.NoError(t, err) + assert.Equal(t, arr[:], buf.Bytes()) + } +} + +func Test_writeArrayOfUint16(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [3]uint16{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint16(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := [3]uint16{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint16(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := []uint16{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint16(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := []uint16{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint16(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, concatByteSlices( + []byte{1, 0, 2, 0, 3, 0}, + ), buf.Bytes()) + } + } +} + +func Test_writeArrayOfUint32(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [3]uint32{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := [3]uint32{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + } + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := []uint32{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := []uint32{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [10]uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + []byte{4, 0, 0, 0}, + []byte{5, 0, 0, 0}, + []byte{6, 0, 0, 0}, + []byte{7, 0, 0, 0}, + []byte{8, 0, 0, 0}, + []byte{9, 0, 0, 0}, + []byte{10, 0, 0, 0}, + ), + buf.Bytes()) + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [32]uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + l := len(arr) + + err := reflect_writeArrayOfUint32(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0}, + []byte{2, 0, 0, 0}, + []byte{3, 0, 0, 0}, + []byte{4, 0, 0, 0}, + []byte{5, 0, 0, 0}, + []byte{6, 0, 0, 0}, + []byte{7, 0, 0, 0}, + []byte{8, 0, 0, 0}, + []byte{9, 0, 0, 0}, + []byte{10, 0, 0, 0}, + []byte{11, 0, 0, 0}, + []byte{12, 0, 0, 0}, + []byte{13, 0, 0, 0}, + []byte{14, 0, 0, 0}, + []byte{15, 0, 0, 0}, + []byte{16, 0, 0, 0}, + []byte{17, 0, 0, 0}, + []byte{18, 0, 0, 0}, + []byte{19, 0, 0, 0}, + []byte{20, 0, 0, 0}, + []byte{21, 0, 0, 0}, + []byte{22, 0, 0, 0}, + []byte{23, 0, 0, 0}, + []byte{24, 0, 0, 0}, + []byte{25, 0, 0, 0}, + []byte{26, 0, 0, 0}, + []byte{27, 0, 0, 0}, + []byte{28, 0, 0, 0}, + []byte{29, 0, 0, 0}, + []byte{30, 0, 0, 0}, + []byte{31, 0, 0, 0}, + []byte{32, 0, 0, 0}, + ), + buf.Bytes()) + } +} + +func Test_writeArrayOfUint64(t *testing.T) { + { + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := [3]uint64{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + { + var buf bytes.Buffer + enc := NewBinEncoder(&buf) + + arr := []uint64{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + } + { + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := [3]uint64{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + { + var buf bytes.Buffer + enc := NewBorshEncoder(&buf) + + arr := []uint64{1, 2, 3} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes(), + ) + } + } + { + var buf bytes.Buffer + + enc := NewBinEncoder(&buf) + arr := [64]uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64} + l := len(arr) + + err := reflect_writeArrayOfUint64(enc, l, reflect.ValueOf(arr), LE) + assert.NoError(t, err) + + assert.Equal(t, + concatByteSlices( + []byte{1, 0, 0, 0, 0, 0, 0, 0}, + []byte{2, 0, 0, 0, 0, 0, 0, 0}, + []byte{3, 0, 0, 0, 0, 0, 0, 0}, + []byte{4, 0, 0, 0, 0, 0, 0, 0}, + []byte{5, 0, 0, 0, 0, 0, 0, 0}, + []byte{6, 0, 0, 0, 0, 0, 0, 0}, + []byte{7, 0, 0, 0, 0, 0, 0, 0}, + []byte{8, 0, 0, 0, 0, 0, 0, 0}, + []byte{9, 0, 0, 0, 0, 0, 0, 0}, + []byte{10, 0, 0, 0, 0, 0, 0, 0}, + []byte{11, 0, 0, 0, 0, 0, 0, 0}, + []byte{12, 0, 0, 0, 0, 0, 0, 0}, + []byte{13, 0, 0, 0, 0, 0, 0, 0}, + []byte{14, 0, 0, 0, 0, 0, 0, 0}, + []byte{15, 0, 0, 0, 0, 0, 0, 0}, + []byte{16, 0, 0, 0, 0, 0, 0, 0}, + []byte{17, 0, 0, 0, 0, 0, 0, 0}, + []byte{18, 0, 0, 0, 0, 0, 0, 0}, + []byte{19, 0, 0, 0, 0, 0, 0, 0}, + []byte{20, 0, 0, 0, 0, 0, 0, 0}, + []byte{21, 0, 0, 0, 0, 0, 0, 0}, + []byte{22, 0, 0, 0, 0, 0, 0, 0}, + []byte{23, 0, 0, 0, 0, 0, 0, 0}, + []byte{24, 0, 0, 0, 0, 0, 0, 0}, + []byte{25, 0, 0, 0, 0, 0, 0, 0}, + []byte{26, 0, 0, 0, 0, 0, 0, 0}, + []byte{27, 0, 0, 0, 0, 0, 0, 0}, + []byte{28, 0, 0, 0, 0, 0, 0, 0}, + []byte{29, 0, 0, 0, 0, 0, 0, 0}, + []byte{30, 0, 0, 0, 0, 0, 0, 0}, + []byte{31, 0, 0, 0, 0, 0, 0, 0}, + []byte{32, 0, 0, 0, 0, 0, 0, 0}, + []byte{33, 0, 0, 0, 0, 0, 0, 0}, + []byte{34, 0, 0, 0, 0, 0, 0, 0}, + []byte{35, 0, 0, 0, 0, 0, 0, 0}, + []byte{36, 0, 0, 0, 0, 0, 0, 0}, + []byte{37, 0, 0, 0, 0, 0, 0, 0}, + []byte{38, 0, 0, 0, 0, 0, 0, 0}, + []byte{39, 0, 0, 0, 0, 0, 0, 0}, + []byte{40, 0, 0, 0, 0, 0, 0, 0}, + []byte{41, 0, 0, 0, 0, 0, 0, 0}, + []byte{42, 0, 0, 0, 0, 0, 0, 0}, + []byte{43, 0, 0, 0, 0, 0, 0, 0}, + []byte{44, 0, 0, 0, 0, 0, 0, 0}, + []byte{45, 0, 0, 0, 0, 0, 0, 0}, + []byte{46, 0, 0, 0, 0, 0, 0, 0}, + []byte{47, 0, 0, 0, 0, 0, 0, 0}, + []byte{48, 0, 0, 0, 0, 0, 0, 0}, + []byte{49, 0, 0, 0, 0, 0, 0, 0}, + []byte{50, 0, 0, 0, 0, 0, 0, 0}, + []byte{51, 0, 0, 0, 0, 0, 0, 0}, + []byte{52, 0, 0, 0, 0, 0, 0, 0}, + []byte{53, 0, 0, 0, 0, 0, 0, 0}, + []byte{54, 0, 0, 0, 0, 0, 0, 0}, + []byte{55, 0, 0, 0, 0, 0, 0, 0}, + []byte{56, 0, 0, 0, 0, 0, 0, 0}, + []byte{57, 0, 0, 0, 0, 0, 0, 0}, + []byte{58, 0, 0, 0, 0, 0, 0, 0}, + []byte{59, 0, 0, 0, 0, 0, 0, 0}, + []byte{60, 0, 0, 0, 0, 0, 0, 0}, + []byte{61, 0, 0, 0, 0, 0, 0, 0}, + []byte{62, 0, 0, 0, 0, 0, 0, 0}, + []byte{63, 0, 0, 0, 0, 0, 0, 0}, + []byte{64, 0, 0, 0, 0, 0, 0, 0}, + ), + buf.Bytes()) + } +} diff --git a/binary/error.go b/binary/error.go new file mode 100644 index 000000000..40e448e0e --- /dev/null +++ b/binary/error.go @@ -0,0 +1,34 @@ +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import "reflect" + +// An InvalidDecoderError describes an invalid argument passed to Decoder. +// (The argument to Decoder must be a non-nil pointer.) +type InvalidDecoderError struct { + Type reflect.Type +} + +func (e *InvalidDecoderError) Error() string { + if e.Type == nil { + return "decoder: Decode(nil)" + } + + if e.Type.Kind() != reflect.Ptr { + return "decoder: Decode(non-pointer " + e.Type.String() + ")" + } + return "decoder: Decode(nil " + e.Type.String() + ")" +} diff --git a/binary/fuzz_test.go b/binary/fuzz_test.go new file mode 100644 index 000000000..28c32be4c --- /dev/null +++ b/binary/fuzz_test.go @@ -0,0 +1,114 @@ +package bin + +import ( + "encoding/binary" + "testing" +) + +// The binary package is used to decode on-chain Solana data, which is +// attacker-controlled. These fuzz targets exist to shake out panics, +// over-allocations, and infinite loops in the three top-level decoders when +// they're handed arbitrary bytes. A "pass" means: no panic, no OOM, no hang. +// Returning an error is fine — silently accepting malformed input is what +// we're looking for. + +type fuzzStruct struct { + A uint64 + B uint32 + C int16 + D bool + E string + F []byte + G [4]uint64 +} + +// seedForFuzz returns a buffer produced by the encoder for a known-valid +// value. Seeding the fuzzer with valid-looking bytes gives coverage-guided +// fuzzing a starting point close to the decoder's happy path. +func seedForFuzz(t testing.TB, enc Encoding) []byte { + t.Helper() + v := fuzzStruct{ + A: 0xdeadbeefcafebabe, + B: 0x01020304, + C: -1234, + D: true, + E: "hello", + F: []byte{1, 2, 3, 4, 5}, + G: [4]uint64{1, 2, 3, 4}, + } + var buf []byte + var err error + switch enc { + case EncodingBin: + buf, err = MarshalBin(&v) + case EncodingBorsh: + buf, err = MarshalBorsh(&v) + case EncodingCompactU16: + buf, err = MarshalCompactU16(&v) + } + if err != nil { + t.Fatalf("seed encode: %v", err) + } + return buf +} + +func FuzzDecodeBin(f *testing.F) { + f.Add(seedForFuzz(f, EncodingBin)) + // Extra small seeds to exercise length-prefix edge cases. + f.Add([]byte{}) + f.Add([]byte{0xff}) + f.Add([]byte{0xff, 0xff, 0xff, 0xff}) + f.Fuzz(func(t *testing.T, data []byte) { + var v fuzzStruct + _ = NewBinDecoder(data).Decode(&v) + }) +} + +func FuzzDecodeBorsh(f *testing.F) { + f.Add(seedForFuzz(f, EncodingBorsh)) + f.Add([]byte{}) + f.Add([]byte{0xff}) + f.Add([]byte{0xff, 0xff, 0xff, 0xff}) + f.Fuzz(func(t *testing.T, data []byte) { + var v fuzzStruct + _ = NewBorshDecoder(data).Decode(&v) + }) +} + +func FuzzDecodeCompactU16(f *testing.F) { + f.Add(seedForFuzz(f, EncodingCompactU16)) + f.Add([]byte{}) + f.Add([]byte{0xff}) + f.Add([]byte{0xff, 0xff, 0xff, 0xff}) + f.Fuzz(func(t *testing.T, data []byte) { + var v fuzzStruct + _ = NewCompactU16Decoder(data).Decode(&v) + }) +} + +// FuzzCompactU16Length targets the three compact-u16 length decoders for +// non-canonical encodings, overflow, and continuation-bit handling. +func FuzzCompactU16Length(f *testing.F) { + f.Add([]byte{0x00}) + f.Add([]byte{0x7f}) + f.Add([]byte{0x80, 0x01}) + f.Add([]byte{0xff, 0xff, 0x03}) + f.Add([]byte{0xff, 0xff, 0xff}) + f.Fuzz(func(t *testing.T, data []byte) { + _, _, _ = DecodeCompactU16(data) + }) +} + +// FuzzUint128JSON exercises the big.Int / hex JSON paths, which do their +// own length checks and base-conversion parsing. +func FuzzUint128JSON(f *testing.F) { + f.Add([]byte(`"0"`)) + f.Add([]byte(`"0x00000000000000000000000000000001"`)) + f.Add([]byte(`"123456789012345678901234567890"`)) + f.Add([]byte(`null`)) + f.Fuzz(func(t *testing.T, data []byte) { + var v Uint128 + v.Endianness = binary.LittleEndian + _ = v.UnmarshalJSON(data) + }) +} diff --git a/binary/heck.go b/binary/heck.go new file mode 100644 index 000000000..2839bccea --- /dev/null +++ b/binary/heck.go @@ -0,0 +1,164 @@ +package bin + +import ( + "strings" + "unicode" +) + +// Ported from https://github.com/withoutboats/heck +// https://github.com/withoutboats/heck/blob/master/LICENSE-APACHE +// https://github.com/withoutboats/heck/blob/master/LICENSE-MIT + +// ToPascalCase converts a string to upper camel case. +func ToPascalCase(s string) string { + return transform( + s, + capitalize, + func(*strings.Builder) {}, + ) +} + +// ToRustSnakeCase converts the given string to a snake_case string. +// Ported from https://github.com/withoutboats/heck/blob/c501fc95db91ce20eaef248a511caec7142208b4/src/lib.rs#L75 as used by Anchor. +func ToRustSnakeCase(s string) string { + return transform( + s, + func(w string, b *strings.Builder) { b.WriteString(strings.ToLower(w)) }, + func(b *strings.Builder) { b.WriteRune('_') }, + ) +} + +// ToSnakeForSighash is the Anchor-sighash-ready alias for ToRustSnakeCase. +func ToSnakeForSighash(s string) string { + return ToRustSnakeCase(s) +} + +// transform walks s token-by-token, emitting a boundary callback between +// tokens and feeding each token to withWord. The word-boundary rules match +// Rust's heck crate: underscores are separators, and case changes inside a +// word (e.g. "camelCase" -> "camel" + "Case", "HTTPServer" -> "HTTP" + +// "Server") create boundaries. +func transform( + s string, + withWord func(string, *strings.Builder), + boundary func(*strings.Builder), +) string { + var builder strings.Builder + firstWord := true + + for _, word := range splitIntoWords(s) { + runes := []rune(word) + init := 0 + mode := wordModeBoundary + + for i := 0; i < len(runes); i++ { + c := runes[i] + + // Skip leading underscores within a token. + if c == '_' { + if init == i { + init++ + } + continue + } + + if i+1 < len(runes) { + next := runes[i+1] + + // The mode including the current character, assuming the + // current character does not result in a word boundary. + nextMode := mode + switch { + case unicode.IsLower(c): + nextMode = wordModeLowercase + case unicode.IsUpper(c): + nextMode = wordModeUppercase + } + + // Word boundary after if next is underscore, or current is + // not uppercase and next is uppercase. + if next == '_' || (nextMode == wordModeLowercase && unicode.IsUpper(next)) { + if !firstWord { + boundary(&builder) + } + withWord(string(runes[init:i+1]), &builder) + firstWord = false + init = i + 1 + mode = wordModeBoundary + continue + } + + // Word boundary before if current and previous are uppercase + // and next is lowercase (XMLHttp -> XML + Http). + if mode == wordModeUppercase && unicode.IsUpper(c) && unicode.IsLower(next) { + if !firstWord { + boundary(&builder) + } else { + firstWord = false + } + withWord(string(runes[init:i]), &builder) + init = i + mode = wordModeBoundary + continue + } + + // Otherwise no word boundary, just update the mode. + mode = nextMode + continue + } + + // Last rune of the token: flush trailing characters as a word. + if !firstWord { + boundary(&builder) + } else { + firstWord = false + } + withWord(string(runes[init:]), &builder) + } + } + + return builder.String() +} + +func capitalize(s string, b *strings.Builder) { + if s == "" { + return + } + runes := []rune(s) + b.WriteString(strings.ToUpper(string(runes[0]))) + if len(runes) > 1 { + lowercase(string(runes[1:]), b) + } +} + +func lowercase(s string, b *strings.Builder) { + runes := []rune(s) + for i, c := range runes { + // Final sigma special-case per the heck crate. + if c == 'Σ' && i == len(runes)-1 { + b.WriteString("ς") + continue + } + b.WriteString(strings.ToLower(string(c))) + } +} + +func splitIntoWords(s string) []string { + return strings.FieldsFunc(s, func(r rune) bool { + return !(unicode.IsLetter(r) || unicode.IsDigit(r)) + }) +} + +type wordMode int + +const ( + // wordModeBoundary: no lowercase or uppercase characters yet in the + // current word. + wordModeBoundary wordMode = iota + // wordModeLowercase: the previous cased character in the current word is + // lowercase. + wordModeLowercase + // wordModeUppercase: the previous cased character in the current word is + // uppercase. + wordModeUppercase +) diff --git a/binary/heck_test.go b/binary/heck_test.go new file mode 100644 index 000000000..96bcb89ff --- /dev/null +++ b/binary/heck_test.go @@ -0,0 +1,78 @@ +package bin + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCamelCase(t *testing.T) { + type item struct { + input string + want string + } + tests := []item{ + // TODO: find out if need to fix, and if yes, then fix. + // {"1hello", "1Hello"}, // actual: `1hello` + // {"1Hello", "1Hello"}, // actual: `1hello` + // {"hello1world", "Hello1World"}, // actual: `Hello1world` + // {"mGridCol6@md", "MGridCol6md"}, // actual: `MGridCol6Md` + // {"A::a", "Aa"}, // actual: `AA` + // {"foìBar-baz", "FoìBarBaz"}, + // + {"hello1World", "Hello1World"}, + {"Hello1World", "Hello1World"}, + {"foo", "Foo"}, + {"foo-bar", "FooBar"}, + {"foo-bar-baz", "FooBarBaz"}, + {"foo--bar", "FooBar"}, + {"--foo-bar", "FooBar"}, + {"--foo--bar", "FooBar"}, + {"FOO-BAR", "FooBar"}, + {"FOÈ-BAR", "FoèBar"}, + {"-foo-bar-", "FooBar"}, + {"--foo--bar--", "FooBar"}, + {"foo-1", "Foo1"}, + {"foo.bar", "FooBar"}, + {"foo..bar", "FooBar"}, + {"..foo..bar..", "FooBar"}, + {"foo_bar", "FooBar"}, + {"__foo__bar__", "FooBar"}, + {"__foo__bar__", "FooBar"}, + {"foo bar", "FooBar"}, + {" foo bar ", "FooBar"}, + {"-", ""}, + {" - ", ""}, + {"fooBar", "FooBar"}, + {"fooBar-baz", "FooBarBaz"}, + {"fooBarBaz-bazzy", "FooBarBazBazzy"}, + {"FBBazzy", "FbBazzy"}, + {"F", "F"}, + {"FooBar", "FooBar"}, + {"Foo", "Foo"}, + {"FOO", "Foo"}, + {"--", ""}, + {"", ""}, + {"--__--_--_", ""}, + {"foo bar?", "FooBar"}, + {"foo bar!", "FooBar"}, + {"foo bar$", "FooBar"}, + {"foo-bar#", "FooBar"}, + {"XMLHttpRequest", "XmlHttpRequest"}, + {"AjaxXMLHttpRequest", "AjaxXmlHttpRequest"}, + {"Ajax-XMLHttpRequest", "AjaxXmlHttpRequest"}, + {"Hello11World", "Hello11World"}, + {"hello1", "Hello1"}, + {"Hello1", "Hello1"}, + {"h1W", "H1W"}, + // TODO: add support to non-alphanumeric characters (non-latin, non-ascii). + } + + for i := range tests { + test := tests[i] + t.Run(test.input, func(t *testing.T) { + t.Parallel() + assert.Equal(t, test.want, ToPascalCase(test.input)) + }) + } +} diff --git a/binary/init_test.go b/binary/init_test.go new file mode 100644 index 000000000..0a1044093 --- /dev/null +++ b/binary/init_test.go @@ -0,0 +1,70 @@ +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "testing" +) + +type aliasTestType uint64 + +type unknownType struct{} + +type binaryTestStruct struct { + F1 string + F2 int16 + F3 uint16 + F4 int32 + F5 uint32 + F6 int64 + F7 uint64 + F8 float32 + F9 float64 + F10 []string + F11 [2]string + F12 byte + F13 []byte + F14 bool + F15 Int64 + F16 Uint64 + F17 JSONFloat64 + F18 Uint128 + F19 Int128 + F20 Float128 + F21 Varuint32 + F22 Varint32 + F23 Bool + F24 HexBytes +} + +type binaryTestStructWithTags struct { + F1 string `bin:"-"` + F2 int16 `bin:"big"` + F3 uint16 `bin:"big"` + F4 int32 `bin:"big"` + F5 uint32 `bin:"big"` + F6 int64 `bin:"big"` + F7 uint64 `bin:"big"` + F8 float32 `bin:"big"` + F9 float64 `bin:"big"` + F10 bool + F11 *Int64 `bin:"optional"` + F12 *[]int64 `bin:"optional"` +} + +func setupBench(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() +} diff --git a/binary/interface.go b/binary/interface.go new file mode 100644 index 000000000..08b3e6ce3 --- /dev/null +++ b/binary/interface.go @@ -0,0 +1,227 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "fmt" + "sync" +) + +// marshalBufInitialCap is the initial capacity allocated for a pooled +// Encoder's internal buffer. Chosen to cover the typical Solana +// instruction payload size (~256 bytes fits the vast majority of +// variant-wrapped instructions and most small state structs) without +// requiring a grow step on the hot path. Larger payloads still grow +// via append. +const marshalBufInitialCap = 256 + +// Per-encoding Encoder pools. Pooled encoders are always in buffered +// mode (output == nil, writes append to e.buf). Marshal helpers Get a +// pooled encoder, encode into it, copy the result out (the pooled +// buffer can be reused by the next caller), Reset, and Put back. +var ( + binEncoderPool = sync.Pool{ + New: func() any { + return &Encoder{ + encoding: EncodingBin, + buf: make([]byte, 0, marshalBufInitialCap), + } + }, + } + borshEncoderPool = sync.Pool{ + New: func() any { + return &Encoder{ + encoding: EncodingBorsh, + buf: make([]byte, 0, marshalBufInitialCap), + } + }, + } + compactU16EncoderPool = sync.Pool{ + New: func() any { + return &Encoder{ + encoding: EncodingCompactU16, + buf: make([]byte, 0, marshalBufInitialCap), + } + }, + } + + // Per-encoding Decoder pools. Pooled decoders have their data/pos + // reset between uses via Decoder.Reset. The encoding field is set + // at pool-New time and preserved by Reset. + binDecoderPool = sync.Pool{ + New: func() any { + return &Decoder{encoding: EncodingBin} + }, + } + borshDecoderPool = sync.Pool{ + New: func() any { + return &Decoder{encoding: EncodingBorsh} + }, + } + compactU16DecoderPool = sync.Pool{ + New: func() any { + return &Decoder{encoding: EncodingCompactU16} + }, + } +) + +// pooledMarshal runs enc.Encode(v) on a pooled encoder, copies the +// resulting bytes out so the pooled buffer can be reused safely, and +// returns the encoder to the pool. It is the shared implementation +// behind MarshalBin / MarshalBorsh / MarshalCompactU16. +func pooledMarshal(pool *sync.Pool, v any) ([]byte, error) { + enc := pool.Get().(*Encoder) + err := enc.Encode(v) + if err != nil { + enc.Reset() + pool.Put(enc) + return nil, err + } + // Copy the bytes out: the pooled encoder's underlying slice will + // be reused by future callers, so returning enc.buf directly would + // let the next Marshal call silently stomp on the caller's result. + out := make([]byte, len(enc.buf)) + copy(out, enc.buf) + enc.Reset() + pool.Put(enc) + return out, nil +} + +// pooledUnmarshal runs dec.Decode(v) on a pooled decoder over the +// provided bytes, then returns the decoder to the pool. Shared +// implementation behind UnmarshalBin / UnmarshalBorsh / +// UnmarshalCompactU16. +func pooledUnmarshal(pool *sync.Pool, v any, b []byte) error { + dec := pool.Get().(*Decoder) + dec.Reset(b) + err := dec.Decode(v) + // Clear the data reference before returning to pool so we don't + // pin the caller's input buffer in the pool. + dec.Reset(nil) + pool.Put(dec) + return err +} + +type BinaryMarshaler interface { + MarshalWithEncoder(encoder *Encoder) error +} + +type BinaryUnmarshaler interface { + UnmarshalWithDecoder(decoder *Decoder) error +} + +type EncoderDecoder interface { + BinaryMarshaler + BinaryUnmarshaler +} + +func MarshalBin(v any) ([]byte, error) { + return pooledMarshal(&binEncoderPool, v) +} + +func MarshalBorsh(v any) ([]byte, error) { + return pooledMarshal(&borshEncoderPool, v) +} + +func MarshalCompactU16(v any) ([]byte, error) { + return pooledMarshal(&compactU16EncoderPool, v) +} + +func UnmarshalBin(v any, b []byte) error { + return pooledUnmarshal(&binDecoderPool, v, b) +} + +func UnmarshalBorsh(v any, b []byte) error { + return pooledUnmarshal(&borshDecoderPool, v, b) +} + +func UnmarshalCompactU16(v any, b []byte) error { + return pooledUnmarshal(&compactU16DecoderPool, v, b) +} + +type byteCounter struct { + count uint64 +} + +func (c *byteCounter) Write(p []byte) (n int, err error) { + c.count += uint64(len(p)) + return len(p), nil +} + +// BinByteCount computes the byte count size for the received populated structure. The reported size +// is the one for the populated structure received in arguments. Depending on how serialization of +// your fields is performed, size could vary for different structure. +func BinByteCount(v any) (uint64, error) { + counter := byteCounter{} + err := NewBinEncoder(&counter).Encode(v) + if err != nil { + return 0, fmt.Errorf("encode %T: %w", v, err) + } + return counter.count, nil +} + +// BorshByteCount computes the byte count size for the received populated structure. The reported size +// is the one for the populated structure received in arguments. Depending on how serialization of +// your fields is performed, size could vary for different structure. +func BorshByteCount(v any) (uint64, error) { + counter := byteCounter{} + err := NewBorshEncoder(&counter).Encode(v) + if err != nil { + return 0, fmt.Errorf("encode %T: %w", v, err) + } + return counter.count, nil +} + +// CompactU16ByteCount computes the byte count size for the received populated structure. The reported size +// is the one for the populated structure received in arguments. Depending on how serialization of +// your fields is performed, size could vary for different structure. +func CompactU16ByteCount(v any) (uint64, error) { + counter := byteCounter{} + err := NewCompactU16Encoder(&counter).Encode(v) + if err != nil { + return 0, fmt.Errorf("encode %T: %w", v, err) + } + return counter.count, nil +} + +// MustBinByteCount acts just like BinByteCount but panics if it encounters any encoding errors. +func MustBinByteCount(v any) uint64 { + count, err := BinByteCount(v) + if err != nil { + panic(err) + } + return count +} + +// MustBorshByteCount acts just like BorshByteCount but panics if it encounters any encoding errors. +func MustBorshByteCount(v any) uint64 { + count, err := BorshByteCount(v) + if err != nil { + panic(err) + } + return count +} + +// MustCompactU16ByteCount acts just like CompactU16ByteCount but panics if it encounters any encoding errors. +func MustCompactU16ByteCount(v any) uint64 { + count, err := CompactU16ByteCount(v) + if err != nil { + panic(err) + } + return count +} diff --git a/binary/interface_bench_test.go b/binary/interface_bench_test.go new file mode 100644 index 000000000..0c4a3e089 --- /dev/null +++ b/binary/interface_bench_test.go @@ -0,0 +1,126 @@ +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "crypto/rand" + "io" + mrand "math/rand" + "testing" +) + +func BenchmarkByteCount(b *testing.B) { + nestedSmall := &benchNested{ + N1: &benchSubset1{F3: makeStringList(10), F4: makeUint64List(10)}, + N2: &benchSubset2{}, + } + + nestedLarge := &benchNested{ + N1: &benchSubset1{F3: makeStringList(200), F4: makeUint64List(200)}, + N2: &benchSubset2{}, + } + + benchmarks := []struct { + name string + v interface{} + }{ + {"flat", &benchFlat{}}, + {"nested/small list", nestedSmall}, + {"nested/large list", nestedLarge}, + {"deep/small list", &benchDeepNested{N1: nestedSmall, N2: nestedSmall, N3: nestedSmall}}, + {"deep/large list", &benchDeepNested{N1: nestedLarge, N2: nestedLarge, N3: nestedLarge}}, + } + + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + setupBench(b) + for i := 0; i < b.N; i++ { + BinByteCount(bm.v) + } + }) + } +} + +type benchFlat struct { + F1 string + F2 int16 + F3 uint16 + F4 int32 + F5 uint32 + F6 int64 + F7 uint64 + F8 float32 + F9 float64 +} + +type benchNested struct { + N1 *benchSubset1 + F1 string + F2 int16 + F3 uint16 + F4 int32 + F5 uint32 + N2 *benchSubset2 + F6 int64 + F7 uint64 + F8 float32 + F9 float64 +} + +type benchDeepNested struct { + N1 *benchNested + F1 string + F2 int16 + N2 *benchNested + F4 int32 + F5 uint32 + N3 *benchNested + F6 int64 + F7 uint64 + F8 float32 + F9 float64 +} + +type benchSubset1 struct { + F1 int64 + F2 string + F3 []string + F4 []int64 +} + +type benchSubset2 struct { + F7 uint64 + F8 float32 + F9 float64 +} + +func makeUint64List(itemCount int) (out []int64) { + out = make([]int64, itemCount) + for i := 0; i < itemCount; i++ { + // get random int64: + out[i] = mrand.Int63() + } + return +} + +func makeStringList(itemCount int) (out []string) { + out = make([]string, itemCount) + for i := 0; i < itemCount; i++ { + data := make([]byte, i>>1) + io.ReadFull(rand.Reader, data) + out[i] = string(data) + } + return +} diff --git a/binary/interface_test.go b/binary/interface_test.go new file mode 100644 index 000000000..f05333068 --- /dev/null +++ b/binary/interface_test.go @@ -0,0 +1,151 @@ +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +type Example struct { + Prefix byte + Value uint32 +} + +func (e *Example) UnmarshalWithDecoder(decoder *Decoder) (err error) { + if e.Prefix, err = decoder.ReadByte(); err != nil { + return err + } + if e.Value, err = decoder.ReadUint32(BE); err != nil { + return err + } + return nil +} + +func (e Example) MarshalWithEncoder(encoder *Encoder) error { + if err := encoder.WriteByte(e.Prefix); err != nil { + return err + } + return encoder.WriteUint32(e.Value, BE) +} + +type testCustomCoder struct { + val string +} + +func (d *testCustomCoder) UnmarshalWithDecoder(decoder *Decoder) error { + d.val = "hello world" + return nil +} + +func (d testCustomCoder) MarshalWithEncoder(encoder *Encoder) error { + return encoder.WriteBytes([]byte("this is a test"), false) +} + +func TestMarshalWithEncoder(t *testing.T) { + { + buf := new(bytes.Buffer) + e := &Example{Value: 72, Prefix: 0xaa} + enc := NewBinEncoder(buf) + enc.Encode(e) + + assert.Equal(t, []byte{ + 0xaa, 0x00, 0x00, 0x00, 0x48, + }, buf.Bytes()) + } + { + // on pointer: + { + buf := new(bytes.Buffer) + e := &testCustomCoder{} + enc := NewBinEncoder(buf) + err := enc.Encode(e) + assert.NoError(t, err) + + assert.Equal(t, []byte("this is a test"), buf.Bytes()) + } + { + buf := new(bytes.Buffer) + e := &testCustomCoder{} + enc := NewBorshEncoder(buf) + err := enc.Encode(e) + assert.NoError(t, err) + + assert.Equal(t, []byte("this is a test"), buf.Bytes()) + } + // on value: + { + buf := new(bytes.Buffer) + e := testCustomCoder{} + enc := NewBinEncoder(buf) + err := enc.Encode(e) + assert.NoError(t, err) + + assert.Equal(t, []byte("this is a test"), buf.Bytes()) + } + { + buf := new(bytes.Buffer) + e := testCustomCoder{} + enc := NewBorshEncoder(buf) + err := enc.Encode(e) + assert.NoError(t, err) + + assert.Equal(t, []byte("this is a test"), buf.Bytes()) + } + } +} + +func TestUnmarshalWithDecoder(t *testing.T) { + { + buf := []byte{ + 0xaa, 0x00, 0x00, 0x00, 0x48, + } + + e := &Example{} + d := NewBinDecoder(buf) + err := d.Decode(e) + assert.NoError(t, err) + assert.Equal(t, e, &Example{Value: 72, Prefix: 0xaa}) + assert.Equal(t, 0, d.Remaining()) + } + { + { + buf := []byte{ + 0xaa, 0x00, 0x00, 0x00, 0x48, + } + + e := &testCustomCoder{} + d := NewBinDecoder(buf) + err := d.Decode(e) + assert.NoError(t, err) + + assert.Equal(t, "hello world", e.val) + } + { + buf := []byte{ + 0xaa, 0x00, 0x00, 0x00, 0x48, + } + + e := &testCustomCoder{} + d := NewBorshDecoder(buf) + err := d.Decode(e) + assert.NoError(t, err) + + assert.Equal(t, "hello world", e.val) + } + } +} diff --git a/binary/logging.go b/binary/logging.go new file mode 100644 index 000000000..7bce191d8 --- /dev/null +++ b/binary/logging.go @@ -0,0 +1,46 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "fmt" + + "github.com/streamingfast/logging" + "go.uber.org/zap" +) + +var ( + zlog = zap.NewNop() + traceEnabled = false +) + +func init() { + zlog_, tracer := logging.PackageLogger("binary", "github.com/gagliardetto/solana-go/binary") + traceEnabled = tracer.Enabled() + zlog = zlog_ +} + +type logStringerFunc func() string + +func (f logStringerFunc) String() string { return f() } + +func typeField(field string, v interface{}) zap.Field { + return zap.Stringer(field, logStringerFunc(func() string { + return fmt.Sprintf("%T", v) + })) +} diff --git a/binary/parse_test.go b/binary/parse_test.go new file mode 100644 index 000000000..1c7c11419 --- /dev/null +++ b/binary/parse_test.go @@ -0,0 +1,78 @@ +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/binary" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_parseFieldTag(t *testing.T) { + tests := []struct { + name string + tag string + expectValue *fieldTag + }{ + { + name: "no tags", + tag: "", + expectValue: &fieldTag{ + Order: binary.LittleEndian, + }, + }, + { + name: "with a skip", + tag: `bin:"-"`, + expectValue: &fieldTag{ + Order: binary.LittleEndian, + Skip: true, + }, + }, + { + name: "with a sizeof", + tag: `bin:"sizeof=Tokens"`, + expectValue: &fieldTag{ + Order: binary.LittleEndian, + SizeOf: "Tokens", + }, + }, + { + name: "with a optional", + tag: `bin:"optional"`, + expectValue: &fieldTag{ + Order: binary.LittleEndian, + Option: true, + }, + }, + { + name: "with a optional and size of", + tag: `bin:"optional sizeof=Nodes"`, + expectValue: &fieldTag{ + Order: binary.LittleEndian, + Option: true, + SizeOf: "Nodes", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expectValue, parseFieldTag(reflect.StructTag(test.tag))) + }) + } +} diff --git a/binary/perf_bench_test.go b/binary/perf_bench_test.go new file mode 100644 index 000000000..f13af4c33 --- /dev/null +++ b/binary/perf_bench_test.go @@ -0,0 +1,303 @@ +package bin + +import ( + "bytes" + "encoding/binary" + "testing" +) + +// discardWriter is a zero-alloc io.Writer used to measure the encoder itself +// without pulling in the growth/copy costs of bytes.Buffer. +type discardWriter struct{ n int } + +func (d *discardWriter) Write(p []byte) (int, error) { + d.n += len(p) + return len(p), nil +} + +// ---- primitive writes (target of review item #1) ---- + +func BenchmarkEncode_WriteUint16(b *testing.B) { + var w discardWriter + e := NewBinEncoder(&w) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = e.WriteUint16(uint16(i), binary.LittleEndian) + } +} + +func BenchmarkEncode_WriteUint32(b *testing.B) { + var w discardWriter + e := NewBinEncoder(&w) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = e.WriteUint32(uint32(i), binary.LittleEndian) + } +} + +func BenchmarkEncode_WriteUint64(b *testing.B) { + var w discardWriter + e := NewBinEncoder(&w) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = e.WriteUint64(uint64(i), binary.LittleEndian) + } +} + +// review item #8 +func BenchmarkEncode_WriteUint128(b *testing.B) { + var w discardWriter + e := NewBinEncoder(&w) + v := Uint128{Lo: 0xdeadbeef, Hi: 0xcafebabe} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = e.WriteUint128(v, binary.LittleEndian) + } +} + +func BenchmarkEncode_WriteUVarInt(b *testing.B) { + var w discardWriter + e := NewBinEncoder(&w) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = e.WriteUVarInt(i) + } +} + +// ---- compact-u16 (target of review item #7) ---- + +func BenchmarkEncode_CompactU16_1byte(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf := make([]byte, 0, 3) + _ = EncodeCompactU16Length(&buf, 42) + } +} + +func BenchmarkEncode_CompactU16_2byte(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf := make([]byte, 0, 3) + _ = EncodeCompactU16Length(&buf, 300) + } +} + +func BenchmarkEncode_CompactU16_3byte(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf := make([]byte, 0, 3) + _ = EncodeCompactU16Length(&buf, 20000) + } +} + +func BenchmarkDecode_CompactU16_1byte(b *testing.B) { + var buf []byte + _ = EncodeCompactU16Length(&buf, 42) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = DecodeCompactU16(buf) + } +} + +func BenchmarkDecode_CompactU16_3byte(b *testing.B) { + var buf []byte + _ = EncodeCompactU16Length(&buf, 20000) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = DecodeCompactU16(buf) + } +} + +// WriteCompactU16 routes through the Encoder, which currently allocates twice +// (once for the scratch append buffer, once via toWriter). +func BenchmarkEncode_CompactU16_ViaEncoder(b *testing.B) { + var w discardWriter + e := NewCompactU16Encoder(&w) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = e.WriteCompactU16(20000) + } +} + +// ---- PoD slice decoding (target of review item #3) ---- + +func BenchmarkDecode_SliceUint64_8k(b *testing.B) { + const l = 8192 + var buf bytes.Buffer + e := NewBorshEncoder(&buf) + _ = e.WriteUint32(uint32(l), LE) + for i := 0; i < l; i++ { + _ = e.WriteUint64(uint64(i), LE) + } + data := buf.Bytes() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var got []uint64 + dec := NewBorshDecoder(data) + _ = dec.Decode(&got) + } +} + +func BenchmarkDecode_SliceUint32_8k(b *testing.B) { + const l = 8192 + var buf bytes.Buffer + e := NewBorshEncoder(&buf) + _ = e.WriteUint32(uint32(l), LE) + for i := 0; i < l; i++ { + _ = e.WriteUint32(uint32(i), LE) + } + data := buf.Bytes() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var got []uint32 + dec := NewBorshDecoder(data) + _ = dec.Decode(&got) + } +} + +func BenchmarkDecode_SliceUint16_8k(b *testing.B) { + const l = 8192 + var buf bytes.Buffer + e := NewBorshEncoder(&buf) + _ = e.WriteUint32(uint32(l), LE) + for i := 0; i < l; i++ { + _ = e.WriteUint16(uint16(i), LE) + } + data := buf.Bytes() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var got []uint16 + dec := NewBorshDecoder(data) + _ = dec.Decode(&got) + } +} + +// ---- end-to-end struct encode/decode (Solana-ish layout) ---- + +type perfBenchStruct struct { + A uint64 + B uint64 + C uint32 + D [32]byte + E []uint64 +} + +func makePerfBenchStruct() perfBenchStruct { + s := perfBenchStruct{A: 1, B: 2, C: 3, E: make([]uint64, 64)} + for i := range s.E { + s.E[i] = uint64(i) + } + for i := range s.D { + s.D[i] = byte(i) + } + return s +} + +func BenchmarkEncode_Struct_Borsh(b *testing.B) { + s := makePerfBenchStruct() + var w discardWriter + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + e := NewBorshEncoder(&w) + _ = e.Encode(&s) + } +} + +// Buffered-mode encoder (writes into internal []byte instead of via io.Writer). +func BenchmarkEncode_Struct_Borsh_Buffered(b *testing.B) { + s := makePerfBenchStruct() + e := NewBorshEncoderBuf() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + e.Reset() + _ = e.Encode(&s) + } +} + +func BenchmarkEncode_WriteUint64_Buffered(b *testing.B) { + e := NewBorshEncoderBuf() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if i&0xfff == 0 { + e.Reset() + } + _ = e.WriteUint64(uint64(i), LE) + } +} + +// PoD slice decode with capacity already in place — measures the cap-reuse +// fast path added in round 4. +func BenchmarkDecode_SliceUint64_8k_Reused(b *testing.B) { + const l = 8192 + var buf bytes.Buffer + e := NewBorshEncoder(&buf) + _ = e.WriteUint32(uint32(l), LE) + for i := 0; i < l; i++ { + _ = e.WriteUint64(uint64(i), LE) + } + data := buf.Bytes() + got := make([]uint64, 0, l) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + dec := NewBorshDecoder(data) + _ = dec.Decode(&got) + } +} + +// ReadString vs ReadStringBorrow — measure the unsafe.String zero-copy win. +func BenchmarkDecode_ReadString_Copy(b *testing.B) { + payload := []byte("the quick brown fox jumps over the lazy dog") + var buf bytes.Buffer + e := NewBinEncoder(&buf) + _ = e.WriteBytes(payload, true) + data := buf.Bytes() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + dec := NewBinDecoder(data) + _, _ = dec.ReadString() + } +} + +func BenchmarkDecode_ReadString_Borrow(b *testing.B) { + payload := []byte("the quick brown fox jumps over the lazy dog") + var buf bytes.Buffer + e := NewBinEncoder(&buf) + _ = e.WriteBytes(payload, true) + data := buf.Bytes() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + dec := NewBinDecoder(data) + _, _ = dec.ReadStringBorrow() + } +} + +func BenchmarkDecode_Struct_Borsh(b *testing.B) { + s := makePerfBenchStruct() + var buf bytes.Buffer + _ = NewBorshEncoder(&buf).Encode(&s) + data := buf.Bytes() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var out perfBenchStruct + dec := NewBorshDecoder(data) + _ = dec.Decode(&out) + } +} diff --git a/binary/sighash.go b/binary/sighash.go new file mode 100644 index 000000000..b920e3e55 --- /dev/null +++ b/binary/sighash.go @@ -0,0 +1,64 @@ +// Copyright 2021 github.com/gagliardetto +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "crypto/sha256" +) + +// Sighash creates an anchor sighash for the provided namespace and element. +// An anchor sighash is the first 8 bytes of the sha256 of {namespace}:{name} +// NOTE: you must first convert the name to snake case using `ToSnakeForSighash`. +func Sighash(namespace string, name string) []byte { + data := namespace + ":" + name + sum := sha256.Sum256([]byte(data)) + return sum[0:8] +} + +func SighashInstruction(name string) []byte { + // Instruction sighash are the first 8 bytes of the sha256 of + // {SIGHASH_INSTRUCTION_NAMESPACE}:{snake_case(name)} + return Sighash(SIGHASH_GLOBAL_NAMESPACE, ToSnakeForSighash(name)) +} + +func SighashAccount(name string) []byte { + // Account sighash are the first 8 bytes of the sha256 of + // {SIGHASH_ACCOUNT_NAMESPACE}:{camelCase(name)} + return Sighash(SIGHASH_ACCOUNT_NAMESPACE, ToPascalCase(name)) +} + +// NOTE: no casing conversion is done here, it's up to the caller to +// provide the correct casing. +func SighashTypeID(namespace string, name string) TypeID { + return TypeIDFromBytes(Sighash(namespace, (name))) +} + +// Namespace for calculating state instruction sighash signatures. +const SIGHASH_STATE_NAMESPACE string = "state" + +// Namespace for calculating instruction sighash signatures for any instruction +// not affecting program state. +const SIGHASH_GLOBAL_NAMESPACE string = "global" + +const SIGHASH_ACCOUNT_NAMESPACE string = "account" + +const ACCOUNT_DISCRIMINATOR_SIZE = 8 + +// https://github.com/project-serum/anchor/pull/64/files +// https://github.com/project-serum/anchor/blob/2f780e0d274f47e442b3f0d107db805a41c6def0/ts/src/coder/common.ts#L109 +// https://github.com/project-serum/anchor/blob/6b5ed789fc856408986e8868229887354d6d4073/lang/syn/src/codegen/program/common.rs#L17 + +// TODO: +// https://github.com/project-serum/anchor/blob/84a2b8200cc3c7cb51d7127918e6cbbd836f0e99/ts/src/error.ts#L48 diff --git a/binary/sighash_test.go b/binary/sighash_test.go new file mode 100644 index 000000000..abab2b115 --- /dev/null +++ b/binary/sighash_test.go @@ -0,0 +1,198 @@ +// Copyright 2021 github.com/gagliardetto +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSighash(t *testing.T) { + { + ixName := "hello" + got := Sighash(SIGHASH_GLOBAL_NAMESPACE, ToSnakeForSighash(ixName)) + require.NotEmpty(t, got) + require.Equal(t, + got, + SighashInstruction(ixName), + ) + + expected := []byte{149, 118, 59, 220, 196, 127, 161, 179} + require.Equal(t, + expected, + got, + ) + } + { + ixName := "serumSwap" + got := Sighash(SIGHASH_GLOBAL_NAMESPACE, ToSnakeForSighash(ixName)) + require.NotEmpty(t, got) + require.Equal(t, + got, + SighashInstruction(ixName), + ) + + expected := []byte{88, 183, 70, 249, 214, 118, 82, 210} + require.Equal(t, + expected, + got, + ) + require.Equal(t, + expected, + SighashInstruction(ixName), + ) + } + { + ixName := "aldrinV2Swap" + got := Sighash(SIGHASH_GLOBAL_NAMESPACE, ToSnakeForSighash(ixName)) + require.NotEmpty(t, got) + require.Equal(t, + got, + SighashInstruction(ixName), + ) + + expected := []byte{190, 166, 89, 139, 33, 152, 16, 10} + require.Equal(t, + expected, + got, + ) + require.Equal(t, + expected, + SighashInstruction(ixName), + ) + } + { + ixName := "raydiumSwapV2" + got := Sighash(SIGHASH_GLOBAL_NAMESPACE, ToSnakeForSighash(ixName)) + require.NotEmpty(t, got) + require.Equal(t, + got, + SighashInstruction(ixName), + ) + + expected := []byte{69, 227, 98, 93, 237, 202, 223, 140} + require.Equal(t, + expected, + got, + ) + require.Equal(t, + expected, + SighashInstruction(ixName), + ) + } + { + accountName := "DialectAccount" + got := Sighash(SIGHASH_ACCOUNT_NAMESPACE, (accountName)) + require.NotEmpty(t, got) + require.Equal(t, + got, + SighashAccount(accountName), + ) + + expected := []byte{157, 38, 120, 189, 93, 204, 119, 18} + require.Equal(t, + expected, + got, + ) + require.Equal(t, + expected, + SighashAccount(accountName), + ) + } +} + +func TestToSnakeForSighash(t *testing.T) { + t.Run( + "typescript", + // "typescript package: https://www.npmjs.com/package/snake-case", + func(t *testing.T) { + // copied from https://github.com/blakeembrey/change-case/blob/040a079f007879cb0472ba4f7cc2e1d3185e90ba/packages/snake-case/src/index.spec.ts + // as used in anchor. + testCases := [][2]string{ + {"", ""}, + {"_id", "id"}, + {"test", "test"}, + {"test string", "test_string"}, + {"Test String", "test_string"}, + {"TestV2", "test_v2"}, + {"version 1.2.10", "version_1_2_10"}, + {"version 1.21.0", "version_1_21_0"}, + {"doSomething2", "do_something2"}, + } + + for _, testCase := range testCases { + t.Run( + testCase[0], + func(t *testing.T) { + assert.Equal(t, + testCase[1], + ToSnakeForSighash(testCase[0]), + "from %q", testCase[0], + ) + }) + } + }, + ) + t.Run( + "rust", + // "rust package: https://docs.rs/heck", + func(t *testing.T) { + // copied from https://github.com/withoutboats/heck/blob/dbcfc7b8db8e532d1fad44518cf73e88d5212161/src/snake.rs#L60 + // as used in anchor. + testCases := [][2]string{ + {"CamelCase", "camel_case"}, + {"This is Human case.", "this_is_human_case"}, + {"MixedUP CamelCase, with some Spaces", "mixed_up_camel_case_with_some_spaces"}, + {"mixed_up_ snake_case with some _spaces", "mixed_up_snake_case_with_some_spaces"}, + {"kebab-case", "kebab_case"}, + {"SHOUTY_SNAKE_CASE", "shouty_snake_case"}, + {"snake_case", "snake_case"}, + {"this-contains_ ALLKinds OfWord_Boundaries", "this_contains_all_kinds_of_word_boundaries"}, + + // #[cfg(feature = "unicode")] + {"XΣXΣ baffle", "xσxσ_baffle"}, + {"XMLHttpRequest", "xml_http_request"}, + {"FIELD_NAME11", "field_name11"}, + {"99BOTTLES", "99bottles"}, + {"FieldNamE11", "field_nam_e11"}, + + {"abc123def456", "abc123def456"}, + {"abc123DEF456", "abc123_def456"}, + {"abc123Def456", "abc123_def456"}, + {"abc123DEf456", "abc123_d_ef456"}, + {"ABC123def456", "abc123def456"}, + {"ABC123DEF456", "abc123def456"}, + {"ABC123Def456", "abc123_def456"}, + {"ABC123DEf456", "abc123d_ef456"}, + {"ABC123dEEf456FOO", "abc123d_e_ef456_foo"}, + {"abcDEF", "abc_def"}, + {"ABcDE", "a_bc_de"}, + } + + for _, testCase := range testCases { + t.Run( + testCase[0], + func(t *testing.T) { + assert.Equal(t, + testCase[1], + ToSnakeForSighash(testCase[0]), + "from %q", testCase[0], + ) + }) + } + }) +} diff --git a/binary/tags-options.go b/binary/tags-options.go new file mode 100644 index 000000000..3e0cd4759 --- /dev/null +++ b/binary/tags-options.go @@ -0,0 +1,107 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import "encoding/binary" + +// option carries per-field decode/encode flags. It is intentionally a small +// value type so it can be passed by value and stack-allocated. Previously it +// was always heap-allocated via &option{...}; now callers should use a literal +// option{} (the zero value is a valid "no special handling" sentinel; the +// decode/encode entry points fill in a default Order if it's nil). +type option struct { + is_OptionalField bool + is_COptionalField bool + sliceSizeIsSet bool + sliceSize int + Order binary.ByteOrder +} + +var ( + LE binary.ByteOrder = binary.LittleEndian + BE binary.ByteOrder = binary.BigEndian +) + +var defaultByteOrder = binary.LittleEndian + +// defaultOption is the value used when a caller would otherwise pass nil. +var defaultOption = option{Order: defaultByteOrder} + +func (o option) is_Optional() bool { + return o.is_OptionalField +} + +func (o option) is_COptional() bool { + return o.is_COptionalField +} + +func (o option) hasSizeOfSlice() bool { + return o.sliceSizeIsSet +} + +func (o option) getSizeOfSlice() int { + return o.sliceSize +} + +func (o *option) setSizeOfSlice(size int) *option { + o.sliceSize = size + o.sliceSizeIsSet = true + return o +} + +type Encoding int + +const ( + EncodingBin Encoding = iota + EncodingCompactU16 + EncodingBorsh +) + +func (enc Encoding) String() string { + switch enc { + case EncodingBin: + return "Bin" + case EncodingCompactU16: + return "CompactU16" + case EncodingBorsh: + return "Borsh" + default: + return "" + } +} + +func (en Encoding) IsBorsh() bool { + return en == EncodingBorsh +} + +func (en Encoding) IsBin() bool { + return en == EncodingBin +} + +func (en Encoding) IsCompactU16() bool { + return en == EncodingCompactU16 +} + +func isValidEncoding(enc Encoding) bool { + switch enc { + case EncodingBin, EncodingCompactU16, EncodingBorsh: + return true + default: + return false + } +} diff --git a/binary/tags-parser.go b/binary/tags-parser.go new file mode 100644 index 000000000..39f8e3ad9 --- /dev/null +++ b/binary/tags-parser.go @@ -0,0 +1,80 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/binary" + "reflect" + "strings" +) + +type fieldTag struct { + SizeOf string + Skip bool + Order binary.ByteOrder + Option bool + COption bool + BinaryExtension bool + + IsBorshEnum bool +} + +func isIn(s string, candidates ...string) bool { + for _, c := range candidates { + if s == c { + return true + } + } + return false +} + +func parseFieldTag(tag reflect.StructTag) *fieldTag { + t := &fieldTag{ + Order: defaultByteOrder, + } + tagStr := tag.Get("bin") + for _, s := range strings.Split(tagStr, " ") { + if strings.HasPrefix(s, "sizeof=") { + tmp := strings.SplitN(s, "=", 2) + t.SizeOf = tmp[1] + } else if s == "big" { + t.Order = binary.BigEndian + } else if s == "little" { + t.Order = binary.LittleEndian + } else if isIn(s, "optional", "option") { + t.Option = true + } else if isIn(s, "coption") { + t.COption = true + } else if s == "binary_extension" { + t.BinaryExtension = true + } else if isIn(s, "-", "skip") { + t.Skip = true + } else if isIn(s, "enum") { + t.IsBorshEnum = true + } + } + + // TODO: parse other borsh tags + if strings.TrimSpace(tag.Get("borsh_skip")) == "true" { + t.Skip = true + } + if strings.TrimSpace(tag.Get("borsh_enum")) == "true" { + t.IsBorshEnum = true + } + return t +} diff --git a/binary/testdata/fuzz/FuzzUint128JSON/4cde65be6ccc67df b/binary/testdata/fuzz/FuzzUint128JSON/4cde65be6ccc67df new file mode 100644 index 000000000..457673f9d --- /dev/null +++ b/binary/testdata/fuzz/FuzzUint128JSON/4cde65be6ccc67df @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\"700000000000000000000000000000000000000\"") diff --git a/binary/testdata/fuzz/FuzzUint128JSON/6d619d56cd91d9ba b/binary/testdata/fuzz/FuzzUint128JSON/6d619d56cd91d9ba new file mode 100644 index 000000000..938c4d574 --- /dev/null +++ b/binary/testdata/fuzz/FuzzUint128JSON/6d619d56cd91d9ba @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\"0X0000000000000000\"") diff --git a/binary/tools.go b/binary/tools.go new file mode 100644 index 000000000..54e4df680 --- /dev/null +++ b/binary/tools.go @@ -0,0 +1,84 @@ +// Copyright 2021 github.com/gagliardetto +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "fmt" + "strconv" + "strings" +) + +// FormatByteSlice formats the given byte slice into a readable format. +func FormatByteSlice(buf []byte) string { + elems := make([]string, 0) + for _, v := range buf { + elems = append(elems, strconv.Itoa(int(v))) + } + + return "{" + strings.Join(elems, ", ") + "}" + fmt.Sprintf("(len=%v)", len(elems)) +} + +func FormatDiscriminator(disc [8]byte) string { + elems := make([]string, 0) + for _, v := range disc { + elems = append(elems, strconv.Itoa(int(v))) + } + return "{" + strings.Join(elems, ", ") + "}" +} + +type WriteByWrite struct { + writes [][]byte + name string +} + +func NewWriteByWrite(name string) *WriteByWrite { + return &WriteByWrite{ + name: name, + } +} + +func (rec *WriteByWrite) Write(b []byte) (int, error) { + // Copy defensively: io.Writer's contract forbids retaining p, and the + // Encoder reuses a per-call scratch buffer across primitive writes. + cp := make([]byte, len(b)) + copy(cp, b) + rec.writes = append(rec.writes, cp) + return len(b), nil +} + +func (rec *WriteByWrite) Bytes() []byte { + out := make([]byte, 0) + for _, v := range rec.writes { + out = append(out, v...) + } + return out +} + +func (rec WriteByWrite) String() string { + builder := new(strings.Builder) + if rec.name != "" { + builder.WriteString(rec.name + ":\n") + } + for index, v := range rec.writes { + builder.WriteString(fmt.Sprintf("- %v: %s\n", index, FormatByteSlice(v))) + } + return builder.String() +} + +// IsByteSlice returns true if the provided element is a []byte. +func IsByteSlice(v interface{}) bool { + _, ok := v.([]byte) + return ok +} diff --git a/binary/tools_test.go b/binary/tools_test.go new file mode 100644 index 000000000..24d465727 --- /dev/null +++ b/binary/tools_test.go @@ -0,0 +1,27 @@ +package bin + +import "encoding/binary" + +func concatByteSlices(slices ...[]byte) (out []byte) { + for i := range slices { + out = append(out, slices[i]...) + } + return +} +func uint16ToBytes(i uint16, order binary.ByteOrder) []byte { + buf := make([]byte, 2) + order.PutUint16(buf, i) + return buf +} + +func uint32ToBytes(i uint32, order binary.ByteOrder) []byte { + buf := make([]byte, 4) + order.PutUint32(buf, i) + return buf +} + +func uint64ToBytes(i uint64, order binary.ByteOrder) []byte { + buf := make([]byte, 8) + order.PutUint64(buf, i) + return buf +} diff --git a/binary/type_plan.go b/binary/type_plan.go new file mode 100644 index 000000000..84178201d --- /dev/null +++ b/binary/type_plan.go @@ -0,0 +1,586 @@ +// Copyright 2024 github.com/gagliardetto +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/binary" + "io" + "reflect" + "sync" + "unsafe" +) + +// fieldFastDecode is a per-field decode closure that bypasses the generic +// type-switch in decodeBin/Borsh/CompactU16 for known primitive kinds. It +// writes directly into the struct field's memory via unsafe pointer arithmetic +// and skips option construction. Returns nil if the field is not eligible. +type fieldFastDecode func(d *Decoder, fv reflect.Value) error + +// fieldFastEncode is the encode-side counterpart to fieldFastDecode. +type fieldFastEncode func(e *Encoder, fv reflect.Value) error + +// fieldPlan caches per-struct-field reflect work that previously ran on every +// decode/encode call: tag parsing, sizeOf wiring, and BinaryUnmarshaler / +// BinaryMarshaler interface satisfaction. +// +// The plan is built lazily once per reflect.Type and cached in typePlanCache. +type fieldPlan struct { + index int // index for rv.Field(i) / rt.Field(i) + name string // for error messages and trace logs + tag fieldTag // parsed once at plan-build time + skip bool + binaryExtension bool + canInterface bool // PkgPath == "" (exported) + fieldType reflect.Type // cached field type + + // sizeOf wiring (struct-local; resolved at plan-build time): + // + // sizeOfTargetIdx: this field has a `sizeof=X` tag — its decoded value + // is the slice length for the field at that index. -1 if not a source. + // sizeFromIdx: this field's slice length is supplied by another + // field's value (i.e. some earlier field had `sizeof=`). + // -1 if not a target. + sizeOfTargetIdx int + sizeFromIdx int + + // BinaryUnmarshaler / BinaryMarshaler satisfaction (Borsh decoder fast + // path inspects these to bypass the indirect() call entirely). + ptrImplementsUnmarshaler bool + valImplementsUnmarshaler bool + ptrImplementsMarshaler bool + valImplementsMarshaler bool + + // Fast per-field dispatch closures populated for primitive kinds. nil + // means the field must go through the generic decodeXxx/encodeXxx path. + // borshFastDecode / binFastDecode are populated separately because + // borsh hard-codes little-endian whereas bin/compact-u16 honour the + // per-field byte-order tag. + borshFastDecode fieldFastDecode + borshFastEncode fieldFastEncode + binFastDecode fieldFastDecode + binFastEncode fieldFastEncode +} + +// typePlan is the cached layout for a struct type. It is immutable once +// constructed and can be shared across goroutines. +type typePlan struct { + // isComplexEnum is true when the first field has type BorshEnum and is + // tagged with `borsh_enum` — these structs decode via deserializeComplexEnum + // instead of the field loop. + isComplexEnum bool + + fields []fieldPlan + + // hasSizeOf is true if any field carries a `sizeof=` tag and at least one + // other field is its target. Enables the per-call sizes-array allocation. + hasSizeOf bool + + // hasBinaryExtension is true if any field is tagged with `binary_extension`. + hasBinaryExtension bool +} + +var typePlanCache sync.Map // map[reflect.Type]*typePlan + +// planForStruct returns the cached typePlan for the given struct type, building +// it on first sight. rt MUST be a struct. +func planForStruct(rt reflect.Type) *typePlan { + if v, ok := typePlanCache.Load(rt); ok { + return v.(*typePlan) + } + plan := buildStructPlan(rt) + actual, _ := typePlanCache.LoadOrStore(rt, plan) + return actual.(*typePlan) +} + +// PrewarmTypes builds and caches the typePlan for each value's underlying +// struct type. Pass struct values (or pointers to them); the helper unwraps +// pointers and ignores anything that doesn't resolve to a struct. Intended +// to be called from package init() in latency-sensitive callers, so the +// first encode/decode of a type doesn't pay the reflect-walk cost. +// +// Steady-state performance is unchanged: the typePlan cache already +// amortizes the cost across calls. Prewarming only moves the one-time +// per-type cost from first-call to init(). +func PrewarmTypes(values ...any) { + for _, v := range values { + rt := reflect.TypeOf(v) + prewarmType(rt) + } +} + +// PrewarmVariantDefinition builds and caches the typePlan for every type +// registered in def. This is the convenience entry point for program +// packages: a single call from init() prewarms every instruction variant +// defined in the package's InstructionImplDef. +// +// def may be nil; the call is a no-op in that case. +func PrewarmVariantDefinition(def *VariantDefinition) { + if def == nil { + return + } + for _, rt := range def.typeIDToType { + prewarmType(rt) + } +} + +// prewarmType unwraps pointer types and, if the result is a struct, builds +// and caches its typePlan. Used by PrewarmTypes and PrewarmVariantDefinition. +func prewarmType(rt reflect.Type) { + if rt == nil { + return + } + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() == reflect.Struct { + planForStruct(rt) + } +} + +func buildStructPlan(rt reflect.Type) *typePlan { + plan := &typePlan{} + n := rt.NumField() + if n == 0 { + return plan + } + + // Detect Borsh complex enum at plan time so the hot path can short-circuit. + first := rt.Field(0) + if isTypeBorshEnum(first.Type) && parseFieldTag(first.Tag).IsBorshEnum { + plan.isComplexEnum = true + return plan + } + + // Build a name → index map so we can resolve sizeof wiring statically. + // This map only lives during plan construction. + nameToIdx := make(map[string]int, n) + for i := 0; i < n; i++ { + nameToIdx[rt.Field(i).Name] = i + } + + plan.fields = make([]fieldPlan, n) + for i := 0; i < n; i++ { + sf := rt.Field(i) + tag := parseFieldTag(sf.Tag) + + fp := fieldPlan{ + index: i, + name: sf.Name, + tag: *tag, + skip: tag.Skip, + binaryExtension: tag.BinaryExtension, + canInterface: sf.PkgPath == "", + fieldType: sf.Type, + sizeOfTargetIdx: -1, + sizeFromIdx: -1, + } + + if tag.SizeOf != "" { + if idx, ok := nameToIdx[tag.SizeOf]; ok { + fp.sizeOfTargetIdx = idx + plan.hasSizeOf = true + } + } + if tag.BinaryExtension { + plan.hasBinaryExtension = true + } + + // Cache BinaryUnmarshaler / BinaryMarshaler interface satisfaction. + // Used by the encoder/decoder per-field fast paths to skip the + // indirect() call and the per-field rv.Interface() boxing. + ptrType := reflect.PtrTo(sf.Type) + fp.ptrImplementsUnmarshaler = ptrType.Implements(unmarshalableType) + fp.valImplementsUnmarshaler = sf.Type.Implements(unmarshalableType) + fp.ptrImplementsMarshaler = ptrType.Implements(marshalableType) + fp.valImplementsMarshaler = sf.Type.Implements(marshalableType) + + // Populate fast dispatch closures for primitive kinds. Eligible only + // when the field has no special tags that change the wire format + // (Option/COption/sizeOf/etc.) and does not implement a custom marshal + // interface — both would require the generic dispatch path. + fp.assignFastClosures() + + plan.fields[i] = fp + } + + // Second pass: wire sizeFromIdx (target ← source) now that targets are known. + for srcIdx := range plan.fields { + tgt := plan.fields[srcIdx].sizeOfTargetIdx + if tgt >= 0 { + plan.fields[tgt].sizeFromIdx = srcIdx + } + } + + return plan +} + +// sizesScratch is a stack-allocated buffer the decode/encode hot paths can +// use to track per-field slice sizes when a struct has sizeof wiring, without +// allocating a heap map. Most Solana structs have far fewer than 16 fields. +const sizesScratchLen = 16 + +type sizesScratch [sizesScratchLen]int + +// assignFastClosures populates the per-field fast dispatch closures when the +// field is eligible for the unsafe-write fast path. A field is eligible iff: +// +// - It is exported (canInterface). +// - It has no Option/COption/Skip/BinaryExtension/SizeOf wiring (those +// require running the generic dispatch path). +// - It does not implement BinaryMarshaler/BinaryUnmarshaler (custom logic). +// - Its kind is one of the supported primitives. +// +// The closures use unsafe pointer writes against `fv.UnsafeAddr()` so they +// work uniformly for both `uint64` fields and named-type aliases (`type Slot +// uint64`) — the in-memory layout is identical when the kind matches. +func (fp *fieldPlan) assignFastClosures() { + if !fp.canInterface || + fp.skip || + fp.binaryExtension || + fp.tag.Option || + fp.tag.COption || + fp.tag.SizeOf != "" || + fp.sizeOfTargetIdx >= 0 || + fp.sizeFromIdx >= 0 || + fp.ptrImplementsUnmarshaler || + fp.valImplementsUnmarshaler || + fp.ptrImplementsMarshaler || + fp.valImplementsMarshaler { + return + } + + // Borsh always uses little-endian. The bin/compact-u16 paths read the + // field's tag-specified byte order, captured here so the closure doesn't + // need to consult any per-call option. + binOrder := fp.tag.Order + if binOrder == nil { + binOrder = defaultByteOrder + } + + switch fp.fieldType.Kind() { + case reflect.Uint8: + fp.borshFastDecode = fastDecodeUint8 + fp.borshFastEncode = fastEncodeUint8 + fp.binFastDecode = fastDecodeUint8 + fp.binFastEncode = fastEncodeUint8 + case reflect.Int8: + fp.borshFastDecode = fastDecodeInt8 + fp.borshFastEncode = fastEncodeInt8 + fp.binFastDecode = fastDecodeInt8 + fp.binFastEncode = fastEncodeInt8 + case reflect.Bool: + fp.borshFastDecode = fastDecodeBool + fp.borshFastEncode = fastEncodeBool + fp.binFastDecode = fastDecodeBool + fp.binFastEncode = fastEncodeBool + case reflect.Uint16: + fp.borshFastDecode = fastDecodeUint16LE + fp.borshFastEncode = fastEncodeUint16LE + fp.binFastDecode = makeBinFastDecodeUint16(binOrder) + fp.binFastEncode = makeBinFastEncodeUint16(binOrder) + case reflect.Int16: + fp.borshFastDecode = fastDecodeInt16LE + fp.borshFastEncode = fastEncodeInt16LE + fp.binFastDecode = makeBinFastDecodeInt16(binOrder) + fp.binFastEncode = makeBinFastEncodeInt16(binOrder) + case reflect.Uint32: + fp.borshFastDecode = fastDecodeUint32LE + fp.borshFastEncode = fastEncodeUint32LE + fp.binFastDecode = makeBinFastDecodeUint32(binOrder) + fp.binFastEncode = makeBinFastEncodeUint32(binOrder) + case reflect.Int32: + fp.borshFastDecode = fastDecodeInt32LE + fp.borshFastEncode = fastEncodeInt32LE + fp.binFastDecode = makeBinFastDecodeInt32(binOrder) + fp.binFastEncode = makeBinFastEncodeInt32(binOrder) + case reflect.Uint64: + fp.borshFastDecode = fastDecodeUint64LE + fp.borshFastEncode = fastEncodeUint64LE + fp.binFastDecode = makeBinFastDecodeUint64(binOrder) + fp.binFastEncode = makeBinFastEncodeUint64(binOrder) + case reflect.Int64: + fp.borshFastDecode = fastDecodeInt64LE + fp.borshFastEncode = fastEncodeInt64LE + fp.binFastDecode = makeBinFastDecodeInt64(binOrder) + fp.binFastEncode = makeBinFastEncodeInt64(binOrder) + } +} + +// ---- Fast primitive closures (Borsh + LE bin/compact-u16) ---- +// +// All decoders write straight into the destination via fv.UnsafeAddr(); all +// encoders read straight out the same way. fv.UnsafeAddr() is safe here +// because plan-driven hot paths only call these on fields obtained via +// rv.Field(i) on a top-level Decode/Encode value (always addressable). + +func fastDecodeUint8(d *Decoder, fv reflect.Value) error { + if d.pos >= len(d.data) { + return io.ErrUnexpectedEOF + } + *(*uint8)(unsafe.Pointer(fv.UnsafeAddr())) = d.data[d.pos] + d.pos++ + return nil +} + +func fastEncodeUint8(e *Encoder, fv reflect.Value) error { + return e.WriteByte(*(*uint8)(unsafe.Pointer(fv.UnsafeAddr()))) +} + +func fastDecodeInt8(d *Decoder, fv reflect.Value) error { + if d.pos >= len(d.data) { + return io.ErrUnexpectedEOF + } + *(*int8)(unsafe.Pointer(fv.UnsafeAddr())) = int8(d.data[d.pos]) + d.pos++ + return nil +} + +func fastEncodeInt8(e *Encoder, fv reflect.Value) error { + return e.WriteByte(byte(*(*int8)(unsafe.Pointer(fv.UnsafeAddr())))) +} + +func fastDecodeBool(d *Decoder, fv reflect.Value) error { + if d.pos >= len(d.data) { + return io.ErrUnexpectedEOF + } + *(*bool)(unsafe.Pointer(fv.UnsafeAddr())) = d.data[d.pos] != 0 + d.pos++ + return nil +} + +func fastEncodeBool(e *Encoder, fv reflect.Value) error { + if *(*bool)(unsafe.Pointer(fv.UnsafeAddr())) { + return e.WriteByte(1) + } + return e.WriteByte(0) +} + +func fastDecodeUint16LE(d *Decoder, fv reflect.Value) error { + if d.pos+2 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*uint16)(unsafe.Pointer(fv.UnsafeAddr())) = binary.LittleEndian.Uint16(d.data[d.pos:]) + d.pos += 2 + return nil +} + +func fastEncodeUint16LE(e *Encoder, fv reflect.Value) error { + return e.WriteUint16(*(*uint16)(unsafe.Pointer(fv.UnsafeAddr())), binary.LittleEndian) +} + +func fastDecodeInt16LE(d *Decoder, fv reflect.Value) error { + if d.pos+2 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*int16)(unsafe.Pointer(fv.UnsafeAddr())) = int16(binary.LittleEndian.Uint16(d.data[d.pos:])) + d.pos += 2 + return nil +} + +func fastEncodeInt16LE(e *Encoder, fv reflect.Value) error { + return e.WriteInt16(*(*int16)(unsafe.Pointer(fv.UnsafeAddr())), binary.LittleEndian) +} + +func fastDecodeUint32LE(d *Decoder, fv reflect.Value) error { + if d.pos+4 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*uint32)(unsafe.Pointer(fv.UnsafeAddr())) = binary.LittleEndian.Uint32(d.data[d.pos:]) + d.pos += 4 + return nil +} + +func fastEncodeUint32LE(e *Encoder, fv reflect.Value) error { + return e.WriteUint32(*(*uint32)(unsafe.Pointer(fv.UnsafeAddr())), binary.LittleEndian) +} + +func fastDecodeInt32LE(d *Decoder, fv reflect.Value) error { + if d.pos+4 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*int32)(unsafe.Pointer(fv.UnsafeAddr())) = int32(binary.LittleEndian.Uint32(d.data[d.pos:])) + d.pos += 4 + return nil +} + +func fastEncodeInt32LE(e *Encoder, fv reflect.Value) error { + return e.WriteInt32(*(*int32)(unsafe.Pointer(fv.UnsafeAddr())), binary.LittleEndian) +} + +func fastDecodeUint64LE(d *Decoder, fv reflect.Value) error { + if d.pos+8 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*uint64)(unsafe.Pointer(fv.UnsafeAddr())) = binary.LittleEndian.Uint64(d.data[d.pos:]) + d.pos += 8 + return nil +} + +func fastEncodeUint64LE(e *Encoder, fv reflect.Value) error { + return e.WriteUint64(*(*uint64)(unsafe.Pointer(fv.UnsafeAddr())), binary.LittleEndian) +} + +func fastDecodeInt64LE(d *Decoder, fv reflect.Value) error { + if d.pos+8 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*int64)(unsafe.Pointer(fv.UnsafeAddr())) = int64(binary.LittleEndian.Uint64(d.data[d.pos:])) + d.pos += 8 + return nil +} + +func fastEncodeInt64LE(e *Encoder, fv reflect.Value) error { + return e.WriteInt64(*(*int64)(unsafe.Pointer(fv.UnsafeAddr())), binary.LittleEndian) +} + +// ---- Bin/CompactU16 fast closures (configurable byte order via tag) ---- + +func makeBinFastDecodeUint16(order binary.ByteOrder) fieldFastDecode { + if order == binary.LittleEndian { + return fastDecodeUint16LE + } + return func(d *Decoder, fv reflect.Value) error { + if d.pos+2 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*uint16)(unsafe.Pointer(fv.UnsafeAddr())) = order.Uint16(d.data[d.pos:]) + d.pos += 2 + return nil + } +} + +func makeBinFastEncodeUint16(order binary.ByteOrder) fieldFastEncode { + if order == binary.LittleEndian { + return fastEncodeUint16LE + } + return func(e *Encoder, fv reflect.Value) error { + return e.WriteUint16(*(*uint16)(unsafe.Pointer(fv.UnsafeAddr())), order) + } +} + +func makeBinFastDecodeInt16(order binary.ByteOrder) fieldFastDecode { + if order == binary.LittleEndian { + return fastDecodeInt16LE + } + return func(d *Decoder, fv reflect.Value) error { + if d.pos+2 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*int16)(unsafe.Pointer(fv.UnsafeAddr())) = int16(order.Uint16(d.data[d.pos:])) + d.pos += 2 + return nil + } +} + +func makeBinFastEncodeInt16(order binary.ByteOrder) fieldFastEncode { + if order == binary.LittleEndian { + return fastEncodeInt16LE + } + return func(e *Encoder, fv reflect.Value) error { + return e.WriteInt16(*(*int16)(unsafe.Pointer(fv.UnsafeAddr())), order) + } +} + +func makeBinFastDecodeUint32(order binary.ByteOrder) fieldFastDecode { + if order == binary.LittleEndian { + return fastDecodeUint32LE + } + return func(d *Decoder, fv reflect.Value) error { + if d.pos+4 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*uint32)(unsafe.Pointer(fv.UnsafeAddr())) = order.Uint32(d.data[d.pos:]) + d.pos += 4 + return nil + } +} + +func makeBinFastEncodeUint32(order binary.ByteOrder) fieldFastEncode { + if order == binary.LittleEndian { + return fastEncodeUint32LE + } + return func(e *Encoder, fv reflect.Value) error { + return e.WriteUint32(*(*uint32)(unsafe.Pointer(fv.UnsafeAddr())), order) + } +} + +func makeBinFastDecodeInt32(order binary.ByteOrder) fieldFastDecode { + if order == binary.LittleEndian { + return fastDecodeInt32LE + } + return func(d *Decoder, fv reflect.Value) error { + if d.pos+4 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*int32)(unsafe.Pointer(fv.UnsafeAddr())) = int32(order.Uint32(d.data[d.pos:])) + d.pos += 4 + return nil + } +} + +func makeBinFastEncodeInt32(order binary.ByteOrder) fieldFastEncode { + if order == binary.LittleEndian { + return fastEncodeInt32LE + } + return func(e *Encoder, fv reflect.Value) error { + return e.WriteInt32(*(*int32)(unsafe.Pointer(fv.UnsafeAddr())), order) + } +} + +func makeBinFastDecodeUint64(order binary.ByteOrder) fieldFastDecode { + if order == binary.LittleEndian { + return fastDecodeUint64LE + } + return func(d *Decoder, fv reflect.Value) error { + if d.pos+8 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*uint64)(unsafe.Pointer(fv.UnsafeAddr())) = order.Uint64(d.data[d.pos:]) + d.pos += 8 + return nil + } +} + +func makeBinFastEncodeUint64(order binary.ByteOrder) fieldFastEncode { + if order == binary.LittleEndian { + return fastEncodeUint64LE + } + return func(e *Encoder, fv reflect.Value) error { + return e.WriteUint64(*(*uint64)(unsafe.Pointer(fv.UnsafeAddr())), order) + } +} + +func makeBinFastDecodeInt64(order binary.ByteOrder) fieldFastDecode { + if order == binary.LittleEndian { + return fastDecodeInt64LE + } + return func(d *Decoder, fv reflect.Value) error { + if d.pos+8 > len(d.data) { + return io.ErrUnexpectedEOF + } + *(*int64)(unsafe.Pointer(fv.UnsafeAddr())) = int64(order.Uint64(d.data[d.pos:])) + d.pos += 8 + return nil + } +} + +func makeBinFastEncodeInt64(order binary.ByteOrder) fieldFastEncode { + if order == binary.LittleEndian { + return fastEncodeInt64LE + } + return func(e *Encoder, fv reflect.Value) error { + return e.WriteInt64(*(*int64)(unsafe.Pointer(fv.UnsafeAddr())), order) + } +} diff --git a/binary/types.go b/binary/types.go new file mode 100644 index 000000000..8196002ee --- /dev/null +++ b/binary/types.go @@ -0,0 +1,341 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "strconv" +) + +type SafeString string + +func (ss SafeString) MarshalWithEncoder(encoder *Encoder) error { + return encoder.WriteString(string(ss)) +} + +func (ss *SafeString) UnmarshalWithDecoder(d *Decoder) error { + s, e := d.SafeReadUTF8String() + if e != nil { + return e + } + + *ss = SafeString(s) + return nil +} + +type Bool bool + +func (b *Bool) UnmarshalJSON(data []byte) error { + var num int + err := json.Unmarshal(data, &num) + if err == nil { + *b = Bool(num != 0) + return nil + } + + var boolVal bool + if err := json.Unmarshal(data, &boolVal); err != nil { + return fmt.Errorf("couldn't unmarshal bool as int or true/false: %w", err) + } + + *b = Bool(boolVal) + return nil +} + +func (b *Bool) UnmarshalWithDecoder(decoder *Decoder) error { + value, err := decoder.ReadBool() + if err != nil { + return err + } + + *b = Bool(value) + return nil +} + +func (b Bool) MarshalWithEncoder(encoder *Encoder) error { + return encoder.WriteBool(bool(b)) +} + +type HexBytes []byte + +func (t HexBytes) MarshalJSON() ([]byte, error) { + return json.Marshal(hex.EncodeToString(t)) +} + +func (t *HexBytes) UnmarshalJSON(data []byte) (err error) { + var s string + err = json.Unmarshal(data, &s) + if err != nil { + return + } + + *t, err = hex.DecodeString(s) + return +} + +func (t HexBytes) String() string { + return hex.EncodeToString(t) +} + +func (o *HexBytes) UnmarshalWithDecoder(decoder *Decoder) error { + value, err := decoder.ReadByteSlice() + if err != nil { + return fmt.Errorf("hex bytes: %w", err) + } + + *o = HexBytes(value) + return nil +} + +func (o HexBytes) MarshalWithEncoder(encoder *Encoder) error { + return encoder.WriteBytes([]byte(o), true) +} + +type Varint16 int16 + +func (o *Varint16) UnmarshalWithDecoder(decoder *Decoder) error { + value, err := decoder.ReadVarint16() + if err != nil { + return fmt.Errorf("varint16: %w", err) + } + + *o = Varint16(value) + return nil +} + +func (o Varint16) MarshalWithEncoder(encoder *Encoder) error { + return encoder.WriteVarInt(int(o)) +} + +type Varuint16 uint16 + +func (o *Varuint16) UnmarshalWithDecoder(decoder *Decoder) error { + value, err := decoder.ReadUvarint16() + if err != nil { + return fmt.Errorf("varuint16: %w", err) + } + + *o = Varuint16(value) + return nil +} + +func (o Varuint16) MarshalWithEncoder(encoder *Encoder) error { + return encoder.WriteUVarInt(int(o)) +} + +type Varuint32 uint32 + +func (o *Varuint32) UnmarshalWithDecoder(decoder *Decoder) error { + value, err := decoder.ReadUvarint64() + if err != nil { + return fmt.Errorf("varuint32: %w", err) + } + + *o = Varuint32(value) + return nil +} + +func (o Varuint32) MarshalWithEncoder(encoder *Encoder) error { + return encoder.WriteUVarInt(int(o)) +} + +type Varint32 int32 + +func (o *Varint32) UnmarshalWithDecoder(decoder *Decoder) error { + value, err := decoder.ReadVarint32() + if err != nil { + return err + } + + *o = Varint32(value) + return nil +} + +func (o Varint32) MarshalWithEncoder(encoder *Encoder) error { + return encoder.WriteVarInt(int(o)) +} + +type JSONFloat64 float64 + +func (f *JSONFloat64) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return errors.New("empty value") + } + + if data[0] == '"' { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + val, err := strconv.ParseFloat(s, 64) + if err != nil { + return err + } + + *f = JSONFloat64(val) + + return nil + } + + var fl float64 + if err := json.Unmarshal(data, &fl); err != nil { + return err + } + + *f = JSONFloat64(fl) + + return nil +} + +func (f *JSONFloat64) UnmarshalWithDecoder(dec *Decoder) error { + value, err := dec.ReadFloat64(dec.currentFieldOpt.Order) + if err != nil { + return err + } + + *f = JSONFloat64(value) + return nil +} + +func (f JSONFloat64) MarshalWithEncoder(enc *Encoder) error { + return enc.WriteFloat64(float64(f), enc.currentFieldOpt.Order) +} + +type Int64 int64 + +func (i Int64) MarshalJSON() (data []byte, err error) { + if i > 0xffffffff || i < -0xffffffff { + encodedInt, err := json.Marshal(int64(i)) + if err != nil { + return nil, err + } + data = append([]byte{'"'}, encodedInt...) + data = append(data, '"') + return data, nil + } + return json.Marshal(int64(i)) +} + +func (i *Int64) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return errors.New("empty value") + } + + if data[0] == '"' { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return err + } + + *i = Int64(val) + + return nil + } + + var v int64 + if err := json.Unmarshal(data, &v); err != nil { + return err + } + + *i = Int64(v) + + return nil +} + +func (i *Int64) UnmarshalWithDecoder(dec *Decoder) error { + value, err := dec.ReadInt64(dec.currentFieldOpt.Order) + if err != nil { + return err + } + + *i = Int64(value) + return nil +} + +func (i Int64) MarshalWithEncoder(enc *Encoder) error { + return enc.WriteInt64(int64(i), enc.currentFieldOpt.Order) +} + +type Uint64 uint64 + +func (i Uint64) MarshalJSON() (data []byte, err error) { + if i > 0xffffffff { + encodedInt, err := json.Marshal(uint64(i)) + if err != nil { + return nil, err + } + data = append([]byte{'"'}, encodedInt...) + data = append(data, '"') + return data, nil + } + return json.Marshal(uint64(i)) +} + +func (i *Uint64) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return errors.New("empty value") + } + + if data[0] == '"' { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + val, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return err + } + + *i = Uint64(val) + + return nil + } + + var v uint64 + if err := json.Unmarshal(data, &v); err != nil { + return err + } + + *i = Uint64(v) + + return nil +} + +func (i *Uint64) UnmarshalWithDecoder(dec *Decoder) error { + value, err := dec.ReadUint64(dec.currentFieldOpt.Order) + if err != nil { + return err + } + + *i = Uint64(value) + return nil +} + +func (i Uint64) MarshalWithEncoder(enc *Encoder) error { + return enc.WriteUint64(uint64(i), enc.currentFieldOpt.Order) +} diff --git a/binary/u128.go b/binary/u128.go new file mode 100644 index 000000000..396df7939 --- /dev/null +++ b/binary/u128.go @@ -0,0 +1,303 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "math/big" + "strings" +) + +// Uint128 +type Uint128 struct { + Lo uint64 + Hi uint64 + Endianness binary.ByteOrder +} + +func NewUint128BigEndian() *Uint128 { + return &Uint128{ + Endianness: binary.BigEndian, + } +} + +func NewUint128LittleEndian() *Uint128 { + return &Uint128{ + Endianness: binary.LittleEndian, + } +} + +func (i Uint128) getByteOrder() binary.ByteOrder { + if i.Endianness == nil { + return defaultByteOrder + } + return i.Endianness +} + +func (i Int128) getByteOrder() binary.ByteOrder { + return Uint128(i).getByteOrder() +} +func (i Float128) getByteOrder() binary.ByteOrder { + return Uint128(i).getByteOrder() +} + +// Bytes returns the 16-byte big-endian representation of the value, regardless +// of the configured Endianness. The big-endian form is what (*big.Int).SetBytes +// expects and is what BigInt/String/DecimalString consume downstream — this +// function exists to feed those, not to produce wire bytes. For wire encoding, +// use Encoder.WriteUint128 with the desired byte order. +func (i Uint128) Bytes() []byte { + buf := make([]byte, 16) + order := i.getByteOrder() + if order == binary.LittleEndian { + order.PutUint64(buf[:8], i.Lo) + order.PutUint64(buf[8:], i.Hi) + ReverseBytes(buf) + } else { + order.PutUint64(buf[:8], i.Hi) + order.PutUint64(buf[8:], i.Lo) + } + return buf +} + +func (i Uint128) BigInt() *big.Int { + buf := i.Bytes() + value := (&big.Int{}).SetBytes(buf) + return value +} + +func (i Uint128) String() string { + //Same for Int128, Float128 + return i.DecimalString() +} + +func (i Uint128) DecimalString() string { + return i.BigInt().String() +} + +func (i Uint128) HexString() string { + number := i.Bytes() + return fmt.Sprintf("0x%s", hex.EncodeToString(number)) +} + +func (i Uint128) MarshalJSON() (data []byte, err error) { + return []byte(`"` + i.String() + `"`), nil +} + +func ReverseBytes(s []byte) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +} + +func (i *Uint128) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + return nil + } + + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + if strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X") { + return i.unmarshalJSON_hex(s) + } + + return i.unmarshalJSON_decimal(s) +} + +func (i *Uint128) unmarshalJSON_decimal(s string) error { + parsed, ok := (&big.Int{}).SetString(s, 0) + if !ok { + return fmt.Errorf("could not parse %q", s) + } + // FillBytes panics on negatives or values that don't fit — guard both. + if parsed.Sign() < 0 { + return fmt.Errorf("uint128: negative value %q", s) + } + if parsed.BitLen() > 128 { + return fmt.Errorf("uint128: value %q exceeds 128 bits", s) + } + oo := parsed.FillBytes(make([]byte, 16)) + ReverseBytes(oo) + + dec := NewBinDecoder(oo) + + out, err := dec.ReadUint128(i.getByteOrder()) + if err != nil { + return err + } + i.Lo = out.Lo + i.Hi = out.Hi + + return nil +} + +func (i *Uint128) unmarshalJSON_hex(s string) error { + // 16 bytes = 32 hex characters. + truncatedVal := s[2:] + if len(truncatedVal) != 32 { + return fmt.Errorf("uint128 expects 32 hex characters after 0x, had %d", len(truncatedVal)) + } + + data, err := hex.DecodeString(truncatedVal) + if err != nil { + return err + } + + order := i.getByteOrder() + if order == binary.LittleEndian { + i.Lo = order.Uint64(data[:8]) + i.Hi = order.Uint64(data[8:]) + } else { + i.Hi = order.Uint64(data[:8]) + i.Lo = order.Uint64(data[8:]) + } + + return nil +} + +func (i *Uint128) UnmarshalWithDecoder(dec *Decoder) error { + order := i.getByteOrder() + if dec.currentFieldOpt.Order != nil { + order = dec.currentFieldOpt.Order + } + value, err := dec.ReadUint128(order) + if err != nil { + return err + } + + *i = value + return nil +} + +func (i Uint128) MarshalWithEncoder(enc *Encoder) error { + order := i.getByteOrder() + if enc.currentFieldOpt.Order != nil { + order = enc.currentFieldOpt.Order + } + return enc.WriteUint128(i, order) +} + +// Int128 +type Int128 Uint128 + +func (i Int128) BigInt() *big.Int { + comp := byte(0x80) + buf := Uint128(i).Bytes() + + var value *big.Int + if (buf[0] & comp) == comp { + buf = twosComplement(buf) + value = (&big.Int{}).SetBytes(buf) + value = value.Neg(value) + } else { + value = (&big.Int{}).SetBytes(buf) + } + return value +} + +func (i Int128) String() string { + return Uint128(i).String() +} + +func (i Int128) DecimalString() string { + return i.BigInt().String() +} + +func (i Int128) MarshalJSON() (data []byte, err error) { + return []byte(`"` + Uint128(i).String() + `"`), nil +} + +func (i *Int128) UnmarshalJSON(data []byte) error { + var el Uint128 + if err := json.Unmarshal(data, &el); err != nil { + return err + } + + out := Int128(el) + *i = out + + return nil +} + +func (i *Int128) UnmarshalWithDecoder(dec *Decoder) error { + order := i.getByteOrder() + if dec.currentFieldOpt.Order != nil { + order = dec.currentFieldOpt.Order + } + value, err := dec.ReadInt128(order) + if err != nil { + return err + } + + *i = value + return nil +} + +func (i Int128) MarshalWithEncoder(enc *Encoder) error { + order := i.getByteOrder() + if enc.currentFieldOpt.Order != nil { + order = enc.currentFieldOpt.Order + } + return enc.WriteInt128(i, order) +} + +type Float128 Uint128 + +func (i Float128) MarshalJSON() (data []byte, err error) { + return []byte(`"` + Uint128(i).String() + `"`), nil +} + +func (i *Float128) UnmarshalJSON(data []byte) error { + var el Uint128 + if err := json.Unmarshal(data, &el); err != nil { + return err + } + + out := Float128(el) + *i = out + + return nil +} + +func (i *Float128) UnmarshalWithDecoder(dec *Decoder) error { + order := i.getByteOrder() + if dec.currentFieldOpt.Order != nil { + order = dec.currentFieldOpt.Order + } + value, err := dec.ReadFloat128(order) + if err != nil { + return err + } + + *i = Float128(value) + return nil +} + +func (i Float128) MarshalWithEncoder(enc *Encoder) error { + order := i.getByteOrder() + if enc.currentFieldOpt.Order != nil { + order = enc.currentFieldOpt.Order + } + return enc.WriteUint128(Uint128(i), order) +} diff --git a/binary/u128_test.go b/binary/u128_test.go new file mode 100644 index 000000000..9a3711888 --- /dev/null +++ b/binary/u128_test.go @@ -0,0 +1,64 @@ +package bin + +import ( + "encoding/json" + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" +) + +func TestUint128(t *testing.T) { + // from bytes: + data := []byte{51, 47, 223, 255, 255, 255, 255, 255, 30, 12, 0, 0, 0, 0, 0, 0} + + numberString := "57240246860720736513843" + parsed, err := decimal.NewFromString(numberString) + if err != nil { + panic(err) + } + { + if parsed.String() != numberString { + t.Errorf("parsed.String() != numberString") + } + } + + { + u128 := NewUint128LittleEndian() + err := u128.UnmarshalWithDecoder(NewBorshDecoder(data)) + require.NoError(t, err) + require.Equal(t, uint64(3102), u128.Hi) + require.Equal(t, uint64(18446744073707401011), u128.Lo) + require.Equal(t, parsed.BigInt(), u128.BigInt()) + require.Equal(t, parsed.String(), u128.DecimalString()) + } + { + u128 := NewUint128BigEndian() + ReverseBytes(data) + err := u128.UnmarshalWithDecoder(NewBorshDecoder(data)) + require.NoError(t, err) + require.Equal(t, uint64(3102), u128.Hi) + require.Equal(t, uint64(18446744073707401011), u128.Lo) + require.Equal(t, parsed.BigInt(), u128.BigInt()) + require.Equal(t, parsed.String(), u128.DecimalString()) + } + { + j := []byte(`{"i":"57240246860720736513843"}`) + var object struct { + I Uint128 `json:"i"` + } + + err := json.Unmarshal(j, &object) + require.NoError(t, err) + require.Equal(t, uint64(3102), object.I.Hi) + require.Equal(t, uint64(18446744073707401011), object.I.Lo) + require.Equal(t, parsed.BigInt(), object.I.BigInt()) + require.Equal(t, parsed.String(), object.I.DecimalString()) + + { + out, err := json.Marshal(object) + require.NoError(t, err) + require.Equal(t, j, out) + } + } +} diff --git a/binary/utils.go b/binary/utils.go new file mode 100644 index 000000000..43dd28859 --- /dev/null +++ b/binary/utils.go @@ -0,0 +1,72 @@ +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "math/big" + "reflect" +) + +// isInvalidValue reports whether rv is the zero reflect.Value (Kind == +// Invalid). It is NOT the same as rv.IsZero(): a valid reflect.Value holding +// a zero-valued T returns false here. Used by the encoders to short-circuit +// before attempting to interface() a nil reflect.Value. +func isInvalidValue(rv reflect.Value) bool { + return rv.Kind() == reflect.Invalid +} + +// asBinaryMarshaler returns a BinaryMarshaler for rv if one is reachable. +// It first tries the value itself; if that fails and rv is addressable, it +// retries via rv.Addr() so that marshalers implemented on *T are still found +// when the field is held by value. Without the second try, a legitimate +// custom marshaler is silently skipped and the encoder falls back to the +// generic reflect path — producing a different wire encoding. +// +// Performance note: rv.Interface() boxes the value into an interface{}, +// which heap-allocates for any non-pointer type larger than a word. We +// short-circuit via reflect.Type.Implements (a static type-info lookup +// with no allocation) so the boxing only happens for types that actually +// satisfy BinaryMarshaler — turning the dominant per-field allocation +// site into a no-op for types like solana.PublicKey. +func asBinaryMarshaler(rv reflect.Value) (BinaryMarshaler, bool) { + if !rv.IsValid() { + return nil, false + } + rt := rv.Type() + if rt.Implements(marshalableType) && rv.CanInterface() { + if m, ok := rv.Interface().(BinaryMarshaler); ok { + return m, true + } + } + if rv.CanAddr() && reflect.PointerTo(rt).Implements(marshalableType) { + addr := rv.Addr() + if addr.CanInterface() { + if m, ok := addr.Interface().(BinaryMarshaler); ok { + return m, true + } + } + } + return nil, false +} + +func twosComplement(v []byte) []byte { + buf := make([]byte, len(v)) + for i, b := range v { + buf[i] = b ^ byte(0xff) + } + one := big.NewInt(1) + value := (&big.Int{}).SetBytes(buf) + return value.Add(value, one).Bytes() +} diff --git a/binary/utils_test.go b/binary/utils_test.go new file mode 100644 index 000000000..a7d219aa1 --- /dev/null +++ b/binary/utils_test.go @@ -0,0 +1,52 @@ +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_twosComplement(t *testing.T) { + tests := []struct { + name string + in []byte + expect []byte + }{ + { + name: "empty array", + in: []byte{}, + expect: []byte{0x1}, + }, + { + name: "one element", + in: []byte{0x01}, + expect: []byte{0xff}, + }, + { + name: "basic test", + in: []byte{0xaa, 0xbb, 0xcc}, + expect: []byte{0x55, 0x44, 0x34}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expect, twosComplement(test.in)) + }) + } + +} diff --git a/binary/variant.go b/binary/variant.go new file mode 100644 index 000000000..185565cf6 --- /dev/null +++ b/binary/variant.go @@ -0,0 +1,355 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "bytes" + "encoding/binary" + "fmt" + "reflect" + "strings" +) + +// +/// Variant (emulates `fc::static_variant` type) +// + +type Variant interface { + Assign(typeID TypeID, impl any) + Obtain(*VariantDefinition) (typeID TypeID, typeName string, impl any) +} + +type VariantType struct { + Name string + Type any +} + +type VariantDefinition struct { + typeIDToType map[TypeID]reflect.Type + typeIDToName map[TypeID]string + typeNameToID map[string]TypeID + typeIDEncoding TypeIDEncoding +} + +// TypeID defines the internal representation of an instruction type ID +// (or account type, etc. in anchor programs) +// and it's used to associate instructions to decoders in the variant tracker. +type TypeID [8]byte + +func (vid TypeID) Bytes() []byte { + return vid[:] +} + +// Uvarint32 parses the TypeID to a uint32. +func (vid TypeID) Uvarint32() uint32 { + return Uvarint32FromTypeID(vid) +} + +// Uint32 parses the TypeID to a uint32. +func (vid TypeID) Uint32() uint32 { + return Uint32FromTypeID(vid, binary.LittleEndian) +} + +// Uint8 parses the TypeID to a Uint8. +func (vid TypeID) Uint8() uint8 { + return Uint8FromTypeID(vid) +} + +// Equal returns true if the provided bytes are equal to +// the bytes of the TypeID. +func (vid TypeID) Equal(b []byte) bool { + return bytes.Equal(vid.Bytes(), b) +} + +// TypeIDFromBytes converts a []byte to a TypeID. +// The provided slice must be 8 bytes long or less. +func TypeIDFromBytes(slice []byte) (id TypeID) { + // TODO: panic if len(slice) > 8 ??? + copy(id[:], slice) + return id +} + +// TypeIDFromSighash converts a sighash bytes to a TypeID. +func TypeIDFromSighash(sh []byte) TypeID { + return TypeIDFromBytes(sh) +} + +// TypeIDFromUvarint32 converts a Uvarint to a TypeID. +func TypeIDFromUvarint32(v uint32) TypeID { + buf := make([]byte, 8) + l := binary.PutUvarint(buf, uint64(v)) + return TypeIDFromBytes(buf[:l]) +} + +// TypeIDFromUint32 converts a uint32 to a TypeID. +func TypeIDFromUint32(v uint32, bo binary.ByteOrder) TypeID { + out := make([]byte, TypeSize.Uint32) + bo.PutUint32(out, v) + return TypeIDFromBytes(out) +} + +// TypeIDFromUint8 converts a uint8 to a TypeID. +func TypeIDFromUint8(v uint8) TypeID { + return TypeIDFromBytes([]byte{v}) +} + +// Uvarint32FromTypeID parses a TypeID bytes to a uvarint 32. +func Uvarint32FromTypeID(vid TypeID) (out uint32) { + l, _ := binary.Uvarint(vid[:]) + out = uint32(l) + return out +} + +// Uint32FromTypeID parses a TypeID bytes to a uint32. +func Uint32FromTypeID(vid TypeID, order binary.ByteOrder) (out uint32) { + out = order.Uint32(vid[:]) + return out +} + +// Uint32FromTypeID parses a TypeID bytes to a uint8. +func Uint8FromTypeID(vid TypeID) (out uint8) { + return vid[0] +} + +type TypeIDEncoding uint32 + +const ( + Uvarint32TypeIDEncoding TypeIDEncoding = iota + Uint32TypeIDEncoding + Uint8TypeIDEncoding + // AnchorTypeIDEncoding is the instruction ID encoding used by programs + // written using the anchor SDK. + // The typeID is the sighash of the instruction. + AnchorTypeIDEncoding + // No type ID; ONLY ONE VARIANT PER PROGRAM. + NoTypeIDEncoding +) + +var NoTypeIDDefaultID = TypeIDFromUint8(0) + +// NewVariantDefinition creates a variant definition based on the *ordered* provided types. +// +// - For anchor instructions, it's the name that defines the binary variant value. +// - For all other types, it's the ordering that defines the binary variant value just like in native `nodeos` C++ +// and in Smart Contract via the `std::variant` type. It's important to pass the entries +// in the right order! +// +// This variant definition can now be passed to functions of `BaseVariant` to implement +// marshal/unmarshaling functionalities for binary & JSON. +// +// This function panics on invalid input (unknown TypeIDEncoding, or +// NoTypeIDEncoding with ≠1 variants). Callers should validate their inputs +// upfront — typically at init time with a known-at-compile-time encoding — +// since variant definitions are expected to be constructed once and reused. +func NewVariantDefinition(typeIDEncoding TypeIDEncoding, types []VariantType) (out *VariantDefinition) { + typeCount := len(types) + out = &VariantDefinition{ + typeIDEncoding: typeIDEncoding, + typeIDToType: make(map[TypeID]reflect.Type, typeCount), + typeIDToName: make(map[TypeID]string, typeCount), + typeNameToID: make(map[string]TypeID, typeCount), + } + + switch typeIDEncoding { + case Uvarint32TypeIDEncoding: + for i, typeDef := range types { + typeID := TypeIDFromUvarint32(uint32(i)) + + // FIXME: Check how the reflect.Type is used and cache all its usage in the definition. + // Right now, on each Unmarshal, we re-compute some expensive stuff that can be + // re-used like the `typeGo.Elem()` which is always the same. It would be preferable + // to have those already pre-defined here so we can actually speed up the + // Unmarshal code. + out.typeIDToType[typeID] = reflect.TypeOf(typeDef.Type) + out.typeIDToName[typeID] = typeDef.Name + out.typeNameToID[typeDef.Name] = typeID + } + case Uint32TypeIDEncoding: + for i, typeDef := range types { + typeID := TypeIDFromUint32(uint32(i), binary.LittleEndian) + + // FIXME: Check how the reflect.Type is used and cache all its usage in the definition. + // Right now, on each Unmarshal, we re-compute some expensive stuff that can be + // re-used like the `typeGo.Elem()` which is always the same. It would be preferable + // to have those already pre-defined here so we can actually speed up the + // Unmarshal code. + out.typeIDToType[typeID] = reflect.TypeOf(typeDef.Type) + out.typeIDToName[typeID] = typeDef.Name + out.typeNameToID[typeDef.Name] = typeID + } + case Uint8TypeIDEncoding: + for i, typeDef := range types { + typeID := TypeIDFromUint8(uint8(i)) + + // FIXME: Check how the reflect.Type is used and cache all its usage in the definition. + // Right now, on each Unmarshal, we re-compute some expensive stuff that can be + // re-used like the `typeGo.Elem()` which is always the same. It would be preferable + // to have those already pre-defined here so we can actually speed up the + // Unmarshal code. + out.typeIDToType[typeID] = reflect.TypeOf(typeDef.Type) + out.typeIDToName[typeID] = typeDef.Name + out.typeNameToID[typeDef.Name] = typeID + } + case AnchorTypeIDEncoding: + for _, typeDef := range types { + typeID := TypeIDFromSighash(Sighash(SIGHASH_GLOBAL_NAMESPACE, typeDef.Name)) + + // FIXME: Check how the reflect.Type is used and cache all its usage in the definition. + // Right now, on each Unmarshal, we re-compute some expensive stuff that can be + // re-used like the `typeGo.Elem()` which is always the same. It would be preferable + // to have those already pre-defined here so we can actually speed up the + // Unmarshal code. + out.typeIDToType[typeID] = reflect.TypeOf(typeDef.Type) + out.typeIDToName[typeID] = typeDef.Name + out.typeNameToID[typeDef.Name] = typeID + } + case NoTypeIDEncoding: + if len(types) != 1 { + panic(fmt.Sprintf("NoTypeIDEncoding can only have one variant type definition, got %v", len(types))) + } + typeDef := types[0] + + typeID := NoTypeIDDefaultID + + // FIXME: Check how the reflect.Type is used and cache all its usage in the definition. + // Right now, on each Unmarshal, we re-compute some expensive stuff that can be + // re-used like the `typeGo.Elem()` which is always the same. It would be preferable + // to have those already pre-defined here so we can actually speed up the + // Unmarshal code. + out.typeIDToType[typeID] = reflect.TypeOf(typeDef.Type) + out.typeIDToName[typeID] = typeDef.Name + out.typeNameToID[typeDef.Name] = typeID + + default: + panic(fmt.Errorf("unsupported TypeIDEncoding: %v", typeIDEncoding)) + } + + return out +} + +func (d *VariantDefinition) TypeID(name string) TypeID { + id, found := d.typeNameToID[name] + if !found { + knownNames := make([]string, len(d.typeNameToID)) + i := 0 + for name := range d.typeNameToID { + knownNames[i] = name + i++ + } + + panic(fmt.Errorf("trying to use an unknown type name %q, known names are %q", name, strings.Join(knownNames, ", "))) + } + + return id +} + +type ( + VariantImplFactory = func() any + OnVariant = func(impl any) error +) + +type BaseVariant struct { + TypeID TypeID + Impl any +} + +var _ Variant = &BaseVariant{} + +func (a *BaseVariant) Assign(typeID TypeID, impl any) { + a.TypeID = typeID + a.Impl = impl +} + +func (a *BaseVariant) Obtain(def *VariantDefinition) (typeID TypeID, typeName string, impl any) { + return a.TypeID, def.typeIDToName[a.TypeID], a.Impl +} + +func (a *BaseVariant) UnmarshalBinaryVariant(decoder *Decoder, def *VariantDefinition) error { + var typeID TypeID + switch def.typeIDEncoding { + case Uvarint32TypeIDEncoding: + val, err := decoder.ReadUvarint32() + if err != nil { + return fmt.Errorf("uvarint32: unable to read variant type id: %w", err) + } + typeID = TypeIDFromUvarint32(val) + case Uint32TypeIDEncoding: + val, err := decoder.ReadUint32(binary.LittleEndian) + if err != nil { + return fmt.Errorf("uint32: unable to read variant type id: %w", err) + } + typeID = TypeIDFromUint32(val, binary.LittleEndian) + case Uint8TypeIDEncoding: + id, err := decoder.ReadUint8() + if err != nil { + return fmt.Errorf("uint8: unable to read variant type id: %w", err) + } + typeID = TypeIDFromBytes([]byte{id}) + case AnchorTypeIDEncoding: + id, err := decoder.ReadTypeID() + if err != nil { + return fmt.Errorf("anchor: unable to read variant type id: %w", err) + } + typeID = id + case NoTypeIDEncoding: + typeID = NoTypeIDDefaultID + } + + a.TypeID = typeID + + typeGo := def.typeIDToType[typeID] + if typeGo == nil { + return fmt.Errorf("no known type for type %d", typeID) + } + + if typeGo.Kind() == reflect.Ptr { + a.Impl = reflect.New(typeGo.Elem()).Interface() + if err := decoder.Decode(a.Impl); err != nil { + return fmt.Errorf("unable to decode variant type %d: %w", typeID, err) + } + } else { + // This is not the most optimal way of doing things for "value" + // types (over "pointer" types) as we always allocate a new pointer + // element, unmarshal it and then either keep the pointer type or turn + // it into a value type. + // + // However, in non-reflection based code, one would do like this and + // avoid an `new` memory allocation: + // + // ``` + // name := eos.Name("") + // json.Unmarshal(data, &name) + // ``` + // + // This would work without a problem. In reflection code however, I + // did not find how one can go from `reflect.Zero(typeGo)` (which is + // the equivalence of doing `name := eos.Name("")`) and take the + // pointer to it so it can be unmarshalled correctly. + // + // A played with various iteration, and nothing got it working. Maybe + // the next step would be to explore the `unsafe` package and obtain + // an unsafe pointer and play with it. + value := reflect.New(typeGo) + if err := decoder.Decode(value.Interface()); err != nil { + return fmt.Errorf("unable to decode variant type %d: %w", typeID, err) + } + + a.Impl = value.Elem().Interface() + } + return nil +} diff --git a/binary/variant_test.go b/binary/variant_test.go new file mode 100644 index 000000000..a24ee4ade --- /dev/null +++ b/binary/variant_test.go @@ -0,0 +1,315 @@ +// Copyright 2021 github.com/gagliardetto +// This file has been modified by github.com/gagliardetto +// +// Copyright 2020 dfuse Platform Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bin + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTypeID(t *testing.T) { + { + ha := Sighash(SIGHASH_GLOBAL_NAMESPACE, "hello") + vid := TypeIDFromSighash(ha) + require.Equal(t, ha, vid.Bytes()) + require.True(t, vid.Equal(ha)) + } + { + expected := uint32(66) + vid := TypeIDFromUint32(expected, binary.LittleEndian) + + got := Uint32FromTypeID(vid, binary.LittleEndian) + require.Equal(t, expected, got) + require.Equal(t, expected, vid.Uint32()) + } + { + expected := uint32(66) + vid := TypeIDFromUvarint32(expected) + + got := Uvarint32FromTypeID(vid) + require.Equal(t, expected, got) + require.Equal(t, expected, vid.Uvarint32()) + } + { + { + vid := TypeIDFromBytes([]byte{}) + expected := []byte{0, 0, 0, 0, 0, 0, 0, 0} + require.Equal(t, expected, vid.Bytes()) + } + { + expected := []byte{1, 2, 3, 4, 5, 6, 7, 8} + vid := TypeIDFromBytes(expected) + require.Equal(t, expected, vid.Bytes()) + } + } + { + expected := uint8(33) + vid := TypeIDFromUint8(expected) + got := Uint8FromTypeID(vid) + require.Equal(t, expected, got) + require.Equal(t, expected, vid.Uint8()) + } + { + m := map[TypeID]string{ + TypeIDFromSighash(Sighash(SIGHASH_GLOBAL_NAMESPACE, "hello")): "hello", + TypeIDFromSighash(Sighash(SIGHASH_GLOBAL_NAMESPACE, "world")): "world", + } + + expected := "world" + require.Equal(t, + expected, + m[TypeIDFromSighash(Sighash(SIGHASH_GLOBAL_NAMESPACE, "world"))], + ) + } +} + +type Forest struct { + T Tree +} + +type Tree struct { + Padding [5]byte + NodeCount uint32 `bin:"sizeof=Nodes"` + Random uint64 + Nodes []*Node +} + +var NodeVariantDef = NewVariantDefinition( + Uint32TypeIDEncoding, + + []VariantType{ + {"left_node", (*NodeLeft)(nil)}, + {"right_node", (*NodeRight)(nil)}, + {"inner_node", (*NodeInner)(nil)}, + }) + +type Node struct { + BaseVariant +} + +type NodeLeft struct { + Key uint32 + Description string +} + +type NodeRight struct { + Owner uint64 + Padding [2]byte + Quantity Uint64 +} + +type NodeInner struct { + Key Uint128 +} + +func (n *Node) UnmarshalWithDecoder(decoder *Decoder) error { + return n.BaseVariant.UnmarshalBinaryVariant(decoder, NodeVariantDef) +} + +func (n *Node) MarshalWithEncoder(encoder *Encoder) error { + err := encoder.WriteUint32(n.TypeID.Uint32(), binary.LittleEndian) + if err != nil { + return err + } + return encoder.Encode(n.Impl) +} + +func TestDecode_Variant(t *testing.T) { + buf := []byte{ + 0x73, 0x65, 0x72, 0x75, 0x6d, // Padding[5]byte + 0x05, 0x00, 0x00, 0x00, // Node length 5 + 0xff, 0xff, 0x00, 0x00, 0x00, 0x0, 0x00, 0x00, // ROOT 65,535 + 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x61, 0x62, 0x63, // left node -> key = 3, description "abc" + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // right node -> owner = 3, quantity 13 + 0x01, 0x00, 0x00, 0x00, 0x52, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x9b, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // right node -> owner = 82, quantity 923 + 0x02, 0x00, 0x00, 0x00, 0xff, 0x7f, 0xc6, 0xa4, 0x7e, 0x8d, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // inner node -> key = 999999999999999 + 0x02, 0x00, 0x00, 0x00, 0x23, 0xd3, 0xd8, 0x9a, 0x99, 0x7e, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // inner node -> key = 983623123129123 + } + + decoder := NewBinDecoder(buf) + forest := Forest{} + err := decoder.Decode(&forest) + require.NoError(t, err) + require.Equal(t, 0, decoder.Remaining()) + assert.Equal(t, Tree{ + Padding: [5]byte{0x73, 0x65, 0x72, 0x75, 0x6d}, + NodeCount: 5, + Random: 65535, + Nodes: []*Node{ + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(0, binary.LittleEndian), + Impl: &NodeLeft{ + Key: 3, + Description: "abc", + }, + }, + }, + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(1, binary.LittleEndian), + Impl: &NodeRight{ + Owner: 3, + Padding: [2]byte{0x00, 0x00}, + Quantity: 13, + }, + }, + }, + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(1, binary.LittleEndian), + Impl: &NodeRight{ + Owner: 82, + Padding: [2]byte{0x00, 0x00}, + Quantity: 923, + }, + }, + }, + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(2, binary.LittleEndian), + Impl: &NodeInner{ + Key: Uint128{ + Lo: 999999999999999, + Hi: 0, + }, + }, + }, + }, + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(2, binary.LittleEndian), + Impl: &NodeInner{ + Key: Uint128{ + Lo: 983623123129123, + Hi: 0, + }, + }, + }, + }, + }, + }, forest.T) +} + +func TestEncode_Variant(t *testing.T) { + expectBuf := []byte{ + 0x73, 0x65, 0x72, 0x75, 0x6d, // Padding[5]byte + 0x05, 0x00, 0x00, 0x00, // Node length 5 + 0xff, 0xff, 0x00, 0x00, 0x00, 0x0, 0x00, 0x00, // ROOT 65,535 + 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x61, 0x62, 0x63, // left node -> key = 3, description "abc" + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // right node -> owner = 3, quantity 13 + 0x01, 0x00, 0x00, 0x00, 0x52, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x9b, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // right node -> owner = 82, quantity 923 + 0x02, 0x00, 0x00, 0x00, 0xff, 0x7f, 0xc6, 0xa4, 0x7e, 0x8d, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // inner node -> key = 999999999999999 + 0x02, 0x00, 0x00, 0x00, 0x23, 0xd3, 0xd8, 0x9a, 0x99, 0x7e, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // inner node -> key = 983623123129123 + } + + buf := new(bytes.Buffer) + enc := NewBinEncoder(buf) + + enc.Encode(&Forest{T: Tree{ + Padding: [5]byte{0x73, 0x65, 0x72, 0x75, 0x6d}, + NodeCount: 5, + Random: 65535, + Nodes: []*Node{ + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(0, binary.LittleEndian), + Impl: &NodeLeft{ + Key: 3, + Description: "abc", + }, + }, + }, + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(1, binary.LittleEndian), + Impl: &NodeRight{ + Owner: 3, + Padding: [2]byte{0x00, 0x00}, + Quantity: 13, + }, + }, + }, + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(1, binary.LittleEndian), + Impl: &NodeRight{ + Owner: 82, + Padding: [2]byte{0x00, 0x00}, + Quantity: 923, + }, + }, + }, + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(2, binary.LittleEndian), + Impl: &NodeInner{ + Key: Uint128{ + Lo: 999999999999999, + Hi: 0, + }, + }, + }, + }, + { + BaseVariant: BaseVariant{ + TypeID: TypeIDFromUint32(2, binary.LittleEndian), + Impl: &NodeInner{ + Key: Uint128{ + Lo: 983623123129123, + Hi: 0, + }, + }, + }, + }, + }, + }}) + + assert.Equal(t, expectBuf, buf.Bytes()) +} + +type unexportesStruct struct { + value uint32 +} + +func TestDecode_UnexporterStruct(t *testing.T) { + buf := []byte{ + 0x05, 0x00, 0x00, 0x00, + } + + decoder := NewBinDecoder(buf) + s := unexportesStruct{} + err := decoder.Decode(&s) + require.NoError(t, err) + require.Equal(t, 4, decoder.Remaining()) + assert.Equal(t, unexportesStruct{value: 0}, s) +} + +func TestEncode_UnexporterStruct(t *testing.T) { + var expectData []byte + + buf := new(bytes.Buffer) + enc := NewBinEncoder(buf) + + enc.Encode(&unexportesStruct{value: 5}) + assert.Equal(t, expectData, buf.Bytes()) +} diff --git a/cmd/slnc/cmd/decoding.go b/cmd/slnc/cmd/decoding.go index 66bd97920..852145872 100644 --- a/cmd/slnc/cmd/decoding.go +++ b/cmd/slnc/cmd/decoding.go @@ -18,7 +18,7 @@ import ( "fmt" "log" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/programs/token" ) diff --git a/cmd/slnc/cmd/get_spl_token.go b/cmd/slnc/cmd/get_spl_token.go index 8b3d5d9a8..82b9d724c 100644 --- a/cmd/slnc/cmd/get_spl_token.go +++ b/cmd/slnc/cmd/get_spl_token.go @@ -22,7 +22,7 @@ import ( "log" "os" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/programs/token" "github.com/gagliardetto/solana-go/rpc" diff --git a/go.mod b/go.mod index c40a36b66..8ba6fc5e1 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,11 @@ module github.com/gagliardetto/solana-go go 1.24.0 require ( - github.com/gagliardetto/binary v0.8.0 github.com/gagliardetto/gofuzz v1.2.2 github.com/gagliardetto/treeout v0.1.4 github.com/google/uuid v1.6.0 github.com/mr-tron/base58 v1.2.0 + github.com/shopspring/decimal v1.4.0 ) require ( diff --git a/go.sum b/go.sum index 360d066f6..8e8bc48ae 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,6 @@ github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/gagliardetto/binary v0.8.0 h1:U9ahc45v9HW0d15LoN++vIXSJyqR/pWw8DDlhd7zvxg= -github.com/gagliardetto/binary v0.8.0/go.mod h1:2tfj51g5o9dnvsc+fL3Jxr22MuWzYXwx9wEoN0XQ7/c= github.com/gagliardetto/gofuzz v1.2.2 h1:XL/8qDMzcgvR4+CyRQW9UGdwPRPMHVJfqQ/uMvSUuQw= github.com/gagliardetto/gofuzz v1.2.2/go.mod h1:bkH/3hYLZrMLbfYWA0pWzXmi5TTRZnu4pMGZBkqMKvY= github.com/gagliardetto/treeout v0.1.4 h1:ozeYerrLCmCubo1TcIjFiOWTTGteOOHND1twdFpgwaw= @@ -79,7 +77,6 @@ github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYs github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -125,8 +122,8 @@ github.com/ryanuber/columnize v2.1.2+incompatible h1:C89EOx/XBWwIXl8wm8OPJBd7kPF github.com/ryanuber/columnize v2.1.2+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k= github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= -github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= -github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.14.0 h1:9tH6MapGnn/j0eb0yIXiLjERO8RB6xIVZRDCX7PtqWA= @@ -139,7 +136,6 @@ github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4= github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= -github.com/streamingfast/logging v0.0.0-20230608130331-f22c91403091/go.mod h1:VlduQ80JcGJSargkRU4Sg9Xo63wZD/l8A5NC/Uo1/uU= github.com/streamingfast/logging v0.0.0-20250404134358-92b15d2fbd2e h1:qGVGDR2/bXLyR498un1hvhDQPUJ/m14JBRTJz+c67Bc= github.com/streamingfast/logging v0.0.0-20250404134358-92b15d2fbd2e/go.mod h1:VlduQ80JcGJSargkRU4Sg9Xo63wZD/l8A5NC/Uo1/uU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/message.go b/message.go index d44b61905..de792fce0 100644 --- a/message.go +++ b/message.go @@ -21,7 +21,7 @@ import ( "encoding/base64" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/treeout" "github.com/gagliardetto/solana-go/text" diff --git a/message_test.go b/message_test.go index 19cd77e3c..65c843bc5 100644 --- a/message_test.go +++ b/message_test.go @@ -3,7 +3,7 @@ package solana import ( "testing" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/message_v0_test.go b/message_v0_test.go index aae25293f..e10f3e8a9 100644 --- a/message_v0_test.go +++ b/message_v0_test.go @@ -3,7 +3,7 @@ package solana import ( "testing" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/nativetypes.go b/nativetypes.go index 5a1b7b9d0..e9c35bd26 100644 --- a/nativetypes.go +++ b/nativetypes.go @@ -23,7 +23,7 @@ import ( "fmt" "io" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/base58" "github.com/mostynb/zstdpool-freelist" mrtronbase58 "github.com/mr-tron/base58" diff --git a/programs/address-lookup-table/CloseLookupTable.go b/programs/address-lookup-table/CloseLookupTable.go index 392cfc84b..d198d0491 100644 --- a/programs/address-lookup-table/CloseLookupTable.go +++ b/programs/address-lookup-table/CloseLookupTable.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" solana "github.com/gagliardetto/solana-go" format "github.com/gagliardetto/solana-go/text/format" treeout "github.com/gagliardetto/treeout" diff --git a/programs/address-lookup-table/CreateLookupTable.go b/programs/address-lookup-table/CreateLookupTable.go index bed10c5d6..35a7da895 100644 --- a/programs/address-lookup-table/CreateLookupTable.go +++ b/programs/address-lookup-table/CreateLookupTable.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" solana "github.com/gagliardetto/solana-go" format "github.com/gagliardetto/solana-go/text/format" treeout "github.com/gagliardetto/treeout" diff --git a/programs/address-lookup-table/DeactivateLookupTable.go b/programs/address-lookup-table/DeactivateLookupTable.go index c9a417566..eb0e6d797 100644 --- a/programs/address-lookup-table/DeactivateLookupTable.go +++ b/programs/address-lookup-table/DeactivateLookupTable.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" solana "github.com/gagliardetto/solana-go" format "github.com/gagliardetto/solana-go/text/format" treeout "github.com/gagliardetto/treeout" diff --git a/programs/address-lookup-table/ExtendLookupTable.go b/programs/address-lookup-table/ExtendLookupTable.go index 714abaa3e..d64fb86f6 100644 --- a/programs/address-lookup-table/ExtendLookupTable.go +++ b/programs/address-lookup-table/ExtendLookupTable.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" solana "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" format "github.com/gagliardetto/solana-go/text/format" treeout "github.com/gagliardetto/treeout" ) diff --git a/programs/address-lookup-table/FreezeLookupTable.go b/programs/address-lookup-table/FreezeLookupTable.go index d317c6c2b..e4a911bfb 100644 --- a/programs/address-lookup-table/FreezeLookupTable.go +++ b/programs/address-lookup-table/FreezeLookupTable.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" solana "github.com/gagliardetto/solana-go" format "github.com/gagliardetto/solana-go/text/format" treeout "github.com/gagliardetto/treeout" diff --git a/programs/address-lookup-table/address-lookup.go b/programs/address-lookup-table/address-lookup.go index cbd8fd9c1..83f27ecc3 100644 --- a/programs/address-lookup-table/address-lookup.go +++ b/programs/address-lookup-table/address-lookup.go @@ -5,8 +5,8 @@ import ( "fmt" "math" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/rpc" ) diff --git a/programs/address-lookup-table/address-lookup_test.go b/programs/address-lookup-table/address-lookup_test.go index 89552763d..3894503af 100644 --- a/programs/address-lookup-table/address-lookup_test.go +++ b/programs/address-lookup-table/address-lookup_test.go @@ -6,8 +6,8 @@ import ( "math" "testing" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/require" ) diff --git a/programs/address-lookup-table/instruction_test.go b/programs/address-lookup-table/instruction_test.go index a549d293c..4865a8297 100644 --- a/programs/address-lookup-table/instruction_test.go +++ b/programs/address-lookup-table/instruction_test.go @@ -5,8 +5,8 @@ import ( "encoding/hex" "testing" - bin "github.com/gagliardetto/binary" solana "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/programs/address-lookup-table/instructions.go b/programs/address-lookup-table/instructions.go index 0c8d1b8aa..f6418ee2a 100644 --- a/programs/address-lookup-table/instructions.go +++ b/programs/address-lookup-table/instructions.go @@ -6,8 +6,8 @@ import ( "fmt" spew "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" solana "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" text "github.com/gagliardetto/solana-go/text" treeout "github.com/gagliardetto/treeout" ) @@ -23,6 +23,7 @@ const ProgramName = "AddressLookupTable" func init() { solana.MustRegisterInstructionDecoder(ProgramID, registryDecodeInstruction) + bin.PrewarmVariantDefinition(InstructionImplDef) } const ( diff --git a/programs/associated-token-account/Create.go b/programs/associated-token-account/Create.go index 343dbdd31..ea4481a74 100644 --- a/programs/associated-token-account/Create.go +++ b/programs/associated-token-account/Create.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" solana "github.com/gagliardetto/solana-go" format "github.com/gagliardetto/solana-go/text/format" treeout "github.com/gagliardetto/treeout" diff --git a/programs/associated-token-account/CreateIdempotent.go b/programs/associated-token-account/CreateIdempotent.go index 25ab6b507..c7b3af20a 100644 --- a/programs/associated-token-account/CreateIdempotent.go +++ b/programs/associated-token-account/CreateIdempotent.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" solana "github.com/gagliardetto/solana-go" format "github.com/gagliardetto/solana-go/text/format" treeout "github.com/gagliardetto/treeout" diff --git a/programs/associated-token-account/RecoverNested.go b/programs/associated-token-account/RecoverNested.go index 04b929a28..5f62cadcc 100644 --- a/programs/associated-token-account/RecoverNested.go +++ b/programs/associated-token-account/RecoverNested.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" solana "github.com/gagliardetto/solana-go" format "github.com/gagliardetto/solana-go/text/format" treeout "github.com/gagliardetto/treeout" diff --git a/programs/associated-token-account/instruction_test.go b/programs/associated-token-account/instruction_test.go index 057a75969..d3f56690a 100644 --- a/programs/associated-token-account/instruction_test.go +++ b/programs/associated-token-account/instruction_test.go @@ -18,8 +18,8 @@ import ( "encoding/hex" "testing" - bin "github.com/gagliardetto/binary" solana "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/programs/associated-token-account/instructions.go b/programs/associated-token-account/instructions.go index 44b11539d..0a5cd6bca 100644 --- a/programs/associated-token-account/instructions.go +++ b/programs/associated-token-account/instructions.go @@ -19,8 +19,8 @@ import ( "fmt" spew "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" solana "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" text "github.com/gagliardetto/solana-go/text" treeout "github.com/gagliardetto/treeout" ) @@ -36,6 +36,7 @@ const ProgramName = "AssociatedTokenAccount" func init() { solana.MustRegisterInstructionDecoder(ProgramID, registryDecodeInstruction) + bin.PrewarmVariantDefinition(InstructionImplDef) } const ( diff --git a/programs/compute-budget/RequestHeapFrame.go b/programs/compute-budget/RequestHeapFrame.go index 920e5bf24..e98620cb9 100644 --- a/programs/compute-budget/RequestHeapFrame.go +++ b/programs/compute-budget/RequestHeapFrame.go @@ -15,9 +15,10 @@ package computebudget import ( + "encoding/binary" "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -97,7 +98,7 @@ func (inst *RequestHeapFrame) EncodeToTree(parent ag_treeout.Branches) { func (obj RequestHeapFrame) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `HeapSize` param: - err = encoder.Encode(obj.HeapSize) + err = encoder.WriteUint32(obj.HeapSize, binary.LittleEndian) if err != nil { return err } diff --git a/programs/compute-budget/RequestUnitsDeprecated.go b/programs/compute-budget/RequestUnitsDeprecated.go index 8a9153b50..5619935d9 100644 --- a/programs/compute-budget/RequestUnitsDeprecated.go +++ b/programs/compute-budget/RequestUnitsDeprecated.go @@ -15,9 +15,10 @@ package computebudget import ( + "encoding/binary" "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -109,13 +110,13 @@ func (inst *RequestUnitsDeprecated) EncodeToTree(parent ag_treeout.Branches) { func (obj RequestUnitsDeprecated) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Units` param: - err = encoder.Encode(obj.Units) + err = encoder.WriteUint32(obj.Units, binary.LittleEndian) if err != nil { return err } // Serialize `AdditionalFee` param: - err = encoder.Encode(obj.AdditionalFee) + err = encoder.WriteUint32(obj.AdditionalFee, binary.LittleEndian) if err != nil { return err } diff --git a/programs/compute-budget/SetComputeUnitLimit.go b/programs/compute-budget/SetComputeUnitLimit.go index f8c9432b9..09c03e528 100644 --- a/programs/compute-budget/SetComputeUnitLimit.go +++ b/programs/compute-budget/SetComputeUnitLimit.go @@ -15,9 +15,10 @@ package computebudget import ( + "encoding/binary" "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -97,7 +98,7 @@ func (inst *SetComputeUnitLimit) EncodeToTree(parent ag_treeout.Branches) { func (obj SetComputeUnitLimit) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Units` param: - err = encoder.Encode(obj.Units) + err = encoder.WriteUint32(obj.Units, binary.LittleEndian) if err != nil { return err } diff --git a/programs/compute-budget/SetComputeUnitPrice.go b/programs/compute-budget/SetComputeUnitPrice.go index 3eb0b0d08..a730440f9 100644 --- a/programs/compute-budget/SetComputeUnitPrice.go +++ b/programs/compute-budget/SetComputeUnitPrice.go @@ -15,9 +15,10 @@ package computebudget import ( + "encoding/binary" "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -91,7 +92,7 @@ func (inst *SetComputeUnitPrice) EncodeToTree(parent ag_treeout.Branches) { func (obj SetComputeUnitPrice) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `MicroLamports` param: - err = encoder.Encode(obj.MicroLamports) + err = encoder.WriteUint64(obj.MicroLamports, binary.LittleEndian) if err != nil { return err } diff --git a/programs/compute-budget/SetLoadedAccountsDataSizeLimit.go b/programs/compute-budget/SetLoadedAccountsDataSizeLimit.go index 8388c3951..d4ef550fb 100644 --- a/programs/compute-budget/SetLoadedAccountsDataSizeLimit.go +++ b/programs/compute-budget/SetLoadedAccountsDataSizeLimit.go @@ -15,9 +15,10 @@ package computebudget import ( + "encoding/binary" "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -91,7 +92,7 @@ func (inst *SetLoadedAccountsDataSizeLimit) EncodeToTree(parent ag_treeout.Branc func (obj SetLoadedAccountsDataSizeLimit) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Bytes` param: - err = encoder.Encode(obj.Bytes) + err = encoder.WriteUint32(obj.Bytes, binary.LittleEndian) if err != nil { return err } diff --git a/programs/compute-budget/instruction.go b/programs/compute-budget/instruction.go index 29e4a2a5f..7c9917e51 100644 --- a/programs/compute-budget/instruction.go +++ b/programs/compute-budget/instruction.go @@ -19,8 +19,8 @@ import ( "fmt" ag_spew "github.com/davecgh/go-spew/spew" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_text "github.com/gagliardetto/solana-go/text" ag_treeout "github.com/gagliardetto/treeout" ) @@ -38,6 +38,7 @@ func init() { if !ProgramID.IsZero() { ag_solanago.MustRegisterInstructionDecoder(ProgramID, registryDecodeInstruction) } + ag_binary.PrewarmVariantDefinition(InstructionImplDef) } const ( diff --git a/programs/compute-budget/instruction_test.go b/programs/compute-budget/instruction_test.go index 695bdcd73..a6cb65ea6 100644 --- a/programs/compute-budget/instruction_test.go +++ b/programs/compute-budget/instruction_test.go @@ -19,7 +19,7 @@ import ( "encoding/hex" "testing" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/programs/memo/Create.go b/programs/memo/Create.go index 2d5a5d302..b7bd52bce 100644 --- a/programs/memo/Create.go +++ b/programs/memo/Create.go @@ -17,8 +17,8 @@ package memo import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" ) diff --git a/programs/memo/instructions.go b/programs/memo/instructions.go index 1c9c19387..9f6b871ef 100644 --- a/programs/memo/instructions.go +++ b/programs/memo/instructions.go @@ -18,8 +18,8 @@ import ( "bytes" "fmt" "github.com/davecgh/go-spew/spew" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_text "github.com/gagliardetto/solana-go/text" "github.com/gagliardetto/treeout" ) @@ -33,6 +33,7 @@ func SetProgramID(pubkey ag_solanago.PublicKey) error { func init() { ag_solanago.MustRegisterInstructionDecoder(ProgramID, registryDecodeInstruction) + ag_binary.PrewarmVariantDefinition(InstructionImplDef) } type MemoInstruction struct { diff --git a/programs/stake/Authorize.go b/programs/stake/Authorize.go index 612654ffd..41a32a292 100644 --- a/programs/stake/Authorize.go +++ b/programs/stake/Authorize.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -125,7 +125,7 @@ func (inst *Authorize) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *Authorize) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.NewAuthorized) + err := encoder.WriteBytes(inst.NewAuthorized[:], false) if err != nil { return err } diff --git a/programs/stake/AuthorizeChecked.go b/programs/stake/AuthorizeChecked.go index bf0e1b844..ab2fd2ce0 100644 --- a/programs/stake/AuthorizeChecked.go +++ b/programs/stake/AuthorizeChecked.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" diff --git a/programs/stake/AuthorizeCheckedWithSeed.go b/programs/stake/AuthorizeCheckedWithSeed.go index e65f7b558..fd4368881 100644 --- a/programs/stake/AuthorizeCheckedWithSeed.go +++ b/programs/stake/AuthorizeCheckedWithSeed.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -143,7 +143,7 @@ func (inst *AuthorizeCheckedWithSeed) UnmarshalWithDecoder(dec *bin.Decoder) err func (inst *AuthorizeCheckedWithSeed) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.Args) + err := inst.Args.MarshalWithEncoder(encoder) if err != nil { return err } diff --git a/programs/stake/AuthorizeWithSeed.go b/programs/stake/AuthorizeWithSeed.go index 932e57de8..df162560f 100644 --- a/programs/stake/AuthorizeWithSeed.go +++ b/programs/stake/AuthorizeWithSeed.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -141,7 +141,7 @@ func (inst *AuthorizeWithSeed) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *AuthorizeWithSeed) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.Args) + err := inst.Args.MarshalWithEncoder(encoder) if err != nil { return err } diff --git a/programs/stake/Authorized.go b/programs/stake/Authorized.go index d40162f02..3bc2e70cf 100644 --- a/programs/stake/Authorized.go +++ b/programs/stake/Authorized.go @@ -17,7 +17,7 @@ package stake import ( "errors" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ) @@ -46,13 +46,13 @@ func (auth *Authorized) UnmarshalWithDecoder(dec *bin.Decoder) error { func (auth *Authorized) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*auth.Staker) + err := encoder.WriteBytes(auth.Staker[:], false) if err != nil { return err } } { - err := encoder.Encode(*auth.Withdrawer) + err := encoder.WriteBytes(auth.Withdrawer[:], false) if err != nil { return err } diff --git a/programs/stake/Deactivate.go b/programs/stake/Deactivate.go index 6cbf54fcb..2751c823e 100644 --- a/programs/stake/Deactivate.go +++ b/programs/stake/Deactivate.go @@ -17,8 +17,8 @@ package stake import ( "fmt" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" ) diff --git a/programs/stake/DeactivateDelinquent.go b/programs/stake/DeactivateDelinquent.go index 95dda6b00..21d92ee8a 100644 --- a/programs/stake/DeactivateDelinquent.go +++ b/programs/stake/DeactivateDelinquent.go @@ -17,8 +17,8 @@ package stake import ( "fmt" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" ) diff --git a/programs/stake/DelegateStake.go b/programs/stake/DelegateStake.go index e29830581..ef4528132 100644 --- a/programs/stake/DelegateStake.go +++ b/programs/stake/DelegateStake.go @@ -17,8 +17,8 @@ package stake import ( "fmt" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" ) diff --git a/programs/stake/GetMinimumDelegation.go b/programs/stake/GetMinimumDelegation.go index 42b68358e..72b2258e4 100644 --- a/programs/stake/GetMinimumDelegation.go +++ b/programs/stake/GetMinimumDelegation.go @@ -15,8 +15,8 @@ package stake import ( - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" ) diff --git a/programs/stake/Initialize.go b/programs/stake/Initialize.go index 4f67ff948..9f5ec6ac8 100644 --- a/programs/stake/Initialize.go +++ b/programs/stake/Initialize.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -58,13 +58,13 @@ func (inst *Initialize) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *Initialize) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.Authorized) + err := inst.Authorized.MarshalWithEncoder(encoder) if err != nil { return err } } { - err := encoder.Encode(*inst.Lockup) + err := inst.Lockup.MarshalWithEncoder(encoder) if err != nil { return err } diff --git a/programs/stake/InitializeChecked.go b/programs/stake/InitializeChecked.go index 657ba0e14..5b520234a 100644 --- a/programs/stake/InitializeChecked.go +++ b/programs/stake/InitializeChecked.go @@ -17,8 +17,8 @@ package stake import ( "fmt" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" ) diff --git a/programs/stake/Lockup.go b/programs/stake/Lockup.go index 204667ec9..1b35f9405 100644 --- a/programs/stake/Lockup.go +++ b/programs/stake/Lockup.go @@ -15,9 +15,10 @@ package stake import ( + "encoding/binary" "errors" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ) @@ -54,19 +55,19 @@ func (lockup *Lockup) UnmarshalWithDecoder(dec *bin.Decoder) error { func (lockup *Lockup) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*lockup.UnixTimestamp) + err := encoder.WriteInt64(*lockup.UnixTimestamp, binary.LittleEndian) if err != nil { return err } } { - err := encoder.Encode(*lockup.Epoch) + err := encoder.WriteUint64(*lockup.Epoch, binary.LittleEndian) if err != nil { return err } } { - err := encoder.Encode(*lockup.Custodian) + err := encoder.WriteBytes(lockup.Custodian[:], false) if err != nil { return err } diff --git a/programs/stake/Merge.go b/programs/stake/Merge.go index d0b582fef..5e257e068 100644 --- a/programs/stake/Merge.go +++ b/programs/stake/Merge.go @@ -17,8 +17,8 @@ package stake import ( "fmt" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" ) @@ -81,8 +81,8 @@ func (inst *Merge) GetSourceStakeAccount() *solana.AccountMeta { return inst.AccountMetaSlice[1] } func (inst *Merge) GetClockSysvar() *solana.AccountMeta { return inst.AccountMetaSlice[2] } -func (inst *Merge) GetStakeHistorySysvar() *solana.AccountMeta { return inst.AccountMetaSlice[3] } -func (inst *Merge) GetStakeAuthority() *solana.AccountMeta { return inst.AccountMetaSlice[4] } +func (inst *Merge) GetStakeHistorySysvar() *solana.AccountMeta { return inst.AccountMetaSlice[3] } +func (inst *Merge) GetStakeAuthority() *solana.AccountMeta { return inst.AccountMetaSlice[4] } func (inst Merge) Build() *Instruction { return &Instruction{BaseVariant: bin.BaseVariant{ diff --git a/programs/stake/MoveLamports.go b/programs/stake/MoveLamports.go index 21587fd93..16e1b453a 100644 --- a/programs/stake/MoveLamports.go +++ b/programs/stake/MoveLamports.go @@ -15,10 +15,11 @@ package stake import ( + "encoding/binary" "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -93,7 +94,7 @@ func (inst *MoveLamports) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *MoveLamports) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } diff --git a/programs/stake/MoveStake.go b/programs/stake/MoveStake.go index ca28f29bb..b106f4c9c 100644 --- a/programs/stake/MoveStake.go +++ b/programs/stake/MoveStake.go @@ -15,10 +15,11 @@ package stake import ( + "encoding/binary" "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -93,7 +94,7 @@ func (inst *MoveStake) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *MoveStake) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } diff --git a/programs/stake/Redelegate.go b/programs/stake/Redelegate.go index f88d87f75..f7b11506b 100644 --- a/programs/stake/Redelegate.go +++ b/programs/stake/Redelegate.go @@ -17,8 +17,8 @@ package stake import ( "fmt" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" ) @@ -81,7 +81,7 @@ func (inst *Redelegate) GetStakeAccount() *solana.AccountMeta { func (inst *Redelegate) GetUninitializedStakeAccount() *solana.AccountMeta { return inst.AccountMetaSlice[1] } -func (inst *Redelegate) GetVoteAccount() *solana.AccountMeta { return inst.AccountMetaSlice[2] } +func (inst *Redelegate) GetVoteAccount() *solana.AccountMeta { return inst.AccountMetaSlice[2] } func (inst *Redelegate) GetUnusedAccount() *solana.AccountMeta { return inst.AccountMetaSlice[3] } func (inst *Redelegate) GetStakeAuthority() *solana.AccountMeta { return inst.AccountMetaSlice[4] diff --git a/programs/stake/SetLockup.go b/programs/stake/SetLockup.go index c70a669b5..23423785e 100644 --- a/programs/stake/SetLockup.go +++ b/programs/stake/SetLockup.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -96,7 +96,7 @@ func (inst *SetLockup) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *SetLockup) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.LockupArgs) + err := inst.LockupArgs.MarshalWithEncoder(encoder) if err != nil { return err } diff --git a/programs/stake/SetLockupChecked.go b/programs/stake/SetLockupChecked.go index 0cab4aa42..413c5f3be 100644 --- a/programs/stake/SetLockupChecked.go +++ b/programs/stake/SetLockupChecked.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -105,7 +105,7 @@ func (inst *SetLockupChecked) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *SetLockupChecked) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.LockupCheckedArgs) + err := inst.LockupCheckedArgs.MarshalWithEncoder(encoder) if err != nil { return err } diff --git a/programs/stake/Split.go b/programs/stake/Split.go index 7d656f91a..b0a148b46 100644 --- a/programs/stake/Split.go +++ b/programs/stake/Split.go @@ -15,10 +15,11 @@ package stake import ( + "encoding/binary" "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -88,7 +89,7 @@ func (inst *Split) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *Split) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } diff --git a/programs/stake/Withdraw.go b/programs/stake/Withdraw.go index 85e1756db..44d3b1641 100644 --- a/programs/stake/Withdraw.go +++ b/programs/stake/Withdraw.go @@ -15,10 +15,11 @@ package stake import ( + "encoding/binary" "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -107,7 +108,7 @@ func (inst *Withdraw) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *Withdraw) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } diff --git a/programs/stake/instructions.go b/programs/stake/instructions.go index dae923ece..1817e4549 100644 --- a/programs/stake/instructions.go +++ b/programs/stake/instructions.go @@ -21,8 +21,8 @@ import ( "fmt" "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text" "github.com/gagliardetto/treeout" ) @@ -38,6 +38,7 @@ const ProgramName = "Stake" func init() { solana.MustRegisterInstructionDecoder(ProgramID, registryDecodeInstruction) + bin.PrewarmVariantDefinition(InstructionImplDef) } const ( diff --git a/programs/stake/types.go b/programs/stake/types.go index 165237105..c0c7b038e 100644 --- a/programs/stake/types.go +++ b/programs/stake/types.go @@ -17,8 +17,8 @@ package stake import ( "encoding/binary" - bin "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" ) type StakeAuthorize uint32 @@ -71,7 +71,7 @@ func (args *LockupArgs) MarshalWithEncoder(encoder *bin.Encoder) error { if err := encoder.WriteOption(true); err != nil { return err } - if err := encoder.Encode(*args.Custodian); err != nil { + if err := encoder.WriteBytes(args.Custodian[:], false); err != nil { return err } } else { @@ -209,7 +209,7 @@ type AuthorizeWithSeedArgs struct { func (args *AuthorizeWithSeedArgs) MarshalWithEncoder(encoder *bin.Encoder) error { { - err := encoder.Encode(*args.NewAuthorizedPubkey) + err := encoder.WriteBytes(args.NewAuthorizedPubkey[:], false) if err != nil { return err } @@ -227,7 +227,7 @@ func (args *AuthorizeWithSeedArgs) MarshalWithEncoder(encoder *bin.Encoder) erro } } { - err := encoder.Encode(*args.AuthorityOwner) + err := encoder.WriteBytes(args.AuthorityOwner[:], false) if err != nil { return err } @@ -286,7 +286,7 @@ func (args *AuthorizeCheckedWithSeedArgs) MarshalWithEncoder(encoder *bin.Encode } } { - err := encoder.Encode(*args.AuthorityOwner) + err := encoder.WriteBytes(args.AuthorityOwner[:], false) if err != nil { return err } diff --git a/programs/system/AdvanceNonceAccount.go b/programs/system/AdvanceNonceAccount.go index 8a14c79ec..7dd59a30c 100644 --- a/programs/system/AdvanceNonceAccount.go +++ b/programs/system/AdvanceNonceAccount.go @@ -18,7 +18,7 @@ import ( "encoding/binary" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/system/Allocate.go b/programs/system/Allocate.go index 639fb7863..5dda73fa0 100644 --- a/programs/system/Allocate.go +++ b/programs/system/Allocate.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -117,7 +117,7 @@ func (inst *Allocate) EncodeToTree(parent ag_treeout.Branches) { func (inst Allocate) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Space` param: { - err := encoder.Encode(*inst.Space) + err := encoder.WriteUint64(*inst.Space, binary.LittleEndian) if err != nil { return err } diff --git a/programs/system/AllocateWithSeed.go b/programs/system/AllocateWithSeed.go index ade535b20..881d3b8fa 100644 --- a/programs/system/AllocateWithSeed.go +++ b/programs/system/AllocateWithSeed.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -170,7 +170,7 @@ func (inst *AllocateWithSeed) EncodeToTree(parent ag_treeout.Branches) { func (inst AllocateWithSeed) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Base` param: { - err := encoder.Encode(*inst.Base) + err := encoder.WriteBytes(inst.Base[:], false) if err != nil { return err } @@ -184,14 +184,14 @@ func (inst AllocateWithSeed) MarshalWithEncoder(encoder *ag_binary.Encoder) erro } // Serialize `Space` param: { - err := encoder.Encode(*inst.Space) + err := encoder.WriteUint64(*inst.Space, binary.LittleEndian) if err != nil { return err } } // Serialize `Owner` param: { - err := encoder.Encode(*inst.Owner) + err := encoder.WriteBytes(inst.Owner[:], false) if err != nil { return err } diff --git a/programs/system/Assign.go b/programs/system/Assign.go index 36e8ac6e9..2450a6bfd 100644 --- a/programs/system/Assign.go +++ b/programs/system/Assign.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -117,7 +117,7 @@ func (inst *Assign) EncodeToTree(parent ag_treeout.Branches) { func (inst Assign) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Owner` param: { - err := encoder.Encode(*inst.Owner) + err := encoder.WriteBytes(inst.Owner[:], false) if err != nil { return err } diff --git a/programs/system/AssignWithSeed.go b/programs/system/AssignWithSeed.go index 226b6a604..35908adce 100644 --- a/programs/system/AssignWithSeed.go +++ b/programs/system/AssignWithSeed.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -157,7 +157,7 @@ func (inst *AssignWithSeed) EncodeToTree(parent ag_treeout.Branches) { func (inst AssignWithSeed) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Base` param: { - err := encoder.Encode(*inst.Base) + err := encoder.WriteBytes(inst.Base[:], false) if err != nil { return err } @@ -171,7 +171,7 @@ func (inst AssignWithSeed) MarshalWithEncoder(encoder *ag_binary.Encoder) error } // Serialize `Owner` param: { - err := encoder.Encode(*inst.Owner) + err := encoder.WriteBytes(inst.Owner[:], false) if err != nil { return err } diff --git a/programs/system/AuthorizeNonceAccount.go b/programs/system/AuthorizeNonceAccount.go index cd4b43365..ae3ac3eac 100644 --- a/programs/system/AuthorizeNonceAccount.go +++ b/programs/system/AuthorizeNonceAccount.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -131,7 +131,7 @@ func (inst *AuthorizeNonceAccount) EncodeToTree(parent ag_treeout.Branches) { func (inst AuthorizeNonceAccount) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Authorized` param: { - err := encoder.Encode(*inst.Authorized) + err := encoder.WriteBytes(inst.Authorized[:], false) if err != nil { return err } diff --git a/programs/system/CreateAccount.go b/programs/system/CreateAccount.go index 859c0e3ee..5a87391e8 100644 --- a/programs/system/CreateAccount.go +++ b/programs/system/CreateAccount.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -157,21 +157,21 @@ func (inst *CreateAccount) EncodeToTree(parent ag_treeout.Branches) { func (inst CreateAccount) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Lamports` param: { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } } // Serialize `Space` param: { - err := encoder.Encode(*inst.Space) + err := encoder.WriteUint64(*inst.Space, binary.LittleEndian) if err != nil { return err } } // Serialize `Owner` param: { - err := encoder.Encode(*inst.Owner) + err := encoder.WriteBytes(inst.Owner[:], false) if err != nil { return err } diff --git a/programs/system/CreateAccountWithSeed.go b/programs/system/CreateAccountWithSeed.go index bf5657463..3195022a1 100644 --- a/programs/system/CreateAccountWithSeed.go +++ b/programs/system/CreateAccountWithSeed.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -205,7 +205,7 @@ func (inst *CreateAccountWithSeed) EncodeToTree(parent ag_treeout.Branches) { func (inst CreateAccountWithSeed) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Base` param: { - err := encoder.Encode(*inst.Base) + err := encoder.WriteBytes(inst.Base[:], false) if err != nil { return err } @@ -219,21 +219,21 @@ func (inst CreateAccountWithSeed) MarshalWithEncoder(encoder *ag_binary.Encoder) } // Serialize `Lamports` param: { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } } // Serialize `Space` param: { - err := encoder.Encode(*inst.Space) + err := encoder.WriteUint64(*inst.Space, binary.LittleEndian) if err != nil { return err } } // Serialize `Owner` param: { - err := encoder.Encode(*inst.Owner) + err := encoder.WriteBytes(inst.Owner[:], false) if err != nil { return err } diff --git a/programs/system/CreateAccountWithSeed_test.go b/programs/system/CreateAccountWithSeed_test.go index b501dfb2f..bc2fc9fbb 100644 --- a/programs/system/CreateAccountWithSeed_test.go +++ b/programs/system/CreateAccountWithSeed_test.go @@ -19,9 +19,9 @@ import ( "strconv" "testing" - bin "github.com/gagliardetto/binary" ag_gofuzz "github.com/gagliardetto/gofuzz" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" ag_require "github.com/stretchr/testify/require" ) diff --git a/programs/system/InitializeNonceAccount.go b/programs/system/InitializeNonceAccount.go index 802a66abc..d55b21d47 100644 --- a/programs/system/InitializeNonceAccount.go +++ b/programs/system/InitializeNonceAccount.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -149,7 +149,7 @@ func (inst *InitializeNonceAccount) EncodeToTree(parent ag_treeout.Branches) { func (inst InitializeNonceAccount) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Authorized` param: { - err := encoder.Encode(*inst.Authorized) + err := encoder.WriteBytes(inst.Authorized[:], false) if err != nil { return err } diff --git a/programs/system/Transfer.go b/programs/system/Transfer.go index 5ed3cd924..475b9c6d9 100644 --- a/programs/system/Transfer.go +++ b/programs/system/Transfer.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -131,7 +131,7 @@ func (inst *Transfer) EncodeToTree(parent ag_treeout.Branches) { func (inst Transfer) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Lamports` param: { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } diff --git a/programs/system/TransferWithSeed.go b/programs/system/TransferWithSeed.go index dc6cf6fc2..5d4b7edae 100644 --- a/programs/system/TransferWithSeed.go +++ b/programs/system/TransferWithSeed.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -171,7 +171,7 @@ func (inst *TransferWithSeed) EncodeToTree(parent ag_treeout.Branches) { func (inst TransferWithSeed) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Lamports` param: { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } @@ -185,7 +185,7 @@ func (inst TransferWithSeed) MarshalWithEncoder(encoder *ag_binary.Encoder) erro } // Serialize `FromOwner` param: { - err := encoder.Encode(*inst.FromOwner) + err := encoder.WriteBytes(inst.FromOwner[:], false) if err != nil { return err } diff --git a/programs/system/UpgradeNonceAccount.go b/programs/system/UpgradeNonceAccount.go index 5213aa08c..0fa5d3880 100644 --- a/programs/system/UpgradeNonceAccount.go +++ b/programs/system/UpgradeNonceAccount.go @@ -18,7 +18,7 @@ import ( "encoding/binary" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/system/WithdrawNonceAccount.go b/programs/system/WithdrawNonceAccount.go index c9aefc8c7..8af13e913 100644 --- a/programs/system/WithdrawNonceAccount.go +++ b/programs/system/WithdrawNonceAccount.go @@ -19,7 +19,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -175,7 +175,7 @@ func (inst *WithdrawNonceAccount) EncodeToTree(parent ag_treeout.Branches) { func (inst WithdrawNonceAccount) MarshalWithEncoder(encoder *ag_binary.Encoder) error { // Serialize `Lamports` param: { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } diff --git a/programs/system/accounts.go b/programs/system/accounts.go index 863fe36ae..c05be1df6 100644 --- a/programs/system/accounts.go +++ b/programs/system/accounts.go @@ -17,8 +17,8 @@ package system import ( "encoding/binary" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" ) type NonceAccount struct { diff --git a/programs/system/accounts_test.go b/programs/system/accounts_test.go index 062f75770..204b7275b 100644 --- a/programs/system/accounts_test.go +++ b/programs/system/accounts_test.go @@ -18,8 +18,8 @@ import ( "encoding/base64" "testing" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/assert" ) diff --git a/programs/system/createaccountwithseed_bench_test.go b/programs/system/createaccountwithseed_bench_test.go new file mode 100644 index 000000000..962b57166 --- /dev/null +++ b/programs/system/createaccountwithseed_bench_test.go @@ -0,0 +1,87 @@ +package system + +import ( + "testing" + + bin "github.com/gagliardetto/solana-go/binary" + + "github.com/gagliardetto/solana-go" +) + +// makeBenchCreateAccountWithSeed builds a fully-populated +// CreateAccountWithSeed instruction. With 5 marshalable parameter fields +// (*PublicKey × 2, *uint64 × 2, *string), it is the largest instruction +// struct in the repo and a representative target for benchmarking the +// instruction-build hot path: a length-prefixed Rust string + two 32-byte +// foreign-package pointer types + two primitive pointers all in a single +// Marshal call. +func makeBenchCreateAccountWithSeed() *CreateAccountWithSeed { + base := solana.PublicKey{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + owner := solana.PublicKey{32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1} + funding := solana.PublicKey{0xaa, 0xbb, 0xcc} + created := solana.PublicKey{0xdd, 0xee, 0xff} + return NewCreateAccountWithSeedInstruction( + base, + "benchmark-seed-string", + 1_000_000_000, // lamports + 8192, // space + owner, + funding, + created, + base, // baseAccount (same key as base for the test) + ) +} + +// BenchmarkEncode_CreateAccountWithSeed exercises the convenience +// MarshalBin path on the largest instruction struct in the repo. Goes +// through the pooled *Encoder + sync.Pool path inside MarshalBin. +func BenchmarkEncode_CreateAccountWithSeed(b *testing.B) { + inst := makeBenchCreateAccountWithSeed() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf, err := bin.MarshalBin(inst) + if err != nil { + b.Fatal(err) + } + _ = buf + } +} + +// BenchmarkDecode_CreateAccountWithSeed exercises the direct +// NewBinDecoder path; does not benefit from the Decoder pool. Compare +// to BenchmarkDecode_CreateAccountWithSeed_UnmarshalBin which does. +func BenchmarkDecode_CreateAccountWithSeed(b *testing.B) { + inst := makeBenchCreateAccountWithSeed() + data, err := bin.MarshalBin(inst) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var out CreateAccountWithSeed + dec := bin.NewBinDecoder(data) + if err := dec.Decode(&out); err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkDecode_CreateAccountWithSeed_UnmarshalBin exercises the +// pooled UnmarshalBin convenience helper. +func BenchmarkDecode_CreateAccountWithSeed_UnmarshalBin(b *testing.B) { + inst := makeBenchCreateAccountWithSeed() + data, err := bin.MarshalBin(inst) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var out CreateAccountWithSeed + if err := bin.UnmarshalBin(&out, data); err != nil { + b.Fatal(err) + } + } +} diff --git a/programs/system/instructions.go b/programs/system/instructions.go index 8e40caa07..6a70ca7a9 100644 --- a/programs/system/instructions.go +++ b/programs/system/instructions.go @@ -23,8 +23,8 @@ import ( "fmt" ag_spew "github.com/davecgh/go-spew/spew" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_text "github.com/gagliardetto/solana-go/text" ag_treeout "github.com/gagliardetto/treeout" ) @@ -40,6 +40,7 @@ const ProgramName = "System" func init() { ag_solanago.MustRegisterInstructionDecoder(ProgramID, registryDecodeInstruction) + ag_binary.PrewarmVariantDefinition(InstructionImplDef) } const ( diff --git a/programs/system/testing_utils.go b/programs/system/testing_utils.go index db393fbfe..228b8d489 100644 --- a/programs/system/testing_utils.go +++ b/programs/system/testing_utils.go @@ -18,7 +18,7 @@ import ( "bytes" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ) func encodeT(data any, buf *bytes.Buffer) error { diff --git a/programs/token-2022/AmountToUiAmount.go b/programs/token-2022/AmountToUiAmount.go index 04fefd85d..6a2f66ae8 100644 --- a/programs/token-2022/AmountToUiAmount.go +++ b/programs/token-2022/AmountToUiAmount.go @@ -1,9 +1,10 @@ package token2022 import ( + "encoding/binary" "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -104,7 +105,7 @@ func (inst *AmountToUiAmount) EncodeToTree(parent ag_treeout.Branches) { func (obj AmountToUiAmount) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } diff --git a/programs/token-2022/Approve.go b/programs/token-2022/Approve.go index dcea5cf40..79c463719 100644 --- a/programs/token-2022/Approve.go +++ b/programs/token-2022/Approve.go @@ -15,10 +15,11 @@ package token2022 import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -197,7 +198,7 @@ func (inst *Approve) EncodeToTree(parent ag_treeout.Branches) { func (obj Approve) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } diff --git a/programs/token-2022/ApproveChecked.go b/programs/token-2022/ApproveChecked.go index 98ad70754..e4db07a95 100644 --- a/programs/token-2022/ApproveChecked.go +++ b/programs/token-2022/ApproveChecked.go @@ -15,10 +15,11 @@ package token2022 import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -235,12 +236,12 @@ func (inst *ApproveChecked) EncodeToTree(parent ag_treeout.Branches) { func (obj ApproveChecked) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } diff --git a/programs/token-2022/Burn.go b/programs/token-2022/Burn.go index 538c018ec..de4046a40 100644 --- a/programs/token-2022/Burn.go +++ b/programs/token-2022/Burn.go @@ -15,10 +15,11 @@ package token2022 import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -197,7 +198,7 @@ func (inst *Burn) EncodeToTree(parent ag_treeout.Branches) { func (obj Burn) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } diff --git a/programs/token-2022/BurnChecked.go b/programs/token-2022/BurnChecked.go index 68abc9da2..fcf76ec98 100644 --- a/programs/token-2022/BurnChecked.go +++ b/programs/token-2022/BurnChecked.go @@ -15,10 +15,11 @@ package token2022 import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -216,12 +217,12 @@ func (inst *BurnChecked) EncodeToTree(parent ag_treeout.Branches) { func (obj BurnChecked) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } diff --git a/programs/token-2022/CloseAccount.go b/programs/token-2022/CloseAccount.go index b93d401fe..4610cf70c 100644 --- a/programs/token-2022/CloseAccount.go +++ b/programs/token-2022/CloseAccount.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token-2022/FreezeAccount.go b/programs/token-2022/FreezeAccount.go index 023c0abbb..325fd966a 100644 --- a/programs/token-2022/FreezeAccount.go +++ b/programs/token-2022/FreezeAccount.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token-2022/GetAccountDataSize.go b/programs/token-2022/GetAccountDataSize.go index b10efbc3b..e3932ca58 100644 --- a/programs/token-2022/GetAccountDataSize.go +++ b/programs/token-2022/GetAccountDataSize.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token-2022/InitializeAccount.go b/programs/token-2022/InitializeAccount.go index b19cf6318..20218f406 100644 --- a/programs/token-2022/InitializeAccount.go +++ b/programs/token-2022/InitializeAccount.go @@ -17,7 +17,7 @@ package token2022 import ( "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token-2022/InitializeAccount2.go b/programs/token-2022/InitializeAccount2.go index 4c4ec3e33..565e75b5c 100644 --- a/programs/token-2022/InitializeAccount2.go +++ b/programs/token-2022/InitializeAccount2.go @@ -17,7 +17,7 @@ package token2022 import ( "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -162,7 +162,7 @@ func (inst *InitializeAccount2) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeAccount2) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Owner` param: - err = encoder.Encode(obj.Owner) + err = encoder.WriteBytes(obj.Owner[:], false) if err != nil { return err } diff --git a/programs/token-2022/InitializeAccount3.go b/programs/token-2022/InitializeAccount3.go index 7d2adeb58..39e8c6477 100644 --- a/programs/token-2022/InitializeAccount3.go +++ b/programs/token-2022/InitializeAccount3.go @@ -17,7 +17,7 @@ package token2022 import ( "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -138,7 +138,7 @@ func (inst *InitializeAccount3) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeAccount3) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Owner` param: - err = encoder.Encode(obj.Owner) + err = encoder.WriteBytes(obj.Owner[:], false) if err != nil { return err } diff --git a/programs/token-2022/InitializeImmutableOwner.go b/programs/token-2022/InitializeImmutableOwner.go index 81765c769..b840fa049 100644 --- a/programs/token-2022/InitializeImmutableOwner.go +++ b/programs/token-2022/InitializeImmutableOwner.go @@ -3,7 +3,7 @@ package token2022 import ( "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token-2022/InitializeMint.go b/programs/token-2022/InitializeMint.go index 40ee8ee45..b818a1577 100644 --- a/programs/token-2022/InitializeMint.go +++ b/programs/token-2022/InitializeMint.go @@ -17,8 +17,8 @@ package token2022 import ( "errors" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" ) @@ -171,12 +171,12 @@ func (inst *InitializeMint) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeMint) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } // Serialize `MintAuthority` param: - err = encoder.Encode(obj.MintAuthority) + err = encoder.WriteBytes(obj.MintAuthority[:], false) if err != nil { return err } @@ -192,7 +192,7 @@ func (obj InitializeMint) MarshalWithEncoder(encoder *ag_binary.Encoder) (err er if err != nil { return err } - err = encoder.Encode(obj.FreezeAuthority) + err = encoder.WriteBytes(obj.FreezeAuthority[:], false) if err != nil { return err } diff --git a/programs/token-2022/InitializeMint2.go b/programs/token-2022/InitializeMint2.go index 9e2300c2d..fdf707e4c 100644 --- a/programs/token-2022/InitializeMint2.go +++ b/programs/token-2022/InitializeMint2.go @@ -17,8 +17,8 @@ package token2022 import ( "errors" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" ) @@ -143,12 +143,12 @@ func (inst *InitializeMint2) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeMint2) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } // Serialize `MintAuthority` param: - err = encoder.Encode(obj.MintAuthority) + err = encoder.WriteBytes(obj.MintAuthority[:], false) if err != nil { return err } @@ -164,7 +164,7 @@ func (obj InitializeMint2) MarshalWithEncoder(encoder *ag_binary.Encoder) (err e if err != nil { return err } - err = encoder.Encode(obj.FreezeAuthority) + err = encoder.WriteBytes(obj.FreezeAuthority[:], false) if err != nil { return err } diff --git a/programs/token-2022/InitializeMintCloseAuthority.go b/programs/token-2022/InitializeMintCloseAuthority.go index 068a4d614..6c53b8716 100644 --- a/programs/token-2022/InitializeMintCloseAuthority.go +++ b/programs/token-2022/InitializeMintCloseAuthority.go @@ -3,8 +3,8 @@ package token2022 import ( "errors" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" ) @@ -108,7 +108,7 @@ func (obj InitializeMintCloseAuthority) MarshalWithEncoder(encoder *ag_binary.En if err != nil { return err } - err = encoder.Encode(obj.CloseAuthority) + err = encoder.WriteBytes(obj.CloseAuthority[:], false) if err != nil { return err } diff --git a/programs/token-2022/InitializeMultisig.go b/programs/token-2022/InitializeMultisig.go index a10dc7420..f96b7b4f1 100644 --- a/programs/token-2022/InitializeMultisig.go +++ b/programs/token-2022/InitializeMultisig.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -193,7 +193,7 @@ func (inst *InitializeMultisig) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeMultisig) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `M` param: - err = encoder.Encode(obj.M) + err = encoder.WriteByte(*obj.M) if err != nil { return err } diff --git a/programs/token-2022/InitializeMultisig2.go b/programs/token-2022/InitializeMultisig2.go index 06c857c03..c4153d23f 100644 --- a/programs/token-2022/InitializeMultisig2.go +++ b/programs/token-2022/InitializeMultisig2.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -159,7 +159,7 @@ func (inst *InitializeMultisig2) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeMultisig2) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `M` param: - err = encoder.Encode(obj.M) + err = encoder.WriteByte(*obj.M) if err != nil { return err } diff --git a/programs/token-2022/MintTo.go b/programs/token-2022/MintTo.go index d6f8cbcff..73d253998 100644 --- a/programs/token-2022/MintTo.go +++ b/programs/token-2022/MintTo.go @@ -15,10 +15,11 @@ package token2022 import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -197,7 +198,7 @@ func (inst *MintTo) EncodeToTree(parent ag_treeout.Branches) { func (obj MintTo) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } diff --git a/programs/token-2022/MintToChecked.go b/programs/token-2022/MintToChecked.go index 2c1e6dc73..15f34d426 100644 --- a/programs/token-2022/MintToChecked.go +++ b/programs/token-2022/MintToChecked.go @@ -15,10 +15,11 @@ package token2022 import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -214,12 +215,12 @@ func (inst *MintToChecked) EncodeToTree(parent ag_treeout.Branches) { func (obj MintToChecked) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } diff --git a/programs/token-2022/Revoke.go b/programs/token-2022/Revoke.go index 6ba948b69..6db555b4c 100644 --- a/programs/token-2022/Revoke.go +++ b/programs/token-2022/Revoke.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token-2022/SetAuthority.go b/programs/token-2022/SetAuthority.go index 14871ddd4..da334bccb 100644 --- a/programs/token-2022/SetAuthority.go +++ b/programs/token-2022/SetAuthority.go @@ -18,8 +18,8 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" ) @@ -187,7 +187,7 @@ func (inst *SetAuthority) EncodeToTree(parent ag_treeout.Branches) { func (obj SetAuthority) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `AuthorityType` param: - err = encoder.Encode(obj.AuthorityType) + err = encoder.WriteByte(uint8(*obj.AuthorityType)) if err != nil { return err } @@ -203,7 +203,7 @@ func (obj SetAuthority) MarshalWithEncoder(encoder *ag_binary.Encoder) (err erro if err != nil { return err } - err = encoder.Encode(obj.NewAuthority) + err = encoder.WriteBytes(obj.NewAuthority[:], false) if err != nil { return err } diff --git a/programs/token-2022/SyncNative.go b/programs/token-2022/SyncNative.go index c40bd6207..547399472 100644 --- a/programs/token-2022/SyncNative.go +++ b/programs/token-2022/SyncNative.go @@ -17,7 +17,7 @@ package token2022 import ( "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token-2022/ThawAccount.go b/programs/token-2022/ThawAccount.go index 95d593c33..46dfea94a 100644 --- a/programs/token-2022/ThawAccount.go +++ b/programs/token-2022/ThawAccount.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token-2022/Transfer.go b/programs/token-2022/Transfer.go index 3519c19f3..cbec44e23 100644 --- a/programs/token-2022/Transfer.go +++ b/programs/token-2022/Transfer.go @@ -15,10 +15,11 @@ package token2022 import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -199,7 +200,7 @@ func (inst *Transfer) EncodeToTree(parent ag_treeout.Branches) { func (obj Transfer) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } diff --git a/programs/token-2022/TransferChecked.go b/programs/token-2022/TransferChecked.go index fbe3251d4..97d6f8ee2 100644 --- a/programs/token-2022/TransferChecked.go +++ b/programs/token-2022/TransferChecked.go @@ -15,10 +15,11 @@ package token2022 import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -237,12 +238,12 @@ func (inst *TransferChecked) EncodeToTree(parent ag_treeout.Branches) { func (obj TransferChecked) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } diff --git a/programs/token-2022/UiAmountToAmount.go b/programs/token-2022/UiAmountToAmount.go index e2d48ae55..4afa3da8e 100644 --- a/programs/token-2022/UiAmountToAmount.go +++ b/programs/token-2022/UiAmountToAmount.go @@ -3,8 +3,8 @@ package token2022 import ( "errors" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" ) diff --git a/programs/token-2022/accounts.go b/programs/token-2022/accounts.go index 64cf4abe3..f11cf41dc 100644 --- a/programs/token-2022/accounts.go +++ b/programs/token-2022/accounts.go @@ -3,8 +3,8 @@ package token2022 import ( "encoding/binary" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" ) type Mint struct { diff --git a/programs/token-2022/instructions.go b/programs/token-2022/instructions.go index 620c1f434..877d73036 100644 --- a/programs/token-2022/instructions.go +++ b/programs/token-2022/instructions.go @@ -23,8 +23,8 @@ import ( "fmt" ag_spew "github.com/davecgh/go-spew/spew" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_text "github.com/gagliardetto/solana-go/text" ag_treeout "github.com/gagliardetto/treeout" ) @@ -45,6 +45,7 @@ func init() { if !ProgramID.IsZero() { ag_solanago.MustRegisterInstructionDecoder(ProgramID, registryDecodeInstruction) } + ag_binary.PrewarmVariantDefinition(InstructionImplDef) } const ( diff --git a/programs/token-2022/rpc.go b/programs/token-2022/rpc.go index dc95df9cd..1be4b2a4c 100644 --- a/programs/token-2022/rpc.go +++ b/programs/token-2022/rpc.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/rpc" ) diff --git a/programs/token-2022/testing_utils.go b/programs/token-2022/testing_utils.go index 604e8707f..3e0dffcbc 100644 --- a/programs/token-2022/testing_utils.go +++ b/programs/token-2022/testing_utils.go @@ -4,7 +4,7 @@ import ( "bytes" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ) func encodeT(data any, buf *bytes.Buffer) error { diff --git a/programs/token-2022/types.go b/programs/token-2022/types.go index a38fbde87..9dd68531d 100644 --- a/programs/token-2022/types.go +++ b/programs/token-2022/types.go @@ -1,7 +1,7 @@ package token2022 import ( - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ) type AuthorityType ag_binary.BorshEnum @@ -81,33 +81,33 @@ const ( type ExtensionType uint16 const ( - ExtensionUninitialized ExtensionType = 0 - ExtensionTransferFeeConfig ExtensionType = 1 - ExtensionTransferFeeAmount ExtensionType = 2 - ExtensionMintCloseAuthority ExtensionType = 3 - ExtensionConfidentialTransferMint ExtensionType = 4 - ExtensionConfidentialTransferAccount ExtensionType = 5 - ExtensionDefaultAccountState ExtensionType = 6 - ExtensionImmutableOwner ExtensionType = 7 - ExtensionMemoTransfer ExtensionType = 8 - ExtensionNonTransferable ExtensionType = 9 - ExtensionInterestBearingConfig ExtensionType = 10 - ExtensionCpiGuard ExtensionType = 11 - ExtensionPermanentDelegate ExtensionType = 12 - ExtensionNonTransferableAccount ExtensionType = 13 - ExtensionTransferHook ExtensionType = 14 - ExtensionTransferHookAccount ExtensionType = 15 - ExtensionConfidentialTransferFeeConfig ExtensionType = 16 - ExtensionConfidentialTransferFeeAmount ExtensionType = 17 - ExtensionMetadataPointer ExtensionType = 18 - ExtensionTokenMetadata ExtensionType = 19 - ExtensionGroupPointer ExtensionType = 20 - ExtensionTokenGroup ExtensionType = 21 - ExtensionGroupMemberPointer ExtensionType = 22 - ExtensionTokenGroupMember ExtensionType = 23 - ExtensionConfidentialMintBurn ExtensionType = 24 - ExtensionScaledUiAmount ExtensionType = 25 - ExtensionPausable ExtensionType = 26 - ExtensionPausableAccount ExtensionType = 27 - ExtensionPermissionedBurn ExtensionType = 28 + ExtensionUninitialized ExtensionType = 0 + ExtensionTransferFeeConfig ExtensionType = 1 + ExtensionTransferFeeAmount ExtensionType = 2 + ExtensionMintCloseAuthority ExtensionType = 3 + ExtensionConfidentialTransferMint ExtensionType = 4 + ExtensionConfidentialTransferAccount ExtensionType = 5 + ExtensionDefaultAccountState ExtensionType = 6 + ExtensionImmutableOwner ExtensionType = 7 + ExtensionMemoTransfer ExtensionType = 8 + ExtensionNonTransferable ExtensionType = 9 + ExtensionInterestBearingConfig ExtensionType = 10 + ExtensionCpiGuard ExtensionType = 11 + ExtensionPermanentDelegate ExtensionType = 12 + ExtensionNonTransferableAccount ExtensionType = 13 + ExtensionTransferHook ExtensionType = 14 + ExtensionTransferHookAccount ExtensionType = 15 + ExtensionConfidentialTransferFeeConfig ExtensionType = 16 + ExtensionConfidentialTransferFeeAmount ExtensionType = 17 + ExtensionMetadataPointer ExtensionType = 18 + ExtensionTokenMetadata ExtensionType = 19 + ExtensionGroupPointer ExtensionType = 20 + ExtensionTokenGroup ExtensionType = 21 + ExtensionGroupMemberPointer ExtensionType = 22 + ExtensionTokenGroupMember ExtensionType = 23 + ExtensionConfidentialMintBurn ExtensionType = 24 + ExtensionScaledUiAmount ExtensionType = 25 + ExtensionPausable ExtensionType = 26 + ExtensionPausableAccount ExtensionType = 27 + ExtensionPermissionedBurn ExtensionType = 28 ) diff --git a/programs/token/Approve.go b/programs/token/Approve.go index e7fb66806..8eb13f942 100644 --- a/programs/token/Approve.go +++ b/programs/token/Approve.go @@ -15,10 +15,11 @@ package token import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -197,7 +198,7 @@ func (inst *Approve) EncodeToTree(parent ag_treeout.Branches) { func (obj Approve) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } diff --git a/programs/token/ApproveChecked.go b/programs/token/ApproveChecked.go index fcfb9f50c..6de405605 100644 --- a/programs/token/ApproveChecked.go +++ b/programs/token/ApproveChecked.go @@ -15,10 +15,11 @@ package token import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -235,12 +236,12 @@ func (inst *ApproveChecked) EncodeToTree(parent ag_treeout.Branches) { func (obj ApproveChecked) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } diff --git a/programs/token/Burn.go b/programs/token/Burn.go index ec745ac1c..f66448883 100644 --- a/programs/token/Burn.go +++ b/programs/token/Burn.go @@ -15,10 +15,11 @@ package token import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -197,7 +198,7 @@ func (inst *Burn) EncodeToTree(parent ag_treeout.Branches) { func (obj Burn) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } diff --git a/programs/token/BurnChecked.go b/programs/token/BurnChecked.go index c434e6dbb..5d7ca0358 100644 --- a/programs/token/BurnChecked.go +++ b/programs/token/BurnChecked.go @@ -15,10 +15,11 @@ package token import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -216,12 +217,12 @@ func (inst *BurnChecked) EncodeToTree(parent ag_treeout.Branches) { func (obj BurnChecked) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } diff --git a/programs/token/CloseAccount.go b/programs/token/CloseAccount.go index 3d4101da5..cfd1f43f8 100644 --- a/programs/token/CloseAccount.go +++ b/programs/token/CloseAccount.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token/FreezeAccount.go b/programs/token/FreezeAccount.go index 303cdc311..607432c96 100644 --- a/programs/token/FreezeAccount.go +++ b/programs/token/FreezeAccount.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token/InitializeAccount.go b/programs/token/InitializeAccount.go index eb87cc2ba..2fa0982de 100644 --- a/programs/token/InitializeAccount.go +++ b/programs/token/InitializeAccount.go @@ -17,7 +17,7 @@ package token import ( "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token/InitializeAccount2.go b/programs/token/InitializeAccount2.go index a010027d2..78c244c70 100644 --- a/programs/token/InitializeAccount2.go +++ b/programs/token/InitializeAccount2.go @@ -17,7 +17,7 @@ package token import ( "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -162,7 +162,7 @@ func (inst *InitializeAccount2) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeAccount2) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Owner` param: - err = encoder.Encode(obj.Owner) + err = encoder.WriteBytes(obj.Owner[:], false) if err != nil { return err } diff --git a/programs/token/InitializeAccount3.go b/programs/token/InitializeAccount3.go index 4b68e4fbf..35a0197f5 100644 --- a/programs/token/InitializeAccount3.go +++ b/programs/token/InitializeAccount3.go @@ -17,7 +17,7 @@ package token import ( "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -138,7 +138,7 @@ func (inst *InitializeAccount3) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeAccount3) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Owner` param: - err = encoder.Encode(obj.Owner) + err = encoder.WriteBytes(obj.Owner[:], false) if err != nil { return err } diff --git a/programs/token/InitializeMint.go b/programs/token/InitializeMint.go index a3e0dba18..9d80b23a4 100644 --- a/programs/token/InitializeMint.go +++ b/programs/token/InitializeMint.go @@ -17,8 +17,8 @@ package token import ( "errors" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" ) @@ -171,12 +171,12 @@ func (inst *InitializeMint) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeMint) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } // Serialize `MintAuthority` param: - err = encoder.Encode(obj.MintAuthority) + err = encoder.WriteBytes(obj.MintAuthority[:], false) if err != nil { return err } @@ -192,7 +192,7 @@ func (obj InitializeMint) MarshalWithEncoder(encoder *ag_binary.Encoder) (err er if err != nil { return err } - err = encoder.Encode(obj.FreezeAuthority) + err = encoder.WriteBytes(obj.FreezeAuthority[:], false) if err != nil { return err } diff --git a/programs/token/InitializeMint2.go b/programs/token/InitializeMint2.go index 238e606fb..11fe0e93f 100644 --- a/programs/token/InitializeMint2.go +++ b/programs/token/InitializeMint2.go @@ -17,8 +17,8 @@ package token import ( "errors" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" ) @@ -143,12 +143,12 @@ func (inst *InitializeMint2) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeMint2) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } // Serialize `MintAuthority` param: - err = encoder.Encode(obj.MintAuthority) + err = encoder.WriteBytes(obj.MintAuthority[:], false) if err != nil { return err } @@ -164,7 +164,7 @@ func (obj InitializeMint2) MarshalWithEncoder(encoder *ag_binary.Encoder) (err e if err != nil { return err } - err = encoder.Encode(obj.FreezeAuthority) + err = encoder.WriteBytes(obj.FreezeAuthority[:], false) if err != nil { return err } diff --git a/programs/token/InitializeMultisig.go b/programs/token/InitializeMultisig.go index d4042cb6c..ce1219a98 100644 --- a/programs/token/InitializeMultisig.go +++ b/programs/token/InitializeMultisig.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -193,7 +193,7 @@ func (inst *InitializeMultisig) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeMultisig) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `M` param: - err = encoder.Encode(obj.M) + err = encoder.WriteByte(*obj.M) if err != nil { return err } diff --git a/programs/token/InitializeMultisig2.go b/programs/token/InitializeMultisig2.go index 6165849cb..11b0a776f 100644 --- a/programs/token/InitializeMultisig2.go +++ b/programs/token/InitializeMultisig2.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -159,7 +159,7 @@ func (inst *InitializeMultisig2) EncodeToTree(parent ag_treeout.Branches) { func (obj InitializeMultisig2) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `M` param: - err = encoder.Encode(obj.M) + err = encoder.WriteByte(*obj.M) if err != nil { return err } diff --git a/programs/token/MintTo.go b/programs/token/MintTo.go index d008bc150..708e87b90 100644 --- a/programs/token/MintTo.go +++ b/programs/token/MintTo.go @@ -15,10 +15,11 @@ package token import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -197,7 +198,7 @@ func (inst *MintTo) EncodeToTree(parent ag_treeout.Branches) { func (obj MintTo) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } diff --git a/programs/token/MintToChecked.go b/programs/token/MintToChecked.go index 6edafa4d2..0e8f17828 100644 --- a/programs/token/MintToChecked.go +++ b/programs/token/MintToChecked.go @@ -15,10 +15,11 @@ package token import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -214,12 +215,12 @@ func (inst *MintToChecked) EncodeToTree(parent ag_treeout.Branches) { func (obj MintToChecked) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } diff --git a/programs/token/Revoke.go b/programs/token/Revoke.go index 6c17c87ad..96733f151 100644 --- a/programs/token/Revoke.go +++ b/programs/token/Revoke.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token/SetAuthority.go b/programs/token/SetAuthority.go index bfb4c8b53..0b3e4a91f 100644 --- a/programs/token/SetAuthority.go +++ b/programs/token/SetAuthority.go @@ -18,8 +18,8 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" ) @@ -187,7 +187,7 @@ func (inst *SetAuthority) EncodeToTree(parent ag_treeout.Branches) { func (obj SetAuthority) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `AuthorityType` param: - err = encoder.Encode(obj.AuthorityType) + err = encoder.WriteByte(uint8(*obj.AuthorityType)) if err != nil { return err } @@ -203,7 +203,7 @@ func (obj SetAuthority) MarshalWithEncoder(encoder *ag_binary.Encoder) (err erro if err != nil { return err } - err = encoder.Encode(obj.NewAuthority) + err = encoder.WriteBytes(obj.NewAuthority[:], false) if err != nil { return err } diff --git a/programs/token/SyncNative.go b/programs/token/SyncNative.go index 41605eb26..76b274cfd 100644 --- a/programs/token/SyncNative.go +++ b/programs/token/SyncNative.go @@ -17,7 +17,7 @@ package token import ( "errors" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token/ThawAccount.go b/programs/token/ThawAccount.go index ffe09c08f..e693863f8 100644 --- a/programs/token/ThawAccount.go +++ b/programs/token/ThawAccount.go @@ -18,7 +18,7 @@ import ( "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" diff --git a/programs/token/Transfer.go b/programs/token/Transfer.go index e68d3fe8c..80d40243d 100644 --- a/programs/token/Transfer.go +++ b/programs/token/Transfer.go @@ -15,10 +15,11 @@ package token import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -199,7 +200,7 @@ func (inst *Transfer) EncodeToTree(parent ag_treeout.Branches) { func (obj Transfer) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } diff --git a/programs/token/TransferChecked.go b/programs/token/TransferChecked.go index 3f30aba38..218981135 100644 --- a/programs/token/TransferChecked.go +++ b/programs/token/TransferChecked.go @@ -15,10 +15,11 @@ package token import ( + "encoding/binary" "errors" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_solanago "github.com/gagliardetto/solana-go" ag_format "github.com/gagliardetto/solana-go/text/format" ag_treeout "github.com/gagliardetto/treeout" @@ -237,12 +238,12 @@ func (inst *TransferChecked) EncodeToTree(parent ag_treeout.Branches) { func (obj TransferChecked) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { // Serialize `Amount` param: - err = encoder.Encode(obj.Amount) + err = encoder.WriteUint64(*obj.Amount, binary.LittleEndian) if err != nil { return err } // Serialize `Decimals` param: - err = encoder.Encode(obj.Decimals) + err = encoder.WriteByte(*obj.Decimals) if err != nil { return err } diff --git a/programs/token/accounts.go b/programs/token/accounts.go index eac089e1c..52b5d9dc6 100644 --- a/programs/token/accounts.go +++ b/programs/token/accounts.go @@ -17,8 +17,8 @@ package token import ( "encoding/binary" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" ) type Mint struct { diff --git a/programs/token/accounts_test.go b/programs/token/accounts_test.go index b952b136d..84cc5fab1 100644 --- a/programs/token/accounts_test.go +++ b/programs/token/accounts_test.go @@ -5,8 +5,8 @@ import ( "testing" "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/require" ) diff --git a/programs/token/instructions.go b/programs/token/instructions.go index 8dae1a451..67ca35a11 100644 --- a/programs/token/instructions.go +++ b/programs/token/instructions.go @@ -22,8 +22,8 @@ import ( "fmt" ag_spew "github.com/davecgh/go-spew/spew" - ag_binary "github.com/gagliardetto/binary" ag_solanago "github.com/gagliardetto/solana-go" + ag_binary "github.com/gagliardetto/solana-go/binary" ag_text "github.com/gagliardetto/solana-go/text" ag_treeout "github.com/gagliardetto/treeout" ) @@ -44,6 +44,7 @@ func init() { if !ProgramID.IsZero() { ag_solanago.MustRegisterInstructionDecoder(ProgramID, registryDecodeInstruction) } + ag_binary.PrewarmVariantDefinition(InstructionImplDef) } const ( diff --git a/programs/token/rpc.go b/programs/token/rpc.go index 3a3949b2f..ac5e70542 100644 --- a/programs/token/rpc.go +++ b/programs/token/rpc.go @@ -21,7 +21,7 @@ import ( "context" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/rpc" ) diff --git a/programs/token/testing_utils.go b/programs/token/testing_utils.go index 233dad1a9..7080892a3 100644 --- a/programs/token/testing_utils.go +++ b/programs/token/testing_utils.go @@ -17,7 +17,7 @@ package token import ( "bytes" "fmt" - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ) func encodeT(data any, buf *bytes.Buffer) error { diff --git a/programs/token/types.go b/programs/token/types.go index 41de8217c..15e7758bf 100644 --- a/programs/token/types.go +++ b/programs/token/types.go @@ -15,7 +15,7 @@ package token import ( - ag_binary "github.com/gagliardetto/binary" + ag_binary "github.com/gagliardetto/solana-go/binary" ) type AuthorityType ag_binary.BorshEnum diff --git a/programs/tokenregistry/instruction.go b/programs/tokenregistry/instruction.go index d0f21511c..08a09a894 100644 --- a/programs/tokenregistry/instruction.go +++ b/programs/tokenregistry/instruction.go @@ -24,12 +24,17 @@ import ( "github.com/gagliardetto/solana-go/text" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" ) func init() { solana.MustRegisterInstructionDecoder(ProgramID(), registryDecodeInstruction) + // Eliminate the first-encode reflect-walk cost for the package's + // marshaled types so the typePlan cache is populated by the time the + // process makes its first call. + bin.PrewarmVariantDefinition(InstructionDefVariant) + bin.PrewarmTypes(TokenMeta{}) } func registryDecodeInstruction(accounts []*solana.AccountMeta, data []byte) (any, error) { diff --git a/programs/tokenregistry/tokenmeta_bench_test.go b/programs/tokenregistry/tokenmeta_bench_test.go new file mode 100644 index 000000000..0020a59fb --- /dev/null +++ b/programs/tokenregistry/tokenmeta_bench_test.go @@ -0,0 +1,180 @@ +package tokenregistry + +import ( + "bytes" + "testing" + "time" + + bin "github.com/gagliardetto/solana-go/binary" + + "github.com/gagliardetto/solana-go" +) + +// makeBenchTokenMeta builds a fully-populated TokenMeta. It is the largest +// reflect-marshaled struct in the repo (9 fields, two foreign-package pointer +// types, four nested fixed-size byte arrays), so it stresses the typePlan +// cache, the indirect() walker, and the array-write fast path simultaneously. +func makeBenchTokenMeta() TokenMeta { + mint := solana.PublicKey{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + auth := solana.PublicKey{32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1} + logo, _ := LogoFromString("https://example.com/token-logo.png") + name, _ := NameFromString("Example Token") + site, _ := WebsiteFromString("https://example.com") + sym, _ := SymbolFromString("EXMPL") + return TokenMeta{ + IsInitialized: true, + Reg: [3]byte{1, 2, 3}, + DataType: 7, + MintAddress: &mint, + RegistrationAuthority: &auth, + Logo: logo, + Name: name, + Website: site, + Symbol: sym, + } +} + +// BenchmarkEncode_TokenMeta exercises the reflect/typePlan path on the +// largest struct in the repo. Uses MarshalBin (the high-level helper) so +// the benchmark is portable across the upstream gagliardetto/binary +// v0.8.0 module and the vendored copy. +func BenchmarkEncode_TokenMeta(b *testing.B) { + tm := makeBenchTokenMeta() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf, err := bin.MarshalBin(&tm) + if err != nil { + b.Fatal(err) + } + _ = buf + } +} + +func BenchmarkDecode_TokenMeta(b *testing.B) { + tm := makeBenchTokenMeta() + data, err := bin.MarshalBin(&tm) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var out TokenMeta + dec := bin.NewBinDecoder(data) + if err := dec.Decode(&out); err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkDecode_TokenMeta_UnmarshalBin exercises the pooled +// UnmarshalBin convenience helper, which draws a Decoder from a +// sync.Pool, resets it with the input bytes, decodes, and returns it +// for reuse. Compare to BenchmarkDecode_TokenMeta which constructs a +// fresh Decoder every iteration. +func BenchmarkDecode_TokenMeta_UnmarshalBin(b *testing.B) { + tm := makeBenchTokenMeta() + data, err := bin.MarshalBin(&tm) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var out TokenMeta + if err := bin.UnmarshalBin(&out, data); err != nil { + b.Fatal(err) + } + } +} + +// Reused-buffer encoder variant — writes into a single bytes.Buffer that's +// reset between iterations, isolating the reflect-walk cost from +// MarshalBin's allocation of a fresh buffer per call. +func BenchmarkEncode_TokenMeta_Reused(b *testing.B) { + tm := makeBenchTokenMeta() + var buf bytes.Buffer + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + enc := bin.NewBinEncoder(&buf) + if err := enc.Encode(&tm); err != nil { + b.Fatal(err) + } + } +} + +// coldStruct is a copy of TokenMeta intentionally never referenced +// outside this file. Its typePlan has not been prewarmed by init(), so +// the very first call to MarshalBin pays the cold-path cost (reflect +// walk + plan construction + sync.Map.LoadOrStore). +type coldStruct struct { + IsInitialized bool + Reg [3]byte + DataType byte + MintAddress *solana.PublicKey + RegistrationAuthority *solana.PublicKey + Logo Logo + Name Name + Website Website + Symbol Symbol +} + +// runtimeWarmup is a throwaway type used to pre-load the encoder's +// reflect-driven code paths, CPU caches, and any lazy-init globals +// BEFORE we measure typePlan-build cost. Without this, the very first +// MarshalBin call in a process pays a fixed ~100-200µs runtime cold- +// start cost that dwarfs the actual typePlan construction (~5-10µs) +// and gives nonsense readings. +type runtimeWarmup struct{ X uint64 } + +// TestFirstCallCost measures, after a runtime warmup, the cost of +// encoding TokenMeta (whose typePlan has been prewarmed by init via +// PrewarmTypes) versus coldStruct (whose typePlan has NEVER been +// touched, so the very first encode pays the reflect-walk cost). The +// delta is the cold-path overhead that PrewarmTypes eliminates. +func TestFirstCallCost(t *testing.T) { + // Throwaway call to warm the runtime, the encoder, and the CPU + // caches. After this, both subsequent measurements start from the + // same hot state and only the typePlan lookup differs. + if _, err := bin.MarshalBin(&runtimeWarmup{X: 42}); err != nil { + t.Fatal(err) + } + + // Warm path: TokenMeta. Prewarmed in init() — the typePlan is + // already in the sync.Map and the first encode hits the load fast + // path. + tm := makeBenchTokenMeta() + startWarm := time.Now() + if _, err := bin.MarshalBin(&tm); err != nil { + t.Fatal(err) + } + warmDur := time.Since(startWarm) + + // Cold path: coldStruct. Cache miss, typePlan built inline. + cs := coldStruct{IsInitialized: true, Reg: [3]byte{1, 2, 3}, DataType: 7} + mint := solana.PublicKey{1, 2, 3} + auth := solana.PublicKey{4, 5, 6} + cs.MintAddress = &mint + cs.RegistrationAuthority = &auth + logo, _ := LogoFromString("logo") + name, _ := NameFromString("name") + site, _ := WebsiteFromString("site") + sym, _ := SymbolFromString("sym") + cs.Logo = logo + cs.Name = name + cs.Website = site + cs.Symbol = sym + + startCold := time.Now() + if _, err := bin.MarshalBin(&cs); err != nil { + t.Fatal(err) + } + coldDur := time.Since(startCold) + + t.Logf("warm (TokenMeta, typePlan prewarmed via init): %v", warmDur) + t.Logf("cold (coldStruct, typePlan miss + build): %v", coldDur) + t.Logf("cold-path overhead saved by prewarm: %v", coldDur-warmDur) +} diff --git a/programs/tokenregistry/types.go b/programs/tokenregistry/types.go index 3444504b3..6ed8ba39d 100644 --- a/programs/tokenregistry/types.go +++ b/programs/tokenregistry/types.go @@ -7,7 +7,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -19,7 +19,7 @@ package tokenregistry import ( "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" ) @@ -38,6 +38,116 @@ type TokenMeta struct { Symbol Symbol } +// MarshalWithEncoder is a hand-written marshaler that bypasses the +// per-field plan-walk overhead of the generic reflect path. The cost +// of the boilerplate is justified for TokenMeta because it is the +// largest reflect-marshaled struct in the repo (9 fields including 4 +// nested fixed-size byte arrays and 2 foreign-package pointer types). +// +// We use the explicit Write* methods on the encoder rather than the +// generic encoder.Encode(v any) entry point, because the latter would +// box every [N]byte field into an interface{} on the way in, paying a +// 32-64 byte heap allocation per call. WriteBytes takes a []byte +// directly with no boxing. +func (obj *TokenMeta) MarshalWithEncoder(encoder *bin.Encoder) error { + if err := encoder.WriteBool(obj.IsInitialized); err != nil { + return err + } + if err := encoder.WriteBytes(obj.Reg[:], false); err != nil { + return err + } + if err := encoder.WriteByte(obj.DataType); err != nil { + return err + } + if obj.MintAddress != nil { + if err := encoder.WriteBytes(obj.MintAddress[:], false); err != nil { + return err + } + } else { + var zero solana.PublicKey + if err := encoder.WriteBytes(zero[:], false); err != nil { + return err + } + } + if obj.RegistrationAuthority != nil { + if err := encoder.WriteBytes(obj.RegistrationAuthority[:], false); err != nil { + return err + } + } else { + var zero solana.PublicKey + if err := encoder.WriteBytes(zero[:], false); err != nil { + return err + } + } + if err := encoder.WriteBytes(obj.Logo[:], false); err != nil { + return err + } + if err := encoder.WriteBytes(obj.Name[:], false); err != nil { + return err + } + if err := encoder.WriteBytes(obj.Website[:], false); err != nil { + return err + } + if err := encoder.WriteBytes(obj.Symbol[:], false); err != nil { + return err + } + return nil +} + +// UnmarshalWithDecoder reads each field directly into its destination +// with no per-field reflect or option-construction cost. The two +// *solana.PublicKey allocations are unavoidable because the field +// types are pointers. +func (obj *TokenMeta) UnmarshalWithDecoder(decoder *bin.Decoder) error { + var err error + if obj.IsInitialized, err = decoder.ReadBool(); err != nil { + return err + } + regBytes, err := decoder.ReadNBytes(3) + if err != nil { + return err + } + copy(obj.Reg[:], regBytes) + if obj.DataType, err = decoder.ReadByte(); err != nil { + return err + } + mint := new(solana.PublicKey) + mintBytes, err := decoder.ReadNBytes(32) + if err != nil { + return err + } + copy(mint[:], mintBytes) + obj.MintAddress = mint + auth := new(solana.PublicKey) + authBytes, err := decoder.ReadNBytes(32) + if err != nil { + return err + } + copy(auth[:], authBytes) + obj.RegistrationAuthority = auth + logoBytes, err := decoder.ReadNBytes(64) + if err != nil { + return err + } + copy(obj.Logo[:], logoBytes) + nameBytes, err := decoder.ReadNBytes(32) + if err != nil { + return err + } + copy(obj.Name[:], nameBytes) + siteBytes, err := decoder.ReadNBytes(32) + if err != nil { + return err + } + copy(obj.Website[:], siteBytes) + symBytes, err := decoder.ReadNBytes(32) + if err != nil { + return err + } + copy(obj.Symbol[:], symBytes) + return nil +} + func DecodeTokenMeta(in []byte) (*TokenMeta, error) { var t *TokenMeta decoder := bin.NewBinDecoder(in) diff --git a/programs/vote/Vote.go b/programs/vote/Vote.go index 651ebd01b..0a428367b 100644 --- a/programs/vote/Vote.go +++ b/programs/vote/Vote.go @@ -18,8 +18,8 @@ import ( "fmt" "time" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" ) diff --git a/programs/vote/Withdraw.go b/programs/vote/Withdraw.go index 67517845f..587fb720c 100644 --- a/programs/vote/Withdraw.go +++ b/programs/vote/Withdraw.go @@ -16,10 +16,11 @@ package vote import ( + "encoding/binary" "errors" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/text/format" "github.com/gagliardetto/treeout" @@ -55,7 +56,7 @@ func (v *Withdraw) UnmarshalWithDecoder(dec *bin.Decoder) error { func (inst *Withdraw) MarshalWithEncoder(encoder *bin.Encoder) error { // Serialize `Lamports` param: { - err := encoder.Encode(*inst.Lamports) + err := encoder.WriteUint64(*inst.Lamports, binary.LittleEndian) if err != nil { return err } diff --git a/programs/vote/instructions.go b/programs/vote/instructions.go index ed518c3bd..2bc74f2cc 100644 --- a/programs/vote/instructions.go +++ b/programs/vote/instructions.go @@ -20,8 +20,8 @@ import ( "fmt" "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text" "github.com/gagliardetto/treeout" ) @@ -37,6 +37,7 @@ const ProgramName = "Vote" func init() { solana.MustRegisterInstructionDecoder(ProgramID, registryDecodeInstruction) + bin.PrewarmVariantDefinition(InstructionImplDef) } type Instruction struct { diff --git a/rpc/client_test.go b/rpc/client_test.go index 038673af5..cbee18d83 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -27,7 +27,7 @@ import ( "testing" "github.com/AlekSi/pointer" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/rpc/example_test.go b/rpc/example_test.go index b7a089abc..a67c6b0a6 100644 --- a/rpc/example_test.go +++ b/rpc/example_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/programs/token" "github.com/gagliardetto/solana-go/rpc" diff --git a/rpc/examples/getAccountInfo/getAccountInfo.go b/rpc/examples/getAccountInfo/getAccountInfo.go index b189ddbd0..c5496c175 100644 --- a/rpc/examples/getAccountInfo/getAccountInfo.go +++ b/rpc/examples/getAccountInfo/getAccountInfo.go @@ -18,7 +18,7 @@ import ( "context" "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" solana "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/programs/token" "github.com/gagliardetto/solana-go/rpc" diff --git a/rpc/examples/getTokenAccountsByOwner/getTokenAccountsByOwner.go b/rpc/examples/getTokenAccountsByOwner/getTokenAccountsByOwner.go index b9a27c957..0a505b746 100644 --- a/rpc/examples/getTokenAccountsByOwner/getTokenAccountsByOwner.go +++ b/rpc/examples/getTokenAccountsByOwner/getTokenAccountsByOwner.go @@ -18,7 +18,7 @@ import ( "context" "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/programs/token" "github.com/gagliardetto/solana-go/rpc" diff --git a/rpc/examples/getTransaction/getTransaction.go b/rpc/examples/getTransaction/getTransaction.go index 4afd7d40a..ca8f29d15 100644 --- a/rpc/examples/getTransaction/getTransaction.go +++ b/rpc/examples/getTransaction/getTransaction.go @@ -18,7 +18,7 @@ import ( "context" "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" ) diff --git a/rpc/examples/sendTransaction/sendTransaction.go b/rpc/examples/sendTransaction/sendTransaction.go index f53d6a854..6fd62c2e2 100644 --- a/rpc/examples/sendTransaction/sendTransaction.go +++ b/rpc/examples/sendTransaction/sendTransaction.go @@ -19,7 +19,7 @@ import ( "encoding/base64" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" ) diff --git a/rpc/getAccountInfo.go b/rpc/getAccountInfo.go index 00784aa5f..22c6c36ee 100644 --- a/rpc/getAccountInfo.go +++ b/rpc/getAccountInfo.go @@ -20,7 +20,7 @@ import ( "context" "errors" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" ) diff --git a/rpc/getParsedTransaction.go b/rpc/getParsedTransaction.go index 25436e756..a381434a6 100644 --- a/rpc/getParsedTransaction.go +++ b/rpc/getParsedTransaction.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" ) diff --git a/rpc/getTransaction.go b/rpc/getTransaction.go index 6a93567b3..a5f53cde8 100644 --- a/rpc/getTransaction.go +++ b/rpc/getTransaction.go @@ -18,7 +18,7 @@ import ( "context" "fmt" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" ) diff --git a/rpc/types.go b/rpc/types.go index d1c396477..1b0dca95f 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -23,7 +23,7 @@ import ( "fmt" "math/big" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go" ) diff --git a/transaction.go b/transaction.go index 8af306acd..8c461f0b7 100644 --- a/transaction.go +++ b/transaction.go @@ -24,7 +24,7 @@ import ( "slices" "github.com/davecgh/go-spew/spew" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/gagliardetto/solana-go/text" "github.com/gagliardetto/treeout" "github.com/mr-tron/base58" diff --git a/transaction_test.go b/transaction_test.go index e5801036e..d3552f986 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -21,7 +21,7 @@ import ( "encoding/base64" "testing" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/mr-tron/base58" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/types_test.go b/types_test.go index 6f61575fc..2491612a7 100644 --- a/types_test.go +++ b/types_test.go @@ -21,7 +21,7 @@ import ( "bytes" "testing" - bin "github.com/gagliardetto/binary" + bin "github.com/gagliardetto/solana-go/binary" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" )