veloren_query_server/
ratelimit.rs

1use std::{
2    collections::HashMap,
3    net::IpAddr,
4    time::{Duration, Instant},
5};
6
7const SHIFT_EVERY: Duration = Duration::from_secs(15);
8
9#[derive(Clone, Copy, PartialEq, Eq, Hash)]
10pub enum ReducedIpAddr {
11    V4(u32),
12    V6(u64),
13}
14
15/// Per-IP state, divided into 4 segments of [`SHIFT_EVERY`] each (one minute at
16/// the time of writing).
17pub struct IpState([u16; 4]);
18
19pub struct RateLimiter {
20    states: HashMap<ReducedIpAddr, IpState>,
21    last_shift: Instant,
22    /// Maximum amount requests that can be done in `4 * SHIFT_EVERY`
23    limit: u16,
24}
25
26impl RateLimiter {
27    pub fn new(limit: u16) -> Self {
28        Self {
29            states: Default::default(),
30            last_shift: Instant::now(),
31            limit,
32        }
33    }
34
35    pub fn maintain(&mut self, now: Instant) {
36        if now.duration_since(self.last_shift) > SHIFT_EVERY {
37            // Remove empty states
38            self.states.retain(|_, state| {
39                state.shift();
40                !state.is_empty()
41            });
42            self.last_shift = now;
43        }
44    }
45
46    pub fn can_request(&mut self, ip: ReducedIpAddr) -> bool {
47        if let Some(state) = self.states.get_mut(&ip) {
48            state.0[0] = state.0[0].saturating_add(1);
49
50            state.total() < self.limit
51        } else {
52            self.states.insert(ip, IpState::default());
53            true
54        }
55    }
56}
57
58impl IpState {
59    fn shift(&mut self) {
60        self.0.rotate_right(1);
61        self.0[0] = 0;
62    }
63
64    fn is_empty(&self) -> bool { self.0.iter().all(|&freq| freq == 0) }
65
66    fn total(&self) -> u16 { self.0.iter().fold(0, |total, &v| total.saturating_add(v)) }
67}
68
69impl Default for IpState {
70    fn default() -> Self { Self([1, 0, 0, 0]) }
71}
72
73impl From<IpAddr> for ReducedIpAddr {
74    fn from(value: IpAddr) -> Self {
75        match value {
76            IpAddr::V4(v4) => Self::V4(u32::from_be_bytes(v4.octets())),
77            IpAddr::V6(v6) => {
78                let bytes = v6.octets();
79                Self::V6(u64::from_be_bytes([
80                    bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
81                ]))
82            },
83        }
84    }
85}