diff --git a/Cargo.toml b/Cargo.toml index 7e249893..e9aa5a89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ tokio = { version = "1.43.0", features = [ "rt-multi-thread", "sync", "signal", + "test-util", "io-util", ] } tokio-stream = { version = "0.1.17", features = ["sync"] } diff --git a/crates/worker/src/executor/process.rs b/crates/worker/src/executor/process.rs index 4eef4ba3..a16cec6d 100644 --- a/crates/worker/src/executor/process.rs +++ b/crates/worker/src/executor/process.rs @@ -1,4 +1,4 @@ -use std::{future::Future, path::PathBuf, pin::Pin, process::Stdio, time::Duration}; +use std::{ffi::OsStr, future::Future, path::PathBuf, pin::Pin, process::Stdio, time::Duration}; use hypha_messages::Executor; use nix::{ @@ -72,78 +72,26 @@ impl JobExecutor for ProcessExecutor { let job_json = serde_json::to_string(&job).expect("valid JobSpec JSON"); // This should come from the config and should be bassed when crearting the executor - let mut process = Command::new(get_process_call(&job.executor)) - .args(get_process_args(&job.executor)) - .arg("--socket") - .args([&sock_path]) // passes the actual path - .args(["--work-dir"]) - .arg(&work_dir) // passes the actual path - .args(["--job"]) - .arg(job_json) // serialized JobSpec JSON - .stdout(Stdio::piped()) - .spawn()?; + let run_future = try_run( + get_process_call(&job.executor), + get_process_args(&job.executor).into_iter().chain( + vec![ + "--socket".to_string(), + sock_path.to_str().expect("a valid path").to_string(), + "--work-dir".to_string(), + work_dir.to_str().expect("a valid path").to_string(), + "--job".to_string(), + job_json, + ] + .into_iter(), + ), + cancel.clone(), + )?; let task_tracker = TaskTracker::new(); let shutdown = cancel.clone(); task_tracker.spawn(async move { - let stdout = process.stdout.take().expect("stdout is available"); - - // Stream output. - let mut lines = BufReader::new(stdout).lines(); - - loop { - tokio::select! { - _ = shutdown.cancelled() => { - tracing::trace!("Received shutdown signal. Stopping executor process"); - - // Send SIGTERM to process - // TODO: This is only available in UNIX environment. - // We need to have a Windows-specific code-path - // if we want it to work there as well. - if let Some(pid) = process.id() { - if let Err(e) = signal::kill(Pid::from_raw(pid as pid_t), Signal::SIGTERM) { - tracing::warn!(error = ?e, "Failed to send SIGTERM to executor process"); - } - } else { - tracing::trace!("Executor process already exited"); - } - break; - } - line = lines.next_line() => { - match line { - Ok(Some(line)) => { - println!("{line}") - } - Ok(None) => { - // TODO - } - Err(_) => { - // TODO - } - } - } - // Received if the driver stopped. - _ = process.wait() => { - tracing::debug!("Executor process task terminated"); - // TODO: Decide what to do if the process failed. - // We could, e.g., restart it. - break - } - } - } - - tokio::select! { - status = process.wait() => { - tracing::trace!(status = ?status, "Executor task exited"); - } - // If the driver process does not exit in time, send SIGKILL. - _ = sleep(Duration::from_secs(5)) => { - tracing::trace!("Executor task didn't exit in time, sending SIGKILL"); - if let Err(e) = process.kill().await { - tracing::warn!(error = ?e, "Failed to send SIGKILL to executor process"); - } - } - } + run_future.await; // At this point the process is no longer running but the bridge still is. // By cancelling here, the bridge will stop serving and terminate. @@ -191,3 +139,132 @@ fn get_process_call(executor: &Executor) -> String { _ => ".".into(), } } + +fn try_run( + command: S, + arguments: I, + token: CancellationToken, +) -> Result + 'static, std::io::Error> +where + S: AsRef, + I: IntoIterator, +{ + // This should come from the config and should be bassed when crearting the executor + let mut process = Command::new(command) + .args(arguments) // serialized JobSpec JSON + .stdout(Stdio::piped()) + .spawn()?; + + Ok(async move { + let stdout = process.stdout.take().expect("stdout is available"); + + // Stream output. + let mut lines = BufReader::new(stdout).lines(); + + loop { + tokio::select! { + _ = token.cancelled() => { + tracing::trace!("Received shutdown signal. Stopping executor process"); + + // Send SIGTERM to process + // TODO: This is only available in UNIX environment. + // We need to have a Windows-specific code-path + // if we want it to work there as well. + if let Some(pid) = process.id() { + if let Err(e) = signal::kill(Pid::from_raw(pid as pid_t), Signal::SIGTERM) { + tracing::warn!(error = ?e, "Failed to send SIGTERM to executor process"); + } + } else { + tracing::trace!("Executor process already exited"); + } + break; + } + line = lines.next_line() => { + match line { + Ok(Some(line)) => { + println!("{line}") + } + Ok(None) => { + // TODO + } + Err(_) => { + // TODO + } + } + } + // Received if the driver stopped. + _ = process.wait() => { + tracing::debug!("Executor process task terminated"); + // TODO: Decide what to do if the process failed. + // We could, e.g., restart it. + break + } + } + } + + tokio::select! { + status = process.wait() => { + tracing::trace!(status = ?status, "Executor task exited"); + } + // If the driver process does not exit in time, send SIGKILL. + _ = sleep(Duration::from_secs(5)) => { + tracing::trace!("Executor task didn't exit in time, sending SIGKILL"); + if let Err(e) = process.kill().await { + tracing::warn!(error = ?e, "Failed to send SIGKILL to executor process"); + } + } + } + }) +} + +#[cfg(test)] +mod tests { + use tokio::time; + + use super::*; + + #[tokio::test] + async fn test_run_graceful_termination() { + let token = CancellationToken::new(); + + let process = tokio::spawn( + try_run( + "bash", + vec!["-c", "while true; do sleep 60; done"], + token.clone(), + ) + .unwrap(), + ); + + // We need to wait a while for the process to actually run. + sleep(Duration::from_secs(1)).await; + + token.cancel(); + process.await.unwrap(); + } + + #[tokio::test] + async fn test_run_sigkill_termination() { + let token = CancellationToken::new(); + + // This command doesn't react to SIGTERM, only to SIGKILL. + let process = tokio::spawn( + try_run( + "bash", + vec!["-c", "trap \"\" SIGTERM; while true; do sleep 60; done"], + token.clone(), + ) + .unwrap(), + ); + + // We need to wait a while for the process to actually run. + sleep(Duration::from_secs(1)).await; + + token.cancel(); + + time::pause(); + time::advance(Duration::from_secs(5)).await; + + process.await.unwrap(); + } +}