1use crate::{
2 api::{ConnectAddr, ParticipantError, ParticipantEvent, Stream},
3 channel::{Protocols, ProtocolsError, RecvProtocols, SendProtocols},
4 metrics::NetworkMetrics,
5 util::DeferredTracer,
6};
7use bytes::Bytes;
8use futures_util::{FutureExt, StreamExt};
9use hashbrown::HashMap;
10use network_protocol::{
11 _internal::SortedVec, Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol,
12 SendProtocol, Sid,
13};
14use std::{
15 sync::{
16 Arc,
17 atomic::{AtomicBool, AtomicI32, Ordering},
18 },
19 time::{Duration, Instant},
20};
21use tokio::{
22 select,
23 sync::{Mutex, RwLock, mpsc, oneshot, watch},
24 task::JoinHandle,
25};
26use tokio_stream::wrappers::UnboundedReceiverStream;
27use tracing::*;
28
29pub(crate) type A2bStreamOpen = (Prio, Promises, Bandwidth, oneshot::Sender<Stream>);
30pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, ConnectAddr, oneshot::Sender<()>);
31pub(crate) type S2bShutdownBparticipant = (Duration, oneshot::Sender<Result<(), ParticipantError>>);
32pub(crate) type B2sPrioStatistic = (Pid, u64, u64);
33
34#[derive(Debug)]
35#[expect(dead_code)]
36struct ChannelInfo {
37 cid: Cid,
38 cid_string: String, remote_con_addr: ConnectAddr,
40}
41
42#[derive(Debug)]
43struct StreamInfo {
44 #[expect(dead_code)]
45 prio: Prio,
46 #[expect(dead_code)]
47 promises: Promises,
48 send_closed: Arc<AtomicBool>,
49 b2a_msg_recv_s: Mutex<async_channel::Sender<Bytes>>,
50}
51
52#[derive(Debug)]
53struct ControlChannels {
54 a2b_open_stream_r: mpsc::UnboundedReceiver<A2bStreamOpen>,
55 b2a_stream_opened_s: mpsc::UnboundedSender<Stream>,
56 b2a_event_s: mpsc::UnboundedSender<ParticipantEvent>,
57 s2b_create_channel_r: mpsc::UnboundedReceiver<S2bCreateChannel>,
58 b2a_bandwidth_stats_s: watch::Sender<f32>,
59 s2b_shutdown_bparticipant_r: oneshot::Receiver<S2bShutdownBparticipant>, }
61
62#[derive(Debug)]
63struct OpenStreamInfo {
64 a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>,
65 a2b_close_stream_s: mpsc::UnboundedSender<Sid>,
66}
67
68#[derive(Debug)]
69pub struct BParticipant {
70 local_pid: Pid, remote_pid: Pid,
72 remote_pid_string: String, offset_sid: Sid,
74 channels: Arc<RwLock<HashMap<Cid, Mutex<ChannelInfo>>>>,
75 streams: RwLock<HashMap<Sid, StreamInfo>>,
76 run_channels: Option<ControlChannels>,
77 shutdown_barrier: AtomicI32,
78 metrics: Arc<NetworkMetrics>,
79 open_stream_channels: Arc<Mutex<Option<OpenStreamInfo>>>,
80}
81
82impl BParticipant {
83 const BARR_CHANNEL: i32 = 1;
85 const BARR_RECV: i32 = 4;
86 const BARR_SEND: i32 = 2;
87 const TICK_TIME: Duration = Duration::from_millis(Self::TICK_TIME_MS);
88 const TICK_TIME_MS: u64 = 5;
89
90 pub(crate) fn new(
91 local_pid: Pid,
92 remote_pid: Pid,
93 offset_sid: Sid,
94 metrics: Arc<NetworkMetrics>,
95 ) -> (
96 Self,
97 mpsc::UnboundedSender<A2bStreamOpen>,
98 mpsc::UnboundedReceiver<Stream>,
99 mpsc::UnboundedReceiver<ParticipantEvent>,
100 mpsc::UnboundedSender<S2bCreateChannel>,
101 oneshot::Sender<S2bShutdownBparticipant>,
102 watch::Receiver<f32>,
103 ) {
104 let (a2b_open_stream_s, a2b_open_stream_r) = mpsc::unbounded_channel::<A2bStreamOpen>();
105 let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded_channel::<Stream>();
106 let (b2a_event_s, b2a_event_r) = mpsc::unbounded_channel::<ParticipantEvent>();
107 let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel();
108 let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded_channel();
109 let (b2a_bandwidth_stats_s, b2a_bandwidth_stats_r) = watch::channel::<f32>(0.0);
110
111 let run_channels = Some(ControlChannels {
112 a2b_open_stream_r,
113 b2a_stream_opened_s,
114 b2a_event_s,
115 s2b_create_channel_r,
116 b2a_bandwidth_stats_s,
117 s2b_shutdown_bparticipant_r,
118 });
119
120 (
121 Self {
122 local_pid,
123 remote_pid,
124 remote_pid_string: remote_pid.to_string(),
125 offset_sid,
126 channels: Arc::new(RwLock::new(HashMap::new())),
127 streams: RwLock::new(HashMap::new()),
128 shutdown_barrier: AtomicI32::new(
129 Self::BARR_CHANNEL + Self::BARR_SEND + Self::BARR_RECV,
130 ),
131 run_channels,
132 metrics,
133 open_stream_channels: Arc::new(Mutex::new(None)),
134 },
135 a2b_open_stream_s,
136 b2a_stream_opened_r,
137 b2a_event_r,
138 s2b_create_channel_s,
139 s2b_shutdown_bparticipant_s,
140 b2a_bandwidth_stats_r,
141 )
142 }
143
144 pub async fn run(mut self, b2s_prio_statistic_s: mpsc::UnboundedSender<B2sPrioStatistic>) {
145 let (b2b_add_send_protocol_s, b2b_add_send_protocol_r) =
146 mpsc::unbounded_channel::<(Cid, SendProtocols)>();
147 let (b2b_add_recv_protocol_s, b2b_add_recv_protocol_r) =
148 mpsc::unbounded_channel::<(Cid, RecvProtocols)>();
149 let (b2b_close_send_protocol_s, b2b_close_send_protocol_r) =
150 async_channel::unbounded::<Cid>();
151 let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) =
152 async_channel::unbounded::<Cid>();
153 let (b2b_notify_send_of_recv_open_s, b2b_notify_send_of_recv_open_r) =
154 crossbeam_channel::unbounded::<(Cid, Sid, Prio, Promises, u64)>();
155 let (b2b_notify_send_of_recv_close_s, b2b_notify_send_of_recv_close_r) =
156 crossbeam_channel::unbounded::<(Cid, Sid)>();
157
158 let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::<Sid>();
159 let (a2b_msg_s, a2b_msg_r) = crossbeam_channel::unbounded::<(Sid, Bytes)>();
160
161 *self.open_stream_channels.lock().await = Some(OpenStreamInfo {
162 a2b_msg_s,
163 a2b_close_stream_s,
164 });
165 let run_channels = self.run_channels.take().unwrap();
166 trace!("start all managers");
167 tokio::join!(
168 self.send_mgr(
169 run_channels.a2b_open_stream_r,
170 a2b_close_stream_r,
171 a2b_msg_r,
172 b2b_add_send_protocol_r,
173 b2b_close_send_protocol_r,
174 b2b_notify_send_of_recv_open_r,
175 b2b_notify_send_of_recv_close_r,
176 run_channels.b2a_event_s.clone(),
177 b2s_prio_statistic_s,
178 run_channels.b2a_bandwidth_stats_s,
179 )
180 .instrument(tracing::info_span!("send")),
181 self.recv_mgr(
182 run_channels.b2a_stream_opened_s,
183 b2b_add_recv_protocol_r,
184 b2b_force_close_recv_protocol_r,
185 b2b_close_send_protocol_s.clone(),
186 b2b_notify_send_of_recv_open_s,
187 b2b_notify_send_of_recv_close_s,
188 )
189 .instrument(tracing::info_span!("recv")),
190 self.create_channel_mgr(
191 run_channels.s2b_create_channel_r,
192 b2b_add_send_protocol_s,
193 b2b_add_recv_protocol_s,
194 run_channels.b2a_event_s,
195 ),
196 self.participant_shutdown_mgr(
197 run_channels.s2b_shutdown_bparticipant_r,
198 b2b_close_send_protocol_s.clone(),
199 b2b_force_close_recv_protocol_s,
200 ),
201 );
202 }
203
204 fn best_protocol(all: &SortedVec<Cid, SendProtocols>, promises: Promises) -> Option<Cid> {
205 all.data.iter().find(|(_, p)| matches!(p, SendProtocols::Mpsc(_))).map(|(c, _)| *c).or_else(
207 || if network_protocol::TcpSendProtocol::<crate::channel::TcpDrain>::supported_promises()
208 .contains(promises)
209 {
210 all.data.iter().find(|(_, p)| matches!(p, SendProtocols::Tcp(_))).map(|(c, _)| *c)
212 } else {
213 None
214 }
215 ).or_else(
216 || if network_protocol::QuicSendProtocol::<crate::channel::QuicDrain>::supported_promises()
218 .contains(promises)
219 {
220 all.data.iter().find(|(_, p)| matches!(p, SendProtocols::Quic(_))).map(|(c, _)| *c)
221 } else {
222 None
223 }
224 ).or_else(
225 || {
226 warn!("couldn't satisfy promises");
227 all.data.first().map(|(c, _)| *c)
228 }
229 )
230 }
231
232 async fn send_mgr(
234 &self,
235 mut a2b_open_stream_r: mpsc::UnboundedReceiver<A2bStreamOpen>,
236 mut a2b_close_stream_r: mpsc::UnboundedReceiver<Sid>,
237 a2b_msg_r: crossbeam_channel::Receiver<(Sid, Bytes)>,
238 mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>,
239 b2b_close_send_protocol_r: async_channel::Receiver<Cid>,
240 b2b_notify_send_of_recv_open_r: crossbeam_channel::Receiver<(
241 Cid,
242 Sid,
243 Prio,
244 Promises,
245 Bandwidth,
246 )>,
247 b2b_notify_send_of_recv_close_r: crossbeam_channel::Receiver<(Cid, Sid)>,
248 b2a_event_s: mpsc::UnboundedSender<ParticipantEvent>,
249 _b2s_prio_statistic_s: mpsc::UnboundedSender<B2sPrioStatistic>,
250 b2a_bandwidth_stats_s: watch::Sender<f32>,
251 ) {
252 let mut sorted_send_protocols = SortedVec::<Cid, SendProtocols>::default();
253 let mut sorted_stream_protocols = SortedVec::<Sid, Cid>::default();
254 let mut interval = tokio::time::interval(Self::TICK_TIME);
255 let mut last_instant = Instant::now();
256 let mut stream_ids = self.offset_sid;
257 let mut part_bandwidth = 0.0f32;
258 trace!("workaround, actively wait for first protocol");
259 if let Some((c, p)) = b2b_add_protocol_r.recv().await {
260 sorted_send_protocols.insert(c, p)
261 }
262 loop {
263 let (open, close, _, addp, remp) = select!(
264 Some(n) = a2b_open_stream_r.recv().fuse() => (Some(n), None, None, None, None),
265 Some(n) = a2b_close_stream_r.recv().fuse() => (None, Some(n), None, None, None),
266 _ = interval.tick() => (None, None, Some(()), None, None),
267 Some(n) = b2b_add_protocol_r.recv().fuse() => (None, None, None, Some(n), None),
268 Ok(n) = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(n)),
269 );
270
271 if let Some((cid, p)) = addp {
272 debug!(?cid, "add protocol");
273 sorted_send_protocols.insert(cid, p);
274 }
275
276 if sorted_send_protocols.data.is_empty() {
278 warn!("no channel");
279 tokio::time::sleep(Self::TICK_TIME * 1000).await; continue;
281 }
282
283 let mut cid = u64::MAX;
286
287 let active_err = async {
288 if let Some((prio, promises, guaranteed_bandwidth, return_s)) = open {
289 let sid = stream_ids;
290 stream_ids += Sid::from(1);
291 cid = Self::best_protocol(&sorted_send_protocols, promises).unwrap();
292 trace!(?sid, ?cid, "open stream");
293
294 let stream = self
295 .create_stream(sid, prio, promises, guaranteed_bandwidth)
296 .await;
297
298 let event = ProtocolEvent::OpenStream {
299 sid,
300 prio,
301 promises,
302 guaranteed_bandwidth,
303 };
304
305 sorted_stream_protocols.insert(sid, cid);
306 return_s.send(stream).unwrap();
307 sorted_send_protocols
308 .get_mut(&cid)
309 .unwrap()
310 .send(event)
311 .await?;
312 }
313
314 for (cid, sid, prio, promises, guaranteed_bandwidth) in
316 b2b_notify_send_of_recv_open_r.try_iter()
317 {
318 match sorted_send_protocols.get_mut(&cid) {
319 Some(p) => {
320 sorted_stream_protocols.insert(sid, cid);
321 p.notify_from_recv(ProtocolEvent::OpenStream {
322 sid,
323 prio,
324 promises,
325 guaranteed_bandwidth,
326 });
327 },
328 None => warn!(?cid, "couldn't notify create protocol, doesn't exist"),
329 };
330 }
331
332 for (sid, buffer) in a2b_msg_r.try_iter() {
334 cid = *sorted_stream_protocols.get(&sid).unwrap();
335 let event = ProtocolEvent::Message { data: buffer, sid };
336 sorted_send_protocols
337 .get_mut(&cid)
338 .unwrap()
339 .send(event)
340 .await?;
341 }
342
343 for (cid, sid) in b2b_notify_send_of_recv_close_r.try_iter() {
345 match sorted_send_protocols.get_mut(&cid) {
346 Some(p) => {
347 let _ = sorted_stream_protocols.delete(&sid);
348 p.notify_from_recv(ProtocolEvent::CloseStream { sid });
349 },
350 None => warn!(?cid, "couldn't notify close protocol, doesn't exist"),
351 };
352 }
353
354 if let Some(sid) = close {
355 trace!(?stream_ids, "delete stream");
356 self.delete_stream(sid).await;
357 if let Some(c) = sorted_stream_protocols.delete(&sid) {
360 cid = c;
361 let event = ProtocolEvent::CloseStream { sid };
362 sorted_send_protocols
363 .get_mut(&c)
364 .unwrap()
365 .send(event)
366 .await?;
367 }
368 }
369
370 let send_time = Instant::now();
371 let diff = send_time.duration_since(last_instant);
372 last_instant = send_time;
373 let mut cnt = 0;
374 for (c, p) in sorted_send_protocols.data.iter_mut() {
375 cid = *c;
376 cnt += p.flush(1_000_000_000, diff).await?; }
378 let flush_time = send_time.elapsed().as_secs_f32();
379 part_bandwidth = 0.99 * part_bandwidth + 0.01 * (cnt as f32 / flush_time);
380 self.metrics
381 .participant_bandwidth(&self.remote_pid_string, part_bandwidth);
382 let _ = b2a_bandwidth_stats_s.send(part_bandwidth);
383 let r: Result<(), network_protocol::ProtocolError<ProtocolsError>> = Ok(());
384 r
385 }
386 .await;
387 if let Err(e) = active_err {
388 info!(?cid, ?e, "protocol failed, shutting down channel");
389 trace!("TODO: for now decide to FAIL this participant and not wait for a failover");
392 sorted_send_protocols.delete(&cid).unwrap();
393 if let Some(info) = self.channels.write().await.get(&cid) {
394 if let Err(e) = b2a_event_s.send(ParticipantEvent::ChannelDeleted(
395 info.lock().await.remote_con_addr.clone(),
396 )) {
397 debug!(?e, "Participant was dropped during channel disconnect");
398 };
399 }
400 self.metrics.channels_disconnected(&self.remote_pid_string);
401 if sorted_send_protocols.data.is_empty() {
402 break;
403 }
404 }
405
406 if let Some(cid) = remp {
407 debug!(?cid, "remove protocol");
408 match sorted_send_protocols.delete(&cid) {
409 Some(mut prot) => {
410 if let Some(info) = self.channels.write().await.get(&cid) {
411 if let Err(e) = b2a_event_s.send(ParticipantEvent::ChannelDeleted(
412 info.lock().await.remote_con_addr.clone(),
413 )) {
414 debug!(?e, "Participant was dropped during channel disconnect");
415 };
416 }
417 self.metrics.channels_disconnected(&self.remote_pid_string);
418 trace!("blocking flush");
419 let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await;
420 trace!("shutdown prot");
421 let _ = prot.send(ProtocolEvent::Shutdown).await;
422 },
423 None => trace!("tried to remove protocol twice"),
424 };
425 if sorted_send_protocols.data.is_empty() {
426 break;
427 }
428 }
429 }
430 trace!("stop sending in api!");
431 self.open_stream_channels.lock().await.take();
432 trace!("Stop send_mgr");
433 self.shutdown_barrier
434 .fetch_sub(Self::BARR_SEND, Ordering::SeqCst);
435 }
436
437 async fn recv_mgr(
438 &self,
439 b2a_stream_opened_s: mpsc::UnboundedSender<Stream>,
440 mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>,
441 b2b_force_close_recv_protocol_r: async_channel::Receiver<Cid>,
442 b2b_close_send_protocol_s: async_channel::Sender<Cid>,
443 b2b_notify_send_of_recv_open_r: crossbeam_channel::Sender<(
444 Cid,
445 Sid,
446 Prio,
447 Promises,
448 Bandwidth,
449 )>,
450 b2b_notify_send_of_recv_close_s: crossbeam_channel::Sender<(Cid, Sid)>,
451 ) {
452 let mut recv_protocols: HashMap<Cid, JoinHandle<()>> = HashMap::new();
453 let (hacky_recv_s, mut hacky_recv_r) = mpsc::unbounded_channel();
455
456 let retrigger = |cid: Cid, mut p: RecvProtocols, map: &mut HashMap<_, _>| {
457 let hacky_recv_s = hacky_recv_s.clone();
458 let handle = tokio::spawn(async move {
459 let r = p.recv().await;
460 let _ = hacky_recv_s.send((cid, r, p)); });
462 map.insert(cid, handle);
463 };
464
465 let remove_c = |recv_protocols: &mut HashMap<Cid, JoinHandle<()>>, cid: &Cid| {
466 match recv_protocols.remove(cid) {
467 Some(h) => {
468 h.abort();
469 debug!(?cid, "remove protocol");
470 },
471 None => trace!("tried to remove protocol twice"),
472 };
473 recv_protocols.is_empty()
474 };
475
476 let mut defered_orphan = DeferredTracer::new(Level::WARN);
477
478 loop {
479 let (event, addp, remp) = select!(
480 Some(n) = hacky_recv_r.recv().fuse() => (Some(n), None, None),
481 Some(n) = b2b_add_protocol_r.recv().fuse() => (None, Some(n), None),
482 Ok(n) = b2b_force_close_recv_protocol_r.recv().fuse() => (None, None, Some(n)),
483 else => {
484 error!("recv_mgr -> something is seriously wrong!, end recv_mgr");
485 break;
486 }
487 );
488
489 if let Some((cid, p)) = addp {
490 debug!(?cid, "add protocol");
491 retrigger(cid, p, &mut recv_protocols);
492 };
493 if let Some(cid) = remp {
494 if remove_c(&mut recv_protocols, &cid) {
496 break;
497 }
498 };
499
500 if let Some((cid, r, p)) = event {
501 match r {
502 Ok(ProtocolEvent::OpenStream {
503 sid,
504 prio,
505 promises,
506 guaranteed_bandwidth,
507 }) => {
508 trace!(?sid, "open stream");
509 let _ = b2b_notify_send_of_recv_open_r.send((
510 cid,
511 sid,
512 prio,
513 promises,
514 guaranteed_bandwidth,
515 ));
516 let stream = self
519 .create_stream(sid, prio, promises, guaranteed_bandwidth)
520 .await;
521 b2a_stream_opened_s.send(stream).unwrap();
522 retrigger(cid, p, &mut recv_protocols);
523 },
524 Ok(ProtocolEvent::CloseStream { sid }) => {
525 trace!(?sid, "close stream");
526 let _ = b2b_notify_send_of_recv_close_s.send((cid, sid));
527 self.delete_stream(sid).await;
528 retrigger(cid, p, &mut recv_protocols);
529 },
530 Ok(ProtocolEvent::Message { data, sid }) => {
531 let lock = self.streams.read().await;
532 match lock.get(&sid) {
533 Some(stream) => {
534 let _ = stream.b2a_msg_recv_s.lock().await.send(data).await;
535 },
536 None => defered_orphan.log(sid),
537 };
538 retrigger(cid, p, &mut recv_protocols);
539 },
540 Ok(ProtocolEvent::Shutdown) => {
541 info!(?cid, "shutdown protocol");
542 if let Err(e) = b2b_close_send_protocol_s.send(cid).await {
543 debug!(?e, ?cid, "send_mgr was already closed simultaneously");
544 }
545 if remove_c(&mut recv_protocols, &cid) {
546 break;
547 }
548 },
549 Err(e) => {
550 info!(?e, ?cid, "protocol failed, shutting down channel");
551 if let Err(e) = b2b_close_send_protocol_s.send(cid).await {
552 debug!(?e, ?cid, "send_mgr was already closed simultaneously");
553 }
554 if remove_c(&mut recv_protocols, &cid) {
555 break;
556 }
557 },
558 }
559 }
560
561 if let Some(table) = defered_orphan.print() {
562 for (sid, cnt) in table.iter() {
563 warn!(?sid, ?cnt, "recv messages with orphan stream");
564 }
565 }
566 }
567 trace!("receiving no longer possible, closing all streams");
568 for (_, si) in self.streams.write().await.drain() {
569 si.send_closed.store(true, Ordering::SeqCst);
570 self.metrics.streams_closed(&self.remote_pid_string);
571 }
572 trace!("Stop recv_mgr");
573 self.shutdown_barrier
574 .fetch_sub(Self::BARR_RECV, Ordering::SeqCst);
575 }
576
577 async fn create_channel_mgr(
578 &self,
579 s2b_create_channel_r: mpsc::UnboundedReceiver<S2bCreateChannel>,
580 b2b_add_send_protocol_s: mpsc::UnboundedSender<(Cid, SendProtocols)>,
581 b2b_add_recv_protocol_s: mpsc::UnboundedSender<(Cid, RecvProtocols)>,
582 b2a_event_s: mpsc::UnboundedSender<ParticipantEvent>,
583 ) {
584 let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r);
585 s2b_create_channel_r
586 .for_each_concurrent(
587 None,
588 |(cid, _, protocol, remote_con_addr, b2s_create_channel_done_s)| {
589 let channels = Arc::clone(&self.channels);
592 let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone();
593 let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone();
594 let b2a_event_s = b2a_event_s.clone();
595 async move {
596 let mut lock = channels.write().await;
597 let mut channel_no = lock.len();
598 lock.insert(
599 cid,
600 Mutex::new(ChannelInfo {
601 cid,
602 cid_string: cid.to_string(),
603 remote_con_addr: remote_con_addr.clone(),
604 }),
605 );
606 drop(lock);
607 let (send, recv) = protocol.split();
608 b2b_add_send_protocol_s.send((cid, send)).unwrap();
609 b2b_add_recv_protocol_s.send((cid, recv)).unwrap();
610 if let Err(e) =
611 b2a_event_s.send(ParticipantEvent::ChannelCreated(remote_con_addr))
612 {
613 debug!(?e, "Participant was dropped during channel connect");
614 };
615 b2s_create_channel_done_s.send(()).unwrap();
616 if channel_no > 5 {
617 debug!(?channel_no, "metrics will overwrite channel #5");
618 channel_no = 5;
619 }
620 self.metrics
621 .channels_connected(&self.remote_pid_string, channel_no, cid);
622 }
623 },
624 )
625 .await;
626 trace!("Stop create_channel_mgr");
627 self.shutdown_barrier
628 .fetch_sub(Self::BARR_CHANNEL, Ordering::SeqCst);
629 }
630
631 async fn participant_shutdown_mgr(
652 &self,
653 s2b_shutdown_bparticipant_r: oneshot::Receiver<S2bShutdownBparticipant>,
654 b2b_close_send_protocol_s: async_channel::Sender<Cid>,
655 b2b_force_close_recv_protocol_s: async_channel::Sender<Cid>,
656 ) {
657 let wait_for_manager = || async {
658 let mut sleep = 0.01f64;
659 loop {
660 let bytes = self.shutdown_barrier.load(Ordering::SeqCst);
661 if bytes == 0 {
662 break;
663 }
664 sleep *= 1.4;
665 tokio::time::sleep(Duration::from_secs_f64(sleep)).await;
666 if sleep > 0.2 {
667 trace!(?bytes, "wait for mgr to close");
668 }
669 }
670 };
671
672 let awaited = s2b_shutdown_bparticipant_r.await.ok();
673 debug!("participant_shutdown_mgr triggered. Closing all streams for send");
674 {
675 let lock = self.streams.read().await;
676 for si in lock.values() {
677 si.send_closed.store(true, Ordering::SeqCst);
678 }
679 }
680
681 let lock = self.channels.read().await;
682 assert!(
683 !lock.is_empty(),
684 "no channel existed remote_pid={}",
685 self.remote_pid
686 );
687 for cid in lock.keys() {
688 if let Err(e) = b2b_close_send_protocol_s.send(*cid).await {
689 debug!(
690 ?e,
691 ?cid,
692 "closing send_mgr may fail if we got a recv error simultaneously"
693 );
694 }
695 }
696 drop(lock);
697
698 trace!("wait for other managers");
699 let timeout = tokio::time::sleep(
700 awaited
701 .as_ref()
702 .map(|(timeout_time, _)| *timeout_time)
703 .unwrap_or_default(),
704 );
705 let timeout = select! {
706 _ = wait_for_manager() => false,
707 _ = timeout => true,
708 };
709 if timeout {
710 warn!("timeout triggered: for killing recv");
711 let lock = self.channels.read().await;
712 for cid in lock.keys() {
713 if let Err(e) = b2b_force_close_recv_protocol_s.send(*cid).await {
714 debug!(
715 ?e,
716 ?cid,
717 "closing recv_mgr may fail if we got a recv error simultaneously"
718 );
719 }
720 }
721 }
722
723 trace!("wait again");
724 wait_for_manager().await;
725
726 if let Some((_, sender)) = awaited {
727 let _ = sender.send(Ok(()));
730 }
731
732 #[cfg(feature = "metrics")]
733 self.metrics.participants_disconnected_total.inc();
734 self.metrics.cleanup_participant(&self.remote_pid_string);
735 trace!("Stop participant_shutdown_mgr");
736 }
737
738 async fn delete_stream(&self, sid: Sid) {
741 let stream = { self.streams.write().await.remove(&sid) };
742 match stream {
743 Some(si) => {
744 si.send_closed.store(true, Ordering::SeqCst);
745 si.b2a_msg_recv_s.lock().await.close();
746 },
747 None => {
748 trace!("Couldn't find the stream, might be simultaneous close from local/remote")
749 },
750 }
751 self.metrics.streams_closed(&self.remote_pid_string);
752 }
753
754 async fn create_stream(
755 &self,
756 sid: Sid,
757 prio: Prio,
758 promises: Promises,
759 guaranteed_bandwidth: Bandwidth,
760 ) -> Stream {
761 let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::<Bytes>();
762 let send_closed = Arc::new(AtomicBool::new(false));
763 self.streams.write().await.insert(sid, StreamInfo {
764 prio,
765 promises,
766 send_closed: Arc::clone(&send_closed),
767 b2a_msg_recv_s: Mutex::new(b2a_msg_recv_s),
768 });
769 self.metrics.streams_opened(&self.remote_pid_string);
770
771 let (a2b_msg_s, a2b_close_stream_s) = {
772 let lock = self.open_stream_channels.lock().await;
773 match &*lock {
774 Some(osi) => (osi.a2b_msg_s.clone(), osi.a2b_close_stream_s.clone()),
775 None => {
776 debug!(
778 "It seems that a stream was requested to open, while the send_mgr is \
779 already closed"
780 );
781 let (a2b_msg_s, _) = crossbeam_channel::unbounded();
782 let (a2b_close_stream_s, _) = mpsc::unbounded_channel();
783 (a2b_msg_s, a2b_close_stream_s)
784 },
785 }
786 };
787
788 Stream::new(
789 self.local_pid,
790 self.remote_pid,
791 sid,
792 prio,
793 promises,
794 guaranteed_bandwidth,
795 send_closed,
796 a2b_msg_s,
797 b2a_msg_recv_r,
798 a2b_close_stream_s,
799 )
800 }
801}
802
803#[cfg(test)]
804mod tests {
805 use super::*;
806 use core::assert_matches::assert_matches;
807 use network_protocol::{ProtocolMetricCache, ProtocolMetrics};
808 use tokio::{
809 runtime::Runtime,
810 sync::{mpsc, oneshot},
811 task::JoinHandle,
812 };
813
814 fn mock_bparticipant() -> (
815 Arc<Runtime>,
816 mpsc::UnboundedSender<A2bStreamOpen>,
817 mpsc::UnboundedReceiver<Stream>,
818 mpsc::UnboundedReceiver<ParticipantEvent>,
819 mpsc::UnboundedSender<S2bCreateChannel>,
820 oneshot::Sender<S2bShutdownBparticipant>,
821 mpsc::UnboundedReceiver<B2sPrioStatistic>,
822 watch::Receiver<f32>,
823 JoinHandle<()>,
824 ) {
825 let runtime = Arc::new(Runtime::new().unwrap());
826 let runtime_clone = Arc::clone(&runtime);
827
828 let (b2s_prio_statistic_s, b2s_prio_statistic_r) =
829 mpsc::unbounded_channel::<B2sPrioStatistic>();
830
831 let (
832 bparticipant,
833 a2b_open_stream_s,
834 b2a_stream_opened_r,
835 b2a_event_r,
836 s2b_create_channel_s,
837 s2b_shutdown_bparticipant_s,
838 b2a_bandwidth_stats_r,
839 ) = runtime_clone.block_on(async move {
840 let local_pid = Pid::fake(0);
841 let remote_pid = Pid::fake(1);
842 let sid = Sid::new(1000);
843 let metrics = Arc::new(NetworkMetrics::new(&local_pid).unwrap());
844
845 BParticipant::new(local_pid, remote_pid, sid, Arc::clone(&metrics))
846 });
847
848 let handle = runtime_clone.spawn(bparticipant.run(b2s_prio_statistic_s));
849 (
850 runtime_clone,
851 a2b_open_stream_s,
852 b2a_stream_opened_r,
853 b2a_event_r,
854 s2b_create_channel_s,
855 s2b_shutdown_bparticipant_s,
856 b2s_prio_statistic_r,
857 b2a_bandwidth_stats_r,
858 handle,
859 )
860 }
861
862 #[expect(clippy::needless_pass_by_ref_mut)]
863 async fn mock_mpsc(
864 cid: Cid,
865 _runtime: &Arc<Runtime>,
866 create_channel: &mut mpsc::UnboundedSender<S2bCreateChannel>,
867 ) -> Protocols {
868 let (s1, r1) = mpsc::channel(100);
869 let (s2, r2) = mpsc::channel(100);
870 let met = Arc::new(ProtocolMetrics::new().unwrap());
871 let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&met));
872 let p1 = Protocols::new_mpsc(s1, r2, metrics);
873 let (complete_s, complete_r) = oneshot::channel();
874 create_channel
875 .send((cid, Sid::new(0), p1, ConnectAddr::Mpsc(42), complete_s))
876 .unwrap();
877 complete_r.await.unwrap();
878 let metrics = ProtocolMetricCache::new(&cid.to_string(), met);
879 Protocols::new_mpsc(s2, r1, metrics)
880 }
881
882 #[test]
883 fn close_bparticipant_by_timeout_during_close() {
884 let (
885 runtime,
886 a2b_open_stream_s,
887 b2a_stream_opened_r,
888 mut b2a_event_r,
889 mut s2b_create_channel_s,
890 s2b_shutdown_bparticipant_s,
891 b2s_prio_statistic_r,
892 _b2a_bandwidth_stats_r,
893 handle,
894 ) = mock_bparticipant();
895
896 let _remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
897 std::thread::sleep(Duration::from_millis(50));
898
899 let (s, r) = oneshot::channel();
900 let before = Instant::now();
901 runtime.block_on(async {
902 drop(s2b_create_channel_s);
903 s2b_shutdown_bparticipant_s
904 .send((Duration::from_secs(1), s))
905 .unwrap();
906 r.await.unwrap().unwrap();
907 });
908 assert!(
909 before.elapsed() > Duration::from_millis(900),
910 "timeout wasn't triggered"
911 );
912 assert_matches!(
913 b2a_event_r.try_recv().unwrap(),
914 ParticipantEvent::ChannelCreated(_)
915 );
916 assert_matches!(
917 b2a_event_r.try_recv().unwrap(),
918 ParticipantEvent::ChannelDeleted(_)
919 );
920 assert_matches!(b2a_event_r.try_recv(), Err(_));
921
922 runtime.block_on(handle).unwrap();
923
924 drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
925 drop(runtime);
926 }
927
928 #[test]
929 fn close_bparticipant_cleanly() {
930 let (
931 runtime,
932 a2b_open_stream_s,
933 b2a_stream_opened_r,
934 mut b2a_event_r,
935 mut s2b_create_channel_s,
936 s2b_shutdown_bparticipant_s,
937 b2s_prio_statistic_r,
938 _b2a_bandwidth_stats_r,
939 handle,
940 ) = mock_bparticipant();
941
942 let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
943 std::thread::sleep(Duration::from_millis(50));
944
945 let (s, r) = oneshot::channel();
946 let before = Instant::now();
947 runtime.block_on(async {
948 drop(s2b_create_channel_s);
949 s2b_shutdown_bparticipant_s
950 .send((Duration::from_secs(2), s))
951 .unwrap();
952 drop(remote); r.await.unwrap().unwrap();
954 });
955 assert!(
956 before.elapsed() < Duration::from_millis(1900),
957 "timeout was triggered"
958 );
959 assert_matches!(
960 b2a_event_r.try_recv().unwrap(),
961 ParticipantEvent::ChannelCreated(_)
962 );
963 assert_matches!(
964 b2a_event_r.try_recv().unwrap(),
965 ParticipantEvent::ChannelDeleted(_)
966 );
967 assert_matches!(b2a_event_r.try_recv(), Err(_));
968
969 runtime.block_on(handle).unwrap();
970
971 drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
972 drop(runtime);
973 }
974
975 #[test]
976 fn create_stream() {
977 let (
978 runtime,
979 a2b_open_stream_s,
980 b2a_stream_opened_r,
981 _b2a_event_r,
982 mut s2b_create_channel_s,
983 s2b_shutdown_bparticipant_s,
984 b2s_prio_statistic_r,
985 _b2a_bandwidth_stats_r,
986 handle,
987 ) = mock_bparticipant();
988
989 let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
990 std::thread::sleep(Duration::from_millis(50));
991
992 let (rs, mut rr) = remote.split();
994 let (stream_sender, _stream_receiver) = oneshot::channel();
995 a2b_open_stream_s
996 .send((7u8, Promises::ENCRYPTED, 1_000_000, stream_sender))
997 .unwrap();
998
999 let stream_event = runtime.block_on(rr.recv()).unwrap();
1000 match stream_event {
1001 ProtocolEvent::OpenStream {
1002 sid,
1003 prio,
1004 promises,
1005 guaranteed_bandwidth,
1006 } => {
1007 assert_eq!(sid, Sid::new(1000));
1008 assert_eq!(prio, 7u8);
1009 assert_eq!(promises, Promises::ENCRYPTED);
1010 assert_eq!(guaranteed_bandwidth, 1_000_000);
1011 },
1012 _ => panic!("wrong event"),
1013 };
1014
1015 let (s, r) = oneshot::channel();
1016 runtime.block_on(async {
1017 drop(s2b_create_channel_s);
1018 s2b_shutdown_bparticipant_s
1019 .send((Duration::from_secs(1), s))
1020 .unwrap();
1021 drop((rs, rr));
1022 r.await.unwrap().unwrap();
1023 });
1024
1025 runtime.block_on(handle).unwrap();
1026
1027 drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
1028 drop(runtime);
1029 }
1030
1031 #[test]
1032 fn created_stream() {
1033 let (
1034 runtime,
1035 a2b_open_stream_s,
1036 mut b2a_stream_opened_r,
1037 _b2a_event_r,
1038 mut s2b_create_channel_s,
1039 s2b_shutdown_bparticipant_s,
1040 b2s_prio_statistic_r,
1041 _b2a_bandwidth_stats_r,
1042 handle,
1043 ) = mock_bparticipant();
1044
1045 let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
1046 std::thread::sleep(Duration::from_millis(50));
1047
1048 let (mut rs, rr) = remote.split();
1050 runtime
1051 .block_on(rs.send(ProtocolEvent::OpenStream {
1052 sid: Sid::new(1000),
1053 prio: 9u8,
1054 promises: Promises::ORDERED,
1055 guaranteed_bandwidth: 1_000_000,
1056 }))
1057 .unwrap();
1058
1059 let stream = runtime.block_on(b2a_stream_opened_r.recv()).unwrap();
1060 assert_eq!(stream.params().promises, Promises::ORDERED);
1061
1062 let (s, r) = oneshot::channel();
1063 runtime.block_on(async {
1064 drop(s2b_create_channel_s);
1065 s2b_shutdown_bparticipant_s
1066 .send((Duration::from_secs(1), s))
1067 .unwrap();
1068 drop((rs, rr));
1069 r.await.unwrap().unwrap();
1070 });
1071
1072 runtime.block_on(handle).unwrap();
1073
1074 drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
1075 drop(runtime);
1076 }
1077}