diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java
index d3e9eddf75aa..186b0927f160 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java
@@ -23,16 +23,17 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
+import java.util.HashMap;
import java.util.PriorityQueue;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse;
import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
-import org.apache.beam.sdk.values.TimestampedValue;
import org.joda.time.Duration;
import org.joda.time.Instant;
@@ -45,14 +46,11 @@
*
See Apache Beam Portability API: How to
* Finalize Bundles for further details.
*/
-@SuppressWarnings({
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
public class FinalizeBundleHandler {
/** A {@link BundleFinalizer.Callback} and expiry time pair. */
@AutoValue
- abstract static class CallbackRegistration {
+ public abstract static class CallbackRegistration {
public static CallbackRegistration create(
Instant expiryTime, BundleFinalizer.Callback callback) {
return new AutoValue_FinalizeBundleHandler_CallbackRegistration(expiryTime, callback);
@@ -63,77 +61,100 @@ public static CallbackRegistration create(
public abstract BundleFinalizer.Callback getCallback();
}
- private final ConcurrentMap> bundleFinalizationCallbacks;
- private final PriorityQueue> cleanUpQueue;
+ private static class FinalizationInfo {
+ FinalizationInfo(
+ String id, Instant expiryTimestamp, Collection callbacks) {
+ this.id = id;
+ this.expiryTimestamp = expiryTimestamp;
+ this.callbacks = callbacks;
+ }
+
+ final String id;
+ final Instant expiryTimestamp;
+ final Collection callbacks;
+
+ Instant getExpiryTimestamp() {
+ return expiryTimestamp;
+ }
+ }
+
+ private final ReentrantLock lock = new ReentrantLock();
+ private final Condition queueMinChanged = lock.newCondition();
- @SuppressWarnings("unused")
- private final Future cleanUpResult;
+ @GuardedBy("lock")
+ private final HashMap bundleFinalizationCallbacks;
+ @GuardedBy("lock")
+ private final PriorityQueue cleanUpQueue;
+
+ @SuppressWarnings("methodref.receiver.bound")
public FinalizeBundleHandler(ExecutorService executorService) {
- this.bundleFinalizationCallbacks = new ConcurrentHashMap<>();
+ this.bundleFinalizationCallbacks = new HashMap<>();
this.cleanUpQueue =
- new PriorityQueue<>(11, Comparator.comparing(TimestampedValue::getTimestamp));
- // Wait until we have at least one element. We are notified on each element
- // being added.
- // Wait until the current time has past the expiry time for the head of the
- // queue.
- // We are notified on each element being added.
- // Wait until we have at least one element. We are notified on each element
- // being added.
- // Wait until the current time has past the expiry time for the head of the
- // queue.
- // We are notified on each element being added.
- cleanUpResult =
- executorService.submit(
- (Callable)
- () -> {
- while (true) {
- synchronized (cleanUpQueue) {
- TimestampedValue expiryTime = cleanUpQueue.peek();
-
- // Wait until we have at least one element. We are notified on each element
- // being added.
- while (expiryTime == null) {
- cleanUpQueue.wait();
- expiryTime = cleanUpQueue.peek();
- }
-
- // Wait until the current time has past the expiry time for the head of the
- // queue.
- // We are notified on each element being added.
- Instant now = Instant.now();
- while (expiryTime.getTimestamp().isAfter(now)) {
- Duration timeDifference = new Duration(now, expiryTime.getTimestamp());
- cleanUpQueue.wait(timeDifference.getMillis());
- expiryTime = cleanUpQueue.peek();
- now = Instant.now();
- }
-
- bundleFinalizationCallbacks.remove(cleanUpQueue.poll().getValue());
- }
- }
- });
+ new PriorityQueue<>(11, Comparator.comparing(FinalizationInfo::getExpiryTimestamp));
+ executorService.execute(this::cleanupThreadBody);
+ }
+
+ private void cleanupThreadBody() {
+ lock.lock();
+ try {
+ while (true) {
+ final @Nullable FinalizationInfo minValue = cleanUpQueue.peek();
+ if (minValue == null) {
+ // Wait for an element to be added and loop to re-examine the min.
+ queueMinChanged.await();
+ continue;
+ }
+
+ Instant now = Instant.now();
+ Duration timeDifference = new Duration(now, minValue.expiryTimestamp);
+ if (timeDifference.getMillis() < 0
+ || (queueMinChanged.await(timeDifference.getMillis(), TimeUnit.MILLISECONDS)
+ && cleanUpQueue.peek() == minValue)) {
+ // The minimum element has an expiry time before now, either because it had elapsed when
+ // we pulled it or because we awaited it and it is still the minimum.
+ checkState(minValue == cleanUpQueue.poll());
+ checkState(bundleFinalizationCallbacks.remove(minValue.id) == minValue);
+ }
+ }
+ } catch (InterruptedException e) {
+ // We're being shutdown.
+ } finally {
+ lock.unlock();
+ }
}
public void registerCallbacks(String bundleId, Collection callbacks) {
if (callbacks.isEmpty()) {
return;
}
-
- Collection priorCallbacks =
- bundleFinalizationCallbacks.putIfAbsent(bundleId, callbacks);
- checkState(
- priorCallbacks == null,
- "Expected to not have any past callbacks for bundle %s but found %s.",
- bundleId,
- priorCallbacks);
- long expiryTimeMillis = Long.MIN_VALUE;
+ Instant maxExpiryTime = Instant.EPOCH;
for (CallbackRegistration callback : callbacks) {
- expiryTimeMillis = Math.max(expiryTimeMillis, callback.getExpiryTime().getMillis());
+ Instant callbackExpiry = callback.getExpiryTime();
+ if (callbackExpiry.isAfter(maxExpiryTime)) {
+ maxExpiryTime = callbackExpiry;
+ }
}
- synchronized (cleanUpQueue) {
- cleanUpQueue.offer(TimestampedValue.of(bundleId, new Instant(expiryTimeMillis)));
- cleanUpQueue.notify();
+ final FinalizationInfo info = new FinalizationInfo(bundleId, maxExpiryTime, callbacks);
+
+ lock.lock();
+ try {
+ FinalizationInfo existingInfo = bundleFinalizationCallbacks.put(bundleId, info);
+ if (existingInfo != null) {
+ throw new IllegalStateException(
+ "Expected to not have any past callbacks for bundle "
+ + bundleId
+ + " but had "
+ + existingInfo.callbacks);
+ }
+ cleanUpQueue.add(info);
+ @SuppressWarnings("ReferenceEquality")
+ boolean newMin = cleanUpQueue.peek() == info;
+ if (newMin) {
+ queueMinChanged.signal();
+ }
+ } finally {
+ lock.unlock();
}
}
@@ -141,16 +162,24 @@ public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.Instructio
throws Exception {
String bundleId = request.getFinalizeBundle().getInstructionId();
- Collection callbacks = bundleFinalizationCallbacks.remove(bundleId);
-
- if (callbacks == null) {
+ @Nullable FinalizationInfo info;
+ lock.lock();
+ try {
+ info = bundleFinalizationCallbacks.remove(bundleId);
+ if (info != null) {
+ checkState(cleanUpQueue.remove(info));
+ }
+ } finally {
+ lock.unlock();
+ }
+ if (info == null) {
// We have already processed the callbacks on a prior bundle finalization attempt
return BeamFnApi.InstructionResponse.newBuilder()
.setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance());
}
Collection failures = new ArrayList<>();
- for (CallbackRegistration callback : callbacks) {
+ for (CallbackRegistration callback : info.callbacks) {
try {
callback.getCallback().onBundleSuccess();
} catch (Exception e) {
@@ -170,4 +199,13 @@ public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.Instructio
return BeamFnApi.InstructionResponse.newBuilder()
.setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance());
}
+
+ int cleanupQueueSize() {
+ lock.lock();
+ try {
+ return cleanUpQueue.size();
+ } finally {
+ lock.unlock();
+ }
+ }
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java
index a760d22b78af..136222a2f017 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java
@@ -22,9 +22,14 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
+import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration;
@@ -32,6 +37,7 @@
import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse;
+import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Test;
@@ -106,6 +112,49 @@ public void testFinalizationContinuesToNextCallbackEvenInFailure() throws Except
}
}
+ @Test
+ public void testCallbackExpiration() throws Exception {
+ ExecutorService executor = Executors.newCachedThreadPool();
+ FinalizeBundleHandler handler = new FinalizeBundleHandler(executor);
+ BundleFinalizer.Callback callback = mock(BundleFinalizer.Callback.class);
+ handler.registerCallbacks(
+ "test",
+ Collections.singletonList(
+ CallbackRegistration.create(Instant.now().plus(Duration.standardHours(1)), callback)));
+ assertEquals(1, handler.cleanupQueueSize());
+
+ BundleFinalizer.Callback callback2 = mock(BundleFinalizer.Callback.class);
+ handler.registerCallbacks(
+ "test2",
+ Collections.singletonList(
+ CallbackRegistration.create(Instant.now().plus(Duration.millis(100)), callback2)));
+ BundleFinalizer.Callback callback3 = mock(BundleFinalizer.Callback.class);
+ handler.registerCallbacks(
+ "test3",
+ Collections.singletonList(
+ CallbackRegistration.create(Instant.now().plus(Duration.millis(1)), callback3)));
+ while (handler.cleanupQueueSize() > 1) {
+ Thread.sleep(500);
+ }
+ // Just the "test" bundle should remain as "test2" and "test3" should have timed out.
+ assertEquals(1, handler.cleanupQueueSize());
+ // Completing test2 and test3 should have successful response but not invoke the callbacks
+ // as they were cleaned up.
+ assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test2")).build());
+ verifyNoMoreInteractions(callback2);
+ assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test3")).build());
+ verifyNoMoreInteractions(callback3);
+ // Completing "test" bundle should call the callback and remove it from cleanup queue.
+ assertEquals(1, handler.cleanupQueueSize());
+ assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build());
+ verify(callback).onBundleSuccess();
+ assertEquals(0, handler.cleanupQueueSize());
+ // Verify that completing again is a no-op as it was cleaned up.
+ assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build());
+ verifyNoMoreInteractions(callback);
+ executor.shutdownNow();
+ }
+
private static InstructionRequest requestFor(String bundleId) {
return InstructionRequest.newBuilder()
.setInstructionId(INSTRUCTION_ID)