1use crate::api::{ConnectAddr, NetworkConnectError};
2use async_trait::async_trait;
3use bytes::BytesMut;
4use futures_util::FutureExt;
5use hashbrown::HashMap;
6use network_protocol::{
7 Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid,
8 ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol,
9 TcpSendProtocol, UnreliableDrain, UnreliableSink,
10};
11#[cfg(feature = "quic")]
12use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol};
13use std::{
14 io,
15 net::SocketAddr,
16 sync::{
17 Arc,
18 atomic::{AtomicU64, Ordering},
19 },
20 time::Duration,
21};
22use tokio::{
23 io::{AsyncReadExt, AsyncWriteExt},
24 net,
25 net::tcp::{OwnedReadHalf, OwnedWriteHalf},
26 select,
27 sync::{Mutex, mpsc, oneshot},
28};
29use tracing::{error, info, trace, warn};
30
31#[derive(Debug)]
32pub(crate) enum Protocols {
33 Tcp((TcpSendProtocol<TcpDrain>, TcpRecvProtocol<TcpSink>)),
34 Mpsc((MpscSendProtocol<MpscDrain>, MpscRecvProtocol<MpscSink>)),
35 #[cfg(feature = "quic")]
36 Quic((QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>)),
37}
38
39#[derive(Debug)]
40pub(crate) enum SendProtocols {
41 Tcp(TcpSendProtocol<TcpDrain>),
42 Mpsc(MpscSendProtocol<MpscDrain>),
43 #[cfg(feature = "quic")]
44 Quic(QuicSendProtocol<QuicDrain>),
45}
46
47#[derive(Debug)]
48pub(crate) enum RecvProtocols {
49 Tcp(TcpRecvProtocol<TcpSink>),
50 Mpsc(MpscRecvProtocol<MpscSink>),
51 #[cfg(feature = "quic")]
52 Quic(QuicRecvProtocol<QuicSink>),
53}
54
55lazy_static::lazy_static! {
56 pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<C2cMpscConnect>>> = {
57 Mutex::new(HashMap::new())
58 };
59}
60
61pub(crate) type C2cMpscConnect = (
62 mpsc::Sender<MpscMsg>,
63 oneshot::Sender<mpsc::Sender<MpscMsg>>,
64);
65pub(crate) type C2sProtocol = (Protocols, ConnectAddr, Cid);
66
67fn anonymize_addr(addr: &SocketAddr) -> String {
68 use std::net::IpAddr;
69 match addr.ip() {
70 IpAddr::V4(ip) => {
71 let [o0, _, o2, _] = ip.octets();
72 format!("{o0}.xxx.{o2}.xxx:{}", addr.port())
73 },
74 IpAddr::V6(ip) => {
75 let [s0, s1, _, _, s4, s5, _, _] = ip.segments();
76 format!(
77 "[{s0:04x}:{s1:04x}:xxxx:xxxx:{s4:04x}:{s5:04x}:xxxx:xxxx]:{}",
78 addr.port()
79 )
80 },
81 }
82}
83
84impl Protocols {
85 const MPSC_CHANNEL_BOUND: usize = 1000;
86
87 pub(crate) async fn with_tcp_connect(
88 addr: SocketAddr,
89 metrics: ProtocolMetricCache,
90 ) -> Result<Self, NetworkConnectError> {
91 let stream = net::TcpStream::connect(addr)
92 .await
93 .and_then(|s| {
94 s.set_nodelay(true)?;
95 Ok(s)
96 })
97 .map_err(NetworkConnectError::Io)?;
98 info!(
99 "Connecting Tcp to: {}",
100 stream.peer_addr().map_err(NetworkConnectError::Io)?
101 );
102 Ok(Self::new_tcp(stream, metrics))
103 }
104
105 pub(crate) async fn with_tcp_listen(
106 addr: SocketAddr,
107 cids: Arc<AtomicU64>,
108 metrics: Arc<ProtocolMetrics>,
109 s2s_stop_listening_r: oneshot::Receiver<()>,
110 c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
111 ) -> io::Result<()> {
112 use socket2::{Domain, Socket, Type};
113 let domain = Domain::for_address(addr);
114 let socket2_socket = Socket::new(domain, Type::STREAM, None)?;
115 if domain == Domain::IPV6 {
116 socket2_socket.set_only_v6(true)?
117 }
118 socket2_socket.set_nonblocking(true)?; #[cfg(not(windows))]
121 socket2_socket.set_reuse_address(true)?;
122 const SEND_BUFFER_SIZE: usize = 262144;
123 const RECV_BUFFER_SIZE: usize = SEND_BUFFER_SIZE * 2;
124 if let Err(e) = socket2_socket.set_recv_buffer_size(RECV_BUFFER_SIZE) {
125 warn!(?e, "Couldn't set recv_buffer size")
126 };
127 if let Err(e) = socket2_socket.set_send_buffer_size(SEND_BUFFER_SIZE) {
128 warn!(?e, "Couldn't set set_buffer size")
129 };
130 let socket2_addr = addr.into();
131 socket2_socket.bind(&socket2_addr)?;
132 socket2_socket.listen(1024)?;
133 let std_listener: std::net::TcpListener = socket2_socket.into();
134 let listener = net::TcpListener::from_std(std_listener)?;
135 trace!(?addr, "Tcp Listener bound");
136 let mut end_receiver = s2s_stop_listening_r.fuse();
137 tokio::spawn(async move {
138 while let Some(data) = select! {
139 next = listener.accept().fuse() => Some(next),
140 _ = &mut end_receiver => None,
141 } {
142 let (stream, remote_addr) = match data {
143 Ok((s, p)) => (s, p),
144 Err(e) => {
145 trace!(?e, "TcpStream Error, ignoring connection attempt");
146 continue;
147 },
148 };
149 if let Err(e) = stream.set_nodelay(true) {
150 warn!(
151 ?e,
152 "Failed to set TCP_NODELAY, client may have degraded latency"
153 );
154 }
155 let cid = cids.fetch_add(1, Ordering::Relaxed);
156 info!(
157 remote_addr = anonymize_addr(&remote_addr),
158 ?cid,
159 "Accepting Tcp from"
160 );
161 let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
162 let _ = c2s_protocol_s.send((
163 Self::new_tcp(stream, metrics.clone()),
164 ConnectAddr::Tcp(remote_addr),
165 cid,
166 ));
167 }
168 });
169 Ok(())
170 }
171
172 pub(crate) fn new_tcp(stream: net::TcpStream, metrics: ProtocolMetricCache) -> Self {
173 let (r, w) = stream.into_split();
174 let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone());
175 let rp = TcpRecvProtocol::new(
176 TcpSink {
177 half: r,
178 buffer: BytesMut::new(),
179 },
180 metrics,
181 );
182 Protocols::Tcp((sp, rp))
183 }
184
185 pub(crate) async fn with_mpsc_connect(
186 addr: u64,
187 metrics: ProtocolMetricCache,
188 ) -> Result<Self, NetworkConnectError> {
189 let mpsc_s = MPSC_POOL
190 .lock()
191 .await
192 .get(&addr)
193 .ok_or_else(|| {
194 NetworkConnectError::Io(io::Error::new(
195 io::ErrorKind::NotConnected,
196 "no mpsc listen on this addr",
197 ))
198 })?
199 .clone();
200 let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
201 let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
202 mpsc_s
203 .send((remote_to_local_s, local_to_remote_oneshot_s))
204 .map_err(|_| {
205 NetworkConnectError::Io(io::Error::new(
206 io::ErrorKind::BrokenPipe,
207 "mpsc pipe broke during connect",
208 ))
209 })?;
210 let local_to_remote_s = local_to_remote_oneshot_r
211 .await
212 .map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?;
213 info!(?addr, "Connecting Mpsc");
214 Ok(Self::new_mpsc(
215 local_to_remote_s,
216 remote_to_local_r,
217 metrics,
218 ))
219 }
220
221 pub(crate) async fn with_mpsc_listen(
222 addr: u64,
223 cids: Arc<AtomicU64>,
224 metrics: Arc<ProtocolMetrics>,
225 s2s_stop_listening_r: oneshot::Receiver<()>,
226 c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
227 ) -> io::Result<()> {
228 let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
229 MPSC_POOL.lock().await.insert(addr, mpsc_s);
230 trace!(?addr, "Mpsc Listener bound");
231 let mut end_receiver = s2s_stop_listening_r.fuse();
232 tokio::spawn(async move {
233 while let Some((local_to_remote_s, local_remote_to_local_s)) = select! {
234 next = mpsc_r.recv().fuse() => next,
235 _ = &mut end_receiver => None,
236 } {
237 let (remote_to_local_s, remote_to_local_r) =
238 mpsc::channel(Self::MPSC_CHANNEL_BOUND);
239 if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) {
240 error!(?e, "mpsc listen aborted");
241 }
242
243 let cid = cids.fetch_add(1, Ordering::Relaxed);
244 info!(?addr, ?cid, "Accepting Mpsc from");
245 let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
246 let _ = c2s_protocol_s.send((
247 Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()),
248 ConnectAddr::Mpsc(addr),
249 cid,
250 ));
251 }
252 warn!("MpscStream Failed, stopping");
253 });
254 Ok(())
255 }
256
257 pub(crate) fn new_mpsc(
258 sender: mpsc::Sender<MpscMsg>,
259 receiver: mpsc::Receiver<MpscMsg>,
260 metrics: ProtocolMetricCache,
261 ) -> Self {
262 let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone());
263 let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics);
264 Protocols::Mpsc((sp, rp))
265 }
266
267 #[cfg(feature = "quic")]
268 pub(crate) async fn with_quic_connect(
269 addr: SocketAddr,
270 config: quinn::ClientConfig,
271 name: String,
272 metrics: ProtocolMetricCache,
273 ) -> Result<Self, NetworkConnectError> {
274 let config = config.clone();
275
276 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
277
278 let bindsock = match addr {
279 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
280 SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
281 };
282 let endpoint = match quinn::Endpoint::client(bindsock) {
283 Ok(e) => e,
284 Err(e) => return Err(NetworkConnectError::Io(e)),
285 };
286
287 info!("Connecting Quic to: {}", &addr);
288 let connecting = endpoint.connect_with(config, addr, &name).map_err(|e| {
289 trace!(?e, "error setting up quic");
290 NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
291 })?;
292 let connection = connecting.await.map_err(|e| {
293 trace!(?e, "error with quic connection");
294 NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
295 })?;
296 Self::new_quic(connection, false, metrics)
297 .await
298 .map_err(|e| {
299 trace!(?e, "error with quic");
300 NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
301 })
302 }
303
304 #[cfg(feature = "quic")]
305 pub(crate) async fn with_quic_listen(
306 addr: SocketAddr,
307 server_config: quinn::ServerConfig,
308 cids: Arc<AtomicU64>,
309 metrics: Arc<ProtocolMetrics>,
310 s2s_stop_listening_r: oneshot::Receiver<()>,
311 c2s_protocol_s: mpsc::UnboundedSender<C2sProtocol>,
312 ) -> io::Result<()> {
313 let endpoint = quinn::Endpoint::server(server_config, addr)?;
314 trace!(?addr, "Quic Listener bound");
315 let mut end_receiver = s2s_stop_listening_r.fuse();
316 tokio::spawn(async move {
317 while let Some(Some(connecting)) = select! {
318 next = endpoint.accept().fuse() => Some(next),
319 _ = &mut end_receiver => None,
320 } {
321 let remote_addr = anonymize_addr(&connecting.remote_address());
322 let connection = match connecting.await {
323 Ok(c) => c,
324 Err(e) => {
325 tracing::debug!(?e, ?remote_addr, "skipping connection attempt");
326 continue;
327 },
328 };
329
330 let cid = cids.fetch_add(1, Ordering::Relaxed);
331 info!(?remote_addr, ?cid, "Accepting Quic from");
332 let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics));
333 match Protocols::new_quic(connection, true, metrics).await {
334 Ok(quic) => {
335 let connect_addr = ConnectAddr::Quic(
340 addr,
341 quinn::ClientConfig::with_platform_verifier(),
342 "TODO_remote_hostname".to_string(),
343 );
344 let _ = c2s_protocol_s.send((quic, connect_addr, cid));
345 },
346 Err(e) => {
347 trace!(?e, "failed to start quic");
348 continue;
349 },
350 }
351 }
352 });
353 Ok(())
354 }
355
356 #[cfg(feature = "quic")]
357 pub(crate) async fn new_quic(
358 connection: quinn::Connection,
359 listen: bool,
360 metrics: ProtocolMetricCache,
361 ) -> Result<Self, quinn::ConnectionError> {
362 let (sendstream, recvstream) = if listen {
363 connection.open_bi().await?
364 } else {
365 connection
366 .accept_bi()
367 .await
368 .or(Err(quinn::ConnectionError::LocallyClosed))?
369 };
370 let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel();
371 let streams_s_clone = recvstreams_s.clone();
372 let (sendstreams_s, sendstreams_r) = mpsc::unbounded_channel();
373 let sp = QuicSendProtocol::new(
374 QuicDrain {
375 con: connection.clone(),
376 main: sendstream,
377 reliables: HashMap::new(),
378 recvstreams_s: streams_s_clone,
379 sendstreams_r,
380 },
381 metrics.clone(),
382 );
383 spawn_new(recvstream, None, &recvstreams_s);
384 let rp = QuicRecvProtocol::new(
385 QuicSink {
386 con: connection,
387 recvstreams_r,
388 recvstreams_s,
389 sendstreams_s,
390 },
391 metrics,
392 );
393 Ok(Protocols::Quic((sp, rp)))
394 }
395
396 pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) {
397 match self {
398 Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)),
399 Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)),
400 #[cfg(feature = "quic")]
401 Protocols::Quic((s, r)) => (SendProtocols::Quic(s), RecvProtocols::Quic(r)),
402 }
403 }
404}
405
406#[async_trait]
407impl network_protocol::InitProtocol for Protocols {
408 type CustomErr = ProtocolsError;
409
410 async fn initialize(
411 &mut self,
412 initializer: bool,
413 local_pid: Pid,
414 secret: u128,
415 ) -> Result<(Pid, Sid, u128), InitProtocolError<Self::CustomErr>> {
416 match self {
417 Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await,
418 Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await,
419 #[cfg(feature = "quic")]
420 Protocols::Quic(p) => p.initialize(initializer, local_pid, secret).await,
421 }
422 }
423}
424
425#[async_trait]
426impl network_protocol::SendProtocol for SendProtocols {
427 type CustomErr = ProtocolsError;
428
429 fn notify_from_recv(&mut self, event: ProtocolEvent) {
430 match self {
431 SendProtocols::Tcp(s) => s.notify_from_recv(event),
432 SendProtocols::Mpsc(s) => s.notify_from_recv(event),
433 #[cfg(feature = "quic")]
434 SendProtocols::Quic(s) => s.notify_from_recv(event),
435 }
436 }
437
438 async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
439 match self {
440 SendProtocols::Tcp(s) => s.send(event).await,
441 SendProtocols::Mpsc(s) => s.send(event).await,
442 #[cfg(feature = "quic")]
443 SendProtocols::Quic(s) => s.send(event).await,
444 }
445 }
446
447 async fn flush(
448 &mut self,
449 bandwidth: Bandwidth,
450 dt: Duration,
451 ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
452 match self {
453 SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await,
454 SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await,
455 #[cfg(feature = "quic")]
456 SendProtocols::Quic(s) => s.flush(bandwidth, dt).await,
457 }
458 }
459}
460
461#[async_trait]
462impl network_protocol::RecvProtocol for RecvProtocols {
463 type CustomErr = ProtocolsError;
464
465 async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
466 match self {
467 RecvProtocols::Tcp(r) => r.recv().await,
468 RecvProtocols::Mpsc(r) => r.recv().await,
469 #[cfg(feature = "quic")]
470 RecvProtocols::Quic(r) => r.recv().await,
471 }
472 }
473}
474
475#[derive(Debug)]
476pub enum MpscError {
477 Send(mpsc::error::SendError<MpscMsg>),
478 Recv,
479}
480
481#[cfg(feature = "quic")]
482#[derive(Debug)]
483pub enum QuicError {
484 Send(io::Error),
485 Connection(quinn::ConnectionError),
486 Write(quinn::WriteError),
487 Read(quinn::ReadError),
488 InternalMpsc,
489}
490
491#[derive(Debug)]
493pub enum ProtocolsError {
494 Tcp(io::Error),
495 Udp(io::Error),
496 #[cfg(feature = "quic")]
497 Quic(QuicError),
498 Mpsc(MpscError),
499}
500
501#[derive(Debug)]
504pub struct TcpDrain {
505 half: OwnedWriteHalf,
506}
507
508#[derive(Debug)]
509pub struct TcpSink {
510 half: OwnedReadHalf,
511 buffer: BytesMut,
512}
513
514#[async_trait]
515impl UnreliableDrain for TcpDrain {
516 type CustomErr = ProtocolsError;
517 type DataFormat = BytesMut;
518
519 async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
520 self.half
521 .write_all(&data)
522 .await
523 .map_err(|e| ProtocolError::Custom(ProtocolsError::Tcp(e)))
524 }
525}
526
527#[async_trait]
528impl UnreliableSink for TcpSink {
529 type CustomErr = ProtocolsError;
530 type DataFormat = BytesMut;
531
532 async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
533 if self.buffer.capacity() < 1500 {
534 self.buffer.reserve(1500 * 4); }
536 match self.half.read_buf(&mut self.buffer).await {
537 Ok(0) => Err(ProtocolError::Custom(ProtocolsError::Tcp(io::Error::new(
538 io::ErrorKind::BrokenPipe,
539 "read returned 0 bytes",
540 )))),
541 Ok(_) => Ok(self.buffer.split()),
542 Err(e) => Err(ProtocolError::Custom(ProtocolsError::Tcp(e))),
543 }
544 }
545}
546
547#[derive(Debug)]
550pub struct MpscDrain {
551 sender: mpsc::Sender<MpscMsg>,
552}
553
554#[derive(Debug)]
555pub struct MpscSink {
556 receiver: mpsc::Receiver<MpscMsg>,
557}
558
559#[async_trait]
560impl UnreliableDrain for MpscDrain {
561 type CustomErr = ProtocolsError;
562 type DataFormat = MpscMsg;
563
564 async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
565 self.sender
566 .send(data)
567 .await
568 .map_err(|e| ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Send(e))))
569 }
570}
571
572#[async_trait]
573impl UnreliableSink for MpscSink {
574 type CustomErr = ProtocolsError;
575 type DataFormat = MpscMsg;
576
577 async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
578 self.receiver
579 .recv()
580 .await
581 .ok_or(ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Recv)))
582 }
583}
584
585#[cfg(feature = "quic")]
588type QuicStream = (
589 BytesMut,
590 Result<Option<usize>, quinn::ReadError>,
591 quinn::RecvStream,
592 Option<Sid>,
593);
594
595#[cfg(feature = "quic")]
596#[derive(Debug)]
597pub struct QuicDrain {
598 con: quinn::Connection,
599 main: quinn::SendStream,
600 reliables: HashMap<Sid, quinn::SendStream>,
601 recvstreams_s: mpsc::UnboundedSender<QuicStream>,
602 sendstreams_r: mpsc::UnboundedReceiver<quinn::SendStream>,
603}
604
605#[cfg(feature = "quic")]
606#[derive(Debug)]
607pub struct QuicSink {
608 con: quinn::Connection,
609 recvstreams_r: mpsc::UnboundedReceiver<QuicStream>,
610 recvstreams_s: mpsc::UnboundedSender<QuicStream>,
611 sendstreams_s: mpsc::UnboundedSender<quinn::SendStream>,
612}
613
614#[cfg(feature = "quic")]
615fn spawn_new(
616 mut recvstream: quinn::RecvStream,
617 sid: Option<Sid>,
618 streams_s: &mpsc::UnboundedSender<QuicStream>,
619) {
620 let streams_s_clone = streams_s.clone();
621 tokio::spawn(async move {
622 let mut buffer = BytesMut::new();
623 buffer.resize(1500, 0u8);
624 let r = recvstream.read(&mut buffer).await;
625 let _ = streams_s_clone.send((buffer, r, recvstream, sid));
626 });
627}
628
629#[cfg(feature = "quic")]
630#[async_trait]
631impl UnreliableDrain for QuicDrain {
632 type CustomErr = ProtocolsError;
633 type DataFormat = QuicDataFormat;
634
635 async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
636 match data.stream {
637 QuicDataFormatStream::Main => self.main.write_all(&data.data).await,
638 QuicDataFormatStream::Unreliable => unimplemented!(),
639 QuicDataFormatStream::Reliable(sid) => {
640 use hashbrown::hash_map::Entry;
641 match self.reliables.entry(sid) {
643 Entry::Occupied(mut occupied) => occupied.get_mut().write_all(&data.data).await,
644 Entry::Vacant(vacant) => {
645 if data.data.is_empty() {
648 let (mut sendstream, recvstream) =
649 self.con.open_bi().await.map_err(|e| {
650 ProtocolError::Custom(ProtocolsError::Quic(
651 QuicError::Connection(e),
652 ))
653 })?;
654 sendstream.write_u64(sid.get_u64()).await.map_err(|e| {
656 ProtocolError::Custom(ProtocolsError::Quic(QuicError::Send(e)))
657 })?;
658 spawn_new(recvstream, Some(sid), &self.recvstreams_s);
659 vacant.insert(sendstream).write_all(&data.data).await
660 } else {
661 let sendstream =
662 self.sendstreams_r
663 .recv()
664 .await
665 .ok_or(ProtocolError::Custom(ProtocolsError::Quic(
666 QuicError::InternalMpsc,
667 )))?;
668 vacant.insert(sendstream).write_all(&data.data).await
669 }
670 },
671 }
672 },
673 }
674 .map_err(|e| ProtocolError::Custom(ProtocolsError::Quic(QuicError::Write(e))))
675 }
676}
677
678#[cfg(feature = "quic")]
679#[async_trait]
680impl UnreliableSink for QuicSink {
681 type CustomErr = ProtocolsError;
682 type DataFormat = QuicDataFormat;
683
684 async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
685 let (mut buffer, result, mut recvstream, id) = loop {
686 use futures_util::FutureExt;
687 let (a, b) = select! {
689 biased;
690 n = self.con.accept_bi().fuse() => (Some(n), None),
691 Some(n) = self.recvstreams_r.recv().fuse() => (None, Some(n)),
692 };
693
694 if let Some(remote_stream) = a {
695 let (sendstream, mut recvstream) = remote_stream.map_err(|e| {
696 ProtocolError::Custom(ProtocolsError::Quic(QuicError::Connection(e)))
697 })?;
698 let sid = match recvstream.read_u64().await {
699 Ok(u64::MAX) => None, Ok(sid) => Some(Sid::new(sid)),
701 Err(_) => return Err(ProtocolError::Violated),
702 };
703 if self.sendstreams_s.send(sendstream).is_err() {
704 return Err(ProtocolError::Custom(ProtocolsError::Quic(
705 QuicError::InternalMpsc,
706 )));
707 }
708 spawn_new(recvstream, sid, &self.recvstreams_s);
709 }
710
711 if let Some(data) = b {
712 break data;
713 }
714 };
715
716 let r = match result {
717 Ok(Some(0)) => Err(ProtocolError::Custom(ProtocolsError::Quic(
718 QuicError::Send(io::Error::new(
719 io::ErrorKind::BrokenPipe,
720 "read returned 0 bytes",
721 )),
722 ))),
723 Ok(Some(n)) => Ok(QuicDataFormat {
724 stream: match id {
725 Some(id) => QuicDataFormatStream::Reliable(id),
726 None => QuicDataFormatStream::Main,
727 },
728 data: buffer.split_to(n),
729 }),
730 Ok(None) => Err(ProtocolError::Custom(ProtocolsError::Quic(
731 QuicError::Send(io::Error::new(
732 io::ErrorKind::BrokenPipe,
733 "read returned None",
734 )),
735 ))),
736 Err(e) => Err(ProtocolError::Custom(ProtocolsError::Quic(
737 QuicError::Read(e),
738 ))),
739 }?;
740
741 let streams_s_clone = self.recvstreams_s.clone();
742 tokio::spawn(async move {
743 buffer.resize(1500, 0u8);
744 let r = recvstream.read(&mut buffer).await;
745 let _ = streams_s_clone.send((buffer, r, recvstream, id));
746 });
747 Ok(r)
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754 use bytes::Bytes;
755 use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol};
756 use std::sync::Arc;
757 use tokio::net::{TcpListener, TcpStream};
758
759 #[tokio::test]
760 async fn tokio_sinks() {
761 let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap();
762 let r1 = tokio::spawn(async move {
763 let (server, _) = listener.accept().await.unwrap();
764 (listener, server)
765 });
766 let client = TcpStream::connect("127.0.0.1:5000").await.unwrap();
767 let (_listener, server) = r1.await.unwrap();
768 let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
769 let client = Protocols::new_tcp(client, metrics.clone());
770 let server = Protocols::new_tcp(server, metrics);
771 let (mut s, _) = client.split();
772 let (_, mut r) = server.split();
773 let event = ProtocolEvent::OpenStream {
774 sid: Sid::new(1),
775 prio: 4u8,
776 promises: Promises::GUARANTEED_DELIVERY,
777 guaranteed_bandwidth: 1_000,
778 };
779 s.send(event.clone()).await.unwrap();
780 s.send(ProtocolEvent::Message {
781 sid: Sid::new(1),
782 data: Bytes::from(&[8u8; 8][..]),
783 })
784 .await
785 .unwrap();
786 s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
787 drop(s); tokio::time::sleep(Duration::from_secs(1)).await;
789 let res = r.recv().await;
790 match res {
791 Ok(ProtocolEvent::OpenStream {
792 sid,
793 prio,
794 promises,
795 guaranteed_bandwidth: _,
796 }) => {
797 assert_eq!(sid, Sid::new(1));
798 assert_eq!(prio, 4u8);
799 assert_eq!(promises, Promises::GUARANTEED_DELIVERY);
800 },
801 _ => {
802 panic!("wrong type {:?}", res);
803 },
804 }
805 r.recv().await.unwrap();
806 }
807
808 #[tokio::test]
809 async fn tokio_sink_stop_after_drop() {
810 let listener = TcpListener::bind("127.0.0.1:5001").await.unwrap();
811 let r1 = tokio::spawn(async move {
812 let (server, _) = listener.accept().await.unwrap();
813 (listener, server)
814 });
815 let client = TcpStream::connect("127.0.0.1:5001").await.unwrap();
816 let (_listener, server) = r1.await.unwrap();
817 let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap()));
818 let client = Protocols::new_tcp(client, metrics.clone());
819 let server = Protocols::new_tcp(server, metrics);
820 let (s, _) = client.split();
821 let (_, mut r) = server.split();
822 let e = tokio::spawn(async move { r.recv().await });
823 drop(s);
824 let e = e.await.unwrap();
825 assert!(e.is_err());
826 assert!(matches!(e, Err(..)));
827 let e = e.unwrap_err();
828 assert!(matches!(e, ProtocolError::Custom(..)));
829 assert!(matches!(e, ProtocolError::Custom(ProtocolsError::Tcp(_))));
830 match e {
831 ProtocolError::Custom(ProtocolsError::Tcp(e)) => {
832 assert_eq!(e.kind(), io::ErrorKind::BrokenPipe)
833 },
834 _ => panic!("invalid error"),
835 }
836 }
837}