diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index f4e58efe..80df7850 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -18,201 +18,95 @@ package org.apache.flink.agents.runtime.operator; import org.apache.flink.agents.api.Event; -import org.apache.flink.agents.api.EventContext; -import org.apache.flink.agents.api.InputEvent; -import org.apache.flink.agents.api.OutputEvent; -import org.apache.flink.agents.api.agents.AgentExecutionOptions; import org.apache.flink.agents.api.context.MemoryUpdate; -import org.apache.flink.agents.api.listener.EventListener; -import org.apache.flink.agents.api.logger.EventLogger; -import org.apache.flink.agents.api.logger.EventLoggerConfig; -import org.apache.flink.agents.api.logger.EventLoggerFactory; -import org.apache.flink.agents.api.logger.EventLoggerOpenParams; -import org.apache.flink.agents.api.resource.Resource; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.agents.plan.PythonFunction; import org.apache.flink.agents.plan.actions.Action; -import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; import org.apache.flink.agents.runtime.actionstate.ActionState; import org.apache.flink.agents.runtime.actionstate.ActionStateStore; -import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; -import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; -import org.apache.flink.agents.runtime.async.ContinuationContext; -import org.apache.flink.agents.runtime.context.ActionStatePersister; -import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; -import org.apache.flink.agents.runtime.context.RunnerContextImpl; -import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; -import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; -import org.apache.flink.agents.runtime.eventlog.FileEventLogger; -import org.apache.flink.agents.runtime.memory.CachedMemoryStore; -import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; -import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue; -import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; -import org.apache.flink.agents.runtime.python.event.PythonEvent; import org.apache.flink.agents.runtime.python.operator.PythonActionTask; -import org.apache.flink.agents.runtime.python.utils.JavaResourceAdapter; -import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; -import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl; import org.apache.flink.agents.runtime.utils.EventUtil; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.operators.MailboxExecutor; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.state.MapState; -import org.apache.flink.api.common.state.MapStateDescriptor; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.python.env.PythonDependencyInfo; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.runtime.state.VoidNamespace; -import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.BoundedOneInput; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.Output; -import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor; -import org.apache.flink.types.Row; import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import pemja.core.PythonInterpreter; import java.lang.reflect.Field; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; -import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; -import static org.apache.flink.agents.runtime.utils.StateUtil.*; import static org.apache.flink.util.Preconditions.checkState; /** * An operator that executes the actions defined in the agent. Upon receiving data from the - * upstream, it first wraps the data into an {@link InputEvent}. It then invokes the corresponding - * action that is interested in the {@link InputEvent}, and collects the output event produced by - * the action. + * upstream, it first wraps the data into an {@link org.apache.flink.agents.api.InputEvent}. It then + * invokes the corresponding action that is interested in the {@link + * org.apache.flink.agents.api.InputEvent}, and collects the output event produced by the action. * - *

For events of type {@link OutputEvent}, the data contained in the event is sent downstream. - * For all other event types, the process is repeated: the event triggers the corresponding action, - * and the resulting output event is collected for further processing. + *

For events of type {@link org.apache.flink.agents.api.OutputEvent}, the data contained in the + * event is sent downstream. For all other event types, the process is repeated: the event triggers + * the corresponding action, and the resulting output event is collected for further processing. + * + *

This operator delegates to the following package-private managers: + * + *

*/ public class ActionExecutionOperator extends AbstractStreamOperator - implements OneInputStreamOperator, BoundedOneInput, ActionStatePersister { + implements OneInputStreamOperator, BoundedOneInput { private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(ActionExecutionOperator.class); - private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker"; - private static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber"; - private static final String PENDING_INPUT_EVENT_STATE_NAME = "pendingInputEvents"; - private final AgentPlan agentPlan; private final Boolean inputIsJava; private transient StreamRecord reusedStreamRecord; - private transient MapState sensoryMemState; - - private transient MapState shortTermMemState; - - private transient PythonEnvironmentManager pythonEnvironmentManager; - - private transient PythonInterpreter pythonInterpreter; - - // PythonActionExecutor for Python actions - private transient PythonActionExecutor pythonActionExecutor; - - // RunnerContext for Python actions - private transient PythonRunnerContextImpl pythonRunnerContext; - - // PythonResourceAdapter for Python resources in Java actions - private transient PythonResourceAdapterImpl pythonResourceAdapter; - - // PythonResourceAdapter for Java resources in Python actions or Python resources - private transient JavaResourceAdapter javaResourceAdapter; - - private transient FlinkAgentsMetricGroupImpl metricGroup; - - private transient BuiltInMetrics builtInMetrics; - - private transient SegmentedQueue keySegmentQueue; - private final transient MailboxExecutor mailboxExecutor; - // RunnerContext for Java Actions - private transient RunnerContextImpl runnerContext; - - // We need to check whether the current thread is the mailbox thread using the mailbox - // processor. - // TODO: This is a temporary workaround. In the future, we should add an interface in - // MailboxExecutor to check whether a thread is a mailbox thread, rather than using reflection - // to obtain the MailboxProcessor instance and make the determination. private transient MailboxProcessor mailboxProcessor; - // An action will be split into one or more ActionTask objects. We use a state to store the - // pending ActionTasks that are waiting to be executed. - private transient ListState actionTasksKState; - - // To avoid processing different InputEvents with the same key, we use a state to store pending - // InputEvents that are waiting to be processed. - private transient ListState pendingInputEventsKState; - - // An operator state is used to track the currently processing keys. This is useful when - // receiving an EndOfInput signal, as we need to wait until all related events are fully - // processed. - private transient ListState currentProcessingKeysOpState; - - private final transient EventLogger eventLogger; - private final transient List eventListeners; - - private transient ActionStateStore actionStateStore; - private transient ValueState sequenceNumberKState; - private transient ListState recoveryMarkerOpState; - private transient Map> checkpointIdToSeqNums; - - // This in memory map keep track of the runner context for the async action task that having - // been finished - private final transient Map - actionTaskMemoryContexts; - - // This in memory map keeps track of the durable execution context for async action tasks - // that have not been finished, allowing recovery of currentCallIndex across invocations - private final transient Map - actionTaskDurableContexts; - - private final transient Map continuationContexts; + private transient FlinkAgentsMetricGroupImpl metricGroup; - private final transient Map pythonAwaitableRefs; + private transient BuiltInMetrics builtInMetrics; // Each job can only have one identifier and this identifier must be consistent across restarts. - // We cannot use job id as the identifier here because user may change job id by - // creating a savepoint, stop the job and then resume from savepoint. - // We use this identifier to control the visibility for long-term memory. - // Inspired by Apache Paimon. private transient String jobIdentifier; - private transient ContinuationActionExecutor continuationActionExecutor; + // Managers + private transient PythonBridgeManager pythonBridgeManager; + private transient DurableExecutionManager durableExecutionManager; + private transient ActionTaskContextManager contextManager; + private transient EventRouter eventRouter; + private transient OperatorStateManager stateManager; public ActionExecutionOperator( AgentPlan agentPlan, @@ -224,14 +118,9 @@ public ActionExecutionOperator( this.inputIsJava = inputIsJava; this.processingTimeService = processingTimeService; this.mailboxExecutor = mailboxExecutor; - this.eventLogger = createEventLogger(agentPlan); - this.eventListeners = new ArrayList<>(); - this.actionStateStore = actionStateStore; - this.checkpointIdToSeqNums = new HashMap<>(); - this.actionTaskMemoryContexts = new HashMap<>(); - this.actionTaskDurableContexts = new HashMap<>(); - this.continuationContexts = new HashMap<>(); - this.pythonAwaitableRefs = new HashMap<>(); + this.durableExecutionManager = new DurableExecutionManager(actionStateStore); + this.eventRouter = new EventRouter(agentPlan, inputIsJava); + this.stateManager = new OperatorStateManager(); OperatorUtils.setChainStrategy(this, ChainingStrategy.ALWAYS); } @@ -247,79 +136,44 @@ public void setup( public void open() throws Exception { super.open(); reusedStreamRecord = new StreamRecord<>(null); - // init sensoryMemState - MapStateDescriptor sensoryMemStateDescriptor = - new MapStateDescriptor<>( - "sensoryMemory", - TypeInformation.of(String.class), - TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); - sensoryMemState = getRuntimeContext().getMapState(sensoryMemStateDescriptor); - - // init shortTermMemState - MapStateDescriptor shortTermMemStateDescriptor = - new MapStateDescriptor<>( - "shortTermMemory", - TypeInformation.of(String.class), - TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); - shortTermMemState = getRuntimeContext().getMapState(shortTermMemStateDescriptor); - metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup()); - builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan); + durableExecutionManager.maybeInitActionStateStore(agentPlan); - keySegmentQueue = new SegmentedQueue(); + stateManager.initializeStates( + getRuntimeContext(), + getOperatorStateBackend(), + durableExecutionManager.getActionStateStore() != null); - maybeInitActionStateStore(); - - if (actionStateStore != null) { - // init recovery marker state for recovery marker persistence - recoveryMarkerOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - RECOVERY_MARKER_STATE_NAME, - TypeInformation.of(Object.class))); - } - // init sequence number state for per key message ordering - sequenceNumberKState = - getRuntimeContext() - .getState( - new ValueStateDescriptor<>( - MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class)); - - // init agent processing related state - actionTasksKState = - getRuntimeContext() - .getListState( - new ListStateDescriptor<>( - "actionTasks", TypeInformation.of(ActionTask.class))); - pendingInputEventsKState = - getRuntimeContext() - .getListState( - new ListStateDescriptor<>( - PENDING_INPUT_EVENT_STATE_NAME, - TypeInformation.of(Event.class))); - // We use UnionList here to ensure that the task can access all keys after parallelism - // modifications. - // Subsequent steps {@link #tryResumeProcessActionTasks} will then filter out keys that do - // not belong to the key range of current task. - currentProcessingKeysOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - "currentProcessingKeys", TypeInformation.of(Object.class))); + metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup()); + builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan); + eventRouter.setBuiltInMetrics(builtInMetrics); // init PythonActionExecutor and PythonResourceAdapter - initPythonEnvironment(); - - // init executor for Java async execution - continuationActionExecutor = - new ContinuationActionExecutor( - agentPlan.getConfig().get(AgentExecutionOptions.NUM_ASYNC_THREADS)); + pythonBridgeManager = new PythonBridgeManager(); + pythonBridgeManager.initPythonEnvironment( + agentPlan, + getExecutionConfig(), + getRuntimeContext().getDistributedCache(), + getContainingTask().getEnvironment().getTaskManagerInfo().getTmpDirectories(), + getRuntimeContext().getJobInfo().getJobId(), + metricGroup, + this::checkMailboxThread, + jobIdentifier); + eventRouter.setPythonActionExecutor(pythonBridgeManager.getPythonActionExecutor()); + + // init context manager for runner context lifecycle + contextManager = + new ActionTaskContextManager( + metricGroup, + this::checkMailboxThread, + agentPlan, + jobIdentifier, + pythonBridgeManager.getPythonRunnerContext()); mailboxProcessor = getMailboxProcessor(); // Initialize the event logger if it is set. - initEventLogger(getRuntimeContext()); + eventRouter.initEventLogger(getRuntimeContext()); // Since an operator restart may change the key range it manages due to changes in // parallelism, @@ -328,17 +182,10 @@ public void open() throws Exception { tryResumeProcessActionTasks(); } - private void initEventLogger(StreamingRuntimeContext runtimeContext) throws Exception { - if (eventLogger == null) { - return; - } - eventLogger.open(new EventLoggerOpenParams(runtimeContext)); - } - @Override public void processWatermark(Watermark mark) throws Exception { - keySegmentQueue.addWatermark(mark); - processEligibleWatermarks(); + eventRouter.getKeySegmentQueue().addWatermark(mark); + eventRouter.processEligibleWatermarks(super::processWatermark); } @Override @@ -346,21 +193,16 @@ public void processElement(StreamRecord record) throws Exception { IN input = record.getValue(); LOG.debug("Receive an element {}", input); - // wrap to InputEvent first - Event inputEvent = wrapToInputEvent(input); + Event inputEvent = eventRouter.wrapToInputEvent(input); if (record.hasTimestamp()) { inputEvent.setSourceTimestamp(record.getTimestamp()); } - keySegmentQueue.addKeyToLastSegment(getCurrentKey()); + eventRouter.getKeySegmentQueue().addKeyToLastSegment(getCurrentKey()); - if (currentKeyHasMoreActionTask()) { - // If there are already actions being processed for the current key, the newly incoming - // event should be queued and processed later. Therefore, we add it to - // pendingInputEventsState. - pendingInputEventsKState.add(inputEvent); + if (stateManager.hasMoreActionTasks()) { + stateManager.addPendingInputEvent(inputEvent); } else { - // Otherwise, the new event is processed immediately. processEvent(getCurrentKey(), inputEvent); } } @@ -370,12 +212,11 @@ public void processElement(StreamRecord record) throws Exception { * `tryProcessActionTaskForKey` to continue processing. */ private void processEvent(Object key, Event event) throws Exception { - notifyEventProcessed(event); + eventRouter.notifyEventProcessed(event); boolean isInputEvent = EventUtil.isInputEvent(event); if (EventUtil.isOutputEvent(event)) { - // If the event is an OutputEvent, we send it downstream. - OUT outputData = getOutputFromOutputEvent(event); + OUT outputData = eventRouter.getOutputFromOutputEvent(event); if (event.hasSourceTimestamp()) { output.collect(reusedStreamRecord.replace(outputData, event.getSourceTimestamp())); } else { @@ -384,45 +225,22 @@ private void processEvent(Object key, Event event) throws Exception { } } else { if (isInputEvent) { - // If the event is an InputEvent, we mark that the key is currently being processed. - currentProcessingKeysOpState.add(key); - initOrIncSequenceNumber(); + stateManager.addProcessingKey(key); + stateManager.initOrIncSequenceNumber(); } - // We then obtain the triggered action and add ActionTasks to the waiting processing - // queue. - List triggerActions = getActionsTriggeredBy(event); + List triggerActions = eventRouter.getActionsTriggeredBy(event); if (triggerActions != null && !triggerActions.isEmpty()) { for (Action triggerAction : triggerActions) { - actionTasksKState.add(createActionTask(key, triggerAction, event)); + stateManager.addActionTask(createActionTask(key, triggerAction, event)); } } } if (isInputEvent) { - // If the event is an InputEvent, we submit a new mail to try processing the actions. mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), "process action task"); } } - private void notifyEventProcessed(Event event) throws Exception { - EventContext eventContext = new EventContext(event); - if (eventLogger != null) { - // If event logging is enabled, we log the event along with its context. - eventLogger.append(eventContext, event); - // For now, we flush the event logger after each event to ensure immediate logging. - // This is a temporary solution to ensure that events are logged immediately. - // TODO: In the future, we may want to implement a more efficient batching mechanism. - eventLogger.flush(); - } - if (eventListeners != null) { - // Notify all registered event listeners about the event. - for (EventListener listener : eventListeners) { - listener.onEventProcessed(eventContext, event); - } - } - builtInMetrics.markEventProcessed(); - } - private void tryProcessActionTaskForKey(Object key) { try { processActionTaskForKey(key); @@ -440,9 +258,9 @@ private void processActionTaskForKey(Object key) throws Exception { // 1. Get an action task for the key. setCurrentKey(key); - ActionTask actionTask = pollFromListState(actionTasksKState); + ActionTask actionTask = stateManager.pollNextActionTask(); if (actionTask == null) { - int removedCount = removeFromListState(currentProcessingKeysOpState, key); + int removedCount = stateManager.removeProcessingKey(key); checkState( removedCount == 1, "Current processing key count for key " @@ -450,25 +268,29 @@ private void processActionTaskForKey(Object key) throws Exception { + " should be 1, but got " + removedCount); checkState( - keySegmentQueue.removeKey(key), + eventRouter.getKeySegmentQueue().removeKey(key), "Current key" + key + " is missing from the segmentedQueue."); - processEligibleWatermarks(); + eventRouter.processEligibleWatermarks(super::processWatermark); return; } // 2. Invoke the action task. - createAndSetRunnerContext(actionTask, key); + contextManager.createAndSetRunnerContext( + actionTask, + key, + stateManager.getSensoryMemState(), + stateManager.getShortTermMemState()); - long sequenceNumber = sequenceNumberKState.value(); + long sequenceNumber = stateManager.getSequenceNumber(); boolean isFinished; List outputEvents; Optional generatedActionTaskOpt = Optional.empty(); ActionState actionState = - maybeGetActionState(key, sequenceNumber, actionTask.action, actionTask.event); + durableExecutionManager.maybeGetActionState( + key, sequenceNumber, actionTask.action, actionTask.event); // Check if action is already completed if (actionState != null && actionState.isCompleted()) { - // Action has completed, skip execution and replay memory/events LOG.debug( "Skipping already completed action: {} for key: {}", actionTask.action.getName(), @@ -489,30 +311,27 @@ private void processActionTaskForKey(Object key) throws Exception { .set(memoryUpdate.getPath(), memoryUpdate.getValue()); } } else { - // Initialize ActionState if not exists, or use existing one for recovery if (actionState == null) { - maybeInitActionState(key, sequenceNumber, actionTask.action, actionTask.event); + durableExecutionManager.maybeInitActionState( + key, sequenceNumber, actionTask.action, actionTask.event); actionState = - maybeGetActionState( + durableExecutionManager.maybeGetActionState( key, sequenceNumber, actionTask.action, actionTask.event); } - // Set up durable execution context for fine-grained recovery - setupDurableExecutionContext(actionTask, actionState); + durableExecutionManager.setupDurableExecutionContext( + actionTask, + actionState, + sequenceNumber, + contextManager.getActionTaskDurableContexts()); ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke( getRuntimeContext().getUserCodeClassLoader(), - this.pythonActionExecutor); - - // We remove the contexts from the map after the task is processed. They will be added - // back later if the action task has a generated action task, meaning it is not - // finished. - actionTaskMemoryContexts.remove(actionTask); - actionTaskDurableContexts.remove(actionTask); - continuationContexts.remove(actionTask); - pythonAwaitableRefs.remove(actionTask); - maybePersistTaskResult( + pythonBridgeManager.getPythonActionExecutor()); + + contextManager.removeContexts(actionTask); + durableExecutionManager.maybePersistTaskResult( key, sequenceNumber, actionTask.action, @@ -531,55 +350,25 @@ private void processActionTaskForKey(Object key) throws Exception { boolean currentInputEventFinished = false; if (isFinished) { builtInMetrics.markActionExecuted(actionTask.action.getName()); - currentInputEventFinished = !currentKeyHasMoreActionTask(); + currentInputEventFinished = !stateManager.hasMoreActionTasks(); - // Persist memory to the Flink state when the action task is finished. actionTask.getRunnerContext().persistMemory(); } else { checkState( generatedActionTaskOpt.isPresent(), "ActionTask not finished, but the generated action task is null."); - // If the action task is not finished, we should get a new action task to continue the - // execution. ActionTask generatedActionTask = generatedActionTaskOpt.get(); - - // If the action task is not finished, we keep the contexts in memory for the - // next generated ActionTask to be invoked. - actionTaskMemoryContexts.put( - generatedActionTask, actionTask.getRunnerContext().getMemoryContext()); - RunnerContextImpl.DurableExecutionContext durableContext = - actionTask.getRunnerContext().getDurableExecutionContext(); - if (durableContext != null) { - actionTaskDurableContexts.put(generatedActionTask, durableContext); - } - if (actionTask.getRunnerContext() instanceof JavaRunnerContextImpl) { - continuationContexts.put( - generatedActionTask, - ((JavaRunnerContextImpl) actionTask.getRunnerContext()) - .getContinuationContext()); - } - if (actionTask.getRunnerContext() instanceof PythonRunnerContextImpl) { - String awaitableRef = - ((PythonRunnerContextImpl) actionTask.getRunnerContext()) - .getPythonAwaitableRef(); - if (awaitableRef != null) { - pythonAwaitableRefs.put(generatedActionTask, awaitableRef); - } - } - - actionTasksKState.add(generatedActionTask); + contextManager.transferContexts(actionTask, generatedActionTask); + stateManager.addActionTask(generatedActionTask); } // 3. Process the next InputEvent or next action task if (currentInputEventFinished) { - // Clean up sensory memory when a single run finished. actionTask.getRunnerContext().clearSensoryMemory(); - // Once all sub-events and actions related to the current InputEvent are completed, - // we can proceed to process the next InputEvent. - int removedCount = removeFromListState(currentProcessingKeysOpState, key); - maybePruneState(key, sequenceNumber); + int removedCount = stateManager.removeProcessingKey(key); + durableExecutionManager.maybePruneState(key, sequenceNumber); checkState( removedCount == 1, "Current processing key count for key " @@ -587,106 +376,18 @@ private void processActionTaskForKey(Object key) throws Exception { + " should be 1, but got " + removedCount); checkState( - keySegmentQueue.removeKey(key), + eventRouter.getKeySegmentQueue().removeKey(key), "Current key" + key + " is missing from the segmentedQueue."); - processEligibleWatermarks(); - Event pendingInputEvent = pollFromListState(pendingInputEventsKState); + eventRouter.processEligibleWatermarks(super::processWatermark); + Event pendingInputEvent = stateManager.pollNextPendingInputEvent(); if (pendingInputEvent != null) { processEvent(key, pendingInputEvent); } - } else if (currentKeyHasMoreActionTask()) { - // If the current key has additional action tasks remaining, we should submit a new mail - // to continue processing them. + } else if (stateManager.hasMoreActionTasks()) { mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), "process action task"); } } - private Resource getResource(String name, ResourceType type) { - try { - return agentPlan.getResource(name, type); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - private void initPythonEnvironment() throws Exception { - boolean containPythonAction = - agentPlan.getActions().values().stream() - .anyMatch(action -> action.getExec() instanceof PythonFunction); - - boolean containPythonResource = - agentPlan.getResourceProviders().values().stream() - .anyMatch( - resourceProviderMap -> - resourceProviderMap.values().stream() - .anyMatch( - resourceProvider -> - resourceProvider - instanceof - PythonResourceProvider)); - - if (containPythonAction || containPythonResource) { - LOG.debug("Begin initialize PythonEnvironmentManager."); - PythonDependencyInfo dependencyInfo = - PythonDependencyInfo.create( - getExecutionConfig().toConfiguration(), - getRuntimeContext().getDistributedCache()); - pythonEnvironmentManager = - new PythonEnvironmentManager( - dependencyInfo, - getContainingTask() - .getEnvironment() - .getTaskManagerInfo() - .getTmpDirectories(), - new HashMap<>(System.getenv()), - getRuntimeContext().getJobInfo().getJobId()); - pythonEnvironmentManager.open(); - EmbeddedPythonEnvironment env = pythonEnvironmentManager.createEnvironment(); - pythonInterpreter = env.getInterpreter(); - pythonRunnerContext = - new PythonRunnerContextImpl( - this.metricGroup, - this::checkMailboxThread, - this.agentPlan, - this.jobIdentifier); - - javaResourceAdapter = new JavaResourceAdapter(this::getResource, pythonInterpreter); - if (containPythonResource) { - initPythonResourceAdapter(); - } - if (containPythonAction) { - initPythonActionExecutor(); - } - } - } - - private void initPythonActionExecutor() throws Exception { - pythonActionExecutor = - new PythonActionExecutor( - pythonInterpreter, - agentPlan, - javaResourceAdapter, - pythonRunnerContext, - jobIdentifier); - pythonActionExecutor.open(); - } - - private void initPythonResourceAdapter() throws Exception { - pythonResourceAdapter = - new PythonResourceAdapterImpl( - (String anotherName, ResourceType anotherType) -> { - try { - return agentPlan.getResource(anotherName, anotherType); - } catch (Exception e) { - throw new RuntimeException(e); - } - }, - pythonInterpreter, - javaResourceAdapter); - pythonResourceAdapter.open(); - agentPlan.setPythonResourceAdapter(pythonResourceAdapter); - } - @Override public void endInput() throws Exception { waitInFlightEventsFinished(); @@ -694,40 +395,24 @@ public void endInput() throws Exception { @VisibleForTesting public void waitInFlightEventsFinished() throws Exception { - while (listStateNotEmpty(currentProcessingKeysOpState)) { + while (stateManager.hasProcessingKeys()) { mailboxExecutor.yield(); } } @Override public void close() throws Exception { - if (runnerContext != null) { - try { - runnerContext.close(); - } finally { - runnerContext = null; - } - } - if (pythonActionExecutor != null) { - pythonActionExecutor.close(); - } - if (pythonInterpreter != null) { - pythonInterpreter.close(); - } - if (pythonEnvironmentManager != null) { - pythonEnvironmentManager.close(); - } - if (eventLogger != null) { - eventLogger.close(); + if (contextManager != null) { + contextManager.close(); } - if (actionStateStore != null) { - actionStateStore.close(); + if (pythonBridgeManager != null) { + pythonBridgeManager.close(); } - if (runnerContext != null) { - runnerContext.close(); + if (eventRouter != null) { + eventRouter.close(); } - if (continuationActionExecutor != null) { - continuationActionExecutor.close(); + if (durableExecutionManager != null) { + durableExecutionManager.close(); } super.close(); @@ -737,29 +422,13 @@ public void close() throws Exception { public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - maybeInitActionStateStore(); - - if (actionStateStore != null) { - List markers = new ArrayList<>(); - - // We use UnionList here to ensure that the task can access all the recovery marker - // after - // parallelism modifications. - // The ActionStateStore will decide how to use the recovery markers. - ListState recoveryMarkerOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - RECOVERY_MARKER_STATE_NAME, - TypeInformation.of(Object.class))); - - Iterable recoveryMarkers = recoveryMarkerOpState.get(); - if (recoveryMarkers != null) { - recoveryMarkers.forEach(markers::add); - } - LOG.info("Rebuilding action state from {} recovery markers", markers.size()); - actionStateStore.rebuildState(markers); - } + durableExecutionManager.maybeInitActionStateStore(agentPlan); + + List markers = + stateManager.collectRecoveryMarkers( + getOperatorStateBackend(), + durableExecutionManager.getActionStateStore() != null); + durableExecutionManager.rebuildStateFromMarkers(markers); // Get job identifier from user configuration. // If not configured, get from state. @@ -774,71 +443,40 @@ public void initializeState(StateInitializationContext context) throws Exception @Override public void snapshotState(StateSnapshotContext context) throws Exception { - if (actionStateStore != null) { - Object recoveryMarker = actionStateStore.getRecoveryMarker(); + if (durableExecutionManager.getActionStateStore() != null) { + Object recoveryMarker = durableExecutionManager.getRecoveryMarker(); if (recoveryMarker != null) { - recoveryMarkerOpState.update(List.of(recoveryMarker)); + stateManager.updateRecoveryMarker(recoveryMarker); } - } - HashMap keyToSeqNum = new HashMap<>(); - getKeyedStateBackend() - .applyToAllKeys( - VoidNamespace.INSTANCE, - VoidNamespaceSerializer.INSTANCE, - new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), - (key, state) -> keyToSeqNum.put(key, state.value())); - checkpointIdToSeqNums.put(context.getCheckpointId(), keyToSeqNum); + Map keyToSeqNum = + stateManager.snapshotSequenceNumbers(getKeyedStateBackend()); + durableExecutionManager.recordCheckpointSeqNums(context.getCheckpointId(), keyToSeqNum); + } super.snapshotState(context); } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { - if (actionStateStore != null) { - Map keyToSeqNum = - checkpointIdToSeqNums.getOrDefault(checkpointId, new HashMap<>()); - for (Map.Entry entry : keyToSeqNum.entrySet()) { - actionStateStore.pruneState(entry.getKey(), entry.getValue()); - } - checkpointIdToSeqNums.remove(checkpointId); - } + durableExecutionManager.notifyCheckpointComplete(checkpointId); super.notifyCheckpointComplete(checkpointId); } - private Event wrapToInputEvent(IN input) { - if (inputIsJava) { - return new InputEvent(input); - } else { - // the input data must originate from Python and be of type Row with two fields — the - // first representing the key, and the second representing the actual data payload. - checkState(input instanceof Row && ((Row) input).getArity() == 2); - return pythonActionExecutor.wrapToInputEvent(((Row) input).getField(1)); - } - } + // --- Test support --- - private OUT getOutputFromOutputEvent(Event event) { - checkState(EventUtil.isOutputEvent(event)); - if (event instanceof OutputEvent) { - return (OUT) ((OutputEvent) event).getOutput(); - } else if (event instanceof PythonEvent) { - Object outputFromOutputEvent = - pythonActionExecutor.getOutputFromOutputEvent(((PythonEvent) event).getEvent()); - return (OUT) outputFromOutputEvent; - } else { - throw new IllegalStateException( - "Unsupported event type: " + event.getClass().getName()); - } + @VisibleForTesting + DurableExecutionManager getDurableExecutionManager() { + return durableExecutionManager; } - private List getActionsTriggeredBy(Event event) { - if (event instanceof PythonEvent) { - return agentPlan.getActionsTriggeredBy(((PythonEvent) event).getEventType()); - } else { - return agentPlan.getActionsTriggeredBy(event.getClass().getName()); - } + @VisibleForTesting + EventRouter getEventRouter() { + return eventRouter; } + // --- Private helpers --- + private MailboxProcessor getMailboxProcessor() throws Exception { Field field = MailboxExecutorImpl.class.getDeclaredField("mailboxProcessor"); field.setAccessible(true); @@ -862,264 +500,19 @@ private ActionTask createActionTask(Object key, Action action, Event event) { } } - private void createAndSetRunnerContext(ActionTask actionTask, Object key) { - RunnerContextImpl runnerContext; - if (actionTask.action.getExec() instanceof JavaFunction) { - runnerContext = createOrGetRunnerContext(true); - } else if (actionTask.action.getExec() instanceof PythonFunction) { - runnerContext = createOrGetRunnerContext(false); - } else { - throw new IllegalStateException( - "Unsupported action type: " + actionTask.action.getExec().getClass()); - } - - RunnerContextImpl.MemoryContext memoryContext; - if (actionTaskMemoryContexts.containsKey(actionTask)) { - // action task for async execution action, should retrieve intermediate results from - // map. - memoryContext = actionTaskMemoryContexts.get(actionTask); - } else { - memoryContext = - new RunnerContextImpl.MemoryContext( - new CachedMemoryStore(sensoryMemState), - new CachedMemoryStore(shortTermMemState)); - } - - runnerContext.switchActionContext( - actionTask.action.getName(), memoryContext, String.valueOf(key.hashCode())); - - if (runnerContext instanceof JavaRunnerContextImpl) { - ContinuationContext continuationContext; - if (continuationContexts.containsKey(actionTask)) { - // action task for async execution action, should retrieve intermediate results from - // map. - continuationContext = continuationContexts.get(actionTask); - } else { - continuationContext = new ContinuationContext(); - } - ((JavaRunnerContextImpl) runnerContext).setContinuationContext(continuationContext); - } - if (runnerContext instanceof PythonRunnerContextImpl) { - // Get the awaitable ref from the transient map. After checkpoint restore, this will be - // null, signaling that the awaitable was lost and needs re-execution. - String awaitableRef = pythonAwaitableRefs.get(actionTask); - ((PythonRunnerContextImpl) runnerContext).setPythonAwaitableRef(awaitableRef); - } - actionTask.setRunnerContext(runnerContext); - } - - private boolean currentKeyHasMoreActionTask() throws Exception { - return listStateNotEmpty(actionTasksKState); - } - private void tryResumeProcessActionTasks() throws Exception { - Iterable keys = currentProcessingKeysOpState.get(); + Iterable keys = stateManager.getProcessingKeys(); if (keys != null) { for (Object key : keys) { - keySegmentQueue.addKeyToLastSegment(key); + eventRouter.getKeySegmentQueue().addKeyToLastSegment(key); mailboxExecutor.submit( () -> tryProcessActionTaskForKey(key), "process action task"); } } - getKeyedStateBackend() - .applyToAllKeys( - VoidNamespace.INSTANCE, - VoidNamespaceSerializer.INSTANCE, - new ListStateDescriptor<>( - PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)), - (key, state) -> - state.get() - .forEach( - event -> keySegmentQueue.addKeyToLastSegment(key))); - } - - private void initOrIncSequenceNumber() throws Exception { - // Initialize the sequence number state if it does not exist. - Long sequenceNumber = sequenceNumberKState.value(); - if (sequenceNumber == null) { - sequenceNumberKState.update(0L); - } else { - sequenceNumberKState.update(sequenceNumber + 1); - } - } - - private ActionState maybeGetActionState( - Object key, long sequenceNum, Action action, Event event) throws Exception { - return actionStateStore == null - ? null - : actionStateStore.get(key.toString(), sequenceNum, action, event); - } - - private void maybeInitActionState(Object key, long sequenceNum, Action action, Event event) - throws Exception { - if (actionStateStore != null) { - // Initialize the action state if it does not exist. It will exist when the action is an - // async action and - // has been persisted before the action task is finished. - if (actionStateStore.get(key, sequenceNum, action, event) == null) { - actionStateStore.put(key, sequenceNum, action, event, new ActionState(event)); - } - } - } - - private void maybePersistTaskResult( - Object key, - long sequenceNum, - Action action, - Event event, - RunnerContextImpl context, - ActionTask.ActionTaskResult actionTaskResult) - throws Exception { - if (actionStateStore == null) { - return; - } - - // if the task is not finished, we skip the persistence for now and wait until it is - // finished. - if (!actionTaskResult.isFinished()) { - return; - } - - ActionState actionState = actionStateStore.get(key, sequenceNum, action, event); - - for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) { - actionState.addSensoryMemoryUpdate(memoryUpdate); - } - - for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) { - actionState.addShortTermMemoryUpdate(memoryUpdate); - } - - for (Event outputEvent : actionTaskResult.getOutputEvents()) { - actionState.addEvent(outputEvent); - } - - // Mark the action as completed and clear call records - // This indicates that recovery should skip the entire action - actionState.markCompleted(); - - actionStateStore.put(key, sequenceNum, action, event, actionState); - - // Clear durable execution context - context.clearDurableExecutionContext(); - } - - /** - * Sets up the durable execution context for fine-grained recovery. - * - *

This method initializes the runner context with a {@link - * RunnerContextImpl.DurableExecutionContext}, which enables execute/execute_async calls to: - * - *

    - *
  • Skip re-execution for already completed calls during recovery - *
  • Persist CallRecords after each code block completion - *
- */ - private void setupDurableExecutionContext(ActionTask actionTask, ActionState actionState) { - if (actionStateStore == null) { - return; - } - - RunnerContextImpl.DurableExecutionContext durableContext; - if (actionTaskDurableContexts.containsKey(actionTask)) { - // Reuse existing context for async action continuation - durableContext = actionTaskDurableContexts.get(actionTask); - } else { - // Create new context for first invocation - final long sequenceNumber; - try { - sequenceNumber = sequenceNumberKState.value(); - } catch (Exception e) { - throw new RuntimeException("Failed to get sequence number from state", e); - } - - durableContext = - new RunnerContextImpl.DurableExecutionContext( - actionTask.getKey(), - sequenceNumber, - actionTask.action, - actionTask.event, - actionState, - this); - } - - actionTask.getRunnerContext().setDurableExecutionContext(durableContext); - } - - @Override - public void persist( - Object key, long sequenceNumber, Action action, Event event, ActionState actionState) { - try { - actionStateStore.put(key, sequenceNumber, action, event, actionState); - } catch (Exception e) { - LOG.error("Failed to persist ActionState", e); - throw new RuntimeException("Failed to persist ActionState", e); - } - } - - private void maybePruneState(Object key, long sequenceNum) throws Exception { - if (actionStateStore != null) { - actionStateStore.pruneState(key, sequenceNum); - } - } - - private void processEligibleWatermarks() throws Exception { - Watermark mark = keySegmentQueue.popOldestWatermark(); - while (mark != null) { - super.processWatermark(mark); - mark = keySegmentQueue.popOldestWatermark(); - } - } - - private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) { - if (isJava) { - if (runnerContext == null) { - if (continuationActionExecutor == null) { - continuationActionExecutor = - new ContinuationActionExecutor( - agentPlan - .getConfig() - .get(AgentExecutionOptions.NUM_ASYNC_THREADS)); - } - runnerContext = - new JavaRunnerContextImpl( - this.metricGroup, - this::checkMailboxThread, - this.agentPlan, - this.jobIdentifier, - continuationActionExecutor); - } - return runnerContext; - } else { - if (pythonRunnerContext == null) { - pythonRunnerContext = - new PythonRunnerContextImpl( - this.metricGroup, - this::checkMailboxThread, - this.agentPlan, - jobIdentifier); - } - return pythonRunnerContext; - } - } - - private EventLogger createEventLogger(AgentPlan agentPlan) { - EventLoggerConfig.Builder loggerConfigBuilder = EventLoggerConfig.builder(); - String baseLogDir = agentPlan.getConfig().get(BASE_LOG_DIR); - if (baseLogDir != null && !baseLogDir.trim().isEmpty()) { - loggerConfigBuilder.property(FileEventLogger.BASE_LOG_DIR_PROPERTY_KEY, baseLogDir); - } - return EventLoggerFactory.createLogger(loggerConfigBuilder.build()); - } - - private void maybeInitActionStateStore() { - if (actionStateStore == null - && KAFKA.getType() - .equalsIgnoreCase(agentPlan.getConfig().get(ACTION_STATE_STORE_BACKEND))) { - LOG.info("Using Kafka as backend of action state store."); - actionStateStore = new KafkaActionStateStore(agentPlan.getConfig()); - } + stateManager.forEachPendingInputEventKey( + getKeyedStateBackend(), + key -> eventRouter.getKeySegmentQueue().addKeyToLastSegment(key)); } /** Failed to execute Action task. */ diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java new file mode 100644 index 00000000..c3a5e701 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.agents.AgentExecutionOptions; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; +import org.apache.flink.agents.runtime.async.ContinuationContext; +import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.apache.flink.agents.runtime.memory.CachedMemoryStore; +import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; +import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; +import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; +import org.apache.flink.api.common.state.MapState; + +import java.util.HashMap; +import java.util.Map; + +/** + * Manages runner context lifecycle for action tasks, including creation, context switching for + * async continuation, and cleanup. + * + *

Holds four transient maps that track intermediate context state across async action task + * invocations: memory contexts, durable execution contexts, continuation contexts, and Python + * awaitable references. + */ +class ActionTaskContextManager { + + // Invariant dependencies (set once at construction, never change) + private final FlinkAgentsMetricGroupImpl metricGroup; + private final Runnable mailboxThreadChecker; + private final AgentPlan agentPlan; + private final String jobIdentifier; + private final PythonRunnerContextImpl pythonRunnerContext; + + // Lazily created singleton runner context for Java actions + private RunnerContextImpl javaRunnerContext; + + // Transient maps for in-flight async action tasks + private final Map actionTaskMemoryContexts; + private final Map + actionTaskDurableContexts; + private final Map continuationContexts; + private final Map pythonAwaitableRefs; + + private final ContinuationActionExecutor continuationActionExecutor; + + ActionTaskContextManager( + FlinkAgentsMetricGroupImpl metricGroup, + Runnable mailboxThreadChecker, + AgentPlan agentPlan, + String jobIdentifier, + PythonRunnerContextImpl pythonRunnerContext) { + this.metricGroup = metricGroup; + this.mailboxThreadChecker = mailboxThreadChecker; + this.agentPlan = agentPlan; + this.jobIdentifier = jobIdentifier; + this.pythonRunnerContext = pythonRunnerContext; + this.continuationActionExecutor = + new ContinuationActionExecutor( + agentPlan.getConfig().get(AgentExecutionOptions.NUM_ASYNC_THREADS)); + this.actionTaskMemoryContexts = new HashMap<>(); + this.actionTaskDurableContexts = new HashMap<>(); + this.continuationContexts = new HashMap<>(); + this.pythonAwaitableRefs = new HashMap<>(); + } + + Map getActionTaskDurableContexts() { + return actionTaskDurableContexts; + } + + /** + * Creates or retrieves the appropriate runner context for the given action task and configures + * it with the correct memory context, continuation context, and awaitable ref. + * + *

If the task has previously cached contexts (from an unfinished async action), those are + * restored. Otherwise, fresh contexts are created from Flink state. + */ + void createAndSetRunnerContext( + ActionTask actionTask, + Object key, + MapState sensoryMemState, + MapState shortTermMemState) { + RunnerContextImpl ctx; + if (actionTask.action.getExec() instanceof JavaFunction) { + ctx = getOrCreateJavaRunnerContext(); + } else if (actionTask.action.getExec() instanceof PythonFunction) { + ctx = getOrCreatePythonRunnerContext(); + } else { + throw new IllegalStateException( + "Unsupported action type: " + actionTask.action.getExec().getClass()); + } + + RunnerContextImpl.MemoryContext memoryContext = actionTaskMemoryContexts.get(actionTask); + if (memoryContext == null) { + memoryContext = + new RunnerContextImpl.MemoryContext( + new CachedMemoryStore(sensoryMemState), + new CachedMemoryStore(shortTermMemState)); + } + + ctx.switchActionContext( + actionTask.action.getName(), memoryContext, String.valueOf(key.hashCode())); + + if (ctx instanceof JavaRunnerContextImpl) { + ContinuationContext continuationContext = continuationContexts.get(actionTask); + if (continuationContext == null) { + continuationContext = new ContinuationContext(); + } + ((JavaRunnerContextImpl) ctx).setContinuationContext(continuationContext); + } + if (ctx instanceof PythonRunnerContextImpl) { + String awaitableRef = pythonAwaitableRefs.get(actionTask); + ((PythonRunnerContextImpl) ctx).setPythonAwaitableRef(awaitableRef); + } + actionTask.setRunnerContext(ctx); + } + + /** + * Removes cached contexts for the given action task after it completes execution. Called after + * each action task invocation. + */ + void removeContexts(ActionTask actionTask) { + actionTaskMemoryContexts.remove(actionTask); + actionTaskDurableContexts.remove(actionTask); + continuationContexts.remove(actionTask); + pythonAwaitableRefs.remove(actionTask); + } + + /** + * Transfers context state from a completed (but not finished) action task to its generated + * continuation task. This preserves memory, durable execution, continuation, and awaitable + * state across async invocations. + */ + void transferContexts(ActionTask fromTask, ActionTask toTask) { + actionTaskMemoryContexts.put(toTask, fromTask.getRunnerContext().getMemoryContext()); + RunnerContextImpl.DurableExecutionContext durableContext = + fromTask.getRunnerContext().getDurableExecutionContext(); + if (durableContext != null) { + actionTaskDurableContexts.put(toTask, durableContext); + } + if (fromTask.getRunnerContext() instanceof JavaRunnerContextImpl) { + continuationContexts.put( + toTask, + ((JavaRunnerContextImpl) fromTask.getRunnerContext()).getContinuationContext()); + } + if (fromTask.getRunnerContext() instanceof PythonRunnerContextImpl) { + String awaitableRef = + ((PythonRunnerContextImpl) fromTask.getRunnerContext()).getPythonAwaitableRef(); + if (awaitableRef != null) { + pythonAwaitableRefs.put(toTask, awaitableRef); + } + } + } + + private RunnerContextImpl getOrCreateJavaRunnerContext() { + if (javaRunnerContext == null) { + javaRunnerContext = + new JavaRunnerContextImpl( + metricGroup, + mailboxThreadChecker, + agentPlan, + jobIdentifier, + continuationActionExecutor); + } + return javaRunnerContext; + } + + private RunnerContextImpl getOrCreatePythonRunnerContext() { + if (pythonRunnerContext == null) { + throw new IllegalStateException( + "PythonRunnerContext should have been initialized by PythonBridgeManager"); + } + return pythonRunnerContext; + } + + void close() throws Exception { + if (javaRunnerContext != null) { + try { + javaRunnerContext.close(); + } finally { + javaRunnerContext = null; + } + } + continuationActionExecutor.close(); + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java new file mode 100644 index 00000000..f3315990 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.context.MemoryUpdate; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.ActionStateStore; +import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; +import org.apache.flink.agents.runtime.context.ActionStatePersister; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; +import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; + +/** + * Manages durable execution state including the {@link ActionStateStore}, action state persistence, + * recovery markers, and checkpoint-based state pruning. + * + *

When no {@link ActionStateStore} is configured (the common case for simple agents), all + * methods are no-ops. This class implements {@link ActionStatePersister} so it can be passed to + * {@link RunnerContextImpl.DurableExecutionContext} for fine-grained call-level recovery. + */ +class DurableExecutionManager implements ActionStatePersister { + + private static final Logger LOG = LoggerFactory.getLogger(DurableExecutionManager.class); + + private ActionStateStore actionStateStore; + private final Map> checkpointIdToSeqNums; + + DurableExecutionManager(ActionStateStore actionStateStore) { + this.actionStateStore = actionStateStore; + this.checkpointIdToSeqNums = new HashMap<>(); + } + + ActionStateStore getActionStateStore() { + return actionStateStore; + } + + /** Initializes the Kafka action state store if configured but not yet created. */ + void maybeInitActionStateStore(AgentPlan agentPlan) { + if (actionStateStore == null + && KAFKA.getType() + .equalsIgnoreCase(agentPlan.getConfig().get(ACTION_STATE_STORE_BACKEND))) { + LOG.info("Using Kafka as backend of action state store."); + actionStateStore = new KafkaActionStateStore(agentPlan.getConfig()); + } + } + + /** Rebuilds state from recovery markers during state initialization. */ + void rebuildStateFromMarkers(List markers) throws Exception { + if (actionStateStore != null) { + LOG.info("Rebuilding action state from {} recovery markers", markers.size()); + actionStateStore.rebuildState(markers); + } + } + + ActionState maybeGetActionState(Object key, long sequenceNum, Action action, Event event) + throws Exception { + return actionStateStore == null + ? null + : actionStateStore.get(key.toString(), sequenceNum, action, event); + } + + void maybeInitActionState(Object key, long sequenceNum, Action action, Event event) + throws Exception { + if (actionStateStore != null) { + if (actionStateStore.get(key, sequenceNum, action, event) == null) { + actionStateStore.put(key, sequenceNum, action, event, new ActionState(event)); + } + } + } + + void maybePersistTaskResult( + Object key, + long sequenceNum, + Action action, + Event event, + RunnerContextImpl context, + ActionTask.ActionTaskResult actionTaskResult) + throws Exception { + if (actionStateStore == null || !actionTaskResult.isFinished()) { + return; + } + + ActionState actionState = actionStateStore.get(key, sequenceNum, action, event); + + for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) { + actionState.addSensoryMemoryUpdate(memoryUpdate); + } + + for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) { + actionState.addShortTermMemoryUpdate(memoryUpdate); + } + + for (Event outputEvent : actionTaskResult.getOutputEvents()) { + actionState.addEvent(outputEvent); + } + + actionState.markCompleted(); + actionStateStore.put(key, sequenceNum, action, event, actionState); + context.clearDurableExecutionContext(); + } + + /** + * Sets up the durable execution context for fine-grained recovery. + * + *

This method initializes the runner context with a {@link + * RunnerContextImpl.DurableExecutionContext}, which enables execute/execute_async calls to: + * + *

    + *
  • Skip re-execution for already completed calls during recovery + *
  • Persist CallRecords after each code block completion + *
+ */ + void setupDurableExecutionContext( + ActionTask actionTask, + ActionState actionState, + long sequenceNumber, + Map actionTaskDurableContexts) { + if (actionStateStore == null) { + return; + } + + RunnerContextImpl.DurableExecutionContext durableContext = + actionTaskDurableContexts.get(actionTask); + if (durableContext == null) { + durableContext = + new RunnerContextImpl.DurableExecutionContext( + actionTask.getKey(), + sequenceNumber, + actionTask.action, + actionTask.event, + actionState, + this); + } + + actionTask.getRunnerContext().setDurableExecutionContext(durableContext); + } + + @Override + public void persist( + Object key, long sequenceNumber, Action action, Event event, ActionState actionState) { + try { + actionStateStore.put(key, sequenceNumber, action, event, actionState); + } catch (Exception e) { + LOG.error("Failed to persist ActionState", e); + throw new RuntimeException("Failed to persist ActionState", e); + } + } + + void maybePruneState(Object key, long sequenceNum) throws Exception { + if (actionStateStore != null) { + actionStateStore.pruneState(key, sequenceNum); + } + } + + /** Records sequence numbers per key at snapshot time for later pruning. */ + void recordCheckpointSeqNums(long checkpointId, Map keyToSeqNum) { + checkpointIdToSeqNums.put(checkpointId, keyToSeqNum); + } + + /** Prunes state for completed checkpoints. */ + void notifyCheckpointComplete(long checkpointId) throws Exception { + Map keyToSeqNum = checkpointIdToSeqNums.remove(checkpointId); + if (actionStateStore != null && keyToSeqNum != null) { + for (Map.Entry entry : keyToSeqNum.entrySet()) { + actionStateStore.pruneState(entry.getKey(), entry.getValue()); + } + } + } + + /** Gets recovery marker from the action state store, if available. */ + Object getRecoveryMarker() { + return actionStateStore != null ? actionStateStore.getRecoveryMarker() : null; + } + + void close() throws Exception { + if (actionStateStore != null) { + actionStateStore.close(); + } + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java new file mode 100644 index 00000000..47c063ca --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.EventContext; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.listener.EventListener; +import org.apache.flink.agents.api.logger.EventLogger; +import org.apache.flink.agents.api.logger.EventLoggerConfig; +import org.apache.flink.agents.api.logger.EventLoggerFactory; +import org.apache.flink.agents.api.logger.EventLoggerOpenParams; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.eventlog.FileEventLogger; +import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; +import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue; +import org.apache.flink.agents.runtime.python.event.PythonEvent; +import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; +import org.apache.flink.agents.runtime.utils.EventUtil; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.types.Row; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Handles event wrapping, unwrapping, routing to actions, notification (event logger and + * listeners), and watermark management via a {@link SegmentedQueue}. + */ +class EventRouter { + + private final AgentPlan agentPlan; + private final boolean inputIsJava; + private final EventLogger eventLogger; + private final List eventListeners; + private final SegmentedQueue keySegmentQueue; + private BuiltInMetrics builtInMetrics; + private PythonActionExecutor pythonActionExecutor; + + EventRouter(AgentPlan agentPlan, boolean inputIsJava) { + this.agentPlan = agentPlan; + this.inputIsJava = inputIsJava; + this.eventLogger = createEventLogger(agentPlan); + this.eventListeners = new ArrayList<>(); + this.keySegmentQueue = new SegmentedQueue(); + } + + SegmentedQueue getKeySegmentQueue() { + return keySegmentQueue; + } + + void setBuiltInMetrics(BuiltInMetrics builtInMetrics) { + this.builtInMetrics = builtInMetrics; + } + + void setPythonActionExecutor(PythonActionExecutor pythonActionExecutor) { + this.pythonActionExecutor = pythonActionExecutor; + } + + /** Initializes the event logger if it is set. */ + void initEventLogger(StreamingRuntimeContext runtimeContext) throws Exception { + if (eventLogger == null) { + return; + } + eventLogger.open(new EventLoggerOpenParams(runtimeContext)); + } + + /** Notifies event logger and listeners that an event has been processed. */ + void notifyEventProcessed(Event event) throws Exception { + EventContext eventContext = new EventContext(event); + if (eventLogger != null) { + eventLogger.append(eventContext, event); + eventLogger.flush(); + } + for (EventListener listener : eventListeners) { + listener.onEventProcessed(eventContext, event); + } + builtInMetrics.markEventProcessed(); + } + + /** Wraps raw input into an InputEvent (Java or Python). */ + @SuppressWarnings("unchecked") + Event wrapToInputEvent(IN input) { + if (inputIsJava) { + return new InputEvent(input); + } else { + checkState(input instanceof Row && ((Row) input).getArity() == 2); + return pythonActionExecutor.wrapToInputEvent(((Row) input).getField(1)); + } + } + + /** Extracts output data from an OutputEvent (Java or Python). */ + @SuppressWarnings("unchecked") + OUT getOutputFromOutputEvent(Event event) { + checkState(EventUtil.isOutputEvent(event)); + if (event instanceof OutputEvent) { + return (OUT) ((OutputEvent) event).getOutput(); + } else if (event instanceof PythonEvent) { + return (OUT) + pythonActionExecutor.getOutputFromOutputEvent(((PythonEvent) event).getEvent()); + } else { + throw new IllegalStateException( + "Unsupported event type: " + event.getClass().getName()); + } + } + + /** Gets actions triggered by a given event from the agent plan. */ + List getActionsTriggeredBy(Event event) { + if (event instanceof PythonEvent) { + return agentPlan.getActionsTriggeredBy(((PythonEvent) event).getEventType()); + } else { + return agentPlan.getActionsTriggeredBy(event.getClass().getName()); + } + } + + /** Processes and emits all eligible watermarks from the segmented queue. */ + void processEligibleWatermarks(WatermarkEmitter emitter) throws Exception { + Watermark mark = keySegmentQueue.popOldestWatermark(); + while (mark != null) { + emitter.emit(mark); + mark = keySegmentQueue.popOldestWatermark(); + } + } + + void close() throws Exception { + if (eventLogger != null) { + eventLogger.close(); + } + } + + private static EventLogger createEventLogger(AgentPlan agentPlan) { + EventLoggerConfig.Builder loggerConfigBuilder = EventLoggerConfig.builder(); + String baseLogDir = agentPlan.getConfig().get(BASE_LOG_DIR); + if (baseLogDir != null && !baseLogDir.trim().isEmpty()) { + loggerConfigBuilder.property(FileEventLogger.BASE_LOG_DIR_PROPERTY_KEY, baseLogDir); + } + return EventLoggerFactory.createLogger(loggerConfigBuilder.build()); + } + + /** Callback interface for emitting watermarks through the operator. */ + @FunctionalInterface + interface WatermarkEmitter { + void emit(Watermark watermark) throws Exception; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java new file mode 100644 index 00000000..b668804a --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import static org.apache.flink.agents.runtime.utils.StateUtil.*; + +/** + * Manages all Flink state objects used by the {@link ActionExecutionOperator}: keyed states for + * memory, action tasks, pending input events, sequence numbers, and operator states for processing + * keys and recovery markers. + */ +class OperatorStateManager { + + static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker"; + static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber"; + static final String PENDING_INPUT_EVENT_STATE_NAME = "pendingInputEvents"; + + private MapState sensoryMemState; + private MapState shortTermMemState; + private ListState actionTasksKState; + private ListState pendingInputEventsKState; + private ListState currentProcessingKeysOpState; + private ValueState sequenceNumberKState; + private ListState recoveryMarkerOpState; + + // --- State Initialization --- + + /** + * Initializes all Flink state objects. Should be called from the operator's {@code open()} + * method. + */ + void initializeStates( + StreamingRuntimeContext runtimeContext, + OperatorStateBackend operatorStateBackend, + boolean hasActionStateStore) + throws Exception { + sensoryMemState = + runtimeContext.getMapState( + new MapStateDescriptor<>( + "sensoryMemory", + TypeInformation.of(String.class), + TypeInformation.of(MemoryObjectImpl.MemoryItem.class))); + + shortTermMemState = + runtimeContext.getMapState( + new MapStateDescriptor<>( + "shortTermMemory", + TypeInformation.of(String.class), + TypeInformation.of(MemoryObjectImpl.MemoryItem.class))); + + if (hasActionStateStore) { + recoveryMarkerOpState = + operatorStateBackend.getUnionListState( + new ListStateDescriptor<>( + RECOVERY_MARKER_STATE_NAME, TypeInformation.of(Object.class))); + } + + sequenceNumberKState = + runtimeContext.getState( + new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class)); + + actionTasksKState = + runtimeContext.getListState( + new ListStateDescriptor<>( + "actionTasks", TypeInformation.of(ActionTask.class))); + pendingInputEventsKState = + runtimeContext.getListState( + new ListStateDescriptor<>( + PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class))); + + // We use UnionList here to ensure that the task can access all keys after parallelism + // modifications. + currentProcessingKeysOpState = + operatorStateBackend.getUnionListState( + new ListStateDescriptor<>( + "currentProcessingKeys", TypeInformation.of(Object.class))); + } + + // --- Memory State Accessors --- + + MapState getSensoryMemState() { + return sensoryMemState; + } + + MapState getShortTermMemState() { + return shortTermMemState; + } + + // --- Action Task State --- + + ActionTask pollNextActionTask() throws Exception { + return pollFromListState(actionTasksKState); + } + + void addActionTask(ActionTask task) throws Exception { + actionTasksKState.add(task); + } + + boolean hasMoreActionTasks() throws Exception { + return listStateNotEmpty(actionTasksKState); + } + + // --- Pending Input Events --- + + Event pollNextPendingInputEvent() throws Exception { + return pollFromListState(pendingInputEventsKState); + } + + void addPendingInputEvent(Event event) throws Exception { + pendingInputEventsKState.add(event); + } + + // --- Processing Key Tracking --- + + void addProcessingKey(Object key) throws Exception { + currentProcessingKeysOpState.add(key); + } + + int removeProcessingKey(Object key) throws Exception { + return removeFromListState(currentProcessingKeysOpState, key); + } + + Iterable getProcessingKeys() throws Exception { + return currentProcessingKeysOpState.get(); + } + + boolean hasProcessingKeys() throws Exception { + return listStateNotEmpty(currentProcessingKeysOpState); + } + + // --- Sequence Number --- + + void initOrIncSequenceNumber() throws Exception { + Long sequenceNumber = sequenceNumberKState.value(); + if (sequenceNumber == null) { + sequenceNumberKState.update(0L); + } else { + sequenceNumberKState.update(sequenceNumber + 1); + } + } + + long getSequenceNumber() throws Exception { + return sequenceNumberKState.value(); + } + + // --- Recovery Marker --- + + /** + * Updates the recovery marker operator state. Only valid when the action state store is enabled + * (i.e., recoveryMarkerOpState was initialized). + */ + void updateRecoveryMarker(Object marker) throws Exception { + recoveryMarkerOpState.update(List.of(marker)); + } + + /** + * Collects recovery markers from operator state during state initialization. + * + * @return list of recovery markers + */ + List collectRecoveryMarkers( + OperatorStateBackend operatorStateBackend, boolean hasActionStateStore) + throws Exception { + List markers = new ArrayList<>(); + if (hasActionStateStore) { + ListState recoveryMarkerState = + operatorStateBackend.getUnionListState( + new ListStateDescriptor<>( + RECOVERY_MARKER_STATE_NAME, TypeInformation.of(Object.class))); + Iterable recoveryMarkers = recoveryMarkerState.get(); + if (recoveryMarkers != null) { + recoveryMarkers.forEach(markers::add); + } + } + return markers; + } + + /** + * Iterates over all keys that have pending input events and passes each key to the given + * consumer. Used during recovery to re-register pending keys with the segmented queue. + */ + void forEachPendingInputEventKey( + KeyedStateBackend keyedStateBackend, Consumer action) throws Exception { + keyedStateBackend.applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ListStateDescriptor<>( + PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)), + (key, state) -> state.get().forEach(event -> action.accept(key))); + } + + /** + * Snapshots sequence numbers for all keys. Returns a map of key to sequence number for the + * given checkpoint. + */ + Map snapshotSequenceNumbers(KeyedStateBackend keyedStateBackend) + throws Exception { + HashMap keyToSeqNum = new HashMap<>(); + keyedStateBackend.applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), + (key, state) -> keyToSeqNum.put(key, state.value())); + return keyToSeqNum; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java new file mode 100644 index 00000000..4105efc4 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; +import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; +import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; +import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; +import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; +import org.apache.flink.agents.runtime.python.utils.JavaResourceAdapter; +import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; +import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.cache.DistributedCache; +import org.apache.flink.python.env.PythonDependencyInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pemja.core.PythonInterpreter; + +import java.util.HashMap; +import java.util.function.BiFunction; + +/** + * Manages the Python execution environment, interpreter, action executor, and resource adapters. + * + *

This class is only active when the agent plan contains Python actions or Python resources. + * When no Python code is present, all fields remain null and {@link #getPythonActionExecutor()} + * returns null. + */ +class PythonBridgeManager { + + private static final Logger LOG = LoggerFactory.getLogger(PythonBridgeManager.class); + + private PythonEnvironmentManager pythonEnvironmentManager; + private PythonInterpreter pythonInterpreter; + private PythonActionExecutor pythonActionExecutor; + private PythonRunnerContextImpl pythonRunnerContext; + + /** Returns the Python action executor, or null if no Python actions exist. */ + PythonActionExecutor getPythonActionExecutor() { + return pythonActionExecutor; + } + + /** Returns the Python runner context, or null if no Python components exist. */ + PythonRunnerContextImpl getPythonRunnerContext() { + return pythonRunnerContext; + } + + /** + * Initializes the Python environment, interpreter, action executor, and resource adapters if + * the agent plan contains Python actions or resources. + */ + void initPythonEnvironment( + AgentPlan agentPlan, + ExecutionConfig executionConfig, + DistributedCache distributedCache, + String[] tmpDirectories, + JobID jobId, + FlinkAgentsMetricGroupImpl metricGroup, + Runnable mailboxThreadChecker, + String jobIdentifier) + throws Exception { + boolean containPythonAction = + agentPlan.getActions().values().stream() + .anyMatch(action -> action.getExec() instanceof PythonFunction); + + boolean containPythonResource = + agentPlan.getResourceProviders().values().stream() + .anyMatch( + resourceProviderMap -> + resourceProviderMap.values().stream() + .anyMatch( + resourceProvider -> + resourceProvider + instanceof + PythonResourceProvider)); + + if (!containPythonAction && !containPythonResource) { + return; + } + + LOG.debug("Begin initialize PythonEnvironmentManager."); + PythonDependencyInfo dependencyInfo = + PythonDependencyInfo.create(executionConfig.toConfiguration(), distributedCache); + pythonEnvironmentManager = + new PythonEnvironmentManager( + dependencyInfo, tmpDirectories, new HashMap<>(System.getenv()), jobId); + pythonEnvironmentManager.open(); + EmbeddedPythonEnvironment env = pythonEnvironmentManager.createEnvironment(); + pythonInterpreter = env.getInterpreter(); + pythonRunnerContext = + new PythonRunnerContextImpl( + metricGroup, mailboxThreadChecker, agentPlan, jobIdentifier); + + BiFunction resourceResolver = + (name, type) -> getResource(agentPlan, name, type); + JavaResourceAdapter javaResourceAdapter = + new JavaResourceAdapter(resourceResolver, pythonInterpreter); + if (containPythonResource) { + PythonResourceAdapterImpl pythonResourceAdapter = + new PythonResourceAdapterImpl( + resourceResolver, pythonInterpreter, javaResourceAdapter); + pythonResourceAdapter.open(); + agentPlan.setPythonResourceAdapter(pythonResourceAdapter); + } + if (containPythonAction) { + pythonActionExecutor = + new PythonActionExecutor( + pythonInterpreter, + agentPlan, + javaResourceAdapter, + pythonRunnerContext, + jobIdentifier); + pythonActionExecutor.open(); + } + } + + private static Resource getResource(AgentPlan agentPlan, String name, ResourceType type) { + try { + return agentPlan.getResource(name, type); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** Closes all Python-related resources in the correct order. */ + void close() throws Exception { + if (pythonActionExecutor != null) { + pythonActionExecutor.close(); + } + if (pythonInterpreter != null) { + pythonInterpreter.close(); + } + if (pythonEnvironmentManager != null) { + pythonEnvironmentManager.close(); + } + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index 646d1e63..3add3e35 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -184,12 +184,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Use reflection to access the action state store for validation - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); assertThat(actionStateStore).isNotNull(); assertThat(actionStateStore.getKeyedActionStates()).isEmpty(); @@ -244,9 +241,10 @@ void testEventLogBaseDirFromAgentConfig() throws Exception { testHarness.open(); ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - Field eventLoggerField = ActionExecutionOperator.class.getDeclaredField("eventLogger"); + EventRouter eventRouterObj = operator.getEventRouter(); + Field eventLoggerField = EventRouter.class.getDeclaredField("eventLogger"); eventLoggerField.setAccessible(true); - Object eventLogger = eventLoggerField.get(operator); + Object eventLogger = eventLoggerField.get(eventRouterObj); assertThat(eventLogger).isInstanceOf(FileEventLogger.class); Field configField = FileEventLogger.class.getDeclaredField("config"); @@ -276,12 +274,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Use reflection to access the action state store for validation - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); Long inputValue = 3L; testHarness.processElement(new StreamRecord<>(inputValue)); @@ -352,11 +347,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), (ActionExecutionOperator) testHarness.getOperator(); // Access the action state store - java.lang.reflect.Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); // Process multiple elements with same key to test state persistence testHarness.processElement(new StreamRecord<>(1L)); @@ -421,11 +414,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(true)), assertThat(recordOutput.size()).isEqualTo(3); // Access the action state store - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); assertThat(actionStateStore.getKeyedActionStates()).isEmpty(); } } @@ -445,10 +436,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), (ActionExecutionOperator) testHarness.getOperator(); // Access the action state store - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); - actionStateStore = (InMemoryActionStateStore) actionStateStoreField.get(operator); + actionStateStore = + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); Long inputValue = 7L;