about summary refs log tree commit diff
diff options
context:
space:
mode:
authoredef <edef@edef.eu>2024-04-30T07·15+0000
committeredef <edef@edef.eu>2024-04-30T09·55+0000
commit095f715a8045933159cb7daf3302246b9ca50658 (patch)
treed66389e63c9672974f210d0e0691aa16c51c720b
parentb3305ea6e26bef913cfa1a1d7b5cb0c13392ed4c (diff)
refactor(nix-compat/wire): drop primitive functions r/8042
These may as well be inlined, and hardly need tests, since they just
alias AsyncReadExt::read_u64_le / AsyncWriteExt::write_u64_le.

Boolean reading is worth making explicit, since callers may differ on
how they want to handle values other than 0 and 1.

Boolean writing simplifies to `.write_u64_le(x as u64)`, which is also
fine to inline.

Change-Id: Ief9722fe886688693feb924ff0306b5bc68dd7a2
Reviewed-on: https://cl.tvl.fyi/c/depot/+/11549
Reviewed-by: flokli <flokli@flokli.de>
Tested-by: BuildkiteCI
-rw-r--r--tvix/nix-compat/src/nix_daemon/worker_protocol.rs46
-rw-r--r--tvix/nix-compat/src/wire/bytes/mod.rs8
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/mod.rs6
-rw-r--r--tvix/nix-compat/src/wire/mod.rs3
-rw-r--r--tvix/nix-compat/src/wire/primitive.rs74
-rw-r--r--users/picnoir/tvix-daemon/src/main.rs8
6 files changed, 33 insertions, 112 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
index 58a48d1bdd25..9ffceffced1b 100644
--- a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
+++ b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs
@@ -131,27 +131,27 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
     r: &mut R,
     client_version: ProtocolVersion,
 ) -> std::io::Result<ClientSettings> {
-    let keep_failed = wire::read_bool(r).await?;
-    let keep_going = wire::read_bool(r).await?;
-    let try_fallback = wire::read_bool(r).await?;
-    let verbosity_uint = wire::read_u64(r).await?;
+    let keep_failed = r.read_u64_le().await? != 0;
+    let keep_going = r.read_u64_le().await? != 0;
+    let try_fallback = r.read_u64_le().await? != 0;
+    let verbosity_uint = r.read_u64_le().await?;
     let verbosity = Verbosity::from_u64(verbosity_uint).ok_or_else(|| {
         Error::new(
             ErrorKind::InvalidData,
             format!("Can't convert integer {} to verbosity", verbosity_uint),
         )
     })?;
-    let max_build_jobs = wire::read_u64(r).await?;
-    let max_silent_time = wire::read_u64(r).await?;
-    _ = wire::read_u64(r).await?; // obsolete useBuildHook
-    let verbose_build = wire::read_bool(r).await?;
-    _ = wire::read_u64(r).await?; // obsolete logType
-    _ = wire::read_u64(r).await?; // obsolete printBuildTrace
-    let build_cores = wire::read_u64(r).await?;
-    let use_substitutes = wire::read_bool(r).await?;
+    let max_build_jobs = r.read_u64_le().await?;
+    let max_silent_time = r.read_u64_le().await?;
+    _ = r.read_u64_le().await?; // obsolete useBuildHook
+    let verbose_build = r.read_u64_le().await? != 0;
+    _ = r.read_u64_le().await?; // obsolete logType
+    _ = r.read_u64_le().await?; // obsolete printBuildTrace
+    let build_cores = r.read_u64_le().await?;
+    let use_substitutes = r.read_u64_le().await? != 0;
     let mut overrides = HashMap::new();
     if client_version.minor() >= 12 {
-        let num_overrides = wire::read_u64(r).await?;
+        let num_overrides = r.read_u64_le().await?;
         for _ in 0..num_overrides {
             let name = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
             let value = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
@@ -197,17 +197,17 @@ pub async fn server_handshake_client<'a, RW: 'a>(
 where
     &'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin,
 {
-    let worker_magic_1 = wire::read_u64(&mut conn).await?;
+    let worker_magic_1 = conn.read_u64_le().await?;
     if worker_magic_1 != WORKER_MAGIC_1 {
         Err(std::io::Error::new(
             ErrorKind::InvalidData,
             format!("Incorrect worker magic number received: {}", worker_magic_1),
         ))
     } else {
-        wire::write_u64(&mut conn, WORKER_MAGIC_2).await?;
-        wire::write_u64(&mut conn, PROTOCOL_VERSION.into()).await?;
+        conn.write_u64_le(WORKER_MAGIC_2).await?;
+        conn.write_u64_le(PROTOCOL_VERSION.into()).await?;
         conn.flush().await?;
-        let client_version = wire::read_u64(&mut conn).await?;
+        let client_version = conn.read_u64_le().await?;
         // Parse into ProtocolVersion.
         let client_version: ProtocolVersion = client_version
             .try_into()
@@ -220,14 +220,14 @@ where
         }
         if client_version.minor() >= 14 {
             // Obsolete CPU affinity.
-            let read_affinity = wire::read_u64(&mut conn).await?;
+            let read_affinity = conn.read_u64_le().await?;
             if read_affinity != 0 {
-                let _cpu_affinity = wire::read_u64(&mut conn).await?;
+                let _cpu_affinity = conn.read_u64_le().await?;
             };
         }
         if client_version.minor() >= 11 {
             // Obsolete reserveSpace
-            let _reserve_space = wire::read_u64(&mut conn).await?;
+            let _reserve_space = conn.read_u64_le().await?;
         }
         if client_version.minor() >= 33 {
             // Nix version. We're plain lying, we're not Nix, but eh…
@@ -245,7 +245,7 @@ where
 
 /// Read a worker [Operation] from the wire.
 pub async fn read_op<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Operation> {
-    let op_number = wire::read_u64(r).await?;
+    let op_number = r.read_u64_le().await?;
     Operation::from_u64(op_number).ok_or(Error::new(
         ErrorKind::InvalidData,
         format!("Invalid OP number {}", op_number),
@@ -278,8 +278,8 @@ where
     W: AsyncReadExt + AsyncWriteExt + Unpin,
 {
     match t {
-        Trust::Trusted => wire::write_u64(conn, 1).await,
-        Trust::NotTrusted => wire::write_u64(conn, 2).await,
+        Trust::Trusted => conn.write_u64_le(1).await,
+        Trust::NotTrusted => conn.write_u64_le(2).await,
     }
 }
 
diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs
index 031d969e287f..fc777bafe20f 100644
--- a/tvix/nix-compat/src/wire/bytes/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/mod.rs
@@ -9,8 +9,6 @@ pub use reader::BytesReader;
 mod writer;
 pub use writer::BytesWriter;
 
-use super::primitive;
-
 /// 8 null bytes, used to write out padding.
 const EMPTY_BYTES: &[u8; 8] = &[0u8; 8];
 
@@ -41,7 +39,7 @@ where
     S: RangeBounds<u64>,
 {
     // read the length field
-    let len = primitive::read_u64(r).await?;
+    let len = r.read_u64_le().await?;
 
     if !allowed_size.contains(&len) {
         return Err(std::io::Error::new(
@@ -52,7 +50,7 @@ where
 
     // calculate the total length, including padding.
     // byte packets are padded to 8 byte blocks each.
-    let padded_len = padding_len(len) as u64 + (len as u64);
+    let padded_len = padding_len(len) as u64 + len;
     let mut limited_reader = r.take(padded_len);
 
     let mut buf = Vec::new();
@@ -105,7 +103,7 @@ pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
     b: B,
 ) -> std::io::Result<()> {
     // write the size packet.
-    primitive::write_u64(w, b.as_ref().len() as u64).await?;
+    w.write_u64_le(b.as_ref().len() as u64).await?;
 
     // write the payload
     w.write_all(b.as_ref()).await?;
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
index ef59e9c160e4..50398d9b9e40 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
@@ -5,9 +5,7 @@ use std::{
     pin::Pin,
     task::{self, ready, Poll},
 };
-use tokio::io::{AsyncRead, ReadBuf};
-
-use crate::wire::read_u64;
+use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
 
 use trailer::{read_trailer, ReadTrailer, Trailer};
 mod trailer;
@@ -52,7 +50,7 @@ where
 {
     /// Constructs a new BytesReader, using the underlying passed reader.
     pub async fn new<S: RangeBounds<u64>>(mut reader: R, allowed_size: S) -> io::Result<Self> {
-        let size = read_u64(&mut reader).await?;
+        let size = reader.read_u64_le().await?;
 
         if !allowed_size.contains(&size) {
             return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size"));
diff --git a/tvix/nix-compat/src/wire/mod.rs b/tvix/nix-compat/src/wire/mod.rs
index 65c053d58e21..a197e3a1f451 100644
--- a/tvix/nix-compat/src/wire/mod.rs
+++ b/tvix/nix-compat/src/wire/mod.rs
@@ -3,6 +3,3 @@
 
 mod bytes;
 pub use bytes::*;
-
-mod primitive;
-pub use primitive::*;
diff --git a/tvix/nix-compat/src/wire/primitive.rs b/tvix/nix-compat/src/wire/primitive.rs
deleted file mode 100644
index ee0f5fc4279d..000000000000
--- a/tvix/nix-compat/src/wire/primitive.rs
+++ /dev/null
@@ -1,74 +0,0 @@
-// SPDX-FileCopyrightText: 2023 embr <git@liclac.eu>
-//
-// SPDX-License-Identifier: EUPL-1.2
-
-use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
-
-#[allow(dead_code)]
-/// Read a u64 from the AsyncRead (little endian).
-pub async fn read_u64<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<u64> {
-    r.read_u64_le().await
-}
-
-/// Write a u64 to the AsyncWrite (little endian).
-pub async fn write_u64<W: AsyncWrite + Unpin>(w: &mut W, v: u64) -> std::io::Result<()> {
-    w.write_u64_le(v).await
-}
-
-#[allow(dead_code)]
-/// Read a boolean from the AsyncRead, encoded as u64 (>0 is true).
-pub async fn read_bool<R: AsyncRead + Unpin>(r: &mut R) -> std::io::Result<bool> {
-    Ok(read_u64(r).await? > 0)
-}
-
-#[allow(dead_code)]
-/// Write a boolean to the AsyncWrite, encoded as u64 (>0 is true).
-pub async fn write_bool<W: AsyncWrite + Unpin>(w: &mut W, v: bool) -> std::io::Result<()> {
-    write_u64(w, if v { 1u64 } else { 0u64 }).await
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use tokio_test::io::Builder;
-
-    // Integers.
-    #[tokio::test]
-    async fn test_read_u64() {
-        let mut mock = Builder::new().read(&1234567890u64.to_le_bytes()).build();
-        assert_eq!(1234567890u64, read_u64(&mut mock).await.unwrap());
-    }
-    #[tokio::test]
-    async fn test_write_u64() {
-        let mut mock = Builder::new().write(&1234567890u64.to_le_bytes()).build();
-        write_u64(&mut mock, 1234567890).await.unwrap();
-    }
-
-    // Booleans.
-    #[tokio::test]
-    async fn test_read_bool_0() {
-        let mut mock = Builder::new().read(&0u64.to_le_bytes()).build();
-        assert!(!read_bool(&mut mock).await.unwrap());
-    }
-    #[tokio::test]
-    async fn test_read_bool_1() {
-        let mut mock = Builder::new().read(&1u64.to_le_bytes()).build();
-        assert!(read_bool(&mut mock).await.unwrap());
-    }
-    #[tokio::test]
-    async fn test_read_bool_2() {
-        let mut mock = Builder::new().read(&2u64.to_le_bytes()).build();
-        assert!(read_bool(&mut mock).await.unwrap());
-    }
-
-    #[tokio::test]
-    async fn test_write_bool_false() {
-        let mut mock = Builder::new().write(&0u64.to_le_bytes()).build();
-        write_bool(&mut mock, false).await.unwrap();
-    }
-    #[tokio::test]
-    async fn test_write_bool_true() {
-        let mut mock = Builder::new().write(&1u64.to_le_bytes()).build();
-        write_bool(&mut mock, true).await.unwrap();
-    }
-}
diff --git a/users/picnoir/tvix-daemon/src/main.rs b/users/picnoir/tvix-daemon/src/main.rs
index 102067fcf7d4..dc49b209e009 100644
--- a/users/picnoir/tvix-daemon/src/main.rs
+++ b/users/picnoir/tvix-daemon/src/main.rs
@@ -4,7 +4,7 @@ use tokio_listener::{self, SystemOptions, UserOptions};
 use tracing::{debug, error, info, instrument, Level};
 
 use nix_compat::worker_protocol::{self, server_handshake_client, ClientSettings, Trust};
-use nix_compat::{wire, ProtocolVersion};
+use nix_compat::ProtocolVersion;
 
 #[derive(Parser, Debug)]
 struct Cli {
@@ -78,7 +78,9 @@ where
             // TODO: implement logging. For now, we'll just send
             // STDERR_LAST, which is good enough to get Nix respond to
             // us.
-            wire::write_u64(&mut client_connection.conn, worker_protocol::STDERR_LAST)
+            client_connection
+                .conn
+                .write_u64_le(worker_protocol::STDERR_LAST)
                 .await
                 .unwrap();
             loop {
@@ -109,6 +111,6 @@ where
     let settings = worker_protocol::read_client_settings(&mut conn.conn, conn.version).await?;
     // The client expects us to send some logs when we're processing
     // the settings. Sending STDERR_LAST signal we're done processing.
-    wire::write_u64(&mut conn.conn, worker_protocol::STDERR_LAST).await?;
+    conn.conn.write_u64_le(worker_protocol::STDERR_LAST).await?;
     Ok(settings)
 }