about summary refs log tree commit diff
path: root/tvix/nix-compat/src
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src')
-rw-r--r--tvix/nix-compat/src/lib.rs2
-rw-r--r--tvix/nix-compat/src/nix_daemon/mod.rs3
-rw-r--r--tvix/nix-compat/src/nix_daemon/protocol_version.rs123
-rw-r--r--tvix/nix-compat/src/nix_daemon/worker_protocol.rs50
4 files changed, 156 insertions, 22 deletions
diff --git a/tvix/nix-compat/src/lib.rs b/tvix/nix-compat/src/lib.rs
index 5a9a3c69fa77..a71ede3eecf0 100644
--- a/tvix/nix-compat/src/lib.rs
+++ b/tvix/nix-compat/src/lib.rs
@@ -14,3 +14,5 @@ pub mod wire;
 mod nix_daemon;
 #[cfg(feature = "wire")]
 pub use nix_daemon::worker_protocol;
+#[cfg(feature = "wire")]
+pub use nix_daemon::ProtocolVersion;
diff --git a/tvix/nix-compat/src/nix_daemon/mod.rs b/tvix/nix-compat/src/nix_daemon/mod.rs
index 633fdbebd47c..fe652377d1b4 100644
--- a/tvix/nix-compat/src/nix_daemon/mod.rs
+++ b/tvix/nix-compat/src/nix_daemon/mod.rs
@@ -1 +1,4 @@
 pub mod worker_protocol;
+
+mod protocol_version;
+pub use protocol_version::ProtocolVersion;
diff --git a/tvix/nix-compat/src/nix_daemon/protocol_version.rs b/tvix/nix-compat/src/nix_daemon/protocol_version.rs
new file mode 100644
index 000000000000..8fd2b085c962
--- /dev/null
+++ b/tvix/nix-compat/src/nix_daemon/protocol_version.rs
@@ -0,0 +1,123 @@
+/// Protocol versions are represented as a u16.
+/// The upper 8 bits are the major version, the lower bits the minor.
+/// This is not aware of any endianness, use [crate::wire::read_u64] to get an
+/// u64 first, and the try_from() impl from here if you're receiving over the
+/// Nix Worker protocol.
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub struct ProtocolVersion(u16);
+
+impl ProtocolVersion {
+    pub const fn from_parts(major: u8, minor: u8) -> Self {
+        Self(((major as u16) << 8) | minor as u16)
+    }
+
+    pub fn major(&self) -> u8 {
+        ((self.0 & 0xff00) >> 8) as u8
+    }
+
+    pub fn minor(&self) -> u8 {
+        (self.0 & 0x00ff) as u8
+    }
+}
+
+impl PartialOrd for ProtocolVersion {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl Ord for ProtocolVersion {
+    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
+        match self.major().cmp(&other.major()) {
+            std::cmp::Ordering::Less => std::cmp::Ordering::Less,
+            std::cmp::Ordering::Greater => std::cmp::Ordering::Greater,
+            std::cmp::Ordering::Equal => {
+                // same major, compare minor
+                self.minor().cmp(&other.minor())
+            }
+        }
+    }
+}
+
+impl From<u16> for ProtocolVersion {
+    fn from(value: u16) -> Self {
+        Self::from_parts(((value & 0xff00) >> 8) as u8, (value & 0x00ff) as u8)
+    }
+}
+
+impl TryFrom<u64> for ProtocolVersion {
+    type Error = &'static str;
+
+    fn try_from(value: u64) -> Result<Self, Self::Error> {
+        if value & !0xffff != 0 {
+            return Err("only two least significant bits might be populated");
+        }
+
+        Ok((value as u16).into())
+    }
+}
+
+impl From<ProtocolVersion> for u16 {
+    fn from(value: ProtocolVersion) -> Self {
+        value.0
+    }
+}
+
+impl From<ProtocolVersion> for u64 {
+    fn from(value: ProtocolVersion) -> Self {
+        value.0 as u64
+    }
+}
+
+impl std::fmt::Display for ProtocolVersion {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}.{}", self.major(), self.minor())
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::ProtocolVersion;
+
+    #[test]
+    fn from_parts() {
+        let version = ProtocolVersion::from_parts(1, 37);
+        assert_eq!(version.major(), 1, "correct major");
+        assert_eq!(version.minor(), 37, "correct minor");
+        assert_eq!("1.37", &version.to_string(), "to_string");
+
+        assert_eq!(0x0125, Into::<u16>::into(version));
+        assert_eq!(0x0125, Into::<u64>::into(version));
+    }
+
+    #[test]
+    fn from_u16() {
+        let version = ProtocolVersion::from(0x0125_u16);
+        assert_eq!("1.37", &version.to_string());
+    }
+
+    #[test]
+    fn from_u64() {
+        let version = ProtocolVersion::try_from(0x0125_u64).expect("must succeed");
+        assert_eq!("1.37", &version.to_string());
+    }
+
+    /// This contains data in higher bits, which should fail.
+    #[test]
+    fn from_u64_fail() {
+        ProtocolVersion::try_from(0xaa0125_u64).expect_err("must fail");
+    }
+
+    #[test]
+    fn ord() {
+        let v0_37 = ProtocolVersion::from_parts(0, 37);
+        let v1_37 = ProtocolVersion::from_parts(1, 37);
+        let v1_40 = ProtocolVersion::from_parts(1, 40);
+
+        assert!(v0_37 < v1_37);
+        assert!(v1_37 > v0_37);
+        assert!(v1_37 < v1_40);
+        assert!(v1_40 > v1_37);
+        assert!(v1_40 <= v1_40);
+    }
+}
diff --git a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
index 4630a4f77067..58a48d1bdd25 100644
--- a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
+++ b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
@@ -9,11 +9,13 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
 
 use crate::wire;
 
