Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.PriorityQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse;
import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
import org.apache.beam.sdk.values.TimestampedValue;
import org.joda.time.Duration;
import org.joda.time.Instant;

Expand All @@ -45,14 +46,11 @@
* <p>See <a href="https://s.apache.org/beam-finalizing-bundles">Apache Beam Portability API: How to
* Finalize Bundles</a> for further details.
*/
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class FinalizeBundleHandler {

/** A {@link BundleFinalizer.Callback} and expiry time pair. */
@AutoValue
abstract static class CallbackRegistration {
public abstract static class CallbackRegistration {
public static CallbackRegistration create(
Instant expiryTime, BundleFinalizer.Callback callback) {
return new AutoValue_FinalizeBundleHandler_CallbackRegistration(expiryTime, callback);
Expand All @@ -63,94 +61,125 @@ public static CallbackRegistration create(
public abstract BundleFinalizer.Callback getCallback();
}

private final ConcurrentMap<String, Collection<CallbackRegistration>> bundleFinalizationCallbacks;
private final PriorityQueue<TimestampedValue<String>> cleanUpQueue;
private static class FinalizationInfo {
FinalizationInfo(
String id, Instant expiryTimestamp, Collection<CallbackRegistration> callbacks) {
this.id = id;
this.expiryTimestamp = expiryTimestamp;
this.callbacks = callbacks;
}

final String id;
final Instant expiryTimestamp;
final Collection<CallbackRegistration> callbacks;

Instant getExpiryTimestamp() {
return expiryTimestamp;
}
}

private final ReentrantLock lock = new ReentrantLock();
private final Condition queueMinChanged = lock.newCondition();

@SuppressWarnings("unused")
private final Future<Void> cleanUpResult;
@GuardedBy("lock")
private final HashMap<String, FinalizationInfo> bundleFinalizationCallbacks;

@GuardedBy("lock")
private final PriorityQueue<FinalizationInfo> cleanUpQueue;

@SuppressWarnings("methodref.receiver.bound")
public FinalizeBundleHandler(ExecutorService executorService) {
this.bundleFinalizationCallbacks = new ConcurrentHashMap<>();
this.bundleFinalizationCallbacks = new HashMap<>();
this.cleanUpQueue =
new PriorityQueue<>(11, Comparator.comparing(TimestampedValue::getTimestamp));
// Wait until we have at least one element. We are notified on each element
// being added.
// Wait until the current time has past the expiry time for the head of the
// queue.
// We are notified on each element being added.
// Wait until we have at least one element. We are notified on each element
// being added.
// Wait until the current time has past the expiry time for the head of the
// queue.
// We are notified on each element being added.
cleanUpResult =
executorService.submit(
(Callable<Void>)
() -> {
while (true) {
synchronized (cleanUpQueue) {
TimestampedValue<String> expiryTime = cleanUpQueue.peek();

// Wait until we have at least one element. We are notified on each element
// being added.
while (expiryTime == null) {
cleanUpQueue.wait();
expiryTime = cleanUpQueue.peek();
}

// Wait until the current time has past the expiry time for the head of the
// queue.
// We are notified on each element being added.
Instant now = Instant.now();
while (expiryTime.getTimestamp().isAfter(now)) {
Duration timeDifference = new Duration(now, expiryTime.getTimestamp());
cleanUpQueue.wait(timeDifference.getMillis());
expiryTime = cleanUpQueue.peek();
now = Instant.now();
}

bundleFinalizationCallbacks.remove(cleanUpQueue.poll().getValue());
}
}
});
new PriorityQueue<>(11, Comparator.comparing(FinalizationInfo::getExpiryTimestamp));
executorService.execute(this::cleanupThreadBody);
}

private void cleanupThreadBody() {
lock.lock();
try {
while (true) {
final @Nullable FinalizationInfo minValue = cleanUpQueue.peek();
if (minValue == null) {
// Wait for an element to be added and loop to re-examine the min.
queueMinChanged.await();
continue;
}

Instant now = Instant.now();
Duration timeDifference = new Duration(now, minValue.expiryTimestamp);
if (timeDifference.getMillis() < 0
|| (queueMinChanged.await(timeDifference.getMillis(), TimeUnit.MILLISECONDS)
&& cleanUpQueue.peek() == minValue)) {
// The minimum element has an expiry time before now, either because it had elapsed when
// we pulled it or because we awaited it and it is still the minimum.
checkState(minValue == cleanUpQueue.poll());
checkState(bundleFinalizationCallbacks.remove(minValue.id) == minValue);
}
}
} catch (InterruptedException e) {
// We're being shutdown.
} finally {
lock.unlock();
}
}

public void registerCallbacks(String bundleId, Collection<CallbackRegistration> callbacks) {
if (callbacks.isEmpty()) {
return;
}

Collection<CallbackRegistration> priorCallbacks =
bundleFinalizationCallbacks.putIfAbsent(bundleId, callbacks);
checkState(
priorCallbacks == null,
"Expected to not have any past callbacks for bundle %s but found %s.",
bundleId,
priorCallbacks);
long expiryTimeMillis = Long.MIN_VALUE;
Instant maxExpiryTime = Instant.EPOCH;
for (CallbackRegistration callback : callbacks) {
expiryTimeMillis = Math.max(expiryTimeMillis, callback.getExpiryTime().getMillis());
Instant callbackExpiry = callback.getExpiryTime();
if (callbackExpiry.isAfter(maxExpiryTime)) {
maxExpiryTime = callbackExpiry;
}
}
synchronized (cleanUpQueue) {
cleanUpQueue.offer(TimestampedValue.of(bundleId, new Instant(expiryTimeMillis)));
cleanUpQueue.notify();
final FinalizationInfo info = new FinalizationInfo(bundleId, maxExpiryTime, callbacks);

lock.lock();
try {
FinalizationInfo existingInfo = bundleFinalizationCallbacks.put(bundleId, info);
if (existingInfo != null) {
throw new IllegalStateException(
"Expected to not have any past callbacks for bundle "
+ bundleId
+ " but had "
+ existingInfo.callbacks);
}
cleanUpQueue.add(info);
@SuppressWarnings("ReferenceEquality")
boolean newMin = cleanUpQueue.peek() == info;
if (newMin) {
queueMinChanged.signal();
}
} finally {
lock.unlock();
}
}

public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.InstructionRequest request)
throws Exception {
String bundleId = request.getFinalizeBundle().getInstructionId();

Collection<CallbackRegistration> callbacks = bundleFinalizationCallbacks.remove(bundleId);

if (callbacks == null) {
@Nullable FinalizationInfo info;
lock.lock();
try {
info = bundleFinalizationCallbacks.remove(bundleId);
if (info != null) {
checkState(cleanUpQueue.remove(info));
}
} finally {
lock.unlock();
}
if (info == null) {
// We have already processed the callbacks on a prior bundle finalization attempt
return BeamFnApi.InstructionResponse.newBuilder()
.setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance());
}

Collection<Exception> failures = new ArrayList<>();
for (CallbackRegistration callback : callbacks) {
for (CallbackRegistration callback : info.callbacks) {
try {
callback.getCallback().onBundleSuccess();
} catch (Exception e) {
Expand All @@ -170,4 +199,13 @@ public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.Instructio
return BeamFnApi.InstructionResponse.newBuilder()
.setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance());
}

int cleanupQueueSize() {
lock.lock();
try {
return cleanUpQueue.size();
} finally {
lock.unlock();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.beam.fn.harness.control.FinalizeBundleHandler.CallbackRegistration;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse;
import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Test;
Expand Down Expand Up @@ -106,6 +112,49 @@ public void testFinalizationContinuesToNextCallbackEvenInFailure() throws Except
}
}

@Test
public void testCallbackExpiration() throws Exception {
ExecutorService executor = Executors.newCachedThreadPool();
FinalizeBundleHandler handler = new FinalizeBundleHandler(executor);
BundleFinalizer.Callback callback = mock(BundleFinalizer.Callback.class);
handler.registerCallbacks(
"test",
Collections.singletonList(
CallbackRegistration.create(Instant.now().plus(Duration.standardHours(1)), callback)));
assertEquals(1, handler.cleanupQueueSize());

BundleFinalizer.Callback callback2 = mock(BundleFinalizer.Callback.class);
handler.registerCallbacks(
"test2",
Collections.singletonList(
CallbackRegistration.create(Instant.now().plus(Duration.millis(100)), callback2)));
BundleFinalizer.Callback callback3 = mock(BundleFinalizer.Callback.class);
handler.registerCallbacks(
"test3",
Collections.singletonList(
CallbackRegistration.create(Instant.now().plus(Duration.millis(1)), callback3)));
while (handler.cleanupQueueSize() > 1) {
Thread.sleep(500);
}
// Just the "test" bundle should remain as "test2" and "test3" should have timed out.
assertEquals(1, handler.cleanupQueueSize());
// Completing test2 and test3 should have successful response but not invoke the callbacks
// as they were cleaned up.
assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test2")).build());
verifyNoMoreInteractions(callback2);
assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test3")).build());
verifyNoMoreInteractions(callback3);
// Completing "test" bundle should call the callback and remove it from cleanup queue.
assertEquals(1, handler.cleanupQueueSize());
assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build());
verify(callback).onBundleSuccess();
assertEquals(0, handler.cleanupQueueSize());
// Verify that completing again is a no-op as it was cleaned up.
assertEquals(SUCCESSFUL_RESPONSE, handler.finalizeBundle(requestFor("test")).build());
verifyNoMoreInteractions(callback);
executor.shutdownNow();
}

private static InstructionRequest requestFor(String bundleId) {
return InstructionRequest.newBuilder()
.setInstructionId(INSTRUCTION_ID)
Expand Down
Loading