From 02c18cfda774bb5c6407243981ff308502c45f71 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Wed, 15 Jan 2025 01:03:46 +0000 Subject: [PATCH] Change Record.IndexType to use MRType MRType provides stronger typing than a uint8. --- cel/canonical_eventlog.go | 17 +++++++++++------ cel/canonical_eventlog_test.go | 6 +++--- extract/extract_test.go | 5 ++++- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/cel/canonical_eventlog.go b/cel/canonical_eventlog.go index bd69e38..1a47bf1 100644 --- a/cel/canonical_eventlog.go +++ b/cel/canonical_eventlog.go @@ -120,7 +120,7 @@ type Record struct { // Generic Measurement Register index number, register type // is determined by IndexType Index uint8 - IndexType uint8 + IndexType MRType Digests map[crypto.Hash][]byte Content TLV } @@ -205,7 +205,7 @@ func (c *eventLog) AppendEvent(event Content, bankAlgos []crypto.Hash, mrIndex i Index: uint8(mrIndex), Digests: digestMap, Content: eventTlv, - IndexType: uint8(c.Type), + IndexType: c.Type, } c.Recs = append(c.Recs, celrPCR) @@ -246,8 +246,13 @@ func createIndexField(indexType uint8, indexNum uint8) TLV { // unmarshalIndex takes in a TLV with its type equals to the PCR or CCMR type value, and // return its index number. -func unmarshalIndex(tlv TLV) (indexType uint8, pcrNum uint8, err error) { - if tlv.Type != uint8(PCRType) && tlv.Type != uint8(CCMRType) { +func unmarshalIndex(tlv TLV) (indexType MRType, pcrNum uint8, err error) { + switch tlv.Type { + case uint8(PCRType): + indexType = PCRType + case uint8(CCMRType): + indexType = CCMRType + default: return 0, 0, fmt.Errorf("type of the TLV [%d] indicates it is not a PCR [%d] or a CCMR [%d] field ", tlv.Type, uint8(PCRType), uint8(CCMRType)) } @@ -257,7 +262,7 @@ func unmarshalIndex(tlv TLV) (indexType uint8, pcrNum uint8, err error) { len(tlv.Value), regIndexValueLength) } - return tlv.Type, tlv.Value[0], nil + return indexType, tlv.Value[0], nil } func createDigestField(digestMap map[crypto.Hash][]byte) (TLV, error) { @@ -318,7 +323,7 @@ func (r *Record) EncodeCELR(buf *bytes.Buffer) error { return err } - indexField, err := createIndexField(r.IndexType, r.Index).MarshalBinary() + indexField, err := createIndexField(uint8(r.IndexType), r.Index).MarshalBinary() if err != nil { return err } diff --git a/cel/canonical_eventlog_test.go b/cel/canonical_eventlog_test.go index 049c83b..ce1bf47 100644 --- a/cel/canonical_eventlog_test.go +++ b/cel/canonical_eventlog_test.go @@ -50,13 +50,13 @@ func TestCELEncodingDecoding(t *testing.T) { if decodedcel.Records()[1].RecNum != 1 { t.Errorf("recnum mismatch") } - if decodedcel.Records()[0].IndexType != uint8(tc) { + if decodedcel.Records()[0].IndexType != tc { t.Errorf("index type mismatch") } if decodedcel.Records()[0].Index != uint8(16) { t.Errorf("pcr value mismatch") } - if decodedcel.Records()[1].IndexType != uint8(tc) { + if decodedcel.Records()[1].IndexType != tc { t.Errorf("index type mismatch") } if decodedcel.Records()[1].Index != uint8(23) { @@ -90,7 +90,7 @@ func TestCELAppendDifferentMRTypes(t *testing.T) { appendFakeMREventOrFatal(t, &el, rot, 8, measuredHashes, event) for _, rec := range el.Records() { - if rec.IndexType != uint8(tc) { + if rec.IndexType != tc { t.Errorf("AppendEvent(): got Index Type %v, want type %v", rec.IndexType, tc) } } diff --git a/extract/extract_test.go b/extract/extract_test.go index f03a434..5e1d31a 100644 --- a/extract/extract_test.go +++ b/extract/extract_test.go @@ -28,6 +28,7 @@ import ( "github.com/google/go-eventlog/register" "github.com/google/go-eventlog/tcg" "github.com/google/go-eventlog/testdata" + "google.golang.org/protobuf/encoding/protojson" ) func TestExtractFirmwareLogStateRTMR(t *testing.T) { @@ -147,10 +148,12 @@ func TestExtractFirmwareLogStateRTMR(t *testing.T) { t.Run(tc.name, func(t *testing.T) { evts := getCCELEvents(t) tc.mutate(evts) - _, err := FirmwareLogState(evts, crypto.SHA384, RTMRRegisterConfig, Opts{Loader: GRUB}) + fs, err := FirmwareLogState(evts, crypto.SHA384, RTMRRegisterConfig, Opts{Loader: GRUB}) if (err != nil) != tc.expectErr { t.Errorf("ExtractFirmwareLogState(%v) = got %v, wantErr: %v", tc.name, err, tc.expectErr) } + bts, _ := protojson.MarshalOptions{UseProtoNames: true}.Marshal(fs) + t.Log(string(bts)) }) } }