use crate::{
api::{ConnectAddr, ParticipantError, ParticipantEvent, Stream},
channel::{Protocols, ProtocolsError, RecvProtocols, SendProtocols},
metrics::NetworkMetrics,
util::DeferredTracer,
};
use bytes::Bytes;
use futures_util::{FutureExt, StreamExt};
use hashbrown::HashMap;
use network_protocol::{
Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, Sid,
_internal::SortedVec,
};
use std::{
sync::{
atomic::{AtomicBool, AtomicI32, Ordering},
Arc,
},
time::{Duration, Instant},
};
use tokio::{
select,
sync::{mpsc, oneshot, watch, Mutex, RwLock},
task::JoinHandle,
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::*;
pub(crate) type A2bStreamOpen = (Prio, Promises, Bandwidth, oneshot::Sender<Stream>);
pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, ConnectAddr, oneshot::Sender<()>);
pub(crate) type S2bShutdownBparticipant = (Duration, oneshot::Sender<Result<(), ParticipantError>>);
pub(crate) type B2sPrioStatistic = (Pid, u64, u64);
#[derive(Debug)]
#[allow(dead_code)]
struct ChannelInfo {
cid: Cid,
cid_string: String, remote_con_addr: ConnectAddr,
}
#[derive(Debug)]
struct StreamInfo {
#[allow(dead_code)]
prio: Prio,
#[allow(dead_code)]
promises: Promises,
send_closed: Arc<AtomicBool>,
b2a_msg_recv_s: Mutex<async_channel::Sender<Bytes>>,
}
#[derive(Debug)]
struct ControlChannels {
a2b_open_stream_r: mpsc::UnboundedReceiver<A2bStreamOpen>,
b2a_stream_opened_s: mpsc::UnboundedSender<Stream>,
b2a_event_s: mpsc::UnboundedSender<ParticipantEvent>,
s2b_create_channel_r: mpsc::UnboundedReceiver<S2bCreateChannel>,
b2a_bandwidth_stats_s: watch::Sender<f32>,
s2b_shutdown_bparticipant_r: oneshot::Receiver<S2bShutdownBparticipant>, }
#[derive(Debug)]
struct OpenStreamInfo {
a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>,
a2b_close_stream_s: mpsc::UnboundedSender<Sid>,
}
#[derive(Debug)]
pub struct BParticipant {
local_pid: Pid, remote_pid: Pid,
remote_pid_string: String, offset_sid: Sid,
channels: Arc<RwLock<HashMap<Cid, Mutex<ChannelInfo>>>>,
streams: RwLock<HashMap<Sid, StreamInfo>>,
run_channels: Option<ControlChannels>,
shutdown_barrier: AtomicI32,
metrics: Arc<NetworkMetrics>,
open_stream_channels: Arc<Mutex<Option<OpenStreamInfo>>>,
}
impl BParticipant {
const BARR_CHANNEL: i32 = 1;
const BARR_RECV: i32 = 4;
const BARR_SEND: i32 = 2;
const TICK_TIME: Duration = Duration::from_millis(Self::TICK_TIME_MS);
const TICK_TIME_MS: u64 = 5;
pub(crate) fn new(
local_pid: Pid,
remote_pid: Pid,
offset_sid: Sid,
metrics: Arc<NetworkMetrics>,
) -> (
Self,
mpsc::UnboundedSender<A2bStreamOpen>,
mpsc::UnboundedReceiver<Stream>,
mpsc::UnboundedReceiver<ParticipantEvent>,
mpsc::UnboundedSender<S2bCreateChannel>,
oneshot::Sender<S2bShutdownBparticipant>,
watch::Receiver<f32>,
) {
let (a2b_open_stream_s, a2b_open_stream_r) = mpsc::unbounded_channel::<A2bStreamOpen>();
let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded_channel::<Stream>();
let (b2a_event_s, b2a_event_r) = mpsc::unbounded_channel::<ParticipantEvent>();
let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel();
let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded_channel();
let (b2a_bandwidth_stats_s, b2a_bandwidth_stats_r) = watch::channel::<f32>(0.0);
let run_channels = Some(ControlChannels {
a2b_open_stream_r,
b2a_stream_opened_s,
b2a_event_s,
s2b_create_channel_r,
b2a_bandwidth_stats_s,
s2b_shutdown_bparticipant_r,
});
(
Self {
local_pid,
remote_pid,
remote_pid_string: remote_pid.to_string(),
offset_sid,
channels: Arc::new(RwLock::new(HashMap::new())),
streams: RwLock::new(HashMap::new()),
shutdown_barrier: AtomicI32::new(
Self::BARR_CHANNEL + Self::BARR_SEND + Self::BARR_RECV,
),
run_channels,
metrics,
open_stream_channels: Arc::new(Mutex::new(None)),
},
a2b_open_stream_s,
b2a_stream_opened_r,
b2a_event_r,
s2b_create_channel_s,
s2b_shutdown_bparticipant_s,
b2a_bandwidth_stats_r,
)
}
pub async fn run(mut self, b2s_prio_statistic_s: mpsc::UnboundedSender<B2sPrioStatistic>) {
let (b2b_add_send_protocol_s, b2b_add_send_protocol_r) =
mpsc::unbounded_channel::<(Cid, SendProtocols)>();
let (b2b_add_recv_protocol_s, b2b_add_recv_protocol_r) =
mpsc::unbounded_channel::<(Cid, RecvProtocols)>();
let (b2b_close_send_protocol_s, b2b_close_send_protocol_r) =
async_channel::unbounded::<Cid>();
let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) =
async_channel::unbounded::<Cid>();
let (b2b_notify_send_of_recv_open_s, b2b_notify_send_of_recv_open_r) =
crossbeam_channel::unbounded::<(Cid, Sid, Prio, Promises, u64)>();
let (b2b_notify_send_of_recv_close_s, b2b_notify_send_of_recv_close_r) =
crossbeam_channel::unbounded::<(Cid, Sid)>();
let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::<Sid>();
let (a2b_msg_s, a2b_msg_r) = crossbeam_channel::unbounded::<(Sid, Bytes)>();
*self.open_stream_channels.lock().await = Some(OpenStreamInfo {
a2b_msg_s,
a2b_close_stream_s,
});
let run_channels = self.run_channels.take().unwrap();
trace!("start all managers");
tokio::join!(
self.send_mgr(
run_channels.a2b_open_stream_r,
a2b_close_stream_r,
a2b_msg_r,
b2b_add_send_protocol_r,
b2b_close_send_protocol_r,
b2b_notify_send_of_recv_open_r,
b2b_notify_send_of_recv_close_r,
run_channels.b2a_event_s.clone(),
b2s_prio_statistic_s,
run_channels.b2a_bandwidth_stats_s,
)
.instrument(tracing::info_span!("send")),
self.recv_mgr(
run_channels.b2a_stream_opened_s,
b2b_add_recv_protocol_r,
b2b_force_close_recv_protocol_r,
b2b_close_send_protocol_s.clone(),
b2b_notify_send_of_recv_open_s,
b2b_notify_send_of_recv_close_s,
)
.instrument(tracing::info_span!("recv")),
self.create_channel_mgr(
run_channels.s2b_create_channel_r,
b2b_add_send_protocol_s,
b2b_add_recv_protocol_s,
run_channels.b2a_event_s,
),
self.participant_shutdown_mgr(
run_channels.s2b_shutdown_bparticipant_r,
b2b_close_send_protocol_s.clone(),
b2b_force_close_recv_protocol_s,
),
);
}
fn best_protocol(all: &SortedVec<Cid, SendProtocols>, promises: Promises) -> Option<Cid> {
all.data.iter().find(|(_, p)| matches!(p, SendProtocols::Mpsc(_))).map(|(c, _)| *c).or_else(
|| if network_protocol::TcpSendProtocol::<crate::channel::TcpDrain>::supported_promises()
.contains(promises)
{
all.data.iter().find(|(_, p)| matches!(p, SendProtocols::Tcp(_))).map(|(c, _)| *c)
} else {
None
}
).or_else(
|| if network_protocol::QuicSendProtocol::<crate::channel::QuicDrain>::supported_promises()
.contains(promises)
{
all.data.iter().find(|(_, p)| matches!(p, SendProtocols::Quic(_))).map(|(c, _)| *c)
} else {
None
}
).or_else(
|| {
warn!("couldn't satisfy promises");
all.data.first().map(|(c, _)| *c)
}
)
}
async fn send_mgr(
&self,
mut a2b_open_stream_r: mpsc::UnboundedReceiver<A2bStreamOpen>,
mut a2b_close_stream_r: mpsc::UnboundedReceiver<Sid>,
a2b_msg_r: crossbeam_channel::Receiver<(Sid, Bytes)>,
mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>,
b2b_close_send_protocol_r: async_channel::Receiver<Cid>,
b2b_notify_send_of_recv_open_r: crossbeam_channel::Receiver<(
Cid,
Sid,
Prio,
Promises,
Bandwidth,
)>,
b2b_notify_send_of_recv_close_r: crossbeam_channel::Receiver<(Cid, Sid)>,
b2a_event_s: mpsc::UnboundedSender<ParticipantEvent>,
_b2s_prio_statistic_s: mpsc::UnboundedSender<B2sPrioStatistic>,
b2a_bandwidth_stats_s: watch::Sender<f32>,
) {
let mut sorted_send_protocols = SortedVec::<Cid, SendProtocols>::default();
let mut sorted_stream_protocols = SortedVec::<Sid, Cid>::default();
let mut interval = tokio::time::interval(Self::TICK_TIME);
let mut last_instant = Instant::now();
let mut stream_ids = self.offset_sid;
let mut part_bandwidth = 0.0f32;
trace!("workaround, actively wait for first protocol");
if let Some((c, p)) = b2b_add_protocol_r.recv().await {
sorted_send_protocols.insert(c, p)
}
loop {
let (open, close, _, addp, remp) = select!(
Some(n) = a2b_open_stream_r.recv().fuse() => (Some(n), None, None, None, None),
Some(n) = a2b_close_stream_r.recv().fuse() => (None, Some(n), None, None, None),
_ = interval.tick() => (None, None, Some(()), None, None),
Some(n) = b2b_add_protocol_r.recv().fuse() => (None, None, None, Some(n), None),
Ok(n) = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(n)),
);
if let Some((cid, p)) = addp {
debug!(?cid, "add protocol");
sorted_send_protocols.insert(cid, p);
}
if sorted_send_protocols.data.is_empty() {
warn!("no channel");
tokio::time::sleep(Self::TICK_TIME * 1000).await; continue;
}
let mut cid = u64::MAX;
let active_err = async {
if let Some((prio, promises, guaranteed_bandwidth, return_s)) = open {
let sid = stream_ids;
stream_ids += Sid::from(1);
cid = Self::best_protocol(&sorted_send_protocols, promises).unwrap();
trace!(?sid, ?cid, "open stream");
let stream = self
.create_stream(sid, prio, promises, guaranteed_bandwidth)
.await;
let event = ProtocolEvent::OpenStream {
sid,
prio,
promises,
guaranteed_bandwidth,
};
sorted_stream_protocols.insert(sid, cid);
return_s.send(stream).unwrap();
sorted_send_protocols
.get_mut(&cid)
.unwrap()
.send(event)
.await?;
}
for (cid, sid, prio, promises, guaranteed_bandwidth) in
b2b_notify_send_of_recv_open_r.try_iter()
{
match sorted_send_protocols.get_mut(&cid) {
Some(p) => {
sorted_stream_protocols.insert(sid, cid);
p.notify_from_recv(ProtocolEvent::OpenStream {
sid,
prio,
promises,
guaranteed_bandwidth,
});
},
None => warn!(?cid, "couldn't notify create protocol, doesn't exist"),
};
}
for (sid, buffer) in a2b_msg_r.try_iter() {
cid = *sorted_stream_protocols.get(&sid).unwrap();
let event = ProtocolEvent::Message { data: buffer, sid };
sorted_send_protocols
.get_mut(&cid)
.unwrap()
.send(event)
.await?;
}
for (cid, sid) in b2b_notify_send_of_recv_close_r.try_iter() {
match sorted_send_protocols.get_mut(&cid) {
Some(p) => {
let _ = sorted_stream_protocols.delete(&sid);
p.notify_from_recv(ProtocolEvent::CloseStream { sid });
},
None => warn!(?cid, "couldn't notify close protocol, doesn't exist"),
};
}
if let Some(sid) = close {
trace!(?stream_ids, "delete stream");
self.delete_stream(sid).await;
if let Some(c) = sorted_stream_protocols.delete(&sid) {
cid = c;
let event = ProtocolEvent::CloseStream { sid };
sorted_send_protocols
.get_mut(&c)
.unwrap()
.send(event)
.await?;
}
}
let send_time = Instant::now();
let diff = send_time.duration_since(last_instant);
last_instant = send_time;
let mut cnt = 0;
for (c, p) in sorted_send_protocols.data.iter_mut() {
cid = *c;
cnt += p.flush(1_000_000_000, diff).await?; }
let flush_time = send_time.elapsed().as_secs_f32();
part_bandwidth = 0.99 * part_bandwidth + 0.01 * (cnt as f32 / flush_time);
self.metrics
.participant_bandwidth(&self.remote_pid_string, part_bandwidth);
let _ = b2a_bandwidth_stats_s.send(part_bandwidth);
let r: Result<(), network_protocol::ProtocolError<ProtocolsError>> = Ok(());
r
}
.await;
if let Err(e) = active_err {
info!(?cid, ?e, "protocol failed, shutting down channel");
trace!("TODO: for now decide to FAIL this participant and not wait for a failover");
sorted_send_protocols.delete(&cid).unwrap();
if let Some(info) = self.channels.write().await.get(&cid) {
if let Err(e) = b2a_event_s.send(ParticipantEvent::ChannelDeleted(
info.lock().await.remote_con_addr.clone(),
)) {
debug!(?e, "Participant was dropped during channel disconnect");
};
}
self.metrics.channels_disconnected(&self.remote_pid_string);
if sorted_send_protocols.data.is_empty() {
break;
}
}
if let Some(cid) = remp {
debug!(?cid, "remove protocol");
match sorted_send_protocols.delete(&cid) {
Some(mut prot) => {
if let Some(info) = self.channels.write().await.get(&cid) {
if let Err(e) = b2a_event_s.send(ParticipantEvent::ChannelDeleted(
info.lock().await.remote_con_addr.clone(),
)) {
debug!(?e, "Participant was dropped during channel disconnect");
};
}
self.metrics.channels_disconnected(&self.remote_pid_string);
trace!("blocking flush");
let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await;
trace!("shutdown prot");
let _ = prot.send(ProtocolEvent::Shutdown).await;
},
None => trace!("tried to remove protocol twice"),
};
if sorted_send_protocols.data.is_empty() {
break;
}
}
}
trace!("stop sending in api!");
self.open_stream_channels.lock().await.take();
trace!("Stop send_mgr");
self.shutdown_barrier
.fetch_sub(Self::BARR_SEND, Ordering::SeqCst);
}
async fn recv_mgr(
&self,
b2a_stream_opened_s: mpsc::UnboundedSender<Stream>,
mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>,
b2b_force_close_recv_protocol_r: async_channel::Receiver<Cid>,
b2b_close_send_protocol_s: async_channel::Sender<Cid>,
b2b_notify_send_of_recv_open_r: crossbeam_channel::Sender<(
Cid,
Sid,
Prio,
Promises,
Bandwidth,
)>,
b2b_notify_send_of_recv_close_s: crossbeam_channel::Sender<(Cid, Sid)>,
) {
let mut recv_protocols: HashMap<Cid, JoinHandle<()>> = HashMap::new();
let (hacky_recv_s, mut hacky_recv_r) = mpsc::unbounded_channel();
let retrigger = |cid: Cid, mut p: RecvProtocols, map: &mut HashMap<_, _>| {
let hacky_recv_s = hacky_recv_s.clone();
let handle = tokio::spawn(async move {
let r = p.recv().await;
let _ = hacky_recv_s.send((cid, r, p)); });
map.insert(cid, handle);
};
let remove_c = |recv_protocols: &mut HashMap<Cid, JoinHandle<()>>, cid: &Cid| {
match recv_protocols.remove(cid) {
Some(h) => {
h.abort();
debug!(?cid, "remove protocol");
},
None => trace!("tried to remove protocol twice"),
};
recv_protocols.is_empty()
};
let mut defered_orphan = DeferredTracer::new(Level::WARN);
loop {
let (event, addp, remp) = select!(
Some(n) = hacky_recv_r.recv().fuse() => (Some(n), None, None),
Some(n) = b2b_add_protocol_r.recv().fuse() => (None, Some(n), None),
Ok(n) = b2b_force_close_recv_protocol_r.recv().fuse() => (None, None, Some(n)),
else => {
error!("recv_mgr -> something is seriously wrong!, end recv_mgr");
break;
}
);
if let Some((cid, p)) = addp {
debug!(?cid, "add protocol");
retrigger(cid, p, &mut recv_protocols);
};
if let Some(cid) = remp {
if remove_c(&mut recv_protocols, &cid) {
break;
}
};
if let Some((cid, r, p)) = event {
match r {
Ok(ProtocolEvent::OpenStream {
sid,
prio,
promises,
guaranteed_bandwidth,
}) => {
trace!(?sid, "open stream");
let _ = b2b_notify_send_of_recv_open_r.send((
cid,
sid,
prio,
promises,
guaranteed_bandwidth,
));
let stream = self
.create_stream(sid, prio, promises, guaranteed_bandwidth)
.await;
b2a_stream_opened_s.send(stream).unwrap();
retrigger(cid, p, &mut recv_protocols);
},
Ok(ProtocolEvent::CloseStream { sid }) => {
trace!(?sid, "close stream");
let _ = b2b_notify_send_of_recv_close_s.send((cid, sid));
self.delete_stream(sid).await;
retrigger(cid, p, &mut recv_protocols);
},
Ok(ProtocolEvent::Message { data, sid }) => {
let lock = self.streams.read().await;
match lock.get(&sid) {
Some(stream) => {
let _ = stream.b2a_msg_recv_s.lock().await.send(data).await;
},
None => defered_orphan.log(sid),
};
retrigger(cid, p, &mut recv_protocols);
},
Ok(ProtocolEvent::Shutdown) => {
info!(?cid, "shutdown protocol");
if let Err(e) = b2b_close_send_protocol_s.send(cid).await {
debug!(?e, ?cid, "send_mgr was already closed simultaneously");
}
if remove_c(&mut recv_protocols, &cid) {
break;
}
},
Err(e) => {
info!(?e, ?cid, "protocol failed, shutting down channel");
if let Err(e) = b2b_close_send_protocol_s.send(cid).await {
debug!(?e, ?cid, "send_mgr was already closed simultaneously");
}
if remove_c(&mut recv_protocols, &cid) {
break;
}
},
}
}
if let Some(table) = defered_orphan.print() {
for (sid, cnt) in table.iter() {
warn!(?sid, ?cnt, "recv messages with orphan stream");
}
}
}
trace!("receiving no longer possible, closing all streams");
for (_, si) in self.streams.write().await.drain() {
si.send_closed.store(true, Ordering::SeqCst);
self.metrics.streams_closed(&self.remote_pid_string);
}
trace!("Stop recv_mgr");
self.shutdown_barrier
.fetch_sub(Self::BARR_RECV, Ordering::SeqCst);
}
async fn create_channel_mgr(
&self,
s2b_create_channel_r: mpsc::UnboundedReceiver<S2bCreateChannel>,
b2b_add_send_protocol_s: mpsc::UnboundedSender<(Cid, SendProtocols)>,
b2b_add_recv_protocol_s: mpsc::UnboundedSender<(Cid, RecvProtocols)>,
b2a_event_s: mpsc::UnboundedSender<ParticipantEvent>,
) {
let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r);
s2b_create_channel_r
.for_each_concurrent(
None,
|(cid, _, protocol, remote_con_addr, b2s_create_channel_done_s)| {
let channels = Arc::clone(&self.channels);
let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone();
let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone();
let b2a_event_s = b2a_event_s.clone();
async move {
let mut lock = channels.write().await;
let mut channel_no = lock.len();
lock.insert(
cid,
Mutex::new(ChannelInfo {
cid,
cid_string: cid.to_string(),
remote_con_addr: remote_con_addr.clone(),
}),
);
drop(lock);
let (send, recv) = protocol.split();
b2b_add_send_protocol_s.send((cid, send)).unwrap();
b2b_add_recv_protocol_s.send((cid, recv)).unwrap();
if let Err(e) =
b2a_event_s.send(ParticipantEvent::ChannelCreated(remote_con_addr))
{
debug!(?e, "Participant was dropped during channel connect");
};
b2s_create_channel_done_s.send(()).unwrap();
if channel_no > 5 {
debug!(?channel_no, "metrics will overwrite channel #5");
channel_no = 5;
}
self.metrics
.channels_connected(&self.remote_pid_string, channel_no, cid);
}
},
)
.await;
trace!("Stop create_channel_mgr");
self.shutdown_barrier
.fetch_sub(Self::BARR_CHANNEL, Ordering::SeqCst);
}
async fn participant_shutdown_mgr(
&self,
s2b_shutdown_bparticipant_r: oneshot::Receiver<S2bShutdownBparticipant>,
b2b_close_send_protocol_s: async_channel::Sender<Cid>,
b2b_force_close_recv_protocol_s: async_channel::Sender<Cid>,
) {
let wait_for_manager = || async {
let mut sleep = 0.01f64;
loop {
let bytes = self.shutdown_barrier.load(Ordering::SeqCst);
if bytes == 0 {
break;
}
sleep *= 1.4;
tokio::time::sleep(Duration::from_secs_f64(sleep)).await;
if sleep > 0.2 {
trace!(?bytes, "wait for mgr to close");
}
}
};
let awaited = s2b_shutdown_bparticipant_r.await.ok();
debug!("participant_shutdown_mgr triggered. Closing all streams for send");
{
let lock = self.streams.read().await;
for si in lock.values() {
si.send_closed.store(true, Ordering::SeqCst);
}
}
let lock = self.channels.read().await;
assert!(
!lock.is_empty(),
"no channel existed remote_pid={}",
self.remote_pid
);
for cid in lock.keys() {
if let Err(e) = b2b_close_send_protocol_s.send(*cid).await {
debug!(
?e,
?cid,
"closing send_mgr may fail if we got a recv error simultaneously"
);
}
}
drop(lock);
trace!("wait for other managers");
let timeout = tokio::time::sleep(
awaited
.as_ref()
.map(|(timeout_time, _)| *timeout_time)
.unwrap_or_default(),
);
let timeout = select! {
_ = wait_for_manager() => false,
_ = timeout => true,
};
if timeout {
warn!("timeout triggered: for killing recv");
let lock = self.channels.read().await;
for cid in lock.keys() {
if let Err(e) = b2b_force_close_recv_protocol_s.send(*cid).await {
debug!(
?e,
?cid,
"closing recv_mgr may fail if we got a recv error simultaneously"
);
}
}
}
trace!("wait again");
wait_for_manager().await;
if let Some((_, sender)) = awaited {
let _ = sender.send(Ok(()));
}
#[cfg(feature = "metrics")]
self.metrics.participants_disconnected_total.inc();
self.metrics.cleanup_participant(&self.remote_pid_string);
trace!("Stop participant_shutdown_mgr");
}
async fn delete_stream(&self, sid: Sid) {
let stream = { self.streams.write().await.remove(&sid) };
match stream {
Some(si) => {
si.send_closed.store(true, Ordering::SeqCst);
si.b2a_msg_recv_s.lock().await.close();
},
None => {
trace!("Couldn't find the stream, might be simultaneous close from local/remote")
},
}
self.metrics.streams_closed(&self.remote_pid_string);
}
async fn create_stream(
&self,
sid: Sid,
prio: Prio,
promises: Promises,
guaranteed_bandwidth: Bandwidth,
) -> Stream {
let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::<Bytes>();
let send_closed = Arc::new(AtomicBool::new(false));
self.streams.write().await.insert(sid, StreamInfo {
prio,
promises,
send_closed: Arc::clone(&send_closed),
b2a_msg_recv_s: Mutex::new(b2a_msg_recv_s),
});
self.metrics.streams_opened(&self.remote_pid_string);
let (a2b_msg_s, a2b_close_stream_s) = {
let lock = self.open_stream_channels.lock().await;
match &*lock {
Some(osi) => (osi.a2b_msg_s.clone(), osi.a2b_close_stream_s.clone()),
None => {
debug!(
"It seems that a stream was requested to open, while the send_mgr is \
already closed"
);
let (a2b_msg_s, _) = crossbeam_channel::unbounded();
let (a2b_close_stream_s, _) = mpsc::unbounded_channel();
(a2b_msg_s, a2b_close_stream_s)
},
}
};
Stream::new(
self.local_pid,
self.remote_pid,
sid,
prio,
promises,
guaranteed_bandwidth,
send_closed,
a2b_msg_s,
b2a_msg_recv_r,
a2b_close_stream_s,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::assert_matches::assert_matches;
use network_protocol::{ProtocolMetricCache, ProtocolMetrics};
use tokio::{
runtime::Runtime,
sync::{mpsc, oneshot},
task::JoinHandle,
};
fn mock_bparticipant() -> (
Arc<Runtime>,
mpsc::UnboundedSender<A2bStreamOpen>,
mpsc::UnboundedReceiver<Stream>,
mpsc::UnboundedReceiver<ParticipantEvent>,
mpsc::UnboundedSender<S2bCreateChannel>,
oneshot::Sender<S2bShutdownBparticipant>,
mpsc::UnboundedReceiver<B2sPrioStatistic>,
watch::Receiver<f32>,
JoinHandle<()>,
) {
let runtime = Arc::new(Runtime::new().unwrap());
let runtime_clone = Arc::clone(&runtime);
let (b2s_prio_statistic_s, b2s_prio_statistic_r) =
mpsc::unbounded_channel::<B2sPrioStatistic>();
let (
bparticipant,
a2b_open_stream_s,
b2a_stream_opened_r,
b2a_event_r,
s2b_create_channel_s,
s2b_shutdown_bparticipant_s,
b2a_bandwidth_stats_r,
) = runtime_clone.block_on(async move {
let local_pid = Pid::fake(0);
let remote_pid = Pid::fake(1);
let sid = Sid::new(1000);
let metrics = Arc::new(NetworkMetrics::new(&local_pid).unwrap());
BParticipant::new(local_pid, remote_pid, sid, Arc::clone(&metrics))
});
let handle = runtime_clone.spawn(bparticipant.run(b2s_prio_statistic_s));
(
runtime_clone,
a2b_open_stream_s,
b2a_stream_opened_r,
b2a_event_r,
s2b_create_channel_s,
s2b_shutdown_bparticipant_s,
b2s_prio_statistic_r,
b2a_bandwidth_stats_r,
handle,
)
}
#[allow(clippy::needless_pass_by_ref_mut)]
async fn mock_mpsc(
cid: Cid,
_runtime: &Arc<Runtime>,
create_channel: &mut mpsc::UnboundedSender<S2bCreateChannel>,
) -> Protocols {
let (s1, r1) = mpsc::channel(100);
let (s2, r2) = mpsc::channel(100);
let met = Arc::new(ProtocolMetrics::new().unwrap());
let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&met));
let p1 = Protocols::new_mpsc(s1, r2, metrics);
let (complete_s, complete_r) = oneshot::channel();
create_channel
.send((cid, Sid::new(0), p1, ConnectAddr::Mpsc(42), complete_s))
.unwrap();
complete_r.await.unwrap();
let metrics = ProtocolMetricCache::new(&cid.to_string(), met);
Protocols::new_mpsc(s2, r1, metrics)
}
#[test]
fn close_bparticipant_by_timeout_during_close() {
let (
runtime,
a2b_open_stream_s,
b2a_stream_opened_r,
mut b2a_event_r,
mut s2b_create_channel_s,
s2b_shutdown_bparticipant_s,
b2s_prio_statistic_r,
_b2a_bandwidth_stats_r,
handle,
) = mock_bparticipant();
let _remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
std::thread::sleep(Duration::from_millis(50));
let (s, r) = oneshot::channel();
let before = Instant::now();
runtime.block_on(async {
drop(s2b_create_channel_s);
s2b_shutdown_bparticipant_s
.send((Duration::from_secs(1), s))
.unwrap();
r.await.unwrap().unwrap();
});
assert!(
before.elapsed() > Duration::from_millis(900),
"timeout wasn't triggered"
);
assert_matches!(
b2a_event_r.try_recv().unwrap(),
ParticipantEvent::ChannelCreated(_)
);
assert_matches!(
b2a_event_r.try_recv().unwrap(),
ParticipantEvent::ChannelDeleted(_)
);
assert_matches!(b2a_event_r.try_recv(), Err(_));
runtime.block_on(handle).unwrap();
drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
drop(runtime);
}
#[test]
fn close_bparticipant_cleanly() {
let (
runtime,
a2b_open_stream_s,
b2a_stream_opened_r,
mut b2a_event_r,
mut s2b_create_channel_s,
s2b_shutdown_bparticipant_s,
b2s_prio_statistic_r,
_b2a_bandwidth_stats_r,
handle,
) = mock_bparticipant();
let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
std::thread::sleep(Duration::from_millis(50));
let (s, r) = oneshot::channel();
let before = Instant::now();
runtime.block_on(async {
drop(s2b_create_channel_s);
s2b_shutdown_bparticipant_s
.send((Duration::from_secs(2), s))
.unwrap();
drop(remote); r.await.unwrap().unwrap();
});
assert!(
before.elapsed() < Duration::from_millis(1900),
"timeout was triggered"
);
assert_matches!(
b2a_event_r.try_recv().unwrap(),
ParticipantEvent::ChannelCreated(_)
);
assert_matches!(
b2a_event_r.try_recv().unwrap(),
ParticipantEvent::ChannelDeleted(_)
);
assert_matches!(b2a_event_r.try_recv(), Err(_));
runtime.block_on(handle).unwrap();
drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
drop(runtime);
}
#[test]
fn create_stream() {
let (
runtime,
a2b_open_stream_s,
b2a_stream_opened_r,
_b2a_event_r,
mut s2b_create_channel_s,
s2b_shutdown_bparticipant_s,
b2s_prio_statistic_r,
_b2a_bandwidth_stats_r,
handle,
) = mock_bparticipant();
let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
std::thread::sleep(Duration::from_millis(50));
let (rs, mut rr) = remote.split();
let (stream_sender, _stream_receiver) = oneshot::channel();
a2b_open_stream_s
.send((7u8, Promises::ENCRYPTED, 1_000_000, stream_sender))
.unwrap();
let stream_event = runtime.block_on(rr.recv()).unwrap();
match stream_event {
ProtocolEvent::OpenStream {
sid,
prio,
promises,
guaranteed_bandwidth,
} => {
assert_eq!(sid, Sid::new(1000));
assert_eq!(prio, 7u8);
assert_eq!(promises, Promises::ENCRYPTED);
assert_eq!(guaranteed_bandwidth, 1_000_000);
},
_ => panic!("wrong event"),
};
let (s, r) = oneshot::channel();
runtime.block_on(async {
drop(s2b_create_channel_s);
s2b_shutdown_bparticipant_s
.send((Duration::from_secs(1), s))
.unwrap();
drop((rs, rr));
r.await.unwrap().unwrap();
});
runtime.block_on(handle).unwrap();
drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
drop(runtime);
}
#[test]
fn created_stream() {
let (
runtime,
a2b_open_stream_s,
mut b2a_stream_opened_r,
_b2a_event_r,
mut s2b_create_channel_s,
s2b_shutdown_bparticipant_s,
b2s_prio_statistic_r,
_b2a_bandwidth_stats_r,
handle,
) = mock_bparticipant();
let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s));
std::thread::sleep(Duration::from_millis(50));
let (mut rs, rr) = remote.split();
runtime
.block_on(rs.send(ProtocolEvent::OpenStream {
sid: Sid::new(1000),
prio: 9u8,
promises: Promises::ORDERED,
guaranteed_bandwidth: 1_000_000,
}))
.unwrap();
let stream = runtime.block_on(b2a_stream_opened_r.recv()).unwrap();
assert_eq!(stream.params().promises, Promises::ORDERED);
let (s, r) = oneshot::channel();
runtime.block_on(async {
drop(s2b_create_channel_s);
s2b_shutdown_bparticipant_s
.send((Duration::from_secs(1), s))
.unwrap();
drop((rs, rr));
r.await.unwrap().unwrap();
});
runtime.block_on(handle).unwrap();
drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r));
drop(runtime);
}
}