diff --git a/.env.sample b/.env.sample index edeabfa..94c0f37 100644 --- a/.env.sample +++ b/.env.sample @@ -3,6 +3,7 @@ MAST_CLIENT="" MAST_SECRET="" MAST_SCOPE="read:accounts" +AUTH_BACKEND_URL="http://hole.localhost" FRONTEND_WHITELIST="https://hole-thu.github.io" DATABASE_URL="postgres://hole:hole_pass@localhost/hole_v2" diff --git a/src/cors.rs b/src/cors.rs new file mode 100644 index 0000000..d4d5dd4 --- /dev/null +++ b/src/cors.rs @@ -0,0 +1,30 @@ +use rocket::fairing::{Fairing, Info, Kind}; +use rocket::http::Header; +use rocket::{Request, Response}; + +pub struct CORS { + pub whitelist: Vec, +} + +#[rocket::async_trait] +impl Fairing for CORS { + fn info(&self) -> Info { + Info { + name: "Add CORS headers to responses", + kind: Kind::Response, + } + } + + async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) { + request + .headers() + .get_one("Origin") + .and_then(|origin| self.whitelist.contains(&origin.to_string()).then(|| origin)) + .and_then(|origin| { + response.set_header(Header::new("Access-Control-Allow-Origin", origin)); + response.set_header(Header::new("Access-Control-Allow-Methods", "POST, GET")); + response.set_header(Header::new("Access-Control-Allow-Credentials", "true")); + Some(()) + }); + } +} diff --git a/src/login.rs b/src/login.rs index 7e40225..204024d 100644 --- a/src/login.rs +++ b/src/login.rs @@ -25,15 +25,22 @@ pub fn cs_login(r: RefHeader) -> Redirect { let mast_cli = env::var("MAST_CLIENT").unwrap(); let mast_scope = env::var("MAST_SCOPE").unwrap(); - let mut redirect_url = Url::parse(&r.0).unwrap(); + let jump_to_url = Url::parse(&r.0).unwrap(); + + let mut redirect_url = env::var("AUTH_BACKEND_URL") + .map(|url| Url::parse(&url).unwrap()) + .unwrap_or_else(|_| jump_to_url.clone()); redirect_url.set_path("/_login/cs/auth"); - redirect_url.set_query(None); redirect_url = Url::parse_with_params( redirect_url.as_str(), - &[("redirect_url", redirect_url.as_str())], + &[ + ("redirect_url", redirect_url.as_str()), + ("jump_to_url", jump_to_url.as_str()), + ], ) .unwrap(); + let url = Url::parse_with_params( &format!("{}oauth/authorize", mast_url), &[ @@ -59,8 +66,8 @@ struct Token { struct Account { pub id: String, } -#[get("/cs/auth?&")] -pub async fn cs_auth(code: String, redirect_url: String, db: Db) -> Redirect { +#[get("/cs/auth?&&")] +pub async fn cs_auth(code: String, redirect_url: String, jump_to_url: String, db: Db) -> Redirect { let mast_url = env::var("MAST_BASE_URL").unwrap(); let mast_cli = env::var("MAST_CLIENT").unwrap(); let mast_sec = env::var("MAST_SECRET").unwrap(); @@ -69,12 +76,15 @@ pub async fn cs_auth(code: String, redirect_url: String, db: Db) -> Redirect { // to keep same let redirect_url = Url::parse_with_params( redirect_url.as_str(), - &[("redirect_url", redirect_url.as_str())], + &[ + ("redirect_url", redirect_url.as_str()), + ("jump_to_url", jump_to_url.as_str()), + ], ) .unwrap(); let client = reqwest::Client::new(); - let token: Token = client + let r = client .post(format!("{}oauth/token", &mast_url)) .form(&[ ("client_id", mast_cli.as_str()), @@ -86,11 +96,10 @@ pub async fn cs_auth(code: String, redirect_url: String, db: Db) -> Redirect { ]) .send() .await - .unwrap() - .json() - .await .unwrap(); + //dbg!(&r); + let token: Token = r.json().await.unwrap(); //dbg!(&token); let client = reqwest::Client::new(); @@ -110,5 +119,5 @@ pub async fn cs_auth(code: String, redirect_url: String, db: Db) -> Redirect { .await .unwrap(); - Redirect::to(format!("/?token={}", tk)) + Redirect::to(format!("{}?token={}", &jump_to_url, &tk)) }