veloren_network/
scheduler.rs

1use crate::{
2    api::{ConnectAddr, ListenAddr, NetworkConnectError, Participant},
3    channel::Protocols,
4    metrics::{NetworkMetrics, ProtocolInfo},
5    participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant},
6};
7use futures_util::StreamExt;
8use hashbrown::HashMap;
9use network_protocol::{Cid, Pid, ProtocolMetricCache, ProtocolMetrics};
10#[cfg(feature = "metrics")]
11use prometheus::Registry;
12use rand::Rng;
13use std::{
14    sync::{
15        Arc,
16        atomic::{AtomicBool, AtomicU64, Ordering},
17    },
18    time::Duration,
19};
20use tokio::{
21    io,
22    sync::{Mutex, mpsc, oneshot},
23};
24use tokio_stream::wrappers::UnboundedReceiverStream;
25use tracing::*;
26
27// Naming of Channels `x2x`
28//  - a: api
29//  - s: scheduler
30//  - b: bparticipant
31//  - p: prios
32//  - r: protocol
33//  - w: wire
34//  - c: channel/handshake
35
36#[derive(Debug)]
37struct ParticipantInfo {
38    secret: u128,
39    #[expect(dead_code)]
40    s2b_create_channel_s: mpsc::UnboundedSender<S2bCreateChannel>,
41    s2b_shutdown_bparticipant_s: Option<oneshot::Sender<S2bShutdownBparticipant>>,
42}
43
44type A2sListen = (ListenAddr, oneshot::Sender<io::Result<()>>);
45pub(crate) type A2sConnect = (
46    ConnectAddr,
47    oneshot::Sender<Result<Participant, NetworkConnectError>>,
48);
49type A2sDisconnect = (Pid, S2bShutdownBparticipant);
50
51#[derive(Debug)]
52struct ControlChannels {
53    a2s_listen_r: mpsc::UnboundedReceiver<A2sListen>,
54    a2s_connect_r: mpsc::UnboundedReceiver<A2sConnect>,
55    a2s_scheduler_shutdown_r: oneshot::Receiver<()>,
56    a2s_disconnect_r: mpsc::UnboundedReceiver<A2sDisconnect>,
57    b2s_prio_statistic_r: mpsc::UnboundedReceiver<B2sPrioStatistic>,
58}
59
60#[derive(Debug, Clone)]
61struct ParticipantChannels {
62    s2a_connected_s: mpsc::UnboundedSender<Participant>,
63    a2s_disconnect_s: mpsc::UnboundedSender<A2sDisconnect>,
64    b2s_prio_statistic_s: mpsc::UnboundedSender<B2sPrioStatistic>,
65}
66
67#[derive(Debug)]
68pub struct Scheduler {
69    local_pid: Pid,
70    local_secret: u128,
71    closed: AtomicBool,
72    run_channels: Option<ControlChannels>,
73    participant_channels: Arc<Mutex<Option<ParticipantChannels>>>,
74    participants: Arc<Mutex<HashMap<Pid, ParticipantInfo>>>,
75    channel_ids: Arc<AtomicU64>,
76    channel_listener: Mutex<HashMap<ProtocolInfo, oneshot::Sender<()>>>,
77    metrics: Arc<NetworkMetrics>,
78    protocol_metrics: Arc<ProtocolMetrics>,
79}
80
81impl Scheduler {
82    pub fn new(
83        local_pid: Pid,
84        #[cfg(feature = "metrics")] registry: Option<&Registry>,
85    ) -> (
86        Self,
87        mpsc::UnboundedSender<A2sListen>,
88        mpsc::UnboundedSender<A2sConnect>,
89        mpsc::UnboundedReceiver<Participant>,
90        oneshot::Sender<()>,
91    ) {
92        let (a2s_listen_s, a2s_listen_r) = mpsc::unbounded_channel::<A2sListen>();
93        let (a2s_connect_s, a2s_connect_r) = mpsc::unbounded_channel::<A2sConnect>();
94        let (s2a_connected_s, s2a_connected_r) = mpsc::unbounded_channel::<Participant>();
95        let (a2s_scheduler_shutdown_s, a2s_scheduler_shutdown_r) = oneshot::channel::<()>();
96        let (a2s_disconnect_s, a2s_disconnect_r) = mpsc::unbounded_channel::<A2sDisconnect>();
97        let (b2s_prio_statistic_s, b2s_prio_statistic_r) =
98            mpsc::unbounded_channel::<B2sPrioStatistic>();
99
100        let run_channels = Some(ControlChannels {
101            a2s_listen_r,
102            a2s_connect_r,
103            a2s_scheduler_shutdown_r,
104            a2s_disconnect_r,
105            b2s_prio_statistic_r,
106        });
107
108        let participant_channels = ParticipantChannels {
109            s2a_connected_s,
110            a2s_disconnect_s,
111            b2s_prio_statistic_s,
112        };
113
114        let metrics = Arc::new(NetworkMetrics::new(&local_pid).unwrap());
115        let protocol_metrics = Arc::new(ProtocolMetrics::new().unwrap());
116
117        #[cfg(feature = "metrics")]
118        {
119            if let Some(registry) = registry {
120                metrics.register(registry).unwrap();
121                protocol_metrics.register(registry).unwrap();
122            }
123        }
124
125        let mut rng = rand::thread_rng();
126        let local_secret: u128 = rng.gen();
127
128        (
129            Self {
130                local_pid,
131                local_secret,
132                closed: AtomicBool::new(false),
133                run_channels,
134                participant_channels: Arc::new(Mutex::new(Some(participant_channels))),
135                participants: Arc::new(Mutex::new(HashMap::new())),
136                channel_ids: Arc::new(AtomicU64::new(0)),
137                channel_listener: Mutex::new(HashMap::new()),
138                metrics,
139                protocol_metrics,
140            },
141            a2s_listen_s,
142            a2s_connect_s,
143            s2a_connected_r,
144            a2s_scheduler_shutdown_s,
145        )
146    }
147
148    pub async fn run(mut self) {
149        let run_channels = self
150            .run_channels
151            .take()
152            .expect("run() can only be called once");
153
154        tokio::join!(
155            self.listen_mgr(run_channels.a2s_listen_r),
156            self.connect_mgr(run_channels.a2s_connect_r),
157            self.disconnect_mgr(run_channels.a2s_disconnect_r),
158            self.prio_adj_mgr(run_channels.b2s_prio_statistic_r),
159            self.scheduler_shutdown_mgr(run_channels.a2s_scheduler_shutdown_r),
160        );
161    }
162
163    async fn listen_mgr(&self, a2s_listen_r: mpsc::UnboundedReceiver<A2sListen>) {
164        trace!("Start listen_mgr");
165        let a2s_listen_r = UnboundedReceiverStream::new(a2s_listen_r);
166        a2s_listen_r
167            .for_each_concurrent(None, |(address, s2a_listen_result_s)| {
168                let address = address;
169                let cids = Arc::clone(&self.channel_ids);
170
171                #[cfg(feature = "metrics")]
172                let mcache = self.metrics.connect_requests_cache(&address);
173
174                debug!(?address, "Got request to open a channel_creator");
175                self.metrics.listen_request(&address);
176                let (s2s_stop_listening_s, s2s_stop_listening_r) = oneshot::channel::<()>();
177                let (c2s_protocol_s, mut c2s_protocol_r) = mpsc::unbounded_channel();
178                let metrics = Arc::clone(&self.protocol_metrics);
179
180                async move {
181                    self.channel_listener
182                        .lock()
183                        .await
184                        .insert(address.clone().into(), s2s_stop_listening_s);
185
186                    #[cfg(feature = "metrics")]
187                    mcache.inc();
188
189                    let res = match address {
190                        ListenAddr::Tcp(addr) => {
191                            Protocols::with_tcp_listen(
192                                addr,
193                                cids,
194                                metrics,
195                                s2s_stop_listening_r,
196                                c2s_protocol_s,
197                            )
198                            .await
199                        },
200                        #[cfg(feature = "quic")]
201                        ListenAddr::Quic(addr, ref server_config) => {
202                            Protocols::with_quic_listen(
203                                addr,
204                                server_config.clone(),
205                                cids,
206                                metrics,
207                                s2s_stop_listening_r,
208                                c2s_protocol_s,
209                            )
210                            .await
211                        },
212                        ListenAddr::Mpsc(addr) => {
213                            Protocols::with_mpsc_listen(
214                                addr,
215                                cids,
216                                metrics,
217                                s2s_stop_listening_r,
218                                c2s_protocol_s,
219                            )
220                            .await
221                        },
222                        _ => unimplemented!(),
223                    };
224                    let _ = s2a_listen_result_s.send(res);
225
226                    while let Some((prot, con_addr, cid)) = c2s_protocol_r.recv().await {
227                        self.init_protocol(prot, con_addr, cid, None, true).await;
228                    }
229                }
230            })
231            .await;
232        trace!("Stop listen_mgr");
233    }
234
235    async fn connect_mgr(&self, mut a2s_connect_r: mpsc::UnboundedReceiver<A2sConnect>) {
236        trace!("Start connect_mgr");
237        while let Some((addr, pid_sender)) = a2s_connect_r.recv().await {
238            let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
239            let metrics =
240                ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&self.protocol_metrics));
241            self.metrics.connect_request(&addr);
242            let protocol = match addr.clone() {
243                ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, metrics).await,
244                #[cfg(feature = "quic")]
245                ConnectAddr::Quic(addr, ref config, name) => {
246                    Protocols::with_quic_connect(addr, config.clone(), name, metrics).await
247                },
248                ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, metrics).await,
249                _ => unimplemented!(),
250            };
251            let protocol = match protocol {
252                Ok(p) => p,
253                Err(e) => {
254                    pid_sender.send(Err(e)).unwrap();
255                    continue;
256                },
257            };
258            self.init_protocol(protocol, addr, cid, Some(pid_sender), false)
259                .await;
260        }
261        trace!("Stop connect_mgr");
262    }
263
264    async fn disconnect_mgr(&self, a2s_disconnect_r: mpsc::UnboundedReceiver<A2sDisconnect>) {
265        trace!("Start disconnect_mgr");
266
267        let a2s_disconnect_r = UnboundedReceiverStream::new(a2s_disconnect_r);
268        a2s_disconnect_r
269            .for_each_concurrent(
270                None,
271                |(pid, (timeout_time, return_once_successful_shutdown))| {
272                    //Closing Participants is done the following way:
273                    // 1. We drop our senders and receivers
274                    // 2. we need to close BParticipant, this will drop its senderns and receivers
275                    // 3. Participant will try to access the BParticipant senders and receivers with
276                    // their next api action, it will fail and be closed then.
277                    let participants = Arc::clone(&self.participants);
278                    async move {
279                        trace!(?pid, "Got request to close participant");
280                        let pi = participants.lock().await.remove(&pid);
281                        trace!(?pid, "dropped participants lock");
282                        let r = if let Some(mut pi) = pi {
283                            let (finished_sender, finished_receiver) = oneshot::channel();
284                            // NOTE: If there's nothing to synchronize on (because the send failed)
285                            // we can assume everything relevant was shut down.
286                            let _ = pi
287                                .s2b_shutdown_bparticipant_s
288                                .take()
289                                .unwrap()
290                                .send((timeout_time, finished_sender));
291                            drop(pi);
292                            trace!(?pid, "dropped bparticipant, waiting for finish");
293                            // If await fails, already shut down, so send Ok(()).
294                            let e = finished_receiver.await.unwrap_or(Ok(()));
295                            trace!(?pid, "waiting completed");
296                            // can fail as api.rs has a timeout
297                            return_once_successful_shutdown.send(e)
298                        } else {
299                            debug!(?pid, "Looks like participant is already dropped");
300                            return_once_successful_shutdown.send(Ok(()))
301                        };
302                        if r.is_err() {
303                            trace!(?pid, "Closed participant with timeout");
304                        } else {
305                            trace!(?pid, "Closed participant");
306                        }
307                    }
308                },
309            )
310            .await;
311        trace!("Stop disconnect_mgr");
312    }
313
314    async fn prio_adj_mgr(
315        &self,
316        mut b2s_prio_statistic_r: mpsc::UnboundedReceiver<B2sPrioStatistic>,
317    ) {
318        trace!("Start prio_adj_mgr");
319        while let Some((_pid, _frame_cnt, _unused)) = b2s_prio_statistic_r.recv().await {
320
321            //TODO adjust prios in participants here!
322        }
323        trace!("Stop prio_adj_mgr");
324    }
325
326    async fn scheduler_shutdown_mgr(&self, a2s_scheduler_shutdown_r: oneshot::Receiver<()>) {
327        trace!("Start scheduler_shutdown_mgr");
328        if a2s_scheduler_shutdown_r.await.is_err() {
329            warn!("Schedule shutdown got triggered because a2s_scheduler_shutdown_r failed");
330        };
331        info!("Shutdown of scheduler requested");
332        self.closed.store(true, Ordering::SeqCst);
333        debug!("Shutting down all BParticipants gracefully");
334        let mut participants = self.participants.lock().await;
335        let waitings = participants
336            .drain()
337            .map(|(pid, mut pi)| {
338                trace!(?pid, "Shutting down BParticipants");
339                let (finished_sender, finished_receiver) = oneshot::channel();
340                pi.s2b_shutdown_bparticipant_s
341                    .take()
342                    .unwrap()
343                    .send((Duration::from_secs(120), finished_sender))
344                    .unwrap();
345                (pid, finished_receiver)
346            })
347            .collect::<Vec<_>>();
348        drop(participants);
349        debug!("Wait for partiticipants to be shut down");
350        for (pid, recv) in waitings {
351            if let Err(e) = recv.await {
352                error!(
353                    ?pid,
354                    ?e,
355                    "Failed to finish sending all remaining messages to participant when shutting \
356                     down"
357                );
358            };
359        }
360        debug!("shutting down protocol listeners");
361        for (addr, end_channel_sender) in self.channel_listener.lock().await.drain() {
362            trace!(?addr, "stopping listen on protocol");
363            if let Err(e) = end_channel_sender.send(()) {
364                warn!(?addr, ?e, "listener crashed/disconnected already");
365            }
366        }
367        debug!("Scheduler shut down gracefully");
368        //removing the possibility to create new participants, needed to close down
369        // some mgr:
370        self.participant_channels.lock().await.take();
371
372        trace!("Stop scheduler_shutdown_mgr");
373    }
374
375    async fn init_protocol(
376        &self,
377        mut protocol: Protocols,
378        con_addr: ConnectAddr, //address necessary to connect to the remote
379        cid: Cid,
380        s2a_return_pid_s: Option<oneshot::Sender<Result<Participant, NetworkConnectError>>>,
381        send_handshake: bool,
382    ) {
383        //channels are unknown till PID is known!
384        /* When A connects to a NETWORK, we, the listener answers with a Handshake.
385          Pro: - Its easier to debug, as someone who opens a port gets a magic number back!
386          Contra: - DOS possibility because we answer first
387                  - Speed, because otherwise the message can be send with the creation
388        */
389        let participant_channels = self.participant_channels.lock().await.clone().unwrap();
390        // spawn is needed here, e.g. for TCP connect it would mean that only 1
391        // participant can be in handshake phase ever! Someone could deadlock
392        // the whole server easily for new clients UDP doesnt work at all, as
393        // the UDP listening is done in another place.
394        let participants = Arc::clone(&self.participants);
395        let metrics = Arc::clone(&self.metrics);
396        let local_pid = self.local_pid;
397        let local_secret = self.local_secret;
398        // this is necessary for UDP to work at all and to remove code duplication
399        tokio::spawn(
400            async move {
401                trace!(?cid, "Open channel and be ready for Handshake");
402                use network_protocol::InitProtocol;
403                let init_result = protocol
404                    .initialize(send_handshake, local_pid, local_secret)
405                    .instrument(info_span!("handshake", ?cid))
406                    .await;
407                match init_result {
408                    Ok((pid, sid, secret)) => {
409                        trace!(
410                            ?cid,
411                            ?pid,
412                            "Detected that my channel is ready!, activating it :)"
413                        );
414                        let mut participants = participants.lock().await;
415                        if !participants.contains_key(&pid) {
416                            debug!(?cid, "New participant connected via a channel");
417                            let (
418                                bparticipant,
419                                a2b_open_stream_s,
420                                b2a_stream_opened_r,
421                                b2a_event_r,
422                                s2b_create_channel_s,
423                                s2b_shutdown_bparticipant_s,
424                                b2a_bandwidth_stats_r,
425                            ) = BParticipant::new(local_pid, pid, sid, Arc::clone(&metrics));
426
427                            let participant = Participant::new(
428                                local_pid,
429                                pid,
430                                a2b_open_stream_s,
431                                b2a_stream_opened_r,
432                                b2a_event_r,
433                                b2a_bandwidth_stats_r,
434                                participant_channels.a2s_disconnect_s,
435                            );
436
437                            #[cfg(feature = "metrics")]
438                            metrics.participants_connected_total.inc();
439                            participants.insert(pid, ParticipantInfo {
440                                secret,
441                                s2b_create_channel_s: s2b_create_channel_s.clone(),
442                                s2b_shutdown_bparticipant_s: Some(s2b_shutdown_bparticipant_s),
443                            });
444                            drop(participants);
445                            trace!("dropped participants lock");
446                            let p = pid;
447                            tokio::spawn(
448                                bparticipant
449                                    .run(participant_channels.b2s_prio_statistic_s)
450                                    .instrument(info_span!("remote", ?p)),
451                            );
452                            //create a new channel within BParticipant and wait for it to run
453                            let (b2s_create_channel_done_s, b2s_create_channel_done_r) =
454                                oneshot::channel();
455                            //From now on wire connects directly with bparticipant!
456                            s2b_create_channel_s
457                                .send((cid, sid, protocol, con_addr, b2s_create_channel_done_s))
458                                .unwrap();
459                            b2s_create_channel_done_r.await.unwrap();
460                            if let Some(pid_oneshot) = s2a_return_pid_s {
461                                // someone is waiting with `connect`, so give them their PID
462                                pid_oneshot.send(Ok(participant)).unwrap();
463                            } else {
464                                // no one is waiting on this Participant, return in to Network
465                                if participant_channels
466                                    .s2a_connected_s
467                                    .send(participant)
468                                    .is_err()
469                                {
470                                    warn!("seems like Network already got closed");
471                                };
472                            }
473                        } else {
474                            let pi = &participants[&pid];
475                            trace!(
476                                ?cid,
477                                "2nd+ channel of participant, going to compare security ids"
478                            );
479                            if pi.secret != secret {
480                                warn!(
481                                    ?cid,
482                                    ?pid,
483                                    ?secret,
484                                    "Detected incompatible Secret!, this is probably an attack!"
485                                );
486                                error!(?cid, "Just dropping here, TODO handle this correctly!");
487                                //TODO
488                                if let Some(pid_oneshot) = s2a_return_pid_s {
489                                    // someone is waiting with `connect`, so give them their Error
490                                    pid_oneshot
491                                        .send(Err(NetworkConnectError::InvalidSecret))
492                                        .unwrap();
493                                }
494                                return;
495                            }
496                            error!(
497                                ?cid,
498                                "Ufff i cant answer the pid_oneshot. as i need to create the SAME \
499                                 participant. maybe switch to ARC"
500                            );
501                        }
502                        //From now on this CHANNEL can receiver other frames!
503                        // move directly to participant!
504                    },
505                    Err(e) => {
506                        debug!(?cid, ?e, "Handshake from a new connection failed");
507                        #[cfg(feature = "metrics")]
508                        metrics.failed_handshakes_total.inc();
509                        if let Some(pid_oneshot) = s2a_return_pid_s {
510                            // someone is waiting with `connect`, so give them their Error
511                            trace!(?cid, "returning the Err to api who requested the connect");
512                            pid_oneshot
513                                .send(Err(NetworkConnectError::Handshake(e)))
514                                .unwrap();
515                        }
516                    },
517                }
518            }
519            .instrument(info_span!("")),
520        ); /*WORKAROUND FOR SPAN NOT TO GET LOST*/
521    }
522}