diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java new file mode 100644 index 000000000000..292c8713ea73 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.PeekingIterator; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.WakeupException; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class KafkaConsumerPollThread { + + KafkaConsumerPollThread() { + recordsDequeuePollTimeout = Duration.ofMillis(10); + consumer = null; + pollFuture = null; + } + + private @Nullable Consumer consumer; + + /** + * The poll timeout while reading records from Kafka. If option to commit reader offsets in to + * Kafka in {@link KafkaCheckpointMark#finalizeCheckpoint()} is enabled, it would be delayed until + * this poll returns. It should be reasonably low as a result. At the same time it probably can't + * be very low like 10 millis, I am not sure how it affects when the latency is high. Probably + * good to experiment. Often multiple marks would be finalized in a batch, it reduce finalization + * overhead to wait a short while and finalize only the last checkpoint mark. + */ + private static final Duration KAFKA_POLL_TIMEOUT = Duration.ofSeconds(1); + + private Duration recordsDequeuePollTimeout; + private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT_MIN = Duration.ofMillis(1); + private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT_MAX = Duration.ofMillis(20); + private static final Duration RECORDS_ENQUEUE_POLL_TIMEOUT = Duration.ofMillis(100); + + private static final long UNINITIALIZED_OFFSET = -1; + + private final AtomicReference consumerPollException = new AtomicReference<>(); + private final SynchronousQueue> availableRecordsQueue = + new SynchronousQueue<>(); + private final AtomicReference<@Nullable KafkaCheckpointMark> finalizedCheckpointMark = + new AtomicReference<>(); + private final AtomicBoolean closed = new AtomicBoolean(false); + private final Map> topicsList = new ConcurrentHashMap<>(); + private final AtomicBoolean topicListUpdated = new AtomicBoolean(false); + private static final Logger LOG = LoggerFactory.getLogger(KafkaConsumerPollThread.class); + + private @Nullable Future pollFuture; + private @Nullable PeekingIterator> activeBatchIterator; + + void startOnExecutor(ExecutorService executorService, Consumer consumer) { + this.consumer = consumer; + // Use a separate thread to read Kafka messages. Kafka Consumer does all its work including + // network I/O inside poll(). Polling only inside #advance(), especially with a small timeout + // like 100 milliseconds does not work well. This along with large receive buffer for + // consumer achieved the best throughput in tests (see `defaultConsumerProperties`). + pollFuture = executorService.submit(this::consumerPollLoop); + } + + void close() throws IOException { + if (consumer == null) { + LOG.debug("Closing consumer poll thread that was never started."); + return; + } + Preconditions.checkStateNotNull(pollFuture); + closed.set(true); + try { + // Wait for threads to shut down. Trying this as a loop to handle a tiny race where poll + // thread + // might block to enqueue right after availableRecordsQueue.poll() below. + while (true) { + if (consumer != null) { + consumer.wakeup(); + } + // todo will this drop unprocessed records? + availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. + try { + Preconditions.checkStateNotNull(pollFuture); + pollFuture.get(10, TimeUnit.SECONDS); + break; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); // not expected + } catch (ExecutionException e) { + throw new IOException(e.getCause()); + } catch (TimeoutException ignored) { + } + LOG.warn("An internal thread is taking a long time to shutdown. will retry."); + } + } finally { + Closeables.close(consumer, true); + } + } + + @Nullable + ConsumerRecord peek() throws IOException { + PeekingIterator> currentIterator = getOrInitialize(); + if (currentIterator.hasNext()) { + return currentIterator.peek(); + } + return null; + } + + void advance() throws IOException { + PeekingIterator> currentIterator = getOrInitialize(); + if (currentIterator.hasNext()) { + currentIterator.next(); + } + } + + private PeekingIterator> getOrInitialize() throws IOException { + if (activeBatchIterator == null || !activeBatchIterator.hasNext()) { + activeBatchIterator = Iterators.peekingIterator(readRecords().iterator()); + } + return activeBatchIterator; + } + + private void consumerPollLoop() { + // Read in a loop and enqueue the batch of records, if any, to availableRecordsQueue. + Consumer consumer = Preconditions.checkStateNotNull(this.consumer); + + try { + ConsumerRecords records = ConsumerRecords.empty(); + while (!closed.get()) { + try { + if (records.isEmpty()) { + records = consumer.poll(KAFKA_POLL_TIMEOUT); + updateTopicList(consumer); + } else if (availableRecordsQueue.offer( + records, RECORDS_ENQUEUE_POLL_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS)) { + records = ConsumerRecords.empty(); + } + + commitCheckpointMark(); + } catch (InterruptedException e) { + LOG.warn("{}: consumer thread is interrupted", this, e); // not expected + break; + } catch (WakeupException e) { + break; + } + } + } catch (Exception e) { // mostly an unrecoverable KafkaException. + LOG.error("{}: Exception while reading from Kafka", this, e); + consumerPollException.set(e); + throw e; + } + LOG.info("{}: Returning from consumer pool loop", this); + // Commit any pending finalized checkpoint before shutdown. + commitCheckpointMark(); + } + + private void updateTopicList(Consumer consumer) { + synchronized (topicsList) { + Map> currentTopicsList = consumer.listTopics(); + topicsList.clear(); + topicsList.putAll(currentTopicsList); + topicListUpdated.set(true); + } + } + + public boolean isTopicListUpdated() { + return topicListUpdated.get(); + } + + public Map> getTopicsList() { + synchronized (topicsList) { + return topicsList; + } + } + + ConsumerRecords readRecords() throws IOException { + @Nullable ConsumerRecords records = null; + try { + // poll available records, wait (if necessary) up to the specified timeout. + records = + availableRecordsQueue.poll(recordsDequeuePollTimeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("{}: Unexpected", this, e); + } + + if (records == null) { + // Check if the poll thread failed with an exception. + if (consumerPollException.get() != null) { + throw new IOException("Exception while reading from Kafka", consumerPollException.get()); + } + if (recordsDequeuePollTimeout.compareTo(RECORDS_DEQUEUE_POLL_TIMEOUT_MIN) > 0) { + recordsDequeuePollTimeout = recordsDequeuePollTimeout.minus(Duration.ofMillis(1)); + LOG.debug("Reducing poll timeout for reader to " + recordsDequeuePollTimeout.toMillis()); + } + return ConsumerRecords.empty(); + } + if (recordsDequeuePollTimeout.compareTo(RECORDS_DEQUEUE_POLL_TIMEOUT_MAX) < 0) { + recordsDequeuePollTimeout = recordsDequeuePollTimeout.plus(Duration.ofMillis(1)); + LOG.debug("Increasing poll timeout for reader to " + recordsDequeuePollTimeout.toMillis()); + LOG.debug("Record count: " + records.count()); + } + return records; + } + + /** + * Enqueue checkpoint mark to be committed to Kafka. This does not block until it is committed. + * There could be a delay of up to KAFKA_POLL_TIMEOUT (1 second). Any checkpoint mark enqueued + * earlier is dropped in favor of this checkpoint mark. Documentation for {@link + * UnboundedSource.CheckpointMark#finalizeCheckpoint()} says these are finalized in order. Only + * the latest offsets need to be committed. + * + *

Returns if a existing checkpoint mark was skipped. + */ + boolean finalizeCheckpointMarkAsync(KafkaCheckpointMark checkpointMark) { + return finalizedCheckpointMark.getAndSet(checkpointMark) != null; + } + + private void commitCheckpointMark() { + KafkaCheckpointMark checkpointMark = finalizedCheckpointMark.getAndSet(null); + + if (checkpointMark != null) { + LOG.debug("{}: Committing finalized checkpoint {}", this, checkpointMark); + Consumer consumer = Preconditions.checkStateNotNull(this.consumer); + + consumer.commitSync( + checkpointMark.getPartitions().stream() + .filter(p -> p.getNextOffset() != UNINITIALIZED_OFFSET) + .collect( + Collectors.toMap( + p -> new TopicPartition(p.getTopic(), p.getPartition()), + p -> new OffsetAndMetadata(p.getNextOffset())))); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThreadCache.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThreadCache.java new file mode 100644 index 000000000000..8828cc9e5f44 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThreadCache.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import java.io.IOException; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalCause; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalNotification; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class KafkaConsumerPollThreadCache { + + private static final Logger LOG = LoggerFactory.getLogger(KafkaConsumerPollThreadCache.class); + private final ExecutorService invalidationExecutor = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("KafkaConsumerPollCache-invalidation-%d") + .build()); + private final ExecutorService backgroundThreads = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("KafkaConsumerPollCache-poll-%d") + .build()); + + // Note on thread safety. This class is thread safe because: + // - Guava Cache is thread safe. + // - There is no state other than Cache. + // - API is strictly a 1:1 wrapper over Cache API (not counting cache.cleanUp() calls). + // - i.e. it does not invoke more than one call, which could make it inconsistent. + // If any of these conditions changes, please test ensure and test thread safety. + + private static class CacheKey { + final Map consumerConfig; + final SerializableFunction, Consumer> consumerFactoryFn; + final KafkaSourceDescriptor descriptor; + + CacheKey( + Map consumerConfig, + SerializableFunction, Consumer> consumerFactoryFn, + KafkaSourceDescriptor descriptor) { + this.consumerConfig = consumerConfig; + this.consumerFactoryFn = consumerFactoryFn; + this.descriptor = descriptor; + } + + @Override + public boolean equals(@Nullable Object other) { + if (other == null) { + return false; + } + if (!(other instanceof CacheKey)) { + return false; + } + CacheKey otherKey = (CacheKey) other; + return descriptor.equals(otherKey.descriptor) + && consumerFactoryFn.equals(otherKey.consumerFactoryFn) + && consumerConfig.equals(otherKey.consumerConfig); + } + + @Override + public int hashCode() { + return Objects.hash(descriptor, consumerFactoryFn, consumerConfig); + } + } + + private static class CacheEntry { + + final KafkaConsumerPollThread pollThread; + final long offset; + + CacheEntry(KafkaConsumerPollThread pollThread, long offset) { + this.pollThread = pollThread; + this.offset = offset; + } + } + + private final Duration cacheDuration = Duration.ofMinutes(1); + private final Cache cache; + + @SuppressWarnings("method.invocation") + KafkaConsumerPollThreadCache() { + this.cache = + CacheBuilder.newBuilder() + .expireAfterWrite(cacheDuration.toMillis(), TimeUnit.MILLISECONDS) + .removalListener( + (RemovalNotification notification) -> { + if (notification.getCause() != RemovalCause.EXPLICIT) { + LOG.info( + "Asynchronously closing reader for {} as it has been idle for over {}", + notification.getKey(), + cacheDuration); + asyncCloseConsumer( + checkNotNull(notification.getKey()), checkNotNull(notification.getValue())); + } + }) + .build(); + } + + KafkaConsumerPollThread acquireConsumer( + Map consumerConfig, + SerializableFunction, Consumer> consumerFactoryFn, + KafkaSourceDescriptor kafkaSourceDescriptor, + long offset) { + CacheKey key = new CacheKey(consumerConfig, consumerFactoryFn, kafkaSourceDescriptor); + CacheEntry entry = cache.asMap().remove(key); + cache.cleanUp(); + if (entry != null) { + if (entry.offset == offset) { + return entry.pollThread; + } else { + // Offset doesn't match, close. + LOG.info("Closing consumer as it is no longer valid {}", kafkaSourceDescriptor); + asyncCloseConsumer(key, entry); + } + } + + Map updatedConsumerConfig = + overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); + LOG.info( + "Creating Kafka consumer for process continuation for {}", + kafkaSourceDescriptor.getTopicPartition()); + Consumer consumer = consumerFactoryFn.apply(updatedConsumerConfig); + ConsumerSpEL.evaluateAssign( + consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition())); + consumer.seek(kafkaSourceDescriptor.getTopicPartition(), offset); + KafkaConsumerPollThread pollThread = new KafkaConsumerPollThread(); + pollThread.startOnExecutor(backgroundThreads, consumer); + return pollThread; + } + + /** Close the reader and log a warning if close fails. */ + private void asyncCloseConsumer(CacheKey key, CacheEntry entry) { + invalidationExecutor.execute( + () -> { + try { + entry.pollThread.close(); + LOG.info("Finished closing consumer for {}", key); + } catch (IOException e) { + LOG.warn("Failed to close consumer for {}", key, e); + } + }); + } + + void releaseConsumer( + Map consumerConfig, + SerializableFunction, Consumer> consumerFactoryFn, + KafkaSourceDescriptor kafkaSourceDescriptor, + KafkaConsumerPollThread pollThread, + long offset) { + CacheKey key = new CacheKey(consumerConfig, consumerFactoryFn, kafkaSourceDescriptor); + CacheEntry existing = cache.asMap().putIfAbsent(key, new CacheEntry(pollThread, offset)); + if (existing != null) { + LOG.warn("Unexpected collision of topic and partition"); + asyncCloseConsumer(key, existing); + } + cache.cleanUp(); + } + + private Map overrideBootstrapServersConfig( + Map currentConfig, KafkaSourceDescriptor description) { + checkState( + currentConfig.containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) + || description.getBootStrapServers() != null); + Map config = new HashMap<>(currentConfig); + if (description.getBootStrapServers() != null && !description.getBootStrapServers().isEmpty()) { + config.put( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, + String.join(",", description.getBootStrapServers())); + } + return config; + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index fed03047cf16..5f1d621605f6 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -33,11 +33,9 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.SynchronousQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; @@ -61,9 +59,7 @@ import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; -import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.common.TopicPartition; -import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.serialization.Deserializer; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; @@ -102,7 +98,7 @@ public boolean start() throws IOException { // This problem of blocking API calls to kafka is solved in higher versions of kafka // client by `KIP-266` for (final PartitionState pState : partitionStates) { - Future future = consumerPollThread.submit(() -> setupInitialOffset(pState)); + Future future = backgroundThread.submit(() -> setupInitialOffset(pState)); try { Duration timeout = resolveDefaultApiTimeout(spec); future.get(timeout.getMillis(), TimeUnit.MILLISECONDS); @@ -129,7 +125,7 @@ public boolean start() throws IOException { // Start consumer read loop. // Note that consumer is not thread safe, should not be accessed out side consumerPollLoop(). - consumerPollThread.submit(this::consumerPollLoop); + pollThread.startOnExecutor(backgroundThread, consumer); // offsetConsumer setup : Map offsetConsumerConfig = @@ -345,44 +341,26 @@ public long getSplitBacklogBytes() { private final Counter bytesReadBySplit; private final Gauge backlogBytesOfSplit; private final Gauge backlogElementsOfSplit; - private HashMap perPartitionBacklogMetrics = new HashMap();; + private final HashMap perPartitionBacklogMetrics = new HashMap<>(); private final Counter checkpointMarkCommitsEnqueued = Metrics.counter(METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC); // Checkpoint marks skipped in favor of newer mark (only the latest needs to be committed). private final Counter checkpointMarkCommitsSkipped = Metrics.counter(METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_SKIPPED_METRIC); - /** - * The poll timeout while reading records from Kafka. If option to commit reader offsets in to - * Kafka in {@link KafkaCheckpointMark#finalizeCheckpoint()} is enabled, it would be delayed until - * this poll returns. It should be reasonably low as a result. At the same time it probably can't - * be very low like 10 millis, I am not sure how it affects when the latency is high. Probably - * good to experiment. Often multiple marks would be finalized in a batch, it reduce finalization - * overhead to wait a short while and finalize only the last checkpoint mark. - */ - private static final Duration KAFKA_POLL_TIMEOUT = Duration.millis(1000); - - private Duration recordsDequeuePollTimeout; - private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT_MIN = Duration.millis(1); - private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT_MAX = Duration.millis(20); - private static final Duration RECORDS_ENQUEUE_POLL_TIMEOUT = Duration.millis(100); + private final transient KafkaConsumerPollThread pollThread; // Use a separate thread to read Kafka messages. Kafka Consumer does all its work including // network I/O inside poll(). Polling only inside #advance(), especially with a small timeout // like 100 milliseconds does not work well. This along with large receive buffer for // consumer achieved best throughput in tests (see `defaultConsumerProperties`). - private final ExecutorService consumerPollThread = + private final ExecutorService backgroundThread = Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) .setNameFormat("KafkaConsumerPoll-thread") .build()); - private AtomicReference consumerPollException = new AtomicReference<>(); - private final SynchronousQueue> availableRecordsQueue = - new SynchronousQueue<>(); - private AtomicReference<@Nullable KafkaCheckpointMark> finalizedCheckpointMark = - new AtomicReference<>(); - private AtomicBoolean closed = new AtomicBoolean(false); + private final AtomicBoolean closed = new AtomicBoolean(false); // Backlog support : // Kafka consumer does not have an API to fetch latest offset for topic. We need to seekToEnd() @@ -397,7 +375,7 @@ public long getSplitBacklogBytes() { private static final long UNINITIALIZED_OFFSET = -1; /** watermark before any records have been read. */ - private static Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + private static final Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; @Override public String toString() { @@ -436,8 +414,8 @@ private static class PartitionState { private Iterator> recordIter = Collections.emptyIterator(); - private KafkaIOUtils.MovingAvg avgRecordSize = new KafkaIOUtils.MovingAvg(); - private KafkaIOUtils.MovingAvg avgOffsetGap = + private final KafkaIOUtils.MovingAvg avgRecordSize = new KafkaIOUtils.MovingAvg(); + private final KafkaIOUtils.MovingAvg avgOffsetGap = new KafkaIOUtils.MovingAvg(); // > 0 only when log compaction is enabled. PartitionState( @@ -538,7 +516,7 @@ String name() { } PartitionState state = - new PartitionState( + new PartitionState<>( tp, nextOffset, source @@ -556,55 +534,7 @@ String name() { bytesReadBySplit = SourceMetrics.bytesReadBySplit(splitId); backlogBytesOfSplit = SourceMetrics.backlogBytesOfSplit(splitId); backlogElementsOfSplit = SourceMetrics.backlogElementsOfSplit(splitId); - recordsDequeuePollTimeout = Duration.millis(10); - } - - private void consumerPollLoop() { - // Read in a loop and enqueue the batch of records, if any, to availableRecordsQueue. - Consumer consumer = Preconditions.checkStateNotNull(this.consumer); - - try { - ConsumerRecords records = ConsumerRecords.empty(); - while (!closed.get()) { - try { - if (records.isEmpty()) { - records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); - } else if (availableRecordsQueue.offer( - records, RECORDS_ENQUEUE_POLL_TIMEOUT.getMillis(), TimeUnit.MILLISECONDS)) { - records = ConsumerRecords.empty(); - } - - commitCheckpointMark(); - } catch (InterruptedException e) { - LOG.warn("{}: consumer thread is interrupted", this, e); // not expected - break; - } catch (WakeupException e) { - break; - } - } - LOG.info("{}: Returning from consumer pool loop", this); - } catch (Exception e) { // mostly an unrecoverable KafkaException. - LOG.error("{}: Exception while reading from Kafka", this, e); - consumerPollException.set(e); - throw e; - } - } - - private void commitCheckpointMark() { - KafkaCheckpointMark checkpointMark = finalizedCheckpointMark.getAndSet(null); - - if (checkpointMark != null) { - LOG.debug("{}: Committing finalized checkpoint {}", this, checkpointMark); - Consumer consumer = Preconditions.checkStateNotNull(this.consumer); - - consumer.commitSync( - checkpointMark.getPartitions().stream() - .filter(p -> p.getNextOffset() != UNINITIALIZED_OFFSET) - .collect( - Collectors.toMap( - p -> new TopicPartition(p.getTopic(), p.getPartition()), - p -> new OffsetAndMetadata(p.getNextOffset())))); - } + pollThread = new KafkaConsumerPollThread(); } /** @@ -615,7 +545,7 @@ private void commitCheckpointMark() { * need to be committed. */ void finalizeCheckpointMarkAsync(KafkaCheckpointMark checkpointMark) { - if (finalizedCheckpointMark.getAndSet(checkpointMark) != null) { + if (pollThread.finalizeCheckpointMarkAsync(checkpointMark)) { checkpointMarkCommitsSkipped.inc(); } checkpointMarkCommitsEnqueued.inc(); @@ -624,39 +554,12 @@ void finalizeCheckpointMarkAsync(KafkaCheckpointMark checkpointMark) { private void nextBatch() throws IOException { curBatch = Collections.emptyIterator(); - ConsumerRecords records; - try { - // poll available records, wait (if necessary) up to the specified timeout. - records = - availableRecordsQueue.poll(recordsDequeuePollTimeout.getMillis(), TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - LOG.warn("{}: Unexpected", this, e); - return; - } - - if (records == null) { - // Check if the poll thread failed with an exception. - if (consumerPollException.get() != null) { - throw new IOException("Exception while reading from Kafka", consumerPollException.get()); - } - if (recordsDequeuePollTimeout.isLongerThan(RECORDS_DEQUEUE_POLL_TIMEOUT_MIN)) { - recordsDequeuePollTimeout = recordsDequeuePollTimeout.minus(Duration.millis(1)); - LOG.debug("Reducing poll timeout for reader to " + recordsDequeuePollTimeout.getMillis()); - } - return; - } - - if (recordsDequeuePollTimeout.isShorterThan(RECORDS_DEQUEUE_POLL_TIMEOUT_MAX)) { - recordsDequeuePollTimeout = recordsDequeuePollTimeout.plus(Duration.millis(1)); - LOG.debug("Increasing poll timeout for reader to " + recordsDequeuePollTimeout.getMillis()); - LOG.debug("Record count: " + records.count()); + ConsumerRecords records = pollThread.readRecords(); + if (!records.isEmpty()) { + partitionStates.forEach(p -> p.recordIter = records.records(p.topicPartition).iterator()); + // cycle through the partitions in order to interleave records from each. + curBatch = Iterators.cycle(new ArrayList<>(partitionStates)); } - - partitionStates.forEach(p -> p.recordIter = records.records(p.topicPartition).iterator()); - - // cycle through the partitions in order to interleave records from each. - curBatch = Iterators.cycle(new ArrayList<>(partitionStates)); } private void setupInitialOffset(PartitionState pState) { @@ -738,7 +641,11 @@ private long getSplitBacklogMessageCount() { @Override public void close() throws IOException { closed.set(true); - consumerPollThread.shutdown(); + try { + pollThread.close(); + } catch (IOException e) { + LOG.warn("Error shutting down poll thread", e); + } offsetFetcherThread.shutdown(); boolean isShutdown = false; @@ -746,18 +653,11 @@ public void close() throws IOException { // Wait for threads to shutdown. Trying this as a loop to handle a tiny race where poll thread // might block to enqueue right after availableRecordsQueue.poll() below. while (!isShutdown) { - - if (consumer != null) { - consumer.wakeup(); - } if (offsetConsumer != null) { offsetConsumer.wakeup(); } - availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. try { - isShutdown = - consumerPollThread.awaitTermination(10, TimeUnit.SECONDS) - && offsetFetcherThread.awaitTermination(10, TimeUnit.SECONDS); + isShutdown = offsetFetcherThread.awaitTermination(10, TimeUnit.SECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); // not expected @@ -768,14 +668,10 @@ public void close() throws IOException { } } - // Commit any pending finalized checkpoint before shutdown. - commitCheckpointMark(); - Closeables.close(keyDeserializerInstance, true); Closeables.close(valueDeserializerInstance, true); Closeables.close(offsetConsumer, true); - Closeables.close(consumer, true); } @VisibleForTesting diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 9bb950bb8e6c..b91b15af3480 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -19,6 +19,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import java.time.Duration; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -26,6 +27,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; import org.apache.beam.sdk.io.kafka.KafkaIOUtils.MovingAvg; @@ -51,7 +53,6 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; @@ -62,7 +63,6 @@ import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; -import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.errors.SerializationException; @@ -97,16 +97,16 @@ * *

Splitting

* - *

TODO(https://github.com/apache/beam/issues/20280): Add support for initial splitting. + *

TODO(...): Add support for initial + * splitting. * *

Checkpoint and Resume Processing

* *

There are 2 types of checkpoint here: self-checkpoint which invokes by the DoFn and - * system-checkpoint which is issued by the runner via {@link - * org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitRequest}. Every time the - * consumer gets empty response from {@link Consumer#poll(long)}, {@link ReadFromKafkaDoFn} will - * checkpoint the current {@link KafkaSourceDescriptor} and move to process the next element. These - * deferred elements will be resumed by the runner as soon as possible. + * system-checkpoint which is issued by the runner via {@link BeamFnApi.ProcessBundleSplitRequest}. + * Every time the consumer gets empty response from {@link Consumer#poll(Duration)}, {@link + * ReadFromKafkaDoFn} will checkpoint the current {@link KafkaSourceDescriptor} and move to process + * the next element. These deferred elements will be resumed by the runner as soon as possible. * *

Progress and Size

* @@ -117,7 +117,7 @@ * ReadFromKafkaDoFn#restrictionTracker(KafkaSourceDescriptor, OffsetRange)} for details. * *

The size is computed by {@link ReadFromKafkaDoFn#getSize(KafkaSourceDescriptor, OffsetRange)}. - * A {@link KafkaIOUtils.MovingAvg} is used to track the average size of kafka records. + * A {@link MovingAvg} is used to track the average size of kafka records. * *

Track Watermark

* @@ -135,8 +135,8 @@ * {@link ReadFromKafkaDoFn} will stop reading from any removed {@link TopicPartition} automatically * by querying Kafka {@link Consumer} APIs. Please note that stopping reading may not happen as soon * as the {@link TopicPartition} is removed. For example, the removal could happen at the same time - * when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(java.time.Duration)}. In that - * case, the {@link ReadFromKafkaDoFn} will still output the fetched records. + * when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(Duration)}. In that case, the + * {@link ReadFromKafkaDoFn} will still output the fetched records. * *

Stop Reading from Stopped {@link TopicPartition}

* @@ -218,6 +218,9 @@ private ReadFromKafkaDoFn( private final TupleTag>> recordTag; + private static final Supplier cache = + Suppliers.memoize(KafkaConsumerPollThreadCache::new); + // Valid between bundle start and bundle finish. private transient @Nullable Deserializer keyDeserializerInstance = null; private transient @Nullable Deserializer valueDeserializerInstance = null; @@ -226,7 +229,7 @@ private ReadFromKafkaDoFn( private transient @Nullable LoadingCache avgRecordSize; private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L; - private HashMap perPartitionBacklogMetrics = new HashMap();; + private final HashMap perPartitionBacklogMetrics = new HashMap<>(); @VisibleForTesting final long consumerPollingTimeout; @VisibleForTesting final DeserializerProvider keyDeserializerProvider; @@ -312,7 +315,6 @@ public OffsetRange initialRestriction(@Element KafkaSourceDescriptor kafkaSource } else if (stopReadTime != null) { endOffset = ConsumerSpEL.offsetForTime(offsetConsumer, partition, stopReadTime); } - new OffsetRange(startOffset, endOffset); Lineage.getSources() .add( "kafka", @@ -433,37 +435,55 @@ public ProcessContinuation processElement( kafkaSourceDescriptor.getTopicPartition(), Optional.ofNullable(watermarkEstimator.currentWatermark())); } - - LOG.info( - "Creating Kafka consumer for process continuation for {}", - kafkaSourceDescriptor.getTopicPartition()); - try (Consumer consumer = consumerFactoryFn.apply(updatedConsumerConfig)) { - ConsumerSpEL.evaluateAssign( - consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition())); - long startOffset = tracker.currentRestriction().getFrom(); - - long expectedOffset = startOffset; - consumer.seek(kafkaSourceDescriptor.getTopicPartition(), startOffset); - ConsumerRecords rawRecords = ConsumerRecords.empty(); - + final long startOffset = tracker.currentRestriction().getFrom(); + String restrictionInfo = + String.format( + "%s_%s_%s", + kafkaSourceDescriptor.getTopic(), kafkaSourceDescriptor.getPartition(), startOffset); + LOG.info("bzablockilog start restriction {}", restrictionInfo); + long resumeOffset = startOffset; + @Nullable KafkaConsumerPollThread pollThread = null; + try { + pollThread = + cache + .get() + .acquireConsumer( + updatedConsumerConfig, consumerFactoryFn, kafkaSourceDescriptor, startOffset); while (true) { - rawRecords = poll(consumer, kafkaSourceDescriptor.getTopicPartition()); + ConsumerRecord rawRecord = pollThread.peek(); + // When there are no records available for the current TopicPartition, self-checkpoint // and move to process the next element. - if (rawRecords.isEmpty()) { - if (!topicPartitionExists( - kafkaSourceDescriptor.getTopicPartition(), consumer.listTopics())) { + // if (pollThread.isTopicListUpdated() && !rawRecordOptional.isPresent()) { + if (pollThread.isTopicListUpdated() && rawRecord == null) { + Map> topicsList = pollThread.getTopicsList(); + if (!topicPartitionExists(kafkaSourceDescriptor.getTopicPartition(), topicsList)) { + LOG.info("bzablockilog stop restriction {}", restrictionInfo); return ProcessContinuation.stop(); } if (timestampPolicy != null) { updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); } + LOG.info("bzablockilog resume restriction {}", restrictionInfo); return ProcessContinuation.resume(); } - for (ConsumerRecord rawRecord : rawRecords) { + + if (rawRecord != null) { + LOG.info( + "bzablockilog picked up {}_{}_{}", + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset()); if (!tracker.tryClaim(rawRecord.offset())) { + LOG.info( + "bzablockilog unsuccessful claim of {}_{}_{}", + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset()); + // XXX need to add unconsumed records back. return ProcessContinuation.stop(); } + try { KafkaRecord kafkaRecord = new KafkaRecord<>( @@ -480,9 +500,9 @@ public ProcessContinuation processElement( + (rawRecord.value() == null ? 0 : rawRecord.value().length); avgRecordSize .getUnchecked(kafkaSourceDescriptor.getTopicPartition()) - .update(recordSize, rawRecord.offset() - expectedOffset); + .update(recordSize, rawRecord.offset() - resumeOffset); rawSizes.update(recordSize); - expectedOffset = rawRecord.offset() + 1; + resumeOffset = rawRecord.offset() + 1; Instant outputTimestamp; // The outputTimestamp and watermark will be computed by timestampPolicy, where the // WatermarkEstimator should be a manual one. @@ -505,13 +525,26 @@ public ProcessContinuation processElement( rawRecord, null, e, - "Failure deserializing Key or Value of Kakfa record reading from Kafka"); + "Failure deserializing Key or Value of Kafka record reading from Kafka"); if (timestampPolicy != null) { updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); } } + + pollThread.advance(); } } + } finally { + if (pollThread != null) { + cache + .get() + .releaseConsumer( + updatedConsumerConfig, + consumerFactoryFn, + kafkaSourceDescriptor, + pollThread, + resumeOffset); + } } } @@ -532,34 +565,6 @@ private boolean topicPartitionExists( return true; } - // see https://github.com/apache/beam/issues/25962 - private ConsumerRecords poll( - Consumer consumer, TopicPartition topicPartition) { - final Stopwatch sw = Stopwatch.createStarted(); - long previousPosition = -1; - java.time.Duration elapsed = java.time.Duration.ZERO; - java.time.Duration timeout = java.time.Duration.ofSeconds(this.consumerPollingTimeout); - while (true) { - final ConsumerRecords rawRecords = consumer.poll(timeout.minus(elapsed)); - if (!rawRecords.isEmpty()) { - // return as we have found some entries - return rawRecords; - } - if (previousPosition == (previousPosition = consumer.position(topicPartition))) { - // there was no progress on the offset/position, which indicates end of stream - return rawRecords; - } - elapsed = sw.elapsed(); - if (elapsed.toMillis() >= timeout.toMillis()) { - // timeout is over - LOG.warn( - "No messages retrieved with polling timeout {} seconds. Consider increasing the consumer polling timeout using withConsumerPollingTimeout method.", - consumerPollingTimeout); - return rawRecords; - } - } - } - private TimestampPolicyContext updateWatermarkManually( TimestampPolicy timestampPolicy, WatermarkEstimator watermarkEstimator, @@ -587,7 +592,7 @@ public void setup() throws Exception { .build( new CacheLoader() { @Override - public AverageRecordSize load(TopicPartition topicPartition) throws Exception { + public AverageRecordSize load(TopicPartition topicPartition) { return new AverageRecordSize(); } }); @@ -626,7 +631,7 @@ private Map overrideBootstrapServersConfig( currentConfig.containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) || description.getBootStrapServers() != null); Map config = new HashMap<>(currentConfig); - if (description.getBootStrapServers() != null && description.getBootStrapServers().size() > 0) { + if (description.getBootStrapServers() != null && !description.getBootStrapServers().isEmpty()) { config.put( ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, String.join(",", description.getBootStrapServers())); @@ -635,8 +640,8 @@ private Map overrideBootstrapServersConfig( } private static class AverageRecordSize { - private MovingAvg avgRecordSize; - private MovingAvg avgRecordGap; + private final MovingAvg avgRecordSize; + private final MovingAvg avgRecordGap; public AverageRecordSize() { this.avgRecordSize = new MovingAvg(); diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java index 1fe1147a7390..b5a2d75291bd 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -154,6 +154,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.slf4j.Logger; @@ -165,7 +166,7 @@ */ @RunWith(JUnit4.class) public class KafkaIOTest { - + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final Logger LOG = LoggerFactory.getLogger(KafkaIOTest.class); /* @@ -181,7 +182,7 @@ public class KafkaIOTest { @Rule public ExpectedException thrown = ExpectedException.none(); @Rule - public ExpectedLogs unboundedReaderExpectedLogs = ExpectedLogs.none(KafkaUnboundedReader.class); + public ExpectedLogs pollThreadExpectedLogs = ExpectedLogs.none(KafkaConsumerPollThread.class); @Rule public ExpectedLogs kafkaIOExpectedLogs = ExpectedLogs.none(KafkaIO.class); @@ -1423,7 +1424,7 @@ public void testUnboundedSourceMetrics() { } @Test - public void testUnboundedReaderLogsCommitFailure() throws Exception { + public void testUnboundedReaderLogsFetchFailure() throws Exception { List topics = ImmutableList.of("topic_a"); @@ -1440,9 +1441,12 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception { UnboundedReader> reader = source.createReader(null, null); - reader.start(); - - unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partition"); + try { + reader.start(); + } catch (Exception e) { + // Racy if we observe the exception on initial advance. + } + pollThreadExpectedLogs.verifyError("Exception while reading from Kafka"); reader.close(); } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index 6ee3d9d96ef6..c2fec5536be9 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -83,8 +83,10 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.junit.rules.Timeout; public class ReadFromKafkaDoFnTest { + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private final TopicPartition topicPartition = new TopicPartition("topic", 0);