Browse Source

opt: use database connection pool

master
hole-thu 3 years ago
parent
commit
658e5b5902
  1. 2
      Cargo.toml
  2. 4
      src/api/comment.rs
  3. 11
      src/api/mod.rs
  4. 13
      src/api/post.rs
  5. 3
      src/api/systemlog.rs
  6. 5
      src/main.rs
  7. 20
      src/models.rs

2
Cargo.toml

@ -8,7 +8,7 @@ license = "AGPL-3.0"
[dependencies] [dependencies]
rocket = { version = "0.5.0-rc.1", features = ["json"] } rocket = { version = "0.5.0-rc.1", features = ["json"] }
diesel = { version = "1.4.8", features= ["sqlite", "chrono"] } diesel = { version = "1.4.8", features= ["sqlite", "chrono", "r2d2"] }
chrono = { version="0.4", features=["serde"] } chrono = { version="0.4", features=["serde"] }
rand = "0.8.5" rand = "0.8.5"
dotenv = "0.15.0" dotenv = "0.15.0"

4
src/api/comment.rs

@ -5,6 +5,7 @@ use rocket::serde::{
json::{json, Value}, json::{json, Value},
Serialize, Serialize,
}; };
use crate::db_conn::DbConn;
use std::collections::HashMap; use std::collections::HashMap;
#[derive(Serialize)] #[derive(Serialize)]
@ -48,8 +49,7 @@ pub fn c2output(p: &Post, cs: &Vec<Comment>, user: &CurrentUser) -> Vec<CommentO
} }
#[get("/getcomment?<pid>")] #[get("/getcomment?<pid>")]
pub fn get_comment(pid: i32, user: CurrentUser) -> API<Value> { pub fn get_comment(pid: i32, user: CurrentUser, conn: DbConn) -> API<Value> {
let conn = establish_connection();
let p = Post::get(&conn, pid).map_err(APIError::from_db)?; let p = Post::get(&conn, pid).map_err(APIError::from_db)?;
if p.is_deleted { if p.is_deleted {
return Err(APIError::PcError(IsDeleted)); return Err(APIError::PcError(IsDeleted));

11
src/api/mod.rs

@ -1,9 +1,10 @@
use crate::models::*; use crate::models::*;
use crate::random_hasher::RandomHasher; use crate::random_hasher::RandomHasher;
use rocket::http::Status; use rocket::http::Status;
use rocket::request::{self, FromRequest, Request}; use rocket::request::{FromRequest, Request, Outcome};
use rocket::response::{self, Responder}; use rocket::response::{self, Responder};
use rocket::serde::json::json; use rocket::serde::json::json;
use crate::db_conn::DbPool;
#[catch(401)] #[catch(401)]
pub fn catch_401_error() -> &'static str { pub fn catch_401_error() -> &'static str {
@ -20,7 +21,7 @@ pub struct CurrentUser {
#[rocket::async_trait] #[rocket::async_trait]
impl<'r> FromRequest<'r> for CurrentUser { impl<'r> FromRequest<'r> for CurrentUser {
type Error = (); type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let rh = request.rocket().state::<RandomHasher>().unwrap(); let rh = request.rocket().state::<RandomHasher>().unwrap();
let mut cu: Option<CurrentUser> = None; let mut cu: Option<CurrentUser> = None;
@ -35,7 +36,7 @@ impl<'r> FromRequest<'r> for CurrentUser {
is_admin: false, is_admin: false,
}); });
} else { } else {
let conn = establish_connection(); let conn = request.rocket().state::<DbPool>().unwrap().get().unwrap();
if let Some(user) = User::get_by_token(&conn, token) { if let Some(user) = User::get_by_token(&conn, token) {
let namehash = rh.hash_with_salt(&user.name); let namehash = rh.hash_with_salt(&user.name);
cu = Some(CurrentUser { cu = Some(CurrentUser {
@ -48,8 +49,8 @@ impl<'r> FromRequest<'r> for CurrentUser {
} }
} }
match cu { match cu {
Some(u) => request::Outcome::Success(u), Some(u) => Outcome::Success(u),
None => request::Outcome::Failure((Status::Unauthorized, ())), None => Outcome::Failure((Status::Unauthorized, ())),
} }
} }
} }

13
src/api/post.rs

@ -2,12 +2,12 @@ use crate::api::comment::{c2output, CommentOutput};
use crate::api::{APIError, CurrentUser, PolicyError::*, API}; use crate::api::{APIError, CurrentUser, PolicyError::*, API};
use crate::models::*; use crate::models::*;
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use diesel::SqliteConnection;
use rocket::form::Form; use rocket::form::Form;
use rocket::serde::{ use rocket::serde::{
json::{json, Value}, json::{json, Value},
Serialize, Serialize,
}; };
use crate::db_conn::DbConn;
#[derive(FromForm)] #[derive(FromForm)]
pub struct PostInput<'r> { pub struct PostInput<'r> {
@ -40,7 +40,7 @@ pub struct PostOutput {
reply: i32, reply: i32,
} }
fn p2output(p: &Post, user: &CurrentUser, conn: &SqliteConnection) -> PostOutput { fn p2output(p: &Post, user: &CurrentUser, conn: &DbConn) -> PostOutput {
PostOutput { PostOutput {
pid: p.id, pid: p.id,
text: p.content.to_string(), text: p.content.to_string(),
@ -80,8 +80,7 @@ fn p2output(p: &Post, user: &CurrentUser, conn: &SqliteConnection) -> PostOutput
} }
#[get("/getone?<pid>")] #[get("/getone?<pid>")]
pub fn get_one(pid: i32, user: CurrentUser) -> API<Value> { pub fn get_one(pid: i32, user: CurrentUser, conn: DbConn) -> API<Value> {
let conn = establish_connection();
let p = Post::get(&conn, pid).map_err(APIError::from_db)?; let p = Post::get(&conn, pid).map_err(APIError::from_db)?;
if !user.is_admin { if !user.is_admin {
if p.is_reported { if p.is_reported {
@ -98,9 +97,8 @@ pub fn get_one(pid: i32, user: CurrentUser) -> API<Value> {
} }
#[get("/getlist?<p>&<order_mode>")] #[get("/getlist?<p>&<order_mode>")]
pub fn get_list(p: Option<u32>, order_mode: u8, user: CurrentUser) -> API<Value> { pub fn get_list(p: Option<u32>, order_mode: u8, user: CurrentUser, conn: DbConn) -> API<Value> {
let page = p.unwrap_or(1); let page = p.unwrap_or(1);
let conn = establish_connection();
let ps = Post::gets_by_page(&conn, order_mode, page, 25, user.is_admin) let ps = Post::gets_by_page(&conn, order_mode, page, 25, user.is_admin)
.map_err(APIError::from_db)?; .map_err(APIError::from_db)?;
let ps_data = ps let ps_data = ps
@ -115,8 +113,7 @@ pub fn get_list(p: Option<u32>, order_mode: u8, user: CurrentUser) -> API<Value>
} }
#[post("/dopost", data = "<poi>")] #[post("/dopost", data = "<poi>")]
pub fn publish_post(poi: Form<PostInput>, user: CurrentUser) -> API<Value> { pub fn publish_post(poi: Form<PostInput>, user: CurrentUser, conn: DbConn) -> API<Value> {
let conn = establish_connection();
dbg!(poi.use_title, poi.allow_search); dbg!(poi.use_title, poi.allow_search);
let r = Post::create( let r = Post::create(
&conn, &conn,

3
src/api/systemlog.rs

@ -2,9 +2,10 @@ use crate::api::{CurrentUser, API};
use crate::random_hasher::RandomHasher; use crate::random_hasher::RandomHasher;
use rocket::serde::json::{json, Value}; use rocket::serde::json::{json, Value};
use rocket::State; use rocket::State;
use crate::db_conn::DbConn;
#[get("/systemlog")] #[get("/systemlog")]
pub fn get_systemlog(user: CurrentUser, rh: &State<RandomHasher>) -> API<Value> { pub fn get_systemlog(user: CurrentUser, rh: &State<RandomHasher>, conn: DbConn) -> API<Value> {
Ok(json!({ Ok(json!({
"tmp_token": rh.get_tmp_token(), "tmp_token": rh.get_tmp_token(),
"salt": look!(rh.salt), "salt": look!(rh.salt),

5
src/main.rs

@ -4,12 +4,16 @@ extern crate rocket;
#[macro_use] #[macro_use]
extern crate diesel; extern crate diesel;
mod api; mod api;
mod models; mod models;
mod random_hasher; mod random_hasher;
mod schema; mod schema;
mod db_conn;
use random_hasher::RandomHasher; use random_hasher::RandomHasher;
use db_conn::init_pool;
#[launch] #[launch]
fn rocket() -> _ { fn rocket() -> _ {
@ -27,6 +31,7 @@ fn rocket() -> _ {
) )
.register("/_api", catchers![api::catch_401_error]) .register("/_api", catchers![api::catch_401_error])
.manage(RandomHasher::get_random_one()) .manage(RandomHasher::get_random_one())
.manage(init_pool())
} }
fn load_env() { fn load_env() {

20
src/models.rs

@ -1,20 +1,16 @@
#![allow(clippy::all)] #![allow(clippy::all)]
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use diesel::{insert_into, Connection, ExpressionMethods, QueryDsl, RunQueryDsl, SqliteConnection}; use diesel::{insert_into, ExpressionMethods, QueryDsl, RunQueryDsl};
use std::env;
use crate::schema::*; use crate::schema::*;
use crate::db_conn::Conn;
type MR<T> = Result<T, diesel::result::Error>; type MR<T> = Result<T, diesel::result::Error>;
no_arg_sql_function!(RANDOM, (), "Represents the sql RANDOM() function"); no_arg_sql_function!(RANDOM, (), "Represents the sql RANDOM() function");
pub fn establish_connection() -> SqliteConnection {
let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
SqliteConnection::establish(&database_url)
.expect(&format!("Error connecting to {}", database_url))
}
#[derive(Queryable, Debug)] #[derive(Queryable, Debug)]
pub struct Post { pub struct Post {
@ -45,12 +41,12 @@ pub struct NewPost<'a> {
} }
impl Post { impl Post {
pub fn get(conn: &SqliteConnection, id: i32) -> MR<Self> { pub fn get(conn: &Conn, id: i32) -> MR<Self> {
posts::table.find(id).first(conn) posts::table.find(id).first(conn)
} }
pub fn gets_by_page( pub fn gets_by_page(
conn: &SqliteConnection, conn: &Conn,
order_mode: u8, order_mode: u8,
page: u32, page: u32,
page_size: u32, page_size: u32,
@ -76,13 +72,13 @@ impl Post {
.load(conn) .load(conn)
} }
pub fn get_comments(&self, conn: &SqliteConnection) -> MR<Vec<Comment>> { pub fn get_comments(&self, conn: &Conn) -> MR<Vec<Comment>> {
comments::table comments::table
.filter(comments::post_id.eq(self.id)) .filter(comments::post_id.eq(self.id))
.load(conn) .load(conn)
} }
pub fn create(conn: &SqliteConnection, new_post: NewPost) -> MR<usize> { pub fn create(conn: &Conn, new_post: NewPost) -> MR<usize> {
// TODO: tags // TODO: tags
insert_into(posts::table).values(&new_post).execute(conn) insert_into(posts::table).values(&new_post).execute(conn)
} }
@ -97,7 +93,7 @@ pub struct User {
} }
impl User { impl User {
pub fn get_by_token(conn: &SqliteConnection, token: &str) -> Option<Self> { pub fn get_by_token(conn: &Conn, token: &str) -> Option<Self> {
users::table.filter(users::token.eq(token)).first(conn).ok() users::table.filter(users::token.eq(token)).first(conn).ok()
} }
} }

Loading…
Cancel
Save