about summary refs log tree commit diff
path: root/tvix/nix-compat/src/nix_daemon/handler.rs
use std::{future::Future, ops::DerefMut, sync::Arc};

use bytes::Bytes;
use tokio::{
    io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
    sync::Mutex,
};
use tracing::{debug, warn};

use super::{
    framing::{NixFramedReader, StderrReadFramedReader},
    types::{AddToStoreNarRequest, QueryValidPaths},
    worker_protocol::{server_handshake_client, ClientSettings, Operation, Trust, STDERR_LAST},
    NixDaemonIO,
};

use crate::{
    store_path::StorePath,
    wire::{
        de::{NixRead, NixReader},
        ser::{NixSerialize, NixWrite, NixWriter, NixWriterBuilder},
        ProtocolVersion,
    },
};

use crate::{nix_daemon::types::NixError, worker_protocol::STDERR_ERROR};

/// Handles a single connection with a nix client.
///
/// As part of its [`initialization`] it performs the handshake with the client
/// and determines the [ProtocolVersion] and [ClientSettings] to use for the remainder of the session.
///
/// Once initialized, [`handle_client`] needs to be called to handle the rest of the session,
/// it delegates all operation handling to an instance of [NixDaemonIO].
///
/// [`initialization`]: NixDaemon::initialize
#[allow(dead_code)]
pub struct NixDaemon<IO, R, W> {
    io: Arc<IO>,
    protocol_version: ProtocolVersion,
    client_settings: ClientSettings,
    reader: NixReader<R>,
    writer: Arc<Mutex<NixWriter<W>>>,
}

impl<IO, R, W> NixDaemon<IO, R, W>
where
    IO: NixDaemonIO + Sync + Send,
{
    pub fn new(
        io: Arc<IO>,
        protocol_version: ProtocolVersion,
        client_settings: ClientSettings,
        reader: NixReader<R>,
        writer: NixWriter<W>,
    ) -> Self {
        Self {
            io,
            protocol_version,
            client_settings,
            reader,
            writer: Arc::new(Mutex::new(writer)),
        }
    }
}

