veloren_network_protocol/
quic.rs

1use crate::{
2    RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink,
3    error::ProtocolError,
4    event::ProtocolEvent,
5    frame::{ITFrame, InitFrame, OTFrame},
6    handshake::{ReliableDrain, ReliableSink},
7    message::{ALLOC_BLOCK, ITMessage},
8    metrics::{ProtocolMetricCache, RemoveReason},
9    prio::PrioManager,
10    types::{Bandwidth, Mid, Promises, Sid},
11    util::SortedVec,
12};
13use async_trait::async_trait;
14use bytes::BytesMut;
15use hashbrown::HashMap;
16use std::time::{Duration, Instant};
17use tracing::info;
18#[cfg(feature = "trace_pedantic")]
19use tracing::trace;
20
21#[derive(PartialEq, Eq)]
22pub enum QuicDataFormatStream {
23    Main,
24    Reliable(Sid),
25    Unreliable,
26}
27
28pub struct QuicDataFormat {
29    pub stream: QuicDataFormatStream,
30    pub data: BytesMut,
31}
32
33impl QuicDataFormat {
34    fn with_main(buffer: &mut BytesMut) -> Self {
35        Self {
36            stream: QuicDataFormatStream::Main,
37            data: buffer.split(),
38        }
39    }
40
41    fn with_reliable(buffer: &mut BytesMut, sid: Sid) -> Self {
42        Self {
43            stream: QuicDataFormatStream::Reliable(sid),
44            data: buffer.split(),
45        }
46    }
47
48    fn with_unreliable(frame: OTFrame) -> Self {
49        let mut buffer = BytesMut::new();
50        frame.write_bytes(&mut buffer);
51        Self {
52            stream: QuicDataFormatStream::Unreliable,
53            data: buffer,
54        }
55    }
56}
57
58/// QUIC implementation of [`SendProtocol`]
59///
60/// [`SendProtocol`]: crate::SendProtocol
61#[derive(Debug)]
62pub struct QuicSendProtocol<D>
63where
64    D: UnreliableDrain<DataFormat = QuicDataFormat>,
65{
66    main_buffer: BytesMut,
67    reliable_buffers: SortedVec<Sid, BytesMut>,
68    store: PrioManager,
69    next_mid: Mid,
70    closing_streams: Vec<Sid>,
71    notify_closing_streams: Vec<Sid>,
72    pending_shutdown: bool,
73    drain: D,
74    #[expect(dead_code)]
75    last: Instant,
76    metrics: ProtocolMetricCache,
77}
78
79/// QUIC implementation of [`RecvProtocol`]
80///
81/// [`RecvProtocol`]: crate::RecvProtocol
82#[derive(Debug)]
83pub struct QuicRecvProtocol<S>
84where
85    S: UnreliableSink<DataFormat = QuicDataFormat>,
86{
87    main_buffer: BytesMut,
88    unreliable_buffer: BytesMut,
89    reliable_buffers: SortedVec<Sid, BytesMut>,
90    pending_reliable_buffers: Vec<(Sid, BytesMut)>,
91    itmsg_allocator: BytesMut,
92    incoming: HashMap<Mid, ITMessage>,
93    sink: S,
94    metrics: ProtocolMetricCache,
95}
96
97fn is_reliable(p: &Promises) -> bool {
98    p.contains(Promises::ORDERED)
99        || p.contains(Promises::CONSISTENCY)
100        || p.contains(Promises::GUARANTEED_DELIVERY)
101}
102
103impl<D> QuicSendProtocol<D>
104where
105    D: UnreliableDrain<DataFormat = QuicDataFormat>,
106{
107    pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self {
108        Self {
109            main_buffer: BytesMut::new(),
110            reliable_buffers: SortedVec::default(),
111            store: PrioManager::new(metrics.clone()),
112            next_mid: 0u64,
113            closing_streams: vec![],
114            notify_closing_streams: vec![],
115            pending_shutdown: false,
116            drain,
117            last: Instant::now(),
118            metrics,
119        }
120    }
121
122    /// returns all promises that this Protocol can take care of
123    /// If you open a Stream anyway, unsupported promises are ignored.
124    pub fn supported_promises() -> Promises {
125        Promises::ORDERED
126            | Promises::CONSISTENCY
127            | Promises::GUARANTEED_DELIVERY
128            | Promises::COMPRESSED
129            | Promises::ENCRYPTED
130    }
131}
132
133impl<S> QuicRecvProtocol<S>
134where
135    S: UnreliableSink<DataFormat = QuicDataFormat>,
136{
137    pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self {
138        Self {
139            main_buffer: BytesMut::new(),
140            unreliable_buffer: BytesMut::new(),
141            reliable_buffers: SortedVec::default(),
142            pending_reliable_buffers: vec![],
143            itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK),
144            incoming: HashMap::new(),
145            sink,
146            metrics,
147        }
148    }
149
150    async fn recv_into_stream(
151        &mut self,
152    ) -> Result<QuicDataFormatStream, ProtocolError<S::CustomErr>> {
153        let chunk = self.sink.recv().await?;
154        let buffer = match chunk.stream {
155            QuicDataFormatStream::Main => &mut self.main_buffer,
156            QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer,
157            QuicDataFormatStream::Reliable(id) => {
158                match self.reliable_buffers.get_mut(&id) {
159                    Some(buffer) => buffer,
160                    None => {
161                        self.pending_reliable_buffers.push((id, BytesMut::new()));
162                        //Violated but will never happen
163                        &mut self
164                            .pending_reliable_buffers
165                            .last_mut()
166                            .ok_or(ProtocolError::Violated)?
167                            .1
168                    },
169                }
170            },
171        };
172        if buffer.is_empty() {
173            *buffer = chunk.data
174        } else {
175            buffer.extend_from_slice(&chunk.data)
176        }
177        Ok(chunk.stream)
178    }
179}
180
181#[async_trait]
182impl<D> SendProtocol for QuicSendProtocol<D>
183where
184    D: UnreliableDrain<DataFormat = QuicDataFormat>,
185{
186    type CustomErr = D::CustomErr;
187
188    fn notify_from_recv(&mut self, event: ProtocolEvent) {
189        match event {
190            ProtocolEvent::OpenStream {
191                sid,
192                prio,
193                promises,
194                guaranteed_bandwidth,
195            } => {
196                self.store
197                    .open_stream(sid, prio, promises, guaranteed_bandwidth);
198                if is_reliable(&promises) {
199                    self.reliable_buffers.insert(sid, BytesMut::new());
200                }
201            },
202            ProtocolEvent::CloseStream { sid } => {
203                if !self.store.try_close_stream(sid) {
204                    #[cfg(feature = "trace_pedantic")]
205                    trace!(?sid, "hold back notify close stream");
206                    self.notify_closing_streams.push(sid);
207                }
208            },
209            _ => {},
210        }
211    }
212
213    async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
214        #[cfg(feature = "trace_pedantic")]
215        trace!(?event, "send");
216        match event {
217            ProtocolEvent::OpenStream {
218                sid,
219                prio,
220                promises,
221                guaranteed_bandwidth,
222            } => {
223                self.store
224                    .open_stream(sid, prio, promises, guaranteed_bandwidth);
225                if is_reliable(&promises) {
226                    self.reliable_buffers.insert(sid, BytesMut::new());
227                    //Send a empty message to notify local drain of stream
228                    self.drain
229                        .send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid))
230                        .await?;
231                }
232                event.to_frame().write_bytes(&mut self.main_buffer);
233                self.drain
234                    .send(QuicDataFormat::with_main(&mut self.main_buffer))
235                    .await?;
236            },
237            ProtocolEvent::CloseStream { sid } => {
238                if self.store.try_close_stream(sid) {
239                    let _ = self.reliable_buffers.delete(&sid); //delete if it was reliable
240                    event.to_frame().write_bytes(&mut self.main_buffer);
241                    self.drain
242                        .send(QuicDataFormat::with_main(&mut self.main_buffer))
243                        .await?;
244                } else {
245                    #[cfg(feature = "trace_pedantic")]
246                    trace!(?sid, "hold back close stream");
247                    self.closing_streams.push(sid);
248                }
249            },
250            ProtocolEvent::Shutdown => {
251                if self.store.is_empty() {
252                    event.to_frame().write_bytes(&mut self.main_buffer);
253                    self.drain
254                        .send(QuicDataFormat::with_main(&mut self.main_buffer))
255                        .await?;
256                } else {
257                    #[cfg(feature = "trace_pedantic")]
258                    trace!("hold back shutdown");
259                    self.pending_shutdown = true;
260                }
261            },
262            ProtocolEvent::Message { data, sid } => {
263                self.metrics.smsg_ib(sid, data.len() as u64);
264                self.store.add(data, self.next_mid, sid);
265                self.next_mid += 1;
266            },
267        }
268        Ok(())
269    }
270
271    async fn flush(
272        &mut self,
273        bandwidth: Bandwidth,
274        dt: Duration,
275    ) -> Result</* actual */ Bandwidth, ProtocolError<Self::CustomErr>> {
276        let (frames, _) = self.store.grab(bandwidth, dt);
277        //Todo: optimize reserve
278        let mut data_frames = 0;
279        let mut data_bandwidth = 0;
280        for (sid, frame) in frames {
281            if let OTFrame::Data { mid: _, data } = &frame {
282                data_bandwidth += data.len();
283                data_frames += 1;
284            }
285            match self.reliable_buffers.get_mut(&sid) {
286                Some(buffer) => frame.write_bytes(buffer),
287                None => {
288                    self.drain
289                        .send(QuicDataFormat::with_unreliable(frame))
290                        .await?
291                },
292            }
293        }
294        for (sid, buffer) in self.reliable_buffers.data.iter_mut() {
295            if !buffer.is_empty() {
296                self.drain
297                    .send(QuicDataFormat::with_reliable(buffer, *sid))
298                    .await?;
299            }
300        }
301        self.metrics
302            .sdata_frames_b(data_frames, data_bandwidth as u64);
303
304        let mut finished_streams = vec![];
305        for (i, &sid) in self.closing_streams.iter().enumerate() {
306            if self.store.try_close_stream(sid) {
307                #[cfg(feature = "trace_pedantic")]
308                trace!(?sid, "close stream, as it's now empty");
309                OTFrame::CloseStream { sid }.write_bytes(&mut self.main_buffer);
310                self.drain
311                    .send(QuicDataFormat::with_main(&mut self.main_buffer))
312                    .await?;
313                finished_streams.push(i);
314            }
315        }
316        for i in finished_streams.iter().rev() {
317            self.closing_streams.remove(*i);
318        }
319
320        let mut finished_streams = vec![];
321        for (i, sid) in self.notify_closing_streams.iter().enumerate() {
322            if self.store.try_close_stream(*sid) {
323                #[cfg(feature = "trace_pedantic")]
324                trace!(?sid, "close stream, as it's now empty");
325                finished_streams.push(i);
326            }
327        }
328        for i in finished_streams.iter().rev() {
329            self.notify_closing_streams.remove(*i);
330        }
331
332        if self.pending_shutdown && self.store.is_empty() {
333            #[cfg(feature = "trace_pedantic")]
334            trace!("shutdown, as it's now empty");
335            OTFrame::Shutdown {}.write_bytes(&mut self.main_buffer);
336            self.drain
337                .send(QuicDataFormat::with_main(&mut self.main_buffer))
338                .await?;
339            self.pending_shutdown = false;
340        }
341        Ok(data_bandwidth as u64)
342    }
343}
344
345#[async_trait]
346impl<S> RecvProtocol for QuicRecvProtocol<S>
347where
348    S: UnreliableSink<DataFormat = QuicDataFormat>,
349{
350    type CustomErr = S::CustomErr;
351
352    async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
353        'outer: loop {
354            match ITFrame::read_frame(&mut self.main_buffer) {
355                Ok(Some(frame)) => {
356                    #[cfg(feature = "trace_pedantic")]
357                    trace!(?frame, "recv");
358                    match frame {
359                        ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown),
360                        ITFrame::OpenStream {
361                            sid,
362                            prio,
363                            promises,
364                            guaranteed_bandwidth,
365                        } => {
366                            if is_reliable(&promises) {
367                                self.reliable_buffers.insert(sid, BytesMut::new());
368                            }
369                            break 'outer Ok(ProtocolEvent::OpenStream {
370                                sid,
371                                prio: prio.min(crate::types::HIGHEST_PRIO),
372                                promises,
373                                guaranteed_bandwidth,
374                            });
375                        },
376                        ITFrame::CloseStream { sid } => {
377                            //FIXME: defer close!
378                            //let _ = self.reliable_buffers.delete(sid); // if it was reliable
379                            break 'outer Ok(ProtocolEvent::CloseStream { sid });
380                        },
381                        _ => break 'outer Err(ProtocolError::Violated),
382                    };
383                },
384                Ok(None) => {},
385                Err(()) => return Err(ProtocolError::Violated),
386            }
387
388            // try to order pending
389            let mut pending_violated = false;
390            let mut reliable = vec![];
391
392            self.pending_reliable_buffers.retain(|(_, buffer)| {
393                // try to get Sid without touching buffer
394                let mut testbuffer = buffer.clone();
395                match ITFrame::read_frame(&mut testbuffer) {
396                    Ok(Some(ITFrame::DataHeader {
397                        sid,
398                        mid: _,
399                        length: _,
400                    })) => {
401                        reliable.push((sid, buffer.clone()));
402                        false
403                    },
404                    Ok(Some(_)) | Err(_) => {
405                        pending_violated = true;
406                        false
407                    },
408                    Ok(None) => true,
409                }
410            });
411
412            if pending_violated {
413                break 'outer Err(ProtocolError::Violated);
414            }
415            for (sid, buffer) in reliable.into_iter() {
416                self.reliable_buffers.insert(sid, buffer)
417            }
418
419            let mut iter = self
420                .reliable_buffers
421                .data
422                .iter_mut()
423                .map(|(_, b)| (b, true))
424                .collect::<Vec<_>>();
425            iter.push((&mut self.unreliable_buffer, false));
426
427            for (buffer, reliable) in iter {
428                loop {
429                    match ITFrame::read_frame(buffer) {
430                        Ok(Some(frame)) => {
431                            #[cfg(feature = "trace_pedantic")]
432                            trace!(?frame, "recv");
433                            match frame {
434                                ITFrame::DataHeader { sid, mid, length } => {
435                                    let m = ITMessage::new(sid, length, &mut self.itmsg_allocator);
436                                    self.metrics.rmsg_ib(sid, length);
437                                    self.incoming.insert(mid, m);
438                                },
439                                ITFrame::Data { mid, data } => {
440                                    self.metrics.rdata_frames_b(data.len() as u64);
441                                    let m = match self.incoming.get_mut(&mid) {
442                                        Some(m) => m,
443                                        None => {
444                                            if reliable {
445                                                info!(
446                                                    ?mid,
447                                                    "protocol violation by remote side: send Data \
448                                                     before Header"
449                                                );
450                                                break 'outer Err(ProtocolError::Violated);
451                                            } else {
452                                                //TODO: cleanup old messages from time to time
453                                                continue;
454                                            }
455                                        },
456                                    };
457                                    m.data.extend_from_slice(&data);
458                                    if m.data.len() == m.length as usize {
459                                        // finished, yay
460                                        let m = self
461                                            .incoming
462                                            .remove(&mid)
463                                            .ok_or(ProtocolError::Violated)?;
464                                        self.metrics.rmsg_ob(
465                                            m.sid,
466                                            RemoveReason::Finished,
467                                            m.data.len() as u64,
468                                        );
469                                        break 'outer Ok(ProtocolEvent::Message {
470                                            sid: m.sid,
471                                            data: m.data.freeze(),
472                                        });
473                                    }
474                                },
475                                _ => break 'outer Err(ProtocolError::Violated),
476                            };
477                        },
478                        Ok(None) => break, //inner => read more data
479                        Err(()) => return Err(ProtocolError::Violated),
480                    }
481                }
482            }
483
484            self.recv_into_stream().await?;
485        }
486    }
487}
488
489#[async_trait]
490impl<D> ReliableDrain for QuicSendProtocol<D>
491where
492    D: UnreliableDrain<DataFormat = QuicDataFormat>,
493{
494    type CustomErr = D::CustomErr;
495
496    async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
497        self.main_buffer.reserve(500);
498        frame.write_bytes(&mut self.main_buffer);
499        self.drain
500            .send(QuicDataFormat::with_main(&mut self.main_buffer))
501            .await
502    }
503}
504
505#[async_trait]
506impl<S> ReliableSink for QuicRecvProtocol<S>
507where
508    S: UnreliableSink<DataFormat = QuicDataFormat>,
509{
510    type CustomErr = S::CustomErr;
511
512    async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
513        while self.main_buffer.len() < 100 {
514            if self.recv_into_stream().await? == QuicDataFormatStream::Main {
515                if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) {
516                    return Ok(frame);
517                }
518            }
519        }
520        Err(ProtocolError::Violated)
521    }
522}
523
524#[cfg(test)]
525mod test_utils {
526    //Quic protocol based on Channel
527    use super::*;
528    use crate::metrics::{ProtocolMetricCache, ProtocolMetrics};
529    use async_channel::*;
530    use std::sync::Arc;
531
532    pub struct QuicDrain {
533        pub sender: Sender<QuicDataFormat>,
534        pub drop_ratio: f32,
535    }
536
537    pub struct QuicSink {
538        pub receiver: Receiver<QuicDataFormat>,
539    }
540
541    /// emulate Quic protocol on Channels
542    pub fn quic_bound(
543        cap: usize,
544        drop_ratio: f32,
545        metrics: Option<ProtocolMetricCache>,
546    ) -> [(QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>); 2] {
547        let (s1, r1) = bounded(cap);
548        let (s2, r2) = bounded(cap);
549        let m = metrics.unwrap_or_else(|| {
550            ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()))
551        });
552        [
553            (
554                QuicSendProtocol::new(
555                    QuicDrain {
556                        sender: s1,
557                        drop_ratio,
558                    },
559                    m.clone(),
560                ),
561                QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()),
562            ),
563            (
564                QuicSendProtocol::new(
565                    QuicDrain {
566                        sender: s2,
567                        drop_ratio,
568                    },
569                    m.clone(),
570                ),
571                QuicRecvProtocol::new(QuicSink { receiver: r1 }, m),
572            ),
573        ]
574    }
575
576    #[async_trait]
577    impl UnreliableDrain for QuicDrain {
578        type CustomErr = ();
579        type DataFormat = QuicDataFormat;
580
581        async fn send(
582            &mut self,
583            data: Self::DataFormat,
584        ) -> Result<(), ProtocolError<Self::CustomErr>> {
585            use rand::Rng;
586            if matches!(data.stream, QuicDataFormatStream::Unreliable)
587                && rand::thread_rng().gen::<f32>() < self.drop_ratio
588            {
589                return Ok(());
590            }
591            self.sender
592                .send(data)
593                .await
594                .map_err(|_| ProtocolError::Custom(()))
595        }
596    }
597
598    #[async_trait]
599    impl UnreliableSink for QuicSink {
600        type CustomErr = ();
601        type DataFormat = QuicDataFormat;
602
603        async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
604            self.receiver
605                .recv()
606                .await
607                .map_err(|_| ProtocolError::Custom(()))
608        }
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use crate::{
615        InitProtocol, ProtocolEvent, RecvProtocol, SendProtocol,
616        error::ProtocolError,
617        frame::OTFrame,
618        metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason},
619        quic::{QuicDataFormat, test_utils::*},
620        types::{Pid, Promises, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, Sid},
621    };
622    use bytes::{Bytes, BytesMut};
623    use std::{sync::Arc, time::Duration};
624
625    #[tokio::test]
626    async fn handshake_all_good() {
627        let [mut p1, mut p2] = quic_bound(10, 0.5, None);
628        let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
629        let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await });
630        let (r1, r2) = tokio::join!(r1, r2);
631        assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42)));
632        assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337)));
633    }
634
635    #[tokio::test]
636    async fn open_stream() {
637        let [p1, p2] = quic_bound(10, 0.5, None);
638        let (mut s, mut r) = (p1.0, p2.1);
639        let event = ProtocolEvent::OpenStream {
640            sid: Sid::new(10),
641            prio: 0u8,
642            promises: Promises::ORDERED,
643            guaranteed_bandwidth: 1_000_000,
644        };
645        s.send(event.clone()).await.unwrap();
646        let e = r.recv().await.unwrap();
647        assert_eq!(event, e);
648    }
649
650    #[tokio::test]
651    async fn send_short_msg() {
652        let [p1, p2] = quic_bound(10, 0.5, None);
653        let (mut s, mut r) = (p1.0, p2.1);
654        let event = ProtocolEvent::OpenStream {
655            sid: Sid::new(10),
656            prio: 3u8,
657            promises: Promises::ORDERED,
658            guaranteed_bandwidth: 1_000_000,
659        };
660        s.send(event).await.unwrap();
661        let _ = r.recv().await.unwrap();
662        let event = ProtocolEvent::Message {
663            sid: Sid::new(10),
664            data: Bytes::from(&[188u8; 600][..]),
665        };
666        s.send(event.clone()).await.unwrap();
667        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
668        let e = r.recv().await.unwrap();
669        assert_eq!(event, e);
670        // 2nd short message
671        let event = ProtocolEvent::Message {
672            sid: Sid::new(10),
673            data: Bytes::from(&[7u8; 30][..]),
674        };
675        s.send(event.clone()).await.unwrap();
676        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
677        let e = r.recv().await.unwrap();
678        assert_eq!(event, e)
679    }
680
681    #[tokio::test]
682    async fn send_long_msg() {
683        let mut metrics =
684            ProtocolMetricCache::new("long_quic", Arc::new(ProtocolMetrics::new().unwrap()));
685        let sid = Sid::new(1);
686        let [p1, p2] = quic_bound(10000, 0.5, Some(metrics.clone()));
687        let (mut s, mut r) = (p1.0, p2.1);
688        let event = ProtocolEvent::OpenStream {
689            sid,
690            prio: 5u8,
691            promises: Promises::COMPRESSED | Promises::ORDERED,
692            guaranteed_bandwidth: 1_000_000,
693        };
694        s.send(event).await.unwrap();
695        let _ = r.recv().await.unwrap();
696        let event = ProtocolEvent::Message {
697            sid,
698            data: Bytes::from(&[99u8; 500_000][..]),
699        };
700        s.send(event.clone()).await.unwrap();
701        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
702        let e = r.recv().await.unwrap();
703        assert_eq!(event, e);
704        metrics.assert_msg(sid, 1, RemoveReason::Finished);
705        metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished);
706        metrics.assert_data_frames(358);
707        metrics.assert_data_frames_bytes(500_000);
708    }
709
710    #[tokio::test]
711    async fn msg_finishes_after_close() {
712        let sid = Sid::new(1);
713        let [p1, p2] = quic_bound(10000, 0.5, None);
714        let (mut s, mut r) = (p1.0, p2.1);
715        let event = ProtocolEvent::OpenStream {
716            sid,
717            prio: 5u8,
718            promises: Promises::COMPRESSED | Promises::ORDERED,
719            guaranteed_bandwidth: 0,
720        };
721        s.send(event).await.unwrap();
722        let _ = r.recv().await.unwrap();
723        let event = ProtocolEvent::Message {
724            sid,
725            data: Bytes::from(&[99u8; 500_000][..]),
726        };
727        s.send(event).await.unwrap();
728        let event = ProtocolEvent::CloseStream { sid };
729        s.send(event).await.unwrap();
730        //send
731        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
732        let e = r.recv().await.unwrap();
733        assert!(matches!(e, ProtocolEvent::Message { .. }));
734        let e = r.recv().await.unwrap();
735        assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
736    }
737
738    #[tokio::test]
739    async fn msg_finishes_after_shutdown() {
740        let sid = Sid::new(1);
741        let [p1, p2] = quic_bound(10000, 0.5, None);
742        let (mut s, mut r) = (p1.0, p2.1);
743        let event = ProtocolEvent::OpenStream {
744            sid,
745            prio: 5u8,
746            promises: Promises::COMPRESSED | Promises::ORDERED,
747            guaranteed_bandwidth: 0,
748        };
749        s.send(event).await.unwrap();
750        let _ = r.recv().await.unwrap();
751        let event = ProtocolEvent::Message {
752            sid,
753            data: Bytes::from(&[99u8; 500_000][..]),
754        };
755        s.send(event).await.unwrap();
756        let event = ProtocolEvent::Shutdown {};
757        s.send(event).await.unwrap();
758        let event = ProtocolEvent::CloseStream { sid };
759        s.send(event).await.unwrap();
760        //send
761        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
762        let e = r.recv().await.unwrap();
763        assert!(matches!(e, ProtocolEvent::Message { .. }));
764        let e = r.recv().await.unwrap();
765        assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
766        let e = r.recv().await.unwrap();
767        assert!(matches!(e, ProtocolEvent::Shutdown));
768    }
769
770    #[tokio::test]
771    async fn msg_finishes_after_drop() {
772        let sid = Sid::new(1);
773        let [p1, p2] = quic_bound(10000, 0.5, None);
774        let (mut s, mut r) = (p1.0, p2.1);
775        let event = ProtocolEvent::OpenStream {
776            sid,
777            prio: 5u8,
778            promises: Promises::COMPRESSED | Promises::ORDERED,
779            guaranteed_bandwidth: 0,
780        };
781        s.send(event).await.unwrap();
782        let event = ProtocolEvent::Message {
783            sid,
784            data: Bytes::from(&[99u8; 500_000][..]),
785        };
786        s.send(event).await.unwrap();
787        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
788        let event = ProtocolEvent::Message {
789            sid,
790            data: Bytes::from(&[100u8; 500_000][..]),
791        };
792        s.send(event).await.unwrap();
793        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
794        drop(s);
795        let e = r.recv().await.unwrap();
796        assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
797        let e = r.recv().await.unwrap();
798        assert!(matches!(e, ProtocolEvent::Message { .. }));
799        let e = r.recv().await.unwrap();
800        assert!(matches!(e, ProtocolEvent::Message { .. }));
801    }
802
803    #[tokio::test]
804    async fn header_and_data_in_seperate_msg() {
805        let sid = Sid::new(1);
806        let (s, r) = async_channel::bounded(10);
807        let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
808        let mut r = super::QuicRecvProtocol::new(QuicSink { receiver: r }, m.clone());
809
810        const DATA1: &[u8; 69] =
811            b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD ";
812        const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open";
813        let mut bytes = BytesMut::with_capacity(1500);
814        OTFrame::OpenStream {
815            sid,
816            prio: 5u8,
817            promises: Promises::COMPRESSED | Promises::ORDERED,
818            guaranteed_bandwidth: 1_000_000,
819        }
820        .write_bytes(&mut bytes);
821        s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
822
823        OTFrame::DataHeader {
824            mid: 99,
825            sid,
826            length: (DATA1.len() + DATA2.len()) as u64,
827        }
828        .write_bytes(&mut bytes);
829        s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
830            .await
831            .unwrap();
832
833        OTFrame::Data {
834            mid: 99,
835            data: Bytes::from(&DATA1[..]),
836        }
837        .write_bytes(&mut bytes);
838        OTFrame::Data {
839            mid: 99,
840            data: Bytes::from(&DATA2[..]),
841        }
842        .write_bytes(&mut bytes);
843        s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
844            .await
845            .unwrap();
846
847        OTFrame::CloseStream { sid }.write_bytes(&mut bytes);
848        s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
849
850        let e = r.recv().await.unwrap();
851        assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
852        let e = r.recv().await.unwrap();
853        assert!(matches!(e, ProtocolEvent::Message { .. }));
854
855        let e = r.recv().await.unwrap();
856        assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
857    }
858
859    #[tokio::test]
860    async fn drop_sink_while_recv() {
861        let sid = Sid::new(1);
862        let (s, r) = async_channel::bounded(10);
863        let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
864        let mut r = super::QuicRecvProtocol::new(QuicSink { receiver: r }, m.clone());
865
866        let mut bytes = BytesMut::with_capacity(1500);
867        OTFrame::OpenStream {
868            sid,
869            prio: 5u8,
870            promises: Promises::COMPRESSED,
871            guaranteed_bandwidth: 1_000_000,
872        }
873        .write_bytes(&mut bytes);
874        s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
875        let e = r.recv().await.unwrap();
876        assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
877
878        let e = tokio::spawn(async move { r.recv().await });
879        drop(s);
880
881        let e = e.await.unwrap();
882        assert_eq!(e, Err(ProtocolError::Custom(())));
883    }
884
885    #[tokio::test]
886    #[should_panic]
887    async fn send_on_stream_from_remote_without_notify() {
888        //remote opens stream
889        //we send on it
890        let [mut p1, mut p2] = quic_bound(10, 0.5, None);
891        let event = ProtocolEvent::OpenStream {
892            sid: Sid::new(10),
893            prio: 3u8,
894            promises: Promises::ORDERED,
895            guaranteed_bandwidth: 1_000_000,
896        };
897        p1.0.send(event).await.unwrap();
898        let _ = p2.1.recv().await.unwrap();
899        let event = ProtocolEvent::Message {
900            sid: Sid::new(10),
901            data: Bytes::from(&[188u8; 600][..]),
902        };
903        p2.0.send(event.clone()).await.unwrap();
904        p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
905        let e = p1.1.recv().await.unwrap();
906        assert_eq!(event, e);
907    }
908
909    #[tokio::test]
910    async fn send_on_stream_from_remote() {
911        //remote opens stream
912        //we send on it
913        let [mut p1, mut p2] = quic_bound(10, 0.5, None);
914        let event = ProtocolEvent::OpenStream {
915            sid: Sid::new(10),
916            prio: 3u8,
917            promises: Promises::ORDERED,
918            guaranteed_bandwidth: 1_000_000,
919        };
920        p1.0.send(event).await.unwrap();
921        let e = p2.1.recv().await.unwrap();
922        p2.0.notify_from_recv(e);
923        let event = ProtocolEvent::Message {
924            sid: Sid::new(10),
925            data: Bytes::from(&[188u8; 600][..]),
926        };
927        p2.0.send(event.clone()).await.unwrap();
928        p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
929        let e = p1.1.recv().await.unwrap();
930        assert_eq!(event, e);
931    }
932
933    #[tokio::test]
934    async fn unrealiable_test() {
935        const MIN_CHECK: usize = 10;
936        const COUNT: usize = 10_000;
937        //We send COUNT msg with 50% of be send each. we check that >= MIN_CHECK && !=
938        // COUNT reach their target
939
940        let [mut p1, mut p2] = quic_bound(
941            COUNT * 2 - 1, /* 2 times as it is HEADER + DATA but -1 as we want to see not all
942                            * succeed */
943            0.5,
944            None,
945        );
946        let event = ProtocolEvent::OpenStream {
947            sid: Sid::new(1337),
948            prio: 3u8,
949            promises: Promises::empty(), /* on purpose! */
950            guaranteed_bandwidth: 1_000_000,
951        };
952        p1.0.send(event).await.unwrap();
953        let e = p2.1.recv().await.unwrap();
954        p2.0.notify_from_recv(e);
955        let event = ProtocolEvent::Message {
956            sid: Sid::new(1337),
957            data: Bytes::from(&[188u8; 600][..]),
958        };
959        for _ in 0..COUNT {
960            p2.0.send(event.clone()).await.unwrap();
961        }
962        p2.0.flush(1_000_000_000, Duration::from_secs(1))
963            .await
964            .unwrap();
965        for _ in 0..COUNT {
966            p2.0.send(event.clone()).await.unwrap();
967        }
968        for _ in 0..MIN_CHECK {
969            let e = p1.1.recv().await.unwrap();
970            assert_eq!(event, e);
971        }
972    }
973}