diff --git a/frame-reader.go b/frame-reader.go index 4ccfc23..22eac2b 100644 --- a/frame-reader.go +++ b/frame-reader.go @@ -3,96 +3,62 @@ package mint type framing interface { - headerLen() int - defaultReadLen() int - frameLen(hdr []byte) (int, error) + parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) } -const ( - kFrameReaderHdr = 0 - kFrameReaderBody = 1 -) +type lastNBytesFraming struct { + headerSize int + lengthSize int +} -type frameNextAction func(f *frameReader) error +func (lnb lastNBytesFraming) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { + headerReady = len(buffer) >= lnb.headerSize + if !headerReady { + return + } + + headerLen = lnb.headerSize + val, _ := decodeUint(buffer[lnb.headerSize-lnb.lengthSize:], lnb.lengthSize) + bodyLen = int(val) + return +} type frameReader struct { - details framing - state uint8 - header []byte - body []byte - working []byte - writeOffset int - remainder []byte + details framing + remainder []byte } func newFrameReader(d framing) *frameReader { - hdr := make([]byte, d.headerLen()) return &frameReader{ - d, - kFrameReaderHdr, - hdr, - nil, - hdr, - 0, - nil, + details: d, + remainder: make([]byte, 0), } } -func dup(a []byte) []byte { - r := make([]byte, len(a)) - copy(r, a) - return r -} - -func (f *frameReader) needed() int { - tmp := (len(f.working) - f.writeOffset) - len(f.remainder) - if tmp < 0 { - return 0 - } - return tmp +func (f *frameReader) ready() bool { + headerReady, headerLen, bodyLen := f.details.parse(f.remainder) + //logf(logTypeFrameReader, "header=%v body=(%v > %v)", headerReady, len(f.remainder), headerLen+bodyLen) + return headerReady && len(f.remainder) >= headerLen+bodyLen } func (f *frameReader) addChunk(in []byte) { - // Append to the buffer. + // Append to the buffer logf(logTypeFrameReader, "Appending %v", len(in)) f.remainder = append(f.remainder, in...) } -func (f *frameReader) process() (hdr []byte, body []byte, err error) { - for f.needed() == 0 { - logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset) - // Fill out our working block - copied := copy(f.working[f.writeOffset:], f.remainder) - f.remainder = f.remainder[copied:] - f.writeOffset += copied - if f.writeOffset < len(f.working) { - logf(logTypeVerbose, "Read would have blocked 1") - return nil, nil, AlertWouldBlock - } - // Reset the write offset, because we are now full. - f.writeOffset = 0 - - // We have read a full frame - if f.state == kFrameReaderBody { - logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder)) - f.state = kFrameReaderHdr - f.working = f.header - return dup(f.header), dup(f.body), nil - } - - // We have read the header - bodyLen, err := f.details.frameLen(f.header) - if err != nil { - return nil, nil, err - } - logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen) - - f.body = make([]byte, bodyLen) - f.working = f.body - f.writeOffset = 0 - f.state = kFrameReaderBody +func (f *frameReader) next() ([]byte, []byte, error) { + // Check to see if we have enough data + headerReady, headerLen, bodyLen := f.details.parse(f.remainder) + if !headerReady || len(f.remainder) < headerLen+bodyLen { + logf(logTypeVerbose, "Read would have blocked") + return nil, nil, AlertWouldBlock } - logf(logTypeVerbose, "Read would have blocked 2") - return nil, nil, AlertWouldBlock + // Read a record off the front of the buffer + header, body := make([]byte, headerLen), make([]byte, bodyLen) + copy(header, f.remainder[:headerLen]) + copy(body, f.remainder[headerLen:headerLen+bodyLen]) + f.remainder = f.remainder[headerLen+bodyLen:] + return header, body, nil } diff --git a/frame-reader_test.go b/frame-reader_test.go index 4ea5efd..3d29508 100644 --- a/frame-reader_test.go +++ b/frame-reader_test.go @@ -1,75 +1,126 @@ package mint import ( + "strings" "testing" + + "github.com/bifurcation/mint/syntax" ) -var kTestFrame = []byte{0x00, 0x05, 'a', 'b', 'c', 'd', 'e'} -var kTestEmptyFrame = []byte{0x00, 0x00} +var ( + fixedFullFrame = unhex("ff00056162636465") + fixedEmptyFrame = unhex("ff0000") + variableFullFrame = unhex("40ff" + strings.Repeat("A0", 255)) + variableEmptyFrame = unhex("00") +) -type simpleHeader struct{} +type variableHeader struct{} -func (h simpleHeader) headerLen() int { - return 2 -} +func (h variableHeader) parse(buffer []byte) (headerReady bool, headerLen, bodyLen int) { + if len(buffer) == 0 { + headerReady = false + return + } -func (h simpleHeader) defaultReadLen() int { - return 1024 -} + // XXX: Need a way to return parse errors other than "insufficient data" + length := struct { + Value uint64 `tls:"varint"` + }{} + read, err := syntax.Unmarshal(buffer, &length) -func (h simpleHeader) frameLen(hdr []byte) (int, error) { - if len(hdr) != 2 { - panic("Assert!") + headerReady = (err == nil) + if !headerReady { + return } - return (int(hdr[0]) << 8) | int(hdr[1]), nil + headerLen = read + bodyLen = int(length.Value) + return +} + +type frameReaderTester struct { + details framing + headerLenFull int + fullFrame []byte + headerLenEmpty int + emptyFrame []byte +} + +func (frt frameReaderTester) checkFrameFull(t *testing.T, hdr, body []byte) { + assertByteEquals(t, hdr, frt.fullFrame[:frt.headerLenFull]) + assertByteEquals(t, body, frt.fullFrame[frt.headerLenFull:]) } -func checkFrame(t *testing.T, hdr []byte, body []byte) { - assertByteEquals(t, hdr, kTestFrame[:2]) - assertByteEquals(t, body, kTestFrame[2:]) +func (frt frameReaderTester) checkFrameEmpty(t *testing.T, hdr, body []byte) { + assertByteEquals(t, hdr, frt.emptyFrame[:frt.headerLenEmpty]) + assertByteEquals(t, body, frt.emptyFrame[frt.headerLenEmpty:]) } -func TestFrameReaderFullFrame(t *testing.T) { - r := newFrameReader(simpleHeader{}) - r.addChunk(kTestFrame) - hdr, body, err := r.process() +func (frt frameReaderTester) TestFrames(t *testing.T) { + r := newFrameReader(frt.details) + r.addChunk(frt.fullFrame) + hdr, body, err := r.next() assertNotError(t, err, "Couldn't read frame 1") - checkFrame(t, hdr, body) + frt.checkFrameFull(t, hdr, body) - r.addChunk(kTestFrame) - hdr, body, err = r.process() + r.addChunk(frt.emptyFrame) + hdr, body, err = r.next() assertNotError(t, err, "Couldn't read frame 2") - checkFrame(t, hdr, body) + frt.checkFrameEmpty(t, hdr, body) } -func TestFrameReaderTwoFrames(t *testing.T) { - r := newFrameReader(simpleHeader{}) - r.addChunk(kTestFrame) - r.addChunk(kTestFrame) - hdr, body, err := r.process() +func (frt frameReaderTester) TestTwoFrames(t *testing.T) { + r := newFrameReader(frt.details) + r.addChunk(frt.fullFrame) + r.addChunk(frt.fullFrame) + hdr, body, err := r.next() assertNotError(t, err, "Couldn't read frame 1") - checkFrame(t, hdr, body) + frt.checkFrameFull(t, hdr, body) - hdr, body, err = r.process() + hdr, body, err = r.next() assertNotError(t, err, "Couldn't read frame 2") - checkFrame(t, hdr, body) + frt.checkFrameFull(t, hdr, body) } -func TestFrameReaderTrickle(t *testing.T) { - r := newFrameReader(simpleHeader{}) +func (frt frameReaderTester) TestTrickle(t *testing.T) { + r := newFrameReader(frt.details) var hdr, body []byte var err error - for i := 0; i <= len(kTestFrame); i += 1 { - hdr, body, err = r.process() - if i < len(kTestFrame) { + for i := 0; i <= len(frt.fullFrame); i += 1 { + hdr, body, err = r.next() + if i < len(frt.fullFrame) { assertEquals(t, err, AlertWouldBlock) assertEquals(t, 0, len(hdr)) assertEquals(t, 0, len(body)) - r.addChunk(kTestFrame[i : i+1]) + r.addChunk(frt.fullFrame[i : i+1]) } } assertNil(t, err, "Error reading") - checkFrame(t, hdr, body) + frt.checkFrameFull(t, hdr, body) +} + +func (frt frameReaderTester) Run(t *testing.T) { + t.Run("frames", frt.TestFrames) + t.Run("two-frames", frt.TestTwoFrames) + t.Run("trickle", frt.TestTrickle) +} + +func TestFrameReader(t *testing.T) { + cases := map[string]frameReaderTester{ + "fixed": frameReaderTester{ + lastNBytesFraming{3, 2}, + 3, fixedFullFrame, + 3, fixedEmptyFrame, + }, + "variable": frameReaderTester{ + variableHeader{}, + 2, variableFullFrame, + 1, variableEmptyFrame, + }, + } + + for label, c := range cases { + t.Run(label, c.Run) + } } diff --git a/handshake-layer.go b/handshake-layer.go index ae11cb8..ae7f506 100644 --- a/handshake-layer.go +++ b/handshake-layer.go @@ -130,6 +130,7 @@ type HandshakeLayer struct { maxFragmentLen int } +/* type handshakeLayerFrameDetails struct { datagram bool } @@ -152,13 +153,14 @@ func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { val, _ := decodeUint(hdr[len(hdr)-3:], 3) return int(val), nil } +*/ func NewHandshakeLayerTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer { h := HandshakeLayer{} h.ctx = c h.conn = r h.datagram = false - h.frame = newFrameReader(&handshakeLayerFrameDetails{false}) + h.frame = newFrameReader(lastNBytesFraming{handshakeHeaderLenTLS, 3}) h.maxFragmentLen = maxFragmentLen return &h } @@ -168,7 +170,7 @@ func NewHandshakeLayerDTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer { h.ctx = c h.conn = r h.datagram = true - h.frame = newFrameReader(&handshakeLayerFrameDetails{true}) + h.frame = newFrameReader(lastNBytesFraming{handshakeHeaderLenDTLS, 3}) h.maxFragmentLen = initialMtu // Not quite right return &h } @@ -359,7 +361,7 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { } for { logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) - if h.frame.needed() > 0 { + if !h.frame.ready() { logf(logTypeVerbose, "Trying to read a new record") err = h.readRecord() @@ -368,7 +370,7 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { } } - hdr, body, err = h.frame.process() + hdr, body, err = h.frame.next() if err == nil { break } diff --git a/record-layer.go b/record-layer.go index ccb90bd..c8591be 100644 --- a/record-layer.go +++ b/record-layer.go @@ -11,6 +11,7 @@ const ( sequenceNumberLen = 8 // sequence number length recordHeaderLenTLS = 5 // record header length (TLS) recordHeaderLenDTLS = 13 // record header length (DTLS) + maxHeaderLen = 256 // invented upper bound for header size maxFragmentLen = 1 << 14 // max number of bytes in a record labelForKey = "key" labelForIV = "iv" @@ -103,25 +104,6 @@ func (r *DefaultRecordLayer) Impl() *DefaultRecordLayer { return r } -type recordLayerFrameDetails struct { - datagram bool -} - -func (d recordLayerFrameDetails) headerLen() int { - if d.datagram { - return recordHeaderLenDTLS - } - return recordHeaderLenTLS -} - -func (d recordLayerFrameDetails) defaultReadLen() int { - return d.headerLen() + maxFragmentLen -} - -func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { - return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil -} - func newCipherStateNull() *cipherState { return &cipherState{EpochClear, 0, 0, nil, nil} } @@ -140,7 +122,7 @@ func NewRecordLayerTLS(conn io.ReadWriter, dir Direction) *DefaultRecordLayer { r.label = "" r.direction = dir r.conn = conn - r.frame = newFrameReader(recordLayerFrameDetails{false}) + r.frame = newFrameReader(lastNBytesFraming{recordHeaderLenTLS, 2}) r.cipher = newCipherStateNull() r.version = tls10Version return &r @@ -151,7 +133,7 @@ func NewRecordLayerDTLS(conn io.ReadWriter, dir Direction) *DefaultRecordLayer { r.label = "" r.direction = dir r.conn = conn - r.frame = newFrameReader(recordLayerFrameDetails{true}) + r.frame = newFrameReader(lastNBytesFraming{recordHeaderLenDTLS, 2}) r.cipher = newCipherStateNull() r.readCiphers = make(map[Epoch]*cipherState, 0) r.readCiphers[0] = r.cipher @@ -352,8 +334,8 @@ func (r *DefaultRecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, erro var header, body []byte for err != nil { - if r.frame.needed() > 0 { - buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen) + if !r.frame.ready() { + buf := make([]byte, maxHeaderLen+maxFragmentLen) n, err := r.conn.Read(buf) if err != nil { logf(logTypeIO, "%s Error reading, %v", r.label, err) @@ -370,7 +352,7 @@ func (r *DefaultRecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, erro r.frame.addChunk(buf) } - header, body, err = r.frame.process() + header, body, err = r.frame.next() // Loop around onAlertWouldBlock to see if some // data is now available. if err != nil && err != AlertWouldBlock {