From d6eaec911114b6f140c9723c7407a7a7f96de99a Mon Sep 17 00:00:00 2001 From: hole-thu Date: Sun, 27 Mar 2022 21:45:14 +0800 Subject: [PATCH] feat: cs login & add vote api file --- Cargo.toml | 5 +- src/api/vote.rs | 70 ++++++++++++++++++++++++++ src/login.rs | 114 +++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 7 +-- src/models.rs | 33 +++++++++++++ src/random_hasher.rs | 14 ++++-- 6 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 src/api/vote.rs create mode 100644 src/login.rs diff --git a/Cargo.toml b/Cargo.toml index 6b02b58..c0dedcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,9 +13,12 @@ diesel = { version = "1.4.8", features = ["postgres", "chrono"] } diesel_migrations = "1.4.0" tokio = "1.17.0" redis = { version="0.21.5", features = ["aio", "tokio-comp"] } -chrono = { version="0.4.19", features =["serde"] } +chrono = { version="0.4.19", features = ["serde"] } rand = "0.8.5" dotenv = "0.15.0" sha2 = "0.10.2" log = "0.4.16" env_logger = "0.9.0" + +url = "2.2.2" +reqwest = { version = "0.11.10", features = ["json"] } diff --git a/src/api/vote.rs b/src/api/vote.rs new file mode 100644 index 0000000..fb42944 --- /dev/null +++ b/src/api/vote.rs @@ -0,0 +1,70 @@ +use crate::api::{CurrentUser, JsonAPI, PolicyError::*}; +use crate::rds_conn::RdsConn; +use crate::rds_models::*; +use rocket::form::Form; +use rocket::futures::future; +use rocket::serde::json::{json, Value}; + +pub async fn get_poll_dict(pid: i32, rconn: &RdsConn, namehash: &str) -> Option { + let opts = PollOption::init(pid, rconn) + .get_list() + .await + .unwrap_or_default(); + if opts.is_empty() { + None + } else { + let choice = future::join_all(opts.iter().enumerate().map(|(idx, opt)| async move { + PollVote::init(pid, idx, rconn) + .has(namehash) + .await + .unwrap_or_default() + .then(|| opt) + })) + .await + .into_iter() + .filter_map(|x| x) + .collect::>() + .pop(); + Some(json!({ + "answers": future::join_all( + opts.iter().enumerate().map(|(idx, opt)| async move { + json!({ + "option": opt, + "votes": PollVote::init(pid, idx, rconn).count().await.unwrap_or_default(), + }) + }) + ).await, + "vote": choice, + })) + } +} + +#[derive(FromForm)] +pub struct VoteInput { + pid: i32, + vote: String, +} + +#[post("/vote", data = "")] +pub async fn vote(vi: Form, user: CurrentUser, rconn: RdsConn) -> JsonAPI { + let pid = vi.pid; + let opts = PollOption::init(pid, &rconn).get_list().await?; + if opts.is_empty() { + Err(NotAllowed)?; + } + + for idx in 0..opts.len() { + if PollVote::init(pid, idx, &rconn).has(&user.namehash).await? { + Err(NotAllowed)?; + } + } + + let idx: usize = opts + .iter() + .position(|x| x.eq(&vi.vote)) + .ok_or_else(|| NotAllowed)?; + + PollVote::init(pid, idx, &rconn).add(&user.namehash).await?; + + code0!(get_poll_dict(vi.pid, &rconn, &user.namehash).await) +} diff --git a/src/login.rs b/src/login.rs new file mode 100644 index 0000000..7e40225 --- /dev/null +++ b/src/login.rs @@ -0,0 +1,114 @@ +use crate::db_conn::Db; +use crate::models::User; +use rocket::request::{FromRequest, Outcome, Request}; +use rocket::response::Redirect; +use rocket::serde::Deserialize; +use std::env; +use url::Url; + +pub struct RefHeader(pub String); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for RefHeader { + type Error = (); + async fn from_request(request: &'r Request<'_>) -> Outcome { + match request.headers().get_one("Referer") { + Some(h) => Outcome::Success(RefHeader(h.to_string())), + None => Outcome::Forward(()), + } + } +} + +#[get("/?p=cs")] +pub fn cs_login(r: RefHeader) -> Redirect { + let mast_url = env::var("MAST_BASE_URL").unwrap(); + let mast_cli = env::var("MAST_CLIENT").unwrap(); + let mast_scope = env::var("MAST_SCOPE").unwrap(); + + let mut redirect_url = Url::parse(&r.0).unwrap(); + redirect_url.set_path("/_login/cs/auth"); + redirect_url.set_query(None); + + redirect_url = Url::parse_with_params( + redirect_url.as_str(), + &[("redirect_url", redirect_url.as_str())], + ) + .unwrap(); + let url = Url::parse_with_params( + &format!("{}oauth/authorize", mast_url), + &[ + ("redirect_uri", redirect_url.as_str()), + ("client_id", &mast_cli), + ("scope", &mast_scope), + ("response_type", "code"), + ], + ) + .unwrap(); + + Redirect::to(url.to_string()) +} + +#[derive(Deserialize, Debug)] +#[serde(crate = "rocket::serde")] +struct Token { + pub access_token: String, +} + +#[derive(Deserialize, Debug)] +#[serde(crate = "rocket::serde")] +struct Account { + pub id: String, +} +#[get("/cs/auth?&")] +pub async fn cs_auth(code: String, redirect_url: String, db: Db) -> Redirect { + let mast_url = env::var("MAST_BASE_URL").unwrap(); + let mast_cli = env::var("MAST_CLIENT").unwrap(); + let mast_sec = env::var("MAST_SECRET").unwrap(); + let mast_scope = env::var("MAST_SCOPE").unwrap(); + + // to keep same + let redirect_url = Url::parse_with_params( + redirect_url.as_str(), + &[("redirect_url", redirect_url.as_str())], + ) + .unwrap(); + + let client = reqwest::Client::new(); + let token: Token = client + .post(format!("{}oauth/token", &mast_url)) + .form(&[ + ("client_id", mast_cli.as_str()), + ("client_secret", mast_sec.as_str()), + ("scope", mast_scope.as_str()), + ("redirect_uri", redirect_url.as_str()), + ("grant_type", "authorization_code"), + ("code", code.as_str()), + ]) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + //dbg!(&token); + + let client = reqwest::Client::new(); + let account = client + .get(format!("{}api/v1/accounts/verify_credentials", &mast_url)) + .bearer_auth(token.access_token) + .send() + .await + .unwrap() + .json::() + .await + .unwrap(); + + //dbg!(&account); + + let tk = User::find_or_create_token(&db, &format!("cs_{}", &account.id), false) + .await + .unwrap(); + + Redirect::to(format!("/?token={}", tk)) +} diff --git a/src/main.rs b/src/main.rs index 93e7481..4b82f8a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,6 +16,7 @@ mod api; mod cache; mod db_conn; mod libs; +mod login; mod models; mod random_hasher; mod rds_conn; @@ -27,7 +28,7 @@ use diesel::Connection; use random_hasher::RandomHasher; use rds_conn::{init_rds_client, RdsConn}; use std::env; -use tokio::time::{interval, Duration}; +use tokio::time::{sleep, Duration}; embed_migrations!("migrations/postgres"); @@ -43,9 +44,8 @@ async fn main() -> Result<(), rocket::Error> { let rconn = RdsConn(rmc.clone()); clear_outdate_redis_data(&rconn.clone()).await; tokio::spawn(async move { - let mut itv = interval(Duration::from_secs(4 * 60 * 60)); loop { - itv.tick().await; + sleep(Duration::from_secs(4 * 60 * 60)).await; models::Post::annealing(establish_connection(), &rconn).await; } }); @@ -71,6 +71,7 @@ async fn main() -> Result<(), rocket::Error> { api::vote::vote, ], ) + .mount("/_login", routes![login::cs_login, login::cs_auth]) .register( "/_api", catchers![api::catch_401_error, api::catch_403_error,], diff --git a/src/models.rs b/src/models.rs index cf0df06..5659ac2 100644 --- a/src/models.rs +++ b/src/models.rs @@ -3,6 +3,7 @@ use crate::cache::*; use crate::db_conn::{Conn, Db}; use crate::libs::diesel_logger::LoggingConnection; +use crate::random_hasher::random_string; use crate::rds_conn::RdsConn; use crate::schema::*; use chrono::{offset::Utc, DateTime}; @@ -368,6 +369,38 @@ impl User { Some(u) } } + + pub async fn find_or_create_token( + db: &Db, + name: &str, + force_refresh: bool, + ) -> QueryResult { + let name = name.to_string(); + db.run(move |c| { + if let Some(u) = { + if force_refresh { + None + } else { + users::table + .filter(users::name.eq(&name)) + .first::(with_log!(c)) + .ok() + } + } { + Ok(u.token) + } else { + let token = random_string(16); + diesel::insert_into(users::table) + .values((users::name.eq(&name), users::token.eq(&token))) + .on_conflict(users::name) + .do_update() + .set(users::token.eq(&token)) + .execute(with_log!(c))?; + Ok(token) + } + }) + .await + } } #[derive(Insertable)] diff --git a/src/random_hasher.rs b/src/random_hasher.rs index bdafda6..7b103a8 100644 --- a/src/random_hasher.rs +++ b/src/random_hasher.rs @@ -2,6 +2,14 @@ use chrono::{offset::Local, DateTime}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use sha2::{Digest, Sha256}; +pub fn random_string(len: usize) -> String { + thread_rng() + .sample_iter(&Alphanumeric) + .take(len) + .map(char::from) + .collect() +} + pub struct RandomHasher { pub salt: String, pub start_time: DateTime, @@ -10,11 +18,7 @@ pub struct RandomHasher { impl RandomHasher { pub fn get_random_one() -> RandomHasher { RandomHasher { - salt: thread_rng() - .sample_iter(&Alphanumeric) - .take(16) - .map(char::from) - .collect(), + salt: random_string(16), start_time: Local::now(), } }