1use crate::api::{StreamError, StreamParams};
2use bincode::{
3 config::legacy,
4 error::DecodeError,
5 serde::{decode_from_slice, encode_to_vec},
6};
7use bytes::Bytes;
8#[cfg(feature = "compression")]
9use network_protocol::Promises;
10use serde::{Serialize, de::DeserializeOwned};
11use std::io;
12#[cfg(all(feature = "compression", debug_assertions))]
13use tracing::warn;
14
15pub struct Message {
23 pub(crate) data: Bytes,
24 #[cfg(feature = "compression")]
25 pub(crate) compressed: bool,
26}
27
28impl Message {
29 pub fn serialize<M: Serialize + ?Sized>(message: &M, stream_params: StreamParams) -> Self {
45 let serialized_data = encode_to_vec(message, legacy()).unwrap();
47
48 #[cfg(feature = "compression")]
49 let compressed = stream_params.promises.contains(Promises::COMPRESSED);
50 #[cfg(feature = "compression")]
51 let data = if compressed {
52 let mut compressed_data = Vec::with_capacity(serialized_data.len() / 4 + 10);
53 let mut table = lz_fear::raw::U32Table::default();
54 lz_fear::raw::compress2(&serialized_data, 0, &mut table, &mut compressed_data).unwrap();
55 compressed_data
56 } else {
57 serialized_data
58 };
59 #[cfg(not(feature = "compression"))]
60 let data = serialized_data;
61 #[cfg(not(feature = "compression"))]
62 let _stream_params = stream_params;
63
64 Self {
65 data: Bytes::from(data),
66 #[cfg(feature = "compression")]
67 compressed,
68 }
69 }
70
71 pub fn deserialize<M: DeserializeOwned>(self) -> Result<M, StreamError> {
107 #[cfg(not(feature = "compression"))]
108 let uncompressed_data = self.data;
109
110 #[cfg(feature = "compression")]
111 let uncompressed_data = if self.compressed {
112 {
113 let mut uncompressed_data = Vec::with_capacity(self.data.len() * 2);
114 if let Err(e) = lz_fear::raw::decompress_raw(
115 &self.data,
116 &[0; 0],
117 &mut uncompressed_data,
118 usize::MAX,
119 ) {
120 return Err(StreamError::Compression(e));
121 }
122 Bytes::from(uncompressed_data)
123 }
124 } else {
125 self.data
126 };
127
128 match decode_from_slice(&uncompressed_data, legacy()) {
129 Ok((m, _)) => Ok(m),
130 Err(e) => Err(StreamError::Deserialize(Box::new(e))),
131 }
132 }
133
134 #[cfg(debug_assertions)]
135 pub(crate) fn verify(&self, params: StreamParams) {
136 #[cfg(not(feature = "compression"))]
137 let _params = params;
138 #[cfg(feature = "compression")]
139 if self.compressed != params.promises.contains(Promises::COMPRESSED) {
140 warn!(
141 ?params,
142 "verify failed, msg is {} and it doesn't match with stream", self.compressed
143 );
144 }
145 }
146}
147
148pub(crate) fn partial_eq_io_error(first: &io::Error, second: &io::Error) -> bool {
150 if let Some(f) = first.raw_os_error() {
151 if let Some(s) = second.raw_os_error() {
152 f == s
153 } else {
154 false
155 }
156 } else {
157 let fk = first.kind();
158 fk == second.kind() && fk != io::ErrorKind::Other
159 }
160}
161
162pub(crate) fn partial_eq_bincode(first: &DecodeError, second: &DecodeError) -> bool {
163 use bincode::{error::DecodeError::*, serde::DecodeError::*};
164 match *first {
165 UnexpectedEnd { additional: f } => {
166 matches!(*second, UnexpectedEnd { additional: s } if f == s)
167 },
168 LimitExceeded => matches!(*second, LimitExceeded),
169 InvalidIntegerType {
170 expected: ref fe,
171 found: ref ff,
172 } => {
173 matches!(*second, InvalidIntegerType { expected: ref se, found: ref sf } if fe == se && ff == sf)
174 },
175 NonZeroTypeIsZero {
176 non_zero_type: ref f,
177 } => matches!(*second, NonZeroTypeIsZero { non_zero_type: ref s } if f == s),
178 UnexpectedVariant {
179 type_name: ft,
180 allowed: fa,
181 found: ff,
182 } => {
183 matches!(*second, UnexpectedVariant { type_name: st, allowed: sa, found: sf } if ft == st && fa == sa && ff == sf)
184 },
185 Utf8 { inner: f } => matches!(*second, Utf8 { inner: s } if f == s),
186 InvalidCharEncoding(f) => matches!(*second, InvalidCharEncoding(s) if f == s),
187 InvalidBooleanValue(f) => matches!(*second, InvalidBooleanValue(s) if f == s),
188 ArrayLengthMismatch {
189 required: fr,
190 found: ff,
191 } => {
192 matches!(*second, ArrayLengthMismatch { required: sr, found: sf } if fr == sr && ff == sf)
193 },
194 OutsideUsizeRange(f) => matches!(*second, OutsideUsizeRange(s) if f == s),
195 EmptyEnum { type_name: f } => matches!(*second, EmptyEnum { type_name: s } if f == s),
196 InvalidDuration {
197 secs: fs,
198 nanos: fnn,
199 } => matches!(*second, InvalidDuration { secs: ss, nanos: sn } if fs == ss && fnn == sn),
200 InvalidSystemTime { duration: fd } => {
201 matches!(*second, InvalidSystemTime { duration: sd } if fd == sd)
202 },
203 CStringNulError { position: fp } => {
204 matches!(*second, CStringNulError { position: sp } if fp == sp)
205 },
206 Io {
207 inner: ref fi,
208 additional: fa,
209 } => {
210 matches!(*second, Io { inner: ref si, additional: sa } if partial_eq_io_error(fi, si) && fa == sa)
211 },
212 Other(f) => matches!(*second, Other(s) if f == s),
213 OtherString(ref f) => matches!(*second, OtherString(ref s) if f == s),
214 Serde(ref f) => match f {
215 AnyNotSupported => matches!(*second, Serde(ref s) if matches!(*s, AnyNotSupported)),
216 IdentifierNotSupported => {
217 matches!(*second, Serde(ref s) if matches!(*s, IdentifierNotSupported))
218 },
219 IgnoredAnyNotSupported => {
220 matches!(*second, Serde(ref s) if matches!(*s, IgnoredAnyNotSupported))
221 },
222 CannotBorrowOwnedData => {
223 matches!(*second, Serde(ref s) if matches!(*s, CannotBorrowOwnedData))
224 },
225 _ => false,
227 },
228 _ => false,
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use crate::{api::StreamParams, message::*};
236
237 fn stub_stream(compressed: bool) -> StreamParams {
238 #[cfg(feature = "compression")]
239 let promises = if compressed {
240 Promises::COMPRESSED
241 } else {
242 Promises::empty()
243 };
244
245 #[cfg(not(feature = "compression"))]
246 let promises = Promises::empty();
247
248 StreamParams { promises }
249 }
250
251 #[test]
252 fn serialize_test() {
253 let msg = Message::serialize("abc", stub_stream(false));
254 assert_eq!(msg.data.len(), 11);
255 assert_eq!(msg.data[0], 3);
256 assert_eq!(msg.data[1..7], [0, 0, 0, 0, 0, 0]);
257 assert_eq!(msg.data[8], b'a');
258 assert_eq!(msg.data[9], b'b');
259 assert_eq!(msg.data[10], b'c');
260 }
261
262 #[cfg(feature = "compression")]
263 #[test]
264 fn serialize_compress_small() {
265 let msg = Message::serialize("abc", stub_stream(true));
266 assert_eq!(msg.data.len(), 12);
267 assert_eq!(msg.data[0], 176);
268 assert_eq!(msg.data[1], 3);
269 assert_eq!(msg.data[2..8], [0, 0, 0, 0, 0, 0]);
270 assert_eq!(msg.data[9], b'a');
271 assert_eq!(msg.data[10], b'b');
272 assert_eq!(msg.data[11], b'c');
273 }
274
275 #[cfg(feature = "compression")]
276 #[test]
277 fn serialize_compress_medium() {
278 let msg = (
279 "abccc",
280 100u32,
281 80u32,
282 "DATA",
283 4,
284 0,
285 0,
286 0,
287 "assets/data/plants/flowers/greenrose.ron",
288 );
289 let msg = Message::serialize(&msg, stub_stream(true));
290 assert_eq!(msg.data.len(), 79);
291 assert_eq!(msg.data[0], 34);
292 assert_eq!(msg.data[1], 5);
293 assert_eq!(msg.data[2], 0);
294 assert_eq!(msg.data[3], 1);
295 assert_eq!(msg.data[20], 20);
296 assert_eq!(msg.data[40], 115);
297 assert_eq!(msg.data[60], 111);
298 }
299
300 #[cfg(feature = "compression")]
301 #[test]
302 fn serialize_compress_large() {
303 use rand::{Rng, SeedableRng};
304 let mut seed = [0u8; 32];
305 seed[8] = 13;
306 seed[9] = 37;
307 let mut rnd = rand::rngs::StdRng::from_seed(seed);
308 let mut msg = vec![0u8; 10000];
309 for (i, s) in msg.iter_mut().enumerate() {
310 match i.rem_euclid(32) {
311 2 => *s = 128,
312 3 => *s = 128 + 16,
313 4 => *s = 150,
314 11 => *s = 64,
315 12 => *s = rnd.random::<u8>() / 32,
316 _ => {},
317 }
318 }
319 let msg = Message::serialize(&msg, stub_stream(true));
320 assert_eq!(msg.data.len(), 1331);
321 }
322}