diff --git a/lib/AckGroupingTracker.cc b/lib/AckGroupingTracker.cc deleted file mode 100644 index 8a9ea0df..00000000 --- a/lib/AckGroupingTracker.cc +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "AckGroupingTracker.h" - -#include -#include -#include - -#include "BitSet.h" -#include "ChunkMessageIdImpl.h" -#include "ClientConnection.h" -#include "Commands.h" -#include "LogUtils.h" -#include "MessageIdImpl.h" - -namespace pulsar { - -DECLARE_LOG_OBJECT(); - -void AckGroupingTracker::doImmediateAck(const MessageId& msgId, const ResultCallback& callback, - CommandAck_AckType ackType) const { - const auto cnx = connectionSupplier_(); - if (!cnx) { - LOG_DEBUG("Connection is not ready, ACK failed for " << msgId); - if (callback) { - callback(ResultAlreadyClosed); - } - return; - } - if (ackType == CommandAck_AckType_Individual) { - // If it's individual ack, we need to acknowledge all message IDs in a chunked message Id - // If it's cumulative ack, we only need to ack the last message ID of a chunked message. - // ChunkedMessageId return last chunk message ID by default, so we don't need to handle it. - if (auto chunkMessageId = - std::dynamic_pointer_cast(Commands::getMessageIdImpl(msgId))) { - auto msgIdList = chunkMessageId->getChunkedMessageIds(); - doImmediateAck(std::set(msgIdList.begin(), msgIdList.end()), callback); - return; - } - } - const auto& ackSet = Commands::getMessageIdImpl(msgId)->getBitSet(); - if (waitResponse_) { - const auto requestId = requestIdSupplier_(); - cnx->sendRequestWithId( - Commands::newAck(consumerId_, msgId.ledgerId(), msgId.entryId(), ackSet, ackType, requestId), - requestId) - .addListener([callback](Result result, const ResponseData&) { - if (callback) { - callback(result); - } - }); - } else { - cnx->sendCommand(Commands::newAck(consumerId_, msgId.ledgerId(), msgId.entryId(), ackSet, ackType)); - if (callback) { - callback(ResultOk); - } - } -} - -static std::ostream& operator<<(std::ostream& os, const std::set& msgIds) { - bool first = true; - for (auto&& msgId : msgIds) { - if (first) { - first = false; - } else { - os << ", "; - } - os << "[" << msgId << "]"; - } - return os; -} - -void AckGroupingTracker::doImmediateAck(const std::set& msgIds, - const ResultCallback& callback) const { - const auto cnx = connectionSupplier_(); - if (!cnx) { - LOG_DEBUG("Connection is not ready, ACK failed for " << msgIds); - if (callback) { - callback(ResultAlreadyClosed); - } - return; - } - - std::set ackMsgIds; - - for (const auto& msgId : msgIds) { - if (auto chunkMessageId = - std::dynamic_pointer_cast(Commands::getMessageIdImpl(msgId))) { - auto msgIdList = chunkMessageId->getChunkedMessageIds(); - ackMsgIds.insert(msgIdList.begin(), msgIdList.end()); - } else { - ackMsgIds.insert(msgId); - } - } - - if (Commands::peerSupportsMultiMessageAcknowledgement(cnx->getServerProtocolVersion())) { - if (waitResponse_) { - const auto requestId = requestIdSupplier_(); - cnx->sendRequestWithId(Commands::newMultiMessageAck(consumerId_, ackMsgIds, requestId), requestId) - .addListener([callback](Result result, const ResponseData&) { - if (callback) { - callback(result); - } - }); - } else { - cnx->sendCommand(Commands::newMultiMessageAck(consumerId_, ackMsgIds)); - if (callback) { - callback(ResultOk); - } - } - } else { - auto count = std::make_shared>(ackMsgIds.size()); - auto wrappedCallback = [callback, count](Result result) { - if (--*count == 0 && callback) { - callback(result); - } - }; - for (auto&& msgId : ackMsgIds) { - doImmediateAck(msgId, wrappedCallback, CommandAck_AckType_Individual); - } - } -} - -} // namespace pulsar diff --git a/lib/AckGroupingTracker.h b/lib/AckGroupingTracker.h index f62492f3..d00c3a27 100644 --- a/lib/AckGroupingTracker.h +++ b/lib/AckGroupingTracker.h @@ -22,11 +22,7 @@ #include #include -#include #include -#include - -#include "ProtoApiEnums.h" namespace pulsar { @@ -34,6 +30,9 @@ class ClientConnection; using ClientConnectionPtr = std::shared_ptr; using ClientConnectionWeakPtr = std::weak_ptr; using ResultCallback = std::function; +class ConsumerImpl; +using ConsumerImplPtr = std::shared_ptr; +using ConsumerImplWeakPtr = std::weak_ptr; /** * @class AckGroupingTracker @@ -42,19 +41,12 @@ using ResultCallback = std::function; */ class AckGroupingTracker : public std::enable_shared_from_this { public: - AckGroupingTracker(std::function connectionSupplier, - std::function requestIdSupplier, uint64_t consumerId, bool waitResponse) - : connectionSupplier_(std::move(connectionSupplier)), - requestIdSupplier_(std::move(requestIdSupplier)), - consumerId_(consumerId), - waitResponse_(waitResponse) {} - virtual ~AckGroupingTracker() = default; /** * Start tracking the ACK requests. */ - virtual void start() {} + virtual void start(const ConsumerImplPtr& consumer) { consumer_ = consumer; } /** * Since ACK requests are grouped and delayed, we need to do some best-effort duplicate check to @@ -72,7 +64,9 @@ class AckGroupingTracker : public std::enable_shared_from_this& msgIds, const ResultCallback& callback) const; - - private: - const std::function connectionSupplier_; - const std::function requestIdSupplier_; - const uint64_t consumerId_; + virtual void close() {} protected: - const bool waitResponse_; + ConsumerImplWeakPtr consumer_; }; // class AckGroupingTracker diff --git a/lib/AckGroupingTrackerDisabled.cc b/lib/AckGroupingTrackerDisabled.cc index d20de44a..92cdfba4 100644 --- a/lib/AckGroupingTrackerDisabled.cc +++ b/lib/AckGroupingTrackerDisabled.cc @@ -19,26 +19,41 @@ #include "AckGroupingTrackerDisabled.h" -#include "ProtoApiEnums.h" +#include "ConsumerImpl.h" namespace pulsar { void AckGroupingTrackerDisabled::addAcknowledge(const MessageId& msgId, const ResultCallback& callback) { - doImmediateAck(msgId, callback, CommandAck_AckType_Individual); + auto consumer = consumer_.lock(); + if (consumer && !consumer->isClosingOrClosed()) { + consumer->doImmediateAck(msgId, callback, CommandAck_AckType_Individual); + } else if (callback) { + callback(ResultAlreadyClosed); + } } void AckGroupingTrackerDisabled::addAcknowledgeList(const MessageIdList& msgIds, const ResultCallback& callback) { - std::set msgIdSet; - for (auto&& msgId : msgIds) { - msgIdSet.emplace(msgId); + auto consumer = consumer_.lock(); + if (consumer && !consumer->isClosingOrClosed()) { + std::set uniqueMsgIds(msgIds.begin(), msgIds.end()); + for (auto&& msgId : msgIds) { + uniqueMsgIds.insert(msgId); + } + consumer->doImmediateAck(uniqueMsgIds, callback); + } else if (callback) { + callback(ResultAlreadyClosed); } - doImmediateAck(msgIdSet, callback); } void AckGroupingTrackerDisabled::addAcknowledgeCumulative(const MessageId& msgId, const ResultCallback& callback) { - doImmediateAck(msgId, callback, CommandAck_AckType_Cumulative); + auto consumer = consumer_.lock(); + if (consumer && !consumer->isClosingOrClosed()) { + consumer->doImmediateAck(msgId, callback, CommandAck_AckType_Cumulative); + } else if (callback) { + callback(ResultAlreadyClosed); + } } } // namespace pulsar diff --git a/lib/AckGroupingTrackerEnabled.cc b/lib/AckGroupingTrackerEnabled.cc index d88426e2..3a2a35d7 100644 --- a/lib/AckGroupingTrackerEnabled.cc +++ b/lib/AckGroupingTrackerEnabled.cc @@ -23,11 +23,8 @@ #include #include -#include "ClientConnection.h" -#include "ClientImpl.h" -#include "Commands.h" +#include "ConsumerImpl.h" #include "ExecutorService.h" -#include "HandlerBase.h" #include "MessageIdUtil.h" namespace pulsar { @@ -45,7 +42,10 @@ static int compare(const MessageId& lhs, const MessageId& rhs) { } } -void AckGroupingTrackerEnabled::start() { this->scheduleTimer(); } +void AckGroupingTrackerEnabled::start(const ConsumerImplPtr& consumer) { + AckGroupingTracker::start(consumer); + this->scheduleTimer(); +} bool AckGroupingTrackerEnabled::isDuplicate(const MessageId& msgId) { { @@ -62,6 +62,13 @@ bool AckGroupingTrackerEnabled::isDuplicate(const MessageId& msgId) { } void AckGroupingTrackerEnabled::addAcknowledge(const MessageId& msgId, const ResultCallback& callback) { + auto consumer = consumer_.lock(); + if (!consumer || consumer->isClosingOrClosed()) { + if (callback) { + callback(ResultAlreadyClosed); + } + return; + } std::lock_guard lock(this->rmutexPendingIndAcks_); this->pendingIndividualAcks_.insert(msgId); if (waitResponse_) { @@ -70,12 +77,19 @@ void AckGroupingTrackerEnabled::addAcknowledge(const MessageId& msgId, const Res callback(ResultOk); } if (this->ackGroupingMaxSize_ > 0 && this->pendingIndividualAcks_.size() >= this->ackGroupingMaxSize_) { - this->flush(); + this->flush(consumer); } } void AckGroupingTrackerEnabled::addAcknowledgeList(const MessageIdList& msgIds, const ResultCallback& callback) { + auto consumer = consumer_.lock(); + if (!consumer || consumer->isClosingOrClosed()) { + if (callback) { + callback(ResultAlreadyClosed); + } + return; + } std::lock_guard lock(this->rmutexPendingIndAcks_); for (const auto& msgId : msgIds) { this->pendingIndividualAcks_.emplace(msgId); @@ -86,12 +100,19 @@ void AckGroupingTrackerEnabled::addAcknowledgeList(const MessageIdList& msgIds, callback(ResultOk); } if (this->ackGroupingMaxSize_ > 0 && this->pendingIndividualAcks_.size() >= this->ackGroupingMaxSize_) { - this->flush(); + this->flush(consumer); } } void AckGroupingTrackerEnabled::addAcknowledgeCumulative(const MessageId& msgId, const ResultCallback& callback) { + auto consumer = consumer_.lock(); + if (!consumer || consumer->isClosingOrClosed()) { + if (callback) { + callback(ResultAlreadyClosed); + } + return; + } std::unique_lock lock(this->mutexCumulativeAckMsgId_); bool completeCallback = true; if (compare(msgId, this->nextCumulativeAckMsgId_) > 0) { @@ -115,23 +136,28 @@ void AckGroupingTrackerEnabled::addAcknowledgeCumulative(const MessageId& msgId, callback(ResultOk); } } - AckGroupingTrackerEnabled::~AckGroupingTrackerEnabled() { - isClosed_ = true; - this->flush(); std::lock_guard lock(this->mutexTimer_); if (this->timer_) { cancelTimer(*this->timer_); } } -void AckGroupingTrackerEnabled::flush() { +void AckGroupingTrackerEnabled::close() { + flushAndClean(); + std::lock_guard lock(this->mutexTimer_); + if (this->timer_) { + cancelTimer(*this->timer_); + } +} + +void AckGroupingTrackerEnabled::flush(const ConsumerImplPtr& consumer) { // Send ACK for cumulative ACK requests. { std::lock_guard lock(this->mutexCumulativeAckMsgId_); if (this->requireCumulativeAck_) { - this->doImmediateAck(this->nextCumulativeAckMsgId_, this->latestCumulativeCallback_, - CommandAck_AckType_Cumulative); + consumer->doImmediateAck(this->nextCumulativeAckMsgId_, this->latestCumulativeCallback_, + CommandAck_AckType_Cumulative); this->latestCumulativeCallback_ = nullptr; this->requireCumulativeAck_ = false; } @@ -147,13 +173,17 @@ void AckGroupingTrackerEnabled::flush() { callback(result); } }; - this->doImmediateAck(this->pendingIndividualAcks_, callback); + consumer->doImmediateAck(this->pendingIndividualAcks_, callback); this->pendingIndividualAcks_.clear(); } } void AckGroupingTrackerEnabled::flushAndClean() { - this->flush(); + auto consumer = consumer_.lock(); + if (!consumer) { + return; + } + this->flush(consumer); { std::lock_guard lock(this->mutexCumulativeAckMsgId_); this->nextCumulativeAckMsgId_ = MessageId::earliest(); @@ -165,10 +195,6 @@ void AckGroupingTrackerEnabled::flushAndClean() { } void AckGroupingTrackerEnabled::scheduleTimer() { - if (isClosed_) { - return; - } - std::lock_guard lock(this->mutexTimer_); this->timer_ = this->executor_->createDeadlineTimer(); this->timer_->expires_after(std::chrono::milliseconds(std::max(1L, this->ackGroupingTimeMs_))); @@ -176,7 +202,11 @@ void AckGroupingTrackerEnabled::scheduleTimer() { this->timer_->async_wait([this, weakSelf](const ASIO_ERROR& ec) -> void { auto self = weakSelf.lock(); if (self && !ec) { - this->flush(); + auto consumer = consumer_.lock(); + if (!consumer || consumer->isClosingOrClosed()) { + return; + } + this->flush(consumer); this->scheduleTimer(); } }); diff --git a/lib/AckGroupingTrackerEnabled.h b/lib/AckGroupingTrackerEnabled.h index 5eb04b98..eb2b449f 100644 --- a/lib/AckGroupingTrackerEnabled.h +++ b/lib/AckGroupingTrackerEnabled.h @@ -21,8 +21,6 @@ #include -#include -#include #include #include @@ -35,9 +33,6 @@ class ClientImpl; using ClientImplPtr = std::shared_ptr; class ExecutorService; using ExecutorServicePtr = std::shared_ptr; -class HandlerBase; -using HandlerBasePtr = std::shared_ptr; -using HandlerBaseWeakPtr = std::weak_ptr; /** * @class AckGroupingTrackerEnabled @@ -45,34 +40,31 @@ using HandlerBaseWeakPtr = std::weak_ptr; */ class AckGroupingTrackerEnabled : public AckGroupingTracker { public: - AckGroupingTrackerEnabled(const std::function& connectionSupplier, - const std::function& requestIdSupplier, uint64_t consumerId, - bool waitResponse, long ackGroupingTimeMs, long ackGroupingMaxSize, + AckGroupingTrackerEnabled(long ackGroupingTimeMs, long ackGroupingMaxSize, bool waitResponse, const ExecutorServicePtr& executor) - : AckGroupingTracker(connectionSupplier, requestIdSupplier, consumerId, waitResponse), - ackGroupingTimeMs_(ackGroupingTimeMs), + : ackGroupingTimeMs_(ackGroupingTimeMs), ackGroupingMaxSize_(ackGroupingMaxSize), + waitResponse_(waitResponse), executor_(executor) { pendingIndividualCallbacks_.reserve(ackGroupingMaxSize); } ~AckGroupingTrackerEnabled(); - void start() override; + void start(const ConsumerImplPtr& consumer) override; bool isDuplicate(const MessageId& msgId) override; void addAcknowledge(const MessageId& msgId, const ResultCallback& callback) override; void addAcknowledgeList(const MessageIdList& msgIds, const ResultCallback& callback) override; void addAcknowledgeCumulative(const MessageId& msgId, const ResultCallback& callback) override; - void flush(); void flushAndClean() override; + void close() override; + + private: + void flush(const ConsumerImplPtr& consumer); protected: - //! Method for scheduling grouping timer. void scheduleTimer(); - //! State - std::atomic_bool isClosed_{false}; - //! Next message ID to be cumulatively cumulatively. MessageId nextCumulativeAckMsgId_{MessageId::earliest()}; bool requireCumulativeAck_{false}; @@ -90,6 +82,8 @@ class AckGroupingTrackerEnabled : public AckGroupingTracker { //! Max number of ACK requests can be grouped. const long ackGroupingMaxSize_; + const bool waitResponse_; + //! ACK request sender's scheduled executor. const ExecutorServicePtr executor_; diff --git a/lib/ConsumerImpl.cc b/lib/ConsumerImpl.cc index 325addaa..92d25cb0 100644 --- a/lib/ConsumerImpl.cc +++ b/lib/ConsumerImpl.cc @@ -74,6 +74,22 @@ static boost::optional getStartMessageId(const boost::optional newAckGroupingTracker(const std::string& topic, + const ConsumerConfiguration& config, + const ClientImplPtr& client) { + if (TopicName::get(topic)->isPersistent()) { + if (config.getAckGroupingTimeMs() > 0) { + return std::make_shared( + config.getAckGroupingTimeMs(), config.getAckGroupingMaxSize(), config.isAckReceiptEnabled(), + client->getIOExecutorProvider()->get()); + } else { + return std::make_shared(); + } + } else { + return std::make_shared(); + } +} + ConsumerImpl::ConsumerImpl(const ClientImplPtr& client, const std::string& topic, const std::string& subscriptionName, const ConsumerConfiguration& conf, bool isPersistent, const ConsumerInterceptorsPtr& interceptors, @@ -105,12 +121,14 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr& client, const std::string& topic consumerStr_("[" + topic + ", " + subscriptionName + ", " + std::to_string(consumerId_) + "] "), messageListenerRunning_(!conf.isStartPaused()), negativeAcksTracker_(std::make_shared(client, *this, conf)), + ackGroupingTrackerPtr_(newAckGroupingTracker(topic, conf, client)), readCompacted_(conf.isReadCompacted()), startMessageId_(getStartMessageId(startMessageId, conf.isStartMessageIdInclusive())), maxPendingChunkedMessage_(conf.getMaxPendingChunkedMessage()), autoAckOldestChunkedMessageOnQueueFull_(conf.isAutoAckOldestChunkedMessageOnQueueFull()), expireTimeOfIncompleteChunkedMessageMs_(conf.getExpireTimeOfIncompleteChunkedMessageMs()), - interceptors_(interceptors) { + interceptors_(interceptors), + requestIdGenerator_(client->getRequestIdGenerator()) { // Initialize un-ACKed messages OT tracker. if (conf.getUnAckedMessagesTimeoutMs() != 0) { if (conf.getTickDurationInMs() > 0) { @@ -169,9 +187,8 @@ ConsumerImpl::~ConsumerImpl() { LOG_WARN(consumerStr_ << "Destroyed consumer which was not properly closed"); ClientConnectionPtr cnx = getCnx().lock(); - ClientImplPtr client = client_.lock(); - if (client && cnx) { - int requestId = client->newRequestId(); + if (cnx) { + auto requestId = newRequestId(); cnx->sendRequestWithId(Commands::newCloseConsumer(consumerId_, requestId), requestId); cnx->removeConsumer(consumerId_); LOG_INFO(consumerStr_ << "Closed consumer for race condition: " << consumerId_); @@ -186,8 +203,6 @@ void ConsumerImpl::setPartitionIndex(int partitionIndex) { partitionIndex_ = par int ConsumerImpl::getPartitionIndex() { return partitionIndex_; } -uint64_t ConsumerImpl::getConsumerId() { return consumerId_; } - Future ConsumerImpl::getConsumerCreatedFuture() { return consumerCreatedPromise_.getFuture(); } @@ -198,38 +213,7 @@ const std::string& ConsumerImpl::getTopic() const { return topic(); } void ConsumerImpl::start() { HandlerBase::start(); - - std::weak_ptr weakSelf{get_shared_this_ptr()}; - auto connectionSupplier = [weakSelf]() -> ClientConnectionPtr { - auto self = weakSelf.lock(); - if (!self) { - return nullptr; - } - return self->getCnx().lock(); - }; - - // NOTE: start() is always called in `ClientImpl`'s method, so lock() returns not null - const auto requestIdGenerator = client_.lock()->getRequestIdGenerator(); - const auto requestIdSupplier = [requestIdGenerator] { return (*requestIdGenerator)++; }; - - // Initialize ackGroupingTrackerPtr_ here because the get_shared_this_ptr() was not initialized until the - // constructor completed. - if (TopicName::get(topic())->isPersistent()) { - if (config_.getAckGroupingTimeMs() > 0) { - ackGroupingTrackerPtr_.reset(new AckGroupingTrackerEnabled( - connectionSupplier, requestIdSupplier, consumerId_, config_.isAckReceiptEnabled(), - config_.getAckGroupingTimeMs(), config_.getAckGroupingMaxSize(), - client_.lock()->getIOExecutorProvider()->get())); - } else { - ackGroupingTrackerPtr_.reset(new AckGroupingTrackerDisabled( - connectionSupplier, requestIdSupplier, consumerId_, config_.isAckReceiptEnabled())); - } - } else { - LOG_INFO(getName() << "ACK will NOT be sent to broker for this non-persistent topic."); - ackGroupingTrackerPtr_.reset(new AckGroupingTracker(connectionSupplier, requestIdSupplier, - consumerId_, config_.isAckReceiptEnabled())); - } - ackGroupingTrackerPtr_->start(); + ackGroupingTrackerPtr_->start(get_shared_this_ptr()); } void ConsumerImpl::beforeConnectionChange(ClientConnection& cnx) { cnx.removeConsumer(consumerId_); } @@ -265,7 +249,7 @@ Future ConsumerImpl::connectionOpened(const ClientConnectionPtr& c unAckedMessageTrackerPtr_->clear(); ClientImplPtr client = client_.lock(); - long requestId = client->newRequestId(); + auto requestId = newRequestId(); SharedBuffer cmd = Commands::newSubscribe( topic(), subscription_, consumerId_, requestId, getSubType(), getConsumerName(), subscriptionMode_, subscribeMessageId, readCompacted_, config_.getProperties(), config_.getSubscriptionProperties(), @@ -344,7 +328,7 @@ Result ConsumerImpl::handleCreateConsumer(const ClientConnectionPtr& cnx, Result // Creating the consumer has timed out. We need to ensure the broker closes the consumer // in case it was indeed created, otherwise it might prevent new subscribe operation, // since we are not closing the connection - int requestId = client_.lock()->newRequestId(); + auto requestId = newRequestId(); cnx->sendRequestWithId(Commands::newCloseConsumer(consumerId_, requestId), requestId); } @@ -396,7 +380,7 @@ void ConsumerImpl::unsubscribeAsync(const ResultCallback& originalCallback) { LOG_DEBUG(getName() << "Unsubscribe request sent for consumer - " << consumerId_); ClientImplPtr client = client_.lock(); lock.unlock(); - int requestId = client->newRequestId(); + auto requestId = newRequestId(); SharedBuffer cmd = Commands::newUnsubscribe(consumerId_, requestId); auto self = get_shared_this_ptr(); cnx->sendRequestWithId(cmd, requestId) @@ -591,17 +575,16 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: LOG_DEBUG(getName() << " metadata.has_num_messages_in_batch() = " << metadata.has_num_messages_in_batch()); - uint32_t numOfMessageReceived = m.impl_->metadata.num_messages_in_batch(); - auto ackGroupingTrackerPtr = ackGroupingTrackerPtr_; - if (ackGroupingTrackerPtr == nullptr) { // The consumer is closing + const auto state = state_.load(std::memory_order_relaxed); + if (state == Closing || state == Closed) { return; } - if (ackGroupingTrackerPtr->isDuplicate(m.getMessageId())) { + uint32_t numOfMessageReceived = m.impl_->metadata.num_messages_in_batch(); + if (ackGroupingTrackerPtr_->isDuplicate(m.getMessageId())) { LOG_DEBUG(getName() << " Ignoring message as it was ACKed earlier by same consumer."); increaseAvailablePermits(cnx, numOfMessageReceived); return; } - ackGroupingTrackerPtr.reset(); if (metadata.has_num_messages_in_batch()) { BitSet::Data words(msg.ack_set_size()); @@ -1340,12 +1323,7 @@ void ConsumerImpl::closeAsync(const ResultCallback& originalCallback) { incomingMessages_.close(); // Flush pending grouped ACK requests. - if (ackGroupingTrackerPtr_.use_count() != 1) { - LOG_ERROR("AckGroupingTracker is shared by other " - << (ackGroupingTrackerPtr_.use_count() - 1) - << " threads, which will prevent flushing the ACKs"); - } - ackGroupingTrackerPtr_.reset(); + ackGroupingTrackerPtr_->close(); negativeAcksTracker_->close(); ClientConnectionPtr cnx = getCnx().lock(); @@ -1364,7 +1342,7 @@ void ConsumerImpl::closeAsync(const ResultCallback& originalCallback) { cancelTimers(); - int requestId = client->newRequestId(); + auto requestId = newRequestId(); auto self = get_shared_this_ptr(); cnx->sendRequestWithId(Commands::newCloseConsumer(consumerId_, requestId), requestId) .addListener([self, callback](Result result, const ResponseData&) { callback(result); }); @@ -1375,7 +1353,7 @@ const std::string& ConsumerImpl::getName() const { return consumerStr_; } void ConsumerImpl::shutdown() { internalShutdown(); } void ConsumerImpl::internalShutdown() { - ackGroupingTrackerPtr_.reset(); + ackGroupingTrackerPtr_->close(); incomingMessages_.clear(); possibleSendToDeadLetterTopicMessages_.clear(); resetCnx(); @@ -1499,8 +1477,7 @@ void ConsumerImpl::getBrokerConsumerStatsAsync(const BrokerConsumerStatsCallback ClientConnectionPtr cnx = getCnx().lock(); if (cnx) { if (cnx->getServerProtocolVersion() >= proto::v8) { - ClientImplPtr client = client_.lock(); - uint64_t requestId = client->newRequestId(); + auto requestId = newRequestId(); LOG_DEBUG(getName() << " Sending ConsumerStats Command for Consumer - " << getConsumerId() << ", requestId - " << requestId); @@ -1542,12 +1519,7 @@ void ConsumerImpl::seekAsync(const MessageId& msgId, const ResultCallback& callb return; } - ClientImplPtr client = client_.lock(); - if (!client) { - LOG_ERROR(getName() << "Client is expired when seekAsync " << msgId); - return; - } - const auto requestId = client->newRequestId(); + const auto requestId = newRequestId(); seekAsyncInternal(requestId, Commands::newSeek(consumerId_, requestId, msgId), SeekArg{msgId}, callback); } @@ -1561,12 +1533,7 @@ void ConsumerImpl::seekAsync(uint64_t timestamp, const ResultCallback& callback) return; } - ClientImplPtr client = client_.lock(); - if (!client) { - LOG_ERROR(getName() << "Client is expired when seekAsync " << timestamp); - return; - } - const auto requestId = client->newRequestId(); + const auto requestId = newRequestId(); seekAsyncInternal(requestId, Commands::newSeek(consumerId_, requestId, timestamp), SeekArg{timestamp}, callback); } @@ -1658,8 +1625,7 @@ void ConsumerImpl::internalGetLastMessageIdAsync(const BackoffPtr& backoff, Time ClientConnectionPtr cnx = getCnx().lock(); if (cnx) { if (cnx->getServerProtocolVersion() >= proto::v12) { - ClientImplPtr client = client_.lock(); - uint64_t requestId = client->newRequestId(); + auto requestId = newRequestId(); LOG_DEBUG(getName() << " Sending getLastMessageId Command for Consumer - " << getConsumerId() << ", requestId - " << requestId); @@ -1926,4 +1892,100 @@ void ConsumerImpl::processPossibleToDLQ(const MessageId& messageId, const Proces } } +void ConsumerImpl::doImmediateAck(const ClientConnectionPtr& cnx, const MessageId& msgId, + CommandAck_AckType ackType, const ResultCallback& callback) { + const auto& ackSet = Commands::getMessageIdImpl(msgId)->getBitSet(); + if (config_.isAckReceiptEnabled()) { + auto requestId = newRequestId(); + cnx->sendRequestWithId( + Commands::newAck(consumerId_, msgId.ledgerId(), msgId.entryId(), ackSet, ackType, requestId), + requestId) + .addListener([callback](Result result, const ResponseData&) { + if (callback) { + callback(result); + } + }); + } else { + cnx->sendCommand(Commands::newAck(consumerId_, msgId.ledgerId(), msgId.entryId(), ackSet, ackType)); + if (callback) { + callback(ResultOk); + } + } +} + +void ConsumerImpl::doImmediateAck(const ClientConnectionPtr& cnx, const std::set& msgIds, + const ResultCallback& callback) { + std::set ackMsgIds; + + for (const auto& msgId : msgIds) { + if (auto chunkMessageId = + std::dynamic_pointer_cast(Commands::getMessageIdImpl(msgId))) { + auto msgIdList = chunkMessageId->getChunkedMessageIds(); + ackMsgIds.insert(msgIdList.begin(), msgIdList.end()); + } else { + ackMsgIds.insert(msgId); + } + } + if (Commands::peerSupportsMultiMessageAcknowledgement(cnx->getServerProtocolVersion())) { + if (config_.isAckReceiptEnabled()) { + auto requestId = newRequestId(); + cnx->sendRequestWithId(Commands::newMultiMessageAck(consumerId_, ackMsgIds, requestId), requestId) + .addListener([callback](Result result, const ResponseData&) { + if (callback) { + callback(result); + } + }); + } else { + cnx->sendCommand(Commands::newMultiMessageAck(consumerId_, ackMsgIds)); + if (callback) { + callback(ResultOk); + } + } + } else { + auto count = std::make_shared>(ackMsgIds.size()); + auto wrappedCallback = [callback, count](Result result) { + if (--*count == 0 && callback) { + callback(result); + } + }; + for (auto&& msgId : ackMsgIds) { + doImmediateAck(msgId, wrappedCallback, CommandAck_AckType_Individual); + } + } +} + +void ConsumerImpl::doImmediateAck(const MessageId& msgId, const ResultCallback& callback, + CommandAck_AckType ackType) { + const auto cnx = getCnx().lock(); + if (!cnx) { + if (callback) { + callback(ResultAlreadyClosed); + } + return; + } + if (ackType == CommandAck_AckType_Individual) { + // If it's individual ack, we need to acknowledge all message IDs in a chunked message Id + // If it's cumulative ack, we only need to ack the last message ID of a chunked message. + // ChunkedMessageId return last chunk message ID by default, so we don't need to handle it. + if (auto chunkMessageId = + std::dynamic_pointer_cast(Commands::getMessageIdImpl(msgId))) { + auto msgIdList = chunkMessageId->getChunkedMessageIds(); + doImmediateAck(cnx, std::set(msgIdList.begin(), msgIdList.end()), callback); + return; + } + } + doImmediateAck(cnx, msgId, ackType, callback); +} + +void ConsumerImpl::doImmediateAck(const std::set& msgIds, const ResultCallback& callback) { + const auto cnx = getCnx().lock(); + if (!cnx) { + if (callback) { + callback(ResultAlreadyClosed); + } + return; + } + doImmediateAck(cnx, msgIds, callback); +} + } /* namespace pulsar */ diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h index 055b487e..5e06723b 100644 --- a/lib/ConsumerImpl.h +++ b/lib/ConsumerImpl.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include "BrokerConsumerStatsImpl.h" @@ -96,7 +97,7 @@ class ConsumerImpl : public ConsumerImplBase { void setPartitionIndex(int partitionIndex); int getPartitionIndex(); void sendFlowPermitsToBroker(const ClientConnectionPtr& cnx, int numMessages); - uint64_t getConsumerId(); + uint64_t getConsumerId() const noexcept { return consumerId_; } void messageReceived(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, bool& isChecksumValid, proto::BrokerEntryMetadata& brokerEntryMetadata, proto::MessageMetadata& msgMetadata, SharedBuffer& payload); @@ -124,6 +125,10 @@ class ConsumerImpl : public ConsumerImplBase { void shutdown() override; void internalShutdown(); bool isClosed() override; + bool isClosingOrClosed() const noexcept { + const auto state = state_.load(std::memory_order_relaxed); + return state == Closing || state == Closed; + } bool isOpen() override; Result pauseMessageListener() override; Result resumeMessageListener() override; @@ -152,6 +157,9 @@ class ConsumerImpl : public ConsumerImplBase { void beforeConnectionChange(ClientConnection& cnx) override; void onNegativeAcksSend(const std::set& messageIds); + void doImmediateAck(const MessageId& msgId, const ResultCallback& callback, CommandAck_AckType ackType); + void doImmediateAck(const std::set& msgIds, const ResultCallback& callback); + protected: // overrided methods from HandlerBase Future connectionOpened(const ClientConnectionPtr& cnx) override; @@ -237,7 +245,7 @@ class ConsumerImpl : public ConsumerImplBase { std::queue pendingReceives_; std::atomic_int availablePermits_; const int receiverQueueRefillThreshold_; - uint64_t consumerId_; + const uint64_t consumerId_; const std::string consumerStr_; int32_t partitionIndex_ = -1; Promise consumerCreatedPromise_; @@ -246,7 +254,7 @@ class ConsumerImpl : public ConsumerImplBase { UnAckedMessageTrackerPtr unAckedMessageTrackerPtr_; BrokerConsumerStatsImpl brokerConsumerStats_; std::shared_ptr negativeAcksTracker_; - AckGroupingTrackerPtr ackGroupingTrackerPtr_; + const AckGroupingTrackerPtr ackGroupingTrackerPtr_; MessageCryptoPtr msgCrypto_; const bool readCompacted_; @@ -340,6 +348,9 @@ class ConsumerImpl : public ConsumerImplBase { std::atomic_bool expireChunkMessageTaskScheduled_{false}; ConsumerInterceptorsPtr interceptors_; + const std::shared_ptr> requestIdGenerator_; + + uint64_t newRequestId() const { return (*requestIdGenerator_)++; } void triggerCheckExpiredChunkedTimer(); void discardChunkMessages(const std::string& uuid, const MessageId& messageId, bool autoAck); @@ -379,6 +390,11 @@ class ConsumerImpl : public ConsumerImplBase { } } + void doImmediateAck(const ClientConnectionPtr& cnx, const MessageId& msgId, CommandAck_AckType ackType, + const ResultCallback& callback); + void doImmediateAck(const ClientConnectionPtr& cnx, const std::set& msgIds, + const ResultCallback& callback); + friend class PulsarFriend; friend class MultiTopicsConsumerImpl; diff --git a/tests/AcknowledgeTest.cc b/tests/AcknowledgeTest.cc index 0e2183e2..464d5d2e 100644 --- a/tests/AcknowledgeTest.cc +++ b/tests/AcknowledgeTest.cc @@ -375,6 +375,34 @@ TEST_F(AcknowledgeTest, testAckReceiptEnabled) { client.close(); } +TEST_F(AcknowledgeTest, testCloseConsumer) { + Client client(lookupUrl); + const auto topic = "test-close-consumer" + unique_str(); + Producer producer; + ASSERT_EQ(ResultOk, client.createProducer(topic, producer)); + ConsumerConfiguration consumerConfig; + consumerConfig.setAckGroupingTimeMs(60000); + consumerConfig.setAckGroupingMaxSize(100); + Consumer consumer; + ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConfig, consumer)); + + producer.send(MessageBuilder().setContent("msg-0").build()); + Message msg; + ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); + consumer.acknowledgeAsync( + msg, nullptr); // it just adds the msg id to the pending ack list due to the ack grouping configs + consumer.close(); // it will flush the pending ACK and prevent any further ack + ASSERT_EQ(ResultAlreadyClosed, consumer.acknowledge(msg)); + ASSERT_EQ(ResultAlreadyClosed, consumer.acknowledgeCumulative(msg)); + ASSERT_EQ(ResultAlreadyClosed, consumer.acknowledge(std::vector{msg.getMessageId()})); + + producer.send(MessageBuilder().setContent("msg-1").build()); + // Recreate the consumer to verify the first message is acknowledged + ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConfig, consumer)); + ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); + ASSERT_EQ("msg-1", msg.getDataAsString()); +} + INSTANTIATE_TEST_SUITE_P(BasicEndToEndTest, AcknowledgeTest, testing::Combine(testing::Values(100, 0), testing::Values(true, false)), [](const testing::TestParamInfo>& info) { diff --git a/tests/BasicEndToEndTest.cc b/tests/BasicEndToEndTest.cc index 43306099..5cf478b2 100644 --- a/tests/BasicEndToEndTest.cc +++ b/tests/BasicEndToEndTest.cc @@ -38,7 +38,6 @@ #include "lib/AckGroupingTrackerEnabled.h" #include "lib/ClientConnection.h" #include "lib/ClientImpl.h" -#include "lib/Commands.h" #include "lib/ConsumerImpl.h" #include "lib/Future.h" #include "lib/Latch.h" @@ -3633,7 +3632,7 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerDefaultBehavior) { ASSERT_EQ(configConsumer.getAckGroupingTimeMs(), 100); ASSERT_EQ(configConsumer.getAckGroupingMaxSize(), 1000); - AckGroupingTracker tracker{nullptr, nullptr, 0, false}; + AckGroupingTracker tracker; Message msg; ASSERT_FALSE(tracker.isDuplicate(msg.getMessageId())); } @@ -3672,10 +3671,8 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerSingleAckBehavior) { // Send ACK. auto clientImplPtr = PulsarFriend::getClientImplPtr(client); - AckGroupingTrackerDisabled tracker([&consumerImpl]() { return consumerImpl.getCnx().lock(); }, - [&clientImplPtr] { return clientImplPtr->newRequestId(); }, - consumerImpl.getConsumerId(), false); - tracker.start(); + AckGroupingTrackerDisabled tracker; + tracker.start(PulsarFriend::getConsumerImplPtr(consumer)); for (auto msgIdx = 0; msgIdx < numMsg; ++msgIdx) { auto connPtr = connWeakPtr.lock(); ASSERT_NE(connPtr, nullptr); @@ -3707,8 +3704,6 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerMultiAckBehavior) { Consumer consumer; ASSERT_EQ(ResultOk, client.subscribe(topicName, subName, consumer)); - auto &consumerImpl = PulsarFriend::getConsumerImpl(consumer); - // Sending and receiving messages. for (auto count = 0; count < numMsg; ++count) { Message msg = MessageBuilder().setContent(std::string("MSG-") + std::to_string(count)).build(); @@ -3724,10 +3719,8 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerMultiAckBehavior) { // Send ACK. auto clientImplPtr = PulsarFriend::getClientImplPtr(client); - AckGroupingTrackerDisabled tracker([&consumerImpl]() { return consumerImpl.getCnx().lock(); }, - [&clientImplPtr] { return clientImplPtr->newRequestId(); }, - consumerImpl.getConsumerId(), false); - tracker.start(); + AckGroupingTrackerDisabled tracker; + tracker.start(PulsarFriend::getConsumerImplPtr(consumer)); tracker.addAcknowledgeList(recvMsgId, nullptr); consumer.close(); @@ -3755,7 +3748,6 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerDisabledIndividualAck) { Consumer consumer; ASSERT_EQ(ResultOk, client.subscribe(topicName, subName, consumer)); - auto &consumerImpl = PulsarFriend::getConsumerImpl(consumer); // Sending and receiving messages. for (auto count = 0; count < numMsg; ++count) { @@ -3771,8 +3763,8 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerDisabledIndividualAck) { } // Send ACK. - AckGroupingTrackerDisabled tracker([&consumerImpl] { return consumerImpl.getCnx().lock(); }, nullptr, - consumerImpl.getConsumerId(), false); + AckGroupingTrackerDisabled tracker; + tracker.start(PulsarFriend::getConsumerImplPtr(consumer)); for (auto &msgId : recvMsgId) { tracker.addAcknowledge(msgId, nullptr); } @@ -3802,7 +3794,6 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerDisabledCumulativeAck) { Consumer consumer; ASSERT_EQ(ResultOk, client.subscribe(topicName, subName, consumer)); - auto &consumerImpl = PulsarFriend::getConsumerImpl(consumer); // Sending and receiving messages. for (auto count = 0; count < numMsg; ++count) { @@ -3818,8 +3809,8 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerDisabledCumulativeAck) { } // Send ACK. - AckGroupingTrackerDisabled tracker([&consumerImpl] { return consumerImpl.getCnx().lock(); }, nullptr, - consumerImpl.getConsumerId(), false); + AckGroupingTrackerDisabled tracker; + tracker.start(PulsarFriend::getConsumerImplPtr(consumer)); auto &latestMsgId = *std::max_element(recvMsgId.begin(), recvMsgId.end()); tracker.addAcknowledgeCumulative(latestMsgId, nullptr); consumer.close(); @@ -3861,7 +3852,6 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerEnabledIndividualAck) { Consumer consumer; ASSERT_EQ(ResultOk, client.subscribe(topicName, subName, consumer)); - auto consumerImpl = PulsarFriend::getConsumerImplPtr(consumer); // Sending and receiving messages. for (auto count = 0; count < numMsg; ++count) { @@ -3877,9 +3867,8 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerEnabledIndividualAck) { } auto tracker = std::make_shared( - [&consumerImpl] { return consumerImpl->getCnx().lock(); }, nullptr, consumerImpl->getConsumerId(), - false, ackGroupingTimeMs, ackGroupingMaxSize, clientImplPtr->getIOExecutorProvider()->get()); - tracker->start(); + ackGroupingTimeMs, ackGroupingMaxSize, false, clientImplPtr->getIOExecutorProvider()->get()); + tracker->start(PulsarFriend::getConsumerImplPtr(consumer)); ASSERT_EQ(tracker->getPendingIndividualAcks().size(), 0); ASSERT_EQ(tracker->getAckGroupingTimeMs(), ackGroupingTimeMs); ASSERT_EQ(tracker->getAckGroupingMaxSize(), ackGroupingMaxSize); @@ -3939,9 +3928,8 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerEnabledCumulativeAck) { std::sort(recvMsgId.begin(), recvMsgId.end()); auto tracker0 = std::make_shared( - [&consumerImpl0] { return consumerImpl0->getCnx().lock(); }, nullptr, consumerImpl0->getConsumerId(), - false, ackGroupingTimeMs, ackGroupingMaxSize, clientImplPtr->getIOExecutorProvider()->get()); - tracker0->start(); + ackGroupingTimeMs, ackGroupingMaxSize, false, clientImplPtr->getIOExecutorProvider()->get()); + tracker0->start(PulsarFriend::getConsumerImplPtr(consumer)); ASSERT_EQ(tracker0->getNextCumulativeAckMsgId(), MessageId::earliest()); ASSERT_FALSE(tracker0->requireCumulativeAck()); @@ -3976,11 +3964,10 @@ TEST(BasicEndToEndTest, testAckGroupingTrackerEnabledCumulativeAck) { auto ret = consumer.receive(msg, 1000); ASSERT_EQ(ResultTimeout, ret) << "Received redundant message: " << msg.getDataAsString(); auto tracker1 = std::make_shared( - [&consumerImpl1] { return consumerImpl1->getCnx().lock(); }, nullptr, consumerImpl1->getConsumerId(), - false, ackGroupingTimeMs, ackGroupingMaxSize, clientImplPtr->getIOExecutorProvider()->get()); - tracker1->start(); + ackGroupingTimeMs, ackGroupingMaxSize, false, clientImplPtr->getIOExecutorProvider()->get()); + tracker1->start(PulsarFriend::getConsumerImplPtr(consumer)); tracker1->addAcknowledgeCumulative(recvMsgId[numMsg - 1], nullptr); - tracker1.reset(); + tracker1->close(); consumer.close(); ASSERT_EQ(ResultOk, client.subscribe(topicName, subName, consumer)); diff --git a/tests/ConsumerTest.cc b/tests/ConsumerTest.cc index dfbc2765..3aa1dd3c 100644 --- a/tests/ConsumerTest.cc +++ b/tests/ConsumerTest.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -1542,12 +1543,19 @@ TEST(ConsumerTest, testConsumerListenerShouldNotSegfaultAfterClose) { consumerConfig.setSubscriptionInitialPosition(InitialPositionEarliest); Latch latchFirstReceiveMsg(1); Latch latchAfterClosed(1); - consumerConfig.setMessageListener( - [&latchFirstReceiveMsg, &latchAfterClosed](Consumer consumer, const Message& msg) { - latchFirstReceiveMsg.countdown(); - LOG_INFO("Consume message: " << msg.getDataAsString()); - latchAfterClosed.wait(); - }); + + std::promise> ackResultsPromise; + consumerConfig.setMessageListener([&latchFirstReceiveMsg, &latchAfterClosed, &ackResultsPromise]( + Consumer consumer, const Message& msg) { + latchFirstReceiveMsg.countdown(); + LOG_INFO("Consume message: " << msg.getDataAsString()); + latchAfterClosed.wait(); + std::vector results(3); + results[0] = consumer.acknowledge(msg); + results[1] = consumer.acknowledgeCumulative(msg); + results[2] = consumer.acknowledge(std::vector{msg.getMessageId()}); + ackResultsPromise.set_value(results); + }); auto result = client.subscribe(topicName, "test-sub", consumerConfig, consumer); ASSERT_EQ(ResultOk, result); @@ -1555,6 +1563,11 @@ TEST(ConsumerTest, testConsumerListenerShouldNotSegfaultAfterClose) { latchFirstReceiveMsg.wait(); ASSERT_EQ(ResultOk, consumer.close()); latchAfterClosed.countdown(); + const auto ackResults = ackResultsPromise.get_future().get(); + ASSERT_EQ(3, ackResults.size()); + for (size_t i = 0; i < ackResults.size(); i++) { + ASSERT_EQ(ResultAlreadyClosed, ackResults[i]) << "ack result[" << i << "] " << ackResults[i]; + } ASSERT_EQ(ResultOk, producer.close()); ASSERT_EQ(ResultOk, client.close());