about summary refs log tree commit diff
path: root/tvix/tools/turbofetch/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/tools/turbofetch/src/main.rs')
-rw-r--r--tvix/tools/turbofetch/src/main.rs220
1 files changed, 220 insertions, 0 deletions
diff --git a/tvix/tools/turbofetch/src/main.rs b/tvix/tools/turbofetch/src/main.rs
new file mode 100644
index 000000000000..4b3a50eb3941
--- /dev/null
+++ b/tvix/tools/turbofetch/src/main.rs
@@ -0,0 +1,220 @@
+//! turbofetch is a high-performance bulk S3 object aggregator.
+//!
+//! It operates on two S3 buckets: a source bucket (nix-cache), and a
+//! work bucket defined at runtime. The work bucket contains a job file
+//! consisting of concatenated 32-character keys, representing narinfo
+//! files in the source bucket, without the `.narinfo` suffix or any
+//! other separators.
+//!
+//! Each run of turbofetch processes a half-open range of indices from the
+//! job file, and outputs a zstd stream of concatenated objects, without
+//! additional separators and in no particular order. These segment files
+//! are written into the work bucket, named for the range of indices they
+//! cover. `/narinfo.zst/000000000c380d40-000000000c385b60` covers the 20k
+//! objects `[0xc380d40, 0xc385b60) = [205000000, 205020000)`. Empirically,
+//! segment files of 20k objects achieve a compression ratio of 4.7x.
+//!
+//! Reassembly is left to narinfo2parquet, which interprets StorePath lines.
+//!
+//! TODO(edef): any retries/error handling whatsoever
+//! Currently, it fails an entire range if anything goes wrong, and doesn't
+//! write any output.
+
+use bytes::Bytes;
+use futures::{stream::FuturesUnordered, Stream, TryStreamExt};
+use rusoto_core::ByteStream;
+use rusoto_s3::{GetObjectRequest, PutObjectRequest, S3Client, S3};
+use serde::Deserialize;
+use std::{io::Write, mem, ops::Range, ptr};
+use tokio::{
+    io::{self, AsyncReadExt, AsyncWriteExt},
+    net::TcpStream,
+};
+
+/// Fetch a group of keys, streaming concatenated chunks as they arrive from S3.
+/// `keys` must be a slice from the job file. Any network error at all fails the
+/// entire batch, and there is no rate limiting.
+fn fetch(keys: &[[u8; 32]]) -> impl Stream<Item = io::Result<Bytes>> {
+    // S3 supports only HTTP/1.1, but we can ease the pain somewhat by using
+    // HTTP pipelining. It terminates the TCP connection after receiving 100
+    // requests, so we chunk the keys up accordingly, and make one connection
+    // for each chunk.
+    keys.chunks(100)
+        .map(|chunk| {
+            const PREFIX: &[u8] = b"GET /nix-cache/";
+            const SUFFIX: &[u8] = b".narinfo HTTP/1.1\nHost: s3.amazonaws.com\n\n";
+            const LENGTH: usize = PREFIX.len() + 32 + SUFFIX.len();
+
+            let mut request = Vec::with_capacity(LENGTH * 100);
+            for key in chunk {
+                request.extend_from_slice(PREFIX);
+                request.extend_from_slice(key);
+                request.extend_from_slice(SUFFIX);
+            }
+
+            (request, chunk.len())
+        })
+        .map(|(request, n)| async move {
+            let (mut read, mut write) = TcpStream::connect("s3.amazonaws.com:80")
+                .await?
+                .into_split();
+
+            let _handle = tokio::spawn(async move {
+                let request = request;
+                write.write_all(&request).await
+            });
+
+            let mut buffer = turbofetch::Buffer::new(512 * 1024);
+            let mut bodies = vec![];
+
+            for _ in 0..n {
+                let body = turbofetch::parse_response(&mut read, &mut buffer).await?;
+                bodies.extend_from_slice(body);
+            }
+
+            Ok::<_, io::Error>(Bytes::from(bodies))
+        })
+        .collect::<FuturesUnordered<_>>()
+}
+
+/// Retrieve a range of keys from the job file.
+async fn get_range(
+    s3: &'static S3Client,
+    bucket: String,
+    key: String,
+    range: Range<u64>,
+) -> io::Result<Box<[[u8; 32]]>> {
+    let resp = s3
+        .get_object(GetObjectRequest {
+            bucket,
+            key,
+            range: Some(format!("bytes={}-{}", range.start * 32, range.end * 32 - 1)),
+            ..GetObjectRequest::default()
+        })
+        .await
+        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
+
+    let mut body = vec![];
+    resp.body
+        .ok_or(io::ErrorKind::InvalidData)?
+        .into_async_read()
+        .read_to_end(&mut body)
+        .await?;
+
+    let body = exact_chunks(body.into_boxed_slice()).ok_or(io::ErrorKind::InvalidData)?;
+
+    Ok(body)
+}
+
+fn exact_chunks(mut buf: Box<[u8]>) -> Option<Box<[[u8; 32]]>> {
+    // SAFETY: We ensure that `buf.len()` is a multiple of 32, and there are no alignment requirements.
+    unsafe {
+        let ptr = buf.as_mut_ptr();
+        let len = buf.len();
+
+        if len % 32 != 0 {
+            return None;
+        }
+
+        let ptr = ptr as *mut [u8; 32];
+        let len = len / 32;
+        mem::forget(buf);
+
+        Some(Box::from_raw(ptr::slice_from_raw_parts_mut(ptr, len)))
+    }
+}
+
+// TODO(edef): factor this out into a separate entry point
+#[tokio::main(flavor = "current_thread")]
+async fn main() -> Result<(), lambda_runtime::Error> {
+    let s3 = S3Client::new(rusoto_core::Region::UsEast1);
+    let s3 = &*Box::leak(Box::new(s3));
+
+    tracing_subscriber::fmt()
+        .json()
+        .with_max_level(tracing::Level::INFO)
+        // this needs to be set to remove duplicated information in the log.
+        .with_current_span(false)
+        // this needs to be set to false, otherwise ANSI color codes will
+        // show up in a confusing manner in CloudWatch logs.
+        .with_ansi(false)
+        // disabling time is handy because CloudWatch will add the ingestion time.
+        .without_time()
+        // remove the name of the function from every log entry
+        .with_target(false)
+        .init();
+
+    lambda_runtime::run(lambda_runtime::service_fn(|event| func(s3, event))).await
+}
+
+/// Lambda request body
+#[derive(Debug, Deserialize)]
+struct Params {
+    work_bucket: String,
+    job_file: String,
+    start: u64,
+    end: u64,
+}
+
+#[tracing::instrument(skip(s3, event), fields(req_id = %event.context.request_id))]
+async fn func(
+    s3: &'static S3Client,
+    event: lambda_runtime::LambdaEvent<
+        aws_lambda_events::lambda_function_urls::LambdaFunctionUrlRequest,
+    >,
+) -> Result<&'static str, lambda_runtime::Error> {
+    let mut params = event.payload.body.ok_or("no body")?;
+
+    if event.payload.is_base64_encoded {
+        params = String::from_utf8(data_encoding::BASE64.decode(params.as_bytes())?)?;
+    }
+
+    let params: Params = serde_json::from_str(&params)?;
+
+    if params.start >= params.end {
+        return Err("nope".into());
+    }
+
+    let keys = get_range(
+        s3,
+        params.work_bucket.clone(),
+        params.job_file.to_owned(),
+        params.start..params.end,
+    )
+    .await?;
+
+    let zchunks = fetch(&keys)
+        .try_fold(
+            Box::new(zstd::Encoder::new(vec![], zstd::DEFAULT_COMPRESSION_LEVEL).unwrap()),
+            |mut w, buf| {
+                w.write_all(&buf).unwrap();
+                async { Ok(w) }
+            },
+        )
+        .await?;
+
+    let zchunks = to_byte_stream(zchunks.finish().unwrap());
+
+    tracing::info!("we got to put_object");
+
+    s3.put_object(PutObjectRequest {
+        bucket: params.work_bucket,
+        key: format!("narinfo.zst/{:016x}-{:016x}", params.start, params.end),
+        body: Some(zchunks),
+        ..Default::default()
+    })
+    .await
+    .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
+
+    tracing::info!("… and it worked!");
+
+    Ok("OK")
+}
+
+fn to_byte_stream(buffer: Vec<u8>) -> ByteStream {
+    let size_hint = buffer.len();
+    ByteStream::new_with_size(
+        futures::stream::once(async { Ok(buffer.into()) }),
+        size_hint,
+    )
+}