diff --git a/alioth-cli/src/main.rs b/alioth-cli/src/main.rs index a14e5895..76e36d03 100644 --- a/alioth-cli/src/main.rs +++ b/alioth-cli/src/main.rs @@ -488,9 +488,7 @@ fn main_run(args: RunArgs) -> Result<(), Error> { } vm.boot().context(error::BootVm)?; - for result in vm.wait() { - result.context(error::WaitVm)?; - } + vm.wait().context(error::WaitVm)?; Ok(()) } diff --git a/alioth/src/board/board.rs b/alioth/src/board/board.rs index 44663b59..44e7e6f4 100644 --- a/alioth/src/board/board.rs +++ b/alioth/src/board/board.rs @@ -20,7 +20,6 @@ mod x86_64; #[cfg(target_os = "linux")] use std::collections::HashMap; use std::ffi::CStr; -use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::mpsc::{Receiver, Sender}; use std::sync::Arc; use std::thread::JoinHandle; @@ -92,14 +91,28 @@ pub enum Error { VmExit { msg: String }, #[snafu(display("Failed to configure firmware"))] Firmware { error: std::io::Error }, + #[snafu(display("Failed to notify the VMM thread"))] + NotifyVmm, + #[snafu(display("Another VCPU thread has signaled failure"))] + PeerFailure, } type Result = std::result::Result; -pub const STATE_CREATED: u8 = 0; -pub const STATE_RUNNING: u8 = 1; -pub const STATE_SHUTDOWN: u8 = 2; -pub const STATE_REBOOT_PENDING: u8 = 3; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum BoardState { + Created, + Running, + Shutdown, + RebootPending, +} + +#[derive(Debug)] +struct MpSync { + state: BoardState, + fatal: bool, + count: u32, +} pub const PCIE_MMIO_64_SIZE: u64 = 1 << 40; @@ -127,9 +140,7 @@ where pub vcpus: Arc>>, pub arch: ArchBoard, pub config: BoardConfig, - pub state: AtomicU8, pub payload: RwLock>, - pub mp_sync: Arc<(Mutex, Condvar)>, pub io_devs: RwLock)>>, #[cfg(target_arch = "aarch64")] pub mmio_devs: RwLock)>>, @@ -140,12 +151,53 @@ where pub vfio_ioases: Mutex, Arc>>, #[cfg(target_os = "linux")] pub vfio_containers: Mutex, Arc>>, + + mp_sync: Mutex, + cond_var: Condvar, } impl Board where V: Vm, { + pub fn new(vm: V, memory: Memory, arch: ArchBoard, config: BoardConfig) -> Self { + Board { + vm, + memory, + arch, + config, + payload: RwLock::new(None), + vcpus: Arc::new(RwLock::new(Vec::new())), + io_devs: RwLock::new(Vec::new()), + #[cfg(target_arch = "aarch64")] + mmio_devs: RwLock::new(Vec::new()), + pci_bus: PciBus::new(), + #[cfg(target_arch = "x86_64")] + fw_cfg: Mutex::new(None), + #[cfg(target_os = "linux")] + vfio_ioases: Mutex::new(HashMap::new()), + #[cfg(target_os = "linux")] + vfio_containers: Mutex::new(HashMap::new()), + + mp_sync: Mutex::new(MpSync { + state: BoardState::Created, + count: 0, + fatal: false, + }), + cond_var: Condvar::new(), + } + } + + pub fn boot(&self) -> Result<()> { + let vcpus = self.vcpus.read(); + let mut mp_sync = self.mp_sync.lock(); + mp_sync.state = BoardState::Running; + for (_, boot_tx) in vcpus.iter() { + boot_tx.send(()).unwrap(); + } + Ok(()) + } + fn load_payload(&self) -> Result { let payload = self.payload.read(); let Some(payload) = payload.as_ref() else { @@ -219,10 +271,10 @@ where break Ok(true); } VmExit::Interrupted => { - let state = self.state.load(Ordering::Acquire); - match state { - STATE_SHUTDOWN => VmEntry::Shutdown, - STATE_REBOOT_PENDING => VmEntry::Reboot, + let mp_sync = self.mp_sync.lock(); + match mp_sync.state { + BoardState::Shutdown => VmEntry::Shutdown, + BoardState::RebootPending => VmEntry::Reboot, _ => VmEntry::None, } } @@ -235,29 +287,36 @@ where } } - fn sync_vcpus(&self, vcpus: &VcpuGuard) { - let (lock, cvar) = &*self.mp_sync; - let mut count = lock.lock(); - *count += 1; - if *count == vcpus.len() as u32 { - *count = 0; - cvar.notify_all(); + fn sync_vcpus(&self, vcpus: &VcpuGuard) -> Result<()> { + let mut mp_sync = self.mp_sync.lock(); + if mp_sync.fatal { + return error::PeerFailure.fail(); + } + + mp_sync.count += 1; + if mp_sync.count == vcpus.len() as u32 { + mp_sync.count = 0; + self.cond_var.notify_all(); } else { - cvar.wait(&mut count) + self.cond_var.wait(&mut mp_sync) + } + + if mp_sync.fatal { + return error::PeerFailure.fail(); } + + Ok(()) } fn run_vcpu_inner( &self, id: u32, - event_tx: &Sender, + vcpu: &mut V::Vcpu, boot_rx: &Receiver<()>, ) -> Result<(), Error> { - let mut vcpu = self.vm.create_vcpu(id).context(error::CreateVcpu { id })?; - event_tx.send(id).unwrap(); - self.init_vcpu(id, &mut vcpu)?; + self.init_vcpu(id, vcpu)?; boot_rx.recv().unwrap(); - if self.state.load(Ordering::Acquire) != STATE_RUNNING { + if self.mp_sync.lock().state != BoardState::Running { return Ok(()); } loop { @@ -274,42 +333,33 @@ where } self.add_pci_devs()?; let init_state = self.load_payload()?; - self.init_boot_vcpu(&mut vcpu, &init_state)?; + self.init_boot_vcpu(vcpu, &init_state)?; self.create_firmware_data(&init_state)?; } - self.init_ap(id, &mut vcpu, &vcpus)?; + self.init_ap(id, vcpu, &vcpus)?; self.coco_finalize(id, &vcpus)?; + self.sync_vcpus(&vcpus)?; drop(vcpus); - let reboot = self.vcpu_loop(&mut vcpu, id)?; + let maybe_reboot = self.vcpu_loop(vcpu, id); - let new_state = if reboot { - STATE_REBOOT_PENDING - } else { - STATE_SHUTDOWN - }; let vcpus = self.vcpus.read(); - match self.state.compare_exchange( - STATE_RUNNING, - new_state, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(STATE_RUNNING) => { - for (vcpu_id, (handle, _)) in vcpus.iter().enumerate() { - if id != vcpu_id as u32 { - log::info!("vcpu{id} to kill {vcpu_id}"); - V::stop_vcpu(vcpu_id as u32, handle).context(error::StopVcpu { id })?; - } + let mut mp_sync = self.mp_sync.lock(); + if mp_sync.state == BoardState::Running { + mp_sync.state = if matches!(maybe_reboot, Ok(true)) { + BoardState::RebootPending + } else { + BoardState::Shutdown + }; + for (vcpu_id, (handle, _)) in vcpus.iter().enumerate() { + if id != vcpu_id as u32 { + log::info!("VCPU-{id}: stopping VCPU-{vcpu_id}"); + V::stop_vcpu(vcpu_id as u32, handle).context(error::StopVcpu { id })?; } } - Err(s) if s == new_state => {} - Ok(s) | Err(s) => { - log::error!("unexpected state: {s}"); - } } - - self.sync_vcpus(&vcpus); + drop(mp_sync); + self.sync_vcpus(&vcpus)?; if id == 0 { let devices = self.pci_bus.segment.devices.read(); @@ -319,22 +369,26 @@ where } self.memory.reset()?; } + self.reset_vcpu(id, vcpu)?; - if new_state == STATE_SHUTDOWN { - break Ok(()); + if let Err(e) = maybe_reboot { + break Err(e); } - match self.state.compare_exchange( - STATE_REBOOT_PENDING, - STATE_RUNNING, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(STATE_REBOOT_PENDING) | Err(STATE_RUNNING) => {} - _ => break Ok(()), + let mut mp_sync = self.mp_sync.lock(); + if mp_sync.state == BoardState::Shutdown { + break Ok(()); } + mp_sync.state = BoardState::Running; + } + } - self.reset_vcpu(id, &mut vcpu)?; + fn create_vcpu(&self, id: u32, event_tx: &Sender) -> Result { + let vcpu = self.vm.create_vcpu(id).context(error::CreateVcpu { id })?; + if event_tx.send(id).is_err() { + error::NotifyVmm.fail() + } else { + Ok(vcpu) } } @@ -344,9 +398,21 @@ where event_tx: Sender, boot_rx: Receiver<()>, ) -> Result<(), Error> { - let ret = self.run_vcpu_inner(id, &event_tx, &boot_rx); - self.state.store(STATE_SHUTDOWN, Ordering::Release); + let mut vcpu = self.create_vcpu(id, &event_tx)?; + + let ret = self.run_vcpu_inner(id, &mut vcpu, &boot_rx); event_tx.send(id).unwrap(); + + if matches!(ret, Ok(_) | Err(Error::PeerFailure { .. })) { + return Ok(()); + } + + log::warn!("VCPU-{id} reported error, unblocking other VCPUs..."); + let mut mp_sync = self.mp_sync.lock(); + mp_sync.fatal = true; + if mp_sync.count > 0 { + self.cond_var.notify_all(); + } ret } diff --git a/alioth/src/board/x86_64.rs b/alioth/src/board/x86_64.rs index 062a4fee..33604405 100644 --- a/alioth/src/board/x86_64.rs +++ b/alioth/src/board/x86_64.rs @@ -207,7 +207,7 @@ where Some(Coco::AmdSnp { .. }) => {} _ => return Ok(()), } - self.sync_vcpus(vcpus); + self.sync_vcpus(vcpus)?; if id == 0 { return Ok(()); } @@ -319,7 +319,7 @@ where pub fn coco_finalize(&self, id: u32, vcpus: &VcpuGuard) -> Result<()> { if let Some(coco) = &self.config.coco { - self.sync_vcpus(vcpus); + self.sync_vcpus(vcpus)?; if id == 0 { match coco { Coco::AmdSev { policy } => { @@ -334,7 +334,7 @@ where } } } - self.sync_vcpus(vcpus); + self.sync_vcpus(vcpus)?; } Ok(()) } diff --git a/alioth/src/vm.rs b/alioth/src/vm.rs index 842b8e6e..d2d6df79 100644 --- a/alioth/src/vm.rs +++ b/alioth/src/vm.rs @@ -12,23 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#[cfg(target_os = "linux")] -use std::collections::HashMap; #[cfg(target_os = "linux")] use std::path::Path; -use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::Arc; use std::thread; +use std::time::Duration; -use parking_lot::{Condvar, Mutex, RwLock}; +#[cfg(target_os = "linux")] +use parking_lot::Mutex; use snafu::{ResultExt, Snafu}; #[cfg(target_arch = "aarch64")] use crate::arch::layout::PL011_START; #[cfg(target_arch = "x86_64")] use crate::arch::layout::{PORT_COM1, PORT_FW_CFG_SELECTOR}; -use crate::board::{ArchBoard, Board, BoardConfig, STATE_CREATED, STATE_RUNNING}; +use crate::board::{ArchBoard, Board, BoardConfig}; #[cfg(target_arch = "x86_64")] use crate::device::fw_cfg::{FwCfg, FwCfgItemParam}; #[cfg(target_arch = "aarch64")] @@ -44,7 +43,6 @@ use crate::loader::Payload; use crate::mem::Memory; #[cfg(target_arch = "aarch64")] use crate::mem::{MemRegion, MemRegionType}; -use crate::pci::bus::PciBus; use crate::pci::{Bdf, PciDevice}; #[cfg(target_os = "linux")] use crate::vfio::bindings::VfioIommu; @@ -133,26 +131,7 @@ where let memory = Memory::new(vm_memory); let arch = ArchBoard::new(&hv, &vm, &config)?; - let board = Arc::new(Board { - vm, - memory, - arch, - config, - state: AtomicU8::new(STATE_CREATED), - payload: RwLock::new(None), - vcpus: Arc::new(RwLock::new(Vec::new())), - mp_sync: Arc::new((Mutex::new(0), Condvar::new())), - io_devs: RwLock::new(Vec::new()), - #[cfg(target_arch = "aarch64")] - mmio_devs: RwLock::new(Vec::new()), - pci_bus: PciBus::new(), - #[cfg(target_arch = "x86_64")] - fw_cfg: Mutex::new(None), - #[cfg(target_os = "linux")] - vfio_ioases: Mutex::new(HashMap::new()), - #[cfg(target_os = "linux")] - vfio_containers: Mutex::new(HashMap::new()), - }); + let board = Arc::new(Board::new(vm, memory, arch, config)); let (event_tx, event_rx) = mpsc::channel(); @@ -165,7 +144,10 @@ where .name(format!("vcpu_{}", vcpu_id)) .spawn(move || board.run_vcpu(vcpu_id, event_tx, boot_rx)) .context(error::VcpuThread { id: vcpu_id })?; - event_rx.recv().unwrap(); + if event_rx.recv_timeout(Duration::from_secs(2)).is_err() { + let err = std::io::ErrorKind::TimedOut.into(); + Err(err).context(error::VcpuThread { id: vcpu_id })?; + } vcpus.push((handle, boot_tx)); } drop(vcpus); @@ -287,15 +269,11 @@ where } pub fn boot(&self) -> Result<(), Error> { - let vcpus = self.board.vcpus.read(); - self.board.state.store(STATE_RUNNING, Ordering::Release); - for (_, boot_tx) in vcpus.iter() { - boot_tx.send(()).unwrap(); - } + self.board.boot()?; Ok(()) } - pub fn wait(&self) -> Vec> { + pub fn wait(&self) -> Result<()> { self.event_rx.recv().unwrap(); let vcpus = self.board.vcpus.read(); for _ in 1..vcpus.len() { @@ -303,17 +281,17 @@ where } drop(vcpus); let mut vcpus = self.board.vcpus.write(); - vcpus - .drain(..) - .enumerate() - .map(|(id, (handle, _))| match handle.join() { - Err(e) => { - log::error!("cannot join vcpu {}: {:?}", id, e); - Ok(()) - } - Ok(r) => r.context(error::Vcpu { id: id as u32 }), - }) - .collect() + let mut ret = Ok(()); + for (id, (handle, _)) in vcpus.drain(..).enumerate() { + let Ok(r) = handle.join() else { + log::error!("Cannot join VCPU-{id}"); + continue; + }; + if r.is_err() && ret.is_ok() { + ret = r.context(error::Vcpu { id: id as u32 }); + } + } + ret } }