#![allow(clippy::unused_unit)] use crate::db_conn::Db; use crate::models::User; use crate::random_hasher::RandomHasher; use rocket::request::{FromRequest, Outcome, Request}; use rocket::response::Redirect; use rocket::serde::Deserialize; use rocket::State; use std::env; use url::Url; #[derive(Debug)] pub struct FrontendAddr(pub String); #[rocket::async_trait] impl<'r> FromRequest<'r> for FrontendAddr { type Error = (); async fn from_request(request: &'r Request<'_>) -> Outcome { Outcome::Success(Self( request .headers() .get_one("Referer") .map(|s| s.to_string()) .unwrap_or_else(|| env::var("DEFAULT_FRONTEND").unwrap()), )) } } #[derive(Debug)] pub struct BackendAddr(pub String); #[rocket::async_trait] impl<'r> FromRequest<'r> for BackendAddr { type Error = (); async fn from_request(request: &'r Request<'_>) -> Outcome { Outcome::Success(Self( request .headers() .get_one("Host") .map(|s| format!("https://{}", s)) .unwrap_or_else(|| env::var("DEFAULT_BACKEND").unwrap()), )) } } #[get("/?p=cs")] pub fn cs_login(r: FrontendAddr, h: BackendAddr) -> Redirect { let mast_url = env::var("MAST_BASE_URL").unwrap(); let mast_cli = env::var("MAST_CLIENT").unwrap(); let mast_scope = env::var("MAST_SCOPE").unwrap(); let jump_to_url = Url::parse(&r.0).unwrap(); let mut redirect_url = Url::parse(&h.0).unwrap(); redirect_url.set_path("/_login/cs/auth"); redirect_url = Url::parse_with_params( redirect_url.as_str(), &[ ("redirect_url", redirect_url.as_str()), ("jump_to_url", jump_to_url.as_str()), ], ) .unwrap(); let url = Url::parse_with_params( &format!("{}oauth/authorize", mast_url), &[ ("redirect_uri", redirect_url.as_str()), ("client_id", &mast_cli), ("scope", &mast_scope), ("response_type", "code"), ], ) .unwrap(); Redirect::to(url.to_string()) } #[derive(Deserialize, Debug)] #[serde(crate = "rocket::serde")] struct Token { pub access_token: String, } #[derive(Deserialize, Debug)] #[serde(crate = "rocket::serde")] struct Account { pub id: String, } #[get("/cs/auth?&&")] pub async fn cs_auth( code: String, redirect_url: String, jump_to_url: String, db: Db, rh: &State, ) -> Result { if !env::var("FRONTEND_WHITELIST") .unwrap_or_default() .split(',') .any(|url| jump_to_url.starts_with(url)) { return Err("前端地址不在白名单内"); } let mast_url = env::var("MAST_BASE_URL").unwrap(); let mast_cli = env::var("MAST_CLIENT").unwrap(); let mast_sec = env::var("MAST_SECRET").unwrap(); let mast_scope = env::var("MAST_SCOPE").unwrap(); // to keep same let redirect_url = Url::parse_with_params( redirect_url.as_str(), &[ ("redirect_url", redirect_url.as_str()), ("jump_to_url", jump_to_url.as_str()), ], ) .unwrap(); let client = reqwest::Client::new(); let r = client .post(format!("{}oauth/token", &mast_url)) .form(&[ ("client_id", mast_cli.as_str()), ("client_secret", mast_sec.as_str()), ("scope", mast_scope.as_str()), ("redirect_uri", redirect_url.as_str()), ("grant_type", "authorization_code"), ("code", code.as_str()), ]) .send() .await .unwrap(); //dbg!(&r); let token: Token = r.json().await.unwrap(); //dbg!(&token); let client = reqwest::Client::new(); let account = client .get(format!("{}api/v1/accounts/verify_credentials", &mast_url)) .bearer_auth(token.access_token) .send() .await .unwrap() .json::() .await .unwrap(); //dbg!(&account); let tk = User::find_or_create_token( &db, &rh.hash_with_salt(&format!("cs_{}", &account.id)), false, ) .await .unwrap(); Ok(Redirect::to(format!("{}?token={}", &jump_to_url, &tk))) } #[get("/gh")] pub fn gh_login(r: FrontendAddr, h: BackendAddr) -> Redirect { let gh_url = "https://github.com/login/oauth/authorize"; let gh_cli = env::var("GH_CLIENT").unwrap(); let gh_scope = "user:email"; let jump_to_url = Url::parse(&r.0).unwrap(); let mut redirect_url = Url::parse(&h.0).unwrap(); redirect_url.set_path("/_login/gh/auth"); redirect_url = Url::parse_with_params( redirect_url.as_str(), &[("jump_to_url", jump_to_url.as_str())], ) .unwrap(); let url = Url::parse_with_params( gh_url, &[ ("redirect_uri", redirect_url.as_str()), ("client_id", &gh_cli), ("scope", gh_scope), ], ) .unwrap(); Redirect::to(url.to_string()) } #[derive(Deserialize, Debug)] #[serde(crate = "rocket::serde")] struct GithubEmail { pub email: String, pub verified: bool, } #[get("/gh/auth?&")] pub async fn gh_auth( code: String, jump_to_url: String, db: Db, rh: &State, ) -> Result { if !env::var("FRONTEND_WHITELIST") .unwrap_or_default() .split(',') .any(|url| jump_to_url.starts_with(url)) { return Err("前端地址不在白名单内"); } let gh_cli = env::var("GH_CLIENT").unwrap(); let gh_sec = env::var("GH_SECRET").unwrap(); let client = reqwest::Client::new(); let r = client .post("https://github.com/login/oauth/access_token") .header(reqwest::header::ACCEPT, "application/json") .form(&[ ("client_id", gh_cli.as_str()), ("client_secret", gh_sec.as_str()), ("code", code.as_str()), ]) .send() .await .unwrap(); //let token: rocket::serde::json::Value = r.json().await.unwrap(); let token: Token = r.json().await.unwrap(); dbg!(&token); let client = reqwest::Client::new(); let r = client .get("https://api.github.com/user/emails") .bearer_auth(token.access_token) .header(reqwest::header::USER_AGENT, "hole_thu LoginBot") .send() .await .unwrap(); // dbg!(&r); let emails = r .json::>() //.json::() .await .unwrap(); //dbg!(&emails); let name = emails .iter() .filter(|email| email.verified) .find_map( |email| match email.email.split('@').collect::>()[..] { [name, "mails.tsinghua.edu.cn"] | [name, "tsinghua.org.cn"] => Some(name), _ => None, }, ); if let Some(name) = name { let tk = User::find_or_create_token(&db, &rh.hash_with_salt(&format!("email_{}", name)), false) .await .unwrap(); Ok(Redirect::to(format!("{}?token={}", &jump_to_url, &tk))) } else { Err("没有找到已验证的清华邮箱/校友邮箱") } }