From 47d1167369ec977e771bdbf2d4bdafb49c0ff58d Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Tue, 20 May 2025 07:14:54 -0700 Subject: [PATCH 01/13] Enhanced Allocation capability support This adds minimal support for the Enhanced Allocation capability. We only support the 64-bit entry format and ignore emulated reads/writes. --- drivers/src/pci/capabilities.rs | 208 +++++++++++++++++++++++++++++++- drivers/src/pci/registers.rs | 51 ++++++++ 2 files changed, 256 insertions(+), 3 deletions(-) diff --git a/drivers/src/pci/capabilities.rs b/drivers/src/pci/capabilities.rs index d84b5c85..e8f8d59f 100644 --- a/drivers/src/pci/capabilities.rs +++ b/drivers/src/pci/capabilities.rs @@ -7,11 +7,13 @@ use core::mem::size_of; use enum_dispatch::enum_dispatch; use memoffset::offset_of; use tock_registers::interfaces::{Readable, Writeable}; +use tock_registers::registers::ReadOnly; use tock_registers::LocalRegisterCopy; use super::error::*; use super::mmio_builder::{MmioReadBuilder, MmioWriteBuilder}; use super::registers::*; +use super::resource::*; // Standard PCI capability IDs. #[repr(u8)] @@ -23,6 +25,7 @@ enum CapabilityId { BridgeSubsystem = 13, PciExpress = 16, MsiX = 17, + EnhancedAllocation = 20, } impl CapabilityId { @@ -36,6 +39,7 @@ impl CapabilityId { 13 => Some(BridgeSubsystem), 16 => Some(PciExpress), 17 => Some(MsiX), + 20 => Some(EnhancedAllocation), _ => None, } } @@ -114,6 +118,7 @@ enum CapabilityType { Vendor, BridgeSubsystem, PciExpress, + EnhancedAllocation, } // Common functionality required by all capabilities. @@ -437,6 +442,149 @@ impl Capability for BridgeSubsystem { } } +// The maximum number of extend allocation entries. The capability's num_entries field allows up to +// 63 entries, but we have no reason to support that many. +const MAX_ENHANCED_ALLOCATION_ENTRIES: usize = 8; + +pub struct EnhancedAllocationEntry { + fixed: &'static mut EnhancedAllocationEntryFixed, + base_high: Option<&'static mut ReadOnly>, + max_offset_high: Option<&'static mut ReadOnly>, +} + +impl EnhancedAllocationEntry { + pub fn enabled(&self) -> bool { + self.fixed + .header + .read(EnhancedAllocationEntryHeader::Enable) + != 0 + } + + pub fn bar_equivalent_indicator(&self) -> usize { + self.fixed + .header + .read(EnhancedAllocationEntryHeader::BarEquivalentIndicator) as usize + } + + fn properties(&self) -> Option { + use EnhancedAllocationEntryHeader::*; + self.fixed + .header + .read_as_enum(PrimaryProperties) + .or_else(|| self.fixed.header.read_as_enum(SecondaryProperties)) + } + + pub fn resource_type(&self) -> Option { + use EnhancedAllocationEntryHeader::PrimaryProperties::Value::*; + Some(match self.properties()? { + Mem => PciResourceType::Mem64, + PrefetchableMem => PciResourceType::PrefetchableMem64, + IoPort => PciResourceType::IoPort, + _ => return None, + }) + } + + pub fn base(&self) -> u64 { + use EnhancedAllocationEntryAddress::Address; + let low = self.fixed.base.read(Address) << Address.shift; + let high = self.base_high.as_ref().map(|r| r.get()).unwrap_or(0); + low as u64 | ((high as u64) << 32) + } + + pub fn max_offset(&self) -> u64 { + use EnhancedAllocationEntryAddress::Address; + let low = self.fixed.max_offset.read(Address) << Address.shift; + // Fill in the low-order bits. + let low = low | !(Address.mask << Address.shift); + let high = self.max_offset_high.as_ref().map(|r| r.get()).unwrap_or(0); + low as u64 | ((high as u64) << 32) + } +} + +pub struct EnhancedAllocation { + pub entries: ArrayVec, + length: usize, +} + +impl EnhancedAllocation { + fn new(config_regs: &mut CommonRegisters, header: &mut CapabilityHeader) -> Self { + // Safety: `header` points to a valid and uniquely-owned capability structure and we are + // trusting that the hardware reported the type of the capability correctly. + let header = unsafe { + (header as *mut CapabilityHeader as *mut EnhancedAllocationHeader) + .as_mut() + .unwrap() + }; + + let mut length = size_of::(); + if config_regs.header_type.read(Type::Layout) == 1 { + length += size_of::(); + } + + let num_entries = header + .num_entries + .read(EnhancedAllocationNumEntries::NumEntries) as usize; + let mut entries = ArrayVec::new(); + for _ in 0..core::cmp::min(num_entries, MAX_ENHANCED_ALLOCATION_ENTRIES) { + // Safety: `header` points to a valid capability structure and we are trusting that the + // hardware-supplied num_entries field doesn't move us outside the ECAM memory region. + let fixed = unsafe { + ((header as *mut EnhancedAllocationHeader).byte_add(length) + as *mut EnhancedAllocationEntryFixed) + .as_mut() + .unwrap() + }; + + let num_dw = fixed.header.read(EnhancedAllocationEntryHeader::Size) as usize + 1; + let size = num_dw * size_of::(); + length += size; + + if size < size_of::() { + // entry too short, ignore. + continue; + } + + // Compute references to optional registers following the fixed part. + let ptr = fixed as *mut EnhancedAllocationEntryFixed; + let mut regs = (size_of::()..size) + .step_by(size_of::()) + // Safety: the memory location is within the entry of a valid capability structure, + // and we trust hardware with ECAM layout. + .map(|pos| unsafe { (ptr.byte_add(pos) as *mut ReadOnly).as_mut().unwrap() }); + + use EnhancedAllocationEntryAddress::Size; + let base_high = (fixed.base.read(Size) == 1).then(|| regs.next()).flatten(); + let max_offset_high = (fixed.max_offset.read(Size) == 1) + .then(|| regs.next()) + .flatten(); + + entries.push(EnhancedAllocationEntry { + fixed, + base_high, + max_offset_high, + }); + } + + Self { entries, length } + } +} + +impl Capability for EnhancedAllocation { + fn length(&self) -> usize { + self.length + } + + fn emulate_read(&self, op: &mut MmioReadBuilder, _cap_offset: usize) { + // TODO: Implement reading if necessary. + op.push_byte(0); + } + + fn emulate_write(&mut self, op: &mut MmioWriteBuilder, _cap_offset: usize) { + // TODO: support entry enable bit writes if necessary. + op.pop_byte(); + } +} + // The possible PCI express device types as reported in the FLAGS register of the PCI express // capabilities register. #[repr(u8)] @@ -592,7 +740,12 @@ struct PciCapability { impl PciCapability { // Creates a new capability of type `id` at `header`, which itself is at `offset` within the // configuration space. - fn new(header: &mut CapabilityHeader, id: CapabilityId, offset: usize) -> Result { + fn new( + config_regs: &mut CommonRegisters, + header: &mut CapabilityHeader, + id: CapabilityId, + offset: usize, + ) -> Result { let cap_type = match id { CapabilityId::PowerManagement => PowerManagement::new(header).into(), CapabilityId::Msi => Msi::new(header)?.into(), @@ -600,6 +753,7 @@ impl PciCapability { CapabilityId::Vendor => Vendor::new(header)?.into(), CapabilityId::BridgeSubsystem => BridgeSubsystem::new(header).into(), CapabilityId::PciExpress => PciExpress::new(header)?.into(), + CapabilityId::EnhancedAllocation => EnhancedAllocation::new(config_regs, header).into(), }; Ok(PciCapability { id, @@ -694,7 +848,7 @@ impl PciCapabilities { // that the hardware provides a valid config space. current_offset = (header.next_cap.get() as usize) & !0x3; if let Some(id) = CapabilityId::from_raw(header.cap_id.get()) { - let cap = PciCapability::new(header, id, offset)?; + let cap = PciCapability::new(config_regs, header, id, offset)?; caps.try_push(cap).map_err(|_| Error::TooManyCapabilities)?; }; } @@ -746,6 +900,15 @@ impl PciCapabilities { self.caps.first().map(|cap| cap.offset()).unwrap_or(0) } + /// Returns the enhanced allocation capability, if present. + pub fn enhanced_allocation(&self) -> Option<&EnhancedAllocation> { + self.capability_by_id(CapabilityId::EnhancedAllocation) + .and_then(|c| match c.cap_type { + CapabilityType::EnhancedAllocation(ref cap) => Some(cap), + _ => None, + }) + } + // Returns a reference to the capability at `offset`. fn capability_by_offset(&self, offset: usize) -> Option<&PciCapability> { self.caps @@ -782,7 +945,24 @@ mod tests { test_config[20] = 0x0000_0004; test_config[21] = 0xaaaa_5c03; // VPD (don't care) test_config[22] = 0xbbbb_cccc; - test_config[23] = 0x0004_0009; // Vendor + test_config[23] = 0x0004_6009; // Vendor + test_config[24] = 0x0004_0014; // Enhanced Allocation + test_config[25] = 0x8000_0002; // - entry 0 + test_config[26] = 0x1000_0000; // - base + test_config[27] = 0x0000_0ffc; // - max_offset + test_config[28] = 0x8000_0003; // - entry 1 + test_config[29] = 0x2000_0002; // - base + test_config[30] = 0x0000_fffc; // - max_offset + test_config[31] = 0x0000_0002; // - base high + test_config[32] = 0x8000_0003; // - entry 2 + test_config[33] = 0x3000_0000; // - base + test_config[34] = 0xffff_fffe; // - max_offset + test_config[35] = 0x0000_000f; // - max_offset high + test_config[36] = 0x8001_1724; // - entry 3 + test_config[37] = 0x4000_0002; // - base + test_config[38] = 0xffff_fffe; // - max_offset + test_config[39] = 0x0000_0004; // - base high + test_config[40] = 0x0000_00ff; // - max_offset high let mut header_mem: Vec = test_config .iter() .map(|v| v.to_le_bytes()) @@ -806,5 +986,27 @@ mod tests { .offset(), 0x5c ); + + let ea_cap = caps.enhanced_allocation().unwrap(); + assert_eq!(ea_cap.entries.len(), 4); + assert!(ea_cap.entries[0].enabled()); + assert_eq!(ea_cap.entries[0].bar_equivalent_indicator(), 0); + assert_eq!( + ea_cap.entries[0].resource_type(), + Some(PciResourceType::Mem64) + ); + assert_eq!(ea_cap.entries[0].base(), 0x1000_0000); + assert_eq!(ea_cap.entries[0].max_offset(), 0xfff); + assert_eq!(ea_cap.entries[1].base(), 0x2_2000_0000); + assert_eq!(ea_cap.entries[1].max_offset(), 0xffff); + assert_eq!(ea_cap.entries[2].base(), 0x3000_0000); + assert_eq!(ea_cap.entries[2].max_offset(), 0xf_ffff_ffff); + assert_eq!(ea_cap.entries[3].bar_equivalent_indicator(), 2); + assert_eq!( + ea_cap.entries[3].resource_type(), + Some(PciResourceType::PrefetchableMem64) + ); + assert_eq!(ea_cap.entries[3].base(), 0x4_4000_0000); + assert_eq!(ea_cap.entries[3].max_offset(), 0xff_ffff_ffff); } } diff --git a/drivers/src/pci/registers.rs b/drivers/src/pci/registers.rs index 46cd31e6..74e4600d 100644 --- a/drivers/src/pci/registers.rs +++ b/drivers/src/pci/registers.rs @@ -111,6 +111,10 @@ register_bitfields![u16, pub MemWindow [ Address OFFSET(4) NUMBITS(12) [], ], + + pub EnhancedAllocationNumEntries [ + NumEntries OFFSET(0) NUMBITS(6), + ], ]; register_bitfields![u8, @@ -157,6 +161,34 @@ register_bitfields![u32, MaxLinkWidth OFFSET(4) NUMBITS(6), PortNumber OFFSET(24) NUMBITS(8), ], + + pub EnhancedAllocationEntryHeader [ + Size OFFSET(0) NUMBITS(3) [], + BarEquivalentIndicator OFFSET(4) NUMBITS(4) [], + PrimaryProperties OFFSET(8) NUMBITS(8) + [ + Mem = 0, + PrefetchableMem = 1, + IoPort = 2, + VirtualFunctionPrefetchableMem = 3, + VirtualFunctionMem = 4, + BridgeMem = 5, + BridgePrefetchableMem = 6, + BridgeIoPort = 7, + UnavailableMem = 0xfd, + UnavailableIoPort = 0xfe, + Unavailable = 0xff, + ], + // Actually uses same enum constants as PrimaryProperties. + SecondaryProperties OFFSET(16) NUMBITS(8) [], + Writable OFFSET(30) NUMBITS(1) [], + Enable OFFSET(31) NUMBITS(1) [], + ], + + pub EnhancedAllocationEntryAddress [ + Size OFFSET(1) NUMBITS(1), + Address OFFSET(2) NUMBITS(30), + ], ]; /// Common portion of the PCI configuration header. @@ -346,6 +378,25 @@ pub struct ExpressRegisters { pub slot_status2: ReadOnly, } +/// Enhanced allocation capability. This is followed by a variable number of entries. +#[repr(C)] +#[derive(FieldOffsets)] +pub struct EnhancedAllocationHeader { + pub header: CapabilityHeader, + pub num_entries: ReadOnly, +} + +/// Enhanced allocation entry. This only covers the fixed layout part, the registers holding the +/// upper halves of base and max_offset are optional and at variable offsets depending on the Size +/// field in the respective lower half register. +#[repr(C)] +#[derive(FieldOffsets)] +pub struct EnhancedAllocationEntryFixed { + pub header: ReadOnly, + pub base: ReadOnly, + pub max_offset: ReadOnly, +} + /// Trait for specifying various mask values for a register. /// /// TODO: Make the `*_mask()` functions const values. From 64e8c99f31e8d1242caf1e9def541eb6c486a902 Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Thu, 22 May 2025 00:06:39 -0700 Subject: [PATCH 02/13] Consult Enhanced Allocation when discovering BARs In addition to the BARs in the PCI header, we now also reflect BARs specified in Enhanced Allocation capabilities into `PciDeviceBarInfo`. --- drivers/src/pci/device.rs | 83 +++++++++++++++++++++++++++++++++------ drivers/src/pci/error.rs | 2 + drivers/src/pci/root.rs | 36 ++++++++++++++--- 3 files changed, 105 insertions(+), 16 deletions(-) diff --git a/drivers/src/pci/device.rs b/drivers/src/pci/device.rs index 72bbc395..cd3c2d86 100644 --- a/drivers/src/pci/device.rs +++ b/drivers/src/pci/device.rs @@ -225,10 +225,18 @@ mod bridge_offsets { define_field_span!(BridgeRegisters, bridge_control, u16); } +/// Indicates whether a BAR was found in the PCI header or via the Enhanced Allocation capability. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BarProvenance { + PciHeader, + EnhancedAllocation, +} + /// Describes a single PCI BAR. #[derive(Clone, Debug)] pub struct PciBarInfo { index: usize, + provenance: BarProvenance, bar_type: PciResourceType, size: u64, } @@ -239,6 +247,11 @@ impl PciBarInfo { self.index } + /// Returns the provenance of this BAR. + pub fn provenance(&self) -> BarProvenance { + self.provenance + } + /// Returns the type of resource this BAR maps. pub fn bar_type(&self) -> PciResourceType { self.bar_type @@ -258,8 +271,34 @@ pub struct PciDeviceBarInfo { impl PciDeviceBarInfo { // Probes the size and type of each BAR from `registers`. - fn new(registers: &mut [ReadWrite]) -> Result { + fn new( + registers: &mut [ReadWrite], + enhanced_allocation: Option<&EnhancedAllocation>, + ) -> Result { let mut bars = ArrayVec::new(); + + // Check the enhanced allocation capability first to give regions discovered via enhanced + // allocation precedence. Per the specification, traditional BARs must be hardware-wired to + // zero if there is a corresponding enhanced allocation entry. + if let Some(enhanced_allocation) = enhanced_allocation { + for entry in enhanced_allocation.entries.iter() { + let bei = entry.bar_equivalent_indicator(); + if bei > PCI_ENDPOINT_BARS || !entry.enabled() { + continue; + } + + if let Some(bar_type) = entry.resource_type() { + let bar = PciBarInfo { + index: bei, + provenance: BarProvenance::EnhancedAllocation, + bar_type, + size: entry.max_offset() + 1, + }; + bars.push(bar); + } + } + } + let mut index = 0; while index < registers.len() { let bar_index = index; @@ -293,8 +332,14 @@ impl PciDeviceBarInfo { return Err(Error::InvalidBarSize(size)); } + // Ignore BARs that conflict with enhanced allocation entries. + if bars.iter().any(|b| b.index == bar_index) { + continue; + } + let bar = PciBarInfo { index: bar_index, + provenance: BarProvenance::PciHeader, bar_type, size, }; @@ -319,6 +364,7 @@ impl PciDeviceBarInfo { // If `index` is the upper half of a 64-bit BAR, return the type of the lower half. self.bars .iter() + .filter(|b| b.provenance() == BarProvenance::PciHeader) .find(|b| b.index() == index || (b.index() + 1 == index && b.bar_type().is_64bit())) .map(|b| b.bar_type()) } @@ -344,7 +390,8 @@ impl PciEndpoint { fn new(registers: &'static mut EndpointRegisters, info: PciDeviceInfo) -> Result { let capabilities = PciCapabilities::new(&mut registers.common, registers.cap_ptr.get() as usize)?; - let bar_info = PciDeviceBarInfo::new(&mut registers.bar)?; + let bar_info = + PciDeviceBarInfo::new(&mut registers.bar, capabilities.enhanced_allocation())?; let common = PciDeviceCommon { info, capabilities, @@ -440,7 +487,8 @@ impl PciBridge { }; let capabilities = PciCapabilities::new(&mut registers.common, registers.cap_ptr.get() as usize)?; - let bar_info = PciDeviceBarInfo::new(&mut registers.bar)?; + let bar_info = + PciDeviceBarInfo::new(&mut registers.bar, capabilities.enhanced_allocation())?; let common = PciDeviceCommon { info, capabilities, @@ -976,14 +1024,24 @@ impl PciDevice { .bar_info() .get(index) .ok_or(Error::BarNotPresent(index))?; - let regs = self.bar_registers(); - let addr_lo = regs[index].get() & !((1u32 << BaseAddress::Address.shift) - 1); - let addr_hi = if bar.bar_type().is_64bit() { - regs[index + 1].get() - } else { - 0 - }; - Ok((addr_lo as u64) | ((addr_hi as u64) << 32)) + match bar.provenance { + BarProvenance::PciHeader => { + let regs = self.bar_registers(); + let addr_lo = regs[index].get() & !((1u32 << BaseAddress::Address.shift) - 1); + let addr_hi = if bar.bar_type().is_64bit() { + regs[index + 1].get() + } else { + 0 + }; + Ok((addr_lo as u64) | ((addr_hi as u64) << 32)) + } + BarProvenance::EnhancedAllocation => { + // OK to unwrap the enhanced allocation capability and its entry because the + // PciBarInfo existence implies their presence. + let cap = self.common().capabilities.enhanced_allocation().unwrap(); + Ok(cap.entries.get(index).unwrap().base()) + } + } } /// Programs the BAR at `bar_index` with the given address. @@ -992,6 +1050,9 @@ impl PciDevice { .bar_info() .get(index) .ok_or(Error::BarNotPresent(index))?; + if bar.provenance != BarProvenance::PciHeader { + return Err(Error::BarIsFixed(index)); + } let regs = self.bar_registers(); regs[index].set(pci_addr as u32); if bar.bar_type().is_64bit() { diff --git a/drivers/src/pci/error.rs b/drivers/src/pci/error.rs index fb1f1db3..9f4784a2 100644 --- a/drivers/src/pci/error.rs +++ b/drivers/src/pci/error.rs @@ -79,6 +79,8 @@ pub enum Error { InvalidBarAddress(u64), /// The device does not have a BAR at the specified index. BarNotPresent(usize), + /// The BAR can't be programmed because the address is fixed. + BarIsFixed(usize), /// A VM has programmed a BAR or bridge window to cover a page it does not own. UnownedBarPage(SupervisorPageAddr), /// The PCI device is already owned. diff --git a/drivers/src/pci/root.rs b/drivers/src/pci/root.rs index 8237cb64..60ce3998 100644 --- a/drivers/src/pci/root.rs +++ b/drivers/src/pci/root.rs @@ -305,11 +305,37 @@ impl PcieRoot { // Now assign BAR resources. let bar_info = dev.bar_info().clone(); for bar in bar_info.bars() { - let range = self.alloc_hypervisor_resource(bar.bar_type(), bar.size())?; - // `range.base()` must be within a PCI root window. - let base = self.physical_to_pci_addr(range.base().into()).unwrap(); - // BAR index must be valid. - dev.set_bar_addr(bar.index(), base).unwrap(); + match bar.provenance() { + BarProvenance::PciHeader => { + let range = self.alloc_hypervisor_resource(bar.bar_type(), bar.size())?; + // `range.base()` must be within a PCI root window. + let base = self.physical_to_pci_addr(range.base().into()).unwrap(); + // BAR index must be valid. + dev.set_bar_addr(bar.index(), base).unwrap(); + } + BarProvenance::EnhancedAllocation => { + // TODO: Because Enhanced Allocation regions use fixed addresses they require + // special handling, which the current resource management implementation does + // not meet: + // 1. When allocating resources for traditional BARS, we need to make sure + // these don't overlap with Enhanced Allocation regions. + // 2. The currently implemented simplistic allocation strategy that splits the + // MMIO space to steal some address space for the hypervisor at the top end + // doesn't work, since the Enhanced Allocation region may reside anywhere + // in MMIO space. + // As a result, we currently can't reserve Enhanced Allocation regions for the + // hypervisor and guarantee that the host VM can't access them. Thus, bail on + // attempts by the hypervisor to grab devices with Enhanced Allocation regions. + // + // We still allow adventurous souls to proceed after building with a special + // feature flag - but this may or may not work depending on where the + // respective Enhanced Allocation region(s) reside in MMIO space and whether + // they happen to collide with other allocations. + if !cfg!(feature = "unsafe_enhanced_allocation") { + unimplemented!(); + } + } + } } if bar_info .bars() From c23d487e16acb92d1e2091d4431373d395b142d7 Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Wed, 11 Jun 2025 02:51:54 -0700 Subject: [PATCH 03/13] drivers: Add `unsafe_enhanced_allocation` feature flag Building with the flag enabled allows salus to grab PCI devices for the hypervisor even if these have Enhanced Allocation regions. Due to missing support in resource allocation, this may cause resource collisions and break isolation, so it is disabled by default. --- BUILD | 13 +++++++++++++ drivers/BUILD | 4 ++++ 2 files changed, 17 insertions(+) diff --git a/BUILD b/BUILD index 333612ba..a6d5e24b 100644 --- a/BUILD +++ b/BUILD @@ -10,10 +10,23 @@ # bazel test //:rustfmt-all # bazel test //:test-all +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@rules_rust//rust:defs.bzl", "rust_binary", "rust_clippy", "rust_doc", "rust_test", "rustfmt_test") load("//:objcopy.bzl", "objcopy_to_object") load("//:lds.bzl", "lds_rule") +bool_flag( + name = "enable_unsafe_enhanced_allocation", + build_setting_default = 0, +) + +config_setting( + name = "unsafe_enhanced_allocation", + flag_values = { + "enable_unsafe_enhanced_allocation": "true", + }, +) + filegroup( name = "salus-all", srcs = [ diff --git a/drivers/BUILD b/drivers/BUILD index 7c5dffdb..d3b445b6 100644 --- a/drivers/BUILD +++ b/drivers/BUILD @@ -27,6 +27,10 @@ rust_library( "@salus-index//:static_assertions", "@salus-index//:tock-registers", ], + crate_features = select({ + "//:unsafe_enhanced_allocation": ["unsafe_enhanced_allocation"], + "//conditions:default": [], + }), ) rust_clippy( From 691d001af496d5c320b2ccf599f0b90381d5b7f6 Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Thu, 22 May 2025 00:08:11 -0700 Subject: [PATCH 04/13] drivers/iommu: Device directory to support both base and extended format Generalize the device directory implementation to allow using either base or extended device context format. At the table implementation level, there is now a type parameter for the device context that we're using, along with layout information and member accessors in a trait implemented for both the `DeviceContextBase` and `DeviceContextExtended` variants. The API remains agnostic of the type parameter, and we're instead selecting the underlying table implementation type to use via a run time parameter, with a new `DeviceDirectoryOps` trait and the `enum_dispatch` crate helping to bridge from the type-agnostic layer to the typed table implementation. --- drivers/src/iommu/core.rs | 19 +- drivers/src/iommu/device_directory.rs | 362 +++++++++++++++++--------- drivers/src/iommu/error.rs | 2 + drivers/src/iommu/mod.rs | 12 +- src/vm_pages.rs | 14 +- 5 files changed, 274 insertions(+), 135 deletions(-) diff --git a/drivers/src/iommu/core.rs b/drivers/src/iommu/core.rs index d8ce4308..f7e80707 100644 --- a/drivers/src/iommu/core.rs +++ b/drivers/src/iommu/core.rs @@ -83,9 +83,6 @@ impl Iommu { if !registers.capabilities.is_set(Capabilities::Sv48x4) { return Err(Error::MissingGStageSupport); } - if !registers.capabilities.is_set(Capabilities::MsiFlat) { - return Err(Error::MissingMsiSupport); - } // Initialize the command queue. let command_queue = CommandQueue::new(get_page().ok_or(Error::OutOfPages)?); @@ -101,7 +98,12 @@ impl Iommu { // TODO: Set up fault queue. // Set up an initial device directory table. - let ddt = DeviceDirectory::new(get_page().ok_or(Error::OutOfPages)?); + let format = if registers.capabilities.is_set(Capabilities::MsiFlat) { + DeviceContextFormat::Extended + } else { + DeviceContextFormat::Base + }; + let ddt = DeviceDirectory::new(get_page().ok_or(Error::OutOfPages)?, format); for dev in pci.devices() { let addr = dev.lock().info().address(); if addr == iommu_addr { @@ -138,6 +140,11 @@ impl Iommu { self.registers.capabilities.read(Capabilities::Version) } + /// Returns whether this IOMMU instance supports MSI page tables. + pub fn supports_msi_page_tables(&self) -> bool { + self.ddt.supports_msi_page_tables() + } + /// Allocates a new GSCID for `owner`. pub fn alloc_gscid(&self, owner: PageOwnerId) -> Result { let mut gscids = self.gscids.lock(); @@ -179,7 +186,7 @@ impl Iommu { &self, dev: &mut PciDevice, pt: &GuestStagePageTable, - msi_pt: &MsiPageTable, + msi_pt: Option<&MsiPageTable>, gscid: GscId, ) -> Result<()> { let dev_id = DeviceId::try_from(dev.info().address())?; @@ -191,7 +198,7 @@ impl Iommu { .and_then(|g| g.as_mut()) .ok_or(Error::InvalidGscId(gscid))?; if pt.page_owner_id() != state.owner - || msi_pt.owner() != state.owner + || !msi_pt.is_none_or(|pt| pt.owner() == state.owner) || dev.owner() != Some(state.owner) { return Err(Error::OwnerMismatch); diff --git a/drivers/src/iommu/device_directory.rs b/drivers/src/iommu/device_directory.rs index b665a722..571c3d6b 100644 --- a/drivers/src/iommu/device_directory.rs +++ b/drivers/src/iommu/device_directory.rs @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 use core::marker::PhantomData; +use enum_dispatch::enum_dispatch; use riscv_page_tables::{GuestStagePageTable, GuestStagePagingMode}; use riscv_pages::*; use riscv_regs::dma_wmb; @@ -15,10 +16,6 @@ use crate::pci::Address; // Maximum number of device ID bits used by the IOMMU. const DEVICE_ID_BITS: usize = 24; -// Number of bits used to index into the leaf table. -const LEAF_INDEX_BITS: usize = 6; -// Number of bits used to index into intermediate tables. -const NON_LEAF_INDEX_BITS: usize = 9; /// The device ID. Used to index into the device directory table. For PCI devices behind an IOMMU /// this is equivalent to the requester ID of the PCI device (i.e. the bits of the B/D/F). @@ -39,16 +36,6 @@ impl DeviceId { pub fn bits(&self) -> u32 { self.0 } - - // Returns the bits from this `DeviceId` used to index at `level`. - fn level_index_bits(&self, level: usize) -> usize { - if level == 0 { - (self.0 as usize) & ((1 << LEAF_INDEX_BITS) - 1) - } else { - let shift = LEAF_INDEX_BITS + NON_LEAF_INDEX_BITS * (level - 1); - ((self.0 as usize) >> shift) & ((1 << NON_LEAF_INDEX_BITS) - 1) - } - } } impl TryFrom
for DeviceId { @@ -77,12 +64,20 @@ impl GscId { // Defines the translation context for a device. A valid device context enables translation for // DMAs from the corresponding device according to the tables programmed into the device context. +// This is the base format, used when capabilities.MSI_FLAT isn't set. #[repr(C)] -struct DeviceContext { +struct DeviceContextBase { tc: u64, iohgatp: u64, fsc: u64, ta: u64, +} + +// Extended format of the device context, used when capabilities.MSI_FLAT is set. The additional +// fields control MSI address matching and translation. +#[repr(C)] +struct DeviceContextExtended { + base: DeviceContextBase, msiptp: u64, msi_addr_mask: u64, msi_addr_pattern: u64, @@ -96,46 +91,52 @@ const DC_VALID: u64 = 1 << 0; // device. Prevents enabling of device contexts that weren't explicitly added with `add_device()`. const DC_SW_INVALIDATED: u64 = 1 << 31; -impl DeviceContext { +// Trait abstracting over base / extended device context format. +trait DeviceContext { + const INDEX_BITS: [u8; 3]; + + fn base(&self) -> &DeviceContextBase; + fn base_mut(&mut self) -> &mut DeviceContextBase; + + // Returns the bits from `device_id` used to index at `level`. + fn level_index_bits(device_id: DeviceId, level: usize) -> usize { + let mask = (1 << Self::INDEX_BITS[level]) - 1; + let shift = Self::INDEX_BITS.iter().take(level).sum::() as usize; + (device_id.0 as usize >> shift) & mask + } + // Clears the device context structure. fn init(&mut self) { - self.tc = DC_SW_INVALIDATED; - self.iohgatp = 0; - self.fsc = 0; - self.ta = 0; - self.msiptp = 0; - self.msi_addr_mask = 0; - self.msi_addr_pattern = 0; + self.base_mut().tc = DC_SW_INVALIDATED; + self.base_mut().iohgatp = 0; + self.base_mut().fsc = 0; + self.base_mut().ta = 0; } // Returns if the device context corresponds to a present device. fn present(&self) -> bool { - (self.tc & (DC_VALID | DC_SW_INVALIDATED)) != 0 + (self.base().tc & (DC_VALID | DC_SW_INVALIDATED)) != 0 } // Returns if the device context is valid. fn valid(&self) -> bool { - (self.tc & DC_VALID) != 0 + (self.base().tc & DC_VALID) != 0 } // Marks the device context as valid, using `pt` and `msi_pt` for translation. fn set( &mut self, pt: &GuestStagePageTable, - msi_pt: &MsiPageTable, + msi_pt: Option<&MsiPageTable>, gscid: GscId, ) { - const MSI_MODE_FLAT: u64 = 0x1; - const MSI_MODE_SHIFT: u64 = 60; - self.msiptp = msi_pt.base_address().pfn().bits() | (MSI_MODE_FLAT << MSI_MODE_SHIFT); - - let (addr, mask) = msi_pt.msi_address_pattern(); - self.msi_addr_mask = mask >> PFN_SHIFT; - self.msi_addr_pattern = addr.pfn().bits(); + // This default implementation (appropriate for the base format) doesn't support MSI + // translation, and hence should never receive a valid `msi_pt` parameter. + assert!(msi_pt.is_none()); const GSCID_SHIFT: u64 = 44; const HGATP_MODE_SHIFT: u64 = 60; - self.iohgatp = pt.get_root_address().pfn().bits() + self.base_mut().iohgatp = pt.get_root_address().pfn().bits() | ((gscid.bits() as u64) << GSCID_SHIFT) | (T::HGATP_MODE << HGATP_MODE_SHIFT); @@ -143,12 +144,56 @@ impl DeviceContext { // as valid. dma_wmb(); - self.tc = DC_VALID; + self.base_mut().tc = DC_VALID; } // Marks the device context as invalid. fn invalidate(&mut self) { - self.tc = DC_SW_INVALIDATED; + self.base_mut().tc = DC_SW_INVALIDATED; + } +} + +impl DeviceContext for DeviceContextBase { + const INDEX_BITS: [u8; 3] = [7, 9, 8]; + + fn base(&self) -> &DeviceContextBase { + self + } + fn base_mut(&mut self) -> &mut DeviceContextBase { + self + } +} + +impl DeviceContext for DeviceContextExtended { + const INDEX_BITS: [u8; 3] = [6, 9, 9]; + + fn base(&self) -> &DeviceContextBase { + &self.base + } + fn base_mut(&mut self) -> &mut DeviceContextBase { + &mut self.base + } + + fn set( + &mut self, + pt: &GuestStagePageTable, + msi_pt: Option<&MsiPageTable>, + gscid: GscId, + ) { + if let Some(msi_pt) = msi_pt { + const MSI_MODE_FLAT: u64 = 0x1; + const MSI_MODE_SHIFT: u64 = 60; + self.msiptp = msi_pt.base_address().pfn().bits() | (MSI_MODE_FLAT << MSI_MODE_SHIFT); + + let (addr, mask) = msi_pt.msi_address_pattern(); + self.msi_addr_mask = mask >> PFN_SHIFT; + self.msi_addr_pattern = addr.pfn().bits(); + } else { + self.msiptp = 0; + } + + // NB: Pass a None `msi_pt` parameter to the base implementation. + self.base.set(pt, None, gscid); } } @@ -179,23 +224,34 @@ impl NonLeafEntry { } } +// Checks whether the table sizes at the different levels for a given DeviceContext look ok. +const fn _check_ddt_layout() -> bool { + DC::INDEX_BITS.len() == 3 + && (size_of::() << DC::INDEX_BITS[0] == 4096) + && (size_of::() << DC::INDEX_BITS[1] == 4096) + && (size_of::() << DC::INDEX_BITS[2] <= 4096) +} + +const_assert!(_check_ddt_layout::()); +const_assert!(_check_ddt_layout::()); + // Represents a single entry in the device directory hierarchy. -enum DeviceDirectoryEntry<'a> { - PresentLeaf(&'a mut DeviceContext), - NotPresentLeaf(&'a mut DeviceContext), - NextLevel(DeviceDirectoryTable<'a>), +enum DeviceDirectoryEntry<'a, DC: DeviceContext> { + Leaf(&'a mut DC), + NextLevel(DeviceDirectoryTable<'a, DC>), Invalid(&'a mut NonLeafEntry, usize), } // Represents a single device directory table. Intermediate DDTs (level > 0) are made up entirely // of non-leaf entries, while Leaf DDTs (level == 0) are made up entirely of `DeviceContext`s. -struct DeviceDirectoryTable<'a> { +#[derive(Clone)] +struct DeviceDirectoryTable<'a, DC: DeviceContext> { table_addr: SupervisorPageAddr, level: usize, - phantom: PhantomData<&'a mut DeviceDirectoryInner>, + phantom: PhantomData<&'a DC>, } -impl<'a> DeviceDirectoryTable<'a> { +impl<'a, DC: DeviceContext + 'a> DeviceDirectoryTable<'a, DC> { // Creates the root `DeviceDirectoryTable` from `owner`. fn from_root(owner: &'a mut DeviceDirectoryInner) -> Self { Self { @@ -225,23 +281,19 @@ impl<'a> DeviceDirectoryTable<'a> { } // Returns the `DeviceDirectoryEntry` for `id` in this table. - fn entry_for_id(&mut self, id: DeviceId) -> DeviceDirectoryEntry<'a> { - let index = id.level_index_bits(self.level); + fn entry_for_id(&mut self, id: DeviceId) -> DeviceDirectoryEntry<'a, DC> { + let index = DC::level_index_bits(id, self.level); use DeviceDirectoryEntry::*; if self.is_leaf() { // Safety: self.table_addr is properly aligned and must point to an array of // `DeviceContext`s if this is a leaf table. Further, `index` is guaranteed // to be within in the bounds of the table. let dc = unsafe { - let ptr = (self.table_addr.bits() as *mut DeviceContext).add(index); + let ptr = (self.table_addr.bits() as *mut DC).add(index); // Pointer must be non-NULL. ptr.as_mut().unwrap() }; - if dc.present() { - PresentLeaf(dc) - } else { - NotPresentLeaf(dc) - } + Leaf(dc) } else { // Safety: self.table_addr is properly aligned and must point to an array of // `NonLeafEntry`s if this is an intermediate table. Further, `index` is guaranteed @@ -262,54 +314,142 @@ impl<'a> DeviceDirectoryTable<'a> { } } - // Returns the next-level table mapping `id`, using `get_page` to allocate a directory table - // page if necessary. - fn next_level_or_fill( + // Returns if this is a leaf directory table. + fn is_leaf(&self) -> bool { + self.level == 0 + } + + // Get the device context for the device identified by `id`. + fn get_context_for_id(&mut self, id: DeviceId) -> Option<&mut DC> { + let mut entry = self.entry_for_id(id); + use DeviceDirectoryEntry::*; + while let NextLevel(mut t) = entry { + entry = t.entry_for_id(id); + } + match entry { + Leaf(dc) if dc.present() => Some(dc), + _ => None, + } + } + + // Returns the device context for the device identified by `id`, creating it if necessary. + fn create_context_for_id( &mut self, id: DeviceId, get_page: &mut dyn FnMut() -> Option>, - ) -> Result> { - use DeviceDirectoryEntry::*; - let table = match self.entry_for_id(id) { - NextLevel(t) => t, - Invalid(nle, level) => { - let page = get_page().ok_or(Error::OutOfPages)?; - nle.set(page.pfn()); - // Safety: We just allocated the page this entry points to and thus have unique - // ownership over the memory it refers to. - unsafe { - // Unwrap ok, we just marked the entry as valid. - Self::from_non_leaf_entry(nle, level).unwrap() + ) -> Result<&mut DC> { + let mut entry = self.entry_for_id(id); + loop { + use DeviceDirectoryEntry::*; + let mut table = match entry { + NextLevel(t) => t, + Invalid(nle, level) => { + let page = get_page().ok_or(Error::OutOfPages)?; + nle.set(page.pfn()); + // Safety: We just allocated the page this entry points to and thus have unique + // ownership over the memory it refers to. + unsafe { + // Unwrap ok, we just marked the entry as valid. + Self::from_non_leaf_entry(nle, level).unwrap() + } } - } - _ => { - return Err(Error::NotIntermediateTable); - } - }; - Ok(table) + Leaf(dc) => return Ok(dc), + }; + entry = table.entry_for_id(id); + } } +} - // Returns if this is a leaf directory table. - fn is_leaf(&self) -> bool { - self.level == 0 +// A trait providing directory mutation operations. Note that it does not have the specific device +// context as type parameter, but is implemented (generically) for `DeviceDirectoryTable` +// parameterized with specific device context types. Thus, this trait connects the API layer to the +// differently typed table implementations. +#[enum_dispatch(DeviceDirectoryOpsDispatch)] +trait DeviceDirectoryOps { + // Adds the device with `id`, creating tables as necessary. + fn add_device( + &mut self, + id: DeviceId, + get_page: &mut dyn FnMut() -> Option>, + ) -> Result<()>; + + // Updates the device context for `id` with the given parameters and marks it valid. + fn enable_device( + &mut self, + id: DeviceId, + pt: &GuestStagePageTable, + msi_pt: Option<&MsiPageTable>, + gscid: GscId, + ) -> Result<()>; + + // Invalidates the device context for `id`. + fn disable_device(&mut self, id: DeviceId) -> Result<()>; +} + +impl<'a, DC: DeviceContext + 'a> DeviceDirectoryOps for DeviceDirectoryTable<'a, DC> { + fn add_device( + &mut self, + id: DeviceId, + get_page: &mut dyn FnMut() -> Option>, + ) -> Result<()> { + let dc = self.create_context_for_id(id, get_page)?; + if !dc.present() { + dc.init(); + } + Ok(()) + } + + fn enable_device( + &mut self, + id: DeviceId, + pt: &GuestStagePageTable, + msi_pt: Option<&MsiPageTable>, + gscid: GscId, + ) -> Result<()> { + let entry = self + .get_context_for_id(id) + .ok_or(Error::DeviceNotFound(id))?; + if entry.valid() { + return Err(Error::DeviceAlreadyEnabled(id)); + } + entry.set(pt, msi_pt, gscid); + Ok(()) } + + fn disable_device(&mut self, id: DeviceId) -> Result<()> { + let entry = self + .get_context_for_id(id) + .ok_or(Error::DeviceNotFound(id))?; + if !entry.valid() { + return Err(Error::DeviceNotEnabled(id)); + } + entry.invalidate(); + Ok(()) + } +} + +// A helper enum for dispatching DeviceDirectoryOps calls from non-type-parameterized context to +// the type-parameterized DeviceDirectoryOps implementations for DeviceDirectoryTable. +#[enum_dispatch] +enum DeviceDirectoryOpsDispatch<'a> { + Base(DeviceDirectoryTable<'a, DeviceContextBase>), + Extended(DeviceDirectoryTable<'a, DeviceContextExtended>), } +// Represents a device directory instance. The instance is protected by a Mutex and thus separate +// from the API layer offered by `DeviceDirectory`. struct DeviceDirectoryInner { root: Page, num_levels: usize, + format: DeviceContextFormat, } impl DeviceDirectoryInner { - fn get_context_for_id(&mut self, id: DeviceId) -> Option<&mut DeviceContext> { - let mut entry = DeviceDirectoryTable::from_root(self).entry_for_id(id); - use DeviceDirectoryEntry::*; - while let NextLevel(mut t) = entry { - entry = t.entry_for_id(id); - } - match entry { - PresentLeaf(dc) => Some(dc), - _ => None, + fn ops(&mut self) -> DeviceDirectoryOpsDispatch { + use DeviceDirectoryOpsDispatch::*; + match self.format { + DeviceContextFormat::Base => Base(DeviceDirectoryTable::from_root(self)), + DeviceContextFormat::Extended => Extended(DeviceDirectoryTable::from_root(self)), } } } @@ -331,6 +471,12 @@ impl DirectoryMode for Ddt3Level { const IOMMU_MODE: u64 = 4; } +/// Indicates which device context format to use. +pub enum DeviceContextFormat { + Base, + Extended, +} + /// Represents the device directory table for the IOMMU. The IOMMU hardware uses the DDT to map /// a requester ID to the translation context for the device. pub struct DeviceDirectory { @@ -340,10 +486,11 @@ pub struct DeviceDirectory { impl DeviceDirectory { /// Creates a new `DeviceDirectory` using `root` as the root table page. - pub fn new(root: Page) -> Self { + pub fn new(root: Page, format: DeviceContextFormat) -> Self { let inner = DeviceDirectoryInner { root, num_levels: D::LEVELS, + format, }; Self { inner: Mutex::new(inner), @@ -356,6 +503,11 @@ impl DeviceDirectory { self.inner.lock().root.addr() } + /// Returns whether this IOMMU instance supports MSI page tables. + pub fn supports_msi_page_tables(&self) -> bool { + matches!(self.inner.lock().format, DeviceContextFormat::Extended) + } + /// Adds and initializes a device context for `id` in this `DeviceDirectory`. The device /// context is initially invalid, i.e. translation is off for the device. Uses `get_page` /// to allocate intermediate directory table pages, if necessary. @@ -365,16 +517,7 @@ impl DeviceDirectory { get_page: &mut dyn FnMut() -> Option>, ) -> Result<()> { let mut inner = self.inner.lock(); - // Silence bogus auto-deref lint, see https://github.com/rust-lang/rust-clippy/issues/9101. - #[allow(clippy::explicit_auto_deref)] - let mut table = DeviceDirectoryTable::from_root(&mut *inner); - while !table.is_leaf() { - table = table.next_level_or_fill(id, get_page)?; - } - if let DeviceDirectoryEntry::NotPresentLeaf(dc) = table.entry_for_id(id) { - dc.init(); - } - Ok(()) + inner.ops().add_device(id, get_page) } /// Enables IOMMU translation for the specified device, using `pt` for 2nd-stage translation @@ -384,38 +527,23 @@ impl DeviceDirectory { &self, id: DeviceId, pt: &GuestStagePageTable, - msi_pt: &MsiPageTable, + msi_pt: Option<&MsiPageTable>, gscid: GscId, ) -> Result<()> { - if pt.page_owner_id() != msi_pt.owner() { + if msi_pt.is_some_and(|msi_pt| msi_pt.owner() != pt.page_owner_id()) { return Err(Error::OwnerMismatch); } + let mut inner = self.inner.lock(); - let entry = inner - .get_context_for_id(id) - .ok_or(Error::DeviceNotFound(id))?; - if entry.valid() { - return Err(Error::DeviceAlreadyEnabled(id)); + if msi_pt.is_some() && !matches!(inner.format, DeviceContextFormat::Extended) { + return Err(Error::MsiTranslationUnsupported); } - entry.set(pt, msi_pt, gscid); - Ok(()) + inner.ops().enable_device(id, pt, msi_pt, gscid) } /// Disables IOMMU translation for the specified device. pub fn disable_device(&self, id: DeviceId) -> Result<()> { let mut inner = self.inner.lock(); - let entry = inner - .get_context_for_id(id) - .ok_or(Error::DeviceNotFound(id))?; - if !entry.valid() { - return Err(Error::DeviceNotEnabled(id)); - } - entry.invalidate(); - Ok(()) + inner.ops().disable_device(id) } } - -fn _assert_ddt_layout() { - const_assert!(core::mem::size_of::() << LEAF_INDEX_BITS == 4096); - const_assert!(core::mem::size_of::() << NON_LEAF_INDEX_BITS == 4096); -} diff --git a/drivers/src/iommu/error.rs b/drivers/src/iommu/error.rs index fc1bac40..22ae9726 100644 --- a/drivers/src/iommu/error.rs +++ b/drivers/src/iommu/error.rs @@ -65,6 +65,8 @@ pub enum Error { GscIdAlreadyFree(GscId), /// Attempted to free a GSCID that's currently being used for translation. GscIdInUse(GscId), + /// MSI translation not supported. + MsiTranslationUnsupported, } /// Holds results for IOMMU operations. diff --git a/drivers/src/iommu/mod.rs b/drivers/src/iommu/mod.rs index 2b2a52c0..ca8a7532 100644 --- a/drivers/src/iommu/mod.rs +++ b/drivers/src/iommu/mod.rs @@ -204,7 +204,7 @@ mod tests { let ddt_page = page_tracker .assign_page_for_internal_state(pages.pop().unwrap(), PageOwnerId::host()) .unwrap(); - let ddt = DeviceDirectory::::new(ddt_page); + let ddt = DeviceDirectory::::new(ddt_page, DeviceContextFormat::Extended); for i in 0..16 { let id = DeviceId::new(i).unwrap(); ddt.add_device(id, &mut || { @@ -217,17 +217,21 @@ mod tests { let gscid = GscId::new(0); let dev = DeviceId::new(2).unwrap(); - assert!(ddt.enable_device(dev, &pt, &msi_pt, gscid).is_ok()); + assert!(ddt.enable_device(dev, &pt, Some(&msi_pt), gscid).is_ok()); assert!(ddt.disable_device(dev).is_ok()); let bad_dev = DeviceId::new(1 << 16).unwrap(); - assert!(ddt.enable_device(bad_dev, &pt, &msi_pt, gscid).is_err()); + assert!(ddt + .enable_device(bad_dev, &pt, Some(&msi_pt), gscid) + .is_err()); let (bad_msi_pt, _) = stub_msi_page_table( page_tracker.clone(), &mut pages, PageOwnerId::new(5).unwrap(), ); - assert!(ddt.enable_device(dev, &pt, &bad_msi_pt, gscid).is_err()); + assert!(ddt + .enable_device(dev, &pt, Some(&bad_msi_pt), gscid) + .is_err()); } #[test] diff --git a/src/vm_pages.rs b/src/vm_pages.rs index 57b6a4aa..4996c483 100644 --- a/src/vm_pages.rs +++ b/src/vm_pages.rs @@ -2209,14 +2209,12 @@ impl<'a, T: GuestStagePagingMode> InitializingVmPages<'a, T> { /// this VM's page tables. pub fn attach_pci_device(&self, dev: &mut PciDevice) -> Result<()> { let iommu_context = self.inner.iommu_context.get().ok_or(Error::NoIommu)?; - Iommu::get() - .unwrap() - .attach_pci_device( - dev, - &self.inner.root, - &iommu_context.msi_page_table, - iommu_context.gscid, - ) + let iommu = Iommu::get().unwrap(); + let msi_pt = iommu + .supports_msi_page_tables() + .then_some(&iommu_context.msi_page_table); + iommu + .attach_pci_device(dev, &self.inner.root, msi_pt, iommu_context.gscid) .map_err(Error::AttachingDevice) } } From ebd1c3b3106a87100869441db667dcb7a8b17a53 Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Tue, 27 May 2025 07:42:06 -0700 Subject: [PATCH 05/13] Consider segment when computing config space offset When computing the config space offset for an address, the segment number does not contribute to the offset (each segment has its own config space). So, make sure the segment in the address matches the config space's segment, then compensate for the segment in the offset computation. --- drivers/src/pci/address.rs | 11 +++++++++++ drivers/src/pci/config_space.rs | 13 ++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/drivers/src/pci/address.rs b/drivers/src/pci/address.rs index fe09165a..661775d8 100644 --- a/drivers/src/pci/address.rs +++ b/drivers/src/pci/address.rs @@ -103,6 +103,12 @@ pub struct BusRange { pub end: Bus, } +impl BusRange { + pub fn contains(&self, bus: Bus) -> bool { + bus >= self.start && bus <= self.end + } +} + // Because Functions are only 3 bits, they are trivally valid Bus numbers(8 bits). impl From for Bus { fn from(f: Function) -> Self { @@ -142,6 +148,11 @@ impl Address { Address(seg.0 << Segment::SHIFT | bus.0 << Bus::SHIFT | dev.0 << Device::SHIFT | func.0) } + /// Creates an `Address` for the given segment. + pub fn segment_address(segment: Segment) -> Address { + Address(segment.0 << Segment::SHIFT) + } + /// Creates an `Address` for the given `Bus` on the first segment. pub fn bus_address(bus: Bus) -> Address { Address(bus.0 << Bus::SHIFT) diff --git a/drivers/src/pci/config_space.rs b/drivers/src/pci/config_space.rs index affcae27..c181f608 100644 --- a/drivers/src/pci/config_space.rs +++ b/drivers/src/pci/config_space.rs @@ -123,9 +123,16 @@ impl PciConfigSpace { // Returns the offset of the given address within this PciConfigSpace. fn config_space_offset(&self, address: Address) -> Option { - (address.bits() as u64) - .checked_sub(Address::bus_address(self.bus_range.start).bits() as u64) - .map(|a| a << PCIE_ECAM_FN_SHIFT) + // Make sure the address is on the correct segment and within the bus range. + if address.segment() != self.segment || !self.bus_range.contains(address.bus()) { + return None; + } + Some( + (address.bits() as u64) + .checked_sub(Address::segment_address(self.segment).bits() as u64)? + .checked_sub(Address::bus_address(self.bus_range.start).bits() as u64)? + << PCIE_ECAM_FN_SHIFT, + ) } } From e53f13bbb15a166975ac78255bb9f38087bf3fd9 Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Tue, 27 May 2025 07:44:59 -0700 Subject: [PATCH 06/13] Support multiple PCI roots A machine might contain multiple PCI segments, each with an independent ECAM. Thus, generalize PCI enumeration to go through however many PCI entries are found in the device tree, and gather them into a vector rather than a singleton. --- drivers/src/pci/root.rs | 35 ++++++++++++---------- src/host_vm.rs | 64 ++++++++++++++++++++++------------------- src/main.rs | 55 ++++++++++++++++++++--------------- src/vm_pages.rs | 21 +++++++------- 4 files changed, 97 insertions(+), 78 deletions(-) diff --git a/drivers/src/pci/root.rs b/drivers/src/pci/root.rs index 60ce3998..bc2ef30e 100644 --- a/drivers/src/pci/root.rs +++ b/drivers/src/pci/root.rs @@ -3,9 +3,10 @@ // SPDX-License-Identifier: Apache-2.0 use alloc::alloc::Global; +use alloc::vec::Vec; use arrayvec::{ArrayString, ArrayVec}; use core::marker::PhantomData; -use device_tree::{DeviceTree, DeviceTreeResult}; +use device_tree::{DeviceTree, DeviceTreeNode, DeviceTreeResult}; use hyp_alloc::{Arena, ArenaId}; use page_tracking::{HwMemMap, PageTracker}; use riscv_pages::*; @@ -41,7 +42,7 @@ pub struct PcieRoot { msi_parent_phandle: u32, } -static PCIE_ROOT: Once = Once::new(); +static PCIE_ROOTS: Once> = Once::new(); // A `u64` from two `u32` cells in a device tree. struct U64Cell(u32, u32); @@ -59,13 +60,7 @@ fn valid_config_mmio_access(offset: u64, len: usize) -> bool { } impl PcieRoot { - /// Creates a `PcieRoot` singleton by finding a supported configuration in the passed `DeviceTree`. - pub fn probe_from(dt: &DeviceTree, mem_map: &mut HwMemMap) -> Result<()> { - let pci_node = dt - .iter() - .find(|n| n.compatible(["pci-host-ecam-generic"]) && !n.disabled()) - .ok_or(Error::NoCompatibleHostNode)?; - + fn probe_one(pci_node: &DeviceTreeNode, mem_map: &mut HwMemMap) -> Result { // Find the ECAM MMIO region, which should be the first entry in the `reg` property. let mut regs = pci_node .props() @@ -192,20 +187,30 @@ impl PcieRoot { let mut device_arena = PciDeviceArena::new(Global); let root_bus = PciBus::enumerate(&config_space, bus_range.start, &mut device_arena)?; - PCIE_ROOT.call_once(|| Self { + Ok(Self { config_space, root_bus, device_arena, resources: Mutex::new(resources), msi_parent_phandle, - }); + }) + } + + /// Probes `PcieRoot`s from supported configurations in the passed `DeviceTree`. + pub fn probe_from(dt: &DeviceTree, mem_map: &mut HwMemMap) -> Result<()> { + let roots = dt + .iter() + .filter(|n| n.compatible(["pci-host-ecam-generic"]) && !n.disabled()) + .map(|n| Self::probe_one(n, mem_map)) + .collect::>>()?; + + PCIE_ROOTS.call_once(|| roots); Ok(()) } - /// Gets a reference to the `PcieRoot` singleton. Panics if `PcieRoot::probe_from()` has not yet - /// been called to initialize it. - pub fn get() -> &'static Self { - PCIE_ROOT.get().unwrap() + /// Returns an iterator over all PcieRoots that have been probed. + pub fn get_roots() -> impl Iterator { + PCIE_ROOTS.get().unwrap().iter() } /// Returns an iterator over all PCI devices. diff --git a/src/host_vm.rs b/src/host_vm.rs index 355d8acf..194541de 100644 --- a/src/host_vm.rs +++ b/src/host_vm.rs @@ -103,7 +103,9 @@ impl HostDtBuilder { soc_node.add_prop("ranges")?; Imsic::get().add_host_imsic_node(&mut self.tree)?; - PcieRoot::get().add_host_pcie_node(&mut self.tree)?; + for pci in PcieRoot::get_roots() { + pci.add_host_pcie_node(&mut self.tree)?; + } Ok(self) } @@ -265,23 +267,24 @@ impl HostVmLoader { self.vm.add_imsic_pages(cpu_id, imsic_pages); } - let pci = PcieRoot::get(); - pci.take_host_devices(); - // Identity-map the PCIe BAR resources. - for (res_type, range) in pci.resources() { - let gpa = range.base().as_guest_phys(PageOwnerId::host()); - self.vm.add_pci_region(gpa, range.length_bytes()); - let pages = pci.take_host_resource(res_type).unwrap(); - self.vm.add_pci_pages(gpa, pages); - } - // Attach our PCI devices to the IOMMU. - if Iommu::get().is_some() { - for dev in pci.devices() { - let mut dev = dev.lock(); - if dev.owner() == Some(PageOwnerId::host()) { - // Silence buggy clippy warning. - #[allow(clippy::explicit_auto_deref)] - self.vm.attach_pci_device(&mut *dev); + for pci in PcieRoot::get_roots() { + pci.take_host_devices(); + // Identity-map the PCIe BAR resources. + for (res_type, range) in pci.resources() { + let gpa = range.base().as_guest_phys(PageOwnerId::host()); + self.vm.add_pci_region(gpa, range.length_bytes()); + let pages = pci.take_host_resource(res_type).unwrap(); + self.vm.add_pci_pages(gpa, pages); + } + // Attach our PCI devices to the IOMMU. + if Iommu::get().is_some() { + for dev in pci.devices() { + let mut dev = dev.lock(); + if dev.owner() == Some(PageOwnerId::host()) { + // Silence buggy clippy warning. + #[allow(clippy::explicit_auto_deref)] + self.vm.attach_pci_device(&mut *dev); + } } } } @@ -365,10 +368,12 @@ impl HostVmLoader { self.vm.add_zero_pages(current_gpa, self.zero_pages); // Set up MMIO emulation for the PCIe config space. - let config_mem = pci.config_space(); - let config_gpa = config_mem.base().as_guest_phys(PageOwnerId::host()); - self.vm - .add_mmio_region(config_gpa, config_mem.length_bytes()); + for pci in PcieRoot::get_roots() { + let config_mem = pci.config_space(); + let config_gpa = config_mem.base().as_guest_phys(PageOwnerId::host()); + self.vm + .add_mmio_region(config_gpa, config_mem.length_bytes()); + } self.vm } @@ -487,14 +492,12 @@ impl HostVmRunner { ) -> core::result::Result<(), MmioEmulationError> { // For now, the only thing we're expecting is MMIO emulation faults in PCI config space. let addr = (self.htval << 2) | (self.stval & 0x3); - let pci = PcieRoot::get(); - if addr < pci.config_space().base().bits() { - return Err(MmioEmulationError::InvalidAddress(addr)); - } - let offset = addr - pci.config_space().base().bits(); - if offset > pci.config_space().length_bytes() { - return Err(MmioEmulationError::InvalidAddress(addr)); - } + let pci = PcieRoot::get_roots() + .find(|pci| { + let base = pci.config_space().base().bits(); + addr >= base && (addr - base) < pci.config_space().length_bytes() + }) + .ok_or(MmioEmulationError::InvalidAddress(addr))?; // Figure out from HTINST what the MMIO operation was. We know the source/destination is // always A0. @@ -516,6 +519,7 @@ impl HostVmRunner { } }; + let offset = addr - pci.config_space().base().bits(); if write { let val = self.gprs.reg(GprIndex::A0); pci.emulate_config_write(offset, val, width, page_tracker, PageOwnerId::host()); diff --git a/src/main.rs b/src/main.rs index 8ff91afd..c13391f9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,8 +44,8 @@ mod vm_pmu; use backtrace::backtrace; use device_tree::{DeviceTree, DeviceTreeError, Fdt}; use drivers::{ - imsic::Imsic, iommu::Iommu, pci::PcieRoot, pmu::PmuInfo, reset::ResetDriver, uart::UartDriver, - CpuInfo, + imsic::Imsic, iommu::Iommu, iommu::IommuError, pci::PciError, pci::PcieRoot, pmu::PmuInfo, + reset::ResetDriver, uart::UartDriver, CpuInfo, }; use host_vm::{HostVm, HostVmLoader, HOST_VM_ALIGN}; use hyp_alloc::HypAlloc; @@ -544,24 +544,25 @@ fn primary_init(hart_id: u64, fdt_addr: u64) -> Result { // Probe for a PCI bus. PcieRoot::probe_from(&hyp_dt, &mut mem_map) .map_err(|e| Error::RequiredDeviceProbe(RequiredDeviceProbe::Pci(e)))?; - let pci = PcieRoot::get(); - for dev in pci.devices() { - let dev = dev.lock(); - println!( - "Found func {}; type: {}, MSI: {}, MSI-X: {}, PCIe: {}", - dev.info(), - dev.info().header_type(), - dev.has_msi(), - dev.has_msix(), - dev.is_pcie(), - ); - for bar in dev.bar_info().bars() { + for pci in PcieRoot::get_roots() { + for dev in pci.devices() { + let dev = dev.lock(); println!( - "BAR{:}: type {:?}, size 0x{:x}", - bar.index(), - bar.bar_type(), - bar.size() + "Found func {}; type: {}, MSI: {}, MSI-X: {}, PCIe: {}", + dev.info(), + dev.info().header_type(), + dev.has_msi(), + dev.has_msix(), + dev.is_pcie(), ); + for bar in dev.bar_info().bars() { + println!( + "BAR{:}: type {:?}, size 0x{:x}", + bar.index(), + bar.bar_type(), + bar.size() + ); + } } } @@ -629,18 +630,26 @@ fn primary_init(hart_id: u64, fdt_addr: u64) -> Result { PerCpu::init(hart_id, &mut hyp_mem).map_err(Error::CreateSmpState)?; // Find and initialize the IOMMU. - match Iommu::probe_from(PcieRoot::get(), &mut || { - hyp_mem.take_pages_for_host_state(1).into_iter().next() - }) { - Ok(_) => { + match PcieRoot::get_roots() + .map(|pci| { + Iommu::probe_from(pci, &mut || { + hyp_mem.take_pages_for_host_state(1).into_iter().next() + }) + }) + .find(|r| !matches!(r, Err(IommuError::ProbingIommu(PciError::DeviceNotFound)))) + { + Some(Ok(_)) => { println!( "Found RISC-V IOMMU version 0x{:x}", Iommu::get().unwrap().version() ); } - Err(e) => { + Some(Err(e)) => { println!("Failed to probe IOMMU: {:?}", e); } + None => { + println!("No IOMMU found!"); + } }; // Initialize global Umode state. diff --git a/src/vm_pages.rs b/src/vm_pages.rs index 4996c483..57c48230 100644 --- a/src/vm_pages.rs +++ b/src/vm_pages.rs @@ -981,16 +981,17 @@ impl Drop for VmIommuContext { // Detach any devices we own from the IOMMU. let owner = self.msi_page_table.owner(); - let pci = PcieRoot::get(); - for dev in pci.devices() { - let mut dev = dev.lock(); - if dev.owner() == Some(owner) { - // Unwrap ok: `self.gscid` must be valid and match the ownership of the device - // to have been attached in the first place. - // - // Silence buggy clippy warning. - #[allow(clippy::explicit_auto_deref)] - iommu.detach_pci_device(&mut *dev, self.gscid).unwrap(); + for pci in PcieRoot::get_roots() { + for dev in pci.devices() { + let mut dev = dev.lock(); + if dev.owner() == Some(owner) { + // Unwrap ok: `self.gscid` must be valid and match the ownership of the device + // to have been attached in the first place. + // + // Silence buggy clippy warning. + #[allow(clippy::explicit_auto_deref)] + iommu.detach_pci_device(&mut *dev, self.gscid).unwrap(); + } } } From 00db1d7149c3df59964c71f1274ba6c009a28ecb Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Tue, 3 Jun 2025 03:42:48 -0700 Subject: [PATCH 07/13] Add `hardware_ad_updates` build flag The new build flag indicates whether hardware support for updating A/D bits in PTEs should be used. This is enabled by default and can be adjusted by passing `--//:enable_hardware_ad_updates=false` to the bazel invocation. The flag gets reflected into a cargo feature, which respective crates inspect. In hardware A/D update mode, the `svadu` CPU extension and the `AMO_HWAD` IOMMU capability must be present. When built with `hardware_ad_updates` disabled, PTEs are initialized with the A/D bits set, side-stepping the need for hardware updates and thus not requiring `svadu` and `AMO_HWAD`. --- BUILD | 16 ++++++++++++++++ drivers/BUILD | 3 +++ drivers/src/iommu/core.rs | 6 ++++++ drivers/src/iommu/device_directory.rs | 13 ++++++++++++- drivers/src/iommu/error.rs | 2 ++ drivers/src/iommu/registers.rs | 1 + riscv-page-tables/BUILD | 4 ++++ riscv-page-tables/src/pte.rs | 10 ++++++++-- src/main.rs | 2 +- 9 files changed, 53 insertions(+), 4 deletions(-) diff --git a/BUILD b/BUILD index a6d5e24b..d8b8544e 100644 --- a/BUILD +++ b/BUILD @@ -27,6 +27,18 @@ config_setting( }, ) +bool_flag( + name = "enable_hardware_ad_updates", + build_setting_default = 1, +) + +config_setting( + name = "hardware_ad_updates", + flag_values = { + "enable_hardware_ad_updates": "true", + }, +) + filegroup( name = "salus-all", srcs = [ @@ -204,6 +216,10 @@ rust_binary( "-Clink-arg=-T$(location //:l_rule)", ], deps = salus_deps, + crate_features = select({ + ":hardware_ad_updates": ["hardware_ad_updates"], + "//conditions:default": [], + }), ) rust_clippy( diff --git a/drivers/BUILD b/drivers/BUILD index d3b445b6..42c3305d 100644 --- a/drivers/BUILD +++ b/drivers/BUILD @@ -30,6 +30,9 @@ rust_library( crate_features = select({ "//:unsafe_enhanced_allocation": ["unsafe_enhanced_allocation"], "//conditions:default": [], + }) + select({ + "//:hardware_ad_updates": ["hardware_ad_updates"], + "//conditions:default": [], }), ) diff --git a/drivers/src/iommu/core.rs b/drivers/src/iommu/core.rs index f7e80707..3a0f4094 100644 --- a/drivers/src/iommu/core.rs +++ b/drivers/src/iommu/core.rs @@ -84,6 +84,12 @@ impl Iommu { return Err(Error::MissingGStageSupport); } + if cfg!(feature = "hardware_ad_updates") + && !registers.capabilities.is_set(Capabilities::AmoHwad) + { + return Err(Error::MissingAmoHwadSupport); + } + // Initialize the command queue. let command_queue = CommandQueue::new(get_page().ok_or(Error::OutOfPages)?); let mut cqb = LocalRegisterCopy::::new(0); diff --git a/drivers/src/iommu/device_directory.rs b/drivers/src/iommu/device_directory.rs index 571c3d6b..8a7a9198 100644 --- a/drivers/src/iommu/device_directory.rs +++ b/drivers/src/iommu/device_directory.rs @@ -87,6 +87,12 @@ struct DeviceContextExtended { // There are a bunch of other bits in `tc` for ATS, etc. but we only care about V for now. const DC_VALID: u64 = 1 << 0; +// Indicates that hardware should update the AD bits in G stage translation PTEs. +const DC_GADE: u64 = 1 << 7; + +// Indicates that hardware should update the AD bits in first translation PTEs. +const DC_SADE: u64 = 1 << 8; + // Set in invalidated device contexts to indicate that the device context corresponds to a real // device. Prevents enabling of device contexts that weren't explicitly added with `add_device()`. const DC_SW_INVALIDATED: u64 = 1 << 31; @@ -144,7 +150,12 @@ trait DeviceContext { // as valid. dma_wmb(); - self.base_mut().tc = DC_VALID; + let ade = if cfg!(feature = "hardware_ad_updates") { + DC_SADE | DC_GADE + } else { + 0 + }; + self.base_mut().tc = DC_VALID | ade; } // Marks the device context as invalid. diff --git a/drivers/src/iommu/error.rs b/drivers/src/iommu/error.rs index 22ae9726..7b47b3e6 100644 --- a/drivers/src/iommu/error.rs +++ b/drivers/src/iommu/error.rs @@ -23,6 +23,8 @@ pub enum Error { MissingGStageSupport, /// Missing required MSI translation support. MissingMsiSupport, + /// Missing A/D update support. + MissingAmoHwadSupport, /// Not enough pages were supplied to create an MSI page table. InsufficientMsiTablePages, /// The supplied MSI page table pages were not properly aligned. diff --git a/drivers/src/iommu/registers.rs b/drivers/src/iommu/registers.rs index ca974419..6d78ef6b 100644 --- a/drivers/src/iommu/registers.rs +++ b/drivers/src/iommu/registers.rs @@ -21,6 +21,7 @@ register_bitfields![u64, Sv57x4 OFFSET(19) NUMBITS(1), MsiFlat OFFSET(22) NUMBITS(1), MsiMrif OFFSET(23) NUMBITS(1), + AmoHwad OFFSET(24) NUMBITS(1), ], pub DirectoryPointer [ diff --git a/riscv-page-tables/BUILD b/riscv-page-tables/BUILD index 89e2d12d..e357b9cb 100644 --- a/riscv-page-tables/BUILD +++ b/riscv-page-tables/BUILD @@ -17,6 +17,10 @@ rust_library( rustc_flags = [ "-Ctarget-feature=+h", ], + crate_features = select({ + "//:hardware_ad_updates": ["hardware_ad_updates"], + "//conditions:default": [], + }), ) rust_clippy( diff --git a/riscv-page-tables/src/pte.rs b/riscv-page-tables/src/pte.rs index 141e3a03..4b65f594 100644 --- a/riscv-page-tables/src/pte.rs +++ b/riscv-page-tables/src/pte.rs @@ -189,13 +189,19 @@ impl PteFieldBits { pub fn leaf_with_perms(perms: PteLeafPerms) -> Self { let mut ret = Self::default(); ret.bits |= perms as u64; + + #[cfg(not(feature = "hardware_ad_updates"))] + { + ret.set_bit(PteFieldBit::Accessed); + ret.set_bit(PteFieldBit::Dirty); + } + ret } /// Creates a new status for a leaf entry with the given `perms`. pub fn user_leaf_with_perms(perms: PteLeafPerms) -> Self { - let mut ret = Self::default(); - ret.bits |= perms as u64; + let mut ret = Self::leaf_with_perms(perms); ret.set_bit(PteFieldBit::User); ret } diff --git a/src/main.rs b/src/main.rs index c13391f9..594fbc67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -492,7 +492,7 @@ fn primary_init(hart_id: u64, fdt_addr: u64) -> Result { // We don't implement or use the SBI timer extension and thus require Sstc for timers. return Err(Error::CpuMissingFeature(RequiredCpuFeature::Sstc)); } - if !cpu_info.has_svadu() { + if cfg!(feature = "hardware_ad_updates") && !cpu_info.has_svadu() { // Salus assumes that hardware will update the accessed and dirty bits in the page table. // It can't handle faults for updating them. return Err(Error::CpuMissingFeature(RequiredCpuFeature::Svadu)); From 8f2dc0d598a40bfe30ffe511b14e44b46c082943 Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Tue, 3 Jun 2025 06:57:24 -0700 Subject: [PATCH 08/13] drivers/iommu: Probe device directory mode IOMMU implementations are not required to support all device directory modes. The way for software to determine whether a mode is supported is to attempt to program the mode and read back the DDTP register to see whether the mode value was accepted. This change replaces the hard-coded 3-level mode with a probe loop to attempt mode values. The loop tries modes in decreasing number of levels to maximize the number of devices that can be managed in the table. --- drivers/src/iommu/core.rs | 82 ++++++++++++++++++++++----- drivers/src/iommu/device_directory.rs | 33 ++--------- drivers/src/iommu/error.rs | 2 + drivers/src/iommu/mod.rs | 2 +- 4 files changed, 76 insertions(+), 43 deletions(-) diff --git a/drivers/src/iommu/core.rs b/drivers/src/iommu/core.rs index 3a0f4094..6b155c62 100644 --- a/drivers/src/iommu/core.rs +++ b/drivers/src/iommu/core.rs @@ -35,7 +35,7 @@ pub struct Iommu { _arena_id: PciArenaId, registers: &'static mut IommuRegisters, command_queue: Mutex, - ddt: DeviceDirectory, + ddt: DeviceDirectory, gscids: Mutex<[Option; MAX_GSCIDS]>, } @@ -46,6 +46,34 @@ static IOMMU: Once = Once::new(); const IOMMU_VENDOR_ID: u16 = 0x1efd; const IOMMU_DEVICE_ID: u16 = 0xedf1; +// Suppress clippy warning about common suffix in favor or matching mode names as per IOMMU spec. +#[allow(clippy::enum_variant_names)] +enum DirectoryMode { + OneLevel, + TwoLevel, + ThreeLevel, +} + +impl DirectoryMode { + fn id(&self) -> u64 { + use DirectoryMode::*; + match self { + OneLevel => 2, + TwoLevel => 3, + ThreeLevel => 4, + } + } + + fn num_levels(&self) -> usize { + use DirectoryMode::*; + match self { + OneLevel => 1, + TwoLevel => 2, + ThreeLevel => 3, + } + } +} + impl Iommu { /// Probes for and initializes the IOMMU device on the given PCI root. Uses `get_page` to /// allocate pages for IOMMU-internal structures. @@ -109,21 +137,40 @@ impl Iommu { } else { DeviceContextFormat::Base }; - let ddt = DeviceDirectory::new(get_page().ok_or(Error::OutOfPages)?, format); - for dev in pci.devices() { - let addr = dev.lock().info().address(); - if addr == iommu_addr { - // Skip the IOMMU itself. - continue; + + let ddt_root = get_page().ok_or(Error::OutOfPages)?; + let mut ddtp = LocalRegisterCopy::::new(0); + ddtp.modify(DirectoryPointer::Ppn.val(ddt_root.pfn().bits())); + + // Probe the directory mode to use. + let mode = [ + DirectoryMode::ThreeLevel, + DirectoryMode::TwoLevel, + DirectoryMode::OneLevel, + ] + .iter() + .find(|mode| { + ddtp.modify(DirectoryPointer::Mode.val(mode.id())); + registers.ddtp.set(ddtp.get()); + while registers.ddtp.is_set(DirectoryPointer::Busy) { + pause(); + } + registers.ddtp.read(DirectoryPointer::Mode) == mode.id() + }) + .ok_or(Error::DeviceDirectoryUnsupported)?; + + let ddt = DeviceDirectory::new(ddt_root, format, mode.num_levels()); + + for pci in PcieRoot::get_roots() { + for dev in pci.devices() { + let addr = dev.lock().info().address(); + if addr == iommu_addr { + // Skip the IOMMU itself. + continue; + } + ddt.add_device(addr.try_into()?, get_page)?; } - ddt.add_device(addr.try_into()?, get_page)?; } - let mut ddtp = LocalRegisterCopy::::new(0); - ddtp.modify(DirectoryPointer::Ppn.val(ddt.base_address().pfn().bits())); - ddtp.modify(DirectoryPointer::Mode.val(Ddt3Level::IOMMU_MODE)); - // Ensure writes to the DDT have completed before we point the IOMMU at it. - mmio_wmb(); - registers.ddtp.set(ddtp.get()); let iommu = Iommu { _arena_id: arena_id, @@ -132,6 +179,13 @@ impl Iommu { ddt, gscids: Mutex::new([None; MAX_GSCIDS]), }; + + // Send a DDT invalidation command to make sure the IOMMU notices the added devices. + let commands = [Command::iodir_inval_ddt(None), Command::iofence()]; + // Unwrap ok: These are the first commands to the IOMMU, so 2 CQ entries will be + // available. + iommu.submit_commands_sync(&commands).unwrap(); + IOMMU.call_once(|| iommu); Ok(()) } diff --git a/drivers/src/iommu/device_directory.rs b/drivers/src/iommu/device_directory.rs index 8a7a9198..758b9532 100644 --- a/drivers/src/iommu/device_directory.rs +++ b/drivers/src/iommu/device_directory.rs @@ -465,24 +465,8 @@ impl DeviceDirectoryInner { } } -/// Defines the layout of the device directory table. Intermediate and leaf tables have the same -/// format regardless of the number of levels. -pub trait DirectoryMode { - /// The number of levels in the device directory hierarchy. - const LEVELS: usize; - /// The value that should be programmed into ddtp.iommu_mode for this translation mode. - const IOMMU_MODE: u64; -} - -/// A 3-level device directory table supporting up to 24-bit requester IDs. -pub enum Ddt3Level {} - -impl DirectoryMode for Ddt3Level { - const LEVELS: usize = 3; - const IOMMU_MODE: u64 = 4; -} - /// Indicates which device context format to use. +#[derive(Debug)] pub enum DeviceContextFormat { Base, Extended, @@ -490,30 +474,23 @@ pub enum DeviceContextFormat { /// Represents the device directory table for the IOMMU. The IOMMU hardware uses the DDT to map /// a requester ID to the translation context for the device. -pub struct DeviceDirectory { +pub struct DeviceDirectory { inner: Mutex, - phantom: PhantomData, } -impl DeviceDirectory { +impl DeviceDirectory { /// Creates a new `DeviceDirectory` using `root` as the root table page. - pub fn new(root: Page, format: DeviceContextFormat) -> Self { + pub fn new(root: Page, format: DeviceContextFormat, num_levels: usize) -> Self { let inner = DeviceDirectoryInner { root, - num_levels: D::LEVELS, + num_levels, format, }; Self { inner: Mutex::new(inner), - phantom: PhantomData, } } - /// Returns the base address of this `DeviceDirectory`. - pub fn base_address(&self) -> SupervisorPageAddr { - self.inner.lock().root.addr() - } - /// Returns whether this IOMMU instance supports MSI page tables. pub fn supports_msi_page_tables(&self) -> bool { matches!(self.inner.lock().format, DeviceContextFormat::Extended) diff --git a/drivers/src/iommu/error.rs b/drivers/src/iommu/error.rs index 7b47b3e6..50bfa505 100644 --- a/drivers/src/iommu/error.rs +++ b/drivers/src/iommu/error.rs @@ -69,6 +69,8 @@ pub enum Error { GscIdInUse(GscId), /// MSI translation not supported. MsiTranslationUnsupported, + /// No feasible device directory mode. + DeviceDirectoryUnsupported, } /// Holds results for IOMMU operations. diff --git a/drivers/src/iommu/mod.rs b/drivers/src/iommu/mod.rs index ca8a7532..dba9bceb 100644 --- a/drivers/src/iommu/mod.rs +++ b/drivers/src/iommu/mod.rs @@ -204,7 +204,7 @@ mod tests { let ddt_page = page_tracker .assign_page_for_internal_state(pages.pop().unwrap(), PageOwnerId::host()) .unwrap(); - let ddt = DeviceDirectory::::new(ddt_page, DeviceContextFormat::Extended); + let ddt = DeviceDirectory::new(ddt_page, DeviceContextFormat::Extended, 3); for i in 0..16 { let id = DeviceId::new(i).unwrap(); ddt.add_device(id, &mut || { From e2d9bbfbc45468c46b95e2f64112a785a5f84430 Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Tue, 3 Jun 2025 09:40:38 -0700 Subject: [PATCH 09/13] drivers/pci: Augment PCI device with DT node ID When enumerating PCI devices, look up the corresponding device tree node and save its node ID in the PCI device information. This is for the benefit of drivers which may need to obtain information from the device tree node. --- device-tree/src/device_tree.rs | 2 ++ device-tree/src/lib.rs | 2 +- drivers/src/pci/bus.rs | 37 ++++++++++++++++++++++++++++++--- drivers/src/pci/device.rs | 38 +++++++++++++++++++++++++--------- drivers/src/pci/root.rs | 16 +++++++++++--- 5 files changed, 78 insertions(+), 17 deletions(-) diff --git a/device-tree/src/device_tree.rs b/device-tree/src/device_tree.rs index c8cbde34..6e94439a 100644 --- a/device-tree/src/device_tree.rs +++ b/device-tree/src/device_tree.rs @@ -46,6 +46,8 @@ pub struct DeviceTreeNode { } pub type NodeArena = Arena; + +/// An ID for a device tree node. The node itself can be obtained via `DeviceTree::get_node`. pub type NodeId = ArenaId; /// A tree representation of the hardware in a system based on v0.3 of the Devicetree Specification. diff --git a/device-tree/src/lib.rs b/device-tree/src/lib.rs index ae4b9904..ea017445 100644 --- a/device-tree/src/lib.rs +++ b/device-tree/src/lib.rs @@ -13,7 +13,7 @@ mod error; mod fdt; mod serialize; -pub use crate::device_tree::{DeviceTree, DeviceTreeIter, DeviceTreeNode}; +pub use crate::device_tree::{DeviceTree, DeviceTreeIter, DeviceTreeNode, NodeId}; pub use error::Error as DeviceTreeError; pub use error::Result as DeviceTreeResult; pub use fdt::{Cpu, Fdt, FdtMemoryRegion, ImsicInfo}; diff --git a/drivers/src/pci/bus.rs b/drivers/src/pci/bus.rs index 590ac5d4..c523decf 100644 --- a/drivers/src/pci/bus.rs +++ b/drivers/src/pci/bus.rs @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 use alloc::vec::Vec; +use device_tree::{DeviceTree, DeviceTreeNode, NodeId}; use sync::Mutex; use super::address::*; @@ -26,12 +27,24 @@ pub struct PciBus { devices: Vec, } +fn pci_address_from_dt_node( + config_space: &PciConfigSpace, + node: &DeviceTreeNode, +) -> Option
{ + let regs = node.props().find(|p| p.name() == "reg")?; + let config_offset = regs.value_u32().next()?; + let (addr, _) = config_space.offset_to_address((config_offset as usize >> 8) << 12)?; + Some(addr) +} + impl PciBus { /// Creates a `PciBus` by enumerating `bus_num` in `config_space`. Devices discovered while /// enumerating the bus are addeded to `device_arena`. pub fn enumerate( + dt: &DeviceTree, config_space: &PciConfigSpace, bus_num: Bus, + bus_node: Option, device_arena: &mut PciDeviceArena, ) -> Result { let bus_config = config_space @@ -43,11 +56,22 @@ impl PciBus { // Unwrap ok, if we have a header the config space for the corresponding function // must exist. let registers_ptr = config_space.registers_for(info.address()).unwrap(); + + // Locate the corresponding device tree node. + let dev_node = bus_node.and_then(|id| { + dt.get_node(id)?.children().copied().find(|&id| { + dt.get_node(id) + .filter(|node| !node.disabled()) + .and_then(|node| pci_address_from_dt_node(config_space, node)) + .is_some_and(|addr| addr == info.address()) + }) + }); + // Safety: We trust that PciConfigSpace returned a valid config space pointer for the // same device as the one referred to by info.address(). We guarantee that the created // device has unique ownership of the register space via the bus enumeration process // by creating at most one device per PCI address. - let pci_dev = unsafe { PciDevice::new(registers_ptr, info.clone()) }?; + let pci_dev = unsafe { PciDevice::new(registers_ptr, info.clone(), dev_node) }?; let id = device_arena .try_insert(Mutex::new(pci_dev)) .map_err(|_| Error::AllocError)?; @@ -65,8 +89,11 @@ impl PciBus { for bd in devices.iter() { let bridge_id = bd.id; let sec_bus = cur_bus.next().ok_or(Error::OutOfBuses)?; + // ID must be valid, we just added it above. - match *device_arena.get(bridge_id).unwrap().lock() { + let mut dev = device_arena.get(bridge_id).unwrap().lock(); + let dt_node = dev.dt_node(); + match *dev { PciDevice::Bridge(ref mut bridge) => { // Set the bridge to cover everything beyond sec_bus until we've enumerated // the buses behind the bridge. @@ -78,7 +105,11 @@ impl PciBus { _ => continue, }; - let child_bus = PciBus::enumerate(config_space, sec_bus, device_arena)?; + // Unlock `dev` and drop the mutable borrow of `device_arena` for the + // `PciBus::enumerate()` call. + drop(dev); + + let child_bus = PciBus::enumerate(dt, config_space, sec_bus, dt_node, device_arena)?; let sub_bus = child_bus.subordinate_bus_num(); // Avoid double mutable borrow of device_arena by re-acquiring the reference to the bridge diff --git a/drivers/src/pci/device.rs b/drivers/src/pci/device.rs index cd3c2d86..b5a11ba8 100644 --- a/drivers/src/pci/device.rs +++ b/drivers/src/pci/device.rs @@ -6,6 +6,7 @@ use arrayvec::ArrayVec; use core::fmt; use core::mem::size_of; use core::ptr::NonNull; +use device_tree::NodeId; use page_tracking::PageTracker; use riscv_pages::*; use tock_registers::interfaces::{ReadWriteable, Readable, Writeable}; @@ -373,6 +374,7 @@ impl PciDeviceBarInfo { // Common state between bridges and endpoints. struct PciDeviceCommon { info: PciDeviceInfo, + dt_node: Option, capabilities: PciCapabilities, bar_info: PciDeviceBarInfo, owner: Option, @@ -387,13 +389,18 @@ pub struct PciEndpoint { impl PciEndpoint { /// Creates a new `PciEndpoint` using the config space at `registers`. - fn new(registers: &'static mut EndpointRegisters, info: PciDeviceInfo) -> Result { + fn new( + registers: &'static mut EndpointRegisters, + info: PciDeviceInfo, + dt_node: Option, + ) -> Result { let capabilities = PciCapabilities::new(&mut registers.common, registers.cap_ptr.get() as usize)?; let bar_info = PciDeviceBarInfo::new(&mut registers.bar, capabilities.enhanced_allocation())?; let common = PciDeviceCommon { info, + dt_node, capabilities, bar_info, owner: None, @@ -436,9 +443,9 @@ impl PciEndpoint { // Discard BAR writes if the BAR is enabled. let io_enabled = self.registers.common.command.is_set(Command::IoEnable); let mem_enabled = self.registers.common.command.is_set(Command::MemoryEnable); - if let Some(bar_type) = self.common.bar_info.index_to_type(index) && - ((bar_type == PciResourceType::IoPort && io_enabled) || - (bar_type != PciResourceType::IoPort && mem_enabled)) + if let Some(bar_type) = self.common.bar_info.index_to_type(index) + && ((bar_type == PciResourceType::IoPort && io_enabled) + || (bar_type != PciResourceType::IoPort && mem_enabled)) { return; } @@ -466,7 +473,11 @@ pub struct PciBridge { impl PciBridge { /// Creates a new `PciBridge` use the config space at `registers`. Downstream buses are initially /// unenumerated. - fn new(registers: &'static mut BridgeRegisters, info: PciDeviceInfo) -> Result { + fn new( + registers: &'static mut BridgeRegisters, + info: PciDeviceInfo, + dt_node: Option, + ) -> Result { // Prevent config cycles from passing beyond this bridge until we're ready to enumreate. registers.sub_bus.set(0); registers.sec_bus.set(0); @@ -491,6 +502,7 @@ impl PciBridge { PciDeviceBarInfo::new(&mut registers.bar, capabilities.enhanced_allocation())?; let common = PciDeviceCommon { info, + dt_node, capabilities, bar_info, owner: None, @@ -611,9 +623,9 @@ impl PciBridge { let index = (op.offset() - bar::START_OFFSET) / size_of::(); let reg = op.pop_dword(self.registers.bar[index].get()); // Discard BAR writes if the BAR is enabled. - if let Some(bar_type) = self.common.bar_info.index_to_type(index) && - ((bar_type == PciResourceType::IoPort && io_enabled) || - (bar_type != PciResourceType::IoPort && mem_enabled)) + if let Some(bar_type) = self.common.bar_info.index_to_type(index) + && ((bar_type == PciResourceType::IoPort && io_enabled) + || (bar_type != PciResourceType::IoPort && mem_enabled)) { return; } @@ -832,16 +844,17 @@ impl PciDevice { pub(super) unsafe fn new( registers_ptr: NonNull, info: PciDeviceInfo, + dt_node: Option, ) -> Result { match info.header_type() { HeaderType::Endpoint => { let registers = registers_ptr.cast().as_mut(); - let ep = PciEndpoint::new(registers, info)?; + let ep = PciEndpoint::new(registers, info, dt_node)?; Ok(PciDevice::Endpoint(ep)) } HeaderType::PciBridge => { let registers = registers_ptr.cast().as_mut(); - let bridge = PciBridge::new(registers, info)?; + let bridge = PciBridge::new(registers, info, dt_node)?; Ok(PciDevice::Bridge(bridge)) } h => Err(Error::UnsupportedHeaderType(info.address(), h)), @@ -853,6 +866,11 @@ impl PciDevice { &self.common().info } + /// Returns the device tree node ID for this device, if present. + pub fn dt_node(&self) -> Option { + self.common().dt_node + } + /// Returns the `PciDeviceBarInfo` for this device. pub fn bar_info(&self) -> &PciDeviceBarInfo { &self.common().bar_info diff --git a/drivers/src/pci/root.rs b/drivers/src/pci/root.rs index bc2ef30e..8c68c6d9 100644 --- a/drivers/src/pci/root.rs +++ b/drivers/src/pci/root.rs @@ -60,7 +60,11 @@ fn valid_config_mmio_access(offset: u64, len: usize) -> bool { } impl PcieRoot { - fn probe_one(pci_node: &DeviceTreeNode, mem_map: &mut HwMemMap) -> Result { + fn probe_one( + dt: &DeviceTree, + pci_node: &DeviceTreeNode, + mem_map: &mut HwMemMap, + ) -> Result { // Find the ECAM MMIO region, which should be the first entry in the `reg` property. let mut regs = pci_node .props() @@ -185,7 +189,13 @@ impl PcieRoot { // Enumerate the PCI hierarchy. let mut device_arena = PciDeviceArena::new(Global); - let root_bus = PciBus::enumerate(&config_space, bus_range.start, &mut device_arena)?; + let root_bus = PciBus::enumerate( + dt, + &config_space, + bus_range.start, + Some(pci_node.id()), + &mut device_arena, + )?; Ok(Self { config_space, @@ -201,7 +211,7 @@ impl PcieRoot { let roots = dt .iter() .filter(|n| n.compatible(["pci-host-ecam-generic"]) && !n.disabled()) - .map(|n| Self::probe_one(n, mem_map)) + .map(|n| Self::probe_one(dt, n, mem_map)) .collect::>>()?; PCIE_ROOTS.call_once(|| roots); From 1d096aecc6cf35153d9b0e406e0263ac2659b37a Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Tue, 3 Jun 2025 14:08:59 -0700 Subject: [PATCH 10/13] drivers/pci: Record IOMMU specifier for enumerated devices Inspect the "iommu-map" device tree property to determine which IOMMU and device identifier to use for devices on a PCI bus. The resulting IOMMU specifier is stored in device information for later use when attaching devices to the IOMMU. --- drivers/src/pci/bus.rs | 21 +++++++++++++++++---- drivers/src/pci/device.rs | 39 +++++++++++++++++++++++++++++++++++++-- drivers/src/pci/root.rs | 26 ++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 6 deletions(-) diff --git a/drivers/src/pci/bus.rs b/drivers/src/pci/bus.rs index c523decf..699657bc 100644 --- a/drivers/src/pci/bus.rs +++ b/drivers/src/pci/bus.rs @@ -8,7 +8,7 @@ use sync::Mutex; use super::address::*; use super::config_space::PciConfigSpace; -use super::device::PciDevice; +use super::device::{IommuSpecifier, PciDevice}; use super::error::*; use super::root::{PciArenaId, PciDeviceArena}; @@ -40,11 +40,12 @@ fn pci_address_from_dt_node( impl PciBus { /// Creates a `PciBus` by enumerating `bus_num` in `config_space`. Devices discovered while /// enumerating the bus are addeded to `device_arena`. - pub fn enumerate( + pub fn enumerate Option>( dt: &DeviceTree, config_space: &PciConfigSpace, bus_num: Bus, bus_node: Option, + build_iommu_specifier: &F, device_arena: &mut PciDeviceArena, ) -> Result { let bus_config = config_space @@ -67,11 +68,16 @@ impl PciBus { }) }); + // Determine the IOMMU specifier for this device. + let iommu_specifier = build_iommu_specifier(info.address()); + // Safety: We trust that PciConfigSpace returned a valid config space pointer for the // same device as the one referred to by info.address(). We guarantee that the created // device has unique ownership of the register space via the bus enumeration process // by creating at most one device per PCI address. - let pci_dev = unsafe { PciDevice::new(registers_ptr, info.clone(), dev_node) }?; + let pci_dev = unsafe { + PciDevice::new(registers_ptr, info.clone(), dev_node, iommu_specifier) + }?; let id = device_arena .try_insert(Mutex::new(pci_dev)) .map_err(|_| Error::AllocError)?; @@ -109,7 +115,14 @@ impl PciBus { // `PciBus::enumerate()` call. drop(dev); - let child_bus = PciBus::enumerate(dt, config_space, sec_bus, dt_node, device_arena)?; + let child_bus = PciBus::enumerate( + dt, + config_space, + sec_bus, + dt_node, + build_iommu_specifier, + device_arena, + )?; let sub_bus = child_bus.subordinate_bus_num(); // Avoid double mutable borrow of device_arena by re-acquiring the reference to the bridge diff --git a/drivers/src/pci/device.rs b/drivers/src/pci/device.rs index b5a11ba8..9e35fe70 100644 --- a/drivers/src/pci/device.rs +++ b/drivers/src/pci/device.rs @@ -13,6 +13,8 @@ use tock_registers::interfaces::{ReadWriteable, Readable, Writeable}; use tock_registers::registers::ReadWrite; use tock_registers::LocalRegisterCopy; +use crate::iommu::DeviceId as IommuDeviceId; + use super::address::*; use super::bus::PciBus; use super::capabilities::*; @@ -371,12 +373,34 @@ impl PciDeviceBarInfo { } } +/// IOMMU specifier for a PCI device. This indicates the IOMMU the device is behind and the device +/// identifier used by the IOMMU to distinguish the device. +pub struct IommuSpecifier { + phandle: u32, + dev_id: IommuDeviceId, +} + +impl IommuSpecifier { + pub fn new(phandle: u32, dev_id: IommuDeviceId) -> Self { + Self { phandle, dev_id } + } + + pub fn iommu_phandle(&self) -> u32 { + self.phandle + } + + pub fn iommu_dev_id(&self) -> IommuDeviceId { + self.dev_id + } +} + // Common state between bridges and endpoints. struct PciDeviceCommon { info: PciDeviceInfo, dt_node: Option, capabilities: PciCapabilities, bar_info: PciDeviceBarInfo, + iommu_specifier: Option, owner: Option, iommu_attached: bool, } @@ -393,6 +417,7 @@ impl PciEndpoint { registers: &'static mut EndpointRegisters, info: PciDeviceInfo, dt_node: Option, + iommu_specifier: Option, ) -> Result { let capabilities = PciCapabilities::new(&mut registers.common, registers.cap_ptr.get() as usize)?; @@ -403,6 +428,7 @@ impl PciEndpoint { dt_node, capabilities, bar_info, + iommu_specifier, owner: None, iommu_attached: false, }; @@ -477,6 +503,7 @@ impl PciBridge { registers: &'static mut BridgeRegisters, info: PciDeviceInfo, dt_node: Option, + iommu_specifier: Option, ) -> Result { // Prevent config cycles from passing beyond this bridge until we're ready to enumreate. registers.sub_bus.set(0); @@ -505,6 +532,7 @@ impl PciBridge { dt_node, capabilities, bar_info, + iommu_specifier, owner: None, iommu_attached: false, }; @@ -845,16 +873,17 @@ impl PciDevice { registers_ptr: NonNull, info: PciDeviceInfo, dt_node: Option, + iommu_specifier: Option, ) -> Result { match info.header_type() { HeaderType::Endpoint => { let registers = registers_ptr.cast().as_mut(); - let ep = PciEndpoint::new(registers, info, dt_node)?; + let ep = PciEndpoint::new(registers, info, dt_node, iommu_specifier)?; Ok(PciDevice::Endpoint(ep)) } HeaderType::PciBridge => { let registers = registers_ptr.cast().as_mut(); - let bridge = PciBridge::new(registers, info, dt_node)?; + let bridge = PciBridge::new(registers, info, dt_node, iommu_specifier)?; Ok(PciDevice::Bridge(bridge)) } h => Err(Error::UnsupportedHeaderType(info.address(), h)), @@ -891,6 +920,12 @@ impl PciDevice { self.common().capabilities.is_pcie() } + /// Returns the IOMMU specifier for this device, indicating which IOMMU it is behind and the + /// device identifier used by the IOMMU to distinguish this device. + pub fn iommu_specifier(&self) -> Option<&IommuSpecifier> { + self.common().iommu_specifier.as_ref() + } + /// Returns the device's owner. pub fn owner(&self) -> Option { self.common().owner diff --git a/drivers/src/pci/root.rs b/drivers/src/pci/root.rs index 8c68c6d9..87daef19 100644 --- a/drivers/src/pci/root.rs +++ b/drivers/src/pci/root.rs @@ -13,6 +13,7 @@ use riscv_pages::*; use sync::{Mutex, Once}; use crate::imsic::Imsic; +use crate::iommu::DeviceId as IommuDeviceId; use super::address::*; use super::bus::PciBus; @@ -187,6 +188,30 @@ impl PcieRoot { } } + // Obtain IOMMU mapping information from the respective DT properties. + let iommu_map_mask = pci_node + .props() + .find(|p| p.name() == "iommu-map-mask") + .and_then(|prop| prop.value_u32().next()) + .unwrap_or(!0u32); + let iommu_map = pci_node.props().find(|p| p.name() == "iommu-map"); + + let build_iommu_specifier = |addr: Address| { + let rid = addr.bits() & iommu_map_mask; + let mut mapping = iommu_map?.value_u32(); + loop { + let rid_base = mapping.next()?; + let phandle = mapping.next()?; + let iommu_base = mapping.next()?; + let len = mapping.next()?; + + if rid >= rid_base || rid - rid_base < len { + let dev_id = IommuDeviceId::new(rid - rid_base + iommu_base)?; + return Some(IommuSpecifier::new(phandle, dev_id)); + } + } + }; + // Enumerate the PCI hierarchy. let mut device_arena = PciDeviceArena::new(Global); let root_bus = PciBus::enumerate( @@ -194,6 +219,7 @@ impl PcieRoot { &config_space, bus_range.start, Some(pci_node.id()), + &build_iommu_specifier, &mut device_arena, )?; From 9376e61f50efeac0e1c66cb6fb5cc8f69274b8d9 Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Wed, 4 Jun 2025 01:23:18 -0700 Subject: [PATCH 11/13] drivers/iommu: Break out GSCID allocation In preparation for operating multiple IOMMUs, break out the GSCID allocation to be backed by a dedicated global allocation table. This change just moves the existing code around, but there's probably an opportunity here to switch to an alternative API that hands out ref-counted RAII handles representing allocated GSCIDs. --- drivers/src/iommu/core.rs | 56 ++------------------ drivers/src/iommu/device_directory.rs | 17 +----- drivers/src/iommu/error.rs | 3 +- drivers/src/iommu/gscid.rs | 76 +++++++++++++++++++++++++++ drivers/src/iommu/mod.rs | 4 +- drivers/src/iommu/queue.rs | 3 +- src/vm_pages.rs | 7 +-- 7 files changed, 89 insertions(+), 77 deletions(-) create mode 100644 drivers/src/iommu/gscid.rs diff --git a/drivers/src/iommu/core.rs b/drivers/src/iommu/core.rs index 6b155c62..61753b90 100644 --- a/drivers/src/iommu/core.rs +++ b/drivers/src/iommu/core.rs @@ -11,32 +11,18 @@ use tock_registers::LocalRegisterCopy; use super::device_directory::*; use super::error::{Error, Result}; +use super::gscid::{GscId, GSCIDS}; use super::msi_page_table::MsiPageTable; use super::queue::*; use super::registers::*; use crate::pci::{self, PciArenaId, PciDevice, PcieRoot}; -// Tracks the state of an allocated global soft-context ID (GSCID). -#[derive(Clone, Copy, Debug)] -struct GscIdState { - owner: PageOwnerId, - ref_count: usize, -} - -// We use a fixed-sized array to track available GSCIDs. We can't use a versioning scheme like we -// would for CPU VMIDs since reassigning GSCIDs on overflow would require us to temporarily disable -// DMA from all devices, which is extremely disruptive. Set a max of 64 allocated GSCIDs for now -// since it's unlikely we'll have more than that number of active VMs with assigned devices for -// the time being. -const MAX_GSCIDS: usize = 64; - /// IOMMU device. Responsible for managing address translation for PCI devices. pub struct Iommu { _arena_id: PciArenaId, registers: &'static mut IommuRegisters, command_queue: Mutex, ddt: DeviceDirectory, - gscids: Mutex<[Option; MAX_GSCIDS]>, } // The global IOMMU singleton. @@ -177,7 +163,6 @@ impl Iommu { registers, command_queue: Mutex::new(command_queue), ddt, - gscids: Mutex::new([None; MAX_GSCIDS]), }; // Send a DDT invalidation command to make sure the IOMMU notices the added devices. @@ -205,41 +190,6 @@ impl Iommu { self.ddt.supports_msi_page_tables() } - /// Allocates a new GSCID for `owner`. - pub fn alloc_gscid(&self, owner: PageOwnerId) -> Result { - let mut gscids = self.gscids.lock(); - let next = gscids - .iter() - .position(|g| g.is_none()) - .ok_or(Error::OutOfGscIds)?; - let state = GscIdState { - owner, - ref_count: 0, - }; - gscids[next] = Some(state); - Ok(GscId::new(next as u16)) - } - - /// Releases `gscid`, which must not be in use in any active device contexts. - pub fn free_gscid(&self, gscid: GscId) -> Result<()> { - let mut gscids = self.gscids.lock(); - let state = gscids - .get_mut(gscid.bits() as usize) - .ok_or(Error::InvalidGscId(gscid))?; - match state { - Some(s) if s.ref_count > 0 => { - return Err(Error::GscIdInUse(gscid)); - } - None => { - return Err(Error::GscIdAlreadyFree(gscid)); - } - _ => { - *state = None; - } - } - Ok(()) - } - /// Enables DMA for the given PCI device, using `pt` for 2nd-stage and `msi_pt` for MSI /// translation. pub fn attach_pci_device( @@ -252,7 +202,7 @@ impl Iommu { let dev_id = DeviceId::try_from(dev.info().address())?; // Make sure the GSCID is valid and that it matches up with the device and page table // owner. - let mut gscids = self.gscids.lock(); + let mut gscids = GSCIDS.lock(); let state = gscids .get_mut(gscid.bits() as usize) .and_then(|g| g.as_mut()) @@ -274,7 +224,7 @@ impl Iommu { let dev_id = DeviceId::try_from(dev.info().address())?; { // Verify that the GSCID is valid and that it matches up with the device owner. - let mut gscids = self.gscids.lock(); + let mut gscids = GSCIDS.lock(); let state = gscids .get_mut(gscid.bits() as usize) .and_then(|g| g.as_mut()) diff --git a/drivers/src/iommu/device_directory.rs b/drivers/src/iommu/device_directory.rs index 758b9532..9f9f0e73 100644 --- a/drivers/src/iommu/device_directory.rs +++ b/drivers/src/iommu/device_directory.rs @@ -11,6 +11,7 @@ use static_assertions::const_assert; use sync::Mutex; use super::error::*; +use super::gscid::GscId; use super::msi_page_table::MsiPageTable; use crate::pci::Address; @@ -46,22 +47,6 @@ impl TryFrom
for DeviceId { } } -/// Global Soft-Context ID. The equivalent of hgatp.VMID, but always 16 bits. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub struct GscId(u16); - -impl GscId { - /// Creates a `GscId` from the raw `id`. - pub(super) fn new(id: u16) -> Self { - GscId(id) - } - - /// Returns the raw bits of this `GscId`. - pub fn bits(&self) -> u16 { - self.0 - } -} - // Defines the translation context for a device. A valid device context enables translation for // DMAs from the corresponding device according to the tables programmed into the device context. // This is the base format, used when capabilities.MSI_FLAT isn't set. diff --git a/drivers/src/iommu/error.rs b/drivers/src/iommu/error.rs index 50bfa505..429bc6ba 100644 --- a/drivers/src/iommu/error.rs +++ b/drivers/src/iommu/error.rs @@ -4,7 +4,8 @@ use riscv_pages::SupervisorPageAddr; -use super::device_directory::{DeviceId, GscId}; +use super::device_directory::DeviceId; +use super::gscid::GscId; use crate::imsic::ImsicLocation; use crate::pci::{Address, PciError}; diff --git a/drivers/src/iommu/gscid.rs b/drivers/src/iommu/gscid.rs new file mode 100644 index 00000000..b75b3678 --- /dev/null +++ b/drivers/src/iommu/gscid.rs @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: 2025 Rivos Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +use riscv_pages::*; +use sync::Mutex; + +use super::error::*; + +/// Global Soft-Context ID. The equivalent of hgatp.VMID, but always 16 bits. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct GscId(u16); + +impl GscId { + /// Creates a `GscId` from the raw `id`. + pub(super) fn new(id: u16) -> Self { + GscId(id) + } + + /// Returns the raw bits of this `GscId`. + pub fn bits(&self) -> u16 { + self.0 + } +} + +// Tracks the state of an allocated global soft-context ID (GSCID). +#[derive(Clone, Copy, Debug)] +pub(super) struct GscIdState { + pub(super) owner: PageOwnerId, + pub(super) ref_count: usize, +} + +// We use a fixed-sized array to track available GSCIDs. We can't use a versioning scheme like we +// would for CPU VMIDs since reassigning GSCIDs on overflow would require us to temporarily disable +// DMA from all devices, which is extremely disruptive. Set a max of 64 allocated GSCIDs for now +// since it's unlikely we'll have more than that number of active VMs with assigned devices for +// the time being. +const MAX_GSCIDS: usize = 64; + +// The global GSCID allocation table. +pub(super) static GSCIDS: Mutex<[Option; MAX_GSCIDS]> = Mutex::new([None; MAX_GSCIDS]); + +/// Allocates a new GSCID for `owner`. +pub fn alloc_gscid(owner: PageOwnerId) -> Result { + let mut gscids = GSCIDS.lock(); + let next = gscids + .iter() + .position(|g| g.is_none()) + .ok_or(Error::OutOfGscIds)?; + let state = GscIdState { + owner, + ref_count: 0, + }; + gscids[next] = Some(state); + Ok(GscId::new(next as u16)) +} + +/// Releases `gscid`, which must not be in use in any active device contexts. +pub fn free_gscid(gscid: GscId) -> Result<()> { + let mut gscids = GSCIDS.lock(); + let state = gscids + .get_mut(gscid.bits() as usize) + .ok_or(Error::InvalidGscId(gscid))?; + match state { + Some(s) if s.ref_count > 0 => { + return Err(Error::GscIdInUse(gscid)); + } + None => { + return Err(Error::GscIdAlreadyFree(gscid)); + } + _ => { + *state = None; + } + } + Ok(()) +} diff --git a/drivers/src/iommu/mod.rs b/drivers/src/iommu/mod.rs index dba9bceb..d2bd33df 100644 --- a/drivers/src/iommu/mod.rs +++ b/drivers/src/iommu/mod.rs @@ -5,14 +5,16 @@ mod core; mod device_directory; mod error; +mod gscid; mod msi_page_table; mod queue; mod registers; pub use self::core::Iommu; -pub use device_directory::{DeviceId, GscId}; +pub use device_directory::DeviceId; pub use error::Error as IommuError; pub use error::Result as IommuResult; +pub use gscid::{alloc_gscid, free_gscid, GscId}; pub use msi_page_table::MsiPageTable; #[cfg(test)] diff --git a/drivers/src/iommu/queue.rs b/drivers/src/iommu/queue.rs index f1bad014..52661e25 100644 --- a/drivers/src/iommu/queue.rs +++ b/drivers/src/iommu/queue.rs @@ -7,8 +7,9 @@ use core::mem::size_of; use data_model::{DataInit, VolatileMemory, VolatileSlice}; use riscv_pages::*; -use super::device_directory::{DeviceId, GscId}; +use super::device_directory::DeviceId; use super::error::*; +use super::gscid::GscId; /// Type marker for a queue where software is the producer. pub enum Producer {} diff --git a/src/vm_pages.rs b/src/vm_pages.rs index 57c48230..f1eb04a0 100644 --- a/src/vm_pages.rs +++ b/src/vm_pages.rs @@ -963,10 +963,7 @@ pub struct VmIommuContext { impl VmIommuContext { // Creates a new `VmIommuContext` using `msi_page_table`. fn new(msi_page_table: MsiPageTable) -> Result { - let gscid = Iommu::get() - .ok_or(Error::NoIommu)? - .alloc_gscid(msi_page_table.owner()) - .map_err(Error::AllocatingGscId)?; + let gscid = alloc_gscid(msi_page_table.owner()).map_err(Error::AllocatingGscId)?; Ok(Self { msi_page_table, gscid, @@ -997,7 +994,7 @@ impl Drop for VmIommuContext { // Unwrap ok: `self.gscid` must be valid and freeable since we've detached all devices // using it. - iommu.free_gscid(self.gscid).unwrap(); + free_gscid(self.gscid).unwrap(); } } From 8818046c5900c8ea50bcb66ad2647cf687c08c47 Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Wed, 4 Jun 2025 03:27:42 -0700 Subject: [PATCH 12/13] Support multiple IOMMUs It is perfectly legal for systems to employ multiple IOMMUs, each handling a subset of PCI devices. Thus, change the probing code to discover all IOMMU devices across all PCI roots. Each device is then assigned to its corresponding IOMMU via the IOMMU phandle from device tree. --- drivers/src/iommu/core.rs | 138 ++++++++++++++++++++++++++----------- drivers/src/iommu/error.rs | 6 ++ drivers/src/pci/root.rs | 23 ++----- src/host_vm.rs | 17 ++--- src/main.rs | 40 +++++------ src/vm_pages.rs | 19 +++-- 6 files changed, 142 insertions(+), 101 deletions(-) diff --git a/drivers/src/iommu/core.rs b/drivers/src/iommu/core.rs index 61753b90..37adbc48 100644 --- a/drivers/src/iommu/core.rs +++ b/drivers/src/iommu/core.rs @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 +use device_tree::DeviceTree; use riscv_page_tables::{GuestStagePageTable, GuestStagePagingMode}; use riscv_pages::*; use riscv_regs::{mmio_wmb, pause}; @@ -15,18 +16,18 @@ use super::gscid::{GscId, GSCIDS}; use super::msi_page_table::MsiPageTable; use super::queue::*; use super::registers::*; -use crate::pci::{self, PciArenaId, PciDevice, PcieRoot}; +use crate::pci::{PciDevice, PcieRoot}; /// IOMMU device. Responsible for managing address translation for PCI devices. pub struct Iommu { - _arena_id: PciArenaId, registers: &'static mut IommuRegisters, command_queue: Mutex, ddt: DeviceDirectory, + phandle: Option, } -// The global IOMMU singleton. -static IOMMU: Once = Once::new(); +// The global list of IOMMUs. +static IOMMUS: [Once; 8] = [Once::INIT; 8]; // Identifiers from the QEMU RFC implementation. const IOMMU_VENDOR_ID: u16 = 0x1efd; @@ -61,28 +62,38 @@ impl DirectoryMode { } impl Iommu { - /// Probes for and initializes the IOMMU device on the given PCI root. Uses `get_page` to - /// allocate pages for IOMMU-internal structures. - pub fn probe_from( + /// Probes for and initializes the given IOMMU device. Uses `get_page` to allocate pages for + /// IOMMU-internal structures. + pub fn probe( + dt: &DeviceTree, pci: &PcieRoot, + dev: &Mutex, get_page: &mut dyn FnMut() -> Option>, - ) -> Result<()> { - let arena_id = pci - .take_and_enable_hypervisor_device( - pci::VendorId::new(IOMMU_VENDOR_ID), - pci::DeviceId::new(IOMMU_DEVICE_ID), - ) + ) -> Result<&'static Iommu> { + let mut dev = dev.lock(); + + if dev.info().vendor_id().bits() != IOMMU_VENDOR_ID + || dev.info().device_id().bits() != IOMMU_DEVICE_ID + { + return Err(Error::NotAnIommu); + } + + pci.take_and_enable_hypervisor_device(&mut dev) .map_err(Error::ProbingIommu)?; - let (iommu_addr, regs_base, regs_size) = { - let dev = pci.get_device(arena_id).unwrap().lock(); - // IOMMU registers are in BAR0. - let bar = dev.bar_info().get(0).ok_or(Error::MissingRegisters)?; - // Unwrap ok: we've already determined BAR0 is valid. - let pci_addr = dev.get_bar_addr(0).unwrap(); - let regs_base = pci.pci_to_physical_addr(pci_addr).unwrap(); - let regs_size = bar.size(); - (dev.info().address(), regs_base, regs_size) - }; + + // IOMMU registers are in BAR0. + let bar = dev.bar_info().get(0).ok_or(Error::MissingRegisters)?; + // Unwrap ok: we've already determined BAR0 is valid. + let pci_addr = dev.get_bar_addr(0).unwrap(); + let regs_base = pci.pci_to_physical_addr(pci_addr).unwrap(); + let regs_size = bar.size(); + let dt_node_id = dev.dt_node(); + + // We're done with inspecting `dev`, so unlock the mutex. It's not only good practice to + // keep the locked section small, but we'll be iterating all devices when checking which to + // add to this IOMMU, which would attempt to acquire the same lock again. + drop(dev); + if regs_size < core::mem::size_of::() as u64 { return Err(Error::InvalidRegisterSize(regs_size)); } @@ -147,37 +158,72 @@ impl Iommu { let ddt = DeviceDirectory::new(ddt_root, format, mode.num_levels()); - for pci in PcieRoot::get_roots() { - for dev in pci.devices() { - let addr = dev.lock().info().address(); - if addr == iommu_addr { - // Skip the IOMMU itself. - continue; + let phandle = dt_node_id.and_then(|id| dt.get_node(id)).and_then(|node| { + node.props() + .find(|p| p.name() == "phandle") + .and_then(|p| p.value_u32().next()) + }); + + // Add devices assigned to this IOMMU to the ddt. + if let Some(phandle) = phandle { + for pci in PcieRoot::get_roots() { + for dev in pci.devices() { + let dev = dev.lock(); + if let Some(spec) = dev.iommu_specifier() + && spec.iommu_phandle() == phandle + { + ddt.add_device(spec.iommu_dev_id(), get_page).unwrap(); + } } - ddt.add_device(addr.try_into()?, get_page)?; } } let iommu = Iommu { - _arena_id: arena_id, registers, command_queue: Mutex::new(command_queue), ddt, + phandle, }; - // Send a DDT invalidation command to make sure the IOMMU notices the added devices. + // Send a DDT invalidation command to make sure the IOMMU notices any added devices. let commands = [Command::iodir_inval_ddt(None), Command::iofence()]; - // Unwrap ok: These are the first commands to the IOMMU, so 2 CQ entries will be - // available. + // Unwrap ok: These are the first commands to the IOMMU, so 2 CQ entries will be available. iommu.submit_commands_sync(&commands).unwrap(); - IOMMU.call_once(|| iommu); - Ok(()) + // Store the iommu object in a slot in `IOMMUS`. We try slots in order until we find one + // that initializes successfully. + let mut iommu = Some(iommu); + for slot in IOMMUS.iter() { + assert!(iommu.is_some()); + + // Note that `Once::call_once()` guarantees to only invoke the closure when it is the + // first call. The closure holds a mutable reference to `iommu`, so it will only move + // the object out of the option when the slot gets initialized successfully. We break + // the loop once that happens, and this maintains the loop invariant `iommu.is_some()` + // which is why the `unwrap` call in the closure is OK. + slot.call_once(|| iommu.take().unwrap()); + if iommu.is_none() { + // Unwrap OK: We just wrote the slot. + return Ok(slot.get().unwrap()); + } + } + + Err(Error::TooManyIommus) + } + + /// Iterates all probed `Iommu`s in the system. + pub fn get_iommus() -> impl Iterator { + IOMMUS.iter().map_while(|slot| slot.get()) + } + + /// Gets the IOMMU matching the given phandle. + pub fn get_by_phandle(phandle: u32) -> Option<&'static Iommu> { + Iommu::get_iommus().find(|iommu| iommu.phandle == Some(phandle)) } - /// Gets a reference to the `Iommu` singleton. - pub fn get() -> Option<&'static Self> { - IOMMU.get() + /// Gets the IOMMU for the given PciDevice. + pub fn get_for_device(dev: &PciDevice) -> Option<&'static Iommu> { + Self::get_by_phandle(dev.iommu_specifier()?.iommu_phandle()) } /// Returns the version of this IOMMU device. @@ -199,7 +245,12 @@ impl Iommu { msi_pt: Option<&MsiPageTable>, gscid: GscId, ) -> Result<()> { - let dev_id = DeviceId::try_from(dev.info().address())?; + let dev_id = dev + .iommu_specifier() + .filter(|spec| Some(spec.iommu_phandle()) == self.phandle) + .map(|spec| spec.iommu_dev_id()) + .ok_or(Error::IommuMismatch)?; + // Make sure the GSCID is valid and that it matches up with the device and page table // owner. let mut gscids = GSCIDS.lock(); @@ -221,7 +272,12 @@ impl Iommu { /// Disables DMA translation for the given PCI device. pub fn detach_pci_device(&self, dev: &mut PciDevice, gscid: GscId) -> Result<()> { - let dev_id = DeviceId::try_from(dev.info().address())?; + let dev_id = dev + .iommu_specifier() + .filter(|spec| Some(spec.iommu_phandle()) == self.phandle) + .map(|spec| spec.iommu_dev_id()) + .ok_or(Error::IommuMismatch)?; + { // Verify that the GSCID is valid and that it matches up with the device owner. let mut gscids = GSCIDS.lock(); diff --git a/drivers/src/iommu/error.rs b/drivers/src/iommu/error.rs index 429bc6ba..73414fb6 100644 --- a/drivers/src/iommu/error.rs +++ b/drivers/src/iommu/error.rs @@ -12,6 +12,10 @@ use crate::pci::{Address, PciError}; /// Errors resulting from interacting with the IOMMU. #[derive(Clone, Copy, Debug)] pub enum Error { + /// The device doesn't identify as an IOMMU. + NotAnIommu, + /// There are more than the maximum number of IOMMUs supported by the code. + TooManyIommus, /// Error encountered while probing and enabling the IOMMU PCI device. ProbingIommu(PciError), /// Couldn't find the IOMMU registers BAR. @@ -48,6 +52,8 @@ pub enum Error { PciAddressTooLarge(Address), /// Mismatch between page table and device ownership. OwnerMismatch, + /// Device isn't managed by this IOMMU. + IommuMismatch, /// No device context found. DeviceNotFound(DeviceId), /// The device already has an active device context. diff --git a/drivers/src/pci/root.rs b/drivers/src/pci/root.rs index 87daef19..16e8b3bf 100644 --- a/drivers/src/pci/root.rs +++ b/drivers/src/pci/root.rs @@ -322,26 +322,17 @@ impl PcieRoot { /// Takes ownership over the PCI device with the given `vendor_id` and `device_id`, and enables /// it for use within the hypervisor by assigning it resources. Returns a `PciDeviceId` which /// can be used to retrieve a reference to the device on success. - pub fn take_and_enable_hypervisor_device( - &self, - vendor_id: VendorId, - device_id: DeviceId, - ) -> Result { - let dev_id = self - .device_arena - .ids() - .find(|&id| { - let d = self.device_arena.get(id).unwrap().lock(); - d.info().vendor_id() == vendor_id && d.info().device_id() == device_id - }) - .ok_or(Error::DeviceNotFound)?; + pub fn take_and_enable_hypervisor_device(&self, dev: &mut PciDevice) -> Result<()> { // Make sure the device is on the root bus. We don't support distributing resources behind // bridges. - if !self.root_bus.devices().any(|bd| bd.id == dev_id) { + if !self + .root_bus + .devices() + .any(|bd| bd.address == dev.info().address()) + { return Err(Error::DeviceNotOnRootBus); } - let mut dev = self.device_arena.get(dev_id).unwrap().lock(); dev.take(PageOwnerId::hypervisor())?; // Now assign BAR resources. let bar_info = dev.bar_info().clone(); @@ -392,7 +383,7 @@ impl PcieRoot { } dev.enable_dma(); - Ok(dev_id) + Ok(()) } /// Adds a node for this PCIe root complex to the host's device tree in `dt`. It's assumed that diff --git a/src/host_vm.rs b/src/host_vm.rs index 194541de..f445006d 100644 --- a/src/host_vm.rs +++ b/src/host_vm.rs @@ -276,15 +276,12 @@ impl HostVmLoader { let pages = pci.take_host_resource(res_type).unwrap(); self.vm.add_pci_pages(gpa, pages); } - // Attach our PCI devices to the IOMMU. - if Iommu::get().is_some() { - for dev in pci.devices() { - let mut dev = dev.lock(); - if dev.owner() == Some(PageOwnerId::host()) { - // Silence buggy clippy warning. - #[allow(clippy::explicit_auto_deref)] - self.vm.attach_pci_device(&mut *dev); - } + // Attach our PCI devices to their respective IOMMU. + for dev in pci.devices() { + let mut dev = dev.lock(); + if dev.owner() == Some(PageOwnerId::host()) && Iommu::get_for_device(&dev).is_some() + { + self.vm.attach_pci_device(&mut dev); } } } @@ -642,7 +639,7 @@ impl HostVm { let imsic_geometry = Imsic::get().host_vm_geometry(); // Reserve MSI page table pages if we have an IOMMU. - let msi_table_pages = Iommu::get().map(|_| { + let msi_table_pages = Iommu::get_iommus().next().map(|_| { let msi_table_size = MsiPageTable::required_table_size(&imsic_geometry); hyp_mem.take_pages_for_host_state_with_alignment( PageSize::num_4k_pages(msi_table_size) as usize, diff --git a/src/main.rs b/src/main.rs index 594fbc67..c571c21c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,8 +44,8 @@ mod vm_pmu; use backtrace::backtrace; use device_tree::{DeviceTree, DeviceTreeError, Fdt}; use drivers::{ - imsic::Imsic, iommu::Iommu, iommu::IommuError, pci::PciError, pci::PcieRoot, pmu::PmuInfo, - reset::ResetDriver, uart::UartDriver, CpuInfo, + imsic::Imsic, iommu::Iommu, iommu::IommuError, pci::PcieRoot, pmu::PmuInfo, reset::ResetDriver, + uart::UartDriver, CpuInfo, }; use host_vm::{HostVm, HostVmLoader, HOST_VM_ALIGN}; use hyp_alloc::HypAlloc; @@ -629,28 +629,22 @@ fn primary_init(hart_id: u64, fdt_addr: u64) -> Result { // Set up per-CPU memory and prepare the structures for secondary CPUs boot. PerCpu::init(hart_id, &mut hyp_mem).map_err(Error::CreateSmpState)?; - // Find and initialize the IOMMU. - match PcieRoot::get_roots() - .map(|pci| { - Iommu::probe_from(pci, &mut || { - hyp_mem.take_pages_for_host_state(1).into_iter().next() - }) - }) - .find(|r| !matches!(r, Err(IommuError::ProbingIommu(PciError::DeviceNotFound)))) - { - Some(Ok(_)) => { - println!( - "Found RISC-V IOMMU version 0x{:x}", - Iommu::get().unwrap().version() - ); - } - Some(Err(e)) => { - println!("Failed to probe IOMMU: {:?}", e); - } - None => { - println!("No IOMMU found!"); + // Find and initialize the IOMMUs. + for pci in PcieRoot::get_roots() { + for dev in pci.devices() { + let mut get_page = || hyp_mem.take_pages_for_host_state(1).into_iter().next(); + let addr = dev.lock().info().address(); + match Iommu::probe(&hyp_dt, pci, dev, &mut get_page) { + Ok(iommu) => println!("{} RISC-V IOMMU version 0x{:x}", addr, iommu.version()), + Err(IommuError::NotAnIommu) => {} + Err(err) => println!("{} RISC-V IOMMU probe failure: {:?}", addr, err), + } } - }; + } + + if Iommu::get_iommus().next().is_none() { + println!("No IOMMU found!"); + } // Initialize global Umode state. UmodeTask::init(umode_elf); diff --git a/src/vm_pages.rs b/src/vm_pages.rs index f1eb04a0..970cf7b7 100644 --- a/src/vm_pages.rs +++ b/src/vm_pages.rs @@ -973,21 +973,17 @@ impl VmIommuContext { impl Drop for VmIommuContext { fn drop(&mut self) { - // Unwrap ok: presence of an IOMMU is checked at creation time - let iommu = Iommu::get().unwrap(); - // Detach any devices we own from the IOMMU. let owner = self.msi_page_table.owner(); for pci in PcieRoot::get_roots() { for dev in pci.devices() { let mut dev = dev.lock(); - if dev.owner() == Some(owner) { + if dev.owner() == Some(owner) + && let Some(iommu) = Iommu::get_for_device(&dev) + { // Unwrap ok: `self.gscid` must be valid and match the ownership of the device // to have been attached in the first place. - // - // Silence buggy clippy warning. - #[allow(clippy::explicit_auto_deref)] - iommu.detach_pci_device(&mut *dev, self.gscid).unwrap(); + iommu.detach_pci_device(&mut dev, self.gscid).unwrap(); } } } @@ -1805,8 +1801,9 @@ impl<'a, T: GuestStagePagingMode> FinalizedVmPages<'a, T> { // If we have an IOMMU context then we need to issue a fence there as well as our page // tables may be used for DMA translation. if let Some(iommu_context) = self.inner.iommu_context.get() { - // Unwrap ok since we must have an IOMMU to have a `VmIommuContext`. - Iommu::get().unwrap().fence(iommu_context.gscid, None); + for iommu in Iommu::get_iommus() { + iommu.fence(iommu_context.gscid, None); + } } Ok(()) } @@ -2207,7 +2204,7 @@ impl<'a, T: GuestStagePagingMode> InitializingVmPages<'a, T> { /// this VM's page tables. pub fn attach_pci_device(&self, dev: &mut PciDevice) -> Result<()> { let iommu_context = self.inner.iommu_context.get().ok_or(Error::NoIommu)?; - let iommu = Iommu::get().unwrap(); + let iommu = Iommu::get_for_device(dev).ok_or(Error::NoIommu)?; let msi_pt = iommu .supports_msi_page_tables() .then_some(&iommu_context.msi_page_table); From 056a642801826cf0263825997e38677240ec81fd Mon Sep 17 00:00:00 2001 From: Mattias Nissler Date: Wed, 4 Jun 2025 06:31:58 -0700 Subject: [PATCH 13/13] drivers/iommu: Recognize alternative IOMMU PCI ids There are a few different PCI device/vendor ID pairs used for RISCV IOMMUs. Match against a list instead of expecting a specific pair. --- drivers/src/iommu/core.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/drivers/src/iommu/core.rs b/drivers/src/iommu/core.rs index 37adbc48..2766cb27 100644 --- a/drivers/src/iommu/core.rs +++ b/drivers/src/iommu/core.rs @@ -30,8 +30,11 @@ pub struct Iommu { static IOMMUS: [Once; 8] = [Once::INIT; 8]; // Identifiers from the QEMU RFC implementation. -const IOMMU_VENDOR_ID: u16 = 0x1efd; -const IOMMU_DEVICE_ID: u16 = 0xedf1; +const IOMMU_PCI_ID_TABLE: [(u16, u16); 3] = [ + (0x1b36, 0x0014), // vanilla qemu IOMMU model + (0x1efd, 0xedf1), // Rivos qemu IOMMU model + (0x1efd, 0x0008), // Rivos hardware IOMMU +]; // Suppress clippy warning about common suffix in favor or matching mode names as per IOMMU spec. #[allow(clippy::enum_variant_names)] @@ -72,9 +75,8 @@ impl Iommu { ) -> Result<&'static Iommu> { let mut dev = dev.lock(); - if dev.info().vendor_id().bits() != IOMMU_VENDOR_ID - || dev.info().device_id().bits() != IOMMU_DEVICE_ID - { + let pci_ids = (dev.info().vendor_id().bits(), dev.info().device_id().bits()); + if !IOMMU_PCI_ID_TABLE.contains(&pci_ids) { return Err(Error::NotAnIommu); }