Browse Source

重构

pull/7/head
hole-thu 4 years ago
parent
commit
3978af5d09
  1. 3
      config.sample.py
  2. 108
      hole.py
  3. 3
      models.py
  4. 15
      requirements.txt
  5. 21
      utils.py

3
config.sample.py

@ -1,5 +1,3 @@
import random
import string
import time import time
SQLALCHEMY_DATABASE_URI = 'sqlite:///hole.db' SQLALCHEMY_DATABASE_URI = 'sqlite:///hole.db'
@ -9,7 +7,6 @@ CLIENT_ID = '<id>'
CLIENT_SECRET = '<secret>' CLIENT_SECRET = '<secret>'
MASTODON_URL = 'https://mastodon.social' MASTODON_URL = 'https://mastodon.social'
REDIRECT_URI = 'http://hole.thu.monster/_auth' REDIRECT_URI = 'http://hole.thu.monster/_auth'
SALT = ''.join(random.choices(string.ascii_letters + string.digits, k=32))
ADMINS = ['cs_114514'] ADMINS = ['cs_114514']
START_TIME = int(time.time()) START_TIME = int(time.time())
ENABLE_TMP = True ENABLE_TMP = True

108
hole.py

@ -10,12 +10,15 @@ from sqlalchemy.sql.expression import func
from mastodon import Mastodon from mastodon import Mastodon
from models import db, User, Post, Comment, Attention, TagRecord, Syslog from models import db, User, Post, Comment, Attention, TagRecord, Syslog
from utils import get_current_user, map_post, map_comment, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin from utils import get_current_username, map_post, map_comment, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin, check_can_del
app = Flask(__name__) app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///hole.db' app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///hole.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['JSON_AS_ASCII'] = False app.config['JSON_AS_ASCII'] = False
app.config['SALT'] = ''.join(random.choices(
string.ascii_letters + string.digits, k=32
))
app.config.from_pyfile('config.py') app.config.from_pyfile('config.py')
db.init_app(app) db.init_app(app)
@ -86,7 +89,7 @@ def auth():
@app.route('/_api/v1/getlist') @app.route('/_api/v1/getlist')
def get_list(): def get_list():
u = get_current_user() username = get_current_username()
p = request.args.get('p', type=int, default=1) p = request.args.get('p', type=int, default=1)
order_mode = request.args.get('order_mode', type=int, default=0) order_mode = request.args.get('order_mode', type=int, default=0)
@ -111,7 +114,7 @@ def get_list():
posts = query.order_by(order).paginate(p, PER_PAGE) posts = query.order_by(order).paginate(p, PER_PAGE)
data = list(map(map_post, posts.items, [u.name] * len(posts.items))) data = list(map(map_post, posts.items, [username] * len(posts.items)))
return { return {
'code': 0, 'code': 0,
@ -123,15 +126,17 @@ def get_list():
@app.route('/_api/v1/getone') @app.route('/_api/v1/getone')
def get_one(): def get_one():
u = get_current_user() username = get_current_username()
pid = request.args.get('pid', type=int) pid = request.args.get('pid', type=int)
post = Post.query.get_or_404(pid) post = Post.query.get_or_404(pid)
if post.deleted or post.is_reported: if post.deleted or post.is_reported and not (
check_can_del(username, post.name_hash)
):
abort(451) abort(451)
data = map_post(post, u.name) data = map_post(post, username)
return { return {
'code': 0, 'code': 0,
@ -141,7 +146,7 @@ def get_one():
@app.route('/_api/v1/search') @app.route('/_api/v1/search')
def search(): def search():
u = get_current_user() username = get_current_username()
page = request.args.get('page', type=int, default=1) page = request.args.get('page', type=int, default=1)
pagesize = min(request.args.get('pagesize', type=int, default=200), 200) pagesize = min(request.args.get('pagesize', type=int, default=200), 200)
@ -155,7 +160,8 @@ def search():
tag=keywords tag=keywords
).all() ).all()
tag_pids = [tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in tag_pids = [
tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in
posts = Post.query.filter( posts = Post.query.filter(
Post.search_text.like("%{}%".format(keywords)) Post.search_text.like("%{}%".format(keywords))
@ -177,7 +183,7 @@ def search():
).all() + posts ).all() + posts
data = [ data = [
map_post(post, u.name) map_post(post, username)
for post in posts for post in posts
] ]
@ -191,27 +197,23 @@ def search():
@app.route('/_api/v1/dopost', methods=['POST']) @app.route('/_api/v1/dopost', methods=['POST'])
@limiter.limit("50 / hour; 1 / 3 second") @limiter.limit("50 / hour; 1 / 3 second")
def do_post(): def do_post():
u = get_current_user() username = get_current_username()
allow_search = request.form.get('allow_search') allow_search = request.form.get('allow_search')
print(allow_search) print(allow_search)
content = request.form.get('text') content = request.form.get('text', '').strip()
content = content.strip() if content else None content = ('[tmp]\n' if username[:4] == 'tmp_' else '') + content
content = '[tmp]\n' + content if u.name[:4] == 'tmp_' else content
post_type = request.form.get('type') post_type = request.form.get('type')
cw = request.form.get('cw') cw = request.form.get('cw', '').strip()
cw = cw.strip() if cw else None
if not content or len(content) > 4096: if not content or len(content) > 4096 or len(cw) > 32:
abort(422)
if cw and len(cw) > 32:
abort(422) abort(422)
search_text = content.replace( search_text = content.replace(
'\n', '') if allow_search else '' '\n', '') if allow_search else ''
p = Post( p = Post(
name_hash=hash_name(u.name), name_hash=hash_name(username),
content=content, content=content,
search_text=search_text, search_text=search_text,
post_type=post_type, post_type=post_type,
@ -238,7 +240,7 @@ def do_post():
if not re.match('\\d+', tag): if not re.match('\\d+', tag):
db.session.add(TagRecord(tag=tag, pid=p.id)) db.session.add(TagRecord(tag=tag, pid=p.id))
db.session.add(Attention(name_hash=hash_name(u.name), pid=p.id)) db.session.add(Attention(name_hash=hash_name(username), pid=p.id))
db.session.commit() db.session.commit()
return { return {
@ -250,7 +252,7 @@ def do_post():
@app.route('/_api/v1/editcw', methods=['POST']) @app.route('/_api/v1/editcw', methods=['POST'])
@limiter.limit("50 / hour; 1 / 2 second") @limiter.limit("50 / hour; 1 / 2 second")
def edit_cw(): def edit_cw():
u = get_current_user() username = get_current_username()
cw = request.form.get('cw') cw = request.form.get('cw')
pid = get_num(request.form.get('pid')) pid = get_num(request.form.get('pid'))
@ -260,11 +262,8 @@ def edit_cw():
abort(422) abort(422)
post = Post.query.get_or_404(pid) post = Post.query.get_or_404(pid)
if post.deleted:
abort(451)
if not (u.name in app.config.get('ADMINS') if not check_can_del(username, post.name_hash):
or hash_name(u.name) == post.name_hash):
abort(403) abort(403)
post.cw = cw post.cw = cw
@ -275,21 +274,19 @@ def edit_cw():
@app.route('/_api/v1/getcomment') @app.route('/_api/v1/getcomment')
def get_comment(): def get_comment():
u = get_current_user() username = get_current_username()
pid = get_num(request.args.get('pid')) pid = get_num(request.args.get('pid'))
post = Post.query.get(pid) post = Post.query.get_or_404(pid)
if not post: if post.deleted and not check_can_del(username, post.name_hash):
abort(404)
if post.deleted:
abort(451) abort(451)
data = map_comment(post, u.name) data = map_comment(post, username)
return { return {
'code': 0, 'code': 0,
'attention': check_attention(u.name, pid), 'attention': check_attention(username, pid),
'likenum': post.likenum, 'likenum': post.likenum,
'data': data 'data': data
} }
@ -298,24 +295,24 @@ def get_comment():
@app.route('/_api/v1/docomment', methods=['POST']) @app.route('/_api/v1/docomment', methods=['POST'])
@limiter.limit("50 / hour; 1 / 3 second") @limiter.limit("50 / hour; 1 / 3 second")
def do_comment(): def do_comment():
u = get_current_user() username = get_current_username()
pid = get_num(request.form.get('pid')) pid = get_num(request.form.get('pid'))
post = Post.query.get(pid) post = Post.query.get(pid)
if not post: if not post:
abort(404) abort(404)
if post.deleted: if post.deleted and not check_can_del(username, post.name_hash):
abort(451) abort(451)
content = request.form.get('text') content = request.form.get('text')
content = content.strip() if content else None content = content.strip() if content else None
content = '[tmp]\n' + content if u.name[:4] == 'tmp_' else content content = '[tmp]\n' + content if username[:4] == 'tmp_' else content
if not content or len(content) > 4096: if not content or len(content) > 4096:
abort(422) abort(422)
c = Comment( c = Comment(
name_hash=hash_name(u.name), name_hash=hash_name(username),
content=content, content=content,
) )
post.comments.append(c) post.comments.append(c)
@ -325,11 +322,11 @@ def do_comment():
post.hot_score += 1 post.hot_score += 1
at = Attention.query.filter_by( at = Attention.query.filter_by(
name_hash=hash_name(u.name), pid=pid name_hash=hash_name(username), pid=pid
).first() ).first()
if not at: if not at:
at = Attention(name_hash=hash_name(u.name), pid=pid, disabled=False) at = Attention(name_hash=hash_name(username), pid=pid, disabled=False)
db.session.add(at) db.session.add(at)
post.likenum += 1 post.likenum += 1
if post.hot_score != -1: if post.hot_score != -1:
@ -350,8 +347,8 @@ def do_comment():
@app.route('/_api/v1/attention', methods=['POST']) @app.route('/_api/v1/attention', methods=['POST'])
@limiter.limit("200 / hour; 1 / second") @limiter.limit("200 / hour; 1 / second")
def attention(): def attention():
u = get_current_user() username = get_current_username()
if u.name[:4] == 'tmp_': if username[:4] == 'tmp_':
abort(403) abort(403)
s = request.form.get('switch') s = request.form.get('switch')
@ -365,11 +362,11 @@ def attention():
abort(404) abort(404)
at = Attention.query.filter_by( at = Attention.query.filter_by(
name_hash=hash_name(u.name), pid=pid name_hash=hash_name(username), pid=pid
).first() ).first()
if not at: if not at:
at = Attention(name_hash=hash_name(u.name), pid=pid, disabled=True) at = Attention(name_hash=hash_name(username), pid=pid, disabled=True)
db.session.add(at) db.session.add(at)
if post.hot_score != -1: if post.hot_score != -1:
post.hot_score += 2 post.hot_score += 2
@ -389,12 +386,12 @@ def attention():
@app.route('/_api/v1/getattention') @app.route('/_api/v1/getattention')
def get_attention(): def get_attention():
u = get_current_user() username = get_current_username()
ats = Attention.query.with_entities( ats = Attention.query.with_entities(
Attention.pid Attention.pid
).filter_by( ).filter_by(
name_hash=hash_name(u.name), disabled=False name_hash=hash_name(username), disabled=False
).all() ).all()
pids = [pid for pid, in ats] or [0] # sql not allow empty in pids = [pid for pid, in ats] or [0] # sql not allow empty in
@ -405,7 +402,7 @@ def get_attention():
).order_by(Post.id.desc()).all() ).order_by(Post.id.desc()).all()
data = [ data = [
map_post(post, u.name, 10) map_post(post, username, 10)
for post in posts for post in posts
] ]
@ -419,7 +416,7 @@ def get_attention():
@app.route('/_api/v1/delete', methods=['POST']) @app.route('/_api/v1/delete', methods=['POST'])
@limiter.limit("50 / hour; 1 / 3 second") @limiter.limit("50 / hour; 1 / 3 second")
def delete(): def delete():
u = get_current_user() username = get_current_username()
obj_type = request.form.get('type') obj_type = request.form.get('type')
obj_id = get_num(request.form.get('id')) obj_id = get_num(request.form.get('id'))
@ -436,7 +433,7 @@ def delete():
if not obj: if not obj:
abort(404) abort(404)
if obj.name_hash == hash_name(u.name): if obj.name_hash == hash_name(username):
if obj_type == 'pid': if obj_type == 'pid':
if len(obj.comments): if len(obj.comments):
abort(403) abort(403)
@ -445,12 +442,12 @@ def delete():
db.session.delete(obj) db.session.delete(obj)
else: else:
obj.deleted = True obj.deleted = True
elif u.name in app.config.get('ADMINS'): elif username in app.config.get('ADMINS'):
obj.deleted = True obj.deleted = True
db.session.add(Syslog( db.session.add(Syslog(
log_type='ADMIN DELETE', log_type='ADMIN DELETE',
log_detail=f"{obj_type}={obj_id}\n{note}", log_detail=f"{obj_type}={obj_id}\n{note}",
name_hash=hash_name(u.name) name_hash=hash_name(username)
)) ))
if note.startswith('!ban'): if note.startswith('!ban'):
db.session.add(Syslog( db.session.add(Syslog(
@ -467,7 +464,7 @@ def delete():
@app.route('/_api/v1/systemlog') @app.route('/_api/v1/systemlog')
def system_log(): def system_log():
u = get_current_user() username = get_current_username()
ss = Syslog.query.order_by(db.desc('timestamp')).limit(100).all() ss = Syslog.query.order_by(db.desc('timestamp')).limit(100).all()
@ -475,14 +472,14 @@ def system_log():
'start_time': app.config['START_TIME'], 'start_time': app.config['START_TIME'],
'salt': look(app.config['SALT']), 'salt': look(app.config['SALT']),
'tmp_token': tmp_token(), 'tmp_token': tmp_token(),
'data': [map_syslog(s, u) for s in ss] 'data': [map_syslog(s, username) for s in ss]
} }
@app.route('/_api/v1/report', methods=['POST']) @app.route('/_api/v1/report', methods=['POST'])
@limiter.limit("10 / hour; 1 / 3 second") @limiter.limit("10 / hour; 1 / 3 second")
def report(): def report():
u = get_current_user() username = get_current_username()
pid = get_num(request.form.get('pid')) pid = get_num(request.form.get('pid'))
@ -491,7 +488,7 @@ def report():
db.session.add(Syslog( db.session.add(Syslog(
log_type='REPORT', log_type='REPORT',
log_detail=f"pid={pid}\n{reason}", log_detail=f"pid={pid}\n{reason}",
name_hash=hash_name(u.name) name_hash=hash_name(username)
)) ))
post = Post.query.get(pid) post = Post.query.get(pid)
@ -505,9 +502,8 @@ def report():
@app.route('/_api/v1/update_score', methods=['POST']) @app.route('/_api/v1/update_score', methods=['POST'])
def edit_hot_score(): def edit_hot_score():
u = get_current_user() username = get_current_username()
if not is_admin(u.name): if not is_admin(username):
print(u.name)
abort(403) abort(403)
pid = request.form.get('pid', type=int) pid = request.form.get('pid', type=int)

3
models.py

@ -26,7 +26,8 @@ class Post(db.Model):
deleted = db.Column(db.Boolean, default=False) deleted = db.Column(db.Boolean, default=False)
is_reported = db.Column(db.Boolean, default=False) is_reported = db.Column(db.Boolean, default=False)
comment_timestamp = db.Column(db.Integer, default=0, index=True) comment_timestamp = db.Column(db.Integer, default=0, index=True)
hot_score = db.Column(db.Integer, default=0, nullable=False, server_default="0") hot_score = db.Column(db.Integer, default=0,
nullable=False, server_default="0")
comments = db.relationship('Comment', backref='post', lazy=True) comments = db.relationship('Comment', backref='post', lazy=True)

15
requirements.txt

@ -1,7 +1,8 @@
Flask>=1.1.2 Flask
Flask-Limit>=1.0.2 Flask-Limit
Flask-Limiter>=1.3.1 Flask-Limiter
Flask-Login>=0.5.0 Flask-Login
Flask-Migrate>=2.5.3 Flask-Migrate
Flask-SQLAlchemy>=2.4.4 Flask-SQLAlchemy
Mastodon.py>=1.5.1 Mastodon.py
redis

21
utils.py

@ -2,6 +2,7 @@ import hashlib
import time import time
from flask import request, abort, current_app from flask import request, abort, current_app
from models import User, Attention, Syslog from models import User, Attention, Syslog
from config import ADMINS, ENABLE_TMP
def get_config(key): def get_config(key):
@ -9,7 +10,7 @@ def get_config(key):
def is_admin(name): def is_admin(name):
return name in get_config('ADMINS') return name in ADMINS
def tmp_token(): def tmp_token():
@ -18,27 +19,29 @@ def tmp_token():
)[5:21] )[5:21]
def get_current_user(): def get_current_username():
token = request.headers.get('User-Token') or request.args.get('user_token') token = request.headers.get('User-Token') or request.args.get('user_token')
if not token: if not token:
abort(401) abort(401)
if len(token.split('_')) == 2 and get_config('ENABLE_TMP'): if len(token.split('_')) == 2 and ENABLE_TMP:
tt, suf = token.split('_') tt, suf = token.split('_')
if tt != tmp_token(): if tt != tmp_token():
abort(401) abort(401)
return User(name='tmp_' + suf) return 'tmp_' + suf
u = User.query.filter_by(token=token).first() u = User.query.filter_by(token=token).first()
if not u or Syslog.query.filter_by( if not u or Syslog.query.filter_by(
log_type='BANNED', name_hash=hash_name(u.name)).first(): log_type='BANNED', name_hash=hash_name(u.name)).first():
abort(401) abort(401)
return u return u.name
def hash_name(name): def hash_name(name):
print(name)
return hashlib.sha256( return hashlib.sha256(
(get_config('SALT') + name).encode('utf-8')).hexdigest() (get_config('SALT') + name).encode('utf-8')
).hexdigest()
def map_post(p, name, mc=50): def map_post(p, name, mc=50):
@ -81,10 +84,10 @@ def map_comment(p, name):
] ]
def map_syslog(s, u=None): def map_syslog(s, username):
return { return {
'type': s.log_type, 'type': s.log_type,
'detail': s.log_detail if check_can_del(u.name, s.name_hash) else '', 'detail': s.log_detail if check_can_del(username, s.name_hash) else '',
'user': look(s.name_hash), 'user': look(s.name_hash),
'timestamp': s.timestamp 'timestamp': s.timestamp
} }
@ -99,7 +102,7 @@ def check_attention(name, pid):
def check_can_del(name, author_hash): def check_can_del(name, author_hash):
return int(hash_name(name) == author_hash or is_admin(name)) return hash_name(name) == author_hash or is_admin(name)
def look(s): def look(s):

Loading…
Cancel
Save