diff --git a/tpm/constants.go b/tpm/constants.go index 8dc8e4cf..c568987f 100644 --- a/tpm/constants.go +++ b/tpm/constants.go @@ -18,7 +18,7 @@ import "github.com/google/go-tpm/tpmutil" func init() { // TPM 1.2 spec uses uint32 for length prefix of byte arrays. - tpmutil.UseTPM12LengthPrefixSize() + tpmutil.UseTPM12Encoding() } // Supported TPM commands. diff --git a/tpm2/constants.go b/tpm2/constants.go index ea954e7a..0b835463 100644 --- a/tpm2/constants.go +++ b/tpm2/constants.go @@ -26,7 +26,7 @@ import ( ) func init() { - tpmutil.UseTPM20LengthPrefixSize() + tpmutil.UseTPM20Encoding() } // MAX_DIGEST_BUFFER is the maximum size of []byte request or response fields. diff --git a/tpmutil/encoding.go b/tpmutil/encoding.go index 9457777c..e05dd360 100644 --- a/tpmutil/encoding.go +++ b/tpmutil/encoding.go @@ -23,32 +23,49 @@ import ( "reflect" ) -// lengthPrefixSize is the size in bytes of length prefix for byte slices. -// -// In TPM 1.2 this is 4 bytes. -// In TPM 2.0 this is 2 bytes. -var lengthPrefixSize int +// Encoding implements encoding logic for different versions of the TPM +// specification. +type Encoding struct { + // lengthPrefixSize is the size in bytes of length prefix for byte slices. + // + // In TPM 1.2 this is 4 bytes. + // In TPM 2.0 this is 2 bytes. + lengthPrefixSize int +} + +var ( + // Encoding1_2 implements TPM 1.2 encoding. + Encoding1_2 = &Encoding{ + lengthPrefixSize: tpm12PrefixSize, + } + // Encoding2_0 implements TPM 2.0 encoding. + Encoding2_0 = &Encoding{ + lengthPrefixSize: tpm20PrefixSize, + } + + defaultEncoding *Encoding +) const ( tpm12PrefixSize = 4 tpm20PrefixSize = 2 ) -// UseTPM12LengthPrefixSize makes Pack/Unpack use TPM 1.2 encoding for byte -// arrays. -func UseTPM12LengthPrefixSize() { - lengthPrefixSize = tpm12PrefixSize +// UseTPM12Encoding makes the package level Pack/Unpack functions use +// TPM 1.2 encoding for byte arrays. +func UseTPM12Encoding() { + defaultEncoding = Encoding1_2 } -// UseTPM20LengthPrefixSize makes Pack/Unpack use TPM 2.0 encoding for byte -// arrays. -func UseTPM20LengthPrefixSize() { - lengthPrefixSize = tpm20PrefixSize +// UseTPM20Encoding makes the package level Pack/Unpack functions use +// TPM 2.0 encoding for byte arrays. +func UseTPM20Encoding() { + defaultEncoding = Encoding2_0 } // packedSize computes the size of a sequence of types that can be passed to // binary.Read or binary.Write. -func packedSize(elts ...interface{}) (int, error) { +func (enc *Encoding) packedSize(elts ...interface{}) (int, error) { var size int for _, e := range elts { marshaler, ok := e.(SelfMarshaler) @@ -59,7 +76,7 @@ func packedSize(elts ...interface{}) (int, error) { v := reflect.ValueOf(e) switch v.Kind() { case reflect.Ptr: - s, err := packedSize(reflect.Indirect(v).Interface()) + s, err := enc.packedSize(reflect.Indirect(v).Interface()) if err != nil { return 0, err } @@ -67,7 +84,7 @@ func packedSize(elts ...interface{}) (int, error) { size += s case reflect.Struct: for i := 0; i < v.NumField(); i++ { - s, err := packedSize(v.Field(i).Interface()) + s, err := enc.packedSize(v.Field(i).Interface()) if err != nil { return 0, err } @@ -77,7 +94,7 @@ func packedSize(elts ...interface{}) (int, error) { case reflect.Slice: switch s := e.(type) { case []byte: - size += lengthPrefixSize + len(s) + size += enc.lengthPrefixSize + len(s) case RawBytes: size += len(s) default: @@ -100,16 +117,27 @@ func packedSize(elts ...interface{}) (int, error) { // fixed length or slices of fixed-length types and packs them into a single // byte array using binary.Write. It updates the CommandHeader to have the right // length. -func packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) { +func (enc *Encoding) packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) { hdrSize := binary.Size(ch) - bodySize, err := packedSize(cmd...) + bodySize, err := enc.packedSize(cmd...) if err != nil { return nil, fmt.Errorf("couldn't compute packed size for message body: %v", err) } ch.Size = uint32(hdrSize + bodySize) in := []interface{}{ch} in = append(in, cmd...) - return Pack(in...) + return enc.Pack(in...) +} + +// Pack encodes a set of elements using the package's default encoding. +// +// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling +// this method. +func Pack(elts ...interface{}) ([]byte, error) { + if defaultEncoding == nil { + return nil, errors.New("default encoding not initialized") + } + return defaultEncoding.Pack(elts...) } // Pack encodes a set of elements into a single byte array, using @@ -119,13 +147,9 @@ func packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) { // It has one difference from encoding/binary: it encodes byte slices with a // prepended length, to match how the TPM encodes variable-length arrays. If // you wish to add a byte slice without length prefix, use RawBytes. -func Pack(elts ...interface{}) ([]byte, error) { - if lengthPrefixSize == 0 { - return nil, errors.New("lengthPrefixSize must be initialized") - } - +func (enc *Encoding) Pack(elts ...interface{}) ([]byte, error) { buf := new(bytes.Buffer) - if err := packType(buf, elts...); err != nil { + if err := enc.packType(buf, elts...); err != nil { return nil, err } @@ -137,7 +161,7 @@ func Pack(elts ...interface{}) ([]byte, error) { // lengthPrefixSize size followed by the bytes. The function unpackType // performs the inverse operation of unpacking slices stored in this manner and // using encoding/binary for everything else. -func packType(buf io.Writer, elts ...interface{}) error { +func (enc *Encoding) packType(buf io.Writer, elts ...interface{}) error { for _, e := range elts { marshaler, ok := e.(SelfMarshaler) if ok { @@ -149,20 +173,20 @@ func packType(buf io.Writer, elts ...interface{}) error { v := reflect.ValueOf(e) switch v.Kind() { case reflect.Ptr: - if err := packType(buf, reflect.Indirect(v).Interface()); err != nil { + if err := enc.packType(buf, reflect.Indirect(v).Interface()); err != nil { return err } case reflect.Struct: // TODO(awly): Currently packType cannot handle non-struct fields that implement SelfMarshaler for i := 0; i < v.NumField(); i++ { - if err := packType(buf, v.Field(i).Interface()); err != nil { + if err := enc.packType(buf, v.Field(i).Interface()); err != nil { return err } } case reflect.Slice: switch s := e.(type) { case []byte: - switch lengthPrefixSize { + switch enc.lengthPrefixSize { case tpm20PrefixSize: if err := binary.Write(buf, binary.BigEndian, uint16(len(s))); err != nil { return err @@ -172,7 +196,7 @@ func packType(buf io.Writer, elts ...interface{}) error { return err } default: - return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", lengthPrefixSize) + return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", enc.lengthPrefixSize) } if err := binary.Write(buf, binary.BigEndian, s); err != nil { return err @@ -195,21 +219,45 @@ func packType(buf io.Writer, elts ...interface{}) error { return nil } +// Unpack is a convenience wrapper around UnpackBuf using the package's default +// encoding. +// +// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling +// this method. +func Unpack(b []byte, elts ...interface{}) (int, error) { + if defaultEncoding == nil { + return 0, errors.New("default encoding not initialized") + } + return defaultEncoding.Unpack(b, elts...) +} + // Unpack is a convenience wrapper around UnpackBuf. Unpack returns the number // of bytes read from b to fill elts and error, if any. -func Unpack(b []byte, elts ...interface{}) (int, error) { +func (enc *Encoding) Unpack(b []byte, elts ...interface{}) (int, error) { buf := bytes.NewBuffer(b) - err := UnpackBuf(buf, elts...) + err := enc.UnpackBuf(buf, elts...) read := len(b) - buf.Len() return read, err } +// UnpackBuf recursively unpacks types from a reader using the package's default +// encoding. +// +// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling +// this method. +func UnpackBuf(buf io.Reader, elts ...interface{}) error { + if defaultEncoding == nil { + return errors.New("default encoding not initialized") + } + return defaultEncoding.UnpackBuf(buf, elts...) +} + // UnpackBuf recursively unpacks types from a reader just as encoding/binary // does under binary.BigEndian, but with one difference: it unpacks a byte // slice by first reading an integer with lengthPrefixSize bytes, then reading // that many bytes. It assumes that incoming values are pointers to values so // that, e.g., underlying slices can be resized as needed. -func UnpackBuf(buf io.Reader, elts ...interface{}) error { +func (enc *Encoding) UnpackBuf(buf io.Reader, elts ...interface{}) error { for _, e := range elts { v := reflect.ValueOf(e) k := v.Kind() @@ -233,7 +281,7 @@ func UnpackBuf(buf io.Reader, elts ...interface{}) error { case reflect.Struct: // Decompose the struct and copy over the values. for i := 0; i < iv.NumField(); i++ { - if err := UnpackBuf(buf, iv.Field(i).Addr().Interface()); err != nil { + if err := enc.UnpackBuf(buf, iv.Field(i).Addr().Interface()); err != nil { return err } } @@ -250,21 +298,21 @@ func UnpackBuf(buf io.Reader, elts ...interface{}) error { } size = int(tmpSize) // TPM 2.0 - case lengthPrefixSize == tpm20PrefixSize: + case enc.lengthPrefixSize == tpm20PrefixSize: var tmpSize uint16 if err := binary.Read(buf, binary.BigEndian, &tmpSize); err != nil { return err } size = int(tmpSize) // TPM 1.2 - case lengthPrefixSize == tpm12PrefixSize: + case enc.lengthPrefixSize == tpm12PrefixSize: var tmpSize uint32 if err := binary.Read(buf, binary.BigEndian, &tmpSize); err != nil { return err } size = int(tmpSize) default: - return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", lengthPrefixSize) + return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", enc.lengthPrefixSize) } // A zero size is used by the TPM to signal that certain elements diff --git a/tpmutil/encoding_test.go b/tpmutil/encoding_test.go index 9efe3155..106db04b 100644 --- a/tpmutil/encoding_test.go +++ b/tpmutil/encoding_test.go @@ -22,10 +22,6 @@ import ( "testing" ) -func init() { - UseTPM12LengthPrefixSize() -} - type invalidPacked struct { A []int B uint32 @@ -61,7 +57,7 @@ func testEncodingInvalidSlices(t *testing.T, f func(io.Writer, interface{}) erro func TestEncodingPackedSizeInvalid(t *testing.T) { f := func(w io.Writer, i interface{}) error { - _, err := packedSize(i) + _, err := Encoding1_2.packedSize(i) return err } @@ -70,7 +66,7 @@ func TestEncodingPackedSizeInvalid(t *testing.T) { func TestEncodingPackTypeInvalid(t *testing.T) { f := func(w io.Writer, i interface{}) error { - return packType(w, i) + return Encoding1_2.packType(w, i) } testEncodingInvalidSlices(t, f) @@ -106,7 +102,7 @@ func TestEncodingPackedSize(t *testing.T) { {[]byte(nil), 4}, } for _, tt := range tests { - if s, err := packedSize(tt.in); err != nil || s != tt.want { + if s, err := Encoding1_2.packedSize(tt.in); err != nil || s != tt.want { t.Errorf("packedSize(%#v): %d, want %d", tt.in, s, tt.want) } } @@ -125,7 +121,7 @@ func TestEncodingPackType(t *testing.T) { RawBytes(buf), } for _, i := range inputs { - if err := packType(ioutil.Discard, i); err != nil { + if err := Encoding1_2.packType(ioutil.Discard, i); err != nil { t.Errorf("packType(%#v): %v", i, err) } } @@ -140,7 +136,7 @@ func TestEncodingPackTypeWriteFail(t *testing.T) { {3, []byte(nil)}, } for _, tt := range tests { - if err := packType(&limitedDiscard{tt.limit}, tt.in); err == nil { + if err := Encoding1_2.packType(&limitedDiscard{tt.limit}, tt.in); err == nil { t.Errorf("packType(%#v) with write size limit %d returned nil, want error", tt.in, tt.limit) } } @@ -167,7 +163,7 @@ func (l *limitedDiscard) Write(p []byte) (n int, err error) { func TestEncodingCommandHeaderInvalidBody(t *testing.T) { var invalid []int ch := commandHeader{1, 0, 2} - _, err := packWithHeader(ch, invalid) + _, err := Encoding1_2.packWithHeader(ch, invalid) if err == nil { t.Fatal("packWithHeader incorrectly packed a body that with an invalid int slice member") } @@ -176,12 +172,12 @@ func TestEncodingCommandHeaderInvalidBody(t *testing.T) { func TestEncodingInvalidPack(t *testing.T) { var invalid []int ch := commandHeader{1, 0, 2} - _, err := packWithHeader(ch, invalid) + _, err := Encoding1_2.packWithHeader(ch, invalid) if err == nil { t.Fatal("packWithHeader incorrectly packed a body that with an invalid int slice member") } - _, err = Pack(invalid) + _, err = Encoding1_2.Pack(invalid) if err == nil { t.Fatal("pack incorrectly packed a slice of int") } @@ -192,14 +188,14 @@ func TestEncodingCommandHeaderEncoding(t *testing.T) { var c uint32 = 137 in := c - b, err := packWithHeader(ch, in) + b, err := Encoding1_2.packWithHeader(ch, in) if err != nil { t.Fatal("Couldn't pack the bytes:", err) } var hdr commandHeader var size uint32 - if _, err := Unpack(b, &hdr, &size); err != nil { + if _, err := Encoding1_2.Unpack(b, &hdr, &size); err != nil { t.Fatal("Couldn't unpack the packed bytes") } @@ -214,19 +210,19 @@ func TestEncodingInvalidUnpack(t *testing.T) { // The value ui is a serialization of uint32(0). ui := []byte{0, 0, 0, 0} uiBuf := bytes.NewBuffer(ui) - if err := UnpackBuf(uiBuf, i); err == nil { + if err := Encoding1_2.UnpackBuf(uiBuf, i); err == nil { t.Fatal("UnpackBuf incorrectly deserialized into a nil pointer") } var ii uint32 - if err := UnpackBuf(uiBuf, ii); err == nil { + if err := Encoding1_2.UnpackBuf(uiBuf, ii); err == nil { t.Fatal("UnpackBuf incorrectly deserialized into a non pointer") } var b []byte var empty []byte emptyBuf := bytes.NewBuffer(empty) - if err := UnpackBuf(emptyBuf, &b); err == nil { + if err := Encoding1_2.UnpackBuf(emptyBuf, &b); err == nil { t.Fatal("UnpackBuf incorrectly deserialized an empty byte array into a byte slice") } @@ -234,14 +230,14 @@ func TestEncodingInvalidUnpack(t *testing.T) { // The slice ui represents uint32(1), which is the length of an empty byte array. ui2 := []byte{0, 0, 0, 1} uiBuf2 := bytes.NewBuffer(ui2) - if err := UnpackBuf(uiBuf2, &b); err == nil { + if err := Encoding1_2.UnpackBuf(uiBuf2, &b); err == nil { t.Fatal("UnpackBuf incorrectly deserialized a byte array that didn't have enough bytes available") } var iii []int ui3 := []byte{0, 0, 0, 1} uiBuf3 := bytes.NewBuffer(ui3) - if err := UnpackBuf(uiBuf3, &iii); err == nil { + if err := Encoding1_2.UnpackBuf(uiBuf3, &iii); err == nil { t.Fatal("UnpackBuf incorrectly deserialized into a slice of ints (only byte slices are supported)") } @@ -253,14 +249,14 @@ func TestEncodingUnpack(t *testing.T) { // The slice ui represents uint32(0), which is the length of an empty byte array. ui := []byte{0, 0, 0, 0} uiBuf := bytes.NewBuffer(ui) - if err := UnpackBuf(uiBuf, &b); err != nil { + if err := Encoding1_2.UnpackBuf(uiBuf, &b); err != nil { t.Fatal("UnpackBuf failed to unpack the empty byte array") } // A byte slice of length 1 with a single entry: b[0] == 137 ui2 := []byte{0, 0, 0, 1, 137} uiBuf2 := bytes.NewBuffer(ui2) - if err := UnpackBuf(uiBuf2, &b); err != nil { + if err := Encoding1_2.UnpackBuf(uiBuf2, &b); err != nil { t.Fatal("UnpackBuf failed to unpack a byte array with a single value in it") } @@ -269,12 +265,12 @@ func TestEncodingUnpack(t *testing.T) { } sp := simplePacked{137, 138} - bsp, err := Pack(sp) + bsp, err := Encoding1_2.Pack(sp) if err != nil { t.Fatal("Couldn't pack a simple struct:", err) } var sp2 simplePacked - if _, err := Unpack(bsp, &sp2); err != nil { + if _, err := Encoding1_2.Unpack(bsp, &sp2); err != nil { t.Fatal("Couldn't unpack a simple struct:", err) } @@ -283,17 +279,17 @@ func TestEncodingUnpack(t *testing.T) { } // Try unpacking a version that's missing a byte at the end. - if _, err := Unpack(bsp[:len(bsp)-1], &sp2); err == nil { + if _, err := Encoding1_2.Unpack(bsp[:len(bsp)-1], &sp2); err == nil { t.Fatal("unpack incorrectly unpacked from a byte array that didn't have enough values") } np := nestedPacked{sp, 139} - bnp, err := Pack(np) + bnp, err := Encoding1_2.Pack(np) if err != nil { t.Fatal("Couldn't pack a nested struct") } var np2 nestedPacked - if _, err := Unpack(bnp, &np2); err != nil { + if _, err := Encoding1_2.Unpack(bnp, &np2); err != nil { t.Fatal("Couldn't unpack a nested struct:", err) } if np.SP.A != np2.SP.A || np.SP.B != np2.SP.B || np.C != np2.C { @@ -301,12 +297,12 @@ func TestEncodingUnpack(t *testing.T) { } ns := nestedSlice{137, b} - bns, err := Pack(ns) + bns, err := Encoding1_2.Pack(ns) if err != nil { t.Fatal("Couldn't pack a struct with a nested byte slice:", err) } var ns2 nestedSlice - if _, err := Unpack(bns, &ns2); err != nil { + if _, err := Encoding1_2.Unpack(bns, &ns2); err != nil { t.Fatal("Couldn't unpacked a struct with a nested slice:", err) } if ns.A != ns2.A || !bytes.Equal(ns.S, ns2.S) { @@ -314,7 +310,7 @@ func TestEncodingUnpack(t *testing.T) { } var hs []Handle - if _, err := Unpack([]byte{0, 3, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, &hs); err != nil { + if _, err := Encoding1_2.Unpack([]byte{0, 3, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, &hs); err != nil { t.Fatal("Couldn't unpack a list of Handles:", err) } if want := []Handle{0x01020304, 0x05060708, 0x090a0b0c}; !reflect.DeepEqual(want, hs) { @@ -324,20 +320,20 @@ func TestEncodingUnpack(t *testing.T) { func TestPartialUnpack(t *testing.T) { u1, u2 := uint32(1), uint32(2) - buf, err := Pack(u1, u2) + buf, err := Encoding1_2.Pack(u1, u2) if err != nil { t.Fatalf("packing uint32 value: %v", err) } var gu1, gu2 uint32 - read1, err := Unpack(buf, &gu1) + read1, err := Encoding1_2.Unpack(buf, &gu1) if err != nil { t.Fatalf("unpacking first uint32 value: %v", err) } if gu1 != u1 { t.Errorf("first unpacked value: got %d, want %d", gu1, u1) } - read2, err := Unpack(buf[read1:], &gu2) + read2, err := Encoding1_2.Unpack(buf[read1:], &gu2) if err != nil { t.Fatalf("unpacking second uint32 value: %v", err) } diff --git a/tpmutil/run.go b/tpmutil/run.go index 641febef..4eeb8c31 100644 --- a/tpmutil/run.go +++ b/tpmutil/run.go @@ -13,10 +13,6 @@ // limitations under the License. // Package tpmutil provides common utility functions for both TPM 1.2 and TPM 2.0 devices. -// -// Users should call either UseTPM12LengthPrefixSize or -// UseTPM20LengthPrefixSize before using this package, depending on their type -// of TPM device. package tpmutil import ( @@ -31,16 +27,27 @@ import ( // returning a header and a body in separate responses. const maxTPMResponse = 4096 +// RunCommand executes cmd with the package's default encoding. +// +// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling +// this method. +func RunCommand(rw io.ReadWriter, tag Tag, cmd Command, in ...interface{}) ([]byte, ResponseCode, error) { + if defaultEncoding == nil { + return nil, 0, errors.New("default encoding not initialized") + } + return defaultEncoding.RunCommand(rw, tag, cmd, in...) +} + // RunCommand executes cmd with given tag and arguments. Returns TPM response // body (without response header) and response code from the header. Returned // error may be nil if response code is not RCSuccess; caller should check // both. -func RunCommand(rw io.ReadWriter, tag Tag, cmd Command, in ...interface{}) ([]byte, ResponseCode, error) { +func (enc *Encoding) RunCommand(rw io.ReadWriter, tag Tag, cmd Command, in ...interface{}) ([]byte, ResponseCode, error) { if rw == nil { return nil, 0, errors.New("nil TPM handle") } ch := commandHeader{tag, 0, cmd} - inb, err := packWithHeader(ch, in...) + inb, err := enc.packWithHeader(ch, in...) if err != nil { return nil, 0, err }