about summary refs log tree commit diff
path: root/tvix/castore/src/blobservice/naive_seeker.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/castore/src/blobservice/naive_seeker.rs')
-rw-r--r--tvix/castore/src/blobservice/naive_seeker.rs77
1 files changed, 35 insertions, 42 deletions
diff --git a/tvix/castore/src/blobservice/naive_seeker.rs b/tvix/castore/src/blobservice/naive_seeker.rs
index 1475de580814..5c943b9870b4 100644
--- a/tvix/castore/src/blobservice/naive_seeker.rs
+++ b/tvix/castore/src/blobservice/naive_seeker.rs
@@ -65,6 +65,9 @@ pin_project! {
     }
 }
 
+/// The buffer size used to discard data.
+const DISCARD_BUF_SIZE: usize = 4096;
+
 impl<R: tokio::io::AsyncRead> NaiveSeeker<R> {
     pub fn new(r: R) -> Self {
         NaiveSeeker {
@@ -174,44 +177,34 @@ impl<R: tokio::io::AsyncRead> tokio::io::AsyncSeek for NaiveSeeker<R> {
 
         // discard some bytes, until pos is where we want it to be.
         // We create a buffer that we'll discard later on.
-        let mut buf = [0; 1024];
+        let mut discard_buf = [0; DISCARD_BUF_SIZE];
 
         // Loop until we've reached the desired seek position. This is done by issuing repeated
-        // `poll_read` calls. If the data is not available yet, we will yield back to the executor
+        // `poll_read` calls.
+        // If the data is not available yet, we will yield back to the executor
         // and wait to be polled again.
         loop {
+            if self.bytes_to_skip == 0 {
+                return Poll::Ready(Ok(self.pos));
+            }
+
             // calculate the length we want to skip at most, which is either a max
             // buffer size, or the number of remaining bytes to read, whatever is
             // smaller.
-            let bytes_to_skip = std::cmp::min(self.bytes_to_skip as usize, buf.len());
-
-            let mut read_buf = tokio::io::ReadBuf::new(&mut buf[..bytes_to_skip]);
-
-            match self.as_mut().poll_read(cx, &mut read_buf) {
-                Poll::Ready(_a) => {
-                    let bytes_read = read_buf.filled().len() as u64;
+            let bytes_to_skip_now = std::cmp::min(self.bytes_to_skip as usize, discard_buf.len());
+            let mut discard_buf = tokio::io::ReadBuf::new(&mut discard_buf[..bytes_to_skip_now]);
 
-                    if bytes_read == 0 {
-                        return Poll::Ready(Err(io::Error::new(
-                            io::ErrorKind::UnexpectedEof,
-                            format!(
-                                "tried to skip {} bytes, but only was able to skip {} until reaching EOF",
-                                bytes_to_skip, bytes_read
-                            ),
-                        )));
-                    }
+            ready!(self.as_mut().poll_read(cx, &mut discard_buf))?;
+            let bytes_skipped = discard_buf.filled().len();
 
-                    // calculate bytes to skip
-                    let bytes_to_skip = self.bytes_to_skip - bytes_read;
-
-                    *self.as_mut().project().bytes_to_skip = bytes_to_skip;
-
-                    if bytes_to_skip == 0 {
-                        return Poll::Ready(Ok(self.pos));
-                    }
-                }
-                Poll::Pending => return Poll::Pending,
-            };
+            if bytes_skipped == 0 {
+                return Poll::Ready(Err(io::Error::new(
+                    io::ErrorKind::UnexpectedEof,
+                    "got EOF while trying to skip bytes",
+                )));
+            }
+            // decrement bytes to skip. The poll_read call already updated self.pos.
+            *self.as_mut().project().bytes_to_skip -= bytes_skipped as u64;
         }
     }
 }
@@ -220,16 +213,16 @@ impl<R: tokio::io::AsyncRead + Send + Unpin + 'static> BlobReader for NaiveSeeke
 
 #[cfg(test)]
 mod tests {
-    use super::NaiveSeeker;
+    use super::{NaiveSeeker, DISCARD_BUF_SIZE};
     use std::io::{Cursor, SeekFrom};
     use tokio::io::{AsyncReadExt, AsyncSeekExt};
 
-    /// This seek requires multiple `poll_read` as we use a 1024 bytes internal
-    /// buffer when doing the seek.
+    /// This seek requires multiple `poll_read` as we use a multiples of
+    /// DISCARD_BUF_SIZE when doing the seek.
     /// This ensures we don't hang indefinitely.
     #[tokio::test]
     async fn seek() {
-        let buf = vec![0u8; 4096];
+        let buf = vec![0u8; DISCARD_BUF_SIZE * 4];
         let reader = Cursor::new(&buf);
         let mut seeker = NaiveSeeker::new(reader);
         seeker.seek(SeekFrom::Start(4000)).await.unwrap();
@@ -237,29 +230,29 @@ mod tests {
 
     #[tokio::test]
     async fn seek_read() {
-        let mut buf = vec![0u8; 2048];
-        buf.extend_from_slice(&[1u8; 2048]);
-        buf.extend_from_slice(&[2u8; 2048]);
+        let mut buf = vec![0u8; DISCARD_BUF_SIZE * 2];
+        buf.extend_from_slice(&[1u8; DISCARD_BUF_SIZE * 2]);
+        buf.extend_from_slice(&[2u8; DISCARD_BUF_SIZE * 2]);
 
         let reader = Cursor::new(&buf);
         let mut seeker = NaiveSeeker::new(reader);
 
-        let mut read_buf = vec![0u8; 1024];
+        let mut read_buf = vec![0u8; DISCARD_BUF_SIZE];
         seeker.read_exact(&mut read_buf).await.expect("must read");
-        assert_eq!(read_buf.as_slice(), &[0u8; 1024]);
+        assert_eq!(read_buf.as_slice(), &[0u8; DISCARD_BUF_SIZE]);
 
         seeker
-            .seek(SeekFrom::Current(1024))
+            .seek(SeekFrom::Current(DISCARD_BUF_SIZE as i64))
             .await
             .expect("must seek");
         seeker.read_exact(&mut read_buf).await.expect("must read");
-        assert_eq!(read_buf.as_slice(), &[1u8; 1024]);
+        assert_eq!(read_buf.as_slice(), &[1u8; DISCARD_BUF_SIZE]);
 
         seeker
-            .seek(SeekFrom::Start(2 * 2048))
+            .seek(SeekFrom::Start(2 * 2 * DISCARD_BUF_SIZE as u64))
             .await
             .expect("must seek");
         seeker.read_exact(&mut read_buf).await.expect("must read");
-        assert_eq!(read_buf.as_slice(), &[2u8; 1024]);
+        assert_eq!(read_buf.as_slice(), &[2u8; DISCARD_BUF_SIZE]);
     }
 }