diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java index 93e7ff2b663d..f34547bd2611 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java @@ -51,9 +51,12 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.windowing.AfterFirst; import org.apache.beam.sdk.transforms.windowing.AfterPane; +import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.Trigger; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.KV; @@ -84,6 +87,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTimeUtils; import org.joda.time.DateTimeZone; +import org.joda.time.Duration; import org.joda.time.format.DateTimeFormat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -159,7 +163,8 @@ static void ensureEOSSupport() { @Override public PCollection expand(PCollection> input) { String topic = Preconditions.checkStateNotNull(spec.getTopic()); - + int numElements = spec.getEosTriggerNumElements(); + Duration timeout = spec.getEosTriggerTimeout(); int numShards = spec.getNumShards(); if (numShards <= 0) { try (Consumer consumer = openConsumer(spec)) { @@ -172,17 +177,34 @@ public PCollection expand(PCollection> input) { } } checkState(numShards > 0, "Could not set number of shards"); - + Trigger.OnceTrigger trigger = null; + if (timeout != null) { + trigger = + AfterFirst.of( + AfterPane.elementCountAtLeast(numElements), + AfterProcessingTime.pastFirstElementInPane().plusDelayOf(timeout)); + } else { + // fallback to default + trigger = AfterPane.elementCountAtLeast(numElements); + } return input .apply( Window.>into(new GlobalWindows()) // Everything into global window. - .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))) + .triggering(Repeatedly.forever(trigger)) .discardingFiredPanes()) .apply( String.format("Shuffle across %d shards", numShards), ParDo.of(new Reshard<>(numShards))) .apply("Persist sharding", GroupByKey.create()) .apply("Assign sequential ids", ParDo.of(new Sequencer<>())) + // Reapply the windowing configuration as the continuation trigger doesn't maintain the + // desired batching. + .apply( + "Windowing", + Window.>>>>into( + new GlobalWindows()) // Everything into global window. + .triggering(Repeatedly.forever(trigger)) + .discardingFiredPanes()) .apply("Persist ids", GroupByKey.create()) .apply( String.format("Write to Kafka topic '%s'", spec.getTopic()), diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index ad5535517646..02d14b745fe8 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -648,6 +648,8 @@ public static WriteRecords writeRecords() { return new AutoValue_KafkaIO_WriteRecords.Builder() .setProducerConfig(WriteRecords.DEFAULT_PRODUCER_PROPERTIES) .setEOS(false) + .setEosTriggerNumElements(1) // keep default numElements + .setEosTriggerTimeout(null) // keep default trigger (timeout) .setNumShards(0) .setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN) .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) @@ -3185,6 +3187,10 @@ public abstract static class WriteRecords @Pure public abstract boolean isEOS(); + public abstract int getEosTriggerNumElements(); + + public abstract @Nullable Duration getEosTriggerTimeout(); + @Pure public abstract @Nullable String getSinkGroupId(); @@ -3221,6 +3227,10 @@ abstract Builder setPublishTimestampFunction( abstract Builder setEOS(boolean eosEnabled); + abstract Builder setEosTriggerNumElements(int numElements); + + abstract Builder setEosTriggerTimeout(@Nullable Duration timeout); + abstract Builder setSinkGroupId(String sinkGroupId); abstract Builder setNumShards(int numShards); @@ -3368,6 +3378,15 @@ public WriteRecords withEOS(int numShards, String sinkGroupId) { return toBuilder().setEOS(true).setNumShards(numShards).setSinkGroupId(sinkGroupId).build(); } + public WriteRecords withEOSTriggerConfig(int numElements, Duration timeout) { + checkArgument(numElements >= 1, "numElements should be >= 1"); + checkArgument(timeout != null, "timeout is required for exactly-once sink"); + return toBuilder() + .setEosTriggerNumElements(numElements) + .setEosTriggerTimeout(timeout) + .build(); + } + /** * When exactly-once semantics are enabled (see {@link #withEOS(int, String)}), the sink needs * to fetch previously stored state with Kafka topic. Fetching the metadata requires a consumer. @@ -3653,6 +3672,19 @@ public Write withEOS(int numShards, String sinkGroupId) { return withWriteRecordsTransform(getWriteRecordsTransform().withEOS(numShards, sinkGroupId)); } + /** + * Set the frequency and numElements threshold at which messages are triggered. + * + *

This is only applicable when the write method is set to EOS. + * + *

Every timeout duration, or numElements (repeated, after first condition is met) collection + * of elements written. + */ + public Write withEOSTriggerConfig(int numElements, Duration timeout) { + return withWriteRecordsTransform( + getWriteRecordsTransform().withEOSTriggerConfig(numElements, timeout)); + } + /** * Wrapper method over {@link WriteRecords#withConsumerFactoryFn(SerializableFunction)}, used to * keep the compatibility with old API based on KV type of element. diff --git a/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java index 51d9b028bab0..a015d6d48f32 100644 --- a/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java +++ b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java @@ -479,6 +479,8 @@ static class KafkaIOWriteTranslator implements TransformPayloadTranslator transform) { } fieldValues.put("eos", writeRecordsTransform.isEOS()); + org.joda.time.Duration eosTriggerTimeout = writeRecordsTransform.getEosTriggerTimeout(); + if (eosTriggerTimeout != null) { + fieldValues.put("eos_trigger_timeout_ms", eosTriggerTimeout.getMillis()); + } + fieldValues.put("eos_trigger_num_elements", writeRecordsTransform.getEosTriggerNumElements()); fieldValues.put("num_shards", writeRecordsTransform.getNumShards()); if (writeRecordsTransform.getSinkGroupId() != null) { diff --git a/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java b/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java index 845e89b3b659..205884b2cb60 100644 --- a/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java +++ b/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java @@ -94,6 +94,8 @@ public class KafkaIOTranslationTest { WRITE_TRANSFORM_SCHEMA_MAPPING.put("getValueSerializer", "value_serializer"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("getPublishTimestampFunction", "publish_timestamp_fn"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("isEOS", "eos"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getEosTriggerTimeout", "eos_trigger_timeout_ms"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getEosTriggerNumElements", "eos_trigger_num_elements"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("getSinkGroupId", "sink_group_id"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("getNumShards", "num_shards"); WRITE_TRANSFORM_SCHEMA_MAPPING.put("getConsumerFactoryFn", "consumer_factory_fn");