veloren_network_protocol/
handshake.rs

1use crate::{
2    InitProtocol,
3    error::{InitProtocolError, ProtocolError},
4    frame::InitFrame,
5    types::{
6        Pid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, Sid, VELOREN_MAGIC_NUMBER,
7        VELOREN_NETWORK_VERSION,
8    },
9};
10use async_trait::async_trait;
11use tracing::{debug, error, info, trace};
12
13/// Implement this for auto Handshake with [`ReliableSink`].
14/// You must make sure that EVERY message send this way actually is received on
15/// the receiving site:
16///  - exactly once
17///  - in the correct order
18///  - correctly
19///
20/// [`ReliableSink`]: crate::ReliableSink
21/// [`RecvProtocol`]: crate::RecvProtocol
22#[async_trait]
23pub trait ReliableDrain {
24    type CustomErr: std::fmt::Debug + Send;
25    async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>>;
26}
27
28/// Implement this for auto Handshake with [`ReliableDrain`]. See
29/// [`ReliableDrain`].
30///
31/// [`ReliableDrain`]: crate::ReliableDrain
32#[async_trait]
33pub trait ReliableSink {
34    type CustomErr: std::fmt::Debug + Send;
35    async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>>;
36}
37
38#[async_trait]
39impl<D, S, E> InitProtocol for (D, S)
40where
41    D: ReliableDrain<CustomErr = E> + Send,
42    S: ReliableSink<CustomErr = E> + Send,
43    E: std::fmt::Debug + Send,
44{
45    type CustomErr = E;
46
47    async fn initialize(
48        &mut self,
49        initializer: bool,
50        local_pid: Pid,
51        local_secret: u128,
52    ) -> Result<(Pid, Sid, u128), InitProtocolError<E>> {
53        #[cfg(debug_assertions)]
54        const WRONG_NUMBER: &str = "Handshake does not contain the magic number required by \
55                                    veloren server.\nWe are not sure if you are a valid veloren \
56                                    client.\nClosing the connection";
57        #[cfg(debug_assertions)]
58        const WRONG_VERSION: &str = "Handshake does contain a correct magic number, but invalid \
59                                     version.\nWe don't know how to communicate with \
60                                     you.\nClosing the connection";
61        const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \
62                             something went wrong on network layer and connection will be closed";
63
64        let drain = &mut self.0;
65        let sink = &mut self.1;
66
67        if initializer {
68            drain
69                .send(InitFrame::Handshake {
70                    magic_number: VELOREN_MAGIC_NUMBER,
71                    version: VELOREN_NETWORK_VERSION,
72                })
73                .await?;
74        }
75
76        match sink.recv().await? {
77            InitFrame::Handshake {
78                magic_number,
79                version,
80            } => {
81                trace!(?magic_number, ?version, "Recv handshake");
82                if magic_number != VELOREN_MAGIC_NUMBER {
83                    error!(?magic_number, "Connection with invalid magic_number");
84                    #[cfg(debug_assertions)]
85                    drain
86                        .send(InitFrame::Raw(WRONG_NUMBER.as_bytes().to_vec()))
87                        .await?;
88                    Err(InitProtocolError::WrongMagicNumber(magic_number))
89                } else if version[0] != VELOREN_NETWORK_VERSION[0]
90                    || version[1] != VELOREN_NETWORK_VERSION[1]
91                {
92                    error!(?version, "Connection with wrong network version");
93                    #[cfg(debug_assertions)]
94                    drain
95                        .send(InitFrame::Raw(
96                            format!(
97                                "{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection",
98                                WRONG_VERSION, VELOREN_NETWORK_VERSION, version,
99                            )
100                            .as_bytes()
101                            .to_vec(),
102                        ))
103                        .await?;
104                    Err(InitProtocolError::WrongVersion(version))
105                } else {
106                    trace!("Handshake Frame completed");
107                    if initializer {
108                        drain
109                            .send(InitFrame::Init {
110                                pid: local_pid,
111                                secret: local_secret,
112                            })
113                            .await?;
114                    } else {
115                        drain
116                            .send(InitFrame::Handshake {
117                                magic_number: VELOREN_MAGIC_NUMBER,
118                                version: VELOREN_NETWORK_VERSION,
119                            })
120                            .await?;
121                    }
122                    Ok(())
123                }
124            },
125            InitFrame::Raw(bytes) => {
126                match std::str::from_utf8(bytes.as_slice()) {
127                    Ok(string) => error!(?string, ERR_S),
128                    _ => error!(?bytes, ERR_S),
129                }
130                Err(InitProtocolError::NotHandshake)
131            },
132            _ => {
133                info!("Handshake failed");
134                Err(InitProtocolError::NotHandshake)
135            },
136        }?;
137
138        match sink.recv().await? {
139            InitFrame::Init { pid, secret } => {
140                debug!(?pid, "Participant send their ID");
141                let stream_id_offset = if initializer {
142                    STREAM_ID_OFFSET1
143                } else {
144                    drain
145                        .send(InitFrame::Init {
146                            pid: local_pid,
147                            secret: local_secret,
148                        })
149                        .await?;
150                    STREAM_ID_OFFSET2
151                };
152                info!(?pid, "This Handshake is now configured!");
153                Ok((pid, stream_id_offset, secret))
154            },
155            InitFrame::Raw(bytes) => {
156                match std::str::from_utf8(bytes.as_slice()) {
157                    Ok(string) => error!(?string, ERR_S),
158                    _ => error!(?bytes, ERR_S),
159                }
160                Err(InitProtocolError::NotId)
161            },
162            _ => {
163                info!("Handshake failed");
164                Err(InitProtocolError::NotId)
165            },
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::{InitProtocolError, mpsc::test_utils::*};
174
175    #[tokio::test]
176    async fn handshake_drop_start() {
177        let [mut p1, p2] = ac_bound(10, None);
178        let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
179        let r2 = tokio::spawn(async move {
180            let _ = &p2;
181            let _ = p2;
182        });
183        let (r1, _) = tokio::join!(r1, r2);
184        assert_eq!(r1.unwrap(), Err(InitProtocolError::Custom(())));
185    }
186
187    #[tokio::test]
188    async fn handshake_wrong_magic_number() {
189        let [mut p1, mut p2] = ac_bound(10, None);
190        let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
191        let r2 = tokio::spawn(async move {
192            let _ = p2.1.recv().await?;
193            p2.0.send(InitFrame::Handshake {
194                magic_number: *b"woopsie",
195                version: VELOREN_NETWORK_VERSION,
196            })
197            .await?;
198            let _ = p2.1.recv().await?;
199            Result::<(), InitProtocolError<()>>::Ok(())
200        });
201        let (r1, r2) = tokio::join!(r1, r2);
202        assert_eq!(
203            r1.unwrap(),
204            Err(InitProtocolError::WrongMagicNumber(*b"woopsie"))
205        );
206        assert_eq!(r2.unwrap(), Ok(()));
207    }
208
209    #[tokio::test]
210    async fn handshake_wrong_version() {
211        let [mut p1, mut p2] = ac_bound(10, None);
212        let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
213        let r2 = tokio::spawn(async move {
214            let _ = p2.1.recv().await?;
215            p2.0.send(InitFrame::Handshake {
216                magic_number: VELOREN_MAGIC_NUMBER,
217                version: [0, 1, 2],
218            })
219            .await?;
220            let _ = p2.1.recv().await?;
221            let _ = p2.1.recv().await?; //this should be closed now
222            Ok(())
223        });
224        let (r1, r2) = tokio::join!(r1, r2);
225        assert_eq!(r1.unwrap(), Err(InitProtocolError::WrongVersion([0, 1, 2])));
226        assert_eq!(r2.unwrap(), Err(InitProtocolError::Custom(())));
227    }
228
229    #[tokio::test]
230    async fn handshake_unexpected_raw() {
231        let [mut p1, mut p2] = ac_bound(10, None);
232        let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
233        let r2 = tokio::spawn(async move {
234            let _ = p2.1.recv().await?;
235            p2.0.send(InitFrame::Handshake {
236                magic_number: VELOREN_MAGIC_NUMBER,
237                version: VELOREN_NETWORK_VERSION,
238            })
239            .await?;
240            let _ = p2.1.recv().await?;
241            p2.0.send(InitFrame::Raw(b"Hello World".to_vec())).await?;
242            Result::<(), InitProtocolError<()>>::Ok(())
243        });
244        let (r1, r2) = tokio::join!(r1, r2);
245        assert_eq!(r1.unwrap(), Err(InitProtocolError::NotId));
246        assert_eq!(r2.unwrap(), Ok(()));
247    }
248}