Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,14 @@ public DoFn<InputT, OutputT>.StartBundleContext startBundleContext(DoFn<InputT,
public String getErrorContext() {
return "SimpleDoFnRunner/StartBundle";
}

@Override
public BundleFinalizer bundleFinalizer() {
return stepContext.bundleFinalizer();
}
}

/** An {@link DoFnInvoker.ArgumentProvider} for {@link DoFn.StartBundle @StartBundle}. */
/** An {@link DoFnInvoker.ArgumentProvider} for {@link DoFn.FinishBundle @FinishBundle}. */
private class DoFnFinishBundleArgumentProvider
extends DoFnInvoker.BaseArgumentProvider<InputT, OutputT> {
/** A concrete implementation of {@link DoFn.FinishBundleContext}. */
Expand Down Expand Up @@ -355,6 +360,11 @@ public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext(
public String getErrorContext() {
return "SimpleDoFnRunner/FinishBundle";
}

@Override
public BundleFinalizer bundleFinalizer() {
return stepContext.bundleFinalizer();
}
}

/**
Expand Down Expand Up @@ -1005,7 +1015,7 @@ public <T> void outputWindowedValue(
@Override
public BundleFinalizer bundleFinalizer() {
throw new UnsupportedOperationException(
"Bundle finalization is not supported in non-portable pipelines.");
"Bundle finalization is not supported in OnTimer calls.");
}
}

Expand Down Expand Up @@ -1259,7 +1269,7 @@ public <T> void outputWindowedValue(
@Override
public BundleFinalizer bundleFinalizer() {
throw new UnsupportedOperationException(
"Bundle finalization is not supported in non-portable pipelines.");
"Bundle finalization is not supported in OnWindowExpiration calls.");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2736,15 +2736,6 @@ static void verifyDoFnSupported(
"%s does not currently support @RequiresTimeSortedInput in streaming mode.",
DataflowRunner.class.getSimpleName()));
}
boolean isUnifiedWorker = useUnifiedWorker(options);
if (DoFnSignatures.usesBundleFinalizer(fn) && !isUnifiedWorker) {
throw new UnsupportedOperationException(
String.format(
"%s does not currently support %s when not using unified worker because it uses "
+ "BundleFinalizers in its implementation. Set the `--experiments=use_runner_v2` "
+ "option to use this DoFn.",
DataflowRunner.class.getSimpleName(), fn.getClass().getSimpleName()));
}
}

static void verifyStateSupportForWindowingStrategy(WindowingStrategy strategy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.api.services.dataflow.model.SideInputInfo;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
Expand Down Expand Up @@ -71,6 +72,7 @@
import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader;
import org.apache.beam.sdk.metrics.MetricsContainer;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.values.PCollectionView;
Expand All @@ -86,6 +88,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.PeekingIterator;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
import org.apache.commons.lang3.tuple.Pair;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.Instant;
Expand Down Expand Up @@ -444,11 +447,29 @@ public void invalidateCache() {
}
}

public Map<Long, Runnable> flushState() {
Map<Long, Runnable> callbacks = new HashMap<>();
public Map<Long, Pair<Instant, Runnable>> flushState() {
Map<Long, Pair<Instant, Runnable>> callbacks = new HashMap<>();

List<Pair<Instant, BundleFinalizer.Callback>> bundleFinalizers = new ArrayList<>();
for (StepContext stepContext : getAllStepContexts()) {
stepContext.flushState();
bundleFinalizers.addAll(stepContext.getBundleFinalizerCallbacks());
stepContext.clearBundleFinalizerCallbacks();
}
for (Pair<Instant, BundleFinalizer.Callback> bundleFinalizer : bundleFinalizers) {
long id = ThreadLocalRandom.current().nextLong();
callbacks.put(
id,
Pair.of(
bundleFinalizer.getLeft(),
() -> {
try {
bundleFinalizer.getRight().onBundleSuccess();
} catch (Exception e) {
throw new RuntimeException("Exception while running bundle finalizer", e);
}
}));
outputBuilder.addFinalizeIds(id);
}

if (activeReader != null) {
Expand All @@ -460,13 +481,15 @@ public Map<Long, Runnable> flushState() {
sourceStateBuilder.addFinalizeIds(id);
callbacks.put(
id,
() -> {
try {
checkpointMark.finalizeCheckpoint();
} catch (IOException e) {
throw new RuntimeException("Exception while finalizing checkpoint", e);
}
});
Pair.of(
Instant.now().plus(Duration.standardMinutes(5)),
() -> {
try {
checkpointMark.finalizeCheckpoint();
} catch (IOException e) {
throw new RuntimeException("Exception while finalizing checkpoint", e);
}
}));

@SuppressWarnings("unchecked")
Coder<UnboundedSource.CheckpointMark> checkpointCoder =
Expand Down Expand Up @@ -697,6 +720,11 @@ public <W extends BoundedWindow> void setStateCleanupTimer(
public DataflowStepContext namespacedToUser() {
return this;
}

@Override
public BundleFinalizer bundleFinalizer() {
return wrapped.bundleFinalizer();
}
}

/** A {@link SideInputReader} that fetches side inputs from the streaming worker's cache. */
Expand Down Expand Up @@ -769,6 +797,7 @@ class StepContext extends DataflowExecutionContext.DataflowStepContext
// A list of timer keys that were modified by user processing earlier in this bundle. This
// serves a tombstone, so that we know not to fire any bundle timers that were modified.
private Table<String, StateNamespace, TimerData> modifiedUserTimerKeys = null;
private final WindmillBundleFinalizer bundleFinalizer = new WindmillBundleFinalizer();

public StepContext(DataflowOperationContext operationContext) {
super(operationContext.nameContext());
Expand Down Expand Up @@ -1044,9 +1073,41 @@ public TimerInternals timerInternals() {
return checkNotNull(systemTimerInternals);
}

@Override
public BundleFinalizer bundleFinalizer() {
return bundleFinalizer;
}

public TimerInternals userTimerInternals() {
ensureStateful("Tried to access user timers");
return checkNotNull(userTimerInternals);
}

public List<Pair<Instant, BundleFinalizer.Callback>> getBundleFinalizerCallbacks() {
return bundleFinalizer.getCallbacks();
}

public void clearBundleFinalizerCallbacks() {
bundleFinalizer.clearCallbacks();
}
}

private static class WindmillBundleFinalizer implements BundleFinalizer {
private List<Pair<Instant, Callback>> callbacks = new ArrayList<>();

private WindmillBundleFinalizer() {}

private List<Pair<Instant, Callback>> getCallbacks() {
return callbacks;
}

private void clearCallbacks() {
callbacks.clear();
}

@Override
public void afterBundleCommit(Instant callbackExpiry, Callback callback) {
callbacks.add(Pair.of(callbackExpiry, callback));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ final class GetWorkResponseChunkAssembler {
private final WorkItem.Builder workItemBuilder; // Reused to reduce GC overhead.
private ByteString data;
private long bufferedSize;
private final List<Long> appliedFinalizeIds;

GetWorkResponseChunkAssembler() {
workTimingInfosTracker = new GetWorkTimingInfosTracker(System::currentTimeMillis);
data = ByteString.EMPTY;
bufferedSize = 0;
metadata = null;
workItemBuilder = WorkItem.newBuilder();
appliedFinalizeIds = new ArrayList<>();
}

/**
Expand All @@ -72,6 +74,7 @@ List<AssembledWorkItem> append(Windmill.StreamingGetWorkResponseChunk chunk) {
metadata = ComputationMetadata.fromProto(chunk.getComputationMetadata());
}
workTimingInfosTracker.addTimingInfo(chunk.getPerWorkItemTimingInfosList());
appliedFinalizeIds.addAll(chunk.getAppliedFinalizeIdsList());

List<AssembledWorkItem> response = new ArrayList<>();
for (int i = 0; i < chunk.getSerializedWorkItemList().size(); i++) {
Expand All @@ -90,7 +93,7 @@ List<AssembledWorkItem> append(Windmill.StreamingGetWorkResponseChunk chunk) {
}

/**
* Attempt to flush the {@link #data} bytes into a {@link WorkItem} w/ it's metadata. Resets the
* Attempt to flush the {@link #data} bytes into a {@link WorkItem} w/ its metadata. Resets the
* data byte string and tracking metadata afterwards, whether the {@link WorkItem} deserialization
* was successful or not.
*/
Expand All @@ -102,14 +105,16 @@ private Optional<AssembledWorkItem> flushToWorkItem() {
workItemBuilder.build(),
Preconditions.checkNotNull(metadata),
workTimingInfosTracker.getLatencyAttributions(),
bufferedSize));
bufferedSize,
appliedFinalizeIds));
} catch (IOException e) {
LOG.error("Failed to parse work item from stream: ", e);
} finally {
workItemBuilder.clear();
workTimingInfosTracker.reset();
data = ByteString.EMPTY;
bufferedSize = 0;
appliedFinalizeIds.clear();
}

return Optional.empty();
Expand Down Expand Up @@ -144,7 +149,9 @@ private static AssembledWorkItem create(
WorkItem workItem,
ComputationMetadata computationMetadata,
ImmutableList<LatencyAttribution> latencyAttributions,
long size) {
long size,
List<Long> appliedFinalizeIds) {
workItem = workItem.toBuilder().addAllAppliedFinalizeIds(appliedFinalizeIds).build();
return new AutoValue_GetWorkResponseChunkAssembler_AssembledWorkItem(
workItem, computationMetadata, latencyAttributions, size);
}
Expand Down
Loading
Loading