diff --git a/flowfile_core/flowfile_core/flowfile/flow_graph.py b/flowfile_core/flowfile_core/flowfile/flow_graph.py index 641b28ca..fb03441a 100644 --- a/flowfile_core/flowfile_core/flowfile/flow_graph.py +++ b/flowfile_core/flowfile_core/flowfile/flow_graph.py @@ -1208,7 +1208,12 @@ def _func(*flowfile_tables: FlowDataEngine) -> FlowDataEngine: log_callback_url=log_callback_url, internal_token=internal_token, ) - result = manager.execute_sync(kernel_id, request, self.flow_logger) + node = self.get_node(node_id) + node._kernel_cancel_context = (kernel_id, manager) + try: + result = manager.execute_sync(kernel_id, request, self.flow_logger) + finally: + node._kernel_cancel_context = None # Forward captured stdout/stderr to the flow logger if result.stdout: diff --git a/flowfile_core/flowfile_core/flowfile/flow_node/flow_node.py b/flowfile_core/flowfile_core/flowfile/flow_node/flow_node.py index 10cf8eb6..380242bd 100644 --- a/flowfile_core/flowfile_core/flowfile/flow_node/flow_node.py +++ b/flowfile_core/flowfile_core/flowfile/flow_node/flow_node.py @@ -136,6 +136,7 @@ def post_init(self): self._schema_callback = None self._state_needs_reset = False self._execution_lock = threading.RLock() # Protects concurrent access to get_resulting_data + self._kernel_cancel_context = None # Initialize execution state self._execution_state = NodeExecutionState() self._executor = None # Will be lazily created @@ -1089,7 +1090,13 @@ def cancel(self): if self._fetch_cached_df is not None: self._fetch_cached_df.cancel() - self.node_stats.is_canceled = True + elif self._kernel_cancel_context is not None: + kernel_id, manager = self._kernel_cancel_context + logger.info("Cancelling kernel execution for kernel '%s'", kernel_id) + try: + manager.interrupt_execution_sync(kernel_id) + except Exception: + logger.exception("Failed to interrupt kernel execution for kernel '%s'", kernel_id) else: logger.warning("No external process to cancel") self.node_stats.is_canceled = True diff --git a/flowfile_core/flowfile_core/kernel/manager.py b/flowfile_core/flowfile_core/kernel/manager.py index b9c8aeae..d247a2a2 100644 --- a/flowfile_core/flowfile_core/kernel/manager.py +++ b/flowfile_core/flowfile_core/kernel/manager.py @@ -627,6 +627,30 @@ def execute_sync( if kernel.state == KernelState.EXECUTING: kernel.state = KernelState.IDLE + def interrupt_execution_sync(self, kernel_id: str) -> bool: + """Send SIGUSR1 to a kernel container to interrupt running user code.""" + kernel = self._kernels.get(kernel_id) + if kernel is None or kernel.container_id is None: + logger.warning("Cannot interrupt kernel '%s': not found or no container", kernel_id) + return False + if kernel.state != KernelState.EXECUTING: + return False + try: + container = self._docker.containers.get(kernel.container_id) + container.kill(signal="SIGUSR1") + logger.info("Sent SIGUSR1 to kernel '%s' (container %s)", kernel_id, kernel.container_id[:12]) + return True + except docker.errors.NotFound: + logger.warning("Container for kernel '%s' not found", kernel_id) + return False + except (docker.errors.APIError, docker.errors.DockerException) as exc: + logger.error("Failed to send SIGUSR1 to kernel '%s': %s", kernel_id, exc) + return False + + async def interrupt_execution(self, kernel_id: str) -> bool: + """Async wrapper around :meth:`interrupt_execution_sync`.""" + return self.interrupt_execution_sync(kernel_id) + async def clear_artifacts(self, kernel_id: str) -> None: kernel = self._get_kernel_or_raise(kernel_id) if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): diff --git a/flowfile_core/tests/test_kernel_cancel.py b/flowfile_core/tests/test_kernel_cancel.py new file mode 100644 index 00000000..52f81a8c --- /dev/null +++ b/flowfile_core/tests/test_kernel_cancel.py @@ -0,0 +1,118 @@ +"""Tests for kernel execution cancellation support.""" + +from unittest.mock import MagicMock, patch + +import docker.errors +import pytest + +from flowfile_core.kernel.manager import KernelManager +from flowfile_core.kernel.models import KernelInfo, KernelState + + +def _make_manager(kernel_id="k1", state=KernelState.EXECUTING, container_id="abc123"): + """Build a KernelManager with a mocked Docker client and one kernel.""" + with patch.object(KernelManager, "__init__", lambda self, *a, **kw: None): + mgr = KernelManager.__new__(KernelManager) + mgr._docker = MagicMock() + mgr._kernels = {} + mgr._kernel_owners = {} + mgr._shared_volume = "/tmp/test" + mgr._docker_network = None + mgr._kernel_volume = None + mgr._kernel_volume_type = None + mgr._kernel_mount_target = None + + kernel = KernelInfo(id=kernel_id, name="test-kernel", state=state, container_id=container_id) + mgr._kernels[kernel_id] = kernel + return mgr + + +def _make_node(): + """Build a minimal FlowNode for cancel testing.""" + from flowfile_core.flowfile.flow_node.flow_node import FlowNode + + setting_input = MagicMock() + setting_input.is_setup = False + setting_input.cache_results = False + + return FlowNode( + node_id=1, + function=lambda: None, + parent_uuid="test-uuid", + setting_input=setting_input, + name="test_node", + node_type="python_script", + ) + + +# -- KernelManager.interrupt_execution_sync ----------------------------------- + + +class TestKernelManagerInterrupt: + def test_sends_sigusr1(self): + mgr = _make_manager() + container = MagicMock() + mgr._docker.containers.get.return_value = container + + assert mgr.interrupt_execution_sync("k1") is True + container.kill.assert_called_once_with(signal="SIGUSR1") + + def test_unknown_kernel(self): + mgr = _make_manager() + assert mgr.interrupt_execution_sync("nonexistent") is False + + def test_kernel_not_executing(self): + mgr = _make_manager(state=KernelState.IDLE) + assert mgr.interrupt_execution_sync("k1") is False + mgr._docker.containers.get.assert_not_called() + + def test_no_container_id(self): + mgr = _make_manager(container_id=None) + assert mgr.interrupt_execution_sync("k1") is False + + def test_docker_not_found(self): + mgr = _make_manager() + mgr._docker.containers.get.side_effect = docker.errors.NotFound("gone") + assert mgr.interrupt_execution_sync("k1") is False + + +# -- FlowNode.cancel with kernel context -------------------------------------- + + +class TestFlowNodeCancelWithKernel: + def test_cancel_calls_interrupt(self): + node = _make_node() + mock_mgr = MagicMock() + node._kernel_cancel_context = ("k1", mock_mgr) + + node.cancel() + + mock_mgr.interrupt_execution_sync.assert_called_once_with("k1") + assert node.node_stats.is_canceled is True + + def test_cancel_without_context(self): + node = _make_node() + node.cancel() + assert node.node_stats.is_canceled is True + + def test_worker_fetcher_takes_priority(self): + node = _make_node() + fetcher = MagicMock() + mock_mgr = MagicMock() + node._fetch_cached_df = fetcher + node._kernel_cancel_context = ("k1", mock_mgr) + + node.cancel() + + fetcher.cancel.assert_called_once() + mock_mgr.interrupt_execution_sync.assert_not_called() + assert node.node_stats.is_canceled is True + + def test_interrupt_exception_does_not_crash(self): + node = _make_node() + mock_mgr = MagicMock() + mock_mgr.interrupt_execution_sync.side_effect = RuntimeError("Docker unavailable") + node._kernel_cancel_context = ("k1", mock_mgr) + + node.cancel() # must not raise + assert node.node_stats.is_canceled is True diff --git a/kernel_runtime/kernel_runtime/main.py b/kernel_runtime/kernel_runtime/main.py index 09c53fc5..9bb42d6d 100644 --- a/kernel_runtime/kernel_runtime/main.py +++ b/kernel_runtime/kernel_runtime/main.py @@ -3,6 +3,7 @@ import io import logging import os +import signal import time from collections.abc import AsyncIterator from pathlib import Path @@ -56,6 +57,20 @@ def _clear_namespace(flow_id: int) -> None: _namespace_access.pop(flow_id, None) +# --------------------------------------------------------------------------- +# Execution cancellation via SIGUSR1 +# --------------------------------------------------------------------------- +_is_executing = False + + +def _cancel_signal_handler(signum, frame): + """Interrupt running user code when the container receives SIGUSR1.""" + if _is_executing: + logger.warning("SIGUSR1 received – interrupting execution") + raise KeyboardInterrupt("Execution cancelled by user") + logger.debug("SIGUSR1 received outside execution, ignoring") + + # --------------------------------------------------------------------------- # Persistence setup (driven by environment variables) # --------------------------------------------------------------------------- @@ -152,6 +167,10 @@ def _setup_persistence() -> None: @contextlib.asynccontextmanager async def _lifespan(app: FastAPI) -> AsyncIterator[None]: _setup_persistence() + try: + signal.signal(signal.SIGUSR1, _cancel_signal_handler) + except ValueError: + pass # not in main thread (e.g. TestClient) yield @@ -295,6 +314,7 @@ async def execute(request: ExecuteRequest): artifacts_before = set(artifact_store.list_all(flow_id=request.flow_id).keys()) + global _is_executing try: flowfile_client._set_context( node_id=request.node_id, @@ -332,8 +352,12 @@ async def execute(request: ExecuteRequest): if request.interactive: user_code = _maybe_wrap_last_expression(user_code) - # Execute user code - exec(user_code, exec_globals) # noqa: S102 + # Execute user code (with cancel support via SIGUSR1) + _is_executing = True + try: + exec(user_code, exec_globals) # noqa: S102 + finally: + _is_executing = False # Collect display outputs display_outputs = [DisplayOutput(**d) for d in flowfile_client._get_displays()] @@ -358,6 +382,18 @@ async def execute(request: ExecuteRequest): stderr=stderr_buf.getvalue(), execution_time_ms=elapsed, ) + except KeyboardInterrupt: + _is_executing = False + display_outputs = [DisplayOutput(**d) for d in flowfile_client._get_displays()] + elapsed = (time.perf_counter() - start) * 1000 + return ExecuteResponse( + success=False, + display_outputs=display_outputs, + stdout=stdout_buf.getvalue(), + stderr=stderr_buf.getvalue(), + error="Execution cancelled by user", + execution_time_ms=elapsed, + ) except Exception as exc: # Still collect any display outputs that were generated before the error display_outputs = [DisplayOutput(**d) for d in flowfile_client._get_displays()] @@ -371,6 +407,7 @@ async def execute(request: ExecuteRequest): execution_time_ms=elapsed, ) finally: + _is_executing = False flowfile_client._clear_context() diff --git a/kernel_runtime/tests/conftest.py b/kernel_runtime/tests/conftest.py index a8c8bf09..b8705945 100644 --- a/kernel_runtime/tests/conftest.py +++ b/kernel_runtime/tests/conftest.py @@ -32,6 +32,7 @@ def _clear_global_state(): main._recovery_status = {"status": "pending", "recovered": [], "errors": []} main._kernel_id = "default" main._persistence_path = "/shared/artifacts" + main._is_executing = False # Detach persistence from artifact store artifact_store._persistence = None artifact_store._lazy_index.clear() @@ -46,6 +47,7 @@ def _clear_global_state(): main._recovery_status = {"status": "pending", "recovered": [], "errors": []} main._kernel_id = "default" main._persistence_path = "/shared/artifacts" + main._is_executing = False artifact_store._persistence = None artifact_store._lazy_index.clear() artifact_store._loading_locks.clear() diff --git a/kernel_runtime/tests/test_main.py b/kernel_runtime/tests/test_main.py index be1866d3..595a59c2 100644 --- a/kernel_runtime/tests/test_main.py +++ b/kernel_runtime/tests/test_main.py @@ -1135,3 +1135,44 @@ def test_clear_node_artifacts_scoped_to_flow(self, client: TestClient): # Flow 2's artifact survives artifacts_f2 = client.get("/artifacts", params={"flow_id": 2}).json() assert "model" in artifacts_f2 + + +class TestExecutionCancellation: + """Tests for SIGUSR1-based execution cancellation.""" + + def test_signal_handler_raises_when_executing(self): + """The handler raises KeyboardInterrupt only while user code is running.""" + import kernel_runtime.main as main_module + + main_module._is_executing = True + with pytest.raises(KeyboardInterrupt, match="cancelled"): + main_module._cancel_signal_handler(None, None) + + def test_signal_handler_ignores_when_not_executing(self): + """Outside of exec(), the handler is a no-op (no crash, no exception).""" + import kernel_runtime.main as main_module + + main_module._is_executing = False + main_module._cancel_signal_handler(None, None) # should not raise + + def test_is_executing_flag_cleared_after_success(self, client: TestClient): + """_is_executing must be False after a successful execution.""" + import kernel_runtime.main as main_module + + resp = client.post( + "/execute", + json={"node_id": 200, "code": "x = 1", "flow_id": 1, "input_paths": {}, "output_dir": ""}, + ) + assert resp.json()["success"] is True + assert main_module._is_executing is False + + def test_is_executing_flag_cleared_after_error(self, client: TestClient): + """_is_executing must be False even when user code raises.""" + import kernel_runtime.main as main_module + + resp = client.post( + "/execute", + json={"node_id": 201, "code": "1/0", "flow_id": 1, "input_paths": {}, "output_dir": ""}, + ) + assert resp.json()["success"] is False + assert main_module._is_executing is False