about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/worker_protocol.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/worker_protocol.rs')
-rw-r--r--tvix/nix-compat/src/wire/worker_protocol.rs106
1 files changed, 100 insertions, 6 deletions
diff --git a/tvix/nix-compat/src/wire/worker_protocol.rs b/tvix/nix-compat/src/wire/worker_protocol.rs
index 121b9b2ea5cb..a9d3bd85478b 100644
--- a/tvix/nix-compat/src/wire/worker_protocol.rs
+++ b/tvix/nix-compat/src/wire/worker_protocol.rs
@@ -7,15 +7,15 @@ use enum_primitive_derive::Primitive;
 use num_traits::{FromPrimitive, ToPrimitive};
 use tokio::io::{AsyncReadExt, AsyncWriteExt};
 
-use crate::wire::primitive;
+use crate::wire::{bytes, primitive};
 
 use super::bytes::read_string;
 
-pub static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc"
-pub static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio"
+static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc"
+static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio"
 pub static STDERR_LAST: u64 = 0x616c7473; // "alts"
-/// Protocol version (1.35)
-pub static PROTOCOL_VERSION: u64 = 1 << 8 | 35;
+/// Protocol version (1.37)
+static PROTOCOL_VERSION: [u8; 8] = [37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
 
 /// Max length of a Nix setting name/value. In bytes.
 ///
@@ -172,6 +172,75 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
     })
 }
 
+/// Performs the initial handshake the server is sending to a connecting client.
+///
+/// During the handshake, the client first send a magic u64, to which
+/// the daemon needs to respond with another magic u64. Then, the
+/// daemon retrieve the client version, and discard a bunch of now
+/// obsolete data.
+///
+/// # Arguments
+///
+/// * conn: connection with the Nix client.
+/// * nix_version: semantic version of the Nix daemon. "2.18.2" for
+///   instance.
+/// * trusted: trust level of the Nix client.
+///
+/// # Return
+///
+/// The protocol version of a client encoded as a u64.
+pub async fn server_handshake_client<'a, RW: 'a>(
+    mut conn: &'a mut RW,
+    nix_version: &str,
+    trusted: Trust,
+) -> std::io::Result<u64>
+where
+    &'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin,
+{
+    let worker_magic_1 = primitive::read_u64(&mut conn).await?;
+    if worker_magic_1 != WORKER_MAGIC_1 {
+        Err(std::io::Error::new(
+            ErrorKind::InvalidData,
+            format!("Incorrect worker magic number received: {}", worker_magic_1),
+        ))
+    } else {
+        primitive::write_u64(&mut conn, WORKER_MAGIC_2).await?;
+        conn.write_all(&PROTOCOL_VERSION).await?;
+        conn.flush().await?;
+        let client_version = primitive::read_u64(&mut conn).await?;
+        if client_version < 0x10a {
+            return Err(Error::new(
+                ErrorKind::Unsupported,
+                format!("The nix client version {} is too old", client_version),
+            ));
+        }
+        let protocol_minor = client_version & 0x00ff;
+        let _protocol_major = client_version & 0xff00;
+        if protocol_minor >= 14 {
+            // Obsolete CPU affinity.
+            let read_affinity = primitive::read_u64(&mut conn).await?;
+            if read_affinity != 0 {
+                let _cpu_affinity = primitive::read_u64(&mut conn).await?;
+            };
+        }
+        if protocol_minor >= 11 {
+            // Obsolete reserveSpace
+            let _reserve_space = primitive::read_u64(&mut conn).await?;
+        }
+        if protocol_minor >= 33 {
+            // Nix version. We're plain lying, we're not Nix, but eh…
+            // Setting it to the 2.3 lineage. Not 100% sure this is a
+            // good idea.
+            bytes::write_bytes(&mut conn, nix_version).await?;
+            conn.flush().await?;
+        }
+        if protocol_minor >= 35 {
+            write_worker_trust_level(&mut conn, trusted).await?;
+        }
+        Ok(protocol_minor)
+    }
+}
+
 /// Read a worker [Operation] from the wire.
 pub async fn read_op<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Operation> {
     let op_number = primitive::read_u64(r).await?;
@@ -204,7 +273,7 @@ pub enum Trust {
 /// decided not to implement it here.
 pub async fn write_worker_trust_level<W>(conn: &mut W, t: Trust) -> std::io::Result<()>
 where
-    W: AsyncReadExt + AsyncWriteExt + Unpin + std::fmt::Debug,
+    W: AsyncReadExt + AsyncWriteExt + Unpin,
 {
     match t {
         Trust::Trusted => primitive::write_u64(conn, 1).await,
@@ -219,6 +288,31 @@ mod tests {
     use tokio_test::io::Builder;
 
     #[tokio::test]
+    async fn test_init_hanshake() {
+        let mut test_conn = tokio_test::io::Builder::new()
+            .read(&WORKER_MAGIC_1.to_le_bytes())
+            .write(&WORKER_MAGIC_2.to_le_bytes())
+            .write(&PROTOCOL_VERSION)
+            // Let's say the client is in sync with the daemon
+            // protocol-wise
+            .read(&PROTOCOL_VERSION)
+            // cpu affinity
+            .read(&[0; 8])
+            // reservespace
+            .read(&[0; 8])
+            // version (size)
+            .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
+            // version (data == 2.18.2 + padding)
+            .write(&[50, 46, 49, 56, 46, 50, 0, 0])
+            // Trusted (1 == client trusted
+            .write(&[1, 0, 0, 0, 0, 0, 0, 0])
+            .build();
+        server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted)
+            .await
+            .unwrap();
+    }
+
+    #[tokio::test]
     async fn test_read_client_settings_without_overrides() {
         // Client settings bits captured from a Nix 2.3.17 run w/ sockdump (protocol version 21).
         let wire_bits = hex!(