diff --git a/.gitignore b/.gitignore index 0563df4..f399965 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# private key +keys + # --> sqlite3 *.db diff --git a/Cargo.toml b/Cargo.toml index 23d8cda..18abe7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ license = "WTFPL-2.0" [features] default = ["mastlogin"] -mastlogin = ["url", "reqwest"] +mastlogin = ["reqwest"] [dependencies] rocket = { version = "=0.5.0-rc.1", features = ["json"] } @@ -26,6 +26,7 @@ dotenv = "0.15.0" sha2 = "0.10.2" log = "0.4.16" env_logger = "0.9.0" +web-push = "0.9.2" +url = "2.2.2" -url = { version="2.2.2",optional = true } reqwest = { version = "0.11.10", features = ["json"], optional = true } diff --git a/src/api/attention.rs b/src/api/attention.rs index 0ad9c50..4b04fc0 100644 --- a/src/api/attention.rs +++ b/src/api/attention.rs @@ -9,6 +9,13 @@ use crate::schema; use diesel::{ExpressionMethods, QueryDsl, RunQueryDsl}; use rocket::form::Form; use rocket::serde::json::json; +use rocket::serde::json::serde_json; +use rocket::serde::Serialize; +use std::fs::File; +use url::Url; +use web_push::{ + ContentEncoding, SubscriptionInfo, VapidSignatureBuilder, WebPushClient, WebPushMessageBuilder, +}; #[derive(FromForm)] pub struct AttentionInput { @@ -76,3 +83,57 @@ pub async fn get_attention(user: CurrentUser, db: Db, rconn: RdsConn) -> JsonApi code0!(ps_data) } + +#[derive(FromForm)] +pub struct NotificatinInput { + enable: bool, + endpoint: String, + auth: String, + p256dh: String, +} + +#[derive(Serialize)] +#[serde(crate = "rocket::serde")] +struct PushData { + title: String, + pid: i32, + text: String, +} + +#[post("/post//notification", data = "")] +pub async fn set_notification(pid: i32, ni: Form, _user: CurrentUser) -> JsonApi { + let url_host = Url::parse(&ni.endpoint) + .map_err(|_| UnknownPushEndpoint)? + .host() + .ok_or(UnknownPushEndpoint)? + .to_string(); + (url_host.ends_with("googleapis.com") || url_host.ends_with("mozilla.com")) + .then(|| ()) + .ok_or(UnknownPushEndpoint)?; + + if ni.enable { + let subscription_info = SubscriptionInfo::new(&ni.endpoint, &ni.p256dh, &ni.auth); + + let file = File::open("keys/private.pem").unwrap(); + let sig_builder = VapidSignatureBuilder::from_pem(file, &subscription_info) + .unwrap() + .build() + .unwrap(); + + let mut builder = WebPushMessageBuilder::new(&subscription_info).unwrap(); + let data = PushData { + title: "测试".to_owned(), + pid, + text: format!("#{} 开启提醒测试成功,消息提醒功能即将正式上线", &pid), + }; + let content = serde_json::to_string(&data).unwrap(); + builder.set_payload(ContentEncoding::Aes128Gcm, content.as_bytes()); + builder.set_vapid_signature(sig_builder); + + let client = WebPushClient::new()?; + + client.send(builder.build()?).await?; + } + + code0!() +} diff --git a/src/api/mod.rs b/src/api/mod.rs index d4e3f16..9647ae1 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -119,12 +119,14 @@ pub enum PolicyError { YouAreTmp, NoReason, OldApi, + UnknownPushEndpoint, } #[derive(Debug)] pub enum ApiError { Db(diesel::result::Error), Rds(redis::RedisError), + WebPush(web_push::WebPushError), Pc(PolicyError), IO(std::io::Error), } @@ -134,6 +136,7 @@ impl<'r> Responder<'r, 'static> for ApiError { match self { ApiError::Db(e) => e2s!(e).respond_to(req), ApiError::Rds(e) => e2s!(e).respond_to(req), + ApiError::WebPush(e) => e2s!(e).respond_to(req), ApiError::IO(e) => e2s!(e).respond_to(req), ApiError::Pc(e) => json!({ "code": -1, @@ -144,7 +147,8 @@ impl<'r> Responder<'r, 'static> for ApiError { PolicyError::TitleUsed => "头衔已被使用", PolicyError::YouAreTmp => "临时用户只可发布内容和进入单个洞", PolicyError::NoReason => "未填写理由", - PolicyError::OldApi => "请使用最新版前端地址并检查更新" + PolicyError::OldApi => "请使用最新版前端地址并检查更新", + PolicyError::UnknownPushEndpoint => "未知的浏览器推送地址", } }) .respond_to(req), @@ -152,6 +156,12 @@ impl<'r> Responder<'r, 'static> for ApiError { } } +impl From for ApiError { + fn from(err: web_push::WebPushError) -> ApiError { + ApiError::WebPush(err) + } +} + impl From for ApiError { fn from(err: diesel::result::Error) -> ApiError { ApiError::Db(err) diff --git a/src/main.rs b/src/main.rs index c8dbf32..439cab2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -90,6 +90,7 @@ async fn main() -> Result<(), rocket::Error> { api::comment::add_comment, api::upload::local_upload, cors::options_handler, + api::attention::set_notification, ], ) .mount(