veloren_query_server/
ratelimit.rs1use std::{
2 collections::HashMap,
3 net::IpAddr,
4 time::{Duration, Instant},
5};
6
7const SHIFT_EVERY: Duration = Duration::from_secs(15);
8
9#[derive(Clone, Copy, PartialEq, Eq, Hash)]
10pub enum ReducedIpAddr {
11 V4(u32),
12 V6(u64),
13}
14
15pub struct IpState([u16; 4]);
18
19pub struct RateLimiter {
20 states: HashMap<ReducedIpAddr, IpState>,
21 last_shift: Instant,
22 limit: u16,
24}
25
26impl RateLimiter {
27 pub fn new(limit: u16) -> Self {
28 Self {
29 states: Default::default(),
30 last_shift: Instant::now(),
31 limit,
32 }
33 }
34
35 pub fn maintain(&mut self, now: Instant) {
36 if now.duration_since(self.last_shift) > SHIFT_EVERY {
37 self.states.retain(|_, state| {
39 state.shift();
40 !state.is_empty()
41 });
42 self.last_shift = now;
43 }
44 }
45
46 pub fn can_request(&mut self, ip: ReducedIpAddr) -> bool {
47 if let Some(state) = self.states.get_mut(&ip) {
48 state.0[0] = state.0[0].saturating_add(1);
49
50 state.total() < self.limit
51 } else {
52 self.states.insert(ip, IpState::default());
53 true
54 }
55 }
56}
57
58impl IpState {
59 fn shift(&mut self) {
60 self.0.rotate_right(1);
61 self.0[0] = 0;
62 }
63
64 fn is_empty(&self) -> bool { self.0.iter().all(|&freq| freq == 0) }
65
66 fn total(&self) -> u16 { self.0.iter().fold(0, |total, &v| total.saturating_add(v)) }
67}
68
69impl Default for IpState {
70 fn default() -> Self { Self([1, 0, 0, 0]) }
71}
72
73impl From<IpAddr> for ReducedIpAddr {
74 fn from(value: IpAddr) -> Self {
75 match value {
76 IpAddr::V4(v4) => Self::V4(u32::from_be_bytes(v4.octets())),
77 IpAddr::V6(v6) => {
78 let bytes = v6.octets();
79 Self::V6(u64::from_be_bytes([
80 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
81 ]))
82 },
83 }
84 }
85}