use crate::api::{StreamError, StreamParams};
use bytes::Bytes;
#[cfg(feature = "compression")]
use network_protocol::Promises;
use serde::{de::DeserializeOwned, Serialize};
use std::io;
#[cfg(all(feature = "compression", debug_assertions))]
use tracing::warn;
pub struct Message {
pub(crate) data: Bytes,
#[cfg(feature = "compression")]
pub(crate) compressed: bool,
}
impl Message {
pub fn serialize<M: Serialize + ?Sized>(message: &M, stream_params: StreamParams) -> Self {
let serialized_data = bincode::serialize(message).unwrap();
#[cfg(feature = "compression")]
let compressed = stream_params.promises.contains(Promises::COMPRESSED);
#[cfg(feature = "compression")]
let data = if compressed {
let mut compressed_data = Vec::with_capacity(serialized_data.len() / 4 + 10);
let mut table = lz_fear::raw::U32Table::default();
lz_fear::raw::compress2(&serialized_data, 0, &mut table, &mut compressed_data).unwrap();
compressed_data
} else {
serialized_data
};
#[cfg(not(feature = "compression"))]
let data = serialized_data;
#[cfg(not(feature = "compression"))]
let _stream_params = stream_params;
Self {
data: Bytes::from(data),
#[cfg(feature = "compression")]
compressed,
}
}
pub fn deserialize<M: DeserializeOwned>(self) -> Result<M, StreamError> {
#[cfg(not(feature = "compression"))]
let uncompressed_data = self.data;
#[cfg(feature = "compression")]
let uncompressed_data = if self.compressed {
{
let mut uncompressed_data = Vec::with_capacity(self.data.len() * 2);
if let Err(e) = lz_fear::raw::decompress_raw(
&self.data,
&[0; 0],
&mut uncompressed_data,
usize::MAX,
) {
return Err(StreamError::Compression(e));
}
Bytes::from(uncompressed_data)
}
} else {
self.data
};
match bincode::deserialize(&uncompressed_data) {
Ok(m) => Ok(m),
Err(e) => Err(StreamError::Deserialize(e)),
}
}
#[cfg(debug_assertions)]
pub(crate) fn verify(&self, params: StreamParams) {
#[cfg(not(feature = "compression"))]
let _params = params;
#[cfg(feature = "compression")]
if self.compressed != params.promises.contains(Promises::COMPRESSED) {
warn!(
?params,
"verify failed, msg is {} and it doesn't match with stream", self.compressed
);
}
}
}
pub(crate) fn partial_eq_io_error(first: &io::Error, second: &io::Error) -> bool {
if let Some(f) = first.raw_os_error() {
if let Some(s) = second.raw_os_error() {
f == s
} else {
false
}
} else {
let fk = first.kind();
fk == second.kind() && fk != io::ErrorKind::Other
}
}
pub(crate) fn partial_eq_bincode(first: &bincode::ErrorKind, second: &bincode::ErrorKind) -> bool {
use bincode::ErrorKind::*;
match *first {
Io(ref f) => matches!(*second, Io(ref s) if partial_eq_io_error(f, s)),
InvalidUtf8Encoding(f) => matches!(*second, InvalidUtf8Encoding(s) if f == s),
InvalidBoolEncoding(f) => matches!(*second, InvalidBoolEncoding(s) if f == s),
InvalidCharEncoding => matches!(*second, InvalidCharEncoding),
InvalidTagEncoding(f) => matches!(*second, InvalidTagEncoding(s) if f == s),
DeserializeAnyNotSupported => matches!(*second, DeserializeAnyNotSupported),
SizeLimit => matches!(*second, SizeLimit),
SequenceMustHaveLength => matches!(*second, SequenceMustHaveLength),
Custom(ref f) => matches!(*second, Custom(ref s) if f == s),
}
}
#[cfg(test)]
mod tests {
use crate::{api::StreamParams, message::*};
fn stub_stream(compressed: bool) -> StreamParams {
#[cfg(feature = "compression")]
let promises = if compressed {
Promises::COMPRESSED
} else {
Promises::empty()
};
#[cfg(not(feature = "compression"))]
let promises = Promises::empty();
StreamParams { promises }
}
#[test]
fn serialize_test() {
let msg = Message::serialize("abc", stub_stream(false));
assert_eq!(msg.data.len(), 11);
assert_eq!(msg.data[0], 3);
assert_eq!(msg.data[1..7], [0, 0, 0, 0, 0, 0]);
assert_eq!(msg.data[8], b'a');
assert_eq!(msg.data[9], b'b');
assert_eq!(msg.data[10], b'c');
}
#[cfg(feature = "compression")]
#[test]
fn serialize_compress_small() {
let msg = Message::serialize("abc", stub_stream(true));
assert_eq!(msg.data.len(), 12);
assert_eq!(msg.data[0], 176);
assert_eq!(msg.data[1], 3);
assert_eq!(msg.data[2..8], [0, 0, 0, 0, 0, 0]);
assert_eq!(msg.data[9], b'a');
assert_eq!(msg.data[10], b'b');
assert_eq!(msg.data[11], b'c');
}
#[cfg(feature = "compression")]
#[test]
fn serialize_compress_medium() {
let msg = (
"abccc",
100u32,
80u32,
"DATA",
4,
0,
0,
0,
"assets/data/plants/flowers/greenrose.ron",
);
let msg = Message::serialize(&msg, stub_stream(true));
assert_eq!(msg.data.len(), 79);
assert_eq!(msg.data[0], 34);
assert_eq!(msg.data[1], 5);
assert_eq!(msg.data[2], 0);
assert_eq!(msg.data[3], 1);
assert_eq!(msg.data[20], 20);
assert_eq!(msg.data[40], 115);
assert_eq!(msg.data[60], 111);
}
#[cfg(feature = "compression")]
#[test]
fn serialize_compress_large() {
use rand::{Rng, SeedableRng};
let mut seed = [0u8; 32];
seed[8] = 13;
seed[9] = 37;
let mut rnd = rand::rngs::StdRng::from_seed(seed);
let mut msg = vec![0u8; 10000];
for (i, s) in msg.iter_mut().enumerate() {
match i.rem_euclid(32) {
2 => *s = 128,
3 => *s = 128 + 16,
4 => *s = 150,
11 => *s = 64,
12 => *s = rnd.gen::<u8>() / 32,
_ => {},
}
}
let msg = Message::serialize(&msg, stub_stream(true));
assert_eq!(msg.data.len(), 1331);
}
}