diff options
Diffstat (limited to 'tvix/nix-compat/src/nix_daemon/handler.rs')
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/handler.rs | 277 |
1 files changed, 277 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/handler.rs b/tvix/nix-compat/src/nix_daemon/handler.rs new file mode 100644 index 000000000000..a61a4810c807 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/handler.rs @@ -0,0 +1,277 @@ +use std::{future::Future, sync::Arc}; + +use bytes::Bytes; +use tokio::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + sync::Mutex, +}; +use tracing::debug; + +use super::{ + types::QueryValidPaths, + worker_protocol::{server_handshake_client, ClientSettings, Operation, Trust, STDERR_LAST}, + NixDaemonIO, +}; + +use crate::{ + store_path::StorePath, + wire::{ + de::{NixRead, NixReader}, + ser::{NixSerialize, NixWrite, NixWriter, NixWriterBuilder}, + ProtocolVersion, + }, +}; + +use crate::{nix_daemon::types::NixError, worker_protocol::STDERR_ERROR}; + +/// Handles a single connection with a nix client. +/// +/// As part of its [`initialization`] it performs the handshake with the client +/// and determines the [ProtocolVersion] and [ClientSettings] to use for the remainder of the session. +/// +/// Once initialized, [`handle_client`] needs to be called to handle the rest of the session, +/// it delegates all operation handling to an instance of [NixDaemonIO]. +/// +/// [`initialization`]: NixDaemon::initialize +#[allow(dead_code)] +pub struct NixDaemon<IO, R, W> { + io: Arc<IO>, + protocol_version: ProtocolVersion, + client_settings: ClientSettings, + reader: NixReader<R>, + writer: Arc<Mutex<NixWriter<W>>>, +} + +impl<IO, R, W> NixDaemon<IO, R, W> +where + IO: NixDaemonIO + Sync + Send, +{ + pub fn new( + io: Arc<IO>, + protocol_version: ProtocolVersion, + client_settings: ClientSettings, + reader: NixReader<R>, + writer: NixWriter<W>, + ) -> Self { + Self { + io, + protocol_version, + client_settings, + reader, + writer: Arc::new(Mutex::new(writer)), + } + } +} + +impl<IO, RW> NixDaemon<IO, ReadHalf<RW>, WriteHalf<RW>> +where + RW: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static, + IO: NixDaemonIO + Sync + Send, +{ + /// Async constructor for NixDaemon. + /// + /// Performs the initial handshake with the client and retrieves the client's preferred + /// settings. + /// + /// The resulting daemon can handle the client session by calling [NixDaemon::handle_client]. + pub async fn initialize(io: Arc<IO>, mut connection: RW) -> Result<Self, std::io::Error> + where + RW: AsyncReadExt + AsyncWriteExt + Send + Unpin, + { + let protocol_version = + server_handshake_client(&mut connection, "2.18.2", Trust::Trusted).await?; + + connection.write_u64_le(STDERR_LAST).await?; + let (reader, writer) = split(connection); + let mut reader = NixReader::builder() + .set_version(protocol_version) + .build(reader); + let mut writer = NixWriterBuilder::default() + .set_version(protocol_version) + .build(writer); + + // The first op is always SetOptions + let operation: Operation = reader.read_value().await?; + if operation != Operation::SetOptions { + return Err(std::io::Error::other( + "Expected SetOptions operation, but got {operation}", + )); + } + let client_settings: ClientSettings = reader.read_value().await?; + writer.write_number(STDERR_LAST).await?; + writer.flush().await?; + + Ok(Self::new( + io, + protocol_version, + client_settings, + reader, + writer, + )) + } + + /// Main client connection loop, reads client's requests and responds to them accordingly. + pub async fn handle_client(&mut self) -> Result<(), std::io::Error> { + let io = self.io.clone(); + loop { + let op_code = self.reader.read_number().await?; + match TryInto::<Operation>::try_into(op_code) { + // Note: please keep operations sorted in ascending order of their numerical op number. + Ok(operation) => match operation { + Operation::IsValidPath => { + let path: StorePath<String> = self.reader.read_value().await?; + self.handle(io.is_valid_path(&path)).await? + } + // Note this operation does not currently delegate to NixDaemonIO, + // The general idea is that we will pass relevant ClientSettings + // into individual NixDaemonIO method calls if the need arises. + // For now we just store the settings in the NixDaemon for future use. + Operation::SetOptions => { + self.client_settings = self.reader.read_value().await?; + self.handle(async { Ok(()) }).await? + } + Operation::QueryPathInfo => { + let path: StorePath<String> = self.reader.read_value().await?; + self.handle(io.query_path_info(&path)).await? + } + Operation::QueryPathFromHashPart => { + let hash: Bytes = self.reader.read_value().await?; + self.handle(io.query_path_from_hash_part(&hash)).await? + } + Operation::QueryValidPaths => { + let query: QueryValidPaths = self.reader.read_value().await?; + self.handle(io.query_valid_paths(&query)).await? + } + Operation::QueryValidDerivers => { + let path: StorePath<String> = self.reader.read_value().await?; + self.handle(io.query_valid_derivers(&path)).await? + } + _ => { + return Err(std::io::Error::other(format!( + "Operation {operation:?} is not implemented" + ))); + } + }, + _ => { + return Err(std::io::Error::other(format!( + "Unknown operation code received: {op_code}" + ))); + } + } + } + } + + /// Handles the operation and sends the response or error to the client. + /// + /// As per nix daemon protocol, after sending the request, the client expects zero or more + /// log lines/activities followed by either + /// * STDERR_LAST and the response bytes + /// * STDERR_ERROR and the error + /// + /// This is a helper method, awaiting on the passed in future and then + /// handling log lines/activities as described above. + async fn handle<T>( + &mut self, + future: impl Future<Output = std::io::Result<T>>, + ) -> Result<(), std::io::Error> + where + T: NixSerialize + Send, + { + let result = future.await; + let mut writer = self.writer.lock().await; + + match result { + Ok(r) => { + // the protocol requires that we first indicate that we are done sending logs + // by sending STDERR_LAST and then the response. + writer.write_number(STDERR_LAST).await?; + writer.write_value(&r).await?; + writer.flush().await + } + Err(e) => { + debug!(err = ?e, "IO error"); + writer.write_number(STDERR_ERROR).await?; + writer.write_value(&NixError::new(format!("{e:?}"))).await?; + writer.flush().await + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{io::Result, sync::Arc}; + + use tokio::io::AsyncWriteExt; + + use crate::{ + nix_daemon::types::UnkeyedValidPathInfo, + wire::ProtocolVersion, + worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2}, + }; + + struct MockDaemonIO {} + + impl NixDaemonIO for MockDaemonIO { + async fn query_path_info( + &self, + _path: &crate::store_path::StorePath<String>, + ) -> Result<Option<UnkeyedValidPathInfo>> { + Ok(None) + } + + async fn query_path_from_hash_part( + &self, + _hash: &[u8], + ) -> Result<Option<UnkeyedValidPathInfo>> { + Ok(None) + } + } + + #[tokio::test] + async fn test_daemon_initialization() { + let mut builder = tokio_test::io::Builder::new(); + let test_conn = builder + .read(&WORKER_MAGIC_1.to_le_bytes()) + .write(&WORKER_MAGIC_2.to_le_bytes()) + // Our version is 1.37 + .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // The client's versin is 1.35 + .read(&[35, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // cpu affinity + .read(&[0; 8]) + // reservespace + .read(&[0; 8]) + // version (size) + .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // version (data == 2.18.2 + padding) + .write(&[50, 46, 49, 56, 46, 50, 0, 0]) + // Trusted (1 == client trusted) + .write(&[1, 0, 0, 0, 0, 0, 0, 0]) + // STDERR_LAST + .write(&[115, 116, 108, 97, 0, 0, 0, 0]); + + let mut bytes = Vec::new(); + let mut writer = NixWriter::new(&mut bytes); + writer + .write_value(&ClientSettings::default()) + .await + .unwrap(); + writer.flush().await.unwrap(); + + let test_conn = test_conn + // SetOptions op + .read(&[19, 0, 0, 0, 0, 0, 0, 0]) + .read(&bytes) + // STDERR_LAST + .write(&[115, 116, 108, 97, 0, 0, 0, 0]) + .build(); + + let daemon = NixDaemon::initialize(Arc::new(MockDaemonIO {}), test_conn) + .await + .unwrap(); + assert_eq!(daemon.client_settings, ClientSettings::default()); + assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35)); + } +} |