basic rate limit
This commit is contained in:
@@ -3,9 +3,10 @@
|
||||
use crate::db_conn::Db;
|
||||
use crate::models::*;
|
||||
use crate::random_hasher::RandomHasher;
|
||||
use crate::rate_limit::MainLimiters;
|
||||
use crate::rds_conn::RdsConn;
|
||||
use crate::rds_models::*;
|
||||
use rocket::http::Status;
|
||||
use rocket::http::{Method, Status};
|
||||
use rocket::outcome::try_outcome;
|
||||
use rocket::request::{FromRequest, Outcome, Request};
|
||||
use rocket::response::{self, Responder};
|
||||
@@ -91,6 +92,7 @@ impl<'r> FromRequest<'r> for CurrentUser {
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
let rh = request.rocket().state::<RandomHasher>().unwrap();
|
||||
let rconn = try_outcome!(request.guard::<RdsConn>().await);
|
||||
let limiters = request.rocket().state::<MainLimiters>().unwrap();
|
||||
|
||||
if let Some(user) = {
|
||||
if let Some(token) = request.headers().get_one("User-Token") {
|
||||
@@ -123,6 +125,11 @@ impl<'r> FromRequest<'r> for CurrentUser {
|
||||
} {
|
||||
if BannedUsers::has(&rconn, &user.namehash).await.unwrap() {
|
||||
Outcome::Error((Status::Forbidden, ()))
|
||||
} else if !limiters.check(
|
||||
request.method() == Method::Post,
|
||||
user.id.unwrap_or_default(),
|
||||
) {
|
||||
Outcome::Error((Status::TooManyRequests, ()))
|
||||
} else {
|
||||
Outcome::Success(user)
|
||||
}
|
||||
|
||||
@@ -355,7 +355,9 @@ impl BlockDictCache {
|
||||
|
||||
if !missing.is_empty() {
|
||||
self.rconn.hset_multiple(&self.key, &missing).await?;
|
||||
self.rconn.expire(&self.key, INSTANCE_EXPIRE_TIME as i64).await?;
|
||||
self.rconn
|
||||
.expire(&self.key, INSTANCE_EXPIRE_TIME as i64)
|
||||
.await?;
|
||||
block_dict.extend(missing.into_iter());
|
||||
}
|
||||
|
||||
|
||||
12
src/main.rs
12
src/main.rs
@@ -18,21 +18,24 @@ mod db_conn;
|
||||
mod login;
|
||||
mod models;
|
||||
mod random_hasher;
|
||||
mod rate_limit;
|
||||
mod rds_conn;
|
||||
mod rds_models;
|
||||
mod schema;
|
||||
|
||||
use std::env;
|
||||
|
||||
use db_conn::{establish_connection, Conn, Db};
|
||||
use diesel::Connection;
|
||||
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
|
||||
use random_hasher::RandomHasher;
|
||||
use rds_conn::{init_rds_client, RdsConn};
|
||||
use rds_models::clear_outdate_redis_data;
|
||||
use rocket::tokio;
|
||||
use rocket::tokio::time::{sleep, Duration};
|
||||
|
||||
use db_conn::{establish_connection, Conn, Db};
|
||||
use random_hasher::RandomHasher;
|
||||
use rate_limit::MainLimiters;
|
||||
use rds_conn::{init_rds_client, RdsConn};
|
||||
use rds_models::clear_outdate_redis_data;
|
||||
|
||||
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
|
||||
|
||||
#[rocket::main]
|
||||
@@ -120,6 +123,7 @@ async fn main() {
|
||||
api::catch_404_error
|
||||
],
|
||||
)
|
||||
.manage(MainLimiters::init())
|
||||
.manage(RandomHasher::get_random_one())
|
||||
.manage(rmc)
|
||||
.attach(Db::fairing())
|
||||
|
||||
@@ -229,8 +229,7 @@ impl Post {
|
||||
let mut cacher = PostListCache::init(room_id, order_mode, rconn);
|
||||
if cacher.need_fill().await {
|
||||
let pids =
|
||||
Self::_get_ids_by_page(db, room_id, order_mode, 0, cacher.i64_minlen())
|
||||
.await?;
|
||||
Self::_get_ids_by_page(db, room_id, order_mode, 0, cacher.i64_minlen()).await?;
|
||||
let ps = Self::get_multi(db, rconn, &pids).await?;
|
||||
cacher.fill(&ps).await;
|
||||
}
|
||||
|
||||
67
src/rate_limit.rs
Normal file
67
src/rate_limit.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use std::iter;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::Mutex;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use lru::LruCache;
|
||||
|
||||
pub struct Limiter {
|
||||
record: Mutex<LruCache<i32, Vec<u64>>>,
|
||||
amount: u64,
|
||||
interval: u64,
|
||||
}
|
||||
|
||||
impl Limiter {
|
||||
pub fn init(amount: u64, interval: u64) -> Self {
|
||||
Self {
|
||||
record: Mutex::new(LruCache::new(NonZeroUsize::new(2000).unwrap())),
|
||||
amount,
|
||||
interval,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self, uid: i32) -> bool {
|
||||
let t = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let mut r = self.record.lock().unwrap();
|
||||
if let Some(ts) = r.pop(&uid) {
|
||||
let new_ts: Vec<u64> = ts
|
||||
.into_iter()
|
||||
.chain(iter::once(t))
|
||||
.filter(|&tt| tt + self.interval > t)
|
||||
.collect();
|
||||
let len = new_ts.len() as u64;
|
||||
r.put(uid, new_ts);
|
||||
len < self.amount
|
||||
} else {
|
||||
r.put(uid, vec![t]);
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MainLimiters {
|
||||
post_min: Limiter,
|
||||
post_hour: Limiter,
|
||||
get_hour: Limiter,
|
||||
}
|
||||
|
||||
impl MainLimiters {
|
||||
pub fn init() -> Self {
|
||||
Self {
|
||||
post_min: Limiter::init(6, 60),
|
||||
post_hour: Limiter::init(50, 3600),
|
||||
get_hour: Limiter::init(1000, 3600),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self, is_post: bool, uid: i32) -> bool {
|
||||
if is_post {
|
||||
self.post_hour.check(uid) && self.post_min.check(uid)
|
||||
} else {
|
||||
self.get_hour.check(uid)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -45,8 +45,4 @@ table! {
|
||||
|
||||
joinable!(comments -> posts (post_id));
|
||||
|
||||
allow_tables_to_appear_in_same_query!(
|
||||
comments,
|
||||
posts,
|
||||
users,
|
||||
);
|
||||
allow_tables_to_appear_in_same_query!(comments, posts, users,);
|
||||
|
||||
Reference in New Issue
Block a user