about summary refs log tree commit diff
path: root/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/nix_daemon/worker_protocol.rs')
-rw-r--r--tvix/nix-compat/src/nix_daemon/worker_protocol.rs50
1 files changed, 28 insertions, 22 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
index 4630a4f77067..58a48d1bdd25 100644
--- a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
+++ b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
@@ -9,11 +9,13 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
 
 use crate::wire;
 
+use super::ProtocolVersion;
+
 static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc"
 static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio"
 pub static STDERR_LAST: u64 = 0x616c7473; // "alts"
-/// Protocol version (1.37)
-static PROTOCOL_VERSION: [u8; 8] = [37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
+
+static PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::from_parts(1, 37);
 
 /// Max length of a Nix setting name/value. In bytes.
 ///
@@ -127,7 +129,7 @@ pub struct ClientSettings {
 /// FUTUREWORK: write serialization.
 pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
     r: &mut R,
-    client_version: u64,
+    client_version: ProtocolVersion,
 ) -> std::io::Result<ClientSettings> {
     let keep_failed = wire::read_bool(r).await?;
     let keep_going = wire::read_bool(r).await?;
@@ -148,7 +150,7 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
     let build_cores = wire::read_u64(r).await?;
     let use_substitutes = wire::read_bool(r).await?;
     let mut overrides = HashMap::new();
-    if client_version >= 12 {
+    if client_version.minor() >= 12 {
         let num_overrides = wire::read_u64(r).await?;
         for _ in 0..num_overrides {
             let name = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
@@ -173,8 +175,8 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
 /// Performs the initial handshake the server is sending to a connecting client.
 ///
 /// During the handshake, the client first send a magic u64, to which
-/// the daemon needs to respond with another magic u64. Then, the
-/// daemon retrieve the client version, and discard a bunch of now
+/// the daemon needs to respond with another magic u64.
+/// Then, the daemon retrieves the client version, and discards a bunch of now
 /// obsolete data.
 ///
 /// # Arguments
@@ -186,12 +188,12 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
 ///
 /// # Return
 ///
-/// The protocol version of a client encoded as a u64.
+/// The protocol version of the client.
 pub async fn server_handshake_client<'a, RW: 'a>(
     mut conn: &'a mut RW,
     nix_version: &str,
     trusted: Trust,
-) -> std::io::Result<u64>
+) -> std::io::Result<ProtocolVersion>
 where
     &'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin,
 {
@@ -203,39 +205,41 @@ where
         ))
     } else {
         wire::write_u64(&mut conn, WORKER_MAGIC_2).await?;
-        conn.write_all(&PROTOCOL_VERSION).await?;
+        wire::write_u64(&mut conn, PROTOCOL_VERSION.into()).await?;
         conn.flush().await?;
         let client_version = wire::read_u64(&mut conn).await?;
-        if client_version < 0x10a {
+        // Parse into ProtocolVersion.
+        let client_version: ProtocolVersion = client_version
+            .try_into()
+            .map_err(|e| Error::new(ErrorKind::Unsupported, e))?;
+        if client_version < ProtocolVersion::from_parts(1, 10) {
             return Err(Error::new(
                 ErrorKind::Unsupported,
                 format!("The nix client version {} is too old", client_version),
             ));
         }
-        let protocol_minor = client_version & 0x00ff;
-        let _protocol_major = client_version & 0xff00;
-        if protocol_minor >= 14 {
+        if client_version.minor() >= 14 {
             // Obsolete CPU affinity.
             let read_affinity = wire::read_u64(&mut conn).await?;
             if read_affinity != 0 {
                 let _cpu_affinity = wire::read_u64(&mut conn).await?;
             };
         }
-        if protocol_minor >= 11 {
+        if client_version.minor() >= 11 {
             // Obsolete reserveSpace
             let _reserve_space = wire::read_u64(&mut conn).await?;
         }
-        if protocol_minor >= 33 {
+        if client_version.minor() >= 33 {
             // Nix version. We're plain lying, we're not Nix, but eh…
             // Setting it to the 2.3 lineage. Not 100% sure this is a
             // good idea.
             wire::write_bytes(&mut conn, nix_version).await?;
             conn.flush().await?;
         }
-        if protocol_minor >= 35 {
+        if client_version.minor() >= 35 {
             write_worker_trust_level(&mut conn, trusted).await?;
         }
-        Ok(protocol_minor)
+        Ok(client_version)
     }
 }
 
@@ -290,10 +294,10 @@ mod tests {
         let mut test_conn = tokio_test::io::Builder::new()
             .read(&WORKER_MAGIC_1.to_le_bytes())
             .write(&WORKER_MAGIC_2.to_le_bytes())
-            .write(&PROTOCOL_VERSION)
+            .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
             // Let's say the client is in sync with the daemon
             // protocol-wise
-            .read(&PROTOCOL_VERSION)
+            .read(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
             // cpu affinity
             .read(&[0; 8])
             // reservespace
@@ -305,9 +309,11 @@ mod tests {
             // Trusted (1 == client trusted
             .write(&[1, 0, 0, 0, 0, 0, 0, 0])
             .build();
-        server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted)
+        let client_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted)
             .await
             .unwrap();
+
+        assert_eq!(client_version, PROTOCOL_VERSION)
     }
 
     #[tokio::test]
@@ -329,7 +335,7 @@ mod tests {
              00 00 00 00 00 00 00 00"
         );
         let mut mock = Builder::new().read(&wire_bits).build();
-        let settings = read_client_settings(&mut mock, 21)
+        let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21))
             .await
             .expect("should parse");
         let expected = ClientSettings {
@@ -380,7 +386,7 @@ mod tests {
              72 72 65 00 00 00 00 00"
         );
         let mut mock = Builder::new().read(&wire_bits).build();
-        let settings = read_client_settings(&mut mock, 21)
+        let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21))
             .await
             .expect("should parse");
         let overrides = HashMap::from([