keys = stateManager.getProcessingKeys();
if (keys != null) {
for (Object key : keys) {
- keySegmentQueue.addKeyToLastSegment(key);
+ eventRouter.getKeySegmentQueue().addKeyToLastSegment(key);
mailboxExecutor.submit(
() -> tryProcessActionTaskForKey(key), "process action task");
}
}
- getKeyedStateBackend()
- .applyToAllKeys(
- VoidNamespace.INSTANCE,
- VoidNamespaceSerializer.INSTANCE,
- new ListStateDescriptor<>(
- PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)),
- (key, state) ->
- state.get()
- .forEach(
- event -> keySegmentQueue.addKeyToLastSegment(key)));
- }
-
- private void initOrIncSequenceNumber() throws Exception {
- // Initialize the sequence number state if it does not exist.
- Long sequenceNumber = sequenceNumberKState.value();
- if (sequenceNumber == null) {
- sequenceNumberKState.update(0L);
- } else {
- sequenceNumberKState.update(sequenceNumber + 1);
- }
- }
-
- private ActionState maybeGetActionState(
- Object key, long sequenceNum, Action action, Event event) throws Exception {
- return actionStateStore == null
- ? null
- : actionStateStore.get(key.toString(), sequenceNum, action, event);
- }
-
- private void maybeInitActionState(Object key, long sequenceNum, Action action, Event event)
- throws Exception {
- if (actionStateStore != null) {
- // Initialize the action state if it does not exist. It will exist when the action is an
- // async action and
- // has been persisted before the action task is finished.
- if (actionStateStore.get(key, sequenceNum, action, event) == null) {
- actionStateStore.put(key, sequenceNum, action, event, new ActionState(event));
- }
- }
- }
-
- private void maybePersistTaskResult(
- Object key,
- long sequenceNum,
- Action action,
- Event event,
- RunnerContextImpl context,
- ActionTask.ActionTaskResult actionTaskResult)
- throws Exception {
- if (actionStateStore == null) {
- return;
- }
-
- // if the task is not finished, we skip the persistence for now and wait until it is
- // finished.
- if (!actionTaskResult.isFinished()) {
- return;
- }
-
- ActionState actionState = actionStateStore.get(key, sequenceNum, action, event);
-
- for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) {
- actionState.addSensoryMemoryUpdate(memoryUpdate);
- }
-
- for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) {
- actionState.addShortTermMemoryUpdate(memoryUpdate);
- }
-
- for (Event outputEvent : actionTaskResult.getOutputEvents()) {
- actionState.addEvent(outputEvent);
- }
-
- // Mark the action as completed and clear call records
- // This indicates that recovery should skip the entire action
- actionState.markCompleted();
-
- actionStateStore.put(key, sequenceNum, action, event, actionState);
-
- // Clear durable execution context
- context.clearDurableExecutionContext();
- }
-
- /**
- * Sets up the durable execution context for fine-grained recovery.
- *
- * This method initializes the runner context with a {@link
- * RunnerContextImpl.DurableExecutionContext}, which enables execute/execute_async calls to:
- *
- *
- * Skip re-execution for already completed calls during recovery
- * Persist CallRecords after each code block completion
- *
- */
- private void setupDurableExecutionContext(ActionTask actionTask, ActionState actionState) {
- if (actionStateStore == null) {
- return;
- }
-
- RunnerContextImpl.DurableExecutionContext durableContext;
- if (actionTaskDurableContexts.containsKey(actionTask)) {
- // Reuse existing context for async action continuation
- durableContext = actionTaskDurableContexts.get(actionTask);
- } else {
- // Create new context for first invocation
- final long sequenceNumber;
- try {
- sequenceNumber = sequenceNumberKState.value();
- } catch (Exception e) {
- throw new RuntimeException("Failed to get sequence number from state", e);
- }
-
- durableContext =
- new RunnerContextImpl.DurableExecutionContext(
- actionTask.getKey(),
- sequenceNumber,
- actionTask.action,
- actionTask.event,
- actionState,
- this);
- }
-
- actionTask.getRunnerContext().setDurableExecutionContext(durableContext);
- }
-
- @Override
- public void persist(
- Object key, long sequenceNumber, Action action, Event event, ActionState actionState) {
- try {
- actionStateStore.put(key, sequenceNumber, action, event, actionState);
- } catch (Exception e) {
- LOG.error("Failed to persist ActionState", e);
- throw new RuntimeException("Failed to persist ActionState", e);
- }
- }
-
- private void maybePruneState(Object key, long sequenceNum) throws Exception {
- if (actionStateStore != null) {
- actionStateStore.pruneState(key, sequenceNum);
- }
- }
-
- private void processEligibleWatermarks() throws Exception {
- Watermark mark = keySegmentQueue.popOldestWatermark();
- while (mark != null) {
- super.processWatermark(mark);
- mark = keySegmentQueue.popOldestWatermark();
- }
- }
-
- private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) {
- if (isJava) {
- if (runnerContext == null) {
- if (continuationActionExecutor == null) {
- continuationActionExecutor =
- new ContinuationActionExecutor(
- agentPlan
- .getConfig()
- .get(AgentExecutionOptions.NUM_ASYNC_THREADS));
- }
- runnerContext =
- new JavaRunnerContextImpl(
- this.metricGroup,
- this::checkMailboxThread,
- this.agentPlan,
- this.jobIdentifier,
- continuationActionExecutor);
- }
- return runnerContext;
- } else {
- if (pythonRunnerContext == null) {
- pythonRunnerContext =
- new PythonRunnerContextImpl(
- this.metricGroup,
- this::checkMailboxThread,
- this.agentPlan,
- jobIdentifier);
- }
- return pythonRunnerContext;
- }
- }
-
- private EventLogger createEventLogger(AgentPlan agentPlan) {
- EventLoggerConfig.Builder loggerConfigBuilder = EventLoggerConfig.builder();
- String baseLogDir = agentPlan.getConfig().get(BASE_LOG_DIR);
- if (baseLogDir != null && !baseLogDir.trim().isEmpty()) {
- loggerConfigBuilder.property(FileEventLogger.BASE_LOG_DIR_PROPERTY_KEY, baseLogDir);
- }
- return EventLoggerFactory.createLogger(loggerConfigBuilder.build());
- }
-
- private void maybeInitActionStateStore() {
- if (actionStateStore == null
- && KAFKA.getType()
- .equalsIgnoreCase(agentPlan.getConfig().get(ACTION_STATE_STORE_BACKEND))) {
- LOG.info("Using Kafka as backend of action state store.");
- actionStateStore = new KafkaActionStateStore(agentPlan.getConfig());
- }
+ stateManager.forEachPendingInputEventKey(
+ getKeyedStateBackend(),
+ key -> eventRouter.getKeySegmentQueue().addKeyToLastSegment(key));
}
/** Failed to execute Action task. */
diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java
new file mode 100644
index 00000000..c3a5e701
--- /dev/null
+++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.operator;
+
+import org.apache.flink.agents.api.agents.AgentExecutionOptions;
+import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.agents.plan.JavaFunction;
+import org.apache.flink.agents.plan.PythonFunction;
+import org.apache.flink.agents.runtime.async.ContinuationActionExecutor;
+import org.apache.flink.agents.runtime.async.ContinuationContext;
+import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl;
+import org.apache.flink.agents.runtime.context.RunnerContextImpl;
+import org.apache.flink.agents.runtime.memory.CachedMemoryStore;
+import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
+import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
+import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl;
+import org.apache.flink.api.common.state.MapState;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Manages runner context lifecycle for action tasks, including creation, context switching for
+ * async continuation, and cleanup.
+ *
+ * Holds four transient maps that track intermediate context state across async action task
+ * invocations: memory contexts, durable execution contexts, continuation contexts, and Python
+ * awaitable references.
+ */
+class ActionTaskContextManager {
+
+ // Invariant dependencies (set once at construction, never change)
+ private final FlinkAgentsMetricGroupImpl metricGroup;
+ private final Runnable mailboxThreadChecker;
+ private final AgentPlan agentPlan;
+ private final String jobIdentifier;
+ private final PythonRunnerContextImpl pythonRunnerContext;
+
+ // Lazily created singleton runner context for Java actions
+ private RunnerContextImpl javaRunnerContext;
+
+ // Transient maps for in-flight async action tasks
+ private final Map actionTaskMemoryContexts;
+ private final Map
+ actionTaskDurableContexts;
+ private final Map continuationContexts;
+ private final Map pythonAwaitableRefs;
+
+ private final ContinuationActionExecutor continuationActionExecutor;
+
+ ActionTaskContextManager(
+ FlinkAgentsMetricGroupImpl metricGroup,
+ Runnable mailboxThreadChecker,
+ AgentPlan agentPlan,
+ String jobIdentifier,
+ PythonRunnerContextImpl pythonRunnerContext) {
+ this.metricGroup = metricGroup;
+ this.mailboxThreadChecker = mailboxThreadChecker;
+ this.agentPlan = agentPlan;
+ this.jobIdentifier = jobIdentifier;
+ this.pythonRunnerContext = pythonRunnerContext;
+ this.continuationActionExecutor =
+ new ContinuationActionExecutor(
+ agentPlan.getConfig().get(AgentExecutionOptions.NUM_ASYNC_THREADS));
+ this.actionTaskMemoryContexts = new HashMap<>();
+ this.actionTaskDurableContexts = new HashMap<>();
+ this.continuationContexts = new HashMap<>();
+ this.pythonAwaitableRefs = new HashMap<>();
+ }
+
+ Map getActionTaskDurableContexts() {
+ return actionTaskDurableContexts;
+ }
+
+ /**
+ * Creates or retrieves the appropriate runner context for the given action task and configures
+ * it with the correct memory context, continuation context, and awaitable ref.
+ *
+ * If the task has previously cached contexts (from an unfinished async action), those are
+ * restored. Otherwise, fresh contexts are created from Flink state.
+ */
+ void createAndSetRunnerContext(
+ ActionTask actionTask,
+ Object key,
+ MapState sensoryMemState,
+ MapState shortTermMemState) {
+ RunnerContextImpl ctx;
+ if (actionTask.action.getExec() instanceof JavaFunction) {
+ ctx = getOrCreateJavaRunnerContext();
+ } else if (actionTask.action.getExec() instanceof PythonFunction) {
+ ctx = getOrCreatePythonRunnerContext();
+ } else {
+ throw new IllegalStateException(
+ "Unsupported action type: " + actionTask.action.getExec().getClass());
+ }
+
+ RunnerContextImpl.MemoryContext memoryContext = actionTaskMemoryContexts.get(actionTask);
+ if (memoryContext == null) {
+ memoryContext =
+ new RunnerContextImpl.MemoryContext(
+ new CachedMemoryStore(sensoryMemState),
+ new CachedMemoryStore(shortTermMemState));
+ }
+
+ ctx.switchActionContext(
+ actionTask.action.getName(), memoryContext, String.valueOf(key.hashCode()));
+
+ if (ctx instanceof JavaRunnerContextImpl) {
+ ContinuationContext continuationContext = continuationContexts.get(actionTask);
+ if (continuationContext == null) {
+ continuationContext = new ContinuationContext();
+ }
+ ((JavaRunnerContextImpl) ctx).setContinuationContext(continuationContext);
+ }
+ if (ctx instanceof PythonRunnerContextImpl) {
+ String awaitableRef = pythonAwaitableRefs.get(actionTask);
+ ((PythonRunnerContextImpl) ctx).setPythonAwaitableRef(awaitableRef);
+ }
+ actionTask.setRunnerContext(ctx);
+ }
+
+ /**
+ * Removes cached contexts for the given action task after it completes execution. Called after
+ * each action task invocation.
+ */
+ void removeContexts(ActionTask actionTask) {
+ actionTaskMemoryContexts.remove(actionTask);
+ actionTaskDurableContexts.remove(actionTask);
+ continuationContexts.remove(actionTask);
+ pythonAwaitableRefs.remove(actionTask);
+ }
+
+ /**
+ * Transfers context state from a completed (but not finished) action task to its generated
+ * continuation task. This preserves memory, durable execution, continuation, and awaitable
+ * state across async invocations.
+ */
+ void transferContexts(ActionTask fromTask, ActionTask toTask) {
+ actionTaskMemoryContexts.put(toTask, fromTask.getRunnerContext().getMemoryContext());
+ RunnerContextImpl.DurableExecutionContext durableContext =
+ fromTask.getRunnerContext().getDurableExecutionContext();
+ if (durableContext != null) {
+ actionTaskDurableContexts.put(toTask, durableContext);
+ }
+ if (fromTask.getRunnerContext() instanceof JavaRunnerContextImpl) {
+ continuationContexts.put(
+ toTask,
+ ((JavaRunnerContextImpl) fromTask.getRunnerContext()).getContinuationContext());
+ }
+ if (fromTask.getRunnerContext() instanceof PythonRunnerContextImpl) {
+ String awaitableRef =
+ ((PythonRunnerContextImpl) fromTask.getRunnerContext()).getPythonAwaitableRef();
+ if (awaitableRef != null) {
+ pythonAwaitableRefs.put(toTask, awaitableRef);
+ }
+ }
+ }
+
+ private RunnerContextImpl getOrCreateJavaRunnerContext() {
+ if (javaRunnerContext == null) {
+ javaRunnerContext =
+ new JavaRunnerContextImpl(
+ metricGroup,
+ mailboxThreadChecker,
+ agentPlan,
+ jobIdentifier,
+ continuationActionExecutor);
+ }
+ return javaRunnerContext;
+ }
+
+ private RunnerContextImpl getOrCreatePythonRunnerContext() {
+ if (pythonRunnerContext == null) {
+ throw new IllegalStateException(
+ "PythonRunnerContext should have been initialized by PythonBridgeManager");
+ }
+ return pythonRunnerContext;
+ }
+
+ void close() throws Exception {
+ if (javaRunnerContext != null) {
+ try {
+ javaRunnerContext.close();
+ } finally {
+ javaRunnerContext = null;
+ }
+ }
+ continuationActionExecutor.close();
+ }
+}
diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java
new file mode 100644
index 00000000..f3315990
--- /dev/null
+++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.operator;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.context.MemoryUpdate;
+import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.agents.plan.actions.Action;
+import org.apache.flink.agents.runtime.actionstate.ActionState;
+import org.apache.flink.agents.runtime.actionstate.ActionStateStore;
+import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore;
+import org.apache.flink.agents.runtime.context.ActionStatePersister;
+import org.apache.flink.agents.runtime.context.RunnerContextImpl;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND;
+import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA;
+
+/**
+ * Manages durable execution state including the {@link ActionStateStore}, action state persistence,
+ * recovery markers, and checkpoint-based state pruning.
+ *
+ * When no {@link ActionStateStore} is configured (the common case for simple agents), all
+ * methods are no-ops. This class implements {@link ActionStatePersister} so it can be passed to
+ * {@link RunnerContextImpl.DurableExecutionContext} for fine-grained call-level recovery.
+ */
+class DurableExecutionManager implements ActionStatePersister {
+
+ private static final Logger LOG = LoggerFactory.getLogger(DurableExecutionManager.class);
+
+ private ActionStateStore actionStateStore;
+ private final Map> checkpointIdToSeqNums;
+
+ DurableExecutionManager(ActionStateStore actionStateStore) {
+ this.actionStateStore = actionStateStore;
+ this.checkpointIdToSeqNums = new HashMap<>();
+ }
+
+ ActionStateStore getActionStateStore() {
+ return actionStateStore;
+ }
+
+ /** Initializes the Kafka action state store if configured but not yet created. */
+ void maybeInitActionStateStore(AgentPlan agentPlan) {
+ if (actionStateStore == null
+ && KAFKA.getType()
+ .equalsIgnoreCase(agentPlan.getConfig().get(ACTION_STATE_STORE_BACKEND))) {
+ LOG.info("Using Kafka as backend of action state store.");
+ actionStateStore = new KafkaActionStateStore(agentPlan.getConfig());
+ }
+ }
+
+ /** Rebuilds state from recovery markers during state initialization. */
+ void rebuildStateFromMarkers(List markers) throws Exception {
+ if (actionStateStore != null) {
+ LOG.info("Rebuilding action state from {} recovery markers", markers.size());
+ actionStateStore.rebuildState(markers);
+ }
+ }
+
+ ActionState maybeGetActionState(Object key, long sequenceNum, Action action, Event event)
+ throws Exception {
+ return actionStateStore == null
+ ? null
+ : actionStateStore.get(key.toString(), sequenceNum, action, event);
+ }
+
+ void maybeInitActionState(Object key, long sequenceNum, Action action, Event event)
+ throws Exception {
+ if (actionStateStore != null) {
+ if (actionStateStore.get(key, sequenceNum, action, event) == null) {
+ actionStateStore.put(key, sequenceNum, action, event, new ActionState(event));
+ }
+ }
+ }
+
+ void maybePersistTaskResult(
+ Object key,
+ long sequenceNum,
+ Action action,
+ Event event,
+ RunnerContextImpl context,
+ ActionTask.ActionTaskResult actionTaskResult)
+ throws Exception {
+ if (actionStateStore == null || !actionTaskResult.isFinished()) {
+ return;
+ }
+
+ ActionState actionState = actionStateStore.get(key, sequenceNum, action, event);
+
+ for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) {
+ actionState.addSensoryMemoryUpdate(memoryUpdate);
+ }
+
+ for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) {
+ actionState.addShortTermMemoryUpdate(memoryUpdate);
+ }
+
+ for (Event outputEvent : actionTaskResult.getOutputEvents()) {
+ actionState.addEvent(outputEvent);
+ }
+
+ actionState.markCompleted();
+ actionStateStore.put(key, sequenceNum, action, event, actionState);
+ context.clearDurableExecutionContext();
+ }
+
+ /**
+ * Sets up the durable execution context for fine-grained recovery.
+ *
+ * This method initializes the runner context with a {@link
+ * RunnerContextImpl.DurableExecutionContext}, which enables execute/execute_async calls to:
+ *
+ *
+ * Skip re-execution for already completed calls during recovery
+ * Persist CallRecords after each code block completion
+ *
+ */
+ void setupDurableExecutionContext(
+ ActionTask actionTask,
+ ActionState actionState,
+ long sequenceNumber,
+ Map actionTaskDurableContexts) {
+ if (actionStateStore == null) {
+ return;
+ }
+
+ RunnerContextImpl.DurableExecutionContext durableContext =
+ actionTaskDurableContexts.get(actionTask);
+ if (durableContext == null) {
+ durableContext =
+ new RunnerContextImpl.DurableExecutionContext(
+ actionTask.getKey(),
+ sequenceNumber,
+ actionTask.action,
+ actionTask.event,
+ actionState,
+ this);
+ }
+
+ actionTask.getRunnerContext().setDurableExecutionContext(durableContext);
+ }
+
+ @Override
+ public void persist(
+ Object key, long sequenceNumber, Action action, Event event, ActionState actionState) {
+ try {
+ actionStateStore.put(key, sequenceNumber, action, event, actionState);
+ } catch (Exception e) {
+ LOG.error("Failed to persist ActionState", e);
+ throw new RuntimeException("Failed to persist ActionState", e);
+ }
+ }
+
+ void maybePruneState(Object key, long sequenceNum) throws Exception {
+ if (actionStateStore != null) {
+ actionStateStore.pruneState(key, sequenceNum);
+ }
+ }
+
+ /** Records sequence numbers per key at snapshot time for later pruning. */
+ void recordCheckpointSeqNums(long checkpointId, Map keyToSeqNum) {
+ checkpointIdToSeqNums.put(checkpointId, keyToSeqNum);
+ }
+
+ /** Prunes state for completed checkpoints. */
+ void notifyCheckpointComplete(long checkpointId) throws Exception {
+ Map keyToSeqNum = checkpointIdToSeqNums.remove(checkpointId);
+ if (actionStateStore != null && keyToSeqNum != null) {
+ for (Map.Entry entry : keyToSeqNum.entrySet()) {
+ actionStateStore.pruneState(entry.getKey(), entry.getValue());
+ }
+ }
+ }
+
+ /** Gets recovery marker from the action state store, if available. */
+ Object getRecoveryMarker() {
+ return actionStateStore != null ? actionStateStore.getRecoveryMarker() : null;
+ }
+
+ void close() throws Exception {
+ if (actionStateStore != null) {
+ actionStateStore.close();
+ }
+ }
+}
diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java
new file mode 100644
index 00000000..47c063ca
--- /dev/null
+++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.operator;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.EventContext;
+import org.apache.flink.agents.api.InputEvent;
+import org.apache.flink.agents.api.OutputEvent;
+import org.apache.flink.agents.api.listener.EventListener;
+import org.apache.flink.agents.api.logger.EventLogger;
+import org.apache.flink.agents.api.logger.EventLoggerConfig;
+import org.apache.flink.agents.api.logger.EventLoggerFactory;
+import org.apache.flink.agents.api.logger.EventLoggerOpenParams;
+import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.agents.plan.actions.Action;
+import org.apache.flink.agents.runtime.eventlog.FileEventLogger;
+import org.apache.flink.agents.runtime.metrics.BuiltInMetrics;
+import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue;
+import org.apache.flink.agents.runtime.python.event.PythonEvent;
+import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
+import org.apache.flink.agents.runtime.utils.EventUtil;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.types.Row;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Handles event wrapping, unwrapping, routing to actions, notification (event logger and
+ * listeners), and watermark management via a {@link SegmentedQueue}.
+ */
+class EventRouter {
+
+ private final AgentPlan agentPlan;
+ private final boolean inputIsJava;
+ private final EventLogger eventLogger;
+ private final List eventListeners;
+ private final SegmentedQueue keySegmentQueue;
+ private BuiltInMetrics builtInMetrics;
+ private PythonActionExecutor pythonActionExecutor;
+
+ EventRouter(AgentPlan agentPlan, boolean inputIsJava) {
+ this.agentPlan = agentPlan;
+ this.inputIsJava = inputIsJava;
+ this.eventLogger = createEventLogger(agentPlan);
+ this.eventListeners = new ArrayList<>();
+ this.keySegmentQueue = new SegmentedQueue();
+ }
+
+ SegmentedQueue getKeySegmentQueue() {
+ return keySegmentQueue;
+ }
+
+ void setBuiltInMetrics(BuiltInMetrics builtInMetrics) {
+ this.builtInMetrics = builtInMetrics;
+ }
+
+ void setPythonActionExecutor(PythonActionExecutor pythonActionExecutor) {
+ this.pythonActionExecutor = pythonActionExecutor;
+ }
+
+ /** Initializes the event logger if it is set. */
+ void initEventLogger(StreamingRuntimeContext runtimeContext) throws Exception {
+ if (eventLogger == null) {
+ return;
+ }
+ eventLogger.open(new EventLoggerOpenParams(runtimeContext));
+ }
+
+ /** Notifies event logger and listeners that an event has been processed. */
+ void notifyEventProcessed(Event event) throws Exception {
+ EventContext eventContext = new EventContext(event);
+ if (eventLogger != null) {
+ eventLogger.append(eventContext, event);
+ eventLogger.flush();
+ }
+ for (EventListener listener : eventListeners) {
+ listener.onEventProcessed(eventContext, event);
+ }
+ builtInMetrics.markEventProcessed();
+ }
+
+ /** Wraps raw input into an InputEvent (Java or Python). */
+ @SuppressWarnings("unchecked")
+ Event wrapToInputEvent(IN input) {
+ if (inputIsJava) {
+ return new InputEvent(input);
+ } else {
+ checkState(input instanceof Row && ((Row) input).getArity() == 2);
+ return pythonActionExecutor.wrapToInputEvent(((Row) input).getField(1));
+ }
+ }
+
+ /** Extracts output data from an OutputEvent (Java or Python). */
+ @SuppressWarnings("unchecked")
+ OUT getOutputFromOutputEvent(Event event) {
+ checkState(EventUtil.isOutputEvent(event));
+ if (event instanceof OutputEvent) {
+ return (OUT) ((OutputEvent) event).getOutput();
+ } else if (event instanceof PythonEvent) {
+ return (OUT)
+ pythonActionExecutor.getOutputFromOutputEvent(((PythonEvent) event).getEvent());
+ } else {
+ throw new IllegalStateException(
+ "Unsupported event type: " + event.getClass().getName());
+ }
+ }
+
+ /** Gets actions triggered by a given event from the agent plan. */
+ List getActionsTriggeredBy(Event event) {
+ if (event instanceof PythonEvent) {
+ return agentPlan.getActionsTriggeredBy(((PythonEvent) event).getEventType());
+ } else {
+ return agentPlan.getActionsTriggeredBy(event.getClass().getName());
+ }
+ }
+
+ /** Processes and emits all eligible watermarks from the segmented queue. */
+ void processEligibleWatermarks(WatermarkEmitter emitter) throws Exception {
+ Watermark mark = keySegmentQueue.popOldestWatermark();
+ while (mark != null) {
+ emitter.emit(mark);
+ mark = keySegmentQueue.popOldestWatermark();
+ }
+ }
+
+ void close() throws Exception {
+ if (eventLogger != null) {
+ eventLogger.close();
+ }
+ }
+
+ private static EventLogger createEventLogger(AgentPlan agentPlan) {
+ EventLoggerConfig.Builder loggerConfigBuilder = EventLoggerConfig.builder();
+ String baseLogDir = agentPlan.getConfig().get(BASE_LOG_DIR);
+ if (baseLogDir != null && !baseLogDir.trim().isEmpty()) {
+ loggerConfigBuilder.property(FileEventLogger.BASE_LOG_DIR_PROPERTY_KEY, baseLogDir);
+ }
+ return EventLoggerFactory.createLogger(loggerConfigBuilder.build());
+ }
+
+ /** Callback interface for emitting watermarks through the operator. */
+ @FunctionalInterface
+ interface WatermarkEmitter {
+ void emit(Watermark watermark) throws Exception;
+ }
+}
diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java
new file mode 100644
index 00000000..b668804a
--- /dev/null
+++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.operator;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import static org.apache.flink.agents.runtime.utils.StateUtil.*;
+
+/**
+ * Manages all Flink state objects used by the {@link ActionExecutionOperator}: keyed states for
+ * memory, action tasks, pending input events, sequence numbers, and operator states for processing
+ * keys and recovery markers.
+ */
+class OperatorStateManager {
+
+ static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker";
+ static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber";
+ static final String PENDING_INPUT_EVENT_STATE_NAME = "pendingInputEvents";
+
+ private MapState sensoryMemState;
+ private MapState shortTermMemState;
+ private ListState actionTasksKState;
+ private ListState pendingInputEventsKState;
+ private ListState currentProcessingKeysOpState;
+ private ValueState sequenceNumberKState;
+ private ListState recoveryMarkerOpState;
+
+ // --- State Initialization ---
+
+ /**
+ * Initializes all Flink state objects. Should be called from the operator's {@code open()}
+ * method.
+ */
+ void initializeStates(
+ StreamingRuntimeContext runtimeContext,
+ OperatorStateBackend operatorStateBackend,
+ boolean hasActionStateStore)
+ throws Exception {
+ sensoryMemState =
+ runtimeContext.getMapState(
+ new MapStateDescriptor<>(
+ "sensoryMemory",
+ TypeInformation.of(String.class),
+ TypeInformation.of(MemoryObjectImpl.MemoryItem.class)));
+
+ shortTermMemState =
+ runtimeContext.getMapState(
+ new MapStateDescriptor<>(
+ "shortTermMemory",
+ TypeInformation.of(String.class),
+ TypeInformation.of(MemoryObjectImpl.MemoryItem.class)));
+
+ if (hasActionStateStore) {
+ recoveryMarkerOpState =
+ operatorStateBackend.getUnionListState(
+ new ListStateDescriptor<>(
+ RECOVERY_MARKER_STATE_NAME, TypeInformation.of(Object.class)));
+ }
+
+ sequenceNumberKState =
+ runtimeContext.getState(
+ new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class));
+
+ actionTasksKState =
+ runtimeContext.getListState(
+ new ListStateDescriptor<>(
+ "actionTasks", TypeInformation.of(ActionTask.class)));
+ pendingInputEventsKState =
+ runtimeContext.getListState(
+ new ListStateDescriptor<>(
+ PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)));
+
+ // We use UnionList here to ensure that the task can access all keys after parallelism
+ // modifications.
+ currentProcessingKeysOpState =
+ operatorStateBackend.getUnionListState(
+ new ListStateDescriptor<>(
+ "currentProcessingKeys", TypeInformation.of(Object.class)));
+ }
+
+ // --- Memory State Accessors ---
+
+ MapState getSensoryMemState() {
+ return sensoryMemState;
+ }
+
+ MapState getShortTermMemState() {
+ return shortTermMemState;
+ }
+
+ // --- Action Task State ---
+
+ ActionTask pollNextActionTask() throws Exception {
+ return pollFromListState(actionTasksKState);
+ }
+
+ void addActionTask(ActionTask task) throws Exception {
+ actionTasksKState.add(task);
+ }
+
+ boolean hasMoreActionTasks() throws Exception {
+ return listStateNotEmpty(actionTasksKState);
+ }
+
+ // --- Pending Input Events ---
+
+ Event pollNextPendingInputEvent() throws Exception {
+ return pollFromListState(pendingInputEventsKState);
+ }
+
+ void addPendingInputEvent(Event event) throws Exception {
+ pendingInputEventsKState.add(event);
+ }
+
+ // --- Processing Key Tracking ---
+
+ void addProcessingKey(Object key) throws Exception {
+ currentProcessingKeysOpState.add(key);
+ }
+
+ int removeProcessingKey(Object key) throws Exception {
+ return removeFromListState(currentProcessingKeysOpState, key);
+ }
+
+ Iterable getProcessingKeys() throws Exception {
+ return currentProcessingKeysOpState.get();
+ }
+
+ boolean hasProcessingKeys() throws Exception {
+ return listStateNotEmpty(currentProcessingKeysOpState);
+ }
+
+ // --- Sequence Number ---
+
+ void initOrIncSequenceNumber() throws Exception {
+ Long sequenceNumber = sequenceNumberKState.value();
+ if (sequenceNumber == null) {
+ sequenceNumberKState.update(0L);
+ } else {
+ sequenceNumberKState.update(sequenceNumber + 1);
+ }
+ }
+
+ long getSequenceNumber() throws Exception {
+ return sequenceNumberKState.value();
+ }
+
+ // --- Recovery Marker ---
+
+ /**
+ * Updates the recovery marker operator state. Only valid when the action state store is enabled
+ * (i.e., recoveryMarkerOpState was initialized).
+ */
+ void updateRecoveryMarker(Object marker) throws Exception {
+ recoveryMarkerOpState.update(List.of(marker));
+ }
+
+ /**
+ * Collects recovery markers from operator state during state initialization.
+ *
+ * @return list of recovery markers
+ */
+ List collectRecoveryMarkers(
+ OperatorStateBackend operatorStateBackend, boolean hasActionStateStore)
+ throws Exception {
+ List markers = new ArrayList<>();
+ if (hasActionStateStore) {
+ ListState recoveryMarkerState =
+ operatorStateBackend.getUnionListState(
+ new ListStateDescriptor<>(
+ RECOVERY_MARKER_STATE_NAME, TypeInformation.of(Object.class)));
+ Iterable recoveryMarkers = recoveryMarkerState.get();
+ if (recoveryMarkers != null) {
+ recoveryMarkers.forEach(markers::add);
+ }
+ }
+ return markers;
+ }
+
+ /**
+ * Iterates over all keys that have pending input events and passes each key to the given
+ * consumer. Used during recovery to re-register pending keys with the segmented queue.
+ */
+ void forEachPendingInputEventKey(
+ KeyedStateBackend> keyedStateBackend, Consumer action) throws Exception {
+ keyedStateBackend.applyToAllKeys(
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ new ListStateDescriptor<>(
+ PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)),
+ (key, state) -> state.get().forEach(event -> action.accept(key)));
+ }
+
+ /**
+ * Snapshots sequence numbers for all keys. Returns a map of key to sequence number for the
+ * given checkpoint.
+ */
+ Map snapshotSequenceNumbers(KeyedStateBackend> keyedStateBackend)
+ throws Exception {
+ HashMap keyToSeqNum = new HashMap<>();
+ keyedStateBackend.applyToAllKeys(
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class),
+ (key, state) -> keyToSeqNum.put(key, state.value()));
+ return keyToSeqNum;
+ }
+}
diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java
new file mode 100644
index 00000000..4105efc4
--- /dev/null
+++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.operator;
+
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.agents.plan.PythonFunction;
+import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider;
+import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment;
+import org.apache.flink.agents.runtime.env.PythonEnvironmentManager;
+import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
+import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl;
+import org.apache.flink.agents.runtime.python.utils.JavaResourceAdapter;
+import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
+import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.cache.DistributedCache;
+import org.apache.flink.python.env.PythonDependencyInfo;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import pemja.core.PythonInterpreter;
+
+import java.util.HashMap;
+import java.util.function.BiFunction;
+
+/**
+ * Manages the Python execution environment, interpreter, action executor, and resource adapters.
+ *
+ * This class is only active when the agent plan contains Python actions or Python resources.
+ * When no Python code is present, all fields remain null and {@link #getPythonActionExecutor()}
+ * returns null.
+ */
+class PythonBridgeManager {
+
+ private static final Logger LOG = LoggerFactory.getLogger(PythonBridgeManager.class);
+
+ private PythonEnvironmentManager pythonEnvironmentManager;
+ private PythonInterpreter pythonInterpreter;
+ private PythonActionExecutor pythonActionExecutor;
+ private PythonRunnerContextImpl pythonRunnerContext;
+
+ /** Returns the Python action executor, or null if no Python actions exist. */
+ PythonActionExecutor getPythonActionExecutor() {
+ return pythonActionExecutor;
+ }
+
+ /** Returns the Python runner context, or null if no Python components exist. */
+ PythonRunnerContextImpl getPythonRunnerContext() {
+ return pythonRunnerContext;
+ }
+
+ /**
+ * Initializes the Python environment, interpreter, action executor, and resource adapters if
+ * the agent plan contains Python actions or resources.
+ */
+ void initPythonEnvironment(
+ AgentPlan agentPlan,
+ ExecutionConfig executionConfig,
+ DistributedCache distributedCache,
+ String[] tmpDirectories,
+ JobID jobId,
+ FlinkAgentsMetricGroupImpl metricGroup,
+ Runnable mailboxThreadChecker,
+ String jobIdentifier)
+ throws Exception {
+ boolean containPythonAction =
+ agentPlan.getActions().values().stream()
+ .anyMatch(action -> action.getExec() instanceof PythonFunction);
+
+ boolean containPythonResource =
+ agentPlan.getResourceProviders().values().stream()
+ .anyMatch(
+ resourceProviderMap ->
+ resourceProviderMap.values().stream()
+ .anyMatch(
+ resourceProvider ->
+ resourceProvider
+ instanceof
+ PythonResourceProvider));
+
+ if (!containPythonAction && !containPythonResource) {
+ return;
+ }
+
+ LOG.debug("Begin initialize PythonEnvironmentManager.");
+ PythonDependencyInfo dependencyInfo =
+ PythonDependencyInfo.create(executionConfig.toConfiguration(), distributedCache);
+ pythonEnvironmentManager =
+ new PythonEnvironmentManager(
+ dependencyInfo, tmpDirectories, new HashMap<>(System.getenv()), jobId);
+ pythonEnvironmentManager.open();
+ EmbeddedPythonEnvironment env = pythonEnvironmentManager.createEnvironment();
+ pythonInterpreter = env.getInterpreter();
+ pythonRunnerContext =
+ new PythonRunnerContextImpl(
+ metricGroup, mailboxThreadChecker, agentPlan, jobIdentifier);
+
+ BiFunction resourceResolver =
+ (name, type) -> getResource(agentPlan, name, type);
+ JavaResourceAdapter javaResourceAdapter =
+ new JavaResourceAdapter(resourceResolver, pythonInterpreter);
+ if (containPythonResource) {
+ PythonResourceAdapterImpl pythonResourceAdapter =
+ new PythonResourceAdapterImpl(
+ resourceResolver, pythonInterpreter, javaResourceAdapter);
+ pythonResourceAdapter.open();
+ agentPlan.setPythonResourceAdapter(pythonResourceAdapter);
+ }
+ if (containPythonAction) {
+ pythonActionExecutor =
+ new PythonActionExecutor(
+ pythonInterpreter,
+ agentPlan,
+ javaResourceAdapter,
+ pythonRunnerContext,
+ jobIdentifier);
+ pythonActionExecutor.open();
+ }
+ }
+
+ private static Resource getResource(AgentPlan agentPlan, String name, ResourceType type) {
+ try {
+ return agentPlan.getResource(name, type);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /** Closes all Python-related resources in the correct order. */
+ void close() throws Exception {
+ if (pythonActionExecutor != null) {
+ pythonActionExecutor.close();
+ }
+ if (pythonInterpreter != null) {
+ pythonInterpreter.close();
+ }
+ if (pythonEnvironmentManager != null) {
+ pythonEnvironmentManager.close();
+ }
+ }
+}
diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index 646d1e63..3add3e35 100644
--- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -184,12 +184,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)),
ActionExecutionOperator operator =
(ActionExecutionOperator) testHarness.getOperator();
- // Use reflection to access the action state store for validation
- Field actionStateStoreField =
- ActionExecutionOperator.class.getDeclaredField("actionStateStore");
- actionStateStoreField.setAccessible(true);
InMemoryActionStateStore actionStateStore =
- (InMemoryActionStateStore) actionStateStoreField.get(operator);
+ (InMemoryActionStateStore)
+ operator.getDurableExecutionManager().getActionStateStore();
assertThat(actionStateStore).isNotNull();
assertThat(actionStateStore.getKeyedActionStates()).isEmpty();
@@ -244,9 +241,10 @@ void testEventLogBaseDirFromAgentConfig() throws Exception {
testHarness.open();
ActionExecutionOperator operator =
(ActionExecutionOperator) testHarness.getOperator();
- Field eventLoggerField = ActionExecutionOperator.class.getDeclaredField("eventLogger");
+ EventRouter eventRouterObj = operator.getEventRouter();
+ Field eventLoggerField = EventRouter.class.getDeclaredField("eventLogger");
eventLoggerField.setAccessible(true);
- Object eventLogger = eventLoggerField.get(operator);
+ Object eventLogger = eventLoggerField.get(eventRouterObj);
assertThat(eventLogger).isInstanceOf(FileEventLogger.class);
Field configField = FileEventLogger.class.getDeclaredField("config");
@@ -276,12 +274,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)),
ActionExecutionOperator operator =
(ActionExecutionOperator) testHarness.getOperator();
- // Use reflection to access the action state store for validation
- Field actionStateStoreField =
- ActionExecutionOperator.class.getDeclaredField("actionStateStore");
- actionStateStoreField.setAccessible(true);
InMemoryActionStateStore actionStateStore =
- (InMemoryActionStateStore) actionStateStoreField.get(operator);
+ (InMemoryActionStateStore)
+ operator.getDurableExecutionManager().getActionStateStore();
Long inputValue = 3L;
testHarness.processElement(new StreamRecord<>(inputValue));
@@ -352,11 +347,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)),
(ActionExecutionOperator) testHarness.getOperator();
// Access the action state store
- java.lang.reflect.Field actionStateStoreField =
- ActionExecutionOperator.class.getDeclaredField("actionStateStore");
- actionStateStoreField.setAccessible(true);
InMemoryActionStateStore actionStateStore =
- (InMemoryActionStateStore) actionStateStoreField.get(operator);
+ (InMemoryActionStateStore)
+ operator.getDurableExecutionManager().getActionStateStore();
// Process multiple elements with same key to test state persistence
testHarness.processElement(new StreamRecord<>(1L));
@@ -421,11 +414,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(true)),
assertThat(recordOutput.size()).isEqualTo(3);
// Access the action state store
- Field actionStateStoreField =
- ActionExecutionOperator.class.getDeclaredField("actionStateStore");
- actionStateStoreField.setAccessible(true);
InMemoryActionStateStore actionStateStore =
- (InMemoryActionStateStore) actionStateStoreField.get(operator);
+ (InMemoryActionStateStore)
+ operator.getDurableExecutionManager().getActionStateStore();
assertThat(actionStateStore.getKeyedActionStates()).isEmpty();
}
}
@@ -445,10 +436,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)),
(ActionExecutionOperator) testHarness.getOperator();
// Access the action state store
- Field actionStateStoreField =
- ActionExecutionOperator.class.getDeclaredField("actionStateStore");
- actionStateStoreField.setAccessible(true);
- actionStateStore = (InMemoryActionStateStore) actionStateStoreField.get(operator);
+ actionStateStore =
+ (InMemoryActionStateStore)
+ operator.getDurableExecutionManager().getActionStateStore();
Long inputValue = 7L;