From 658e5b59024deb84db8719b5fcdbe845e7ee63d7 Mon Sep 17 00:00:00 2001 From: hole-thu Date: Wed, 16 Mar 2022 18:30:56 +0800 Subject: [PATCH] opt: use database connection pool --- Cargo.toml | 2 +- src/api/comment.rs | 4 ++-- src/api/mod.rs | 11 ++++++----- src/api/post.rs | 13 +++++-------- src/api/systemlog.rs | 3 ++- src/main.rs | 5 +++++ src/models.rs | 20 ++++++++------------ 7 files changed, 29 insertions(+), 29 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f564bf0..680cb04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ license = "AGPL-3.0" [dependencies] rocket = { version = "0.5.0-rc.1", features = ["json"] } -diesel = { version = "1.4.8", features= ["sqlite", "chrono"] } +diesel = { version = "1.4.8", features= ["sqlite", "chrono", "r2d2"] } chrono = { version="0.4", features=["serde"] } rand = "0.8.5" dotenv = "0.15.0" diff --git a/src/api/comment.rs b/src/api/comment.rs index c17b524..771999d 100644 --- a/src/api/comment.rs +++ b/src/api/comment.rs @@ -5,6 +5,7 @@ use rocket::serde::{ json::{json, Value}, Serialize, }; +use crate::db_conn::DbConn; use std::collections::HashMap; #[derive(Serialize)] @@ -48,8 +49,7 @@ pub fn c2output(p: &Post, cs: &Vec, user: &CurrentUser) -> Vec")] -pub fn get_comment(pid: i32, user: CurrentUser) -> API { - let conn = establish_connection(); +pub fn get_comment(pid: i32, user: CurrentUser, conn: DbConn) -> API { let p = Post::get(&conn, pid).map_err(APIError::from_db)?; if p.is_deleted { return Err(APIError::PcError(IsDeleted)); diff --git a/src/api/mod.rs b/src/api/mod.rs index ac74b6a..a7f5a2c 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,9 +1,10 @@ use crate::models::*; use crate::random_hasher::RandomHasher; use rocket::http::Status; -use rocket::request::{self, FromRequest, Request}; +use rocket::request::{FromRequest, Request, Outcome}; use rocket::response::{self, Responder}; use rocket::serde::json::json; +use crate::db_conn::DbPool; #[catch(401)] pub fn catch_401_error() -> &'static str { @@ -20,7 +21,7 @@ pub struct CurrentUser { #[rocket::async_trait] impl<'r> FromRequest<'r> for CurrentUser { type Error = (); - async fn from_request(request: &'r Request<'_>) -> request::Outcome { + async fn from_request(request: &'r Request<'_>) -> Outcome { let rh = request.rocket().state::().unwrap(); let mut cu: Option = None; @@ -35,7 +36,7 @@ impl<'r> FromRequest<'r> for CurrentUser { is_admin: false, }); } else { - let conn = establish_connection(); + let conn = request.rocket().state::().unwrap().get().unwrap(); if let Some(user) = User::get_by_token(&conn, token) { let namehash = rh.hash_with_salt(&user.name); cu = Some(CurrentUser { @@ -48,8 +49,8 @@ impl<'r> FromRequest<'r> for CurrentUser { } } match cu { - Some(u) => request::Outcome::Success(u), - None => request::Outcome::Failure((Status::Unauthorized, ())), + Some(u) => Outcome::Success(u), + None => Outcome::Failure((Status::Unauthorized, ())), } } } diff --git a/src/api/post.rs b/src/api/post.rs index 1d6ef11..e085190 100644 --- a/src/api/post.rs +++ b/src/api/post.rs @@ -2,12 +2,12 @@ use crate::api::comment::{c2output, CommentOutput}; use crate::api::{APIError, CurrentUser, PolicyError::*, API}; use crate::models::*; use chrono::NaiveDateTime; -use diesel::SqliteConnection; use rocket::form::Form; use rocket::serde::{ json::{json, Value}, Serialize, }; +use crate::db_conn::DbConn; #[derive(FromForm)] pub struct PostInput<'r> { @@ -40,7 +40,7 @@ pub struct PostOutput { reply: i32, } -fn p2output(p: &Post, user: &CurrentUser, conn: &SqliteConnection) -> PostOutput { +fn p2output(p: &Post, user: &CurrentUser, conn: &DbConn) -> PostOutput { PostOutput { pid: p.id, text: p.content.to_string(), @@ -80,8 +80,7 @@ fn p2output(p: &Post, user: &CurrentUser, conn: &SqliteConnection) -> PostOutput } #[get("/getone?")] -pub fn get_one(pid: i32, user: CurrentUser) -> API { - let conn = establish_connection(); +pub fn get_one(pid: i32, user: CurrentUser, conn: DbConn) -> API { let p = Post::get(&conn, pid).map_err(APIError::from_db)?; if !user.is_admin { if p.is_reported { @@ -98,9 +97,8 @@ pub fn get_one(pid: i32, user: CurrentUser) -> API { } #[get("/getlist?

