diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java index d3e9eddf75aa..186b0927f160 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java @@ -23,16 +23,17 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; +import java.util.HashMap; import java.util.PriorityQueue; -import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; -import org.apache.beam.sdk.values.TimestampedValue; import org.joda.time.Duration; import org.joda.time.Instant; @@ -45,14 +46,11 @@ *

See Apache Beam Portability API: How to * Finalize Bundles for further details. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class FinalizeBundleHandler { /** A {@link BundleFinalizer.Callback} and expiry time pair. */ @AutoValue - abstract static class CallbackRegistration { + public abstract static class CallbackRegistration { public static CallbackRegistration create( Instant expiryTime, BundleFinalizer.Callback callback) { return new AutoValue_FinalizeBundleHandler_CallbackRegistration(expiryTime, callback); @@ -63,77 +61,100 @@ public static CallbackRegistration create( public abstract BundleFinalizer.Callback getCallback(); } - private final ConcurrentMap> bundleFinalizationCallbacks; - private final PriorityQueue> cleanUpQueue; + private static class FinalizationInfo { + FinalizationInfo( + String id, Instant expiryTimestamp, Collection callbacks) { + this.id = id; + this.expiryTimestamp = expiryTimestamp; + this.callbacks = callbacks; + } + + final String id; + final Instant expiryTimestamp; + final Collection callbacks; + + Instant getExpiryTimestamp() { + return expiryTimestamp; + } + } + + private final ReentrantLock lock = new ReentrantLock(); + private final Condition queueMinChanged = lock.newCondition(); - @SuppressWarnings("unused") - private final Future cleanUpResult; + @GuardedBy("lock") + private final HashMap bundleFinalizationCallbacks; + @GuardedBy("lock") + private final PriorityQueue cleanUpQueue; + + @SuppressWarnings("methodref.receiver.bound") public FinalizeBundleHandler(ExecutorService executorService) { - this.bundleFinalizationCallbacks = new ConcurrentHashMap<>(); + this.bundleFinalizationCallbacks = new HashMap<>(); this.cleanUpQueue = - new PriorityQueue<>(11, Comparator.comparing(TimestampedValue::getTimestamp)); - // Wait until we have at least one element. We are notified on each element - // being added. - // Wait until the current time has past the expiry time for the head of the - // queue. - // We are notified on each element being added. - // Wait until we have at least one element. We are notified on each element - // being added. - // Wait until the current time has past the expiry time for the head of the - // queue. - // We are notified on each element being added. - cleanUpResult = - executorService.submit( - (Callable) - () -> { - while (true) { - synchronized (cleanUpQueue) { - TimestampedValue expiryTime = cleanUpQueue.peek(); - - // Wait until we have at least one element. We are notified on each element - // being added. - while (expiryTime == null) { - cleanUpQueue.wait(); - expiryTime = cleanUpQueue.peek(); - } - - // Wait until the current time has past the expiry time for the head of the - // queue. - // We are notified on each element being added. - Instant now = Instant.now(); - while (expiryTime.getTimestamp().isAfter(now)) { - Duration timeDifference = new Duration(now, expiryTime.getTimestamp()); - cleanUpQueue.wait(timeDifference.getMillis()); - expiryTime = cleanUpQueue.peek(); - now = Instant.now(); - } - - bundleFinalizationCallbacks.remove(cleanUpQueue.poll().getValue()); - } - } - }); + new PriorityQueue<>(11, Comparator.comparing(FinalizationInfo::getExpiryTimestamp)); + executorService.execute(this::cleanupThreadBody); + } + + private void cleanupThreadBody() { + lock.lock(); + try { + while (true) { + final @Nullable FinalizationInfo minValue = cleanUpQueue.peek(); + if (minValue == null) { + // Wait for an element to be added and loop to re-examine the min. + queueMinChanged.await(); + continue; + } + + Instant now = Instant.now(); + Duration timeDifference = new Duration(now, minValue.expiryTimestamp); + if (timeDifference.getMillis() < 0 + || (queueMinChanged.await(timeDifference.getMillis(), TimeUnit.MILLISECONDS) + && cleanUpQueue.peek() == minValue)) { + // The minimum element has an expiry time before now, either because it had elapsed when + // we pulled it or because we awaited it and it is still the minimum. + checkState(minValue == cleanUpQueue.poll()); + checkState(bundleFinalizationCallbacks.remove(minValue.id) == minValue); + } + } + } catch (InterruptedException e) { + // We're being shutdown. + } finally { + lock.unlock(); + } } public void registerCallbacks(String bundleId, Collection callbacks) { if (callbacks.isEmpty()) { return; } - - Collection priorCallbacks = - bundleFinalizationCallbacks.putIfAbsent(bundleId, callbacks); - checkState( - priorCallbacks == null, - "Expected to not have any past callbacks for bundle %s but found %s.", - bundleId, - priorCallbacks); - long expiryTimeMillis = Long.MIN_VALUE; + Instant maxExpiryTime = Instant.EPOCH; for (CallbackRegistration callback : callbacks) { - expiryTimeMillis = Math.max(expiryTimeMillis, callback.getExpiryTime().getMillis()); + Instant callbackExpiry = callback.getExpiryTime(); + if (callbackExpiry.isAfter(maxExpiryTime)) { + maxExpiryTime = callbackExpiry; + } } - synchronized (cleanUpQueue) { - cleanUpQueue.offer(TimestampedValue.of(bundleId, new Instant(expiryTimeMillis))); - cleanUpQueue.notify(); + final FinalizationInfo info = new FinalizationInfo(bundleId, maxExpiryTime, callbacks); + + lock.lock(); + try { + FinalizationInfo existingInfo = bundleFinalizationCallbacks.put(bundleId, info); + if (existingInfo != null) { + throw new IllegalStateException( + "Expected to not have any past callbacks for bundle " + + bundleId + + " but had " + + existingInfo.callbacks); + } + cleanUpQueue.add(info); + @SuppressWarnings("ReferenceEquality") + boolean newMin = cleanUpQueue.peek() == info; + if (newMin) { + queueMinChanged.signal(); + } + } finally { + lock.unlock(); } } @@ -141,16 +162,24 @@ public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.Instructio throws Exception { String bundleId = request.getFinalizeBundle().getInstructionId(); - Collection callbacks = bundleFinalizationCallbacks.remove(bundleId); - - if (callbacks == null) { + @Nullable FinalizationInfo info; + lock.lock(); + try { + info = bundleFinalizationCallbacks.remove(bundleId); + if (info != null) { + checkState(cleanUpQueue.remove(info)); + } + } finally { + lock.unlock(); + } + if (info == null) { // We have already processed the callbacks on a prior bundle finalization attempt return BeamFnApi.InstructionResponse.newBuilder() .setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance()); } Collection failures = new ArrayList<>(); - for (CallbackRegistration callback : callbacks) { + for (CallbackRegistration callback : info.callbacks) { try { callback.getCallback().onBundleSuccess(); } catch (Exception e) { @@ -170,4 +199,13 @@ public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.Instructio return BeamFnApi.InstructionResponse.newBuilder() .setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance()); } + + int cleanupQueueSize() { + lock.lock(); + try { + return cleanUpQueue.size(); + } finally { + lock.unlock(); + } + } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java index a760d22b78af..136222a2f017 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/FinalizeBundleHandlerTest.java @@ -22,9 +22,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration; @@ -32,6 +37,7 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Test; @@ -106,6 +112,49 @@ public void testFinalizationContinuesToNextCallbackEvenInFailure() throws Except } } + @Test + public void testCallbackExpiration() throws Exception { + ExecutorService executor = Executors.newCachedThreadPool(); + FinalizeBundleHandler handler = new FinalizeBundleHandler(executor); + BundleFinalizer.Callback callback = mock(BundleFinalizer.Callback.class); + handler.registerCallbacks( + "test", + Collections.singletonList( + CallbackRegistration.create(Instant.now().plus(Duration.standardHours(1)), callback))); + assertEquals(1, handler.cleanupQueueSize()); + + BundleFinalizer.Callback callback2 = mock(BundleFinalizer.Callback.class); + handler.registerCallbacks( + "test2", + Collections.singletonList( + CallbackRegistration.create(Instant.now().plus(Duration.millis(100)), callback2))); + BundleFinalizer.Callback callback3 = mock(BundleFinalizer.Callback.class); + handler.registerCallbacks( + "test3", + Collections.singletonList( + CallbackRegistration.create(Instant.now().plus(Duration.millis(1)), callback3))); + while (handler.cleanupQueueSize() > 1) { + Thread.sleep(500); + } + // Just the "test" bundle should remain as "test2" and "test3" should have timed out. + assertEquals(1, handler.cleanupQueueSize()); + // Completing test2 and test3 should have successful response but not invoke the callbacks + // as they were cleaned up. + assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test2")).build()); + verifyNoMoreInteractions(callback2); + assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test3")).build()); + verifyNoMoreInteractions(callback3); + // Completing "test" bundle should call the callback and remove it from cleanup queue. + assertEquals(1, handler.cleanupQueueSize()); + assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build()); + verify(callback).onBundleSuccess(); + assertEquals(0, handler.cleanupQueueSize()); + // Verify that completing again is a no-op as it was cleaned up. + assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build()); + verifyNoMoreInteractions(callback); + executor.shutdownNow(); + } + private static InstructionRequest requestFor(String bundleId) { return InstructionRequest.newBuilder() .setInstructionId(INSTRUCTION_ID)