diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java index 8e7c670d5c26..8752408770ac 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java @@ -160,6 +160,18 @@ void addIncomingTimerEndpoint( * instant provides the timeout on how long the finalization callback is valid for. */ DoFn.BundleFinalizer getBundleFinalizer(); + + /** + * Returns true if the runner has no state for the keys in the ProcessBundleRequest. If true, + * the SDK can begin stateful processing with an initial empty state. + */ + boolean getHasNoState(); + + /** + * Returns true if the runner will never process another bundle for the keys it contains. + * Therefore, the generated state need not be included in the bundle commit. + */ + boolean getOnlyBundleForKeys(); } /** diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index b8ad51816a7a..e4cffeed0b53 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -246,7 +246,9 @@ private void addRunnerAndConsumersForPTransformRecursively( BundleFinalizer bundleFinalizer, Collection> channelRoots, Map outboundAggregatorMap, - Set runnerCapabilities) + Set runnerCapabilities, + boolean hasNoState, + boolean onlyBundleForKeys) throws IOException { // Recursively ensure that all consumers of the output PCollection have been created. @@ -279,7 +281,9 @@ private void addRunnerAndConsumersForPTransformRecursively( bundleFinalizer, channelRoots, outboundAggregatorMap, - runnerCapabilities); + runnerCapabilities, + hasNoState, + onlyBundleForKeys); } } @@ -488,6 +492,16 @@ public BundleSplitListener getSplitListener() { public BundleFinalizer getBundleFinalizer() { return bundleFinalizer; } + + @Override + public boolean getHasNoState() { + return hasNoState; + } + + @Override + public boolean getOnlyBundleForKeys() { + return onlyBundleForKeys; + } }); processedPTransformIds.add(pTransformId); } @@ -913,7 +927,9 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) { bundleFinalizer, bundleProcessor.getChannelRoots(), bundleProcessor.getOutboundAggregators(), - bundleProcessor.getRunnerCapabilities()); + bundleProcessor.getRunnerCapabilities(), + processBundleRequest.getHasNoState(), + processBundleRequest.getOnlyBundleForKeys()); } bundleProcessor.finish(); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java index ba56c6d656ca..c4b98153709f 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java @@ -54,27 +54,33 @@ public class BagUserState { private List newValues; private boolean isCleared; private boolean isClosed; + private final boolean hasNoState; + private final boolean onlyBundleForKeys; static final int BAG_APPEND_BATCHING_LIMIT = 10 * 1024 * 1024; /** The cache must be namespaced for this state object accordingly. */ public BagUserState( - Cache cache, - BeamFnStateClient beamFnStateClient, - String instructionId, - StateKey stateKey, - Coder valueCoder) { + Cache cache, + BeamFnStateClient beamFnStateClient, + String instructionId, + StateKey stateKey, + Coder valueCoder, + boolean hasNoState, + boolean onlyBundleForKeys) { checkArgument( stateKey.hasBagUserState(), "Expected BagUserState StateKey but received %s.", stateKey); this.cache = cache; this.beamFnStateClient = beamFnStateClient; this.valueCoder = valueCoder; + this.hasNoState = hasNoState; + this.onlyBundleForKeys = onlyBundleForKeys; this.request = StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(); this.oldValues = StateFetchingIterators.readAllAndDecodeStartingFrom( - this.cache, beamFnStateClient, request, valueCoder); + this.cache, beamFnStateClient, request, valueCoder, hasNoState); this.newValues = new ArrayList<>(); } @@ -127,7 +133,7 @@ public void asyncClose() throws Exception { beamFnStateClient.handle( request.toBuilder().setClear(StateClearRequest.getDefaultInstance())); } - if (!newValues.isEmpty()) { + if (!onlyBundleForKeys && !newValues.isEmpty()) { // Batch values up to a arbitrary limit to reduce overhead of write // requests. We treat this limit as strict to ensure that large elements // are not batched as they may otherwise exceed runner limits. diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java index 6913c75a5f2d..c14d9745bc30 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java @@ -114,6 +114,8 @@ public static class Factory { private final Map, SideInputSpec> sideInputSpecMap; private final Coder keyCoder; private final Coder windowCoder; + private final boolean hasNoState; + private final boolean onlyBundleForKeys; public Factory( PipelineOptions pipelineOptions, @@ -126,7 +128,9 @@ public Factory( Map, SideInputSpec> sideInputSpecMap, BeamFnStateClient beamFnStateClient, Coder keyCoder, - Coder windowCoder) { + Coder windowCoder, + boolean hasNoState, + boolean onlyBundleForKeys) { this.pipelineOptions = pipelineOptions; this.runnerCapabilities = runnerCapabilities; this.ptransformId = ptransformId; @@ -138,6 +142,8 @@ public Factory( this.beamFnStateClient = beamFnStateClient; this.keyCoder = keyCoder; this.windowCoder = windowCoder; + this.hasNoState = hasNoState; + this.onlyBundleForKeys = onlyBundleForKeys; } public static Factory factoryForPTransformContext( @@ -220,7 +226,9 @@ public static Factory factoryForPTransformContext( tagToSideInputSpecMap, context.getBeamFnStateClient(), keyCoder, - windowCoder); + windowCoder, + context.getHasNoState(), + context.getOnlyBundleForKeys()); } public FnApiStateAccessor create() { @@ -235,7 +243,9 @@ public FnApiStateAccessor create() { sideInputSpecMap, beamFnStateClient, keyCoder, - windowCoder); + windowCoder, + hasNoState, + onlyBundleForKeys); } } @@ -252,6 +262,8 @@ public FnApiStateAccessor create() { private final Collection stateFinalizers; private final Coder keyCoder; private final Coder windowCoder; + private final boolean hasNoState; + private final boolean onlyBundleForKeys; private @Nullable Supplier currentWindowSupplier; private @Nullable Supplier encodedCurrentKeySupplier; @@ -268,7 +280,9 @@ public FnApiStateAccessor( Map, SideInputSpec> sideInputSpecMap, BeamFnStateClient beamFnStateClient, Coder keyCoder, - Coder windowCoder) { + Coder windowCoder, + boolean hasNoState, + boolean onlyBundleForKeys) { this.pipelineOptions = pipelineOptions; this.runnerCapabilities = runnerCapabilities; this.stateKeyObjectCache = Maps.newHashMap(); @@ -282,6 +296,8 @@ public FnApiStateAccessor( this.keyCoder = keyCoder; this.windowCoder = windowCoder; this.stateFinalizers = new ArrayList<>(); + this.hasNoState = hasNoState; + this.onlyBundleForKeys = onlyBundleForKeys; } public void setKeyAndWindowContext(MutatingStateContext keyAndWindowContext) { @@ -417,7 +433,8 @@ public T get(PCollectionView view, BoundedWindow window) { runnerCapabilities.contains( BeamUrns.getUrn( RunnerApi.StandardRunnerProtocols.Enum - .MULTIMAP_KEYS_VALUES_SIDE_INPUT)))); + .MULTIMAP_KEYS_VALUES_SIDE_INPUT)), + hasNoState)); default: throw new IllegalStateException( String.format( @@ -1201,7 +1218,9 @@ private BagUserState createBagUserState(StateKey stateKey, Coder value beamFnStateClient, processBundleInstructionId.get(), stateKey, - valueCoder); + valueCoder, + hasNoState, + onlyBundleForKeys); stateFinalizers.add(rval::asyncClose); return rval; } @@ -1283,7 +1302,9 @@ private MultimapUserState createMultimapUserState( processBundleInstructionId.get(), stateKey, keyCoder, - valueCoder); + valueCoder, + hasNoState, + onlyBundleForKeys); stateFinalizers.add(rval::asyncClose); return rval; } @@ -1318,7 +1339,9 @@ private OrderedListUserState createOrderedListUserState( beamFnStateClient, processBundleInstructionId.get(), stateKey, - valueCoder); + valueCoder, + hasNoState, + onlyBundleForKeys); stateFinalizers.add(rval::asyncClose); return rval; } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java index 7f87cb4d4e41..b5492bfa8565 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/IterableSideInput.java @@ -40,7 +40,8 @@ public IterableSideInput( BeamFnStateClient beamFnStateClient, String instructionId, StateKey stateKey, - Coder valueCoder) { + Coder valueCoder, + boolean hasNoState) { checkArgument( stateKey.hasIterableSideInput(), "Expected IterableSideInput StateKey but received %s.", @@ -50,7 +51,8 @@ public IterableSideInput( cache, beamFnStateClient, StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(), - valueCoder); + valueCoder, + hasNoState); } @Override diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java index 8e38a57ff9fe..e403332d8f13 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java @@ -55,6 +55,7 @@ public class MultimapSideInput implements MultimapView { private final Coder valueCoder; private volatile Function> bulkReadResult; private final boolean useBulkRead; + private final boolean hasNoState; public MultimapSideInput( Cache cache, @@ -63,7 +64,8 @@ public MultimapSideInput( StateKey stateKey, Coder keyCoder, Coder valueCoder, - boolean useBulkRead) { + boolean useBulkRead, + boolean hasNoState) { checkArgument( stateKey.hasMultimapKeysSideInput(), "Expected MultimapKeysSideInput StateKey but received %s.", @@ -75,12 +77,13 @@ public MultimapSideInput( this.keyCoder = keyCoder; this.valueCoder = valueCoder; this.useBulkRead = useBulkRead; + this.hasNoState = hasNoState; } @Override public Iterable get() { return StateFetchingIterators.readAllAndDecodeStartingFrom( - cache, beamFnStateClient, keysRequest, keyCoder); + cache, beamFnStateClient, keysRequest, keyCoder, hasNoState); } @Override @@ -120,7 +123,8 @@ public Iterable get(K k) { Caches.noop(), beamFnStateClient, bulkReadRequest, - KvCoder.of(keyCoder, IterableCoder.of(valueCoder))) + KvCoder.of(keyCoder, IterableCoder.of(valueCoder)), + hasNoState) .iterator(); while (bulkRead.size() < BULK_READ_SIZE && entries.hasNext()) { KV> entry = entries.next(); @@ -169,7 +173,11 @@ public Iterable get(K k) { StateRequest request = keysRequest.toBuilder().setStateKey(stateKey).build(); return StateFetchingIterators.readAllAndDecodeStartingFrom( - Caches.subCache(cache, "ValuesForKey", encodedKey), beamFnStateClient, request, valueCoder); + Caches.subCache(cache, "ValuesForKey", encodedKey), + beamFnStateClient, + request, + valueCoder, + hasNoState); } private ByteString encodeKey(K k) { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java index 83d78ff836c7..da858240e569 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java @@ -72,6 +72,8 @@ public class MultimapUserState { private boolean isClosed; private boolean isCleared; + private final boolean hasNoState; + private final boolean onlyBundleForKeys; // Pending updates to persistent storage private HashMap pendingRemoves = Maps.newHashMap(); private HashMap>> pendingAdds = Maps.newHashMap(); @@ -84,7 +86,9 @@ public MultimapUserState( String instructionId, StateKey stateKey, Coder mapKeyCoder, - Coder valueCoder) { + Coder valueCoder, + boolean hasNoState, + boolean onlyBundleForKeys) { checkArgument( stateKey.hasMultimapKeysUserState(), "Expected MultimapKeysUserState StateKey but received %s.", @@ -93,6 +97,8 @@ public MultimapUserState( this.beamFnStateClient = beamFnStateClient; this.mapKeyCoder = mapKeyCoder; this.valueCoder = valueCoder; + this.hasNoState = hasNoState; + this.onlyBundleForKeys = onlyBundleForKeys; // Note: These StateRequest protos are constructed even if we never try to read the // corresponding state type. Consider constructing them lazily, as needed. @@ -100,7 +106,7 @@ public MultimapUserState( StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(); this.persistedKeys = StateFetchingIterators.readAllAndDecodeStartingFrom( - cache, beamFnStateClient, keysStateRequest, mapKeyCoder); + cache, beamFnStateClient, keysStateRequest, mapKeyCoder, hasNoState); StateRequest.Builder userStateRequestBuilder = StateRequest.newBuilder(); userStateRequestBuilder @@ -128,7 +134,8 @@ public MultimapUserState( Caches.subCache(this.cache, "AllEntries"), beamFnStateClient, entriesStateRequest, - KvCoder.of(mapKeyCoder, IterableCoder.of(valueCoder))); + KvCoder.of(mapKeyCoder, IterableCoder.of(valueCoder)), + hasNoState); } public void clear() { @@ -438,7 +445,7 @@ private void startStateApiWrites() { } // Persist pending key-values - if (!pendingAdds.isEmpty()) { + if (!pendingAdds.isEmpty() && !onlyBundleForKeys) { for (KV> entry : pendingAdds.values()) { StateRequest request = createUserStateRequest(entry.getKey()); beamFnStateClient.handle( @@ -542,7 +549,8 @@ private CachingStateIterable getPersistedValues(Object structuralKey, K key) request.getStateKey().getMultimapUserState().getMapKey()), beamFnStateClient, request, - valueCoder)); + valueCoder, + hasNoState)); }) .getValue(); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/OrderedListUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/OrderedListUserState.java index 47b5057880b9..008ecb77cb37 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/OrderedListUserState.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/OrderedListUserState.java @@ -88,7 +88,8 @@ public class OrderedListUserState { private boolean isCleared = false; private boolean isClosed = false; - + private final boolean hasNoState; + private final boolean onlyBundleForKeys; public static class TimestampedValueCoder extends StructuredCoder> { private final Coder valueCoder; @@ -161,7 +162,9 @@ public OrderedListUserState( BeamFnStateClient beamFnStateClient, String instructionId, StateKey stateKey, - Coder valueCoder) { + Coder valueCoder, + boolean hasNoState, + boolean onlyBundleForKeys) { checkArgument( stateKey.hasOrderedListUserState(), "Expected OrderedListUserState StateKey but received %s.", @@ -170,6 +173,8 @@ public OrderedListUserState( this.timestampedValueCoder = TimestampedValueCoder.of(valueCoder); this.requestTemplate = StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(); + this.hasNoState = hasNoState; + this.onlyBundleForKeys = onlyBundleForKeys; } public void add(TimestampedValue value) { @@ -218,7 +223,8 @@ public Iterable> readRange(Instant minTimestamp, Instant lim Caches.noop(), this.beamFnStateClient, getRequestBuilder.build(), - this.timestampedValueCoder); + this.timestampedValueCoder, + hasNoState); // Make a snapshot of the current pendingRemoves and use them to filter persistent values. // The values of pendingRemoves are copied, so that they will still be accessible in @@ -303,6 +309,11 @@ public void asyncClose() throws Exception { pendingRemoves.clear(); } + if (onlyBundleForKeys) { + pendingAdds.clear(); + return; + } + if (!pendingAdds.isEmpty()) { ByteStringOutputStream outStream = new ByteStringOutputStream(); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 1e06c98f2e31..0f53a472ba4c 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -86,12 +86,14 @@ public static CachingStateIterable readAllAndDecodeStartingFrom( Cache cache, BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk, - Coder valueCoder) { + Coder valueCoder, + boolean hasNoState) { return new CachingStateIterable<>( (Cache>) cache, beamFnStateClient, stateRequestForFirstChunk, - valueCoder); + valueCoder, + hasNoState); } /** @@ -328,16 +330,19 @@ public static Block fromValues( private final BeamFnStateClient beamFnStateClient; private final StateRequest stateRequestForFirstChunk; private final Coder valueCoder; + private final boolean hasNoState; public CachingStateIterable( Cache> cache, BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk, - Coder valueCoder) { + Coder valueCoder, + boolean hasNoState) { this.cache = cache; this.beamFnStateClient = beamFnStateClient; this.stateRequestForFirstChunk = stateRequestForFirstChunk; this.valueCoder = valueCoder; + this.hasNoState = hasNoState; } /** @@ -510,7 +515,8 @@ class CachingStateIterator implements PrefetchableIterator { public CachingStateIterator() { this.underlyingStateFetchingIterator = - new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk); + new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk, + hasNoState); this.dataStreamDecoder = new DataStreamDecoder<>(valueCoder, underlyingStateFetchingIterator); this.currentBlock = @@ -677,6 +683,15 @@ static class LazyBlockingStateFetchingIterator implements PrefetchableIterator