diff --git a/Cargo.toml b/Cargo.toml index 792a1df..f3725bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,5 +26,6 @@ web-push = "0.9.2" url = "2.2.2" futures = "0.3.24" futures-util = "0.3.24" +lru = "0.11" reqwest = { version = "0.11.10", features = ["json"], optional = true } diff --git a/src/api/mod.rs b/src/api/mod.rs index 8ed5df5..9834e49 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -3,9 +3,10 @@ use crate::db_conn::Db; use crate::models::*; use crate::random_hasher::RandomHasher; +use crate::rate_limit::MainLimiters; use crate::rds_conn::RdsConn; use crate::rds_models::*; -use rocket::http::Status; +use rocket::http::{Method, Status}; use rocket::outcome::try_outcome; use rocket::request::{FromRequest, Outcome, Request}; use rocket::response::{self, Responder}; @@ -91,6 +92,7 @@ impl<'r> FromRequest<'r> for CurrentUser { async fn from_request(request: &'r Request<'_>) -> Outcome { let rh = request.rocket().state::().unwrap(); let rconn = try_outcome!(request.guard::().await); + let limiters = request.rocket().state::().unwrap(); if let Some(user) = { if let Some(token) = request.headers().get_one("User-Token") { @@ -123,6 +125,11 @@ impl<'r> FromRequest<'r> for CurrentUser { } { if BannedUsers::has(&rconn, &user.namehash).await.unwrap() { Outcome::Error((Status::Forbidden, ())) + } else if !limiters.check( + request.method() == Method::Post, + user.id.unwrap_or_default(), + ) { + Outcome::Error((Status::TooManyRequests, ())) } else { Outcome::Success(user) } diff --git a/src/cache.rs b/src/cache.rs index 3952b02..fe37018 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -355,7 +355,9 @@ impl BlockDictCache { if !missing.is_empty() { self.rconn.hset_multiple(&self.key, &missing).await?; - self.rconn.expire(&self.key, INSTANCE_EXPIRE_TIME as i64).await?; + self.rconn + .expire(&self.key, INSTANCE_EXPIRE_TIME as i64) + .await?; block_dict.extend(missing.into_iter()); } diff --git a/src/main.rs b/src/main.rs index 87dd72f..3965997 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,20 +18,23 @@ mod db_conn; mod login; mod models; mod random_hasher; +mod rate_limit; mod rds_conn; mod rds_models; mod schema; use std::env; -use db_conn::{establish_connection, Conn, Db}; use diesel::Connection; use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; +use rocket::tokio; +use rocket::tokio::time::{sleep, Duration}; + +use db_conn::{establish_connection, Conn, Db}; use random_hasher::RandomHasher; +use rate_limit::MainLimiters; use rds_conn::{init_rds_client, RdsConn}; use rds_models::clear_outdate_redis_data; -use rocket::tokio; -use rocket::tokio::time::{sleep, Duration}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres"); @@ -120,6 +123,7 @@ async fn main() { api::catch_404_error ], ) + .manage(MainLimiters::init()) .manage(RandomHasher::get_random_one()) .manage(rmc) .attach(Db::fairing()) diff --git a/src/models.rs b/src/models.rs index c62567d..67b7ac6 100644 --- a/src/models.rs +++ b/src/models.rs @@ -229,8 +229,7 @@ impl Post { let mut cacher = PostListCache::init(room_id, order_mode, rconn); if cacher.need_fill().await { let pids = - Self::_get_ids_by_page(db, room_id, order_mode, 0, cacher.i64_minlen()) - .await?; + Self::_get_ids_by_page(db, room_id, order_mode, 0, cacher.i64_minlen()).await?; let ps = Self::get_multi(db, rconn, &pids).await?; cacher.fill(&ps).await; } diff --git a/src/rate_limit.rs b/src/rate_limit.rs new file mode 100644 index 0000000..fe1220f --- /dev/null +++ b/src/rate_limit.rs @@ -0,0 +1,67 @@ +use std::iter; +use std::num::NonZeroUsize; +use std::sync::Mutex; +use std::time::SystemTime; + +use lru::LruCache; + +pub struct Limiter { + record: Mutex>>, + amount: u64, + interval: u64, +} + +impl Limiter { + pub fn init(amount: u64, interval: u64) -> Self { + Self { + record: Mutex::new(LruCache::new(NonZeroUsize::new(2000).unwrap())), + amount, + interval, + } + } + + pub fn check(&self, uid: i32) -> bool { + let t = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + let mut r = self.record.lock().unwrap(); + if let Some(ts) = r.pop(&uid) { + let new_ts: Vec = ts + .into_iter() + .chain(iter::once(t)) + .filter(|&tt| tt + self.interval > t) + .collect(); + let len = new_ts.len() as u64; + r.put(uid, new_ts); + len < self.amount + } else { + r.put(uid, vec![t]); + true + } + } +} + +pub struct MainLimiters { + post_min: Limiter, + post_hour: Limiter, + get_hour: Limiter, +} + +impl MainLimiters { + pub fn init() -> Self { + Self { + post_min: Limiter::init(6, 60), + post_hour: Limiter::init(50, 3600), + get_hour: Limiter::init(1000, 3600), + } + } + + pub fn check(&self, is_post: bool, uid: i32) -> bool { + if is_post { + self.post_hour.check(uid) && self.post_min.check(uid) + } else { + self.get_hour.check(uid) + } + } +} diff --git a/src/schema.rs b/src/schema.rs index ee347e4..c7a1a57 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -45,8 +45,4 @@ table! { joinable!(comments -> posts (post_id)); -allow_tables_to_appear_in_same_query!( - comments, - posts, - users, -); +allow_tables_to_appear_in_same_query!(comments, posts, users,);