From 2fed83adf821f6392b9aaabe51eb516ad269d569 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sat, 19 Apr 2025 12:28:10 -0700 Subject: [PATCH 01/10] refactor(virtio): change DEVICE_ID to a function Signed-off-by: Changyuan Lyu --- alioth/src/virtio/dev/balloon.rs | 4 +++- alioth/src/virtio/dev/blk.rs | 4 +++- alioth/src/virtio/dev/dev.rs | 4 ++-- alioth/src/virtio/dev/entropy.rs | 4 +++- alioth/src/virtio/dev/fs.rs | 4 +++- alioth/src/virtio/dev/net/net.rs | 4 +++- alioth/src/virtio/dev/vsock/vhost_vsock.rs | 4 +++- 7 files changed, 20 insertions(+), 8 deletions(-) diff --git a/alioth/src/virtio/dev/balloon.rs b/alioth/src/virtio/dev/balloon.rs index b616db94..bdb3a347 100644 --- a/alioth/src/virtio/dev/balloon.rs +++ b/alioth/src/virtio/dev/balloon.rs @@ -193,7 +193,9 @@ impl Virtio for Balloon { type Config = BalloonConfigMmio; type Feature = BalloonFeature; - const DEVICE_ID: DeviceId = DeviceId::Balloon; + fn id(&self) -> DeviceId { + DeviceId::Balloon + } fn name(&self) -> &str { &self.name diff --git a/alioth/src/virtio/dev/blk.rs b/alioth/src/virtio/dev/blk.rs index 89a558d9..26b27315 100644 --- a/alioth/src/virtio/dev/blk.rs +++ b/alioth/src/virtio/dev/blk.rs @@ -278,7 +278,9 @@ impl Virtio for Block { type Config = BlockConfig; type Feature = BlockFeature; - const DEVICE_ID: DeviceId = DeviceId::Block; + fn id(&self) -> DeviceId { + DeviceId::Block + } fn name(&self) -> &str { &self.name diff --git a/alioth/src/virtio/dev/dev.rs b/alioth/src/virtio/dev/dev.rs index 4ceebcf9..13dfde9d 100644 --- a/alioth/src/virtio/dev/dev.rs +++ b/alioth/src/virtio/dev/dev.rs @@ -43,11 +43,11 @@ use crate::virtio::worker::Waker; use crate::virtio::{DeviceId, IrqSender, Result, VirtioFeature, error}; pub trait Virtio: Debug + Send + Sync + 'static { - const DEVICE_ID: DeviceId; type Config: Mmio; type Feature: Flags + Debug; fn name(&self) -> &str; + fn id(&self) -> DeviceId; fn num_queues(&self) -> u16; fn config(&self) -> Arc; fn feature(&self) -> u64; @@ -167,7 +167,7 @@ where D: Virtio, { let name = name.into(); - let id = D::DEVICE_ID; + let id = dev.id(); let device_config = dev.config(); let mut device_feature = dev.feature(); if restricted_memory { diff --git a/alioth/src/virtio/dev/entropy.rs b/alioth/src/virtio/dev/entropy.rs index b433c913..e1c38210 100644 --- a/alioth/src/virtio/dev/entropy.rs +++ b/alioth/src/virtio/dev/entropy.rs @@ -83,7 +83,9 @@ impl Virtio for Entropy { type Config = EntropyConfig; type Feature = EntropyFeature; - const DEVICE_ID: DeviceId = DeviceId::Entropy; + fn id(&self) -> DeviceId { + DeviceId::Entropy + } fn name(&self) -> &str { &self.name diff --git a/alioth/src/virtio/dev/fs.rs b/alioth/src/virtio/dev/fs.rs index 3610f589..03148382 100644 --- a/alioth/src/virtio/dev/fs.rs +++ b/alioth/src/virtio/dev/fs.rs @@ -185,7 +185,9 @@ impl Virtio for VuFs { type Config = FsConfig; type Feature = FsFeature; - const DEVICE_ID: DeviceId = DeviceId::FileSystem; + fn id(&self) -> DeviceId { + DeviceId::FileSystem + } fn name(&self) -> &str { &self.name diff --git a/alioth/src/virtio/dev/net/net.rs b/alioth/src/virtio/dev/net/net.rs index 077303cc..875a6037 100644 --- a/alioth/src/virtio/dev/net/net.rs +++ b/alioth/src/virtio/dev/net/net.rs @@ -293,7 +293,9 @@ impl Virtio for Net { type Config = NetConfig; type Feature = NetFeature; - const DEVICE_ID: DeviceId = DeviceId::Net; + fn id(&self) -> DeviceId { + DeviceId::Net + } fn name(&self) -> &str { &self.name diff --git a/alioth/src/virtio/dev/vsock/vhost_vsock.rs b/alioth/src/virtio/dev/vsock/vhost_vsock.rs index a441b508..87afd014 100644 --- a/alioth/src/virtio/dev/vsock/vhost_vsock.rs +++ b/alioth/src/virtio/dev/vsock/vhost_vsock.rs @@ -103,7 +103,9 @@ impl Virtio for VhostVsock { type Config = VsockConfig; type Feature = VsockFeature; - const DEVICE_ID: DeviceId = DeviceId::Socket; + fn id(&self) -> DeviceId { + DeviceId::Socket + } fn name(&self) -> &str { &self.name From 853f6549bce0e0b0821c3fe5d5cce4da8c6d7484 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sun, 13 Apr 2025 17:39:58 -0700 Subject: [PATCH 02/10] refactor(utils): send/receive FDs over Unix domain sockets Signed-off-by: Changyuan Lyu --- alioth/src/utils/uds.rs | 133 ++++++++++++++++++++++++++++++++++++++ alioth/src/utils/utils.rs | 2 + alioth/src/virtio/vu.rs | 77 +--------------------- 3 files changed, 138 insertions(+), 74 deletions(-) create mode 100644 alioth/src/utils/uds.rs diff --git a/alioth/src/utils/uds.rs b/alioth/src/utils/uds.rs new file mode 100644 index 00000000..bdf316e6 --- /dev/null +++ b/alioth/src/utils/uds.rs @@ -0,0 +1,133 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io::{ErrorKind, IoSlice, IoSliceMut, Result}; +use std::iter::zip; +use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::ptr::{null_mut, read_unaligned, write_unaligned}; + +use crate::ffi; + +pub const UDS_MAX_FD: usize = 32; + +const CMSG_BUF_LEN: usize = + unsafe { libc::CMSG_SPACE((UDS_MAX_FD * size_of::()) as u32) } as usize; + +pub fn recv_msg_with_fds( + conn: &UnixStream, + bufs: &mut [IoSliceMut], + fds: &mut [Option], +) -> Result { + let mut cmsg_buf = [0u64; CMSG_BUF_LEN / size_of::()]; + let mut uds_msg = libc::msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: bufs.as_mut_ptr() as _, + msg_iovlen: bufs.len() as _, + msg_control: cmsg_buf.as_mut_ptr() as _, + msg_controllen: CMSG_BUF_LEN as _, + msg_flags: 0, + }; + let flag = libc::MSG_CMSG_CLOEXEC; + let size = ffi!(unsafe { libc::recvmsg(conn.as_raw_fd(), &mut uds_msg, flag) })?; + + if size == 0 { + let buffer_size = bufs.iter().map(|b| b.len()).sum::(); + let err = if buffer_size == 0 { + ErrorKind::InvalidInput + } else { + ErrorKind::ConnectionAborted + }; + return Err(err.into()); + } + + if uds_msg.msg_flags & libc::MSG_CTRUNC > 0 { + return Err(ErrorKind::OutOfMemory.into()); + } + + let mut overflow = false; + let mut cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(&uds_msg) }; + let mut iter = fds.iter_mut(); + while !cmsg_ptr.is_null() { + let cmsg = unsafe { read_unaligned(cmsg_ptr) }; + if cmsg.cmsg_level != libc::SOL_SOCKET || cmsg.cmsg_type != libc::SCM_RIGHTS { + continue; + } + + let cmsg_data_ptr = unsafe { libc::CMSG_DATA(cmsg_ptr) } as *const RawFd; + for i in 0.. { + let len = unsafe { libc::CMSG_LEN((size_of::() * (i + 1)) as u32) }; + if len > cmsg.cmsg_len as u32 { + break; + } + + let raw_fd = unsafe { read_unaligned(cmsg_data_ptr.add(i)) }; + let owned_fd = unsafe { OwnedFd::from_raw_fd(raw_fd) }; + if let Some(fd) = iter.next() { + *fd = Some(owned_fd); + } else { + overflow = true; + } + } + cmsg_ptr = unsafe { libc::CMSG_NXTHDR(&uds_msg, cmsg_ptr) }; + } + + if overflow { + Err(ErrorKind::OutOfMemory.into()) + } else { + Ok(size as usize) + } +} + +pub fn send_msg_with_fds(conn: &UnixStream, bufs: &[IoSlice], fds: &[RawFd]) -> Result { + if fds.len() > UDS_MAX_FD { + return Err(ErrorKind::OutOfMemory.into()); + } + + let mut raw_fds = [0; UDS_MAX_FD]; + for (raw_fd, fd) in zip(&mut raw_fds, fds) { + *raw_fd = fd.as_raw_fd(); + } + let fds_size = size_of_val(fds) as u32; + let buf_len = if fds_size > 0 { + unsafe { libc::CMSG_SPACE(fds_size) } + } else { + 0 + } as usize; + let mut cmsg_buf = [0u64; CMSG_BUF_LEN / size_of::()]; + let uds_msg = libc::msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: bufs.as_ptr() as _, + msg_iovlen: bufs.len() as _, + msg_control: cmsg_buf.as_mut_ptr() as _, + msg_controllen: buf_len as _, + msg_flags: 0, + }; + if fds_size > 0 { + let cmsg = libc::cmsghdr { + cmsg_level: libc::SOL_SOCKET, + cmsg_type: libc::SCM_RIGHTS, + cmsg_len: unsafe { libc::CMSG_LEN(fds_size) } as _, + }; + let cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(&uds_msg) }; + unsafe { + write_unaligned(cmsg_ptr, cmsg); + write_unaligned(libc::CMSG_DATA(cmsg_ptr) as *mut _, raw_fds); + } + } + let size = ffi!(unsafe { libc::sendmsg(conn.as_raw_fd(), &uds_msg, 0) })?; + Ok(size as usize) +} diff --git a/alioth/src/utils/utils.rs b/alioth/src/utils/utils.rs index ad238931..ced19dfc 100644 --- a/alioth/src/utils/utils.rs +++ b/alioth/src/utils/utils.rs @@ -15,6 +15,8 @@ pub mod endian; #[cfg(target_os = "linux")] pub mod ioctls; +#[cfg(target_os = "linux")] +pub mod uds; use std::sync::atomic::{AtomicU64, Ordering}; diff --git a/alioth/src/virtio/vu.rs b/alioth/src/virtio/vu.rs index ee5176bc..50e99930 100644 --- a/alioth/src/virtio/vu.rs +++ b/alioth/src/virtio/vu.rs @@ -13,12 +13,10 @@ // limitations under the License. use std::io::{IoSlice, IoSliceMut, Read, Write}; -use std::iter::zip; use std::mem::{size_of, size_of_val}; use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}; use std::os::unix::net::UnixStream; use std::path::{Path, PathBuf}; -use std::ptr::null_mut; use std::sync::Arc; use bitfield::bitfield; @@ -30,6 +28,7 @@ use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes}; use crate::errors::{BoxTrace, DebugTrace, trace_error}; use crate::mem::LayoutChanged; use crate::mem::mapped::ArcMemPages; +use crate::utils::uds::{recv_msg_with_fds, send_msg_with_fds}; use crate::{ffi, mem}; bitflags! { @@ -211,8 +210,6 @@ pub enum Error { DeviceFeature { feature: u64 }, #[snafu(display("vhost-user backend is missing protocol feature {feature:x?}"))] ProtocolFeature { feature: VuFeature }, - #[snafu(display("Insufficient buffer (len {len}) for holding {need} fds"))] - InsufficientBuffer { len: usize, need: usize }, } type Result = std::result::Result; @@ -268,40 +265,8 @@ impl VuDev { IoSlice::new(vhost_msg.as_bytes()), IoSlice::new(payload.as_bytes()), ]; - let fd_size = size_of_val(fds); - let mut cmsg_buf = if fds.is_empty() { - vec![] - } else { - vec![0u8; unsafe { libc::CMSG_SPACE(fd_size as _) } as _] - }; - let uds_msg = libc::msghdr { - msg_name: null_mut(), - msg_namelen: 0, - msg_iov: bufs.as_ptr() as _, - msg_iovlen: if size_of::() == 0 { 1 } else { 2 }, - msg_control: if fds.is_empty() { - null_mut() - } else { - cmsg_buf.as_mut_ptr() as _ - }, - msg_controllen: cmsg_buf.len(), - msg_flags: 0, - }; - if !fds.is_empty() { - let cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(&uds_msg) }; - let cmsg = libc::cmsghdr { - cmsg_level: libc::SOL_SOCKET, - cmsg_type: libc::SCM_RIGHTS, - cmsg_len: unsafe { libc::CMSG_LEN(fd_size as _) } as _, - }; - unsafe { std::ptr::write_unaligned(cmsg_ptr, cmsg) }; - let data = - unsafe { std::slice::from_raw_parts_mut(libc::CMSG_DATA(cmsg_ptr), fd_size) }; - data.copy_from_slice(fds.as_bytes()); - } - let mut conn = self.conn.lock(); - ffi!(unsafe { libc::sendmsg(conn.as_raw_fd(), &uds_msg, 0) })?; + send_msg_with_fds(&conn, &bufs, fds)?; let mut resp = Message::new_zeroed(); let mut payload = R::new_zeroed(); @@ -447,25 +412,13 @@ impl VuDev { ) -> Result<(u32, u32)> { let mut msg = Message::new_zeroed(); let mut bufs = [IoSliceMut::new(msg.as_mut_bytes()), IoSliceMut::new(buf)]; - const CMSG_BUF_LEN: usize = unsafe { libc::CMSG_SPACE(8) } as usize; - debug_assert_eq!(CMSG_BUF_LEN % size_of::(), 0); - let mut cmsg_buf = [0u64; CMSG_BUF_LEN / size_of::()]; - let mut uds_msg = libc::msghdr { - msg_name: null_mut(), - msg_namelen: 0, - msg_iov: bufs.as_mut_ptr() as _, - msg_iovlen: bufs.len(), - msg_control: cmsg_buf.as_mut_ptr() as _, - msg_controllen: CMSG_BUF_LEN, - msg_flags: 0, - }; let Some(channel) = &self.channel else { return error::ProtocolFeature { feature: VuFeature::BACKEND_REQ, } .fail(); }; - let r_size = ffi!(unsafe { libc::recvmsg(channel.as_raw_fd(), &mut uds_msg, 0) })? as usize; + let r_size = recv_msg_with_fds(channel, &mut bufs, fds)?; let expected_size = size_of::() + msg.size as usize; if r_size != expected_size { return error::MsgSize { @@ -474,30 +427,6 @@ impl VuDev { } .fail(); } - - let cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(&uds_msg) }; - if cmsg_ptr.is_null() { - return Ok((msg.request, msg.size)); - } - let cmsg = unsafe { &*cmsg_ptr }; - if cmsg.cmsg_level != libc::SOL_SOCKET || cmsg.cmsg_type != libc::SCM_RIGHTS { - return Ok((msg.request, msg.size)); - } - let cmsg_data_ptr = unsafe { libc::CMSG_DATA(cmsg_ptr) } as *const RawFd; - let count = - (cmsg_ptr as usize + cmsg.cmsg_len - cmsg_data_ptr as usize) / size_of::(); - if count > fds.len() { - return error::InsufficientBuffer { - len: fds.len(), - need: count, - } - .fail(); - } - for (fd, index) in zip(fds.iter_mut(), 0..count) { - *fd = Some(unsafe { - OwnedFd::from_raw_fd(std::ptr::read_unaligned(cmsg_data_ptr.add(index))) - }); - } Ok((msg.request, msg.size)) } From f55b39a384605f1f80e98b8d1c57bf8a29a1fae0 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sun, 13 Apr 2025 18:27:06 -0700 Subject: [PATCH 03/10] fix: replace RawFd with BorrowedFd in APIs Signed-off-by: Changyuan Lyu --- alioth/src/hv/kvm/device.rs | 10 +-- alioth/src/hv/kvm/vcpu/aarch64.rs | 2 +- alioth/src/hv/kvm/vm/aarch64.rs | 6 +- alioth/src/hv/kvm/vm/vm.rs | 91 +++++++++----------- alioth/src/hv/kvm/vm/x86_64.rs | 2 +- alioth/src/hv/kvm/x86_64.rs | 21 ++--- alioth/src/utils/ioctls.rs | 96 +++++++++++++++++----- alioth/src/utils/uds.rs | 4 +- alioth/src/virtio/dev/fs.rs | 13 +-- alioth/src/virtio/dev/net/net.rs | 4 +- alioth/src/virtio/dev/vsock/vhost_vsock.rs | 9 +- alioth/src/virtio/pci.rs | 27 ++++-- alioth/src/virtio/virtio.rs | 10 ++- alioth/src/virtio/vu.rs | 18 ++-- 14 files changed, 184 insertions(+), 129 deletions(-) diff --git a/alioth/src/hv/kvm/device.rs b/alioth/src/hv/kvm/device.rs index a6c62783..832b05eb 100644 --- a/alioth/src/hv/kvm/device.rs +++ b/alioth/src/hv/kvm/device.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::mem::{MaybeUninit, size_of_val}; -use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}; +use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd}; use snafu::ResultExt; @@ -26,7 +26,7 @@ use crate::hv::{KvmError, Result}; pub(super) struct KvmDevice(pub OwnedFd); impl KvmDevice { - pub fn new(vm_fd: &impl AsRawFd, type_: KvmDevType) -> Result { + pub fn new(vm_fd: &impl AsFd, type_: KvmDevType) -> Result { let mut create_device = KvmCreateDevice { type_, fd: 0, @@ -44,12 +44,6 @@ impl AsFd for KvmDevice { } } -impl AsRawFd for KvmDevice { - fn as_raw_fd(&self) -> RawFd { - self.0.as_raw_fd() - } -} - impl KvmDevice { pub fn set_attr(&self, group: u32, attr: u64, val: &T) -> Result<(), KvmError> { let attr = KvmDeviceAttr { diff --git a/alioth/src/hv/kvm/vcpu/aarch64.rs b/alioth/src/hv/kvm/vcpu/aarch64.rs index 7c77d786..100d56d7 100644 --- a/alioth/src/hv/kvm/vcpu/aarch64.rs +++ b/alioth/src/hv/kvm/vcpu/aarch64.rs @@ -33,7 +33,7 @@ const fn encode_system_reg(reg: SReg) -> u64 { impl KvmVcpu { pub fn kvm_vcpu_init(&self, is_bsp: bool) -> Result<()> { let mut arm_cpu_init = - unsafe { kvm_arm_preferred_target(&self.vm) }.context(error::CreateVcpu)?; + unsafe { kvm_arm_preferred_target(&self.vm.fd) }.context(error::CreateVcpu)?; if self.vm.check_extension(KvmCap::ARM_PSCI_0_2)? == 1 { arm_cpu_init.features[0] |= KvmArmVcpuFeature::PSCI_0_2.bits(); } diff --git a/alioth/src/hv/kvm/vm/aarch64.rs b/alioth/src/hv/kvm/vm/aarch64.rs index 53100b72..c62a3874 100644 --- a/alioth/src/hv/kvm/vm/aarch64.rs +++ b/alioth/src/hv/kvm/vm/aarch64.rs @@ -79,7 +79,7 @@ impl KvmVm { distributor_base: u64, cpu_interface_base: u64, ) -> Result { - let dev = KvmDevice::new(&self.vm, KvmDevType::ARM_VGIC_V2)?; + let dev = KvmDevice::new(&self.vm.fd, KvmDevType::ARM_VGIC_V2)?; let gic = KvmGicV2 { dev }; gic.dev.set_attr( KvmDevArmVgicGrp::ADDR.raw(), @@ -118,7 +118,7 @@ impl KvmVm { redistributor_base: u64, redistributor_count: u32, ) -> Result { - let dev = KvmDevice::new(&self.vm, KvmDevType::ARM_VGIC_V3)?; + let dev = KvmDevice::new(&self.vm.fd, KvmDevType::ARM_VGIC_V3)?; dev.set_attr( KvmDevArmVgicGrp::ADDR.raw(), KvmVgicAddrType::DIST_V3.raw(), @@ -153,7 +153,7 @@ impl Its for KvmIts { impl KvmVm { pub fn kvm_create_its(&self, base: u64) -> Result { - let dev = KvmDevice::new(&self.vm, KvmDevType::ARM_ITS)?; + let dev = KvmDevice::new(&self.vm.fd, KvmDevType::ARM_ITS)?; dev.set_attr( KvmDevArmVgicGrp::ADDR.raw(), KvmVgicAddrType::ITS.raw(), diff --git a/alioth/src/hv/kvm/vm/vm.rs b/alioth/src/hv/kvm/vm/vm.rs index 80102edb..083e177d 100644 --- a/alioth/src/hv/kvm/vm/vm.rs +++ b/alioth/src/hv/kvm/vm/vm.rs @@ -18,8 +18,9 @@ mod aarch64; mod x86_64; use std::collections::HashMap; +use std::fmt::{self, Display, Formatter}; use std::io::ErrorKind; -use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}; +use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd}; use std::os::unix::thread::JoinHandleExt; use std::sync::Arc; use std::sync::atomic::{AtomicU32, Ordering}; @@ -113,17 +114,13 @@ impl VmInner { _flags: 0, entries, }; - log::trace!( - "vm-{}: updating GSI routing table to {:#x?}", - self.as_raw_fd(), - irq_routing - ); - unsafe { kvm_set_gsi_routing(self, &irq_routing) }.context(kvm_error::GsiRouting)?; + log::trace!("{self}: updating GSI routing table to {irq_routing:#x?}"); + unsafe { kvm_set_gsi_routing(&self.fd, &irq_routing) }.context(kvm_error::GsiRouting)?; Ok(()) } pub fn check_extension(&self, id: KvmCap) -> Result { - let ret = unsafe { kvm_check_extension(self, id) }; + let ret = unsafe { kvm_check_extension(&self.fd, id) }; match ret { Ok(num) => Ok(num), Err(_) => error::Capability { @@ -134,9 +131,9 @@ impl VmInner { } } -impl AsRawFd for VmInner { - fn as_raw_fd(&self) -> RawFd { - self.fd.as_raw_fd() +impl Display for VmInner { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "kvm-{}", self.fd.as_raw_fd()) } } @@ -171,11 +168,11 @@ impl KvmMemory { userspace_addr: 0, flags, }; - unsafe { kvm_set_user_memory_region(&self.vm, ®ion) } + unsafe { kvm_set_user_memory_region(&self.vm.fd, ®ion) } .context(error::GuestUnmap { gpa, size })?; log::trace!( - "vm-{}: slot-{slot}: unmapped: {gpa:#018x}, size={size:#x}", - self.vm.as_raw_fd() + "{}: slot-{slot}: unmapped: {gpa:#018x}, size={size:#x}", + self.vm ); Ok(()) } @@ -206,7 +203,7 @@ impl VmMemory for KvmMemory { guest_memfd_offset: gpa, ..Default::default() }; - unsafe { kvm_set_user_memory_region2(&self.vm, ®ion) } + unsafe { kvm_set_user_memory_region2(&self.vm.fd, ®ion) } } else { let region = KvmUserspaceMemoryRegion { slot: *slot_id, @@ -215,13 +212,13 @@ impl VmMemory for KvmMemory { userspace_addr: hva as _, flags, }; - unsafe { kvm_set_user_memory_region(&self.vm, ®ion) } + unsafe { kvm_set_user_memory_region(&self.vm.fd, ®ion) } } .context(error::GuestMap { hva, gpa, size })?; slots.insert((gpa, size), *slot_id); log::trace!( - "vm-{}: slot-{slot_id}: mapped: {gpa:#018x} -> {hva:#018x}, size = {size:#x}", - self.vm.as_raw_fd() + "{}: slot-{slot_id}: mapped: {gpa:#018x} -> {hva:#018x}, size = {size:#x}", + self.vm ); *slot_id += 1; Ok(()) @@ -240,7 +237,7 @@ impl VmMemory for KvmMemory { addr: range.as_ptr() as u64, size: range.len() as u64, }; - unsafe { kvm_memory_encrypt_reg_region(&self.vm, ®ion) } + unsafe { kvm_memory_encrypt_reg_region(&self.vm.fd, ®ion) } .context(error::EncryptedRegion)?; Ok(()) } @@ -250,7 +247,7 @@ impl VmMemory for KvmMemory { addr: range.as_ptr() as u64, size: range.len() as u64, }; - unsafe { kvm_memory_encrypt_unreg_region(&self.vm, ®ion) } + unsafe { kvm_memory_encrypt_unreg_region(&self.vm.fd, ®ion) } .context(error::EncryptedRegion)?; Ok(()) } @@ -266,7 +263,7 @@ impl VmMemory for KvmMemory { }, flags: 0, }; - unsafe { kvm_set_memory_attributes(&self.vm, &attr) }.context(error::EncryptedRegion)?; + unsafe { kvm_set_memory_attributes(&self.vm.fd, &attr) }.context(error::EncryptedRegion)?; Ok(()) } @@ -297,10 +294,10 @@ impl Drop for KvmIrqSender { flags: KvmIrqfdFlag::DEASSIGN, ..Default::default() }; - if let Err(e) = unsafe { kvm_irqfd(&self.vm, &request) } { + if let Err(e) = unsafe { kvm_irqfd(&self.vm.fd, &request) } { log::error!( - "vm-{}: removing irqfd {:#x}: {e}", - self.vm.as_raw_fd(), + "{}: removing irqfd {:#x}: {e}", + self.vm, self.event_fd.as_raw_fd(), ) } @@ -338,9 +335,9 @@ impl Drop for KvmIrqFd { let mut table = self.vm.msi_table.write(); let Some(entry) = table.remove(&self.gsi) else { log::error!( - "vm-{}: cannot find gsi {:#x} in the gsi table", + "{}: cannot find gsi {:#x} in the gsi table", + self.vm, self.gsi, - self.vm.as_raw_fd() ); return; }; @@ -349,8 +346,8 @@ impl Drop for KvmIrqFd { } if let Err(e) = self.deassign_irqfd() { log::error!( - "vm-{}: removing irqfd {:#x}: {e}", - self.vm.as_raw_fd(), + "{}: removing irqfd {:#x}: {e}", + self.vm, self.event_fd.as_raw_fd(), ) } @@ -370,10 +367,10 @@ impl KvmIrqFd { gsi: self.gsi, ..Default::default() }; - unsafe { kvm_irqfd(&self.vm, &request) }.context(error::IrqFd)?; + unsafe { kvm_irqfd(&self.vm.fd, &request) }.context(error::IrqFd)?; log::debug!( - "vm-{}: assigned: gsi {:#x} -> irqfd {:#x}", - self.vm.as_raw_fd(), + "{}: assigned: gsi {:#x} -> irqfd {:#x}", + self.vm, self.gsi, self.event_fd.as_raw_fd() ); @@ -387,10 +384,10 @@ impl KvmIrqFd { flags: KvmIrqfdFlag::DEASSIGN, ..Default::default() }; - unsafe { kvm_irqfd(&self.vm, &request) }.context(error::IrqFd)?; + unsafe { kvm_irqfd(&self.vm.fd, &request) }.context(error::IrqFd)?; log::debug!( - "vm-{}: de-assigned: gsi {:#x} -> irqfd {:#x}", - self.vm.as_raw_fd(), + "{}: de-assigned: gsi {:#x} -> irqfd {:#x}", + self.vm, self.gsi, self.event_fd.as_raw_fd() ); @@ -437,11 +434,7 @@ impl IrqFd for KvmIrqFd { fn get_masked(&self) -> bool { let table = self.vm.msi_table.read(); let Some(entry) = table.get(&self.gsi) else { - unreachable!( - "vm-{}: cannot find gsi {:#x}", - self.vm.as_raw_fd(), - self.gsi - ); + unreachable!("{}: cannot find gsi {:#x}", self.vm, self.gsi); }; entry.masked } @@ -449,11 +442,7 @@ impl IrqFd for KvmIrqFd { fn set_masked(&self, val: bool) -> Result { let mut table = self.vm.msi_table.write(); let Some(entry) = table.get_mut(&self.gsi) else { - unreachable!( - "vm-{}: cannot find gsi {:#x}", - self.vm.as_raw_fd(), - self.gsi - ); + unreachable!("{}: cannot find gsi {:#x}", self.vm, self.gsi); }; if entry.masked == val { return Ok(false); @@ -494,7 +483,7 @@ impl MsiSender for KvmMsiSender { flags: KvmMsiFlag::VALID_DEVID, ..Default::default() }; - unsafe { kvm_signal_msi(&self.vm, &kvm_msi) }.context(error::SendInterrupt)?; + unsafe { kvm_signal_msi(&self.vm.fd, &kvm_msi) }.context(error::SendInterrupt)?; Ok(()) } @@ -527,8 +516,8 @@ impl MsiSender for KvmMsiSender { return kvm_error::AllocateGsi.fail()?; }; log::debug!( - "vm-{}: allocated: gsi {gsi:#x} -> irqfd {:#x}", - self.vm.as_raw_fd(), + "{}: allocated: gsi {gsi:#x} -> irqfd {:#x}", + self.vm, event_fd.as_raw_fd() ); let entry = KvmIrqFd { @@ -580,7 +569,7 @@ impl IoeventFdRegistry for KvmIoeventFdRegistry { request.datamatch = data; request.flags |= KvmIoEventFdFlag::DATA_MATCH; } - unsafe { kvm_ioeventfd(&self.vm, &request) }.context(error::IoeventFd)?; + unsafe { kvm_ioeventfd(&self.vm.fd, &request) }.context(error::IoeventFd)?; let mut fds = self.vm.ioeventfds.lock(); fds.insert(request.fd, request); Ok(()) @@ -601,7 +590,7 @@ impl IoeventFdRegistry for KvmIoeventFdRegistry { let mut fds = self.vm.ioeventfds.lock(); if let Some(mut request) = fds.remove(&fd.as_fd().as_raw_fd()) { request.flags |= KvmIoEventFdFlag::DEASSIGN; - unsafe { kvm_ioeventfd(&self.vm, &request) }.context(error::IoeventFd)?; + unsafe { kvm_ioeventfd(&self.vm.fd, &request) }.context(error::IoeventFd)?; } Ok(()) } @@ -621,7 +610,7 @@ impl Vm for KvmVm { type Vcpu = KvmVcpu; fn create_vcpu(&self, id: u32) -> Result { - let vcpu_fd = unsafe { kvm_create_vcpu(&self.vm, id) }.context(error::CreateVcpu)?; + let vcpu_fd = unsafe { kvm_create_vcpu(&self.vm.fd, id) }.context(error::CreateVcpu)?; let kvm_run = unsafe { KvmRunBlock::new(vcpu_fd, self.vcpu_mmap_size) }?; Ok(KvmVcpu { fd: unsafe { OwnedFd::from_raw_fd(vcpu_fd) }, @@ -665,7 +654,7 @@ impl Vm for KvmVm { ..Default::default() }; self.vm.update_routing_table(&self.vm.msi_table.read())?; - unsafe { kvm_irqfd(&self.vm, &request) }.context(error::CreateIrq { pin })?; + unsafe { kvm_irqfd(&self.vm.fd, &request) }.context(error::CreateIrq { pin })?; Ok(KvmIrqSender { pin, vm: self.vm.clone(), diff --git a/alioth/src/hv/kvm/vm/x86_64.rs b/alioth/src/hv/kvm/vm/x86_64.rs index 536e55db..f4744c7a 100644 --- a/alioth/src/hv/kvm/vm/x86_64.rs +++ b/alioth/src/hv/kvm/vm/x86_64.rs @@ -48,7 +48,7 @@ impl KvmVm { id: cmd, error: 0, }; - unsafe { kvm_memory_encrypt_op(&self.vm, &mut req) }.context(kvm_error::SevCmd)?; + unsafe { kvm_memory_encrypt_op(&self.vm.fd, &mut req) }.context(kvm_error::SevCmd)?; Ok(()) } diff --git a/alioth/src/hv/kvm/x86_64.rs b/alioth/src/hv/kvm/x86_64.rs index d559586e..c2aa437a 100644 --- a/alioth/src/hv/kvm/x86_64.rs +++ b/alioth/src/hv/kvm/x86_64.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; +use std::os::fd::{FromRawFd, OwnedFd}; use snafu::ResultExt; @@ -79,10 +79,11 @@ impl Kvm { } } Some(Coco::AmdSnp { .. }) => { - let bitmap = unsafe { kvm_check_extension(&kvm_vm.vm, KvmCap::EXIT_HYPERCALL) } - .context(kvm_error::CheckExtension { - ext: "KVM_CAP_EXIT_HYPERCALL", - })?; + let bitmap = + unsafe { kvm_check_extension(&kvm_vm.vm.fd, KvmCap::EXIT_HYPERCALL) } + .context(kvm_error::CheckExtension { + ext: "KVM_CAP_EXIT_HYPERCALL", + })?; if bitmap != 0 { let request = KvmEnableCap { cap: KvmCap::EXIT_HYPERCALL, @@ -90,7 +91,7 @@ impl Kvm { flags: 0, pad: [0; 64], }; - unsafe { kvm_enable_cap(&kvm_vm.vm, &request) }.context( + unsafe { kvm_enable_cap(&kvm_vm.vm.fd, &request) }.context( kvm_error::EnableCap { cap: "KVM_CAP_EXIT_HYPERCALL", }, @@ -98,15 +99,15 @@ impl Kvm { } let mut init = KvmSevInit::default(); kvm_vm.sev_op(KVM_SEV_INIT2, Some(&mut init))?; - log::debug!("vm-{}: snp init: {init:#x?}", kvm_vm.vm.as_raw_fd()); + log::debug!("{}: snp init: {init:#x?}", kvm_vm.vm); } _ => {} } } - unsafe { kvm_create_irqchip(&kvm_vm.vm) }.context(error::CreateDevice)?; + unsafe { kvm_create_irqchip(&kvm_vm.vm.fd) }.context(error::CreateDevice)?; // TODO should be in parameters - unsafe { kvm_set_tss_addr(&kvm_vm.vm, 0xf000_0000) }.context(error::SetVmParam)?; - unsafe { kvm_set_identity_map_addr(&kvm_vm.vm, &0xf000_3000) } + unsafe { kvm_set_tss_addr(&kvm_vm.vm.fd, 0xf000_0000) }.context(error::SetVmParam)?; + unsafe { kvm_set_identity_map_addr(&kvm_vm.vm.fd, &0xf000_3000) } .context(error::SetVmParam)?; Ok(()) } diff --git a/alioth/src/utils/ioctls.rs b/alioth/src/utils/ioctls.rs index b3f9f0d0..128996e1 100644 --- a/alioth/src/utils/ioctls.rs +++ b/alioth/src/utils/ioctls.rs @@ -50,10 +50,12 @@ pub const fn ioctl_iowr(type_: u8, nr: u8) -> u32 { macro_rules! ioctl_none { ($name:ident, $type_:expr, $nr:expr, $val:expr) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name(fd: &F) -> ::std::io::Result { + pub unsafe fn $name(fd: &F) -> ::std::io::Result { let op = $crate::utils::ioctls::ioctl_io($type_, $nr); let v = $val as ::libc::c_ulong; - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, v,) }) + $crate::ffi!(unsafe { + ::libc::ioctl(::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), op as _, v) + }) } }; ($name:ident, $type_:expr, $nr:expr) => { @@ -65,22 +67,26 @@ macro_rules! ioctl_none { macro_rules! ioctl_write_val { ($name:ident, $code:expr) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name( + pub unsafe fn $name( fd: &F, val: ::libc::c_ulong, ) -> ::std::io::Result { let op = $code; - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, val) }) + $crate::ffi!(unsafe { + ::libc::ioctl(::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), op as _, val) + }) } }; ($name:ident, $code:expr, $ty:ty) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name( + pub unsafe fn $name( fd: &F, val: $ty, ) -> ::std::io::Result { let op = $code; - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, val) }) + $crate::ffi!(unsafe { + ::libc::ioctl(::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), op as _, val) + }) } }; } @@ -89,23 +95,35 @@ macro_rules! ioctl_write_val { macro_rules! ioctl_write_ptr { ($name:ident, $code:expr, $ty:ty) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name( + pub unsafe fn $name( fd: &F, val: &$ty, ) -> ::std::io::Result { let op = $code; - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, val as *const $ty) }) + $crate::ffi!(unsafe { + ::libc::ioctl( + ::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), + op as _, + val as *const $ty, + ) + }) } }; ($name:ident, $type_:expr, $nr:expr, $ty:ty) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name( + pub unsafe fn $name( fd: &F, val: &$ty, ) -> ::std::io::Result { let op = $crate::utils::ioctls::ioctl_iow::<$ty>($type_, $nr); - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, val as *const $ty,) }) + $crate::ffi!(unsafe { + ::libc::ioctl( + ::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), + op as _, + val as *const $ty, + ) + }) } }; } @@ -114,12 +132,18 @@ macro_rules! ioctl_write_ptr { macro_rules! ioctl_write_buf { ($name:ident, $code:expr, $ty:ident) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name( + pub unsafe fn $name( fd: &F, val: &$ty, ) -> ::std::io::Result { let op = $code; - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, val as *const $ty) }) + $crate::ffi!(unsafe { + ::libc::ioctl( + ::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), + op as _, + val as *const $ty, + ) + }) } }; ($name:ident, $type_:expr, $nr:expr, $ty:ident) => { @@ -135,12 +159,18 @@ macro_rules! ioctl_write_buf { macro_rules! ioctl_writeread { ($name:ident, $code:expr, $ty:ty) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name( + pub unsafe fn $name( fd: &F, val: &mut $ty, ) -> ::std::io::Result { let op = $code; - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, val as *mut $ty) }) + $crate::ffi!(unsafe { + ::libc::ioctl( + ::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), + op as _, + val as *mut $ty, + ) + }) } }; ($name:ident, $type_:expr, $nr:expr, $ty:ty) => { @@ -152,12 +182,18 @@ macro_rules! ioctl_writeread { }; ($name:ident, $code:expr) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name( + pub unsafe fn $name( fd: &F, val: &mut T, ) -> ::std::io::Result { let op = $code; - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, val as *mut T) }) + $crate::ffi!(unsafe { + ::libc::ioctl( + ::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), + op as _, + val as *mut T, + ) + }) } }; } @@ -166,12 +202,18 @@ macro_rules! ioctl_writeread { macro_rules! ioctl_writeread_buf { ($name:ident, $type_:expr, $nr:expr, $ty:ident) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name( + pub unsafe fn $name( fd: &F, val: &mut $ty, ) -> ::std::io::Result { let op = $crate::utils::ioctls::ioctl_iowr::<$ty<0>>($type_, $nr); - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, val as *mut $ty,) }) + $crate::ffi!(unsafe { + ::libc::ioctl( + ::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), + op as _, + val as *mut $ty, + ) + }) } }; } @@ -180,18 +222,28 @@ macro_rules! ioctl_writeread_buf { macro_rules! ioctl_read { ($name:ident, $code:expr, $ty:ty) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name(fd: &F) -> ::std::io::Result<$ty> { + pub unsafe fn $name(fd: &F) -> ::std::io::Result<$ty> { let mut val = ::core::mem::MaybeUninit::<$ty>::uninit(); - $crate::ffi!(::libc::ioctl(fd.as_raw_fd(), $code as _, val.as_mut_ptr()))?; + $crate::ffi!(::libc::ioctl( + ::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), + $code as _, + val.as_mut_ptr() + ))?; ::std::io::Result::Ok(val.assume_init()) } }; ($name:ident, $type_:expr, $nr:expr, $ty:ty) => { #[allow(clippy::missing_safety_doc)] - pub unsafe fn $name(fd: &F) -> ::std::io::Result<$ty> { + pub unsafe fn $name(fd: &F) -> ::std::io::Result<$ty> { let mut val = ::core::mem::MaybeUninit::<$ty>::uninit(); let op = $crate::utils::ioctls::ioctl_ior::<$ty>($type_, $nr); - $crate::ffi!(unsafe { ::libc::ioctl(fd.as_raw_fd(), op as _, val.as_mut_ptr(),) })?; + $crate::ffi!(unsafe { + ::libc::ioctl( + ::std::os::fd::AsRawFd::as_raw_fd(&fd.as_fd()), + op as _, + val.as_mut_ptr(), + ) + })?; ::std::io::Result::Ok(unsafe { val.assume_init() }) } }; diff --git a/alioth/src/utils/uds.rs b/alioth/src/utils/uds.rs index bdf316e6..10fa23bc 100644 --- a/alioth/src/utils/uds.rs +++ b/alioth/src/utils/uds.rs @@ -14,7 +14,7 @@ use std::io::{ErrorKind, IoSlice, IoSliceMut, Result}; use std::iter::zip; -use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}; +use std::os::fd::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}; use std::os::unix::net::UnixStream; use std::ptr::{null_mut, read_unaligned, write_unaligned}; @@ -91,7 +91,7 @@ pub fn recv_msg_with_fds( } } -pub fn send_msg_with_fds(conn: &UnixStream, bufs: &[IoSlice], fds: &[RawFd]) -> Result { +pub fn send_msg_with_fds(conn: &UnixStream, bufs: &[IoSlice], fds: &[BorrowedFd]) -> Result { if fds.len() > UDS_MAX_FD { return Err(ErrorKind::OutOfMemory.into()); } diff --git a/alioth/src/virtio/dev/fs.rs b/alioth/src/virtio/dev/fs.rs index 03148382..f8415767 100644 --- a/alioth/src/virtio/dev/fs.rs +++ b/alioth/src/virtio/dev/fs.rs @@ -15,7 +15,7 @@ use std::io::ErrorKind; use std::iter::zip; use std::mem::size_of_val; -use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; +use std::os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd}; use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::Ordering; @@ -255,21 +255,22 @@ impl VirtioMio for VuFs { self.vu_dev .set_features(&(feature | VirtioFeature::VHOST_PROTOCOL.bits()))?; for (index, fd) in active_mio.ioeventfds.iter().enumerate() { - self.vu_dev - .set_virtq_kick(&(index as u64), fd.as_fd().as_raw_fd())?; + self.vu_dev.set_virtq_kick(&(index as u64), fd.as_fd())?; } for (index, queue) in active_mio.queues.iter().enumerate() { let Some(queue) = queue else { continue; }; let reg = queue.reg(); - let irq_fd = active_mio.irq_sender.queue_irqfd(index as _)?; - self.vu_dev.set_virtq_call(&(index as u64), irq_fd).unwrap(); + active_mio.irq_sender.queue_irqfd(index as _, |fd| { + self.vu_dev.set_virtq_call(&(index as u64), fd)?; + Ok(()) + })?; let err_fd = unsafe { OwnedFd::from_raw_fd(ffi!(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK))?) }; self.vu_dev - .set_virtq_err(&(index as u64), err_fd.as_raw_fd()) + .set_virtq_err(&(index as u64), err_fd.as_fd()) .unwrap(); active_mio.poll.registry().register( &mut SourceFd(&err_fd.as_raw_fd()), diff --git a/alioth/src/virtio/dev/net/net.rs b/alioth/src/virtio/dev/net/net.rs index 875a6037..6acf78fd 100644 --- a/alioth/src/virtio/dev/net/net.rs +++ b/alioth/src/virtio/dev/net/net.rs @@ -19,7 +19,7 @@ use std::fs::{File, OpenOptions}; use std::io::{ErrorKind, IoSlice}; use std::mem::MaybeUninit; use std::num::NonZeroU16; -use std::os::fd::AsRawFd; +use std::os::fd::{AsFd, AsRawFd}; use std::os::unix::prelude::OpenOptionsExt; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -538,7 +538,7 @@ fn setup_socket(file: &mut File, if_name: Option<&str>, mq: bool) -> Result<()> Ok(()) } -fn detect_tap_offload(tap: &impl AsRawFd) -> NetFeature { +fn detect_tap_offload(tap: &impl AsFd) -> NetFeature { let mut tap_feature = TunFeature::all(); let mut dev_feat = NetFeature::GUEST_CSUM | NetFeature::GUEST_TSO4 diff --git a/alioth/src/virtio/dev/vsock/vhost_vsock.rs b/alioth/src/virtio/dev/vsock/vhost_vsock.rs index 87afd014..85c2c8cc 100644 --- a/alioth/src/virtio/dev/vsock/vhost_vsock.rs +++ b/alioth/src/virtio/dev/vsock/vhost_vsock.rs @@ -175,8 +175,13 @@ impl VirtioMio for VhostVsock { }; let reg = queue.reg(); let index = index as u32; - let fd = active_mio.irq_sender.queue_irqfd(index as _)?; - self.vhost_dev.set_virtq_call(&VirtqFile { index, fd })?; + active_mio.irq_sender.queue_irqfd(index as _, |fd| { + self.vhost_dev.set_virtq_call(&VirtqFile { + index, + fd: fd.as_raw_fd(), + })?; + Ok(()) + })?; self.vhost_dev.set_virtq_num(&VirtqState { index, diff --git a/alioth/src/virtio/pci.rs b/alioth/src/virtio/pci.rs index e586e0a3..59f76195 100644 --- a/alioth/src/virtio/pci.rs +++ b/alioth/src/virtio/pci.rs @@ -14,7 +14,7 @@ use std::marker::PhantomData; use std::mem::size_of; -use std::os::fd::{AsFd, AsRawFd, RawFd}; +use std::os::fd::{AsFd, AsRawFd, BorrowedFd}; use std::sync::Arc; use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::mpsc::Sender; @@ -85,7 +85,10 @@ where } } - fn get_irqfd(&self, vector: u16) -> Result { + fn get_irqfd(&self, vector: u16, f: F) -> Result + where + F: FnOnce(BorrowedFd) -> Result, + { let mut entries = self.msix_table.entries.write(); let Some(entry) = entries.get_mut(vector as usize) else { return error::InvalidMsixVector { vector }.fail(); @@ -97,11 +100,11 @@ where irqfd.set_addr_lo(e.addr_lo)?; irqfd.set_data(e.data)?; irqfd.set_masked(e.control.masked())?; - let raw_fd = irqfd.as_fd().as_raw_fd(); + let r = f(irqfd.as_fd())?; *entry = MsixTableMmioEntry::IrqFd(irqfd); - Ok(raw_fd) + Ok(r) } - MsixTableMmioEntry::IrqFd(f) => Ok(f.as_fd().as_raw_fd()), + MsixTableMmioEntry::IrqFd(fd) => f(fd.as_fd()), } } } @@ -128,15 +131,21 @@ where } } - fn config_irqfd(&self) -> Result { - self.get_irqfd(self.msix_vector.config.load(Ordering::Acquire)) + fn config_irqfd(&self, f: F) -> Result + where + F: FnOnce(BorrowedFd) -> Result, + { + self.get_irqfd(self.msix_vector.config.load(Ordering::Acquire), f) } - fn queue_irqfd(&self, idx: u16) -> Result { + fn queue_irqfd(&self, idx: u16, f: F) -> Result + where + F: FnOnce(BorrowedFd) -> Result, + { let Some(vector) = self.msix_vector.queues.get(idx as usize) else { return error::InvalidQueueIndex { index: idx }.fail(); }; - self.get_irqfd(vector.load(Ordering::Acquire)) + self.get_irqfd(vector.load(Ordering::Acquire), f) } } diff --git a/alioth/src/virtio/virtio.rs b/alioth/src/virtio/virtio.rs index 9a0d59cf..058fd5d9 100644 --- a/alioth/src/virtio/virtio.rs +++ b/alioth/src/virtio/virtio.rs @@ -26,7 +26,7 @@ pub mod vu; pub mod worker; use std::fmt::Debug; -use std::os::fd::RawFd; +use std::os::fd::BorrowedFd; use std::path::PathBuf; use bitflags::bitflags; @@ -121,6 +121,10 @@ bitflags! { pub trait IrqSender: Send + Sync + Debug + 'static { fn queue_irq(&self, idx: u16); fn config_irq(&self); - fn queue_irqfd(&self, idx: u16) -> Result; - fn config_irqfd(&self) -> Result; + fn queue_irqfd(&self, idx: u16, f: F) -> Result + where + F: FnOnce(BorrowedFd) -> Result; + fn config_irqfd(&self, f: F) -> Result + where + F: FnOnce(BorrowedFd) -> Result; } diff --git a/alioth/src/virtio/vu.rs b/alioth/src/virtio/vu.rs index 50e99930..8279ffdc 100644 --- a/alioth/src/virtio/vu.rs +++ b/alioth/src/virtio/vu.rs @@ -14,7 +14,7 @@ use std::io::{IoSlice, IoSliceMut, Read, Write}; use std::mem::{size_of, size_of_val}; -use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}; +use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd}; use std::os::unix::net::UnixStream; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -241,7 +241,7 @@ impl VuDev { })?; let channel = unsafe { UnixStream::from_raw_fd(socket_fds[0]) }; let peer = unsafe { OwnedFd::from_raw_fd(socket_fds[1]) }; - self.set_backend_req_fd(peer.as_raw_fd())?; + self.set_backend_req_fd(peer.as_fd())?; self.channel = Some(channel); Ok(()) } @@ -254,7 +254,7 @@ impl VuDev { &self, req: u32, payload: &T, - fds: &[RawFd], + fds: &[BorrowedFd], ) -> Result { let vhost_msg = Message { request: req, @@ -369,15 +369,15 @@ impl VuDev { self.send_msg(VHOST_USER_GET_QUEUE_NUM, &(), &[]) } - pub fn set_virtq_kick(&self, payload: &u64, fd: RawFd) -> Result<()> { + pub fn set_virtq_kick(&self, payload: &u64, fd: BorrowedFd) -> Result<()> { self.send_msg(VHOST_USER_SET_VIRTQ_KICK, payload, &[fd]) } - pub fn set_virtq_call(&self, payload: &u64, fd: RawFd) -> Result<()> { + pub fn set_virtq_call(&self, payload: &u64, fd: BorrowedFd) -> Result<()> { self.send_msg(VHOST_USER_SET_VIRTQ_CALL, payload, &[fd]) } - pub fn set_virtq_err(&self, payload: &u64, fd: RawFd) -> Result<()> { + pub fn set_virtq_err(&self, payload: &u64, fd: BorrowedFd) -> Result<()> { self.send_msg(VHOST_USER_SET_VIRTQ_ERR, payload, &[fd]) } @@ -393,7 +393,7 @@ impl VuDev { self.send_msg(VHOST_USER_GET_STATUS, &(), &[]) } - pub fn add_mem_region(&self, payload: &MemorySingleRegion, fd: RawFd) -> Result<()> { + pub fn add_mem_region(&self, payload: &MemorySingleRegion, fd: BorrowedFd) -> Result<()> { self.send_msg(VHOST_USER_ADD_MEM_REG, payload, &[fd]) } @@ -401,7 +401,7 @@ impl VuDev { self.send_msg(VHOST_USER_REM_MEM_REG, payload, &[]) } - fn set_backend_req_fd(&self, fd: RawFd) -> Result<()> { + fn set_backend_req_fd(&self, fd: BorrowedFd) -> Result<()> { self.send_msg(VHOST_USER_SET_BACKEND_REQ_FD, &0u64, &[fd]) } @@ -470,7 +470,7 @@ impl LayoutChanged for UpdateVuMem { mmap_offset: offset, }, }; - let ret = self.dev.add_mem_region(®ion, fd.as_raw_fd()); + let ret = self.dev.add_mem_region(®ion, fd); ret.box_trace(mem::error::ChangeLayout)?; log::trace!( "vu-{}: added memory region {:x?}", From 626c45d18e34397ee4cc13b8a3de5baa5a111903 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sun, 20 Apr 2025 22:44:35 -0700 Subject: [PATCH 04/10] refactor(vu): split VuDev into VuSession and VuChannel Signed-off-by: Changyuan Lyu --- alioth/src/virtio/dev/fs.rs | 89 +++--- alioth/src/virtio/virtio.rs | 1 + alioth/src/virtio/vu.rs | 505 ------------------------------- alioth/src/virtio/vu/bindings.rs | 207 +++++++++++++ alioth/src/virtio/vu/conn.rs | 295 ++++++++++++++++++ alioth/src/virtio/vu/frontend.rs | 75 +++++ alioth/src/virtio/vu/vu.rs | 57 ++++ 7 files changed, 685 insertions(+), 544 deletions(-) delete mode 100644 alioth/src/virtio/vu.rs create mode 100644 alioth/src/virtio/vu/bindings.rs create mode 100644 alioth/src/virtio/vu/conn.rs create mode 100644 alioth/src/virtio/vu/frontend.rs create mode 100644 alioth/src/virtio/vu/vu.rs diff --git a/alioth/src/virtio/dev/fs.rs b/alioth/src/virtio/dev/fs.rs index f8415767..3df4a5e1 100644 --- a/alioth/src/virtio/dev/fs.rs +++ b/alioth/src/virtio/dev/fs.rs @@ -39,9 +39,10 @@ use crate::mem::mapped::{ArcMemPages, RamBus}; use crate::mem::{LayoutChanged, MemRegion, MemRegionType}; use crate::virtio::dev::{DevParam, Virtio, WakeEvent}; use crate::virtio::queue::{Queue, VirtQueue}; -use crate::virtio::vu::{ - DeviceConfig, Error, UpdateVuMem, VirtqAddr, VirtqState, VuDev, VuFeature, error as vu_error, -}; +use crate::virtio::vu::bindings::{DeviceConfig, VirtqAddr, VirtqState, VuBackMsg, VuFeature}; +use crate::virtio::vu::conn::{VuChannel, VuSession}; +use crate::virtio::vu::frontend::UpdateVuMem; +use crate::virtio::vu::{Error, error as vu_error}; use crate::virtio::worker::Waker; use crate::virtio::worker::mio::{ActiveMio, Mio, VirtioMio}; use crate::virtio::{DeviceId, IrqSender, Result, VirtioFeature, error}; @@ -79,7 +80,8 @@ const VHOST_USER_BACKEND_FS_UNMAP: u32 = 7; #[derive(Debug)] pub struct VuFs { name: Arc, - vu_dev: Arc, + session: Arc, + channel: Option, config: Arc, feature: u64, num_queues: u16, @@ -90,8 +92,8 @@ pub struct VuFs { impl VuFs { pub fn new(param: VuFsParam, name: impl Into>) -> Result { let name = name.into(); - let mut vu_dev = VuDev::new(param.socket)?; - let dev_feat = vu_dev.get_features()?; + let session = Arc::new(VuSession::new(param.socket)?); + let dev_feat = session.get_features()?; let virtio_feat = VirtioFeature::from_bits_retain(dev_feat); let need_feat = VirtioFeature::VHOST_PROTOCOL | VirtioFeature::VERSION_1; if !virtio_feat.contains(need_feat) { @@ -101,7 +103,7 @@ impl VuFs { .fail()?; } - let prot_feat = VuFeature::from_bits_retain(vu_dev.get_protocol_features()?); + let prot_feat = VuFeature::from_bits_retain(session.get_protocol_features()?); log::debug!("{name}: vhost-user feat: {prot_feat:x?}"); let mut need_feat = VuFeature::MQ | VuFeature::REPLY_ACK | VuFeature::CONFIGURE_MEM_SLOTS; if param.tag.is_none() { @@ -117,10 +119,10 @@ impl VuFs { } .fail()?; } - vu_dev.set_protocol_features(&need_feat.bits())?; + session.set_protocol_features(&need_feat.bits())?; - vu_dev.set_owner()?; - let num_queues = vu_dev.get_queue_num()? as u16; + session.set_owner()?; + let num_queues = session.get_queue_num()? as u16; let config = if let Some(tag) = param.tag { assert!(tag.len() <= 36); assert_ne!(tag.len(), 0); @@ -134,21 +136,23 @@ impl VuFs { } else { let mut empty_cfg = DeviceConfig::new_zeroed(); empty_cfg.size = size_of_val(&empty_cfg.region) as _; - let dev_config = vu_dev.get_config(&empty_cfg)?; + let dev_config = session.get_config(&empty_cfg)?; FsConfig::read_from_prefix(&dev_config.region).unwrap().0 }; - let dax_region = if param.dax_window > 0 { - vu_dev.setup_channel()?; + let (dax_region, channel) = if param.dax_window > 0 { + let channel = session.create_channel()?; let size = align_up!(param.dax_window, 12); - Some(ArcMemPages::from_anonymous(size, Some(PROT_NONE), None)?) + let region = ArcMemPages::from_anonymous(size, Some(PROT_NONE), None)?; + (Some(region), Some(channel)) } else { - None + (None, None) }; Ok(VuFs { num_queues, name, - vu_dev: Arc::new(vu_dev), + session, + channel, config: Arc::new(config), feature: dev_feat & !VirtioFeature::VHOST_PROTOCOL.bits(), error_fds: Vec::new(), @@ -236,7 +240,8 @@ impl Virtio for VuFs { fn mem_change_callback(&self) -> Option> { Some(Box::new(UpdateVuMem { - dev: self.vu_dev.clone(), + name: self.name.clone(), + session: self.session.clone(), })) } } @@ -252,10 +257,10 @@ impl VirtioMio for VuFs { S: IrqSender, E: IoeventFd, { - self.vu_dev + self.session .set_features(&(feature | VirtioFeature::VHOST_PROTOCOL.bits()))?; for (index, fd) in active_mio.ioeventfds.iter().enumerate() { - self.vu_dev.set_virtq_kick(&(index as u64), fd.as_fd())?; + self.session.set_virtq_kick(&(index as u64), fd.as_fd())?; } for (index, queue) in active_mio.queues.iter().enumerate() { let Some(queue) = queue else { @@ -263,13 +268,13 @@ impl VirtioMio for VuFs { }; let reg = queue.reg(); active_mio.irq_sender.queue_irqfd(index as _, |fd| { - self.vu_dev.set_virtq_call(&(index as u64), fd)?; + self.session.set_virtq_call(&(index as u64), fd)?; Ok(()) })?; let err_fd = unsafe { OwnedFd::from_raw_fd(ffi!(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK))?) }; - self.vu_dev + self.session .set_virtq_err(&(index as u64), err_fd.as_fd()) .unwrap(); active_mio.poll.registry().register( @@ -283,14 +288,14 @@ impl VirtioMio for VuFs { index: index as _, val: reg.size.load(Ordering::Acquire) as _, }; - self.vu_dev.set_virtq_num(&virtq_num).unwrap(); + self.session.set_virtq_num(&virtq_num).unwrap(); log::info!("set_virtq_num: {virtq_num:x?}"); let virtq_base = VirtqState { index: index as _, val: 0, }; - self.vu_dev.set_virtq_base(&virtq_base).unwrap(); + self.session.set_virtq_base(&virtq_base).unwrap(); log::info!("set_virtq_base: {virtq_base:x?}"); let mem = active_mio.mem; @@ -302,7 +307,7 @@ impl VirtioMio for VuFs { avail_hva: mem.translate(reg.driver.load(Ordering::Acquire) as _)? as _, log_guest_addr: 0, }; - self.vu_dev.set_virtq_addr(&virtq_addr).unwrap(); + self.session.set_virtq_addr(&virtq_addr).unwrap(); log::info!("queue: {:x?}", reg); log::info!("virtq_addr: {virtq_addr:x?}"); } @@ -311,13 +316,13 @@ impl VirtioMio for VuFs { index: index as _, val: 1, }; - self.vu_dev.set_virtq_enable(&virtq_enable).unwrap(); + self.session.set_virtq_enable(&virtq_enable).unwrap(); log::info!("virtq_enable: {virtq_enable:x?}"); } - if let Some(channel) = self.vu_dev.get_channel() { - channel.set_nonblocking(true)?; + if let Some(channel) = &self.channel { + channel.conn.set_nonblocking(true)?; active_mio.poll.registry().register( - &mut SourceFd(&channel.as_raw_fd()), + &mut SourceFd(&channel.conn.as_raw_fd()), Token(self.num_queues as _), Interest::READABLE, )?; @@ -349,14 +354,19 @@ impl VirtioMio for VuFs { } .fail()?; }; + let Some(channel) = &self.channel else { + return vu_error::ProtocolFeature { + feature: VuFeature::BACKEND_REQ, + } + .fail()?; + }; loop { - let mut fs_map = VuFsMap::new_zeroed(); - let mut fds = [None, None, None, None, None, None, None, None]; - let ret = self - .vu_dev - .receive_from_channel(fs_map.as_mut_bytes(), &mut fds); - let (request, size) = match ret { - Ok((r, s)) => (r, s), + let mut fds = [const { None }; 8]; + let msg = channel.recv_msg(&mut fds); + let fs_map: VuFsMap = channel.recv_payload()?; + + let (request, size) = match msg { + Ok(m) => (m.request, m.size), Err(Error::System { error, .. }) if error.kind() == ErrorKind::WouldBlock => break, Err(e) => return Err(e)?, }; @@ -411,7 +421,7 @@ impl VirtioMio for VuFs { } _ => unimplemented!("unknown request {request:#x}"), } - self.vu_dev.ack_request(request, &0u64)?; + channel.reply(VuBackMsg::from(request), &0u64, &[])?; } Ok(()) } @@ -438,14 +448,15 @@ impl VirtioMio for VuFs { index: q_index as _, val: 0, }; - self.vu_dev.set_virtq_enable(&disable).unwrap(); + self.session.set_virtq_enable(&disable).unwrap(); } while let Some(fd) = self.error_fds.pop() { registry.deregister(&mut SourceFd(&fd.as_raw_fd())).unwrap(); } - if let Some(channel) = self.vu_dev.get_channel() { + if let Some(channel) = &self.channel { + let channel_fd = channel.conn.as_fd(); registry - .deregister(&mut SourceFd(&channel.as_raw_fd())) + .deregister(&mut SourceFd(&channel_fd.as_raw_fd())) .unwrap(); } } diff --git a/alioth/src/virtio/virtio.rs b/alioth/src/virtio/virtio.rs index 058fd5d9..8a70417a 100644 --- a/alioth/src/virtio/virtio.rs +++ b/alioth/src/virtio/virtio.rs @@ -21,6 +21,7 @@ pub mod queue; #[path = "vhost/vhost.rs"] pub mod vhost; #[cfg(target_os = "linux")] +#[path = "vu/vu.rs"] pub mod vu; #[path = "worker/worker.rs"] pub mod worker; diff --git a/alioth/src/virtio/vu.rs b/alioth/src/virtio/vu.rs deleted file mode 100644 index 8279ffdc..00000000 --- a/alioth/src/virtio/vu.rs +++ /dev/null @@ -1,505 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::io::{IoSlice, IoSliceMut, Read, Write}; -use std::mem::{size_of, size_of_val}; -use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd}; -use std::os::unix::net::UnixStream; -use std::path::{Path, PathBuf}; -use std::sync::Arc; - -use bitfield::bitfield; -use bitflags::bitflags; -use parking_lot::Mutex; -use snafu::{ResultExt, Snafu}; -use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes}; - -use crate::errors::{BoxTrace, DebugTrace, trace_error}; -use crate::mem::LayoutChanged; -use crate::mem::mapped::ArcMemPages; -use crate::utils::uds::{recv_msg_with_fds, send_msg_with_fds}; -use crate::{ffi, mem}; - -bitflags! { - #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] - #[repr(transparent)] - pub struct VuFeature: u64 { - const MQ = 1 << 0; - const LOG_SHMFD = 1 << 1; - const RARP = 1 << 2; - const REPLY_ACK = 1 << 3; - const MTU = 1 << 4; - const BACKEND_REQ = 1 << 5; - const CROSS_ENDIAN = 1 << 6; - const CRYPTO_SESSION = 1 << 7; - const PAGEFAULT = 1 << 8; - const CONFIG = 1 << 9; - const BACKEND_SEND_FD = 1 << 10; - const HOST_NOTIFIER = 1 << 11; - const INFLIGHT_SHMFD = 1 << 12; - const RESET_DEVICE = 1 << 13; - const INBAND_NOTIFICATIONS = 1 << 14; - const CONFIGURE_MEM_SLOTS = 1 << 15; - const STATUS = 1 << 16; - const XEN_MMAP = 1 << 17; - const SHARED_OBJECT = 1 << 18; - const DEVICE_STATE = 1 << 19; - } -} - -pub const VHOST_USER_GET_FEATURES: u32 = 1; -pub const VHOST_USER_SET_FEATURES: u32 = 2; -pub const VHOST_USER_SET_OWNER: u32 = 3; -#[deprecated] -pub const VHOST_USER_RESET_OWNER: u32 = 4; -pub const VHOST_USER_SET_MEM_TABLE: u32 = 5; -pub const VHOST_USER_SET_LOG_BASE: u32 = 6; -pub const VHOST_USER_SET_LOG_FD: u32 = 7; -pub const VHOST_USER_SET_VIRTQ_NUM: u32 = 8; -pub const VHOST_USER_SET_VIRTQ_ADDR: u32 = 9; -pub const VHOST_USER_SET_VIRTQ_BASE: u32 = 10; -pub const VHOST_USER_GET_VIRTQ_BASE: u32 = 11; -pub const VHOST_USER_SET_VIRTQ_KICK: u32 = 12; -pub const VHOST_USER_SET_VIRTQ_CALL: u32 = 13; -pub const VHOST_USER_SET_VIRTQ_ERR: u32 = 14; -pub const VHOST_USER_GET_PROTOCOL_FEATURES: u32 = 15; -pub const VHOST_USER_SET_PROTOCOL_FEATURES: u32 = 16; -pub const VHOST_USER_GET_QUEUE_NUM: u32 = 17; -pub const VHOST_USER_SET_VIRTQ_ENABLE: u32 = 18; -pub const VHOST_USER_SEND_RARP: u32 = 19; -pub const VHOST_USER_NET_SET_MTU: u32 = 20; -pub const VHOST_USER_SET_BACKEND_REQ_FD: u32 = 21; -pub const VHOST_USER_IOTLB_MSG: u32 = 22; -pub const VHOST_USER_SET_VIRTQ_ENDIAN: u32 = 23; -pub const VHOST_USER_GET_CONFIG: u32 = 24; -pub const VHOST_USER_SET_CONFIG: u32 = 25; -pub const VHOST_USER_CREATE_CRYPTO_SESSION: u32 = 26; -pub const VHOST_USER_CLOSE_CRYPTO_SESSION: u32 = 27; -pub const VHOST_USER_POSTCOPY_ADVISE: u32 = 28; -pub const VHOST_USER_POSTCOPY_LISTEN: u32 = 29; -pub const VHOST_USER_POSTCOPY_END: u32 = 30; -pub const VHOST_USER_GET_INFLIGHT_FD: u32 = 31; -pub const VHOST_USER_SET_INFLIGHT_FD: u32 = 32; -pub const VHOST_USER_GPU_SET_SOCKET: u32 = 33; -pub const VHOST_USER_RESET_DEVICE: u32 = 34; -pub const VHOST_USER_GET_MAX_MEM_SLOTS: u32 = 36; -pub const VHOST_USER_ADD_MEM_REG: u32 = 37; -pub const VHOST_USER_REM_MEM_REG: u32 = 38; -pub const VHOST_USER_SET_STATUS: u32 = 39; -pub const VHOST_USER_GET_STATUS: u32 = 40; -pub const VHOST_USER_GET_SHARED_OBJECT: u32 = 41; -pub const VHOST_USER_SET_DEVICE_STATE_FD: u32 = 42; -pub const VHOST_USER_CHECK_DEVICE_STATE: u32 = 43; - -bitfield! { - #[derive(Copy, Clone, Default, IntoBytes, FromBytes, Immutable)] - #[repr(transparent)] - pub struct MessageFlag(u32); - impl Debug; - need_reply, set_need_reply: 3; - reply, set_reply: 2; - version, set_version: 1, 0; -} - -impl MessageFlag { - pub const NEED_REPLY: u32 = 1 << 3; - pub const REPLY: u32 = 1 << 2; - pub const VERSION_1: u32 = 0x1; - - pub const fn sender() -> Self { - MessageFlag(MessageFlag::VERSION_1 | MessageFlag::NEED_REPLY) - } - - pub const fn receiver() -> Self { - MessageFlag(MessageFlag::VERSION_1 | MessageFlag::REPLY) - } -} - -#[derive(Debug, IntoBytes, FromBytes, Immutable)] -#[repr(C)] -pub struct VirtqState { - pub index: u32, - pub val: u32, -} - -#[derive(Debug, IntoBytes, FromBytes, Immutable)] -#[repr(C)] -pub struct VirtqAddr { - pub index: u32, - pub flags: u32, - pub desc_hva: u64, - pub used_hva: u64, - pub avail_hva: u64, - pub log_guest_addr: u64, -} - -#[derive(Debug, IntoBytes, FromBytes, Immutable)] -#[repr(C)] -pub struct MemoryRegion { - pub gpa: u64, - pub size: u64, - pub hva: u64, - pub mmap_offset: u64, -} - -#[derive(Debug, IntoBytes, FromBytes, Immutable)] -#[repr(C)] -pub struct MemorySingleRegion { - pub _padding: u64, - pub region: MemoryRegion, -} - -#[derive(Debug, IntoBytes, FromBytes, Immutable)] -#[repr(C)] -pub struct MemoryMultipleRegion { - pub num: u32, - pub _padding: u32, - pub regions: [MemoryRegion; 8], -} - -#[derive(Debug, IntoBytes, FromBytes, Immutable)] -#[repr(C)] -pub struct DeviceConfig { - pub offset: u32, - pub size: u32, - pub flags: u32, - pub region: [u8; 256], -} - -#[derive(Debug, IntoBytes, FromBytes, Immutable)] -#[repr(C)] -pub struct Message { - pub request: u32, - pub flag: MessageFlag, - pub size: u32, -} - -#[trace_error] -#[derive(Snafu, DebugTrace)] -#[snafu(module, visibility(pub(crate)), context(suffix(false)))] -pub enum Error { - #[snafu(display("Cannot access socket {path:?}"))] - AccessSocket { - path: PathBuf, - error: std::io::Error, - }, - #[snafu(display("Error from OS"), context(false))] - System { error: std::io::Error }, - #[snafu(display("Invalid vhost-user response message, want {want}, got {got}"))] - InvalidResp { want: u32, got: u32 }, - #[snafu(display("Invalid vhost-user message size, want {want}, get {got}"))] - MsgSize { want: usize, got: usize }, - #[snafu(display("Invalid vhost-user message payload size, want {want}, got {got}"))] - PayloadSize { want: usize, got: u32 }, - #[snafu(display("vhost-user backend replied error code {ret:#x} to request {req:#x}"))] - RequestErr { ret: u64, req: u32 }, - #[snafu(display("vhost-user backend signaled an error of queue {index:#x}"))] - QueueErr { index: u16 }, - #[snafu(display("vhost-user backend is missing device feature {feature:#x}"))] - DeviceFeature { feature: u64 }, - #[snafu(display("vhost-user backend is missing protocol feature {feature:x?}"))] - ProtocolFeature { feature: VuFeature }, -} - -type Result = std::result::Result; - -#[derive(Debug)] -pub struct VuDev { - conn: Mutex, - channel: Option, -} - -impl VuDev { - pub fn new>(sock: P) -> Result { - let conn = UnixStream::connect(&sock).context(error::AccessSocket { - path: sock.as_ref(), - })?; - Ok(VuDev { - conn: Mutex::new(conn), - channel: None, - }) - } - - pub fn setup_channel(&mut self) -> Result<()> { - if self.channel.is_some() { - return Ok(()); - } - let mut socket_fds = [0; 2]; - ffi!(unsafe { - libc::socketpair(libc::PF_UNIX, libc::SOCK_STREAM, 0, socket_fds.as_mut_ptr()) - })?; - let channel = unsafe { UnixStream::from_raw_fd(socket_fds[0]) }; - let peer = unsafe { OwnedFd::from_raw_fd(socket_fds[1]) }; - self.set_backend_req_fd(peer.as_fd())?; - self.channel = Some(channel); - Ok(()) - } - - pub fn get_channel(&self) -> Option<&UnixStream> { - self.channel.as_ref() - } - - fn send_msg( - &self, - req: u32, - payload: &T, - fds: &[BorrowedFd], - ) -> Result { - let vhost_msg = Message { - request: req, - flag: MessageFlag::sender(), - size: size_of::() as u32, - }; - let bufs = [ - IoSlice::new(vhost_msg.as_bytes()), - IoSlice::new(payload.as_bytes()), - ]; - let mut conn = self.conn.lock(); - send_msg_with_fds(&conn, &bufs, fds)?; - - let mut resp = Message::new_zeroed(); - let mut payload = R::new_zeroed(); - let mut ret_code = u64::MAX; - let mut bufs = if size_of::() == 0 { - [ - IoSliceMut::new(resp.as_mut_bytes()), - IoSliceMut::new(ret_code.as_mut_bytes()), - ] - } else { - [ - IoSliceMut::new(resp.as_mut_bytes()), - IoSliceMut::new(payload.as_mut_bytes()), - ] - }; - let read_size = conn.read_vectored(&mut bufs)?; - let expect_size = size_of::() + bufs[1].len(); - if read_size != expect_size { - return error::MsgSize { - want: expect_size, - got: read_size, - } - .fail(); - } - if resp.request != req { - return error::InvalidResp { - want: req, - got: resp.request, - } - .fail(); - } - if size_of::() != 0 { - if resp.size != size_of::() as u32 { - return error::PayloadSize { - want: size_of::(), - got: resp.size, - } - .fail(); - } - } else { - if resp.size != size_of::() as u32 { - return error::PayloadSize { - want: size_of::(), - got: resp.size, - } - .fail(); - } - if ret_code != 0 { - return error::RequestErr { ret: ret_code, req }.fail(); - } - } - Ok(payload) - } - - pub fn get_features(&self) -> Result { - self.send_msg(VHOST_USER_GET_FEATURES, &(), &[]) - } - - pub fn set_features(&self, payload: &u64) -> Result<()> { - self.send_msg(VHOST_USER_SET_FEATURES, payload, &[]) - } - - pub fn get_protocol_features(&self) -> Result { - self.send_msg(VHOST_USER_GET_PROTOCOL_FEATURES, &(), &[]) - } - - pub fn set_protocol_features(&self, payload: &u64) -> Result { - self.send_msg(VHOST_USER_SET_PROTOCOL_FEATURES, payload, &[]) - } - - pub fn set_owner(&self) -> Result<()> { - self.send_msg(VHOST_USER_SET_OWNER, &(), &[]) - } - - pub fn set_virtq_num(&self, payload: &VirtqState) -> Result<()> { - self.send_msg(VHOST_USER_SET_VIRTQ_NUM, payload, &[]) - } - - pub fn set_virtq_addr(&self, payload: &VirtqAddr) -> Result<()> { - self.send_msg(VHOST_USER_SET_VIRTQ_ADDR, payload, &[]) - } - - pub fn set_virtq_base(&self, payload: &VirtqState) -> Result<()> { - self.send_msg(VHOST_USER_SET_VIRTQ_BASE, payload, &[]) - } - - pub fn get_config(&self, payload: &DeviceConfig) -> Result { - self.send_msg(VHOST_USER_GET_CONFIG, payload, &[]) - } - - pub fn set_config(&self, payload: &DeviceConfig) -> Result<()> { - self.send_msg(VHOST_USER_SET_CONFIG, payload, &[]) - } - - pub fn get_virtq_base(&self, payload: &VirtqState) -> Result { - self.send_msg(VHOST_USER_GET_VIRTQ_BASE, payload, &[]) - } - - pub fn get_queue_num(&self) -> Result { - self.send_msg(VHOST_USER_GET_QUEUE_NUM, &(), &[]) - } - - pub fn set_virtq_kick(&self, payload: &u64, fd: BorrowedFd) -> Result<()> { - self.send_msg(VHOST_USER_SET_VIRTQ_KICK, payload, &[fd]) - } - - pub fn set_virtq_call(&self, payload: &u64, fd: BorrowedFd) -> Result<()> { - self.send_msg(VHOST_USER_SET_VIRTQ_CALL, payload, &[fd]) - } - - pub fn set_virtq_err(&self, payload: &u64, fd: BorrowedFd) -> Result<()> { - self.send_msg(VHOST_USER_SET_VIRTQ_ERR, payload, &[fd]) - } - - pub fn set_virtq_enable(&self, payload: &VirtqState) -> Result<()> { - self.send_msg(VHOST_USER_SET_VIRTQ_ENABLE, payload, &[]) - } - - pub fn set_status(&self, payload: &u64) -> Result<()> { - self.send_msg(VHOST_USER_SET_STATUS, payload, &[]) - } - - pub fn get_status(&self) -> Result { - self.send_msg(VHOST_USER_GET_STATUS, &(), &[]) - } - - pub fn add_mem_region(&self, payload: &MemorySingleRegion, fd: BorrowedFd) -> Result<()> { - self.send_msg(VHOST_USER_ADD_MEM_REG, payload, &[fd]) - } - - pub fn remove_mem_region(&self, payload: &MemorySingleRegion) -> Result<()> { - self.send_msg(VHOST_USER_REM_MEM_REG, payload, &[]) - } - - fn set_backend_req_fd(&self, fd: BorrowedFd) -> Result<()> { - self.send_msg(VHOST_USER_SET_BACKEND_REQ_FD, &0u64, &[fd]) - } - - pub fn receive_from_channel( - &self, - buf: &mut [u8], - fds: &mut [Option], - ) -> Result<(u32, u32)> { - let mut msg = Message::new_zeroed(); - let mut bufs = [IoSliceMut::new(msg.as_mut_bytes()), IoSliceMut::new(buf)]; - let Some(channel) = &self.channel else { - return error::ProtocolFeature { - feature: VuFeature::BACKEND_REQ, - } - .fail(); - }; - let r_size = recv_msg_with_fds(channel, &mut bufs, fds)?; - let expected_size = size_of::() + msg.size as usize; - if r_size != expected_size { - return error::MsgSize { - want: expected_size, - got: r_size, - } - .fail(); - } - Ok((msg.request, msg.size)) - } - - pub fn ack_request(&self, req: u32, payload: &T) -> Result<()> { - let Some(channel) = &self.channel else { - return error::ProtocolFeature { - feature: VuFeature::BACKEND_REQ, - } - .fail(); - }; - let msg = Message { - request: req, - flag: MessageFlag::receiver(), - size: size_of_val(payload) as _, - }; - let bufs = [ - IoSlice::new(msg.as_bytes()), - IoSlice::new(payload.as_bytes()), - ]; - Write::write_vectored(&mut (&*channel), &bufs)?; - Ok(()) - } -} - -#[derive(Debug)] -pub struct UpdateVuMem { - pub dev: Arc, -} - -impl LayoutChanged for UpdateVuMem { - fn ram_added(&self, gpa: u64, pages: &ArcMemPages) -> mem::Result<()> { - let Some((fd, offset)) = pages.fd() else { - return Ok(()); - }; - let region = MemorySingleRegion { - _padding: 0, - region: MemoryRegion { - gpa: gpa as _, - size: pages.size() as _, - hva: pages.addr() as _, - mmap_offset: offset, - }, - }; - let ret = self.dev.add_mem_region(®ion, fd); - ret.box_trace(mem::error::ChangeLayout)?; - log::trace!( - "vu-{}: added memory region {:x?}", - self.dev.conn.lock().as_raw_fd(), - region.region - ); - Ok(()) - } - - fn ram_removed(&self, gpa: u64, pages: &ArcMemPages) -> mem::Result<()> { - let Some((_, offset)) = pages.fd() else { - return Ok(()); - }; - let region = MemorySingleRegion { - _padding: 0, - region: MemoryRegion { - gpa: gpa as _, - size: pages.size() as _, - hva: pages.addr() as _, - mmap_offset: offset, - }, - }; - let ret = self.dev.remove_mem_region(®ion); - ret.box_trace(mem::error::ChangeLayout)?; - log::trace!( - "vu-{}: removed memory region {:x?}", - self.dev.conn.lock().as_raw_fd(), - region.region - ); - Ok(()) - } -} diff --git a/alioth/src/virtio/vu/bindings.rs b/alioth/src/virtio/vu/bindings.rs new file mode 100644 index 00000000..e603b920 --- /dev/null +++ b/alioth/src/virtio/vu/bindings.rs @@ -0,0 +1,207 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bitfield::bitfield; +use bitflags::bitflags; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; + +use crate::c_enum; + +bitflags! { + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[repr(transparent)] + pub struct VuFeature: u64 { + const MQ = 1 << 0; + const LOG_SHMFD = 1 << 1; + const RARP = 1 << 2; + const REPLY_ACK = 1 << 3; + const MTU = 1 << 4; + const BACKEND_REQ = 1 << 5; + const CROSS_ENDIAN = 1 << 6; + const CRYPTO_SESSION = 1 << 7; + const PAGEFAULT = 1 << 8; + const CONFIG = 1 << 9; + const BACKEND_SEND_FD = 1 << 10; + const HOST_NOTIFIER = 1 << 11; + const INFLIGHT_SHMFD = 1 << 12; + const RESET_DEVICE = 1 << 13; + const INBAND_NOTIFICATIONS = 1 << 14; + const CONFIGURE_MEM_SLOTS = 1 << 15; + const STATUS = 1 << 16; + const XEN_MMAP = 1 << 17; + const SHARED_OBJECT = 1 << 18; + const DEVICE_STATE = 1 << 19; + } +} + +c_enum! { + pub struct VuFrontMsg(u32); + { + GET_FEATURES = 1; + SET_FEATURES = 2; + SET_OWNER = 3; + RESET_OWNER = 4; + SET_MEM_TABLE = 5; + SET_LOG_BASE = 6; + SET_LOG_FD = 7; + SET_VIRTQ_NUM = 8; + SET_VIRTQ_ADDR = 9; + SET_VIRTQ_BASE = 10; + GET_VIRTQ_BASE = 11; + SET_VIRTQ_KICK = 12; + SET_VIRTQ_CALL = 13; + SET_VIRTQ_ERR = 14; + GET_PROTOCOL_FEATURES = 15; + SET_PROTOCOL_FEATURES = 16; + GET_QUEUE_NUM = 17; + SET_VIRTQ_ENABLE = 18; + SEND_RARP = 19; + NET_SET_MTU = 20; + SET_BACKEND_REQ_FD = 21; + IOTLB_MSG = 22; + SET_VIRTQ_ENDIAN = 23; + GET_CONFIG = 24; + SET_CONFIG = 25; + CREATE_CRYPTO_SESSION = 26; + CLOSE_CRYPTO_SESSION = 27; + POSTCOPY_ADVISE = 28; + POSTCOPY_LISTEN = 29; + POSTCOPY_END = 30; + GET_INFLIGHT_FD = 31; + SET_INFLIGHT_FD = 32; + GPU_SET_SOCKET = 33; + RESET_DEVICE = 34; + GET_MAX_MEM_SLOTS = 36; + ADD_MEM_REG = 37; + REM_MEM_REG = 38; + SET_STATUS = 39; + GET_STATUS = 40; + GET_SHARED_OBJECT = 41; + SET_DEVICE_STATE_FD = 42; + CHECK_DEVICE_STATE = 43; + } +} + +c_enum! { + pub struct VuFrontMsgSize((u32, usize)); + { + GET_FEATURES = (0, size_of::()); + } +} + +c_enum! { + pub struct VuBackMsg(u32); + { + IOTLB_MSG = 1; + CONFIG_CHANGE_MSG = 2; + VIRTQ_HOST_NOTIFIER_MSG = 3; + VIRTQ_CALL = 4; + VIRTQ_ERR = 5; + SHARED_OBJECT_ADD = 6; + SHARED_OBJECT_REMOVE = 7; + SHARED_OBJECT_LOOKUP = 8; + } +} + +bitfield! { + #[derive(Copy, Clone, Default, IntoBytes, FromBytes, Immutable)] + #[repr(transparent)] + pub struct MessageFlag(u32); + impl Debug; + pub need_reply, set_need_reply: 3; + pub reply, set_reply: 2; + pub version, set_version: 1, 0; +} + +impl MessageFlag { + pub const NEED_REPLY: u32 = 1 << 3; + pub const REPLY: u32 = 1 << 2; + pub const VERSION_1: u32 = 0x1; + + pub const fn sender() -> Self { + MessageFlag(MessageFlag::VERSION_1 | MessageFlag::NEED_REPLY) + } + + pub const fn receiver() -> Self { + MessageFlag(MessageFlag::VERSION_1 | MessageFlag::REPLY) + } +} + +#[derive(Debug, IntoBytes, FromBytes, Immutable, KnownLayout)] +#[repr(C)] +pub struct VirtqState { + pub index: u32, + pub val: u32, +} + +#[derive(Debug, Clone, IntoBytes, FromBytes, Immutable, KnownLayout)] +#[repr(C)] +pub struct VirtqAddr { + pub index: u32, + pub flags: u32, + pub desc_hva: u64, + pub used_hva: u64, + pub avail_hva: u64, + pub log_guest_addr: u64, +} + +#[derive(Debug, Clone, IntoBytes, FromBytes, Immutable, KnownLayout)] +#[repr(C)] +pub struct MemoryRegion { + pub gpa: u64, + pub size: u64, + pub hva: u64, + pub mmap_offset: u64, +} + +#[derive(Debug, IntoBytes, FromBytes, Immutable, KnownLayout)] +#[repr(C)] +pub struct MemorySingleRegion { + pub _padding: u64, + pub region: MemoryRegion, +} + +#[derive(Debug, IntoBytes, FromBytes, Immutable, KnownLayout)] +#[repr(C)] +pub struct MemoryMultipleRegion { + pub num: u32, + pub _padding: u32, + pub regions: [MemoryRegion; 8], +} + +#[derive(Debug, IntoBytes, FromBytes, Immutable, KnownLayout)] +#[repr(C)] +pub struct DeviceConfig { + pub offset: u32, + pub size: u32, + pub flags: u32, + pub region: [u8; 256], +} + +#[derive(Debug, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)] +#[repr(C)] +pub struct FsMap { + pub fd_offset: [u64; 8], + pub cache_offset: [u64; 8], + pub len: [u64; 8], + pub flags: [u64; 8], +} + +#[derive(Debug, IntoBytes, FromBytes, Immutable, KnownLayout)] +#[repr(C)] +pub struct Message { + pub request: u32, + pub flag: MessageFlag, + pub size: u32, +} diff --git a/alioth/src/virtio/vu/conn.rs b/alioth/src/virtio/vu/conn.rs new file mode 100644 index 00000000..a9399830 --- /dev/null +++ b/alioth/src/virtio/vu/conn.rs @@ -0,0 +1,295 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io::{IoSlice, IoSliceMut, Read}; +use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd}; +use std::os::unix::net::UnixStream; +use std::path::Path; + +use snafu::ResultExt; +use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes}; + +use crate::ffi; +use crate::utils::uds::{recv_msg_with_fds, send_msg_with_fds}; +use crate::virtio::vu::bindings::{ + DeviceConfig, MemorySingleRegion, Message, MessageFlag, VirtqAddr, VirtqState, VuBackMsg, + VuFrontMsg, +}; +use crate::virtio::vu::{Result, error}; + +fn send(mut conn: &UnixStream, req: u32, payload: &T, fds: &[BorrowedFd]) -> Result +where + T: IntoBytes + Immutable, + R: FromBytes + IntoBytes, +{ + let vhost_msg = Message { + request: req, + flag: MessageFlag::sender(), + size: size_of::() as u32, + }; + let bufs = [ + IoSlice::new(vhost_msg.as_bytes()), + IoSlice::new(payload.as_bytes()), + ]; + let done = send_msg_with_fds(conn, &bufs, fds)?; + let want = size_of_val(&vhost_msg) + vhost_msg.size as usize; + if done != want { + return error::PartialWrite { done, want }.fail(); + } + + let mut resp = Message::new_zeroed(); + let mut payload = R::new_zeroed(); + let mut ret_code = u64::MAX; + let mut bufs = [ + IoSliceMut::new(resp.as_mut_bytes()), + if size_of::() > 0 { + IoSliceMut::new(payload.as_mut_bytes()) + } else { + IoSliceMut::new(ret_code.as_mut_bytes()) + }, + ]; + let resp_size = bufs[1].len() as u32; + let expect_size = size_of::() + bufs[1].len(); + + let size = conn.read_vectored(&mut bufs)?; + if size != expect_size { + return error::MsgSize { + want: expect_size, + got: size, + } + .fail(); + } + if resp.request != req { + return error::Response { + want: req, + got: resp.request, + } + .fail(); + } + if resp.size != resp_size { + return error::PayloadSize { + want: size_of::(), + got: resp.size, + } + .fail(); + } + if size_of::() == 0 && ret_code != 0 { + return error::RequestErr { ret: ret_code, req }.fail(); + } + + Ok(payload) +} + +fn reply(conn: &UnixStream, req: u32, payload: &T, fds: &[BorrowedFd]) -> Result<()> +where + T: IntoBytes + Immutable, +{ + let msg = Message { + request: req, + flag: MessageFlag::receiver(), + size: size_of_val(payload) as _, + }; + let bufs = [ + IoSlice::new(msg.as_bytes()), + IoSlice::new(payload.as_bytes()), + ]; + let done = send_msg_with_fds(conn, &bufs, fds)?; + let want = size_of_val(&msg) + size_of_val(payload); + if done != want { + return error::PartialWrite { want, done }.fail(); + } + Ok(()) +} + +fn recv_with_fds(conn: &UnixStream, fds: &mut [Option]) -> Result +where + T: IntoBytes + Immutable + FromBytes, +{ + let mut msg = T::new_zeroed(); + let mut bufs = [IoSliceMut::new(msg.as_mut_bytes())]; + let size = recv_msg_with_fds(conn, &mut bufs, fds)?; + if size != size_of::() { + error::MsgSize { + want: size_of::(), + got: size, + } + .fail() + } else { + Ok(msg) + } +} + +#[derive(Debug)] +pub struct VuSession { + pub conn: UnixStream, +} + +impl VuSession { + pub fn new>(path: P) -> Result { + let conn = UnixStream::connect(&path).context(error::AccessSocket { + path: path.as_ref(), + })?; + Ok(VuSession { conn }) + } + + fn send(&self, req: VuFrontMsg, payload: &T, fds: &[BorrowedFd]) -> Result + where + T: IntoBytes + Immutable, + R: FromBytes + IntoBytes, + { + send(&self.conn, req.raw(), payload, fds) + } + + pub fn recv_payload(&self) -> Result + where + T: IntoBytes + Immutable + FromBytes, + { + recv_with_fds(&self.conn, &mut []) + } + + pub fn recv_msg(&self, fds: &mut [Option]) -> Result { + recv_with_fds(&self.conn, fds) + } + + pub fn reply( + &self, + req: VuFrontMsg, + payload: &T, + fds: &[BorrowedFd], + ) -> Result<()> { + reply(&self.conn, req.raw(), payload, fds) + } + + pub fn get_features(&self) -> Result { + self.send(VuFrontMsg::GET_FEATURES, &(), &[]) + } + + pub fn set_features(&self, payload: &u64) -> Result<()> { + self.send(VuFrontMsg::SET_FEATURES, payload, &[]) + } + + pub fn get_protocol_features(&self) -> Result { + self.send(VuFrontMsg::GET_PROTOCOL_FEATURES, &(), &[]) + } + + pub fn set_protocol_features(&self, payload: &u64) -> Result { + self.send(VuFrontMsg::SET_PROTOCOL_FEATURES, payload, &[]) + } + + pub fn set_owner(&self) -> Result<()> { + self.send(VuFrontMsg::SET_OWNER, &(), &[]) + } + + pub fn set_virtq_num(&self, payload: &VirtqState) -> Result<()> { + self.send(VuFrontMsg::SET_VIRTQ_NUM, payload, &[]) + } + + pub fn set_virtq_addr(&self, payload: &VirtqAddr) -> Result<()> { + self.send(VuFrontMsg::SET_VIRTQ_ADDR, payload, &[]) + } + + pub fn set_virtq_base(&self, payload: &VirtqState) -> Result<()> { + self.send(VuFrontMsg::SET_VIRTQ_BASE, payload, &[]) + } + + pub fn get_config(&self, payload: &DeviceConfig) -> Result { + self.send(VuFrontMsg::GET_CONFIG, payload, &[]) + } + + pub fn set_config(&self, payload: &DeviceConfig) -> Result<()> { + self.send(VuFrontMsg::SET_CONFIG, payload, &[]) + } + + pub fn get_virtq_base(&self, payload: &VirtqState) -> Result { + self.send(VuFrontMsg::GET_VIRTQ_BASE, payload, &[]) + } + + pub fn get_queue_num(&self) -> Result { + self.send(VuFrontMsg::GET_QUEUE_NUM, &(), &[]) + } + + pub fn set_virtq_kick(&self, payload: &u64, fd: BorrowedFd) -> Result<()> { + self.send(VuFrontMsg::SET_VIRTQ_KICK, payload, &[fd]) + } + + pub fn set_virtq_call(&self, payload: &u64, fd: BorrowedFd) -> Result<()> { + self.send(VuFrontMsg::SET_VIRTQ_CALL, payload, &[fd]) + } + + pub fn set_virtq_err(&self, payload: &u64, fd: BorrowedFd) -> Result<()> { + self.send(VuFrontMsg::SET_VIRTQ_ERR, payload, &[fd]) + } + + pub fn set_virtq_enable(&self, payload: &VirtqState) -> Result<()> { + self.send(VuFrontMsg::SET_VIRTQ_ENABLE, payload, &[]) + } + + pub fn set_status(&self, payload: &u64) -> Result<()> { + self.send(VuFrontMsg::SET_STATUS, payload, &[]) + } + + pub fn get_status(&self) -> Result { + self.send(VuFrontMsg::GET_STATUS, &(), &[]) + } + + pub fn add_mem_region(&self, payload: &MemorySingleRegion, fd: BorrowedFd) -> Result<()> { + self.send(VuFrontMsg::ADD_MEM_REG, payload, &[fd]) + } + + pub fn remove_mem_region(&self, payload: &MemorySingleRegion) -> Result<()> { + self.send(VuFrontMsg::REM_MEM_REG, payload, &[]) + } + + fn set_backend_req_fd(&self, fd: BorrowedFd) -> Result<()> { + self.send(VuFrontMsg::SET_BACKEND_REQ_FD, &(), &[fd]) + } + + pub fn create_channel(&self) -> Result { + let mut socket_fds = [0; 2]; + ffi!(unsafe { + libc::socketpair(libc::PF_UNIX, libc::SOCK_STREAM, 0, socket_fds.as_mut_ptr()) + })?; + let channel = unsafe { UnixStream::from_raw_fd(socket_fds[0]) }; + let peer = unsafe { OwnedFd::from_raw_fd(socket_fds[1]) }; + self.set_backend_req_fd(peer.as_fd())?; + Ok(VuChannel { conn: channel }) + } +} + +#[derive(Debug)] +pub struct VuChannel { + pub conn: UnixStream, +} + +impl VuChannel { + pub fn recv_payload(&self) -> Result + where + T: IntoBytes + Immutable + FromBytes, + { + recv_with_fds(&self.conn, &mut []) + } + + pub fn recv_msg(&self, fds: &mut [Option]) -> Result { + recv_with_fds(&self.conn, fds) + } + + pub fn reply( + &self, + req: VuBackMsg, + payload: &T, + fds: &[BorrowedFd], + ) -> Result<()> { + reply(&self.conn, req.raw(), payload, fds) + } +} diff --git a/alioth/src/virtio/vu/frontend.rs b/alioth/src/virtio/vu/frontend.rs new file mode 100644 index 00000000..fa53d515 --- /dev/null +++ b/alioth/src/virtio/vu/frontend.rs @@ -0,0 +1,75 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use bitflags::bitflags; + +use crate::errors::BoxTrace; +use crate::mem; +use crate::mem::LayoutChanged; +use crate::mem::mapped::ArcMemPages; +use crate::virtio::vu::bindings::{MemoryRegion, MemorySingleRegion}; +use crate::virtio::vu::conn::VuSession; + +bitflags! { + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub struct VuDevFeature: u64 { } +} + +#[derive(Debug)] +pub struct UpdateVuMem { + pub name: Arc, + pub session: Arc, +} + +impl LayoutChanged for UpdateVuMem { + fn ram_added(&self, gpa: u64, pages: &ArcMemPages) -> mem::Result<()> { + let Some((fd, offset)) = pages.fd() else { + return Ok(()); + }; + let region = MemorySingleRegion { + _padding: 0, + region: MemoryRegion { + gpa: gpa as _, + size: pages.size() as _, + hva: pages.addr() as _, + mmap_offset: offset, + }, + }; + let ret = self.session.add_mem_region(®ion, fd); + ret.box_trace(mem::error::ChangeLayout)?; + log::trace!("{}: add memory region: {:x?}", self.name, region.region); + Ok(()) + } + + fn ram_removed(&self, gpa: u64, pages: &ArcMemPages) -> mem::Result<()> { + let Some((_, offset)) = pages.fd() else { + return Ok(()); + }; + let region = MemorySingleRegion { + _padding: 0, + region: MemoryRegion { + gpa: gpa as _, + size: pages.size() as _, + hva: pages.addr() as _, + mmap_offset: offset, + }, + }; + let ret = self.session.remove_mem_region(®ion); + ret.box_trace(mem::error::ChangeLayout)?; + log::trace!("{}: remove memory region: {:x?}", self.name, region.region); + Ok(()) + } +} diff --git a/alioth/src/virtio/vu/vu.rs b/alioth/src/virtio/vu/vu.rs new file mode 100644 index 00000000..32ca61a7 --- /dev/null +++ b/alioth/src/virtio/vu/vu.rs @@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod bindings; +pub mod conn; +pub mod frontend; + +use std::path::PathBuf; + +use snafu::Snafu; + +use crate::errors::{DebugTrace, trace_error}; +use crate::virtio::vu::bindings::VuFeature; + +#[trace_error] +#[derive(Snafu, DebugTrace)] +#[snafu(module, visibility(pub(crate)), context(suffix(false)))] +pub enum Error { + #[snafu(display("Cannot access socket {path:?}"))] + AccessSocket { + path: PathBuf, + error: std::io::Error, + }, + #[snafu(display("Error from OS"), context(false))] + System { error: std::io::Error }, + #[snafu(display("vhost-user message ({req:#x}) missing fd"))] + MissingFd { req: u32 }, + #[snafu(display("Unexpected vhost-user response, want {want}, got {got}"))] + Response { want: u32, got: u32 }, + #[snafu(display("Unexpected vhost-user message size, want {want}, get {got}"))] + MsgSize { want: usize, got: usize }, + #[snafu(display("Failed to send {want} bytes, only {done} bytes were sent"))] + PartialWrite { want: usize, done: usize }, + #[snafu(display("Invalid vhost-user message payload size, want {want}, got {got}"))] + PayloadSize { want: usize, got: u32 }, + #[snafu(display("vhost-user backend replied error code {ret:#x} to request {req:#x}"))] + RequestErr { ret: u64, req: u32 }, + #[snafu(display("vhost-user backend signaled an error of queue {index:#x}"))] + QueueErr { index: u16 }, + #[snafu(display("vhost-user backend is missing device feature {feature:#x}"))] + DeviceFeature { feature: u64 }, + #[snafu(display("vhost-user backend is missing protocol feature {feature:x?}"))] + ProtocolFeature { feature: VuFeature }, +} + +type Result = std::result::Result; From 0213aa314ad71724d5d90b381ebee8eb73482986 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sun, 20 Apr 2025 23:56:12 -0700 Subject: [PATCH 05/10] feat(vu): add a general vhost-user frontend device Signed-off-by: Changyuan Lyu --- alioth/src/virtio/dev/fs.rs | 205 ++++------------- alioth/src/virtio/vu/frontend.rs | 381 ++++++++++++++++++++++++++++++- 2 files changed, 418 insertions(+), 168 deletions(-) diff --git a/alioth/src/virtio/dev/fs.rs b/alioth/src/virtio/dev/fs.rs index 3df4a5e1..dce140d8 100644 --- a/alioth/src/virtio/dev/fs.rs +++ b/alioth/src/virtio/dev/fs.rs @@ -15,18 +15,14 @@ use std::io::ErrorKind; use std::iter::zip; use std::mem::size_of_val; -use std::os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd}; +use std::os::fd::AsRawFd; use std::path::PathBuf; use std::sync::Arc; -use std::sync::atomic::Ordering; use std::sync::mpsc::Receiver; use std::thread::JoinHandle; use bitflags::bitflags; -use libc::{ - EFD_CLOEXEC, EFD_NONBLOCK, MAP_ANONYMOUS, MAP_FAILED, MAP_FIXED, MAP_PRIVATE, MAP_SHARED, - PROT_NONE, eventfd, mmap, -}; +use libc::{MAP_ANONYMOUS, MAP_FAILED, MAP_FIXED, MAP_PRIVATE, MAP_SHARED, PROT_NONE, mmap}; use mio::event::Event; use mio::unix::SourceFd; use mio::{Interest, Registry, Token}; @@ -39,13 +35,12 @@ use crate::mem::mapped::{ArcMemPages, RamBus}; use crate::mem::{LayoutChanged, MemRegion, MemRegionType}; use crate::virtio::dev::{DevParam, Virtio, WakeEvent}; use crate::virtio::queue::{Queue, VirtQueue}; -use crate::virtio::vu::bindings::{DeviceConfig, VirtqAddr, VirtqState, VuBackMsg, VuFeature}; -use crate::virtio::vu::conn::{VuChannel, VuSession}; -use crate::virtio::vu::frontend::UpdateVuMem; +use crate::virtio::vu::bindings::{DeviceConfig, VuBackMsg, VuFeature}; +use crate::virtio::vu::frontend::VuFrontend; use crate::virtio::vu::{Error, error as vu_error}; use crate::virtio::worker::Waker; use crate::virtio::worker::mio::{ActiveMio, Mio, VirtioMio}; -use crate::virtio::{DeviceId, IrqSender, Result, VirtioFeature, error}; +use crate::virtio::{DeviceId, IrqSender, Result}; use crate::{align_up, ffi, impl_mmio_for_zerocopy}; #[repr(C, align(4))] @@ -79,83 +74,50 @@ const VHOST_USER_BACKEND_FS_UNMAP: u32 = 7; #[derive(Debug)] pub struct VuFs { - name: Arc, - session: Arc, - channel: Option, + frontend: VuFrontend, config: Arc, - feature: u64, - num_queues: u16, dax_region: Option, - error_fds: Vec, } impl VuFs { pub fn new(param: VuFsParam, name: impl Into>) -> Result { - let name = name.into(); - let session = Arc::new(VuSession::new(param.socket)?); - let dev_feat = session.get_features()?; - let virtio_feat = VirtioFeature::from_bits_retain(dev_feat); - let need_feat = VirtioFeature::VHOST_PROTOCOL | VirtioFeature::VERSION_1; - if !virtio_feat.contains(need_feat) { - return vu_error::DeviceFeature { - feature: need_feat.bits(), - } - .fail()?; - } - - let prot_feat = VuFeature::from_bits_retain(session.get_protocol_features()?); - log::debug!("{name}: vhost-user feat: {prot_feat:x?}"); - let mut need_feat = VuFeature::MQ | VuFeature::REPLY_ACK | VuFeature::CONFIGURE_MEM_SLOTS; - if param.tag.is_none() { - need_feat |= VuFeature::CONFIG; - } + let mut extra_features = VuFeature::empty(); if param.dax_window > 0 { - assert!(param.dax_window.count_ones() == 1 && param.dax_window > (4 << 10)); - need_feat |= VuFeature::BACKEND_REQ | VuFeature::BACKEND_SEND_FD; - } - if !prot_feat.contains(need_feat) { - return vu_error::ProtocolFeature { - feature: need_feat & !prot_feat, - } - .fail()?; + extra_features |= VuFeature::BACKEND_REQ | VuFeature::BACKEND_SEND_FD + }; + if param.tag.is_none() { + extra_features |= VuFeature::CONFIG; } - session.set_protocol_features(&need_feat.bits())?; - - session.set_owner()?; - let num_queues = session.get_queue_num()? as u16; + let mut frontend = + VuFrontend::new(name, ¶m.socket, DeviceId::FileSystem, extra_features)?; let config = if let Some(tag) = param.tag { assert!(tag.len() <= 36); assert_ne!(tag.len(), 0); let mut config = FsConfig::new_zeroed(); config.tag[0..tag.len()].copy_from_slice(tag.as_bytes()); - config.num_request_queues = num_queues as u32 - 1; - if FsFeature::from_bits_retain(dev_feat).contains(FsFeature::NOTIFICATION) { + config.num_request_queues = frontend.num_queues() as u32 - 1; + if FsFeature::from_bits_retain(frontend.feature()).contains(FsFeature::NOTIFICATION) { config.num_request_queues -= 1; } config } else { let mut empty_cfg = DeviceConfig::new_zeroed(); empty_cfg.size = size_of_val(&empty_cfg.region) as _; - let dev_config = session.get_config(&empty_cfg)?; + let dev_config = frontend.session().get_config(&empty_cfg)?; FsConfig::read_from_prefix(&dev_config.region).unwrap().0 }; - let (dax_region, channel) = if param.dax_window > 0 { - let channel = session.create_channel()?; + + let mut dax_region = None; + if param.dax_window > 0 { + let channel = frontend.session().create_channel()?; let size = align_up!(param.dax_window, 12); - let region = ArcMemPages::from_anonymous(size, Some(PROT_NONE), None)?; - (Some(region), Some(channel)) - } else { - (None, None) - }; + dax_region = Some(ArcMemPages::from_anonymous(size, Some(PROT_NONE), None)?); + frontend.set_channel(channel); + } Ok(VuFs { - num_queues, - name, - session, - channel, + frontend, config: Arc::new(config), - feature: dev_feat & !VirtioFeature::VHOST_PROTOCOL.bits(), - error_fds: Vec::new(), dax_region, }) } @@ -194,7 +156,7 @@ impl Virtio for VuFs { } fn name(&self) -> &str { - &self.name + self.frontend.name() } fn config(&self) -> Arc { @@ -202,11 +164,11 @@ impl Virtio for VuFs { } fn feature(&self) -> u64 { - self.feature + self.frontend.feature() } fn num_queues(&self) -> u16 { - self.num_queues + self.frontend.num_queues() } fn spawn_worker( @@ -223,11 +185,7 @@ impl Virtio for VuFs { } fn ioeventfd_offloaded(&self, q_index: u16) -> Result { - if q_index < self.num_queues { - Ok(true) - } else { - error::InvalidQueueIndex { index: q_index }.fail() - } + self.frontend.ioeventfd_offloaded(q_index) } fn shared_mem_regions(&self) -> Option> { @@ -239,10 +197,7 @@ impl Virtio for VuFs { } fn mem_change_callback(&self) -> Option> { - Some(Box::new(UpdateVuMem { - name: self.name.clone(), - session: self.session.clone(), - })) + self.frontend.mem_change_callback() } } @@ -257,73 +212,12 @@ impl VirtioMio for VuFs { S: IrqSender, E: IoeventFd, { - self.session - .set_features(&(feature | VirtioFeature::VHOST_PROTOCOL.bits()))?; - for (index, fd) in active_mio.ioeventfds.iter().enumerate() { - self.session.set_virtq_kick(&(index as u64), fd.as_fd())?; - } - for (index, queue) in active_mio.queues.iter().enumerate() { - let Some(queue) = queue else { - continue; - }; - let reg = queue.reg(); - active_mio.irq_sender.queue_irqfd(index as _, |fd| { - self.session.set_virtq_call(&(index as u64), fd)?; - Ok(()) - })?; - - let err_fd = - unsafe { OwnedFd::from_raw_fd(ffi!(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK))?) }; - self.session - .set_virtq_err(&(index as u64), err_fd.as_fd()) - .unwrap(); - active_mio.poll.registry().register( - &mut SourceFd(&err_fd.as_raw_fd()), - Token(index), - Interest::READABLE, - )?; - self.error_fds.push(err_fd); - - let virtq_num = VirtqState { - index: index as _, - val: reg.size.load(Ordering::Acquire) as _, - }; - self.session.set_virtq_num(&virtq_num).unwrap(); - log::info!("set_virtq_num: {virtq_num:x?}"); - - let virtq_base = VirtqState { - index: index as _, - val: 0, - }; - self.session.set_virtq_base(&virtq_base).unwrap(); - - log::info!("set_virtq_base: {virtq_base:x?}"); - let mem = active_mio.mem; - let virtq_addr = VirtqAddr { - index: index as _, - flags: 0, - desc_hva: mem.translate(reg.desc.load(Ordering::Acquire) as _)? as _, - used_hva: mem.translate(reg.device.load(Ordering::Acquire) as _)? as _, - avail_hva: mem.translate(reg.driver.load(Ordering::Acquire) as _)? as _, - log_guest_addr: 0, - }; - self.session.set_virtq_addr(&virtq_addr).unwrap(); - log::info!("queue: {:x?}", reg); - log::info!("virtq_addr: {virtq_addr:x?}"); - } - for index in 0..active_mio.queues.len() { - let virtq_enable = VirtqState { - index: index as _, - val: 1, - }; - self.session.set_virtq_enable(&virtq_enable).unwrap(); - log::info!("virtq_enable: {virtq_enable:x?}"); - } - if let Some(channel) = &self.channel { + self.frontend.activate(feature, active_mio)?; + if let Some(channel) = self.frontend.channel() { channel.conn.set_nonblocking(true)?; active_mio.poll.registry().register( &mut SourceFd(&channel.conn.as_raw_fd()), - Token(self.num_queues as _), + Token(self.frontend.num_queues() as _), Interest::READABLE, )?; } @@ -354,7 +248,7 @@ impl VirtioMio for VuFs { } .fail()?; }; - let Some(channel) = &self.channel else { + let Some(channel) = self.frontend.channel() else { return vu_error::ProtocolFeature { feature: VuFeature::BACKEND_REQ, } @@ -387,7 +281,7 @@ impl VirtioMio for VuFs { let map_addr = dax_region.addr() + fs_map.cache_offset[index] as usize; log::trace!( "{}: mapping fd {raw_fd} to offset {:#x}", - self.name, + self.name(), fs_map.cache_offset[index] ); ffi!( @@ -410,7 +304,10 @@ impl VirtioMio for VuFs { if len == 0 { continue; } - log::trace!("{}: unmapping offset {offset:#x}, size {len:#x}", self.name); + log::trace!( + "{}: unmapping offset {offset:#x}, size {len:#x}", + self.name() + ); let map_addr = dax_region.addr() + offset as usize; let flags = MAP_ANONYMOUS | MAP_PRIVATE | MAP_FIXED; ffi!( @@ -419,7 +316,7 @@ impl VirtioMio for VuFs { )?; } } - _ => unimplemented!("unknown request {request:#x}"), + _ => unimplemented!("{}: unknown request {request:#x}", self.name()), } channel.reply(VuBackMsg::from(request), &0u64, &[])?; } @@ -429,35 +326,17 @@ impl VirtioMio for VuFs { fn handle_queue<'a, 'm, Q, S, E>( &mut self, index: u16, - _active_mio: &mut ActiveMio<'a, 'm, Q, S, E>, + active_mio: &mut ActiveMio<'a, 'm, Q, S, E>, ) -> Result<()> where Q: VirtQueue<'m>, S: IrqSender, E: IoeventFd, { - unreachable!( - "{}: queue {index} notification should go to vhost-user backend", - self.name - ) + self.frontend.handle_queue(index, active_mio) } fn reset(&mut self, registry: &Registry) { - for q_index in 0..self.num_queues { - let disable = VirtqState { - index: q_index as _, - val: 0, - }; - self.session.set_virtq_enable(&disable).unwrap(); - } - while let Some(fd) = self.error_fds.pop() { - registry.deregister(&mut SourceFd(&fd.as_raw_fd())).unwrap(); - } - if let Some(channel) = &self.channel { - let channel_fd = channel.conn.as_fd(); - registry - .deregister(&mut SourceFd(&channel_fd.as_raw_fd())) - .unwrap(); - } + self.frontend.reset(registry) } } diff --git a/alioth/src/virtio/vu/frontend.rs b/alioth/src/virtio/vu/frontend.rs index fa53d515..fc815df6 100644 --- a/alioth/src/virtio/vu/frontend.rs +++ b/alioth/src/virtio/vu/frontend.rs @@ -12,16 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd}; +use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::sync::mpsc::Receiver; +use std::thread::JoinHandle; use bitflags::bitflags; +use mio::event::Event; +use mio::unix::SourceFd; +use mio::{Interest, Registry, Token}; +use zerocopy::IntoBytes; use crate::errors::BoxTrace; -use crate::mem; -use crate::mem::LayoutChanged; -use crate::mem::mapped::ArcMemPages; -use crate::virtio::vu::bindings::{MemoryRegion, MemorySingleRegion}; -use crate::virtio::vu::conn::VuSession; +use crate::hv::IoeventFd; +use crate::mem::emulated::{Action, Mmio}; +use crate::mem::mapped::{ArcMemPages, RamBus}; +use crate::mem::{LayoutChanged, MemRegion}; +use crate::virtio::dev::{DevParam, Virtio, WakeEvent}; +use crate::virtio::queue::{Queue, VirtQueue}; +use crate::virtio::vu::bindings::{ + DeviceConfig, MemoryRegion, MemorySingleRegion, VirtqAddr, VirtqState, VuFeature, +}; +use crate::virtio::vu::conn::{VuChannel, VuSession}; +use crate::virtio::vu::error as vu_error; +use crate::virtio::worker::Waker; +use crate::virtio::worker::mio::{ActiveMio, Mio, VirtioMio}; +use crate::virtio::{DevStatus, DeviceId, IrqSender, Result, VirtioFeature, error}; +use crate::{ffi, mem}; bitflags! { #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -73,3 +92,355 @@ impl LayoutChanged for UpdateVuMem { Ok(()) } } + +#[derive(Debug)] +pub struct VuDevConfig { + session: Arc, +} + +impl Mmio for VuDevConfig { + fn size(&self) -> u64 { + 256 + } + + fn read(&self, offset: u64, size: u8) -> mem::Result { + let req = DeviceConfig { + offset: offset as u32, + size: size as u32, + flags: 0, + region: [0u8; 256], + }; + let resp = self.session.get_config(&req).unwrap(); + let mut ret = 0u64; + ret.as_mut_bytes().copy_from_slice(&resp.region[0..8]); + ret &= u64::MAX >> (64 - (size << 3)); + Ok(ret) + } + + fn write(&self, offset: u64, size: u8, val: u64) -> mem::Result { + let mut req = DeviceConfig { + offset: offset as u32, + size: size as u32, + flags: 0, + region: [0u8; 256], + }; + req.region[0..8].copy_from_slice(val.as_bytes()); + self.session.set_config(&req).unwrap(); + Ok(Action::None) + } +} + +#[derive(Debug)] +pub struct VuFrontend { + name: Arc, + session: Arc, + channel: Option, + id: DeviceId, + vu_feature: VuFeature, + device_feature: u64, + num_queues: u16, + err_fds: Box<[OwnedFd]>, +} + +impl VuFrontend { + pub fn new

( + name: impl Into>, + socket: P, + id: DeviceId, + extra_feat: VuFeature, + ) -> Result + where + P: AsRef, + { + let name = name.into(); + let session = Arc::new(VuSession::new(socket)?); + + let device_feature = session.get_features()?; + let feat = VirtioFeature::from_bits_retain(device_feature); + log::trace!("{name}: get device feature: {feat:x?}"); + let need_feat = VirtioFeature::VHOST_PROTOCOL | VirtioFeature::VERSION_1; + if !feat.contains(need_feat) { + return vu_error::DeviceFeature { + feature: need_feat.bits(), + } + .fail()?; + } + + let protocol_feat = VuFeature::from_bits_retain(session.get_protocol_features()?); + log::trace!("{name}: get protocol feature: {protocol_feat:x?}"); + let need_feat = + VuFeature::MQ | VuFeature::REPLY_ACK | VuFeature::CONFIGURE_MEM_SLOTS | extra_feat; + if !protocol_feat.contains(need_feat) { + return vu_error::ProtocolFeature { + feature: need_feat & !protocol_feat, + } + .fail()?; + } + + let mut vu_feature = need_feat; + if protocol_feat.contains(VuFeature::STATUS) { + vu_feature |= VuFeature::STATUS + }; + session.set_protocol_features(&vu_feature.bits())?; + log::trace!("{name}: set protocol feature: {vu_feature:x?}"); + + let num_queues = session.get_queue_num()? as u16; + log::trace!("{name}: get queue number: {num_queues}"); + + let mut err_fds = vec![]; + for index in 0..num_queues { + let raw_fd = ffi!(unsafe { libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK) })?; + let fd = unsafe { OwnedFd::from_raw_fd(raw_fd) }; + session.set_virtq_err(&(index as u64), fd.as_fd())?; + log::trace!("{name}: queue-{index}: set error fd: {}", fd.as_raw_fd()); + err_fds.push(fd); + } + + session.set_owner()?; + log::trace!("{name}: set owner"); + + Ok(VuFrontend { + name, + session, + channel: None, + id, + vu_feature, + device_feature, + num_queues, + err_fds: err_fds.into(), + }) + } + + pub fn session(&self) -> &VuSession { + &self.session + } + + pub fn channel(&self) -> Option<&VuChannel> { + self.channel.as_ref() + } + + pub fn set_channel(&mut self, channel: VuChannel) { + self.channel = Some(channel) + } +} + +impl Virtio for VuFrontend { + type Config = VuDevConfig; + type Feature = VuDevFeature; + + fn id(&self) -> DeviceId { + self.id + } + + fn name(&self) -> &str { + &self.name + } + + fn num_queues(&self) -> u16 { + self.num_queues + } + + fn config(&self) -> Arc { + assert!(self.vu_feature.contains(VuFeature::CONFIG)); + Arc::new(VuDevConfig { + session: self.session.clone(), + }) + } + + fn feature(&self) -> u64 { + self.device_feature + } + + fn spawn_worker( + self, + event_rx: Receiver>, + memory: Arc, + queue_regs: Arc<[Queue]>, + ) -> Result<(JoinHandle<()>, Arc)> + where + S: IrqSender, + E: IoeventFd, + { + Mio::spawn_worker(self, event_rx, memory, queue_regs) + } + + fn ioeventfd_offloaded(&self, q_index: u16) -> Result { + if q_index < self.num_queues { + Ok(true) + } else { + error::InvalidQueueIndex { index: q_index }.fail() + } + } + + fn shared_mem_regions(&self) -> Option> { + None + } + + fn mem_change_callback(&self) -> Option> { + Some(Box::new(UpdateVuMem { + name: self.name.clone(), + session: self.session.clone(), + })) + } +} + +impl VirtioMio for VuFrontend { + fn activate<'a, 'm, Q, S, E>( + &mut self, + feature: u64, + active_mio: &mut ActiveMio<'a, 'm, Q, S, E>, + ) -> Result<()> + where + Q: VirtQueue<'m>, + S: IrqSender, + E: IoeventFd, + { + let name = &*self.name; + self.session + .set_features(&(feature | VirtioFeature::VHOST_PROTOCOL.bits()))?; + log::trace!("{name}: set driver feature: {feature:x?}"); + + for (index, fd) in active_mio.ioeventfds.iter().enumerate() { + self.session.set_virtq_kick(&(index as u64), fd.as_fd())?; + let raw_fd = fd.as_fd().as_raw_fd(); + log::trace!("{name}: queue-{index}: set kick fd: {raw_fd}"); + } + + for (index, queue) in active_mio.queues.iter().enumerate() { + let Some(queue) = queue else { + log::trace!("{name}: queue-{index} is disabled"); + continue; + }; + let reg = queue.reg(); + + let _ = active_mio.irq_sender.queue_irqfd(index as _, |fd| { + self.session.set_virtq_call(&(index as u64), fd)?; + log::trace!("{name}: queue-{index}: set call fd: {}", fd.as_raw_fd()); + Ok(()) + }); + + let virtq_num = VirtqState { + index: index as _, + val: reg.size.load(Ordering::Acquire) as _, + }; + self.session.set_virtq_num(&virtq_num)?; + log::trace!("{name}: queue-{index}: set size: {}", virtq_num.val); + + let virtq_base = VirtqState { + index: index as _, + val: 0, + }; + self.session.set_virtq_base(&virtq_base)?; + log::trace!("{name}: queue-{index}: set base: {}", virtq_base.val); + + let mem = active_mio.mem; + let virtq_addr = VirtqAddr { + index: index as _, + flags: 0, + desc_hva: mem.translate(reg.desc.load(Ordering::Acquire) as _)? as _, + used_hva: mem.translate(reg.device.load(Ordering::Acquire) as _)? as _, + avail_hva: mem.translate(reg.driver.load(Ordering::Acquire) as _)? as _, + log_guest_addr: 0, + }; + self.session.set_virtq_addr(&virtq_addr)?; + log::trace!("{name}: queue-{index}: set addr: {virtq_addr:x?}"); + + let virtq_enable = VirtqState { + index: index as _, + val: 1, + }; + self.session.set_virtq_enable(&virtq_enable)?; + log::trace!("{name}: queue-{index}: set enabled: {}", virtq_enable.val); + } + + for (index, fd) in self.err_fds.iter().enumerate() { + active_mio.poll.registry().register( + &mut SourceFd(&fd.as_raw_fd()), + Token(index), + Interest::READABLE, + )?; + } + + if self.vu_feature.contains(VuFeature::STATUS) { + let dev_status = DevStatus::from_bits_retain(0xf); + self.session.set_status(&(dev_status.bits() as u64))?; + log::trace!("{name}: set status: {dev_status:x?}"); + } + Ok(()) + } + + fn handle_event<'a, 'm, Q, S, E>( + &mut self, + _: &Event, + _: &mut ActiveMio<'a, 'm, Q, S, E>, + ) -> Result<()> + where + Q: VirtQueue<'m>, + S: IrqSender, + E: IoeventFd, + { + unreachable!() + } + + fn handle_queue<'a, 'm, Q, S, E>( + &mut self, + index: u16, + _: &mut ActiveMio<'a, 'm, Q, S, E>, + ) -> Result<()> + where + Q: VirtQueue<'m>, + S: IrqSender, + E: IoeventFd, + { + unreachable!( + "{}: queue {index} notification should go to vhost-user backend", + self.name + ) + } + + fn reset(&mut self, registry: &Registry) { + let name = &*self.name; + for index in 0..self.num_queues { + let disable = VirtqState { + index: index as _, + val: 0, + }; + if let Err(e) = self.session.set_virtq_enable(&disable) { + log::error!("{name}: failed to disable queue-{index}: {e:?}") + } + } + if self.vu_feature.contains(VuFeature::STATUS) { + if let Err(e) = self.session.set_status(&0) { + log::error!("{name}: failed to reset device status: {e:?}"); + } + } + for (index, fd) in self.err_fds.iter().enumerate() { + if let Err(e) = registry.deregister(&mut SourceFd(&fd.as_raw_fd())) { + log::error!("{name}: queue-{index}: failed to deregister error fd: {e:?}"); + } + } + if let Some(channel) = &self.channel { + let channel_fd = channel.conn.as_fd(); + if let Err(e) = registry.deregister(&mut SourceFd(&channel_fd.as_raw_fd())) { + log::error!("{name}: failed to deregister backend channel fd: {e:?}") + } + } + } +} + +pub struct VuFrontendParam { + pub socket: PathBuf, + pub id: DeviceId, +} + +impl DevParam for VuFrontendParam { + type Device = VuFrontend; + + fn build(self, name: impl Into>) -> Result { + VuFrontend::new(name, self.socket, self.id, VuFeature::CONFIG) + } + + fn needs_mem_shared_fd(&self) -> bool { + true + } +} From bb3fde6e4f604c8d8cb2f0aaf4884c51ff675868 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sun, 27 Apr 2025 18:23:39 -0700 Subject: [PATCH 06/10] feat(vu): implement vhost-user backend as a transport layer Signed-off-by: Changyuan Lyu --- alioth/src/mem/mapped.rs | 2 +- alioth/src/virtio/vu/backend.rs | 514 ++++++++++++++++++++++++++++++++ alioth/src/virtio/vu/vu.rs | 1 + 3 files changed, 516 insertions(+), 1 deletion(-) create mode 100644 alioth/src/virtio/vu/backend.rs diff --git a/alioth/src/mem/mapped.rs b/alioth/src/mem/mapped.rs index 610bd678..b6e422fa 100644 --- a/alioth/src/mem/mapped.rs +++ b/alioth/src/mem/mapped.rs @@ -441,7 +441,7 @@ impl RamBus { Ok(()) } - pub(super) fn remove(&self, gpa: u64) -> Result { + pub(crate) fn remove(&self, gpa: u64) -> Result { let mut ram = self.ram.write(); ram.inner.remove(gpa) } diff --git a/alioth/src/virtio/vu/backend.rs b/alioth/src/virtio/vu/backend.rs new file mode 100644 index 00000000..9b922900 --- /dev/null +++ b/alioth/src/virtio/vu/backend.rs @@ -0,0 +1,514 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::min; +use std::fs::File; +use std::io::{ErrorKind, Write}; +use std::iter::zip; +use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd}; +use std::os::unix::net::UnixStream; +use std::sync::Arc; +use std::sync::atomic::Ordering; + +use macros::trace_error; +use snafu::Snafu; +use zerocopy::{FromZeros, IntoBytes}; + +use crate::errors::DebugTrace; +use crate::hv::IoeventFd; +use crate::mem::mapped::{ArcMemPages, RamBus}; +use crate::virtio::dev::{StartParam, VirtioDevice, WakeEvent}; +use crate::virtio::vu::Error as VuError; +use crate::virtio::vu::bindings::{ + DeviceConfig, MemoryRegion, MemorySingleRegion, Message, VirtqAddr, VirtqState, VuFeature, + VuFrontMsg, +}; +use crate::virtio::vu::conn::{VuChannel, VuSession}; +use crate::virtio::{self, DevStatus, IrqSender, VirtioFeature}; + +#[trace_error] +#[derive(Snafu, DebugTrace)] +#[snafu(module, context(suffix(false)))] +pub enum Error { + #[snafu(display("Error from OS"), context(false))] + System { error: std::io::Error }, + #[snafu(display("Failed to access guest memory"), context(false))] + Memory { source: Box }, + #[snafu(display("vhost-user protocol error"), context(false))] + Vu { + source: Box, + }, + #[snafu(display("failed to parse the payload of {req:?}"))] + Parse { req: VuFrontMsg }, + #[snafu(display("frontend requested invalid queue index: {index}"))] + InvalidQueue { index: u16 }, + #[snafu(display("{req:?} did not contain an FD"))] + MissingFd { req: VuFrontMsg }, + #[snafu(display("frontend did not set size for queue {index}"))] + MissingSize { index: u16 }, + #[snafu(display("frontend did not set addresses for queue {index}"))] + MissingAddr { index: u16 }, + #[snafu(display("frontend did not set ioeventfd for queue {index}"))] + MissingIoeventfd { index: u16 }, + #[snafu(display("cannot convert frontend HVA {hva:#x} to GPA"))] + Convert { hva: u64 }, + #[snafu(display("invalid message {req:?} with payload size {size}"))] + InvalidMsg { req: VuFrontMsg, size: u32 }, +} + +type Result = std::result::Result; + +#[derive(Debug)] +pub struct VuIrqSender { + queues: Box<[Option]>, +} + +impl VuIrqSender { + fn signal_irqfd(&self, mut fd: &File) { + if let Err(e) = fd.write(1u64.as_bytes()) { + log::error!("failed to signal irqfd: {e:?}"); + } + } +} + +impl IrqSender for VuIrqSender { + fn config_irq(&self) { + // TODO: investigate VHOST_USER_BACKEND_CONFIG_CHANGE_MSG + log::error!("config irqfd is not available"); + } + + fn queue_irq(&self, idx: u16) { + let Some(queue) = self.queues.get(idx as usize) else { + log::error!("invalid queue index: {idx}"); + return; + }; + let Some(fd) = queue.as_ref() else { + log::error!("queue-{idx} irqfd is not available"); + return; + }; + self.signal_irqfd(fd); + } + + fn config_irqfd(&self, _: F) -> virtio::Result + where + F: FnOnce(BorrowedFd) -> virtio::Result, + { + unreachable!() + } + + fn queue_irqfd(&self, _: u16, _: F) -> virtio::Result + where + F: FnOnce(BorrowedFd) -> virtio::Result, + { + unreachable!() + } +} + +#[derive(Debug)] +pub struct VuEventfd { + fd: File, +} + +impl AsFd for VuEventfd { + fn as_fd(&self) -> BorrowedFd { + self.fd.as_fd() + } +} + +impl IoeventFd for VuEventfd {} + +#[derive(Debug, Default)] +struct VuQueueInit { + enable: bool, + size: Option, + addr: Option, + ioeventfd: Option, + irqfd: Option, + errfd: Option, +} + +#[derive(Debug)] +struct VuInit { + drv_feat: u64, + queues: Box<[VuQueueInit]>, + regions: Vec<(MemoryRegion, Option)>, +} + +pub struct VuBackend { + session: VuSession, + channel: Option>, + status: DevStatus, + memory: Arc, + dev: VirtioDevice, + init: VuInit, +} + +impl VuBackend { + pub fn new( + conn: UnixStream, + dev: VirtioDevice, + memory: Arc, + ) -> Result { + conn.set_nonblocking(false)?; + let queue_num = dev.queue_regs.len(); + Ok(VuBackend { + session: VuSession { conn }, + channel: None, + dev, + memory, + status: DevStatus::empty(), + init: VuInit { + drv_feat: 0, + queues: (0..queue_num).map(|_| VuQueueInit::default()).collect(), + regions: vec![], + }, + }) + } + + fn wake_up_dev(&self, event: WakeEvent) { + let is_start = matches!(event, WakeEvent::Start { .. }); + if let Err(e) = self.dev.event_tx.send(event) { + log::error!("{}: failed to send event: {e}", self.dev.name); + return; + } + if is_start { + return; + } + if let Err(e) = self.dev.waker.wake() { + log::error!("{}: failed to wake up device: {e}", self.dev.name); + } + } + + fn convert_frontend_hva(&self, hva: u64) -> Result { + for region in &self.init.regions { + let (r, _) = ®ion; + if hva >= r.hva && hva < r.hva + r.size { + return Ok(r.gpa + (hva - r.hva)); + } + } + error::Convert { hva }.fail() + } + + fn parse_init(&mut self) -> Result> { + for (index, (param, queue)) in zip(&self.init.queues, &*self.dev.queue_regs).enumerate() { + let index = index as u16; + queue.enabled.store(param.enable, Ordering::Release); + if !param.enable { + continue; + } + + let Some(size) = param.size else { + return error::MissingSize { index }.fail(); + }; + queue.size.store(size, Ordering::Release); + + let Some(addr) = ¶m.addr else { + return error::MissingAddr { index }.fail(); + }; + + let desc_gpa = self.convert_frontend_hva(addr.desc_hva)?; + queue.desc.store(desc_gpa, Ordering::Release); + + let dev_gpa = self.convert_frontend_hva(addr.used_hva)?; + queue.device.store(dev_gpa, Ordering::Release); + + let drv_gpa = self.convert_frontend_hva(addr.avail_hva)?; + queue.driver.store(drv_gpa, Ordering::Release); + } + + self.init.regions.sort_by_key(|(r, _)| r.gpa); + for (region, fd) in self.init.regions.iter_mut() { + let Some(fd) = fd.take() else { + continue; + }; + let user_mem = ArcMemPages::from_file( + File::from(fd), + region.mmap_offset as i64, + region.size as usize, + libc::PROT_READ | libc::PROT_WRITE, + )?; + self.memory.add(region.gpa, user_mem)?; + } + + let queues = &mut self.init.queues; + + let queue_irqfds = queues.iter_mut().map(|q| q.irqfd.take()).collect(); + let irq_sender = VuIrqSender { + queues: queue_irqfds, + }; + + let mut ioeventfds = vec![]; + for (index, q) in queues.iter_mut().enumerate() { + match q.ioeventfd.take() { + Some(fd) => ioeventfds.push(VuEventfd { fd }), + None => { + let index = index as u16; + return error::MissingIoeventfd { index }.fail(); + } + } + } + + Ok(StartParam { + feature: self.init.drv_feat, + irq_sender: Arc::new(irq_sender), + ioeventfds: Some(ioeventfds.into()), + }) + } + + fn handle_msg(&mut self, msg: &mut Message, fds: &mut [Option; 8]) -> Result<()> { + let name = &*self.dev.name; + let (req, size) = (VuFrontMsg::from(msg.request), msg.size); + + match (req, size) { + (VuFrontMsg::GET_PROTOCOL_FEATURES, 0) => { + let feature = VuFeature::MQ + | VuFeature::REPLY_ACK + | VuFeature::CONFIGURE_MEM_SLOTS + | VuFeature::BACKEND_REQ + | VuFeature::BACKEND_SEND_FD + | VuFeature::CONFIG + | VuFeature::STATUS; + self.session.reply(req, &feature.bits(), &[])?; + msg.flag.set_need_reply(false); + log::debug!("{name}: get protocol feature: {feature:x?}"); + } + (VuFrontMsg::SET_PROTOCOL_FEATURES, 8) => { + let feature: u64 = self.session.recv_payload()?; + let feature = VuFeature::from_bits_retain(feature); + log::debug!("{name}: set protocol feature: {feature:x?}"); + } + (VuFrontMsg::GET_FEATURES, 0) => { + let feature = self.dev.device_feature | VirtioFeature::VHOST_PROTOCOL.bits(); + self.session.reply(req, &feature, &[])?; + msg.flag.set_need_reply(false); + log::debug!("{name}: get device feature: {feature:#x}"); + } + (VuFrontMsg::SET_FEATURES, 8) => { + self.init.drv_feat = self.session.recv_payload()?; + log::debug!("{name}: set driver feature: {:#x}", self.init.drv_feat); + } + (VuFrontMsg::SET_OWNER, 0) => { + log::trace!("{name}: set owner"); + } + (VuFrontMsg::GET_QUEUE_NUM, 0) => { + let count = self.init.queues.len() as u64; + self.session.reply(req, &count, &[])?; + log::debug!("{name}: get queue number: {count}"); + msg.flag.set_need_reply(false); + } + (VuFrontMsg::SET_BACKEND_REQ_FD, 0) => { + let Some(fd) = fds[0].take() else { + return error::MissingFd { req }.fail()?; + }; + log::trace!("{name}: set backend request fd: {}", fd.as_raw_fd()); + self.channel = Some(Arc::new(VuChannel { + conn: UnixStream::from(fd), + })); + } + (VuFrontMsg::SET_VIRTQ_ERR, 8) => { + let index = self.session.recv_payload::()? as u16; + let Some(fd) = fds[0].take() else { + return error::MissingFd { req: msg.request }.fail(); + }; + let Some(q) = self.init.queues.get_mut(index as usize) else { + return error::InvalidQueue { index }.fail(); + }; + log::debug!("{name}: queue-{index}: set error fd: {}", fd.as_raw_fd()); + q.errfd = Some(File::from(fd)); + } + (VuFrontMsg::SET_VIRTQ_CALL, 8) => { + let index = self.session.recv_payload::()? as u16; + let Some(fd) = fds[0].take() else { + return error::MissingFd { req: msg.request }.fail(); + }; + let Some(q) = self.init.queues.get_mut(index as usize) else { + return error::InvalidQueue { index }.fail(); + }; + log::debug!("{name}: queue-{index}: set call fd: {}", fd.as_raw_fd()); + q.irqfd = Some(File::from(fd)); + } + (VuFrontMsg::SET_VIRTQ_KICK, 8) => { + let index = self.session.recv_payload::()? as u16; + let Some(fd) = fds[0].take() else { + return error::MissingFd { req: msg.request }.fail(); + }; + let Some(q) = self.init.queues.get_mut(index as usize) else { + return error::InvalidQueue { index }.fail(); + }; + log::debug!("{name}: queue-{index}: set kick fd: {}", fd.as_raw_fd()); + q.ioeventfd = Some(File::from(fd)); + } + (VuFrontMsg::SET_VIRTQ_NUM, 8) => { + let virtq_num: VirtqState = self.session.recv_payload()?; + let (index, size) = (virtq_num.index as u16, virtq_num.val as u16); + let Some(q) = self.init.queues.get_mut(index as usize) else { + return error::InvalidQueue { index }.fail(); + }; + q.size = Some(size); + log::debug!("{name}: queue-{index}: set size: {size}"); + } + (VuFrontMsg::SET_VIRTQ_BASE, 8) => { + let virtq_base: VirtqState = self.session.recv_payload()?; + let (index, base) = (virtq_base.index as u16, virtq_base.val); + let Some(_q) = self.init.queues.get_mut(index as usize) else { + return error::InvalidQueue { index }.fail(); + }; + log::warn!("{name}: queue-{index}: set base: {base}"); + } + (VuFrontMsg::GET_VIRTQ_BASE, 8) => { + let mut virtq_base: VirtqState = self.session.recv_payload()?; + let (index, base) = (virtq_base.index as u16, virtq_base.val); + let Some(_q) = self.init.queues.get_mut(index as usize) else { + return error::InvalidQueue { index }.fail(); + }; + virtq_base.val = 0; + self.session.reply(req, &virtq_base, &[])?; + msg.flag.set_need_reply(false); + log::warn!("{name}: queue-{index}: get base: {base}"); + } + (VuFrontMsg::SET_VIRTQ_ADDR, 40) => { + let virtq_addr: VirtqAddr = self.session.recv_payload()?; + let index = virtq_addr.index as u16; + let Some(q) = self.init.queues.get_mut(index as usize) else { + return error::InvalidQueue { index }.fail(); + }; + log::debug!("{name}: queue-{index}: set addr: {virtq_addr:x?}"); + q.addr = Some(virtq_addr); + } + (VuFrontMsg::SET_VIRTQ_ENABLE, 8) => { + let virtq_num: VirtqState = self.session.recv_payload()?; + let (index, enabled) = (virtq_num.index as u16, virtq_num.val != 0); + let Some(q) = self.init.queues.get_mut(index as usize) else { + return error::InvalidQueue { index }.fail(); + }; + q.enable = enabled; + log::debug!("{name}: queue-{index}: set enabled: {enabled}"); + } + (VuFrontMsg::GET_MAX_MEM_SLOTS, 0) => { + self.session.reply(req, &128u64, &[])?; + msg.flag.set_need_reply(false); + log::debug!("{name}: get max mem slots: 128"); + } + (VuFrontMsg::ADD_MEM_REG, 40) => { + let single: MemorySingleRegion = self.session.recv_payload()?; + let Some(fd) = fds[0].take() else { + return error::MissingFd { req: msg.request }.fail(); + }; + let region = &single.region; + log::debug!("{name}: add mem: {region:x?}, fd: {}", fd.as_raw_fd()); + self.init.regions.push((single.region.clone(), Some(fd))); + } + (VuFrontMsg::REM_MEM_REG, 40) => { + let single: MemorySingleRegion = self.session.recv_payload()?; + let region = &single.region; + let mut indexes = vec![]; + for (index, (r, _)) in self.init.regions.iter().enumerate() { + if r.gpa == region.gpa && r.hva == region.hva && r.size == region.size { + log::info!("{name}: remove mem: {r:x?}"); + indexes.push(index); + } + } + for index in indexes.iter().rev() { + self.init.regions.remove(*index); + } + let _ = self.memory.remove(region.gpa); + } + (VuFrontMsg::GET_STATUS, 0) => { + let status = self.status.bits() as u64; + self.session.reply(req, &status, &[])?; + msg.flag.set_need_reply(false); + log::debug!("{name}: get status: {status:x?}"); + } + (VuFrontMsg::SET_STATUS, 8) => { + let status: u64 = self.session.recv_payload()?; + let new = DevStatus::from_bits_retain(status as u8); + let old = self.status; + self.status = new; + log::debug!("{name}: set status: {old:x?} -> {new:x?}"); + if (old ^ new).contains(DevStatus::DRIVER_OK) { + let event = if new.contains(DevStatus::DRIVER_OK) { + let param = self.parse_init()?; + WakeEvent::Start { param } + } else { + WakeEvent::Reset + }; + self.wake_up_dev(event); + } + } + (VuFrontMsg::GET_CONFIG, 268) => { + let dev_config: DeviceConfig = self.session.recv_payload()?; + let mut done = 0; + let mut resp = DeviceConfig::new_zeroed(); + while let Some(n) = (dev_config.size as usize - done).checked_ilog2() { + let size = min(1 << n, 8) as u8; + let offset = dev_config.offset as u64 + done as u64; + let v = self.dev.device_config.read(offset, size)?; + resp.region[done..(done + size as usize)] + .copy_from_slice(&v.as_bytes()[..size as usize]); + done += size as usize; + } + resp.offset = dev_config.offset; + resp.size = dev_config.size; + resp.flags = dev_config.flags; + self.session.reply(req, &resp, &[])?; + log::debug!("{name}: get config: {dev_config:?}"); + msg.flag.set_need_reply(false); + } + (VuFrontMsg::SET_CONFIG, 268) => { + let dev_config: DeviceConfig = self.session.recv_payload()?; + let mut done = 0; + while let Some(n) = (dev_config.size as usize - done).checked_ilog2() { + let size = min(1 << n, 8) as u8; + let mut v = 0; + v.as_mut_bytes()[..size as usize] + .copy_from_slice(&dev_config.region[done..(done + size as usize)]); + let offset = dev_config.offset as u64 + done as u64; + self.dev.device_config.write(offset, size, v)?; + done += size as usize; + } + log::debug!("{name}: set config: {dev_config:?}"); + } + _ => return error::InvalidMsg { req, size }.fail(), + } + Ok(()) + } + + pub fn run(&mut self) -> Result<()> { + let mut fds = [const { None }; 8]; + loop { + let msg = self.session.recv_msg(&mut fds); + match msg { + Ok(mut msg) => { + let ret = self.handle_msg(&mut msg, &mut fds); + if let Err(e) = &ret { + let name = &*self.dev.name; + log::error!("{name}: cannot handle message {:#x}: {e:?}", msg.request); + } + let req = VuFrontMsg::from(msg.request); + if msg.flag.need_reply() { + let code = if ret.is_ok() { 0 } else { u64::MAX }; + self.session.reply(req, &code, &[])?; + } + } + Err(VuError::System { error, .. }) + if error.kind() == ErrorKind::ConnectionAborted => + { + break; + } + Err(e) => return Err(e)?, + } + } + Ok(()) + } +} diff --git a/alioth/src/virtio/vu/vu.rs b/alioth/src/virtio/vu/vu.rs index 32ca61a7..f9c0f273 100644 --- a/alioth/src/virtio/vu/vu.rs +++ b/alioth/src/virtio/vu/vu.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod backend; pub mod bindings; pub mod conn; pub mod frontend; From 72501916802db7d9d2f7bcb9b71d3156771c4aa1 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sun, 27 Apr 2025 18:39:24 -0700 Subject: [PATCH 07/10] feat(cli): add "run" as an alias of "boot" Also split main.rs into modules. Signed-off-by: Changyuan Lyu --- README.md | 2 +- alioth-cli/src/boot.rs | 440 ++++++++++++++++++++++++++++++++++++ alioth-cli/src/main.rs | 459 +------------------------------------- alioth-cli/src/objects.rs | 63 ++++++ 4 files changed, 510 insertions(+), 454 deletions(-) create mode 100644 alioth-cli/src/boot.rs create mode 100644 alioth-cli/src/objects.rs diff --git a/README.md b/README.md index e55eb469..a733aa3d 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ Alioth /AL-lee-oth/ is an experimental ```sh alioth -l info --log-to-file \ - run \ + boot \ --kernel /path/to/vmlinuz \ --cmd-line "console=ttyS0" \ --initramfs /path/to/initramfs \ diff --git a/alioth-cli/src/boot.rs b/alioth-cli/src/boot.rs new file mode 100644 index 00000000..ca456285 --- /dev/null +++ b/alioth-cli/src/boot.rs @@ -0,0 +1,440 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(target_arch = "x86_64")] +use std::ffi::CString; +#[cfg(target_arch = "x86_64")] +use std::fs::File; +use std::path::PathBuf; + +use alioth::board::BoardConfig; +#[cfg(target_arch = "x86_64")] +use alioth::device::fw_cfg::FwCfgItemParam; +use alioth::errors::{DebugTrace, trace_error}; +use alioth::hv::Coco; +#[cfg(target_os = "macos")] +use alioth::hv::Hvf; +#[cfg(target_os = "linux")] +use alioth::hv::{Kvm, KvmConfig}; +use alioth::loader::{ExecType, Payload}; +use alioth::mem::{MemBackend, MemConfig}; +#[cfg(target_os = "linux")] +use alioth::vfio::{CdevParam, ContainerParam, GroupParam, IoasParam}; +use alioth::virtio::dev::balloon::BalloonParam; +use alioth::virtio::dev::blk::BlockParam; +use alioth::virtio::dev::entropy::EntropyParam; +#[cfg(target_os = "linux")] +use alioth::virtio::dev::fs::VuFsParam; +#[cfg(target_os = "linux")] +use alioth::virtio::dev::net::NetParam; +#[cfg(target_os = "linux")] +use alioth::virtio::dev::vsock::VhostVsockParam; +use alioth::vm::Machine; +use clap::Args; +use serde::Deserialize; +use serde_aco::{Help, help_text}; +use snafu::{ResultExt, Snafu}; + +use crate::objects::{DOC_OBJECTS, parse_objects}; + +#[trace_error] +#[derive(Snafu, DebugTrace)] +#[snafu(module, context(suffix(false)))] +pub enum Error { + #[snafu(display("Failed to parse {arg}"))] + ParseArg { + arg: String, + error: serde_aco::Error, + }, + #[snafu(display("Failed to parse objects"), context(false))] + ParseObjects { source: crate::objects::Error }, + #[cfg(target_os = "linux")] + #[snafu(display("Failed to access system hypervisor"))] + Hypervisor { source: alioth::hv::Error }, + #[snafu(display("Failed to create a VM"))] + CreateVm { source: alioth::vm::Error }, + #[snafu(display("Failed to create a device"))] + CreateDevice { source: alioth::vm::Error }, + #[cfg(target_arch = "x86_64")] + #[snafu(display("Failed to open {path:?}"))] + OpenFile { + path: PathBuf, + error: std::io::Error, + }, + #[cfg(target_arch = "x86_64")] + #[snafu(display("Failed to configure the fw-cfg device"))] + FwCfg { error: std::io::Error }, + #[cfg(target_arch = "x86_64")] + #[snafu(display("{s} is not a valid CString"))] + CreateCString { s: String }, + #[snafu(display("Failed to boot a VM"))] + BootVm { source: alioth::vm::Error }, + #[snafu(display("VM did not shutdown peacefully"))] + WaitVm { source: alioth::vm::Error }, +} + +#[derive(Debug, Deserialize, Clone, Help)] +#[cfg_attr(target_os = "macos", derive(Default))] +enum Hypervisor { + /// KVM backed by the Linux kernel. + #[cfg(target_os = "linux")] + #[serde(alias = "kvm")] + Kvm(KvmConfig), + /// macOS Hypervisor Framework. + #[cfg(target_os = "macos")] + #[serde(alias = "hvf")] + #[default] + Hvf, +} + +#[cfg(target_os = "linux")] +impl Default for Hypervisor { + fn default() -> Self { + Hypervisor::Kvm(KvmConfig::default()) + } +} + +#[cfg(target_os = "linux")] +#[derive(Debug, Deserialize, Clone, Help)] +enum FsParam { + #[serde(alias = "vu")] + /// VirtIO device backed by a vhost-user process, e.g. virtiofsd. + Vu(VuFsParam), +} + +#[cfg(target_os = "linux")] +#[derive(Debug, Deserialize, Clone, Help)] +enum VsockParam { + /// Vsock device backed by host kernel vhost-vsock module. + #[serde(alias = "vhost")] + Vhost(VhostVsockParam), +} + +#[derive(Args, Debug, Clone)] +#[command(arg_required_else_help = true, alias("run"))] +pub struct BootArgs { + #[arg(long, help( + help_text::("Specify the Hypervisor to run on.") + ), value_name = "HV")] + hypervisor: Option, + + /// Path to a Linux kernel image. + #[arg(short, long, value_name = "PATH")] + kernel: Option, + + /// Path to an ELF kernel with PVH note. + #[cfg(target_arch = "x86_64")] + #[arg(long, value_name = "PATH")] + pvh: Option, + + /// Path to a firmware image. + #[arg(long, short, value_name = "PATH")] + firmware: Option, + + /// Command line to pass to the kernel, e.g. `console=ttyS0`. + #[arg(short, long, value_name = "ARGS")] + cmd_line: Option, + + /// Path to an initramfs image. + #[arg(short, long, value_name = "PATH")] + initramfs: Option, + + /// Number of VCPUs assigned to the guest. + #[arg(long, default_value_t = 1)] + num_cpu: u32, + + /// DEPRECATED: Use --memory instead. + #[arg(long, default_value = "1G")] + mem_size: String, + + #[arg(short, long, help( + help_text::("Specify the memory of the guest.") + ))] + memory: Option, + + /// Add a pvpanic device. + #[arg(long)] + pvpanic: bool, + + #[cfg(target_arch = "x86_64")] + #[arg(long = "fw-cfg", help( + help_text::("Add an extra item to the fw_cfg device.") + ), value_name = "ITEM")] + fw_cfgs: Vec, + + /// Add a VirtIO entropy device. + #[arg(long)] + entropy: bool, + + #[cfg(target_os = "linux")] + #[arg(long, help( + help_text::("Add a VirtIO net device backed by TUN/TAP, MacVTap, or IPVTap.") + ))] + net: Vec, + + #[arg(long, help( + help_text::("Add a VirtIO block device.") + ))] + blk: Vec, + + #[arg(long, help( + help_text::("Enable confidential compute supported by host platform.") + ))] + coco: Option, + + #[cfg(target_os = "linux")] + #[arg(long, help( + help_text::("Add a VirtIO filesystem device.") + ))] + fs: Vec, + + #[cfg(target_os = "linux")] + #[arg(long, help( + help_text::("Add a VirtIO vsock device.") + ))] + vsock: Option, + + #[cfg(target_os = "linux")] + #[arg(long, help(help_text::( + "Assign a host PCI device to the guest using IOMMUFD API." + ) ))] + vfio_cdev: Vec, + + #[cfg(target_os = "linux")] + #[arg(long, help(help_text::("Create a new IO address space.")))] + vfio_ioas: Vec, + + #[cfg(target_os = "linux")] + #[arg(long, help(help_text::( + "Assign a host PCI device to the guest using legacy VFIO API." + )))] + vfio_group: Vec, + + #[cfg(target_os = "linux")] + #[arg(long, help(help_text::("Add a new VFIO container.")))] + vfio_container: Vec, + + #[arg(long)] + #[arg(long, help(help_text::("Add a VirtIO balloon device.")))] + balloon: Option, + + #[arg(short, long("object"), help = DOC_OBJECTS, value_name = "OBJECT")] + objects: Vec, +} + +pub fn boot(args: BootArgs) -> Result<(), Error> { + let objects = parse_objects(&args.objects)?; + let hv_config = if let Some(hv_cfg_opt) = args.hypervisor { + serde_aco::from_args(&hv_cfg_opt, &objects).context(error::ParseArg { arg: hv_cfg_opt })? + } else { + Hypervisor::default() + }; + let hypervisor = match hv_config { + #[cfg(target_os = "linux")] + Hypervisor::Kvm(kvm_config) => Kvm::new(kvm_config).context(error::Hypervisor)?, + #[cfg(target_os = "macos")] + Hypervisor::Hvf => Hvf {}, + }; + let coco = match args.coco { + None => None, + Some(c) => Some(serde_aco::from_args(&c, &objects).context(error::ParseArg { arg: c })?), + }; + let mem_config = if let Some(s) = args.memory { + serde_aco::from_args(&s, &objects).context(error::ParseArg { arg: s })? + } else { + #[cfg(target_os = "linux")] + eprintln!( + "Please update the cmd line to --memory size={},backend=memfd", + args.mem_size + ); + let size = serde_aco::from_args(&args.mem_size, &objects) + .context(error::ParseArg { arg: args.mem_size })?; + MemConfig { + size, + #[cfg(target_os = "linux")] + backend: MemBackend::Memfd, + #[cfg(not(target_os = "linux"))] + backend: MemBackend::Anonymous, + ..Default::default() + } + }; + let board_config = BoardConfig { + mem: mem_config, + num_cpu: args.num_cpu, + coco, + }; + let vm = Machine::new(hypervisor, board_config).context(error::CreateVm)?; + #[cfg(target_arch = "x86_64")] + vm.add_com1().context(error::CreateDevice)?; + #[cfg(target_arch = "aarch64")] + vm.add_pl011().context(error::CreateDevice)?; + + if args.pvpanic { + vm.add_pvpanic().context(error::CreateDevice)?; + } + + #[cfg(target_arch = "x86_64")] + if args.firmware.is_some() || !args.fw_cfgs.is_empty() { + let params = args + .fw_cfgs + .into_iter() + .map(|s| serde_aco::from_args(&s, &objects).context(error::ParseArg { arg: s })) + .collect::, _>>()?; + let fw_cfg = vm + .add_fw_cfg(params.into_iter()) + .context(error::CreateDevice)?; + let mut dev = fw_cfg.lock(); + + if let Some(kernel) = &args.kernel { + dev.add_kernel_data(File::open(kernel).context(error::OpenFile { path: kernel })?) + .context(error::FwCfg)? + } + if let Some(initramfs) = &args.initramfs { + dev.add_initramfs_data( + File::open(initramfs).context(error::OpenFile { path: initramfs })?, + ) + .context(error::FwCfg)?; + } + if let Some(cmdline) = &args.cmd_line { + let Ok(cmdline_c) = CString::new(cmdline.as_str()) else { + return error::CreateCString { + s: cmdline.to_owned(), + } + .fail(); + }; + dev.add_kernel_cmdline(cmdline_c); + } + }; + + if args.entropy { + vm.add_virtio_dev("virtio-entropy", EntropyParam) + .context(error::CreateDevice)?; + } + #[cfg(target_os = "linux")] + for (index, net_opt) in args.net.into_iter().enumerate() { + let net_param: NetParam = + serde_aco::from_args(&net_opt, &objects).context(error::ParseArg { arg: net_opt })?; + vm.add_virtio_dev(format!("virtio-net-{index}"), net_param) + .context(error::CreateDevice)?; + } + for (index, blk) in args.blk.into_iter().enumerate() { + let param = match serde_aco::from_args(&blk, &objects) { + Ok(param) => param, + Err(serde_aco::Error::ExpectedMapEq) => { + eprintln!( + "Please update the cmd line to --blk path={blk}, see https://github.com/google/alioth/pull/72 for details" + ); + BlockParam { + path: blk.into(), + ..Default::default() + } + } + Err(e) => return Err(e).context(error::ParseArg { arg: blk })?, + }; + + vm.add_virtio_dev(format!("virtio-blk-{index}"), param) + .context(error::CreateDevice)?; + } + #[cfg(target_os = "linux")] + for (index, fs) in args.fs.into_iter().enumerate() { + let param: FsParam = + serde_aco::from_args(&fs, &objects).context(error::ParseArg { arg: fs })?; + match param { + FsParam::Vu(p) => vm + .add_virtio_dev(format!("vu-fs-{index}"), p) + .context(error::CreateDevice)?, + }; + } + #[cfg(target_os = "linux")] + if let Some(vsock) = args.vsock { + let param = + serde_aco::from_args(&vsock, &objects).context(error::ParseArg { arg: vsock })?; + match param { + VsockParam::Vhost(p) => vm + .add_virtio_dev("vhost-vsock", p) + .context(error::CreateDevice)?, + }; + } + if let Some(balloon) = args.balloon { + let param: BalloonParam = + serde_aco::from_args(&balloon, &objects).context(error::ParseArg { arg: balloon })?; + vm.add_virtio_dev("virtio-balloon", param) + .context(error::CreateDevice)?; + } + + #[cfg(target_os = "linux")] + for ioas in args.vfio_ioas.into_iter() { + let param: IoasParam = + serde_aco::from_args(&ioas, &objects).context(error::ParseArg { arg: ioas })?; + vm.add_vfio_ioas(param).context(error::CreateDevice)?; + } + #[cfg(target_os = "linux")] + for (index, vfio) in args.vfio_cdev.into_iter().enumerate() { + let param: CdevParam = + serde_aco::from_args(&vfio, &objects).context(error::ParseArg { arg: vfio })?; + vm.add_vfio_cdev(format!("vfio-{index}").into(), param) + .context(error::CreateDevice)?; + } + + #[cfg(target_os = "linux")] + for container in args.vfio_container.into_iter() { + let param: ContainerParam = serde_aco::from_args(&container, &objects) + .context(error::ParseArg { arg: container })?; + vm.add_vfio_container(param).context(error::CreateDevice)?; + } + #[cfg(target_os = "linux")] + for (index, group) in args.vfio_group.into_iter().enumerate() { + let param: GroupParam = + serde_aco::from_args(&group, &objects).context(error::ParseArg { arg: group })?; + vm.add_vfio_devs_in_group(&index.to_string(), param) + .context(error::CreateDevice)?; + } + + let payload = if let Some(fw) = args.firmware { + Some(Payload { + executable: fw, + exec_type: ExecType::Firmware, + initramfs: None, + cmd_line: None, + }) + } else if let Some(kernel) = args.kernel { + Some(Payload { + exec_type: ExecType::Linux, + executable: kernel, + initramfs: args.initramfs, + cmd_line: args.cmd_line, + }) + } else { + #[cfg(target_arch = "x86_64")] + if let Some(pvh_kernel) = args.pvh { + Some(Payload { + executable: pvh_kernel, + exec_type: ExecType::Pvh, + initramfs: args.initramfs, + cmd_line: args.cmd_line, + }) + } else { + None + } + #[cfg(not(target_arch = "x86_64"))] + None + }; + if let Some(payload) = payload { + vm.add_payload(payload); + } + + vm.boot().context(error::BootVm)?; + vm.wait().context(error::WaitVm)?; + Ok(()) +} diff --git a/alioth-cli/src/main.rs b/alioth-cli/src/main.rs index fcc2eef2..afc3066b 100644 --- a/alioth-cli/src/main.rs +++ b/alioth-cli/src/main.rs @@ -12,41 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; -#[cfg(target_arch = "x86_64")] -use std::ffi::CString; -#[cfg(target_arch = "x86_64")] -use std::fs::File; +mod boot; +mod objects; + use std::path::PathBuf; -use alioth::board::BoardConfig; -#[cfg(target_arch = "x86_64")] -use alioth::device::fw_cfg::FwCfgItemParam; -use alioth::errors::{DebugTrace, trace_error}; -use alioth::hv::Coco; -#[cfg(target_os = "macos")] -use alioth::hv::Hvf; -#[cfg(target_os = "linux")] -use alioth::hv::{Kvm, KvmConfig}; -use alioth::loader::{ExecType, Payload}; -use alioth::mem::{MemBackend, MemConfig}; -#[cfg(target_os = "linux")] -use alioth::vfio::{CdevParam, ContainerParam, GroupParam, IoasParam}; -use alioth::virtio::dev::balloon::BalloonParam; -use alioth::virtio::dev::blk::BlockParam; -use alioth::virtio::dev::entropy::EntropyParam; -#[cfg(target_os = "linux")] -use alioth::virtio::dev::fs::VuFsParam; -#[cfg(target_os = "linux")] -use alioth::virtio::dev::net::NetParam; -#[cfg(target_os = "linux")] -use alioth::virtio::dev::vsock::VhostVsockParam; -use alioth::vm::Machine; -use clap::{Args, Parser, Subcommand}; +use clap::{Parser, Subcommand}; use flexi_logger::{FileSpec, Logger}; -use serde::Deserialize; -use serde_aco::{Help, help_text}; -use snafu::{ResultExt, Snafu}; #[derive(Parser, Debug)] #[command(author, version, about)] @@ -72,426 +44,7 @@ struct Cli { #[derive(Subcommand, Debug)] enum Command { /// Create and boot a virtual machine. - Run(RunArgs), -} - -#[derive(Debug, Deserialize, Clone, Help)] -#[cfg_attr(target_os = "macos", derive(Default))] -enum Hypervisor { - /// KVM backed by the Linux kernel. - #[cfg(target_os = "linux")] - #[serde(alias = "kvm")] - Kvm(KvmConfig), - /// macOS Hypervisor Framework. - #[cfg(target_os = "macos")] - #[serde(alias = "hvf")] - #[default] - Hvf, -} - -#[cfg(target_os = "linux")] -impl Default for Hypervisor { - fn default() -> Self { - Hypervisor::Kvm(KvmConfig::default()) - } -} - -#[cfg(target_os = "linux")] -#[derive(Debug, Deserialize, Clone, Help)] -enum FsParam { - #[serde(alias = "vu")] - /// VirtIO device backed by a vhost-user process, e.g. virtiofsd. - Vu(VuFsParam), -} - -#[cfg(target_os = "linux")] -#[derive(Debug, Deserialize, Clone, Help)] -enum VsockParam { - /// Vsock device backed by host kernel vhost-vsock module. - #[serde(alias = "vhost")] - Vhost(VhostVsockParam), -} - -const DOC_OBJECTS: &str = r#"Supply additional data to other command line flags. -* , - -Any value that comes after an equal sign(=) and contains a comma(,) -or equal sign can be supplied using this flag. `` must start -with `id_` and `` cannot contain any comma or equal sign. - -Example: assuming we are going a add a virtio-blk device backed by -`/path/to/disk,2024.img` and a virtio-fs device backed by a -vhost-user process listening on socket `/path/to/socket=1`, these -2 devices can be expressed in the command line as follows: - --blk path=id_blk --fs vu,socket=id_fs,tag=shared-dir \ - -o id_blk,/path/to/disk,2024.img \ - -o id_fs,/path/to/socket=1"#; - -#[derive(Args, Debug, Clone)] -#[command(arg_required_else_help = true)] -struct RunArgs { - #[arg(long, help( - help_text::("Specify the Hypervisor to run on.") - ), value_name = "HV")] - hypervisor: Option, - - /// Path to a Linux kernel image. - #[arg(short, long, value_name = "PATH")] - kernel: Option, - - /// Path to an ELF kernel with PVH note. - #[cfg(target_arch = "x86_64")] - #[arg(long, value_name = "PATH")] - pvh: Option, - - /// Path to a firmware image. - #[arg(long, short, value_name = "PATH")] - firmware: Option, - - /// Command line to pass to the kernel, e.g. `console=ttyS0`. - #[arg(short, long, value_name = "ARGS")] - cmd_line: Option, - - /// Path to an initramfs image. - #[arg(short, long, value_name = "PATH")] - initramfs: Option, - - /// Number of VCPUs assigned to the guest. - #[arg(long, default_value_t = 1)] - num_cpu: u32, - - /// DEPRECATED: Use --memory instead. - #[arg(long, default_value = "1G")] - mem_size: String, - - #[arg(short, long, help( - help_text::("Specify the memory of the guest.") - ))] - memory: Option, - - /// Add a pvpanic device. - #[arg(long)] - pvpanic: bool, - - #[cfg(target_arch = "x86_64")] - #[arg(long = "fw-cfg", help( - help_text::("Add an extra item to the fw_cfg device.") - ), value_name = "ITEM")] - fw_cfgs: Vec, - - /// Add a VirtIO entropy device. - #[arg(long)] - entropy: bool, - - #[cfg(target_os = "linux")] - #[arg(long, help( - help_text::("Add a VirtIO net device backed by TUN/TAP, MacVTap, or IPVTap.") - ))] - net: Vec, - - #[arg(long, help( - help_text::("Add a VirtIO block device.") - ))] - blk: Vec, - - #[arg(long, help( - help_text::("Enable confidential compute supported by host platform.") - ))] - coco: Option, - - #[cfg(target_os = "linux")] - #[arg(long, help( - help_text::("Add a VirtIO filesystem device.") - ))] - fs: Vec, - - #[cfg(target_os = "linux")] - #[arg(long, help( - help_text::("Add a VirtIO vsock device.") - ))] - vsock: Option, - - #[cfg(target_os = "linux")] - #[arg(long, help(help_text::( - "Assign a host PCI device to the guest using IOMMUFD API." - ) ))] - vfio_cdev: Vec, - - #[cfg(target_os = "linux")] - #[arg(long, help(help_text::("Create a new IO address space.")))] - vfio_ioas: Vec, - - #[cfg(target_os = "linux")] - #[arg(long, help(help_text::( - "Assign a host PCI device to the guest using legacy VFIO API." - )))] - vfio_group: Vec, - - #[cfg(target_os = "linux")] - #[arg(long, help(help_text::("Add a new VFIO container.")))] - vfio_container: Vec, - - #[arg(long)] - #[arg(long, help(help_text::("Add a VirtIO balloon device.")))] - balloon: Option, - - #[arg(short, long("object"), help = DOC_OBJECTS, value_name = "OBJECT")] - objects: Vec, -} - -#[trace_error] -#[derive(Snafu, DebugTrace)] -#[snafu(module, context(suffix(false)))] -pub enum Error { - #[snafu(display("Failed to parse {arg}"))] - ParseArg { - arg: String, - error: serde_aco::Error, - }, - #[snafu(display("Failed to access system hypervisor"))] - Hypervisor { source: alioth::hv::Error }, - #[snafu(display("Failed to create a VM"))] - CreateVm { source: alioth::vm::Error }, - #[snafu(display("Failed to create a device"))] - CreateDevice { source: alioth::vm::Error }, - #[snafu(display("Failed to open {path:?}"))] - OpenFile { - path: PathBuf, - error: std::io::Error, - }, - #[snafu(display("Failed to configure the fw-cfg device"))] - FwCfg { error: std::io::Error }, - #[snafu(display("{s} is not a valid CString"))] - CreateCString { s: String }, - #[snafu(display("Failed to boot a VM"))] - BootVm { source: alioth::vm::Error }, - #[snafu(display("VM did not shutdown peacefully"))] - WaitVm { source: alioth::vm::Error }, - #[snafu(display("Invalid object key {key:?}, must start with `id_`"))] - InvalidKey { key: String }, - #[snafu(display("Key {key:?} showed up more than once"))] - DuplicateKey { key: String }, -} - -fn main_run(args: RunArgs) -> Result<(), Error> { - let mut objects = HashMap::new(); - for obj_s in &args.objects { - let (key, val) = obj_s.split_once(',').unwrap_or((obj_s, "")); - if !key.starts_with("id_") { - return error::InvalidKey { - key: key.to_owned(), - } - .fail(); - } - if objects.insert(key, val).is_some() { - return error::DuplicateKey { - key: key.to_owned(), - } - .fail(); - } - } - let hv_config = if let Some(hv_cfg_opt) = args.hypervisor { - serde_aco::from_args(&hv_cfg_opt, &objects).context(error::ParseArg { arg: hv_cfg_opt })? - } else { - Hypervisor::default() - }; - let hypervisor = match hv_config { - #[cfg(target_os = "linux")] - Hypervisor::Kvm(kvm_config) => Kvm::new(kvm_config).context(error::Hypervisor)?, - #[cfg(target_os = "macos")] - Hypervisor::Hvf => Hvf {}, - }; - let coco = match args.coco { - None => None, - Some(c) => Some(serde_aco::from_args(&c, &objects).context(error::ParseArg { arg: c })?), - }; - let mem_config = if let Some(s) = args.memory { - serde_aco::from_args(&s, &objects).context(error::ParseArg { arg: s })? - } else { - #[cfg(target_os = "linux")] - eprintln!( - "Please update the cmd line to --memory size={},backend=memfd", - args.mem_size - ); - let size = serde_aco::from_args(&args.mem_size, &objects) - .context(error::ParseArg { arg: args.mem_size })?; - MemConfig { - size, - #[cfg(target_os = "linux")] - backend: MemBackend::Memfd, - #[cfg(not(target_os = "linux"))] - backend: MemBackend::Anonymous, - ..Default::default() - } - }; - let board_config = BoardConfig { - mem: mem_config, - num_cpu: args.num_cpu, - coco, - }; - let vm = Machine::new(hypervisor, board_config).context(error::CreateVm)?; - #[cfg(target_arch = "x86_64")] - vm.add_com1().context(error::CreateDevice)?; - #[cfg(target_arch = "aarch64")] - vm.add_pl011().context(error::CreateDevice)?; - - if args.pvpanic { - vm.add_pvpanic().context(error::CreateDevice)?; - } - - #[cfg(target_arch = "x86_64")] - if args.firmware.is_some() || !args.fw_cfgs.is_empty() { - let params = args - .fw_cfgs - .into_iter() - .map(|s| serde_aco::from_args(&s, &objects).context(error::ParseArg { arg: s })) - .collect::, _>>()?; - let fw_cfg = vm - .add_fw_cfg(params.into_iter()) - .context(error::CreateDevice)?; - let mut dev = fw_cfg.lock(); - #[cfg(target_arch = "x86_64")] - if let Some(kernel) = &args.kernel { - dev.add_kernel_data(File::open(kernel).context(error::OpenFile { path: kernel })?) - .context(error::FwCfg)? - } - if let Some(initramfs) = &args.initramfs { - dev.add_initramfs_data( - File::open(initramfs).context(error::OpenFile { path: initramfs })?, - ) - .context(error::FwCfg)?; - } - if let Some(cmdline) = &args.cmd_line { - let Ok(cmdline_c) = CString::new(cmdline.as_str()) else { - return error::CreateCString { - s: cmdline.to_owned(), - } - .fail(); - }; - dev.add_kernel_cmdline(cmdline_c); - } - }; - - if args.entropy { - vm.add_virtio_dev("virtio-entropy", EntropyParam) - .context(error::CreateDevice)?; - } - #[cfg(target_os = "linux")] - for (index, net_opt) in args.net.into_iter().enumerate() { - let net_param: NetParam = - serde_aco::from_args(&net_opt, &objects).context(error::ParseArg { arg: net_opt })?; - vm.add_virtio_dev(format!("virtio-net-{index}"), net_param) - .context(error::CreateDevice)?; - } - for (index, blk) in args.blk.into_iter().enumerate() { - let param = match serde_aco::from_args(&blk, &objects) { - Ok(param) => param, - Err(serde_aco::Error::ExpectedMapEq) => { - eprintln!( - "Please update the cmd line to --blk path={blk}, see https://github.com/google/alioth/pull/72 for details" - ); - BlockParam { - path: blk.into(), - ..Default::default() - } - } - Err(e) => return Err(e).context(error::ParseArg { arg: blk })?, - }; - - vm.add_virtio_dev(format!("virtio-blk-{index}"), param) - .context(error::CreateDevice)?; - } - #[cfg(target_os = "linux")] - for (index, fs) in args.fs.into_iter().enumerate() { - let param: FsParam = - serde_aco::from_args(&fs, &objects).context(error::ParseArg { arg: fs })?; - match param { - FsParam::Vu(p) => vm - .add_virtio_dev(format!("vu-fs-{index}"), p) - .context(error::CreateDevice)?, - }; - } - #[cfg(target_os = "linux")] - if let Some(vsock) = args.vsock { - let param = - serde_aco::from_args(&vsock, &objects).context(error::ParseArg { arg: vsock })?; - match param { - VsockParam::Vhost(p) => vm - .add_virtio_dev("vhost-vsock", p) - .context(error::CreateDevice)?, - }; - } - if let Some(balloon) = args.balloon { - let param: BalloonParam = - serde_aco::from_args(&balloon, &objects).context(error::ParseArg { arg: balloon })?; - vm.add_virtio_dev("virtio-balloon", param) - .context(error::CreateDevice)?; - } - - #[cfg(target_os = "linux")] - for ioas in args.vfio_ioas.into_iter() { - let param: IoasParam = - serde_aco::from_args(&ioas, &objects).context(error::ParseArg { arg: ioas })?; - vm.add_vfio_ioas(param).context(error::CreateDevice)?; - } - #[cfg(target_os = "linux")] - for (index, vfio) in args.vfio_cdev.into_iter().enumerate() { - let param: CdevParam = - serde_aco::from_args(&vfio, &objects).context(error::ParseArg { arg: vfio })?; - vm.add_vfio_cdev(format!("vfio-{index}").into(), param) - .context(error::CreateDevice)?; - } - - #[cfg(target_os = "linux")] - for container in args.vfio_container.into_iter() { - let param: ContainerParam = serde_aco::from_args(&container, &objects) - .context(error::ParseArg { arg: container })?; - vm.add_vfio_container(param).context(error::CreateDevice)?; - } - #[cfg(target_os = "linux")] - for (index, group) in args.vfio_group.into_iter().enumerate() { - let param: GroupParam = - serde_aco::from_args(&group, &objects).context(error::ParseArg { arg: group })?; - vm.add_vfio_devs_in_group(&index.to_string(), param) - .context(error::CreateDevice)?; - } - - let payload = if let Some(fw) = args.firmware { - Some(Payload { - executable: fw, - exec_type: ExecType::Firmware, - initramfs: None, - cmd_line: None, - }) - } else if let Some(kernel) = args.kernel { - Some(Payload { - exec_type: ExecType::Linux, - executable: kernel, - initramfs: args.initramfs, - cmd_line: args.cmd_line, - }) - } else { - #[cfg(target_arch = "x86_64")] - if let Some(pvh_kernel) = args.pvh { - Some(Payload { - executable: pvh_kernel, - exec_type: ExecType::Pvh, - initramfs: args.initramfs, - cmd_line: args.cmd_line, - }) - } else { - None - } - #[cfg(not(target_arch = "x86_64"))] - None - }; - if let Some(payload) = payload { - vm.add_payload(payload); - } - - vm.boot().context(error::BootVm)?; - vm.wait().context(error::WaitVm)?; - Ok(()) + Boot(boot::BootArgs), } fn main() -> Result<(), Box> { @@ -518,7 +71,7 @@ fn main() -> Result<(), Box> { ); match cli.cmd { - Command::Run(args) => main_run(args)?, + Command::Boot(args) => boot::boot(args)?, } Ok(()) } diff --git a/alioth-cli/src/objects.rs b/alioth-cli/src/objects.rs new file mode 100644 index 00000000..338ace96 --- /dev/null +++ b/alioth-cli/src/objects.rs @@ -0,0 +1,63 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use alioth::errors::{DebugTrace, trace_error}; +use snafu::Snafu; + +pub const DOC_OBJECTS: &str = r#"Supply additional data to other command line flags. +* , + +Any value that comes after an equal sign(=) and contains a comma(,) +or equal sign can be supplied using this flag. `` must start +with `id_` and `` cannot contain any comma or equal sign. + +Example: assuming we are going a add a virtio-blk device backed by +`/path/to/disk,2024.img` and a virtio-fs device backed by a +vhost-user process listening on socket `/path/to/socket=1`, these +2 devices can be expressed in the command line as follows: + --blk path=id_blk --fs vu,socket=id_fs,tag=shared-dir \ + -o id_blk,/path/to/disk,2024.img \ + -o id_fs,/path/to/socket=1"#; + +#[trace_error] +#[derive(Snafu, DebugTrace)] +#[snafu(module, context(suffix(false)))] +pub enum Error { + #[snafu(display("Invalid object key {key:?}, must start with `id_`"))] + InvalidKey { key: String }, + #[snafu(display("Key {key:?} showed up more than once"))] + DuplicateKey { key: String }, +} + +pub fn parse_objects(objects: &[String]) -> Result, Error> { + let mut map = HashMap::new(); + for obj_s in objects { + let (key, val) = obj_s.split_once(',').unwrap_or((obj_s, "")); + if !key.starts_with("id_") { + return error::InvalidKey { + key: key.to_owned(), + } + .fail(); + } + if map.insert(key, val).is_some() { + return error::DuplicateKey { + key: key.to_owned(), + } + .fail(); + } + } + Ok(map) +} From 79c7dcde97a0c9cf95bbf474d6060ae2e841c1b0 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Tue, 29 Apr 2025 22:12:42 -0700 Subject: [PATCH 08/10] feat(cli): add cli interface for vhost-user backends Signed-off-by: Changyuan Lyu --- alioth-cli/src/main.rs | 21 +++-- alioth-cli/src/vu.rs | 154 +++++++++++++++++++++++++++++++ alioth/src/net/net.rs | 2 +- alioth/src/virtio/dev/net/net.rs | 2 +- 4 files changed, 170 insertions(+), 9 deletions(-) create mode 100644 alioth-cli/src/vu.rs diff --git a/alioth-cli/src/main.rs b/alioth-cli/src/main.rs index afc3066b..97540691 100644 --- a/alioth-cli/src/main.rs +++ b/alioth-cli/src/main.rs @@ -14,12 +14,23 @@ mod boot; mod objects; +#[cfg(target_os = "linux")] +mod vu; use std::path::PathBuf; use clap::{Parser, Subcommand}; use flexi_logger::{FileSpec, Logger}; +#[derive(Subcommand, Debug)] +enum Command { + /// Create and boot a virtual machine. + Boot(Box), + #[cfg(target_os = "linux")] + /// Start a vhost-user backend device. + Vu(Box), +} + #[derive(Parser, Debug)] #[command(author, version, about)] struct Cli { @@ -41,12 +52,6 @@ struct Cli { pub cmd: Command, } -#[derive(Subcommand, Debug)] -enum Command { - /// Create and boot a virtual machine. - Boot(boot::BootArgs), -} - fn main() -> Result<(), Box> { let cli = Cli::parse(); let logger = if let Some(ref spec) = cli.log_spec { @@ -71,7 +76,9 @@ fn main() -> Result<(), Box> { ); match cli.cmd { - Command::Boot(args) => boot::boot(args)?, + Command::Boot(args) => boot::boot(*args)?, + #[cfg(target_os = "linux")] + Command::Vu(args) => vu::start(*args)?, } Ok(()) } diff --git a/alioth-cli/src/vu.rs b/alioth-cli/src/vu.rs new file mode 100644 index 00000000..88469995 --- /dev/null +++ b/alioth-cli/src/vu.rs @@ -0,0 +1,154 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::marker::PhantomData; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::path::PathBuf; +use std::sync::Arc; +use std::thread::spawn; + +use alioth::errors::{DebugTrace, trace_error}; +use alioth::mem::mapped::RamBus; +use alioth::virtio::dev::blk::BlockParam; +use alioth::virtio::dev::net::NetParam; +use alioth::virtio::dev::{DevParam, Virtio, VirtioDevice}; +use alioth::virtio::vu::backend::{VuBackend, VuEventfd, VuIrqSender}; +use clap::{Args, Subcommand}; +use serde::Deserialize; +use serde_aco::{Help, help_text}; +use snafu::{ResultExt, Snafu}; + +use crate::objects::{DOC_OBJECTS, parse_objects}; + +#[trace_error] +#[derive(Snafu, DebugTrace)] +#[snafu(module, context(suffix(false)))] +pub enum Error { + #[snafu(display("Failed to parse {arg}"))] + ParseArg { + arg: String, + error: serde_aco::Error, + }, + #[snafu(display("Failed to parse objects"), context(false))] + ParseObjects { source: crate::objects::Error }, + #[snafu(display("Failed to bind socket {socket:?}"))] + Bind { + socket: PathBuf, + error: std::io::Error, + }, + #[snafu(display("Failed to create a VirtIO device"))] + CreateVirtio { source: alioth::virtio::Error }, + #[snafu(display("Failed to create a vhost-user backend"))] + CreateVu { + source: alioth::virtio::vu::backend::Error, + }, + #[snafu(display("vhost-user device runtime error"))] + Runtime { + source: alioth::virtio::vu::backend::Error, + }, +} + +fn phantom_parser(_: &str) -> Result, &'static str> { + Ok(PhantomData) +} + +#[derive(Args, Debug, Clone)] +pub struct DevArgs +where + T: Help + Send + Sync + 'static, +{ + #[arg(short, long, value_name("PARAM"), help(help_text::("Specify device parameters.")))] + pub param: String, + + #[arg(short, long("object"), help(DOC_OBJECTS), value_name("OBJECT"))] + pub objects: Vec, + + #[arg(hide(true), value_parser(phantom_parser::), default_value(""))] + pub phantom: PhantomData, +} + +#[derive(Subcommand, Debug, Clone)] +pub enum DevType { + /// VirtIO net device backed by TUN/TAP, MacVTap, or IPVTap. + Net(DevArgs), + /// VirtIO block device backed by a file. + Blk(DevArgs), +} + +#[derive(Args, Debug, Clone)] +#[command(arg_required_else_help = true)] +pub struct VuArgs { + /// Path to a Unix domain socket to listen on. + #[arg(short, long, value_name = "PATH")] + pub socket: PathBuf, + + #[command(subcommand)] + pub ty: DevType, +} + +fn create_dev( + name: String, + args: &DevArgs

, + memory: Arc, +) -> Result, Error> +where + D: Virtio, + P: DevParam + Help + for<'a> Deserialize<'a> + Send + Sync + 'static, +{ + let name: Arc = name.into(); + let objects = parse_objects(&args.objects)?; + let param: P = serde_aco::from_args(&args.param, &objects) + .context(error::ParseArg { arg: &args.param })?; + let dev = param.build(name.clone()).context(error::CreateVirtio)?; + let dev = VirtioDevice::new(name, dev, memory, false).context(error::CreateVirtio)?; + Ok(dev) +} + +fn serve_conn(index: u32, conn: UnixStream, args: &VuArgs) -> Result<(), Error> { + let memory = Arc::new(RamBus::new()); + let dev = match &args.ty { + DevType::Net(args) => create_dev(format!("net-{index}"), args, memory.clone()), + DevType::Blk(args) => create_dev(format!("blk-{index}"), args, memory.clone()), + }?; + let mut backend = VuBackend::new(conn, dev, memory).context(error::CreateVu)?; + backend.run().context(error::Runtime) +} + +fn serve(index: u32, conn: UnixStream, args: &VuArgs) { + match serve_conn(index, conn, args) { + Ok(()) => log::info!("Serve {index}: done"), + Err(e) => log::error!("Serve {index}: {e:?}"), + } +} + +pub fn start(args: VuArgs) -> Result<(), Error> { + let listener = UnixListener::bind(&args.socket).context(error::Bind { + socket: &args.socket, + })?; + let args = Arc::new(args); + let mut index = 0; + for stream in listener.incoming() { + match stream { + Ok(conn) => { + let args = args.clone(); + spawn(move || serve(index, conn, &args)); + index = index.wrapping_add(1); + } + Err(e) => { + log::error!("Accept: {e:?}"); + } + } + } + Ok(()) +} diff --git a/alioth/src/net/net.rs b/alioth/src/net/net.rs index 065931e2..cd55a535 100644 --- a/alioth/src/net/net.rs +++ b/alioth/src/net/net.rs @@ -17,7 +17,7 @@ use serde::de::{self, Visitor}; use serde_aco::{Help, TypedHelp}; use zerocopy::{FromBytes, Immutable, IntoBytes}; -#[derive(Debug, Default, FromBytes, Immutable, IntoBytes, PartialEq, Eq)] +#[derive(Debug, Clone, Default, FromBytes, Immutable, IntoBytes, PartialEq, Eq)] #[repr(transparent)] pub struct MacAddr([u8; 6]); diff --git a/alioth/src/virtio/dev/net/net.rs b/alioth/src/virtio/dev/net/net.rs index 6acf78fd..61347aea 100644 --- a/alioth/src/virtio/dev/net/net.rs +++ b/alioth/src/virtio/dev/net/net.rs @@ -155,7 +155,7 @@ pub struct Net { api: WorkerApi, } -#[derive(Deserialize, Help)] +#[derive(Deserialize, Debug, Clone, Help)] pub struct NetParam { /// MAC address of the virtual NIC, e.g. 06:3a:76:53:da:3d. pub mac: MacAddr, From 09b40c064c8cceda4f02d9deb32c1c45fb442e75 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Tue, 29 Apr 2025 22:17:53 -0700 Subject: [PATCH 09/10] feat(cli): add cli interface for vhost-user frontends Signed-off-by: Changyuan Lyu --- alioth-cli/src/boot.rs | 141 ++++++++++++++++++++++++------- alioth-cli/src/vu.rs | 8 +- alioth/src/virtio/dev/blk.rs | 6 +- alioth/src/virtio/dev/net/net.rs | 8 +- 4 files changed, 123 insertions(+), 40 deletions(-) diff --git a/alioth-cli/src/boot.rs b/alioth-cli/src/boot.rs index ca456285..a2448136 100644 --- a/alioth-cli/src/boot.rs +++ b/alioth-cli/src/boot.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; #[cfg(target_arch = "x86_64")] use std::ffi::CString; #[cfg(target_arch = "x86_64")] @@ -22,24 +23,28 @@ use alioth::board::BoardConfig; #[cfg(target_arch = "x86_64")] use alioth::device::fw_cfg::FwCfgItemParam; use alioth::errors::{DebugTrace, trace_error}; -use alioth::hv::Coco; #[cfg(target_os = "macos")] use alioth::hv::Hvf; +use alioth::hv::{self, Coco}; #[cfg(target_os = "linux")] use alioth::hv::{Kvm, KvmConfig}; use alioth::loader::{ExecType, Payload}; use alioth::mem::{MemBackend, MemConfig}; #[cfg(target_os = "linux")] use alioth::vfio::{CdevParam, ContainerParam, GroupParam, IoasParam}; +#[cfg(target_os = "linux")] +use alioth::virtio::DeviceId; use alioth::virtio::dev::balloon::BalloonParam; -use alioth::virtio::dev::blk::BlockParam; +use alioth::virtio::dev::blk::BlkFileParam; use alioth::virtio::dev::entropy::EntropyParam; #[cfg(target_os = "linux")] use alioth::virtio::dev::fs::VuFsParam; #[cfg(target_os = "linux")] -use alioth::virtio::dev::net::NetParam; +use alioth::virtio::dev::net::NetTapParam; #[cfg(target_os = "linux")] use alioth::virtio::dev::vsock::VhostVsockParam; +#[cfg(target_os = "linux")] +use alioth::virtio::vu::frontend::VuFrontendParam; use alioth::vm::Machine; use clap::Args; use serde::Deserialize; @@ -121,6 +126,34 @@ enum VsockParam { Vhost(VhostVsockParam), } +#[cfg(target_os = "linux")] +#[derive(Deserialize, Help)] +struct VuSocket { + socket: PathBuf, +} + +#[cfg(target_os = "linux")] +#[derive(Deserialize, Help)] +enum NetParam { + /// VirtIO net device backed by TUN/TAP, MacVTap, or IPVTap. + #[serde(alias = "tap")] + Tap(NetTapParam), + /// vhost-user net device over a Unix domain socket. + #[serde(alias = "vu")] + Vu(VuSocket), +} + +#[derive(Deserialize, Help)] +enum BlkParam { + /// VirtIO block device backed a disk image file. + #[serde(alias = "file")] + File(BlkFileParam), + #[cfg(target_os = "linux")] + #[serde(alias = "vu")] + /// vhost-user block device over a Unix domain socket. + Vu(VuSocket), +} + #[derive(Args, Debug, Clone)] #[command(arg_required_else_help = true, alias("run"))] pub struct BootArgs { @@ -179,12 +212,12 @@ pub struct BootArgs { #[cfg(target_os = "linux")] #[arg(long, help( - help_text::("Add a VirtIO net device backed by TUN/TAP, MacVTap, or IPVTap.") + help_text::("Add a VirtIO net device.") ))] net: Vec, #[arg(long, help( - help_text::("Add a VirtIO block device.") + help_text::("Add a VirtIO block device.") ))] blk: Vec, @@ -233,6 +266,78 @@ pub struct BootArgs { objects: Vec, } +#[cfg(target_os = "linux")] +fn add_net( + vm: &Machine, + args: Vec, + objects: &HashMap<&str, &str>, +) -> Result<(), Error> +where + H: hv::Hypervisor + 'static, +{ + for (index, net_opt) in args.into_iter().enumerate() { + let net_param: NetParam = match serde_aco::from_args(&net_opt, objects) { + Ok(p) => p, + Err(_) => { + let tap_param = serde_aco::from_args::(&net_opt, objects) + .context(error::ParseArg { arg: net_opt })?; + NetParam::Tap(tap_param) + } + }; + match net_param { + NetParam::Tap(tap_param) => vm.add_virtio_dev(format!("virtio-net-{index}"), tap_param), + #[cfg(target_os = "linux")] + NetParam::Vu(sock) => { + let param = VuFrontendParam { + id: DeviceId::Net, + socket: sock.socket, + }; + vm.add_virtio_dev(format!("vu-net-{index}"), param) + } + } + .context(error::CreateDevice)?; + } + Ok(()) +} + +fn add_blk( + vm: &Machine, + args: Vec, + objects: &HashMap<&str, &str>, +) -> Result<(), Error> +where + H: hv::Hypervisor + 'static, +{ + for (index, opt) in args.into_iter().enumerate() { + let param: BlkParam = match serde_aco::from_args(&opt, objects) { + Ok(param) => param, + Err(_) => match serde_aco::from_args(&opt, objects) { + Ok(param) => BlkParam::File(param), + Err(_) => { + eprintln!("Please update the cmd line to --blk file,path={opt}"); + BlkParam::File(BlkFileParam { + path: opt.into(), + ..Default::default() + }) + } + }, + }; + match param { + BlkParam::File(p) => vm.add_virtio_dev(format!("virtio-blk-{index}"), p), + #[cfg(target_os = "linux")] + BlkParam::Vu(s) => { + let p = VuFrontendParam { + id: DeviceId::Block, + socket: s.socket, + }; + vm.add_virtio_dev(format!("vu-net-{index}"), p) + } + } + .context(error::CreateDevice)?; + } + Ok(()) +} + pub fn boot(args: BootArgs) -> Result<(), Error> { let objects = parse_objects(&args.objects)?; let hv_config = if let Some(hv_cfg_opt) = args.hypervisor { @@ -322,30 +427,8 @@ pub fn boot(args: BootArgs) -> Result<(), Error> { .context(error::CreateDevice)?; } #[cfg(target_os = "linux")] - for (index, net_opt) in args.net.into_iter().enumerate() { - let net_param: NetParam = - serde_aco::from_args(&net_opt, &objects).context(error::ParseArg { arg: net_opt })?; - vm.add_virtio_dev(format!("virtio-net-{index}"), net_param) - .context(error::CreateDevice)?; - } - for (index, blk) in args.blk.into_iter().enumerate() { - let param = match serde_aco::from_args(&blk, &objects) { - Ok(param) => param, - Err(serde_aco::Error::ExpectedMapEq) => { - eprintln!( - "Please update the cmd line to --blk path={blk}, see https://github.com/google/alioth/pull/72 for details" - ); - BlockParam { - path: blk.into(), - ..Default::default() - } - } - Err(e) => return Err(e).context(error::ParseArg { arg: blk })?, - }; - - vm.add_virtio_dev(format!("virtio-blk-{index}"), param) - .context(error::CreateDevice)?; - } + add_net(&vm, args.net, &objects)?; + add_blk(&vm, args.blk, &objects)?; #[cfg(target_os = "linux")] for (index, fs) in args.fs.into_iter().enumerate() { let param: FsParam = diff --git a/alioth-cli/src/vu.rs b/alioth-cli/src/vu.rs index 88469995..bde9f744 100644 --- a/alioth-cli/src/vu.rs +++ b/alioth-cli/src/vu.rs @@ -20,8 +20,8 @@ use std::thread::spawn; use alioth::errors::{DebugTrace, trace_error}; use alioth::mem::mapped::RamBus; -use alioth::virtio::dev::blk::BlockParam; -use alioth::virtio::dev::net::NetParam; +use alioth::virtio::dev::blk::BlkFileParam; +use alioth::virtio::dev::net::NetTapParam; use alioth::virtio::dev::{DevParam, Virtio, VirtioDevice}; use alioth::virtio::vu::backend::{VuBackend, VuEventfd, VuIrqSender}; use clap::{Args, Subcommand}; @@ -81,9 +81,9 @@ where #[derive(Subcommand, Debug, Clone)] pub enum DevType { /// VirtIO net device backed by TUN/TAP, MacVTap, or IPVTap. - Net(DevArgs), + Net(DevArgs), /// VirtIO block device backed by a file. - Blk(DevArgs), + Blk(DevArgs), } #[derive(Args, Debug, Clone)] diff --git a/alioth/src/virtio/dev/blk.rs b/alioth/src/virtio/dev/blk.rs index 26b27315..0825b6ba 100644 --- a/alioth/src/virtio/dev/blk.rs +++ b/alioth/src/virtio/dev/blk.rs @@ -141,7 +141,7 @@ pub struct BlockConfig { impl_mmio_for_zerocopy!(BlockConfig); #[derive(Debug, Clone, Deserialize, Help, Default)] -pub struct BlockParam { +pub struct BlkFileParam { /// Path to a raw-formatted disk image. pub path: PathBuf, /// Set the device as readonly. [default: false] @@ -152,7 +152,7 @@ pub struct BlockParam { pub api: WorkerApi, } -impl DevParam for BlockParam { +impl DevParam for BlkFileParam { type Device = Block; fn build(self, name: impl Into>) -> Result { @@ -189,7 +189,7 @@ pub struct Block { } impl Block { - pub fn new(param: BlockParam, name: impl Into>) -> Result { + pub fn new(param: BlkFileParam, name: impl Into>) -> Result { let access_disk = error::AccessFile { path: param.path.as_path(), }; diff --git a/alioth/src/virtio/dev/net/net.rs b/alioth/src/virtio/dev/net/net.rs index 61347aea..21d75e51 100644 --- a/alioth/src/virtio/dev/net/net.rs +++ b/alioth/src/virtio/dev/net/net.rs @@ -155,8 +155,8 @@ pub struct Net { api: WorkerApi, } -#[derive(Deserialize, Debug, Clone, Help)] -pub struct NetParam { +#[derive(Debug, Deserialize, Clone, Help)] +pub struct NetTapParam { /// MAC address of the virtual NIC, e.g. 06:3a:76:53:da:3d. pub mac: MacAddr, /// Maximum transmission unit. @@ -179,7 +179,7 @@ pub struct NetParam { pub api: WorkerApi, } -impl DevParam for NetParam { +impl DevParam for NetTapParam { type Device = Net; fn build(self, name: impl Into>) -> Result { @@ -199,7 +199,7 @@ fn new_socket(dev_tap: Option<&Path>, blocking: bool) -> Result { } impl Net { - pub fn new(param: NetParam, name: impl Into>) -> Result { + pub fn new(param: NetTapParam, name: impl Into>) -> Result { let mut socket = new_socket( param.tap.as_deref(), matches!(param.api, WorkerApi::IoUring), From 94d0725fe3fa86061a811d3e3bba4d66347e2d78 Mon Sep 17 00:00:00 2001 From: Changyuan Lyu Date: Sat, 3 May 2025 20:19:10 -0700 Subject: [PATCH 10/10] fix(vu): transfer device config with variable length Signed-off-by: Changyuan Lyu --- alioth/src/virtio/dev/fs.rs | 13 +++-- alioth/src/virtio/vu/backend.rs | 24 ++++----- alioth/src/virtio/vu/bindings.rs | 3 +- alioth/src/virtio/vu/conn.rs | 90 +++++++++++++++++++++++++++----- alioth/src/virtio/vu/frontend.rs | 17 +++--- 5 files changed, 107 insertions(+), 40 deletions(-) diff --git a/alioth/src/virtio/dev/fs.rs b/alioth/src/virtio/dev/fs.rs index dce140d8..3661b74c 100644 --- a/alioth/src/virtio/dev/fs.rs +++ b/alioth/src/virtio/dev/fs.rs @@ -101,10 +101,15 @@ impl VuFs { } config } else { - let mut empty_cfg = DeviceConfig::new_zeroed(); - empty_cfg.size = size_of_val(&empty_cfg.region) as _; - let dev_config = frontend.session().get_config(&empty_cfg)?; - FsConfig::read_from_prefix(&dev_config.region).unwrap().0 + let cfg = DeviceConfig { + offset: 0, + size: size_of::() as u32, + flags: 0, + }; + let mut config = FsConfig::new_zeroed(); + frontend.session().get_config(&cfg, config.as_mut_bytes())?; + log::info!("{}: get config: {config:?}", frontend.name()); + config }; let mut dax_region = None; diff --git a/alioth/src/virtio/vu/backend.rs b/alioth/src/virtio/vu/backend.rs index 9b922900..e833de1d 100644 --- a/alioth/src/virtio/vu/backend.rs +++ b/alioth/src/virtio/vu/backend.rs @@ -23,7 +23,7 @@ use std::sync::atomic::Ordering; use macros::trace_error; use snafu::Snafu; -use zerocopy::{FromZeros, IntoBytes}; +use zerocopy::IntoBytes; use crate::errors::DebugTrace; use crate::hv::IoeventFd; @@ -31,7 +31,7 @@ use crate::mem::mapped::{ArcMemPages, RamBus}; use crate::virtio::dev::{StartParam, VirtioDevice, WakeEvent}; use crate::virtio::vu::Error as VuError; use crate::virtio::vu::bindings::{ - DeviceConfig, MemoryRegion, MemorySingleRegion, Message, VirtqAddr, VirtqState, VuFeature, + MAX_CONFIG_SIZE, MemoryRegion, MemorySingleRegion, Message, VirtqAddr, VirtqState, VuFeature, VuFrontMsg, }; use crate::virtio::vu::conn::{VuChannel, VuSession}; @@ -446,33 +446,31 @@ impl VuBackend { self.wake_up_dev(event); } } - (VuFrontMsg::GET_CONFIG, 268) => { - let dev_config: DeviceConfig = self.session.recv_payload()?; + (VuFrontMsg::GET_CONFIG, 12..) => { + let mut region = [0u8; MAX_CONFIG_SIZE]; + let dev_config = self.session.recv_config(&mut region)?; let mut done = 0; - let mut resp = DeviceConfig::new_zeroed(); while let Some(n) = (dev_config.size as usize - done).checked_ilog2() { let size = min(1 << n, 8) as u8; let offset = dev_config.offset as u64 + done as u64; let v = self.dev.device_config.read(offset, size)?; - resp.region[done..(done + size as usize)] + region[done..(done + size as usize)] .copy_from_slice(&v.as_bytes()[..size as usize]); done += size as usize; } - resp.offset = dev_config.offset; - resp.size = dev_config.size; - resp.flags = dev_config.flags; - self.session.reply(req, &resp, &[])?; + self.session.reply_config(&dev_config, ®ion[..done])?; log::debug!("{name}: get config: {dev_config:?}"); msg.flag.set_need_reply(false); } - (VuFrontMsg::SET_CONFIG, 268) => { - let dev_config: DeviceConfig = self.session.recv_payload()?; + (VuFrontMsg::SET_CONFIG, 12..) => { + let mut region = [0u8; MAX_CONFIG_SIZE]; + let dev_config = self.session.recv_config(&mut region)?; let mut done = 0; while let Some(n) = (dev_config.size as usize - done).checked_ilog2() { let size = min(1 << n, 8) as u8; let mut v = 0; v.as_mut_bytes()[..size as usize] - .copy_from_slice(&dev_config.region[done..(done + size as usize)]); + .copy_from_slice(®ion[done..(done + size as usize)]); let offset = dev_config.offset as u64 + done as u64; self.dev.device_config.write(offset, size, v)?; done += size as usize; diff --git a/alioth/src/virtio/vu/bindings.rs b/alioth/src/virtio/vu/bindings.rs index e603b920..3cd7ef87 100644 --- a/alioth/src/virtio/vu/bindings.rs +++ b/alioth/src/virtio/vu/bindings.rs @@ -180,13 +180,14 @@ pub struct MemoryMultipleRegion { pub regions: [MemoryRegion; 8], } +pub const MAX_CONFIG_SIZE: usize = 256; + #[derive(Debug, IntoBytes, FromBytes, Immutable, KnownLayout)] #[repr(C)] pub struct DeviceConfig { pub offset: u32, pub size: u32, pub flags: u32, - pub region: [u8; 256], } #[derive(Debug, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)] diff --git a/alioth/src/virtio/vu/conn.rs b/alioth/src/virtio/vu/conn.rs index a9399830..ca65ad5d 100644 --- a/alioth/src/virtio/vu/conn.rs +++ b/alioth/src/virtio/vu/conn.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::io::{IoSlice, IoSliceMut, Read}; +use std::io::{IoSlice, IoSliceMut, Read, Write}; use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd}; use std::os::unix::net::UnixStream; use std::path::Path; @@ -23,12 +23,19 @@ use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes}; use crate::ffi; use crate::utils::uds::{recv_msg_with_fds, send_msg_with_fds}; use crate::virtio::vu::bindings::{ - DeviceConfig, MemorySingleRegion, Message, MessageFlag, VirtqAddr, VirtqState, VuBackMsg, - VuFrontMsg, + DeviceConfig, MAX_CONFIG_SIZE, MemorySingleRegion, Message, MessageFlag, VirtqAddr, VirtqState, + VuBackMsg, VuFrontMsg, }; use crate::virtio::vu::{Result, error}; -fn send(mut conn: &UnixStream, req: u32, payload: &T, fds: &[BorrowedFd]) -> Result +fn send( + mut conn: &UnixStream, + req: u32, + payload: &T, + in_: &[u8], + out: &mut [u8], + fds: &[BorrowedFd], +) -> Result where T: IntoBytes + Immutable, R: FromBytes + IntoBytes, @@ -36,11 +43,12 @@ where let vhost_msg = Message { request: req, flag: MessageFlag::sender(), - size: size_of::() as u32, + size: (size_of::() + in_.len()) as u32, }; let bufs = [ IoSlice::new(vhost_msg.as_bytes()), IoSlice::new(payload.as_bytes()), + IoSlice::new(in_), ]; let done = send_msg_with_fds(conn, &bufs, fds)?; let want = size_of_val(&vhost_msg) + vhost_msg.size as usize; @@ -58,9 +66,10 @@ where } else { IoSliceMut::new(ret_code.as_mut_bytes()) }, + IoSliceMut::new(out), ]; - let resp_size = bufs[1].len() as u32; - let expect_size = size_of::() + bufs[1].len(); + let resp_size = bufs[1].len() + bufs[2].len(); + let expect_size = size_of::() + resp_size; let size = conn.read_vectored(&mut bufs)?; if size != expect_size { @@ -77,9 +86,9 @@ where } .fail(); } - if resp.size != resp_size { + if resp.size as usize != resp_size { return error::PayloadSize { - want: size_of::(), + want: resp_size, got: resp.size, } .fail(); @@ -112,6 +121,26 @@ where Ok(()) } +fn reply_config(mut conn: &UnixStream, config: &DeviceConfig, buf: &[u8]) -> Result<()> { + let msg = Message { + request: VuFrontMsg::GET_CONFIG.raw(), + flag: MessageFlag::receiver(), + size: (size_of_val(config) + buf.len()) as _, + }; + let bufs = [ + IoSlice::new(msg.as_bytes()), + IoSlice::new(config.as_bytes()), + IoSlice::new(buf), + ]; + let done = conn.write_vectored(&bufs)?; + let want = size_of_val(&msg) + msg.size as usize; + if done != want { + return error::PartialWrite { want, done }.fail(); + } + + Ok(()) +} + fn recv_with_fds(conn: &UnixStream, fds: &mut [Option]) -> Result where T: IntoBytes + Immutable + FromBytes, @@ -130,6 +159,24 @@ where } } +fn recv_config(mut conn: &UnixStream, buf: &mut [u8]) -> Result { + let mut dev_config = DeviceConfig::new_zeroed(); + let mut bufs = [ + IoSliceMut::new(dev_config.as_mut_bytes()), + IoSliceMut::new(buf), + ]; + let got = conn.read_vectored(&mut bufs)?; + let want = size_of::() + dev_config.size as usize; + if got != want { + return error::PayloadSize { + want, + got: got as u32, + } + .fail(); + } + Ok(dev_config) +} + #[derive(Debug)] pub struct VuSession { pub conn: UnixStream, @@ -148,7 +195,7 @@ impl VuSession { T: IntoBytes + Immutable, R: FromBytes + IntoBytes, { - send(&self.conn, req.raw(), payload, fds) + send(&self.conn, req.raw(), payload, &[], &mut [], fds) } pub fn recv_payload(&self) -> Result @@ -158,6 +205,10 @@ impl VuSession { recv_with_fds(&self.conn, &mut []) } + pub fn recv_config(&self, buf: &mut [u8]) -> Result { + recv_config(&self.conn, buf) + } + pub fn recv_msg(&self, fds: &mut [Option]) -> Result { recv_with_fds(&self.conn, fds) } @@ -171,6 +222,10 @@ impl VuSession { reply(&self.conn, req.raw(), payload, fds) } + pub fn reply_config(&self, config: &DeviceConfig, buf: &[u8]) -> Result<()> { + reply_config(&self.conn, config, buf) + } + pub fn get_features(&self) -> Result { self.send(VuFrontMsg::GET_FEATURES, &(), &[]) } @@ -203,12 +258,19 @@ impl VuSession { self.send(VuFrontMsg::SET_VIRTQ_BASE, payload, &[]) } - pub fn get_config(&self, payload: &DeviceConfig) -> Result { - self.send(VuFrontMsg::GET_CONFIG, payload, &[]) + pub fn get_config(&self, payload: &DeviceConfig, buf: &mut [u8]) -> Result { + let in_ = [0; MAX_CONFIG_SIZE]; + let len = buf.len(); + let req = VuFrontMsg::GET_CONFIG.raw(); + send(&self.conn, req, payload, &in_[..len], buf, &[]) } - pub fn set_config(&self, payload: &DeviceConfig) -> Result<()> { - self.send(VuFrontMsg::SET_CONFIG, payload, &[]) + pub fn set_config(&self, payload: &DeviceConfig, buf: &[u8]) -> Result<()> + where + DeviceConfig: IntoBytes, + { + let req = VuFrontMsg::SET_CONFIG.raw(); + send(&self.conn, req, payload, buf, &mut [], &[]) } pub fn get_virtq_base(&self, payload: &VirtqState) -> Result { diff --git a/alioth/src/virtio/vu/frontend.rs b/alioth/src/virtio/vu/frontend.rs index fc815df6..69663113 100644 --- a/alioth/src/virtio/vu/frontend.rs +++ b/alioth/src/virtio/vu/frontend.rs @@ -108,24 +108,25 @@ impl Mmio for VuDevConfig { offset: offset as u32, size: size as u32, flags: 0, - region: [0u8; 256], }; - let resp = self.session.get_config(&req).unwrap(); let mut ret = 0u64; - ret.as_mut_bytes().copy_from_slice(&resp.region[0..8]); - ret &= u64::MAX >> (64 - (size << 3)); + let buf = &mut ret.as_mut_bytes()[..size as usize]; + self.session + .get_config(&req, buf) + .box_trace(mem::error::Mmio)?; Ok(ret) } fn write(&self, offset: u64, size: u8, val: u64) -> mem::Result { - let mut req = DeviceConfig { + let req = DeviceConfig { offset: offset as u32, size: size as u32, flags: 0, - region: [0u8; 256], }; - req.region[0..8].copy_from_slice(val.as_bytes()); - self.session.set_config(&req).unwrap(); + let buf = &val.as_bytes()[..size as usize]; + self.session + .set_config(&req, buf) + .box_trace(mem::error::Mmio)?; Ok(Action::None) } }