Files
hole-backend-rust/src/cache.rs
2026-03-24 03:19:49 +08:00

328 lines
9.6 KiB
Rust

use crate::api::{Api, CurrentUser};
use crate::db_conn::Db;
use crate::models::{Comment, Post, User};
use crate::rds_conn::RdsConn;
use crate::rds_models::BlockedUsers;
use diesel::result::{Error as DieselError, QueryResult};
use moka::future::Cache;
use rand::Rng;
use redis::RedisResult;
use rocket::futures::future;
use rocket::tokio::sync::RwLock;
use std::collections::HashMap;
use std::future::Future;
use std::io;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
const USER_COUNT_EXPIRE_TIME: u64 = 60;
const INSTANCE_EXPIRE_TIME: u64 = 60 * 60;
// Global cache getters using OnceLock
fn post_cache() -> &'static Cache<i32, Post> {
static CACHE: OnceLock<Cache<i32, Post>> = OnceLock::new();
CACHE.get_or_init(|| Cache::builder().max_capacity(10_000).build())
}
fn post_comment_cache() -> &'static Cache<String, Vec<Comment>> {
static CACHE: OnceLock<Cache<String, Vec<Comment>>> = OnceLock::new();
CACHE.get_or_init(|| {
Cache::builder()
.time_to_idle(Duration::from_secs(INSTANCE_EXPIRE_TIME))
.build()
})
}
// Each list in post_list_cache, keyed by room_id and mode, is a sorted list. The element is a pair of numbers. The first one is the weight used to sort, the second one is the post id.
fn post_list_cache() -> &'static Cache<String, Arc<RwLock<Vec<(i64, i32)>>>> {
static CACHE: OnceLock<Cache<String, Arc<RwLock<Vec<(i64, i32)>>>>> = OnceLock::new();
CACHE.get_or_init(|| Cache::builder().build())
}
fn user_cache() -> &'static Cache<String, User> {
static CACHE: OnceLock<Cache<String, User>> = OnceLock::new();
CACHE.get_or_init(|| {
Cache::builder()
.time_to_idle(Duration::from_secs(INSTANCE_EXPIRE_TIME))
.build()
})
}
fn block_dict_cache() -> &'static Cache<String, Arc<RwLock<HashMap<String, bool>>>> {
static CACHE: OnceLock<Cache<String, Arc<RwLock<HashMap<String, bool>>>>> = OnceLock::new();
CACHE.get_or_init(|| {
Cache::builder()
.time_to_idle(Duration::from_secs(INSTANCE_EXPIRE_TIME))
.build()
})
}
fn user_count_cache() -> &'static Cache<String, i64> {
static CACHE: OnceLock<Cache<String, i64>> = OnceLock::new();
CACHE.get_or_init(|| {
Cache::builder()
.time_to_live(Duration::from_secs(USER_COUNT_EXPIRE_TIME))
.build()
})
}
fn map_shared_diesel_error(err: Arc<DieselError>) -> DieselError {
match err.as_ref() {
DieselError::NotFound => DieselError::NotFound,
DieselError::RollbackTransaction => DieselError::RollbackTransaction,
DieselError::AlreadyInTransaction => DieselError::AlreadyInTransaction,
DieselError::NotInTransaction => DieselError::NotInTransaction,
DieselError::BrokenTransactionManager => DieselError::BrokenTransactionManager,
_ => DieselError::QueryBuilderError(Box::new(io::Error::other(err.to_string()))),
}
}
pub struct PostCache;
impl PostCache {
pub async fn sets(ps: &[&Post]) {
if ps.is_empty() {
return;
}
for p in ps {
post_cache().insert(p.id, (*p).clone()).await;
}
}
pub async fn get(pid: &i32) -> Option<Post> {
post_cache().get(pid).await
}
pub async fn get_with<F>(pid: i32, init: F) -> QueryResult<Post>
where
F: Future<Output = QueryResult<Post>>,
{
post_cache()
.try_get_with(pid, init)
.await
.map_err(map_shared_diesel_error)
}
pub async fn gets(pids: &[i32]) -> Vec<Option<Post>> {
future::join_all(pids.iter().map(Self::get)).await
}
pub async fn clear_all() {
post_cache().invalidate_all();
}
}
pub struct PostCommentCache {
key: String,
}
impl PostCommentCache {
pub fn init(post_id: i32) -> Self {
Self {
key: format!("hole_v2:cache:post_comments:{}", post_id),
}
}
pub async fn get_with<F>(&self, init: F) -> QueryResult<Vec<Comment>>
where
F: Future<Output = QueryResult<Vec<Comment>>>,
{
post_comment_cache()
.try_get_with(self.key.clone(), init)
.await
.map_err(map_shared_diesel_error)
}
pub async fn clear(&mut self) {
post_comment_cache().invalidate(&self.key).await;
}
}
pub struct PostListCache {
key: String,
mode: u8,
}
impl PostListCache {
pub const MAX_LENGTH: usize = 900;
// pub const MIN_LENGTH: usize = 200;
pub const CUT_LENGTH: usize = 100;
pub fn init(room_id: Option<i32>, mode: u8) -> Self {
Self {
key: format!(
"hole_v2:cache:post_list:{}:{}",
room_id.map_or_else(String::new, |i| i.to_string()),
&mode
),
mode,
}
}
fn p2pair(&self, p: &Post) -> (i64, i32) {
(
match self.mode {
0 => (-p.id).into(),
1 => -p.last_comment_time.timestamp(),
2 => (-p.hot_score).into(),
3 => rand::thread_rng().gen_range(0..i64::MAX),
4 => (-p.n_attentions).into(),
_ => panic!("wrong mode"),
},
p.id,
)
}
pub async fn fill_with<F>(&mut self, query_posts: F) -> QueryResult<usize>
where
F: Future<Output = QueryResult<Vec<Post>>>,
{
let list_ref = post_list_cache()
.try_get_with(self.key.clone(), async {
let mut items: Vec<(i64, i32)> =
query_posts.await?.iter().map(|p| self.p2pair(p)).collect();
items.sort_by(|a, b| a.0.cmp(&b.0));
Ok(Arc::new(RwLock::new(items)))
})
.await
.map_err(map_shared_diesel_error)?;
let list = list_ref.read().await;
// Double-Checked Locking
if list.len() <= Self::MAX_LENGTH {
return Ok(list.len());
}
drop(list);
let mut list = list_ref.write().await;
if list.len() <= Self::MAX_LENGTH {
Ok(list.len())
} else {
list.truncate(Self::MAX_LENGTH - Self::CUT_LENGTH);
Ok(list.len())
}
}
pub async fn put(&mut self, p: &Post) {
// Don't put is there is no cache. Let fill_with handle it.
if let Some(list_ref) = post_list_cache().get(&self.key).await {
let mut list = list_ref.write().await;
// Remove any existing entry for this post_id
if let Some(pos) = list.iter().position(|(_, pid)| *pid == p.id) {
list.remove(pos);
}
if p.is_deleted || (self.mode > 0 && p.is_reported) {
return;
}
list.push(self.p2pair(p));
list.sort_by(|a, b| a.0.cmp(&b.0));
}
}
pub async fn get_pids(&mut self, start: usize, limit: usize) -> Vec<i32> {
if let Some(list_ref) = post_list_cache().get(&self.key).await {
let list = list_ref.read().await;
list.iter()
.skip(start)
.take(limit)
.map(|(_, pid)| *pid)
.collect()
} else {
vec![]
}
}
pub async fn clear(&mut self) {
post_list_cache().invalidate(&self.key).await;
}
}
pub struct UserCache {
key: String,
}
impl UserCache {
pub fn init(user_id: &str) -> Self {
Self {
key: format!("hole_v2:cache:user:{}", user_id),
}
}
// No need to use get_with for User. Just check and set separately.
pub async fn set(&self, u: &User) {
user_cache().insert(self.key.clone(), u.clone()).await;
}
pub async fn get(&self) -> Option<User> {
user_cache().get(&self.key).await
}
pub async fn clear_all() {
user_cache().invalidate_all();
}
}
pub struct BlockDictCache {
key: String,
}
impl BlockDictCache {
pub fn init(namehash: &str, post_id: i32) -> Self {
Self {
key: format!("hole_v2:cache:block_dict:{}:{}", namehash, post_id),
}
}
pub async fn get_or_create(
&mut self,
user: &CurrentUser,
hash_list: &[&String],
rconn: &RdsConn,
) -> RedisResult<HashMap<String, bool>> {
let dict_ref = block_dict_cache()
.get_with(self.key.clone(), async move {
Arc::new(RwLock::new(HashMap::new()))
})
.await;
// Find missing hashes
let mut missing_keys: Vec<String> = Vec::new();
{
let block_dict = dict_ref.read().await;
for hash in hash_list {
if !block_dict.contains_key(hash.as_str()) {
missing_keys.push((*hash).clone());
}
}
}
if !missing_keys.is_empty() {
let mut missing: Vec<(String, bool)> = Vec::with_capacity(missing_keys.len());
for hash in missing_keys {
let is_blocked = BlockedUsers::check_if_block(rconn, user, &hash).await?;
missing.push((hash, is_blocked));
}
let mut block_dict = dict_ref.write().await;
for (hash, is_blocked) in missing {
block_dict.entry(hash).or_insert(is_blocked);
}
}
let out = dict_ref.read().await.clone();
Ok(out)
}
pub async fn clear(&mut self) {
block_dict_cache().invalidate(&self.key).await;
}
}
pub async fn cached_user_count(db: &Db) -> Api<i64> {
let key = "hole_v2:cache:user_count";
Ok(user_count_cache()
.try_get_with(key.to_string(), async { User::get_count(db).await })
.await
.map_err(map_shared_diesel_error)?)
}