1use crate::{
2 api::{ConnectAddr, ListenAddr, NetworkConnectError, Participant},
3 channel::Protocols,
4 metrics::{NetworkMetrics, ProtocolInfo},
5 participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant},
6};
7use futures_util::StreamExt;
8use hashbrown::HashMap;
9use network_protocol::{Cid, Pid, ProtocolMetricCache, ProtocolMetrics};
10#[cfg(feature = "metrics")]
11use prometheus::Registry;
12use rand::Rng;
13use std::{
14 sync::{
15 Arc,
16 atomic::{AtomicBool, AtomicU64, Ordering},
17 },
18 time::Duration,
19};
20use tokio::{
21 io,
22 sync::{Mutex, mpsc, oneshot},
23};
24use tokio_stream::wrappers::UnboundedReceiverStream;
25use tracing::*;
26
27#[derive(Debug)]
37struct ParticipantInfo {
38 secret: u128,
39 #[expect(dead_code)]
40 s2b_create_channel_s: mpsc::UnboundedSender<S2bCreateChannel>,
41 s2b_shutdown_bparticipant_s: Option<oneshot::Sender<S2bShutdownBparticipant>>,
42}
43
44type A2sListen = (ListenAddr, oneshot::Sender<io::Result<()>>);
45pub(crate) type A2sConnect = (
46 ConnectAddr,
47 oneshot::Sender<Result<Participant, NetworkConnectError>>,
48);
49type A2sDisconnect = (Pid, S2bShutdownBparticipant);
50
51#[derive(Debug)]
52struct ControlChannels {
53 a2s_listen_r: mpsc::UnboundedReceiver<A2sListen>,
54 a2s_connect_r: mpsc::UnboundedReceiver<A2sConnect>,
55 a2s_scheduler_shutdown_r: oneshot::Receiver<()>,
56 a2s_disconnect_r: mpsc::UnboundedReceiver<A2sDisconnect>,
57 b2s_prio_statistic_r: mpsc::UnboundedReceiver<B2sPrioStatistic>,
58}
59
60#[derive(Debug, Clone)]
61struct ParticipantChannels {
62 s2a_connected_s: mpsc::UnboundedSender<Participant>,
63 a2s_disconnect_s: mpsc::UnboundedSender<A2sDisconnect>,
64 b2s_prio_statistic_s: mpsc::UnboundedSender<B2sPrioStatistic>,
65}
66
67#[derive(Debug)]
68pub struct Scheduler {
69 local_pid: Pid,
70 local_secret: u128,
71 closed: AtomicBool,
72 run_channels: Option<ControlChannels>,
73 participant_channels: Arc<Mutex<Option<ParticipantChannels>>>,
74 participants: Arc<Mutex<HashMap<Pid, ParticipantInfo>>>,
75 channel_ids: Arc<AtomicU64>,
76 channel_listener: Mutex<HashMap<ProtocolInfo, oneshot::Sender<()>>>,
77 metrics: Arc<NetworkMetrics>,
78 protocol_metrics: Arc<ProtocolMetrics>,
79}
80
81impl Scheduler {
82 pub fn new(
83 local_pid: Pid,
84 #[cfg(feature = "metrics")] registry: Option<&Registry>,
85 ) -> (
86 Self,
87 mpsc::UnboundedSender<A2sListen>,
88 mpsc::UnboundedSender<A2sConnect>,
89 mpsc::UnboundedReceiver<Participant>,
90 oneshot::Sender<()>,
91 ) {
92 let (a2s_listen_s, a2s_listen_r) = mpsc::unbounded_channel::<A2sListen>();
93 let (a2s_connect_s, a2s_connect_r) = mpsc::unbounded_channel::<A2sConnect>();
94 let (s2a_connected_s, s2a_connected_r) = mpsc::unbounded_channel::<Participant>();
95 let (a2s_scheduler_shutdown_s, a2s_scheduler_shutdown_r) = oneshot::channel::<()>();
96 let (a2s_disconnect_s, a2s_disconnect_r) = mpsc::unbounded_channel::<A2sDisconnect>();
97 let (b2s_prio_statistic_s, b2s_prio_statistic_r) =
98 mpsc::unbounded_channel::<B2sPrioStatistic>();
99
100 let run_channels = Some(ControlChannels {
101 a2s_listen_r,
102 a2s_connect_r,
103 a2s_scheduler_shutdown_r,
104 a2s_disconnect_r,
105 b2s_prio_statistic_r,
106 });
107
108 let participant_channels = ParticipantChannels {
109 s2a_connected_s,
110 a2s_disconnect_s,
111 b2s_prio_statistic_s,
112 };
113
114 let metrics = Arc::new(NetworkMetrics::new(&local_pid).unwrap());
115 let protocol_metrics = Arc::new(ProtocolMetrics::new().unwrap());
116
117 #[cfg(feature = "metrics")]
118 {
119 if let Some(registry) = registry {
120 metrics.register(registry).unwrap();
121 protocol_metrics.register(registry).unwrap();
122 }
123 }
124
125 let mut rng = rand::thread_rng();
126 let local_secret: u128 = rng.gen();
127
128 (
129 Self {
130 local_pid,
131 local_secret,
132 closed: AtomicBool::new(false),
133 run_channels,
134 participant_channels: Arc::new(Mutex::new(Some(participant_channels))),
135 participants: Arc::new(Mutex::new(HashMap::new())),
136 channel_ids: Arc::new(AtomicU64::new(0)),
137 channel_listener: Mutex::new(HashMap::new()),
138 metrics,
139 protocol_metrics,
140 },
141 a2s_listen_s,
142 a2s_connect_s,
143 s2a_connected_r,
144 a2s_scheduler_shutdown_s,
145 )
146 }
147
148 pub async fn run(mut self) {
149 let run_channels = self
150 .run_channels
151 .take()
152 .expect("run() can only be called once");
153
154 tokio::join!(
155 self.listen_mgr(run_channels.a2s_listen_r),
156 self.connect_mgr(run_channels.a2s_connect_r),
157 self.disconnect_mgr(run_channels.a2s_disconnect_r),
158 self.prio_adj_mgr(run_channels.b2s_prio_statistic_r),
159 self.scheduler_shutdown_mgr(run_channels.a2s_scheduler_shutdown_r),
160 );
161 }
162
163 async fn listen_mgr(&self, a2s_listen_r: mpsc::UnboundedReceiver<A2sListen>) {
164 trace!("Start listen_mgr");
165 let a2s_listen_r = UnboundedReceiverStream::new(a2s_listen_r);
166 a2s_listen_r
167 .for_each_concurrent(None, |(address, s2a_listen_result_s)| {
168 let address = address;
169 let cids = Arc::clone(&self.channel_ids);
170
171 #[cfg(feature = "metrics")]
172 let mcache = self.metrics.connect_requests_cache(&address);
173
174 debug!(?address, "Got request to open a channel_creator");
175 self.metrics.listen_request(&address);
176 let (s2s_stop_listening_s, s2s_stop_listening_r) = oneshot::channel::<()>();
177 let (c2s_protocol_s, mut c2s_protocol_r) = mpsc::unbounded_channel();
178 let metrics = Arc::clone(&self.protocol_metrics);
179
180 async move {
181 self.channel_listener
182 .lock()
183 .await
184 .insert(address.clone().into(), s2s_stop_listening_s);
185
186 #[cfg(feature = "metrics")]
187 mcache.inc();
188
189 let res = match address {
190 ListenAddr::Tcp(addr) => {
191 Protocols::with_tcp_listen(
192 addr,
193 cids,
194 metrics,
195 s2s_stop_listening_r,
196 c2s_protocol_s,
197 )
198 .await
199 },
200 #[cfg(feature = "quic")]
201 ListenAddr::Quic(addr, ref server_config) => {
202 Protocols::with_quic_listen(
203 addr,
204 server_config.clone(),
205 cids,
206 metrics,
207 s2s_stop_listening_r,
208 c2s_protocol_s,
209 )
210 .await
211 },
212 ListenAddr::Mpsc(addr) => {
213 Protocols::with_mpsc_listen(
214 addr,
215 cids,
216 metrics,
217 s2s_stop_listening_r,
218 c2s_protocol_s,
219 )
220 .await
221 },
222 _ => unimplemented!(),
223 };
224 let _ = s2a_listen_result_s.send(res);
225
226 while let Some((prot, con_addr, cid)) = c2s_protocol_r.recv().await {
227 self.init_protocol(prot, con_addr, cid, None, true).await;
228 }
229 }
230 })
231 .await;
232 trace!("Stop listen_mgr");
233 }
234
235 async fn connect_mgr(&self, mut a2s_connect_r: mpsc::UnboundedReceiver<A2sConnect>) {
236 trace!("Start connect_mgr");
237 while let Some((addr, pid_sender)) = a2s_connect_r.recv().await {
238 let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
239 let metrics =
240 ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&self.protocol_metrics));
241 self.metrics.connect_request(&addr);
242 let protocol = match addr.clone() {
243 ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, metrics).await,
244 #[cfg(feature = "quic")]
245 ConnectAddr::Quic(addr, ref config, name) => {
246 Protocols::with_quic_connect(addr, config.clone(), name, metrics).await
247 },
248 ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, metrics).await,
249 _ => unimplemented!(),
250 };
251 let protocol = match protocol {
252 Ok(p) => p,
253 Err(e) => {
254 pid_sender.send(Err(e)).unwrap();
255 continue;
256 },
257 };
258 self.init_protocol(protocol, addr, cid, Some(pid_sender), false)
259 .await;
260 }
261 trace!("Stop connect_mgr");
262 }
263
264 async fn disconnect_mgr(&self, a2s_disconnect_r: mpsc::UnboundedReceiver<A2sDisconnect>) {
265 trace!("Start disconnect_mgr");
266
267 let a2s_disconnect_r = UnboundedReceiverStream::new(a2s_disconnect_r);
268 a2s_disconnect_r
269 .for_each_concurrent(
270 None,
271 |(pid, (timeout_time, return_once_successful_shutdown))| {
272 let participants = Arc::clone(&self.participants);
278 async move {
279 trace!(?pid, "Got request to close participant");
280 let pi = participants.lock().await.remove(&pid);
281 trace!(?pid, "dropped participants lock");
282 let r = if let Some(mut pi) = pi {
283 let (finished_sender, finished_receiver) = oneshot::channel();
284 let _ = pi
287 .s2b_shutdown_bparticipant_s
288 .take()
289 .unwrap()
290 .send((timeout_time, finished_sender));
291 drop(pi);
292 trace!(?pid, "dropped bparticipant, waiting for finish");
293 let e = finished_receiver.await.unwrap_or(Ok(()));
295 trace!(?pid, "waiting completed");
296 return_once_successful_shutdown.send(e)
298 } else {
299 debug!(?pid, "Looks like participant is already dropped");
300 return_once_successful_shutdown.send(Ok(()))
301 };
302 if r.is_err() {
303 trace!(?pid, "Closed participant with timeout");
304 } else {
305 trace!(?pid, "Closed participant");
306 }
307 }
308 },
309 )
310 .await;
311 trace!("Stop disconnect_mgr");
312 }
313
314 async fn prio_adj_mgr(
315 &self,
316 mut b2s_prio_statistic_r: mpsc::UnboundedReceiver<B2sPrioStatistic>,
317 ) {
318 trace!("Start prio_adj_mgr");
319 while let Some((_pid, _frame_cnt, _unused)) = b2s_prio_statistic_r.recv().await {
320
321 }
323 trace!("Stop prio_adj_mgr");
324 }
325
326 async fn scheduler_shutdown_mgr(&self, a2s_scheduler_shutdown_r: oneshot::Receiver<()>) {
327 trace!("Start scheduler_shutdown_mgr");
328 if a2s_scheduler_shutdown_r.await.is_err() {
329 warn!("Schedule shutdown got triggered because a2s_scheduler_shutdown_r failed");
330 };
331 info!("Shutdown of scheduler requested");
332 self.closed.store(true, Ordering::SeqCst);
333 debug!("Shutting down all BParticipants gracefully");
334 let mut participants = self.participants.lock().await;
335 let waitings = participants
336 .drain()
337 .map(|(pid, mut pi)| {
338 trace!(?pid, "Shutting down BParticipants");
339 let (finished_sender, finished_receiver) = oneshot::channel();
340 pi.s2b_shutdown_bparticipant_s
341 .take()
342 .unwrap()
343 .send((Duration::from_secs(120), finished_sender))
344 .unwrap();
345 (pid, finished_receiver)
346 })
347 .collect::<Vec<_>>();
348 drop(participants);
349 debug!("Wait for partiticipants to be shut down");
350 for (pid, recv) in waitings {
351 if let Err(e) = recv.await {
352 error!(
353 ?pid,
354 ?e,
355 "Failed to finish sending all remaining messages to participant when shutting \
356 down"
357 );
358 };
359 }
360 debug!("shutting down protocol listeners");
361 for (addr, end_channel_sender) in self.channel_listener.lock().await.drain() {
362 trace!(?addr, "stopping listen on protocol");
363 if let Err(e) = end_channel_sender.send(()) {
364 warn!(?addr, ?e, "listener crashed/disconnected already");
365 }
366 }
367 debug!("Scheduler shut down gracefully");
368 self.participant_channels.lock().await.take();
371
372 trace!("Stop scheduler_shutdown_mgr");
373 }
374
375 async fn init_protocol(
376 &self,
377 mut protocol: Protocols,
378 con_addr: ConnectAddr, cid: Cid,
380 s2a_return_pid_s: Option<oneshot::Sender<Result<Participant, NetworkConnectError>>>,
381 send_handshake: bool,
382 ) {
383 let participant_channels = self.participant_channels.lock().await.clone().unwrap();
390 let participants = Arc::clone(&self.participants);
395 let metrics = Arc::clone(&self.metrics);
396 let local_pid = self.local_pid;
397 let local_secret = self.local_secret;
398 tokio::spawn(
400 async move {
401 trace!(?cid, "Open channel and be ready for Handshake");
402 use network_protocol::InitProtocol;
403 let init_result = protocol
404 .initialize(send_handshake, local_pid, local_secret)
405 .instrument(info_span!("handshake", ?cid))
406 .await;
407 match init_result {
408 Ok((pid, sid, secret)) => {
409 trace!(
410 ?cid,
411 ?pid,
412 "Detected that my channel is ready!, activating it :)"
413 );
414 let mut participants = participants.lock().await;
415 if !participants.contains_key(&pid) {
416 debug!(?cid, "New participant connected via a channel");
417 let (
418 bparticipant,
419 a2b_open_stream_s,
420 b2a_stream_opened_r,
421 b2a_event_r,
422 s2b_create_channel_s,
423 s2b_shutdown_bparticipant_s,
424 b2a_bandwidth_stats_r,
425 ) = BParticipant::new(local_pid, pid, sid, Arc::clone(&metrics));
426
427 let participant = Participant::new(
428 local_pid,
429 pid,
430 a2b_open_stream_s,
431 b2a_stream_opened_r,
432 b2a_event_r,
433 b2a_bandwidth_stats_r,
434 participant_channels.a2s_disconnect_s,
435 );
436
437 #[cfg(feature = "metrics")]
438 metrics.participants_connected_total.inc();
439 participants.insert(pid, ParticipantInfo {
440 secret,
441 s2b_create_channel_s: s2b_create_channel_s.clone(),
442 s2b_shutdown_bparticipant_s: Some(s2b_shutdown_bparticipant_s),
443 });
444 drop(participants);
445 trace!("dropped participants lock");
446 let p = pid;
447 tokio::spawn(
448 bparticipant
449 .run(participant_channels.b2s_prio_statistic_s)
450 .instrument(info_span!("remote", ?p)),
451 );
452 let (b2s_create_channel_done_s, b2s_create_channel_done_r) =
454 oneshot::channel();
455 s2b_create_channel_s
457 .send((cid, sid, protocol, con_addr, b2s_create_channel_done_s))
458 .unwrap();
459 b2s_create_channel_done_r.await.unwrap();
460 if let Some(pid_oneshot) = s2a_return_pid_s {
461 pid_oneshot.send(Ok(participant)).unwrap();
463 } else {
464 if participant_channels
466 .s2a_connected_s
467 .send(participant)
468 .is_err()
469 {
470 warn!("seems like Network already got closed");
471 };
472 }
473 } else {
474 let pi = &participants[&pid];
475 trace!(
476 ?cid,
477 "2nd+ channel of participant, going to compare security ids"
478 );
479 if pi.secret != secret {
480 warn!(
481 ?cid,
482 ?pid,
483 ?secret,
484 "Detected incompatible Secret!, this is probably an attack!"
485 );
486 error!(?cid, "Just dropping here, TODO handle this correctly!");
487 if let Some(pid_oneshot) = s2a_return_pid_s {
489 pid_oneshot
491 .send(Err(NetworkConnectError::InvalidSecret))
492 .unwrap();
493 }
494 return;
495 }
496 error!(
497 ?cid,
498 "Ufff i cant answer the pid_oneshot. as i need to create the SAME \
499 participant. maybe switch to ARC"
500 );
501 }
502 },
505 Err(e) => {
506 debug!(?cid, ?e, "Handshake from a new connection failed");
507 #[cfg(feature = "metrics")]
508 metrics.failed_handshakes_total.inc();
509 if let Some(pid_oneshot) = s2a_return_pid_s {
510 trace!(?cid, "returning the Err to api who requested the connect");
512 pid_oneshot
513 .send(Err(NetworkConnectError::Handshake(e)))
514 .unwrap();
515 }
516 },
517 }
518 }
519 .instrument(info_span!("")),
520 ); }
522}