about summary refs log tree commit diff
path: root/tvix/castore/src/blobservice/naive_seeker.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/castore/src/blobservice/naive_seeker.rs')
-rw-r--r--tvix/castore/src/blobservice/naive_seeker.rs104
1 files changed, 52 insertions, 52 deletions
diff --git a/tvix/castore/src/blobservice/naive_seeker.rs b/tvix/castore/src/blobservice/naive_seeker.rs
index 1475de5808..f5a5307150 100644
--- a/tvix/castore/src/blobservice/naive_seeker.rs
+++ b/tvix/castore/src/blobservice/naive_seeker.rs
@@ -4,7 +4,7 @@ use pin_project_lite::pin_project;
 use std::io;
 use std::task::Poll;
 use tokio::io::AsyncRead;
-use tracing::{debug, instrument};
+use tracing::{debug, instrument, trace, warn};
 
 pin_project! {
     /// This implements [tokio::io::AsyncSeek] for and [tokio::io::AsyncRead] by
@@ -65,6 +65,9 @@ pin_project! {
     }
 }
 
+/// The buffer size used to discard data.
+const DISCARD_BUF_SIZE: usize = 4096;
+
 impl<R: tokio::io::AsyncRead> NaiveSeeker<R> {
     pub fn new(r: R) -> Self {
         NaiveSeeker {
@@ -76,6 +79,7 @@ impl<R: tokio::io::AsyncRead> NaiveSeeker<R> {
 }
 
 impl<R: tokio::io::AsyncRead> tokio::io::AsyncRead for NaiveSeeker<R> {
+    #[instrument(level = "trace", skip_all)]
     fn poll_read(
         self: std::pin::Pin<&mut Self>,
         cx: &mut std::task::Context<'_>,
@@ -87,9 +91,12 @@ impl<R: tokio::io::AsyncRead> tokio::io::AsyncRead for NaiveSeeker<R> {
 
         let this = self.project();
         ready!(this.r.poll_read(cx, buf))?;
+
         let bytes_read = buf.filled().len() - filled_before;
         *this.pos += bytes_read as u64;
 
+        trace!(bytes_read = bytes_read, new_pos = this.pos, "poll_read");
+
         Ok(()).into()
     }
 }
@@ -102,16 +109,18 @@ impl<R: tokio::io::AsyncRead> tokio::io::AsyncBufRead for NaiveSeeker<R> {
         self.project().r.poll_fill_buf(cx)
     }
 
+    #[instrument(level = "trace", skip(self))]
     fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
         let this = self.project();
         this.r.consume(amt);
-        let pos: &mut u64 = this.pos;
-        *pos += amt as u64;
+        *this.pos += amt as u64;
+
+        trace!(new_pos = this.pos, "consume");
     }
 }
 
 impl<R: tokio::io::AsyncRead> tokio::io::AsyncSeek for NaiveSeeker<R> {
-    #[instrument(skip(self), err(Debug))]
+    #[instrument(level="trace", skip(self), fields(inner_pos=%self.pos), err(Debug))]
     fn start_seek(
         self: std::pin::Pin<&mut Self>,
         position: std::io::SeekFrom,
@@ -146,23 +155,24 @@ impl<R: tokio::io::AsyncRead> tokio::io::AsyncSeek for NaiveSeeker<R> {
             }
         };
 
-        debug!(absolute_offset=?absolute_offset, "seek");
-
-        // we already know absolute_offset is larger than self.pos
+        // we already know absolute_offset is >= self.pos
         debug_assert!(
             absolute_offset >= self.pos,
-            "absolute_offset {} is larger than self.pos {}",
+            "absolute_offset {} must be >= self.pos {}",
             absolute_offset,
             self.pos
         );
 
         // calculate bytes to skip
-        *self.project().bytes_to_skip = absolute_offset - self.pos;
+        let this = self.project();
+        *this.bytes_to_skip = absolute_offset - *this.pos;
+
+        debug!(bytes_to_skip = *this.bytes_to_skip, "seek");
 
         Ok(())
     }
 
-    #[instrument(skip(self))]
+    #[instrument(skip_all)]
     fn poll_complete(
         mut self: std::pin::Pin<&mut Self>,
         cx: &mut std::task::Context<'_>,
@@ -174,44 +184,34 @@ impl<R: tokio::io::AsyncRead> tokio::io::AsyncSeek for NaiveSeeker<R> {
 
         // discard some bytes, until pos is where we want it to be.
         // We create a buffer that we'll discard later on.
-        let mut buf = [0; 1024];
+        let mut discard_buf = [0; DISCARD_BUF_SIZE];
 
         // Loop until we've reached the desired seek position. This is done by issuing repeated
-        // `poll_read` calls. If the data is not available yet, we will yield back to the executor
+        // `poll_read` calls.
+        // If the data is not available yet, we will yield back to the executor
         // and wait to be polled again.
         loop {
+            if self.bytes_to_skip == 0 {
+                return Poll::Ready(Ok(self.pos));
+            }
+
             // calculate the length we want to skip at most, which is either a max
             // buffer size, or the number of remaining bytes to read, whatever is
             // smaller.
-            let bytes_to_skip = std::cmp::min(self.bytes_to_skip as usize, buf.len());
-
-            let mut read_buf = tokio::io::ReadBuf::new(&mut buf[..bytes_to_skip]);
-
-            match self.as_mut().poll_read(cx, &mut read_buf) {
-                Poll::Ready(_a) => {
-                    let bytes_read = read_buf.filled().len() as u64;
-
-                    if bytes_read == 0 {
-                        return Poll::Ready(Err(io::Error::new(
-                            io::ErrorKind::UnexpectedEof,
-                            format!(
-                                "tried to skip {} bytes, but only was able to skip {} until reaching EOF",
-                                bytes_to_skip, bytes_read
-                            ),
-                        )));
-                    }
+            let bytes_to_skip_now = std::cmp::min(self.bytes_to_skip as usize, discard_buf.len());
+            let mut discard_buf = tokio::io::ReadBuf::new(&mut discard_buf[..bytes_to_skip_now]);
 
-                    // calculate bytes to skip
-                    let bytes_to_skip = self.bytes_to_skip - bytes_read;
+            ready!(self.as_mut().poll_read(cx, &mut discard_buf))?;
+            let bytes_skipped = discard_buf.filled().len();
 
-                    *self.as_mut().project().bytes_to_skip = bytes_to_skip;
-
-                    if bytes_to_skip == 0 {
-                        return Poll::Ready(Ok(self.pos));
-                    }
-                }
-                Poll::Pending => return Poll::Pending,
-            };
+            if bytes_skipped == 0 {
+                return Poll::Ready(Err(io::Error::new(
+                    io::ErrorKind::UnexpectedEof,
+                    "got EOF while trying to skip bytes",
+                )));
+            }
+            // decrement bytes to skip. The poll_read call already updated self.pos.
+            *self.as_mut().project().bytes_to_skip -= bytes_skipped as u64;
         }
     }
 }
@@ -220,16 +220,16 @@ impl<R: tokio::io::AsyncRead + Send + Unpin + 'static> BlobReader for NaiveSeeke
 
 #[cfg(test)]
 mod tests {
-    use super::NaiveSeeker;
+    use super::{NaiveSeeker, DISCARD_BUF_SIZE};
     use std::io::{Cursor, SeekFrom};
     use tokio::io::{AsyncReadExt, AsyncSeekExt};
 
-    /// This seek requires multiple `poll_read` as we use a 1024 bytes internal
-    /// buffer when doing the seek.
+    /// This seek requires multiple `poll_read` as we use a multiples of
+    /// DISCARD_BUF_SIZE when doing the seek.
     /// This ensures we don't hang indefinitely.
     #[tokio::test]
     async fn seek() {
-        let buf = vec![0u8; 4096];
+        let buf = vec![0u8; DISCARD_BUF_SIZE * 4];
         let reader = Cursor::new(&buf);
         let mut seeker = NaiveSeeker::new(reader);
         seeker.seek(SeekFrom::Start(4000)).await.unwrap();
@@ -237,29 +237,29 @@ mod tests {
 
     #[tokio::test]
     async fn seek_read() {
-        let mut buf = vec![0u8; 2048];
-        buf.extend_from_slice(&[1u8; 2048]);
-        buf.extend_from_slice(&[2u8; 2048]);
+        let mut buf = vec![0u8; DISCARD_BUF_SIZE * 2];
+        buf.extend_from_slice(&[1u8; DISCARD_BUF_SIZE * 2]);
+        buf.extend_from_slice(&[2u8; DISCARD_BUF_SIZE * 2]);
 
         let reader = Cursor::new(&buf);
         let mut seeker = NaiveSeeker::new(reader);
 
-        let mut read_buf = vec![0u8; 1024];
+        let mut read_buf = vec![0u8; DISCARD_BUF_SIZE];
         seeker.read_exact(&mut read_buf).await.expect("must read");
-        assert_eq!(read_buf.as_slice(), &[0u8; 1024]);
+        assert_eq!(read_buf.as_slice(), &[0u8; DISCARD_BUF_SIZE]);
 
         seeker
-            .seek(SeekFrom::Current(1024))
+            .seek(SeekFrom::Current(DISCARD_BUF_SIZE as i64))
             .await
             .expect("must seek");
         seeker.read_exact(&mut read_buf).await.expect("must read");
-        assert_eq!(read_buf.as_slice(), &[1u8; 1024]);
+        assert_eq!(read_buf.as_slice(), &[1u8; DISCARD_BUF_SIZE]);
 
         seeker
-            .seek(SeekFrom::Start(2 * 2048))
+            .seek(SeekFrom::Start(2 * 2 * DISCARD_BUF_SIZE as u64))
             .await
             .expect("must seek");
         seeker.read_exact(&mut read_buf).await.expect("must read");
-        assert_eq!(read_buf.as_slice(), &[2u8; 1024]);
+        assert_eq!(read_buf.as_slice(), &[2u8; DISCARD_BUF_SIZE]);
     }
 }