diff --git a/client_test.go b/client_test.go index 952328f..af31ded 100644 --- a/client_test.go +++ b/client_test.go @@ -180,7 +180,7 @@ func TestClient(t *testing.T) { conn1.mx.RUnlock() assert.Equal(t, initID+2, newInitID) - assert.NoError(t, closeClosers(conn1, conn2)) + assert.NoError(t, closeClosers(conn1, conn2, tr1)) checkClientConnsClosed(t, conn1, conn2) assert.Error(t, errWithTimeout(serveErrCh1)) diff --git a/deadline.go b/deadline.go new file mode 100644 index 0000000..5274496 --- /dev/null +++ b/deadline.go @@ -0,0 +1,78 @@ +package dmsg + +import ( + "io" + "math" + "sync" + "sync/atomic" + "time" +) + +type deadline struct { + mu sync.Mutex + timestamp int64 + ch chan struct{} + modCh chan struct{} + modChClosed uint32 +} + +func newDeadline() *deadline { + return &deadline{ + timestamp: math.MaxInt64, + ch: make(chan struct{}), + modCh: make(chan struct{}, 1), + } +} + +func (d *deadline) Close() error { + d.mu.Lock() + defer d.mu.Unlock() + + atomic.StoreUint32(&d.modChClosed, 1) + close(d.modCh) + + return nil +} + +func (d *deadline) Set(t time.Time) error { + d.mu.Lock() + defer d.mu.Unlock() + + if atomic.LoadUint32(&d.modChClosed) != 0 { + return io.ErrClosedPipe + } + + atomic.StoreInt64(&d.timestamp, t.UnixNano()) + d.modCh <- struct{}{} + + return nil +} + +func (d *deadline) Serve() error { + defer close(d.ch) + for { + if d.Exceeded() { + select { + case d.ch <- struct{}{}: + case _, ok := <-d.modCh: + if !ok { + return nil + } + } + continue + } + if _, ok := <-d.modCh; !ok { + return nil + } + } +} + +func (d *deadline) Exceeded() bool { + now := time.Now().UnixNano() + deadline := atomic.LoadInt64(&d.timestamp) + return now >= deadline +} + +func (d *deadline) Chan() <-chan struct{} { + return d.ch +} diff --git a/deadline_test.go b/deadline_test.go new file mode 100644 index 0000000..8790ad5 --- /dev/null +++ b/deadline_test.go @@ -0,0 +1,133 @@ +package dmsg + +import ( + "math" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func Test_newDeadline(t *testing.T) { + deadline := newDeadline() + require.NotNil(t, deadline) + + require.EqualValues(t, math.MaxInt64, deadline.timestamp) + require.EqualValues(t, 0, deadline.modChClosed) + + ch := deadline.ch + require.NotNil(t, ch) + + modCh := deadline.modCh + require.NotNil(t, modCh) +} + +func Test_deadline_SetDeadline(t *testing.T) { + deadline := newDeadline() + + var wg sync.WaitGroup + var serveErr error + wg.Add(1) + go func() { + defer wg.Done() + serveErr = deadline.Serve() + }() + + now := time.Now() + cases := []struct { + deadline time.Time + unixNano int64 + }{ + { + deadline: now.Add(1 * time.Second), + unixNano: now.Add(1 * time.Second).UnixNano(), + }, + { + deadline: now.Add(-1 * time.Second), + unixNano: now.Add(-1 * time.Second).UnixNano(), + }, + } + + for _, tc := range cases { + require.NoError(t, deadline.Set(tc.deadline)) + require.Equal(t, tc.unixNano, deadline.timestamp) + } + + require.NoError(t, deadline.Close()) + require.NotEqual(t, 0, deadline.modChClosed) + + wg.Wait() + require.NoError(t, serveErr) +} + +func Test_deadline_Chan(t *testing.T) { + deadline := newDeadline() + ch := deadline.Chan() + require.NotNil(t, ch) +} + +func Test_deadline_Exceeded(t *testing.T) { + deadline := newDeadline() + + var wg sync.WaitGroup + var serveErr error + wg.Add(1) + go func() { + defer wg.Done() + serveErr = deadline.Serve() + }() + + now := time.Now() + cases := []struct { + deadline time.Time + exceeded bool + }{ + { + deadline: now.Add(1 * time.Second), + exceeded: false, + }, + { + deadline: now.Add(-1 * time.Second), + exceeded: true, + }, + } + + for _, tc := range cases { + require.NoError(t, deadline.Set(tc.deadline)) + require.Equal(t, tc.exceeded, deadline.Exceeded()) + } + + require.NoError(t, deadline.Close()) + require.NotEqual(t, 0, deadline.modChClosed) + + wg.Wait() + require.NoError(t, serveErr) +} + +func Test_deadline_concurrency(t *testing.T) { + t.Run("Set and close", func(t *testing.T) { + deadline := newDeadline() + + var wg sync.WaitGroup + var serveErr error + wg.Add(1) + go func() { + defer wg.Done() + serveErr = deadline.Serve() + }() + + var setErr error + wg.Add(1) + go func() { + defer wg.Done() + setErr = deadline.Set(time.Now().Add(1 * time.Second)) + }() + + require.NoError(t, deadline.Close()) + + wg.Wait() + require.NoError(t, serveErr) + require.Error(t, setErr) + }) +} diff --git a/server_test.go b/server_test.go index 049ab63..8b8641d 100644 --- a/server_test.go +++ b/server_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "math" "math/rand" "net" "os" @@ -228,6 +229,10 @@ func TestServer_Serve(t *testing.T) { testServerSelfDialing(t) }) + t.Run("Deadlines", func(t *testing.T) { + testServerDeadlines(t) + }) + t.Run("Server disconnection closes transports", func(t *testing.T) { testServerDisconnection(t) }) @@ -255,7 +260,7 @@ func testServerDisconnection(t *testing.T) { responder := createClient(t, dc, responderName) initiator := createClient(t, dc, initiatorName) initConn, respConns := dial(t, initiator, responder, port, noDelay) - testTransportMessaging(t, initConn, respConns) + require.NoError(t, testTransportMessaging(initConn, respConns, time.Duration(0))) require.NoError(t, srv.Close()) require.NoError(t, errWithTimeout(srvErrCh)) @@ -276,24 +281,92 @@ func testServerSelfDialing(t *testing.T) { client := createClient(t, dc, "client") selfWrTp, selfRdTp := dial(t, client, client, port, noDelay) // try to write/read message to/from self - testTransportMessaging(t, selfWrTp, selfRdTp) + require.NoError(t, testTransportMessaging(selfWrTp, selfRdTp, time.Duration(0))) require.NoError(t, closeClosers(selfRdTp, selfWrTp, client)) assert.NoError(t, srv.Close()) assert.NoError(t, errWithTimeout(srvErrCh)) } -func testTransportMessaging(t *testing.T, init, resp io.ReadWriter) { +func testServerDeadlines(t *testing.T) { + t.Parallel() + + now := time.Now() + cases := []struct { + deadline time.Time + error error + }{ + {now.Add(10 * time.Second), nil}, + {now.Add(-1 * time.Second), ErrDeadlineExceeded}, + {now.Add(1 * time.Millisecond), ErrDeadlineExceeded}, + {time.Unix(0, math.MaxInt64), nil}, + } + + dc := disc.NewMock() + srv, srvErrCh, err := createServer(dc) + require.NoError(t, err) + + responder := createClient(t, dc, responderName) + initiator := createClient(t, dc, initiatorName) + initiatorTransport, responderTransport := dial(t, initiator, responder, port, noDelay) + + t.Run("Read", func(t *testing.T) { + for _, tc := range cases { + require.NoError(t, responderTransport.SetReadDeadline(tc.deadline)) + require.Equal(t, tc.error, testTransportMessaging(initiatorTransport, responderTransport, time.Duration(0))) + } + }) + + t.Run("Write", func(t *testing.T) { + for _, tc := range cases { + require.NoError(t, initiatorTransport.SetWriteDeadline(tc.deadline)) + require.Equal(t, tc.error, testTransportMessaging(initiatorTransport, responderTransport, time.Duration(0))) + } + }) + + t.Run("Read and write", func(t *testing.T) { + for _, tc := range cases { + require.NoError(t, initiatorTransport.SetDeadline(tc.deadline)) + require.NoError(t, responderTransport.SetDeadline(tc.deadline)) + require.Equal(t, tc.error, testTransportMessaging(initiatorTransport, responderTransport, time.Duration(0))) + } + }) + + t.Run("Multiple read and write", func(t *testing.T) { + duration := 10 * time.Second + + deadline := time.Now().Add(duration) + require.NoError(t, initiatorTransport.SetDeadline(deadline)) + require.NoError(t, responderTransport.SetDeadline(deadline)) + time.Sleep(duration) + require.Equal(t, ErrDeadlineExceeded, testTransportMessaging(initiatorTransport, responderTransport, time.Duration(0))) + + deadline = time.Now().Add(duration) + require.NoError(t, initiatorTransport.SetDeadline(deadline)) + require.NoError(t, responderTransport.SetDeadline(deadline)) + require.Equal(t, ErrDeadlineExceeded, testTransportMessaging(initiatorTransport, responderTransport, duration)) + }) + + assert.NoError(t, srv.Close()) + assert.NoError(t, errWithTimeout(srvErrCh)) +} + +func testTransportMessaging(init, resp io.ReadWriter, delay time.Duration) error { for i := 0; i < msgCount; i++ { - _, err := init.Write([]byte(message)) - require.NoError(t, err) // TODO: Sometimes this returns error: "io: read/write on closed pipe" + if _, err := init.Write([]byte(message)); err != nil { + return err + } + + time.Sleep(delay) recvBuf := make([]byte, bufSize) for i := 0; i < len(message); i += bufSize { - _, err := resp.Read(recvBuf) - require.NoError(t, err) + if _, err := resp.Read(recvBuf); err != nil { + return err + } } } + return nil } func testServerCappedTransport(t *testing.T) { diff --git a/transport.go b/transport.go index 2b1da95..b18bc07 100644 --- a/transport.go +++ b/transport.go @@ -7,6 +7,7 @@ import ( "io" "net" "sync" + "time" "github.com/skycoin/skycoin/src/util/logging" @@ -20,6 +21,8 @@ var ( ErrRequestCheckFailed = errors.New("failed to create transport: request check failed") ErrAcceptCheckFailed = errors.New("failed to create transport: accept check failed") ErrPortNotListening = errors.New("failed to create transport: port not listening") + ErrDeadlineExceeded = errors.New("deadline exceeded") + ErrDeadlineClosed = errors.New("failed to set deadline: deadline manager is closed") ) // Transport represents communication between two nodes via a single hop: @@ -48,28 +51,46 @@ type Transport struct { done chan struct{} // chan which closes when transport stops serving doneOnce sync.Once // ensures 'done' only closes once doneFunc func(id uint16) // contains a method to remove the transport from dmsg.Client + + readDeadline *deadline + writeDeadline *deadline } // NewTransport creates a new dms_tp. func NewTransport(conn net.Conn, log *logging.Logger, local, remote Addr, id uint16, doneFunc func(id uint16)) *Transport { tp := &Transport{ - Conn: conn, - log: log, - id: id, - local: local, - remote: remote, - inCh: make(chan Frame), - ackWaiter: ioutil.NewUint16AckWaiter(), - ackBuf: make([]byte, 0, tpAckCap), - buf: make(net.Buffers, 0, tpBufFrameCap), - bufCh: make(chan struct{}, 1), - serving: make(chan struct{}), - done: make(chan struct{}), - doneFunc: doneFunc, + Conn: conn, + log: log, + id: id, + local: local, + remote: remote, + inCh: make(chan Frame), + ackWaiter: ioutil.NewUint16AckWaiter(), + ackBuf: make([]byte, 0, tpAckCap), + buf: make(net.Buffers, 0, tpBufFrameCap), + bufCh: make(chan struct{}, 1), + serving: make(chan struct{}), + done: make(chan struct{}), + doneFunc: doneFunc, + readDeadline: newDeadline(), + writeDeadline: newDeadline(), } if err := tp.ackWaiter.RandSeq(); err != nil { log.Fatalln("failed to set ack_waiter seq:", err) } + + go func() { + if err := tp.readDeadline.Serve(); err != nil { + log.WithError(err).Warn("Failed to serve read deadline") + } + }() + + go func() { + if err := tp.writeDeadline.Serve(); err != nil { + log.WithError(err).Warn("Failed to serve write deadline") + } + }() + return tp } @@ -105,6 +126,14 @@ func (tp *Transport) close() (closed bool) { tp.inMx.Lock() close(tp.inCh) tp.inMx.Unlock() + + if err := tp.readDeadline.Close(); err != nil { + tp.log.WithError(err).Warn("Failed to close read deadline") + } + + if err := tp.writeDeadline.Close(); err != nil { + tp.log.WithError(err).Warn("Failed to close write deadline") + } }) tp.serve() // just in case. @@ -356,43 +385,87 @@ func (tp *Transport) Serve() { } // Read implements io.Reader -// TODO(evanlinjin): read deadline. -func (tp *Transport) Read(p []byte) (n int, err error) { +func (tp *Transport) Read(p []byte) (int, error) { + if tp.readDeadline.Exceeded() { + return 0, ErrDeadlineExceeded + } + + var n int + var err error + done := make(chan struct{}) + go func() { + n, err = tp.read(p) + close(done) + }() + + select { + case <-done: + return n, err + case _, ok := <-tp.readDeadline.ch: + if !ok { + return 0, io.EOF + } + return 0, ErrDeadlineExceeded + } +} + +func (tp *Transport) read(p []byte) (n int, err error) { <-tp.serving tp.rMx.Lock() defer tp.rMx.Unlock() -startRead: - tp.bufMx.Lock() - n, err = tp.buf.Read(p) - if tp.bufSize -= n; tp.bufSize < tpBufCap && len(tp.ackBuf) > 0 { - acks := tp.ackBuf - tp.ackBuf = make([]byte, 0, tpAckCap) - go func() { - if err := writeFrame(tp.Conn, acks); err != nil { - tp.close() + for { + tp.bufMx.Lock() + n, err = tp.buf.Read(p) + if tp.bufSize -= n; tp.bufSize < tpBufCap && len(tp.ackBuf) > 0 { + acks := tp.ackBuf + tp.ackBuf = make([]byte, 0, tpAckCap) + go func() { + if err := writeFrame(tp.Conn, acks); err != nil { + tp.close() + } + }() + } + tp.bufMx.Unlock() + if n > 0 || len(p) == 0 { + if !tp.IsClosed() { + err = nil } - }() + return n, err + } + if _, ok := <-tp.bufCh; !ok { + return n, err + } } - tp.bufMx.Unlock() +} - if n > 0 || len(p) == 0 { - if !tp.IsClosed() { - err = nil - } - return n, err +// Write implements io.Writer +func (tp *Transport) Write(p []byte) (int, error) { + if tp.writeDeadline.Exceeded() { + return 0, ErrDeadlineExceeded } - if _, ok := <-tp.bufCh; !ok { + var n int + var err error + done := make(chan struct{}) + go func() { + n, err = tp.write(p) + close(done) + }() + + select { + case <-done: return n, err + case _, ok := <-tp.writeDeadline.ch: + if !ok { + return 0, io.ErrClosedPipe + } + return 0, ErrDeadlineExceeded } - goto startRead } -// Write implements io.Writer -// TODO(evanlinjin): write deadline. -func (tp *Transport) Write(p []byte) (int, error) { +func (tp *Transport) write(p []byte) (int, error) { <-tp.serving if tp.IsClosed() { @@ -411,3 +484,21 @@ func (tp *Transport) Write(p []byte) (int, error) { } return len(p), nil } + +// SetDeadline sets read and write deadlines for transport. +func (tp *Transport) SetDeadline(t time.Time) error { + if err := tp.SetReadDeadline(t); err != nil { + return err + } + return tp.SetWriteDeadline(t) +} + +// SetReadDeadline sets read deadline for transport. +func (tp *Transport) SetReadDeadline(t time.Time) error { + return tp.readDeadline.Set(t) +} + +// SetWriteDeadline sets write deadline for transport. +func (tp *Transport) SetWriteDeadline(t time.Time) error { + return tp.writeDeadline.Set(t) +}