1use crate::{
2 RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink,
3 error::ProtocolError,
4 event::ProtocolEvent,
5 frame::{ITFrame, InitFrame, OTFrame},
6 handshake::{ReliableDrain, ReliableSink},
7 message::{ALLOC_BLOCK, ITMessage},
8 metrics::{ProtocolMetricCache, RemoveReason},
9 prio::PrioManager,
10 types::{Bandwidth, Mid, Promises, Sid},
11 util::SortedVec,
12};
13use async_trait::async_trait;
14use bytes::BytesMut;
15use hashbrown::HashMap;
16use std::time::{Duration, Instant};
17use tracing::info;
18#[cfg(feature = "trace_pedantic")]
19use tracing::trace;
20
21#[derive(PartialEq, Eq)]
22pub enum QuicDataFormatStream {
23 Main,
24 Reliable(Sid),
25 Unreliable,
26}
27
28pub struct QuicDataFormat {
29 pub stream: QuicDataFormatStream,
30 pub data: BytesMut,
31}
32
33impl QuicDataFormat {
34 fn with_main(buffer: &mut BytesMut) -> Self {
35 Self {
36 stream: QuicDataFormatStream::Main,
37 data: buffer.split(),
38 }
39 }
40
41 fn with_reliable(buffer: &mut BytesMut, sid: Sid) -> Self {
42 Self {
43 stream: QuicDataFormatStream::Reliable(sid),
44 data: buffer.split(),
45 }
46 }
47
48 fn with_unreliable(frame: OTFrame) -> Self {
49 let mut buffer = BytesMut::new();
50 frame.write_bytes(&mut buffer);
51 Self {
52 stream: QuicDataFormatStream::Unreliable,
53 data: buffer,
54 }
55 }
56}
57
58#[derive(Debug)]
62pub struct QuicSendProtocol<D>
63where
64 D: UnreliableDrain<DataFormat = QuicDataFormat>,
65{
66 main_buffer: BytesMut,
67 reliable_buffers: SortedVec<Sid, BytesMut>,
68 store: PrioManager,
69 next_mid: Mid,
70 closing_streams: Vec<Sid>,
71 notify_closing_streams: Vec<Sid>,
72 pending_shutdown: bool,
73 drain: D,
74 #[expect(dead_code)]
75 last: Instant,
76 metrics: ProtocolMetricCache,
77}
78
79#[derive(Debug)]
83pub struct QuicRecvProtocol<S>
84where
85 S: UnreliableSink<DataFormat = QuicDataFormat>,
86{
87 main_buffer: BytesMut,
88 unreliable_buffer: BytesMut,
89 reliable_buffers: SortedVec<Sid, BytesMut>,
90 pending_reliable_buffers: Vec<(Sid, BytesMut)>,
91 itmsg_allocator: BytesMut,
92 incoming: HashMap<Mid, ITMessage>,
93 sink: S,
94 metrics: ProtocolMetricCache,
95}
96
97fn is_reliable(p: &Promises) -> bool {
98 p.contains(Promises::ORDERED)
99 || p.contains(Promises::CONSISTENCY)
100 || p.contains(Promises::GUARANTEED_DELIVERY)
101}
102
103impl<D> QuicSendProtocol<D>
104where
105 D: UnreliableDrain<DataFormat = QuicDataFormat>,
106{
107 pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self {
108 Self {
109 main_buffer: BytesMut::new(),
110 reliable_buffers: SortedVec::default(),
111 store: PrioManager::new(metrics.clone()),
112 next_mid: 0u64,
113 closing_streams: vec![],
114 notify_closing_streams: vec![],
115 pending_shutdown: false,
116 drain,
117 last: Instant::now(),
118 metrics,
119 }
120 }
121
122 pub fn supported_promises() -> Promises {
125 Promises::ORDERED
126 | Promises::CONSISTENCY
127 | Promises::GUARANTEED_DELIVERY
128 | Promises::COMPRESSED
129 | Promises::ENCRYPTED
130 }
131}
132
133impl<S> QuicRecvProtocol<S>
134where
135 S: UnreliableSink<DataFormat = QuicDataFormat>,
136{
137 pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self {
138 Self {
139 main_buffer: BytesMut::new(),
140 unreliable_buffer: BytesMut::new(),
141 reliable_buffers: SortedVec::default(),
142 pending_reliable_buffers: vec![],
143 itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK),
144 incoming: HashMap::new(),
145 sink,
146 metrics,
147 }
148 }
149
150 async fn recv_into_stream(
151 &mut self,
152 ) -> Result<QuicDataFormatStream, ProtocolError<S::CustomErr>> {
153 let chunk = self.sink.recv().await?;
154 let buffer = match chunk.stream {
155 QuicDataFormatStream::Main => &mut self.main_buffer,
156 QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer,
157 QuicDataFormatStream::Reliable(id) => {
158 match self.reliable_buffers.get_mut(&id) {
159 Some(buffer) => buffer,
160 None => {
161 self.pending_reliable_buffers.push((id, BytesMut::new()));
162 &mut self
164 .pending_reliable_buffers
165 .last_mut()
166 .ok_or(ProtocolError::Violated)?
167 .1
168 },
169 }
170 },
171 };
172 if buffer.is_empty() {
173 *buffer = chunk.data
174 } else {
175 buffer.extend_from_slice(&chunk.data)
176 }
177 Ok(chunk.stream)
178 }
179}
180
181#[async_trait]
182impl<D> SendProtocol for QuicSendProtocol<D>
183where
184 D: UnreliableDrain<DataFormat = QuicDataFormat>,
185{
186 type CustomErr = D::CustomErr;
187
188 fn notify_from_recv(&mut self, event: ProtocolEvent) {
189 match event {
190 ProtocolEvent::OpenStream {
191 sid,
192 prio,
193 promises,
194 guaranteed_bandwidth,
195 } => {
196 self.store
197 .open_stream(sid, prio, promises, guaranteed_bandwidth);
198 if is_reliable(&promises) {
199 self.reliable_buffers.insert(sid, BytesMut::new());
200 }
201 },
202 ProtocolEvent::CloseStream { sid } => {
203 if !self.store.try_close_stream(sid) {
204 #[cfg(feature = "trace_pedantic")]
205 trace!(?sid, "hold back notify close stream");
206 self.notify_closing_streams.push(sid);
207 }
208 },
209 _ => {},
210 }
211 }
212
213 async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
214 #[cfg(feature = "trace_pedantic")]
215 trace!(?event, "send");
216 match event {
217 ProtocolEvent::OpenStream {
218 sid,
219 prio,
220 promises,
221 guaranteed_bandwidth,
222 } => {
223 self.store
224 .open_stream(sid, prio, promises, guaranteed_bandwidth);
225 if is_reliable(&promises) {
226 self.reliable_buffers.insert(sid, BytesMut::new());
227 self.drain
229 .send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid))
230 .await?;
231 }
232 event.to_frame().write_bytes(&mut self.main_buffer);
233 self.drain
234 .send(QuicDataFormat::with_main(&mut self.main_buffer))
235 .await?;
236 },
237 ProtocolEvent::CloseStream { sid } => {
238 if self.store.try_close_stream(sid) {
239 let _ = self.reliable_buffers.delete(&sid); event.to_frame().write_bytes(&mut self.main_buffer);
241 self.drain
242 .send(QuicDataFormat::with_main(&mut self.main_buffer))
243 .await?;
244 } else {
245 #[cfg(feature = "trace_pedantic")]
246 trace!(?sid, "hold back close stream");
247 self.closing_streams.push(sid);
248 }
249 },
250 ProtocolEvent::Shutdown => {
251 if self.store.is_empty() {
252 event.to_frame().write_bytes(&mut self.main_buffer);
253 self.drain
254 .send(QuicDataFormat::with_main(&mut self.main_buffer))
255 .await?;
256 } else {
257 #[cfg(feature = "trace_pedantic")]
258 trace!("hold back shutdown");
259 self.pending_shutdown = true;
260 }
261 },
262 ProtocolEvent::Message { data, sid } => {
263 self.metrics.smsg_ib(sid, data.len() as u64);
264 self.store.add(data, self.next_mid, sid);
265 self.next_mid += 1;
266 },
267 }
268 Ok(())
269 }
270
271 async fn flush(
272 &mut self,
273 bandwidth: Bandwidth,
274 dt: Duration,
275 ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
276 let (frames, _) = self.store.grab(bandwidth, dt);
277 let mut data_frames = 0;
279 let mut data_bandwidth = 0;
280 for (sid, frame) in frames {
281 if let OTFrame::Data { mid: _, data } = &frame {
282 data_bandwidth += data.len();
283 data_frames += 1;
284 }
285 match self.reliable_buffers.get_mut(&sid) {
286 Some(buffer) => frame.write_bytes(buffer),
287 None => {
288 self.drain
289 .send(QuicDataFormat::with_unreliable(frame))
290 .await?
291 },
292 }
293 }
294 for (sid, buffer) in self.reliable_buffers.data.iter_mut() {
295 if !buffer.is_empty() {
296 self.drain
297 .send(QuicDataFormat::with_reliable(buffer, *sid))
298 .await?;
299 }
300 }
301 self.metrics
302 .sdata_frames_b(data_frames, data_bandwidth as u64);
303
304 let mut finished_streams = vec![];
305 for (i, &sid) in self.closing_streams.iter().enumerate() {
306 if self.store.try_close_stream(sid) {
307 #[cfg(feature = "trace_pedantic")]
308 trace!(?sid, "close stream, as it's now empty");
309 OTFrame::CloseStream { sid }.write_bytes(&mut self.main_buffer);
310 self.drain
311 .send(QuicDataFormat::with_main(&mut self.main_buffer))
312 .await?;
313 finished_streams.push(i);
314 }
315 }
316 for i in finished_streams.iter().rev() {
317 self.closing_streams.remove(*i);
318 }
319
320 let mut finished_streams = vec![];
321 for (i, sid) in self.notify_closing_streams.iter().enumerate() {
322 if self.store.try_close_stream(*sid) {
323 #[cfg(feature = "trace_pedantic")]
324 trace!(?sid, "close stream, as it's now empty");
325 finished_streams.push(i);
326 }
327 }
328 for i in finished_streams.iter().rev() {
329 self.notify_closing_streams.remove(*i);
330 }
331
332 if self.pending_shutdown && self.store.is_empty() {
333 #[cfg(feature = "trace_pedantic")]
334 trace!("shutdown, as it's now empty");
335 OTFrame::Shutdown {}.write_bytes(&mut self.main_buffer);
336 self.drain
337 .send(QuicDataFormat::with_main(&mut self.main_buffer))
338 .await?;
339 self.pending_shutdown = false;
340 }
341 Ok(data_bandwidth as u64)
342 }
343}
344
345#[async_trait]
346impl<S> RecvProtocol for QuicRecvProtocol<S>
347where
348 S: UnreliableSink<DataFormat = QuicDataFormat>,
349{
350 type CustomErr = S::CustomErr;
351
352 async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
353 'outer: loop {
354 match ITFrame::read_frame(&mut self.main_buffer) {
355 Ok(Some(frame)) => {
356 #[cfg(feature = "trace_pedantic")]
357 trace!(?frame, "recv");
358 match frame {
359 ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown),
360 ITFrame::OpenStream {
361 sid,
362 prio,
363 promises,
364 guaranteed_bandwidth,
365 } => {
366 if is_reliable(&promises) {
367 self.reliable_buffers.insert(sid, BytesMut::new());
368 }
369 break 'outer Ok(ProtocolEvent::OpenStream {
370 sid,
371 prio: prio.min(crate::types::HIGHEST_PRIO),
372 promises,
373 guaranteed_bandwidth,
374 });
375 },
376 ITFrame::CloseStream { sid } => {
377 break 'outer Ok(ProtocolEvent::CloseStream { sid });
380 },
381 _ => break 'outer Err(ProtocolError::Violated),
382 };
383 },
384 Ok(None) => {},
385 Err(()) => return Err(ProtocolError::Violated),
386 }
387
388 let mut pending_violated = false;
390 let mut reliable = vec![];
391
392 self.pending_reliable_buffers.retain(|(_, buffer)| {
393 let mut testbuffer = buffer.clone();
395 match ITFrame::read_frame(&mut testbuffer) {
396 Ok(Some(ITFrame::DataHeader {
397 sid,
398 mid: _,
399 length: _,
400 })) => {
401 reliable.push((sid, buffer.clone()));
402 false
403 },
404 Ok(Some(_)) | Err(_) => {
405 pending_violated = true;
406 false
407 },
408 Ok(None) => true,
409 }
410 });
411
412 if pending_violated {
413 break 'outer Err(ProtocolError::Violated);
414 }
415 for (sid, buffer) in reliable.into_iter() {
416 self.reliable_buffers.insert(sid, buffer)
417 }
418
419 let mut iter = self
420 .reliable_buffers
421 .data
422 .iter_mut()
423 .map(|(_, b)| (b, true))
424 .collect::<Vec<_>>();
425 iter.push((&mut self.unreliable_buffer, false));
426
427 for (buffer, reliable) in iter {
428 loop {
429 match ITFrame::read_frame(buffer) {
430 Ok(Some(frame)) => {
431 #[cfg(feature = "trace_pedantic")]
432 trace!(?frame, "recv");
433 match frame {
434 ITFrame::DataHeader { sid, mid, length } => {
435 let m = ITMessage::new(sid, length, &mut self.itmsg_allocator);
436 self.metrics.rmsg_ib(sid, length);
437 self.incoming.insert(mid, m);
438 },
439 ITFrame::Data { mid, data } => {
440 self.metrics.rdata_frames_b(data.len() as u64);
441 let m = match self.incoming.get_mut(&mid) {
442 Some(m) => m,
443 None => {
444 if reliable {
445 info!(
446 ?mid,
447 "protocol violation by remote side: send Data \
448 before Header"
449 );
450 break 'outer Err(ProtocolError::Violated);
451 } else {
452 continue;
454 }
455 },
456 };
457 m.data.extend_from_slice(&data);
458 if m.data.len() == m.length as usize {
459 let m = self
461 .incoming
462 .remove(&mid)
463 .ok_or(ProtocolError::Violated)?;
464 self.metrics.rmsg_ob(
465 m.sid,
466 RemoveReason::Finished,
467 m.data.len() as u64,
468 );
469 break 'outer Ok(ProtocolEvent::Message {
470 sid: m.sid,
471 data: m.data.freeze(),
472 });
473 }
474 },
475 _ => break 'outer Err(ProtocolError::Violated),
476 };
477 },
478 Ok(None) => break, Err(()) => return Err(ProtocolError::Violated),
480 }
481 }
482 }
483
484 self.recv_into_stream().await?;
485 }
486 }
487}
488
489#[async_trait]
490impl<D> ReliableDrain for QuicSendProtocol<D>
491where
492 D: UnreliableDrain<DataFormat = QuicDataFormat>,
493{
494 type CustomErr = D::CustomErr;
495
496 async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
497 self.main_buffer.reserve(500);
498 frame.write_bytes(&mut self.main_buffer);
499 self.drain
500 .send(QuicDataFormat::with_main(&mut self.main_buffer))
501 .await
502 }
503}
504
505#[async_trait]
506impl<S> ReliableSink for QuicRecvProtocol<S>
507where
508 S: UnreliableSink<DataFormat = QuicDataFormat>,
509{
510 type CustomErr = S::CustomErr;
511
512 async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
513 while self.main_buffer.len() < 100 {
514 if self.recv_into_stream().await? == QuicDataFormatStream::Main {
515 if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) {
516 return Ok(frame);
517 }
518 }
519 }
520 Err(ProtocolError::Violated)
521 }
522}
523
524#[cfg(test)]
525mod test_utils {
526 use super::*;
528 use crate::metrics::{ProtocolMetricCache, ProtocolMetrics};
529 use async_channel::*;
530 use std::sync::Arc;
531
532 pub struct QuicDrain {
533 pub sender: Sender<QuicDataFormat>,
534 pub drop_ratio: f32,
535 }
536
537 pub struct QuicSink {
538 pub receiver: Receiver<QuicDataFormat>,
539 }
540
541 pub fn quic_bound(
543 cap: usize,
544 drop_ratio: f32,
545 metrics: Option<ProtocolMetricCache>,
546 ) -> [(QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>); 2] {
547 let (s1, r1) = bounded(cap);
548 let (s2, r2) = bounded(cap);
549 let m = metrics.unwrap_or_else(|| {
550 ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()))
551 });
552 [
553 (
554 QuicSendProtocol::new(
555 QuicDrain {
556 sender: s1,
557 drop_ratio,
558 },
559 m.clone(),
560 ),
561 QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()),
562 ),
563 (
564 QuicSendProtocol::new(
565 QuicDrain {
566 sender: s2,
567 drop_ratio,
568 },
569 m.clone(),
570 ),
571 QuicRecvProtocol::new(QuicSink { receiver: r1 }, m),
572 ),
573 ]
574 }
575
576 #[async_trait]
577 impl UnreliableDrain for QuicDrain {
578 type CustomErr = ();
579 type DataFormat = QuicDataFormat;
580
581 async fn send(
582 &mut self,
583 data: Self::DataFormat,
584 ) -> Result<(), ProtocolError<Self::CustomErr>> {
585 use rand::Rng;
586 if matches!(data.stream, QuicDataFormatStream::Unreliable)
587 && rand::thread_rng().gen::<f32>() < self.drop_ratio
588 {
589 return Ok(());
590 }
591 self.sender
592 .send(data)
593 .await
594 .map_err(|_| ProtocolError::Custom(()))
595 }
596 }
597
598 #[async_trait]
599 impl UnreliableSink for QuicSink {
600 type CustomErr = ();
601 type DataFormat = QuicDataFormat;
602
603 async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
604 self.receiver
605 .recv()
606 .await
607 .map_err(|_| ProtocolError::Custom(()))
608 }
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use crate::{
615 InitProtocol, ProtocolEvent, RecvProtocol, SendProtocol,
616 error::ProtocolError,
617 frame::OTFrame,
618 metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason},
619 quic::{QuicDataFormat, test_utils::*},
620 types::{Pid, Promises, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, Sid},
621 };
622 use bytes::{Bytes, BytesMut};
623 use std::{sync::Arc, time::Duration};
624
625 #[tokio::test]
626 async fn handshake_all_good() {
627 let [mut p1, mut p2] = quic_bound(10, 0.5, None);
628 let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
629 let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await });
630 let (r1, r2) = tokio::join!(r1, r2);
631 assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42)));
632 assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337)));
633 }
634
635 #[tokio::test]
636 async fn open_stream() {
637 let [p1, p2] = quic_bound(10, 0.5, None);
638 let (mut s, mut r) = (p1.0, p2.1);
639 let event = ProtocolEvent::OpenStream {
640 sid: Sid::new(10),
641 prio: 0u8,
642 promises: Promises::ORDERED,
643 guaranteed_bandwidth: 1_000_000,
644 };
645 s.send(event.clone()).await.unwrap();
646 let e = r.recv().await.unwrap();
647 assert_eq!(event, e);
648 }
649
650 #[tokio::test]
651 async fn send_short_msg() {
652 let [p1, p2] = quic_bound(10, 0.5, None);
653 let (mut s, mut r) = (p1.0, p2.1);
654 let event = ProtocolEvent::OpenStream {
655 sid: Sid::new(10),
656 prio: 3u8,
657 promises: Promises::ORDERED,
658 guaranteed_bandwidth: 1_000_000,
659 };
660 s.send(event).await.unwrap();
661 let _ = r.recv().await.unwrap();
662 let event = ProtocolEvent::Message {
663 sid: Sid::new(10),
664 data: Bytes::from(&[188u8; 600][..]),
665 };
666 s.send(event.clone()).await.unwrap();
667 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
668 let e = r.recv().await.unwrap();
669 assert_eq!(event, e);
670 let event = ProtocolEvent::Message {
672 sid: Sid::new(10),
673 data: Bytes::from(&[7u8; 30][..]),
674 };
675 s.send(event.clone()).await.unwrap();
676 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
677 let e = r.recv().await.unwrap();
678 assert_eq!(event, e)
679 }
680
681 #[tokio::test]
682 async fn send_long_msg() {
683 let mut metrics =
684 ProtocolMetricCache::new("long_quic", Arc::new(ProtocolMetrics::new().unwrap()));
685 let sid = Sid::new(1);
686 let [p1, p2] = quic_bound(10000, 0.5, Some(metrics.clone()));
687 let (mut s, mut r) = (p1.0, p2.1);
688 let event = ProtocolEvent::OpenStream {
689 sid,
690 prio: 5u8,
691 promises: Promises::COMPRESSED | Promises::ORDERED,
692 guaranteed_bandwidth: 1_000_000,
693 };
694 s.send(event).await.unwrap();
695 let _ = r.recv().await.unwrap();
696 let event = ProtocolEvent::Message {
697 sid,
698 data: Bytes::from(&[99u8; 500_000][..]),
699 };
700 s.send(event.clone()).await.unwrap();
701 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
702 let e = r.recv().await.unwrap();
703 assert_eq!(event, e);
704 metrics.assert_msg(sid, 1, RemoveReason::Finished);
705 metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished);
706 metrics.assert_data_frames(358);
707 metrics.assert_data_frames_bytes(500_000);
708 }
709
710 #[tokio::test]
711 async fn msg_finishes_after_close() {
712 let sid = Sid::new(1);
713 let [p1, p2] = quic_bound(10000, 0.5, None);
714 let (mut s, mut r) = (p1.0, p2.1);
715 let event = ProtocolEvent::OpenStream {
716 sid,
717 prio: 5u8,
718 promises: Promises::COMPRESSED | Promises::ORDERED,
719 guaranteed_bandwidth: 0,
720 };
721 s.send(event).await.unwrap();
722 let _ = r.recv().await.unwrap();
723 let event = ProtocolEvent::Message {
724 sid,
725 data: Bytes::from(&[99u8; 500_000][..]),
726 };
727 s.send(event).await.unwrap();
728 let event = ProtocolEvent::CloseStream { sid };
729 s.send(event).await.unwrap();
730 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
732 let e = r.recv().await.unwrap();
733 assert!(matches!(e, ProtocolEvent::Message { .. }));
734 let e = r.recv().await.unwrap();
735 assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
736 }
737
738 #[tokio::test]
739 async fn msg_finishes_after_shutdown() {
740 let sid = Sid::new(1);
741 let [p1, p2] = quic_bound(10000, 0.5, None);
742 let (mut s, mut r) = (p1.0, p2.1);
743 let event = ProtocolEvent::OpenStream {
744 sid,
745 prio: 5u8,
746 promises: Promises::COMPRESSED | Promises::ORDERED,
747 guaranteed_bandwidth: 0,
748 };
749 s.send(event).await.unwrap();
750 let _ = r.recv().await.unwrap();
751 let event = ProtocolEvent::Message {
752 sid,
753 data: Bytes::from(&[99u8; 500_000][..]),
754 };
755 s.send(event).await.unwrap();
756 let event = ProtocolEvent::Shutdown {};
757 s.send(event).await.unwrap();
758 let event = ProtocolEvent::CloseStream { sid };
759 s.send(event).await.unwrap();
760 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
762 let e = r.recv().await.unwrap();
763 assert!(matches!(e, ProtocolEvent::Message { .. }));
764 let e = r.recv().await.unwrap();
765 assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
766 let e = r.recv().await.unwrap();
767 assert!(matches!(e, ProtocolEvent::Shutdown));
768 }
769
770 #[tokio::test]
771 async fn msg_finishes_after_drop() {
772 let sid = Sid::new(1);
773 let [p1, p2] = quic_bound(10000, 0.5, None);
774 let (mut s, mut r) = (p1.0, p2.1);
775 let event = ProtocolEvent::OpenStream {
776 sid,
777 prio: 5u8,
778 promises: Promises::COMPRESSED | Promises::ORDERED,
779 guaranteed_bandwidth: 0,
780 };
781 s.send(event).await.unwrap();
782 let event = ProtocolEvent::Message {
783 sid,
784 data: Bytes::from(&[99u8; 500_000][..]),
785 };
786 s.send(event).await.unwrap();
787 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
788 let event = ProtocolEvent::Message {
789 sid,
790 data: Bytes::from(&[100u8; 500_000][..]),
791 };
792 s.send(event).await.unwrap();
793 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
794 drop(s);
795 let e = r.recv().await.unwrap();
796 assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
797 let e = r.recv().await.unwrap();
798 assert!(matches!(e, ProtocolEvent::Message { .. }));
799 let e = r.recv().await.unwrap();
800 assert!(matches!(e, ProtocolEvent::Message { .. }));
801 }
802
803 #[tokio::test]
804 async fn header_and_data_in_seperate_msg() {
805 let sid = Sid::new(1);
806 let (s, r) = async_channel::bounded(10);
807 let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
808 let mut r = super::QuicRecvProtocol::new(QuicSink { receiver: r }, m.clone());
809
810 const DATA1: &[u8; 69] =
811 b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD ";
812 const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open";
813 let mut bytes = BytesMut::with_capacity(1500);
814 OTFrame::OpenStream {
815 sid,
816 prio: 5u8,
817 promises: Promises::COMPRESSED | Promises::ORDERED,
818 guaranteed_bandwidth: 1_000_000,
819 }
820 .write_bytes(&mut bytes);
821 s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
822
823 OTFrame::DataHeader {
824 mid: 99,
825 sid,
826 length: (DATA1.len() + DATA2.len()) as u64,
827 }
828 .write_bytes(&mut bytes);
829 s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
830 .await
831 .unwrap();
832
833 OTFrame::Data {
834 mid: 99,
835 data: Bytes::from(&DATA1[..]),
836 }
837 .write_bytes(&mut bytes);
838 OTFrame::Data {
839 mid: 99,
840 data: Bytes::from(&DATA2[..]),
841 }
842 .write_bytes(&mut bytes);
843 s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
844 .await
845 .unwrap();
846
847 OTFrame::CloseStream { sid }.write_bytes(&mut bytes);
848 s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
849
850 let e = r.recv().await.unwrap();
851 assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
852 let e = r.recv().await.unwrap();
853 assert!(matches!(e, ProtocolEvent::Message { .. }));
854
855 let e = r.recv().await.unwrap();
856 assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
857 }
858
859 #[tokio::test]
860 async fn drop_sink_while_recv() {
861 let sid = Sid::new(1);
862 let (s, r) = async_channel::bounded(10);
863 let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
864 let mut r = super::QuicRecvProtocol::new(QuicSink { receiver: r }, m.clone());
865
866 let mut bytes = BytesMut::with_capacity(1500);
867 OTFrame::OpenStream {
868 sid,
869 prio: 5u8,
870 promises: Promises::COMPRESSED,
871 guaranteed_bandwidth: 1_000_000,
872 }
873 .write_bytes(&mut bytes);
874 s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
875 let e = r.recv().await.unwrap();
876 assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
877
878 let e = tokio::spawn(async move { r.recv().await });
879 drop(s);
880
881 let e = e.await.unwrap();
882 assert_eq!(e, Err(ProtocolError::Custom(())));
883 }
884
885 #[tokio::test]
886 #[should_panic]
887 async fn send_on_stream_from_remote_without_notify() {
888 let [mut p1, mut p2] = quic_bound(10, 0.5, None);
891 let event = ProtocolEvent::OpenStream {
892 sid: Sid::new(10),
893 prio: 3u8,
894 promises: Promises::ORDERED,
895 guaranteed_bandwidth: 1_000_000,
896 };
897 p1.0.send(event).await.unwrap();
898 let _ = p2.1.recv().await.unwrap();
899 let event = ProtocolEvent::Message {
900 sid: Sid::new(10),
901 data: Bytes::from(&[188u8; 600][..]),
902 };
903 p2.0.send(event.clone()).await.unwrap();
904 p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
905 let e = p1.1.recv().await.unwrap();
906 assert_eq!(event, e);
907 }
908
909 #[tokio::test]
910 async fn send_on_stream_from_remote() {
911 let [mut p1, mut p2] = quic_bound(10, 0.5, None);
914 let event = ProtocolEvent::OpenStream {
915 sid: Sid::new(10),
916 prio: 3u8,
917 promises: Promises::ORDERED,
918 guaranteed_bandwidth: 1_000_000,
919 };
920 p1.0.send(event).await.unwrap();
921 let e = p2.1.recv().await.unwrap();
922 p2.0.notify_from_recv(e);
923 let event = ProtocolEvent::Message {
924 sid: Sid::new(10),
925 data: Bytes::from(&[188u8; 600][..]),
926 };
927 p2.0.send(event.clone()).await.unwrap();
928 p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
929 let e = p1.1.recv().await.unwrap();
930 assert_eq!(event, e);
931 }
932
933 #[tokio::test]
934 async fn unrealiable_test() {
935 const MIN_CHECK: usize = 10;
936 const COUNT: usize = 10_000;
937 let [mut p1, mut p2] = quic_bound(
941 COUNT * 2 - 1, 0.5,
944 None,
945 );
946 let event = ProtocolEvent::OpenStream {
947 sid: Sid::new(1337),
948 prio: 3u8,
949 promises: Promises::empty(), guaranteed_bandwidth: 1_000_000,
951 };
952 p1.0.send(event).await.unwrap();
953 let e = p2.1.recv().await.unwrap();
954 p2.0.notify_from_recv(e);
955 let event = ProtocolEvent::Message {
956 sid: Sid::new(1337),
957 data: Bytes::from(&[188u8; 600][..]),
958 };
959 for _ in 0..COUNT {
960 p2.0.send(event.clone()).await.unwrap();
961 }
962 p2.0.flush(1_000_000_000, Duration::from_secs(1))
963 .await
964 .unwrap();
965 for _ in 0..COUNT {
966 p2.0.send(event.clone()).await.unwrap();
967 }
968 for _ in 0..MIN_CHECK {
969 let e = p1.1.recv().await.unwrap();
970 assert_eq!(event, e);
971 }
972 }
973}