+use super::ProtocolVersion;
+
 static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc"
 static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio"
 pub static STDERR_LAST: u64 = 0x616c7473; // "alts"
-/// Protocol version (1.37)
-static PROTOCOL_VERSION: [u8; 8] = [37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
+
+static PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::from_parts(1, 37);
 
 /// Max length of a Nix setting name/value. In bytes.
 ///
@@ -127,7 +129,7 @@ pub struct ClientSettings {
 /// FUTUREWORK: write serialization.
 pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
     r: &mut R,
-    client_version: u64,
+    client_version: ProtocolVersion,
 ) -> std::io::Result<ClientSettings> {
     let keep_failed = wire::read_bool(r).await?;
     let keep_going = wire::read_bool(r).await?;
@@ -148,7 +150,7 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
     let build_cores = wire::read_u64(r).await?;
     let use_substitutes = wire::read_bool(r).await?;
     let mut overrides = HashMap::new();
-    if client_version >= 12 {
+    if client_version.minor() >= 12 {
         let num_overrides = wire::read_u64(r).await?;
         for _ in 0..num_overrides {
             let name = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
@@ -173,8 +175,8 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
 /// Performs the initial handshake the server is sending to a connecting client.
 ///
 /// During the handshake, the client first send a magic u64, to which
-/// the daemon needs to respond with another magic u64. Then, the
-/// daemon retrieve the client version, and discard a bunch of now
+/// the daemon needs to respond with another magic u64.
+/// Then, the daemon retrieves the client version, and discards a bunch of now
 /// obsolete data.
 ///
 /// # Arguments
@@ -186,12 +188,12 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
 ///
 /// # Return
 ///
-/// The protocol version of a client encoded as a u64.
+/// The protocol version of the client.
 pub async fn server_handshake_client<'a, RW: 'a>(
     mut conn: &'a mut RW,
     nix_version: &str,
     trusted: Trust,
-) -> std::io::Result<u64>
+) -> std::io::Result<ProtocolVersion>
 where
     &'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin,
 {
@@ -203,39 +205,41 @@ where
         ))
     } else {
         wire::write_u64(&mut conn, WORKER_MAGIC_2).await?;
-        conn.write_all(&PROTOCOL_VERSION).await?;
+        wire::write_u64(&mut conn, PROTOCOL_VERSION.into()).await?;
         conn.flush().await?;
         let client_version = wire::read_u64(&mut conn).await?;
-        if client_version < 0x10a {
+        // Parse into ProtocolVersion.
+        let client_version: ProtocolVersion = client_version
+            .try_into()
+            .map_err(|e| Error::new(ErrorKind::Unsupported, e))?;
+        if client_version < ProtocolVersion::from_parts(1, 10) {
             return Err(Error::new(
                 ErrorKind::Unsupported,
                 format!("The nix client version {} is too old", client_version),
             ));
         }
-        let protocol_minor = client_version & 0x00ff;
-        let _protocol_major = client_version & 0xff00;
-        if protocol_minor >= 14 {
+        if client_version.minor() >= 14 {
             // Obsolete CPU affinity.
             let read_affinity = wire::read_u64(&mut conn).await?;
             if read_affinity != 0 {
                 let _cpu_affinity = wire::read_u64(&mut conn).await?;
             };
         }
-        if protocol_minor >= 11 {
+        if client_version.minor() >= 11 {
             // Obsolete reserveSpace
             let _reserve_space = wire::read_u64(&mut conn).await?;
         }
-        if protocol_minor >= 33 {
+        if client_version.minor() >= 33 {
             // Nix version. We're plain lying, we're not Nix, but eh…
             // Setting it to the 2.3 lineage. Not 100% sure this is a
             // good idea.
             wire::write_bytes(&mut conn, nix_version).await?;
             conn.flush().await?;
         }
-        if protocol_minor >= 35 {
+        if client_version.minor() >= 35 {
             write_worker_trust_level(&mut conn, trusted).await?;
         }
-        Ok(protocol_minor)
+        Ok(client_version)
     }
 }
 
@@ -290,10 +294,10 @@ mod tests {
         let mut test_conn = tokio_test::io::Builder::new()
             .read(&WORKER_MAGIC_1.to_le_bytes())
             .write(&WORKER_MAGIC_2.to_le_bytes())
-            .write(&PROTOCOL_VERSION)
+            .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
             // Let's say the client is in sync with the daemon
             // protocol-wise
-            .read(&PROTOCOL_VERSION)
+            .read(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
             // cpu affinity
             .read(&[0; 8])
             // reservespace
@@ -305,9 +309,11 @@ mod tests {
             // Trusted (1 == client trusted
             .write(&[1, 0, 0, 0, 0, 0, 0, 0])
             .build();
-        server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted)
+        let client_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted)
             .await
             .unwrap();
+
+        assert_eq!(client_version, PROTOCOL_VERSION)
     }
 
     #[tokio::test]
@@ -329,7 +335,7 @@ mod tests {
              00 00 00 00 00 00 00 00"
         );
         let mut mock = Builder::new().read(&wire_bits).build();
-        let settings = read_client_settings(&mut mock, 21)
+        let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21))
             .await
             .expect("should parse");
         let expected = ClientSettings {
@@ -380,7 +386,7 @@ mod tests {
              72 72 65 00 00 00 00 00"
         );
         let mut mock = Builder::new().read(&wire_bits).build();
-        let settings = read_client_settings(&mut mock, 21)
+        let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21))
             .await
             .expect("should parse");
         let overrides = HashMap::from([