From f4ea0c354c1614cce6810bcda638d2745f478d5d Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Fri, 5 Jul 2024 20:41:06 +0200 Subject: [PATCH 1/6] Extract background poll thread from KafkaUnboundedReader --- .../sdk/io/kafka/KafkaConsumerPollThread.java | 209 ++++++++++++++++++ .../sdk/io/kafka/KafkaUnboundedReader.java | 146 ++---------- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 37 +++- 3 files changed, 253 insertions(+), 139 deletions(-) create mode 100644 sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java new file mode 100644 index 000000000000..4dbf8a9ed50f --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.WakeupException; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class KafkaConsumerPollThread { + + KafkaConsumerPollThread() { + recordsDequeuePollTimeout = Duration.ofMillis(10); + consumer = null; + pollFuture = null; + } + + private @Nullable Consumer consumer; + + /** + * The poll timeout while reading records from Kafka. If option to commit reader offsets in to + * Kafka in {@link KafkaCheckpointMark#finalizeCheckpoint()} is enabled, it would be delayed until + * this poll returns. It should be reasonably low as a result. At the same time it probably can't + * be very low like 10 millis, I am not sure how it affects when the latency is high. Probably + * good to experiment. Often multiple marks would be finalized in a batch, it reduce finalization + * overhead to wait a short while and finalize only the last checkpoint mark. + */ + private static final Duration KAFKA_POLL_TIMEOUT = Duration.ofSeconds(1); + + private Duration recordsDequeuePollTimeout; + private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT_MIN = Duration.ofMillis(1); + private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT_MAX = Duration.ofMillis(20); + private static final Duration RECORDS_ENQUEUE_POLL_TIMEOUT = Duration.ofMillis(100); + + private static final long UNINITIALIZED_OFFSET = -1; + + private final AtomicReference consumerPollException = new AtomicReference<>(); + private final SynchronousQueue> availableRecordsQueue = + new SynchronousQueue<>(); + private final AtomicReference<@Nullable KafkaCheckpointMark> finalizedCheckpointMark = + new AtomicReference<>(); + private final AtomicBoolean closed = new AtomicBoolean(false); + + private static final Logger LOG = LoggerFactory.getLogger(KafkaConsumerPollThread.class); + + private @Nullable Future pollFuture; + + void startOnExecutor(ExecutorService executorService, Consumer consumer) { + this.consumer = consumer; + // Use a separate thread to read Kafka messages. Kafka Consumer does all its work including + // network I/O inside poll(). Polling only inside #advance(), especially with a small timeout + // like 100 milliseconds does not work well. This along with large receive buffer for + // consumer achieved best throughput in tests (see `defaultConsumerProperties`). + pollFuture = executorService.submit(this::consumerPollLoop); + } + + void close() throws IOException { + if (consumer == null) { + LOG.debug("Closing consumer poll thread that was never started."); + return; + } + Preconditions.checkStateNotNull(pollFuture); + closed.set(true); + try { + // Wait for threads to shutdown. Trying this as a loop to handle a tiny race where poll thread + // might block to enqueue right after availableRecordsQueue.poll() below. + while (true) { + if (consumer != null) { + consumer.wakeup(); + } + availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. + try { + Preconditions.checkStateNotNull(pollFuture); + pollFuture.get(10, TimeUnit.SECONDS); + break; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); // not expected + } catch (ExecutionException e) { + throw new IOException(e.getCause()); + } catch (TimeoutException ignored) { + } + LOG.warn("An internal thread is taking a long time to shutdown. will retry."); + } + } finally { + Closeables.close(consumer, true); + } + } + + private void consumerPollLoop() { + // Read in a loop and enqueue the batch of records, if any, to availableRecordsQueue. + Consumer consumer = Preconditions.checkStateNotNull(this.consumer); + + try { + ConsumerRecords records = ConsumerRecords.empty(); + while (!closed.get()) { + try { + if (records.isEmpty()) { + records = consumer.poll(KAFKA_POLL_TIMEOUT); + } else if (availableRecordsQueue.offer( + records, RECORDS_ENQUEUE_POLL_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS)) { + records = ConsumerRecords.empty(); + } + + commitCheckpointMark(); + } catch (InterruptedException e) { + LOG.warn("{}: consumer thread is interrupted", this, e); // not expected + break; + } catch (WakeupException e) { + break; + } + } + } catch (Exception e) { // mostly an unrecoverable KafkaException. + LOG.error("{}: Exception while reading from Kafka", this, e); + consumerPollException.set(e); + throw e; + } + LOG.info("{}: Returning from consumer pool loop", this); + // Commit any pending finalized checkpoint before shutdown. + commitCheckpointMark(); + } + + ConsumerRecords readRecords() throws IOException { + @Nullable ConsumerRecords records = null; + try { + // poll available records, wait (if necessary) up to the specified timeout. + records = + availableRecordsQueue.poll(recordsDequeuePollTimeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("{}: Unexpected", this, e); + } + + if (records == null) { + // Check if the poll thread failed with an exception. + if (consumerPollException.get() != null) { + throw new IOException("Exception while reading from Kafka", consumerPollException.get()); + } + if (recordsDequeuePollTimeout.compareTo(RECORDS_DEQUEUE_POLL_TIMEOUT_MIN) > 0) { + recordsDequeuePollTimeout = recordsDequeuePollTimeout.minus(Duration.ofMillis(1)); + LOG.debug("Reducing poll timeout for reader to " + recordsDequeuePollTimeout.toMillis()); + } + } else if (recordsDequeuePollTimeout.compareTo(RECORDS_DEQUEUE_POLL_TIMEOUT_MAX) < 0) { + recordsDequeuePollTimeout = recordsDequeuePollTimeout.plus(Duration.ofMillis(1)); + LOG.debug("Increasing poll timeout for reader to " + recordsDequeuePollTimeout.toMillis()); + LOG.debug("Record count: " + records.count()); + } + + return records == null ? ConsumerRecords.empty() : records; + } + + /** + * Enqueue checkpoint mark to be committed to Kafka. This does not block until it is committed. + * There could be a delay of up to KAFKA_POLL_TIMEOUT (1 second). Any checkpoint mark enqueued + * earlier is dropped in favor of this checkpoint mark. Documentation for {@link + * UnboundedSource.CheckpointMark#finalizeCheckpoint()} says these are finalized in order. Only + * the latest offsets need to be committed. + * + *

Returns if a existing checkpoint mark was skipped. + */ + boolean finalizeCheckpointMarkAsync(KafkaCheckpointMark checkpointMark) { + return finalizedCheckpointMark.getAndSet(checkpointMark) != null; + } + + private void commitCheckpointMark() { + KafkaCheckpointMark checkpointMark = finalizedCheckpointMark.getAndSet(null); + + if (checkpointMark != null) { + LOG.debug("{}: Committing finalized checkpoint {}", this, checkpointMark); + Consumer consumer = Preconditions.checkStateNotNull(this.consumer); + + consumer.commitSync( + checkpointMark.getPartitions().stream() + .filter(p -> p.getNextOffset() != UNINITIALIZED_OFFSET) + .collect( + Collectors.toMap( + p -> new TopicPartition(p.getTopic(), p.getPartition()), + p -> new OffsetAndMetadata(p.getNextOffset())))); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index fed03047cf16..7aa9df199f91 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -33,11 +33,9 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.SynchronousQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; @@ -61,9 +59,7 @@ import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; -import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.common.TopicPartition; -import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.serialization.Deserializer; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; @@ -102,7 +98,7 @@ public boolean start() throws IOException { // This problem of blocking API calls to kafka is solved in higher versions of kafka // client by `KIP-266` for (final PartitionState pState : partitionStates) { - Future future = consumerPollThread.submit(() -> setupInitialOffset(pState)); + Future future = backgroundThread.submit(() -> setupInitialOffset(pState)); try { Duration timeout = resolveDefaultApiTimeout(spec); future.get(timeout.getMillis(), TimeUnit.MILLISECONDS); @@ -129,7 +125,7 @@ public boolean start() throws IOException { // Start consumer read loop. // Note that consumer is not thread safe, should not be accessed out side consumerPollLoop(). - consumerPollThread.submit(this::consumerPollLoop); + pollThread.startOnExecutor(backgroundThread, consumer); // offsetConsumer setup : Map offsetConsumerConfig = @@ -345,44 +341,26 @@ public long getSplitBacklogBytes() { private final Counter bytesReadBySplit; private final Gauge backlogBytesOfSplit; private final Gauge backlogElementsOfSplit; - private HashMap perPartitionBacklogMetrics = new HashMap();; + private final HashMap perPartitionBacklogMetrics = new HashMap<>(); private final Counter checkpointMarkCommitsEnqueued = Metrics.counter(METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC); // Checkpoint marks skipped in favor of newer mark (only the latest needs to be committed). private final Counter checkpointMarkCommitsSkipped = Metrics.counter(METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_SKIPPED_METRIC); - /** - * The poll timeout while reading records from Kafka. If option to commit reader offsets in to - * Kafka in {@link KafkaCheckpointMark#finalizeCheckpoint()} is enabled, it would be delayed until - * this poll returns. It should be reasonably low as a result. At the same time it probably can't - * be very low like 10 millis, I am not sure how it affects when the latency is high. Probably - * good to experiment. Often multiple marks would be finalized in a batch, it reduce finalization - * overhead to wait a short while and finalize only the last checkpoint mark. - */ - private static final Duration KAFKA_POLL_TIMEOUT = Duration.millis(1000); - - private Duration recordsDequeuePollTimeout; - private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT_MIN = Duration.millis(1); - private static final Duration RECORDS_DEQUEUE_POLL_TIMEOUT_MAX = Duration.millis(20); - private static final Duration RECORDS_ENQUEUE_POLL_TIMEOUT = Duration.millis(100); + private final KafkaConsumerPollThread pollThread; // Use a separate thread to read Kafka messages. Kafka Consumer does all its work including // network I/O inside poll(). Polling only inside #advance(), especially with a small timeout // like 100 milliseconds does not work well. This along with large receive buffer for // consumer achieved best throughput in tests (see `defaultConsumerProperties`). - private final ExecutorService consumerPollThread = + private final ExecutorService backgroundThread = Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) .setNameFormat("KafkaConsumerPoll-thread") .build()); - private AtomicReference consumerPollException = new AtomicReference<>(); - private final SynchronousQueue> availableRecordsQueue = - new SynchronousQueue<>(); - private AtomicReference<@Nullable KafkaCheckpointMark> finalizedCheckpointMark = - new AtomicReference<>(); - private AtomicBoolean closed = new AtomicBoolean(false); + private final AtomicBoolean closed = new AtomicBoolean(false); // Backlog support : // Kafka consumer does not have an API to fetch latest offset for topic. We need to seekToEnd() @@ -397,7 +375,7 @@ public long getSplitBacklogBytes() { private static final long UNINITIALIZED_OFFSET = -1; /** watermark before any records have been read. */ - private static Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + private static final Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; @Override public String toString() { @@ -436,8 +414,8 @@ private static class PartitionState { private Iterator> recordIter = Collections.emptyIterator(); - private KafkaIOUtils.MovingAvg avgRecordSize = new KafkaIOUtils.MovingAvg(); - private KafkaIOUtils.MovingAvg avgOffsetGap = + private final KafkaIOUtils.MovingAvg avgRecordSize = new KafkaIOUtils.MovingAvg(); + private final KafkaIOUtils.MovingAvg avgOffsetGap = new KafkaIOUtils.MovingAvg(); // > 0 only when log compaction is enabled. PartitionState( @@ -538,7 +516,7 @@ String name() { } PartitionState state = - new PartitionState( + new PartitionState<>( tp, nextOffset, source @@ -556,55 +534,7 @@ String name() { bytesReadBySplit = SourceMetrics.bytesReadBySplit(splitId); backlogBytesOfSplit = SourceMetrics.backlogBytesOfSplit(splitId); backlogElementsOfSplit = SourceMetrics.backlogElementsOfSplit(splitId); - recordsDequeuePollTimeout = Duration.millis(10); - } - - private void consumerPollLoop() { - // Read in a loop and enqueue the batch of records, if any, to availableRecordsQueue. - Consumer consumer = Preconditions.checkStateNotNull(this.consumer); - - try { - ConsumerRecords records = ConsumerRecords.empty(); - while (!closed.get()) { - try { - if (records.isEmpty()) { - records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); - } else if (availableRecordsQueue.offer( - records, RECORDS_ENQUEUE_POLL_TIMEOUT.getMillis(), TimeUnit.MILLISECONDS)) { - records = ConsumerRecords.empty(); - } - - commitCheckpointMark(); - } catch (InterruptedException e) { - LOG.warn("{}: consumer thread is interrupted", this, e); // not expected - break; - } catch (WakeupException e) { - break; - } - } - LOG.info("{}: Returning from consumer pool loop", this); - } catch (Exception e) { // mostly an unrecoverable KafkaException. - LOG.error("{}: Exception while reading from Kafka", this, e); - consumerPollException.set(e); - throw e; - } - } - - private void commitCheckpointMark() { - KafkaCheckpointMark checkpointMark = finalizedCheckpointMark.getAndSet(null); - - if (checkpointMark != null) { - LOG.debug("{}: Committing finalized checkpoint {}", this, checkpointMark); - Consumer consumer = Preconditions.checkStateNotNull(this.consumer); - - consumer.commitSync( - checkpointMark.getPartitions().stream() - .filter(p -> p.getNextOffset() != UNINITIALIZED_OFFSET) - .collect( - Collectors.toMap( - p -> new TopicPartition(p.getTopic(), p.getPartition()), - p -> new OffsetAndMetadata(p.getNextOffset())))); - } + pollThread = new KafkaConsumerPollThread(); } /** @@ -615,7 +545,7 @@ private void commitCheckpointMark() { * need to be committed. */ void finalizeCheckpointMarkAsync(KafkaCheckpointMark checkpointMark) { - if (finalizedCheckpointMark.getAndSet(checkpointMark) != null) { + if (pollThread.finalizeCheckpointMarkAsync(checkpointMark)) { checkpointMarkCommitsSkipped.inc(); } checkpointMarkCommitsEnqueued.inc(); @@ -624,39 +554,12 @@ void finalizeCheckpointMarkAsync(KafkaCheckpointMark checkpointMark) { private void nextBatch() throws IOException { curBatch = Collections.emptyIterator(); - ConsumerRecords records; - try { - // poll available records, wait (if necessary) up to the specified timeout. - records = - availableRecordsQueue.poll(recordsDequeuePollTimeout.getMillis(), TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - LOG.warn("{}: Unexpected", this, e); - return; - } - - if (records == null) { - // Check if the poll thread failed with an exception. - if (consumerPollException.get() != null) { - throw new IOException("Exception while reading from Kafka", consumerPollException.get()); - } - if (recordsDequeuePollTimeout.isLongerThan(RECORDS_DEQUEUE_POLL_TIMEOUT_MIN)) { - recordsDequeuePollTimeout = recordsDequeuePollTimeout.minus(Duration.millis(1)); - LOG.debug("Reducing poll timeout for reader to " + recordsDequeuePollTimeout.getMillis()); - } - return; - } - - if (recordsDequeuePollTimeout.isShorterThan(RECORDS_DEQUEUE_POLL_TIMEOUT_MAX)) { - recordsDequeuePollTimeout = recordsDequeuePollTimeout.plus(Duration.millis(1)); - LOG.debug("Increasing poll timeout for reader to " + recordsDequeuePollTimeout.getMillis()); - LOG.debug("Record count: " + records.count()); + ConsumerRecords records = pollThread.readRecords(); + if (records != null) { + partitionStates.forEach(p -> p.recordIter = records.records(p.topicPartition).iterator()); + // cycle through the partitions in order to interleave records from each. + curBatch = Iterators.cycle(new ArrayList<>(partitionStates)); } - - partitionStates.forEach(p -> p.recordIter = records.records(p.topicPartition).iterator()); - - // cycle through the partitions in order to interleave records from each. - curBatch = Iterators.cycle(new ArrayList<>(partitionStates)); } private void setupInitialOffset(PartitionState pState) { @@ -738,7 +641,7 @@ private long getSplitBacklogMessageCount() { @Override public void close() throws IOException { closed.set(true); - consumerPollThread.shutdown(); + pollThread.close(); offsetFetcherThread.shutdown(); boolean isShutdown = false; @@ -746,18 +649,11 @@ public void close() throws IOException { // Wait for threads to shutdown. Trying this as a loop to handle a tiny race where poll thread // might block to enqueue right after availableRecordsQueue.poll() below. while (!isShutdown) { - - if (consumer != null) { - consumer.wakeup(); - } if (offsetConsumer != null) { offsetConsumer.wakeup(); } - availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. try { - isShutdown = - consumerPollThread.awaitTermination(10, TimeUnit.SECONDS) - && offsetFetcherThread.awaitTermination(10, TimeUnit.SECONDS); + isShutdown = offsetFetcherThread.awaitTermination(10, TimeUnit.SECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); // not expected @@ -768,14 +664,10 @@ public void close() throws IOException { } } - // Commit any pending finalized checkpoint before shutdown. - commitCheckpointMark(); - Closeables.close(keyDeserializerInstance, true); Closeables.close(valueDeserializerInstance, true); Closeables.close(offsetConsumer, true); - Closeables.close(consumer, true); } @VisibleForTesting diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 9bb950bb8e6c..5f934ac2551e 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -19,12 +19,15 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import java.time.Duration; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; @@ -59,6 +62,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; @@ -97,14 +101,15 @@ * *

Splitting

* - *

TODO(https://github.com/apache/beam/issues/20280): Add support for initial splitting. + *

TODO(...): Add support for initial + * splitting. * *

Checkpoint and Resume Processing

* *

There are 2 types of checkpoint here: self-checkpoint which invokes by the DoFn and * system-checkpoint which is issued by the runner via {@link * org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitRequest}. Every time the - * consumer gets empty response from {@link Consumer#poll(long)}, {@link ReadFromKafkaDoFn} will + * consumer gets empty response from {@link Consumer#poll(Duration)}, {@link ReadFromKafkaDoFn} will * checkpoint the current {@link KafkaSourceDescriptor} and move to process the next element. These * deferred elements will be resumed by the runner as soon as possible. * @@ -117,7 +122,7 @@ * ReadFromKafkaDoFn#restrictionTracker(KafkaSourceDescriptor, OffsetRange)} for details. * *

The size is computed by {@link ReadFromKafkaDoFn#getSize(KafkaSourceDescriptor, OffsetRange)}. - * A {@link KafkaIOUtils.MovingAvg} is used to track the average size of kafka records. + * A {@link MovingAvg} is used to track the average size of kafka records. * *

Track Watermark

* @@ -135,8 +140,8 @@ * {@link ReadFromKafkaDoFn} will stop reading from any removed {@link TopicPartition} automatically * by querying Kafka {@link Consumer} APIs. Please note that stopping reading may not happen as soon * as the {@link TopicPartition} is removed. For example, the removal could happen at the same time - * when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(java.time.Duration)}. In that - * case, the {@link ReadFromKafkaDoFn} will still output the fetched records. + * when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(Duration)}. In that case, the + * {@link ReadFromKafkaDoFn} will still output the fetched records. * *

Stop Reading from Stopped {@link TopicPartition}

* @@ -226,7 +231,7 @@ private ReadFromKafkaDoFn( private transient @Nullable LoadingCache avgRecordSize; private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L; - private HashMap perPartitionBacklogMetrics = new HashMap();; + private final HashMap perPartitionBacklogMetrics = new HashMap<>(); @VisibleForTesting final long consumerPollingTimeout; @VisibleForTesting final DeserializerProvider keyDeserializerProvider; @@ -444,10 +449,18 @@ public ProcessContinuation processElement( long expectedOffset = startOffset; consumer.seek(kafkaSourceDescriptor.getTopicPartition(), startOffset); - ConsumerRecords rawRecords = ConsumerRecords.empty(); + + KafkaConsumerPollThread pollThread = new KafkaConsumerPollThread(); + ExecutorService backgroundThread = + Executors.newSingleThreadExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("KafkaConsumerPoll-thread") + .build()); + pollThread.startOnExecutor(backgroundThread, consumer); while (true) { - rawRecords = poll(consumer, kafkaSourceDescriptor.getTopicPartition()); + ConsumerRecords rawRecords = pollThread.readRecords(); // When there are no records available for the current TopicPartition, self-checkpoint // and move to process the next element. if (rawRecords.isEmpty()) { @@ -587,7 +600,7 @@ public void setup() throws Exception { .build( new CacheLoader() { @Override - public AverageRecordSize load(TopicPartition topicPartition) throws Exception { + public AverageRecordSize load(TopicPartition topicPartition) { return new AverageRecordSize(); } }); @@ -626,7 +639,7 @@ private Map overrideBootstrapServersConfig( currentConfig.containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) || description.getBootStrapServers() != null); Map config = new HashMap<>(currentConfig); - if (description.getBootStrapServers() != null && description.getBootStrapServers().size() > 0) { + if (description.getBootStrapServers() != null && !description.getBootStrapServers().isEmpty()) { config.put( ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, String.join(",", description.getBootStrapServers())); @@ -635,8 +648,8 @@ private Map overrideBootstrapServersConfig( } private static class AverageRecordSize { - private MovingAvg avgRecordSize; - private MovingAvg avgRecordGap; + private final MovingAvg avgRecordSize; + private final MovingAvg avgRecordGap; public AverageRecordSize() { this.avgRecordSize = new MovingAvg(); From 237e3937d60a23aab51cede7317f7b6a5d46c105 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Fri, 5 Jul 2024 23:04:55 +0200 Subject: [PATCH 2/6] Add cache of background polling threads and use from kafka SDF --- .../sdk/io/kafka/KafkaConsumerPollThread.java | 14 +- .../kafka/KafkaConsumerPollThreadCache.java | 207 ++++++++++++++++++ .../sdk/io/kafka/KafkaUnboundedReader.java | 10 +- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 94 +++----- .../apache/beam/sdk/io/kafka/KafkaIOTest.java | 23 +- .../sdk/io/kafka/ReadFromKafkaDoFnTest.java | 2 + 6 files changed, 268 insertions(+), 82 deletions(-) create mode 100644 sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThreadCache.java diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java index 4dbf8a9ed50f..175c5672f3dc 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java @@ -19,7 +19,12 @@ import java.io.IOException; import java.time.Duration; -import java.util.concurrent.*; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -168,13 +173,14 @@ ConsumerRecords readRecords() throws IOException { recordsDequeuePollTimeout = recordsDequeuePollTimeout.minus(Duration.ofMillis(1)); LOG.debug("Reducing poll timeout for reader to " + recordsDequeuePollTimeout.toMillis()); } - } else if (recordsDequeuePollTimeout.compareTo(RECORDS_DEQUEUE_POLL_TIMEOUT_MAX) < 0) { + return ConsumerRecords.empty(); + } + if (recordsDequeuePollTimeout.compareTo(RECORDS_DEQUEUE_POLL_TIMEOUT_MAX) < 0) { recordsDequeuePollTimeout = recordsDequeuePollTimeout.plus(Duration.ofMillis(1)); LOG.debug("Increasing poll timeout for reader to " + recordsDequeuePollTimeout.toMillis()); LOG.debug("Record count: " + records.count()); } - - return records == null ? ConsumerRecords.empty() : records; + return records; } /** diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThreadCache.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThreadCache.java new file mode 100644 index 000000000000..8828cc9e5f44 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThreadCache.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import java.io.IOException; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalCause; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalNotification; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class KafkaConsumerPollThreadCache { + + private static final Logger LOG = LoggerFactory.getLogger(KafkaConsumerPollThreadCache.class); + private final ExecutorService invalidationExecutor = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("KafkaConsumerPollCache-invalidation-%d") + .build()); + private final ExecutorService backgroundThreads = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("KafkaConsumerPollCache-poll-%d") + .build()); + + // Note on thread safety. This class is thread safe because: + // - Guava Cache is thread safe. + // - There is no state other than Cache. + // - API is strictly a 1:1 wrapper over Cache API (not counting cache.cleanUp() calls). + // - i.e. it does not invoke more than one call, which could make it inconsistent. + // If any of these conditions changes, please test ensure and test thread safety. + + private static class CacheKey { + final Map consumerConfig; + final SerializableFunction, Consumer> consumerFactoryFn; + final KafkaSourceDescriptor descriptor; + + CacheKey( + Map consumerConfig, + SerializableFunction, Consumer> consumerFactoryFn, + KafkaSourceDescriptor descriptor) { + this.consumerConfig = consumerConfig; + this.consumerFactoryFn = consumerFactoryFn; + this.descriptor = descriptor; + } + + @Override + public boolean equals(@Nullable Object other) { + if (other == null) { + return false; + } + if (!(other instanceof CacheKey)) { + return false; + } + CacheKey otherKey = (CacheKey) other; + return descriptor.equals(otherKey.descriptor) + && consumerFactoryFn.equals(otherKey.consumerFactoryFn) + && consumerConfig.equals(otherKey.consumerConfig); + } + + @Override + public int hashCode() { + return Objects.hash(descriptor, consumerFactoryFn, consumerConfig); + } + } + + private static class CacheEntry { + + final KafkaConsumerPollThread pollThread; + final long offset; + + CacheEntry(KafkaConsumerPollThread pollThread, long offset) { + this.pollThread = pollThread; + this.offset = offset; + } + } + + private final Duration cacheDuration = Duration.ofMinutes(1); + private final Cache cache; + + @SuppressWarnings("method.invocation") + KafkaConsumerPollThreadCache() { + this.cache = + CacheBuilder.newBuilder() + .expireAfterWrite(cacheDuration.toMillis(), TimeUnit.MILLISECONDS) + .removalListener( + (RemovalNotification notification) -> { + if (notification.getCause() != RemovalCause.EXPLICIT) { + LOG.info( + "Asynchronously closing reader for {} as it has been idle for over {}", + notification.getKey(), + cacheDuration); + asyncCloseConsumer( + checkNotNull(notification.getKey()), checkNotNull(notification.getValue())); + } + }) + .build(); + } + + KafkaConsumerPollThread acquireConsumer( + Map consumerConfig, + SerializableFunction, Consumer> consumerFactoryFn, + KafkaSourceDescriptor kafkaSourceDescriptor, + long offset) { + CacheKey key = new CacheKey(consumerConfig, consumerFactoryFn, kafkaSourceDescriptor); + CacheEntry entry = cache.asMap().remove(key); + cache.cleanUp(); + if (entry != null) { + if (entry.offset == offset) { + return entry.pollThread; + } else { + // Offset doesn't match, close. + LOG.info("Closing consumer as it is no longer valid {}", kafkaSourceDescriptor); + asyncCloseConsumer(key, entry); + } + } + + Map updatedConsumerConfig = + overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); + LOG.info( + "Creating Kafka consumer for process continuation for {}", + kafkaSourceDescriptor.getTopicPartition()); + Consumer consumer = consumerFactoryFn.apply(updatedConsumerConfig); + ConsumerSpEL.evaluateAssign( + consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition())); + consumer.seek(kafkaSourceDescriptor.getTopicPartition(), offset); + KafkaConsumerPollThread pollThread = new KafkaConsumerPollThread(); + pollThread.startOnExecutor(backgroundThreads, consumer); + return pollThread; + } + + /** Close the reader and log a warning if close fails. */ + private void asyncCloseConsumer(CacheKey key, CacheEntry entry) { + invalidationExecutor.execute( + () -> { + try { + entry.pollThread.close(); + LOG.info("Finished closing consumer for {}", key); + } catch (IOException e) { + LOG.warn("Failed to close consumer for {}", key, e); + } + }); + } + + void releaseConsumer( + Map consumerConfig, + SerializableFunction, Consumer> consumerFactoryFn, + KafkaSourceDescriptor kafkaSourceDescriptor, + KafkaConsumerPollThread pollThread, + long offset) { + CacheKey key = new CacheKey(consumerConfig, consumerFactoryFn, kafkaSourceDescriptor); + CacheEntry existing = cache.asMap().putIfAbsent(key, new CacheEntry(pollThread, offset)); + if (existing != null) { + LOG.warn("Unexpected collision of topic and partition"); + asyncCloseConsumer(key, existing); + } + cache.cleanUp(); + } + + private Map overrideBootstrapServersConfig( + Map currentConfig, KafkaSourceDescriptor description) { + checkState( + currentConfig.containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) + || description.getBootStrapServers() != null); + Map config = new HashMap<>(currentConfig); + if (description.getBootStrapServers() != null && !description.getBootStrapServers().isEmpty()) { + config.put( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, + String.join(",", description.getBootStrapServers())); + } + return config; + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index 7aa9df199f91..5f1d621605f6 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -348,7 +348,7 @@ public long getSplitBacklogBytes() { private final Counter checkpointMarkCommitsSkipped = Metrics.counter(METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_SKIPPED_METRIC); - private final KafkaConsumerPollThread pollThread; + private final transient KafkaConsumerPollThread pollThread; // Use a separate thread to read Kafka messages. Kafka Consumer does all its work including // network I/O inside poll(). Polling only inside #advance(), especially with a small timeout @@ -555,7 +555,7 @@ private void nextBatch() throws IOException { curBatch = Collections.emptyIterator(); ConsumerRecords records = pollThread.readRecords(); - if (records != null) { + if (!records.isEmpty()) { partitionStates.forEach(p -> p.recordIter = records.records(p.topicPartition).iterator()); // cycle through the partitions in order to interleave records from each. curBatch = Iterators.cycle(new ArrayList<>(partitionStates)); @@ -641,7 +641,11 @@ private long getSplitBacklogMessageCount() { @Override public void close() throws IOException { closed.set(true); - pollThread.close(); + try { + pollThread.close(); + } catch (IOException e) { + LOG.warn("Error shutting down poll thread", e); + } offsetFetcherThread.shutdown(); boolean isShutdown = false; diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 5f934ac2551e..41b4c34c99a6 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -21,14 +21,10 @@ import java.time.Duration; import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; import org.apache.beam.sdk.io.kafka.KafkaIOUtils.MovingAvg; @@ -62,12 +58,10 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; -import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.errors.SerializationException; import org.apache.kafka.common.serialization.Deserializer; @@ -107,11 +101,10 @@ *

Checkpoint and Resume Processing

* *

There are 2 types of checkpoint here: self-checkpoint which invokes by the DoFn and - * system-checkpoint which is issued by the runner via {@link - * org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitRequest}. Every time the - * consumer gets empty response from {@link Consumer#poll(Duration)}, {@link ReadFromKafkaDoFn} will - * checkpoint the current {@link KafkaSourceDescriptor} and move to process the next element. These - * deferred elements will be resumed by the runner as soon as possible. + * system-checkpoint which is issued by the runner via {@link BeamFnApi.ProcessBundleSplitRequest}. + * Every time the consumer gets empty response from {@link Consumer#poll(Duration)}, {@link + * ReadFromKafkaDoFn} will checkpoint the current {@link KafkaSourceDescriptor} and move to process + * the next element. These deferred elements will be resumed by the runner as soon as possible. * *

Progress and Size

* @@ -223,6 +216,9 @@ private ReadFromKafkaDoFn( private final TupleTag>> recordTag; + private static final Supplier cache = + Suppliers.memoize(KafkaConsumerPollThreadCache::new); + // Valid between bundle start and bundle finish. private transient @Nullable Deserializer keyDeserializerInstance = null; private transient @Nullable Deserializer valueDeserializerInstance = null; @@ -438,27 +434,15 @@ public ProcessContinuation processElement( kafkaSourceDescriptor.getTopicPartition(), Optional.ofNullable(watermarkEstimator.currentWatermark())); } - - LOG.info( - "Creating Kafka consumer for process continuation for {}", - kafkaSourceDescriptor.getTopicPartition()); - try (Consumer consumer = consumerFactoryFn.apply(updatedConsumerConfig)) { - ConsumerSpEL.evaluateAssign( - consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition())); - long startOffset = tracker.currentRestriction().getFrom(); - - long expectedOffset = startOffset; - consumer.seek(kafkaSourceDescriptor.getTopicPartition(), startOffset); - - KafkaConsumerPollThread pollThread = new KafkaConsumerPollThread(); - ExecutorService backgroundThread = - Executors.newSingleThreadExecutor( - new ThreadFactoryBuilder() - .setDaemon(true) - .setNameFormat("KafkaConsumerPoll-thread") - .build()); - pollThread.startOnExecutor(backgroundThread, consumer); - + final long startOffset = tracker.currentRestriction().getFrom(); + long resumeOffset = startOffset; + @Nullable KafkaConsumerPollThread pollThread = null; + try { + pollThread = + cache + .get() + .acquireConsumer( + updatedConsumerConfig, consumerFactoryFn, kafkaSourceDescriptor, startOffset); while (true) { ConsumerRecords rawRecords = pollThread.readRecords(); // When there are no records available for the current TopicPartition, self-checkpoint @@ -475,6 +459,7 @@ public ProcessContinuation processElement( } for (ConsumerRecord rawRecord : rawRecords) { if (!tracker.tryClaim(rawRecord.offset())) { + // XXX need to add unconsumed records back. return ProcessContinuation.stop(); } try { @@ -493,9 +478,9 @@ public ProcessContinuation processElement( + (rawRecord.value() == null ? 0 : rawRecord.value().length); avgRecordSize .getUnchecked(kafkaSourceDescriptor.getTopicPartition()) - .update(recordSize, rawRecord.offset() - expectedOffset); + .update(recordSize, rawRecord.offset() - resumeOffset); rawSizes.update(recordSize); - expectedOffset = rawRecord.offset() + 1; + resumeOffset = rawRecord.offset() + 1; Instant outputTimestamp; // The outputTimestamp and watermark will be computed by timestampPolicy, where the // WatermarkEstimator should be a manual one. @@ -525,6 +510,17 @@ public ProcessContinuation processElement( } } } + } finally { + if (pollThread != null) { + cache + .get() + .releaseConsumer( + updatedConsumerConfig, + consumerFactoryFn, + kafkaSourceDescriptor, + pollThread, + resumeOffset); + } } } @@ -545,34 +541,6 @@ private boolean topicPartitionExists( return true; } - // see https://github.com/apache/beam/issues/25962 - private ConsumerRecords poll( - Consumer consumer, TopicPartition topicPartition) { - final Stopwatch sw = Stopwatch.createStarted(); - long previousPosition = -1; - java.time.Duration elapsed = java.time.Duration.ZERO; - java.time.Duration timeout = java.time.Duration.ofSeconds(this.consumerPollingTimeout); - while (true) { - final ConsumerRecords rawRecords = consumer.poll(timeout.minus(elapsed)); - if (!rawRecords.isEmpty()) { - // return as we have found some entries - return rawRecords; - } - if (previousPosition == (previousPosition = consumer.position(topicPartition))) { - // there was no progress on the offset/position, which indicates end of stream - return rawRecords; - } - elapsed = sw.elapsed(); - if (elapsed.toMillis() >= timeout.toMillis()) { - // timeout is over - LOG.warn( - "No messages retrieved with polling timeout {} seconds. Consider increasing the consumer polling timeout using withConsumerPollingTimeout method.", - consumerPollingTimeout); - return rawRecords; - } - } - } - private TimestampPolicyContext updateWatermarkManually( TimestampPolicy timestampPolicy, WatermarkEstimator watermarkEstimator, diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java index 1fe1147a7390..97230730862b 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -117,12 +117,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; -import org.apache.kafka.clients.consumer.Consumer; -import org.apache.kafka.clients.consumer.ConsumerConfig; -import org.apache.kafka.clients.consumer.ConsumerRecord; -import org.apache.kafka.clients.consumer.MockConsumer; -import org.apache.kafka.clients.consumer.OffsetAndTimestamp; -import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.consumer.*; import org.apache.kafka.clients.producer.MockProducer; import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerConfig; @@ -154,6 +149,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.slf4j.Logger; @@ -165,7 +161,7 @@ */ @RunWith(JUnit4.class) public class KafkaIOTest { - + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final Logger LOG = LoggerFactory.getLogger(KafkaIOTest.class); /* @@ -181,7 +177,7 @@ public class KafkaIOTest { @Rule public ExpectedException thrown = ExpectedException.none(); @Rule - public ExpectedLogs unboundedReaderExpectedLogs = ExpectedLogs.none(KafkaUnboundedReader.class); + public ExpectedLogs pollThreadExpectedLogs = ExpectedLogs.none(KafkaConsumerPollThread.class); @Rule public ExpectedLogs kafkaIOExpectedLogs = ExpectedLogs.none(KafkaIO.class); @@ -1423,7 +1419,7 @@ public void testUnboundedSourceMetrics() { } @Test - public void testUnboundedReaderLogsCommitFailure() throws Exception { + public void testUnboundedReaderLogsFetchFailure() throws Exception { List topics = ImmutableList.of("topic_a"); @@ -1440,9 +1436,12 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception { UnboundedReader> reader = source.createReader(null, null); - reader.start(); - - unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partition"); + try { + reader.start(); + } catch (Exception e) { + // Racy if we observe the exception on initial advance. + } + pollThreadExpectedLogs.verifyError("Exception while reading from Kafka"); reader.close(); } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index 6ee3d9d96ef6..c2fec5536be9 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -83,8 +83,10 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.junit.rules.Timeout; public class ReadFromKafkaDoFnTest { + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private final TopicPartition topicPartition = new TopicPartition("topic", 0); From d69711208e8c924ecbe7082cf8052494e84900bc Mon Sep 17 00:00:00 2001 From: Bartosz Zablocki Date: Thu, 12 Sep 2024 14:01:53 +0200 Subject: [PATCH 3/6] first fix of compilation errors --- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 52 +++++++++---------- .../apache/beam/sdk/io/kafka/KafkaIOTest.java | 7 ++- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 41b4c34c99a6..c2403aea4f1c 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -447,16 +447,16 @@ public ProcessContinuation processElement( ConsumerRecords rawRecords = pollThread.readRecords(); // When there are no records available for the current TopicPartition, self-checkpoint // and move to process the next element. - if (rawRecords.isEmpty()) { - if (!topicPartitionExists( - kafkaSourceDescriptor.getTopicPartition(), consumer.listTopics())) { - return ProcessContinuation.stop(); - } - if (timestampPolicy != null) { - updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); - } - return ProcessContinuation.resume(); - } + // if (rawRecords.isEmpty()) { + // if (!topicPartitionExists( + // kafkaSourceDescriptor.getTopicPartition(), consumer.listTopics())) { + // return ProcessContinuation.stop(); + // } + // if (timestampPolicy != null) { + // updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); + // } + // return ProcessContinuation.resume(); + // } for (ConsumerRecord rawRecord : rawRecords) { if (!tracker.tryClaim(rawRecord.offset())) { // XXX need to add unconsumed records back. @@ -524,22 +524,22 @@ public ProcessContinuation processElement( } } - private boolean topicPartitionExists( - TopicPartition topicPartition, Map> topicListMap) { - // Check if the current TopicPartition still exists. - Set existingTopicPartitions = new HashSet<>(); - for (List topicPartitionList : topicListMap.values()) { - topicPartitionList.forEach( - partitionInfo -> { - existingTopicPartitions.add( - new TopicPartition(partitionInfo.topic(), partitionInfo.partition())); - }); - } - if (!existingTopicPartitions.contains(topicPartition)) { - return false; - } - return true; - } + // private boolean topicPartitionExists( + // TopicPartition topicPartition, Map> topicListMap) { + // // Check if the current TopicPartition still exists. + // Set existingTopicPartitions = new HashSet<>(); + // for (List topicPartitionList : topicListMap.values()) { + // topicPartitionList.forEach( + // partitionInfo -> { + // existingTopicPartitions.add( + // new TopicPartition(partitionInfo.topic(), partitionInfo.partition())); + // }); + // } + // if (!existingTopicPartitions.contains(topicPartition)) { + // return false; + // } + // return true; + // } private TimestampPolicyContext updateWatermarkManually( TimestampPolicy timestampPolicy, diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java index 97230730862b..b5a2d75291bd 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -117,7 +117,12 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; -import org.apache.kafka.clients.consumer.*; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; import org.apache.kafka.clients.producer.MockProducer; import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerConfig; From 4871e15cb4ce6d36402eb8439730afd0a26034e1 Mon Sep 17 00:00:00 2001 From: Bartosz Zablocki Date: Thu, 12 Sep 2024 16:45:15 +0200 Subject: [PATCH 4/6] fix compilation errors related to fetching the current list of topics --- .../sdk/io/kafka/KafkaConsumerPollThread.java | 27 ++++++++- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 58 ++++++++++--------- 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java index 175c5672f3dc..047c0ac556ca 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java @@ -19,6 +19,9 @@ import java.io.IOException; import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -34,6 +37,7 @@ import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.errors.WakeupException; import org.checkerframework.checker.nullness.qual.Nullable; @@ -73,7 +77,8 @@ public class KafkaConsumerPollThread { private final AtomicReference<@Nullable KafkaCheckpointMark> finalizedCheckpointMark = new AtomicReference<>(); private final AtomicBoolean closed = new AtomicBoolean(false); - + private final Map> topicsList = new ConcurrentHashMap<>(); + private final AtomicBoolean topicListUpdated = new AtomicBoolean(false); private static final Logger LOG = LoggerFactory.getLogger(KafkaConsumerPollThread.class); private @Nullable Future pollFuture; @@ -130,6 +135,7 @@ private void consumerPollLoop() { try { if (records.isEmpty()) { records = consumer.poll(KAFKA_POLL_TIMEOUT); + updateTopicList(consumer); } else if (availableRecordsQueue.offer( records, RECORDS_ENQUEUE_POLL_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS)) { records = ConsumerRecords.empty(); @@ -153,6 +159,25 @@ private void consumerPollLoop() { commitCheckpointMark(); } + private void updateTopicList(Consumer consumer) { + synchronized (topicsList) { + Map> currentTopicsList = consumer.listTopics(); + topicsList.clear(); + topicsList.putAll(currentTopicsList); + topicListUpdated.set(true); + } + } + + public boolean isTopicListUpdated() { + return topicListUpdated.get(); + } + + public Map> getTopicsList() { + synchronized (topicsList) { + return topicsList; + } + } + ConsumerRecords readRecords() throws IOException { @Nullable ConsumerRecords records = null; try { diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index c2403aea4f1c..e3f6ee7ad696 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -21,8 +21,11 @@ import java.time.Duration; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.TimeUnit; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.sdk.coders.Coder; @@ -50,7 +53,6 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; @@ -62,6 +64,7 @@ import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.errors.SerializationException; import org.apache.kafka.common.serialization.Deserializer; @@ -313,7 +316,6 @@ public OffsetRange initialRestriction(@Element KafkaSourceDescriptor kafkaSource } else if (stopReadTime != null) { endOffset = ConsumerSpEL.offsetForTime(offsetConsumer, partition, stopReadTime); } - new OffsetRange(startOffset, endOffset); Lineage.getSources() .add( "kafka", @@ -447,16 +449,16 @@ public ProcessContinuation processElement( ConsumerRecords rawRecords = pollThread.readRecords(); // When there are no records available for the current TopicPartition, self-checkpoint // and move to process the next element. - // if (rawRecords.isEmpty()) { - // if (!topicPartitionExists( - // kafkaSourceDescriptor.getTopicPartition(), consumer.listTopics())) { - // return ProcessContinuation.stop(); - // } - // if (timestampPolicy != null) { - // updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); - // } - // return ProcessContinuation.resume(); - // } + if (pollThread.isTopicListUpdated() && rawRecords.isEmpty()) { + Map> topicsList = pollThread.getTopicsList(); + if (!topicPartitionExists(kafkaSourceDescriptor.getTopicPartition(), topicsList)) { + return ProcessContinuation.stop(); + } + if (timestampPolicy != null) { + updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); + } + return ProcessContinuation.resume(); + } for (ConsumerRecord rawRecord : rawRecords) { if (!tracker.tryClaim(rawRecord.offset())) { // XXX need to add unconsumed records back. @@ -524,22 +526,22 @@ public ProcessContinuation processElement( } } - // private boolean topicPartitionExists( - // TopicPartition topicPartition, Map> topicListMap) { - // // Check if the current TopicPartition still exists. - // Set existingTopicPartitions = new HashSet<>(); - // for (List topicPartitionList : topicListMap.values()) { - // topicPartitionList.forEach( - // partitionInfo -> { - // existingTopicPartitions.add( - // new TopicPartition(partitionInfo.topic(), partitionInfo.partition())); - // }); - // } - // if (!existingTopicPartitions.contains(topicPartition)) { - // return false; - // } - // return true; - // } + private boolean topicPartitionExists( + TopicPartition topicPartition, Map> topicListMap) { + // Check if the current TopicPartition still exists. + Set existingTopicPartitions = new HashSet<>(); + for (List topicPartitionList : topicListMap.values()) { + topicPartitionList.forEach( + partitionInfo -> { + existingTopicPartitions.add( + new TopicPartition(partitionInfo.topic(), partitionInfo.partition())); + }); + } + if (!existingTopicPartitions.contains(topicPartition)) { + return false; + } + return true; + } private TimestampPolicyContext updateWatermarkManually( TimestampPolicy timestampPolicy, From 9f8e26af304a7d3313c22c72c46f4dc05698063a Mon Sep 17 00:00:00 2001 From: Bartosz Zablocki Date: Thu, 12 Sep 2024 17:29:48 +0200 Subject: [PATCH 5/6] Handle tracker interruption --- .../sdk/io/kafka/KafkaConsumerPollThread.java | 9 +++++ .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 35 ++++++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java index 047c0ac556ca..867d5924418f 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java @@ -35,6 +35,7 @@ import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.common.PartitionInfo; @@ -208,6 +209,12 @@ ConsumerRecords readRecords() throws IOException { return records; } + public void putBackToQueue(Map>> rawRecords) { + // todo handle InterruptedException + availableRecordsQueue.offer( + new ConsumerRecords<>(rawRecords), RECORDS_ENQUEUE_POLL_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS); + } + /** * Enqueue checkpoint mark to be committed to Kafka. This does not block until it is committed. * There could be a delay of up to KAFKA_POLL_TIMEOUT (1 second). Any checkpoint mark enqueued @@ -237,4 +244,6 @@ private void commitCheckpointMark() { p -> new OffsetAndMetadata(p.getNextOffset())))); } } + + } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index e3f6ee7ad696..2040109466bb 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -22,6 +22,7 @@ import java.time.Duration; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -59,6 +60,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; @@ -459,9 +461,40 @@ public ProcessContinuation processElement( } return ProcessContinuation.resume(); } - for (ConsumerRecord rawRecord : rawRecords) { + Iterator> recordsIterator = rawRecords.iterator(); + while (recordsIterator.hasNext()) { + ConsumerRecord rawRecord = recordsIterator.next(); + // for (ConsumerRecord rawRecord : rawRecords) { if (!tracker.tryClaim(rawRecord.offset())) { // XXX need to add unconsumed records back. + // todo what should happen here? + // Let's say we have X outstanding records in the queue in the background thread + // and this tracker is not letting us claim this particular record. + // We have to call `ProcessContinuation.stop()`, and do something with this and further + // records. + // But what? + // Should we also close the thread that reads the data? + // My answer: + // We don't close the background thread, can it be picked up be the next tracker? + // 'acquireConsumer()' take startoffset as an argument, but doesn't use this as a key. + // How to test it? + // So maybe put it back to the queue. + ImmutableList.Builder> recordListBuilder = + new ImmutableList.Builder<>(); + recordListBuilder.add(rawRecord); + // drain rawRecords + while (recordsIterator.hasNext()) { + rawRecord = recordsIterator.next(); + recordListBuilder.add(rawRecord); + } + ImmutableList> recordsToPutBack = + recordListBuilder.build(); + ImmutableMap>> + recordsPerTopicPartition = + ImmutableMap.of(kafkaSourceDescriptor.getTopicPartition(), recordsToPutBack); + + pollThread.putBackToQueue(recordsPerTopicPartition); + return ProcessContinuation.stop(); } try { From 6c0e7709a3e762ed8ff71fe2d678eabdbb22d5e6 Mon Sep 17 00:00:00 2001 From: Bartosz Zablocki Date: Mon, 16 Sep 2024 20:35:21 +0200 Subject: [PATCH 6/6] Handle case when we can't claim a record --- .../sdk/io/kafka/KafkaConsumerPollThread.java | 40 +++++++++--- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 65 ++++++++----------- 2 files changed, 57 insertions(+), 48 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java index 867d5924418f..292c8713ea73 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaConsumerPollThread.java @@ -33,6 +33,8 @@ import java.util.stream.Collectors; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.PeekingIterator; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerRecord; @@ -83,13 +85,14 @@ public class KafkaConsumerPollThread { private static final Logger LOG = LoggerFactory.getLogger(KafkaConsumerPollThread.class); private @Nullable Future pollFuture; + private @Nullable PeekingIterator> activeBatchIterator; void startOnExecutor(ExecutorService executorService, Consumer consumer) { this.consumer = consumer; // Use a separate thread to read Kafka messages. Kafka Consumer does all its work including // network I/O inside poll(). Polling only inside #advance(), especially with a small timeout // like 100 milliseconds does not work well. This along with large receive buffer for - // consumer achieved best throughput in tests (see `defaultConsumerProperties`). + // consumer achieved the best throughput in tests (see `defaultConsumerProperties`). pollFuture = executorService.submit(this::consumerPollLoop); } @@ -101,12 +104,14 @@ void close() throws IOException { Preconditions.checkStateNotNull(pollFuture); closed.set(true); try { - // Wait for threads to shutdown. Trying this as a loop to handle a tiny race where poll thread + // Wait for threads to shut down. Trying this as a loop to handle a tiny race where poll + // thread // might block to enqueue right after availableRecordsQueue.poll() below. while (true) { if (consumer != null) { consumer.wakeup(); } + // todo will this drop unprocessed records? availableRecordsQueue.poll(); // drain unread batch, this unblocks consumer thread. try { Preconditions.checkStateNotNull(pollFuture); @@ -126,6 +131,29 @@ void close() throws IOException { } } + @Nullable + ConsumerRecord peek() throws IOException { + PeekingIterator> currentIterator = getOrInitialize(); + if (currentIterator.hasNext()) { + return currentIterator.peek(); + } + return null; + } + + void advance() throws IOException { + PeekingIterator> currentIterator = getOrInitialize(); + if (currentIterator.hasNext()) { + currentIterator.next(); + } + } + + private PeekingIterator> getOrInitialize() throws IOException { + if (activeBatchIterator == null || !activeBatchIterator.hasNext()) { + activeBatchIterator = Iterators.peekingIterator(readRecords().iterator()); + } + return activeBatchIterator; + } + private void consumerPollLoop() { // Read in a loop and enqueue the batch of records, if any, to availableRecordsQueue. Consumer consumer = Preconditions.checkStateNotNull(this.consumer); @@ -209,12 +237,6 @@ ConsumerRecords readRecords() throws IOException { return records; } - public void putBackToQueue(Map>> rawRecords) { - // todo handle InterruptedException - availableRecordsQueue.offer( - new ConsumerRecords<>(rawRecords), RECORDS_ENQUEUE_POLL_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS); - } - /** * Enqueue checkpoint mark to be committed to Kafka. This does not block until it is committed. * There could be a delay of up to KAFKA_POLL_TIMEOUT (1 second). Any checkpoint mark enqueued @@ -244,6 +266,4 @@ private void commitCheckpointMark() { p -> new OffsetAndMetadata(p.getNextOffset())))); } } - - } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 2040109466bb..b91b15af3480 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -22,7 +22,6 @@ import java.time.Duration; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -60,12 +59,10 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; -import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.errors.SerializationException; @@ -439,6 +436,11 @@ public ProcessContinuation processElement( Optional.ofNullable(watermarkEstimator.currentWatermark())); } final long startOffset = tracker.currentRestriction().getFrom(); + String restrictionInfo = + String.format( + "%s_%s_%s", + kafkaSourceDescriptor.getTopic(), kafkaSourceDescriptor.getPartition(), startOffset); + LOG.info("bzablockilog start restriction {}", restrictionInfo); long resumeOffset = startOffset; @Nullable KafkaConsumerPollThread pollThread = null; try { @@ -448,55 +450,40 @@ public ProcessContinuation processElement( .acquireConsumer( updatedConsumerConfig, consumerFactoryFn, kafkaSourceDescriptor, startOffset); while (true) { - ConsumerRecords rawRecords = pollThread.readRecords(); + ConsumerRecord rawRecord = pollThread.peek(); + // When there are no records available for the current TopicPartition, self-checkpoint // and move to process the next element. - if (pollThread.isTopicListUpdated() && rawRecords.isEmpty()) { + // if (pollThread.isTopicListUpdated() && !rawRecordOptional.isPresent()) { + if (pollThread.isTopicListUpdated() && rawRecord == null) { Map> topicsList = pollThread.getTopicsList(); if (!topicPartitionExists(kafkaSourceDescriptor.getTopicPartition(), topicsList)) { + LOG.info("bzablockilog stop restriction {}", restrictionInfo); return ProcessContinuation.stop(); } if (timestampPolicy != null) { updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); } + LOG.info("bzablockilog resume restriction {}", restrictionInfo); return ProcessContinuation.resume(); } - Iterator> recordsIterator = rawRecords.iterator(); - while (recordsIterator.hasNext()) { - ConsumerRecord rawRecord = recordsIterator.next(); - // for (ConsumerRecord rawRecord : rawRecords) { + + if (rawRecord != null) { + LOG.info( + "bzablockilog picked up {}_{}_{}", + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset()); if (!tracker.tryClaim(rawRecord.offset())) { + LOG.info( + "bzablockilog unsuccessful claim of {}_{}_{}", + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset()); // XXX need to add unconsumed records back. - // todo what should happen here? - // Let's say we have X outstanding records in the queue in the background thread - // and this tracker is not letting us claim this particular record. - // We have to call `ProcessContinuation.stop()`, and do something with this and further - // records. - // But what? - // Should we also close the thread that reads the data? - // My answer: - // We don't close the background thread, can it be picked up be the next tracker? - // 'acquireConsumer()' take startoffset as an argument, but doesn't use this as a key. - // How to test it? - // So maybe put it back to the queue. - ImmutableList.Builder> recordListBuilder = - new ImmutableList.Builder<>(); - recordListBuilder.add(rawRecord); - // drain rawRecords - while (recordsIterator.hasNext()) { - rawRecord = recordsIterator.next(); - recordListBuilder.add(rawRecord); - } - ImmutableList> recordsToPutBack = - recordListBuilder.build(); - ImmutableMap>> - recordsPerTopicPartition = - ImmutableMap.of(kafkaSourceDescriptor.getTopicPartition(), recordsToPutBack); - - pollThread.putBackToQueue(recordsPerTopicPartition); - return ProcessContinuation.stop(); } + try { KafkaRecord kafkaRecord = new KafkaRecord<>( @@ -538,11 +525,13 @@ public ProcessContinuation processElement( rawRecord, null, e, - "Failure deserializing Key or Value of Kakfa record reading from Kafka"); + "Failure deserializing Key or Value of Kafka record reading from Kafka"); if (timestampPolicy != null) { updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); } } + + pollThread.advance(); } } } finally {