diff --git a/migrations/2022-03-15-061041_create_users/down.sql b/migrations/2022-03-15-061041_create_users/down.sql new file mode 100644 index 0000000..cf7e736 --- /dev/null +++ b/migrations/2022-03-15-061041_create_users/down.sql @@ -0,0 +1,2 @@ +-- This file should undo anything in `up.sql` +DROP TABLE users diff --git a/migrations/2022-03-15-061041_create_users/up.sql b/migrations/2022-03-15-061041_create_users/up.sql new file mode 100644 index 0000000..8a7be68 --- /dev/null +++ b/migrations/2022-03-15-061041_create_users/up.sql @@ -0,0 +1,9 @@ +-- Your SQL goes here + +CREATE TABLE users ( + id INTEGER NOT NULL PRIMARY KEY, + name VARCHAR NOT NULL UNIQUE, + token VARCHAR NOT NULL UNIQUE, + is_admin BOOLEAN NOT NULL DEFAULT FALSE +); +CREATE INDEX users_toekn_idx ON users (`token`); diff --git a/src/api/mod.rs b/src/api/mod.rs index ba11f30..5cc46a4 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,8 +1,9 @@ +use crate::models::*; use crate::random_hasher::RandomHasher; use rocket::http::Status; use rocket::request::{self, FromRequest, Request}; -use rocket::serde::json::{Value, json}; use rocket::response::{self, Responder}; +use rocket::serde::json::{json, Value}; #[catch(401)] pub fn catch_401_error() -> Value { @@ -21,31 +22,32 @@ pub struct CurrentUser { impl<'r> FromRequest<'r> for CurrentUser { type Error = (); async fn from_request(request: &'r Request<'_>) -> request::Outcome { - let token = request.headers().get_one("User-Token"); - match token { - Some(t) => request::Outcome::Success(CurrentUser { - namehash: request - .rocket() - .state::() - .unwrap() - .hash_with_salt(t), - is_admin: t == "admin", // TODO - }), - None => request::Outcome::Failure((Status::Unauthorized, ())), + if let Some(token) = request.headers().get_one("User-Token") { + let conn = establish_connection(); + if let Some(user) = User::get_by_token(&conn, token) { + return request::Outcome::Success(CurrentUser { + namehash: request + .rocket() + .state::() + .unwrap() + .hash_with_salt(&user.name), + is_admin: user.is_admin, + }); + } } + request::Outcome::Failure((Status::Unauthorized, ())) } } pub enum PolicyError { IsReported, IsDeleted, - NotAllowed + NotAllowed, } - pub enum APIError { DbError(diesel::result::Error), - PcError(PolicyError) + PcError(PolicyError), } impl APIError { @@ -57,20 +59,20 @@ impl APIError { impl<'r> Responder<'r, 'static> for APIError { fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { match self { - APIError::DbError(e) => - json!({ - "code": -1, - "msg": e.to_string() - }).respond_to(req), - APIError::PcError(e) => - json!({ - "code": -1, - "msg": match e { - PolicyError::IsReported => "内容被举报,处理中", - PolicyError::IsDeleted => "内容被删除", - PolicyError::NotAllowed => "不允许的操作", - } - }).respond_to(req), + APIError::DbError(e) => json!({ + "code": -1, + "msg": e.to_string() + }) + .respond_to(req), + APIError::PcError(e) => json!({ + "code": -1, + "msg": match e { + PolicyError::IsReported => "内容被举报,处理中", + PolicyError::IsDeleted => "内容被删除", + PolicyError::NotAllowed => "不允许的操作", + } + }) + .respond_to(req), } } } diff --git a/src/models.rs b/src/models.rs index 458f5f7..120205b 100644 --- a/src/models.rs +++ b/src/models.rs @@ -4,7 +4,7 @@ use chrono::NaiveDateTime; use diesel::{insert_into, Connection, ExpressionMethods, QueryDsl, RunQueryDsl, SqliteConnection}; use std::env; -use crate::schema::posts; +use crate::schema::*; type MR = Result; @@ -80,3 +80,17 @@ impl Post { insert_into(posts::table).values(&new_post).execute(conn) } } + +#[derive(Queryable, Debug)] +pub struct User { + pub id: i32, + pub name: String, + pub token: String, + pub is_admin: bool, +} + +impl User { + pub fn get_by_token(conn: &SqliteConnection, token: &str) -> Option { + users::table.filter(users::token.eq(token)).first(conn).ok() + } +} diff --git a/src/schema.rs b/src/schema.rs index 0e9cc3a..6277029 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -15,3 +15,17 @@ table! { allow_search -> Bool, } } + +table! { + users (id) { + id -> Integer, + name -> Text, + token -> Text, + is_admin -> Bool, + } +} + +allow_tables_to_appear_in_same_query!( + posts, + users, +); diff --git a/tools/migdb.py b/tools/migdb.py index c605717..4d6c98e 100644 --- a/tools/migdb.py +++ b/tools/migdb.py @@ -1,9 +1,13 @@ import sqlite3 from datetime import datetime -def mig_post(db_old, db_new): - c_old = db_old.cursor() - c_new = db_new.cursor() +db_old = sqlite3.connect('hole.db') +db_new = sqlite3.connect('hole_v2.db') +c_old = db_old.cursor() +c_new = db_new.cursor() + + +def mig_post(): rs = c_old.execute( 'SELECT id, name_hash, content, cw, author_title, ' 'likenum, n_comments, timestamp, comment_timestamp, ' @@ -25,13 +29,22 @@ def mig_post(db_old, db_new): ) db_new.commit() - c_old.close() - c_new.close() + +def mig_user(): + rs = c_old.execute('SELECT name, token FROM user') + + for r in rs: + c_new.execute( + 'INSERT OR REPLACE INTO users(name, token) VALUES(?, ?)', + r + ) + db_new.commit() if __name__ == '__main__': - db_old = sqlite3.connect('hole.db') - db_new = sqlite3.connect('hole_v2.db') + # mig_post() + mig_user() - mig_post(db_old, db_new) +c_old.close() +c_new.close()