diff --git a/src/vfio_device.rs b/src/vfio_device.rs index 1e8f0e6..6ecd33b 100644 --- a/src/vfio_device.rs +++ b/src/vfio_device.rs @@ -22,6 +22,12 @@ use kvm_bindings::{ #[cfg(feature = "kvm")] use kvm_ioctls::DeviceFd; use log::{debug, error, warn}; +#[cfg(feature = "mshv")] +use mshv_bindings::{ + mshv_device_attr, MSHV_DEV_VFIO_GROUP, MSHV_DEV_VFIO_GROUP_ADD, MSHV_DEV_VFIO_GROUP_DEL, +}; +#[cfg(feature = "mshv")] +use mshv_ioctls::DeviceFd; use vfio_bindings::bindings::vfio::*; use vm_memory::{Address, GuestMemory, GuestMemoryRegion, MemoryRegionAddress}; use vmm_sys_util::errno::Error as SysError; @@ -48,6 +54,8 @@ pub enum VfioError { GroupGetDeviceFD, #[cfg(feature = "kvm")] KvmSetDeviceAttr(SysError), + #[cfg(feature = "mshv")] + MshvSetDeviceAttr(SysError), VfioDeviceGetInfo, VfioDeviceGetRegionInfo(SysError), InvalidPath, @@ -91,6 +99,10 @@ impl fmt::Display for VfioError { VfioError::KvmSetDeviceAttr(e) => { write!(f, "failed to set KVM vfio device's attribute: {}", e) } + #[cfg(feature = "mshv")] + VfioError::MshvSetDeviceAttr(e) => { + write!(f, "failed to set MSHV vfio device's attribute: {}", e) + } VfioError::VfioDeviceGetInfo => { write!(f, "failed to get vfio device's info or info doesn't match") } @@ -125,6 +137,8 @@ impl std::error::Error for VfioError { VfioError::GroupGetDeviceFD => None, #[cfg(feature = "kvm")] VfioError::KvmSetDeviceAttr(e) => Some(e), + #[cfg(feature = "mshv")] + VfioError::MshvSetDeviceAttr(e) => Some(e), VfioError::VfioDeviceGetInfo => None, VfioError::VfioDeviceGetRegionInfo(e) => Some(e), VfioError::InvalidPath => None, @@ -259,6 +273,36 @@ impl VfioContainer { .map_err(VfioError::KvmSetDeviceAttr) } + #[cfg(feature = "mshv")] + fn mshv_device_add_group(&self, group_fd: RawFd) -> Result<()> { + let group_fd_ptr = &group_fd as *const i32; + let dev_attr = mshv_device_attr { + flags: 0, + group: MSHV_DEV_VFIO_GROUP, + attr: u64::from(MSHV_DEV_VFIO_GROUP_ADD), + addr: group_fd_ptr as u64, + }; + + self.device_fd + .set_device_attr(&dev_attr) + .map_err(VfioError::MshvSetDeviceAttr) + } + + #[cfg(feature = "mshv")] + fn mshv_device_del_group(&self, group_fd: RawFd) -> Result<()> { + let group_fd_ptr = &group_fd as *const i32; + let dev_attr = mshv_device_attr { + flags: 0, + group: MSHV_DEV_VFIO_GROUP, + attr: u64::from(MSHV_DEV_VFIO_GROUP_DEL), + addr: group_fd_ptr as u64, + }; + + self.device_fd + .set_device_attr(&dev_attr) + .map_err(VfioError::MshvSetDeviceAttr) + } + fn get_group(&self, group_id: u32) -> Result> { // Safe because there's no legal way to break the lock. let mut hash = self.groups.lock().unwrap(); @@ -289,7 +333,11 @@ impl VfioContainer { // Add the new group object to the KVM driver. #[cfg(feature = "kvm")] - if let Err(e) = self.kvm_device_add_group(group.as_raw_fd()) { + let result = self.kvm_device_add_group(group.as_raw_fd()); + #[cfg(feature = "mshv")] + let result = self.mshv_device_add_group(group.as_raw_fd()); + + if let Err(e) = result { let _ = unsafe { ioctl_with_ref(&*group, VFIO_GROUP_UNSET_CONTAINER(), &self.as_raw_fd()) }; return Err(e); @@ -311,7 +359,11 @@ impl VfioContainer { // - one reference held by the groups hashmap if Arc::strong_count(&group) == 3 { #[cfg(feature = "kvm")] - match self.kvm_device_del_group(group.as_raw_fd()) { + let result = self.kvm_device_del_group(group.as_raw_fd()); + #[cfg(feature = "mshv")] + let result = self.mshv_device_del_group(group.as_raw_fd()); + + match result { Ok(_) => {} Err(e) => { error!("Could not delete VFIO group: {:?}", e);