about summary refs log tree commit diff
path: root/tvix/castore/src/fs/fuse/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/castore/src/fs/fuse/mod.rs')
-rw-r--r--tvix/castore/src/fs/fuse/mod.rs56
1 files changed, 35 insertions, 21 deletions
diff --git a/tvix/castore/src/fs/fuse/mod.rs b/tvix/castore/src/fs/fuse/mod.rs
index 94b73d422a14..64ef29ed2aa1 100644
--- a/tvix/castore/src/fs/fuse/mod.rs
+++ b/tvix/castore/src/fs/fuse/mod.rs
@@ -1,6 +1,8 @@
-use std::{io, path::Path, sync::Arc, thread};
+use std::{io, path::Path, sync::Arc};
 
 use fuse_backend_rs::{api::filesystem::FileSystem, transport::FuseSession};
+use parking_lot::Mutex;
+use threadpool::ThreadPool;
 use tracing::{error, instrument};
 
 #[cfg(test)]
@@ -49,9 +51,12 @@ where
     }
 }
 
+/// Starts a [Filesystem] with the specified number of threads, and provides
+/// functions to unmount, and wait for it to have completed.
+#[derive(Clone)]
 pub struct FuseDaemon {
-    session: FuseSession,
-    threads: Vec<thread::JoinHandle<()>>,
+    session: Arc<Mutex<FuseSession>>,
+    threads: Arc<ThreadPool>,
 }
 
 impl FuseDaemon {
@@ -59,7 +64,7 @@ impl FuseDaemon {
     pub fn new<FS, P>(
         fs: FS,
         mountpoint: P,
-        threads: usize,
+        num_threads: usize,
         allow_other: bool,
     ) -> Result<Self, io::Error>
     where
@@ -76,40 +81,49 @@ impl FuseDaemon {
         session
             .mount()
             .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
-        let mut join_handles = Vec::with_capacity(threads);
-        for _ in 0..threads {
+
+        // construct a thread pool
+        let threads = threadpool::Builder::new()
+            .num_threads(num_threads)
+            .thread_name("fuse_server".to_string())
+            .build();
+
+        for _ in 0..num_threads {
+            // for each thread requested, create and start a FuseServer accepting requests.
             let mut server = FuseServer {
                 server: server.clone(),
                 channel: session
                     .new_channel()
                     .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?,
             };
-            let join_handle = thread::Builder::new()
-                .name("fuse_server".to_string())
-                .spawn(move || {
-                    let _ = server.start();
-                })?;
-            join_handles.push(join_handle);
+
+            threads.execute(move || {
+                let _ = server.start();
+            });
         }
 
         Ok(FuseDaemon {
-            session,
-            threads: join_handles,
+            session: Arc::new(Mutex::new(session)),
+            threads: Arc::new(threads),
         })
     }
 
+    /// Waits for all threads to finish.
+    #[instrument(skip_all)]
+    pub fn wait(&self) {
+        self.threads.join()
+    }
+
+    /// Send the unmount command, and waits for all threads to finish.
     #[instrument(skip_all, err)]
-    pub fn unmount(&mut self) -> Result<(), io::Error> {
+    pub fn unmount(&self) -> Result<(), io::Error> {
+        // Send the unmount command.
         self.session
+            .lock()
             .umount()
             .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
 
-        for thread in self.threads.drain(..) {
-            thread.join().map_err(|_| {
-                io::Error::new(io::ErrorKind::Other, "failed to join fuse server thread")
-            })?;
-        }
-
+        self.wait();
         Ok(())
     }
 }