about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/writer.rs
blob: 347934b3dc8f17222206048409088d8cfb7c84e9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
use pin_project_lite::pin_project;
use std::task::{ready, Poll};

use tokio::io::AsyncWrite;

use super::{padding_len, BytesPacketPosition, EMPTY_BYTES, LEN_SIZE};

pin_project! {
    /// Writes a "bytes wire packet" to the underlying writer.
    /// The format is the same as in [crate::wire::bytes::write_bytes],
    /// however this structure provides a [AsyncWrite] interface,
    /// allowing to not having to pass around the entire payload in memory.
    ///
    /// It internally takes care of writing (non-payload) framing (size and
    /// padding).
    ///
    /// During construction, the expected payload size needs to be provided.
    ///
    /// After writing the payload to it, the user MUST call flush (or shutdown),
    /// which will validate the written payload size to match, and write the
    /// necessary padding.
    ///
    /// In case flush is not called at the end, invalid data might be sent
    /// silently.
    ///
    /// The underlying writer returning `Ok(0)` is considered an EOF situation,
    /// which is stronger than the "typically means the underlying object is no
    /// longer able to accept bytes" interpretation from the docs. If such a
    /// situation occurs, an error is returned.
    ///
    /// The struct holds three fields, the underlying writer, the (expected)
    /// payload length, and an enum, tracking the state.
    pub struct BytesWriter<W>
    where
        W: AsyncWrite,
    {
        #[pin]
        inner: W,
        payload_len: u64,
        state: BytesPacketPosition,
    }
}

impl<W> BytesWriter<W>
where
    W: AsyncWrite,
{
    /// Constructs a new BytesWriter, using the underlying passed writer.
    pub fn new(w: W, payload_len: u64) -> Self {
        Self {
            inner: w,
            payload_len,
            state: BytesPacketPosition::Size(0),
        }
    }
}

/// Returns an error if the passed usize is 0.
#[inline]
fn ensure_nonzero_bytes_written(bytes_written: usize) -> Result<usize, std::io::Error> {
    if bytes_written == 0 {
        Err(std::io::Error::new(
            std::io::ErrorKind::WriteZero,
            "underlying writer accepted 0 bytes",
        ))
    } else {
        Ok(bytes_written)
    }
}

