use super::BlobReader; use pin_project_lite::pin_project; use std::io; use std::task::Poll; use tokio::io::AsyncRead; use tracing::{debug, instrument}; pin_project! { /// This implements [tokio::io::AsyncSeek] for and [tokio::io::AsyncRead] by /// simply skipping over some bytes, keeping track of the position. /// It fails whenever you try to seek backwards. /// /// ## Pinning concerns: /// /// [NaiveSeeker] is itself pinned by callers, and we do not need to concern /// ourselves regarding that. /// /// Though, its fields as per /// /// can be pinned or unpinned. /// /// So we need to go over each field and choose our policy carefully. /// /// The obvious cases are the bookkeeping integers we keep in the structure, /// those are private and not shared to anyone, we never build a /// `Pin<&mut X>` out of them at any point, therefore, we can safely never /// mark them as pinned. Of course, it is expected that no developer here /// attempt to `pin!(self.pos)` to pin them because it makes no sense. If /// they have to become pinned, they should be marked `#[pin]` and we need /// to discuss it. /// /// So the bookkeeping integers are in the right state with respect to their /// pinning status. The projection should offer direct access. /// /// On the `r` field, i.e. a `BufReader`, given that /// /// is available, even a `Pin<&mut BufReader>` can be safely moved. /// /// The only care we should have regards the internal reader itself, i.e. /// the `R` instance, see that Tokio decided to `#[pin]` it too: /// /// /// In general, there's no `Unpin` instance for `R: tokio::io::AsyncRead` /// (see ). /// /// Therefore, we could keep it unpinned and pin it in every call site /// whenever we need to call `poll_*` which can be confusing to the non- /// expert developer and we have a fair share amount of situations where the /// [BufReader] instance is naked, i.e. in its `&mut BufReader` /// form, this is annoying because it could lead to expose the naked `R` /// internal instance somehow and would produce a risk of making it move /// unexpectedly. /// /// We choose the path of the least resistance as we have no reason to have /// access to the raw `BufReader` instance, we just `#[pin]` it too and /// enjoy its `poll_*` safe APIs and push the unpinning concerns to the /// internal implementations themselves, which studied the question longer /// than us. pub struct NaiveSeeker { #[pin] r: tokio::io::BufReader, pos: u64, bytes_to_skip: u64, } } impl NaiveSeeker { pub fn new(r: R) -> Self { NaiveSeeker { r: tokio::io::BufReader::new(r), pos: 0, bytes_to_skip: 0, } } } impl tokio::io::AsyncRead for NaiveSeeker { fn poll_read( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { // The amount of data read can be determined by the increase // in the length of the slice returned by `ReadBuf::filled`. let filled_before = buf.filled().len(); let this = self.project(); let pos: &mut u64 = this.pos; match this.r.poll_read(cx, buf) { Poll::Ready(a) => { let bytes_read = buf.filled().len() - filled_before; *pos += bytes_read as u64; Poll::Ready(a) } Poll::Pending => Poll::Pending, } } } impl tokio::io::AsyncBufRead for NaiveSeeker { fn poll_fill_buf( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { self.project().r.poll_fill_buf(cx) } fn consume(self: std::pin::Pin<&mut Self>, amt: usize) { let this = self.project(); this.r.consume(amt); let pos: &mut u64 = this.pos; *pos += amt as u64; } } impl tokio::io::AsyncSeek for NaiveSeeker { #[instrument(skip(self))] fn start_seek( self: std::pin::Pin<&mut Self>, position: std::io::SeekFrom, ) -> std::io::Result<()> { let absolute_offset: u64 = match position { io::SeekFrom::Start(start_offset) => { if start_offset < self.pos { return Err(io::Error::new( io::ErrorKind::Unsupported, format!("can't seek backwards ({} -> {})", self.pos, start_offset), )); } else { start_offset } } // we don't know the total size, can't support this. io::SeekFrom::End(_end_offset) => { return Err(io::Error::new( io::ErrorKind::Unsupported, "can't seek from end", )); } io::SeekFrom::Current(relative_offset) => { if relative_offset < 0 { return Err(io::Error::new( io::ErrorKind::Unsupported, "can't seek backwards relative to current position", )); } else { self.pos + relative_offset as u64 } } }; debug!(absolute_offset=?absolute_offset, "seek"); // we already know absolute_offset is larger than self.pos debug_assert!( absolute_offset >= self.pos, "absolute_offset {} is larger than self.pos {}", absolute_offset, self.pos ); // calculate bytes to skip *self.project().bytes_to_skip = absolute_offset - self.pos; Ok(()) } #[instrument(skip(self))] fn poll_complete( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { if self.bytes_to_skip == 0 { // return the new position (from the start of the stream) return Poll::Ready(Ok(self.pos)); } // discard some bytes, until pos is where we want it to be. // We create a buffer that we'll discard later on. let mut buf = [0; 1024]; // Loop until we've reached the desired seek position. This is done by issuing repeated // `poll_read` calls. If the data is not available yet, we will yield back to the executor // and wait to be polled again. loop { // calculate the length we want to skip at most, which is either a max // buffer size, or the number of remaining bytes to read, whatever is // smaller. let bytes_to_skip = std::cmp::min(self.bytes_to_skip as usize, buf.len()); let mut read_buf = tokio::io::ReadBuf::new(&mut buf[..bytes_to_skip]); match self.as_mut().poll_read(cx, &mut read_buf) { Poll::Ready(_a) => { let bytes_read = read_buf.filled().len() as u64; if bytes_read == 0 { return Poll::Ready(Err(io::Error::new( io::ErrorKind::UnexpectedEof, format!( "tried to skip {} bytes, but only was able to skip {} until reaching EOF", bytes_to_skip, bytes_read ), ))); } // calculate bytes to skip let bytes_to_skip = self.bytes_to_skip - bytes_read; *self.as_mut().project().bytes_to_skip = bytes_to_skip; if bytes_to_skip == 0 { return Poll::Ready(Ok(self.pos)); } } Poll::Pending => return Poll::Pending, }; } } } impl BlobReader for NaiveSeeker {} #[cfg(test)] mod tests { use super::NaiveSeeker; use std::io::{Cursor, SeekFrom}; use tokio::io::{AsyncReadExt, AsyncSeekExt}; /// This seek requires multiple `poll_read` as we use a 1024 bytes internal /// buffer when doing the seek. /// This ensures we don't hang indefinitely. #[tokio::test] async fn seek() { let buf = vec![0u8; 4096]; let reader = Cursor::new(&buf); let mut seeker = NaiveSeeker::new(reader); seeker.seek(SeekFrom::Start(4000)).await.unwrap(); } #[tokio::test] async fn seek_read() { let mut buf = vec![0u8; 2048]; buf.extend_from_slice(&[1u8; 2048]); buf.extend_from_slice(&[2u8; 2048]); let reader = Cursor::new(&buf); let mut seeker = NaiveSeeker::new(reader); let mut read_buf = vec![0u8; 1024]; seeker.read_exact(&mut read_buf).await.expect("must read"); assert_eq!(read_buf.as_slice(), &[0u8; 1024]); seeker .seek(SeekFrom::Current(1024)) .await .expect("must seek"); seeker.read_exact(&mut read_buf).await.expect("must read"); assert_eq!(read_buf.as_slice(), &[1u8; 1024]); seeker .seek(SeekFrom::Start(2 * 2048)) .await .expect("must seek"); seeker.read_exact(&mut read_buf).await.expect("must read"); assert_eq!(read_buf.as_slice(), &[2u8; 1024]); } }