diff --git a/src/login.rs b/src/login.rs index 22d6ec1..4ad229f 100644 --- a/src/login.rs +++ b/src/login.rs @@ -69,7 +69,20 @@ struct Account { pub id: String, } #[get("/cs/auth?&&")] -pub async fn cs_auth(code: String, redirect_url: String, jump_to_url: String, db: Db) -> Redirect { +pub async fn cs_auth( + code: String, + redirect_url: String, + jump_to_url: String, + db: Db, +) -> Result { + if !env::var("FRONTEND_WHITELIST") + .unwrap_or_default() + .split(',') + .any(|url| jump_to_url.starts_with(url)) + { + return Err("前端地址不在白名单内"); + } + let mast_url = env::var("MAST_BASE_URL").unwrap(); let mast_cli = env::var("MAST_CLIENT").unwrap(); let mast_sec = env::var("MAST_SECRET").unwrap(); @@ -121,19 +134,109 @@ pub async fn cs_auth(code: String, redirect_url: String, jump_to_url: String, db .await .unwrap(); - Redirect::to(format!( - "{}?token={}", + Ok(Redirect::to(format!("{}?token={}", &jump_to_url, &tk))) +} + +#[get("/gh")] +pub fn gh_login(r: RefHeader) -> Redirect { + let gh_url = "https://github.com/login/oauth/authorize"; + let gh_cli = env::var("GH_CLIENT").unwrap(); + let gh_scope = "user:email"; + + let jump_to_url = Url::parse(&r.0).unwrap(); + + let mut redirect_url = env::var("AUTH_BACKEND_URL") + .map(|url| Url::parse(&url).unwrap()) + .unwrap_or_else(|_| jump_to_url.clone()); + redirect_url.set_path("/_login/gh/auth"); + + redirect_url = Url::parse_with_params( + redirect_url.as_str(), + &[("jump_to_url", jump_to_url.as_str())], + ) + .unwrap(); + + let url = Url::parse_with_params( + gh_url, + &[ + ("redirect_uri", redirect_url.as_str()), + ("client_id", &gh_cli), + ("scope", gh_scope), + ], + ) + .unwrap(); + + Redirect::to(url.to_string()) +} + +#[derive(Deserialize, Debug)] +#[serde(crate = "rocket::serde")] +struct GithubEmail { + pub email: String, + pub verified: bool, +} + +#[get("/gh/auth?&")] +pub async fn gh_auth(code: String, jump_to_url: String, db: Db) -> Result { + if !env::var("FRONTEND_WHITELIST") + .unwrap_or_default() + .split(',') + .any(|url| jump_to_url.starts_with(url)) + { + return Err("前端地址不在白名单内"); + } + + let gh_cli = env::var("GH_CLIENT").unwrap(); + let gh_sec = env::var("GH_SECRET").unwrap(); + + let client = reqwest::Client::new(); + let r = client + .post("https://github.com/login/oauth/access_token") + .header(reqwest::header::ACCEPT, "application/json") + .form(&[ + ("client_id", gh_cli.as_str()), + ("client_secret", gh_sec.as_str()), + ("code", code.as_str()), + ]) + .send() + .await + .unwrap(); + + //let token: rocket::serde::json::Value = r.json().await.unwrap(); + let token: Token = r.json().await.unwrap(); + + dbg!(&token); + + let client = reqwest::Client::new(); + let r = client + .get("https://api.github.com/user/emails") + .bearer_auth(token.access_token) + .header(reqwest::header::USER_AGENT, "hole_thu LoginBot") + .send() + .await + .unwrap(); + // dbg!(&r); + let emails = r + .json::>() + //.json::() + .await + .unwrap(); + + //dbg!(&emails); + + for email in emails { + if let Some(email_name) = email + .email + .strip_suffix("@mails.tsinghua.edu.cn") + .and_then(|name| email.verified.then_some(name)) { - if env::var("FRONTEND_WHITELIST") - .unwrap_or_default() - .split(',') - .any(|url| jump_to_url.starts_with(url)) - { - &jump_to_url - } else { - "/" - } - }, - &tk - )) + let tk = User::find_or_create_token(&db, &format!("gh_{}", email_name), false) + .await + .unwrap(); + + return Ok(Redirect::to(format!("{}?token={}", &jump_to_url, &tk))); + } + } + + Err("没有找到已验证的清华邮箱") } diff --git a/src/main.rs b/src/main.rs index f57d785..8167285 100644 --- a/src/main.rs +++ b/src/main.rs @@ -99,7 +99,12 @@ async fn main() { "/_login", [ #[cfg(feature = "mastlogin")] - routes![login::cs_login, login::cs_auth], + routes![ + login::cs_login, + login::cs_auth, + login::gh_login, + login::gh_auth + ], routes![], ] .concat(),