about summary refs log tree commit diff
path: root/tvix/castore/src/blobservice/naive_seeker.rs
diff options
context:
space:
mode:
authorFlorian Klink <flokli@flokli.de>2024-04-15T15·46+0300
committerflokli <flokli@flokli.de>2024-04-16T18·45+0000
commit9398bc46b60d809aac679157fe4c491586e649ff (patch)
tree5f350f4ebcc6a7b1ebe5dfac989d7508e61c7d25 /tvix/castore/src/blobservice/naive_seeker.rs
parent4d802fa0aefbc22b5f78bf7281c0028ad3bf2fff (diff)
refactor(tvix/castore/blob/naive_seeker): rework skipping for clarity r/7944
Increase the discard_buf to 4096 (as I've seen this size).
Use the ready! macro to propagate pendings.
Make it more clear what exactly should be skipped in total, and what
during the current iteration.

Also write down that poll_read call already takes care of updating
self.pos, as I ran into that trap earlier (and added it here).

Change-Id: I2d22e1c8a835c0f3dd0c648917009b2bad4fd57c
Reviewed-on: https://cl.tvl.fyi/c/depot/+/11442
Reviewed-by: raitobezarius <tvl@lahfa.xyz>
Tested-by: BuildkiteCI
Diffstat (limited to '')
-rw-r--r--tvix/castore/src/blobservice/naive_seeker.rs77
1 files changed, 35 insertions, 42 deletions
diff --git a/tvix/castore/src/blobservice/naive_seeker.rs b/tvix/castore/src/blobservice/naive_seeker.rs
index 1475de5808..5c943b9870 100644
--- a/tvix/castore/src/blobservice/naive_seeker.rs
+++ b/tvix/castore/src/blobservice/naive_seeker.rs
@@ -65,6 +65,9 @@ pin_project! {
     }
 }
 
