veloren_network/
channel.rs

1use crate::api::{ConnectAddr, NetworkConnectError};
2use async_trait::async_trait;
3use bytes::BytesMut;
4use futures_util::FutureExt;
5use hashbrown::HashMap;
6use network_protocol::{
7    Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid,
8    ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol,
9    TcpSendProtocol, UnreliableDrain, UnreliableSink,
10};
11#[cfg(feature = "quic")]
12use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol};
13use std::{
14    io,
15    net::SocketAddr,
16    sync::{
17        Arc,
18        atomic::{AtomicU64, Ordering},
19    },
20    time::Duration,
21};
22use tokio::{
23    io::{AsyncReadExt, AsyncWriteExt},
24    net,
25    net::tcp::{OwnedReadHalf, OwnedWriteHalf},
26    select,
27    sync::{Mutex, mpsc, oneshot},
28};
29use tracing::{error, info, trace, warn};
30
31#[derive(Debug)]
32pub(crate) enum Protocols {
33    Tcp((TcpSendProtocol<TcpDrain>, TcpRecvProtocol<TcpSink>)),
34    Mpsc((MpscSendProtocol<MpscDrain>, MpscRecvProtocol<MpscSink>)),
35    #[cfg(feature = "quic")]
36    Quic((QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>)),
37}
38
39#[derive(Debug)]
40pub(crate) enum SendProtocols {
41    Tcp(TcpSendProtocol<TcpDrain>),
42    Mpsc(MpscSendProtocol<MpscDrain>),
43    #[cfg(feature = "quic")]
44    Quic(QuicSendProtocol<QuicDrain>),
45}
46
47#[derive(Debug)]
48pub(crate) enum RecvProtocols {
49    Tcp(TcpRecvProtocol<TcpSink>),
50    Mpsc(MpscRecvProtocol<MpscSink>),
51    #[cfg(feature = "quic")]
52    Quic(QuicRecvProtocol<QuicSink>),
53}
54
55lazy_static::lazy_static! {
56    pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<C2cMpscConnect>>> = {
57        Mutex::new(HashMap::new())
58    };
59}
60
61pub(crate) type C2cMpscConnect = (
62    mpsc::Sender<MpscMsg>,
63    oneshot::Sender<mpsc::Sender<MpscMsg>>,
64);
65pub(crate) type C2sProtocol = (Protocols, ConnectAddr, Cid);
66
67fn anonymize_addr(addr: &SocketAddr) -> String {
68    use std::net::IpAddr;
69    match addr.ip() {
70        IpAddr::V4(ip) => {
71            let [o0, _, o2, _] = ip.octets();
72            format!("{o0}.xxx.{o2}.xxx:{}", addr.port())
73        },
74        IpAddr::V6(ip) => {
75            let [s0, s1, _, _, s4, s5, _, _] = ip.segments();
76            format!(
77                "[{s0:04x}:{s1:04x}:xxxx:xxxx:{s4:04x}:{s5:04x}:xxxx:xxxx]:{}",
78                addr.port()
79            )
80        },
81    }
82}
83
84impl Protocols {
85    const MPSC_CHANNEL_BOUND: usize = 1000;
86
87    pub(crate) async fn with_tcp_connect(
88        addr: SocketAddr,
89        metrics: ProtocolMetricCache,
90    ) -> Result<Self, NetworkConnectError> {
91        let stream = net::TcpStream::connect(addr)
92            .await
93            .and_then(|s| {
94                s.set_nodelay(true)?;
95                Ok(s)
96            })
97            .map_err(NetworkConnectError::Io)?;
98        info!(
99            "Connecting Tcp to: {}",
100            stream.peer_addr().map_err(NetworkConnectError::Io)?
101        );
102        Ok(Self::new_tcp(stream, metrics))
103    }
104
105    pub(crate) async fn with_tcp_listen(
106        addr: SocketAddr,
107        cids: Arc<AtomicU64>,
108        metrics: Arc<ProtocolMetrics>,
109        s2s_stop_listening_r: oneshot::Receiver<()>,
110        c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
111    ) -> io::Result<()> {
112        use socket2::{Domain, Socket, Type};
113        let domain = Domain::for_address(addr);
114        let socket2_socket = Socket::new(domain, Type::STREAM, None)?;
115        if domain == Domain::IPV6 {
116            socket2_socket.set_only_v6(true)?
117        }
118        socket2_socket.set_nonblocking(true)?; // Needed by Tokio
119        // See https://docs.rs/tokio/latest/tokio/net/struct.TcpSocket.html
120        #[cfg(not(windows))]
121        socket2_socket.set_reuse_address(true)?;
122        const SEND_BUFFER_SIZE: usize = 262144;
123        const RECV_BUFFER_SIZE: usize = SEND_BUFFER_SIZE * 2;
124        if let Err(e) = socket2_socket.set_recv_buffer_size(RECV_BUFFER_SIZE) {
125            warn!(?e, "Couldn't set recv_buffer size")
126        };
127        if let Err(e) = socket2_socket.set_send_buffer_size(SEND_BUFFER_SIZE) {
128            warn!(?e, "Couldn't set set_buffer size")
129        };
130        let socket2_addr = addr.into();
131        socket2_socket.bind(&socket2_addr)?;
132        socket2_socket.listen(1024)?;
133        let std_listener: std::net::TcpListener = socket2_socket.into();
134        let listener = net::TcpListener::from_std(std_listener)?;
135        trace!(?addr, "Tcp Listener bound");
136        let mut end_receiver = s2s_stop_listening_r.fuse();
137        tokio::spawn(async move {
138            while let Some(data) = select! {
139                    next = listener.accept().fuse() => Some(next),
140                    _ = &mut end_receiver => None,
141            } {
142                let (stream, remote_addr) = match data {
143                    Ok((s, p)) => (s, p),
144                    Err(e) => {
145                        trace!(?e, "TcpStream Error, ignoring connection attempt");
146                        continue;
147                    },
148                };
149                if let Err(e) = stream.set_nodelay(true) {
150                    warn!(
151                        ?e,
152                        "Failed to set TCP_NODELAY, client may have degraded latency"
153                    );
154                }
155                let cid = cids.fetch_add(1, Ordering::Relaxed);
156                info!(
157                    remote_addr = anonymize_addr(&remote_addr),
158                    ?cid,
159                    "Accepting Tcp from"
160                );
161                let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
162                let _ = c2s_protocol_s.send((
163                    Self::new_tcp(stream, metrics.clone()),
164                    ConnectAddr::Tcp(remote_addr),
165                    cid,
166                ));
167            }
168        });
169        Ok(())
170    }
171
172    pub(crate) fn new_tcp(stream: net::TcpStream, metrics: ProtocolMetricCache) -> Self {
173        let (r, w) = stream.into_split();
174        let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone());
175        let rp = TcpRecvProtocol::new(
176            TcpSink {
177                half: r,
178                buffer: BytesMut::new(),
179            },
180            metrics,
181        );
182        Protocols::Tcp((sp, rp))
183    }
184
185    pub(crate) async fn with_mpsc_connect(
186        addr: u64,
187        metrics: ProtocolMetricCache,
188    ) -> Result<Self, NetworkConnectError> {
189        let mpsc_s = MPSC_POOL
190            .lock()
191            .await
192            .get(&addr)
193            .ok_or_else(|| {
194                NetworkConnectError::Io(io::Error::new(
195                    io::ErrorKind::NotConnected,
196                    "no mpsc listen on this addr",
197                ))
198            })?
199            .clone();
200        let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
201        let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
202        mpsc_s
203            .send((remote_to_local_s, local_to_remote_oneshot_s))
204            .map_err(|_| {
205                NetworkConnectError::Io(io::Error::new(
206                    io::ErrorKind::BrokenPipe,
207                    "mpsc pipe broke during connect",
208                ))
209            })?;
210        let local_to_remote_s = local_to_remote_oneshot_r
211            .await
212            .map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
213        info!(?addr, "Connecting Mpsc");
214        Ok(Self::new_mpsc(
215            local_to_remote_s,
216            remote_to_local_r,
217            metrics,
218        ))
219    }
220
221    pub(crate) async fn with_mpsc_listen(
222        addr: u64,
223        cids: Arc<AtomicU64>,
224        metrics: Arc<ProtocolMetrics>,
225        s2s_stop_listening_r: oneshot::Receiver<()>,
226        c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
227    ) -> io::Result<()> {
228        let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
229        MPSC_POOL.lock().await.insert(addr, mpsc_s);
230        trace!(?addr, "Mpsc Listener bound");
231        let mut end_receiver = s2s_stop_listening_r.fuse();
232        tokio::spawn(async move {
233            while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
234                    next = mpsc_r.recv().fuse() => next,
235                    _ = &mut end_receiver => None,
236            } {
237                let (remote_to_local_s, remote_to_local_r) =
238                    mpsc::channel(Self::MPSC_CHANNEL_BOUND);
239                if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) {
240                    error!(?e, "mpsc listen aborted");
241                }
242
243                let cid = cids.fetch_add(1, Ordering::Relaxed);
244                info!(?addr, ?cid, "Accepting Mpsc from");
245                let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
246                let _ = c2s_protocol_s.send((
247                    Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()),
248                    ConnectAddr::Mpsc(addr),
249                    cid,
250                ));
251            }
252            warn!("MpscStream Failed, stopping");
253        });
254        Ok(())
255    }
256
257    pub(crate) fn new_mpsc(
258        sender: mpsc::Sender<MpscMsg>,
259        receiver: mpsc::Receiver<MpscMsg>,
260        metrics: ProtocolMetricCache,
261    ) -> Self {
262        let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone());
263        let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics);
264        Protocols::Mpsc((sp, rp))
265    }
266
267    #[cfg(feature = "quic")]
268    pub(crate) async fn with_quic_connect(
269        addr: SocketAddr,
270        config: quinn::ClientConfig,
271        name: String,
272        metrics: ProtocolMetricCache,
273    ) -> Result<Self, NetworkConnectError> {
274        let config = config.clone();
275
276        use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
277
278        let bindsock = match addr {
279            SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
280            SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
281        };
282        let endpoint = match quinn::Endpoint::client(bindsock) {
283            Ok(e) => e,
284            Err(e) => return Err(NetworkConnectError::Io(e)),
285        };
286
287        info!("Connecting Quic to: {}", &addr);
288        let connecting = endpoint.connect_with(config, addr, &name).map_err(|e| {
289            trace!(?e, "error setting up quic");
290            NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
291        })?;
292        let connection = connecting.await.map_err(|e| {
293            trace!(?e, "error with quic connection");
294            NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
295        })?;
296        Self::new_quic(connection, false, metrics)
297            .await
298            .map_err(|e| {
299                trace!(?e, "error with quic");
300                NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
301            })
302    }
303
304    #[cfg(feature = "quic")]
305    pub(crate) async fn with_quic_listen(
306        addr: SocketAddr,
307        server_config: quinn::ServerConfig,
308        cids: Arc<AtomicU64>,
309        metrics: Arc<ProtocolMetrics>,
310        s2s_stop_listening_r: oneshot::Receiver<()>,
311        c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
312    ) -> io::Result<()> {
313        let endpoint = quinn::Endpoint::server(server_config, addr)?;
314        trace!(?addr, "Quic Listener bound");
315        let mut end_receiver = s2s_stop_listening_r.fuse();
316        let config = quinn::ClientConfig::try_with_platform_verifier()
317            .map_err(|e| io::Error::other(Box::new(e)))?;
318        tokio::spawn(async move {
319            while let Some(Some(connecting)) = select! {
320                next = endpoint.accept().fuse() => Some(next),
321                _ = &mut end_receiver => None,
322            } {
323                let remote_addr = anonymize_addr(&connecting.remote_address());
324                let connection = match connecting.await {
325                    Ok(c) => c,
326                    Err(e) => {
327                        tracing::debug!(?e, ?remote_addr, "skipping connection attempt");
328                        continue;
329                    },
330                };
331
332                let cid = cids.fetch_add(1, Ordering::Relaxed);
333                info!(?remote_addr, ?cid, "Accepting Quic from");
334                let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
335                match Protocols::new_quic(connection, true, metrics).await {
336                    Ok(quic) => {
337                        // TODO: we cannot guess the client hostname in quic server here.
338                        // though we need it for the certificate to be validated, in the future
339                        // this will either go away with new auth, or we have to do something like
340                        // a reverse DNS lookup
341                        let connect_addr = ConnectAddr::Quic(
342                            addr,
343                            config.clone(),
344                            "TODO_remote_hostname".to_string(),
345                        );
346                        let _ = c2s_protocol_s.send((quic, connect_addr, cid));
347                    },
348                    Err(e) => {
349                        trace!(?e, "failed to start quic");
350                        continue;
351                    },
352                }
353            }
354        });
355        Ok(())
356    }
357
358    #[cfg(feature = "quic")]
359    pub(crate) async fn new_quic(
360        connection: quinn::Connection,
361        listen: bool,
362        metrics: ProtocolMetricCache,
363    ) -> Result<Self, quinn::ConnectionError> {
364        let (sendstream, recvstream) = if listen {
365            connection.open_bi().await?
366        } else {
367            connection
368                .accept_bi()
369                .await
370                .or(Err(quinn::ConnectionError::LocallyClosed))?
371        };
372        let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel();
373        let streams_s_clone = recvstreams_s.clone();
374        let (sendstreams_s, sendstreams_r) = mpsc::unbounded_channel();
375        let sp = QuicSendProtocol::new(
376            QuicDrain {
377                con: connection.clone(),
378                main: sendstream,
379                reliables: HashMap::new(),
380                recvstreams_s: streams_s_clone,
381                sendstreams_r,
382            },
383            metrics.clone(),
384        );
385        spawn_new(recvstream, None, &recvstreams_s);
386        let rp = QuicRecvProtocol::new(
387            QuicSink {
388                con: connection,
389                recvstreams_r,
390                recvstreams_s,
391                sendstreams_s,
392            },
393            metrics,
394        );
395        Ok(Protocols::Quic((sp, rp)))
396    }
397
398    pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) {
399        match self {
400            Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)),
401            Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)),
402            #[cfg(feature = "quic")]
403            Protocols::Quic((s, r)) => (SendProtocols::Quic(s), RecvProtocols::Quic(r)),
404        }
405    }
406}
407
408#[async_trait]
409impl network_protocol::InitProtocol for Protocols {
410    type CustomErr = ProtocolsError;
411
412    async fn initialize(
413        &mut self,
414        initializer: bool,
415        local_pid: Pid,
416        secret: u128,
417    ) -> Result<(Pid, Sid, u128), InitProtocolError<Self::CustomErr>> {
418        match self {
419            Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await,
420            Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await,
421            #[cfg(feature = "quic")]
422            Protocols::Quic(p) => p.initialize(initializer, local_pid, secret).await,
423        }
424    }
425}
426
427#[async_trait]
428impl network_protocol::SendProtocol for SendProtocols {
429    type CustomErr = ProtocolsError;
430
431    fn notify_from_recv(&mut self, event: ProtocolEvent) {
432        match self {
433            SendProtocols::Tcp(s) => s.notify_from_recv(event),
434            SendProtocols::Mpsc(s) => s.notify_from_recv(event),
435            #[cfg(feature = "quic")]
436            SendProtocols::Quic(s) => s.notify_from_recv(event),
437        }
438    }
439
440    async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
441        match self {
442            SendProtocols::Tcp(s) => s.send(event).await,
443            SendProtocols::Mpsc(s) => s.send(event).await,
444            #[cfg(feature = "quic")]
445            SendProtocols::Quic(s) => s.send(event).await,
446        }
447    }
448
449    async fn flush(
450        &mut self,
451        bandwidth: Bandwidth,
452        dt: Duration,
453    ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
454        match self {
455            SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await,
456            SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await,
457            #[cfg(feature = "quic")]
458            SendProtocols::Quic(s) => s.flush(bandwidth, dt).await,
459        }
460    }
461}
462
463#[async_trait]
464impl network_protocol::RecvProtocol for RecvProtocols {
465    type CustomErr = ProtocolsError;
466
467    async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
468        match self {
469            RecvProtocols::Tcp(r) => r.recv().await,
470            RecvProtocols::Mpsc(r) => r.recv().await,
471            #[cfg(feature = "quic")]
472            RecvProtocols::Quic(r) => r.recv().await,
473        }
474    }
475}
476
477#[derive(Debug)]
478pub enum MpscError {
479    Send(mpsc::error::SendError<MpscMsg>),
480    Recv,
481}
482
483#[cfg(feature = "quic")]
484#[derive(Debug)]
485pub enum QuicError {
486    Send(io::Error),
487    Connection(quinn::ConnectionError),
488    Write(quinn::WriteError),
489    Read(quinn::ReadError),
490    InternalMpsc,
491}
492
493/// Error types for Protocols
494#[derive(Debug)]
495pub enum ProtocolsError {
496    Tcp(io::Error),
497    Udp(io::Error),
498    #[cfg(feature = "quic")]
499    Quic(QuicError),
500    Mpsc(MpscError),
501}
502
503///////////////////////////////////////
504// TCP
505#[derive(Debug)]
506pub struct TcpDrain {
507    half: OwnedWriteHalf,
508}
509
510#[derive(Debug)]
511pub struct TcpSink {
512    half: OwnedReadHalf,
513    buffer: BytesMut,
514}
515
516#[async_trait]
517impl UnreliableDrain for TcpDrain {
518    type CustomErr = ProtocolsError;
519    type DataFormat = BytesMut;
520
521    async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
522        self.half
523            .write_all(&data)
524            .await
525            .map_err(|e| ProtocolError::Custom(ProtocolsError::Tcp(e)))
526    }
527}
528
529#[async_trait]
530impl UnreliableSink for TcpSink {
531    type CustomErr = ProtocolsError;
532    type DataFormat = BytesMut;
533
534    async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
535        if self.buffer.capacity() < 1500 {
536            self.buffer.reserve(1500 * 4); // reserve multiple, so that we alloc less often
537        }
538        match self.half.read_buf(&mut self.buffer).await {
539            Ok(0) => Err(ProtocolError::Custom(ProtocolsError::Tcp(io::Error::new(
540                io::ErrorKind::BrokenPipe,
541                "read returned 0 bytes",
542            )))),
543            Ok(_) => Ok(self.buffer.split()),
544            Err(e) => Err(ProtocolError::Custom(ProtocolsError::Tcp(e))),
545        }
546    }
547}
548
549///////////////////////////////////////
550// MPSC
551#[derive(Debug)]
552pub struct MpscDrain {
553    sender: mpsc::Sender<MpscMsg>,
554}
555
556#[derive(Debug)]
557pub struct MpscSink {
558    receiver: mpsc::Receiver<MpscMsg>,
559}
560
561#[async_trait]
562impl UnreliableDrain for MpscDrain {
563    type CustomErr = ProtocolsError;
564    type DataFormat = MpscMsg;
565
566    async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
567        self.sender
568            .send(data)
569            .await
570            .map_err(|e| ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Send(e))))
571    }
572}
573
574#[async_trait]
575impl UnreliableSink for MpscSink {
576    type CustomErr = ProtocolsError;
577    type DataFormat = MpscMsg;
578
579    async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
580        self.receiver
581            .recv()
582            .await
583            .ok_or(ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Recv)))
584    }
585}
586
587///////////////////////////////////////
588// QUIC
589#[cfg(feature = "quic")]
590type QuicStream = (
591    BytesMut,
592    Result<Option<usize>, quinn::ReadError>,
593    quinn::RecvStream,
594    Option<Sid>,
595);
596
597#[cfg(feature = "quic")]
598#[derive(Debug)]
599pub struct QuicDrain {
600    con: quinn::Connection,
601    main: quinn::SendStream,
602    reliables: HashMap<Sid, quinn::SendStream>,
603    recvstreams_s: mpsc::UnboundedSender<QuicStream>,
604    sendstreams_r: mpsc::UnboundedReceiver<quinn::SendStream>,
605}
606
607#[cfg(feature = "quic")]
608#[derive(Debug)]
609pub struct QuicSink {
610    con: quinn::Connection,
611    recvstreams_r: mpsc::UnboundedReceiver<QuicStream>,
612    recvstreams_s: mpsc::UnboundedSender<QuicStream>,
613    sendstreams_s: mpsc::UnboundedSender<quinn::SendStream>,
614}
615
616#[cfg(feature = "quic")]
617fn spawn_new(
618    mut recvstream: quinn::RecvStream,
619    sid: Option<Sid>,
620    streams_s: &mpsc::UnboundedSender<QuicStream>,
621) {
622    let streams_s_clone = streams_s.clone();
623    tokio::spawn(async move {
624        let mut buffer = BytesMut::new();
625        buffer.resize(1500, 0u8);
626        let r = recvstream.read(&mut buffer).await;
627        let _ = streams_s_clone.send((buffer, r, recvstream, sid));
628    });
629}
630
631#[cfg(feature = "quic")]
632#[async_trait]
633impl UnreliableDrain for QuicDrain {
634    type CustomErr = ProtocolsError;
635    type DataFormat = QuicDataFormat;
636
637    async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
638        match data.stream {
639            QuicDataFormatStream::Main => self.main.write_all(&data.data).await,
640            QuicDataFormatStream::Unreliable => unimplemented!(),
641            QuicDataFormatStream::Reliable(sid) => {
642                use hashbrown::hash_map::Entry;
643                //tracing::trace!(?sid, "Reliable");
644                match self.reliables.entry(sid) {
645                    Entry::Occupied(mut occupied) => occupied.get_mut().write_all(&data.data).await,
646                    Entry::Vacant(vacant) => {
647                        // IF the buffer is empty this was created locally and WE are allowed to
648                        // open_bi(), if not, we NEED to block on sendstreams_r
649                        if data.data.is_empty() {
650                            let (mut sendstream, recvstream) =
651                                self.con.open_bi().await.map_err(|e| {
652                                    ProtocolError::Custom(ProtocolsError::Quic(
653                                        QuicError::Connection(e),
654                                    ))
655                                })?;
656                            // send SID as first msg
657                            sendstream.write_u64(sid.get_u64()).await.map_err(|e| {
658                                ProtocolError::Custom(ProtocolsError::Quic(QuicError::Send(e)))
659                            })?;
660                            spawn_new(recvstream, Some(sid), &self.recvstreams_s);
661                            vacant.insert(sendstream).write_all(&data.data).await
662                        } else {
663                            let sendstream =
664                                self.sendstreams_r
665                                    .recv()
666                                    .await
667                                    .ok_or(ProtocolError::Custom(ProtocolsError::Quic(
668                                        QuicError::InternalMpsc,
669                                    )))?;
670                            vacant.insert(sendstream).write_all(&data.data).await
671                        }
672                    },
673                }
674            },
675        }
676        .map_err(|e| ProtocolError::Custom(ProtocolsError::Quic(QuicError::Write(e))))
677    }
678}
679
680#[cfg(feature = "quic")]
681#[async_trait]
682impl UnreliableSink for QuicSink {
683    type CustomErr = ProtocolsError;
684    type DataFormat = QuicDataFormat;
685
686    async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
687        let (mut buffer, result, mut recvstream, id) = loop {
688            use futures_util::FutureExt;
689            // first handle all bi streams!
690            let (a, b) = select! {
691                biased;
692                n = self.con.accept_bi().fuse() => (Some(n), None),
693                Some(n) = self.recvstreams_r.recv().fuse() => (None, Some(n)),
694            };
695
696            if let Some(remote_stream) = a {
697                let (sendstream, mut recvstream) = remote_stream.map_err(|e| {
698                    ProtocolError::Custom(ProtocolsError::Quic(QuicError::Connection(e)))
699                })?;
700                let sid = match recvstream.read_u64().await {
701                    Ok(u64::MAX) => None, //unreliable
702                    Ok(sid) => Some(Sid::new(sid)),
703                    Err(_) => return Err(ProtocolError::Violated),
704                };
705                if self.sendstreams_s.send(sendstream).is_err() {
706                    return Err(ProtocolError::Custom(ProtocolsError::Quic(
707                        QuicError::InternalMpsc,
708                    )));
709                }
710                spawn_new(recvstream, sid, &self.recvstreams_s);
711            }
712
713            if let Some(data) = b {
714                break data;
715            }
716        };
717
718        let r = match result {
719            Ok(Some(0)) => Err(ProtocolError::Custom(ProtocolsError::Quic(
720                QuicError::Send(io::Error::new(
721                    io::ErrorKind::BrokenPipe,
722                    "read returned 0 bytes",
723                )),
724            ))),
725            Ok(Some(n)) => Ok(QuicDataFormat {
726                stream: match id {
727                    Some(id) => QuicDataFormatStream::Reliable(id),
728                    None => QuicDataFormatStream::Main,
729                },
730                data: buffer.split_to(n),
731            }),
732            Ok(None) => Err(ProtocolError::Custom(ProtocolsError::Quic(
733                QuicError::Send(io::Error::new(
734                    io::ErrorKind::BrokenPipe,
735                    "read returned None",
736                )),
737            ))),
738            Err(e) => Err(ProtocolError::Custom(ProtocolsError::Quic(
739                QuicError::Read(e),
740            ))),
741        }?;
742
743        let streams_s_clone = self.recvstreams_s.clone();
744        tokio::spawn(async move {
745            buffer.resize(1500, 0u8);
746            let r = recvstream.read(&mut buffer).await;
747            let _ = streams_s_clone.send((buffer, r, recvstream, id));
748        });
749        Ok(r)
750    }
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756    use bytes::Bytes;
757    use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol};
758    use std::sync::Arc;
759    use tokio::net::{TcpListener, TcpStream};
760
761    #[tokio::test]
762    async fn tokio_sinks() {
763        let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap();
764        let r1 = tokio::spawn(async move {
765            let (server, _) = listener.accept().await.unwrap();
766            (listener, server)
767        });
768        let client = TcpStream::connect("127.0.0.1:5000").await.unwrap();
769        let (_listener, server) = r1.await.unwrap();
770        let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
771        let client = Protocols::new_tcp(client, metrics.clone());
772        let server = Protocols::new_tcp(server, metrics);
773        let (mut s, _) = client.split();
774        let (_, mut r) = server.split();
775        let event = ProtocolEvent::OpenStream {
776            sid: Sid::new(1),
777            prio: 4u8,
778            promises: Promises::GUARANTEED_DELIVERY,
779            guaranteed_bandwidth: 1_000,
780        };
781        s.send(event.clone()).await.unwrap();
782        s.send(ProtocolEvent::Message {
783            sid: Sid::new(1),
784            data: Bytes::from(&[8u8; 8][..]),
785        })
786        .await
787        .unwrap();
788        s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
789        drop(s); // recv must work even after shutdown of send!
790        tokio::time::sleep(Duration::from_secs(1)).await;
791        let res = r.recv().await;
792        match res {
793            Ok(ProtocolEvent::OpenStream {
794                sid,
795                prio,
796                promises,
797                guaranteed_bandwidth: _,
798            }) => {
799                assert_eq!(sid, Sid::new(1));
800                assert_eq!(prio, 4u8);
801                assert_eq!(promises, Promises::GUARANTEED_DELIVERY);
802            },
803            _ => {
804                panic!("wrong type {:?}", res);
805            },
806        }
807        r.recv().await.unwrap();
808    }
809
810    #[tokio::test]
811    async fn tokio_sink_stop_after_drop() {
812        let listener = TcpListener::bind("127.0.0.1:5001").await.unwrap();
813        let r1 = tokio::spawn(async move {
814            let (server, _) = listener.accept().await.unwrap();
815            (listener, server)
816        });
817        let client = TcpStream::connect("127.0.0.1:5001").await.unwrap();
818        let (_listener, server) = r1.await.unwrap();
819        let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
820        let client = Protocols::new_tcp(client, metrics.clone());
821        let server = Protocols::new_tcp(server, metrics);
822        let (s, _) = client.split();
823        let (_, mut r) = server.split();
824        let e = tokio::spawn(async move { r.recv().await });
825        drop(s);
826        let e = e.await.unwrap();
827        assert!(e.is_err());
828        assert!(matches!(e, Err(..)));
829        let e = e.unwrap_err();
830        assert!(matches!(e, ProtocolError::Custom(..)));
831        assert!(matches!(e, ProtocolError::Custom(ProtocolsError::Tcp(_))));
832        match e {
833            ProtocolError::Custom(ProtocolsError::Tcp(e)) => {
834                assert_eq!(e.kind(), io::ErrorKind::BrokenPipe)
835            },
836            _ => panic!("invalid error"),
837        }
838    }
839}