1use crate::{
2 RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink,
3 error::ProtocolError,
4 event::ProtocolEvent,
5 frame::{ITFrame, InitFrame, OTFrame},
6 handshake::{ReliableDrain, ReliableSink},
7 message::{ALLOC_BLOCK, ITMessage},
8 metrics::{ProtocolMetricCache, RemoveReason},
9 prio::PrioManager,
10 types::{Bandwidth, Mid, Promises, Sid},
11 util::SortedVec,
12};
13use async_trait::async_trait;
14use bytes::BytesMut;
15use hashbrown::HashMap;
16use std::time::{Duration, Instant};
17use tracing::info;
18#[cfg(feature = "trace_pedantic")]
19use tracing::trace;
20
21#[derive(PartialEq, Eq)]
22pub enum QuicDataFormatStream {
23 Main,
24 Reliable(Sid),
25 Unreliable,
26}
27
28pub struct QuicDataFormat {
29 pub stream: QuicDataFormatStream,
30 pub data: BytesMut,
31}
32
33impl QuicDataFormat {
34 fn with_main(buffer: &mut BytesMut) -> Self {
35 Self {
36 stream: QuicDataFormatStream::Main,
37 data: buffer.split(),
38 }
39 }
40
41 fn with_reliable(buffer: &mut BytesMut, sid: Sid) -> Self {
42 Self {
43 stream: QuicDataFormatStream::Reliable(sid),
44 data: buffer.split(),
45 }
46 }
47
48 fn with_unreliable(frame: OTFrame) -> Self {
49 let mut buffer = BytesMut::new();
50 frame.write_bytes(&mut buffer);
51 Self {
52 stream: QuicDataFormatStream::Unreliable,
53 data: buffer,
54 }
55 }
56}
57
58#[derive(Debug)]
62pub struct QuicSendProtocol<D>
63where
64 D: UnreliableDrain<DataFormat = QuicDataFormat>,
65{
66 main_buffer: BytesMut,
67 reliable_buffers: SortedVec<Sid, BytesMut>,
68 store: PrioManager,
69 next_mid: Mid,
70 closing_streams: Vec<Sid>,
71 notify_closing_streams: Vec<Sid>,
72 pending_shutdown: bool,
73 drain: D,
74 #[expect(dead_code)]
75 last: Instant,
76 metrics: ProtocolMetricCache,
77}
78
79#[derive(Debug)]
83pub struct QuicRecvProtocol<S>
84where
85 S: UnreliableSink<DataFormat = QuicDataFormat>,
86{
87 main_buffer: BytesMut,
88 unreliable_buffer: BytesMut,
89 reliable_buffers: SortedVec<Sid, BytesMut>,
90 pending_reliable_buffers: Vec<(Sid, BytesMut)>,
91 itmsg_allocator: BytesMut,
92 incoming: HashMap<Mid, ITMessage>,
93 sink: S,
94 metrics: ProtocolMetricCache,
95}
96
97fn is_reliable(p: &Promises) -> bool {
98 p.contains(Promises::ORDERED)
99 || p.contains(Promises::CONSISTENCY)
100 || p.contains(Promises::GUARANTEED_DELIVERY)
101}
102
103impl<D> QuicSendProtocol<D>
104where
105 D: UnreliableDrain<DataFormat = QuicDataFormat>,
106{
107 pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self {
108 Self {
109 main_buffer: BytesMut::new(),
110 reliable_buffers: SortedVec::default(),
111 store: PrioManager::new(metrics.clone()),
112 next_mid: 0u64,
113 closing_streams: vec![],
114 notify_closing_streams: vec![],
115 pending_shutdown: false,
116 drain,
117 last: Instant::now(),
118 metrics,
119 }
120 }
121
122 pub fn supported_promises() -> Promises {
125 Promises::ORDERED
126 | Promises::CONSISTENCY
127 | Promises::GUARANTEED_DELIVERY
128 | Promises::COMPRESSED
129 | Promises::ENCRYPTED
130 }
131}
132
133impl<S> QuicRecvProtocol<S>
134where
135 S: UnreliableSink<DataFormat = QuicDataFormat>,
136{
137 pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self {
138 Self {
139 main_buffer: BytesMut::new(),
140 unreliable_buffer: BytesMut::new(),
141 reliable_buffers: SortedVec::default(),
142 pending_reliable_buffers: vec![],
143 itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK),
144 incoming: HashMap::new(),
145 sink,
146 metrics,
147 }
148 }
149
150 async fn recv_into_stream(
151 &mut self,
152 ) -> Result<QuicDataFormatStream, ProtocolError<S::CustomErr>> {
153 let chunk = self.sink.recv().await?;
154 let buffer = match chunk.stream {
155 QuicDataFormatStream::Main => &mut self.main_buffer,
156 QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer,
157 QuicDataFormatStream::Reliable(id) => {
158 match self.reliable_buffers.get_mut(&id) {
159 Some(buffer) => buffer,
160 None => {
161 self.pending_reliable_buffers.push((id, BytesMut::new()));
162 &mut self
164 .pending_reliable_buffers
165 .last_mut()
166 .ok_or(ProtocolError::Violated)?
167 .1
168 },
169 }
170 },
171 };
172 if buffer.is_empty() {
173 *buffer = chunk.data
174 } else {
175 buffer.extend_from_slice(&chunk.data)
176 }
177 Ok(chunk.stream)
178 }
179}
180
181#[async_trait]
182impl<D> SendProtocol for QuicSendProtocol<D>
183where
184 D: UnreliableDrain<DataFormat = QuicDataFormat>,
185{
186 type CustomErr = D::CustomErr;
187
188 fn notify_from_recv(&mut self, event: ProtocolEvent) {
189 match event {
190 ProtocolEvent::OpenStream {
191 sid,
192 prio,
193 promises,
194 guaranteed_bandwidth,
195 } => {
196 self.store
197 .open_stream(sid, prio, promises, guaranteed_bandwidth);
198 if is_reliable(&promises) {
199 self.reliable_buffers.insert(sid, BytesMut::new());
200 }
201 },
202 ProtocolEvent::CloseStream { sid } if !self.store.try_close_stream(sid) => {
203 #[cfg(feature = "trace_pedantic")]
204 trace!(?sid, "hold back notify close stream");
205 self.notify_closing_streams.push(sid);
206 },
207 _ => {},
208 }
209 }
210
211 async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
212 #[cfg(feature = "trace_pedantic")]
213 trace!(?event, "send");
214 match event {
215 ProtocolEvent::OpenStream {
216 sid,
217 prio,
218 promises,
219 guaranteed_bandwidth,
220 } => {
221 self.store
222 .open_stream(sid, prio, promises, guaranteed_bandwidth);
223 if is_reliable(&promises) {
224 self.reliable_buffers.insert(sid, BytesMut::new());
225 self.drain
227 .send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid))
228 .await?;
229 }
230 event.to_frame().write_bytes(&mut self.main_buffer);
231 self.drain
232 .send(QuicDataFormat::with_main(&mut self.main_buffer))
233 .await?;
234 },
235 ProtocolEvent::CloseStream { sid } => {
236 if self.store.try_close_stream(sid) {
237 let _ = self.reliable_buffers.delete(&sid); event.to_frame().write_bytes(&mut self.main_buffer);
239 self.drain
240 .send(QuicDataFormat::with_main(&mut self.main_buffer))
241 .await?;
242 } else {
243 #[cfg(feature = "trace_pedantic")]
244 trace!(?sid, "hold back close stream");
245 self.closing_streams.push(sid);
246 }
247 },
248 ProtocolEvent::Shutdown => {
249 if self.store.is_empty() {
250 event.to_frame().write_bytes(&mut self.main_buffer);
251 self.drain
252 .send(QuicDataFormat::with_main(&mut self.main_buffer))
253 .await?;
254 } else {
255 #[cfg(feature = "trace_pedantic")]
256 trace!("hold back shutdown");
257 self.pending_shutdown = true;
258 }
259 },
260 ProtocolEvent::Message { data, sid } => {
261 self.metrics.smsg_ib(sid, data.len() as u64);
262 self.store.add(data, self.next_mid, sid);
263 self.next_mid += 1;
264 },
265 }
266 Ok(())
267 }
268
269 async fn flush(
270 &mut self,
271 bandwidth: Bandwidth,
272 dt: Duration,
273 ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
274 let (frames, _) = self.store.grab(bandwidth, dt);
275 let mut data_frames = 0;
277 let mut data_bandwidth = 0;
278 for (sid, frame) in frames {
279 if let OTFrame::Data { mid: _, data } = &frame {
280 data_bandwidth += data.len();
281 data_frames += 1;
282 }
283 match self.reliable_buffers.get_mut(&sid) {
284 Some(buffer) => frame.write_bytes(buffer),
285 None => {
286 self.drain
287 .send(QuicDataFormat::with_unreliable(frame))
288 .await?
289 },
290 }
291 }
292 for (sid, buffer) in self.reliable_buffers.data.iter_mut() {
293 if !buffer.is_empty() {
294 self.drain
295 .send(QuicDataFormat::with_reliable(buffer, *sid))
296 .await?;
297 }
298 }
299 self.metrics
300 .sdata_frames_b(data_frames, data_bandwidth as u64);
301
302 let mut finished_streams = vec![];
303 for (i, &sid) in self.closing_streams.iter().enumerate() {
304 if self.store.try_close_stream(sid) {
305 #[cfg(feature = "trace_pedantic")]
306 trace!(?sid, "close stream, as it's now empty");
307 OTFrame::CloseStream { sid }.write_bytes(&mut self.main_buffer);
308 self.drain
309 .send(QuicDataFormat::with_main(&mut self.main_buffer))
310 .await?;
311 finished_streams.push(i);
312 }
313 }
314 for i in finished_streams.iter().rev() {
315 self.closing_streams.remove(*i);
316 }
317
318 let mut finished_streams = vec![];
319 for (i, sid) in self.notify_closing_streams.iter().enumerate() {
320 if self.store.try_close_stream(*sid) {
321 #[cfg(feature = "trace_pedantic")]
322 trace!(?sid, "close stream, as it's now empty");
323 finished_streams.push(i);
324 }
325 }
326 for i in finished_streams.iter().rev() {
327 self.notify_closing_streams.remove(*i);
328 }
329
330 if self.pending_shutdown && self.store.is_empty() {
331 #[cfg(feature = "trace_pedantic")]
332 trace!("shutdown, as it's now empty");
333 OTFrame::Shutdown {}.write_bytes(&mut self.main_buffer);
334 self.drain
335 .send(QuicDataFormat::with_main(&mut self.main_buffer))
336 .await?;
337 self.pending_shutdown = false;
338 }
339 Ok(data_bandwidth as u64)
340 }
341}
342
343#[async_trait]
344impl<S> RecvProtocol for QuicRecvProtocol<S>
345where
346 S: UnreliableSink<DataFormat = QuicDataFormat>,
347{
348 type CustomErr = S::CustomErr;
349
350 async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
351 'outer: loop {
352 match ITFrame::read_frame(&mut self.main_buffer) {
353 Ok(Some(frame)) => {
354 #[cfg(feature = "trace_pedantic")]
355 trace!(?frame, "recv");
356 match frame {
357 ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown),
358 ITFrame::OpenStream {
359 sid,
360 prio,
361 promises,
362 guaranteed_bandwidth,
363 } => {
364 if is_reliable(&promises) {
365 self.reliable_buffers.insert(sid, BytesMut::new());
366 }
367 break 'outer Ok(ProtocolEvent::OpenStream {
368 sid,
369 prio: prio.min(crate::types::HIGHEST_PRIO),
370 promises,
371 guaranteed_bandwidth,
372 });
373 },
374 ITFrame::CloseStream { sid } => {
375 break 'outer Ok(ProtocolEvent::CloseStream { sid });
378 },
379 _ => break 'outer Err(ProtocolError::Violated),
380 };
381 },
382 Ok(None) => {},
383 Err(()) => return Err(ProtocolError::Violated),
384 }
385
386 let mut pending_violated = false;
388 let mut reliable = vec![];
389
390 self.pending_reliable_buffers.retain(|(_, buffer)| {
391 let mut testbuffer = buffer.clone();
393 match ITFrame::read_frame(&mut testbuffer) {
394 Ok(Some(ITFrame::DataHeader {
395 sid,
396 mid: _,
397 length: _,
398 })) => {
399 reliable.push((sid, buffer.clone()));
400 false
401 },
402 Ok(Some(_)) | Err(_) => {
403 pending_violated = true;
404 false
405 },
406 Ok(None) => true,
407 }
408 });
409
410 if pending_violated {
411 break 'outer Err(ProtocolError::Violated);
412 }
413 for (sid, buffer) in reliable.into_iter() {
414 self.reliable_buffers.insert(sid, buffer)
415 }
416
417 let mut iter = self
418 .reliable_buffers
419 .data
420 .iter_mut()
421 .map(|(_, b)| (b, true))
422 .collect::<Vec<_>>();
423 iter.push((&mut self.unreliable_buffer, false));
424
425 for (buffer, reliable) in iter {
426 loop {
427 match ITFrame::read_frame(buffer) {
428 Ok(Some(frame)) => {
429 #[cfg(feature = "trace_pedantic")]
430 trace!(?frame, "recv");
431 match frame {
432 ITFrame::DataHeader { sid, mid, length } => {
433 let m = ITMessage::new(sid, length, &mut self.itmsg_allocator);
434 self.metrics.rmsg_ib(sid, length);
435 self.incoming.insert(mid, m);
436 },
437 ITFrame::Data { mid, data } => {
438 self.metrics.rdata_frames_b(data.len() as u64);
439 let m = match self.incoming.get_mut(&mid) {
440 Some(m) => m,
441 None => {
442 if reliable {
443 info!(
444 ?mid,
445 "protocol violation by remote side: send Data \
446 before Header"
447 );
448 break 'outer Err(ProtocolError::Violated);
449 } else {
450 continue;
452 }
453 },
454 };
455 m.data.extend_from_slice(&data);
456 if m.data.len() == m.length as usize {
457 let m = self
459 .incoming
460 .remove(&mid)
461 .ok_or(ProtocolError::Violated)?;
462 self.metrics.rmsg_ob(
463 m.sid,
464 RemoveReason::Finished,
465 m.data.len() as u64,
466 );
467 break 'outer Ok(ProtocolEvent::Message {
468 sid: m.sid,
469 data: m.data.freeze(),
470 });
471 }
472 },
473 _ => break 'outer Err(ProtocolError::Violated),
474 };
475 },
476 Ok(None) => break, Err(()) => return Err(ProtocolError::Violated),
478 }
479 }
480 }
481
482 self.recv_into_stream().await?;
483 }
484 }
485}
486
487#[async_trait]
488impl<D> ReliableDrain for QuicSendProtocol<D>
489where
490 D: UnreliableDrain<DataFormat = QuicDataFormat>,
491{
492 type CustomErr = D::CustomErr;
493
494 async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
495 self.main_buffer.reserve(500);
496 frame.write_bytes(&mut self.main_buffer);
497 self.drain
498 .send(QuicDataFormat::with_main(&mut self.main_buffer))
499 .await
500 }
501}
502
503#[async_trait]
504impl<S> ReliableSink for QuicRecvProtocol<S>
505where
506 S: UnreliableSink<DataFormat = QuicDataFormat>,
507{
508 type CustomErr = S::CustomErr;
509
510 async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
511 while self.main_buffer.len() < 100 {
512 if self.recv_into_stream().await? == QuicDataFormatStream::Main
513 && let Some(frame) = InitFrame::read_frame(&mut self.main_buffer)
514 {
515 return Ok(frame);
516 }
517 }
518 Err(ProtocolError::Violated)
519 }
520}
521
522#[cfg(test)]
523mod test_utils {
524 use super::*;
526 use crate::metrics::{ProtocolMetricCache, ProtocolMetrics};
527 use async_channel::*;
528 use std::sync::Arc;
529
530 pub struct QuicDrain {
531 pub sender: Sender<QuicDataFormat>,
532 pub drop_ratio: f32,
533 }
534
535 pub struct QuicSink {
536 pub receiver: Receiver<QuicDataFormat>,
537 }
538
539 pub fn quic_bound(
541 cap: usize,
542 drop_ratio: f32,
543 metrics: Option<ProtocolMetricCache>,
544 ) -> [(QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>); 2] {
545 let (s1, r1) = bounded(cap);
546 let (s2, r2) = bounded(cap);
547 let m = metrics.unwrap_or_else(|| {
548 ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()))
549 });
550 [
551 (
552 QuicSendProtocol::new(
553 QuicDrain {
554 sender: s1,
555 drop_ratio,
556 },
557 m.clone(),
558 ),
559 QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()),
560 ),
561 (
562 QuicSendProtocol::new(
563 QuicDrain {
564 sender: s2,
565 drop_ratio,
566 },
567 m.clone(),
568 ),
569 QuicRecvProtocol::new(QuicSink { receiver: r1 }, m),
570 ),
571 ]
572 }
573
574 #[async_trait]
575 impl UnreliableDrain for QuicDrain {
576 type CustomErr = ();
577 type DataFormat = QuicDataFormat;
578
579 async fn send(
580 &mut self,
581 data: Self::DataFormat,
582 ) -> Result<(), ProtocolError<Self::CustomErr>> {
583 use rand::RngExt;
584 if matches!(data.stream, QuicDataFormatStream::Unreliable)
585 && rand::rng().random::<f32>() < self.drop_ratio
586 {
587 return Ok(());
588 }
589 self.sender
590 .send(data)
591 .await
592 .map_err(|_| ProtocolError::Custom(()))
593 }
594 }
595
596 #[async_trait]
597 impl UnreliableSink for QuicSink {
598 type CustomErr = ();
599 type DataFormat = QuicDataFormat;
600
601 async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
602 self.receiver
603 .recv()
604 .await
605 .map_err(|_| ProtocolError::Custom(()))
606 }
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use crate::{
613 InitProtocol, ProtocolEvent, RecvProtocol, SendProtocol,
614 error::ProtocolError,
615 frame::OTFrame,
616 metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason},
617 quic::{QuicDataFormat, test_utils::*},
618 types::{Pid, Promises, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, Sid},
619 };
620 use bytes::{Bytes, BytesMut};
621 use std::{sync::Arc, time::Duration};
622
623 #[tokio::test]
624 async fn handshake_all_good() {
625 let [mut p1, mut p2] = quic_bound(10, 0.5, None);
626 let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
627 let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await });
628 let (r1, r2) = tokio::join!(r1, r2);
629 assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42)));
630 assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337)));
631 }
632
633 #[tokio::test]
634 async fn open_stream() {
635 let [p1, p2] = quic_bound(10, 0.5, None);
636 let (mut s, mut r) = (p1.0, p2.1);
637 let event = ProtocolEvent::OpenStream {
638 sid: Sid::new(10),
639 prio: 0u8,
640 promises: Promises::ORDERED,
641 guaranteed_bandwidth: 1_000_000,
642 };
643 s.send(event.clone()).await.unwrap();
644 let e = r.recv().await.unwrap();
645 assert_eq!(event, e);
646 }
647
648 #[tokio::test]
649 async fn send_short_msg() {
650 let [p1, p2] = quic_bound(10, 0.5, None);
651 let (mut s, mut r) = (p1.0, p2.1);
652 let event = ProtocolEvent::OpenStream {
653 sid: Sid::new(10),
654 prio: 3u8,
655 promises: Promises::ORDERED,
656 guaranteed_bandwidth: 1_000_000,
657 };
658 s.send(event).await.unwrap();
659 let _ = r.recv().await.unwrap();
660 let event = ProtocolEvent::Message {
661 sid: Sid::new(10),
662 data: Bytes::from(&[188u8; 600][..]),
663 };
664 s.send(event.clone()).await.unwrap();
665 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
666 let e = r.recv().await.unwrap();
667 assert_eq!(event, e);
668 let event = ProtocolEvent::Message {
670 sid: Sid::new(10),
671 data: Bytes::from(&[7u8; 30][..]),
672 };
673 s.send(event.clone()).await.unwrap();
674 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
675 let e = r.recv().await.unwrap();
676 assert_eq!(event, e)
677 }
678
679 #[tokio::test]
680 async fn send_long_msg() {
681 let mut metrics =
682 ProtocolMetricCache::new("long_quic", Arc::new(ProtocolMetrics::new().unwrap()));
683 let sid = Sid::new(1);
684 let [p1, p2] = quic_bound(10000, 0.5, Some(metrics.clone()));
685 let (mut s, mut r) = (p1.0, p2.1);
686 let event = ProtocolEvent::OpenStream {
687 sid,
688 prio: 5u8,
689 promises: Promises::COMPRESSED | Promises::ORDERED,
690 guaranteed_bandwidth: 1_000_000,
691 };
692 s.send(event).await.unwrap();
693 let _ = r.recv().await.unwrap();
694 let event = ProtocolEvent::Message {
695 sid,
696 data: Bytes::from(&[99u8; 500_000][..]),
697 };
698 s.send(event.clone()).await.unwrap();
699 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
700 let e = r.recv().await.unwrap();
701 assert_eq!(event, e);
702 metrics.assert_msg(sid, 1, RemoveReason::Finished);
703 metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished);
704 metrics.assert_data_frames(358);
705 metrics.assert_data_frames_bytes(500_000);
706 }
707
708 #[tokio::test]
709 async fn msg_finishes_after_close() {
710 let sid = Sid::new(1);
711 let [p1, p2] = quic_bound(10000, 0.5, None);
712 let (mut s, mut r) = (p1.0, p2.1);
713 let event = ProtocolEvent::OpenStream {
714 sid,
715 prio: 5u8,
716 promises: Promises::COMPRESSED | Promises::ORDERED,
717 guaranteed_bandwidth: 0,
718 };
719 s.send(event).await.unwrap();
720 let _ = r.recv().await.unwrap();
721 let event = ProtocolEvent::Message {
722 sid,
723 data: Bytes::from(&[99u8; 500_000][..]),
724 };
725 s.send(event).await.unwrap();
726 let event = ProtocolEvent::CloseStream { sid };
727 s.send(event).await.unwrap();
728 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
730 let e = r.recv().await.unwrap();
731 assert!(matches!(e, ProtocolEvent::Message { .. }));
732 let e = r.recv().await.unwrap();
733 assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
734 }
735
736 #[tokio::test]
737 async fn msg_finishes_after_shutdown() {
738 let sid = Sid::new(1);
739 let [p1, p2] = quic_bound(10000, 0.5, None);
740 let (mut s, mut r) = (p1.0, p2.1);
741 let event = ProtocolEvent::OpenStream {
742 sid,
743 prio: 5u8,
744 promises: Promises::COMPRESSED | Promises::ORDERED,
745 guaranteed_bandwidth: 0,
746 };
747 s.send(event).await.unwrap();
748 let _ = r.recv().await.unwrap();
749 let event = ProtocolEvent::Message {
750 sid,
751 data: Bytes::from(&[99u8; 500_000][..]),
752 };
753 s.send(event).await.unwrap();
754 let event = ProtocolEvent::Shutdown {};
755 s.send(event).await.unwrap();
756 let event = ProtocolEvent::CloseStream { sid };
757 s.send(event).await.unwrap();
758 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
760 let e = r.recv().await.unwrap();
761 assert!(matches!(e, ProtocolEvent::Message { .. }));
762 let e = r.recv().await.unwrap();
763 assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
764 let e = r.recv().await.unwrap();
765 assert!(matches!(e, ProtocolEvent::Shutdown));
766 }
767
768 #[tokio::test]
769 async fn msg_finishes_after_drop() {
770 let sid = Sid::new(1);
771 let [p1, p2] = quic_bound(10000, 0.5, None);
772 let (mut s, mut r) = (p1.0, p2.1);
773 let event = ProtocolEvent::OpenStream {
774 sid,
775 prio: 5u8,
776 promises: Promises::COMPRESSED | Promises::ORDERED,
777 guaranteed_bandwidth: 0,
778 };
779 s.send(event).await.unwrap();
780 let event = ProtocolEvent::Message {
781 sid,
782 data: Bytes::from(&[99u8; 500_000][..]),
783 };
784 s.send(event).await.unwrap();
785 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
786 let event = ProtocolEvent::Message {
787 sid,
788 data: Bytes::from(&[100u8; 500_000][..]),
789 };
790 s.send(event).await.unwrap();
791 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
792 drop(s);
793 let e = r.recv().await.unwrap();
794 assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
795 let e = r.recv().await.unwrap();
796 assert!(matches!(e, ProtocolEvent::Message { .. }));
797 let e = r.recv().await.unwrap();
798 assert!(matches!(e, ProtocolEvent::Message { .. }));
799 }
800
801 #[tokio::test]
802 async fn header_and_data_in_seperate_msg() {
803 let sid = Sid::new(1);
804 let (s, r) = async_channel::bounded(10);
805 let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
806 let mut r = super::QuicRecvProtocol::new(QuicSink { receiver: r }, m.clone());
807
808 const DATA1: &[u8; 69] =
809 b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD ";
810 const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open";
811 let mut bytes = BytesMut::with_capacity(1500);
812 OTFrame::OpenStream {
813 sid,
814 prio: 5u8,
815 promises: Promises::COMPRESSED | Promises::ORDERED,
816 guaranteed_bandwidth: 1_000_000,
817 }
818 .write_bytes(&mut bytes);
819 s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
820
821 OTFrame::DataHeader {
822 mid: 99,
823 sid,
824 length: (DATA1.len() + DATA2.len()) as u64,
825 }
826 .write_bytes(&mut bytes);
827 s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
828 .await
829 .unwrap();
830
831 OTFrame::Data {
832 mid: 99,
833 data: Bytes::from(&DATA1[..]),
834 }
835 .write_bytes(&mut bytes);
836 OTFrame::Data {
837 mid: 99,
838 data: Bytes::from(&DATA2[..]),
839 }
840 .write_bytes(&mut bytes);
841 s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
842 .await
843 .unwrap();
844
845 OTFrame::CloseStream { sid }.write_bytes(&mut bytes);
846 s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
847
848 let e = r.recv().await.unwrap();
849 assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
850 let e = r.recv().await.unwrap();
851 assert!(matches!(e, ProtocolEvent::Message { .. }));
852
853 let e = r.recv().await.unwrap();
854 assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
855 }
856
857 #[tokio::test]
858 async fn drop_sink_while_recv() {
859 let sid = Sid::new(1);
860 let (s, r) = async_channel::bounded(10);
861 let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
862 let mut r = super::QuicRecvProtocol::new(QuicSink { receiver: r }, m.clone());
863
864 let mut bytes = BytesMut::with_capacity(1500);
865 OTFrame::OpenStream {
866 sid,
867 prio: 5u8,
868 promises: Promises::COMPRESSED,
869 guaranteed_bandwidth: 1_000_000,
870 }
871 .write_bytes(&mut bytes);
872 s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
873 let e = r.recv().await.unwrap();
874 assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
875
876 let e = tokio::spawn(async move { r.recv().await });
877 drop(s);
878
879 let e = e.await.unwrap();
880 assert_eq!(e, Err(ProtocolError::Custom(())));
881 }
882
883 #[tokio::test]
884 #[should_panic]
885 async fn send_on_stream_from_remote_without_notify() {
886 let [mut p1, mut p2] = quic_bound(10, 0.5, None);
889 let event = ProtocolEvent::OpenStream {
890 sid: Sid::new(10),
891 prio: 3u8,
892 promises: Promises::ORDERED,
893 guaranteed_bandwidth: 1_000_000,
894 };
895 p1.0.send(event).await.unwrap();
896 let _ = p2.1.recv().await.unwrap();
897 let event = ProtocolEvent::Message {
898 sid: Sid::new(10),
899 data: Bytes::from(&[188u8; 600][..]),
900 };
901 p2.0.send(event.clone()).await.unwrap();
902 p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
903 let e = p1.1.recv().await.unwrap();
904 assert_eq!(event, e);
905 }
906
907 #[tokio::test]
908 async fn send_on_stream_from_remote() {
909 let [mut p1, mut p2] = quic_bound(10, 0.5, None);
912 let event = ProtocolEvent::OpenStream {
913 sid: Sid::new(10),
914 prio: 3u8,
915 promises: Promises::ORDERED,
916 guaranteed_bandwidth: 1_000_000,
917 };
918 p1.0.send(event).await.unwrap();
919 let e = p2.1.recv().await.unwrap();
920 p2.0.notify_from_recv(e);
921 let event = ProtocolEvent::Message {
922 sid: Sid::new(10),
923 data: Bytes::from(&[188u8; 600][..]),
924 };
925 p2.0.send(event.clone()).await.unwrap();
926 p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
927 let e = p1.1.recv().await.unwrap();
928 assert_eq!(event, e);
929 }
930
931 #[tokio::test]
932 async fn unrealiable_test() {
933 const MIN_CHECK: usize = 10;
934 const COUNT: usize = 10_000;
935 let [mut p1, mut p2] = quic_bound(
939 COUNT * 2 - 1, 0.5,
942 None,
943 );
944 let event = ProtocolEvent::OpenStream {
945 sid: Sid::new(1337),
946 prio: 3u8,
947 promises: Promises::empty(), guaranteed_bandwidth: 1_000_000,
949 };
950 p1.0.send(event).await.unwrap();
951 let e = p2.1.recv().await.unwrap();
952 p2.0.notify_from_recv(e);
953 let event = ProtocolEvent::Message {
954 sid: Sid::new(1337),
955 data: Bytes::from(&[188u8; 600][..]),
956 };
957 for _ in 0..COUNT {
958 p2.0.send(event.clone()).await.unwrap();
959 }
960 p2.0.flush(1_000_000_000, Duration::from_secs(1))
961 .await
962 .unwrap();
963 for _ in 0..COUNT {
964 p2.0.send(event.clone()).await.unwrap();
965 }
966 for _ in 0..MIN_CHECK {
967 let e = p1.1.recv().await.unwrap();
968 assert_eq!(event, e);
969 }
970 }
971}