impl<W> AsyncWrite for BytesWriter<W>
where
    W: AsyncWrite,
{
    fn poll_write(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, std::io::Error>> {
        // Use a loop, so we can deal with (multiple) state transitions.
        let mut this = self.project();

        loop {
            match *this.state {
                BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
                BytesPacketPosition::Size(pos) => {
                    let size_field = &this.payload_len.to_le_bytes();

                    let bytes_written = ensure_nonzero_bytes_written(ready!(this
                        .inner
                        .as_mut()
                        .poll_write(cx, &size_field[pos..]))?)?;

                    let new_pos = pos + bytes_written;
                    if new_pos == LEN_SIZE {
                        *this.state = BytesPacketPosition::Payload(0);
                    } else {
                        *this.state = BytesPacketPosition::Size(new_pos);
                    }
                }
                BytesPacketPosition::Payload(pos) => {
                    // Ensure we still have space for more payload
                    if pos + (buf.len() as u64) > *this.payload_len {
                        return Poll::Ready(Err(std::io::Error::new(
                            std::io::ErrorKind::InvalidData,
                            "tried to write excess bytes",
                        )));
                    }
                    let bytes_written = ready!(this.inner.as_mut().poll_write(cx, buf))?;
                    ensure_nonzero_bytes_written(bytes_written)?;
                    let new_pos = pos + (bytes_written as u64);
                    if new_pos == *this.payload_len {
                        *this.state = BytesPacketPosition::Padding(0)
                    } else {
                        *this.state = BytesPacketPosition::Payload(new_pos)
                    }

                    return Poll::Ready(Ok(bytes_written));
                }
                // If we're already in padding state, there should be no more payload left to write!
                BytesPacketPosition::Padding(_pos) => {
                    return Poll::Ready(Err(std::io::Error::new(
                        std::io::ErrorKind::InvalidData,
                        "tried to write excess bytes",
                    )))
                }
            }
        }
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        let mut this = self.project();

        loop {
            match *this.state {
                BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
                BytesPacketPosition::Size(pos) => {
                    // More bytes to write in the size field
                    let size_field = &this.payload_len.to_le_bytes()[..];
                    let bytes_written = ensure_nonzero_bytes_written(ready!(this
                        .inner
                        .as_mut()
                        .poll_write(cx, &size_field[pos..]))?)?;
                    let new_pos = pos + bytes_written;
                    if new_pos == LEN_SIZE {
                        // Size field written, now ready to receive payload
                        *this.state = BytesPacketPosition::Payload(0);
                    } else {
                        *this.state = BytesPacketPosition::Size(new_pos);
                    }
                }
                BytesPacketPosition::Payload(_pos) => {
                    // If we're at position 0 and want to write 0 bytes of payload
                    // in total, we can transition to padding.
                    // Otherwise, break, as we're expecting more payload to
                    // be written.
                    if *this.payload_len == 0 {
                        *this.state = BytesPacketPosition::Padding(0);
                    } else {
                        break;
                    }
                }
                BytesPacketPosition::Padding(pos) => {
                    // Write remaining padding, if there is padding to write.
                    let total_padding_len = padding_len(*this.payload_len) as usize;

                    if pos != total_padding_len {
                        let bytes_written = ensure_nonzero_bytes_written(ready!(this
                            .inner
                            .as_mut()
                            .poll_write(cx, &EMPTY_BYTES[pos..total_padding_len]))?)?;
                        *this.state = BytesPacketPosition::Padding(pos + bytes_written);
                    } else {
                        // everything written, break
                        break;
                    }
                }
            }
        }
        // Flush the underlying writer.
        this.inner.as_mut().poll_flush(cx)
    }

    fn poll_shutdown(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        // Call flush.
        ready!(self.as_mut().poll_flush(cx))?;

        let this = self.project();

        // After a flush, being inside the padding state, and at the end of the padding
        // is the only way to prevent a dirty shutdown.
        if let BytesPacketPosition::Padding(pos) = *this.state {
            let padding_len = padding_len(*this.payload_len) as usize;
            if padding_len == pos {
                // Shutdown the underlying writer
                return this.inner.poll_shutdown(cx);
            }
        }

        // Shutdown the underlying writer, bubbling up any errors.
        ready!(this.inner.poll_shutdown(cx))?;

        // return an error about unclean shutdown
        Poll::Ready(Err(std::io::Error::new(
            std::io::ErrorKind::BrokenPipe,
            "unclean shutdown",
        )))
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use crate::wire::bytes::write_bytes;
    use hex_literal::hex;
    use lazy_static::lazy_static;
    use tokio::io::AsyncWriteExt;
    use tokio_test::{assert_err, assert_ok, io::Builder};

    use super::*;

    lazy_static! {
        pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024);
    }

    /// Helper function, calling the (simpler) write_bytes with the payload.
    /// We use this to create data we want to see on the wire.
    async fn produce_exp_bytes(payload: &[u8]) -> Vec<u8> {
        let mut exp = vec![];
        write_bytes(&mut exp, payload).await.unwrap();
        exp
    }

    /// Write an empty bytes packet.
    #[tokio::test]
    async fn write_empty() {
        let payload = &[];
        let mut mock = Builder::new()
            .write(&produce_exp_bytes(payload).await)
            .build();

        let mut w = BytesWriter::new(&mut mock, 0);
        assert_ok!(w.write_all(&[]).await, "write all data");
        assert_ok!(w.flush().await, "flush");
    }

    /// Write an empty bytes packet, not calling write.
    #[tokio::test]
    async fn write_empty_only_flush() {
        let payload = &[];
        let mut mock = Builder::new()
            .write(&produce_exp_bytes(payload).await)
            .build();

        let mut w = BytesWriter::new(&mut mock, 0);
        assert_ok!(w.flush().await, "flush");
    }

    /// Write an empty bytes packet, not calling write or flush, only shutdown.
    #[tokio::test]
    async fn write_empty_only_shutdown() {
        let payload = &[];
        let mut mock = Builder::new()
            .write(&produce_exp_bytes(payload).await)
            .build();

        let mut w = BytesWriter::new(&mut mock, 0);
        assert_ok!(w.shutdown().await, "shutdown");
    }

    /// Write a 1 bytes packet
    #[tokio::test]
    async fn write_1b() {
        let payload = &[0xff];

        let mut mock = Builder::new()
            .write(&produce_exp_bytes(payload).await)
            .build();

        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
        assert_ok!(w.write_all(payload).await);
        assert_ok!(w.flush().await, "flush");
    }

    /// Write a 8 bytes payload (no padding)
    #[tokio::test]
    async fn write_8b() {
        let payload = &hex!("0001020304050607");

        let mut mock = Builder::new()
            .write(&produce_exp_bytes(payload).await)
            .build();

        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
        assert_ok!(w.write_all(payload).await);
        assert_ok!(w.flush().await, "flush");
    }

    /// Write a 9 bytes payload (7 bytes padding)
    #[tokio::test]
    async fn write_9b() {
        let payload = &hex!("000102030405060708");

        let mut mock = Builder::new()
            .write(&produce_exp_bytes(payload).await)
            .build();

        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
        assert_ok!(w.write_all(payload).await);
        assert_ok!(w.flush().await, "flush");
    }

    /// Write a 9 bytes packet very granularly, with a lot of flushing in between,
    /// and a shutdown at the end.
    #[tokio::test]
    async fn write_9b_flush() {
        let payload = &hex!("000102030405060708");
        let exp_bytes = produce_exp_bytes(payload).await;

        let mut mock = Builder::new().write(&exp_bytes).build();

        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
        assert_ok!(w.flush().await);

        assert_ok!(w.write_all(&payload[..4]).await);
        assert_ok!(w.flush().await);

        // empty write, cause why not
        assert_ok!(w.write_all(&[]).await);
        assert_ok!(w.flush().await);

        assert_ok!(w.write_all(&payload[4..]).await);
        assert_ok!(w.flush().await);
        assert_ok!(w.shutdown().await);
    }

    /// Write a 9 bytes packet, but cause the sink to only accept half of the
    /// padding, ensuring we correctly write (only) the rest of the padding later.
    /// We write another 2 bytes of "bait", where a faulty implementation (pre
    /// cl/11384) would put too many null bytes.
    #[tokio::test]
    async fn write_9b_write_padding_2steps() {
        let payload = &hex!("000102030405060708");
        let exp_bytes = produce_exp_bytes(payload).await;

        let mut mock = Builder::new()
            .write(&exp_bytes[0..8]) // size
            .write(&exp_bytes[8..17]) // payload
            .write(&exp_bytes[17..19]) // padding (2 of 7 bytes)
            // insert a wait to prevent Mock from merging the two writes into one
            .wait(Duration::from_nanos(1))
            .write(&hex!("0000000000ffff")) // padding (5 of 7 bytes, plus 2 bytes of "bait")
            .build();

        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
        assert_ok!(w.write_all(&payload[..]).await);
        assert_ok!(w.flush().await);
        // Write bait
        assert_ok!(mock.write_all(&hex!("ffff")).await);
    }

    /// Write a larger bytes packet
    #[tokio::test]
    async fn write_1m() {
        let payload = LARGE_PAYLOAD.as_slice();
        let exp_bytes = produce_exp_bytes(payload).await;

        let mut mock = Builder::new().write(&exp_bytes).build();
        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);

        assert_ok!(w.write_all(payload).await);
        assert_ok!(w.flush().await, "flush");
    }

    /// Not calling flush at the end, but shutdown is also ok if we wrote all
    /// bytes we promised to write (as shutdown implies flush)
    #[tokio::test]
    async fn write_shutdown_without_flush_end() {
        let payload = &[0xf0, 0xff];
        let exp_bytes = produce_exp_bytes(payload).await;

        let mut mock = Builder::new().write(&exp_bytes).build();
        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);

        // call flush to write the size field
        assert_ok!(w.flush().await);

        // write payload
        assert_ok!(w.write_all(payload).await);

        // call shutdown
        assert_ok!(w.shutdown().await);
    }

    /// Writing more bytes than previously signalled should fail.
    #[tokio::test]
    async fn write_more_than_signalled_fail() {
        let mut buf = Vec::new();
        let mut w = BytesWriter::new(&mut buf, 2);

        assert_err!(w.write_all(&hex!("000102")).await);
    }
    /// Writing more bytes than previously signalled, but in two parts
    #[tokio::test]
    async fn write_more_than_signalled_split_fail() {
        let mut buf = Vec::new();
        let mut w = BytesWriter::new(&mut buf, 2);

        // write two bytes
        assert_ok!(w.write_all(&hex!("0001")).await);

        // write the excess byte.
        assert_err!(w.write_all(&hex!("02")).await);
    }

    /// Writing more bytes than previously signalled, but flushing after the
    /// signalled amount should fail.
    #[tokio::test]
    async fn write_more_than_signalled_flush_fail() {
        let mut buf = Vec::new();
        let mut w = BytesWriter::new(&mut buf, 2);

        // write two bytes, then flush
        assert_ok!(w.write_all(&hex!("0001")).await);
        assert_ok!(w.flush().await);

        // write the excess byte.
        assert_err!(w.write_all(&hex!("02")).await);
    }

    /// Calling shutdown while not having written all bytes that were promised
    /// returns an error.
    /// Note there's still cases of silent corruption if the user doesn't call
    /// shutdown explicitly (only drops).
    #[tokio::test]
    async fn premature_shutdown() {
        let payload = &[0xf0, 0xff];
        let mut buf = Vec::new();
        let mut w = BytesWriter::new(&mut buf, payload.len() as u64);

        // call flush to write the size field
        assert_ok!(w.flush().await);

        // write half of the payload (!)
        assert_ok!(w.write_all(&payload[0..1]).await);

        // call shutdown, ensure it fails
        assert_err!(w.shutdown().await);
    }

    /// Write to a Writer that fails to write during the size packet (after 4 bytes).
    /// Ensure this error gets propagated on the first call to write.
    #[tokio::test]
    async fn inner_writer_fail_during_size_firstwrite() {
        let payload = &[0xf0];

        let mut mock = Builder::new()
            .write(&1u32.to_le_bytes())
            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
            .build();
        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);

        assert_err!(w.write_all(payload).await);
    }

    /// Write to a Writer that fails to write during the size packet (after 4 bytes).
    /// Ensure this error gets propagated during an initial flush
    #[tokio::test]
    async fn inner_writer_fail_during_size_initial_flush() {
        let payload = &[0xf0];

        let mut mock = Builder::new()
            .write(&1u32.to_le_bytes())
            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
            .build();
        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);

        assert_err!(w.flush().await);
    }

    /// Write to a writer that fails to write during the payload (after 9 bytes).
    /// Ensure this error gets propagated when we're writing this byte.
    #[tokio::test]
    async fn inner_writer_fail_during_write() {
        let payload = &hex!("f0ff");

        let mut mock = Builder::new()
            .write(&2u64.to_le_bytes())
            .write(&hex!("f0"))
            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
            .build();
        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);

        assert_ok!(w.write(&hex!("f0")).await);
        assert_err!(w.write(&hex!("ff")).await);
    }

    /// Write to a writer that fails to write during the padding (after 10 bytes).
    /// Ensure this error gets propagated during a flush.
    #[tokio::test]
    async fn inner_writer_fail_during_padding_flush() {
        let payload = &hex!("f0");

        let mut mock = Builder::new()
            .write(&1u64.to_le_bytes())
            .write(&hex!("f0"))
            .write(&hex!("00"))
            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
            .build();
        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);

        assert_ok!(w.write(&hex!("f0")).await);
        assert_err!(w.flush().await);
    }
}