From 036d2addd9f363647cc9e7a8e6f7d1363693f0c5 Mon Sep 17 00:00:00 2001 From: Michael Hoffmann Date: Mon, 18 Oct 2021 11:13:58 +0200 Subject: [PATCH] add snappy compression --- channel.go | 2 + channel_test.go | 4 +- connection.go | 38 +++++++++++++++ connection_test.go | 18 ++++--- glide.lock | 6 ++- glide.yaml | 2 + init_test.go | 1 + messages.go | 2 + preinit_connection.go | 9 ++++ snappy.go | 110 ++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 181 insertions(+), 11 deletions(-) create mode 100644 snappy.go diff --git a/channel.go b/channel.go index cac650cfd..c97dcab67 100644 --- a/channel.go +++ b/channel.go @@ -189,6 +189,7 @@ type Channel struct { onPeerStatusChanged func(*Peer) dialer func(ctx context.Context, hostPort string) (net.Conn, error) connContext func(ctx context.Context, conn net.Conn) context.Context + compressionMethod string closed chan struct{} // mutable contains all the members of Channel which are mutable. @@ -339,6 +340,7 @@ func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) { LanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), TChannelVersion: VersionInfo, }, + CompressionMethod: ch.ConnectionOptions().CompressionMethod, }, ServiceName: serviceName, } diff --git a/channel_test.go b/channel_test.go index 316d5362a..a1256d2d5 100644 --- a/channel_test.go +++ b/channel_test.go @@ -45,7 +45,8 @@ func toMap(fields LogFields) map[string]interface{} { func TestNewChannel(t *testing.T) { ch, err := NewChannel("svc", &ChannelOptions{ - ProcessName: "pname", + ProcessName: "pname", + DefaultConnectionOptions: ConnectionOptions{CompressionMethod: SnappyCompression}.withDefaults(), }) require.NoError(t, err, "NewChannel failed") @@ -60,6 +61,7 @@ func TestNewChannel(t *testing.T) { LanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), TChannelVersion: VersionInfo, }, + CompressionMethod: SnappyCompression, }, }, ch.PeerInfo(), "Wrong local peer info") } diff --git a/connection.go b/connection.go index dd426a5cb..6687a0058 100644 --- a/connection.go +++ b/connection.go @@ -52,6 +52,32 @@ const ( DefaultConnectionBufferSize = 512 ) +// CompressionMethod used in connection +type CompressionMethod string + +const ( + // NoCompression means that no compression is used + NoCompression CompressionMethod = "" + // SnappyCompression enables the snappy compression + SnappyCompression CompressionMethod = "snappy" +) + +func (cm CompressionMethod) String() string { + return string(cm) +} + +// NewCompressionMethod returns CompressionMethod from string +func NewCompressionMethod(s string) (CompressionMethod, error) { + switch s { + case "none", "": + return NoCompression, nil + case "snappy": + return SnappyCompression, nil + default: + return "", fmt.Errorf("invalid compression method '%s'", s) + } +} + // PeerVersion contains version related information for a specific peer. // These values are extracted from the init headers. type PeerVersion struct { @@ -73,6 +99,9 @@ type PeerInfo struct { // Version returns the version information for the remote peer. Version PeerVersion `json:"version"` + + // CompressionMethod returns the compression method used by the peer. + CompressionMethod CompressionMethod `json:"compressionMethod"` } func (p PeerInfo) String() string { @@ -158,6 +187,9 @@ type ConnectionOptions struct { // MaxCloseTime controls how long we allow a connection to complete pending // calls before shutting down. Only used if it is non-zero. MaxCloseTime time.Duration + + // CompressionMethod specifies the compression used + CompressionMethod CompressionMethod } // connectionEvents are the events that can be triggered by a connection. @@ -314,6 +346,7 @@ func (ch *Channel) newConnection(baseCtx context.Context, conn net.Conn, initial {"remoteHostPort", remotePeer.HostPort}, {"remoteIsEphemeral", remotePeer.IsEphemeral}, {"remoteProcess", remotePeer.ProcessName}, + {"compression", remotePeer.CompressionMethod}, }...) if outboundHP != "" { connDirection = outbound @@ -324,6 +357,11 @@ func (ch *Channel) newConnection(baseCtx context.Context, conn net.Conn, initial peerInfo := ch.PeerInfo() timeNow := ch.timeNow().UnixNano() + // Enable compression if both peers have it enabled + if peerInfo.CompressionMethod == SnappyCompression && peerInfo.CompressionMethod == remotePeer.CompressionMethod { + conn = NewSnappyConnection(conn) + } + c := &Connection{ channelConnectionCommon: ch.channelConnectionCommon, diff --git a/connection_test.go b/connection_test.go index 3ba93736e..032c0cee6 100644 --- a/connection_test.go +++ b/connection_test.go @@ -212,10 +212,11 @@ func TestRemotePeer(t *testing.T) { remote: func(t testing.TB, ts *testutils.TestServer) *Channel { return ts.NewClient(nil) }, expectedFn: func(state *RuntimeState, ts *testutils.TestServer) PeerInfo { return PeerInfo{ - HostPort: state.RootPeers[ts.HostPort()].OutboundConnections[0].LocalHostPort, - IsEphemeral: true, - ProcessName: state.LocalPeer.ProcessName, - Version: wantVersion, + HostPort: state.RootPeers[ts.HostPort()].OutboundConnections[0].LocalHostPort, + IsEphemeral: true, + ProcessName: state.LocalPeer.ProcessName, + Version: wantVersion, + CompressionMethod: NoCompression, } }, }, @@ -224,10 +225,11 @@ func TestRemotePeer(t *testing.T) { remote: func(t testing.TB, ts *testutils.TestServer) *Channel { return ts.NewServer(nil) }, expectedFn: func(state *RuntimeState, ts *testutils.TestServer) PeerInfo { return PeerInfo{ - HostPort: state.LocalPeer.HostPort, - IsEphemeral: false, - ProcessName: state.LocalPeer.ProcessName, - Version: wantVersion, + HostPort: state.LocalPeer.HostPort, + IsEphemeral: false, + ProcessName: state.LocalPeer.ProcessName, + Version: wantVersion, + CompressionMethod: NoCompression, } }, }, diff --git a/glide.lock b/glide.lock index a3186e947..d863c9d1a 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: 7400b8d3d51badea662da185fd9cc0098c7a1fb0a2731b17ceae4e91f6455d1d -updated: 2020-03-17T12:48:49.8534584+11:00 +hash: 59640aee5ecbe113a82206d2ddc96ecea477c4a75f92747bbe04ef16599317eb +updated: 2020-09-03T09:22:05.734767741+03:00 imports: - name: github.com/apache/thrift version: b2a4d4ae21c789b689dd162deb819665567f481c @@ -9,6 +9,8 @@ imports: version: 5ca90424ceb7e5e9affff2765da00e9dd737f274 subpackages: - statsd +- name: github.com/golang/snappy + version: 196ae77b8a26000fa30caa8b2b541e09674dbc43 - name: github.com/opentracing/opentracing-go version: 659c90643e714681897ec2521c60567dd21da733 subpackages: diff --git a/glide.yaml b/glide.yaml index 157c14cd9..28e21914b 100644 --- a/glide.yaml +++ b/glide.yaml @@ -38,6 +38,8 @@ import: - unix - package: go.uber.org/multierr version: ^1.1.0 +- package: github.com/golang/snappy + version: 196ae77b8a26000fa30caa8b2b541e09674dbc43 testImport: - package: github.com/jessevdk/go-flags version: ^1 diff --git a/init_test.go b/init_test.go index b7dcd0e9c..b3f0625a1 100644 --- a/init_test.go +++ b/init_test.go @@ -242,6 +242,7 @@ func TestHandleInitReqNewVersion(t *testing.T) { InitParamTChannelLanguage: "go", InitParamTChannelLanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), InitParamTChannelVersion: VersionInfo, + InitParamTChannelCompression: NoCompression.String(), }, }, }, msg, "unexpected init res") diff --git a/messages.go b/messages.go index afcbf0df3..a6b8256ca 100644 --- a/messages.go +++ b/messages.go @@ -78,6 +78,8 @@ const ( InitParamTChannelLanguageVersion = "tchannel_language_version" // InitParamTChannelVersion contains the library version. InitParamTChannelVersion = "tchannel_version" + // InitParamTChannelCompression contains the compression method. + InitParamTChannelCompression = "tchannel_compression" ) // initMessage is the base for messages in the initialization handshake diff --git a/preinit_connection.go b/preinit_connection.go index f49ddec43..f19923d4e 100644 --- a/preinit_connection.go +++ b/preinit_connection.go @@ -109,6 +109,7 @@ func (ch *Channel) getInitParams() initParams { InitParamTChannelLanguage: localPeer.Version.Language, InitParamTChannelLanguageVersion: localPeer.Version.LanguageVersion, InitParamTChannelVersion: localPeer.Version.TChannelVersion, + InitParamTChannelCompression: localPeer.CompressionMethod.String(), } } @@ -194,6 +195,14 @@ func parseRemotePeer(p initParams, remoteAddr net.Addr) (PeerInfo, peerAddressCo return remotePeer, remotePeerAddress, fmt.Errorf("header %v is required", InitParamProcessName) } + if compressionMethod, ok := p[InitParamTChannelCompression]; ok { + cm, err := NewCompressionMethod(compressionMethod) + if err != nil { + return remotePeer, remotePeerAddress, err + } + remotePeer.CompressionMethod = cm + } + // If the remote host:port is ephemeral, use the socket address as the // host:port and set IsEphemeral to true. if isEphemeralHostPort(remotePeer.HostPort) { diff --git a/snappy.go b/snappy.go new file mode 100644 index 000000000..501750fdd --- /dev/null +++ b/snappy.go @@ -0,0 +1,110 @@ +package tchannel + +import ( + "errors" + "net" + "syscall" + "time" + + "github.com/golang/snappy" + "go.uber.org/multierr" +) + +// SnappyConn wraps net.Conn with Snappy compression +type SnappyConn struct { + conn net.Conn + reader *snappy.Reader + writer *snappy.Writer +} + +// NewSnappyConnection creates a new Snappy compressed connection. +// +// The snappy writer is not buffered, to honor potential deadlines on the underlying net.Conn +func NewSnappyConnection(conn net.Conn) net.Conn { + w := snappy.NewWriter(conn) + r := snappy.NewReader(conn) + return &SnappyConn{conn: conn, writer: w, reader: r} +} + +// Read reads data from the connection. +// Read can be made to time out and return an Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetReadDeadline. +func (sc *SnappyConn) Read(b []byte) (n int, err error) { + return sc.reader.Read(b) +} + +// Write writes data to the connection. +// Write can be made to time out and return an Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetWriteDeadline. +func (sc *SnappyConn) Write(b []byte) (n int, err error) { + return sc.writer.Write(b) +} + +// Close closes the connection. +// Any blocked Read or Write operations will be unblocked and return errors. +func (sc *SnappyConn) Close() (err error) { + return multierr.Combine( + sc.writer.Close(), + sc.conn.Close(), + ) +} + +// LocalAddr returns the local network address. +func (sc *SnappyConn) LocalAddr() net.Addr { + return sc.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (sc *SnappyConn) RemoteAddr() net.Addr { + return sc.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. It is equivalent to calling both +// SetReadDeadline and SetWriteDeadline. +// +// A deadline is an absolute time after which I/O operations +// fail with a timeout (see type Error) instead of +// blocking. The deadline applies to all future and pending +// I/O, not just the immediately following call to Read or +// Write. After a deadline has been exceeded, the connection +// can be refreshed by setting a deadline in the future. +// +// An idle timeout can be implemented by repeatedly extending +// the deadline after successful Read or Write calls. +// +// A zero value for t means I/O operations will not time out. +// +// Note that if a TCP connection has keep-alive turned on, +// which is the default unless overridden by Dialer.KeepAlive +// or ListenConfig.KeepAlive, then a keep-alive failure may +// also return a timeout error. On Unix systems a keep-alive +// failure on I/O can be detected using +// errors.Is(err, syscall.ETIMEDOUT). +func (sc *SnappyConn) SetDeadline(t time.Time) error { + return sc.conn.SetDeadline(t) +} + +// SetReadDeadline sets the deadline for future Read calls +// and any currently-blocked Read call. +// A zero value for t means Read will not time out. +func (sc *SnappyConn) SetReadDeadline(t time.Time) error { + return sc.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the deadline for future Write calls +// and any currently-blocked Write call. +// Even if write times out, it may return n > 0, indicating that +// some of the data was successfully written. +// A zero value for t means Write will not time out. +func (sc *SnappyConn) SetWriteDeadline(t time.Time) error { + return sc.conn.SetWriteDeadline(t) +} + +// SyscallConn from the underlying connection +func (sc *SnappyConn) SyscallConn() (syscall.RawConn, error) { + if sysc, ok := sc.conn.(syscall.Conn); ok { + return sysc.SyscallConn() + } + return nil, errors.New("the underlying connection does not implement syscall.Conn") +}