veloren_network_protocol/
mpsc.rs

1#[cfg(feature = "metrics")]
2use crate::metrics::RemoveReason;
3use crate::{
4    RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink,
5    error::ProtocolError,
6    event::ProtocolEvent,
7    frame::InitFrame,
8    handshake::{ReliableDrain, ReliableSink},
9    metrics::ProtocolMetricCache,
10    types::{Bandwidth, Promises},
11};
12use async_trait::async_trait;
13use std::time::{Duration, Instant};
14#[cfg(feature = "trace_pedantic")]
15use tracing::trace;
16
17/// used for implementing your own MPSC `Sink` and `Drain`
18#[derive(Debug)]
19pub enum MpscMsg {
20    Event(ProtocolEvent),
21    InitFrame(InitFrame),
22}
23
24/// MPSC implementation of [`SendProtocol`]
25///
26/// [`SendProtocol`]: crate::SendProtocol
27#[derive(Debug)]
28pub struct MpscSendProtocol<D>
29where
30    D: UnreliableDrain<DataFormat = MpscMsg>,
31{
32    drain: D,
33    #[expect(dead_code)]
34    last: Instant,
35    metrics: ProtocolMetricCache,
36}
37
38/// MPSC implementation of [`RecvProtocol`]
39///
40/// [`RecvProtocol`]: crate::RecvProtocol
41#[derive(Debug)]
42pub struct MpscRecvProtocol<S>
43where
44    S: UnreliableSink<DataFormat = MpscMsg>,
45{
46    sink: S,
47    metrics: ProtocolMetricCache,
48}
49
50impl<D> MpscSendProtocol<D>
51where
52    D: UnreliableDrain<DataFormat = MpscMsg>,
53{
54    pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self {
55        Self {
56            drain,
57            last: Instant::now(),
58            metrics,
59        }
60    }
61
62    /// returns all promises that this Protocol can take care of
63    /// If you open a Stream anyway, unsupported promises are ignored.
64    pub fn supported_promises() -> Promises {
65        Promises::ORDERED
66            | Promises::CONSISTENCY
67            | Promises::GUARANTEED_DELIVERY
68            | Promises::COMPRESSED
69            | Promises::ENCRYPTED /*assume a direct mpsc connection is secure*/
70    }
71}
72
73impl<S> MpscRecvProtocol<S>
74where
75    S: UnreliableSink<DataFormat = MpscMsg>,
76{
77    pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { Self { sink, metrics } }
78}
79
80#[async_trait]
81impl<D> SendProtocol for MpscSendProtocol<D>
82where
83    D: UnreliableDrain<DataFormat = MpscMsg>,
84{
85    type CustomErr = D::CustomErr;
86
87    fn notify_from_recv(&mut self, _event: ProtocolEvent) {}
88
89    async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
90        #[cfg(feature = "trace_pedantic")]
91        trace!(?event, "send");
92        match &event {
93            ProtocolEvent::Message {
94                data: _data,
95                sid: _sid,
96            } => {
97                #[cfg(feature = "metrics")]
98                let (bytes, line) = {
99                    let sid = *_sid;
100                    let bytes = _data.len() as u64;
101                    let line = self.metrics.init_sid(sid);
102                    line.smsg_it.inc();
103                    line.smsg_ib.inc_by(bytes);
104                    (bytes, line)
105                };
106                let r = self.drain.send(MpscMsg::Event(event)).await;
107                #[cfg(feature = "metrics")]
108                {
109                    line.smsg_ot[RemoveReason::Finished.i()].inc();
110                    line.smsg_ob[RemoveReason::Finished.i()].inc_by(bytes);
111                }
112                r
113            },
114            _ => self.drain.send(MpscMsg::Event(event)).await,
115        }
116    }
117
118    async fn flush(
119        &mut self,
120        _: Bandwidth,
121        _: Duration,
122    ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
123        Ok(0)
124    }
125}
126
127#[async_trait]
128impl<S> RecvProtocol for MpscRecvProtocol<S>
129where
130    S: UnreliableSink<DataFormat = MpscMsg>,
131{
132    type CustomErr = S::CustomErr;
133
134    async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
135        let event = self.sink.recv().await?;
136        #[cfg(feature = "trace_pedantic")]
137        trace!(?event, "recv");
138        match event {
139            MpscMsg::Event(e) => {
140                #[cfg(feature = "metrics")]
141                {
142                    if let ProtocolEvent::Message { data, sid } = &e {
143                        let sid = *sid;
144                        let bytes = data.len() as u64;
145                        let line = self.metrics.init_sid(sid);
146                        line.rmsg_it.inc();
147                        line.rmsg_ib.inc_by(bytes);
148                        line.rmsg_ot[RemoveReason::Finished.i()].inc();
149                        line.rmsg_ob[RemoveReason::Finished.i()].inc_by(bytes);
150                    }
151                }
152                Ok(e)
153            },
154            MpscMsg::InitFrame(_) => Err(ProtocolError::Violated),
155        }
156    }
157}
158
159#[async_trait]
160impl<D> ReliableDrain for MpscSendProtocol<D>
161where
162    D: UnreliableDrain<DataFormat = MpscMsg>,
163{
164    type CustomErr = D::CustomErr;
165
166    async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
167        self.drain.send(MpscMsg::InitFrame(frame)).await
168    }
169}
170
171#[async_trait]
172impl<S> ReliableSink for MpscRecvProtocol<S>
173where
174    S: UnreliableSink<DataFormat = MpscMsg>,
175{
176    type CustomErr = S::CustomErr;
177
178    async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
179        match self.sink.recv().await? {
180            MpscMsg::Event(_) => Err(ProtocolError::Violated),
181            MpscMsg::InitFrame(f) => Ok(f),
182        }
183    }
184}
185
186#[cfg(test)]
187pub mod test_utils {
188    use super::*;
189    use crate::metrics::{ProtocolMetricCache, ProtocolMetrics};
190    use async_channel::*;
191    use std::sync::Arc;
192
193    pub struct ACDrain {
194        sender: Sender<MpscMsg>,
195    }
196
197    pub struct ACSink {
198        receiver: Receiver<MpscMsg>,
199    }
200
201    pub fn ac_bound(
202        cap: usize,
203        metrics: Option<ProtocolMetricCache>,
204    ) -> [(MpscSendProtocol<ACDrain>, MpscRecvProtocol<ACSink>); 2] {
205        let (s1, r1) = bounded(cap);
206        let (s2, r2) = bounded(cap);
207        let m = metrics.unwrap_or_else(|| {
208            ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap()))
209        });
210        [
211            (
212                MpscSendProtocol::new(ACDrain { sender: s1 }, m.clone()),
213                MpscRecvProtocol::new(ACSink { receiver: r2 }, m.clone()),
214            ),
215            (
216                MpscSendProtocol::new(ACDrain { sender: s2 }, m.clone()),
217                MpscRecvProtocol::new(ACSink { receiver: r1 }, m),
218            ),
219        ]
220    }
221
222    #[async_trait]
223    impl UnreliableDrain for ACDrain {
224        type CustomErr = ();
225        type DataFormat = MpscMsg;
226
227        async fn send(
228            &mut self,
229            data: Self::DataFormat,
230        ) -> Result<(), ProtocolError<Self::CustomErr>> {
231            self.sender
232                .send(data)
233                .await
234                .map_err(|_| ProtocolError::Custom(()))
235        }
236    }
237
238    #[async_trait]
239    impl UnreliableSink for ACSink {
240        type CustomErr = ();
241        type DataFormat = MpscMsg;
242
243        async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
244            self.receiver
245                .recv()
246                .await
247                .map_err(|_| ProtocolError::Custom(()))
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use crate::{
255        InitProtocol,
256        mpsc::test_utils::*,
257        types::{Pid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2},
258    };
259
260    #[tokio::test]
261    async fn handshake_all_good() {
262        let [mut p1, mut p2] = ac_bound(10, None);
263        let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
264        let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await });
265        let (r1, r2) = tokio::join!(r1, r2);
266        assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42)));
267        assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337)));
268    }
269}