From 511f6654c814e6ab915c19ce1f90f83941d14adb Mon Sep 17 00:00:00 2001 From: Evgeny Korolev Date: Fri, 8 Nov 2019 13:10:33 +0300 Subject: [PATCH 1/3] Refactor, rename, reorder 1. Split ClientConfig and ManagedClientConfig (revert to Comcast) 2. Rename ConsumerConfig to ManagedConsumerConfig 3. Rename ProducerConfig to ManagedProducerConfig 4. Rename NewPartitionManagedConsumer() to NewManagedPartitionConsumer() (according to ManagedPartitionConsumer type name and NewManagedPartitionProducer()) 5. Reorder code: type constructor, type definition, type methods 6. Unpublish fields that are not used by other pulsar-go-client packages (revert to Comcast) 7. Fixed build error in managed_consumer.go (Go 1.13 faults at https://github.com/wolfstudy/pulsar-client-go/blob/6c1405cf4104fe07f70bbd2a6d26197dd2d8e21e/core/manage/managed_consumer.go#L216) --- cmd/cli/main.go | 26 +- core/conn/conn.go | 46 +-- core/conn/conn_test.go | 30 +- core/conn/connector.go | 38 ++- core/conn/connector_test.go | 5 +- core/conn/mockserver.go | 20 +- core/frame/framedispatcher.go | 154 +++++----- core/frame/mocksender.go | 18 +- core/manage/client.go | 135 +++++---- core/manage/client_test.go | 22 +- core/manage/managed_client.go | 61 ++-- core/manage/managed_client_pool.go | 4 +- core/manage/managed_client_pool_test.go | 48 +-- core/manage/managed_client_test.go | 24 +- core/manage/managed_consumer.go | 282 +++++++++--------- .../managed_consumer_integration_test.go | 54 ++-- core/manage/managed_consumer_test.go | 32 +- core/manage/managed_producer.go | 219 +++++++------- core/manage/managed_producer_test.go | 40 ++- core/manage/pubsub.go | 92 +++--- core/manage/pubsub_test.go | 14 +- core/manage/subscriptions.go | 64 ++-- core/manage/unackedMsgTracker.go | 22 +- core/manage/util_test.go | 25 ++ core/pub/producer.go | 70 ++--- core/srv/discoverer.go | 24 +- core/srv/pinger.go | 14 +- core/sub/consumer.go | 109 ++++--- core/sub/consumer_test.go | 6 +- examples/consumer/consumer.go | 12 +- examples/producer/producer.go | 8 +- utils/util.go | 43 --- 32 files changed, 901 insertions(+), 860 deletions(-) create mode 100644 core/manage/util_test.go diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 10d118d..5aef046 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -131,16 +131,18 @@ func main() { switch args.producer { case true: // Create the managed producer - mpCfg := manage.ProducerConfig{ + mpCfg := manage.ManagedProducerConfig{ Name: args.name, Topic: args.topic, NewProducerTimeout: time.Second, InitialReconnectDelay: time.Second, MaxReconnectDelay: time.Minute, - ClientConfig: manage.ClientConfig{ - Addr: args.pulsar, - TLSConfig: tlsCfg, - Errs: asyncErrs, + ManagedClientConfig: manage.ManagedClientConfig{ + ClientConfig: manage.ClientConfig{ + Addr: args.pulsar, + TLSConfig: tlsCfg, + Errs: asyncErrs, + }, }, } mp := manage.NewManagedProducer(mcp, mpCfg) @@ -183,7 +185,7 @@ func main() { return } sctx, cancel := context.WithTimeout(ctx, time.Second) - _, err := mp.Send(sctx, payload,"") + _, err := mp.Send(sctx, payload, "") cancel() if err != nil { fmt.Fprintln(os.Stderr, err) @@ -199,16 +201,18 @@ func main() { queue := make(chan msg.Message, 8) // Create managed consumer - mcCfg := manage.ConsumerConfig{ + mcCfg := manage.ManagedConsumerConfig{ Name: args.name, Topic: args.topic, NewConsumerTimeout: time.Second, InitialReconnectDelay: time.Second, MaxReconnectDelay: time.Minute, - ClientConfig: manage.ClientConfig{ - Addr: args.pulsar, - TLSConfig: tlsCfg, - Errs: asyncErrs, + ManagedClientConfig: manage.ManagedClientConfig{ + ClientConfig: manage.ClientConfig{ + Addr: args.pulsar, + TLSConfig: tlsCfg, + Errs: asyncErrs, + }, }, } diff --git a/core/conn/conn.go b/core/conn/conn.go index 0e5b523..6c676fa 100644 --- a/core/conn/conn.go +++ b/core/conn/conn.go @@ -41,9 +41,9 @@ func NewTCPConn(addr string, timeout time.Duration) (*Conn, error) { } return &Conn{ - Rc: c, - W: c, - Closedc: make(chan struct{}), + rc: c, + w: c, + closedc: make(chan struct{}), }, nil } @@ -62,23 +62,23 @@ func NewTLSConn(addr string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, } return &Conn{ - Rc: c, - W: c, - Closedc: make(chan struct{}), + rc: c, + w: c, + closedc: make(chan struct{}), }, nil } // Conn is responsible for writing and reading // Frames to and from the underlying connection (r and w). type Conn struct { - Rc io.ReadCloser + rc io.ReadCloser - Wmu sync.Mutex // protects w to ensure frames aren't interleaved - W io.Writer + wmu sync.Mutex // protects w to ensure frames aren't interleaved + w io.Writer - Cmu sync.Mutex // protects following - IsClosed bool - Closedc chan struct{} + cmu sync.Mutex // protects following + isClosed bool + closedc chan struct{} } // Close closes the underlaying connection. @@ -86,16 +86,16 @@ type Conn struct { // an error. It will also cause the closed channel // to unblock. func (c *Conn) Close() error { - c.Cmu.Lock() - defer c.Cmu.Unlock() + c.cmu.Lock() + defer c.cmu.Unlock() - if c.IsClosed { + if c.isClosed { return nil } - err := c.Rc.Close() - close(c.Closedc) - c.IsClosed = true + err := c.rc.Close() + close(c.closedc) + c.isClosed = true return err } @@ -104,7 +104,7 @@ func (c *Conn) Close() error { // when the connection has been closed and is no // longer usable. func (c *Conn) Closed() <-chan struct{} { - return c.Closedc + return c.closedc } // Read blocks while it reads from r until an error occurs. @@ -116,7 +116,7 @@ func (c *Conn) Closed() <-chan struct{} { func (c *Conn) Read(frameHandler func(f frame.Frame)) error { for { var f frame.Frame - if err := f.Decode(c.Rc); err != nil { + if err := f.Decode(c.rc); err != nil { // It's very possible that the connection is already closed at this // point, since any connection closed errors would bubble up // from Decode. But just in case it's a decode error (bad data for example), @@ -164,9 +164,9 @@ func (c *Conn) writeFrame(f *frame.Frame) error { return err } - c.Wmu.Lock() - _, err := b.WriteTo(c.W) - c.Wmu.Unlock() + c.wmu.Lock() + _, err := b.WriteTo(c.w) + c.wmu.Unlock() return err } diff --git a/core/conn/conn_test.go b/core/conn/conn_test.go index 0b246a7..c1ff9d2 100644 --- a/core/conn/conn_test.go +++ b/core/conn/conn_test.go @@ -60,10 +60,10 @@ func TestConn_Read(t *testing.T) { } c := Conn{ - Rc: &mockReadCloser{ + rc: &mockReadCloser{ Reader: &b, }, - Closedc: make(chan struct{}), + closedc: make(chan struct{}), } var gotFrames []frame.Frame @@ -87,10 +87,10 @@ func TestConn_Read(t *testing.T) { func TestConn_Close(t *testing.T) { c := Conn{ - Rc: &mockReadCloser{ + rc: &mockReadCloser{ Reader: new(bytes.Buffer), }, - Closedc: make(chan struct{}), + closedc: make(chan struct{}), } // no-op @@ -116,8 +116,8 @@ func TestConn_GarbageInput(t *testing.T) { Reader: bytes.NewBufferString("this isn't a valid Pulsar frame"), } c := Conn{ - Rc: mrc, - Closedc: make(chan struct{}), + rc: mrc, + closedc: make(chan struct{}), } var gotFrames []frame.Frame @@ -164,8 +164,8 @@ func TestConn_TimeoutReader(t *testing.T) { Reader: iotest.TimeoutReader(&b), } c := Conn{ - Rc: mrc, - Closedc: make(chan struct{}), + rc: mrc, + closedc: make(chan struct{}), } var gotFrames []frame.Frame @@ -205,10 +205,10 @@ func TestConn_Read_SlowSrc(t *testing.T) { c := Conn{ // OneByteReader returns a single byte per read, // regardless of how big its input buffer is. - Rc: &mockReadCloser{ + rc: &mockReadCloser{ Reader: iotest.OneByteReader(&b), }, - Closedc: make(chan struct{}), + closedc: make(chan struct{}), } var gotFrames []frame.Frame @@ -267,10 +267,10 @@ func TestConn_Read_MutliFrame(t *testing.T) { } c := Conn{ - Rc: &mockReadCloser{ + rc: &mockReadCloser{ Reader: &b, }, - Closedc: make(chan struct{}), + closedc: make(chan struct{}), } var gotFrames []frame.Frame @@ -326,11 +326,11 @@ func TestConn_writeFrame(t *testing.T) { // same buffer is used for reads and writes var rw bytes.Buffer c := Conn{ - Rc: &mockReadCloser{ + rc: &mockReadCloser{ Reader: &rw, }, - W: &rw, - Closedc: make(chan struct{}), + w: &rw, + closedc: make(chan struct{}), } // write the frames in parallel (order will diff --git a/core/conn/connector.go b/core/conn/connector.go index 3b8a3fb..4f2fddb 100644 --- a/core/conn/connector.go +++ b/core/conn/connector.go @@ -20,14 +20,30 @@ import ( "github.com/golang/protobuf/proto" "github.com/wolfstudy/pulsar-client-go/core/frame" "github.com/wolfstudy/pulsar-client-go/pkg/api" - "github.com/wolfstudy/pulsar-client-go/utils" +) + +const ( + // ProtoVersion is the Pulsar protocol version + // used by this client. + ProtoVersion = int32(api.ProtocolVersion_v12) + + // ClientVersion is an opaque string sent + // by the client to the server on connect, eg: + // "Pulsar-Client-Java-v1.15.2" + ClientVersion = "pulsar-client-go" + + // undefRequestID defines a RequestID of -1. + // + // Usage example: + // https://github.com/apache/incubator-pulsar/blob/fdc7b8426d8253c9437777ae51a4639239550f00/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java#L325 + undefRequestID = 1<<64 - 1 ) // NewConnector returns a ready-to-use connector. func NewConnector(s frame.CmdSender, dispatcher *frame.Dispatcher) *Connector { return &Connector{ - S: s, - Dispatcher: dispatcher, + s: s, + dispatcher: dispatcher, } } @@ -36,8 +52,8 @@ func NewConnector(s frame.CmdSender, dispatcher *frame.Dispatcher) *Connector { // // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Connectionestablishment-ly8l2n type Connector struct { - S frame.CmdSender - Dispatcher *frame.Dispatcher // used to manage the request/response state + s frame.CmdSender + dispatcher *frame.Dispatcher // used to manage the request/response state } // Connect initiates the client's session. After sending, @@ -48,17 +64,17 @@ type Connector struct { // // It's required to have completed Connect/Connected before using the client. func (c *Connector) Connect(ctx context.Context, authMethod, proxyBrokerURL string) (*api.CommandConnected, error) { - resp, cancel, err := c.Dispatcher.RegisterGlobal() + resp, cancel, err := c.dispatcher.RegisterGlobal() if err != nil { return nil, err } defer cancel() // NOTE: The source seems to indicate that the ERROR messages's - // RequestID will be -1 (ie UndefRequestID) in the case that it's + // RequestID will be -1 (ie undefRequestID) in the case that it's // associated with a CONNECT request. // https://github.com/apache/incubator-pulsar/blob/fdc7b8426d8253c9437777ae51a4639239550f00/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java#L325 - errResp, cancel, err := c.Dispatcher.RegisterReqID(utils.UndefRequestID) + errResp, cancel, err := c.dispatcher.RegisterReqID(undefRequestID) if err != nil { return nil, err } @@ -67,8 +83,8 @@ func (c *Connector) Connect(ctx context.Context, authMethod, proxyBrokerURL stri // create and send CONNECT msg connect := api.CommandConnect{ - ClientVersion: proto.String(utils.ClientVersion), - ProtocolVersion: proto.Int32(utils.ProtoVersion), + ClientVersion: proto.String(ClientVersion), + ProtocolVersion: proto.Int32(ProtoVersion), } if authMethod != "" { connect.AuthMethodName = proto.String(authMethod) @@ -82,7 +98,7 @@ func (c *Connector) Connect(ctx context.Context, authMethod, proxyBrokerURL stri Connect: &connect, } - if err := c.S.SendSimpleCmd(cmd); err != nil { + if err := c.s.SendSimpleCmd(cmd); err != nil { return nil, err } diff --git a/core/conn/connector_test.go b/core/conn/connector_test.go index 0189ed1..85e4c4a 100644 --- a/core/conn/connector_test.go +++ b/core/conn/connector_test.go @@ -21,7 +21,6 @@ import ( "github.com/golang/protobuf/proto" "github.com/wolfstudy/pulsar-client-go/core/frame" "github.com/wolfstudy/pulsar-client-go/pkg/api" - "github.com/wolfstudy/pulsar-client-go/utils" ) func TestConnector(t *testing.T) { @@ -156,7 +155,7 @@ func TestConnector_Error(t *testing.T) { time.Sleep(100 * time.Millisecond) errorMsg := api.CommandError{ - RequestId: proto.Uint64(utils.UndefRequestID), + RequestId: proto.Uint64(undefRequestID), Message: proto.String("there was an error of sorts"), } f := frame.Frame{ @@ -165,7 +164,7 @@ func TestConnector_Error(t *testing.T) { Error: &errorMsg, }, } - if err := dispatcher.NotifyReqID(utils.UndefRequestID, f); err != nil { + if err := dispatcher.NotifyReqID(undefRequestID, f); err != nil { t.Fatalf("HandleReqID() err = %v; nil expected", err) } diff --git a/core/conn/mockserver.go b/core/conn/mockserver.go index 0a073fb..dcf2634 100644 --- a/core/conn/mockserver.go +++ b/core/conn/mockserver.go @@ -19,13 +19,6 @@ import ( "net" ) -// MockPulsarServer emulates a Pulsar server -type MockPulsarServer struct { - Addr string - Errs chan error - Conns chan *Conn -} - func NewMockPulsarServer(ctx context.Context) (*MockPulsarServer, error) { l, err := net.ListenTCP("tcp4", &net.TCPAddr{ IP: net.IPv4zero, @@ -62,9 +55,9 @@ func NewMockPulsarServer(ctx context.Context) (*MockPulsarServer, error) { }() mock.Conns <- &Conn{ - Rc: c, - W: c, - Closedc: make(chan struct{}), + rc: c, + w: c, + closedc: make(chan struct{}), } } }() @@ -72,4 +65,11 @@ func NewMockPulsarServer(ctx context.Context) (*MockPulsarServer, error) { return &mock, nil } +// MockPulsarServer emulates a Pulsar server +type MockPulsarServer struct { + Addr string + Errs chan error + Conns chan *Conn +} + diff --git a/core/frame/framedispatcher.go b/core/frame/framedispatcher.go index 4937fdd..e1d59e1 100644 --- a/core/frame/framedispatcher.go +++ b/core/frame/framedispatcher.go @@ -24,8 +24,8 @@ import ( // NewFrameDispatcher returns an instantiated FrameDispatcher. func NewFrameDispatcher() *Dispatcher { return &Dispatcher{ - ProdSeqIDs: make(map[ProdSeqKey]AsyncResp), - ReqIDs: make(map[uint64]AsyncResp), + prodSeqIDs: make(map[prodSeqKey]asyncResp), + reqIDs: make(map[uint64]asyncResp), } } @@ -38,37 +38,18 @@ type Dispatcher struct { // therefore a single channel is used as their // Respective FrameDispatcher. If the channel is // nil, there's no outstanding request. - GlobalMu sync.Mutex // protects following - Global *AsyncResp + globalMu sync.Mutex // protects following + global *asyncResp // All Responses that are correlated by their // requestID - ReqIDMu sync.Mutex // protects following - ReqIDs map[uint64]AsyncResp + reqIDMu sync.Mutex // protects following + reqIDs map[uint64]asyncResp // All Responses that are correlated by their // (producerID, sequenceID) tuple - ProdSeqIDsMu sync.Mutex // protects following - ProdSeqIDs map[ProdSeqKey]AsyncResp -} - -// AsyncResp manages the state between a request -// and Response. Requestors wait on the `Resp` channel -// for the corResponding Response frame to their request. -// If they are no longer interested in the Response (timeout), -// then the `done` channel is closed, signaling to the Response -// side that the Response is not expected/needed. -type AsyncResp struct { - Resp chan<- Frame - Done <-chan struct{} -} - -// prodSeqKey is a composite lookup key for the dispatchers -// that use producerID and sequenceID to correlate Responses, -// which are the SendReceipt and SendError Responses. -type ProdSeqKey struct { - ProducerID uint64 - SequenceID uint64 + prodSeqIDsMu sync.Mutex // protects following + prodSeqIDs map[prodSeqKey]asyncResp } // RegisterGlobal is used to wait for Responses that have no identifying @@ -85,9 +66,9 @@ func (f *Dispatcher) RegisterGlobal() (Response <-chan Frame, cancel func(), err return } - f.GlobalMu.Lock() - f.Global = nil - f.GlobalMu.Unlock() + f.globalMu.Lock() + f.global = nil + f.globalMu.Unlock() close(done) done = nil @@ -95,16 +76,16 @@ func (f *Dispatcher) RegisterGlobal() (Response <-chan Frame, cancel func(), err Resp := make(chan Frame) - f.GlobalMu.Lock() - if f.Global != nil { - f.GlobalMu.Unlock() + f.globalMu.Lock() + if f.global != nil { + f.globalMu.Unlock() return nil, nil, errors.New("outstanding global request already in progress") } - f.Global = &AsyncResp{ - Resp: Resp, - Done: done, + f.global = &asyncResp{ + resp: Resp, + done: done, } - f.GlobalMu.Unlock() + f.globalMu.Unlock() return Resp, cancel, nil } @@ -112,22 +93,22 @@ func (f *Dispatcher) RegisterGlobal() (Response <-chan Frame, cancel func(), err // NotifyGlobal should be called with Response frames that have // no identifying id (Pong, Connected). func (f *Dispatcher) NotifyGlobal(frame Frame) error { - f.GlobalMu.Lock() - a := f.Global + f.globalMu.Lock() + a := f.global // ensure additional calls to notify // fail with UnexpectedMsg (unless register is called again) - f.Global = nil - f.GlobalMu.Unlock() + f.global = nil + f.globalMu.Unlock() if a == nil { return utils.NewUnexpectedErrMsg(frame.BaseCmd.GetType()) } select { - case a.Resp <- frame: + case a.resp <- frame: // sent Response back to sender return nil - case <-a.Done: + case <-a.done: return utils.NewUnexpectedErrMsg(frame.BaseCmd.GetType()) } } @@ -137,7 +118,7 @@ func (f *Dispatcher) NotifyGlobal(frame Frame) error { // specifically when they're not interested in the Response. It is an error // to have multiple outstanding requests with the same id tuple. func (f *Dispatcher) RegisterProdSeqIDs(producerID, sequenceID uint64) (Response <-chan Frame, cancel func(), err error) { - key := ProdSeqKey{producerID, sequenceID} + key := prodSeqKey{producerID, sequenceID} var mu sync.Mutex done := make(chan struct{}) @@ -148,9 +129,9 @@ func (f *Dispatcher) RegisterProdSeqIDs(producerID, sequenceID uint64) (Response return } - f.ProdSeqIDsMu.Lock() - delete(f.ProdSeqIDs, key) - f.ProdSeqIDsMu.Unlock() + f.prodSeqIDsMu.Lock() + delete(f.prodSeqIDs, key) + f.prodSeqIDsMu.Unlock() close(done) done = nil @@ -158,16 +139,16 @@ func (f *Dispatcher) RegisterProdSeqIDs(producerID, sequenceID uint64) (Response Resp := make(chan Frame) - f.ProdSeqIDsMu.Lock() - if _, ok := f.ProdSeqIDs[key]; ok { - f.ProdSeqIDsMu.Unlock() + f.prodSeqIDsMu.Lock() + if _, ok := f.prodSeqIDs[key]; ok { + f.prodSeqIDsMu.Unlock() return nil, nil, fmt.Errorf("already exists an outstanding Response for producerID %d, sequenceID %d", producerID, sequenceID) } - f.ProdSeqIDs[key] = AsyncResp{ - Resp: Resp, - Done: done, + f.prodSeqIDs[key] = asyncResp{ + resp: Resp, + done: done, } - f.ProdSeqIDsMu.Unlock() + f.prodSeqIDsMu.Unlock() return Resp, cancel, nil } @@ -175,25 +156,25 @@ func (f *Dispatcher) RegisterProdSeqIDs(producerID, sequenceID uint64) (Response // NotifyProdSeqIDs should be called with Response frames that have // (producerID, sequenceID) id tuples to correlate them to their requests. func (f *Dispatcher) NotifyProdSeqIDs(producerID, sequenceID uint64, frame Frame) error { - key := ProdSeqKey{producerID, sequenceID} + key := prodSeqKey{producerID, sequenceID} - f.ProdSeqIDsMu.Lock() + f.prodSeqIDsMu.Lock() // fetch Response channel from cubbyhole - a, ok := f.ProdSeqIDs[key] + a, ok := f.prodSeqIDs[key] // ensure additional calls to notify with same key will // fail with UnexpectedMsg (unless registerProdSeqIDs with same key is called) - delete(f.ProdSeqIDs, key) - f.ProdSeqIDsMu.Unlock() + delete(f.prodSeqIDs, key) + f.prodSeqIDsMu.Unlock() if !ok { return utils.NewUnexpectedErrMsg(frame.BaseCmd.GetType(), producerID, sequenceID) } select { - case a.Resp <- frame: + case a.resp <- frame: // Response was correctly pushed into channel return nil - case <-a.Done: + case <-a.done: return utils.NewUnexpectedErrMsg(frame.BaseCmd.GetType(), producerID, sequenceID) } } @@ -212,9 +193,9 @@ func (f *Dispatcher) RegisterReqID(requestID uint64) (Response <-chan Frame, can return } - f.ReqIDMu.Lock() - delete(f.ReqIDs, requestID) - f.ReqIDMu.Unlock() + f.reqIDMu.Lock() + delete(f.reqIDs, requestID) + f.reqIDMu.Unlock() close(done) done = nil @@ -222,16 +203,16 @@ func (f *Dispatcher) RegisterReqID(requestID uint64) (Response <-chan Frame, can Resp := make(chan Frame) - f.ReqIDMu.Lock() - if _, ok := f.ReqIDs[requestID]; ok { - f.ReqIDMu.Unlock() + f.reqIDMu.Lock() + if _, ok := f.reqIDs[requestID]; ok { + f.reqIDMu.Unlock() return nil, nil, fmt.Errorf("already exists an outstanding Response for requestID %d", requestID) } - f.ReqIDs[requestID] = AsyncResp{ - Resp: Resp, - Done: done, + f.reqIDs[requestID] = asyncResp{ + resp: Resp, + done: done, } - f.ReqIDMu.Unlock() + f.reqIDMu.Unlock() return Resp, cancel, nil } @@ -239,13 +220,13 @@ func (f *Dispatcher) RegisterReqID(requestID uint64) (Response <-chan Frame, can // NotifyReqID should be called with Response frames that have // a requestID to correlate them to their requests. func (f *Dispatcher) NotifyReqID(requestID uint64, frame Frame) error { - f.ReqIDMu.Lock() + f.reqIDMu.Lock() // fetch Response channel from cubbyhole - a, ok := f.ReqIDs[requestID] + a, ok := f.reqIDs[requestID] // ensure additional calls to notifyReqID with same key will // fail with UnexpectedMsg (unless addReqID with same key is called) - delete(f.ReqIDs, requestID) - f.ReqIDMu.Unlock() + delete(f.reqIDs, requestID) + f.reqIDMu.Unlock() if !ok { return utils.NewUnexpectedErrMsg(frame.BaseCmd.GetType(), requestID) @@ -253,10 +234,29 @@ func (f *Dispatcher) NotifyReqID(requestID uint64, frame Frame) error { // send received message to Response channel select { - case a.Resp <- frame: + case a.resp <- frame: // Response was correctly pushed into channel return nil - case <-a.Done: + case <-a.done: return utils.NewUnexpectedErrMsg(frame.BaseCmd.GetType(), requestID) } } + +// asyncResp manages the state between a request +// and Response. Requestors wait on the `resp` channel +// for the corResponding Response frame to their request. +// If they are no longer interested in the Response (timeout), +// then the `done` channel is closed, signaling to the Response +// side that the Response is not expected/needed. +type asyncResp struct { + resp chan<- Frame + done <-chan struct{} +} + +// prodSeqKey is a composite lookup key for the dispatchers +// that use producerID and sequenceID to correlate Responses, +// which are the SendReceipt and SendError Responses. +type prodSeqKey struct { + producerID uint64 + sequenceID uint64 +} diff --git a/core/frame/mocksender.go b/core/frame/mocksender.go index cd56cf3..3c6d251 100644 --- a/core/frame/mocksender.go +++ b/core/frame/mocksender.go @@ -29,14 +29,14 @@ type CmdSender interface { // MockSender implements the sender interface type MockSender struct { - Mu sync.Mutex // protects following + mu sync.Mutex // protects following Frames []Frame - Closedc chan struct{} + closedc chan struct{} } func (m *MockSender) GetFrames() []Frame { - m.Mu.Lock() - defer m.Mu.Unlock() + m.mu.Lock() + defer m.mu.Unlock() cp := make([]Frame, len(m.Frames)) copy(cp, m.Frames) @@ -45,8 +45,8 @@ func (m *MockSender) GetFrames() []Frame { } func (m *MockSender) SendSimpleCmd(cmd api.BaseCommand) error { - m.Mu.Lock() - defer m.Mu.Unlock() + m.mu.Lock() + defer m.mu.Unlock() m.Frames = append(m.Frames, Frame{ BaseCmd: &cmd, @@ -56,8 +56,8 @@ func (m *MockSender) SendSimpleCmd(cmd api.BaseCommand) error { } func (m *MockSender) SendPayloadCmd(cmd api.BaseCommand, metadata api.MessageMetadata, payload []byte) error { - m.Mu.Lock() - defer m.Mu.Unlock() + m.mu.Lock() + defer m.mu.Unlock() m.Frames = append(m.Frames, Frame{ BaseCmd: &cmd, @@ -69,5 +69,5 @@ func (m *MockSender) SendPayloadCmd(cmd api.BaseCommand, metadata api.MessageMet } func (m *MockSender) Closed() <-chan struct{} { - return m.Closedc + return m.closedc } diff --git a/core/manage/client.go b/core/manage/client.go index 7335d3a..47ad7e4 100644 --- a/core/manage/client.go +++ b/core/manage/client.go @@ -15,7 +15,9 @@ package manage import ( "context" + "crypto/tls" "fmt" + "time" "github.com/wolfstudy/pulsar-client-go/core/conn" "github.com/wolfstudy/pulsar-client-go/core/frame" @@ -27,17 +29,50 @@ import ( "github.com/wolfstudy/pulsar-client-go/utils" ) +// authMethodTLS is the name of the TLS authentication +// method, used in the CONNECT message. +const authMethodTLS = "tls" + +// ClientConfig is used to configure a Pulsar client. +type ClientConfig struct { + Addr string // pulsar broker address. May start with pulsar:// + phyAddr string // if set, the TCP connection should be made using this address. This is only ever set during Topic Lookup + DialTimeout time.Duration // timeout to use when establishing TCP connection + TLSConfig *tls.Config // TLS configuration. May be nil, in which case TLS will not be used + Errs chan<- error // asynchronous errors will be sent here. May be nil +} + +// connAddr returns the address that should be used +// for the TCP connection. It defaults to phyAddr if set, +// otherwise Addr. This is to support the proxying through +// a broker, as determined during topic lookup. +func (c ClientConfig) connAddr() string { + if c.phyAddr != "" { + return c.phyAddr + } + return c.Addr +} + +// setDefaults returns a modified config with appropriate zero values set to defaults. +func (c ClientConfig) setDefaults() ClientConfig { + if c.DialTimeout <= 0 { + c.DialTimeout = 5 * time.Second + } + + return c +} + // NewClient returns a Pulsar client for the given configuration options. func NewClient(cfg ClientConfig) (*Client, error) { - cfg = cfg.SetDefaults() + cfg = cfg.setDefaults() var cnx *conn.Conn var err error if cfg.TLSConfig != nil { - cnx, err = conn.NewTLSConn(cfg.ConnAddr(), cfg.TLSConfig, cfg.DialTimeout) + cnx, err = conn.NewTLSConn(cfg.connAddr(), cfg.TLSConfig, cfg.DialTimeout) } else { - cnx, err = conn.NewTCPConn(cfg.ConnAddr(), cfg.DialTimeout) + cnx, err = conn.NewTCPConn(cfg.connAddr(), cfg.DialTimeout) } if err != nil { return nil, err @@ -49,15 +84,15 @@ func NewClient(cfg ClientConfig) (*Client, error) { subs := NewSubscriptions() c := &Client{ - C: cnx, - AsyncErrs: utils.AsyncErrors(cfg.Errs), - - Dispatcher: dispatcher, - Subscriptions: subs, - Connector: conn.NewConnector(cnx, dispatcher), - Pinger: srv.NewPinger(cnx, dispatcher), - Discoverer: srv.NewDiscoverer(cnx, dispatcher, &reqID), - Pubsub: NewPubsub(cnx, dispatcher, subs, &reqID), + c: cnx, + asyncErrs: utils.AsyncErrors(cfg.Errs), + + dispatcher: dispatcher, + subscriptions: subs, + connector: conn.NewConnector(cnx, dispatcher), + pinger: srv.NewPinger(cnx, dispatcher), + discoverer: srv.NewDiscoverer(cnx, dispatcher, &reqID), + pubsub: NewPubsub(cnx, dispatcher, subs, &reqID), } handler := func(f frame.Frame) { @@ -69,12 +104,12 @@ func NewClient(cfg ClientConfig) (*Client, error) { // the connection has been closed and is no longer usable. defer func() { if err := c.Close(); err != nil { - c.AsyncErrs.Send(err) + c.asyncErrs.Send(err) } }() if err := cnx.Read(handler); err != nil { - c.AsyncErrs.Send(err) + c.asyncErrs.Send(err) } }() @@ -84,16 +119,16 @@ func NewClient(cfg ClientConfig) (*Client, error) { // Client is a Pulsar client, capable of sending and receiving // messages and managing the associated state. type Client struct { - C *conn.Conn - AsyncErrs utils.AsyncErrors + c *conn.Conn + asyncErrs utils.AsyncErrors - Dispatcher *frame.Dispatcher + dispatcher *frame.Dispatcher - Subscriptions *Subscriptions - Connector *conn.Connector - Pinger *srv.Pinger - Discoverer *srv.Discoverer - Pubsub *Pubsub + subscriptions *Subscriptions + connector *conn.Connector + pinger *srv.Pinger + discoverer *srv.Discoverer + pubsub *Pubsub } // Closed returns a channel that unblocks when the client's connection @@ -101,13 +136,13 @@ type Client struct { // channel and recreate the Client if closed. // TODO: Rename to Done func (c *Client) Closed() <-chan struct{} { - return c.C.Closed() + return c.c.Closed() } // Close closes the connection. The channel returned from `Closed` will unblock. // The client should no longer be used after calling Close. func (c *Client) Close() error { - return c.C.Close() + return c.c.Close() } // Connect sends a Connect message to the Pulsar server, then @@ -121,7 +156,7 @@ func (c *Client) Close() error { // See "Connection establishment" for more info: // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Connectionestablishment-6pslvw func (c *Client) Connect(ctx context.Context, proxyBrokerURL string) (*api.CommandConnected, error) { - return c.Connector.Connect(ctx, "", proxyBrokerURL) + return c.connector.Connect(ctx, "", proxyBrokerURL) } // ConnectTLS sends a Connect message to the Pulsar server, then @@ -134,14 +169,14 @@ func (c *Client) Connect(ctx context.Context, proxyBrokerURL string) (*api.Comma // See "Connection establishment" for more info: // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Connectionestablishment-6pslvw func (c *Client) ConnectTLS(ctx context.Context, proxyBrokerURL string) (*api.CommandConnected, error) { - return c.Connector.Connect(ctx, utils.AuthMethodTLS, proxyBrokerURL) + return c.connector.Connect(ctx, authMethodTLS, proxyBrokerURL) } // Ping sends a PING message to the Pulsar server, then // waits for either a PONG response or the context to // timeout. func (c *Client) Ping(ctx context.Context) error { - return c.Pinger.Ping(ctx) + return c.pinger.Ping(ctx) } // LookupTopic returns metadata about the given topic. Topic lookup needs @@ -154,13 +189,13 @@ func (c *Client) Ping(ctx context.Context) error { // See "Topic lookup" for more info: // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Topiclookup-rxds6i func (c *Client) LookupTopic(ctx context.Context, topic string, authoritative bool) (*api.CommandLookupTopicResponse, error) { - return c.Discoverer.LookupTopic(ctx, topic, authoritative) + return c.discoverer.LookupTopic(ctx, topic, authoritative) } // NewProducer creates a new producer capable of sending message to the // given topic. func (c *Client) NewProducer(ctx context.Context, topic, producerName string) (*pub.Producer, error) { - return c.Pubsub.Producer(ctx, topic, producerName) + return c.pubsub.Producer(ctx, topic, producerName) } // NewSharedConsumer creates a new shared consumer capable of reading messages from the @@ -168,11 +203,11 @@ func (c *Client) NewProducer(ctx context.Context, topic, producerName string) (* // See "Subscription modes" for more information: // https://pulsar.incubator.apache.org/docs/latest/getting-started/ConceptsAndArchitecture/#Subscriptionmodes-jdrefl func (c *Client) NewSharedConsumer(ctx context.Context, topic, subscriptionName string, queue chan msg.Message) (*sub.Consumer, error) { - return c.Pubsub.Subscribe(ctx, topic, subscriptionName, api.CommandSubscribe_Shared, api.CommandSubscribe_Latest, queue) + return c.pubsub.Subscribe(ctx, topic, subscriptionName, api.CommandSubscribe_Shared, api.CommandSubscribe_Latest, queue) } -func (c *Client) NewConsumerWithCfg(ctx context.Context, cfg ConsumerConfig, queue chan msg.Message) (*sub.Consumer, error) { - return c.Pubsub.SubscribeWithCfg(ctx, cfg, queue) +func (c *Client) NewConsumerWithCfg(ctx context.Context, cfg ManagedConsumerConfig, queue chan msg.Message) (*sub.Consumer, error) { + return c.pubsub.SubscribeWithCfg(ctx, cfg, queue) } // NewExclusiveConsumer creates a new exclusive consumer capable of reading messages from the @@ -184,7 +219,7 @@ func (c *Client) NewExclusiveConsumer(ctx context.Context, topic, subscriptionNa if earliest { initialPosition = api.CommandSubscribe_Earliest } - return c.Pubsub.Subscribe(ctx, topic, subscriptionName, api.CommandSubscribe_Exclusive, initialPosition, queue) + return c.pubsub.Subscribe(ctx, topic, subscriptionName, api.CommandSubscribe_Exclusive, initialPosition, queue) } // NewFailoverConsumer creates a new failover consumer capable of reading messages from the @@ -192,7 +227,7 @@ func (c *Client) NewExclusiveConsumer(ctx context.Context, topic, subscriptionNa // See "Subscription modes" for more information: // https://pulsar.incubator.apache.org/docs/latest/getting-started/ConceptsAndArchitecture/#Subscriptionmodes-jdrefl func (c *Client) NewFailoverConsumer(ctx context.Context, topic, subscriptionName string, queue chan msg.Message) (*sub.Consumer, error) { - return c.Pubsub.Subscribe(ctx, topic, subscriptionName, api.CommandSubscribe_Failover, api.CommandSubscribe_Latest, queue) + return c.pubsub.Subscribe(ctx, topic, subscriptionName, api.CommandSubscribe_Failover, api.CommandSubscribe_Latest, queue) } // handleFrame is called by the underlaying core with @@ -207,70 +242,70 @@ func (c *Client) handleFrame(f frame.Frame) { // Solicited responses with NO response ID associated case api.BaseCommand_CONNECTED: - err = c.Dispatcher.NotifyGlobal(f) + err = c.dispatcher.NotifyGlobal(f) case api.BaseCommand_PONG: - err = c.Dispatcher.NotifyGlobal(f) + err = c.dispatcher.NotifyGlobal(f) // Solicited responses with a requestID to correlate // it to its request case api.BaseCommand_SUCCESS: - err = c.Dispatcher.NotifyReqID(f.BaseCmd.GetSuccess().GetRequestId(), f) + err = c.dispatcher.NotifyReqID(f.BaseCmd.GetSuccess().GetRequestId(), f) case api.BaseCommand_ERROR: - err = c.Dispatcher.NotifyReqID(f.BaseCmd.GetError().GetRequestId(), f) + err = c.dispatcher.NotifyReqID(f.BaseCmd.GetError().GetRequestId(), f) case api.BaseCommand_LOOKUP_RESPONSE: - err = c.Dispatcher.NotifyReqID(f.BaseCmd.GetLookupTopicResponse().GetRequestId(), f) + err = c.dispatcher.NotifyReqID(f.BaseCmd.GetLookupTopicResponse().GetRequestId(), f) case api.BaseCommand_PARTITIONED_METADATA_RESPONSE: - err = c.Dispatcher.NotifyReqID(f.BaseCmd.GetPartitionMetadataResponse().GetRequestId(), f) + err = c.dispatcher.NotifyReqID(f.BaseCmd.GetPartitionMetadataResponse().GetRequestId(), f) case api.BaseCommand_PRODUCER_SUCCESS: - err = c.Dispatcher.NotifyReqID(f.BaseCmd.GetProducerSuccess().GetRequestId(), f) + err = c.dispatcher.NotifyReqID(f.BaseCmd.GetProducerSuccess().GetRequestId(), f) // Solicited responses with a (producerID, sequenceID) tuple to correlate // it to its request case api.BaseCommand_SEND_RECEIPT: msg := f.BaseCmd.GetSendReceipt() - err = c.Dispatcher.NotifyProdSeqIDs(msg.GetProducerId(), msg.GetSequenceId(), f) + err = c.dispatcher.NotifyProdSeqIDs(msg.GetProducerId(), msg.GetSequenceId(), f) case api.BaseCommand_SEND_ERROR: msg := f.BaseCmd.GetSendError() - err = c.Dispatcher.NotifyProdSeqIDs(msg.GetProducerId(), msg.GetSequenceId(), f) + err = c.dispatcher.NotifyProdSeqIDs(msg.GetProducerId(), msg.GetSequenceId(), f) // Unsolicited responses that have a producer ID case api.BaseCommand_CLOSE_PRODUCER: - err = c.Subscriptions.HandleCloseProducer(f.BaseCmd.GetCloseProducer().GetProducerId(), f) + err = c.subscriptions.HandleCloseProducer(f.BaseCmd.GetCloseProducer().GetProducerId(), f) // Unsolicited responses that have a consumer ID case api.BaseCommand_CLOSE_CONSUMER: - err = c.Subscriptions.HandleCloseConsumer(f.BaseCmd.GetCloseConsumer().GetConsumerId(), f) + err = c.subscriptions.HandleCloseConsumer(f.BaseCmd.GetCloseConsumer().GetConsumerId(), f) case api.BaseCommand_REACHED_END_OF_TOPIC: - err = c.Subscriptions.HandleReachedEndOfTopic(f.BaseCmd.GetReachedEndOfTopic().GetConsumerId(), f) + err = c.subscriptions.HandleReachedEndOfTopic(f.BaseCmd.GetReachedEndOfTopic().GetConsumerId(), f) case api.BaseCommand_MESSAGE: - err = c.Subscriptions.HandleMessage(f.BaseCmd.GetMessage().GetConsumerId(), f) + err = c.subscriptions.HandleMessage(f.BaseCmd.GetMessage().GetConsumerId(), f) // Unsolicited responses case api.BaseCommand_PING: - err = c.Pinger.HandlePing(msgType, f.BaseCmd.GetPing()) + err = c.pinger.HandlePing(msgType, f.BaseCmd.GetPing()) // In the failover subscription mode, // all consumers receive ACTIVE_CONSUMER_CHANGE when a new subscriber is created or a subscriber exits. case api.BaseCommand_ACTIVE_CONSUMER_CHANGE: - err = c.Subscriptions.HandleActiveConsumerChange(f.BaseCmd.GetActiveConsumerChange().GetConsumerId(), f) + err = c.subscriptions.HandleActiveConsumerChange(f.BaseCmd.GetActiveConsumerChange().GetConsumerId(), f) default: err = fmt.Errorf("unhandled message of type %q", f.BaseCmd.GetType()) } if err != nil { - c.AsyncErrs.Send(err) + c.asyncErrs.Send(err) } } diff --git a/core/manage/client_test.go b/core/manage/client_test.go index c463273..54724ed 100644 --- a/core/manage/client_test.go +++ b/core/manage/client_test.go @@ -62,7 +62,7 @@ func TestClient_Int_PubSub(t *testing.T) { t.Logf("PONG received") } - topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", utils.RandString(32)) + topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", RandString(32)) t.Logf("test topic: %q", topic) topicResp, err := c.LookupTopic(ctx, topic, false) @@ -83,7 +83,7 @@ func TestClient_Int_PubSub(t *testing.T) { // create multiple consumers consumers := make([]*sub.Consumer, 32) - subName := utils.RandString(16) + subName := RandString(16) for i := range consumers { name := fmt.Sprintf("%s-%d", subName, i) consumers[i], err = c.NewExclusiveConsumer(ctx, topic, name, false, make(chan msg.Message, N)) @@ -208,7 +208,7 @@ func TestClient_Int_ServerInitiatedTopicClose(t *testing.T) { t.Logf("PONG received") } - topicName := fmt.Sprintf("test-%s", utils.RandString(32)) + topicName := fmt.Sprintf("test-%s", RandString(32)) topic := fmt.Sprintf("persistent://sample/standalone/ns1/%s", topicName) t.Logf("topic: %q", topic) @@ -221,13 +221,13 @@ func TestClient_Int_ServerInitiatedTopicClose(t *testing.T) { } t.Log(topicResp.String()) - subscriptionName := utils.RandString(32) + subscriptionName := RandString(32) topicConsumer, err := c.NewExclusiveConsumer(ctx, topic, subscriptionName, false, make(chan msg.Message, 1)) if err != nil { t.Fatal(err) } - producerName := utils.RandString(32) + producerName := RandString(32) topicProducer, err := c.NewProducer(ctx, topic, producerName) if err != nil { t.Fatal(err) @@ -327,7 +327,7 @@ func TestClient_Int_Unsubscribe(t *testing.T) { t.Logf("PONG received") } - topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", utils.RandString(32)) + topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", RandString(32)) t.Logf("test topic: %q", topic) topicResp, err := c.LookupTopic(ctx, topic, false) @@ -351,7 +351,7 @@ func TestClient_Int_Unsubscribe(t *testing.T) { } t.Log(topicResp.String()) - topicConsumer, err := c.NewExclusiveConsumer(ctx, topic, utils.RandString(32), false, make(chan msg.Message, 1)) + topicConsumer, err := c.NewExclusiveConsumer(ctx, topic, RandString(32), false, make(chan msg.Message, 1)) if err != nil { t.Fatal(err) } @@ -400,7 +400,7 @@ func TestClient_Int_RedeliverOverflow(t *testing.T) { t.Logf("PONG received") } - topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", utils.RandString(32)) + topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", RandString(32)) t.Logf("test topic: %q", topic) topicResp, err := c.LookupTopic(ctx, topic, false) @@ -426,7 +426,7 @@ func TestClient_Int_RedeliverOverflow(t *testing.T) { } // create single consumer with buffer size 1 - cs, err := c.NewSharedConsumer(ctx, topic, utils.RandString(16), make(chan msg.Message, 1)) + cs, err := c.NewSharedConsumer(ctx, topic, RandString(16), make(chan msg.Message, 1)) if err != nil { t.Fatal(err) } @@ -532,7 +532,7 @@ func TestClient_Int_RedeliverAll(t *testing.T) { t.Logf("PONG received") } - topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", utils.RandString(32)) + topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", RandString(32)) t.Logf("test topic: %q", topic) topicResp, err := c.LookupTopic(ctx, topic, false) @@ -558,7 +558,7 @@ func TestClient_Int_RedeliverAll(t *testing.T) { } // create single consumer with buffer size N - cs, err := c.NewExclusiveConsumer(ctx, topic, utils.RandString(16), false, make(chan msg.Message, N)) + cs, err := c.NewExclusiveConsumer(ctx, topic, RandString(16), false, make(chan msg.Message, N)) if err != nil { t.Fatal(err) } diff --git a/core/manage/managed_client.go b/core/manage/managed_client.go index a04a179..9eccdc1 100644 --- a/core/manage/managed_client.go +++ b/core/manage/managed_client.go @@ -15,7 +15,6 @@ package manage import ( "context" - "crypto/tls" "errors" "sync" "time" @@ -23,13 +22,9 @@ import ( "github.com/wolfstudy/pulsar-client-go/utils" ) -// ClientConfig is used to configure a Pulsar client. -type ClientConfig struct { - Addr string // pulsar broker address. May start with pulsar:// - phyAddr string // if set, the TCP connection should be made using this address. This is only ever set during Topic Lookup - DialTimeout time.Duration // timeout to use when establishing TCP connection - TLSConfig *tls.Config // TLS configuration. May be nil, in which case TLS will not be used - Errs chan<- error // asynchronous errors will be sent here. May be nil +// ManagedClientConfig is used to configure a ManagedClient. +type ManagedClientConfig struct { + ClientConfig PingFrequency time.Duration // how often to PING server PingTimeout time.Duration // how long to wait for PONG response @@ -38,49 +33,29 @@ type ClientConfig struct { MaxReconnectDelay time.Duration // maximum time to wait to attempt to reconnect Client } -// ConnAddr returns the address that should be used -// for the TCP connection. It defaults to phyAddr if set, -// otherwise Addr. This is to support the proxying through -// a broker, as determined during topic lookup. -func (c ClientConfig) ConnAddr() string { - if c.phyAddr != "" { - return c.phyAddr - } - return c.Addr -} - -// setDefaults returns a modified config with appropriate zero values set to defaults. -func (c ClientConfig) SetDefaults() ClientConfig { - if c.DialTimeout <= 0 { - c.DialTimeout = 5 * time.Second - } - - return c -} - // setDefaults returns a modified config with appropriate zero values set to defaults. -func (c ClientConfig) setDefaults() ClientConfig { - if c.PingFrequency <= 0 { - c.PingFrequency = 30 * time.Second // default used by Java client +func (m ManagedClientConfig) setDefaults() ManagedClientConfig { + if m.PingFrequency <= 0 { + m.PingFrequency = 30 * time.Second // default used by Java client } - if c.PingTimeout <= 0 { - c.PingTimeout = c.PingFrequency / 2 + if m.PingTimeout <= 0 { + m.PingTimeout = m.PingFrequency / 2 } - if c.ConnectTimeout <= 0 { - c.ConnectTimeout = 5 * time.Second + if m.ConnectTimeout <= 0 { + m.ConnectTimeout = 5 * time.Second } - if c.InitialReconnectDelay <= 0 { - c.InitialReconnectDelay = 1 * time.Second + if m.InitialReconnectDelay <= 0 { + m.InitialReconnectDelay = 1 * time.Second } - if c.MaxReconnectDelay <= 0 { - c.MaxReconnectDelay = 2 * time.Minute + if m.MaxReconnectDelay <= 0 { + m.MaxReconnectDelay = 2 * time.Minute } - return c + return m } // NewManagedClient returns a ManagedClient for the given address. The // Client will be created and monitored in the background. -func NewManagedClient(cfg ClientConfig) *ManagedClient { +func NewManagedClient(cfg ManagedClientConfig) *ManagedClient { cfg = cfg.setDefaults() m := ManagedClient{ @@ -101,7 +76,7 @@ func NewManagedClient(cfg ClientConfig) *ManagedClient { // ManagedClient wraps a Client with re-connect and // connection management logic. type ManagedClient struct { - cfg ClientConfig + cfg ManagedClientConfig asyncErrs utils.AsyncErrors @@ -197,7 +172,7 @@ func (m *ManagedClient) unset() { // newClient attempts to create a Client and perform a Connect request. func (m *ManagedClient) newClient(ctx context.Context) (*Client, error) { - client, err := NewClient(m.cfg) + client, err := NewClient(m.cfg.ClientConfig) if err != nil { return nil, err } diff --git a/core/manage/managed_client_pool.go b/core/manage/managed_client_pool.go index 38e3724..d2aac21 100644 --- a/core/manage/managed_client_pool.go +++ b/core/manage/managed_client_pool.go @@ -54,7 +54,7 @@ type clientPoolKey struct { // Get returns the ManagedClient for the given client configuration. // First the cache is checked for an existing client. If one doesn't exist, // a new one is created and cached, then returned. -func (m *ClientPool) Get(cfg ClientConfig) *ManagedClient { +func (m *ClientPool) Get(cfg ManagedClientConfig) *ManagedClient { key := clientPoolKey{ logicalAddr: strings.TrimPrefix(cfg.Addr, "pulsar://"), dialTimeout: cfg.DialTimeout, @@ -107,7 +107,7 @@ const maxTopicLookupRedirects = 8 // the ManagedClient for the discovered topic information. // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Topiclookup-6g0lo // incubator-pulsar/pulsar-client/src/main/java/org/apache/pulsar/client/impl/BinaryProtoLookupService.java -func (m *ClientPool) ForTopic(ctx context.Context, cfg ClientConfig, topic string) (*ManagedClient, error) { +func (m *ClientPool) ForTopic(ctx context.Context, cfg ManagedClientConfig, topic string) (*ManagedClient, error) { // For initial lookup request, authoritative should == false var authoritative bool serviceAddr := cfg.Addr diff --git a/core/manage/managed_client_pool_test.go b/core/manage/managed_client_pool_test.go index 31a8962..b7369da 100644 --- a/core/manage/managed_client_pool_test.go +++ b/core/manage/managed_client_pool_test.go @@ -32,9 +32,9 @@ func TestManagedClientPool(t *testing.T) { } mcp := NewClientPool() - mc := mcp.Get(ClientConfig{ + mc := mcp.Get(ManagedClientConfig{ClientConfig: ClientConfig{ Addr: srv.Addr, - }) + }}) expectedFrames := []api.BaseCommand_Type{ api.BaseCommand_CONNECT, @@ -49,9 +49,9 @@ func TestManagedClientPool(t *testing.T) { } // Assert additional call to Get returns same object - mc2 := mcp.Get(ClientConfig{ + mc2 := mcp.Get(ManagedClientConfig{ClientConfig: ClientConfig{ Addr: srv.Addr, - }) + }}) if mc != mc2 { t.Fatalf("Get() returned %v; expected identical result from first call to Get() %v", mc2, mc) } @@ -67,9 +67,9 @@ func TestManagedClientPool_Stop(t *testing.T) { } mcp := NewClientPool() - mc := mcp.Get(ClientConfig{ + mc := mcp.Get(ManagedClientConfig{ClientConfig: ClientConfig{ Addr: srv.Addr, - }) + }}) expectedFrames := []api.BaseCommand_Type{ api.BaseCommand_CONNECT, @@ -83,9 +83,9 @@ func TestManagedClientPool_Stop(t *testing.T) { // from the pool time.Sleep(200 * time.Millisecond) - mc2 := mcp.Get(ClientConfig{ + mc2 := mcp.Get(ManagedClientConfig{ClientConfig: ClientConfig{ Addr: srv.Addr, - }) + }}) if mc == mc2 { t.Fatal("Get() returned same ManagedClient as previous; expected different object after calling ManagedClient.Stop()") } @@ -111,14 +111,14 @@ func TestManagedClientPool_ForTopic(t *testing.T) { topic := "test" cp := NewClientPool() - mc, err := cp.ForTopic(ctx, ClientConfig{ + mc, err := cp.ForTopic(ctx, ManagedClientConfig{ClientConfig: ClientConfig{ Addr: primarySrv.Addr, - }, topic) + }}, topic) if err != nil { t.Fatalf("ForTopic() err = %v; expected nil", err) } - if got, expected := mc.cfg.ConnAddr(), primarySrv.Addr; got != expected { + if got, expected := mc.cfg.connAddr(), primarySrv.Addr; got != expected { t.Fatalf("ManagedClient address = %q; expected %q", got, expected) } else { t.Logf("ManagedClient address = %q", got) @@ -139,9 +139,9 @@ func TestManagedClientPool_ForTopic_Failed(t *testing.T) { primarySrv.SetTopicLookupResp(topic, primarySrv.Addr, api.CommandLookupTopicResponse_Failed, false) cp := NewClientPool() - _, err = cp.ForTopic(ctx, ClientConfig{ + _, err = cp.ForTopic(ctx, ManagedClientConfig{ClientConfig: ClientConfig{ Addr: primarySrv.Addr, - }, topic) + }}, topic) if err == nil { t.Fatalf("ForTopic() err = %v; expected non-nil", err) @@ -163,15 +163,15 @@ func TestManagedClientPool_ForTopic_Proxy(t *testing.T) { primarySrv.SetTopicLookupResp(topic, brokerURL, api.CommandLookupTopicResponse_Connect, true) cp := NewClientPool() - mc, err := cp.ForTopic(ctx, ClientConfig{ + mc, err := cp.ForTopic(ctx, ManagedClientConfig{ClientConfig: ClientConfig{ Addr: primarySrv.Addr, - }, topic) + }}, topic) if err != nil { t.Fatalf("ForTopic() err = %v; expected nil", err) } // original broker addr should be used as physical address - if got, expected := mc.cfg.ConnAddr(), primarySrv.Addr; got != expected { + if got, expected := mc.cfg.connAddr(), primarySrv.Addr; got != expected { t.Fatalf("ManagedClient address = %q; expected %q", got, expected) } else { t.Logf("ManagedClient address = %q", got) @@ -209,14 +209,14 @@ func TestManagedClientPool_ForTopic_Connect(t *testing.T) { cp := NewClientPool() - mc, err := cp.ForTopic(ctx, ClientConfig{ + mc, err := cp.ForTopic(ctx, ManagedClientConfig{ClientConfig: ClientConfig{ Addr: primarySrv.Addr, - }, topic) + }}, topic) if err != nil { t.Fatalf("ForTopic() err = %v; expected nil", err) } - if got, expected := mc.cfg.ConnAddr(), topicSrv.Addr; got != expected { + if got, expected := mc.cfg.connAddr(), topicSrv.Addr; got != expected { t.Fatalf("ManagedClient address = %q; expected %q", got, expected) } else { t.Logf("redirected ManagedClient address = %q", got) @@ -247,14 +247,14 @@ func TestManagedClientPool_ForTopic_Redirect(t *testing.T) { cp := NewClientPool() - mc, err := cp.ForTopic(ctx, ClientConfig{ + mc, err := cp.ForTopic(ctx, ManagedClientConfig{ClientConfig: ClientConfig{ Addr: primarySrv.Addr, - }, topic) + }}, topic) if err != nil { t.Fatalf("ForTopic() err = %v; expected nil", err) } - if got, expected := mc.cfg.ConnAddr(), topicSrv.Addr; got != expected { + if got, expected := mc.cfg.connAddr(), topicSrv.Addr; got != expected { t.Fatalf("ManagedClient address = %q; expected %q", got, expected) } else { t.Logf("redirected ManagedClient address = %q", got) @@ -286,9 +286,9 @@ func TestManagedClientPool_ForTopic_RedirectLoop(t *testing.T) { cp := NewClientPool() - _, err = cp.ForTopic(ctx, ClientConfig{ + _, err = cp.ForTopic(ctx, ManagedClientConfig{ClientConfig: ClientConfig{ Addr: primarySrv.Addr, - }, topic) + }}, topic) if err == nil { t.Fatalf("ForTopic() err = %v; expected non-nil", err) diff --git a/core/manage/managed_client_test.go b/core/manage/managed_client_test.go index 42dca51..ae78efd 100644 --- a/core/manage/managed_client_test.go +++ b/core/manage/managed_client_test.go @@ -31,9 +31,9 @@ func TestManagedClient(t *testing.T) { t.Fatal(err) } - mc := NewManagedClient(ClientConfig{ + mc := NewManagedClient(ManagedClientConfig{ClientConfig: ClientConfig{ Addr: srv.Addr, - }) + }}) defer mc.Stop() expectedFrames := []api.BaseCommand_Type{ @@ -64,9 +64,9 @@ func TestManagedClient_SrvClosed(t *testing.T) { t.Fatal(err) } - mc := NewManagedClient(ClientConfig{ + mc := NewManagedClient(ManagedClientConfig{ClientConfig: ClientConfig{ Addr: srv.Addr, - }) + }}) defer mc.Stop() // repeatedly close the connection from the server's end; @@ -108,8 +108,10 @@ func TestManagedClient_PingFailure(t *testing.T) { srv.SetIgnorePings(true) // Set client to ping every 1/2 second - mc := NewManagedClient(ClientConfig{ - Addr: srv.Addr, + mc := NewManagedClient(ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: srv.Addr, + }, PingFrequency: 500 * time.Millisecond, }) defer mc.Stop() @@ -142,8 +144,10 @@ func TestManagedClient_ConnectFailure(t *testing.T) { // then later enable them srv.SetIgnoreConnects(true) - mc := NewManagedClient(ClientConfig{ - Addr: srv.Addr, + mc := NewManagedClient(ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: srv.Addr, + }, ConnectTimeout: time.Second, // shorten connect timeout since no CONNECT response is expected }) defer mc.Stop() @@ -180,9 +184,9 @@ func TestManagedClient_Stop(t *testing.T) { t.Fatal(err) } - mc := NewManagedClient(ClientConfig{ + mc := NewManagedClient(ManagedClientConfig{ClientConfig: ClientConfig{ Addr: srv.Addr, - }) + }}) defer mc.Stop() // wait for the client to connect diff --git a/core/manage/managed_consumer.go b/core/manage/managed_consumer.go index 9fd7ea8..0c0fe28 100644 --- a/core/manage/managed_consumer.go +++ b/core/manage/managed_consumer.go @@ -41,7 +41,7 @@ const ( // and each message is only distributed to one consumer. // When the consumer disconnects, all messages sent to him // but not confirmed will be rescheduled and distributed to other surviving consumers. - SubscriptionModeShard // 2 + SubscriptionModeShard // 2 // SubscriptionModeFailover multiple consumers can be bound to the same subscription. // Consumers will be sorted in lexicographic order, @@ -49,17 +49,17 @@ const ( // This consumer is called the master consumer. // When the master consumer is disconnected, // all messages (unconfirmed and subsequently entered) will be distributed to the next consumer in the queue. - SubscriptionModeFailover // 3 + SubscriptionModeFailover // 3 - SubscriptionModeKeyShared //4 + SubscriptionModeKeyShared //4 ) // ErrorInvalidSubMode When SubscriptionMode is not one of SubscriptionModeExclusive, SubscriptionModeShard, SubscriptionModeFailover var ErrorInvalidSubMode = errors.New("invalid subscription mode") -// ConsumerConfig is used to configure a ManagedConsumer. -type ConsumerConfig struct { - ClientConfig +// ManagedConsumerConfig is used to configure a ManagedConsumer. +type ManagedConsumerConfig struct { + ManagedClientConfig Topic string Name string // subscription name @@ -74,8 +74,8 @@ type ConsumerConfig struct { AckTimeoutMillis time.Duration } -// SetDefaults returns a modified config with appropriate zero values set to defaults. -func (m ConsumerConfig) SetDefaults() ConsumerConfig { +// setDefaults returns a modified config with appropriate zero values set to defaults. +func (m ManagedConsumerConfig) setDefaults() ManagedConsumerConfig { if m.NewConsumerTimeout <= 0 { m.NewConsumerTimeout = 5 * time.Second } @@ -95,8 +95,8 @@ func (m ConsumerConfig) SetDefaults() ConsumerConfig { // NewManagedConsumer returns an initialized ManagedConsumer. It will create and recreate // a Consumer for the given discovery address and topic on a background goroutine. -func NewManagedConsumer(cp *ClientPool, cfg ConsumerConfig) *ManagedConsumer { - cfg = cfg.SetDefaults() +func NewManagedConsumer(cp *ClientPool, cfg ManagedConsumerConfig) *ManagedConsumer { + cfg = cfg.setDefaults() m := &ManagedConsumer{ clientPool: cp, @@ -108,9 +108,9 @@ func NewManagedConsumer(cp *ClientPool, cfg ConsumerConfig) *ManagedConsumer { if cfg.SubMode == SubscriptionModeShard || cfg.SubMode == SubscriptionModeKeyShared { //TODO:end with `-partition-d%` if !strings.Contains(cfg.Topic, "-partition-") && cfg.AckTimeoutMillis != 0 { - m.UnAckTracker = NewUnackedMessageTracker() - m.UnAckTracker.consumer = m - m.UnAckTracker.Start(int64(cfg.AckTimeoutMillis)) + m.unAckTracker = NewUnackedMessageTracker() + m.unAckTracker.consumer = m + m.unAckTracker.Start(int64(cfg.AckTimeoutMillis)) } } @@ -119,84 +119,10 @@ func NewManagedConsumer(cp *ClientPool, cfg ConsumerConfig) *ManagedConsumer { return m } -// NewManagedConsumer returns an initialized ManagedConsumer. It will create and recreate -// a Consumer for the given discovery address and topic on a background goroutine. -func NewPartitionManagedConsumer(cp *ClientPool, cfg ConsumerConfig) (*ManagedPartitionConsumer, error) { - cfg = cfg.SetDefaults() - ctx := context.Background() - - mpc := &ManagedPartitionConsumer{ - clientPool: cp, - cfg: cfg, - asyncErrs: utils.AsyncErrors(cfg.Errs), - queue: make(chan msg.Message, cfg.QueueSize), - MConsumer: make([]*ManagedConsumer, 0), - } - - if cfg.SubMode == SubscriptionModeShard || cfg.SubMode == SubscriptionModeKeyShared { - if cfg.AckTimeoutMillis != 0 { - mpc.UnAckTracker = NewUnackedMessageTracker() - mpc.UnAckTracker.partitionConsumer = mpc - mpc.UnAckTracker.Start(int64(cfg.AckTimeoutMillis)) - } - } - - manageClient := cp.Get(cfg.ClientConfig) - - client, err := manageClient.Get(ctx) - if err != nil { - log.Errorf("create client error:%s", err.Error()) - return nil, err - } - - res, err := client.Discoverer.PartitionedMetadata(ctx, cfg.Topic) - if err != nil { - log.Errorf("get partition metadata error:%s", err.Error()) - return nil, err - } - numPartitions := res.GetPartitions() - topicName := cfg.Topic - for i := 0; uint32(i) < numPartitions; i++ { - cfg.Topic = fmt.Sprintf("%s-partition-%d", topicName, i) - mpc.MConsumer = append(mpc.MConsumer, NewManagedConsumer(cp, cfg)) - } - - go mpc.getMessageFromSubConsumer(ctx) - - return mpc, nil -} - -type ManagedPartitionConsumer struct { - clientPool *ClientPool - cfg ConsumerConfig - asyncErrs utils.AsyncErrors - - queue chan msg.Message - - mu sync.RWMutex // protects following - MConsumer []*ManagedConsumer - UnAckTracker *UnackedMessageTracker -} - -func (mpc *ManagedPartitionConsumer) getMessageFromSubConsumer(ctx context.Context) chan msg.Message { - for i := 0; i < len(mpc.MConsumer); i++ { - go func(index int) { - log.Infof("receive message form index:%d", index) - err := mpc.MConsumer[index].ReceiveAsync(ctx, mpc.queue) - if err != nil { - log.Errorf("receive message error:%s", err.Error()) - return - } - }(i) - } - return mpc.queue - -} - // ManagedConsumer wraps a Consumer with reconnect logic. type ManagedConsumer struct { clientPool *ClientPool - cfg ConsumerConfig + cfg ManagedConsumerConfig asyncErrs utils.AsyncErrors queue chan msg.Message @@ -204,36 +130,7 @@ type ManagedConsumer struct { mu sync.RWMutex // protects following consumer *sub.Consumer // either consumer is nil and wait isn't or vice versa waitc chan struct{} // if consumer is nil, this will unblock when it's been re-set - UnAckTracker *UnackedMessageTracker -} - -func (mpc *ManagedPartitionConsumer) Receive(ctx context.Context) (msg.Message, error) { - for { - select { - case tmpMsg, ok := <-mpc.queue: - if ok { - if mpc.UnAckTracker != nil { - log.Debugf("receive add untrack: key: %s", tmpMsg.Meta.GetPartitionKey(), string(tmpMsg.Payload)) - mpc.UnAckTracker.Add(tmpMsg.Msg.GetMessageId()) - } - return tmpMsg, nil - } - - case <-ctx.Done(): - return msg.Message{}, ctx.Err() - } - } -} - -// Ack acquires a consumer and Sends an ACK message for the given message. -func (mpc *ManagedPartitionConsumer) Ack(ctx context.Context, msg msg.Message) error { - - if mpc.UnAckTracker != nil { - log.Debugf("ack remove untrack: key: %s", msg.Meta.GetPartitionKey(), string(msg.Payload)) - mpc.UnAckTracker.Remove(msg.Msg.GetMessageId()) - } - - return mpc.MConsumer[msg.Msg.GetMessageId().GetPartition()].Ack(ctx, msg) + unAckTracker *UnackedMessageTracker } // Ack acquires a consumer and Sends an ACK message for the given message. @@ -254,8 +151,8 @@ func (m *ManagedConsumer) Ack(ctx context.Context, msg msg.Message) error { return ctx.Err() } } - if m.UnAckTracker != nil { - m.UnAckTracker.Remove(msg.Msg.GetMessageId()) + if m.unAckTracker != nil { + m.unAckTracker.Remove(msg.Msg.GetMessageId()) } return consumer.Ack(msg) } @@ -292,8 +189,8 @@ func (m *ManagedConsumer) Receive(ctx context.Context) (msg.Message, error) { select { case tmpMsg, ok := <-m.queue: if ok { - if m.UnAckTracker != nil { - m.UnAckTracker.Add(tmpMsg.Msg.GetMessageId()) + if m.unAckTracker != nil { + m.unAckTracker.Add(tmpMsg.Msg.GetMessageId()) } return tmpMsg, nil } @@ -368,8 +265,8 @@ CONSUMER: case tmpMsg := <-m.queue: msgs <- tmpMsg - if m.UnAckTracker != nil { - m.UnAckTracker.Add(tmpMsg.Msg.GetMessageId()) + if m.unAckTracker != nil { + m.unAckTracker.Add(tmpMsg.Msg.GetMessageId()) } if receivedSinceFlow++; receivedSinceFlow >= highwater { @@ -429,7 +326,7 @@ func (m *ManagedConsumer) unset() { // newConsumer attempts to create a Consumer. func (m *ManagedConsumer) newConsumer(ctx context.Context) (*sub.Consumer, error) { - mc, err := m.clientPool.ForTopic(ctx, m.cfg.ClientConfig, m.cfg.Topic) + mc, err := m.clientPool.ForTopic(ctx, m.cfg.ManagedClientConfig, m.cfg.Topic) if err != nil { return nil, err } @@ -534,8 +431,8 @@ func (m *ManagedConsumer) RedeliverUnacknowledged(ctx context.Context) error { return ctx.Err() } } - if m.UnAckTracker != nil { - m.UnAckTracker.clear() + if m.unAckTracker != nil { + m.unAckTracker.clear() } return consumer.RedeliverUnacknowledged(ctx) } @@ -585,20 +482,13 @@ func (m *ManagedConsumer) Unsubscribe(ctx context.Context) error { return ctx.Err() } } - if m.UnAckTracker != nil { - m.UnAckTracker.Stop() + if m.unAckTracker != nil { + m.unAckTracker.Stop() } return consumer.Unsubscribe(ctx) } } -func (mpc *ManagedPartitionConsumer) Unsubscribe(ctx context.Context) error { - for _, consumer := range mpc.MConsumer { - return consumer.Unsubscribe(ctx) - } - return nil -} - // Monitor a scoped deferrable lock func (m *ManagedConsumer) Monitor() func() { m.mu.Lock() @@ -608,18 +498,128 @@ func (m *ManagedConsumer) Monitor() func() { // Close consumer func (m *ManagedConsumer) Close(ctx context.Context) error { defer m.Monitor()() - if m.UnAckTracker != nil { - m.UnAckTracker.Stop() + if m.unAckTracker != nil { + m.unAckTracker.Stop() } return m.consumer.Close(ctx) } +// NewManagedPartitionConsumer returns an initialized ManagedPartitionConsumer. It will create and recreate +// a Consumer for the given discovery address and topic on a background goroutine. +func NewManagedPartitionConsumer(cp *ClientPool, cfg ManagedConsumerConfig) (*ManagedPartitionConsumer, error) { + cfg = cfg.setDefaults() + ctx := context.Background() + + mpc := &ManagedPartitionConsumer{ + clientPool: cp, + cfg: cfg, + asyncErrs: utils.AsyncErrors(cfg.Errs), + queue: make(chan msg.Message, cfg.QueueSize), + managedConsumers: make([]*ManagedConsumer, 0), + } + + if cfg.SubMode == SubscriptionModeShard || cfg.SubMode == SubscriptionModeKeyShared { + if cfg.AckTimeoutMillis != 0 { + mpc.unAckTracker = NewUnackedMessageTracker() + mpc.unAckTracker.partitionConsumer = mpc + mpc.unAckTracker.Start(int64(cfg.AckTimeoutMillis)) + } + } + + managedClient := cp.Get(cfg.ManagedClientConfig) + + client, err := managedClient.Get(ctx) + if err != nil { + log.Errorf("create client error:%s", err.Error()) + return nil, err + } + + res, err := client.discoverer.PartitionedMetadata(ctx, cfg.Topic) + if err != nil { + log.Errorf("get partition metadata error:%s", err.Error()) + return nil, err + } + numPartitions := res.GetPartitions() + topicName := cfg.Topic + for i := 0; uint32(i) < numPartitions; i++ { + cfg.Topic = fmt.Sprintf("%s-partition-%d", topicName, i) + mpc.managedConsumers = append(mpc.managedConsumers, NewManagedConsumer(cp, cfg)) + } + + go mpc.getMessageFromSubConsumer(ctx) + + return mpc, nil +} + +type ManagedPartitionConsumer struct { + clientPool *ClientPool + cfg ManagedConsumerConfig + asyncErrs utils.AsyncErrors + + queue chan msg.Message + + mu sync.RWMutex // protects following + managedConsumers []*ManagedConsumer + unAckTracker *UnackedMessageTracker +} + +func (mpc *ManagedPartitionConsumer) getMessageFromSubConsumer(ctx context.Context) chan msg.Message { + for i := 0; i < len(mpc.managedConsumers); i++ { + go func(index int) { + log.Infof("receive message form index:%d", index) + err := mpc.managedConsumers[index].ReceiveAsync(ctx, mpc.queue) + if err != nil { + log.Errorf("receive message error:%s", err.Error()) + return + } + }(i) + } + return mpc.queue + +} + +func (mpc *ManagedPartitionConsumer) Receive(ctx context.Context) (msg.Message, error) { + for { + select { + case tmpMsg, ok := <-mpc.queue: + if ok { + if mpc.unAckTracker != nil { + log.Debugf("receive add untrack: key: %s %s", tmpMsg.Meta.GetPartitionKey(), string(tmpMsg.Payload)) + mpc.unAckTracker.Add(tmpMsg.Msg.GetMessageId()) + } + return tmpMsg, nil + } + + case <-ctx.Done(): + return msg.Message{}, ctx.Err() + } + } +} + +// Ack acquires a consumer and Sends an ACK message for the given message. +func (mpc *ManagedPartitionConsumer) Ack(ctx context.Context, msg msg.Message) error { + + if mpc.unAckTracker != nil { + log.Debugf("ack remove untrack: key: %s %s", msg.Meta.GetPartitionKey(), string(msg.Payload)) + mpc.unAckTracker.Remove(msg.Msg.GetMessageId()) + } + + return mpc.managedConsumers[msg.Msg.GetMessageId().GetPartition()].Ack(ctx, msg) +} + +func (mpc *ManagedPartitionConsumer) Unsubscribe(ctx context.Context) error { + for _, consumer := range mpc.managedConsumers { + return consumer.Unsubscribe(ctx) + } + return nil +} + // Close consumer func (mpc *ManagedPartitionConsumer) Close(ctx context.Context) error { var errMsg string - for _, consumer := range mpc.MConsumer { - if mpc.UnAckTracker != nil { - mpc.UnAckTracker.Stop() + for _, consumer := range mpc.managedConsumers { + if mpc.unAckTracker != nil { + mpc.unAckTracker.Stop() } if err := consumer.Close(ctx); err != nil { errMsg += fmt.Sprintf("topic %s, name %s: %s ", consumer.cfg.Topic, consumer.cfg.Name, err.Error()) diff --git a/core/manage/managed_consumer_integration_test.go b/core/manage/managed_consumer_integration_test.go index cc5a77f..17e1fa7 100644 --- a/core/manage/managed_consumer_integration_test.go +++ b/core/manage/managed_consumer_integration_test.go @@ -35,7 +35,7 @@ func TestManagedConsumer_Int_ReceiveAsync(t *testing.T) { } }() - topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", utils.RandString(32)) + topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", RandString(32)) cp := NewClientPool() mcCfg := ClientConfig{ Addr: utils.PulsarAddr(t), @@ -45,21 +45,25 @@ func TestManagedConsumer_Int_ReceiveAsync(t *testing.T) { messages := make(chan msg.Message, 16) errs := make(chan error, 1) - consumerCfg := ConsumerConfig{ - ClientConfig: mcCfg, - Name: utils.RandString(8), - Topic: topic, - QueueSize: 128, + consumerCfg := ManagedConsumerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: mcCfg, + }, + Name: RandString(8), + Topic: topic, + QueueSize: 128, } mc := NewManagedConsumer(cp, consumerCfg) go func() { errs <- mc.ReceiveAsync(ctx, messages) }() - producerCfg := ProducerConfig{ - ClientConfig: mcCfg, - Name: utils.RandString(8), - Topic: topic, + producerCfg := ManagedProducerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: mcCfg, + }, + Name: RandString(8), + Topic: topic, } mp := NewManagedProducer(cp, producerCfg) @@ -76,7 +80,7 @@ func TestManagedConsumer_Int_ReceiveAsync(t *testing.T) { MORE: for _, msg := range expected { for { - if _, err := mp.Send(ctx, []byte(msg),""); err != nil { + if _, err := mp.Send(ctx, []byte(msg), ""); err != nil { continue } continue MORE @@ -133,7 +137,7 @@ func TestManagedConsumer_Int_ReceiveAsync_Multiple(t *testing.T) { } }() - topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", utils.RandString(32)) + topic := fmt.Sprintf("persistent://sample/standalone/ns1/test-%s", RandString(32)) cp := NewClientPool() mcCfg := ClientConfig{ Addr: utils.PulsarAddr(t), @@ -143,16 +147,18 @@ func TestManagedConsumer_Int_ReceiveAsync_Multiple(t *testing.T) { messages := make(chan msg.Message, 16) errs := make(chan error, 1) consumers := make([]*ManagedConsumer, 8) - consumerName := utils.RandString(8) + consumerName := RandString(8) // Create multiple managed consumers. All will // use the same messages channel. for i := range consumers { - consumerCfg := ConsumerConfig{ - ClientConfig: mcCfg, - Name: consumerName, - Topic: topic, - QueueSize: 128, + consumerCfg := ManagedConsumerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: mcCfg, + }, + Name: consumerName, + Topic: topic, + QueueSize: 128, } consumers[i] = NewManagedConsumer(cp, consumerCfg) go func(mc *ManagedConsumer) { @@ -160,10 +166,12 @@ func TestManagedConsumer_Int_ReceiveAsync_Multiple(t *testing.T) { }(consumers[i]) } - producerCfg := ProducerConfig{ - ClientConfig: mcCfg, - Name: utils.RandString(8), - Topic: topic, + producerCfg := ManagedProducerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: mcCfg, + }, + Name: RandString(8), + Topic: topic, } mp := NewManagedProducer(cp, producerCfg) @@ -180,7 +188,7 @@ func TestManagedConsumer_Int_ReceiveAsync_Multiple(t *testing.T) { MORE: for _, msg := range expected { for { - if _, err := mp.Send(ctx, []byte(msg),""); err != nil { + if _, err := mp.Send(ctx, []byte(msg), ""); err != nil { continue } continue MORE diff --git a/core/manage/managed_consumer_test.go b/core/manage/managed_consumer_test.go index f45b85e..c51915f 100644 --- a/core/manage/managed_consumer_test.go +++ b/core/manage/managed_consumer_test.go @@ -36,9 +36,11 @@ func TestManagedConsumer(t *testing.T) { } cp := NewClientPool() - mc := NewManagedConsumer(cp, ConsumerConfig{ - ClientConfig: ClientConfig{ - Addr: srv.Addr, + mc := NewManagedConsumer(cp, ManagedConsumerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: srv.Addr, + }, }, NewConsumerTimeout: time.Second, Topic: "test-topic", @@ -113,9 +115,11 @@ func TestManagedConsumer_ReceiveAsync(t *testing.T) { queueSize := 4 cp := NewClientPool() - mc := NewManagedConsumer(cp, ConsumerConfig{ - ClientConfig: ClientConfig{ - Addr: srv.Addr, + mc := NewManagedConsumer(cp, ManagedConsumerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: srv.Addr, + }, }, NewConsumerTimeout: time.Second, Topic: "test-topic", @@ -236,9 +240,11 @@ func TestManagedConsumer_SrvClosed(t *testing.T) { } cp := NewClientPool() - NewManagedConsumer(cp, ConsumerConfig{ - ClientConfig: ClientConfig{ - Addr: srv.Addr, + NewManagedConsumer(cp, ManagedConsumerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: srv.Addr, + }, }, NewConsumerTimeout: time.Second, Topic: "test-topic", @@ -270,9 +276,11 @@ func TestManagedConsumer_ConsumerClosed(t *testing.T) { } cp := NewClientPool() - NewManagedConsumer(cp, ConsumerConfig{ - ClientConfig: ClientConfig{ - Addr: srv.Addr, + NewManagedConsumer(cp, ManagedConsumerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: srv.Addr, + }, }, NewConsumerTimeout: time.Second, diff --git a/core/manage/managed_producer.go b/core/manage/managed_producer.go index 189038c..0c0681e 100644 --- a/core/manage/managed_producer.go +++ b/core/manage/managed_producer.go @@ -26,9 +26,9 @@ import ( "github.com/wolfstudy/pulsar-client-go/utils" ) -// ProducerConfig is used to configure a ManagedProducer. -type ProducerConfig struct { - ClientConfig +// ManagedProducerConfig is used to configure a ManagedProducer. +type ManagedProducerConfig struct { + ManagedClientConfig Topic string Name string @@ -41,7 +41,7 @@ type ProducerConfig struct { } // setDefaults returns a modified config with appropriate zero values set to defaults. -func (m ProducerConfig) setDefaults() ProducerConfig { +func (m ManagedProducerConfig) setDefaults() ManagedProducerConfig { if m.NewProducerTimeout <= 0 { m.NewProducerTimeout = 5 * time.Second } @@ -57,14 +57,14 @@ func (m ProducerConfig) setDefaults() ProducerConfig { // NewManagedProducer returns an initialized ManagedProducer. It will create and re-create // a Producer for the given discovery address and topic on a background goroutine. -func NewManagedProducer(cp *ClientPool, cfg ProducerConfig) *ManagedProducer { +func NewManagedProducer(cp *ClientPool, cfg ManagedProducerConfig) *ManagedProducer { cfg = cfg.setDefaults() m := ManagedProducer{ - ClientPool: cp, - Cfg: cfg, - AsyncErrs: utils.AsyncErrors(cfg.Errs), - Waitc: make(chan struct{}), + clientPool: cp, + cfg: cfg, + asyncErrs: utils.AsyncErrors(cfg.Errs), + waitc: make(chan struct{}), } go m.manage() @@ -72,85 +72,25 @@ func NewManagedProducer(cp *ClientPool, cfg ProducerConfig) *ManagedProducer { return &m } -func NewManagedPartitionProducer(cp *ClientPool, cfg ProducerConfig) (*ManagedPartitionProducer, error) { - cfg = cfg.setDefaults() - ctx := context.Background() - - m := ManagedPartitionProducer{ - ClientPool: cp, - Cfg: cfg, - AsyncErrs: utils.AsyncErrors(cfg.Errs), - Waitc: make(chan struct{}), - MProducer: make([]*ManagedProducer, 0), - } - - manageClient := cp.Get(cfg.ClientConfig) - client, err := manageClient.Get(ctx) - if err != nil { - log.Errorf("create client error:%s", err.Error()) - return nil, err - } - res, err := client.Discoverer.PartitionedMetadata(ctx, cfg.Topic) - if err != nil { - log.Errorf("get partition metadata error:%s", err.Error()) - return nil, err - } - numPartitions := res.GetPartitions() - m.numPartitions = numPartitions - topicName := cfg.Topic - for i := 0; uint32(i) < numPartitions; i++ { - cfg.Topic = fmt.Sprintf("%s-partition-%d", topicName, i) - m.MProducer = append(m.MProducer, NewManagedProducer(cp, cfg)) - } - - var router pub.MessageRouter - if m.Cfg.Router == pub.UseSinglePartition { - router = &pub.SinglePartitionRouter{ - Partition: numPartitions, - } - } else { - router = &pub.RoundRobinRouter{ - Counter: 0, - } - } - - m.MessageRouter = router - - return &m, nil -} - // ManagedProducer wraps a Producer with re-connect logic. type ManagedProducer struct { - ClientPool *ClientPool - Cfg ProducerConfig - AsyncErrs utils.AsyncErrors + clientPool *ClientPool + cfg ManagedProducerConfig + asyncErrs utils.AsyncErrors - Mu sync.RWMutex // protects following - Producer *pub.Producer // either producer is nil and wait isn't or vice versa - Waitc chan struct{} // if producer is nil, this will unblock when it's been re-set -} - -type ManagedPartitionProducer struct { - ClientPool *ClientPool - Cfg ProducerConfig - AsyncErrs utils.AsyncErrors - - Mu sync.RWMutex // protects following - Producer *pub.Producer // either producer is nil and wait isn't or vice versa - Waitc chan struct{} // if producer is nil, this will unblock when it's been re-set - MProducer []*ManagedProducer - MessageRouter pub.MessageRouter - numPartitions uint32 + mu sync.RWMutex // protects following + producer *pub.Producer // either producer is nil and wait isn't or vice versa + waitc chan struct{} // if producer is nil, this will unblock when it's been re-set } // Send attempts to use the Producer's Send method if available. If not available, // an error is returned. func (m *ManagedProducer) Send(ctx context.Context, payload []byte, msgKey string) (*api.CommandSendReceipt, error) { for { - m.Mu.RLock() - producer := m.Producer - wait := m.Waitc - m.Mu.RUnlock() + m.mu.RLock() + producer := m.producer + wait := m.waitc + m.mu.RUnlock() if producer != nil { return producer.Send(ctx, payload, msgKey) @@ -167,47 +107,40 @@ func (m *ManagedProducer) Send(ctx context.Context, payload []byte, msgKey strin } } -func (m *ManagedPartitionProducer) Send(ctx context.Context, payload []byte, msgKey string) (*api.CommandSendReceipt, error) { - partition := m.MessageRouter.ChoosePartition(msgKey, m.numPartitions) - log.Debugf("choose partition is: %d, msg payload is: %s, msg key is: %s", partition, string(payload), msgKey) - - return m.MProducer[partition].Send(ctx, payload, msgKey) -} - // Set unblocks the "wait" channel (if not nil), // and sets the producer under lock. func (m *ManagedProducer) Set(p *pub.Producer) { - m.Mu.Lock() + m.mu.Lock() - m.Producer = p + m.producer = p - if m.Waitc != nil { - close(m.Waitc) - m.Waitc = nil + if m.waitc != nil { + close(m.waitc) + m.waitc = nil } - m.Mu.Unlock() + m.mu.Unlock() } // Unset creates the "wait" channel (if nil), // and sets the producer to nil under lock. func (m *ManagedProducer) Unset() { - m.Mu.Lock() + m.mu.Lock() - if m.Waitc == nil { + if m.waitc == nil { // allow unset() to be called // multiple times by only creating // wait chan if its nil - m.Waitc = make(chan struct{}) + m.waitc = make(chan struct{}) } - m.Producer = nil + m.producer = nil - m.Mu.Unlock() + m.mu.Unlock() } // NewProducer attempts to create a Producer. func (m *ManagedProducer) NewProducer(ctx context.Context) (*pub.Producer, error) { - mc, err := m.ClientPool.ForTopic(ctx, m.Cfg.ClientConfig, m.Cfg.Topic) + mc, err := m.clientPool.ForTopic(ctx, m.cfg.ManagedClientConfig, m.cfg.Topic) if err != nil { return nil, err } @@ -219,31 +152,31 @@ func (m *ManagedProducer) NewProducer(ctx context.Context) (*pub.Producer, error // Create the topic producer. A blank producer name will // cause Pulsar to generate a unique name. - return client.NewProducer(ctx, m.Cfg.Topic, m.Cfg.Name) + return client.NewProducer(ctx, m.cfg.Topic, m.cfg.Name) } // Reconnect blocks while a new Producer is created. func (m *ManagedProducer) Reconnect(initial bool) *pub.Producer { - retryDelay := m.Cfg.InitialReconnectDelay + retryDelay := m.cfg.InitialReconnectDelay for attempt := 1; ; attempt++ { if initial { initial = false } else { <-time.After(retryDelay) - if retryDelay < m.Cfg.MaxReconnectDelay { + if retryDelay < m.cfg.MaxReconnectDelay { // double retry delay until we reach the max - if retryDelay *= 2; retryDelay > m.Cfg.MaxReconnectDelay { - retryDelay = m.Cfg.MaxReconnectDelay + if retryDelay *= 2; retryDelay > m.cfg.MaxReconnectDelay { + retryDelay = m.cfg.MaxReconnectDelay } } } - ctx, cancel := context.WithTimeout(context.Background(), m.Cfg.NewProducerTimeout) + ctx, cancel := context.WithTimeout(context.Background(), m.cfg.NewProducerTimeout) newProducer, err := m.NewProducer(ctx) cancel() if err != nil { - m.AsyncErrs.Send(err) + m.asyncErrs.Send(err) continue } @@ -273,21 +206,87 @@ func (m *ManagedProducer) manage() { // Monitor a scoped deferrable lock func (m *ManagedProducer) Monitor() func() { - m.Mu.Lock() - return m.Mu.Unlock + m.mu.Lock() + return m.mu.Unlock } // Close producer func (m *ManagedProducer) Close(ctx context.Context) error { defer m.Monitor()() - return m.Producer.Close(ctx) + return m.producer.Close(ctx) +} + +func NewManagedPartitionProducer(cp *ClientPool, cfg ManagedProducerConfig) (*ManagedPartitionProducer, error) { + cfg = cfg.setDefaults() + ctx := context.Background() + + m := ManagedPartitionProducer{ + clientPool: cp, + cfg: cfg, + asyncErrs: utils.AsyncErrors(cfg.Errs), + waitc: make(chan struct{}), + managedProducers: make([]*ManagedProducer, 0), + } + + managedClient := cp.Get(cfg.ManagedClientConfig) + client, err := managedClient.Get(ctx) + if err != nil { + log.Errorf("create client error:%s", err.Error()) + return nil, err + } + res, err := client.discoverer.PartitionedMetadata(ctx, cfg.Topic) + if err != nil { + log.Errorf("get partition metadata error:%s", err.Error()) + return nil, err + } + numPartitions := res.GetPartitions() + m.numPartitions = numPartitions + topicName := cfg.Topic + for i := 0; uint32(i) < numPartitions; i++ { + cfg.Topic = fmt.Sprintf("%s-partition-%d", topicName, i) + m.managedProducers = append(m.managedProducers, NewManagedProducer(cp, cfg)) + } + + var router pub.MessageRouter + if m.cfg.Router == pub.UseSinglePartition { + router = &pub.SinglePartitionRouter{ + Partition: numPartitions, + } + } else { + router = &pub.RoundRobinRouter{ + Counter: 0, + } + } + + m.messageRouter = router + + return &m, nil +} + +type ManagedPartitionProducer struct { + clientPool *ClientPool + cfg ManagedProducerConfig + asyncErrs utils.AsyncErrors + + mu sync.RWMutex // protects following + waitc chan struct{} // if producer is nil, this will unblock when it's been re-set + managedProducers []*ManagedProducer + messageRouter pub.MessageRouter + numPartitions uint32 +} + +func (m *ManagedPartitionProducer) Send(ctx context.Context, payload []byte, msgKey string) (*api.CommandSendReceipt, error) { + partition := m.messageRouter.ChoosePartition(msgKey, m.numPartitions) + log.Debugf("choose partition is: %d, msg payload is: %s, msg key is: %s", partition, string(payload), msgKey) + + return m.managedProducers[partition].Send(ctx, payload, msgKey) } func (m *ManagedPartitionProducer) Close(ctx context.Context) error { var errMsg string - for _, producer := range m.MProducer { + for _, producer := range m.managedProducers { if err := producer.Close(ctx); err != nil { - errMsg += fmt.Sprintf("topic %s, name %s: %s ", producer.Cfg.Topic, producer.Cfg.Name, err.Error()) + errMsg += fmt.Sprintf("topic %s, name %s: %s ", producer.cfg.Topic, producer.cfg.Name, err.Error()) } } if errMsg != "" { diff --git a/core/manage/managed_producer_test.go b/core/manage/managed_producer_test.go index 3cffa2b..0c5bd2f 100644 --- a/core/manage/managed_producer_test.go +++ b/core/manage/managed_producer_test.go @@ -35,9 +35,11 @@ func TestManagedProducer(t *testing.T) { } cp := NewClientPool() - mp := NewManagedProducer(cp, ProducerConfig{ - ClientConfig: ClientConfig{ - Addr: srv.Addr, + mp := NewManagedProducer(cp, ManagedProducerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: srv.Addr, + }, }, NewProducerTimeout: time.Second, Topic: "test-topic", @@ -53,7 +55,7 @@ func TestManagedProducer(t *testing.T) { } payload := []byte("hi") - if _, err = mp.Send(ctx, payload,""); err != nil { + if _, err = mp.Send(ctx, payload, ""); err != nil { t.Fatal(err) } @@ -91,9 +93,11 @@ func TestManagedProducer_Redirect(t *testing.T) { primarySrv.SetTopicLookupResp(topic, topicSrv.Addr, api.CommandLookupTopicResponse_Connect, false) cp := NewClientPool() - mp := NewManagedProducer(cp, ProducerConfig{ - ClientConfig: ClientConfig{ - Addr: primarySrv.Addr, + mp := NewManagedProducer(cp, ManagedProducerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: primarySrv.Addr, + }, }, NewProducerTimeout: time.Second, Topic: topic, @@ -116,7 +120,7 @@ func TestManagedProducer_Redirect(t *testing.T) { } payload := []byte("hi") - if _, err = mp.Send(ctx, payload,""); err != nil { + if _, err = mp.Send(ctx, payload, ""); err != nil { t.Fatal(err) } @@ -147,9 +151,11 @@ func TestManagedProducer_SrvClosed(t *testing.T) { } cp := NewClientPool() - mp := NewManagedProducer(cp, ProducerConfig{ - ClientConfig: ClientConfig{ - Addr: srv.Addr, + mp := NewManagedProducer(cp, ManagedProducerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: srv.Addr, + }, }, NewProducerTimeout: time.Second, Topic: "test-topic", @@ -170,7 +176,7 @@ func TestManagedProducer_SrvClosed(t *testing.T) { } payload := []byte("hi") - if _, err = mp.Send(ctx, payload,""); err != nil { + if _, err = mp.Send(ctx, payload, ""); err != nil { t.Fatal(err) } @@ -198,9 +204,11 @@ func TestManagedProducer_ProducerClosed(t *testing.T) { } cp := NewClientPool() - mp := NewManagedProducer(cp, ProducerConfig{ - ClientConfig: ClientConfig{ - Addr: srv.Addr, + mp := NewManagedProducer(cp, ManagedProducerConfig{ + ManagedClientConfig: ManagedClientConfig{ + ClientConfig: ClientConfig{ + Addr: srv.Addr, + }, }, NewProducerTimeout: time.Second, Topic: "test-topic", @@ -257,7 +265,7 @@ func TestManagedProducer_ProducerClosed(t *testing.T) { } payload := []byte("hi") - if _, err = mp.Send(ctx, payload,""); err != nil { + if _, err = mp.Send(ctx, payload, ""); err != nil { t.Fatal(err) } diff --git a/core/manage/pubsub.go b/core/manage/pubsub.go index 0b744ff..14ff7d2 100644 --- a/core/manage/pubsub.go +++ b/core/manage/pubsub.go @@ -29,32 +29,32 @@ import ( // NewPubsub returns a ready-to-use pubsub. func NewPubsub(s frame.CmdSender, dispatcher *frame.Dispatcher, subscriptions *Subscriptions, reqID *msg.MonotonicID) *Pubsub { return &Pubsub{ - S: s, - ReqID: reqID, - ProducerID: &msg.MonotonicID{ID: 0}, - ConsumerID: &msg.MonotonicID{ID: 0}, - Dispatcher: dispatcher, - Subscriptions: subscriptions, + s: s, + reqID: reqID, + producerID: &msg.MonotonicID{ID: 0}, + consumerID: &msg.MonotonicID{ID: 0}, + dispatcher: dispatcher, + subscriptions: subscriptions, } } // Pubsub is responsible for creating producers and consumers on a give topic. type Pubsub struct { - S frame.CmdSender - ReqID *msg.MonotonicID - ProducerID *msg.MonotonicID - ConsumerID *msg.MonotonicID + s frame.CmdSender + reqID *msg.MonotonicID + producerID *msg.MonotonicID + consumerID *msg.MonotonicID - Dispatcher *frame.Dispatcher // handles request response state - Subscriptions *Subscriptions + dispatcher *frame.Dispatcher // handles request response state + subscriptions *Subscriptions } // Subscribe subscribes to the given topic. The queueSize determines the buffer // size of the Consumer.Messages() channel. func (t *Pubsub) Subscribe(ctx context.Context, topic, subscribe string, subType api.CommandSubscribe_SubType, initialPosition api.CommandSubscribe_InitialPosition, queue chan msg.Message) (*sub.Consumer, error) { - requestID := t.ReqID.Next() - consumerID := t.ConsumerID.Next() + requestID := t.reqID.Next() + consumerID := t.consumerID.Next() cmd := api.BaseCommand{ Type: api.BaseCommand_SUBSCRIBE.Enum(), @@ -68,28 +68,28 @@ func (t *Pubsub) Subscribe(ctx context.Context, topic, subscribe string, subType }, } - resp, cancel, errs := t.Dispatcher.RegisterReqID(*requestID) + resp, cancel, errs := t.dispatcher.RegisterReqID(*requestID) if errs != nil { return nil, errs } defer cancel() - c := sub.NewConsumer(t.S, t.Dispatcher, topic, t.ReqID, *consumerID, queue) + c := sub.NewConsumer(t.s, t.dispatcher, topic, t.reqID, *consumerID, queue) // the new subscription needs to be added to the map // before sending the subscribe command, otherwise there'd // be a race between receiving the success result and // a possible message to the subscription - t.Subscriptions.AddConsumer(c) + t.subscriptions.AddConsumer(c) - if errs := t.S.SendSimpleCmd(cmd); errs != nil { - t.Subscriptions.DelConsumer(c) + if errs := t.s.SendSimpleCmd(cmd); errs != nil { + t.subscriptions.DelConsumer(c) return nil, errs } // wait for a response or timeout select { case <-ctx.Done(): - t.Subscriptions.DelConsumer(c) + t.subscriptions.DelConsumer(c) return nil, ctx.Err() case f := <-resp: @@ -102,22 +102,24 @@ func (t *Pubsub) Subscribe(ctx context.Context, topic, subscribe string, subType return c, nil case api.BaseCommand_ERROR: - t.Subscriptions.DelConsumer(c) + t.subscriptions.DelConsumer(c) errMsg := f.BaseCmd.GetError() return nil, fmt.Errorf("%s: %s", errMsg.GetError().String(), errMsg.GetMessage()) default: - t.Subscriptions.DelConsumer(c) + t.subscriptions.DelConsumer(c) return nil, utils.NewUnexpectedErrMsg(msgType, *requestID) } } } -func (t *Pubsub) SubscribeWithCfg(ctx context.Context, cfg ConsumerConfig, queue chan msg.Message) (*sub.Consumer, error) { - requestID := t.ReqID.Next() - consumerID := t.ConsumerID.Next() +// TODO: replace Subscribe() method above + +func (t *Pubsub) SubscribeWithCfg(ctx context.Context, cfg ManagedConsumerConfig, queue chan msg.Message) (*sub.Consumer, error) { + requestID := t.reqID.Next() + consumerID := t.consumerID.Next() subType, subPos := t.GetCfgMode(cfg) @@ -133,22 +135,22 @@ func (t *Pubsub) SubscribeWithCfg(ctx context.Context, cfg ConsumerConfig, queue }, } - resp, cancel, errs := t.Dispatcher.RegisterReqID(*requestID) + resp, cancel, errs := t.dispatcher.RegisterReqID(*requestID) if errs != nil { return nil, errs } defer cancel() - c := sub.NewConsumer(t.S, t.Dispatcher, cfg.Topic, t.ReqID, *consumerID, queue) + c := sub.NewConsumer(t.s, t.dispatcher, cfg.Topic, t.reqID, *consumerID, queue) // the new subscription needs to be added to the map // before sending the subscribe command, otherwise there'd // be a race between receiving the success result and // a possible message to the subscription - t.Subscriptions.AddConsumer(c) + t.subscriptions.AddConsumer(c) - if errs := t.S.SendSimpleCmd(cmd); errs != nil { - t.Subscriptions.DelConsumer(c) + if errs := t.s.SendSimpleCmd(cmd); errs != nil { + t.subscriptions.DelConsumer(c) return nil, errs } @@ -156,7 +158,7 @@ func (t *Pubsub) SubscribeWithCfg(ctx context.Context, cfg ConsumerConfig, queue select { case <-ctx.Done(): - t.Subscriptions.DelConsumer(c) + t.subscriptions.DelConsumer(c) return nil, ctx.Err() case f := <-resp: @@ -169,20 +171,20 @@ func (t *Pubsub) SubscribeWithCfg(ctx context.Context, cfg ConsumerConfig, queue return c, nil case api.BaseCommand_ERROR: - t.Subscriptions.DelConsumer(c) + t.subscriptions.DelConsumer(c) errMsg := f.BaseCmd.GetError() return nil, fmt.Errorf("%s: %s", errMsg.GetError().String(), errMsg.GetMessage()) default: - t.Subscriptions.DelConsumer(c) + t.subscriptions.DelConsumer(c) return nil, utils.NewUnexpectedErrMsg(msgType, *requestID) } } } -func (t *Pubsub) GetCfgMode(cfg ConsumerConfig) (api.CommandSubscribe_SubType, api.CommandSubscribe_InitialPosition) { +func (t *Pubsub) GetCfgMode(cfg ManagedConsumerConfig) (api.CommandSubscribe_SubType, api.CommandSubscribe_InitialPosition) { var ( subType api.CommandSubscribe_SubType subPos api.CommandSubscribe_InitialPosition @@ -210,12 +212,10 @@ func (t *Pubsub) GetCfgMode(cfg ConsumerConfig) (api.CommandSubscribe_SubType, a return subType, subPos } -// TODO: replace Subscribe() method above - // Producer creates a new producer for the given topic and producerName. func (t *Pubsub) Producer(ctx context.Context, topic, producerName string) (*pub.Producer, error) { - requestID := t.ReqID.Next() - producerID := t.ProducerID.Next() + requestID := t.reqID.Next() + producerID := t.producerID.Next() cmd := api.BaseCommand{ Type: api.BaseCommand_PRODUCER.Enum(), @@ -229,19 +229,19 @@ func (t *Pubsub) Producer(ctx context.Context, topic, producerName string) (*pub cmd.Producer.ProducerName = proto.String(producerName) } - resp, cancel, err := t.Dispatcher.RegisterReqID(*requestID) + resp, cancel, err := t.dispatcher.RegisterReqID(*requestID) if err != nil { return nil, err } defer cancel() - p := pub.NewProducer(t.S, t.Dispatcher, t.ReqID, *producerID) + p := pub.NewProducer(t.s, t.dispatcher, t.reqID, *producerID) // the new producer needs to be added to subscriptions before sending // the create command to avoid potential race conditions - t.Subscriptions.AddProducer(p) + t.subscriptions.AddProducer(p) - if err := t.S.SendSimpleCmd(cmd); err != nil { - t.Subscriptions.DelProducer(p) + if err := t.s.SendSimpleCmd(cmd); err != nil { + t.subscriptions.DelProducer(p) return nil, err } @@ -249,7 +249,7 @@ func (t *Pubsub) Producer(ctx context.Context, topic, producerName string) (*pub select { case <-ctx.Done(): - t.Subscriptions.DelProducer(p) + t.subscriptions.DelProducer(p) return nil, ctx.Err() case f := <-resp: @@ -265,13 +265,13 @@ func (t *Pubsub) Producer(ctx context.Context, topic, producerName string) (*pub return p, nil case api.BaseCommand_ERROR: - t.Subscriptions.DelProducer(p) + t.subscriptions.DelProducer(p) errMsg := f.BaseCmd.GetError() return nil, fmt.Errorf("%s: %s", errMsg.GetError().String(), errMsg.GetMessage()) default: - t.Subscriptions.DelProducer(p) + t.subscriptions.DelProducer(p) return nil, utils.NewUnexpectedErrMsg(msgType, *requestID) } diff --git a/core/manage/pubsub_test.go b/core/manage/pubsub_test.go index 23daf9e..e72f81b 100644 --- a/core/manage/pubsub_test.go +++ b/core/manage/pubsub_test.go @@ -37,7 +37,7 @@ func TestPubsub_Subscribe_Success(t *testing.T) { tp := NewPubsub(&ms, dispatcher, subs, reqID) // manually set consumerID to verify that it's correctly // being set on Consumer - tp.ConsumerID = &msg.MonotonicID{ID: consID} + tp.consumerID = &msg.MonotonicID{ID: consID} ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -84,7 +84,7 @@ func TestPubsub_Subscribe_Success(t *testing.T) { t.Fatalf("got Consumer.consumerID = %d; expected %d", got.ConsumerID, consID) } - if _, ok := subs.Consumers[got.ConsumerID]; !ok { + if _, ok := subs.consumers[got.ConsumerID]; !ok { t.Fatalf("subscriptions.consumers[%d] is absent; expected consumer", got.ConsumerID) } } @@ -138,7 +138,7 @@ func TestPubsub_Subscribe_Error(t *testing.T) { } t.Logf("subscribe() err = %v", r.err) - if got, expected := len(subs.Consumers), 0; got != expected { + if got, expected := len(subs.consumers), 0; got != expected { t.Fatalf("subscriptions.consumers has %d elements; expected %d", got, expected) } } @@ -154,7 +154,7 @@ func TestPubsub_Producer_Success(t *testing.T) { tp := NewPubsub(&ms, dispatcher, subs, reqID) // manually set producerID to verify that it's correctly // being set on Producer - tp.ProducerID = &msg.MonotonicID{ID: prodID} + tp.producerID = &msg.MonotonicID{ID: prodID} ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -206,7 +206,7 @@ func TestPubsub_Producer_Success(t *testing.T) { t.Fatalf("got Producer.producerName = %q; expected %q", got.ProducerName, prodName) } - if _, ok := subs.Producers[got.ProducerID]; !ok { + if _, ok := subs.producers[got.ProducerID]; !ok { t.Fatalf("subscriptions.producers[%d] is absent; expected producer", got.ProducerID) } } @@ -222,7 +222,7 @@ func TestPubsub_Producer_Error(t *testing.T) { tp := NewPubsub(&ms, dispatcher, subs, reqID) // manually set producerID to verify that it's correctly // being set on Producer - tp.ProducerID = &msg.MonotonicID{ID: prodID} + tp.producerID = &msg.MonotonicID{ID: prodID} ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -263,7 +263,7 @@ func TestPubsub_Producer_Error(t *testing.T) { } t.Logf("producer() err = %v", r.err) - if got, expected := len(subs.Producers), 0; got != expected { + if got, expected := len(subs.producers), 0; got != expected { t.Fatalf("subscriptions.producers has %d elements; expected %d", got, expected) } } diff --git a/core/manage/subscriptions.go b/core/manage/subscriptions.go index e506a16..4bb985a 100644 --- a/core/manage/subscriptions.go +++ b/core/manage/subscriptions.go @@ -27,52 +27,52 @@ import ( // NewSubscriptions returns a ready-to-use subscriptions. func NewSubscriptions() *Subscriptions { return &Subscriptions{ - Consumers: make(map[uint64]*sub.Consumer), - Producers: make(map[uint64]*pub.Producer), + consumers: make(map[uint64]*sub.Consumer), + producers: make(map[uint64]*pub.Producer), } } // Subscriptions is responsible for storing producers and consumers // based on their IDs. type Subscriptions struct { - Cmu sync.RWMutex // protects following - Consumers map[uint64]*sub.Consumer + cmu sync.RWMutex // protects following + consumers map[uint64]*sub.Consumer - Pmu sync.Mutex // protects following - Producers map[uint64]*pub.Producer + pmu sync.Mutex // protects following + producers map[uint64]*pub.Producer } func (s *Subscriptions) AddConsumer(c *sub.Consumer) { - s.Cmu.Lock() - s.Consumers[c.ConsumerID] = c - s.Cmu.Unlock() + s.cmu.Lock() + s.consumers[c.ConsumerID] = c + s.cmu.Unlock() } func (s *Subscriptions) DelConsumer(c *sub.Consumer) { - s.Cmu.Lock() - delete(s.Consumers, c.ConsumerID) - s.Cmu.Unlock() + s.cmu.Lock() + delete(s.consumers, c.ConsumerID) + s.cmu.Unlock() } func (s *Subscriptions) HandleCloseConsumer(consumerID uint64, f frame.Frame) error { - s.Cmu.Lock() - defer s.Cmu.Unlock() + s.cmu.Lock() + defer s.cmu.Unlock() - c, ok := s.Consumers[consumerID] + c, ok := s.consumers[consumerID] if !ok { return utils.NewUnexpectedErrMsg(f.BaseCmd.GetType(), consumerID) } - delete(s.Consumers, consumerID) + delete(s.consumers, consumerID) return c.HandleCloseConsumer(f) } func (s *Subscriptions) HandleReachedEndOfTopic(consumerID uint64, f frame.Frame) error { - s.Cmu.Lock() - defer s.Cmu.Unlock() + s.cmu.Lock() + defer s.cmu.Unlock() - c, ok := s.Consumers[consumerID] + c, ok := s.consumers[consumerID] if !ok { return utils.NewUnexpectedErrMsg(f.BaseCmd.GetType(), consumerID) } @@ -81,9 +81,9 @@ func (s *Subscriptions) HandleReachedEndOfTopic(consumerID uint64, f frame.Frame } func (s *Subscriptions) HandleMessage(consumerID uint64, f frame.Frame) error { - s.Cmu.RLock() - c, ok := s.Consumers[consumerID] - s.Cmu.RUnlock() + s.cmu.RLock() + c, ok := s.consumers[consumerID] + s.cmu.RUnlock() if !ok { return utils.NewUnexpectedErrMsg(f.BaseCmd.GetType(), consumerID) @@ -93,27 +93,27 @@ func (s *Subscriptions) HandleMessage(consumerID uint64, f frame.Frame) error { } func (s *Subscriptions) AddProducer(p *pub.Producer) { - s.Pmu.Lock() - s.Producers[p.ProducerID] = p - s.Pmu.Unlock() + s.pmu.Lock() + s.producers[p.ProducerID] = p + s.pmu.Unlock() } func (s *Subscriptions) DelProducer(p *pub.Producer) { - s.Pmu.Lock() - delete(s.Producers, p.ProducerID) - s.Pmu.Unlock() + s.pmu.Lock() + delete(s.producers, p.ProducerID) + s.pmu.Unlock() } func (s *Subscriptions) HandleCloseProducer(producerID uint64, f frame.Frame) error { - s.Pmu.Lock() - defer s.Pmu.Unlock() + s.pmu.Lock() + defer s.pmu.Unlock() - p, ok := s.Producers[producerID] + p, ok := s.producers[producerID] if !ok { return utils.NewUnexpectedErrMsg(f.BaseCmd.GetType(), producerID) } - delete(s.Producers, producerID) + delete(s.producers, producerID) return p.HandleCloseProducer(f) } diff --git a/core/manage/unackedMsgTracker.go b/core/manage/unackedMsgTracker.go index 880e6e6..2d22ccf 100644 --- a/core/manage/unackedMsgTracker.go +++ b/core/manage/unackedMsgTracker.go @@ -29,6 +29,15 @@ import ( "github.com/wolfstudy/pulsar-client-go/pkg/log" ) +func NewUnackedMessageTracker() *UnackedMessageTracker { + UnAckTracker := &UnackedMessageTracker{ + currentSet: set.NewSet(), + oldOpenSet: set.NewSet(), + } + + return UnAckTracker +} + type UnackedMessageTracker struct { cmu sync.RWMutex // protects following currentSet set.Set @@ -39,15 +48,6 @@ type UnackedMessageTracker struct { partitionConsumer *ManagedPartitionConsumer } -func NewUnackedMessageTracker() *UnackedMessageTracker { - UnAckTracker := &UnackedMessageTracker{ - currentSet: set.NewSet(), - oldOpenSet: set.NewSet(), - } - - return UnAckTracker -} - func (t *UnackedMessageTracker) Size() int { t.cmu.Lock() defer t.cmu.Unlock() @@ -172,7 +172,7 @@ func (t *UnackedMessageTracker) Start(ackTimeoutMillis int64) { messageIdsMap[msgID.GetPartition()] = append(messageIdsMap[msgID.GetPartition()], msgID) } - for index, subConsumer := range t.partitionConsumer.MConsumer { + for index, subConsumer := range t.partitionConsumer.managedConsumers { if messageIdsMap[int32(index)] != nil { cmd := api.BaseCommand{ Type: api.BaseCommand_REDELIVER_UNACKNOWLEDGED_MESSAGES.Enum(), @@ -182,7 +182,7 @@ func (t *UnackedMessageTracker) Start(ackTimeoutMillis int64) { }, } log.Debugf("index value: %d, partition name is:%s, messageID length:%d", - index, t.partitionConsumer.MConsumer[index].consumer.Topic, len(messageIdsMap[int32(index)])) + index, t.partitionConsumer.managedConsumers[index].consumer.Topic, len(messageIdsMap[int32(index)])) if err := subConsumer.consumer.S.SendSimpleCmd(cmd); err != nil { log.Errorf("send partition subConsumer redeliver cmd error:%s", err.Error()) return diff --git a/core/manage/util_test.go b/core/manage/util_test.go new file mode 100644 index 0000000..5dbd7e7 --- /dev/null +++ b/core/manage/util_test.go @@ -0,0 +1,25 @@ +package manage + +import ( + "math/rand" + "sync" + "time" +) + +var ( + randStringChars = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + randStringMu = new(sync.Mutex) //protects randStringRand, which isn't threadsafe + randStringRand = rand.New(rand.NewSource(time.Now().UnixNano())) +) + +func RandString(n int) string { + b := make([]rune, n) + l := len(randStringChars) + randStringMu.Lock() + for i := range b { + b[i] = randStringChars[randStringRand.Intn(l)] + } + randStringMu.Unlock() + return string(b) +} + diff --git a/core/pub/producer.go b/core/pub/producer.go index 29b6961..37c20d6 100644 --- a/core/pub/producer.go +++ b/core/pub/producer.go @@ -35,42 +35,42 @@ var ErrClosedProducer = errors.New("producer is closed") // sends messages (type MESSAGE) to Pulsar. func NewProducer(s frame.CmdSender, dispatcher *frame.Dispatcher, reqID *msg.MonotonicID, producerID uint64) *Producer { return &Producer{ - S: s, + s: s, ProducerID: producerID, - ReqID: reqID, - SeqID: &msg.MonotonicID{ID: 0}, - Dispatcher: dispatcher, - Closedc: make(chan struct{}), + reqID: reqID, + seqID: &msg.MonotonicID{ID: 0}, + dispatcher: dispatcher, + closedc: make(chan struct{}), } } // Producer is responsible for creating a subscription producer and // managing its state. type Producer struct { - S frame.CmdSender + s frame.CmdSender ProducerID uint64 ProducerName string - ReqID *msg.MonotonicID - SeqID *msg.MonotonicID + reqID *msg.MonotonicID + seqID *msg.MonotonicID - Dispatcher *frame.Dispatcher // handles request/response state + dispatcher *frame.Dispatcher // handles request/response state - Mu sync.RWMutex // protects following - IsClosed bool - Closedc chan struct{} + mu sync.RWMutex // protects following + isClosed bool + closedc chan struct{} } func (p *Producer) Send(ctx context.Context, payload []byte, msgKey string) (*api.CommandSendReceipt, error) { - p.Mu.RLock() - if p.IsClosed { - p.Mu.RUnlock() + p.mu.RLock() + if p.isClosed { + p.mu.RUnlock() return nil, ErrClosedProducer } - p.Mu.RUnlock() + p.mu.RUnlock() - sequenceID := p.SeqID.Next() + sequenceID := p.seqID.Next() cmd := api.BaseCommand{ Type: api.BaseCommand_SEND.Enum(), @@ -100,13 +100,13 @@ func (p *Producer) Send(ctx context.Context, payload []byte, msgKey string) (*ap } } - resp, cancel, err := p.Dispatcher.RegisterProdSeqIDs(p.ProducerID, *sequenceID) + resp, cancel, err := p.dispatcher.RegisterProdSeqIDs(p.ProducerID, *sequenceID) if err != nil { return nil, err } defer cancel() - if err := p.S.SendPayloadCmd(cmd, metadata, payload); err != nil { + if err := p.s.SendPayloadCmd(cmd, metadata, payload); err != nil { return nil, err } @@ -142,7 +142,7 @@ func (p *Producer) Send(ctx context.Context, payload []byte, msgKey string) (*ap // been closed. // TODO: Rename Done func (p *Producer) Closed() <-chan struct{} { - return p.Closedc + return p.closedc } func (p *Producer) Name() string { @@ -150,13 +150,13 @@ func (p *Producer) Name() string { } func (p *Producer) LastSequenceID() uint64 { - return *p.SeqID.Last() + return *p.seqID.Last() } // ConnClosed unblocks when the producer's connection has been closed. Once that // happens, it's necessary to first recreate the client and then the producer. func (p *Producer) ConnClosed() <-chan struct{} { - return p.S.Closed() + return p.s.Closed() } // Close closes the producer. When receiving a CloseProducer command, @@ -164,14 +164,14 @@ func (p *Producer) ConnClosed() <-chan struct{} { // wait until all pending messages are persisted and then reply Success to the client. // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#command-closeproducer func (p *Producer) Close(ctx context.Context) error { - p.Mu.Lock() - defer p.Mu.Unlock() + p.mu.Lock() + defer p.mu.Unlock() - if p.IsClosed { + if p.isClosed { return nil } - requestID := p.ReqID.Next() + requestID := p.reqID.Next() cmd := api.BaseCommand{ Type: api.BaseCommand_CLOSE_PRODUCER.Enum(), @@ -181,13 +181,13 @@ func (p *Producer) Close(ctx context.Context) error { }, } - resp, cancel, err := p.Dispatcher.RegisterReqID(*requestID) + resp, cancel, err := p.dispatcher.RegisterReqID(*requestID) if err != nil { return err } defer cancel() - if err := p.S.SendSimpleCmd(cmd); err != nil { + if err := p.s.SendSimpleCmd(cmd); err != nil { return err } @@ -196,8 +196,8 @@ func (p *Producer) Close(ctx context.Context) error { return ctx.Err() case <-resp: - p.IsClosed = true - close(p.Closedc) + p.isClosed = true + close(p.closedc) return nil } @@ -212,15 +212,15 @@ func (p *Producer) Close(ctx context.Context) error { // When receiving the CloseProducer, the client is expected to go through the service discovery lookup again and recreate the producer again. The TCP connection is not being affected. // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#command-closeproducer func (p *Producer) HandleCloseProducer(f frame.Frame) error { - p.Mu.Lock() - defer p.Mu.Unlock() + p.mu.Lock() + defer p.mu.Unlock() - if p.IsClosed { + if p.isClosed { return nil } - p.IsClosed = true - close(p.Closedc) + p.isClosed = true + close(p.closedc) return nil } diff --git a/core/srv/discoverer.go b/core/srv/discoverer.go index 4972b6f..f8c1448 100644 --- a/core/srv/discoverer.go +++ b/core/srv/discoverer.go @@ -25,9 +25,9 @@ import ( // NewDiscoverer returns a ready-to-use discoverer func NewDiscoverer(s frame.CmdSender, dispatcher *frame.Dispatcher, reqID *msg.MonotonicID) *Discoverer { return &Discoverer{ - S: s, - ReqID: reqID, - Dispatcher: dispatcher, + s: s, + reqID: reqID, + dispatcher: dispatcher, } } @@ -35,9 +35,9 @@ func NewDiscoverer(s frame.CmdSender, dispatcher *frame.Dispatcher, reqID *msg.M // // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Servicediscovery-40v5m type Discoverer struct { - S frame.CmdSender - ReqID *msg.MonotonicID - Dispatcher *frame.Dispatcher + s frame.CmdSender + reqID *msg.MonotonicID + dispatcher *frame.Dispatcher } // PartitionedMetadata performs a PARTITIONED_METADATA request for the given @@ -46,7 +46,7 @@ type Discoverer struct { // // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Partitionedtopicsdiscovery-g14a9h func (d *Discoverer) PartitionedMetadata(ctx context.Context, topic string) (*api.CommandPartitionedTopicMetadataResponse, error) { - requestID := d.ReqID.Next() + requestID := d.reqID.Next() cmd := api.BaseCommand{ Type: api.BaseCommand_PARTITIONED_METADATA.Enum(), PartitionMetadata: &api.CommandPartitionedTopicMetadata{ @@ -55,13 +55,13 @@ func (d *Discoverer) PartitionedMetadata(ctx context.Context, topic string) (*ap }, } - resp, cancel, err := d.Dispatcher.RegisterReqID(*requestID) + resp, cancel, err := d.dispatcher.RegisterReqID(*requestID) if err != nil { return nil, err } defer cancel() - if err := d.S.SendSimpleCmd(cmd); err != nil { + if err := d.s.SendSimpleCmd(cmd); err != nil { return nil, err } @@ -82,7 +82,7 @@ func (d *Discoverer) PartitionedMetadata(ctx context.Context, topic string) (*ap // // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Topiclookup-dk72wp func (d *Discoverer) LookupTopic(ctx context.Context, topic string, authoritative bool) (*api.CommandLookupTopicResponse, error) { - requestID := d.ReqID.Next() + requestID := d.reqID.Next() cmd := api.BaseCommand{ Type: api.BaseCommand_LOOKUP.Enum(), @@ -93,13 +93,13 @@ func (d *Discoverer) LookupTopic(ctx context.Context, topic string, authoritativ }, } - resp, cancel, err := d.Dispatcher.RegisterReqID(*requestID) + resp, cancel, err := d.dispatcher.RegisterReqID(*requestID) if err != nil { return nil, err } defer cancel() - if err := d.S.SendSimpleCmd(cmd); err != nil { + if err := d.s.SendSimpleCmd(cmd); err != nil { return nil, err } diff --git a/core/srv/pinger.go b/core/srv/pinger.go index 7dc159d..84c95db 100644 --- a/core/srv/pinger.go +++ b/core/srv/pinger.go @@ -23,8 +23,8 @@ import ( // NewPinger returns a ready-to-use pinger. func NewPinger(s frame.CmdSender, dispatcher *frame.Dispatcher) *Pinger { return &Pinger{ - S: s, - Dispatcher: dispatcher, + s: s, + dispatcher: dispatcher, } } @@ -36,15 +36,15 @@ func NewPinger(s frame.CmdSender, dispatcher *frame.Dispatcher) *Pinger { // // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#KeepAlive-53utwq type Pinger struct { - S frame.CmdSender - Dispatcher *frame.Dispatcher // used to manage the request/response state + s frame.CmdSender + dispatcher *frame.Dispatcher // used to manage the request/response state } // Ping sends a PING message to the Pulsar server, then // waits for either a PONG response or the context to // timeout. func (p *Pinger) Ping(ctx context.Context) error { - resp, cancel, err := p.Dispatcher.RegisterGlobal() + resp, cancel, err := p.dispatcher.RegisterGlobal() if err != nil { return err } @@ -55,7 +55,7 @@ func (p *Pinger) Ping(ctx context.Context) error { Ping: &api.CommandPing{}, } - if err := p.S.SendSimpleCmd(cmd); err != nil { + if err := p.s.SendSimpleCmd(cmd); err != nil { return err } @@ -80,5 +80,5 @@ func (p *Pinger) HandlePing(msgType api.BaseCommand_Type, msg *api.CommandPing) Pong: &api.CommandPong{}, } - return p.S.SendSimpleCmd(cmd) + return p.s.SendSimpleCmd(cmd) } diff --git a/core/sub/consumer.go b/core/sub/consumer.go index 888704a..d79536d 100644 --- a/core/sub/consumer.go +++ b/core/sub/consumer.go @@ -39,11 +39,11 @@ func NewConsumer(s frame.CmdSender, dispatcher *frame.Dispatcher, topic string, S: s, Topic: topic, ConsumerID: ConsumerID, - ReqID: reqID, - Dispatcher: dispatcher, - Queue: queue, - Closedc: make(chan struct{}), - EndOfTopicc: make(chan struct{}), + reqID: reqID, + dispatcher: dispatcher, + queue: queue, + closedc: make(chan struct{}), + endOfTopicc: make(chan struct{}), } } @@ -54,27 +54,26 @@ type Consumer struct { Topic string ConsumerID uint64 - ReqID *msg.MonotonicID - Dispatcher *frame.Dispatcher // handles request/response state + reqID *msg.MonotonicID + dispatcher *frame.Dispatcher // handles request/response state - Queue chan msg.Message + queue chan msg.Message - Omu sync.Mutex // protects following - Overflow []*api.MessageIdData // IDs of messages that were dropped because of full buffer + omu sync.Mutex // protects following + overflow []*api.MessageIdData // IDs of messages that were dropped because of full buffer - Mu sync.Mutex // protects following - IsClosed bool - Closedc chan struct{} - IsEndOfTopic bool - EndOfTopicc chan struct{} + mu sync.Mutex // protects following + isClosed bool + closedc chan struct{} + isEndOfTopic bool + endOfTopicc chan struct{} } - // Messages returns a read-only channel of messages // received by the consumer. The channel will never be // closed by the consumer. func (c *Consumer) Messages() <-chan msg.Message { - return c.Queue + return c.queue } // Ack is used to signal to the broker that a given message has been @@ -116,7 +115,7 @@ func (c *Consumer) Flow(permits uint32) error { // consumer has been closed, in which case the channel will have // been closed and unblocked. func (c *Consumer) Closed() <-chan struct{} { - return c.Closedc + return c.closedc } // ConnClosed unblocks when the consumer's connection has been closed. Once that @@ -128,14 +127,14 @@ func (c *Consumer) ConnClosed() <-chan struct{} { // Close closes the consumer. The channel returned from the Closed method // will then unblock upon successful closure. func (c *Consumer) Close(ctx context.Context) error { - c.Mu.Lock() - defer c.Mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() - if c.IsClosed { + if c.isClosed { return nil } - requestID := c.ReqID.Next() + requestID := c.reqID.Next() cmd := api.BaseCommand{ Type: api.BaseCommand_CLOSE_CONSUMER.Enum(), @@ -145,7 +144,7 @@ func (c *Consumer) Close(ctx context.Context) error { }, } - resp, cancel, err := c.Dispatcher.RegisterReqID(*requestID) + resp, cancel, err := c.dispatcher.RegisterReqID(*requestID) if err != nil { return err } @@ -160,8 +159,8 @@ func (c *Consumer) Close(ctx context.Context) error { return ctx.Err() case <-resp: - c.IsClosed = true - close(c.Closedc) + c.isClosed = true + close(c.closedc) return nil } @@ -169,7 +168,7 @@ func (c *Consumer) Close(ctx context.Context) error { // Unsubscribe the consumer from its topic. func (c *Consumer) Unsubscribe(ctx context.Context) error { - requestID := c.ReqID.Next() + requestID := c.reqID.Next() cmd := api.BaseCommand{ Type: api.BaseCommand_UNSUBSCRIBE.Enum(), @@ -179,7 +178,7 @@ func (c *Consumer) Unsubscribe(ctx context.Context) error { }, } - resp, cancel, err := c.Dispatcher.RegisterReqID(*requestID) + resp, cancel, err := c.dispatcher.RegisterReqID(*requestID) if err != nil { return err } @@ -202,15 +201,15 @@ func (c *Consumer) Unsubscribe(ctx context.Context) error { // HandleCloseConsumer should be called when a CLOSE_CONSUMER message is received // associated with this consumer. func (c *Consumer) HandleCloseConsumer(f frame.Frame) error { - c.Mu.Lock() - defer c.Mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() - if c.IsClosed { + if c.isClosed { return nil } - c.IsClosed = true - close(c.Closedc) + c.isClosed = true + close(c.closedc) return nil } @@ -218,21 +217,21 @@ func (c *Consumer) HandleCloseConsumer(f frame.Frame) error { // ReachedEndOfTopic unblocks whenever the topic has been "terminated" and // all the messages on the subscription were acknowledged. func (c *Consumer) ReachedEndOfTopic() <-chan struct{} { - return c.EndOfTopicc + return c.endOfTopicc } // HandleReachedEndOfTopic should be called for all received REACHED_END_OF_TOPIC messages // associated with this consumer. func (c *Consumer) HandleReachedEndOfTopic(f frame.Frame) error { - c.Mu.Lock() - defer c.Mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() - if c.IsEndOfTopic { + if c.isEndOfTopic { return nil } - c.IsEndOfTopic = true - close(c.EndOfTopicc) + c.isEndOfTopic = true + close(c.endOfTopicc) return nil } @@ -251,10 +250,10 @@ func (c *Consumer) RedeliverUnacknowledged(ctx context.Context) error { return err } - // clear Overflow slice - c.Omu.Lock() - c.Overflow = nil - c.Omu.Unlock() + // clear overflow slice + c.omu.Lock() + c.overflow = nil + c.omu.Unlock() return nil } @@ -265,10 +264,10 @@ func (c *Consumer) RedeliverUnacknowledged(ctx context.Context) error { // will be redelivered. // https://github.com/apache/incubator-pulsar/issues/2003 func (c *Consumer) RedeliverOverflow(ctx context.Context) (int, error) { - c.Omu.Lock() - defer c.Omu.Unlock() + c.omu.Lock() + defer c.omu.Unlock() - l := len(c.Overflow) + l := len(c.overflow) if l == 0 { return l, nil @@ -286,7 +285,7 @@ func (c *Consumer) RedeliverOverflow(ctx context.Context) (int, error) { Type: api.BaseCommand_REDELIVER_UNACKNOWLEDGED_MESSAGES.Enum(), RedeliverUnacknowledgedMessages: &api.CommandRedeliverUnacknowledgedMessages{ ConsumerId: proto.Uint64(c.ConsumerID), - MessageIds: c.Overflow[i:end], + MessageIds: c.overflow[i:end], }, } @@ -295,8 +294,8 @@ func (c *Consumer) RedeliverOverflow(ctx context.Context) (int, error) { } } - // clear Overflow slice - c.Overflow = nil + // clear overflow slice + c.overflow = nil return l, nil } @@ -313,26 +312,26 @@ func (c *Consumer) HandleMessage(f frame.Frame) error { } select { - case c.Queue <- m: + case c.queue <- m: return nil default: - // Add messageId to Overflow buffer, avoiding duplicates. + // Add messageId to overflow buffer, avoiding duplicates. newMid := f.BaseCmd.GetMessage().GetMessageId() var dup bool - c.Omu.Lock() - for _, mid := range c.Overflow { + c.omu.Lock() + for _, mid := range c.overflow { if proto.Equal(mid, newMid) { dup = true break } } if !dup { - c.Overflow = append(c.Overflow, newMid) + c.overflow = append(c.overflow, newMid) } - c.Omu.Unlock() + c.omu.Unlock() - return fmt.Errorf("consumer message queue on topic %q is full (capacity = %d)", c.Topic, cap(c.Queue)) + return fmt.Errorf("consumer message queue on topic %q is full (capacity = %d)", c.Topic, cap(c.queue)) } } diff --git a/core/sub/consumer_test.go b/core/sub/consumer_test.go index 2114847..890273c 100644 --- a/core/sub/consumer_test.go +++ b/core/sub/consumer_test.go @@ -178,7 +178,7 @@ func TestConsumer_handleMessage_fullQueue(t *testing.T) { t.Fatalf("handleMessage() err = %v; expected nil for msg number %d and queueSize %d", err, i+1, queueSize) } - if got, expected := len(c.Overflow), 0; got != expected { + if got, expected := len(c.overflow), 0; got != expected { t.Fatalf("len(consumer overflow buffer) = %d; expected %d", got, expected) } } @@ -191,7 +191,7 @@ func TestConsumer_handleMessage_fullQueue(t *testing.T) { } t.Logf("handleMessage() err (expected) = %q for msg number %d and queueSize %d", err, queueSize+1, queueSize) - if got, expected := len(c.Overflow), 1; got != expected { + if got, expected := len(c.overflow), 1; got != expected { t.Fatalf("len(consumer overflow buffer) = %d; expected %d", got, expected) } @@ -302,7 +302,7 @@ func TestConsumer_RedeliverOverflow(t *testing.T) { entryID := uint64(i) // the msg.MessageIdData must be unique for each msg.Message, // otherwise the consumer will consider them duplicates - // and not store them in Overflow + // and not store them in overflow f := frame.Frame{ BaseCmd: &api.BaseCommand{ Type: api.BaseCommand_MESSAGE.Enum(), diff --git a/examples/consumer/consumer.go b/examples/consumer/consumer.go index 4192a7c..81eab67 100644 --- a/examples/consumer/consumer.go +++ b/examples/consumer/consumer.go @@ -33,9 +33,11 @@ var clientPool = manage.NewClientPool() func main() { ctx := context.Background() - consumerConf := manage.ConsumerConfig{ - ClientConfig: manage.ClientConfig{ - Addr: "localhost:6650", + consumerConf := manage.ManagedConsumerConfig{ + ManagedClientConfig: manage.ManagedClientConfig{ + ClientConfig: manage.ClientConfig{ + Addr: "localhost:6650", + }, }, Topic: "multi-topic-10", @@ -46,14 +48,14 @@ func main() { QueueSize: 5, } //mp := manage.NewManagedConsumer(clientPool, consumerConf) - mp, err := manage.NewPartitionManagedConsumer(clientPool, consumerConf) + mp, err := manage.NewManagedPartitionConsumer(clientPool, consumerConf) if err != nil { log.Fatal(err) } //mp2 := manage.NewManagedConsumer(clientPool, consumerConf) - mp2, err := manage.NewPartitionManagedConsumer(clientPool, consumerConf) + mp2, err := manage.NewManagedPartitionConsumer(clientPool, consumerConf) if err != nil { log.Fatal(err) } diff --git a/examples/producer/producer.go b/examples/producer/producer.go index 133a593..843b82c 100644 --- a/examples/producer/producer.go +++ b/examples/producer/producer.go @@ -33,9 +33,11 @@ var clientPool = manage.NewClientPool() func main() { ctx := context.Background() - producerConf := manage.ProducerConfig{ - ClientConfig: manage.ClientConfig{ - Addr: "localhost:6650", + producerConf := manage.ManagedProducerConfig{ + ManagedClientConfig: manage.ManagedClientConfig{ + ClientConfig: manage.ClientConfig{ + Addr: "localhost:6650", + }, }, Topic: "multi-topic-10", diff --git a/utils/util.go b/utils/util.go index e1741cf..8be5d00 100644 --- a/utils/util.go +++ b/utils/util.go @@ -15,56 +15,13 @@ package utils import ( "flag" - "math/rand" - "sync" "testing" - "time" - - "github.com/wolfstudy/pulsar-client-go/pkg/api" ) // ################ // helper functions // ################ -var ( - randStringChars = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") - randStringMu = new(sync.Mutex) //protects randStringRand, which isn't threadsafe - randStringRand = rand.New(rand.NewSource(time.Now().UnixNano())) -) - -// authMethodTLS is the name of the TLS authentication -// method, used in the CONNECT message. -const AuthMethodTLS = "tls" - -const ( - // ProtoVersion is the Pulsar protocol version - // used by this client. - ProtoVersion = int32(api.ProtocolVersion_v12) - - // ClientVersion is an opaque string sent - // by the client to the server on connect, eg: - // "Pulsar-Client-Java-v1.15.2" - ClientVersion = "pulsar-client-go" - - // NndefRequestID defines a RequestID of -1. - // - // Usage example: - // https://github.com/apache/incubator-pulsar/blob/fdc7b8426d8253c9437777ae51a4639239550f00/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java#L325 - UndefRequestID = 1<<64 - 1 -) - -func RandString(n int) string { - b := make([]rune, n) - l := len(randStringChars) - randStringMu.Lock() - for i := range b { - b[i] = randStringChars[randStringRand.Intn(l)] - } - randStringMu.Unlock() - return string(b) -} - // pulsarAddr, if provided, is the Pulsar server to use for integration // tests (most likely Pulsar standalone running on localhost). var _PulsarAddr = flag.String("pulsar", "", "Address of Pulsar server to connect to. If blank, tests are skipped") From da1e0e04d3300c8d07df0dba028f5ae9d0be874b Mon Sep 17 00:00:00 2001 From: Evgeny Korolev Date: Fri, 8 Nov 2019 13:45:34 +0300 Subject: [PATCH 2/3] Add ClientConfig.UseTLS field, support pulsar+ssl:// scheme --- cmd/cli/README.md | 2 +- cmd/cli/main.go | 31 +++++++++++++++--------------- core/conn/conn.go | 10 ++++++++-- core/manage/client.go | 12 +++++++++--- core/manage/managed_client.go | 4 +++- core/manage/managed_client_pool.go | 8 +++++--- 6 files changed, 42 insertions(+), 25 deletions(-) diff --git a/cmd/cli/README.md b/cmd/cli/README.md index e3cd798..863f418 100644 --- a/cmd/cli/README.md +++ b/cmd/cli/README.md @@ -15,7 +15,7 @@ Usage of ./cli: -producer if true, produce messages, otherwise consume -pulsar string - pulsar address (default "localhost:6650") + pulsar address. May start with pulsar:// or pulsar+ssl:// (default "localhost:6650") -rate duration rate at which to send messages (default 1s) -shared diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 5aef046..2dafd5a 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -62,7 +62,7 @@ var args = struct { } func main() { - flag.StringVar(&args.pulsar, "pulsar", args.pulsar, "pulsar address") + flag.StringVar(&args.pulsar, "pulsar", args.pulsar, "pulsar address. May start with pulsar:// or pulsar+ssl://") flag.StringVar(&args.tlsCert, "tls-cert", args.tlsCert, "(optional) path to TLS certificate") flag.StringVar(&args.tlsKey, "tls-key", args.tlsKey, "(optional) path to TLS key") flag.StringVar(&args.tlsCA, "tls-ca", args.tlsKey, "(optional) path to root certificate") @@ -91,11 +91,11 @@ func main() { cancel() }() - var tlsCfg *tls.Config + tlsCfg := &tls.Config{ + InsecureSkipVerify: args.tlsSkipVerify, + } + if args.tlsCert != "" && args.tlsKey != "" { - tlsCfg = &tls.Config{ - InsecureSkipVerify: args.tlsSkipVerify, - } var err error cert, err := tls.LoadX509KeyPair(args.tlsCert, args.tlsKey) if err != nil { @@ -104,16 +104,6 @@ func main() { } tlsCfg.Certificates = []tls.Certificate{cert} - if args.tlsCA != "" { - rootCA, err := ioutil.ReadFile(args.tlsCA) - if err != nil { - fmt.Fprintln(os.Stderr, "error loading certificate authority:", err) - os.Exit(1) - } - tlsCfg.RootCAs = x509.NewCertPool() - tlsCfg.RootCAs.AppendCertsFromPEM(rootCA) - } - // Inspect certificate and print the CommonName attribute, // since this may be used for authorization if len(cert.Certificate[0]) > 0 { @@ -126,6 +116,16 @@ func main() { } } + if args.tlsCA != "" { + rootCA, err := ioutil.ReadFile(args.tlsCA) + if err != nil { + fmt.Fprintln(os.Stderr, "error loading certificate authority:", err) + os.Exit(1) + } + tlsCfg.RootCAs = x509.NewCertPool() + tlsCfg.RootCAs.AppendCertsFromPEM(rootCA) + } + mcp := manage.NewClientPool() switch args.producer { @@ -140,6 +140,7 @@ func main() { ManagedClientConfig: manage.ManagedClientConfig{ ClientConfig: manage.ClientConfig{ Addr: args.pulsar, + UseTLS: args.tlsCert != "" && args.tlsKey != "", TLSConfig: tlsCfg, Errs: asyncErrs, }, diff --git a/core/conn/conn.go b/core/conn/conn.go index 6c676fa..0718685 100644 --- a/core/conn/conn.go +++ b/core/conn/conn.go @@ -26,10 +26,15 @@ import ( "github.com/wolfstudy/pulsar-client-go/pkg/api" ) +const ( + SchemaPulsar = "pulsar://" + SchemaPulsarTSL = "pulsar+ssl://" +) + // NewTCPConn creates a core using a TCPv4 connection to the given // (pulsar server) address. func NewTCPConn(addr string, timeout time.Duration) (*Conn, error) { - addr = strings.TrimPrefix(addr, "pulsar://") + addr = strings.TrimPrefix(addr, SchemaPulsar) d := net.Dialer{ DualStack: false, @@ -50,7 +55,8 @@ func NewTCPConn(addr string, timeout time.Duration) (*Conn, error) { // NewTLSConn creates a core using a TCPv4+TLS connection to the given // (pulsar server) address. func NewTLSConn(addr string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) { - addr = strings.TrimPrefix(addr, "pulsar://") + addr = strings.TrimPrefix(addr, SchemaPulsar) + addr = strings.TrimPrefix(addr, SchemaPulsarTSL) d := net.Dialer{ DualStack: false, diff --git a/core/manage/client.go b/core/manage/client.go index 47ad7e4..fb70d0b 100644 --- a/core/manage/client.go +++ b/core/manage/client.go @@ -17,6 +17,7 @@ import ( "context" "crypto/tls" "fmt" + "strings" "time" "github.com/wolfstudy/pulsar-client-go/core/conn" @@ -35,10 +36,11 @@ const authMethodTLS = "tls" // ClientConfig is used to configure a Pulsar client. type ClientConfig struct { - Addr string // pulsar broker address. May start with pulsar:// + Addr string // pulsar broker address. May start with pulsar:// or pulsar+ssl:// phyAddr string // if set, the TCP connection should be made using this address. This is only ever set during Topic Lookup DialTimeout time.Duration // timeout to use when establishing TCP connection - TLSConfig *tls.Config // TLS configuration. May be nil, in which case TLS will not be used + UseTLS bool // use TLS to connect pulsar. + TLSConfig *tls.Config // TLS configuration, applies with UseTLS == true. May be nil Errs chan<- error // asynchronous errors will be sent here. May be nil } @@ -59,6 +61,10 @@ func (c ClientConfig) setDefaults() ClientConfig { c.DialTimeout = 5 * time.Second } + if strings.HasPrefix(c.Addr, conn.SchemaPulsarTSL) { + c.UseTLS = true + } + return c } @@ -69,7 +75,7 @@ func NewClient(cfg ClientConfig) (*Client, error) { var cnx *conn.Conn var err error - if cfg.TLSConfig != nil { + if cfg.UseTLS { cnx, err = conn.NewTLSConn(cfg.connAddr(), cfg.TLSConfig, cfg.DialTimeout) } else { cnx, err = conn.NewTCPConn(cfg.connAddr(), cfg.DialTimeout) diff --git a/core/manage/managed_client.go b/core/manage/managed_client.go index 9eccdc1..1377104 100644 --- a/core/manage/managed_client.go +++ b/core/manage/managed_client.go @@ -35,6 +35,8 @@ type ManagedClientConfig struct { // setDefaults returns a modified config with appropriate zero values set to defaults. func (m ManagedClientConfig) setDefaults() ManagedClientConfig { + m.ClientConfig = m.ClientConfig.setDefaults() + if m.PingFrequency <= 0 { m.PingFrequency = 30 * time.Second // default used by Java client } @@ -186,7 +188,7 @@ func (m *ManagedClient) newClient(ctx context.Context) (*Client, error) { if m.cfg.phyAddr != m.cfg.Addr { proxyBrokerURL = m.cfg.Addr } - if m.cfg.TLSConfig != nil { + if m.cfg.UseTLS { _, err = client.ConnectTLS(ctx, proxyBrokerURL) } else { _, err = client.Connect(ctx, proxyBrokerURL) diff --git a/core/manage/managed_client_pool.go b/core/manage/managed_client_pool.go index d2aac21..7eac38e 100644 --- a/core/manage/managed_client_pool.go +++ b/core/manage/managed_client_pool.go @@ -20,6 +20,7 @@ import ( "sync" "time" + "github.com/wolfstudy/pulsar-client-go/core/conn" "github.com/wolfstudy/pulsar-client-go/pkg/api" ) @@ -55,10 +56,11 @@ type clientPoolKey struct { // First the cache is checked for an existing client. If one doesn't exist, // a new one is created and cached, then returned. func (m *ClientPool) Get(cfg ManagedClientConfig) *ManagedClient { + cfg = cfg.setDefaults() key := clientPoolKey{ - logicalAddr: strings.TrimPrefix(cfg.Addr, "pulsar://"), + logicalAddr: strings.TrimPrefix(strings.TrimPrefix(cfg.Addr, conn.SchemaPulsar), conn.SchemaPulsarTSL), dialTimeout: cfg.DialTimeout, - tls: cfg.TLSConfig != nil, + tls: cfg.UseTLS, pingFrequency: cfg.PingFrequency, pingTimeout: cfg.PingTimeout, connectTimeout: cfg.ConnectTimeout, @@ -139,7 +141,7 @@ func (m *ClientPool) ForTopic(ctx context.Context, cfg ManagedClientConfig, topi // Update configured address with address // provided in response - if cfg.TLSConfig != nil { + if cfg.UseTLS { cfg.Addr = lookupResp.GetBrokerServiceUrlTls() } else { cfg.Addr = lookupResp.GetBrokerServiceUrl() From eddcb7b229e615b4f20fce3d4942f9484ef003f9 Mon Sep 17 00:00:00 2001 From: Evgeny Korolev Date: Fri, 8 Nov 2019 14:14:44 +0300 Subject: [PATCH 3/3] Add Authentication interface and basic implementations from JavaClient --- cmd/cli/main.go | 23 ++++++------- core/auth/basic.go | 48 +++++++++++++++++++++++++++ core/auth/disabled.go | 45 ++++++++++++++++++++++++++ core/auth/interface.go | 26 +++++++++++++++ core/auth/tls.go | 34 +++++++++++++++++++ core/auth/token.go | 61 +++++++++++++++++++++++++++++++++++ core/conn/connector.go | 3 +- core/conn/connector_test.go | 10 +++--- core/manage/client.go | 51 +++++++++++++++++------------ core/manage/managed_client.go | 6 +--- 10 files changed, 262 insertions(+), 45 deletions(-) create mode 100644 core/auth/basic.go create mode 100644 core/auth/disabled.go create mode 100644 core/auth/interface.go create mode 100644 core/auth/tls.go create mode 100644 core/auth/token.go diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 2dafd5a..c047ff0 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -31,6 +31,7 @@ import ( "syscall" "time" + "github.com/wolfstudy/pulsar-client-go/core/auth" "github.com/wolfstudy/pulsar-client-go/core/manage" "github.com/wolfstudy/pulsar-client-go/core/msg" ) @@ -91,23 +92,18 @@ func main() { cancel() }() + var authentication auth.Authentication tlsCfg := &tls.Config{ InsecureSkipVerify: args.tlsSkipVerify, } if args.tlsCert != "" && args.tlsKey != "" { - var err error - cert, err := tls.LoadX509KeyPair(args.tlsCert, args.tlsKey) - if err != nil { - fmt.Fprintln(os.Stderr, "error loading certificates:", err) - os.Exit(1) - } - tlsCfg.Certificates = []tls.Certificate{cert} + authentication = auth.NewAuthenticationTLS(args.tlsCert, args.tlsKey) // Inspect certificate and print the CommonName attribute, // since this may be used for authorization - if len(cert.Certificate[0]) > 0 { - x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if certs := authentication.GetAuthData().GetTlsCertificates(); len(certs) > 0 && len(certs[0].Certificate) > 0 && len(certs[0].Certificate[0]) > 0 { + x509Cert, err := x509.ParseCertificate(certs[0].Certificate[0]) if err != nil { fmt.Fprintln(os.Stderr, "error loading public certificate:", err) os.Exit(1) @@ -139,10 +135,11 @@ func main() { MaxReconnectDelay: time.Minute, ManagedClientConfig: manage.ManagedClientConfig{ ClientConfig: manage.ClientConfig{ - Addr: args.pulsar, - UseTLS: args.tlsCert != "" && args.tlsKey != "", - TLSConfig: tlsCfg, - Errs: asyncErrs, + Addr: args.pulsar, + UseTLS: args.tlsCert != "" && args.tlsKey != "", + TLSConfig: tlsCfg, + Authentication: authentication, + Errs: asyncErrs, }, }, } diff --git a/core/auth/basic.go b/core/auth/basic.go new file mode 100644 index 0000000..1936cb9 --- /dev/null +++ b/core/auth/basic.go @@ -0,0 +1,48 @@ +package auth + +import ( + "encoding/base64" + "net/http" +) + +func NewAuthenticationBasic(userId, password string) Authentication { + return authenticationBasic{userId: userId, password: password} +} + +type authenticationBasic struct { + userId, password string +} + +func (a authenticationBasic) GetAuthMethodName() string { + return "basic" +} +func (a authenticationBasic) GetAuthData() AuthenticationDataProvider { + commandAuthToken := []byte(a.userId + ":" + a.password) + httpAuthToken := "Basic " + base64.StdEncoding.EncodeToString(commandAuthToken) + return authenticationDataBasic{ + httpAuthToken: httpAuthToken, + commandAuthToken: commandAuthToken, + } +} + +type authenticationDataBasic struct { + authenticationDataNull + httpAuthToken string + commandAuthToken []byte +} + +func (adBasic authenticationDataBasic) HasDataForHttp() bool { + return true +} +func (adBasic authenticationDataBasic) GetHttpHeaders() http.Header { + return http.Header{ + "Authorization": []string{adBasic.httpAuthToken}, + } +} + +func (adBasic authenticationDataBasic) HasDataFromCommand() bool { + return true +} +func (adBasic authenticationDataBasic) GetCommandData() []byte { + return adBasic.commandAuthToken +} diff --git a/core/auth/disabled.go b/core/auth/disabled.go new file mode 100644 index 0000000..656070f --- /dev/null +++ b/core/auth/disabled.go @@ -0,0 +1,45 @@ +package auth + +import ( + "crypto/tls" + "net/http" +) + +func NewAuthenticationDisabled() Authentication { + return authenticationDisabled{} +} + +type authenticationDisabled struct{} + +func (a authenticationDisabled) GetAuthMethodName() string { + return "" +} +func (a authenticationDisabled) GetAuthData() AuthenticationDataProvider { + return authenticationDataNull{} +} + +type authenticationDataNull struct{} + +func (adNull authenticationDataNull) HasDataForTls() bool { + return false +} +func (adNull authenticationDataNull) GetTlsCertificates() []tls.Certificate { + return nil +} + +func (adNull authenticationDataNull) HasDataForHttp() bool { + return false +} +func (adNull authenticationDataNull) GetHttpAuthType() string { + return "" +} +func (adNull authenticationDataNull) GetHttpHeaders() http.Header { + return nil +} + +func (adNull authenticationDataNull) HasDataFromCommand() bool { + return false +} +func (adNull authenticationDataNull) GetCommandData() []byte { + return nil +} diff --git a/core/auth/interface.go b/core/auth/interface.go new file mode 100644 index 0000000..a78059c --- /dev/null +++ b/core/auth/interface.go @@ -0,0 +1,26 @@ +package auth + +import ( + "crypto/tls" + "net/http" +) + +type ( + Authentication interface { + GetAuthMethodName() string + GetAuthData() AuthenticationDataProvider + } + + AuthenticationDataProvider interface { + HasDataForTls() bool + GetTlsCertificates() []tls.Certificate + // GetTslPrivateKey is redundant due to Go TLS implementation + + HasDataForHttp() bool + GetHttpAuthType() string + GetHttpHeaders() http.Header + + HasDataFromCommand() bool + GetCommandData() []byte + } +) diff --git a/core/auth/tls.go b/core/auth/tls.go new file mode 100644 index 0000000..2c99cf7 --- /dev/null +++ b/core/auth/tls.go @@ -0,0 +1,34 @@ +package auth + +import "crypto/tls" + +func NewAuthenticationTLS(certFile, keyFile string) Authentication { + return authenticationTls{certFile: certFile, keyFile: keyFile} +} + +type authenticationTls struct { + certFile, keyFile string +} + +func (a authenticationTls) GetAuthMethodName() string { + return "tls" +} +func (a authenticationTls) GetAuthData() AuthenticationDataProvider { + if certificate, err := tls.LoadX509KeyPair(a.certFile, a.keyFile); err == nil { + return authenticationDataTls{certificates: []tls.Certificate{certificate}} + } else { + panic(err) + } +} + +type authenticationDataTls struct { + authenticationDataNull + certificates []tls.Certificate +} + +func (adTls authenticationDataTls) HasDataForTls() bool { + return true +} +func (adTls authenticationDataTls) GetTlsCertificates() []tls.Certificate { + return adTls.certificates +} diff --git a/core/auth/token.go b/core/auth/token.go new file mode 100644 index 0000000..9ca0df7 --- /dev/null +++ b/core/auth/token.go @@ -0,0 +1,61 @@ +package auth + +import ( + "io/ioutil" + "net/http" + "strings" +) + +type AuthenticationTokenSupplier func() []byte + +func NewAuthenticationTokenFromSupplier(tokenSupplier AuthenticationTokenSupplier) Authentication { + return authenticationToken{supplier: tokenSupplier} +} +func NewAuthenticationTokenFromString(token string) Authentication { + token = strings.TrimPrefix(token, "token:") + return authenticationToken{supplier: func() []byte { + return []byte(token) + }} +} +func NewAuthenticationTokenFromFile(fileName string) Authentication { + fileName = strings.TrimPrefix(fileName, "file:") + return authenticationToken{supplier: func() []byte { + if content, err := ioutil.ReadFile(fileName); err == nil { + return content + } else { + panic(err) + } + }} +} + +type authenticationToken struct { + supplier AuthenticationTokenSupplier +} + +func (a authenticationToken) GetAuthMethodName() string { + return "token" +} +func (a authenticationToken) GetAuthData() AuthenticationDataProvider { + return authenticationDataToken{supplier: a.supplier} +} + +type authenticationDataToken struct { + authenticationDataNull + supplier AuthenticationTokenSupplier +} + +func (adToken authenticationDataToken) HasDataForHttp() bool { + return true +} +func (adToken authenticationDataToken) GetHttpHeaders() http.Header { + return http.Header{ + "Authorization": []string{"Bearer " + string(adToken.supplier())}, + } +} + +func (adToken authenticationDataToken) HasDataFromCommand() bool { + return true +} +func (adToken authenticationDataToken) GetCommandData() []byte { + return adToken.supplier() +} diff --git a/core/conn/connector.go b/core/conn/connector.go index 4f2fddb..e9fd031 100644 --- a/core/conn/connector.go +++ b/core/conn/connector.go @@ -63,7 +63,7 @@ type Connector struct { // The provided context should have a timeout associated with it. // // It's required to have completed Connect/Connected before using the client. -func (c *Connector) Connect(ctx context.Context, authMethod, proxyBrokerURL string) (*api.CommandConnected, error) { +func (c *Connector) Connect(ctx context.Context, authMethod string, authData []byte, proxyBrokerURL string) (*api.CommandConnected, error) { resp, cancel, err := c.dispatcher.RegisterGlobal() if err != nil { return nil, err @@ -88,6 +88,7 @@ func (c *Connector) Connect(ctx context.Context, authMethod, proxyBrokerURL stri } if authMethod != "" { connect.AuthMethodName = proto.String(authMethod) + connect.AuthData = authData } if proxyBrokerURL != "" { connect.ProxyToBrokerUrl = proto.String(proxyBrokerURL) diff --git a/core/conn/connector_test.go b/core/conn/connector_test.go index 85e4c4a..94e5d5d 100644 --- a/core/conn/connector_test.go +++ b/core/conn/connector_test.go @@ -40,7 +40,7 @@ func TestConnector(t *testing.T) { go func() { var r response - r.success, r.err = c.Connect(ctx, "", "") + r.success, r.err = c.Connect(ctx, "", nil, "") resps <- r }() @@ -94,7 +94,7 @@ func TestConnector_Timeout(t *testing.T) { go func() { var r response - r.success, r.err = c.Connect(ctx, "", "") + r.success, r.err = c.Connect(ctx, "", nil, "") resps <- r }() @@ -148,7 +148,7 @@ func TestConnector_Error(t *testing.T) { go func() { var r response - r.success, r.err = c.Connect(ctx, "", "") + r.success, r.err = c.Connect(ctx, "", nil, "") resps <- r }() @@ -208,13 +208,13 @@ func TestConnector_Outstanding(t *testing.T) { defer cancel() // perform 1st connect - go c.Connect(ctx, "", "") + go c.Connect(ctx, "", nil, "") time.Sleep(100 * time.Millisecond) // Additional attempts to connect while there's // an outstanding one should cause an error - if _, err := c.Connect(ctx, "", ""); err == nil { + if _, err := c.Connect(ctx, "", nil, ""); err == nil { t.Fatalf("connector.connect() err = %v; expected non-nil because of outstanding request", err) } else { t.Logf("connector.connect() err = %v", err) diff --git a/core/manage/client.go b/core/manage/client.go index fb70d0b..09040b6 100644 --- a/core/manage/client.go +++ b/core/manage/client.go @@ -20,6 +20,7 @@ import ( "strings" "time" + "github.com/wolfstudy/pulsar-client-go/core/auth" "github.com/wolfstudy/pulsar-client-go/core/conn" "github.com/wolfstudy/pulsar-client-go/core/frame" "github.com/wolfstudy/pulsar-client-go/core/msg" @@ -36,12 +37,13 @@ const authMethodTLS = "tls" // ClientConfig is used to configure a Pulsar client. type ClientConfig struct { - Addr string // pulsar broker address. May start with pulsar:// or pulsar+ssl:// - phyAddr string // if set, the TCP connection should be made using this address. This is only ever set during Topic Lookup - DialTimeout time.Duration // timeout to use when establishing TCP connection - UseTLS bool // use TLS to connect pulsar. - TLSConfig *tls.Config // TLS configuration, applies with UseTLS == true. May be nil - Errs chan<- error // asynchronous errors will be sent here. May be nil + Addr string // pulsar broker address. May start with pulsar:// or pulsar+ssl:// + phyAddr string // if set, the TCP connection should be made using this address. This is only ever set during Topic Lookup + DialTimeout time.Duration // timeout to use when establishing TCP connection + UseTLS bool // use TLS to connect pulsar. + TLSConfig *tls.Config // TLS configuration, applies with UseTLS == true. May be nil + Authentication auth.Authentication // authentication provider. May be nil + Errs chan<- error // asynchronous errors will be sent here. May be nil } // connAddr returns the address that should be used @@ -65,6 +67,10 @@ func (c ClientConfig) setDefaults() ClientConfig { c.UseTLS = true } + if c.Authentication == nil { + c.Authentication = auth.NewAuthenticationDisabled() + } + return c } @@ -76,6 +82,13 @@ func NewClient(cfg ClientConfig) (*Client, error) { var err error if cfg.UseTLS { + if data := cfg.Authentication.GetAuthData(); data.HasDataForTls() { + if cfg.TLSConfig == nil { + cfg.TLSConfig = &tls.Config{} + } + cfg.TLSConfig.Certificates = data.GetTlsCertificates() + } + cnx, err = conn.NewTLSConn(cfg.connAddr(), cfg.TLSConfig, cfg.DialTimeout) } else { cnx, err = conn.NewTCPConn(cfg.connAddr(), cfg.DialTimeout) @@ -93,6 +106,8 @@ func NewClient(cfg ClientConfig) (*Client, error) { c: cnx, asyncErrs: utils.AsyncErrors(cfg.Errs), + authentication: cfg.Authentication, + dispatcher: dispatcher, subscriptions: subs, connector: conn.NewConnector(cnx, dispatcher), @@ -128,6 +143,8 @@ type Client struct { c *conn.Conn asyncErrs utils.AsyncErrors + authentication auth.Authentication + dispatcher *frame.Dispatcher subscriptions *Subscriptions @@ -155,27 +172,19 @@ func (c *Client) Close() error { // waits for either a CONNECTED response or the context to // timeout. Connect should be called immediately after // creating a client, before sending any other messages. -// The "auth method" is not set in the CONNECT message. -// See ConnectTLS for TLS auth method. // The proxyBrokerURL may be blank, or it can be used to indicate // that the client is connecting through a proxy server. // See "Connection establishment" for more info: // https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Connectionestablishment-6pslvw func (c *Client) Connect(ctx context.Context, proxyBrokerURL string) (*api.CommandConnected, error) { - return c.connector.Connect(ctx, "", proxyBrokerURL) -} + authMethod := c.authentication.GetAuthMethodName() -// ConnectTLS sends a Connect message to the Pulsar server, then -// waits for either a CONNECTED response or the context to -// timeout. Connect should be called immediately after -// creating a client, before sending any other messages. -// The "auth method" is set to tls in the CONNECT message. -// The proxyBrokerURL may be blank, or it can be used to indicate -// that the client is connecting through a proxy server. -// See "Connection establishment" for more info: -// https://pulsar.incubator.apache.org/docs/latest/project/BinaryProtocol/#Connectionestablishment-6pslvw -func (c *Client) ConnectTLS(ctx context.Context, proxyBrokerURL string) (*api.CommandConnected, error) { - return c.connector.Connect(ctx, authMethodTLS, proxyBrokerURL) + var authData []byte + if data := c.authentication.GetAuthData(); data.HasDataFromCommand() { + authData = data.GetCommandData() + } + + return c.connector.Connect(ctx, authMethod, authData, proxyBrokerURL) } // Ping sends a PING message to the Pulsar server, then diff --git a/core/manage/managed_client.go b/core/manage/managed_client.go index 1377104..059cbdd 100644 --- a/core/manage/managed_client.go +++ b/core/manage/managed_client.go @@ -188,11 +188,7 @@ func (m *ManagedClient) newClient(ctx context.Context) (*Client, error) { if m.cfg.phyAddr != m.cfg.Addr { proxyBrokerURL = m.cfg.Addr } - if m.cfg.UseTLS { - _, err = client.ConnectTLS(ctx, proxyBrokerURL) - } else { - _, err = client.Connect(ctx, proxyBrokerURL) - } + _, err = client.Connect(ctx, proxyBrokerURL) if err != nil { _ = client.Close()