use crate::api::{Api, CurrentUser}; use crate::db_conn::Db; use crate::models::{Comment, Post, User}; use crate::rds_conn::RdsConn; use crate::rds_models::BlockedUsers; use diesel::result::{Error as DieselError, QueryResult}; use moka::future::Cache; use rand::Rng; use redis::RedisResult; use rocket::futures::future; use rocket::tokio::sync::RwLock; use std::collections::HashMap; use std::future::Future; use std::io; use std::sync::Arc; use std::sync::OnceLock; use std::time::Duration; const USER_COUNT_EXPIRE_TIME: u64 = 60; const INSTANCE_EXPIRE_TIME: u64 = 60 * 60; // Global cache getters using OnceLock fn post_cache() -> &'static Cache { static CACHE: OnceLock> = OnceLock::new(); CACHE.get_or_init(|| Cache::builder().max_capacity(10_000).build()) } fn post_comment_cache() -> &'static Cache> { static CACHE: OnceLock>> = OnceLock::new(); CACHE.get_or_init(|| { Cache::builder() .time_to_idle(Duration::from_secs(INSTANCE_EXPIRE_TIME)) .build() }) } // Each list in post_list_cache, keyed by room_id and mode, is a sorted list. The element is a pair of numbers. The first one is the weight used to sort, the second one is the post id. fn post_list_cache() -> &'static Cache>>> { static CACHE: OnceLock>>>> = OnceLock::new(); CACHE.get_or_init(|| Cache::builder().build()) } fn user_cache() -> &'static Cache { static CACHE: OnceLock> = OnceLock::new(); CACHE.get_or_init(|| { Cache::builder() .time_to_idle(Duration::from_secs(INSTANCE_EXPIRE_TIME)) .build() }) } fn block_dict_cache() -> &'static Cache>>> { static CACHE: OnceLock>>>> = OnceLock::new(); CACHE.get_or_init(|| { Cache::builder() .time_to_idle(Duration::from_secs(INSTANCE_EXPIRE_TIME)) .build() }) } fn user_count_cache() -> &'static Cache { static CACHE: OnceLock> = OnceLock::new(); CACHE.get_or_init(|| { Cache::builder() .time_to_live(Duration::from_secs(USER_COUNT_EXPIRE_TIME)) .build() }) } fn map_shared_diesel_error(err: Arc) -> DieselError { match err.as_ref() { DieselError::NotFound => DieselError::NotFound, DieselError::RollbackTransaction => DieselError::RollbackTransaction, DieselError::AlreadyInTransaction => DieselError::AlreadyInTransaction, DieselError::NotInTransaction => DieselError::NotInTransaction, DieselError::BrokenTransactionManager => DieselError::BrokenTransactionManager, _ => DieselError::QueryBuilderError(Box::new(io::Error::other(err.to_string()))), } } pub struct PostCache; impl PostCache { pub async fn sets(ps: &[&Post]) { if ps.is_empty() { return; } for p in ps { post_cache().insert(p.id, (*p).clone()).await; } } pub async fn get(pid: &i32) -> Option { post_cache().get(pid).await } pub async fn get_with(pid: i32, init: F) -> QueryResult where F: Future>, { post_cache() .try_get_with(pid, init) .await .map_err(map_shared_diesel_error) } pub async fn gets(pids: &[i32]) -> Vec> { future::join_all(pids.iter().map(Self::get)).await } pub async fn clear_all() { post_cache().invalidate_all(); } } pub struct PostCommentCache { key: String, } impl PostCommentCache { pub fn init(post_id: i32) -> Self { Self { key: format!("hole_v2:cache:post_comments:{}", post_id), } } pub async fn get_with(&self, init: F) -> QueryResult> where F: Future>>, { post_comment_cache() .try_get_with(self.key.clone(), init) .await .map_err(map_shared_diesel_error) } pub async fn clear(&mut self) { post_comment_cache().invalidate(&self.key).await; } } pub struct PostListCache { key: String, mode: u8, } impl PostListCache { pub const MAX_LENGTH: usize = 900; // pub const MIN_LENGTH: usize = 200; pub const CUT_LENGTH: usize = 100; pub fn init(room_id: Option, mode: u8) -> Self { Self { key: format!( "hole_v2:cache:post_list:{}:{}", room_id.map_or_else(String::new, |i| i.to_string()), &mode ), mode, } } fn p2pair(&self, p: &Post) -> (i64, i32) { ( match self.mode { 0 => (-p.id).into(), 1 => -p.last_comment_time.timestamp(), 2 => (-p.hot_score).into(), 3 => rand::thread_rng().gen_range(0..i64::MAX), 4 => (-p.n_attentions).into(), _ => panic!("wrong mode"), }, p.id, ) } pub async fn fill_with(&mut self, query_posts: F) -> QueryResult where F: Future>>, { let list_ref = post_list_cache() .try_get_with(self.key.clone(), async { let mut items: Vec<(i64, i32)> = query_posts.await?.iter().map(|p| self.p2pair(p)).collect(); items.sort_by(|a, b| a.0.cmp(&b.0)); Ok(Arc::new(RwLock::new(items))) }) .await .map_err(map_shared_diesel_error)?; let list = list_ref.read().await; // Double-Checked Locking if list.len() <= Self::MAX_LENGTH { return Ok(list.len()); } drop(list); let mut list = list_ref.write().await; if list.len() <= Self::MAX_LENGTH { Ok(list.len()) } else { list.truncate(Self::MAX_LENGTH - Self::CUT_LENGTH); Ok(list.len()) } } pub async fn put(&mut self, p: &Post) { // Don't put is there is no cache. Let fill_with handle it. if let Some(list_ref) = post_list_cache().get(&self.key).await { let mut list = list_ref.write().await; // Remove any existing entry for this post_id if let Some(pos) = list.iter().position(|(_, pid)| *pid == p.id) { list.remove(pos); } if p.is_deleted || (self.mode > 0 && p.is_reported) { return; } list.push(self.p2pair(p)); list.sort_by(|a, b| a.0.cmp(&b.0)); } } pub async fn get_pids(&mut self, start: usize, limit: usize) -> Vec { if let Some(list_ref) = post_list_cache().get(&self.key).await { let list = list_ref.read().await; list.iter() .skip(start) .take(limit) .map(|(_, pid)| *pid) .collect() } else { vec![] } } pub async fn clear(&mut self) { post_list_cache().invalidate(&self.key).await; } } pub struct UserCache { key: String, } impl UserCache { pub fn init(user_id: &str) -> Self { Self { key: format!("hole_v2:cache:user:{}", user_id), } } // No need to use get_with for User. Just check and set separately. pub async fn set(&self, u: &User) { user_cache().insert(self.key.clone(), u.clone()).await; } pub async fn get(&self) -> Option { user_cache().get(&self.key).await } pub async fn clear_all() { user_cache().invalidate_all(); } } pub struct BlockDictCache { key: String, } impl BlockDictCache { pub fn init(namehash: &str, post_id: i32) -> Self { Self { key: format!("hole_v2:cache:block_dict:{}:{}", namehash, post_id), } } pub async fn get_or_create( &mut self, user: &CurrentUser, hash_list: &[&String], rconn: &RdsConn, ) -> RedisResult> { let dict_ref = block_dict_cache() .get_with(self.key.clone(), async move { Arc::new(RwLock::new(HashMap::new())) }) .await; // Find missing hashes let mut missing_keys: Vec = Vec::new(); { let block_dict = dict_ref.read().await; for hash in hash_list { if !block_dict.contains_key(hash.as_str()) { missing_keys.push((*hash).clone()); } } } if !missing_keys.is_empty() { let mut missing: Vec<(String, bool)> = Vec::with_capacity(missing_keys.len()); for hash in missing_keys { let is_blocked = BlockedUsers::check_if_block(rconn, user, &hash).await?; missing.push((hash, is_blocked)); } let mut block_dict = dict_ref.write().await; for (hash, is_blocked) in missing { block_dict.entry(hash).or_insert(is_blocked); } } let out = dict_ref.read().await.clone(); Ok(out) } pub async fn clear(&mut self) { block_dict_cache().invalidate(&self.key).await; } } pub async fn cached_user_count(db: &Db) -> Api { let key = "hole_v2:cache:user_count"; Ok(user_count_cache() .try_get_with(key.to_string(), async { User::get_count(db).await }) .await .map_err(map_shared_diesel_error)?) }