From 489db950c3053c5a1e5af195071372326a63c2e5 Mon Sep 17 00:00:00 2001 From: kungf Date: Mon, 2 Aug 2021 15:38:57 +0800 Subject: [PATCH] check msg magic to avoid invalid memory allocate Now raft uses data[0:4] as the message size, but data[0:4] can be any value, if we not check the messages whether they are sended by raft itself, the memory may be very big. Signed-off-by: kungf --- proto/codec.go | 101 +++++++++++++++++++++++++++++-------------------- proto/proto.go | 6 +++ 2 files changed, 66 insertions(+), 41 deletions(-) diff --git a/proto/codec.go b/proto/codec.go index 07e8a0a..667009d 100644 --- a/proto/codec.go +++ b/proto/codec.go @@ -15,6 +15,7 @@ package proto import ( "encoding/binary" + "fmt" "io" "sort" @@ -26,7 +27,7 @@ const ( peer_size uint64 = 11 entry_header uint64 = 17 snapmeta_header uint64 = 20 - message_header uint64 = 68 + message_header uint64 = 68 + 1 + 4 // 1 is the len(MsgMagic), 4 is the len(entries + context) ) // Peer codec @@ -149,12 +150,8 @@ func (e *Entry) Decode(datas []byte) { } // Message codec -func (m *Message) Size() uint64 { - if m.Type == ReqMsgSnapShot { - return message_header + m.SnapshotMeta.Size() - } - - size := message_header + 4 +func (m *Message) entryAndContextSize() uint64 { + size := uint64(4) if len(m.Entries) > 0 { for _, e := range m.Entries { size = size + e.Size() + 4 @@ -170,28 +167,35 @@ func (m *Message) Encode(w io.Writer) error { buf := getByteSlice() defer returnByteSlice(buf) - binary.BigEndian.PutUint32(buf, uint32(m.Size())) - buf[4] = version1 - buf[5] = byte(m.Type) + buf[0] = MsgMagic + buf[1] = version1 + buf[2] = byte(m.Type) if m.ForceVote { - buf[6] = 1 + buf[3] = 1 } else { - buf[6] = 0 + buf[3] = 0 } if m.Reject { - buf[7] = 1 + buf[4] = 1 } else { - buf[7] = 0 + buf[4] = 0 } - binary.BigEndian.PutUint64(buf[8:], m.RejectIndex) - binary.BigEndian.PutUint64(buf[16:], m.ID) - binary.BigEndian.PutUint64(buf[24:], m.From) - binary.BigEndian.PutUint64(buf[32:], m.To) - binary.BigEndian.PutUint64(buf[40:], m.Term) - binary.BigEndian.PutUint64(buf[48:], m.LogTerm) - binary.BigEndian.PutUint64(buf[56:], m.Index) - binary.BigEndian.PutUint64(buf[64:], m.Commit) - if _, err := w.Write(buf[0 : message_header+4]); err != nil { + binary.BigEndian.PutUint64(buf[5:], m.RejectIndex) + binary.BigEndian.PutUint64(buf[13:], m.ID) + binary.BigEndian.PutUint64(buf[21:], m.From) + binary.BigEndian.PutUint64(buf[29:], m.To) + binary.BigEndian.PutUint64(buf[37:], m.Term) + binary.BigEndian.PutUint64(buf[45:], m.LogTerm) + binary.BigEndian.PutUint64(buf[53:], m.Index) + binary.BigEndian.PutUint64(buf[61:], m.Commit) + + if m.Type == ReqMsgSnapShot { + binary.BigEndian.PutUint32(buf[69:], uint32(m.SnapshotMeta.Size())) + } else { + binary.BigEndian.PutUint32(buf[69:], uint32(m.entryAndContextSize())) + } + + if _, err := w.Write(buf[0:message_header]); err != nil { return err } @@ -227,31 +231,46 @@ func (m *Message) Decode(r *util.BufferReader) error { datas []byte err error ) - if datas, err = r.ReadFull(4); err != nil { + if datas, err = r.ReadFull(int(message_header)); err != nil { return err } - if datas, err = r.ReadFull(int(binary.BigEndian.Uint32(datas))); err != nil { - return err + + if MsgType(datas[0]) != MsgMagic { + return fmt.Errorf("Invalid message magic") } - ver := datas[0] + ver := datas[1] if ver == version1 { - m.Type = MsgType(datas[1]) - m.ForceVote = (datas[2] == 1) - m.Reject = (datas[3] == 1) - m.RejectIndex = binary.BigEndian.Uint64(datas[4:]) - m.ID = binary.BigEndian.Uint64(datas[12:]) - m.From = binary.BigEndian.Uint64(datas[20:]) - m.To = binary.BigEndian.Uint64(datas[28:]) - m.Term = binary.BigEndian.Uint64(datas[36:]) - m.LogTerm = binary.BigEndian.Uint64(datas[44:]) - m.Index = binary.BigEndian.Uint64(datas[52:]) - m.Commit = binary.BigEndian.Uint64(datas[60:]) + m.Type = MsgType(datas[2]) + if m.Type >= MsgTypeEnd { + return fmt.Errorf("Unknow message type %v", m.Type) + + } + m.ForceVote = (datas[3] == 1) + m.Reject = (datas[4] == 1) + m.RejectIndex = binary.BigEndian.Uint64(datas[5:]) + m.ID = binary.BigEndian.Uint64(datas[13:]) + m.From = binary.BigEndian.Uint64(datas[21:]) + m.To = binary.BigEndian.Uint64(datas[29:]) + m.Term = binary.BigEndian.Uint64(datas[37:]) + m.LogTerm = binary.BigEndian.Uint64(datas[45:]) + m.Index = binary.BigEndian.Uint64(datas[53:]) + m.Commit = binary.BigEndian.Uint64(datas[61:]) + + dataSize := binary.BigEndian.Uint32(datas[69:]) + if dataSize <= 0 { + return nil + } + + if datas, err = r.ReadFull(int(dataSize)); err != nil { + return err + } + if m.Type == ReqMsgSnapShot { - m.SnapshotMeta.Decode(datas[message_header:]) + m.SnapshotMeta.Decode(datas[0:]) } else { - size := binary.BigEndian.Uint32(datas[message_header:]) - start := message_header + 4 + size := binary.BigEndian.Uint32(datas[0:]) + start := uint64(4) if size > 0 { for i := uint32(0); i < size; i++ { esize := binary.BigEndian.Uint32(datas[start:]) diff --git a/proto/proto.go b/proto/proto.go index 27a4bc6..47fa135 100644 --- a/proto/proto.go +++ b/proto/proto.go @@ -42,6 +42,10 @@ const ( LeaseMsgTimeout ReqCheckQuorum RespCheckQuorum + + MsgTypeEnd + + MsgMagic = 0x5A ) const ( @@ -163,6 +167,8 @@ func (t MsgType) String() string { return "ReqCheckQuorum" case 15: return "RespCheckQuorum" + case 90: + return "MsgMagic" } return "unkown" }