Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.AfterFirst;
import org.apache.beam.sdk.transforms.windowing.AfterPane;
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
import org.apache.beam.sdk.transforms.windowing.Trigger;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.sdk.values.KV;
Expand Down Expand Up @@ -84,6 +87,7 @@
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.DateTimeUtils;
import org.joda.time.DateTimeZone;
import org.joda.time.Duration;
import org.joda.time.format.DateTimeFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -159,7 +163,8 @@ static void ensureEOSSupport() {
@Override
public PCollection<Void> expand(PCollection<ProducerRecord<K, V>> input) {
String topic = Preconditions.checkStateNotNull(spec.getTopic());

int numElements = spec.getEosTriggerNumElements();
Duration timeout = spec.getEosTriggerTimeout();
int numShards = spec.getNumShards();
if (numShards <= 0) {
try (Consumer<?, ?> consumer = openConsumer(spec)) {
Expand All @@ -172,17 +177,34 @@ public PCollection<Void> expand(PCollection<ProducerRecord<K, V>> input) {
}
}
checkState(numShards > 0, "Could not set number of shards");

Trigger.OnceTrigger trigger = null;
if (timeout != null) {
trigger =
AfterFirst.of(
AfterPane.elementCountAtLeast(numElements),
AfterProcessingTime.pastFirstElementInPane().plusDelayOf(timeout));
} else {
// fallback to default
trigger = AfterPane.elementCountAtLeast(numElements);
}
return input
.apply(
Window.<ProducerRecord<K, V>>into(new GlobalWindows()) // Everything into global window.
.triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1)))
.triggering(Repeatedly.forever(trigger))
.discardingFiredPanes())
.apply(
String.format("Shuffle across %d shards", numShards),
ParDo.of(new Reshard<>(numShards)))
.apply("Persist sharding", GroupByKey.create())
.apply("Assign sequential ids", ParDo.of(new Sequencer<>()))
// Reapply the windowing configuration as the continuation trigger doesn't maintain the
// desired batching.
.apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed or is the the windowing inherited from above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It didn't work without repeating triggering definition

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we discussed, it looks like this is because the continuation trigger changes it to aftercount(1)

can you add a comment here
// Reapply the windowing configuration as the continuation trigger doesn't maintain the desired batching.

"Windowing",
Window.<KV<Integer, KV<Long, TimestampedValue<ProducerRecord<K, V>>>>>into(
new GlobalWindows()) // Everything into global window.
.triggering(Repeatedly.forever(trigger))
.discardingFiredPanes())
.apply("Persist ids", GroupByKey.create())
.apply(
String.format("Write to Kafka topic '%s'", spec.getTopic()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ public static <K, V> WriteRecords<K, V> writeRecords() {
return new AutoValue_KafkaIO_WriteRecords.Builder<K, V>()
.setProducerConfig(WriteRecords.DEFAULT_PRODUCER_PROPERTIES)
.setEOS(false)
.setEosTriggerNumElements(1) // keep default numElements
.setEosTriggerTimeout(null) // keep default trigger (timeout)
.setNumShards(0)
.setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN)
.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
Expand Down Expand Up @@ -3185,6 +3187,10 @@ public abstract static class WriteRecords<K, V>
@Pure
public abstract boolean isEOS();

public abstract int getEosTriggerNumElements();

public abstract @Nullable Duration getEosTriggerTimeout();

@Pure
public abstract @Nullable String getSinkGroupId();

Expand Down Expand Up @@ -3221,6 +3227,10 @@ abstract Builder<K, V> setPublishTimestampFunction(

abstract Builder<K, V> setEOS(boolean eosEnabled);

abstract Builder<K, V> setEosTriggerNumElements(int numElements);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would name these more narrowly. Theres lots of potential numElements and timeouts in the scope of KafkaIO

abstract Builder<K, V> setEosTriggerTimeout(@Nullable Duration timeout);

abstract Builder<K, V> setSinkGroupId(String sinkGroupId);

abstract Builder<K, V> setNumShards(int numShards);
Expand Down Expand Up @@ -3368,6 +3378,15 @@ public WriteRecords<K, V> withEOS(int numShards, String sinkGroupId) {
return toBuilder().setEOS(true).setNumShards(numShards).setSinkGroupId(sinkGroupId).build();
}

public WriteRecords<K, V> withEOSTriggerConfig(int numElements, Duration timeout) {
checkArgument(numElements >= 1, "numElements should be >= 1");
checkArgument(timeout != null, "timeout is required for exactly-once sink");
return toBuilder()
.setEosTriggerNumElements(numElements)
.setEosTriggerTimeout(timeout)
.build();
}

/**
* When exactly-once semantics are enabled (see {@link #withEOS(int, String)}), the sink needs
* to fetch previously stored state with Kafka topic. Fetching the metadata requires a consumer.
Expand Down Expand Up @@ -3653,6 +3672,19 @@ public Write<K, V> withEOS(int numShards, String sinkGroupId) {
return withWriteRecordsTransform(getWriteRecordsTransform().withEOS(numShards, sinkGroupId));
}

/**
* Set the frequency and numElements threshold at which messages are triggered.
*
* <p>This is only applicable when the write method is set to EOS.
*
* <p>Every timeout duration, or numElements (repeated, after first condition is met) collection
* of elements written.
*/
public Write<K, V> withEOSTriggerConfig(int numElements, Duration timeout) {
return withWriteRecordsTransform(
getWriteRecordsTransform().withEOSTriggerConfig(numElements, timeout));
}

/**
* Wrapper method over {@link WriteRecords#withConsumerFactoryFn(SerializableFunction)}, used to
* keep the compatibility with old API based on KV type of element.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ static class KafkaIOWriteTranslator implements TransformPayloadTranslator<Write<
.addNullableByteArrayField("producer_factory_fn")
.addNullableByteArrayField("publish_timestamp_fn")
.addBooleanField("eos")
.addInt32Field("eos_trigger_num_elements")
.addNullableInt64Field("eos_trigger_timeout_ms")
.addInt32Field("num_shards")
.addNullableStringField("sink_group_id")
.addNullableByteArrayField("consumer_factory_fn")
Expand Down Expand Up @@ -547,6 +549,11 @@ public Row toConfigRow(Write<?, ?> transform) {
}

fieldValues.put("eos", writeRecordsTransform.isEOS());
org.joda.time.Duration eosTriggerTimeout = writeRecordsTransform.getEosTriggerTimeout();
if (eosTriggerTimeout != null) {
fieldValues.put("eos_trigger_timeout_ms", eosTriggerTimeout.getMillis());
}
fieldValues.put("eos_trigger_num_elements", writeRecordsTransform.getEosTriggerNumElements());
fieldValues.put("num_shards", writeRecordsTransform.getNumShards());

if (writeRecordsTransform.getSinkGroupId() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ public class KafkaIOTranslationTest {
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getValueSerializer", "value_serializer");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getPublishTimestampFunction", "publish_timestamp_fn");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("isEOS", "eos");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getEosTriggerTimeout", "eos_trigger_timeout_ms");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getEosTriggerNumElements", "eos_trigger_num_elements");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getSinkGroupId", "sink_group_id");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getNumShards", "num_shards");
WRITE_TRANSFORM_SCHEMA_MAPPING.put("getConsumerFactoryFn", "consumer_factory_fn");
Expand Down
Loading