veloren_network/
channel.rs

1use crate::api::{ConnectAddr, NetworkConnectError};
2use async_trait::async_trait;
3use bytes::BytesMut;
4use futures_util::FutureExt;
5use hashbrown::HashMap;
6use network_protocol::{
7    Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid,
8    ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol,
9    TcpSendProtocol, UnreliableDrain, UnreliableSink,
10};
11#[cfg(feature = "quic")]
12use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol};
13use std::{
14    io,
15    net::SocketAddr,
16    sync::{
17        Arc,
18        atomic::{AtomicU64, Ordering},
19    },
20    time::Duration,
21};
22use tokio::{
23    io::{AsyncReadExt, AsyncWriteExt},
24    net,
25    net::tcp::{OwnedReadHalf, OwnedWriteHalf},
26    select,
27    sync::{Mutex, mpsc, oneshot},
28};
29use tracing::{error, info, trace, warn};
30
31#[derive(Debug)]
32pub(crate) enum Protocols {
33    Tcp((TcpSendProtocol<TcpDrain>, TcpRecvProtocol<TcpSink>)),
34    Mpsc((MpscSendProtocol<MpscDrain>, MpscRecvProtocol<MpscSink>)),
35    #[cfg(feature = "quic")]
36    Quic((QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>)),
37}
38
39#[derive(Debug)]
40pub(crate) enum SendProtocols {
41    Tcp(TcpSendProtocol<TcpDrain>),
42    Mpsc(MpscSendProtocol<MpscDrain>),
43    #[cfg(feature = "quic")]
44    Quic(QuicSendProtocol<QuicDrain>),
45}
46
47#[derive(Debug)]
48pub(crate) enum RecvProtocols {
49    Tcp(TcpRecvProtocol<TcpSink>),
50    Mpsc(MpscRecvProtocol<MpscSink>),
51    #[cfg(feature = "quic")]
52    Quic(QuicRecvProtocol<QuicSink>),
53}
54
55lazy_static::lazy_static! {
56    pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<C2cMpscConnect>>> = {
57        Mutex::new(HashMap::new())
58    };
59}
60
61pub(crate) type C2cMpscConnect = (
62    mpsc::Sender<MpscMsg>,
63    oneshot::Sender<mpsc::Sender<MpscMsg>>,
64);
65pub(crate) type C2sProtocol = (Protocols, ConnectAddr, Cid);
66
67fn anonymize_addr(addr: &SocketAddr) -> String {
68    use std::net::IpAddr;
69    match addr.ip() {
70        IpAddr::V4(ip) => {
71            let [o0, _, o2, _] = ip.octets();
72            format!("{o0}.xxx.{o2}.xxx:{}", addr.port())
73        },
74        IpAddr::V6(ip) => {
75            let [s0, s1, _, _, s4, s5, _, _] = ip.segments();
76            format!(
77                "[{s0:04x}:{s1:04x}:xxxx:xxxx:{s4:04x}:{s5:04x}:xxxx:xxxx]:{}",
78                addr.port()
79            )
80        },
81    }
82}
83
84impl Protocols {
85    const MPSC_CHANNEL_BOUND: usize = 1000;
86
87    pub(crate) async fn with_tcp_connect(
88        addr: SocketAddr,
89        metrics: ProtocolMetricCache,
90    ) -> Result<Self, NetworkConnectError> {
91        let stream = net::TcpStream::connect(addr)
92            .await
93            .and_then(|s| {
94                s.set_nodelay(true)?;
95                Ok(s)
96            })
97            .map_err(NetworkConnectError::Io)?;
98        info!(
99            "Connecting Tcp to: {}",
100            stream.peer_addr().map_err(NetworkConnectError::Io)?
101        );
102        Ok(Self::new_tcp(stream, metrics))
103    }
104
105    pub(crate) async fn with_tcp_listen(
106        addr: SocketAddr,
107        cids: Arc<AtomicU64>,
108        metrics: Arc<ProtocolMetrics>,
109        s2s_stop_listening_r: oneshot::Receiver<()>,
110        c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
111    ) -> io::Result<()> {
112        use socket2::{Domain, Socket, Type};
113        let domain = Domain::for_address(addr);
114        let socket2_socket = Socket::new(domain, Type::STREAM, None)?;
115        if domain == Domain::IPV6 {
116            socket2_socket.set_only_v6(true)?
117        }
118        socket2_socket.set_nonblocking(true)?; // Needed by Tokio
119        // See https://docs.rs/tokio/latest/tokio/net/struct.TcpSocket.html
120        #[cfg(not(windows))]
121        socket2_socket.set_reuse_address(true)?;
122        const SEND_BUFFER_SIZE: usize = 262144;
123        const RECV_BUFFER_SIZE: usize = SEND_BUFFER_SIZE * 2;
124        if let Err(e) = socket2_socket.set_recv_buffer_size(RECV_BUFFER_SIZE) {
125            warn!(?e, "Couldn't set recv_buffer size")
126        };
127        if let Err(e) = socket2_socket.set_send_buffer_size(SEND_BUFFER_SIZE) {
128            warn!(?e, "Couldn't set set_buffer size")
129        };
130        let socket2_addr = addr.into();
131        socket2_socket.bind(&socket2_addr)?;
132        socket2_socket.listen(1024)?;
133        let std_listener: std::net::TcpListener = socket2_socket.into();
134        let listener = net::TcpListener::from_std(std_listener)?;
135        trace!(?addr, "Tcp Listener bound");
136        let mut end_receiver = s2s_stop_listening_r.fuse();
137        tokio::spawn(async move {
138            while let Some(data) = select! {
139                    next = listener.accept().fuse() => Some(next),
140                    _ = &mut end_receiver => None,
141            } {
142                let (stream, remote_addr) = match data {
143                    Ok((s, p)) => (s, p),
144                    Err(e) => {
145                        trace!(?e, "TcpStream Error, ignoring connection attempt");
146                        continue;
147                    },
148                };
149                if let Err(e) = stream.set_nodelay(true) {
150                    warn!(
151                        ?e,
152                        "Failed to set TCP_NODELAY, client may have degraded latency"
153                    );
154                }
155                let cid = cids.fetch_add(1, Ordering::Relaxed);
156                info!(
157                    remote_addr = anonymize_addr(&remote_addr),
158                    ?cid,
159                    "Accepting Tcp from"
160                );
161                let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
162                let _ = c2s_protocol_s.send((
163                    Self::new_tcp(stream, metrics.clone()),
164                    ConnectAddr::Tcp(remote_addr),
165                    cid,
166                ));
167            }
168        });
169        Ok(())
170    }
171
172    pub(crate) fn new_tcp(stream: net::TcpStream, metrics: ProtocolMetricCache) -> Self {
173        let (r, w) = stream.into_split();
174        let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone());
175        let rp = TcpRecvProtocol::new(
176            TcpSink {
177                half: r,
178                buffer: BytesMut::new(),
179            },
180            metrics,
181        );
182        Protocols::Tcp((sp, rp))
183    }
184
185    pub(crate) async fn with_mpsc_connect(
186        addr: u64,
187        metrics: ProtocolMetricCache,
188    ) -> Result<Self, NetworkConnectError> {
189        let mpsc_s = MPSC_POOL
190            .lock()
191            .await
192            .get(&addr)
193            .ok_or_else(|| {
194                NetworkConnectError::Io(io::Error::new(
195                    io::ErrorKind::NotConnected,
196                    "no mpsc listen on this addr",
197                ))
198            })?
199            .clone();
200        let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
201        let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
202        mpsc_s
203            .send((remote_to_local_s, local_to_remote_oneshot_s))
204            .map_err(|_| {
205                NetworkConnectError::Io(io::Error::new(
206                    io::ErrorKind::BrokenPipe,
207                    "mpsc pipe broke during connect",
208                ))
209            })?;
210        let local_to_remote_s = local_to_remote_oneshot_r
211            .await
212            .map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
213        info!(?addr, "Connecting Mpsc");
214        Ok(Self::new_mpsc(
215            local_to_remote_s,
216            remote_to_local_r,
217            metrics,
218        ))
219    }
220
221    pub(crate) async fn with_mpsc_listen(
222        addr: u64,
223        cids: Arc<AtomicU64>,
224        metrics: Arc<ProtocolMetrics>,
225        s2s_stop_listening_r: oneshot::Receiver<()>,
226        c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
227    ) -> io::Result<()> {
228        let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
229        MPSC_POOL.lock().await.insert(addr, mpsc_s);
230        trace!(?addr, "Mpsc Listener bound");
231        let mut end_receiver = s2s_stop_listening_r.fuse();
232        tokio::spawn(async move {
233            while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
234                    next = mpsc_r.recv().fuse() => next,
235                    _ = &mut end_receiver => None,
236            } {
237                let (remote_to_local_s, remote_to_local_r) =
238                    mpsc::channel(Self::MPSC_CHANNEL_BOUND);
239                if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) {
240                    error!(?e, "mpsc listen aborted");
241                }
242
243                let cid = cids.fetch_add(1, Ordering::Relaxed);
244                info!(?addr, ?cid, "Accepting Mpsc from");
245                let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
246                let _ = c2s_protocol_s.send((
247                    Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()),
248                    ConnectAddr::Mpsc(addr),
249                    cid,
250                ));
251            }
252            warn!("MpscStream Failed, stopping");
253        });
254        Ok(())
255    }
256
257    pub(crate) fn new_mpsc(
258        sender: mpsc::Sender<MpscMsg>,
259        receiver: mpsc::Receiver<MpscMsg>,
260        metrics: ProtocolMetricCache,
261    ) -> Self {
262        let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone());
263        let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics);
264        Protocols::Mpsc((sp, rp))
265    }
266
267    #[cfg(feature = "quic")]
268    pub(crate) async fn with_quic_connect(
269        addr: SocketAddr,
270        config: quinn::ClientConfig,
271        name: String,
272        metrics: ProtocolMetricCache,
273    ) -> Result<Self, NetworkConnectError> {
274        let config = config.clone();
275
276        use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
277
278        let bindsock = match addr {
279            SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
280            SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
281        };
282        let endpoint = match quinn::Endpoint::client(bindsock) {
283            Ok(e) => e,
284            Err(e) => return Err(NetworkConnectError::Io(e)),
285        };
286
287        info!("Connecting Quic to: {}", &addr);
288        let connecting = endpoint.connect_with(config, addr, &name).map_err(|e| {
289            trace!(?e, "error setting up quic");
290            NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
291        })?;
292        let connection = connecting.await.map_err(|e| {
293            trace!(?e, "error with quic connection");
294            NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
295        })?;
296        Self::new_quic(connection, false, metrics)
297            .await
298            .map_err(|e| {
299                trace!(?e, "error with quic");
300                NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
301            })
302    }
303
304    #[cfg(feature = "quic")]
305    pub(crate) async fn with_quic_listen(
306        addr: SocketAddr,
307        server_config: quinn::ServerConfig,
308        cids: Arc<AtomicU64>,
309        metrics: Arc<ProtocolMetrics>,
310        s2s_stop_listening_r: oneshot::Receiver<()>,
311        c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
312    ) -> io::Result<()> {
313        let endpoint = quinn::Endpoint::server(server_config, addr)?;
314        trace!(?addr, "Quic Listener bound");
315        let mut end_receiver = s2s_stop_listening_r.fuse();
316        tokio::spawn(async move {
317            while let Some(Some(connecting)) = select! {
318                next = endpoint.accept().fuse() => Some(next),
319                _ = &mut end_receiver => None,
320            } {
321                let remote_addr = anonymize_addr(&connecting.remote_address());
322                let connection = match connecting.await {
323                    Ok(c) => c,
324                    Err(e) => {
325                        tracing::debug!(?e, ?remote_addr, "skipping connection attempt");
326                        continue;
327                    },
328                };
329
330                let cid = cids.fetch_add(1, Ordering::Relaxed);
331                info!(?remote_addr, ?cid, "Accepting Quic from");
332                let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
333                match Protocols::new_quic(connection, true, metrics).await {
334                    Ok(quic) => {
335                        // TODO: we cannot guess the client hostname in quic server here.
336                        // though we need it for the certificate to be validated, in the future
337                        // this will either go away with new auth, or we have to do something like
338                        // a reverse DNS lookup
339                        let connect_addr = ConnectAddr::Quic(
340                            addr,
341                            quinn::ClientConfig::with_platform_verifier(),
342                            "TODO_remote_hostname".to_string(),
343                        );
344                        let _ = c2s_protocol_s.send((quic, connect_addr, cid));
345                    },
346                    Err(e) => {
347                        trace!(?e, "failed to start quic");
348                        continue;
349                    },
350                }
351            }
352        });
353        Ok(())
354    }
355
356    #[cfg(feature = "quic")]
357    pub(crate) async fn new_quic(
358        connection: quinn::Connection,
359        listen: bool,
360        metrics: ProtocolMetricCache,
361    ) -> Result<Self, quinn::ConnectionError> {
362        let (sendstream, recvstream) = if listen {
363            connection.open_bi().await?
364        } else {
365            connection
366                .accept_bi()
367                .await
368                .or(Err(quinn::ConnectionError::LocallyClosed))?
369        };
370        let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel();
371        let streams_s_clone = recvstreams_s.clone();
372        let (sendstreams_s, sendstreams_r) = mpsc::unbounded_channel();
373        let sp = QuicSendProtocol::new(
374            QuicDrain {
375                con: connection.clone(),
376                main: sendstream,
377                reliables: HashMap::new(),
378                recvstreams_s: streams_s_clone,
379                sendstreams_r,
380            },
381            metrics.clone(),
382        );
383        spawn_new(recvstream, None, &recvstreams_s);
384        let rp = QuicRecvProtocol::new(
385            QuicSink {
386                con: connection,
387                recvstreams_r,
388                recvstreams_s,
389                sendstreams_s,
390            },
391            metrics,
392        );
393        Ok(Protocols::Quic((sp, rp)))
394    }
395
396    pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) {
397        match self {
398            Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)),
399            Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)),
400            #[cfg(feature = "quic")]
401            Protocols::Quic((s, r)) => (SendProtocols::Quic(s), RecvProtocols::Quic(r)),
402        }
403    }
404}
405
406#[async_trait]
407impl network_protocol::InitProtocol for Protocols {
408    type CustomErr = ProtocolsError;
409
410    async fn initialize(
411        &mut self,
412        initializer: bool,
413        local_pid: Pid,
414        secret: u128,
415    ) -> Result<(Pid, Sid, u128), InitProtocolError<Self::CustomErr>> {
416        match self {
417            Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await,
418            Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await,
419            #[cfg(feature = "quic")]
420            Protocols::Quic(p) => p.initialize(initializer, local_pid, secret).await,
421        }
422    }
423}
424
425#[async_trait]
426impl network_protocol::SendProtocol for SendProtocols {
427    type CustomErr = ProtocolsError;
428
429    fn notify_from_recv(&mut self, event: ProtocolEvent) {
430        match self {
431            SendProtocols::Tcp(s) => s.notify_from_recv(event),
432            SendProtocols::Mpsc(s) => s.notify_from_recv(event),
433            #[cfg(feature = "quic")]
434            SendProtocols::Quic(s) => s.notify_from_recv(event),
435        }
436    }
437
438    async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
439        match self {
440            SendProtocols::Tcp(s) => s.send(event).await,
441            SendProtocols::Mpsc(s) => s.send(event).await,
442            #[cfg(feature = "quic")]
443            SendProtocols::Quic(s) => s.send(event).await,
444        }
445    }
446
447    async fn flush(
448        &mut self,
449        bandwidth: Bandwidth,
450        dt: Duration,
451    ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
452        match self {
453            SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await,
454            SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await,
455            #[cfg(feature = "quic")]
456            SendProtocols::Quic(s) => s.flush(bandwidth, dt).await,
457        }
458    }
459}
460
461#[async_trait]
462impl network_protocol::RecvProtocol for RecvProtocols {
463    type CustomErr = ProtocolsError;
464
465    async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
466        match self {
467            RecvProtocols::Tcp(r) => r.recv().await,
468            RecvProtocols::Mpsc(r) => r.recv().await,
469            #[cfg(feature = "quic")]
470            RecvProtocols::Quic(r) => r.recv().await,
471        }
472    }
473}
474
475#[derive(Debug)]
476pub enum MpscError {
477    Send(mpsc::error::SendError<MpscMsg>),
478    Recv,
479}
480
481#[cfg(feature = "quic")]
482#[derive(Debug)]
483pub enum QuicError {
484    Send(io::Error),
485    Connection(quinn::ConnectionError),
486    Write(quinn::WriteError),
487    Read(quinn::ReadError),
488    InternalMpsc,
489}
490
491/// Error types for Protocols
492#[derive(Debug)]
493pub enum ProtocolsError {
494    Tcp(io::Error),
495    Udp(io::Error),
496    #[cfg(feature = "quic")]
497    Quic(QuicError),
498    Mpsc(MpscError),
499}
500
501///////////////////////////////////////
502// TCP
503#[derive(Debug)]
504pub struct TcpDrain {
505    half: OwnedWriteHalf,
506}
507
508#[derive(Debug)]
509pub struct TcpSink {
510    half: OwnedReadHalf,
511    buffer: BytesMut,
512}
513
514#[async_trait]
515impl UnreliableDrain for TcpDrain {
516    type CustomErr = ProtocolsError;
517    type DataFormat = BytesMut;
518
519    async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
520        self.half
521            .write_all(&data)
522            .await
523            .map_err(|e| ProtocolError::Custom(ProtocolsError::Tcp(e)))
524    }
525}
526
527#[async_trait]
528impl UnreliableSink for TcpSink {
529    type CustomErr = ProtocolsError;
530    type DataFormat = BytesMut;
531
532    async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
533        if self.buffer.capacity() < 1500 {
534            self.buffer.reserve(1500 * 4); // reserve multiple, so that we alloc less often
535        }
536        match self.half.read_buf(&mut self.buffer).await {
537            Ok(0) => Err(ProtocolError::Custom(ProtocolsError::Tcp(io::Error::new(
538                io::ErrorKind::BrokenPipe,
539                "read returned 0 bytes",
540            )))),
541            Ok(_) => Ok(self.buffer.split()),
542            Err(e) => Err(ProtocolError::Custom(ProtocolsError::Tcp(e))),
543        }
544    }
545}
546
547///////////////////////////////////////
548// MPSC
549#[derive(Debug)]
550pub struct MpscDrain {
551    sender: mpsc::Sender<MpscMsg>,
552}
553
554#[derive(Debug)]
555pub struct MpscSink {
556    receiver: mpsc::Receiver<MpscMsg>,
557}
558
559#[async_trait]
560impl UnreliableDrain for MpscDrain {
561    type CustomErr = ProtocolsError;
562    type DataFormat = MpscMsg;
563
564    async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
565        self.sender
566            .send(data)
567            .await
568            .map_err(|e| ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Send(e))))
569    }
570}
571
572#[async_trait]
573impl UnreliableSink for MpscSink {
574    type CustomErr = ProtocolsError;
575    type DataFormat = MpscMsg;
576
577    async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
578        self.receiver
579            .recv()
580            .await
581            .ok_or(ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Recv)))
582    }
583}
584
585///////////////////////////////////////
586// QUIC
587#[cfg(feature = "quic")]
588type QuicStream = (
589    BytesMut,
590    Result<Option<usize>, quinn::ReadError>,
591    quinn::RecvStream,
592    Option<Sid>,
593);
594
595#[cfg(feature = "quic")]
596#[derive(Debug)]
597pub struct QuicDrain {
598    con: quinn::Connection,
599    main: quinn::SendStream,
600    reliables: HashMap<Sid, quinn::SendStream>,
601    recvstreams_s: mpsc::UnboundedSender<QuicStream>,
602    sendstreams_r: mpsc::UnboundedReceiver<quinn::SendStream>,
603}
604
605#[cfg(feature = "quic")]
606#[derive(Debug)]
607pub struct QuicSink {
608    con: quinn::Connection,
609    recvstreams_r: mpsc::UnboundedReceiver<QuicStream>,
610    recvstreams_s: mpsc::UnboundedSender<QuicStream>,
611    sendstreams_s: mpsc::UnboundedSender<quinn::SendStream>,
612}
613
614#[cfg(feature = "quic")]
615fn spawn_new(
616    mut recvstream: quinn::RecvStream,
617    sid: Option<Sid>,
618    streams_s: &mpsc::UnboundedSender<QuicStream>,
619) {
620    let streams_s_clone = streams_s.clone();
621    tokio::spawn(async move {
622        let mut buffer = BytesMut::new();
623        buffer.resize(1500, 0u8);
624        let r = recvstream.read(&mut buffer).await;
625        let _ = streams_s_clone.send((buffer, r, recvstream, sid));
626    });
627}
628
629#[cfg(feature = "quic")]
630#[async_trait]
631impl UnreliableDrain for QuicDrain {
632    type CustomErr = ProtocolsError;
633    type DataFormat = QuicDataFormat;
634
635    async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
636        match data.stream {
637            QuicDataFormatStream::Main => self.main.write_all(&data.data).await,
638            QuicDataFormatStream::Unreliable => unimplemented!(),
639            QuicDataFormatStream::Reliable(sid) => {
640                use hashbrown::hash_map::Entry;
641                //tracing::trace!(?sid, "Reliable");
642                match self.reliables.entry(sid) {
643                    Entry::Occupied(mut occupied) => occupied.get_mut().write_all(&data.data).await,
644                    Entry::Vacant(vacant) => {
645                        // IF the buffer is empty this was created locally and WE are allowed to
646                        // open_bi(), if not, we NEED to block on sendstreams_r
647                        if data.data.is_empty() {
648                            let (mut sendstream, recvstream) =
649                                self.con.open_bi().await.map_err(|e| {
650                                    ProtocolError::Custom(ProtocolsError::Quic(
651                                        QuicError::Connection(e),
652                                    ))
653                                })?;
654                            // send SID as first msg
655                            sendstream.write_u64(sid.get_u64()).await.map_err(|e| {
656                                ProtocolError::Custom(ProtocolsError::Quic(QuicError::Send(e)))
657                            })?;
658                            spawn_new(recvstream, Some(sid), &self.recvstreams_s);
659                            vacant.insert(sendstream).write_all(&data.data).await
660                        } else {
661                            let sendstream =
662                                self.sendstreams_r
663                                    .recv()
664                                    .await
665                                    .ok_or(ProtocolError::Custom(ProtocolsError::Quic(
666                                        QuicError::InternalMpsc,
667                                    )))?;
668                            vacant.insert(sendstream).write_all(&data.data).await
669                        }
670                    },
671                }
672            },
673        }
674        .map_err(|e| ProtocolError::Custom(ProtocolsError::Quic(QuicError::Write(e))))
675    }
676}
677
678#[cfg(feature = "quic")]
679#[async_trait]
680impl UnreliableSink for QuicSink {
681    type CustomErr = ProtocolsError;
682    type DataFormat = QuicDataFormat;
683
684    async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
685        let (mut buffer, result, mut recvstream, id) = loop {
686            use futures_util::FutureExt;
687            // first handle all bi streams!
688            let (a, b) = select! {
689                biased;
690                n = self.con.accept_bi().fuse() => (Some(n), None),
691                Some(n) = self.recvstreams_r.recv().fuse() => (None, Some(n)),
692            };
693
694            if let Some(remote_stream) = a {
695                let (sendstream, mut recvstream) = remote_stream.map_err(|e| {
696                    ProtocolError::Custom(ProtocolsError::Quic(QuicError::Connection(e)))
697                })?;
698                let sid = match recvstream.read_u64().await {
699                    Ok(u64::MAX) => None, //unreliable
700                    Ok(sid) => Some(Sid::new(sid)),
701                    Err(_) => return Err(ProtocolError::Violated),
702                };
703                if self.sendstreams_s.send(sendstream).is_err() {
704                    return Err(ProtocolError::Custom(ProtocolsError::Quic(
705                        QuicError::InternalMpsc,
706                    )));
707                }
708                spawn_new(recvstream, sid, &self.recvstreams_s);
709            }
710
711            if let Some(data) = b {
712                break data;
713            }
714        };
715
716        let r = match result {
717            Ok(Some(0)) => Err(ProtocolError::Custom(ProtocolsError::Quic(
718                QuicError::Send(io::Error::new(
719                    io::ErrorKind::BrokenPipe,
720                    "read returned 0 bytes",
721                )),
722            ))),
723            Ok(Some(n)) => Ok(QuicDataFormat {
724                stream: match id {
725                    Some(id) => QuicDataFormatStream::Reliable(id),
726                    None => QuicDataFormatStream::Main,
727                },
728                data: buffer.split_to(n),
729            }),
730            Ok(None) => Err(ProtocolError::Custom(ProtocolsError::Quic(
731                QuicError::Send(io::Error::new(
732                    io::ErrorKind::BrokenPipe,
733                    "read returned None",
734                )),
735            ))),
736            Err(e) => Err(ProtocolError::Custom(ProtocolsError::Quic(
737                QuicError::Read(e),
738            ))),
739        }?;
740
741        let streams_s_clone = self.recvstreams_s.clone();
742        tokio::spawn(async move {
743            buffer.resize(1500, 0u8);
744            let r = recvstream.read(&mut buffer).await;
745            let _ = streams_s_clone.send((buffer, r, recvstream, id));
746        });
747        Ok(r)
748    }
749}
750
751#[cfg(test)]
752mod tests {
753    use super::*;
754    use bytes::Bytes;
755    use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol};
756    use std::sync::Arc;
757    use tokio::net::{TcpListener, TcpStream};
758
759    #[tokio::test]
760    async fn tokio_sinks() {
761        let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap();
762        let r1 = tokio::spawn(async move {
763            let (server, _) = listener.accept().await.unwrap();
764            (listener, server)
765        });
766        let client = TcpStream::connect("127.0.0.1:5000").await.unwrap();
767        let (_listener, server) = r1.await.unwrap();
768        let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
769        let client = Protocols::new_tcp(client, metrics.clone());
770        let server = Protocols::new_tcp(server, metrics);
771        let (mut s, _) = client.split();
772        let (_, mut r) = server.split();
773        let event = ProtocolEvent::OpenStream {
774            sid: Sid::new(1),
775            prio: 4u8,
776            promises: Promises::GUARANTEED_DELIVERY,
777            guaranteed_bandwidth: 1_000,
778        };
779        s.send(event.clone()).await.unwrap();
780        s.send(ProtocolEvent::Message {
781            sid: Sid::new(1),
782            data: Bytes::from(&[8u8; 8][..]),
783        })
784        .await
785        .unwrap();
786        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
787        drop(s); // recv must work even after shutdown of send!
788        tokio::time::sleep(Duration::from_secs(1)).await;
789        let res = r.recv().await;
790        match res {
791            Ok(ProtocolEvent::OpenStream {
792                sid,
793                prio,
794                promises,
795                guaranteed_bandwidth: _,
796            }) => {
797                assert_eq!(sid, Sid::new(1));
798                assert_eq!(prio, 4u8);
799                assert_eq!(promises, Promises::GUARANTEED_DELIVERY);
800            },
801            _ => {
802                panic!("wrong type {:?}", res);
803            },
804        }
805        r.recv().await.unwrap();
806    }
807
808    #[tokio::test]
809    async fn tokio_sink_stop_after_drop() {
810        let listener = TcpListener::bind("127.0.0.1:5001").await.unwrap();
811        let r1 = tokio::spawn(async move {
812            let (server, _) = listener.accept().await.unwrap();
813            (listener, server)
814        });
815        let client = TcpStream::connect("127.0.0.1:5001").await.unwrap();
816        let (_listener, server) = r1.await.unwrap();
817        let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
818        let client = Protocols::new_tcp(client, metrics.clone());
819        let server = Protocols::new_tcp(server, metrics);
820        let (s, _) = client.split();
821        let (_, mut r) = server.split();
822        let e = tokio::spawn(async move { r.recv().await });
823        drop(s);
824        let e = e.await.unwrap();
825        assert!(e.is_err());
826        assert!(matches!(e, Err(..)));
827        let e = e.unwrap_err();
828        assert!(matches!(e, ProtocolError::Custom(..)));
829        assert!(matches!(e, ProtocolError::Custom(ProtocolsError::Tcp(_))));
830        match e {
831            ProtocolError::Custom(ProtocolsError::Tcp(e)) => {
832                assert_eq!(e.kind(), io::ErrorKind::BrokenPipe)
833            },
834            _ => panic!("invalid error"),
835        }
836    }
837}