1use crate::api::{ConnectAddr, NetworkConnectError};
2use async_trait::async_trait;
3use bytes::BytesMut;
4use futures_util::FutureExt;
5use hashbrown::HashMap;
6use network_protocol::{
7 Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid,
8 ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol,
9 TcpSendProtocol, UnreliableDrain, UnreliableSink,
10};
11#[cfg(feature = "quic")]
12use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol};
13use std::{
14 io,
15 net::SocketAddr,
16 sync::{
17 Arc,
18 atomic::{AtomicU64, Ordering},
19 },
20 time::Duration,
21};
22use tokio::{
23 io::{AsyncReadExt, AsyncWriteExt},
24 net,
25 net::tcp::{OwnedReadHalf, OwnedWriteHalf},
26 select,
27 sync::{Mutex, mpsc, oneshot},
28};
29use tracing::{error, info, trace, warn};
30
31#[derive(Debug)]
32pub(crate) enum Protocols {
33 Tcp((TcpSendProtocol<TcpDrain>, TcpRecvProtocol<TcpSink>)),
34 Mpsc((MpscSendProtocol<MpscDrain>, MpscRecvProtocol<MpscSink>)),
35 #[cfg(feature = "quic")]
36 Quic((QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>)),
37}
38
39#[derive(Debug)]
40pub(crate) enum SendProtocols {
41 Tcp(TcpSendProtocol<TcpDrain>),
42 Mpsc(MpscSendProtocol<MpscDrain>),
43 #[cfg(feature = "quic")]
44 Quic(QuicSendProtocol<QuicDrain>),
45}
46
47#[derive(Debug)]
48pub(crate) enum RecvProtocols {
49 Tcp(TcpRecvProtocol<TcpSink>),
50 Mpsc(MpscRecvProtocol<MpscSink>),
51 #[cfg(feature = "quic")]
52 Quic(QuicRecvProtocol<QuicSink>),
53}
54
55lazy_static::lazy_static! {
56 pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<C2cMpscConnect>>> = {
57 Mutex::new(HashMap::new())
58 };
59}
60
61pub(crate) type C2cMpscConnect = (
62 mpsc::Sender<MpscMsg>,
63 oneshot::Sender<mpsc::Sender<MpscMsg>>,
64);
65pub(crate) type C2sProtocol = (Protocols, ConnectAddr, Cid);
66
67fn anonymize_addr(addr: &SocketAddr) -> String {
68 use std::net::IpAddr;
69 match addr.ip() {
70 IpAddr::V4(ip) => {
71 let [o0, _, o2, _] = ip.octets();
72 format!("{o0}.xxx.{o2}.xxx:{}", addr.port())
73 },
74 IpAddr::V6(ip) => {
75 let [s0, s1, _, _, s4, s5, _, _] = ip.segments();
76 format!(
77 "[{s0:04x}:{s1:04x}:xxxx:xxxx:{s4:04x}:{s5:04x}:xxxx:xxxx]:{}",
78 addr.port()
79 )
80 },
81 }
82}
83
84impl Protocols {
85 const MPSC_CHANNEL_BOUND: usize = 1000;
86
87 pub(crate) async fn with_tcp_connect(
88 addr: SocketAddr,
89 metrics: ProtocolMetricCache,
90 ) -> Result<Self, NetworkConnectError> {
91 let stream = net::TcpStream::connect(addr)
92 .await
93 .and_then(|s| {
94 s.set_nodelay(true)?;
95 Ok(s)
96 })
97 .map_err(NetworkConnectError::Io)?;
98 info!(
99 "Connecting Tcp to: {}",
100 stream.peer_addr().map_err(NetworkConnectError::Io)?
101 );
102 Ok(Self::new_tcp(stream, metrics))
103 }
104
105 pub(crate) async fn with_tcp_listen(
106 addr: SocketAddr,
107 cids: Arc<AtomicU64>,
108 metrics: Arc<ProtocolMetrics>,
109 s2s_stop_listening_r: oneshot::Receiver<()>,
110 c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
111 ) -> io::Result<()> {
112 use socket2::{Domain, Socket, Type};
113 let domain = Domain::for_address(addr);
114 let socket2_socket = Socket::new(domain, Type::STREAM, None)?;
115 if domain == Domain::IPV6 {
116 socket2_socket.set_only_v6(true)?
117 }
118 socket2_socket.set_nonblocking(true)?; #[cfg(not(windows))]
121 socket2_socket.set_reuse_address(true)?;
122 const SEND_BUFFER_SIZE: usize = 262144;
123 const RECV_BUFFER_SIZE: usize = SEND_BUFFER_SIZE * 2;
124 if let Err(e) = socket2_socket.set_recv_buffer_size(RECV_BUFFER_SIZE) {
125 warn!(?e, "Couldn't set recv_buffer size")
126 };
127 if let Err(e) = socket2_socket.set_send_buffer_size(SEND_BUFFER_SIZE) {
128 warn!(?e, "Couldn't set set_buffer size")
129 };
130 let socket2_addr = addr.into();
131 socket2_socket.bind(&socket2_addr)?;
132 socket2_socket.listen(1024)?;
133 let std_listener: std::net::TcpListener = socket2_socket.into();
134 let listener = net::TcpListener::from_std(std_listener)?;
135 trace!(?addr, "Tcp Listener bound");
136 let mut end_receiver = s2s_stop_listening_r.fuse();
137 tokio::spawn(async move {
138 while let Some(data) = select! {
139 next = listener.accept().fuse() => Some(next),
140 _ = &mut end_receiver => None,
141 } {
142 let (stream, remote_addr) = match data {
143 Ok((s, p)) => (s, p),
144 Err(e) => {
145 trace!(?e, "TcpStream Error, ignoring connection attempt");
146 continue;
147 },
148 };
149 if let Err(e) = stream.set_nodelay(true) {
150 warn!(
151 ?e,
152 "Failed to set TCP_NODELAY, client may have degraded latency"
153 );
154 }
155 let cid = cids.fetch_add(1, Ordering::Relaxed);
156 info!(
157 remote_addr = anonymize_addr(&remote_addr),
158 ?cid,
159 "Accepting Tcp from"
160 );
161 let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
162 let _ = c2s_protocol_s.send((
163 Self::new_tcp(stream, metrics.clone()),
164 ConnectAddr::Tcp(remote_addr),
165 cid,
166 ));
167 }
168 });
169 Ok(())
170 }
171
172 pub(crate) fn new_tcp(stream: net::TcpStream, metrics: ProtocolMetricCache) -> Self {
173 let (r, w) = stream.into_split();
174 let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone());
175 let rp = TcpRecvProtocol::new(
176 TcpSink {
177 half: r,
178 buffer: BytesMut::new(),
179 },
180 metrics,
181 );
182 Protocols::Tcp((sp, rp))
183 }
184
185 pub(crate) async fn with_mpsc_connect(
186 addr: u64,
187 metrics: ProtocolMetricCache,
188 ) -> Result<Self, NetworkConnectError> {
189 let mpsc_s = MPSC_POOL
190 .lock()
191 .await
192 .get(&addr)
193 .ok_or_else(|| {
194 NetworkConnectError::Io(io::Error::new(
195 io::ErrorKind::NotConnected,
196 "no mpsc listen on this addr",
197 ))
198 })?
199 .clone();
200 let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
201 let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
202 mpsc_s
203 .send((remote_to_local_s, local_to_remote_oneshot_s))
204 .map_err(|_| {
205 NetworkConnectError::Io(io::Error::new(
206 io::ErrorKind::BrokenPipe,
207 "mpsc pipe broke during connect",
208 ))
209 })?;
210 let local_to_remote_s = local_to_remote_oneshot_r
211 .await
212 .map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
213 info!(?addr, "Connecting Mpsc");
214 Ok(Self::new_mpsc(
215 local_to_remote_s,
216 remote_to_local_r,
217 metrics,
218 ))
219 }
220
221 pub(crate) async fn with_mpsc_listen(
222 addr: u64,
223 cids: Arc<AtomicU64>,
224 metrics: Arc<ProtocolMetrics>,
225 s2s_stop_listening_r: oneshot::Receiver<()>,
226 c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
227 ) -> io::Result<()> {
228 let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
229 MPSC_POOL.lock().await.insert(addr, mpsc_s);
230 trace!(?addr, "Mpsc Listener bound");
231 let mut end_receiver = s2s_stop_listening_r.fuse();
232 tokio::spawn(async move {
233 while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
234 next = mpsc_r.recv().fuse() => next,
235 _ = &mut end_receiver => None,
236 } {
237 let (remote_to_local_s, remote_to_local_r) =
238 mpsc::channel(Self::MPSC_CHANNEL_BOUND);
239 if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) {
240 error!(?e, "mpsc listen aborted");
241 }
242
243 let cid = cids.fetch_add(1, Ordering::Relaxed);
244 info!(?addr, ?cid, "Accepting Mpsc from");
245 let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
246 let _ = c2s_protocol_s.send((
247 Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()),
248 ConnectAddr::Mpsc(addr),
249 cid,
250 ));
251 }
252 warn!("MpscStream Failed, stopping");
253 });
254 Ok(())
255 }
256
257 pub(crate) fn new_mpsc(
258 sender: mpsc::Sender<MpscMsg>,
259 receiver: mpsc::Receiver<MpscMsg>,
260 metrics: ProtocolMetricCache,
261 ) -> Self {
262 let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone());
263 let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics);
264 Protocols::Mpsc((sp, rp))
265 }
266
267 #[cfg(feature = "quic")]
268 pub(crate) async fn with_quic_connect(
269 addr: SocketAddr,
270 config: quinn::ClientConfig,
271 name: String,
272 metrics: ProtocolMetricCache,
273 ) -> Result<Self, NetworkConnectError> {
274 let config = config.clone();
275
276 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
277
278 let bindsock = match addr {
279 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
280 SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
281 };
282 let endpoint = match quinn::Endpoint::client(bindsock) {
283 Ok(e) => e,
284 Err(e) => return Err(NetworkConnectError::Io(e)),
285 };
286
287 info!("Connecting Quic to: {}", &addr);
288 let connecting = endpoint.connect_with(config, addr, &name).map_err(|e| {
289 trace!(?e, "error setting up quic");
290 NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
291 })?;
292 let connection = connecting.await.map_err(|e| {
293 trace!(?e, "error with quic connection");
294 NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
295 })?;
296 Self::new_quic(connection, false, metrics)
297 .await
298 .map_err(|e| {
299 trace!(?e, "error with quic");
300 NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
301 })
302 }
303
304 #[cfg(feature = "quic")]
305 pub(crate) async fn with_quic_listen(
306 addr: SocketAddr,
307 server_config: quinn::ServerConfig,
308 cids: Arc<AtomicU64>,
309 metrics: Arc<ProtocolMetrics>,
310 s2s_stop_listening_r: oneshot::Receiver<()>,
311 c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
312 ) -> io::Result<()> {
313 let endpoint = quinn::Endpoint::server(server_config, addr)?;
314 trace!(?addr, "Quic Listener bound");
315 let mut end_receiver = s2s_stop_listening_r.fuse();
316 let config = quinn::ClientConfig::try_with_platform_verifier()
317 .map_err(|e| io::Error::other(Box::new(e)))?;
318 tokio::spawn(async move {
319 while let Some(Some(connecting)) = select! {
320 next = endpoint.accept().fuse() => Some(next),
321 _ = &mut end_receiver => None,
322 } {
323 let remote_addr = anonymize_addr(&connecting.remote_address());
324 let connection = match connecting.await {
325 Ok(c) => c,
326 Err(e) => {
327 tracing::debug!(?e, ?remote_addr, "skipping connection attempt");
328 continue;
329 },
330 };
331
332 let cid = cids.fetch_add(1, Ordering::Relaxed);
333 info!(?remote_addr, ?cid, "Accepting Quic from");
334 let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
335 match Protocols::new_quic(connection, true, metrics).await {
336 Ok(quic) => {
337 let connect_addr = ConnectAddr::Quic(
342 addr,
343 config.clone(),
344 "TODO_remote_hostname".to_string(),
345 );
346 let _ = c2s_protocol_s.send((quic, connect_addr, cid));
347 },
348 Err(e) => {
349 trace!(?e, "failed to start quic");
350 continue;
351 },
352 }
353 }
354 });
355 Ok(())
356 }
357
358 #[cfg(feature = "quic")]
359 pub(crate) async fn new_quic(
360 connection: quinn::Connection,
361 listen: bool,
362 metrics: ProtocolMetricCache,
363 ) -> Result<Self, quinn::ConnectionError> {
364 let (sendstream, recvstream) = if listen {
365 connection.open_bi().await?
366 } else {
367 connection
368 .accept_bi()
369 .await
370 .or(Err(quinn::ConnectionError::LocallyClosed))?
371 };
372 let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel();
373 let streams_s_clone = recvstreams_s.clone();
374 let (sendstreams_s, sendstreams_r) = mpsc::unbounded_channel();
375 let sp = QuicSendProtocol::new(
376 QuicDrain {
377 con: connection.clone(),
378 main: sendstream,
379 reliables: HashMap::new(),
380 recvstreams_s: streams_s_clone,
381 sendstreams_r,
382 },
383 metrics.clone(),
384 );
385 spawn_new(recvstream, None, &recvstreams_s);
386 let rp = QuicRecvProtocol::new(
387 QuicSink {
388 con: connection,
389 recvstreams_r,
390 recvstreams_s,
391 sendstreams_s,
392 },
393 metrics,
394 );
395 Ok(Protocols::Quic((sp, rp)))
396 }
397
398 pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) {
399 match self {
400 Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)),
401 Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)),
402 #[cfg(feature = "quic")]
403 Protocols::Quic((s, r)) => (SendProtocols::Quic(s), RecvProtocols::Quic(r)),
404 }
405 }
406}
407
408#[async_trait]
409impl network_protocol::InitProtocol for Protocols {
410 type CustomErr = ProtocolsError;
411
412 async fn initialize(
413 &mut self,
414 initializer: bool,
415 local_pid: Pid,
416 secret: u128,
417 ) -> Result<(Pid, Sid, u128), InitProtocolError<Self::CustomErr>> {
418 match self {
419 Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await,
420 Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await,
421 #[cfg(feature = "quic")]
422 Protocols::Quic(p) => p.initialize(initializer, local_pid, secret).await,
423 }
424 }
425}
426
427#[async_trait]
428impl network_protocol::SendProtocol for SendProtocols {
429 type CustomErr = ProtocolsError;
430
431 fn notify_from_recv(&mut self, event: ProtocolEvent) {
432 match self {
433 SendProtocols::Tcp(s) => s.notify_from_recv(event),
434 SendProtocols::Mpsc(s) => s.notify_from_recv(event),
435 #[cfg(feature = "quic")]
436 SendProtocols::Quic(s) => s.notify_from_recv(event),
437 }
438 }
439
440 async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
441 match self {
442 SendProtocols::Tcp(s) => s.send(event).await,
443 SendProtocols::Mpsc(s) => s.send(event).await,
444 #[cfg(feature = "quic")]
445 SendProtocols::Quic(s) => s.send(event).await,
446 }
447 }
448
449 async fn flush(
450 &mut self,
451 bandwidth: Bandwidth,
452 dt: Duration,
453 ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
454 match self {
455 SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await,
456 SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await,
457 #[cfg(feature = "quic")]
458 SendProtocols::Quic(s) => s.flush(bandwidth, dt).await,
459 }
460 }
461}
462
463#[async_trait]
464impl network_protocol::RecvProtocol for RecvProtocols {
465 type CustomErr = ProtocolsError;
466
467 async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
468 match self {
469 RecvProtocols::Tcp(r) => r.recv().await,
470 RecvProtocols::Mpsc(r) => r.recv().await,
471 #[cfg(feature = "quic")]
472 RecvProtocols::Quic(r) => r.recv().await,
473 }
474 }
475}
476
477#[derive(Debug)]
478pub enum MpscError {
479 Send(mpsc::error::SendError<MpscMsg>),
480 Recv,
481}
482
483#[cfg(feature = "quic")]
484#[derive(Debug)]
485pub enum QuicError {
486 Send(io::Error),
487 Connection(quinn::ConnectionError),
488 Write(quinn::WriteError),
489 Read(quinn::ReadError),
490 InternalMpsc,
491}
492
493#[derive(Debug)]
495pub enum ProtocolsError {
496 Tcp(io::Error),
497 Udp(io::Error),
498 #[cfg(feature = "quic")]
499 Quic(QuicError),
500 Mpsc(MpscError),
501}
502
503#[derive(Debug)]
506pub struct TcpDrain {
507 half: OwnedWriteHalf,
508}
509
510#[derive(Debug)]
511pub struct TcpSink {
512 half: OwnedReadHalf,
513 buffer: BytesMut,
514}
515
516#[async_trait]
517impl UnreliableDrain for TcpDrain {
518 type CustomErr = ProtocolsError;
519 type DataFormat = BytesMut;
520
521 async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
522 self.half
523 .write_all(&data)
524 .await
525 .map_err(|e| ProtocolError::Custom(ProtocolsError::Tcp(e)))
526 }
527}
528
529#[async_trait]
530impl UnreliableSink for TcpSink {
531 type CustomErr = ProtocolsError;
532 type DataFormat = BytesMut;
533
534 async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
535 if self.buffer.capacity() < 1500 {
536 self.buffer.reserve(1500 * 4); }
538 match self.half.read_buf(&mut self.buffer).await {
539 Ok(0) => Err(ProtocolError::Custom(ProtocolsError::Tcp(io::Error::new(
540 io::ErrorKind::BrokenPipe,
541 "read returned 0 bytes",
542 )))),
543 Ok(_) => Ok(self.buffer.split()),
544 Err(e) => Err(ProtocolError::Custom(ProtocolsError::Tcp(e))),
545 }
546 }
547}
548
549#[derive(Debug)]
552pub struct MpscDrain {
553 sender: mpsc::Sender<MpscMsg>,
554}
555
556#[derive(Debug)]
557pub struct MpscSink {
558 receiver: mpsc::Receiver<MpscMsg>,
559}
560
561#[async_trait]
562impl UnreliableDrain for MpscDrain {
563 type CustomErr = ProtocolsError;
564 type DataFormat = MpscMsg;
565
566 async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
567 self.sender
568 .send(data)
569 .await
570 .map_err(|e| ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Send(e))))
571 }
572}
573
574#[async_trait]
575impl UnreliableSink for MpscSink {
576 type CustomErr = ProtocolsError;
577 type DataFormat = MpscMsg;
578
579 async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
580 self.receiver
581 .recv()
582 .await
583 .ok_or(ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Recv)))
584 }
585}
586
587#[cfg(feature = "quic")]
590type QuicStream = (
591 BytesMut,
592 Result<Option<usize>, quinn::ReadError>,
593 quinn::RecvStream,
594 Option<Sid>,
595);
596
597#[cfg(feature = "quic")]
598#[derive(Debug)]
599pub struct QuicDrain {
600 con: quinn::Connection,
601 main: quinn::SendStream,
602 reliables: HashMap<Sid, quinn::SendStream>,
603 recvstreams_s: mpsc::UnboundedSender<QuicStream>,
604 sendstreams_r: mpsc::UnboundedReceiver<quinn::SendStream>,
605}
606
607#[cfg(feature = "quic")]
608#[derive(Debug)]
609pub struct QuicSink {
610 con: quinn::Connection,
611 recvstreams_r: mpsc::UnboundedReceiver<QuicStream>,
612 recvstreams_s: mpsc::UnboundedSender<QuicStream>,
613 sendstreams_s: mpsc::UnboundedSender<quinn::SendStream>,
614}
615
616#[cfg(feature = "quic")]
617fn spawn_new(
618 mut recvstream: quinn::RecvStream,
619 sid: Option<Sid>,
620 streams_s: &mpsc::UnboundedSender<QuicStream>,
621) {
622 let streams_s_clone = streams_s.clone();
623 tokio::spawn(async move {
624 let mut buffer = BytesMut::new();
625 buffer.resize(1500, 0u8);
626 let r = recvstream.read(&mut buffer).await;
627 let _ = streams_s_clone.send((buffer, r, recvstream, sid));
628 });
629}
630
631#[cfg(feature = "quic")]
632#[async_trait]
633impl UnreliableDrain for QuicDrain {
634 type CustomErr = ProtocolsError;
635 type DataFormat = QuicDataFormat;
636
637 async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
638 match data.stream {
639 QuicDataFormatStream::Main => self.main.write_all(&data.data).await,
640 QuicDataFormatStream::Unreliable => unimplemented!(),
641 QuicDataFormatStream::Reliable(sid) => {
642 use hashbrown::hash_map::Entry;
643 match self.reliables.entry(sid) {
645 Entry::Occupied(mut occupied) => occupied.get_mut().write_all(&data.data).await,
646 Entry::Vacant(vacant) => {
647 if data.data.is_empty() {
650 let (mut sendstream, recvstream) =
651 self.con.open_bi().await.map_err(|e| {
652 ProtocolError::Custom(ProtocolsError::Quic(
653 QuicError::Connection(e),
654 ))
655 })?;
656 sendstream.write_u64(sid.get_u64()).await.map_err(|e| {
658 ProtocolError::Custom(ProtocolsError::Quic(QuicError::Send(e)))
659 })?;
660 spawn_new(recvstream, Some(sid), &self.recvstreams_s);
661 vacant.insert(sendstream).write_all(&data.data).await
662 } else {
663 let sendstream =
664 self.sendstreams_r
665 .recv()
666 .await
667 .ok_or(ProtocolError::Custom(ProtocolsError::Quic(
668 QuicError::InternalMpsc,
669 )))?;
670 vacant.insert(sendstream).write_all(&data.data).await
671 }
672 },
673 }
674 },
675 }
676 .map_err(|e| ProtocolError::Custom(ProtocolsError::Quic(QuicError::Write(e))))
677 }
678}
679
680#[cfg(feature = "quic")]
681#[async_trait]
682impl UnreliableSink for QuicSink {
683 type CustomErr = ProtocolsError;
684 type DataFormat = QuicDataFormat;
685
686 async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
687 let (mut buffer, result, mut recvstream, id) = loop {
688 use futures_util::FutureExt;
689 let (a, b) = select! {
691 biased;
692 n = self.con.accept_bi().fuse() => (Some(n), None),
693 Some(n) = self.recvstreams_r.recv().fuse() => (None, Some(n)),
694 };
695
696 if let Some(remote_stream) = a {
697 let (sendstream, mut recvstream) = remote_stream.map_err(|e| {
698 ProtocolError::Custom(ProtocolsError::Quic(QuicError::Connection(e)))
699 })?;
700 let sid = match recvstream.read_u64().await {
701 Ok(u64::MAX) => None, Ok(sid) => Some(Sid::new(sid)),
703 Err(_) => return Err(ProtocolError::Violated),
704 };
705 if self.sendstreams_s.send(sendstream).is_err() {
706 return Err(ProtocolError::Custom(ProtocolsError::Quic(
707 QuicError::InternalMpsc,
708 )));
709 }
710 spawn_new(recvstream, sid, &self.recvstreams_s);
711 }
712
713 if let Some(data) = b {
714 break data;
715 }
716 };
717
718 let r = match result {
719 Ok(Some(0)) => Err(ProtocolError::Custom(ProtocolsError::Quic(
720 QuicError::Send(io::Error::new(
721 io::ErrorKind::BrokenPipe,
722 "read returned 0 bytes",
723 )),
724 ))),
725 Ok(Some(n)) => Ok(QuicDataFormat {
726 stream: match id {
727 Some(id) => QuicDataFormatStream::Reliable(id),
728 None => QuicDataFormatStream::Main,
729 },
730 data: buffer.split_to(n),
731 }),
732 Ok(None) => Err(ProtocolError::Custom(ProtocolsError::Quic(
733 QuicError::Send(io::Error::new(
734 io::ErrorKind::BrokenPipe,
735 "read returned None",
736 )),
737 ))),
738 Err(e) => Err(ProtocolError::Custom(ProtocolsError::Quic(
739 QuicError::Read(e),
740 ))),
741 }?;
742
743 let streams_s_clone = self.recvstreams_s.clone();
744 tokio::spawn(async move {
745 buffer.resize(1500, 0u8);
746 let r = recvstream.read(&mut buffer).await;
747 let _ = streams_s_clone.send((buffer, r, recvstream, id));
748 });
749 Ok(r)
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756 use bytes::Bytes;
757 use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol};
758 use std::sync::Arc;
759 use tokio::net::{TcpListener, TcpStream};
760
761 #[tokio::test]
762 async fn tokio_sinks() {
763 let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap();
764 let r1 = tokio::spawn(async move {
765 let (server, _) = listener.accept().await.unwrap();
766 (listener, server)
767 });
768 let client = TcpStream::connect("127.0.0.1:5000").await.unwrap();
769 let (_listener, server) = r1.await.unwrap();
770 let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
771 let client = Protocols::new_tcp(client, metrics.clone());
772 let server = Protocols::new_tcp(server, metrics);
773 let (mut s, _) = client.split();
774 let (_, mut r) = server.split();
775 let event = ProtocolEvent::OpenStream {
776 sid: Sid::new(1),
777 prio: 4u8,
778 promises: Promises::GUARANTEED_DELIVERY,
779 guaranteed_bandwidth: 1_000,
780 };
781 s.send(event.clone()).await.unwrap();
782 s.send(ProtocolEvent::Message {
783 sid: Sid::new(1),
784 data: Bytes::from(&[8u8; 8][..]),
785 })
786 .await
787 .unwrap();
788 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
789 drop(s); tokio::time::sleep(Duration::from_secs(1)).await;
791 let res = r.recv().await;
792 match res {
793 Ok(ProtocolEvent::OpenStream {
794 sid,
795 prio,
796 promises,
797 guaranteed_bandwidth: _,
798 }) => {
799 assert_eq!(sid, Sid::new(1));
800 assert_eq!(prio, 4u8);
801 assert_eq!(promises, Promises::GUARANTEED_DELIVERY);
802 },
803 _ => {
804 panic!("wrong type {:?}", res);
805 },
806 }
807 r.recv().await.unwrap();
808 }
809
810 #[tokio::test]
811 async fn tokio_sink_stop_after_drop() {
812 let listener = TcpListener::bind("127.0.0.1:5001").await.unwrap();
813 let r1 = tokio::spawn(async move {
814 let (server, _) = listener.accept().await.unwrap();
815 (listener, server)
816 });
817 let client = TcpStream::connect("127.0.0.1:5001").await.unwrap();
818 let (_listener, server) = r1.await.unwrap();
819 let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
820 let client = Protocols::new_tcp(client, metrics.clone());
821 let server = Protocols::new_tcp(server, metrics);
822 let (s, _) = client.split();
823 let (_, mut r) = server.split();
824 let e = tokio::spawn(async move { r.recv().await });
825 drop(s);
826 let e = e.await.unwrap();
827 assert!(e.is_err());
828 assert!(matches!(e, Err(..)));
829 let e = e.unwrap_err();
830 assert!(matches!(e, ProtocolError::Custom(..)));
831 assert!(matches!(e, ProtocolError::Custom(ProtocolsError::Tcp(_))));
832 match e {
833 ProtocolError::Custom(ProtocolsError::Tcp(e)) => {
834 assert_eq!(e.kind(), io::ErrorKind::BrokenPipe)
835 },
836 _ => panic!("invalid error"),
837 }
838 }
839}