diff options
Diffstat (limited to 'tvix/nix-compat/src/nix_daemon/worker_protocol.rs')
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/worker_protocol.rs | 50 |
1 files changed, 28 insertions, 22 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs index 4630a4f77067..58a48d1bdd25 100644 --- a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs +++ b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs @@ -9,11 +9,13 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::wire; +use super::ProtocolVersion; + static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc" static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio" pub static STDERR_LAST: u64 = 0x616c7473; // "alts" -/// Protocol version (1.37) -static PROTOCOL_VERSION: [u8; 8] = [37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + +static PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::from_parts(1, 37); /// Max length of a Nix setting name/value. In bytes. /// @@ -127,7 +129,7 @@ pub struct ClientSettings { /// FUTUREWORK: write serialization. pub async fn read_client_settings<R: AsyncReadExt + Unpin>( r: &mut R, - client_version: u64, + client_version: ProtocolVersion, ) -> std::io::Result<ClientSettings> { let keep_failed = wire::read_bool(r).await?; let keep_going = wire::read_bool(r).await?; @@ -148,7 +150,7 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>( let build_cores = wire::read_u64(r).await?; let use_substitutes = wire::read_bool(r).await?; let mut overrides = HashMap::new(); - if client_version >= 12 { + if client_version.minor() >= 12 { let num_overrides = wire::read_u64(r).await?; for _ in 0..num_overrides { let name = wire::read_string(r, 0..MAX_SETTING_SIZE).await?; @@ -173,8 +175,8 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>( /// Performs the initial handshake the server is sending to a connecting client. /// /// During the handshake, the client first send a magic u64, to which -/// the daemon needs to respond with another magic u64. Then, the -/// daemon retrieve the client version, and discard a bunch of now +/// the daemon needs to respond with another magic u64. +/// Then, the daemon retrieves the client version, and discards a bunch of now /// obsolete data. /// /// # Arguments @@ -186,12 +188,12 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>( /// /// # Return /// -/// The protocol version of a client encoded as a u64. +/// The protocol version of the client. pub async fn server_handshake_client<'a, RW: 'a>( mut conn: &'a mut RW, nix_version: &str, trusted: Trust, -) -> std::io::Result<u64> +) -> std::io::Result<ProtocolVersion> where &'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin, { @@ -203,39 +205,41 @@ where )) } else { wire::write_u64(&mut conn, WORKER_MAGIC_2).await?; - conn.write_all(&PROTOCOL_VERSION).await?; + wire::write_u64(&mut conn, PROTOCOL_VERSION.into()).await?; conn.flush().await?; let client_version = wire::read_u64(&mut conn).await?; - if client_version < 0x10a { + // Parse into ProtocolVersion. + let client_version: ProtocolVersion = client_version + .try_into() + .map_err(|e| Error::new(ErrorKind::Unsupported, e))?; + if client_version < ProtocolVersion::from_parts(1, 10) { return Err(Error::new( ErrorKind::Unsupported, format!("The nix client version {} is too old", client_version), )); } - let protocol_minor = client_version & 0x00ff; - let _protocol_major = client_version & 0xff00; - if protocol_minor >= 14 { + if client_version.minor() >= 14 { // Obsolete CPU affinity. let read_affinity = wire::read_u64(&mut conn).await?; if read_affinity != 0 { let _cpu_affinity = wire::read_u64(&mut conn).await?; }; } - if protocol_minor >= 11 { + if client_version.minor() >= 11 { // Obsolete reserveSpace let _reserve_space = wire::read_u64(&mut conn).await?; } - if protocol_minor >= 33 { + if client_version.minor() >= 33 { // Nix version. We're plain lying, we're not Nix, but eh… // Setting it to the 2.3 lineage. Not 100% sure this is a // good idea. wire::write_bytes(&mut conn, nix_version).await?; conn.flush().await?; } - if protocol_minor >= 35 { + if client_version.minor() >= 35 { write_worker_trust_level(&mut conn, trusted).await?; } - Ok(protocol_minor) + Ok(client_version) } } @@ -290,10 +294,10 @@ mod tests { let mut test_conn = tokio_test::io::Builder::new() .read(&WORKER_MAGIC_1.to_le_bytes()) .write(&WORKER_MAGIC_2.to_le_bytes()) - .write(&PROTOCOL_VERSION) + .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) // Let's say the client is in sync with the daemon // protocol-wise - .read(&PROTOCOL_VERSION) + .read(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) // cpu affinity .read(&[0; 8]) // reservespace @@ -305,9 +309,11 @@ mod tests { // Trusted (1 == client trusted .write(&[1, 0, 0, 0, 0, 0, 0, 0]) .build(); - server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) + let client_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) .await .unwrap(); + + assert_eq!(client_version, PROTOCOL_VERSION) } #[tokio::test] @@ -329,7 +335,7 @@ mod tests { 00 00 00 00 00 00 00 00" ); let mut mock = Builder::new().read(&wire_bits).build(); - let settings = read_client_settings(&mut mock, 21) + let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21)) .await .expect("should parse"); let expected = ClientSettings { @@ -380,7 +386,7 @@ mod tests { 72 72 65 00 00 00 00 00" ); let mut mock = Builder::new().read(&wire_bits).build(); - let settings = read_client_settings(&mut mock, 21) + let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21)) .await .expect("should parse"); let overrides = HashMap::from([ |