veloren_network/
participant.rs

1use crate::{
2    api::{ConnectAddr, ParticipantError, ParticipantEvent, Stream},
3    channel::{Protocols, ProtocolsError, RecvProtocols, SendProtocols},
4    metrics::NetworkMetrics,
5    util::DeferredTracer,
6};
7use bytes::Bytes;
8use futures_util::{FutureExt, StreamExt};
9use hashbrown::HashMap;
10use network_protocol::{
11    _internal::SortedVec, Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol,
12    SendProtocol, Sid,
13};
14use std::{
15    sync::{
16        Arc,
17        atomic::{AtomicBool, AtomicI32, Ordering},
18    },
19    time::{Duration, Instant},
20};
21use tokio::{
22    select,
23    sync::{Mutex, RwLock, mpsc, oneshot, watch},
24    task::JoinHandle,
25};
26use tokio_stream::wrappers::UnboundedReceiverStream;
27use tracing::*;
28
29pub(crate) type A2bStreamOpen = (Prio, Promises, Bandwidth, oneshot::Sender<Stream>);
30pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, ConnectAddr, oneshot::Sender<()>);
31pub(crate) type S2bShutdownBparticipant = (Duration, oneshot::Sender<Result<(), ParticipantError>>);
32pub(crate) type B2sPrioStatistic = (Pid, u64, u64);
33
34#[derive(Debug)]
35#[expect(dead_code)]
36struct ChannelInfo {
37    cid: Cid,
38    cid_string: String, //optimisationmetrics
39    remote_con_addr: ConnectAddr,
40}
41
42#[derive(Debug)]
43struct StreamInfo {
44    #[expect(dead_code)]
45    prio: Prio,
46    #[expect(dead_code)]
47    promises: Promises,
48    send_closed: Arc<AtomicBool>,
49    b2a_msg_recv_s: Mutex<async_channel::Sender<Bytes>>,
50}
51
52#[derive(Debug)]
53struct ControlChannels {
54    a2b_open_stream_r: mpsc::UnboundedReceiver<A2bStreamOpen>,
55    b2a_stream_opened_s: mpsc::UnboundedSender<Stream>,
56    b2a_event_s: mpsc::UnboundedSender<ParticipantEvent>,
57    s2b_create_channel_r: mpsc::UnboundedReceiver<S2bCreateChannel>,
58    b2a_bandwidth_stats_s: watch::Sender<f32>,
59    s2b_shutdown_bparticipant_r: oneshot::Receiver<S2bShutdownBparticipant>, /* own */
60}
61
62#[derive(Debug)]
63struct OpenStreamInfo {
64    a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>,
65    a2b_close_stream_s: mpsc::UnboundedSender<Sid>,
66}
67
68#[derive(Debug)]
69pub struct BParticipant {
70    local_pid: Pid, //tracing
71    remote_pid: Pid,
72    remote_pid_string: String, //optimisation
73    offset_sid: Sid,
74    channels: Arc<RwLock<HashMap<Cid, Mutex<ChannelInfo>>>>,
75    streams: RwLock<HashMap<Sid, StreamInfo>>,
76    run_channels: Option<ControlChannels>,
77    shutdown_barrier: AtomicI32,
78    metrics: Arc<NetworkMetrics>,
79    open_stream_channels: Arc<Mutex<Option<OpenStreamInfo>>>,
80}
81
82impl BParticipant {
83    // We use integer instead of Barrier to not block mgr from freeing at the end
84    const BARR_CHANNEL: i32 = 1;
85    const BARR_RECV: i32 = 4;
86    const BARR_SEND: i32 = 2;
87    const TICK_TIME: Duration = Duration::from_millis(Self::TICK_TIME_MS);
88    const TICK_TIME_MS: u64 = 5;
89
90    pub(crate) fn new(
91        local_pid: Pid,
92        remote_pid: Pid,
93        offset_sid: Sid,
94        metrics: Arc<NetworkMetrics>,
95    ) -> (
96        Self,
97        mpsc::UnboundedSender<A2bStreamOpen>,
98        mpsc::UnboundedReceiver<Stream>,
99        mpsc::UnboundedReceiver<ParticipantEvent>,
100        mpsc::UnboundedSender<S2bCreateChannel>,
101        oneshot::Sender<S2bShutdownBparticipant>,
102        watch::Receiver<f32>,
103    ) {
104        let (a2b_open_stream_s, a2b_open_stream_r) = mpsc::unbounded_channel::<A2bStreamOpen>();
105        let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded_channel::<Stream>();
106        let (b2a_event_s, b2a_event_r) = mpsc::unbounded_channel::<ParticipantEvent>();
107        let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel();
108        let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded_channel();
109        let (b2a_bandwidth_stats_s, b2a_bandwidth_stats_r) = watch::channel::<f32>(0.0);
110
111        let run_channels = Some(ControlChannels {
112            a2b_open_stream_r,
113            b2a_stream_opened_s,
114            b2a_event_s,
115            s2b_create_channel_r,
116            b2a_bandwidth_stats_s,
117            s2b_shutdown_bparticipant_r,
118        });
119
120        (
121            Self {
122                local_pid,
123                remote_pid,
124                remote_pid_string: remote_pid.to_string(),
125                offset_sid,
126                channels: Arc::new(RwLock::new(HashMap::new())),
127                streams: RwLock::new(HashMap::new()),
128                shutdown_barrier: AtomicI32::new(
129                    Self::BARR_CHANNEL + Self::BARR_SEND + Self::BARR_RECV,
130                ),
131                run_channels,
132                metrics,
133                open_stream_channels: Arc::new(Mutex::new(None)),
134            },
135            a2b_open_stream_s,
136            b2a_stream_opened_r,
137            b2a_event_r,
138            s2b_create_channel_s,
139            s2b_shutdown_bparticipant_s,
140            b2a_bandwidth_stats_r,
141        )
142    }
143
144    pub async fn run(mut self, b2s_prio_statistic_s: mpsc::UnboundedSender<B2sPrioStatistic>) {
145        let (b2b_add_send_protocol_s, b2b_add_send_protocol_r) =
146            mpsc::unbounded_channel::<(Cid, SendProtocols)>();
147        let (b2b_add_recv_protocol_s, b2b_add_recv_protocol_r) =
148            mpsc::unbounded_channel::<(Cid, RecvProtocols)>();
149        let (b2b_close_send_protocol_s, b2b_close_send_protocol_r) =
150            async_channel::unbounded::<Cid>();
151        let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) =
152            async_channel::unbounded::<Cid>();
153        let (b2b_notify_send_of_recv_open_s, b2b_notify_send_of_recv_open_r) =
154            crossbeam_channel::unbounded::<(Cid, Sid, Prio, Promises, u64)>();
155        let (b2b_notify_send_of_recv_close_s, b2b_notify_send_of_recv_close_r) =
156            crossbeam_channel::unbounded::<(Cid, Sid)>();
157
158        let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::<Sid>();
159        let (a2b_msg_s, a2b_msg_r) = crossbeam_channel::unbounded::<(Sid, Bytes)>();
160
161        *self.open_stream_channels.lock().await = Some(OpenStreamInfo {
162            a2b_msg_s,
163            a2b_close_stream_s,
164        });
165        let run_channels = self.run_channels.take().unwrap();
166        trace!("start all managers");
167        tokio::join!(
168            self.send_mgr(
169                run_channels.a2b_open_stream_r,
170                a2b_close_stream_r,
171                a2b_msg_r,
172                b2b_add_send_protocol_r,
173                b2b_close_send_protocol_r,
174                b2b_notify_send_of_recv_open_r,
175                b2b_notify_send_of_recv_close_r,
176                run_channels.b2a_event_s.clone(),
177                b2s_prio_statistic_s,
178                run_channels.b2a_bandwidth_stats_s,
179            )
180            .instrument(tracing::info_span!("send")),
181            self.recv_mgr(
182                run_channels.b2a_stream_opened_s,
183                b2b_add_recv_protocol_r,
184                b2b_force_close_recv_protocol_r,
185                b2b_close_send_protocol_s.clone(),
186                b2b_notify_send_of_recv_open_s,
187                b2b_notify_send_of_recv_close_s,
188            )
189            .instrument(tracing::info_span!("recv")),
190            self.create_channel_mgr(
191                run_channels.s2b_create_channel_r,
192                b2b_add_send_protocol_s,
193                b2b_add_recv_protocol_s,
194                run_channels.b2a_event_s,
195            ),
196            self.participant_shutdown_mgr(
197                run_channels.s2b_shutdown_bparticipant_r,
198                b2b_close_send_protocol_s.clone(),
199                b2b_force_close_recv_protocol_s,
200            ),
201        );
202    }
203
204    fn best_protocol(all: &SortedVec<Cid, SendProtocols>, promises: Promises) -> Option<Cid> {
205        // check for mpsc
206        all.data.iter().find(|(_, p)| matches!(p, SendProtocols::Mpsc(_))).map(|(c, _)| *c).or_else(
207            || if network_protocol::TcpSendProtocol::<crate::channel::TcpDrain>::supported_promises()
208                .contains(promises)
209            {
210                // check for tcp
211                all.data.iter().find(|(_, p)| matches!(p, SendProtocols::Tcp(_))).map(|(c, _)| *c)
212            } else {
213                None
214            }
215        ).or_else(
216            // check for quic, TODO: evaluate to order quic BEFORE tcp once its stable
217            || if network_protocol::QuicSendProtocol::<crate::channel::QuicDrain>::supported_promises()
218                .contains(promises)
219            {
220                all.data.iter().find(|(_, p)| matches!(p, SendProtocols::Quic(_))).map(|(c, _)| *c)
221            } else {
222                None
223            }
224        ).or_else(
225            || {
226                warn!("couldn't satisfy promises");
227                all.data.first().map(|(c, _)| *c)
228            }
229        )
230    }
231
232    //TODO: local stream_cid: HashMap<Sid, Cid> to know the respective protocol
233    async fn send_mgr(
234        &self,
235        mut a2b_open_stream_r: mpsc::UnboundedReceiver<A2bStreamOpen>,
236        mut a2b_close_stream_r: mpsc::UnboundedReceiver<Sid>,
237        a2b_msg_r: crossbeam_channel::Receiver<(Sid, Bytes)>,
238        mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>,
239        b2b_close_send_protocol_r: async_channel::Receiver<Cid>,
240        b2b_notify_send_of_recv_open_r: crossbeam_channel::Receiver<(
241            Cid,
242            Sid,
243            Prio,
244            Promises,
245            Bandwidth,
246        )>,
247        b2b_notify_send_of_recv_close_r: crossbeam_channel::Receiver<(Cid, Sid)>,
248        b2a_event_s: mpsc::UnboundedSender<ParticipantEvent>,
249        _b2s_prio_statistic_s: mpsc::UnboundedSender<B2sPrioStatistic>,
250        b2a_bandwidth_stats_s: watch::Sender<f32>,
251    ) {
252        let mut sorted_send_protocols = SortedVec::<Cid, SendProtocols>::default();
253        let mut sorted_stream_protocols = SortedVec::<Sid, Cid>::default();
254        let mut interval = tokio::time::interval(Self::TICK_TIME);
255        let mut last_instant = Instant::now();
256        let mut stream_ids = self.offset_sid;
257        let mut part_bandwidth = 0.0f32;
258        trace!("workaround, actively wait for first protocol");
259        if let Some((c, p)) = b2b_add_protocol_r.recv().await {
260            sorted_send_protocols.insert(c, p)
261        }
262        loop {
263            let (open, close, _, addp, remp) = select!(
264                Some(n) = a2b_open_stream_r.recv().fuse() => (Some(n), None, None, None, None),
265                Some(n) = a2b_close_stream_r.recv().fuse() => (None, Some(n), None, None, None),
266                _ = interval.tick() => (None, None, Some(()), None, None),
267                Some(n) = b2b_add_protocol_r.recv().fuse() => (None, None, None, Some(n), None),
268                Ok(n) = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(n)),
269            );
270
271            if let Some((cid, p)) = addp {
272                debug!(?cid, "add protocol");
273                sorted_send_protocols.insert(cid, p);
274            }
275
276            //verify that we have at LEAST 1 channel before continuing
277            if sorted_send_protocols.data.is_empty() {
278                warn!("no channel");
279                tokio::time::sleep(Self::TICK_TIME * 1000).await; //TODO: failover
280                continue;
281            }
282
283            //let (cid, active) = sorted_send_protocols.data.iter_mut().next().unwrap();
284            //used for error handling
285            let mut cid = u64::MAX;
286
287            let active_err = async {
288                if let Some((prio, promises, guaranteed_bandwidth, return_s)) = open {
289                    let sid = stream_ids;
290                    stream_ids += Sid::from(1);
291                    cid = Self::best_protocol(&sorted_send_protocols, promises).unwrap();
292                    trace!(?sid, ?cid, "open stream");
293
294                    let stream = self
295                        .create_stream(sid, prio, promises, guaranteed_bandwidth)
296                        .await;
297
298                    let event = ProtocolEvent::OpenStream {
299                        sid,
300                        prio,
301                        promises,
302                        guaranteed_bandwidth,
303                    };
304
305                    sorted_stream_protocols.insert(sid, cid);
306                    return_s.send(stream).unwrap();
307                    sorted_send_protocols
308                        .get_mut(&cid)
309                        .unwrap()
310                        .send(event)
311                        .await?;
312                }
313
314                // process recv content first
315                for (cid, sid, prio, promises, guaranteed_bandwidth) in
316                    b2b_notify_send_of_recv_open_r.try_iter()
317                {
318                    match sorted_send_protocols.get_mut(&cid) {
319                        Some(p) => {
320                            sorted_stream_protocols.insert(sid, cid);
321                            p.notify_from_recv(ProtocolEvent::OpenStream {
322                                sid,
323                                prio,
324                                promises,
325                                guaranteed_bandwidth,
326                            });
327                        },
328                        None => warn!(?cid, "couldn't notify create protocol, doesn't exist"),
329                    };
330                }
331
332                // get all messages and assign it to a channel
333                for (sid, buffer) in a2b_msg_r.try_iter() {
334                    cid = *sorted_stream_protocols.get(&sid).unwrap();
335                    let event = ProtocolEvent::Message { data: buffer, sid };
336                    sorted_send_protocols
337                        .get_mut(&cid)
338                        .unwrap()
339                        .send(event)
340                        .await?;
341                }
342
343                // process recv content afterwards
344                for (cid, sid) in b2b_notify_send_of_recv_close_r.try_iter() {
345                    match sorted_send_protocols.get_mut(&cid) {
346                        Some(p) => {
347                            let _ = sorted_stream_protocols.delete(&sid);
348                            p.notify_from_recv(ProtocolEvent::CloseStream { sid });
349                        },
350                        None => warn!(?cid, "couldn't notify close protocol, doesn't exist"),
351                    };
352                }
353
354                if let Some(sid) = close {
355                    trace!(?stream_ids, "delete stream");
356                    self.delete_stream(sid).await;
357                    // Fire&Forget the protocol will take care to verify that this Frame is delayed
358                    // till the last msg was received!
359                    if let Some(c) = sorted_stream_protocols.delete(&sid) {
360                        cid = c;
361                        let event = ProtocolEvent::CloseStream { sid };
362                        sorted_send_protocols
363                            .get_mut(&c)
364                            .unwrap()
365                            .send(event)
366                            .await?;
367                    }
368                }
369
370                let send_time = Instant::now();
371                let diff = send_time.duration_since(last_instant);
372                last_instant = send_time;
373                let mut cnt = 0;
374                for (c, p) in sorted_send_protocols.data.iter_mut() {
375                    cid = *c;
376                    cnt += p.flush(1_000_000_000, diff).await?; //this actually blocks, so we cant set streams while it.
377                }
378                let flush_time = send_time.elapsed().as_secs_f32();
379                part_bandwidth = 0.99 * part_bandwidth + 0.01 * (cnt as f32 / flush_time);
380                self.metrics
381                    .participant_bandwidth(&self.remote_pid_string, part_bandwidth);
382                let _ = b2a_bandwidth_stats_s.send(part_bandwidth);
383                let r: Result<(), network_protocol::ProtocolError<ProtocolsError>> = Ok(());
384                r
385            }
386            .await;
387            if let Err(e) = active_err {
388                info!(?cid, ?e, "protocol failed, shutting down channel");
389                // remote recv will now fail, which will trigger remote send which will trigger
390                // recv
391                trace!("TODO: for now decide to FAIL this participant and not wait for a failover");
392                sorted_send_protocols.delete(&cid).unwrap();
393                if let Some(info) = self.channels.write().await.get(&cid) {
394                    if let Err(e) = b2a_event_s.send(ParticipantEvent::ChannelDeleted(
395                        info.lock().await.remote_con_addr.clone(),
396                    )) {
397                        debug!(?e, "Participant was dropped during channel disconnect");
398                    };
399                }
400                self.metrics.channels_disconnected(&self.remote_pid_string);
401                if sorted_send_protocols.data.is_empty() {
402                    break;
403                }
404            }
405
406            if let Some(cid) = remp {
407                debug!(?cid, "remove protocol");
408                match sorted_send_protocols.delete(&cid) {
409                    Some(mut prot) => {
410                        if let Some(info) = self.channels.write().await.get(&cid) {
411                            if let Err(e) = b2a_event_s.send(ParticipantEvent::ChannelDeleted(
412                                info.lock().await.remote_con_addr.clone(),
413                            )) {
414                                debug!(?e, "Participant was dropped during channel disconnect");
415                            };
416                        }
417                        self.metrics.channels_disconnected(&self.remote_pid_string);
418                        trace!("blocking flush");
419                        let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await;
420                        trace!("shutdown prot");
421                        let _ = prot.send(ProtocolEvent::Shutdown).await;
422                    },
423                    None => trace!("tried to remove protocol twice"),
424                };
425                if sorted_send_protocols.data.is_empty() {
426                    break;
427                }
428            }
429        }
430        trace!("stop sending in api!");
431        self.open_stream_channels.lock().await.take();
432        trace!("Stop send_mgr");
433        self.shutdown_barrier
434            .fetch_sub(Self::BARR_SEND, Ordering::SeqCst);
435    }
436
437    async fn recv_mgr(
438        &self,
439        b2a_stream_opened_s: mpsc::UnboundedSender<Stream>,
440        mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>,
441        b2b_force_close_recv_protocol_r: async_channel::Receiver<Cid>,
442        b2b_close_send_protocol_s: async_channel::Sender<Cid>,
443        b2b_notify_send_of_recv_open_r: crossbeam_channel::Sender<(
444            Cid,
445            Sid,
446            Prio,
447            Promises,
448            Bandwidth,
449        )>,
450        b2b_notify_send_of_recv_close_s: crossbeam_channel::Sender<(Cid, Sid)>,
451    ) {
452        let mut recv_protocols: HashMap<Cid, JoinHandle<()>> = HashMap::new();
453        // we should be able to directly await futures imo
454        let (hacky_recv_s, mut hacky_recv_r) = mpsc::unbounded_channel();
455
456        let retrigger = |cid: Cid, mut p: RecvProtocols, map: &mut HashMap<_, _>| {
457            let hacky_recv_s = hacky_recv_s.clone();
458            let handle = tokio::spawn(async move {
459                let r = p.recv().await;
460                let _ = hacky_recv_s.send((cid, r, p)); // ignoring failed
461            });
462            map.insert(cid, handle);
463        };
464
465        let remove_c = |recv_protocols: &mut HashMap<Cid, JoinHandle<()>>, cid: &Cid| {
466            match recv_protocols.remove(cid) {
467                Some(h) => {
468                    h.abort();
469                    debug!(?cid, "remove protocol");
470                },
471                None => trace!("tried to remove protocol twice"),
472            };
473            recv_protocols.is_empty()
474        };
475
476        let mut defered_orphan = DeferredTracer::new(Level::WARN);
477
478        loop {
479            let (event, addp, remp) = select!(
480                Some(n) = hacky_recv_r.recv().fuse() => (Some(n), None, None),
481                Some(n) = b2b_add_protocol_r.recv().fuse() => (None, Some(n), None),
482                Ok(n) = b2b_force_close_recv_protocol_r.recv().fuse() => (None, None, Some(n)),
483                else => {
484                    error!("recv_mgr -> something is seriously wrong!, end recv_mgr");
485                    break;
486                }
487            );
488
489            if let Some((cid, p)) = addp {
490                debug!(?cid, "add protocol");
491                retrigger(cid, p, &mut recv_protocols);
492            };
493            if let Some(cid) = remp {
494                // no need to stop the send_mgr here as it has been canceled before
495                if remove_c(&mut recv_protocols, &cid) {
496                    break;
497                }
498            };
499
500            if let Some((cid, r, p)) = event {
501                match r {
502                    Ok(ProtocolEvent::OpenStream {
503                        sid,
504                        prio,
505                        promises,
506                        guaranteed_bandwidth,
507                    }) => {
508                        trace!(?sid, "open stream");
509                        let _ = b2b_notify_send_of_recv_open_r.send((
510                            cid,
511                            sid,
512                            prio,
513                            promises,
514                            guaranteed_bandwidth,
515                        ));
516                        // waiting for receiving is not necessary, because the send_mgr will first
517                        // process this before process messages!
518                        let stream = self
519                            .create_stream(sid, prio, promises, guaranteed_bandwidth)
520                            .await;
521                        b2a_stream_opened_s.send(stream).unwrap();
522                        retrigger(cid, p, &mut recv_protocols);
523                    },
524                    Ok(ProtocolEvent::CloseStream { sid }) => {
525                        trace!(?sid, "close stream");
526                        let _ = b2b_notify_send_of_recv_close_s.send((cid, sid));
527                        self.delete_stream(sid).await;
528                        retrigger(cid, p, &mut recv_protocols);
529                    },
530                    Ok(ProtocolEvent::Message { data, sid }) => {
531                        let lock = self.streams.read().await;
532                        match lock.get(&sid) {
533                            Some(stream) => {
534                                let _ = stream.b2a_msg_recv_s.lock().await.send(data).await;
535                            },
536                            None => defered_orphan.log(sid),
537                        };
538                        retrigger(cid, p, &mut recv_protocols);
539                    },
540                    Ok(ProtocolEvent::Shutdown) => {
541                        info!(?cid, "shutdown protocol");
542                        if let Err(e) = b2b_close_send_protocol_s.send(cid).await {
543                            debug!(?e, ?cid, "send_mgr was already closed simultaneously");
544                        }
545                        if remove_c(&mut recv_protocols, &cid) {
546                            break;
547                        }
548                    },
549                    Err(e) => {
550                        info!(?e, ?cid, "protocol failed, shutting down channel");
551                        if let Err(e) = b2b_close_send_protocol_s.send(cid).await {
552                            debug!(?e, ?cid, "send_mgr was already closed simultaneously");
553                        }
554                        if remove_c(&mut recv_protocols, &cid) {
555                            break;
556                        }
557                    },
558                }
559            }
560
561            if let Some(table) = defered_orphan.print() {
562                for (sid, cnt) in table.iter() {
563                    warn!(?sid, ?cnt, "recv messages with orphan stream");
564                }
565            }
566        }
567        trace!("receiving no longer possible, closing all streams");
568        for (_, si) in self.streams.write().await.drain() {
569            si.send_closed.store(true, Ordering::SeqCst);
570            self.metrics.streams_closed(&self.remote_pid_string);
571        }
572        trace!("Stop recv_mgr");
573        self.shutdown_barrier
574            .fetch_sub(Self::BARR_RECV, Ordering::SeqCst);
575    }
576
577    async fn create_channel_mgr(
578        &self,
579        s2b_create_channel_r: mpsc::UnboundedReceiver<S2bCreateChannel>,
580        b2b_add_send_protocol_s: mpsc::UnboundedSender<(Cid, SendProtocols)>,
581        b2b_add_recv_protocol_s: mpsc::UnboundedSender<(Cid, RecvProtocols)>,
582        b2a_event_s: mpsc::UnboundedSender<ParticipantEvent>,
583    ) {
584        let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r);
585        s2b_create_channel_r
586            .for_each_concurrent(
587                None,
588                |(cid, _, protocol, remote_con_addr, b2s_create_channel_done_s)| {
589                    // This channel is now configured, and we are running it in scope of the
590                    // participant.
591                    let channels = Arc::clone(&self.channels);
592                    let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone();
593                    let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone();
594                    let b2a_event_s = b2a_event_s.clone();
595                    async move {
596                        let mut lock = channels.write().await;
597                        let mut channel_no = lock.len();
598                        lock.insert(
599                            cid,
600                            Mutex::new(ChannelInfo {
601                                cid,
602                                cid_string: cid.to_string(),
603                                remote_con_addr: remote_con_addr.clone(),
604                            }),
605                        );
606                        drop(lock);
607                        let (send, recv) = protocol.split();
608                        b2b_add_send_protocol_s.send((cid, send)).unwrap();
609                        b2b_add_recv_protocol_s.send((cid, recv)).unwrap();
610                        if let Err(e) =
611                            b2a_event_s.send(ParticipantEvent::ChannelCreated(remote_con_addr))
612                        {
613                            debug!(?e, "Participant was dropped during channel connect");
614                        };
615                        b2s_create_channel_done_s.send(()).unwrap();
616                        if channel_no > 5 {
617                            debug!(?channel_no, "metrics will overwrite channel #5");
618                            channel_no = 5;
619                        }
620                        self.metrics
621                            .channels_connected(&self.remote_pid_string, channel_no, cid);
622                    }
623                },
624            )
625            .await;
626        trace!("Stop create_channel_mgr");
627        self.shutdown_barrier
628            .fetch_sub(Self::BARR_CHANNEL, Ordering::SeqCst);
629    }
630
631    /// sink shutdown:
632    ///  Situation AS, AR, BS, BR. A wants to close.
633    ///  AS shutdown.
634    ///  BR notices shutdown and tries to stops BS. (success)
635    ///  BS shutdown
636    ///  AR notices shutdown and tries to stop AS. (fails)
637    /// For the case where BS didn't get shutdowned, e.g. by a handing situation
638    /// on the remote, we have a timeout to also force close AR.
639    ///
640    /// This fn will:
641    ///   1. stop api to interact with bparticipant by closing sendmsg and
642    ///      openstream
643    ///   2. stop the send_mgr (it will take care of clearing the queue and
644    ///      finish with a Shutdown)
645    ///   3. force stop recv after 60 seconds
646    ///   4. this fn finishes last and afterwards BParticipant drops
647    ///
648    /// before calling this fn, make sure `s2b_create_channel` is closed!
649    /// If BParticipant kills itself managers stay active till this function is
650    /// called by api to get the result status
651    async fn participant_shutdown_mgr(
652        &self,
653        s2b_shutdown_bparticipant_r: oneshot::Receiver<S2bShutdownBparticipant>,
654        b2b_close_send_protocol_s: async_channel::Sender<Cid>,
655        b2b_force_close_recv_protocol_s: async_channel::Sender<Cid>,
656    ) {
657        let wait_for_manager = || async {
658            let mut sleep = 0.01f64;
659            loop {
660                let bytes = self.shutdown_barrier.load(Ordering::SeqCst);
661                if bytes == 0 {
662                    break;
663                }
664                sleep *= 1.4;
665                tokio::time::sleep(Duration::from_secs_f64(sleep)).await;
666                if sleep > 0.2 {
667                    trace!(?bytes, "wait for mgr to close");
668                }
669            }
670        };
671
672        let awaited = s2b_shutdown_bparticipant_r.await.ok();
673        debug!("participant_shutdown_mgr triggered. Closing all streams for send");
674        {
675            let lock = self.streams.read().await;
676            for si in lock.values() {
677                si.send_closed.store(true, Ordering::SeqCst);
678            }
679        }
680
681        let lock = self.channels.read().await;
682        assert!(
683            !lock.is_empty(),
684            "no channel existed remote_pid={}",
685            self.remote_pid
686        );
687        for cid in lock.keys() {
688            if let Err(e) = b2b_close_send_protocol_s.send(*cid).await {
689                debug!(
690                    ?e,
691                    ?cid,
692                    "closing send_mgr may fail if we got a recv error simultaneously"
693                );
694            }
695        }
696        drop(lock);
697
698        trace!("wait for other managers");
699        let timeout = tokio::time::sleep(
700            awaited
701                .as_ref()
702                .map(|(timeout_time, _)| *timeout_time)
703                .unwrap_or_default(),
704        );
705        let timeout = select! {
706            _ = wait_for_manager() => false,
707            _ = timeout => true,
708        };
709        if timeout {
710            warn!("timeout triggered: for killing recv");
711            let lock = self.channels.read().await;
712            for cid in lock.keys() {
713                if let Err(e) = b2b_force_close_recv_protocol_s.send(*cid).await {
714                    debug!(
715                        ?e,
716                        ?cid,
717                        "closing recv_mgr may fail if we got a recv error simultaneously"
718                    );
719                }
720            }
721        }
722
723        trace!("wait again");
724        wait_for_manager().await;
725
726        if let Some((_, sender)) = awaited {
727            // Don't care whether this send succeeded since if the other end is dropped
728            // there's nothing to synchronize on.
729            let _ = sender.send(Ok(()));
730        }
731
732        #[cfg(feature = "metrics")]
733        self.metrics.participants_disconnected_total.inc();
734        self.metrics.cleanup_participant(&self.remote_pid_string);
735        trace!("Stop participant_shutdown_mgr");
736    }
737
738    /// Stopping API and participant usage
739    /// Protocol will take care of the order of the frame
740    async fn delete_stream(&self, sid: Sid) {
741        let stream = { self.streams.write().await.remove(&sid) };
742        match stream {
743            Some(si) => {
744                si.send_closed.store(true, Ordering::SeqCst);
745                si.b2a_msg_recv_s.lock().await.close();
746            },
747            None => {
748                trace!("Couldn't find the stream, might be simultaneous close from local/remote")
749            },
750        }
751        self.metrics.streams_closed(&self.remote_pid_string);
752    }
753
754    async fn create_stream(
755        &self,
756        sid: Sid,
757        prio: Prio,
758        promises: Promises,
759        guaranteed_bandwidth: Bandwidth,
760    ) -> Stream {
761        let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::<Bytes>();
762        let send_closed = Arc::new(AtomicBool::new(false));
763        self.streams.write().await.insert(sid, StreamInfo {
764            prio,
765            promises,
766            send_closed: Arc::clone(&send_closed),
767            b2a_msg_recv_s: Mutex::new(b2a_msg_recv_s),
768        });
769        self.metrics.streams_opened(&self.remote_pid_string);
770
771        let (a2b_msg_s, a2b_close_stream_s) = {
772            let lock = self.open_stream_channels.lock().await;
773            match &*lock {
774                Some(osi) => (osi.a2b_msg_s.clone(), osi.a2b_close_stream_s.clone()),
775                None => {
776                    // This Stream will not be able to send. feed it some "Dummy" Channels.
777                    debug!(
778                        "It seems that a stream was requested to open, while the send_mgr is \
779                         already closed"
780                    );
781                    let (a2b_msg_s, _) = crossbeam_channel::unbounded();
782                    let (a2b_close_stream_s, _) = mpsc::unbounded_channel();
783                    (a2b_msg_s, a2b_close_stream_s)
784                },
785            }
786        };
787
788        Stream::new(
789            self.local_pid,
790            self.remote_pid,
791            sid,
792            prio,
793            promises,
794            guaranteed_bandwidth,
795            send_closed,
796            a2b_msg_s,
797            b2a_msg_recv_r,
798            a2b_close_stream_s,
799        )
800    }
801}
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806    use core::assert_matches::assert_matches;
807    use network_protocol::{ProtocolMetricCache, ProtocolMetrics};
808    use tokio::{
809        runtime::Runtime,
810        sync::{mpsc, oneshot},
811        task::JoinHandle,
812    };
813
814    fn mock_bparticipant() -> (
815        Arc<Runtime>,
816        mpsc::UnboundedSender<A2bStreamOpen>,
817        mpsc::UnboundedReceiver<Stream>,
818        mpsc::UnboundedReceiver<ParticipantEvent>,
819        mpsc::UnboundedSender<S2bCreateChannel>,
820        oneshot::Sender<S2bShutdownBparticipant>,
821        mpsc::UnboundedReceiver<B2sPrioStatistic>,
822        watch::Receiver<f32>,
823        JoinHandle<()>,
824    ) {
825        let runtime = Arc::new(Runtime::new().unwrap());
826        let runtime_clone = Arc::clone(&runtime);
827
828        let (b2s_prio_statistic_s, b2s_prio_statistic_r) =
829            mpsc::unbounded_channel::<B2sPrioStatistic>();
830
831        let (
832            bparticipant,
833            a2b_open_stream_s,
834            b2a_stream_opened_r,
835            b2a_event_r,
836            s2b_create_channel_s,
837            s2b_shutdown_bparticipant_s,
838            b2a_bandwidth_stats_r,
839        ) = runtime_clone.block_on(async move {
840            let local_pid = Pid::fake(0);
841            let remote_pid = Pid::fake(1);
842            let sid = Sid::new(1000);
843            let metrics = Arc::new(NetworkMetrics::new(&local_pid).unwrap());
844
845            BParticipant::new(local_pid, remote_pid, sid, Arc::clone(&metrics))
846        });
847
848        let handle = runtime_clone.spawn(bparticipant.run(b2s_prio_statistic_s));
849        (
850            runtime_clone,
851            a2b_open_stream_s,
852            b2a_stream_opened_r,
853            b2a_event_r,
854            s2b_create_channel_s,
855            s2b_shutdown_bparticipant_s,
856            b2s_prio_statistic_r,
857            b2a_bandwidth_stats_r,
858            handle,
859        )
860    }
861
862    #[expect(clippy::needless_pass_by_ref_mut)]
863    async fn mock_mpsc(
864        cid: Cid,
865        _runtime: &Arc<Runtime>,
866        create_channel: &mut mpsc::UnboundedSender<S2bCreateChannel>,
867    ) -> Protocols {
868        let (s1, r1) = mpsc::channel(100);
869        let (s2, r2) = mpsc::channel(100);
870        let met = Arc::new(ProtocolMetrics::new().unwrap());
871        let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&met));
872        let p1 = Protocols::new_mpsc(s1, r2, metrics);
873        let (complete_s, complete_r) = oneshot::channel();
874        create_channel
875            .send((cid, Sid::new(0), p1, ConnectAddr::Mpsc(42), complete_s))
876            .unwrap();
877        complete_r.await.unwrap();
878        let metrics = ProtocolMetricCache::new(&cid.to_string(), met);
879        Protocols::new_mpsc(s2, r1, metrics)
880    }
881
882    #[test]
883    fn close_bparticipant_by_timeout_during_close() {
884        let (
885            runtime,
886            a2b_open_stream_s,
887            b2a_stream_opened_r,
888            mut b2a_event_r,
889            mut s2b_create_channel_s,
890            s2b_shutdown_bparticipant_s,
891            b2s_prio_statistic_r,
892            _b2a_bandwidth_stats_r,
893            handle,
894        ) = mock_bparticipant();
895
896        let _remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
897        std::thread::sleep(Duration::from_millis(50));
898
899        let (s, r) = oneshot::channel();
900        let before = Instant::now();
901        runtime.block_on(async {
902            drop(s2b_create_channel_s);
903            s2b_shutdown_bparticipant_s
904                .send((Duration::from_secs(1), s))
905                .unwrap();
906            r.await.unwrap().unwrap();
907        });
908        assert!(
909            before.elapsed() > Duration::from_millis(900),
910            "timeout wasn't triggered"
911        );
912        assert_matches!(
913            b2a_event_r.try_recv().unwrap(),
914            ParticipantEvent::ChannelCreated(_)
915        );
916        assert_matches!(
917            b2a_event_r.try_recv().unwrap(),
918            ParticipantEvent::ChannelDeleted(_)
919        );
920        assert_matches!(b2a_event_r.try_recv(), Err(_));
921
922        runtime.block_on(handle).unwrap();
923
924        drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
925        drop(runtime);
926    }
927
928    #[test]
929    fn close_bparticipant_cleanly() {
930        let (
931            runtime,
932            a2b_open_stream_s,
933            b2a_stream_opened_r,
934            mut b2a_event_r,
935            mut s2b_create_channel_s,
936            s2b_shutdown_bparticipant_s,
937            b2s_prio_statistic_r,
938            _b2a_bandwidth_stats_r,
939            handle,
940        ) = mock_bparticipant();
941
942        let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
943        std::thread::sleep(Duration::from_millis(50));
944
945        let (s, r) = oneshot::channel();
946        let before = Instant::now();
947        runtime.block_on(async {
948            drop(s2b_create_channel_s);
949            s2b_shutdown_bparticipant_s
950                .send((Duration::from_secs(2), s))
951                .unwrap();
952            drop(remote); // remote needs to be dropped as soon as local.sender is closed
953            r.await.unwrap().unwrap();
954        });
955        assert!(
956            before.elapsed() < Duration::from_millis(1900),
957            "timeout was triggered"
958        );
959        assert_matches!(
960            b2a_event_r.try_recv().unwrap(),
961            ParticipantEvent::ChannelCreated(_)
962        );
963        assert_matches!(
964            b2a_event_r.try_recv().unwrap(),
965            ParticipantEvent::ChannelDeleted(_)
966        );
967        assert_matches!(b2a_event_r.try_recv(), Err(_));
968
969        runtime.block_on(handle).unwrap();
970
971        drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
972        drop(runtime);
973    }
974
975    #[test]
976    fn create_stream() {
977        let (
978            runtime,
979            a2b_open_stream_s,
980            b2a_stream_opened_r,
981            _b2a_event_r,
982            mut s2b_create_channel_s,
983            s2b_shutdown_bparticipant_s,
984            b2s_prio_statistic_r,
985            _b2a_bandwidth_stats_r,
986            handle,
987        ) = mock_bparticipant();
988
989        let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
990        std::thread::sleep(Duration::from_millis(50));
991
992        // created stream
993        let (rs, mut rr) = remote.split();
994        let (stream_sender, _stream_receiver) = oneshot::channel();
995        a2b_open_stream_s
996            .send((7u8, Promises::ENCRYPTED, 1_000_000, stream_sender))
997            .unwrap();
998
999        let stream_event = runtime.block_on(rr.recv()).unwrap();
1000        match stream_event {
1001            ProtocolEvent::OpenStream {
1002                sid,
1003                prio,
1004                promises,
1005                guaranteed_bandwidth,
1006            } => {
1007                assert_eq!(sid, Sid::new(1000));
1008                assert_eq!(prio, 7u8);
1009                assert_eq!(promises, Promises::ENCRYPTED);
1010                assert_eq!(guaranteed_bandwidth, 1_000_000);
1011            },
1012            _ => panic!("wrong event"),
1013        };
1014
1015        let (s, r) = oneshot::channel();
1016        runtime.block_on(async {
1017            drop(s2b_create_channel_s);
1018            s2b_shutdown_bparticipant_s
1019                .send((Duration::from_secs(1), s))
1020                .unwrap();
1021            drop((rs, rr));
1022            r.await.unwrap().unwrap();
1023        });
1024
1025        runtime.block_on(handle).unwrap();
1026
1027        drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
1028        drop(runtime);
1029    }
1030
1031    #[test]
1032    fn created_stream() {
1033        let (
1034            runtime,
1035            a2b_open_stream_s,
1036            mut b2a_stream_opened_r,
1037            _b2a_event_r,
1038            mut s2b_create_channel_s,
1039            s2b_shutdown_bparticipant_s,
1040            b2s_prio_statistic_r,
1041            _b2a_bandwidth_stats_r,
1042            handle,
1043        ) = mock_bparticipant();
1044
1045        let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
1046        std::thread::sleep(Duration::from_millis(50));
1047
1048        // create stream
1049        let (mut rs, rr) = remote.split();
1050        runtime
1051            .block_on(rs.send(ProtocolEvent::OpenStream {
1052                sid: Sid::new(1000),
1053                prio: 9u8,
1054                promises: Promises::ORDERED,
1055                guaranteed_bandwidth: 1_000_000,
1056            }))
1057            .unwrap();
1058
1059        let stream = runtime.block_on(b2a_stream_opened_r.recv()).unwrap();
1060        assert_eq!(stream.params().promises, Promises::ORDERED);
1061
1062        let (s, r) = oneshot::channel();
1063        runtime.block_on(async {
1064            drop(s2b_create_channel_s);
1065            s2b_shutdown_bparticipant_s
1066                .send((Duration::from_secs(1), s))
1067                .unwrap();
1068            drop((rs, rr));
1069            r.await.unwrap().unwrap();
1070        });
1071
1072        runtime.block_on(handle).unwrap();
1073
1074        drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
1075        drop(runtime);
1076    }
1077}