Skip to main content

veloren_network_protocol/
quic.rs

1use crate::{
2    RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink,
3    error::ProtocolError,
4    event::ProtocolEvent,
5    frame::{ITFrame, InitFrame, OTFrame},
6    handshake::{ReliableDrain, ReliableSink},
7    message::{ALLOC_BLOCK, ITMessage},
8    metrics::{ProtocolMetricCache, RemoveReason},
9    prio::PrioManager,
10    types::{Bandwidth, Mid, Promises, Sid},
11    util::SortedVec,
12};
13use async_trait::async_trait;
14use bytes::BytesMut;
15use hashbrown::HashMap;
16use std::time::{Duration, Instant};
17use tracing::info;
18#[cfg(feature = "trace_pedantic")]
19use tracing::trace;
20
21#[derive(PartialEq, Eq)]
22pub enum QuicDataFormatStream {
23    Main,
24    Reliable(Sid),
25    Unreliable,
26}
27
28pub struct QuicDataFormat {
29    pub stream: QuicDataFormatStream,
30    pub data: BytesMut,
31}
32
33impl QuicDataFormat {
34    fn with_main(buffer: &mut BytesMut) -> Self {
35        Self {
36            stream: QuicDataFormatStream::Main,
37            data: buffer.split(),
38        }
39    }
40
41    fn with_reliable(buffer: &mut BytesMut, sid: Sid) -> Self {
42        Self {
43            stream: QuicDataFormatStream::Reliable(sid),
44            data: buffer.split(),
45        }
46    }
47
48    fn with_unreliable(frame: OTFrame) -> Self {
49        let mut buffer = BytesMut::new();
50        frame.write_bytes(&mut buffer);
51        Self {
52            stream: QuicDataFormatStream::Unreliable,
53            data: buffer,
54        }
55    }
56}
57
58/// QUIC implementation of [`SendProtocol`]
59///
60/// [`SendProtocol`]: crate::SendProtocol
61#[derive(Debug)]
62pub struct QuicSendProtocol<D>
63where
64    D: UnreliableDrain<DataFormat = QuicDataFormat>,
65{
66    main_buffer: BytesMut,
67    reliable_buffers: SortedVec<Sid, BytesMut>,
68    store: PrioManager,
69    next_mid: Mid,
70    closing_streams: Vec<Sid>,
71    notify_closing_streams: Vec<Sid>,
72    pending_shutdown: bool,
73    drain: D,
74    #[expect(dead_code)]
75    last: Instant,
76    metrics: ProtocolMetricCache,
77}
78
79/// QUIC implementation of [`RecvProtocol`]
80///
81/// [`RecvProtocol`]: crate::RecvProtocol
82#[derive(Debug)]
83pub struct QuicRecvProtocol<S>
84where
85    S: UnreliableSink<DataFormat = QuicDataFormat>,
86{
87    main_buffer: BytesMut,
88    unreliable_buffer: BytesMut,
89    reliable_buffers: SortedVec<Sid, BytesMut>,
90    pending_reliable_buffers: Vec<(Sid, BytesMut)>,
91    itmsg_allocator: BytesMut,
92    incoming: HashMap<Mid, ITMessage>,
93    sink: S,
94    metrics: ProtocolMetricCache,
95}
96
97fn is_reliable(p: &Promises) -> bool {
98    p.contains(Promises::ORDERED)
99        || p.contains(Promises::CONSISTENCY)
100        || p.contains(Promises::GUARANTEED_DELIVERY)
101}
102
103impl<D> QuicSendProtocol<D>
104where
105    D: UnreliableDrain<DataFormat = QuicDataFormat>,
106{
107    pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self {
108        Self {
109            main_buffer: BytesMut::new(),
110            reliable_buffers: SortedVec::default(),
111            store: PrioManager::new(metrics.clone()),
112            next_mid: 0u64,
113            closing_streams: vec![],
114            notify_closing_streams: vec![],
115            pending_shutdown: false,
116            drain,
117            last: Instant::now(),
118            metrics,
119        }
120    }
121
122    /// returns all promises that this Protocol can take care of
123    /// If you open a Stream anyway, unsupported promises are ignored.
124    pub fn supported_promises() -> Promises {
125        Promises::ORDERED
126            | Promises::CONSISTENCY
127            | Promises::GUARANTEED_DELIVERY
128            | Promises::COMPRESSED
129            | Promises::ENCRYPTED
130    }
131}
132
133impl<S> QuicRecvProtocol<S>
134where
135    S: UnreliableSink<DataFormat = QuicDataFormat>,
136{
137    pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self {
138        Self {
139            main_buffer: BytesMut::new(),
140            unreliable_buffer: BytesMut::new(),
141            reliable_buffers: SortedVec::default(),
142            pending_reliable_buffers: vec![],
143            itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK),
144            incoming: HashMap::new(),
145            sink,
146            metrics,
147        }
148    }
149
150    async fn recv_into_stream(
151        &mut self,
152    ) -> Result<QuicDataFormatStream, ProtocolError<S::CustomErr>> {
153        let chunk = self.sink.recv().await?;
154        let buffer = match chunk.stream {
155            QuicDataFormatStream::Main => &mut self.main_buffer,
156            QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer,
157            QuicDataFormatStream::Reliable(id) => {
158                match self.reliable_buffers.get_mut(&id) {
159                    Some(buffer) => buffer,
160                    None => {
161                        self.pending_reliable_buffers.push((id, BytesMut::new()));
162                        //Violated but will never happen
163                        &mut self
164                            .pending_reliable_buffers
165                            .last_mut()
166                            .ok_or(ProtocolError::Violated)?
167                            .1
168                    },
169                }
170            },
171        };
172        if buffer.is_empty() {
173            *buffer = chunk.data
174        } else {
175            buffer.extend_from_slice(&chunk.data)
176        }
177        Ok(chunk.stream)
178    }
179}
180
181#[async_trait]
182impl<D> SendProtocol for QuicSendProtocol<D>
183where
184    D: UnreliableDrain<DataFormat = QuicDataFormat>,
185{
186    type CustomErr = D::CustomErr;
187
188    fn notify_from_recv(&mut self, event: ProtocolEvent) {
189        match event {
190            ProtocolEvent::OpenStream {
191                sid,
192                prio,
193                promises,
194                guaranteed_bandwidth,
195            } => {
196                self.store
197                    .open_stream(sid, prio, promises, guaranteed_bandwidth);
198                if is_reliable(&promises) {
199                    self.reliable_buffers.insert(sid, BytesMut::new());
200                }
201            },
202            ProtocolEvent::CloseStream { sid } if !self.store.try_close_stream(sid) => {
203                #[cfg(feature = "trace_pedantic")]
204                trace!(?sid, "hold back notify close stream");
205                self.notify_closing_streams.push(sid);
206            },
207            _ => {},
208        }
209    }
210
211    async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
212        #[cfg(feature = "trace_pedantic")]
213        trace!(?event, "send");
214        match event {
215            ProtocolEvent::OpenStream {
216                sid,
217                prio,
218                promises,
219                guaranteed_bandwidth,
220            } => {
221                self.store
222                    .open_stream(sid, prio, promises, guaranteed_bandwidth);
223                if is_reliable(&promises) {
224                    self.reliable_buffers.insert(sid, BytesMut::new());
225                    //Send a empty message to notify local drain of stream
226                    self.drain
227                        .send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid))
228                        .await?;
229                }
230                event.to_frame().write_bytes(&mut self.main_buffer);
231                self.drain
232                    .send(QuicDataFormat::with_main(&mut self.main_buffer))
233                    .await?;
234            },
235            ProtocolEvent::CloseStream { sid } => {
236                if self.store.try_close_stream(sid) {
237                    let _ = self.reliable_buffers.delete(&sid); //delete if it was reliable
238                    event.to_frame().write_bytes(&mut self.main_buffer);
239                    self.drain
240                        .send(QuicDataFormat::with_main(&mut self.main_buffer))
241                        .await?;
242                } else {
243                    #[cfg(feature = "trace_pedantic")]
244                    trace!(?sid, "hold back close stream");
245                    self.closing_streams.push(sid);
246                }
247            },
248            ProtocolEvent::Shutdown => {
249                if self.store.is_empty() {
250                    event.to_frame().write_bytes(&mut self.main_buffer);
251                    self.drain
252                        .send(QuicDataFormat::with_main(&mut self.main_buffer))
253                        .await?;
254                } else {
255                    #[cfg(feature = "trace_pedantic")]
256                    trace!("hold back shutdown");
257                    self.pending_shutdown = true;
258                }
259            },
260            ProtocolEvent::Message { data, sid } => {
261                self.metrics.smsg_ib(sid, data.len() as u64);
262                self.store.add(data, self.next_mid, sid);
263                self.next_mid += 1;
264            },
265        }
266        Ok(())
267    }
268
269    async fn flush(
270        &mut self,
271        bandwidth: Bandwidth,
272        dt: Duration,
273    ) -> Result</* actual */ Bandwidth, ProtocolError<Self::CustomErr>> {
274        let (frames, _) = self.store.grab(bandwidth, dt);
275        //Todo: optimize reserve
276        let mut data_frames = 0;
277        let mut data_bandwidth = 0;
278        for (sid, frame) in frames {
279            if let OTFrame::Data { mid: _, data } = &frame {
280                data_bandwidth += data.len();
281                data_frames += 1;
282            }
283            match self.reliable_buffers.get_mut(&sid) {
284                Some(buffer) => frame.write_bytes(buffer),
285                None => {
286                    self.drain
287                        .send(QuicDataFormat::with_unreliable(frame))
288                        .await?
289                },
290            }
291        }
292        for (sid, buffer) in self.reliable_buffers.data.iter_mut() {
293            if !buffer.is_empty() {
294                self.drain
295                    .send(QuicDataFormat::with_reliable(buffer, *sid))
296                    .await?;
297            }
298        }
299        self.metrics
300            .sdata_frames_b(data_frames, data_bandwidth as u64);
301
302        let mut finished_streams = vec![];
303        for (i, &sid) in self.closing_streams.iter().enumerate() {
304            if self.store.try_close_stream(sid) {
305                #[cfg(feature = "trace_pedantic")]
306                trace!(?sid, "close stream, as it's now empty");
307                OTFrame::CloseStream { sid }.write_bytes(&mut self.main_buffer);
308                self.drain
309                    .send(QuicDataFormat::with_main(&mut self.main_buffer))
310                    .await?;
311                finished_streams.push(i);
312            }
313        }
314        for i in finished_streams.iter().rev() {
315            self.closing_streams.remove(*i);
316        }
317
318        let mut finished_streams = vec![];
319        for (i, sid) in self.notify_closing_streams.iter().enumerate() {
320            if self.store.try_close_stream(*sid) {
321                #[cfg(feature = "trace_pedantic")]
322                trace!(?sid, "close stream, as it's now empty");
323                finished_streams.push(i);
324            }
325        }
326        for i in finished_streams.iter().rev() {
327            self.notify_closing_streams.remove(*i);
328        }
329
330        if self.pending_shutdown && self.store.is_empty() {
331            #[cfg(feature = "trace_pedantic")]
332            trace!("shutdown, as it's now empty");
333            OTFrame::Shutdown {}.write_bytes(&mut self.main_buffer);
334            self.drain
335                .send(QuicDataFormat::with_main(&mut self.main_buffer))
336                .await?;
337            self.pending_shutdown = false;
338        }
339        Ok(data_bandwidth as u64)
340    }
341}
342
343#[async_trait]
344impl<S> RecvProtocol for QuicRecvProtocol<S>
345where
346    S: UnreliableSink<DataFormat = QuicDataFormat>,
347{
348    type CustomErr = S::CustomErr;
349
350    async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
351        'outer: loop {
352            match ITFrame::read_frame(&mut self.main_buffer) {
353                Ok(Some(frame)) => {
354                    #[cfg(feature = "trace_pedantic")]
355                    trace!(?frame, "recv");
356                    match frame {
357                        ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown),
358                        ITFrame::OpenStream {
359                            sid,
360                            prio,
361                            promises,
362                            guaranteed_bandwidth,
363                        } => {
364                            if is_reliable(&promises) {
365                                self.reliable_buffers.insert(sid, BytesMut::new());
366                            }
367                            break 'outer Ok(ProtocolEvent::OpenStream {
368                                sid,
369                                prio: prio.min(crate::types::HIGHEST_PRIO),
370                                promises,
371                                guaranteed_bandwidth,
372                            });
373                        },
374                        ITFrame::CloseStream { sid } => {
375                            //FIXME: defer close!
376                            //let _ = self.reliable_buffers.delete(sid); // if it was reliable
377                            break 'outer Ok(ProtocolEvent::CloseStream { sid });
378                        },
379                        _ => break 'outer Err(ProtocolError::Violated),
380                    };
381                },
382                Ok(None) => {},
383                Err(()) => return Err(ProtocolError::Violated),
384            }
385
386            // try to order pending
387            let mut pending_violated = false;
388            let mut reliable = vec![];
389
390            self.pending_reliable_buffers.retain(|(_, buffer)| {
391                // try to get Sid without touching buffer
392                let mut testbuffer = buffer.clone();
393                match ITFrame::read_frame(&mut testbuffer) {
394                    Ok(Some(ITFrame::DataHeader {
395                        sid,
396                        mid: _,
397                        length: _,
398                    })) => {
399                        reliable.push((sid, buffer.clone()));
400                        false
401                    },
402                    Ok(Some(_)) | Err(_) => {
403                        pending_violated = true;
404                        false
405                    },
406                    Ok(None) => true,
407                }
408            });
409
410            if pending_violated {
411                break 'outer Err(ProtocolError::Violated);
412            }
413            for (sid, buffer) in reliable.into_iter() {
414                self.reliable_buffers.insert(sid, buffer)
415            }
416
417            let mut iter = self
418                .reliable_buffers
419                .data
420                .iter_mut()
421                .map(|(_, b)| (b, true))
422                .collect::<Vec<_>>();
423            iter.push((&mut self.unreliable_buffer, false));
424
425            for (buffer, reliable) in iter {
426                loop {
427                    match ITFrame::read_frame(buffer) {
428                        Ok(Some(frame)) => {
429                            #[cfg(feature = "trace_pedantic")]
430                            trace!(?frame, "recv");
431                            match frame {
432                                ITFrame::DataHeader { sid, mid, length } => {
433                                    let m = ITMessage::new(sid, length, &mut self.itmsg_allocator);
434                                    self.metrics.rmsg_ib(sid, length);
435                                    self.incoming.insert(mid, m);
436                                },
437                                ITFrame::Data { mid, data } => {
438                                    self.metrics.rdata_frames_b(data.len() as u64);
439                                    let m = match self.incoming.get_mut(&mid) {
440                                        Some(m) => m,
441                                        None => {
442                                            if reliable {
443                                                info!(
444                                                    ?mid,
445                                                    "protocol violation by remote side: send Data \
446                                                     before Header"
447                                                );
448                                                break 'outer Err(ProtocolError::Violated);
449                                            } else {
450                                                //TODO: cleanup old messages from time to time
451                                                continue;
452                                            }
453                                        },
454                                    };
455                                    m.data.extend_from_slice(&data);
456                                    if m.data.len() == m.length as usize {
457                                        // finished, yay
458                                        let m = self
459                                            .incoming
460                                            .remove(&mid)
461                                            .ok_or(ProtocolError::Violated)?;
462                                        self.metrics.rmsg_ob(
463                                            m.sid,
464                                            RemoveReason::Finished,
465                                            m.data.len() as u64,
466                                        );
467                                        break 'outer Ok(ProtocolEvent::Message {
468                                            sid: m.sid,
469                                            data: m.data.freeze(),
470                                        });
471                                    }
472                                },
473                                _ => break 'outer Err(ProtocolError::Violated),
474                            };
475                        },
476                        Ok(None) => break, //inner => read more data
477                        Err(()) => return Err(ProtocolError::Violated),
478                    }
479                }
480            }
481
482            self.recv_into_stream().await?;
483        }
484    }
485}
486
487#[async_trait]
488impl<D> ReliableDrain for QuicSendProtocol<D>
489where
490    D: UnreliableDrain<DataFormat = QuicDataFormat>,
491{
492    type CustomErr = D::CustomErr;
493
494    async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
495        self.main_buffer.reserve(500);
496        frame.write_bytes(&mut self.main_buffer);
497        self.drain
498            .send(QuicDataFormat::with_main(&mut self.main_buffer))
499            .await
500    }
501}
502
503#[async_trait]
504impl<S> ReliableSink for QuicRecvProtocol<S>
505where
506    S: UnreliableSink<DataFormat = QuicDataFormat>,
507{
508    type CustomErr = S::CustomErr;
509
510    async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
511        while self.main_buffer.len() < 100 {
512            if self.recv_into_stream().await? == QuicDataFormatStream::Main
513                && let Some(frame) = InitFrame::read_frame(&mut self.main_buffer)
514            {
515                return Ok(frame);
516            }
517        }
518        Err(ProtocolError::Violated)
519    }
520}
521
522#[cfg(test)]
523mod test_utils {
524    //Quic protocol based on Channel
525    use super::*;
526    use crate::metrics::{ProtocolMetricCache, ProtocolMetrics};
527    use async_channel::*;
528    use std::sync::Arc;
529
530    pub struct QuicDrain {
531        pub sender: Sender<QuicDataFormat>,
532        pub drop_ratio: f32,
533    }
534
535    pub struct QuicSink {
536        pub receiver: Receiver<QuicDataFormat>,
537    }
538
539    /// emulate Quic protocol on Channels
540    pub fn quic_bound(
541        cap: usize,
542        drop_ratio: f32,
543        metrics: Option<ProtocolMetricCache>,
544    ) -> [(QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>); 2] {
545        let (s1, r1) = bounded(cap);
546        let (s2, r2) = bounded(cap);
547        let m = metrics.unwrap_or_else(|| {
548            ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()))
549        });
550        [
551            (
552                QuicSendProtocol::new(
553                    QuicDrain {
554                        sender: s1,
555                        drop_ratio,
556                    },
557                    m.clone(),
558                ),
559                QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()),
560            ),
561            (
562                QuicSendProtocol::new(
563                    QuicDrain {
564                        sender: s2,
565                        drop_ratio,
566                    },
567                    m.clone(),
568                ),
569                QuicRecvProtocol::new(QuicSink { receiver: r1 }, m),
570            ),
571        ]
572    }
573
574    #[async_trait]
575    impl UnreliableDrain for QuicDrain {
576        type CustomErr = ();
577        type DataFormat = QuicDataFormat;
578
579        async fn send(
580            &mut self,
581            data: Self::DataFormat,
582        ) -> Result<(), ProtocolError<Self::CustomErr>> {
583            use rand::RngExt;
584            if matches!(data.stream, QuicDataFormatStream::Unreliable)
585                && rand::rng().random::<f32>() < self.drop_ratio
586            {
587                return Ok(());
588            }
589            self.sender
590                .send(data)
591                .await
592                .map_err(|_| ProtocolError::Custom(()))
593        }
594    }
595
596    #[async_trait]
597    impl UnreliableSink for QuicSink {
598        type CustomErr = ();
599        type DataFormat = QuicDataFormat;
600
601        async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
602            self.receiver
603                .recv()
604                .await
605                .map_err(|_| ProtocolError::Custom(()))
606        }
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use crate::{
613        InitProtocol, ProtocolEvent, RecvProtocol, SendProtocol,
614        error::ProtocolError,
615        frame::OTFrame,
616        metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason},
617        quic::{QuicDataFormat, test_utils::*},
618        types::{Pid, Promises, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, Sid},
619    };
620    use bytes::{Bytes, BytesMut};
621    use std::{sync::Arc, time::Duration};
622
623    #[tokio::test]
624    async fn handshake_all_good() {
625        let [mut p1, mut p2] = quic_bound(10, 0.5, None);
626        let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
627        let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await });
628        let (r1, r2) = tokio::join!(r1, r2);
629        assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42)));
630        assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337)));
631    }
632
633    #[tokio::test]
634    async fn open_stream() {
635        let [p1, p2] = quic_bound(10, 0.5, None);
636        let (mut s, mut r) = (p1.0, p2.1);
637        let event = ProtocolEvent::OpenStream {
638            sid: Sid::new(10),
639            prio: 0u8,
640            promises: Promises::ORDERED,
641            guaranteed_bandwidth: 1_000_000,
642        };
643        s.send(event.clone()).await.unwrap();
644        let e = r.recv().await.unwrap();
645        assert_eq!(event, e);
646    }
647
648    #[tokio::test]
649    async fn send_short_msg() {
650        let [p1, p2] = quic_bound(10, 0.5, None);
651        let (mut s, mut r) = (p1.0, p2.1);
652        let event = ProtocolEvent::OpenStream {
653            sid: Sid::new(10),
654            prio: 3u8,
655            promises: Promises::ORDERED,
656            guaranteed_bandwidth: 1_000_000,
657        };
658        s.send(event).await.unwrap();
659        let _ = r.recv().await.unwrap();
660        let event = ProtocolEvent::Message {
661            sid: Sid::new(10),
662            data: Bytes::from(&[188u8; 600][..]),
663        };
664        s.send(event.clone()).await.unwrap();
665        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
666        let e = r.recv().await.unwrap();
667        assert_eq!(event, e);
668        // 2nd short message
669        let event = ProtocolEvent::Message {
670            sid: Sid::new(10),
671            data: Bytes::from(&[7u8; 30][..]),
672        };
673        s.send(event.clone()).await.unwrap();
674        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
675        let e = r.recv().await.unwrap();
676        assert_eq!(event, e)
677    }
678
679    #[tokio::test]
680    async fn send_long_msg() {
681        let mut metrics =
682            ProtocolMetricCache::new("long_quic", Arc::new(ProtocolMetrics::new().unwrap()));
683        let sid = Sid::new(1);
684        let [p1, p2] = quic_bound(10000, 0.5, Some(metrics.clone()));
685        let (mut s, mut r) = (p1.0, p2.1);
686        let event = ProtocolEvent::OpenStream {
687            sid,
688            prio: 5u8,
689            promises: Promises::COMPRESSED | Promises::ORDERED,
690            guaranteed_bandwidth: 1_000_000,
691        };
692        s.send(event).await.unwrap();
693        let _ = r.recv().await.unwrap();
694        let event = ProtocolEvent::Message {
695            sid,
696            data: Bytes::from(&[99u8; 500_000][..]),
697        };
698        s.send(event.clone()).await.unwrap();
699        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
700        let e = r.recv().await.unwrap();
701        assert_eq!(event, e);
702        metrics.assert_msg(sid, 1, RemoveReason::Finished);
703        metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished);
704        metrics.assert_data_frames(358);
705        metrics.assert_data_frames_bytes(500_000);
706    }
707
708    #[tokio::test]
709    async fn msg_finishes_after_close() {
710        let sid = Sid::new(1);
711        let [p1, p2] = quic_bound(10000, 0.5, None);
712        let (mut s, mut r) = (p1.0, p2.1);
713        let event = ProtocolEvent::OpenStream {
714            sid,
715            prio: 5u8,
716            promises: Promises::COMPRESSED | Promises::ORDERED,
717            guaranteed_bandwidth: 0,
718        };
719        s.send(event).await.unwrap();
720        let _ = r.recv().await.unwrap();
721        let event = ProtocolEvent::Message {
722            sid,
723            data: Bytes::from(&[99u8; 500_000][..]),
724        };
725        s.send(event).await.unwrap();
726        let event = ProtocolEvent::CloseStream { sid };
727        s.send(event).await.unwrap();
728        //send
729        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
730        let e = r.recv().await.unwrap();
731        assert!(matches!(e, ProtocolEvent::Message { .. }));
732        let e = r.recv().await.unwrap();
733        assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
734    }
735
736    #[tokio::test]
737    async fn msg_finishes_after_shutdown() {
738        let sid = Sid::new(1);
739        let [p1, p2] = quic_bound(10000, 0.5, None);
740        let (mut s, mut r) = (p1.0, p2.1);
741        let event = ProtocolEvent::OpenStream {
742            sid,
743            prio: 5u8,
744            promises: Promises::COMPRESSED | Promises::ORDERED,
745            guaranteed_bandwidth: 0,
746        };
747        s.send(event).await.unwrap();
748        let _ = r.recv().await.unwrap();
749        let event = ProtocolEvent::Message {
750            sid,
751            data: Bytes::from(&[99u8; 500_000][..]),
752        };
753        s.send(event).await.unwrap();
754        let event = ProtocolEvent::Shutdown {};
755        s.send(event).await.unwrap();
756        let event = ProtocolEvent::CloseStream { sid };
757        s.send(event).await.unwrap();
758        //send
759        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
760        let e = r.recv().await.unwrap();
761        assert!(matches!(e, ProtocolEvent::Message { .. }));
762        let e = r.recv().await.unwrap();
763        assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
764        let e = r.recv().await.unwrap();
765        assert!(matches!(e, ProtocolEvent::Shutdown));
766    }
767
768    #[tokio::test]
769    async fn msg_finishes_after_drop() {
770        let sid = Sid::new(1);
771        let [p1, p2] = quic_bound(10000, 0.5, None);
772        let (mut s, mut r) = (p1.0, p2.1);
773        let event = ProtocolEvent::OpenStream {
774            sid,
775            prio: 5u8,
776            promises: Promises::COMPRESSED | Promises::ORDERED,
777            guaranteed_bandwidth: 0,
778        };
779        s.send(event).await.unwrap();
780        let event = ProtocolEvent::Message {
781            sid,
782            data: Bytes::from(&[99u8; 500_000][..]),
783        };
784        s.send(event).await.unwrap();
785        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
786        let event = ProtocolEvent::Message {
787            sid,
788            data: Bytes::from(&[100u8; 500_000][..]),
789        };
790        s.send(event).await.unwrap();
791        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
792        drop(s);
793        let e = r.recv().await.unwrap();
794        assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
795        let e = r.recv().await.unwrap();
796        assert!(matches!(e, ProtocolEvent::Message { .. }));
797        let e = r.recv().await.unwrap();
798        assert!(matches!(e, ProtocolEvent::Message { .. }));
799    }
800
801    #[tokio::test]
802    async fn header_and_data_in_seperate_msg() {
803        let sid = Sid::new(1);
804        let (s, r) = async_channel::bounded(10);
805        let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
806        let mut r = super::QuicRecvProtocol::new(QuicSink { receiver: r }, m.clone());
807
808        const DATA1: &[u8; 69] =
809            b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD ";
810        const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open";
811        let mut bytes = BytesMut::with_capacity(1500);
812        OTFrame::OpenStream {
813            sid,
814            prio: 5u8,
815            promises: Promises::COMPRESSED | Promises::ORDERED,
816            guaranteed_bandwidth: 1_000_000,
817        }
818        .write_bytes(&mut bytes);
819        s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
820
821        OTFrame::DataHeader {
822            mid: 99,
823            sid,
824            length: (DATA1.len() + DATA2.len()) as u64,
825        }
826        .write_bytes(&mut bytes);
827        s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
828            .await
829            .unwrap();
830
831        OTFrame::Data {
832            mid: 99,
833            data: Bytes::from(&DATA1[..]),
834        }
835        .write_bytes(&mut bytes);
836        OTFrame::Data {
837            mid: 99,
838            data: Bytes::from(&DATA2[..]),
839        }
840        .write_bytes(&mut bytes);
841        s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
842            .await
843            .unwrap();
844
845        OTFrame::CloseStream { sid }.write_bytes(&mut bytes);
846        s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
847
848        let e = r.recv().await.unwrap();
849        assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
850        let e = r.recv().await.unwrap();
851        assert!(matches!(e, ProtocolEvent::Message { .. }));
852
853        let e = r.recv().await.unwrap();
854        assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
855    }
856
857    #[tokio::test]
858    async fn drop_sink_while_recv() {
859        let sid = Sid::new(1);
860        let (s, r) = async_channel::bounded(10);
861        let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
862        let mut r = super::QuicRecvProtocol::new(QuicSink { receiver: r }, m.clone());
863
864        let mut bytes = BytesMut::with_capacity(1500);
865        OTFrame::OpenStream {
866            sid,
867            prio: 5u8,
868            promises: Promises::COMPRESSED,
869            guaranteed_bandwidth: 1_000_000,
870        }
871        .write_bytes(&mut bytes);
872        s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
873        let e = r.recv().await.unwrap();
874        assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
875
876        let e = tokio::spawn(async move { r.recv().await });
877        drop(s);
878
879        let e = e.await.unwrap();
880        assert_eq!(e, Err(ProtocolError::Custom(())));
881    }
882
883    #[tokio::test]
884    #[should_panic]
885    async fn send_on_stream_from_remote_without_notify() {
886        //remote opens stream
887        //we send on it
888        let [mut p1, mut p2] = quic_bound(10, 0.5, None);
889        let event = ProtocolEvent::OpenStream {
890            sid: Sid::new(10),
891            prio: 3u8,
892            promises: Promises::ORDERED,
893            guaranteed_bandwidth: 1_000_000,
894        };
895        p1.0.send(event).await.unwrap();
896        let _ = p2.1.recv().await.unwrap();
897        let event = ProtocolEvent::Message {
898            sid: Sid::new(10),
899            data: Bytes::from(&[188u8; 600][..]),
900        };
901        p2.0.send(event.clone()).await.unwrap();
902        p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
903        let e = p1.1.recv().await.unwrap();
904        assert_eq!(event, e);
905    }
906
907    #[tokio::test]
908    async fn send_on_stream_from_remote() {
909        //remote opens stream
910        //we send on it
911        let [mut p1, mut p2] = quic_bound(10, 0.5, None);
912        let event = ProtocolEvent::OpenStream {
913            sid: Sid::new(10),
914            prio: 3u8,
915            promises: Promises::ORDERED,
916            guaranteed_bandwidth: 1_000_000,
917        };
918        p1.0.send(event).await.unwrap();
919        let e = p2.1.recv().await.unwrap();
920        p2.0.notify_from_recv(e);
921        let event = ProtocolEvent::Message {
922            sid: Sid::new(10),
923            data: Bytes::from(&[188u8; 600][..]),
924        };
925        p2.0.send(event.clone()).await.unwrap();
926        p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
927        let e = p1.1.recv().await.unwrap();
928        assert_eq!(event, e);
929    }
930
931    #[tokio::test]
932    async fn unrealiable_test() {
933        const MIN_CHECK: usize = 10;
934        const COUNT: usize = 10_000;
935        //We send COUNT msg with 50% of be send each. we check that >= MIN_CHECK && !=
936        // COUNT reach their target
937
938        let [mut p1, mut p2] = quic_bound(
939            COUNT * 2 - 1, /* 2 times as it is HEADER + DATA but -1 as we want to see not all
940                            * succeed */
941            0.5,
942            None,
943        );
944        let event = ProtocolEvent::OpenStream {
945            sid: Sid::new(1337),
946            prio: 3u8,
947            promises: Promises::empty(), /* on purpose! */
948            guaranteed_bandwidth: 1_000_000,
949        };
950        p1.0.send(event).await.unwrap();
951        let e = p2.1.recv().await.unwrap();
952        p2.0.notify_from_recv(e);
953        let event = ProtocolEvent::Message {
954            sid: Sid::new(1337),
955            data: Bytes::from(&[188u8; 600][..]),
956        };
957        for _ in 0..COUNT {
958            p2.0.send(event.clone()).await.unwrap();
959        }
960        p2.0.flush(1_000_000_000, Duration::from_secs(1))
961            .await
962            .unwrap();
963        for _ in 0..COUNT {
964            p2.0.send(event.clone()).await.unwrap();
965        }
966        for _ in 0..MIN_CHECK {
967            let e = p1.1.recv().await.unwrap();
968            assert_eq!(event, e);
969        }
970    }
971}