diff --git a/src/api/mod.rs b/src/api/mod.rs index ffb10e6..8c61395 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,6 +2,7 @@ use crate::db_conn::Db; use crate::models::*; use crate::random_hasher::RandomHasher; use crate::rds_conn::RdsConn; +use crate::rds_models::BannedUsers; use rocket::http::Status; use rocket::outcome::try_outcome; use rocket::request::{FromRequest, Outcome, Request}; @@ -13,6 +14,12 @@ pub fn catch_401_error() -> &'static str { "未登录或token过期" } +#[catch(403)] +pub fn catch_403_error() -> &'static str { + "可能被封禁了,等下次重置吧" +} + + pub struct CurrentUser { id: Option, // tmp user has no id, only for block namehash: String, @@ -25,34 +32,40 @@ impl<'r> FromRequest<'r> for CurrentUser { type Error = (); async fn from_request(request: &'r Request<'_>) -> Outcome { let rh = request.rocket().state::().unwrap(); - let mut cu: Option = None; + let rconn = try_outcome!(request.guard::().await); + + let mut id = None; + let mut namehash = None; + let mut is_admin = false; if let Some(token) = request.headers().get_one("User-Token") { let sp = token.split('_').collect::>(); if sp.len() == 2 && sp[0] == rh.get_tmp_token() { - let namehash = rh.hash_with_salt(sp[1]); - cu = Some(CurrentUser { - id: None, - custom_title: format!("TODO: {}", &namehash), - namehash: namehash, - is_admin: false, - }); + namehash = Some(rh.hash_with_salt(sp[1])); + id = None; + is_admin = false; } else { let db = try_outcome!(request.guard::().await); - let rconn = try_outcome!(request.guard::().await); - if let Some(user) = User::get_by_token(&db, &rconn, token).await { - let namehash = rh.hash_with_salt(&user.name); - cu = Some(CurrentUser { - id: Some(user.id), - custom_title: format!("TODO: {}", &namehash), - namehash: namehash, - is_admin: user.is_admin, - }); + if let Some(u) = User::get_by_token(&db, &rconn, token).await { + id = Some(u.id); + namehash = Some(rh.hash_with_salt(&u.name)); + is_admin = u.is_admin; } } } - match cu { - Some(u) => Outcome::Success(u), + match namehash { + Some(nh) => { + if BannedUsers::has(&rconn, &nh).await.unwrap() { + Outcome::Failure((Status::Forbidden, ())) + } else { + Outcome::Success(CurrentUser { + id: id, + custom_title: format!("title todo: {}", &nh), + namehash: nh, + is_admin: is_admin, + }) + } + } None => Outcome::Failure((Status::Unauthorized, ())), } } diff --git a/src/api/operation.rs b/src/api/operation.rs index f58c922..d62a83c 100644 --- a/src/api/operation.rs +++ b/src/api/operation.rs @@ -49,9 +49,9 @@ pub async fn delete( _ => return Err(APIError::PcError(NotAllowed)), } - if user.is_admin && author_hash != user.namehash { + if user.is_admin && !user.namehash.eq(author_hash) { Systemlog { - user_hash: user.namehash, + user_hash: user.namehash.clone(), action_type: LogType::AdminDelete, target: format!("#{}, {}={}", p.id, di.id_type, di.id), detail: di.note.clone(), @@ -59,6 +59,19 @@ pub async fn delete( } .create(&rconn) .await?; + + if di.note.starts_with("!ban ") { + Systemlog { + user_hash: user.namehash.clone(), + action_type: LogType::Ban, + target: look!(author_hash), + detail: di.note.clone(), + time: Local::now(), + } + .create(&rconn) + .await?; + BannedUsers::add(&rconn, author_hash).await?; + } } Ok(json!({ diff --git a/src/main.rs b/src/main.rs index 4642424..6ec7d9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,7 +23,7 @@ mod schema; use db_conn::{establish_connection, Conn, Db}; use diesel::Connection; use random_hasher::RandomHasher; -use rds_conn::init_rds_client; +use rds_conn::{init_rds_client, RdsConn}; use std::env; use tokio::time::{interval, Duration}; @@ -38,7 +38,8 @@ async fn main() -> Result<(), rocket::Error> { } env_logger::init(); let rmc = init_rds_client().await; - let rconn = rds_conn::RdsConn(rmc.clone()); + let rconn = RdsConn(rmc.clone()); + clear_outdate_redis_data(&rconn.clone()).await; tokio::spawn(async move { let mut itv = interval(Duration::from_secs(4 * 60 * 60)); loop { @@ -64,7 +65,10 @@ async fn main() -> Result<(), rocket::Error> { api::operation::delete, ], ) - .register("/_api", catchers![api::catch_401_error]) + .register( + "/_api", + catchers![api::catch_401_error, api::catch_403_error,], + ) .manage(RandomHasher::get_random_one()) .manage(rmc) .attach(Db::fairing()) @@ -85,3 +89,7 @@ fn init_database() { let conn = Conn::establish(&database_url).unwrap(); embedded_migrations::run(&conn).unwrap(); } + +async fn clear_outdate_redis_data(rconn: &RdsConn) { + rds_models::BannedUsers::clear(&rconn).await.unwrap(); +} diff --git a/src/rds_models.rs b/src/rds_models.rs index 2b762e1..4a978ef 100644 --- a/src/rds_models.rs +++ b/src/rds_models.rs @@ -5,6 +5,7 @@ use rocket::serde::json::serde_json; use rocket::serde::{Deserialize, Serialize}; const KEY_SYSTEMLOG: &str = "hole_v2:systemlog_list"; +const KEY_BANNED_USERS: &str = "hole_v2:banned_user_hash_list"; const SYSTEMLOG_MAX_LEN: isize = 1000; pub struct Attention { @@ -35,6 +36,8 @@ impl Attention { pub async fn all(&mut self) -> RedisResult> { self.rconn.smembers(&self.key).await } + + // TODO: clear all } #[derive(Serialize, Deserialize, Debug)] @@ -42,10 +45,9 @@ impl Attention { pub enum LogType { AdminDelete, Report, - Ban + Ban, } - #[derive(Serialize, Deserialize, Debug)] #[serde(crate = "rocket::serde")] pub struct Systemlog { @@ -78,3 +80,22 @@ impl Systemlog { .collect()) } } + +pub struct BannedUsers; + +impl BannedUsers { + pub async fn add(rconn: &RdsConn, namehash: &str) -> RedisResult<()> { + rconn + .clone() + .sadd::<&str, &str, ()>(KEY_BANNED_USERS, namehash) + .await + } + + pub async fn has(rconn: &RdsConn, namehash: &str) -> RedisResult { + rconn.clone().sismember(KEY_BANNED_USERS, namehash).await + } + + pub async fn clear(rconn: &RdsConn) -> RedisResult<()> { + rconn.clone().del(KEY_BANNED_USERS).await + } +}