veloren_network_protocol/
mpsc.rs1#[cfg(feature = "metrics")]
2use crate::metrics::RemoveReason;
3use crate::{
4 RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink,
5 error::ProtocolError,
6 event::ProtocolEvent,
7 frame::InitFrame,
8 handshake::{ReliableDrain, ReliableSink},
9 metrics::ProtocolMetricCache,
10 types::{Bandwidth, Promises},
11};
12use async_trait::async_trait;
13use std::time::{Duration, Instant};
14#[cfg(feature = "trace_pedantic")]
15use tracing::trace;
16
17#[derive(Debug)]
19pub enum MpscMsg {
20 Event(ProtocolEvent),
21 InitFrame(InitFrame),
22}
23
24#[derive(Debug)]
28pub struct MpscSendProtocol<D>
29where
30 D: UnreliableDrain<DataFormat = MpscMsg>,
31{
32 drain: D,
33 #[expect(dead_code)]
34 last: Instant,
35 metrics: ProtocolMetricCache,
36}
37
38#[derive(Debug)]
42pub struct MpscRecvProtocol<S>
43where
44 S: UnreliableSink<DataFormat = MpscMsg>,
45{
46 sink: S,
47 metrics: ProtocolMetricCache,
48}
49
50impl<D> MpscSendProtocol<D>
51where
52 D: UnreliableDrain<DataFormat = MpscMsg>,
53{
54 pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self {
55 Self {
56 drain,
57 last: Instant::now(),
58 metrics,
59 }
60 }
61
62 pub fn supported_promises() -> Promises {
65 Promises::ORDERED
66 | Promises::CONSISTENCY
67 | Promises::GUARANTEED_DELIVERY
68 | Promises::COMPRESSED
69 | Promises::ENCRYPTED }
71}
72
73impl<S> MpscRecvProtocol<S>
74where
75 S: UnreliableSink<DataFormat = MpscMsg>,
76{
77 pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { Self { sink, metrics } }
78}
79
80#[async_trait]
81impl<D> SendProtocol for MpscSendProtocol<D>
82where
83 D: UnreliableDrain<DataFormat = MpscMsg>,
84{
85 type CustomErr = D::CustomErr;
86
87 fn notify_from_recv(&mut self, _event: ProtocolEvent) {}
88
89 async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
90 #[cfg(feature = "trace_pedantic")]
91 trace!(?event, "send");
92 match &event {
93 ProtocolEvent::Message {
94 data: _data,
95 sid: _sid,
96 } => {
97 #[cfg(feature = "metrics")]
98 let (bytes, line) = {
99 let sid = *_sid;
100 let bytes = _data.len() as u64;
101 let line = self.metrics.init_sid(sid);
102 line.smsg_it.inc();
103 line.smsg_ib.inc_by(bytes);
104 (bytes, line)
105 };
106 let r = self.drain.send(MpscMsg::Event(event)).await;
107 #[cfg(feature = "metrics")]
108 {
109 line.smsg_ot[RemoveReason::Finished.i()].inc();
110 line.smsg_ob[RemoveReason::Finished.i()].inc_by(bytes);
111 }
112 r
113 },
114 _ => self.drain.send(MpscMsg::Event(event)).await,
115 }
116 }
117
118 async fn flush(
119 &mut self,
120 _: Bandwidth,
121 _: Duration,
122 ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
123 Ok(0)
124 }
125}
126
127#[async_trait]
128impl<S> RecvProtocol for MpscRecvProtocol<S>
129where
130 S: UnreliableSink<DataFormat = MpscMsg>,
131{
132 type CustomErr = S::CustomErr;
133
134 async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
135 let event = self.sink.recv().await?;
136 #[cfg(feature = "trace_pedantic")]
137 trace!(?event, "recv");
138 match event {
139 MpscMsg::Event(e) => {
140 #[cfg(feature = "metrics")]
141 {
142 if let ProtocolEvent::Message { data, sid } = &e {
143 let sid = *sid;
144 let bytes = data.len() as u64;
145 let line = self.metrics.init_sid(sid);
146 line.rmsg_it.inc();
147 line.rmsg_ib.inc_by(bytes);
148 line.rmsg_ot[RemoveReason::Finished.i()].inc();
149 line.rmsg_ob[RemoveReason::Finished.i()].inc_by(bytes);
150 }
151 }
152 Ok(e)
153 },
154 MpscMsg::InitFrame(_) => Err(ProtocolError::Violated),
155 }
156 }
157}
158
159#[async_trait]
160impl<D> ReliableDrain for MpscSendProtocol<D>
161where
162 D: UnreliableDrain<DataFormat = MpscMsg>,
163{
164 type CustomErr = D::CustomErr;
165
166 async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
167 self.drain.send(MpscMsg::InitFrame(frame)).await
168 }
169}
170
171#[async_trait]
172impl<S> ReliableSink for MpscRecvProtocol<S>
173where
174 S: UnreliableSink<DataFormat = MpscMsg>,
175{
176 type CustomErr = S::CustomErr;
177
178 async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
179 match self.sink.recv().await? {
180 MpscMsg::Event(_) => Err(ProtocolError::Violated),
181 MpscMsg::InitFrame(f) => Ok(f),
182 }
183 }
184}
185
186#[cfg(test)]
187pub mod test_utils {
188 use super::*;
189 use crate::metrics::{ProtocolMetricCache, ProtocolMetrics};
190 use async_channel::*;
191 use std::sync::Arc;
192
193 pub struct ACDrain {
194 sender: Sender<MpscMsg>,
195 }
196
197 pub struct ACSink {
198 receiver: Receiver<MpscMsg>,
199 }
200
201 pub fn ac_bound(
202 cap: usize,
203 metrics: Option<ProtocolMetricCache>,
204 ) -> [(MpscSendProtocol<ACDrain>, MpscRecvProtocol<ACSink>); 2] {
205 let (s1, r1) = bounded(cap);
206 let (s2, r2) = bounded(cap);
207 let m = metrics.unwrap_or_else(|| {
208 ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap()))
209 });
210 [
211 (
212 MpscSendProtocol::new(ACDrain { sender: s1 }, m.clone()),
213 MpscRecvProtocol::new(ACSink { receiver: r2 }, m.clone()),
214 ),
215 (
216 MpscSendProtocol::new(ACDrain { sender: s2 }, m.clone()),
217 MpscRecvProtocol::new(ACSink { receiver: r1 }, m),
218 ),
219 ]
220 }
221
222 #[async_trait]
223 impl UnreliableDrain for ACDrain {
224 type CustomErr = ();
225 type DataFormat = MpscMsg;
226
227 async fn send(
228 &mut self,
229 data: Self::DataFormat,
230 ) -> Result<(), ProtocolError<Self::CustomErr>> {
231 self.sender
232 .send(data)
233 .await
234 .map_err(|_| ProtocolError::Custom(()))
235 }
236 }
237
238 #[async_trait]
239 impl UnreliableSink for ACSink {
240 type CustomErr = ();
241 type DataFormat = MpscMsg;
242
243 async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
244 self.receiver
245 .recv()
246 .await
247 .map_err(|_| ProtocolError::Custom(()))
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use crate::{
255 InitProtocol,
256 mpsc::test_utils::*,
257 types::{Pid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2},
258 };
259
260 #[tokio::test]
261 async fn handshake_all_good() {
262 let [mut p1, mut p2] = ac_bound(10, None);
263 let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
264 let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await });
265 let (r1, r2) = tokio::join!(r1, r2);
266 assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42)));
267 assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337)));
268 }
269}