From 84943a396504675bb050c5283b908b4c7438cc30 Mon Sep 17 00:00:00 2001 From: hole-thu Date: Tue, 2 Jan 2024 17:40:05 +0800 Subject: [PATCH] rocket 5.0 & diesel 2.1 --- Cargo.toml | 13 ++-- src/api/mod.rs | 4 +- src/cache.rs | 10 +-- src/libs/diesel_logger.rs | 148 -------------------------------------- src/libs/mod.rs | 1 - src/main.rs | 11 +-- src/models.rs | 43 ++++++----- src/rds_models.rs | 4 +- 8 files changed, 41 insertions(+), 193 deletions(-) delete mode 100644 src/libs/diesel_logger.rs delete mode 100644 src/libs/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 5e10165..792a1df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,14 +11,11 @@ default = ["mastlogin"] mastlogin = ["reqwest"] [dependencies] -rocket = { version = "=0.5.0-rc.2", features = ["json"] } -rocket_codegen = "=0.5.0-rc.2" -rocket_http = "=0.5.0-rc.2" -rocket_sync_db_pools = { version = "=0.1.0-rc.2", features = ["diesel_postgres_pool"] } -rocket_sync_db_pools_codegen = "=0.1.0-rc.2" -diesel = { version = "1.4.8", features = ["postgres", "chrono"] } -diesel_migrations = "1.4.0" -redis = { version="0.23.0", features = ["aio", "tokio-comp"] } +rocket = { version = "0.5.0", features = ["json"] } +rocket_sync_db_pools = { version = "0.1.0", features = ["diesel_postgres_pool"] } +diesel = { version = "2.1", features = ["postgres", "chrono"] } +diesel_migrations = "2.1" +redis = { version="0.24.0", features = ["aio", "tokio-comp"] } chrono = { version="0.4.19", features = ["serde"] } rand = "0.8.5" dotenv = "0.15.0" diff --git a/src/api/mod.rs b/src/api/mod.rs index ae3f6f3..8ed5df5 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -122,12 +122,12 @@ impl<'r> FromRequest<'r> for CurrentUser { } } { if BannedUsers::has(&rconn, &user.namehash).await.unwrap() { - Outcome::Failure((Status::Forbidden, ())) + Outcome::Error((Status::Forbidden, ())) } else { Outcome::Success(user) } } else { - Outcome::Failure((Status::Unauthorized, ())) + Outcome::Error((Status::Unauthorized, ())) } } } diff --git a/src/cache.rs b/src/cache.rs index 1def013..3952b02 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -12,9 +12,9 @@ use rocket::futures::future; use std::collections::HashMap; const KEY_USER_COUNT: &str = "hole_v2:cache:user_count"; -const USER_COUNT_EXPIRE_TIME: usize = 5 * 60; +const USER_COUNT_EXPIRE_TIME: u64 = 5 * 60; -const INSTANCE_EXPIRE_TIME: usize = 60 * 60; +const INSTANCE_EXPIRE_TIME: u64 = 60 * 60; const MIN_LENGTH: isize = 200; const MAX_LENGTH: isize = 900; @@ -133,7 +133,7 @@ impl PostCommentCache { // dbg!(&rds_result); if let Ok(s) = rds_result { self.rconn - .expire::<&String, bool>(&self.key, INSTANCE_EXPIRE_TIME) + .expire::<&String, bool>(&self.key, INSTANCE_EXPIRE_TIME as i64) .await .unwrap_or_else(|e| { warn!( @@ -302,7 +302,7 @@ impl UserCache { let rds_result = self.rconn.get::<&String, String>(&self.key).await; if let Ok(s) = rds_result { self.rconn - .expire::<&String, bool>(&self.key, INSTANCE_EXPIRE_TIME) + .expire::<&String, bool>(&self.key, INSTANCE_EXPIRE_TIME as i64) .await .unwrap_or_else(|e| { warn!( @@ -355,7 +355,7 @@ impl BlockDictCache { if !missing.is_empty() { self.rconn.hset_multiple(&self.key, &missing).await?; - self.rconn.expire(&self.key, INSTANCE_EXPIRE_TIME).await?; + self.rconn.expire(&self.key, INSTANCE_EXPIRE_TIME as i64).await?; block_dict.extend(missing.into_iter()); } diff --git a/src/libs/diesel_logger.rs b/src/libs/diesel_logger.rs deleted file mode 100644 index 9ce4818..0000000 --- a/src/libs/diesel_logger.rs +++ /dev/null @@ -1,148 +0,0 @@ -/* - * from https://github.com/shssoichiro/diesel-logger - * change Connection to &mut Connection - */ - -use std::ops::Deref; -use std::time::{Duration, Instant}; - -use diesel::backend::{Backend, UsesAnsiSavepointSyntax}; -use diesel::connection::{AnsiTransactionManager, SimpleConnection}; -use diesel::debug_query; -use diesel::deserialize::QueryableByName; -use diesel::prelude::*; -use diesel::query_builder::{AsQuery, QueryFragment, QueryId}; -use diesel::sql_types::HasSqlType; - -/// Wraps a diesel `Connection` to time and log each query using -/// the configured logger for the `log` crate. -/// -/// Currently, this produces a `debug` log on every query, -/// an `info` on queries that take longer than 1 second, -/// and a `warn`ing on queries that take longer than 5 seconds. -/// These thresholds will be configurable in a future version. -pub struct LoggingConnection<'r, C: Connection>(&'r mut C); - -impl<'r, C: Connection> LoggingConnection<'r, C> { - pub fn new(conn: &'r mut C) -> Self { - LoggingConnection(conn) - } -} - -impl<'r, C: Connection> Deref for LoggingConnection<'r, C> { - type Target = C; - fn deref(&self) -> &Self::Target { - self.0 - } -} - -impl<'r, C> SimpleConnection for LoggingConnection<'r, C> -where - C: Connection + Send + 'static, -{ - fn batch_execute(&self, query: &str) -> QueryResult<()> { - let start_time = Instant::now(); - let result = self.0.batch_execute(query); - let duration = start_time.elapsed(); - log_query(query, duration); - result - } -} - -impl Connection for LoggingConnection<'_, C> -where - C: Connection + Send + 'static, - C::Backend: UsesAnsiSavepointSyntax, - ::QueryBuilder: Default, -{ - type Backend = C::Backend; - type TransactionManager = C::TransactionManager; - - fn establish(_: &str) -> ConnectionResult { - Err(ConnectionError::__Nonexhaustive) - //Ok(LoggingConnection(C::establish(database_url)?)) - } - - fn execute(&self, query: &str) -> QueryResult { - let start_time = Instant::now(); - let result = self.0.execute(query); - let duration = start_time.elapsed(); - log_query(query, duration); - result - } - - fn query_by_index(&self, source: T) -> QueryResult> - where - T: AsQuery, - T::Query: QueryFragment + QueryId, - Self::Backend: HasSqlType, - U: Queryable, - { - let query = source.as_query(); - let debug_query = debug_query::(&query).to_string(); - let start_time = Instant::now(); - let result = self.0.query_by_index(query); - let duration = start_time.elapsed(); - log_query(&debug_query, duration); - result - } - - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName, - { - let debug_query = debug_query::(&source).to_string(); - let start_time = Instant::now(); - let result = self.0.query_by_name(source); - let duration = start_time.elapsed(); - log_query(&debug_query, duration); - result - } - - fn execute_returning_count(&self, source: &T) -> QueryResult - where - T: QueryFragment + QueryId, - { - let debug_query = debug_query::(&source).to_string(); - let start_time = Instant::now(); - let result = self.0.execute_returning_count(source); - let duration = start_time.elapsed(); - log_query(&debug_query, duration); - result - } - - fn transaction_manager(&self) -> &Self::TransactionManager { - self.0.transaction_manager() - } -} - -fn log_query(query: &str, duration: Duration) { - if duration.as_secs() >= 5 { - warn!( - "Slow query ran in {:.2} seconds: {}", - duration_to_secs(duration), - query - ); - } else if duration.as_secs() >= 1 { - info!( - "Slow query ran in {:.2} seconds: {}", - duration_to_secs(duration), - query - ); - } else { - debug!("Query ran in {:.1} ms: {}", duration_to_ms(duration), query); - } -} - -const NANOS_PER_MILLI: u32 = 1_000_000; -const MILLIS_PER_SEC: u32 = 1_000; - -fn duration_to_secs(duration: Duration) -> f32 { - duration_to_ms(duration) / MILLIS_PER_SEC as f32 -} - -fn duration_to_ms(duration: Duration) -> f32 { - (duration.as_secs() as u32 * 1000) as f32 - + (duration.subsec_nanos() as f32 / NANOS_PER_MILLI as f32) -} diff --git a/src/libs/mod.rs b/src/libs/mod.rs deleted file mode 100644 index 98e4729..0000000 --- a/src/libs/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod diesel_logger; diff --git a/src/main.rs b/src/main.rs index 87b8812..87dd72f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,6 @@ mod api; mod cache; mod cors; mod db_conn; -mod libs; #[cfg(feature = "mastlogin")] mod login; mod models; @@ -23,16 +22,18 @@ mod rds_conn; mod rds_models; mod schema; +use std::env; + use db_conn::{establish_connection, Conn, Db}; use diesel::Connection; +use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; use random_hasher::RandomHasher; use rds_conn::{init_rds_client, RdsConn}; use rds_models::clear_outdate_redis_data; use rocket::tokio; use rocket::tokio::time::{sleep, Duration}; -use std::env; -embed_migrations!("migrations/postgres"); +pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres"); #[rocket::main] async fn main() { @@ -143,6 +144,6 @@ fn load_env() { fn init_database() { let database_url = env::var("DATABASE_URL").unwrap(); - let conn = Conn::establish(&database_url).unwrap(); - embedded_migrations::run(&conn).unwrap(); + let mut conn = Conn::establish(&database_url).unwrap(); + conn.run_pending_migrations(MIGRATIONS).unwrap(); } diff --git a/src/models.rs b/src/models.rs index 6e1beb8..c62567d 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,4 +1,4 @@ -#![allow(clippy::all)] +// #![allow(clippy::all)] use crate::cache::*; use crate::db_conn::{Conn, Db}; @@ -6,7 +6,6 @@ use crate::random_hasher::random_string; use crate::rds_conn::RdsConn; use crate::schema::*; use chrono::{offset::Utc, DateTime}; -use diesel::dsl::any; use diesel::sql_types::*; use diesel::{ insert_into, BoolExpressionMethods, ExpressionMethods, QueryDsl, QueryResult, RunQueryDsl, @@ -17,7 +16,7 @@ use rocket::serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::collections::HashMap; -no_arg_sql_function!(RANDOM, (), "Represents the sql RANDOM() function"); +sql_function!(fn random()); sql_function!(fn floor(x: Float) -> Int4); sql_function!(fn float4(x: Int4) -> Float); @@ -40,7 +39,7 @@ macro_rules! _get_multi { // eq(any()) is only for postgres db.run(move |c| { $table::table - .filter($table::id.eq(any(ids))) + .filter($table::id.eq_any(ids)) .filter($table::is_deleted.eq(false)) .load(with_log!(c)) }) @@ -84,11 +83,11 @@ macro_rules! base_query { }; } +// TODO: log sql query macro_rules! with_log { - ($c: expr) => {{ - use crate::libs::diesel_logger::LoggingConnection; - &LoggingConnection::new($c) - }}; + ($c: expr) => { + $c + }; } #[derive(Queryable, Insertable, Serialize, Deserialize, Debug)] @@ -137,7 +136,7 @@ pub struct User { } #[derive(Insertable)] -#[table_name = "posts"] +#[diesel(table_name = posts)] pub struct NewPost { pub content: String, pub cw: String, @@ -154,8 +153,8 @@ impl Post { _get_multi!(posts); - pub async fn get_multi(db: &Db, rconn: &RdsConn, ids: &Vec) -> QueryResult> { - let mut cacher = PostCache::init(&rconn); + pub async fn get_multi(db: &Db, rconn: &RdsConn, ids: &[i32]) -> QueryResult> { + let mut cacher = PostCache::init(rconn); let mut cached_posts = cacher.gets(ids).await; let mut id2po = HashMap::>::new(); @@ -166,7 +165,7 @@ impl Post { .zip(cached_posts.iter_mut()) .filter_map(|(pid, p)| match p { None => { - id2po.insert(pid.clone(), p); + id2po.insert(*pid, p); Some(pid) } _ => None, @@ -194,12 +193,12 @@ impl Post { pub async fn get(db: &Db, rconn: &RdsConn, id: i32) -> QueryResult { // 注意即使is_deleted也应该缓存和返回 - let mut cacher = PostCache::init(&rconn); + let mut cacher = PostCache::init(rconn); if let Some(p) = cacher.get(&id).await { Ok(p) } else { let p = Self::_get(db, id).await?; - cacher.sets(&vec![&p]).await; + cacher.sets(&[&p]).await; Ok(p) } } @@ -227,10 +226,10 @@ impl Post { start: i64, limit: i64, ) -> QueryResult> { - let mut cacher = PostListCache::init(room_id, order_mode, &rconn); + let mut cacher = PostListCache::init(room_id, order_mode, rconn); if cacher.need_fill().await { let pids = - Self::_get_ids_by_page(db, room_id, order_mode.clone(), 0, cacher.i64_minlen()) + Self::_get_ids_by_page(db, room_id, order_mode, 0, cacher.i64_minlen()) .await?; let ps = Self::get_multi(db, rconn, &pids).await?; cacher.fill(&ps).await; @@ -268,7 +267,7 @@ impl Post { 0 => query.order(posts::id.desc()), 1 => query.order(posts::last_comment_time.desc()), 2 => query.order(posts::hot_score.desc()), - 3 => query.order(RANDOM), + 3 => query.order(random()), 4 => query.order(posts::n_attentions.desc()), _ => panic!("Wrong order mode!"), }; @@ -287,7 +286,7 @@ impl Post { start: i64, limit: i64, ) -> QueryResult> { - let search_text2 = search_text.replace("%", "\\%"); + let search_text2 = search_text.replace('%', "\\%"); let pids = db .run(move |c| { let pat; @@ -314,7 +313,7 @@ impl Post { ) } 1 => { - pat = format!("%{}%", search_text2.replace(" ", "%")); + pat = format!("%{}%", search_text2.replace(' ', "%")); query .filter( posts::content.like(&pat).or(comments::content @@ -351,7 +350,7 @@ impl Post { } pub async fn set_instance_cache(&self, rconn: &RdsConn) { - PostCache::init(rconn).sets(&vec![self]).await; + PostCache::init(rconn).sets(&[self]).await; } pub async fn refresh_cache(&self, rconn: &RdsConn, is_new: bool) { join!( @@ -409,7 +408,7 @@ impl User { _ => token, }; // dbg!(token); - let mut cacher = UserCache::init(token, &rconn); + let mut cacher = UserCache::init(token, rconn); if let Some(u) = cacher.get().await { Some(u) } else { @@ -465,7 +464,7 @@ impl User { } #[derive(Insertable)] -#[table_name = "comments"] +#[diesel(table_name = comments)] pub struct NewComment { pub content: String, pub author_hash: String, diff --git a/src/rds_models.rs b/src/rds_models.rs index c85f74a..223ce69 100644 --- a/src/rds_models.rs +++ b/src/rds_models.rs @@ -79,7 +79,7 @@ const KEY_SYSTEMLOG: &str = "hole_v2:systemlog_list"; const KEY_BANNED_USERS: &str = "hole_v2:banned_user_hash_list"; const KEY_BLOCKED_COUNTER: &str = "hole_v2:blocked_counter"; const KEY_CUSTOM_TITLE: &str = "hole_v2:title"; -const CUSTOM_TITLE_KEEP_TIME: usize = 7 * 24 * 60 * 60; +const CUSTOM_TITLE_KEEP_TIME: u64 = 7 * 24 * 60 * 60; macro_rules! KEY_TITLE_SECRET { ($title: expr) => { format!("hole_v2:title_secret:{}", $title) @@ -299,7 +299,7 @@ impl CustomTitle { let secret = if let Some(ss) = s { rconn .clone() - .expire(KEY_TITLE_SECRET!(title), CUSTOM_TITLE_KEEP_TIME) + .expire(KEY_TITLE_SECRET!(title), CUSTOM_TITLE_KEEP_TIME as i64) .await?; ss } else {