&")] -pub fn get_list(p: Option, order_mode: u8, user: CurrentUser) -> API { +pub fn get_list(p: Option, order_mode: u8, user: CurrentUser, conn: DbConn) -> API { let page = p.unwrap_or(1); - let conn = establish_connection(); let ps = Post::gets_by_page(&conn, order_mode, page, 25, user.is_admin) .map_err(APIError::from_db)?; let ps_data = ps @@ -115,8 +113,7 @@ pub fn get_list(p: Option, order_mode: u8, user: CurrentUser) -> API } #[post("/dopost", data = "")] -pub fn publish_post(poi: Form, user: CurrentUser) -> API { - let conn = establish_connection(); +pub fn publish_post(poi: Form, user: CurrentUser, conn: DbConn) -> API { dbg!(poi.use_title, poi.allow_search); let r = Post::create( &conn, diff --git a/src/api/systemlog.rs b/src/api/systemlog.rs index 0119636..8c879e8 100644 --- a/src/api/systemlog.rs +++ b/src/api/systemlog.rs @@ -2,9 +2,10 @@ use crate::api::{CurrentUser, API}; use crate::random_hasher::RandomHasher; use rocket::serde::json::{json, Value}; use rocket::State; +use crate::db_conn::DbConn; #[get("/systemlog")] -pub fn get_systemlog(user: CurrentUser, rh: &State) -> API { +pub fn get_systemlog(user: CurrentUser, rh: &State, conn: DbConn) -> API { Ok(json!({ "tmp_token": rh.get_tmp_token(), "salt": look!(rh.salt), diff --git a/src/main.rs b/src/main.rs index 947da03..00b631a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,12 +4,16 @@ extern crate rocket; #[macro_use] extern crate diesel; + mod api; mod models; mod random_hasher; mod schema; +mod db_conn; + use random_hasher::RandomHasher; +use db_conn::init_pool; #[launch] fn rocket() -> _ { @@ -27,6 +31,7 @@ fn rocket() -> _ { ) .register("/_api", catchers![api::catch_401_error]) .manage(RandomHasher::get_random_one()) + .manage(init_pool()) } fn load_env() { diff --git a/src/models.rs b/src/models.rs index c2d8be1..6acbd2e 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,20 +1,16 @@ #![allow(clippy::all)] use chrono::NaiveDateTime; -use diesel::{insert_into, Connection, ExpressionMethods, QueryDsl, RunQueryDsl, SqliteConnection}; -use std::env; +use diesel::{insert_into, ExpressionMethods, QueryDsl, RunQueryDsl}; use crate::schema::*; +use crate::db_conn::Conn; + type MR = Result; no_arg_sql_function!(RANDOM, (), "Represents the sql RANDOM() function"); -pub fn establish_connection() -> SqliteConnection { - let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set"); - SqliteConnection::establish(&database_url) - .expect(&format!("Error connecting to {}", database_url)) -} #[derive(Queryable, Debug)] pub struct Post { @@ -45,12 +41,12 @@ pub struct NewPost<'a> { } impl Post { - pub fn get(conn: &SqliteConnection, id: i32) -> MR { + pub fn get(conn: &Conn, id: i32) -> MR { posts::table.find(id).first(conn) } pub fn gets_by_page( - conn: &SqliteConnection, + conn: &Conn, order_mode: u8, page: u32, page_size: u32, @@ -76,13 +72,13 @@ impl Post { .load(conn) } - pub fn get_comments(&self, conn: &SqliteConnection) -> MR> { + pub fn get_comments(&self, conn: &Conn) -> MR> { comments::table .filter(comments::post_id.eq(self.id)) .load(conn) } - pub fn create(conn: &SqliteConnection, new_post: NewPost) -> MR { + pub fn create(conn: &Conn, new_post: NewPost) -> MR { // TODO: tags insert_into(posts::table).values(&new_post).execute(conn) } @@ -97,7 +93,7 @@ pub struct User { } impl User { - pub fn get_by_token(conn: &SqliteConnection, token: &str) -> Option { + pub fn get_by_token(conn: &Conn, token: &str) -> Option { users::table.filter(users::token.eq(token)).first(conn).ok() } }