diff options
Diffstat (limited to 'tvix/nix-compat/src')
-rw-r--r-- | tvix/nix-compat/src/lib.rs | 2 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/mod.rs | 3 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/protocol_version.rs | 123 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/worker_protocol.rs | 50 |
4 files changed, 156 insertions, 22 deletions
diff --git a/tvix/nix-compat/src/lib.rs b/tvix/nix-compat/src/lib.rs index 5a9a3c69fa77..a71ede3eecf0 100644 --- a/tvix/nix-compat/src/lib.rs +++ b/tvix/nix-compat/src/lib.rs @@ -14,3 +14,5 @@ pub mod wire; mod nix_daemon; #[cfg(feature = "wire")] pub use nix_daemon::worker_protocol; +#[cfg(feature = "wire")] +pub use nix_daemon::ProtocolVersion; diff --git a/tvix/nix-compat/src/nix_daemon/mod.rs b/tvix/nix-compat/src/nix_daemon/mod.rs index 633fdbebd47c..fe652377d1b4 100644 --- a/tvix/nix-compat/src/nix_daemon/mod.rs +++ b/tvix/nix-compat/src/nix_daemon/mod.rs @@ -1 +1,4 @@ pub mod worker_protocol; + +mod protocol_version; +pub use protocol_version::ProtocolVersion; diff --git a/tvix/nix-compat/src/nix_daemon/protocol_version.rs b/tvix/nix-compat/src/nix_daemon/protocol_version.rs new file mode 100644 index 000000000000..8fd2b085c962 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/protocol_version.rs @@ -0,0 +1,123 @@ +/// Protocol versions are represented as a u16. +/// The upper 8 bits are the major version, the lower bits the minor. +/// This is not aware of any endianness, use [crate::wire::read_u64] to get an +/// u64 first, and the try_from() impl from here if you're receiving over the +/// Nix Worker protocol. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct ProtocolVersion(u16); + +impl ProtocolVersion { + pub const fn from_parts(major: u8, minor: u8) -> Self { + Self(((major as u16) << 8) | minor as u16) + } + + pub fn major(&self) -> u8 { + ((self.0 & 0xff00) >> 8) as u8 + } + + pub fn minor(&self) -> u8 { + (self.0 & 0x00ff) as u8 + } +} + +impl PartialOrd for ProtocolVersion { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for ProtocolVersion { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match self.major().cmp(&other.major()) { + std::cmp::Ordering::Less => std::cmp::Ordering::Less, + std::cmp::Ordering::Greater => std::cmp::Ordering::Greater, + std::cmp::Ordering::Equal => { + // same major, compare minor + self.minor().cmp(&other.minor()) + } + } + } +} + +impl From<u16> for ProtocolVersion { + fn from(value: u16) -> Self { + Self::from_parts(((value & 0xff00) >> 8) as u8, (value & 0x00ff) as u8) + } +} + +impl TryFrom<u64> for ProtocolVersion { + type Error = &'static str; + + fn try_from(value: u64) -> Result<Self, Self::Error> { + if value & !0xffff != 0 { + return Err("only two least significant bits might be populated"); + } + + Ok((value as u16).into()) + } +} + +impl From<ProtocolVersion> for u16 { + fn from(value: ProtocolVersion) -> Self { + value.0 + } +} + +impl From<ProtocolVersion> for u64 { + fn from(value: ProtocolVersion) -> Self { + value.0 as u64 + } +} + +impl std::fmt::Display for ProtocolVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}", self.major(), self.minor()) + } +} + +#[cfg(test)] +mod tests { + use super::ProtocolVersion; + + #[test] + fn from_parts() { + let version = ProtocolVersion::from_parts(1, 37); + assert_eq!(version.major(), 1, "correct major"); + assert_eq!(version.minor(), 37, "correct minor"); + assert_eq!("1.37", &version.to_string(), "to_string"); + + assert_eq!(0x0125, Into::<u16>::into(version)); + assert_eq!(0x0125, Into::<u64>::into(version)); + } + + #[test] + fn from_u16() { + let version = ProtocolVersion::from(0x0125_u16); + assert_eq!("1.37", &version.to_string()); + } + + #[test] + fn from_u64() { + let version = ProtocolVersion::try_from(0x0125_u64).expect("must succeed"); + assert_eq!("1.37", &version.to_string()); + } + + /// This contains data in higher bits, which should fail. + #[test] + fn from_u64_fail() { + ProtocolVersion::try_from(0xaa0125_u64).expect_err("must fail"); + } + + #[test] + fn ord() { + let v0_37 = ProtocolVersion::from_parts(0, 37); + let v1_37 = ProtocolVersion::from_parts(1, 37); + let v1_40 = ProtocolVersion::from_parts(1, 40); + + assert!(v0_37 < v1_37); + assert!(v1_37 > v0_37); + assert!(v1_37 < v1_40); + assert!(v1_40 > v1_37); + assert!(v1_40 <= v1_40); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs index 4630a4f77067..58a48d1bdd25 100644 --- a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs +++ b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs @@ -9,11 +9,13 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::wire; +use super::ProtocolVersion; + static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc" static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio" pub static STDERR_LAST: u64 = 0x616c7473; // "alts" -/// Protocol version (1.37) -static PROTOCOL_VERSION: [u8; 8] = [37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + +static PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::from_parts(1, 37); /// Max length of a Nix setting name/value. In bytes. /// @@ -127,7 +129,7 @@ pub struct ClientSettings { /// FUTUREWORK: write serialization. pub async fn read_client_settings<R: AsyncReadExt + Unpin>( r: &mut R, - client_version: u64, + client_version: ProtocolVersion, ) -> std::io::Result<ClientSettings> { let keep_failed = wire::read_bool(r).await?; let keep_going = wire::read_bool(r).await?; @@ -148,7 +150,7 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>( let build_cores = wire::read_u64(r).await?; let use_substitutes = wire::read_bool(r).await?; let mut overrides = HashMap::new(); - if client_version >= 12 { + if client_version.minor() >= 12 { let num_overrides = wire::read_u64(r).await?; for _ in 0..num_overrides { let name = wire::read_string(r, 0..MAX_SETTING_SIZE).await?; @@ -173,8 +175,8 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>( /// Performs the initial handshake the server is sending to a connecting client. /// /// During the handshake, the client first send a magic u64, to which -/// the daemon needs to respond with another magic u64. Then, the -/// daemon retrieve the client version, and discard a bunch of now +/// the daemon needs to respond with another magic u64. +/// Then, the daemon retrieves the client version, and discards a bunch of now /// obsolete data. /// /// # Arguments @@ -186,12 +188,12 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>( /// /// # Return /// -/// The protocol version of a client encoded as a u64. +/// The protocol version of the client. pub async fn server_handshake_client<'a, RW: 'a>( mut conn: &'a mut RW, nix_version: &str, trusted: Trust, -) -> std::io::Result<u64> +) -> std::io::Result<ProtocolVersion> where &'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin, { @@ -203,39 +205,41 @@ where )) } else { wire::write_u64(&mut conn, WORKER_MAGIC_2).await?; - conn.write_all(&PROTOCOL_VERSION).await?; + wire::write_u64(&mut conn, PROTOCOL_VERSION.into()).await?; conn.flush().await?; let client_version = wire::read_u64(&mut conn).await?; - if client_version < 0x10a { + // Parse into ProtocolVersion. + let client_version: ProtocolVersion = client_version + .try_into() + .map_err(|e| Error::new(ErrorKind::Unsupported, e))?; + if client_version < ProtocolVersion::from_parts(1, 10) { return Err(Error::new( ErrorKind::Unsupported, format!("The nix client version {} is too old", client_version), )); } - let protocol_minor = client_version & 0x00ff; - let _protocol_major = client_version & 0xff00; - if protocol_minor >= 14 { + if client_version.minor() >= 14 { // Obsolete CPU affinity. let read_affinity = wire::read_u64(&mut conn).await?; if read_affinity != 0 { let _cpu_affinity = wire::read_u64(&mut conn).await?; }; } - if protocol_minor >= 11 { + if client_version.minor() >= 11 { // Obsolete reserveSpace let _reserve_space = wire::read_u64(&mut conn).await?; } - if protocol_minor >= 33 { + if client_version.minor() >= 33 { // Nix version. We're plain lying, we're not Nix, but eh… // Setting it to the 2.3 lineage. Not 100% sure this is a // good idea. wire::write_bytes(&mut conn, nix_version).await?; conn.flush().await?; } - if protocol_minor >= 35 { + if client_version.minor() >= 35 { write_worker_trust_level(&mut conn, trusted).await?; } - Ok(protocol_minor) + Ok(client_version) } } @@ -290,10 +294,10 @@ mod tests { let mut test_conn = tokio_test::io::Builder::new() .read(&WORKER_MAGIC_1.to_le_bytes()) .write(&WORKER_MAGIC_2.to_le_bytes()) - .write(&PROTOCOL_VERSION) + .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) // Let's say the client is in sync with the daemon // protocol-wise - .read(&PROTOCOL_VERSION) + .read(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) // cpu affinity .read(&[0; 8]) // reservespace @@ -305,9 +309,11 @@ mod tests { // Trusted (1 == client trusted .write(&[1, 0, 0, 0, 0, 0, 0, 0]) .build(); - server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) + let client_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) .await .unwrap(); + + assert_eq!(client_version, PROTOCOL_VERSION) } #[tokio::test] @@ -329,7 +335,7 @@ mod tests { 00 00 00 00 00 00 00 00" ); let mut mock = Builder::new().read(&wire_bits).build(); - let settings = read_client_settings(&mut mock, 21) + let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21)) .await .expect("should parse"); let expected = ClientSettings { @@ -380,7 +386,7 @@ mod tests { 72 72 65 00 00 00 00 00" ); let mut mock = Builder::new().read(&wire_bits).build(); - let settings = read_client_settings(&mut mock, 21) + let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21)) .await .expect("should parse"); let overrides = HashMap::from([ |