diff options
-rw-r--r-- | tvix/castore/src/blobservice/naive_seeker.rs | 77 |
1 files changed, 35 insertions, 42 deletions
diff --git a/tvix/castore/src/blobservice/naive_seeker.rs b/tvix/castore/src/blobservice/naive_seeker.rs index 1475de580814..5c943b9870b4 100644 --- a/tvix/castore/src/blobservice/naive_seeker.rs +++ b/tvix/castore/src/blobservice/naive_seeker.rs @@ -65,6 +65,9 @@ pin_project! { } } +/// The buffer size used to discard data. +const DISCARD_BUF_SIZE: usize = 4096; + impl<R: tokio::io::AsyncRead> NaiveSeeker<R> { pub fn new(r: R) -> Self { NaiveSeeker { @@ -174,44 +177,34 @@ impl<R: tokio::io::AsyncRead> tokio::io::AsyncSeek for NaiveSeeker<R> { // discard some bytes, until pos is where we want it to be. // We create a buffer that we'll discard later on. - let mut buf = [0; 1024]; + let mut discard_buf = [0; DISCARD_BUF_SIZE]; // Loop until we've reached the desired seek position. This is done by issuing repeated - // `poll_read` calls. If the data is not available yet, we will yield back to the executor + // `poll_read` calls. + // If the data is not available yet, we will yield back to the executor // and wait to be polled again. loop { + if self.bytes_to_skip == 0 { + return Poll::Ready(Ok(self.pos)); + } + // calculate the length we want to skip at most, which is either a max // buffer size, or the number of remaining bytes to read, whatever is // smaller. - let bytes_to_skip = std::cmp::min(self.bytes_to_skip as usize, buf.len()); - - let mut read_buf = tokio::io::ReadBuf::new(&mut buf[..bytes_to_skip]); - - match self.as_mut().poll_read(cx, &mut read_buf) { - Poll::Ready(_a) => { - let bytes_read = read_buf.filled().len() as u64; + let bytes_to_skip_now = std::cmp::min(self.bytes_to_skip as usize, discard_buf.len()); + let mut discard_buf = tokio::io::ReadBuf::new(&mut discard_buf[..bytes_to_skip_now]); - if bytes_read == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - format!( - "tried to skip {} bytes, but only was able to skip {} until reaching EOF", - bytes_to_skip, bytes_read - ), - ))); - } + ready!(self.as_mut().poll_read(cx, &mut discard_buf))?; + let bytes_skipped = discard_buf.filled().len(); - // calculate bytes to skip - let bytes_to_skip = self.bytes_to_skip - bytes_read; - - *self.as_mut().project().bytes_to_skip = bytes_to_skip; - - if bytes_to_skip == 0 { - return Poll::Ready(Ok(self.pos)); - } - } - Poll::Pending => return Poll::Pending, - }; + if bytes_skipped == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "got EOF while trying to skip bytes", + ))); + } + // decrement bytes to skip. The poll_read call already updated self.pos. + *self.as_mut().project().bytes_to_skip -= bytes_skipped as u64; } } } @@ -220,16 +213,16 @@ impl<R: tokio::io::AsyncRead + Send + Unpin + 'static> BlobReader for NaiveSeeke #[cfg(test)] mod tests { - use super::NaiveSeeker; + use super::{NaiveSeeker, DISCARD_BUF_SIZE}; use std::io::{Cursor, SeekFrom}; use tokio::io::{AsyncReadExt, AsyncSeekExt}; - /// This seek requires multiple `poll_read` as we use a 1024 bytes internal - /// buffer when doing the seek. + /// This seek requires multiple `poll_read` as we use a multiples of + /// DISCARD_BUF_SIZE when doing the seek. /// This ensures we don't hang indefinitely. #[tokio::test] async fn seek() { - let buf = vec![0u8; 4096]; + let buf = vec![0u8; DISCARD_BUF_SIZE * 4]; let reader = Cursor::new(&buf); let mut seeker = NaiveSeeker::new(reader); seeker.seek(SeekFrom::Start(4000)).await.unwrap(); @@ -237,29 +230,29 @@ mod tests { #[tokio::test] async fn seek_read() { - let mut buf = vec![0u8; 2048]; - buf.extend_from_slice(&[1u8; 2048]); - buf.extend_from_slice(&[2u8; 2048]); + let mut buf = vec![0u8; DISCARD_BUF_SIZE * 2]; + buf.extend_from_slice(&[1u8; DISCARD_BUF_SIZE * 2]); + buf.extend_from_slice(&[2u8; DISCARD_BUF_SIZE * 2]); let reader = Cursor::new(&buf); let mut seeker = NaiveSeeker::new(reader); - let mut read_buf = vec![0u8; 1024]; + let mut read_buf = vec![0u8; DISCARD_BUF_SIZE]; seeker.read_exact(&mut read_buf).await.expect("must read"); - assert_eq!(read_buf.as_slice(), &[0u8; 1024]); + assert_eq!(read_buf.as_slice(), &[0u8; DISCARD_BUF_SIZE]); seeker - .seek(SeekFrom::Current(1024)) + .seek(SeekFrom::Current(DISCARD_BUF_SIZE as i64)) .await .expect("must seek"); seeker.read_exact(&mut read_buf).await.expect("must read"); - assert_eq!(read_buf.as_slice(), &[1u8; 1024]); + assert_eq!(read_buf.as_slice(), &[1u8; DISCARD_BUF_SIZE]); seeker - .seek(SeekFrom::Start(2 * 2048)) + .seek(SeekFrom::Start(2 * 2 * DISCARD_BUF_SIZE as u64)) .await .expect("must seek"); seeker.read_exact(&mut read_buf).await.expect("must read"); - assert_eq!(read_buf.as_slice(), &[2u8; 1024]); + assert_eq!(read_buf.as_slice(), &[2u8; DISCARD_BUF_SIZE]); } } |