impl<IO, RW> NixDaemon<IO, ReadHalf<RW>, WriteHalf<RW>>
where
    RW: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static,
    IO: NixDaemonIO + Sync + Send,
{
    /// Async constructor for NixDaemon.
    ///
    /// Performs the initial handshake with the client and retrieves the client's preferred
    /// settings.
    ///
    /// The resulting daemon can handle the client session by calling [NixDaemon::handle_client].
    pub async fn initialize(io: Arc<IO>, mut connection: RW) -> Result<Self, std::io::Error>
    where
        RW: AsyncReadExt + AsyncWriteExt + Send + Unpin,
    {
        let protocol_version =
            server_handshake_client(&mut connection, "2.18.2", Trust::Trusted).await?;

        connection.write_u64_le(STDERR_LAST).await?;
        let (reader, writer) = split(connection);
        let mut reader = NixReader::builder()
            .set_version(protocol_version)
            .build(reader);
        let mut writer = NixWriterBuilder::default()
            .set_version(protocol_version)
            .build(writer);

        // The first op is always SetOptions
        let operation: Operation = reader.read_value().await?;
        if operation != Operation::SetOptions {
            return Err(std::io::Error::other(
                "Expected SetOptions operation, but got {operation}",
            ));
        }
        let client_settings: ClientSettings = reader.read_value().await?;
        writer.write_number(STDERR_LAST).await?;
        writer.flush().await?;

        Ok(Self::new(
            io,
            protocol_version,
            client_settings,
            reader,
            writer,
        ))
    }

    /// Main client connection loop, reads client's requests and responds to them accordingly.
    pub async fn handle_client(&mut self) -> Result<(), std::io::Error> {
        let io = self.io.clone();
        loop {
            let op_code = self.reader.read_number().await?;
            match TryInto::<Operation>::try_into(op_code) {
                // Note: please keep operations sorted in ascending order of their numerical op number.
                Ok(operation) => match operation {
                    Operation::IsValidPath => {
                        let path: StorePath<String> = self.reader.read_value().await?;
                        Self::handle(&self.writer, io.is_valid_path(&path)).await?
                    }
                    // Note this operation does not currently delegate to NixDaemonIO,
                    // The general idea is that we will pass relevant ClientSettings
                    // into individual NixDaemonIO method calls if the need arises.
                    // For now we just store the settings in the NixDaemon for future use.
                    Operation::SetOptions => {
                        self.client_settings = self.reader.read_value().await?;
                        Self::handle(&self.writer, async { Ok(()) }).await?
                    }
                    Operation::QueryPathInfo => {
                        let path: StorePath<String> = self.reader.read_value().await?;
                        Self::handle(&self.writer, io.query_path_info(&path)).await?
                    }
                    Operation::QueryPathFromHashPart => {
                        let hash: Bytes = self.reader.read_value().await?;
                        Self::handle(&self.writer, io.query_path_from_hash_part(&hash)).await?
                    }
                    Operation::QueryValidPaths => {
                        let query: QueryValidPaths = self.reader.read_value().await?;
                        Self::handle(&self.writer, io.query_valid_paths(&query)).await?
                    }
                    Operation::QueryValidDerivers => {
                        let path: StorePath<String> = self.reader.read_value().await?;
                        Self::handle(&self.writer, io.query_valid_derivers(&path)).await?
                    }
                    // FUTUREWORK: These are just stubs that return an empty list.
                    // It's important not to return an error for the local-overlay:// store
                    // to work properly. While it will not see certain referrers and realizations
                    // it will not fail on various operations like gc and optimize store. At the
                    // same time, returning an empty list here shouldn't break any of local-overlay store's
                    // invariants.
                    Operation::QueryReferrers | Operation::QueryRealisation => {
                        let _: String = self.reader.read_value().await?;
                        Self::handle(&self.writer, async move {
                            warn!(
                                ?operation,
                                "This operation is not implemented. Returning empty result..."
                            );
                            Ok(Vec::<StorePath<String>>::new())
                        })
                        .await?
                    }
                    Operation::AddToStoreNar => {
                        let request: AddToStoreNarRequest = self.reader.read_value().await?;
                        let minor_version = self.protocol_version.minor();
                        match minor_version {
                            ..21 => {
                                // Before protocol version 1.21, the nar is sent unframed, so we just
                                // pass the reader directly to the operation.
                                Self::handle(
                                    &self.writer,
                                    self.io.add_to_store_nar(request, &mut self.reader),
                                )
                                .await?
                            }
                            21..23 => {
                                // Protocol versions 1.21 .. 1.23 use STDERR_READ protocol, see logging.md#stderr_read.
                                Self::handle(&self.writer, async {
                                    let mut writer = self.writer.lock().await;
                                    let mut reader = StderrReadFramedReader::new(
                                        &mut self.reader,
                                        writer.deref_mut(),
                                    );
                                    self.io.add_to_store_nar(request, &mut reader).await
                                })
                                .await?
                            }
                            23.. => {
                                // Starting at protocol version 1.23, the framed protocol is used, see serialization.md#framed
                                let mut framed = NixFramedReader::new(&mut self.reader);
                                Self::handle(&self.writer, async {
                                    self.io.add_to_store_nar(request, &mut framed).await
                                })
                                .await?
                            }
                        }
                    }
                    _ => {
                        return Err(std::io::Error::other(format!(
                            "Operation {operation:?} is not implemented"
                        )));
                    }
                },
                _ => {
                    return Err(std::io::Error::other(format!(
                        "Unknown operation code received: {op_code}"
                    )));
                }
            }
        }
    }

    /// Handles the operation and sends the response or error to the client.
    ///
    /// As per nix daemon protocol, after sending the request, the client expects zero or more
    /// log lines/activities followed by either
    /// * STDERR_LAST and the response bytes
    /// * STDERR_ERROR and the error
    ///
    /// This is a helper method, awaiting on the passed in future and then
    /// handling log lines/activities as described above.
    async fn handle<T>(
        writer: &Arc<Mutex<NixWriter<WriteHalf<RW>>>>,
        future: impl Future<Output = std::io::Result<T>>,
    ) -> Result<(), std::io::Error>
    where
        T: NixSerialize + Send,
    {
        let result = future.await;
        let mut writer = writer.lock().await;

        match result {
            Ok(r) => {
                // the protocol requires that we first indicate that we are done sending logs
                // by sending STDERR_LAST and then the response.
                writer.write_number(STDERR_LAST).await?;
                writer.write_value(&r).await?;
                writer.flush().await
            }
            Err(e) => {
                debug!(err = ?e, "IO error");
                writer.write_number(STDERR_ERROR).await?;
                writer.write_value(&NixError::new(format!("{e:?}"))).await?;
                writer.flush().await
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::{io::ErrorKind, sync::Arc};

    use mockall::predicate;
    use tokio::io::AsyncWriteExt;

    use crate::{
        nix_daemon::MockNixDaemonIO,
        wire::ProtocolVersion,
        worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2},
    };

    #[tokio::test]
    async fn test_daemon_initialization() {
        let mut builder = tokio_test::io::Builder::new();
        let test_conn = builder
            .read(&WORKER_MAGIC_1.to_le_bytes())
            .write(&WORKER_MAGIC_2.to_le_bytes())
            // Our version is 1.37
            .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
            // The client's versin is 1.35
            .read(&[35, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
            // cpu affinity
            .read(&[0; 8])
            // reservespace
            .read(&[0; 8])
            // version (size)
            .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
            // version (data == 2.18.2 + padding)
            .write(&[50, 46, 49, 56, 46, 50, 0, 0])
            // Trusted (1 == client trusted)
            .write(&[1, 0, 0, 0, 0, 0, 0, 0])
            // STDERR_LAST
            .write(&[115, 116, 108, 97, 0, 0, 0, 0]);

        let mut bytes = Vec::new();
        let mut writer = NixWriter::new(&mut bytes);
        writer
            .write_value(&ClientSettings::default())
            .await
            .unwrap();
        writer.flush().await.unwrap();

        let test_conn = test_conn
            // SetOptions op
            .read(&[19, 0, 0, 0, 0, 0, 0, 0])
            .read(&bytes)
            // STDERR_LAST
            .write(&[115, 116, 108, 97, 0, 0, 0, 0])
            .build();

        let mock = MockNixDaemonIO::new();
        let daemon = NixDaemon::initialize(Arc::new(mock), test_conn)
            .await
            .unwrap();
        assert_eq!(daemon.client_settings, ClientSettings::default());
        assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35));
    }

    async fn serialize<T>(req: &T, protocol_version: ProtocolVersion) -> Vec<u8>
    where
        T: NixSerialize + Send,
    {
        let mut result: Vec<u8> = Vec::new();
        let mut w = NixWriter::builder()
            .set_version(protocol_version)
            .build(&mut result);
        w.write_value(req).await.unwrap();
        w.flush().await.unwrap();
        result
    }

    async fn respond<T>(
        resp: &Result<T, std::io::Error>,
        protocol_version: ProtocolVersion,
    ) -> Vec<u8>
    where
        T: NixSerialize + Send,
    {
        let mut result: Vec<u8> = Vec::new();
        let mut w = NixWriter::builder()
            .set_version(protocol_version)
            .build(&mut result);
        match resp {
            Ok(value) => {
                w.write_value(&STDERR_LAST).await.unwrap();
                w.write_value(value).await.unwrap();
            }
            Err(e) => {
                w.write_value(&STDERR_ERROR).await.unwrap();
                w.write_value(&NixError::new(format!("{:?}", e)))
                    .await
                    .unwrap();
            }
        }
        w.flush().await.unwrap();
        result
    }

    #[tokio::test]
    async fn test_handle_is_valid_path_ok() {
        let version = ProtocolVersion::from_parts(1, 37);
        let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
        let mut mock = MockNixDaemonIO::new();
        let (reader, writer) = split(io);
        let path: StorePath<String> = StorePath::<String>::from_absolute_path(
            "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
        )
        .unwrap();
        mock.expect_is_valid_path()
            .with(predicate::eq(path.clone()))
            .times(1)
            .returning(|_| Box::pin(async { Ok(true) }));

        handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
        handle.read(&serialize(&path, version).await);
        handle.write(&respond(&Ok(true), version).await);
        drop(handle);

        let mut daemon = NixDaemon::new(
            Arc::new(mock),
            version,
            ClientSettings::default(),
            NixReader::new(reader),
            NixWriter::new(writer),
        );
        assert_eq!(
            ErrorKind::UnexpectedEof,
            daemon
                .handle_client()
                .await
                .expect_err("Expecting eof")
                .kind()
        );
    }

    #[tokio::test]
    async fn test_handle_is_valid_path_err() {
        let version = ProtocolVersion::from_parts(1, 37);
        let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
        let mut mock = MockNixDaemonIO::new();
        let (reader, writer) = split(io);
        let path: StorePath<String> = StorePath::<String>::from_absolute_path(
            "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
        )
        .unwrap();
        mock.expect_is_valid_path()
            .with(predicate::eq(path.clone()))
            .times(1)
            .returning(|_| Box::pin(async { Err(std::io::Error::other("hello")) }));

        handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
        handle.read(&serialize(&path, version).await);
        handle.write(&respond::<bool>(&Err(std::io::Error::other("hello")), version).await);
        drop(handle);

        let mut daemon = NixDaemon::new(
            Arc::new(mock),
            version,
            ClientSettings::default(),
            NixReader::new(reader),
            NixWriter::new(writer),
        );
        assert_eq!(
            ErrorKind::UnexpectedEof,
            daemon
                .handle_client()
                .await
                .expect_err("Expecting eof")
                .kind()
        );
    }
}