use std::{ convert, error, fmt, io, ops::Deref, path::Path, sync::{Arc, MutexGuard, RwLock}, }; use fuse_backend_rs::{ api::{filesystem::FileSystem, server::Server}, transport::{FsCacheReqHandler, Reader, VirtioFsWriter}, }; use tracing::error; use vhost::vhost_user::{ Listener, SlaveFsCacheReq, VhostUserProtocolFeatures, VhostUserVirtioFeatures, }; use vhost_user_backend::{VhostUserBackendMut, VhostUserDaemon, VringMutex, VringState, VringT}; use virtio_bindings::bindings::virtio_ring::{ VIRTIO_RING_F_EVENT_IDX, VIRTIO_RING_F_INDIRECT_DESC, }; use virtio_queue::QueueT; use vm_memory::{GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap}; use vmm_sys_util::epoll::EventSet; const VIRTIO_F_VERSION_1: u32 = 32; const NUM_QUEUES: usize = 2; const QUEUE_SIZE: usize = 1024; #[derive(Debug)] enum Error { /// Failed to handle non-input event. HandleEventNotEpollIn, /// Failed to handle unknown event. HandleEventUnknownEvent, /// Invalid descriptor chain. InvalidDescriptorChain, /// Failed to handle filesystem requests. HandleRequests(fuse_backend_rs::Error), /// Failed to construct new vhost user daemon. NewDaemon, /// Failed to start the vhost user daemon. StartDaemon, /// Failed to wait for the vhost user daemon. WaitDaemon, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "vhost_user_fs_error: {self:?}") } } impl error::Error for Error {} impl convert::From<Error> for io::Error { fn from(e: Error) -> Self { io::Error::new(io::ErrorKind::Other, e) } } struct VhostUserFsBackend<FS> where FS: FileSystem + Send + Sync, { server: Arc<Server<Arc<FS>>>, event_idx: bool, guest_mem: GuestMemoryAtomic<GuestMemoryMmap>, cache_req: Option<SlaveFsCacheReq>, } impl<FS> VhostUserFsBackend<FS> where FS: FileSystem + Send + Sync, { fn process_queue(&mut self, vring: &mut MutexGuard<VringState>) -> std::io::Result<bool> { let mut used_descs = false; while let Some(desc_chain) = vring .get_queue_mut() .pop_descriptor_chain(self.guest_mem.memory()) { let memory = desc_chain.memory(); let reader = Reader::from_descriptor_chain(memory, desc_chain.clone()) .map_err(|_| Error::InvalidDescriptorChain)?; let writer = VirtioFsWriter::new(memory, desc_chain.clone()) .map_err(|_| Error::InvalidDescriptorChain)?; self.server .handle_message( reader, writer.into(), self.cache_req .as_mut() .map(|req| req as &mut dyn FsCacheReqHandler), None, ) .map_err(Error::HandleRequests)?; // TODO: Is len 0 correct? if let Err(error) = vring .get_queue_mut() .add_used(memory, desc_chain.head_index(), 0) { error!(?error, "failed to add desc back to ring"); } // TODO: What happens if we error out before here? used_descs = true; } let needs_notification = if self.event_idx { match vring .get_queue_mut() .needs_notification(self.guest_mem.memory().deref()) { Ok(needs_notification) => needs_notification, Err(error) => { error!(?error, "failed to check if queue needs notification"); true } } } else { true }; if needs_notification { if let Err(error) = vring.signal_used_queue() { error!(?error, "failed to signal used queue"); } } Ok(used_descs) } } impl<FS> VhostUserBackendMut<VringMutex> for VhostUserFsBackend<FS> where FS: FileSystem + Send + Sync, { fn num_queues(&self) -> usize { NUM_QUEUES } fn max_queue_size(&self) -> usize { QUEUE_SIZE } fn features(&self) -> u64 { 1 << VIRTIO_F_VERSION_1 | 1 << VIRTIO_RING_F_INDIRECT_DESC | 1 << VIRTIO_RING_F_EVENT_IDX | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() } fn protocol_features(&self) -> VhostUserProtocolFeatures { VhostUserProtocolFeatures::MQ | VhostUserProtocolFeatures::SLAVE_REQ } fn set_event_idx(&mut self, enabled: bool) { self.event_idx = enabled; } fn update_memory(&mut self, _mem: GuestMemoryAtomic<GuestMemoryMmap>) -> std::io::Result<()> { // This is what most the vhost user implementations do... Ok(()) } fn set_slave_req_fd(&mut self, cache_req: SlaveFsCacheReq) { self.cache_req = Some(cache_req); } fn handle_event( &mut self, device_event: u16, evset: vmm_sys_util::epoll::EventSet, vrings: &[VringMutex], _thread_id: usize, ) -> std::io::Result<bool> { if evset != EventSet::IN { return Err(Error::HandleEventNotEpollIn.into()); } let mut queue = match device_event { // High priority queue 0 => vrings[0].get_mut(), // Regurlar priority queue 1 => vrings[1].get_mut(), _ => { return Err(Error::HandleEventUnknownEvent.into()); } }; if self.event_idx { loop { queue .get_queue_mut() .enable_notification(self.guest_mem.memory().deref()) .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?; if !self.process_queue(&mut queue)? { break; } } } else { self.process_queue(&mut queue)?; } Ok(false) } } pub fn start_virtiofs_daemon<FS, P>(fs: FS, socket: P) -> io::Result<()> where FS: FileSystem + Send + Sync + 'static, P: AsRef<Path>, { let guest_mem = GuestMemoryAtomic::new(GuestMemoryMmap::new()); let server = Arc::new(fuse_backend_rs::api::server::Server::new(Arc::new(fs))); let backend = Arc::new(RwLock::new(VhostUserFsBackend { server, guest_mem: guest_mem.clone(), event_idx: false, cache_req: None, })); let listener = Listener::new(socket, true).unwrap(); let mut fs_daemon = VhostUserDaemon::new(String::from("vhost-user-fs-tvix-store"), backend, guest_mem) .map_err(|_| Error::NewDaemon)?; fs_daemon.start(listener).map_err(|_| Error::StartDaemon)?; fs_daemon.wait().map_err(|_| Error::WaitDaemon)?; Ok(()) }