You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

274 lines
7.3 KiB

#![allow(clippy::unused_unit)]
use crate::db_conn::Db;
use crate::models::User;
use crate::random_hasher::RandomHasher;
use rocket::request::{FromRequest, Outcome, Request};
use rocket::response::Redirect;
use rocket::serde::Deserialize;
use rocket::State;
use std::env;
use url::Url;
#[derive(Debug)]
pub struct FrontendAddr(pub String);
#[rocket::async_trait]
impl<'r> FromRequest<'r> for FrontendAddr {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
Outcome::Success(Self(
request
.headers()
.get_one("Referer")
.map(|s| s.to_string())
.unwrap_or_else(|| env::var("DEFAULT_FRONTEND").unwrap()),
))
}
}
#[derive(Debug)]
pub struct BackendAddr(pub String);
#[rocket::async_trait]
impl<'r> FromRequest<'r> for BackendAddr {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
Outcome::Success(Self(
request
.headers()
.get_one("Host")
.map(|s| format!("https://{}", s))
.unwrap_or_else(|| env::var("DEFAULT_BACKEND").unwrap()),
))
}
}
#[get("/?p=cs")]
pub fn cs_login(r: FrontendAddr, h: BackendAddr) -> Redirect {
let mast_url = env::var("MAST_BASE_URL").unwrap();
let mast_cli = env::var("MAST_CLIENT").unwrap();
let mast_scope = env::var("MAST_SCOPE").unwrap();
let jump_to_url = Url::parse(&r.0).unwrap();
let mut redirect_url = Url::parse(&h.0).unwrap();
redirect_url.set_path("/_login/cs/auth");
redirect_url = Url::parse_with_params(
redirect_url.as_str(),
&[
("redirect_url", redirect_url.as_str()),
("jump_to_url", jump_to_url.as_str()),
],
)
.unwrap();
let url = Url::parse_with_params(
&format!("{}oauth/authorize", mast_url),
&[
("redirect_uri", redirect_url.as_str()),
("client_id", &mast_cli),
("scope", &mast_scope),
("response_type", "code"),
],
)
.unwrap();
Redirect::to(url.to_string())
}
#[derive(Deserialize, Debug)]
#[serde(crate = "rocket::serde")]
struct Token {
pub access_token: String,
}
#[derive(Deserialize, Debug)]
#[serde(crate = "rocket::serde")]
struct Account {
pub id: String,
}
#[get("/cs/auth?<code>&<redirect_url>&<jump_to_url>")]
pub async fn cs_auth(
code: String,
redirect_url: String,
jump_to_url: String,
db: Db,
rh: &State<RandomHasher>,
) -> Result<Redirect, &'static str> {
if !env::var("FRONTEND_WHITELIST")
.unwrap_or_default()
.split(',')
.any(|url| jump_to_url.starts_with(url))
{
return Err("前端地址不在白名单内");
}
let mast_url = env::var("MAST_BASE_URL").unwrap();
let mast_cli = env::var("MAST_CLIENT").unwrap();
let mast_sec = env::var("MAST_SECRET").unwrap();
let mast_scope = env::var("MAST_SCOPE").unwrap();
// to keep same
let redirect_url = Url::parse_with_params(
redirect_url.as_str(),
&[
("redirect_url", redirect_url.as_str()),
("jump_to_url", jump_to_url.as_str()),
],
)
.unwrap();
let client = reqwest::Client::new();
let r = client
.post(format!("{}oauth/token", &mast_url))
.form(&[
("client_id", mast_cli.as_str()),
("client_secret", mast_sec.as_str()),
("scope", mast_scope.as_str()),
("redirect_uri", redirect_url.as_str()),
("grant_type", "authorization_code"),
("code", code.as_str()),
])
.send()
.await
.unwrap();
//dbg!(&r);
let token: Token = r.json().await.unwrap();
//dbg!(&token);
let client = reqwest::Client::new();
let account = client
.get(format!("{}api/v1/accounts/verify_credentials", &mast_url))
.bearer_auth(token.access_token)
.send()
.await
.unwrap()
.json::<Account>()
.await
.unwrap();
//dbg!(&account);
let tk = User::find_or_create_token(
&db,
&rh.hash_with_salt(&format!("cs_{}", &account.id)),
false,
)
.await
.unwrap();
Ok(Redirect::to(format!("{}?token={}", &jump_to_url, &tk)))
}
#[get("/gh")]
pub fn gh_login(r: FrontendAddr, h: BackendAddr) -> Redirect {
let gh_url = "https://github.com/login/oauth/authorize";
let gh_cli = env::var("GH_CLIENT").unwrap();
let gh_scope = "user:email";
let jump_to_url = Url::parse(&r.0).unwrap();
let mut redirect_url = Url::parse(&h.0).unwrap();
redirect_url.set_path("/_login/gh/auth");
redirect_url = Url::parse_with_params(
redirect_url.as_str(),
&[("jump_to_url", jump_to_url.as_str())],
)
.unwrap();
let url = Url::parse_with_params(
gh_url,
&[
("redirect_uri", redirect_url.as_str()),
("client_id", &gh_cli),
("scope", gh_scope),
],
)
.unwrap();
Redirect::to(url.to_string())
}
#[derive(Deserialize, Debug)]
#[serde(crate = "rocket::serde")]
struct GithubEmail {
pub email: String,
pub verified: bool,
}
#[get("/gh/auth?<code>&<jump_to_url>")]
pub async fn gh_auth(
code: String,
jump_to_url: String,
db: Db,
rh: &State<RandomHasher>,
) -> Result<Redirect, &'static str> {
if !env::var("FRONTEND_WHITELIST")
.unwrap_or_default()
.split(',')
.any(|url| jump_to_url.starts_with(url))
{
return Err("前端地址不在白名单内");
}
let gh_cli = env::var("GH_CLIENT").unwrap();
let gh_sec = env::var("GH_SECRET").unwrap();
let client = reqwest::Client::new();
let r = client
.post("https://github.com/login/oauth/access_token")
.header(reqwest::header::ACCEPT, "application/json")
.form(&[
("client_id", gh_cli.as_str()),
("client_secret", gh_sec.as_str()),
("code", code.as_str()),
])
.send()
.await
.unwrap();
//let token: rocket::serde::json::Value = r.json().await.unwrap();
let token: Token = r.json().await.unwrap();
dbg!(&token);
let client = reqwest::Client::new();
let r = client
.get("https://api.github.com/user/emails")
.bearer_auth(token.access_token)
.header(reqwest::header::USER_AGENT, "hole_thu LoginBot")
.send()
.await
.unwrap();
// dbg!(&r);
let emails = r
.json::<Vec<GithubEmail>>()
//.json::<rocket::serde::json::Value>()
.await
.unwrap();
//dbg!(&emails);
let name = emails
.iter()
.filter(|email| email.verified)
.find_map(
|email| match email.email.split('@').collect::<Vec<&str>>()[..] {
[name, "mails.tsinghua.edu.cn"] | [name, "tsinghua.org.cn"] => Some(name),
_ => None,
},
);
if let Some(name) = name {
let tk =
User::find_or_create_token(&db, &rh.hash_with_salt(&format!("email_{}", name)), false)
.await
.unwrap();
Ok(Redirect::to(format!("{}?token={}", &jump_to_url, &tk)))
} else {
Err("没有找到已验证的清华邮箱/校友邮箱")
}
}