1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use axum::{
    extract::{ConnectInfo, Query, State},
    middleware::Next,
    response::{IntoResponse, Response},
    routing::get,
    Json, Router,
};
use chrono::DateTime;
use hyper::{Request, StatusCode};
use serde::{Deserialize, Deserializer};
use server::chat::ChatCache;
use std::{
    collections::HashSet,
    net::{IpAddr, SocketAddr},
    str::FromStr,
    sync::Arc,
};
use tokio::sync::Mutex;

/// Keep Size small, so we dont have to Clone much for each request.
#[derive(Clone)]
struct ChatToken {
    secret_token: Option<String>,
}

#[derive(Clone, Default)]
struct IpAddresses {
    users: Arc<Mutex<HashSet<IpAddr>>>,
}

async fn validate_secret<B>(
    State(token): State<ChatToken>,
    req: Request<B>,
    next: Next<B>,
) -> Result<Response, StatusCode> {
    // check if this endpoint is disabled
    let secret_token = token.secret_token.ok_or(StatusCode::METHOD_NOT_ALLOWED)?;

    pub const X_SECRET_TOKEN: &str = "X-Secret-Token";
    let session_cookie = req
        .headers()
        .get(X_SECRET_TOKEN)
        .ok_or(StatusCode::UNAUTHORIZED)?;

    if session_cookie.as_bytes() != secret_token.as_bytes() {
        return Err(StatusCode::UNAUTHORIZED);
    }

    Ok(next.run(req).await)
}

/// Logs each new IP address that accesses this API authenticated
async fn log_users<B>(
    State(ip_addresses): State<IpAddresses>,
    ConnectInfo(addr): ConnectInfo<SocketAddr>,
    req: Request<B>,
    next: Next<B>,
) -> Result<Response, StatusCode> {
    let mut ip_addresses = ip_addresses.users.lock().await;
    let ip_addr = addr.ip();
    if !ip_addresses.contains(&ip_addr) {
        ip_addresses.insert(ip_addr);
        let users_so_far = ip_addresses.len();
        tracing::info!(?ip_addr, ?users_so_far, "Is accessing the /chat endpoint");
    }
    Ok(next.run(req).await)
}

pub fn router(cache: ChatCache, secret_token: Option<String>) -> Router {
    let token = ChatToken { secret_token };
    let ip_addrs = IpAddresses::default();
    Router::new()
        .route("/history", get(history))
        .layer(axum::middleware::from_fn_with_state(ip_addrs, log_users))
        .layer(axum::middleware::from_fn_with_state(token, validate_secret))
        .with_state(cache)
}

#[derive(Debug, Deserialize)]
struct Params {
    #[serde(default, deserialize_with = "empty_string_as_none")]
    /// To be used to get all messages without duplicates nor losing messages
    from_time_exclusive_rfc3339: Option<String>,
}

fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error>
where
    D: Deserializer<'de>,
    T: FromStr,
    T::Err: core::fmt::Display,
{
    let opt = Option::<String>::deserialize(de)?;
    match opt.as_deref() {
        None | Some("") => Ok(None),
        Some(s) => FromStr::from_str(s)
            .map_err(serde::de::Error::custom)
            .map(Some),
    }
}

async fn history(
    State(cache): State<ChatCache>,
    Query(params): Query<Params>,
) -> Result<impl IntoResponse, StatusCode> {
    // first validate parameters before we take lock
    let from_time_exclusive = if let Some(rfc3339) = params.from_time_exclusive_rfc3339 {
        Some(DateTime::parse_from_rfc3339(&rfc3339).map_err(|_| StatusCode::BAD_REQUEST)?)
    } else {
        None
    };

    let messages = cache.messages.lock().await;
    let filtered: Vec<_> = match from_time_exclusive {
        Some(from_time_exclusive) => messages
            .iter()
            .filter(|msg| msg.time > from_time_exclusive)
            .cloned()
            .collect(),
        None => messages.iter().cloned().collect(),
    };
    Ok(Json(filtered))
}