use std::{future::Future, sync::Arc};
use bytes::Bytes;
use tokio::{
io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
sync::Mutex,
};
use tracing::debug;
use super::{
types::QueryValidPaths,
worker_protocol::{server_handshake_client, ClientSettings, Operation, Trust, STDERR_LAST},
NixDaemonIO,
};
use crate::{
store_path::StorePath,
wire::{
de::{NixRead, NixReader},
ser::{NixSerialize, NixWrite, NixWriter, NixWriterBuilder},
ProtocolVersion,
},
};
use crate::{nix_daemon::types::NixError, worker_protocol::STDERR_ERROR};
/// Handles a single connection with a nix client.
///
/// As part of its [`initialization`] it performs the handshake with the client
/// and determines the [ProtocolVersion] and [ClientSettings] to use for the remainder of the session.
///
/// Once initialized, [`handle_client`] needs to be called to handle the rest of the session,
/// it delegates all operation handling to an instance of [NixDaemonIO].
///
/// [`initialization`]: NixDaemon::initialize
#[allow(dead_code)]
pub struct NixDaemon<IO, R, W> {
io: Arc<IO>,
protocol_version: ProtocolVersion,
client_settings: ClientSettings,
reader: NixReader<R>,
writer: Arc<Mutex<NixWriter<W>>>,
}
impl<IO, R, W> NixDaemon<IO, R, W>
where
IO: NixDaemonIO + Sync + Send,
{
pub fn new(
io: Arc<IO>,
protocol_version: ProtocolVersion,
client_settings: ClientSettings,
reader: NixReader<R>,
writer: NixWriter<W>,
) -> Self {
Self {
io,
protocol_version,
client_settings,
reader,
writer: Arc::new(Mutex::new(writer)),
}
}
}
impl<IO, RW> NixDaemon<IO, ReadHalf<RW>, WriteHalf<RW>>
where
RW: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static,
IO: NixDaemonIO + Sync + Send,
{
/// Async constructor for NixDaemon.
///
/// Performs the initial handshake with the client and retrieves the client's preferred
/// settings.
///
/// The resulting daemon can handle the client session by calling [NixDaemon::handle_client].
pub async fn initialize(io: Arc<IO>, mut connection: RW) -> Result<Self, std::io::Error>
where
RW: AsyncReadExt + AsyncWriteExt + Send + Unpin,
{
let protocol_version =
server_handshake_client(&mut connection, "2.18.2", Trust::Trusted).await?;
connection.write_u64_le(STDERR_LAST).await?;
let (reader, writer) = split(connection);
let mut reader = NixReader::builder()
.set_version(protocol_version)
.build(reader);
let mut writer = NixWriterBuilder::default()
.set_version(protocol_version)
.build(writer);
// The first op is always SetOptions
let operation: Operation = reader.read_value().await?;
if operation != Operation::SetOptions {
return Err(std::io::Error::other(
"Expected SetOptions operation, but got {operation}",
));
}
let client_settings: ClientSettings = reader.read_value().await?;
writer.write_number(STDERR_LAST).await?;
writer.flush().await?;
Ok(Self::new(
io,
protocol_version,
client_settings,
reader,
writer,
))
}
/// Main client connection loop, reads client's requests and responds to them accordingly.
pub async fn handle_client(&mut self) -> Result<(), std::io::Error> {
let io = self.io.clone();
loop {
let op_code = self.reader.read_number().await?;
match TryInto::<Operation>::try_into(op_code) {
// Note: please keep operations sorted in ascending order of their numerical op number.
Ok(operation) => match operation {
Operation::IsValidPath => {
let path: StorePath<String> = self.reader.read_value().await?;
self.handle(io.is_valid_path(&path)).await?
}
// Note this operation does not currently delegate to NixDaemonIO,
// The general idea is that we will pass relevant ClientSettings
// into individual NixDaemonIO method calls if the need arises.
// For now we just store the settings in the NixDaemon for future use.
Operation::SetOptions => {
self.client_settings = self.reader.read_value().await?;
self.handle(async { Ok(()) }).await?
}
Operation::QueryPathInfo => {
let path: StorePath<String> = self.reader.read_value().await?;
self.handle(io.query_path_info(&path)).await?
}
Operation::QueryPathFromHashPart => {
let hash: Bytes = self.reader.read_value().await?;
self.handle(io.query_path_from_hash_part(&hash)).await?
}
Operation::QueryValidPaths => {
let query: QueryValidPaths = self.reader.read_value().await?;
self.handle(io.query_valid_paths(&query)).await?
}
Operation::QueryValidDerivers => {
let path: StorePath<String> = self.reader.read_value().await?;
self.handle(io.query_valid_derivers(&path)).await?
}
_ => {
return Err(std::io::Error::other(format!(
"Operation {operation:?} is not implemented"
)));
}
},
_ => {
return Err(std::io::Error::other(format!(
"Unknown operation code received: {op_code}"
)));
}
}
}
}
/// Handles the operation and sends the response or error to the client.
///
/// As per nix daemon protocol, after sending the request, the client expects zero or more
/// log lines/activities followed by either
/// * STDERR_LAST and the response bytes
/// * STDERR_ERROR and the error
///
/// This is a helper method, awaiting on the passed in future and then
/// handling log lines/activities as described above.
async fn handle<T>(
&mut self,
future: impl Future<Output = std::io::Result<T>>,
) -> Result<(), std::io::Error>
where
T: NixSerialize + Send,
{
let result = future.await;
let mut writer = self.writer.lock().await;
match result {
Ok(r) => {
// the protocol requires that we first indicate that we are done sending logs
// by sending STDERR_LAST and then the response.
writer.write_number(STDERR_LAST).await?;
writer.write_value(&r).await?;
writer.flush().await
}
Err(e) => {
debug!(err = ?e, "IO error");
writer.write_number(STDERR_ERROR).await?;
writer.write_value(&NixError::new(format!("{e:?}"))).await?;
writer.flush().await
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{io::Result, sync::Arc};
use tokio::io::AsyncWriteExt;
use crate::{
nix_daemon::types::UnkeyedValidPathInfo,
wire::ProtocolVersion,
worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2},
};
struct MockDaemonIO {}
impl NixDaemonIO for MockDaemonIO {
async fn query_path_info(
&self,
_path: &crate::store_path::StorePath<String>,
) -> Result<Option<UnkeyedValidPathInfo>> {
Ok(None)
}
async fn query_path_from_hash_part(
&self,
_hash: &[u8],
) -> Result<Option<UnkeyedValidPathInfo>> {
Ok(None)
}
}
#[tokio::test]
async fn test_daemon_initialization() {
let mut builder = tokio_test::io::Builder::new();
let test_conn = builder
.read(&WORKER_MAGIC_1.to_le_bytes())
.write(&WORKER_MAGIC_2.to_le_bytes())
// Our version is 1.37
.write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
// The client's versin is 1.35
.read(&[35, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
// cpu affinity
.read(&[0; 8])
// reservespace
.read(&[0; 8])
// version (size)
.write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
// version (data == 2.18.2 + padding)
.write(&[50, 46, 49, 56, 46, 50, 0, 0])
// Trusted (1 == client trusted)
.write(&[1, 0, 0, 0, 0, 0, 0, 0])
// STDERR_LAST
.write(&[115, 116, 108, 97, 0, 0, 0, 0]);
let mut bytes = Vec::new();
let mut writer = NixWriter::new(&mut bytes);
writer
.write_value(&ClientSettings::default())
.await
.unwrap();
writer.flush().await.unwrap();
let test_conn = test_conn
// SetOptions op
.read(&[19, 0, 0, 0, 0, 0, 0, 0])
.read(&bytes)
// STDERR_LAST
.write(&[115, 116, 108, 97, 0, 0, 0, 0])
.build();
let daemon = NixDaemon::initialize(Arc::new(MockDaemonIO {}), test_conn)
.await
.unwrap();
assert_eq!(daemon.client_settings, ClientSettings::default());
assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35));
}
}