about summary refs log tree commit diff
path: root/tvix/store/src/proto/grpc_blobservice_wrapper.rs
blob: 3ec1d68872c74a8712976eb5dfa989a202d59c7a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use crate::{
    blobservice::{BlobService, BlobWriter},
    proto::sync_read_into_async_read::SyncReadIntoAsyncRead,
    B3Digest,
};
use std::{collections::VecDeque, io, pin::Pin};
use tokio::task;
use tokio_stream::StreamExt;
use tokio_util::io::ReaderStream;
use tonic::{async_trait, Request, Response, Status, Streaming};
use tracing::{instrument, warn};

pub struct GRPCBlobServiceWrapper<BS: BlobService> {
    blob_service: BS,
}

impl<BS: BlobService> From<BS> for GRPCBlobServiceWrapper<BS> {
    fn from(value: BS) -> Self {
        Self {
            blob_service: value,
        }
    }
}

#[async_trait]
impl<BS: BlobService + Send + Sync + Clone + 'static> super::blob_service_server::BlobService
    for GRPCBlobServiceWrapper<BS>
{
    // https://github.com/tokio-rs/tokio/issues/2723#issuecomment-1534723933
    type ReadStream =
        Pin<Box<dyn futures::Stream<Item = Result<super::BlobChunk, Status>> + Send + 'static>>;

    #[instrument(skip(self))]
    async fn stat(
        &self,
        request: Request<super::StatBlobRequest>,
    ) -> Result<Response<super::BlobMeta>, Status> {
        let rq = request.into_inner();
        let req_digest = B3Digest::from_vec(rq.digest)
            .map_err(|_e| Status::invalid_argument("invalid digest length"))?;

        if rq.include_chunks || rq.include_bao {
            return Err(Status::internal("not implemented"));
        }

        match self.blob_service.has(&req_digest) {
            Ok(true) => Ok(Response::new(super::BlobMeta::default())),
            Ok(false) => Err(Status::not_found(format!("blob {} not found", &req_digest))),
            Err(e) => Err(e.into()),
        }
    }

    #[instrument(skip(self))]
    async fn read(
        &self,
        request: Request<super::ReadBlobRequest>,
    ) -> Result<Response<Self::ReadStream>, Status> {
        let rq = request.into_inner();

        let req_digest = B3Digest::from_vec(rq.digest)
            .map_err(|_e| Status::invalid_argument("invalid digest length"))?;

        match self.blob_service.open_read(&req_digest) {
            Ok(Some(reader)) => {
                let async_reader: SyncReadIntoAsyncRead<_, bytes::BytesMut> = reader.into();

                fn stream_mapper(
                    x: Result<bytes::Bytes, io::Error>,
                ) -> Result<super::BlobChunk, Status> {
                    match x {
                        Ok(bytes) => Ok(super::BlobChunk {
                            data: bytes.to_vec(),
                        }),
                        Err(e) => Err(Status::from(e)),
                    }
                }

                let chunks_stream = ReaderStream::new(async_reader).map(stream_mapper);
                Ok(Response::new(Box::pin(chunks_stream)))
            }
            Ok(None) => Err(Status::not_found(format!("blob {} not found", &req_digest))),
            Err(e) => Err(e.into()),
        }
    }

    #[instrument(skip(self))]
    async fn put(
        &self,
        request: Request<Streaming<super::BlobChunk>>,
    ) -> Result<Response<super::PutBlobResponse>, Status> {
        let req_inner = request.into_inner();

        let data_stream = req_inner.map(|x| {
            x.map(|x| VecDeque::from(x.data))
                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
        });

        let data_reader = tokio_util::io::StreamReader::new(data_stream);

        // prepare a writer, which we'll use in the blocking task below.
        let mut writer = self
            .blob_service
            .open_write()
            .map_err(|e| Status::internal(format!("unable to open for write: {}", e)))?;

        let result = task::spawn_blocking(move || -> Result<super::PutBlobResponse, Status> {
            // construct a sync reader to the data
            let mut reader = tokio_util::io::SyncIoBridge::new(data_reader);

            io::copy(&mut reader, &mut writer).map_err(|e| {
                warn!("error copying: {}", e);
                Status::internal("error copying")
            })?;

            let digest = writer
                .close()
                .map_err(|e| {
                    warn!("error closing stream: {}", e);
                    Status::internal("error closing stream")
                })?
                .to_vec();

            Ok(super::PutBlobResponse { digest })
        })
        .await
        .map_err(|_| Status::internal("failed to wait for task"))??;

        Ok(Response::new(result))
    }
}