about summary refs log tree commit diff
path: root/tvix/nix-compat/src/nix_daemon/handler.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/nix_daemon/handler.rs')
-rw-r--r--tvix/nix-compat/src/nix_daemon/handler.rs151
1 files changed, 119 insertions, 32 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/handler.rs b/tvix/nix-compat/src/nix_daemon/handler.rs
index 4f43612114d8..6fb45bdb7e2d 100644
--- a/tvix/nix-compat/src/nix_daemon/handler.rs
+++ b/tvix/nix-compat/src/nix_daemon/handler.rs
@@ -254,45 +254,17 @@ where
 #[cfg(test)]
 mod tests {
     use super::*;
-    use std::{io::Result, sync::Arc};
+    use std::{io::ErrorKind, sync::Arc};
 
+    use mockall::predicate;
     use tokio::io::AsyncWriteExt;
 
     use crate::{
-        nix_daemon::types::UnkeyedValidPathInfo,
+        nix_daemon::MockNixDaemonIO,
         wire::ProtocolVersion,
         worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2},
     };
 
-    struct MockDaemonIO {}
-
-    impl NixDaemonIO for MockDaemonIO {
-        async fn query_path_info(
-            &self,
-            _path: &crate::store_path::StorePath<String>,
-        ) -> Result<Option<UnkeyedValidPathInfo>> {
-            Ok(None)
-        }
-
-        async fn query_path_from_hash_part(
-            &self,
-            _hash: &[u8],
-        ) -> Result<Option<UnkeyedValidPathInfo>> {
-            Ok(None)
-        }
-
-        async fn add_to_store_nar<R>(
-            &self,
-            _request: crate::nix_daemon::types::AddToStoreNarRequest,
-            _reader: &mut R,
-        ) -> Result<()>
-        where
-            R: tokio::io::AsyncRead + Send + Unpin,
-        {
-            Ok(())
-        }
-    }
-
     #[tokio::test]
     async fn test_daemon_initialization() {
         let mut builder = tokio_test::io::Builder::new();
@@ -332,10 +304,125 @@ mod tests {
             .write(&[115, 116, 108, 97, 0, 0, 0, 0])
             .build();
 
-        let daemon = NixDaemon::initialize(Arc::new(MockDaemonIO {}), test_conn)
+        let mock = MockNixDaemonIO::new();
+        let daemon = NixDaemon::initialize(Arc::new(mock), test_conn)
             .await
             .unwrap();
         assert_eq!(daemon.client_settings, ClientSettings::default());
         assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35));
     }
+
+    async fn serialize<T>(req: &T, protocol_version: ProtocolVersion) -> Vec<u8>
+    where
+        T: NixSerialize + Send,
+    {
+        let mut result: Vec<u8> = Vec::new();
+        let mut w = NixWriter::builder()
+            .set_version(protocol_version)
+            .build(&mut result);
+        w.write_value(req).await.unwrap();
+        w.flush().await.unwrap();
+        result
+    }
+
+    async fn respond<T>(
+        resp: &Result<T, std::io::Error>,
+        protocol_version: ProtocolVersion,
+    ) -> Vec<u8>
+    where
+        T: NixSerialize + Send,
+    {
+        let mut result: Vec<u8> = Vec::new();
+        let mut w = NixWriter::builder()
+            .set_version(protocol_version)
+            .build(&mut result);
+        match resp {
+            Ok(value) => {
+                w.write_value(&STDERR_LAST).await.unwrap();
+                w.write_value(value).await.unwrap();
+            }
+            Err(e) => {
+                w.write_value(&STDERR_ERROR).await.unwrap();
+                w.write_value(&NixError::new(format!("{:?}", e)))
+                    .await
+                    .unwrap();
+            }
+        }
+        w.flush().await.unwrap();
+        result
+    }
+
+    #[tokio::test]
+    async fn test_handle_is_valid_path_ok() {
+        let version = ProtocolVersion::from_parts(1, 37);
+        let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
+        let mut mock = MockNixDaemonIO::new();
+        let (reader, writer) = split(io);
+        let path: StorePath<String> = StorePath::<String>::from_absolute_path(
+            "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
+        )
+        .unwrap();
+        mock.expect_is_valid_path()
+            .with(predicate::eq(path.clone()))
+            .times(1)
+            .returning(|_| Box::pin(async { Ok(true) }));
+
+        handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
+        handle.read(&serialize(&path, version).await);
+        handle.write(&respond(&Ok(true), version).await);
+        drop(handle);
+
+        let mut daemon = NixDaemon::new(
+            Arc::new(mock),
+            version,
+            ClientSettings::default(),
+            NixReader::new(reader),
+            NixWriter::new(writer),
+        );
+        assert_eq!(
+            ErrorKind::UnexpectedEof,
+            daemon
+                .handle_client()
+                .await
+                .expect_err("Expecting eof")
+                .kind()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_handle_is_valid_path_err() {
+        let version = ProtocolVersion::from_parts(1, 37);
+        let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
+        let mut mock = MockNixDaemonIO::new();
+        let (reader, writer) = split(io);
+        let path: StorePath<String> = StorePath::<String>::from_absolute_path(
+            "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
+        )
+        .unwrap();
+        mock.expect_is_valid_path()
+            .with(predicate::eq(path.clone()))
+            .times(1)
+            .returning(|_| Box::pin(async { Err(std::io::Error::other("hello")) }));
+
+        handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
+        handle.read(&serialize(&path, version).await);
+        handle.write(&respond::<bool>(&Err(std::io::Error::other("hello")), version).await);
+        drop(handle);
+
+        let mut daemon = NixDaemon::new(
+            Arc::new(mock),
+            version,
+            ClientSettings::default(),
+            NixReader::new(reader),
+            NixWriter::new(writer),
+        );
+        assert_eq!(
+            ErrorKind::UnexpectedEof,
+            daemon
+                .handle_client()
+                .await
+                .expect_err("Expecting eof")
+                .kind()
+        );
+    }
 }