1use crate::api::{StreamError, StreamParams};
2use bytes::Bytes;
3#[cfg(feature = "compression")]
4use network_protocol::Promises;
5use serde::{Serialize, de::DeserializeOwned};
6use std::io;
7#[cfg(all(feature = "compression", debug_assertions))]
8use tracing::warn;
9
10pub struct Message {
18 pub(crate) data: Bytes,
19 #[cfg(feature = "compression")]
20 pub(crate) compressed: bool,
21}
22
23impl Message {
24 pub fn serialize<M: Serialize + ?Sized>(message: &M, stream_params: StreamParams) -> Self {
40 let serialized_data = bincode::serialize(message).unwrap();
42
43 #[cfg(feature = "compression")]
44 let compressed = stream_params.promises.contains(Promises::COMPRESSED);
45 #[cfg(feature = "compression")]
46 let data = if compressed {
47 let mut compressed_data = Vec::with_capacity(serialized_data.len() / 4 + 10);
48 let mut table = lz_fear::raw::U32Table::default();
49 lz_fear::raw::compress2(&serialized_data, 0, &mut table, &mut compressed_data).unwrap();
50 compressed_data
51 } else {
52 serialized_data
53 };
54 #[cfg(not(feature = "compression"))]
55 let data = serialized_data;
56 #[cfg(not(feature = "compression"))]
57 let _stream_params = stream_params;
58
59 Self {
60 data: Bytes::from(data),
61 #[cfg(feature = "compression")]
62 compressed,
63 }
64 }
65
66 pub fn deserialize<M: DeserializeOwned>(self) -> Result<M, StreamError> {
102 #[cfg(not(feature = "compression"))]
103 let uncompressed_data = self.data;
104
105 #[cfg(feature = "compression")]
106 let uncompressed_data = if self.compressed {
107 {
108 let mut uncompressed_data = Vec::with_capacity(self.data.len() * 2);
109 if let Err(e) = lz_fear::raw::decompress_raw(
110 &self.data,
111 &[0; 0],
112 &mut uncompressed_data,
113 usize::MAX,
114 ) {
115 return Err(StreamError::Compression(e));
116 }
117 Bytes::from(uncompressed_data)
118 }
119 } else {
120 self.data
121 };
122
123 match bincode::deserialize(&uncompressed_data) {
124 Ok(m) => Ok(m),
125 Err(e) => Err(StreamError::Deserialize(e)),
126 }
127 }
128
129 #[cfg(debug_assertions)]
130 pub(crate) fn verify(&self, params: StreamParams) {
131 #[cfg(not(feature = "compression"))]
132 let _params = params;
133 #[cfg(feature = "compression")]
134 if self.compressed != params.promises.contains(Promises::COMPRESSED) {
135 warn!(
136 ?params,
137 "verify failed, msg is {} and it doesn't match with stream", self.compressed
138 );
139 }
140 }
141}
142
143pub(crate) fn partial_eq_io_error(first: &io::Error, second: &io::Error) -> bool {
145 if let Some(f) = first.raw_os_error() {
146 if let Some(s) = second.raw_os_error() {
147 f == s
148 } else {
149 false
150 }
151 } else {
152 let fk = first.kind();
153 fk == second.kind() && fk != io::ErrorKind::Other
154 }
155}
156
157pub(crate) fn partial_eq_bincode(first: &bincode::ErrorKind, second: &bincode::ErrorKind) -> bool {
158 use bincode::ErrorKind::*;
159 match *first {
160 Io(ref f) => matches!(*second, Io(ref s) if partial_eq_io_error(f, s)),
161 InvalidUtf8Encoding(f) => matches!(*second, InvalidUtf8Encoding(s) if f == s),
162 InvalidBoolEncoding(f) => matches!(*second, InvalidBoolEncoding(s) if f == s),
163 InvalidCharEncoding => matches!(*second, InvalidCharEncoding),
164 InvalidTagEncoding(f) => matches!(*second, InvalidTagEncoding(s) if f == s),
165 DeserializeAnyNotSupported => matches!(*second, DeserializeAnyNotSupported),
166 SizeLimit => matches!(*second, SizeLimit),
167 SequenceMustHaveLength => matches!(*second, SequenceMustHaveLength),
168 Custom(ref f) => matches!(*second, Custom(ref s) if f == s),
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use crate::{api::StreamParams, message::*};
175
176 fn stub_stream(compressed: bool) -> StreamParams {
177 #[cfg(feature = "compression")]
178 let promises = if compressed {
179 Promises::COMPRESSED
180 } else {
181 Promises::empty()
182 };
183
184 #[cfg(not(feature = "compression"))]
185 let promises = Promises::empty();
186
187 StreamParams { promises }
188 }
189
190 #[test]
191 fn serialize_test() {
192 let msg = Message::serialize("abc", stub_stream(false));
193 assert_eq!(msg.data.len(), 11);
194 assert_eq!(msg.data[0], 3);
195 assert_eq!(msg.data[1..7], [0, 0, 0, 0, 0, 0]);
196 assert_eq!(msg.data[8], b'a');
197 assert_eq!(msg.data[9], b'b');
198 assert_eq!(msg.data[10], b'c');
199 }
200
201 #[cfg(feature = "compression")]
202 #[test]
203 fn serialize_compress_small() {
204 let msg = Message::serialize("abc", stub_stream(true));
205 assert_eq!(msg.data.len(), 12);
206 assert_eq!(msg.data[0], 176);
207 assert_eq!(msg.data[1], 3);
208 assert_eq!(msg.data[2..8], [0, 0, 0, 0, 0, 0]);
209 assert_eq!(msg.data[9], b'a');
210 assert_eq!(msg.data[10], b'b');
211 assert_eq!(msg.data[11], b'c');
212 }
213
214 #[cfg(feature = "compression")]
215 #[test]
216 fn serialize_compress_medium() {
217 let msg = (
218 "abccc",
219 100u32,
220 80u32,
221 "DATA",
222 4,
223 0,
224 0,
225 0,
226 "assets/data/plants/flowers/greenrose.ron",
227 );
228 let msg = Message::serialize(&msg, stub_stream(true));
229 assert_eq!(msg.data.len(), 79);
230 assert_eq!(msg.data[0], 34);
231 assert_eq!(msg.data[1], 5);
232 assert_eq!(msg.data[2], 0);
233 assert_eq!(msg.data[3], 1);
234 assert_eq!(msg.data[20], 20);
235 assert_eq!(msg.data[40], 115);
236 assert_eq!(msg.data[60], 111);
237 }
238
239 #[cfg(feature = "compression")]
240 #[test]
241 fn serialize_compress_large() {
242 use rand::{Rng, SeedableRng};
243 let mut seed = [0u8; 32];
244 seed[8] = 13;
245 seed[9] = 37;
246 let mut rnd = rand::rngs::StdRng::from_seed(seed);
247 let mut msg = vec![0u8; 10000];
248 for (i, s) in msg.iter_mut().enumerate() {
249 match i.rem_euclid(32) {
250 2 => *s = 128,
251 3 => *s = 128 + 16,
252 4 => *s = 150,
253 11 => *s = 64,
254 12 => *s = rnd.gen::<u8>() / 32,
255 _ => {},
256 }
257 }
258 let msg = Message::serialize(&msg, stub_stream(true));
259 assert_eq!(msg.data.len(), 1331);
260 }
261}