+/// The buffer size used to discard data.
+const DISCARD_BUF_SIZE: usize = 4096;
+
 impl<R: tokio::io::AsyncRead> NaiveSeeker<R> {
     pub fn new(r: R) -> Self {
         NaiveSeeker {
@@ -174,44 +177,34 @@ impl<R: tokio::io::AsyncRead> tokio::io::AsyncSeek for NaiveSeeker<R> {
 
         // discard some bytes, until pos is where we want it to be.
         // We create a buffer that we'll discard later on.
-        let mut buf = [0; 1024];
+        let mut discard_buf = [0; DISCARD_BUF_SIZE];
 
         // Loop until we've reached the desired seek position. This is done by issuing repeated
-        // `poll_read` calls. If the data is not available yet, we will yield back to the executor
+        // `poll_read` calls.
+        // If the data is not available yet, we will yield back to the executor
         // and wait to be polled again.
         loop {
+            if self.bytes_to_skip == 0 {
+                return Poll::Ready(Ok(self.pos));
+            }
+
             // calculate the length we want to skip at most, which is either a max
             // buffer size, or the number of remaining bytes to read, whatever is
             // smaller.
-            let bytes_to_skip = std::cmp::min(self.bytes_to_skip as usize, buf.len());
-
-            let mut read_buf = tokio::io::ReadBuf::new(&mut buf[..bytes_to_skip]);
-
-            match self.as_mut().poll_read(cx, &mut read_buf) {
-                Poll::Ready(_a) => {
-                    let bytes_read = read_buf.filled().len() as u64;
+            let bytes_to_skip_now = std::cmp::min(self.bytes_to_skip as usize, discard_buf.len());
+            let mut discard_buf = tokio::io::ReadBuf::new(&mut discard_buf[..bytes_to_skip_now]);
 
-                    if bytes_read == 0 {
-                        return Poll::Ready(Err(io::Error::new(
-                            io::ErrorKind::UnexpectedEof,
-                            format!(
-                                "tried to skip {} bytes, but only was able to skip {} until reaching EOF",
-                                bytes_to_skip, bytes_read
-                            ),
-                        )));
-                    }
+            ready!(self.as_mut().poll_read(cx, &mut discard_buf))?;
+            let bytes_skipped = discard_buf.filled().len();
 
-                    // calculate bytes to skip
-                    let bytes_to_skip = self.bytes_to_skip - bytes_read;
-
-                    *self.as_mut().project().bytes_to_skip = bytes_to_skip;
-
-                    if bytes_to_skip == 0 {
-                        return Poll::Ready(Ok(self.pos));
-                    }
-                }
-                Poll::Pending => return Poll::Pending,
-            };
+            if bytes_skipped == 0 {
+                return Poll::Ready(Err(io::Error::new(
+                    io::ErrorKind::UnexpectedEof,
+                    "got EOF while trying to skip bytes",
+                )));
+            }
+            // decrement bytes to skip. The poll_read call already updated self.pos.
+            *self.as_mut().project().bytes_to_skip -= bytes_skipped as u64;
         }
     }
 }
@@ -220,16 +213,16 @@ impl<R: tokio::io::AsyncRead + Send + Unpin + 'static> BlobReader for NaiveSeeke
 
 #[cfg(test)]
 mod tests {
-    use super::NaiveSeeker;
+    use super::{NaiveSeeker, DISCARD_BUF_SIZE};
     use std::io::{Cursor, SeekFrom};
     use tokio::io::{AsyncReadExt, AsyncSeekExt};
 
-    /// This seek requires multiple `poll_read` as we use a 1024 bytes internal
-    /// buffer when doing the seek.
+    /// This seek requires multiple `poll_read` as we use a multiples of
+    /// DISCARD_BUF_SIZE when doing the seek.
     /// This ensures we don't hang indefinitely.
     #[tokio::test]
     async fn seek() {
-        let buf = vec![0u8; 4096];
+        let buf = vec![0u8; DISCARD_BUF_SIZE * 4];
         let reader = Cursor::new(&buf);
         let mut seeker = NaiveSeeker::new(reader);
         seeker.seek(SeekFrom::Start(4000)).await.unwrap();
@@ -237,29 +230,29 @@ mod tests {
 
     #[tokio::test]
     async fn seek_read() {
-        let mut buf = vec![0u8; 2048];
-        buf.extend_from_slice(&[1u8; 2048]);
-        buf.extend_from_slice(&[2u8; 2048]);
+        let mut buf = vec![0u8; DISCARD_BUF_SIZE * 2];
+        buf.extend_from_slice(&[1u8; DISCARD_BUF_SIZE * 2]);
+        buf.extend_from_slice(&[2u8; DISCARD_BUF_SIZE * 2]);
 
         let reader = Cursor::new(&buf);
         let mut seeker = NaiveSeeker::new(reader);
 
-        let mut read_buf = vec![0u8; 1024];
+        let mut read_buf = vec![0u8; DISCARD_BUF_SIZE];
         seeker.read_exact(&mut read_buf).await.expect("must read");
-        assert_eq!(read_buf.as_slice(), &[0u8; 1024]);
+        assert_eq!(read_buf.as_slice(), &[0u8; DISCARD_BUF_SIZE]);
 
         seeker
-            .seek(SeekFrom::Current(1024))
+            .seek(SeekFrom::Current(DISCARD_BUF_SIZE as i64))
             .await
             .expect("must seek");
         seeker.read_exact(&mut read_buf).await.expect("must read");
-        assert_eq!(read_buf.as_slice(), &[1u8; 1024]);
+        assert_eq!(read_buf.as_slice(), &[1u8; DISCARD_BUF_SIZE]);
 
         seeker
-            .seek(SeekFrom::Start(2 * 2048))
+            .seek(SeekFrom::Start(2 * 2 * DISCARD_BUF_SIZE as u64))
             .await
             .expect("must seek");
         seeker.read_exact(&mut read_buf).await.expect("must read");
-        assert_eq!(read_buf.as_slice(), &[2u8; 1024]);
+        assert_eq!(read_buf.as_slice(), &[2u8; DISCARD_BUF_SIZE]);
     }
 }