diff options
Diffstat (limited to 'tvix/store/src/proto/sync_read_into_async_read.rs')
-rw-r--r-- | tvix/store/src/proto/sync_read_into_async_read.rs | 158 |
1 files changed, 158 insertions, 0 deletions
diff --git a/tvix/store/src/proto/sync_read_into_async_read.rs b/tvix/store/src/proto/sync_read_into_async_read.rs new file mode 100644 index 000000000000..0a0ef019781c --- /dev/null +++ b/tvix/store/src/proto/sync_read_into_async_read.rs @@ -0,0 +1,158 @@ +use bytes::Buf; +use core::task::Poll::Ready; +use futures::ready; +use futures::Future; +use std::io; +use std::io::Read; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use tokio::io::AsyncRead; +use tokio::runtime::Handle; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; + +#[derive(Debug)] +enum State<Buf: bytes::Buf + bytes::BufMut> { + Idle(Option<Buf>), + Busy(JoinHandle<(io::Result<usize>, Buf)>), +} + +use State::{Busy, Idle}; + +/// Use a [`SyncReadIntoAsyncRead`] to asynchronously read from a +/// synchronous API. +#[derive(Debug)] +pub struct SyncReadIntoAsyncRead<R: Read + Send, Buf: bytes::Buf + bytes::BufMut> { + state: Mutex<State<Buf>>, + reader: Arc<Mutex<R>>, + rt: Handle, +} + +impl<R: Read + Send, Buf: bytes::Buf + bytes::BufMut> SyncReadIntoAsyncRead<R, Buf> { + /// This must be called from within a Tokio runtime context, or else it will panic. + #[track_caller] + pub fn new(rt: Handle, reader: R) -> Self { + Self { + rt, + state: State::Idle(None).into(), + reader: Arc::new(reader.into()), + } + } + + /// This must be called from within a Tokio runtime context, or else it will panic. + pub fn new_with_reader(readable: R) -> Self { + Self::new(Handle::current(), readable) + } +} + +/// Repeats operations that are interrupted. +macro_rules! uninterruptibly { + ($e:expr) => {{ + loop { + match $e { + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} + res => break res, + } + } + }}; +} + +impl< + R: Read + Send + 'static + std::marker::Unpin, + Buf: bytes::Buf + bytes::BufMut + Send + Default + std::marker::Unpin + 'static, + > AsyncRead for SyncReadIntoAsyncRead<R, Buf> +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + dst: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let me = self.get_mut(); + // Do we need this mutex? + let state = me.state.get_mut(); + + loop { + match state { + Idle(ref mut buf_cell) => { + let mut buf = buf_cell.take().unwrap_or_default(); + + if buf.has_remaining() { + // Here, we will split the `buf` into `[..dst.remaining()... ; rest ]` + // The `rest` is stuffed into the `buf_cell` for further poll_read. + // The other is completely consumed into the unfilled destination. + // `rest` can be empty. + let mut adjusted_src = + buf.copy_to_bytes(std::cmp::min(buf.remaining(), dst.remaining())); + let copied_size = adjusted_src.remaining(); + adjusted_src.copy_to_slice(dst.initialize_unfilled_to(copied_size)); + dst.set_filled(copied_size); + *buf_cell = Some(buf); + return Ready(Ok(())); + } + + let reader = me.reader.clone(); + *state = Busy(me.rt.spawn_blocking(move || { + let result = uninterruptibly!(reader.blocking_lock().read( + // SAFETY: `reader.read` will *ONLY* write initialized bytes + // and never *READ* uninitialized bytes + // inside this buffer. + // + // Furthermore, casting the slice as `*mut [u8]` + // is safe because it has the same layout. + // + // Finally, the pointer obtained is valid and owned + // by `buf` only as we have a valid mutable reference + // to it, it is valid for write. + // + // Here, we copy an nightly API: https://doc.rust-lang.org/stable/src/core/mem/maybe_uninit.rs.html#994-998 + unsafe { + &mut *(buf.chunk_mut().as_uninit_slice_mut() + as *mut [std::mem::MaybeUninit<u8>] + as *mut [u8]) + } + )); + + if let Ok(n) = result { + // SAFETY: given we initialize `n` bytes, we can move `n` bytes + // forward. + unsafe { + buf.advance_mut(n); + } + } + + (result, buf) + })); + } + Busy(ref mut rx) => { + let (result, mut buf) = ready!(Pin::new(rx).poll(cx))?; + + match result { + Ok(n) => { + if n > 0 { + let remaining = std::cmp::min(n, dst.remaining()); + let mut adjusted_src = buf.copy_to_bytes(remaining); + adjusted_src.copy_to_slice(dst.initialize_unfilled_to(remaining)); + dst.advance(remaining); + } + *state = Idle(Some(buf)); + return Ready(Ok(())); + } + Err(e) => { + *state = Idle(None); + return Ready(Err(e)); + } + } + } + } + } + } +} + +impl<R: Read + Send, Buf: bytes::Buf + bytes::BufMut> From<R> for SyncReadIntoAsyncRead<R, Buf> { + /// This must be called from within a Tokio runtime context, or else it will panic. + fn from(value: R) -> Self { + Self::new_with_reader(value) + } +} |