Compare commits

...

14 Commits

  1. 4
      .gitignore
  2. 5
      clear_redis.py
  3. 10
      config.sample.py
  4. 8
      fix_n_comments.py
  5. 422
      hole.py
  6. 8
      hot_score_attenuation.py
  7. 22
      migration_search_table.py
  8. 1
      migrations/README
  9. 50
      migrations/alembic.ini
  10. 91
      migrations/env.py
  11. 24
      migrations/script.py.mako
  12. 30
      migrations/versions/4f4a8c914911_add_hot_score.py
  13. 28
      migrations/versions/865bf933ea82_add_n_comments.py
  14. 30
      migrations/versions/91e5c7d37d43_add_author_title.py
  15. 28
      migrations/versions/9ac8682d438c_add_bool_can_search.py
  16. 47
      models.py
  17. 15
      requirements.txt
  18. 103
      utils.py

4
.gitignore vendored

@ -1,4 +1,6 @@
/migrations/ /backup/
/.venv/
/libsimple
__pycache__/ __pycache__/
*.pyc *.pyc

5
clear_redis.py

@ -0,0 +1,5 @@
# 每次重置时执行
from utils import rds, RDS_KEY_TITLE, RDS_KEY_BLOCKED_COUNT
rds.delete(RDS_KEY_BLOCKED_COUNT)
rds.delete(RDS_KEY_TITLE)

10
config.sample.py

@ -1,5 +1,3 @@
import random
import string
import time import time
SQLALCHEMY_DATABASE_URI = 'sqlite:///hole.db' SQLALCHEMY_DATABASE_URI = 'sqlite:///hole.db'
@ -9,7 +7,13 @@ CLIENT_ID = '<id>'
CLIENT_SECRET = '<secret>' CLIENT_SECRET = '<secret>'
MASTODON_URL = 'https://mastodon.social' MASTODON_URL = 'https://mastodon.social'
REDIRECT_URI = 'http://hole.thu.monster/_auth' REDIRECT_URI = 'http://hole.thu.monster/_auth'
SALT = ''.join(random.choices(string.ascii_letters + string.digits, k=32))
ADMINS = ['cs_114514'] ADMINS = ['cs_114514']
START_TIME = int(time.time()) START_TIME = int(time.time())
ENABLE_TMP = True ENABLE_TMP = True
RDS_CONFIG = {
'host': 'localhost',
'port': 6379,
'decode_responses': True
}
SEARCH_DB = 'hole_search.db'
EXT_SIMPLE_URL = 'libsimple/libsimple'

8
fix_n_comments.py

@ -0,0 +1,8 @@
from hole import app
from models import Post, db
with app.app_context():
for post in Post.query:
post.n_comments = len([c for c in post.comments if not c.deleted])
db.session.commit()

422
hole.py

@ -9,19 +9,21 @@ from flask_migrate import Migrate
from sqlalchemy.sql.expression import func from sqlalchemy.sql.expression import func
from mastodon import Mastodon from mastodon import Mastodon
from models import db, User, Post, Comment, Attention, TagRecord, Syslog from models import db, User, Post, Comment, Attention, TagRecord, Syslog, SearchDB
from utils import get_current_user, map_post, map_comment, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin from utils import get_current_username, map_post, map_comments, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin, check_can_del, rds, RDS_KEY_POLL_OPTS, RDS_KEY_POLL_VOTES, gen_poll_dict, name_with_tmp_limit, RDS_KEY_BLOCK_SET, RDS_KEY_BLOCKED_COUNT, RDS_KEY_DANGEROUS_USERS, RDS_KEY_TITLE
app = Flask(__name__) app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///hole.db' app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///hole.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['JSON_AS_ASCII'] = False app.config['JSON_AS_ASCII'] = False
app.config['SALT'] = ''.join(random.choices(
string.ascii_letters + string.digits, k=32
))
app.config.from_pyfile('config.py') app.config.from_pyfile('config.py')
db.init_app(app) db.init_app(app)
migrate = Migrate(app, db) migrate = Migrate(app, db)
CS_LOGIN_URL = Mastodon(api_base_url=app.config['MASTODON_URL']) \ CS_LOGIN_URL = Mastodon(api_base_url=app.config['MASTODON_URL']) \
.auth_request_url( .auth_request_url(
client_id=app.config['CLIENT_ID'], client_id=app.config['CLIENT_ID'],
@ -36,6 +38,22 @@ limiter = Limiter(
) )
PER_PAGE = 50 PER_PAGE = 50
DANGEROUS_USER_THRESHOLD = 10
class APIError(Exception):
msg = '未知错误'
def __init__(self, msg):
self.msg = msg
def __str__(self):
return str(self.msg)
@app.errorhandler(APIError)
def handle_api_error(e):
return {'code': 1, 'msg': e.msg}
@app.route('/_login') @app.route('/_login')
@ -86,7 +104,7 @@ def auth():
@app.route('/_api/v1/getlist') @app.route('/_api/v1/getlist')
def get_list(): def get_list():
u = get_current_user() username = get_current_username()
p = request.args.get('p', type=int, default=1) p = request.args.get('p', type=int, default=1)
order_mode = request.args.get('order_mode', type=int, default=0) order_mode = request.args.get('order_mode', type=int, default=0)
@ -111,11 +129,12 @@ def get_list():
posts = query.order_by(order).paginate(p, PER_PAGE) posts = query.order_by(order).paginate(p, PER_PAGE)
data = list(map(map_post, posts.items, [u.name] * len(posts.items))) data = list(map(map_post, posts.items, [username] * len(posts.items)))
return { return {
'code': 0, 'code': 0,
'tmp_token': tmp_token(), 'tmp_token': tmp_token(),
'custom_title': rds.hget(RDS_KEY_TITLE, hash_name(username)),
'count': len(data), 'count': len(data),
'data': data 'data': data
} }
@ -123,15 +142,39 @@ def get_list():
@app.route('/_api/v1/getone') @app.route('/_api/v1/getone')
def get_one(): def get_one():
u = get_current_user() username = get_current_username()
pid = request.args.get('pid', type=int) pid = request.args.get('pid', type=int)
post = Post.query.get_or_404(pid) post = Post.query.get_or_404(pid)
if post.deleted or post.is_reported: if post.deleted or post.is_reported and not (
check_can_del(username, post.name_hash)
):
abort(451) abort(451)
data = map_post(post, u.name) data = map_post(post, username)
return {
'code': 0,
'data': data
}
@app.route('/_api/v1/getmulti')
def get_multi():
username = get_current_username()
pids = request.args.getlist('pids')
pids = pids[:500] or [0]
posts = Post.query.filter(
Post.id.in_(pids)
).filter_by(
deleted=False
).order_by(
Post.id.desc()
).all()
data = [map_post(post, username) for post in posts]
return { return {
'code': 0, 'code': 0,
@ -141,45 +184,77 @@ def get_one():
@app.route('/_api/v1/search') @app.route('/_api/v1/search')
def search(): def search():
u = get_current_user() username = get_current_username()
page = request.args.get('page', type=int, default=1) page = request.args.get('page', type=int, default=1)
pagesize = min(request.args.get('pagesize', type=int, default=200), 200) search_mode = request.args.get('search_mode', type=int)
keywords = request.args.get('keywords') pagesize = min(request.args.get('pagesize', type=int, default=PER_PAGE), 2 * PER_PAGE)
keywords = request.args.get('keywords', '').strip()
if not keywords: if not keywords:
abort(422) raise APIError("搜索词不可为空")
if search_mode is None:
raise APIError("请点击“强制检查更新”,更新网页到最新版")
tag_pids = TagRecord.query.with_entities( data = []
TagRecord.pid
).filter_by(
tag=keywords
).all()
tag_pids = [tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in if search_mode == 0: # tag 搜索
tag_pids = TagRecord.query.with_entities(
TagRecord.pid
).filter_by(tag=keywords).order_by(
TagRecord.pid.desc()
).limit(pagesize).offset((page - 1) * pagesize).all()
posts = Post.query.filter( tag_pids = [
Post.search_text.like("%{}%".format(keywords)) tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in
).filter(
Post.id.notin_(tag_pids)
).filter_by(
deleted=False, is_reported=False
).order_by(
Post.id.desc()
).limit(pagesize).offset((page - 1) * pagesize).all()
if page == 1: posts = Post.query.filter(Post.id.in_(tag_pids)).filter_by(
posts = Post.query.filter( deleted=False, is_reported=False
Post.id.in_(tag_pids) ).order_by(Post.id.desc()).all()
).filter_by(
data = [
map_post(post, username)
for post in posts
]
elif search_mode == 1: # 全文搜索
search_db = SearchDB()
for highlighted_content, obj_type, obj_id in search_db.query(
keywords, pagesize, (page - 1) * pagesize
):
if obj_type == 'post':
obj = Post.query.get(obj_id)
else:
obj = Comment.query.get(obj_id)
if not obj or obj.deleted:
continue
if obj_type == 'post':
post = obj
else:
post = obj.post
if not post or post.deleted or post.is_reported:
continue
obj.content = highlighted_content
if obj_type == 'post':
post_dict = map_post(post, username)
else:
post_dict = map_post(post, username, 1000)
post_dict['comments'] = [
c for c in post_dict['comments'] if c['cid'] == obj_id
]
post_dict['key'] = "search_%s_%s" % (obj_type, obj_id)
data.append(post_dict)
del search_db
elif search_mode == 2: # 头衔
posts = Post.query.filter_by(author_title=keywords).filter_by(
deleted=False, is_reported=False deleted=False, is_reported=False
).order_by( ).order_by(
Post.id.desc() Post.id.desc()
).all() + posts ).limit(pagesize).offset((page - 1) * pagesize).all()
data = [
data = [ map_post(post, username)
map_post(post, u.name) for post in posts
for post in posts ]
]
return { return {
'code': 0, 'code': 0,
@ -191,56 +266,61 @@ def search():
@app.route('/_api/v1/dopost', methods=['POST']) @app.route('/_api/v1/dopost', methods=['POST'])
@limiter.limit("50 / hour; 1 / 3 second") @limiter.limit("50 / hour; 1 / 3 second")
def do_post(): def do_post():
u = get_current_user() username = get_current_username()
allow_search = request.form.get('allow_search') allow_search = request.form.get('allow_search')
print(allow_search) content = request.form.get('text', '').strip()
content = request.form.get('text') content = ('[tmp]\n' if username[:4] == 'tmp_' else '') + content
content = content.strip() if content else None
content = '[tmp]\n' + content if u.name[:4] == 'tmp_' else content
post_type = request.form.get('type') post_type = request.form.get('type')
cw = request.form.get('cw') cw = request.form.get('cw', '').strip()
cw = cw.strip() if cw else None poll_options = request.form.getlist('poll_options')
use_title = request.form.get('use_title')
if not content or len(content) > 4096:
abort(422) if not content or len(content) > 4096 or len(cw) > 32:
if cw and len(cw) > 32: raise APIError('无内容或超长')
abort(422)
if poll_options and poll_options[0]:
search_text = content.replace( if len(poll_options) != len(set(poll_options)):
'\n', '') if allow_search else '' raise APIError('有重复的投票选项')
if len(poll_options) > 8:
raise APIError('选项过多')
if max(map(len, poll_options)) > 32:
raise APIError('选项过长')
name_hash = hash_name(username)
p = Post( p = Post(
name_hash=hash_name(u.name), name_hash=name_hash,
author_title=rds.hget(RDS_KEY_TITLE, name_hash) if use_title else None,
content=content, content=content,
search_text=search_text, allow_search=bool(allow_search),
post_type=post_type, post_type=post_type,
cw=cw or None, cw=cw or None,
likenum=1, likenum=1,
comments=[] comments=[]
) )
if post_type == 'text':
pass
elif post_type == 'image':
# TODO
p.file_url = 'foo bar'
else:
abort(422)
db.session.add(p) db.session.add(p)
db.session.commit() db.session.commit()
tags = re.findall('(^|\\s)#([^#\\s]{1,32})', content) tags = re.findall('(^|\\s)#([^#\\s]{1,32})', content)
# print(tags)
for t in tags: for t in tags:
tag = t[1] tag = t[1]
if not re.match('\\d+', tag): if not re.match('\\d+', tag):
db.session.add(TagRecord(tag=tag, pid=p.id)) db.session.add(TagRecord(tag=tag, pid=p.id))
db.session.add(Attention(name_hash=hash_name(u.name), pid=p.id)) db.session.add(Attention(name_hash=hash_name(username), pid=p.id))
db.session.commit() db.session.commit()
if allow_search:
search_db = SearchDB()
search_db.insert(content, 'post', p.id)
search_db.commit()
del search_db
rds.delete(RDS_KEY_POLL_OPTS % p.id) # 由于历史原因,现在的数据库里发布后删再发布可能导致id重复
if poll_options and poll_options[0]:
rds.rpush(RDS_KEY_POLL_OPTS % p.id, *poll_options)
return { return {
'code': 0, 'code': 0,
'date': p.id 'date': p.id
@ -250,7 +330,7 @@ def do_post():
@app.route('/_api/v1/editcw', methods=['POST']) @app.route('/_api/v1/editcw', methods=['POST'])
@limiter.limit("50 / hour; 1 / 2 second") @limiter.limit("50 / hour; 1 / 2 second")
def edit_cw(): def edit_cw():
u = get_current_user() username = get_current_username()
cw = request.form.get('cw') cw = request.form.get('cw')
pid = get_num(request.form.get('pid')) pid = get_num(request.form.get('pid'))
@ -260,11 +340,8 @@ def edit_cw():
abort(422) abort(422)
post = Post.query.get_or_404(pid) post = Post.query.get_or_404(pid)
if post.deleted:
abort(451)
if not (u.name in app.config.get('ADMINS') if not check_can_del(username, post.name_hash):
or hash_name(u.name) == post.name_hash):
abort(403) abort(403)
post.cw = cw post.cw = cw
@ -275,21 +352,19 @@ def edit_cw():
@app.route('/_api/v1/getcomment') @app.route('/_api/v1/getcomment')
def get_comment(): def get_comment():
u = get_current_user() username = get_current_username()
pid = get_num(request.args.get('pid')) pid = get_num(request.args.get('pid'))
post = Post.query.get(pid) post = Post.query.get_or_404(pid)
if not post: if post.deleted and not check_can_del(username, post.name_hash):
abort(404)
if post.deleted:
abort(451) abort(451)
data = map_comment(post, u.name) data = map_comments(post, username)
return { return {
'code': 0, 'code': 0,
'attention': check_attention(u.name, pid), 'attention': check_attention(username, pid),
'likenum': post.likenum, 'likenum': post.likenum,
'data': data 'data': data
} }
@ -298,38 +373,43 @@ def get_comment():
@app.route('/_api/v1/docomment', methods=['POST']) @app.route('/_api/v1/docomment', methods=['POST'])
@limiter.limit("50 / hour; 1 / 3 second") @limiter.limit("50 / hour; 1 / 3 second")
def do_comment(): def do_comment():
u = get_current_user() username = get_current_username()
pid = get_num(request.form.get('pid')) pid = get_num(request.form.get('pid'))
post = Post.query.get(pid) post = Post.query.get(pid)
if not post: if not post:
abort(404) abort(404)
if post.deleted: if post.deleted and not check_can_del(username, post.name_hash):
abort(451) abort(451)
content = request.form.get('text') content = request.form.get('text', '').strip()
content = content.strip() if content else None if username.startswith('tmp_'):
content = '[tmp]\n' + content if u.name[:4] == 'tmp_' else content content = '[tmp]\n' + content
if not content or len(content) > 4096: if not content or len(content) > 4096:
abort(422) abort(422)
use_title = request.form.get('use_title')
name_hash = hash_name(username)
c = Comment( c = Comment(
name_hash=hash_name(u.name), name_hash=name_hash,
author_title=rds.hget(RDS_KEY_TITLE, name_hash) if use_title else None,
content=content, content=content,
) )
post.comments.append(c) post.comments.append(c)
post.comment_timestamp = c.timestamp post.comment_timestamp = c.timestamp
post.n_comments += 1
if post.hot_score != -1: if post.hot_score != -1:
post.hot_score += 1 post.hot_score += 1
at = Attention.query.filter_by( at = Attention.query.filter_by(
name_hash=hash_name(u.name), pid=pid name_hash=hash_name(username), pid=pid
).first() ).first()
if not at: if not at:
at = Attention(name_hash=hash_name(u.name), pid=pid, disabled=False) at = Attention(name_hash=hash_name(username), pid=pid, disabled=False)
db.session.add(at) db.session.add(at)
post.likenum += 1 post.likenum += 1
if post.hot_score != -1: if post.hot_score != -1:
@ -341,6 +421,12 @@ def do_comment():
db.session.commit() db.session.commit()
if post.allow_search:
search_db = SearchDB()
search_db.insert(content, 'comment', c.id)
search_db.commit()
del search_db
return { return {
'code': 0, 'code': 0,
'data': pid 'data': pid
@ -350,51 +436,52 @@ def do_comment():
@app.route('/_api/v1/attention', methods=['POST']) @app.route('/_api/v1/attention', methods=['POST'])
@limiter.limit("200 / hour; 1 / second") @limiter.limit("200 / hour; 1 / second")
def attention(): def attention():
u = get_current_user() username = get_current_username()
if u.name[:4] == 'tmp_': if username[:4] == 'tmp_':
abort(403) raise APIError('临时用户无法手动关注')
s = request.form.get('switch') s = request.form.get('switch', type=int)
if s not in ['0', '1']: if s not in [0, 1]:
abort(422) abort(422)
pid = get_num(request.form.get('pid')) pid = request.form.get('pid', type=int)
post = Post.query.get(pid) post = Post.query.get_or_404(pid)
if not post:
abort(404)
at = Attention.query.filter_by( at = Attention.query.filter_by(
name_hash=hash_name(u.name), pid=pid name_hash=hash_name(username), pid=pid
).first() ).first()
if not at: if not at:
at = Attention(name_hash=hash_name(u.name), pid=pid, disabled=True) at = Attention(name_hash=hash_name(username), pid=pid, disabled=True)
db.session.add(at) db.session.add(at)
if post.hot_score != -1: if post.hot_score != -1:
post.hot_score += 2 post.hot_score += 2
if(at.disabled != (s == '0')): if at.disabled == bool(s):
at.disabled = (s == '0') at.disabled = not bool(s)
post.likenum += 1 - 2 * int(s == '0') post.likenum += 2 * s - 1
if is_admin(username) and s:
post.is_reported = False
db.session.commit() db.session.commit()
return { return {
'code': 0, 'code': 0,
'likenum': post.likenum, 'likenum': post.likenum,
'attention': (s == '1') 'attention': bool(s)
} }
@app.route('/_api/v1/getattention') @app.route('/_api/v1/getattention')
def get_attention(): def get_attention():
u = get_current_user() username = get_current_username()
ats = Attention.query.with_entities( ats = Attention.query.with_entities(
Attention.pid Attention.pid
).filter_by( ).filter_by(
name_hash=hash_name(u.name), disabled=False name_hash=hash_name(username), disabled=False
).all() ).all()
pids = [pid for pid, in ats] or [0] # sql not allow empty in pids = [pid for pid, in ats] or [0] # sql not allow empty in
@ -405,7 +492,7 @@ def get_attention():
).order_by(Post.id.desc()).all() ).order_by(Post.id.desc()).all()
data = [ data = [
map_post(post, u.name, 10) map_post(post, username, 10)
for post in posts for post in posts
] ]
@ -419,7 +506,7 @@ def get_attention():
@app.route('/_api/v1/delete', methods=['POST']) @app.route('/_api/v1/delete', methods=['POST'])
@limiter.limit("50 / hour; 1 / 3 second") @limiter.limit("50 / hour; 1 / 3 second")
def delete(): def delete():
u = get_current_user() username = get_current_username()
obj_type = request.form.get('type') obj_type = request.form.get('type')
obj_id = get_num(request.form.get('id')) obj_id = get_num(request.form.get('id'))
@ -428,29 +515,36 @@ def delete():
if note and len(note) > 100: if note and len(note) > 100:
abort(422) abort(422)
obj = None # 兼容
if obj_type == 'pid': if obj_type == 'pid':
obj = Post.query.get(obj_id) obj_type = 'post'
elif obj_type == 'cid': elif obj_type == 'cid':
obj_type = 'comment'
obj = None
if obj_type == 'post':
obj = Post.query.get(obj_id)
elif obj_type == 'comment':
obj = Comment.query.get(obj_id) obj = Comment.query.get(obj_id)
if not obj: if not obj:
abort(404) abort(404)
if obj.name_hash == hash_name(u.name): if obj.name_hash == hash_name(username):
if obj_type == 'pid': if obj_type == 'post':
if len(obj.comments): if obj.n_comments:
abort(403) abort("已经有评论了")
Attention.query.filter_by(pid=obj.id).delete() Attention.query.filter_by(pid=obj.id).delete()
TagRecord.query.filter_by(pid=obj.id).delete() TagRecord.query.filter_by(pid=obj.id).delete()
db.session.delete(obj) db.session.delete(obj)
else: else:
obj.deleted = True obj.deleted = True
elif u.name in app.config.get('ADMINS'):
elif username in app.config.get('ADMINS'):
obj.deleted = True obj.deleted = True
db.session.add(Syslog( db.session.add(Syslog(
log_type='ADMIN DELETE', log_type='ADMIN DELETE',
log_detail=f"{obj_type}={obj_id}\n{note}", log_detail=f"{obj_type}={obj_id}\n{note}",
name_hash=hash_name(u.name) name_hash=hash_name(username)
)) ))
if note.startswith('!ban'): if note.startswith('!ban'):
db.session.add(Syslog( db.session.add(Syslog(
@ -461,13 +555,16 @@ def delete():
else: else:
abort(403) abort(403)
if obj_type == 'comment':
obj.post.n_comments -= 1
db.session.commit() db.session.commit()
return {'code': 0} return {'code': 0}
@app.route('/_api/v1/systemlog') @app.route('/_api/v1/systemlog')
def system_log(): def system_log():
u = get_current_user() username = get_current_username()
ss = Syslog.query.order_by(db.desc('timestamp')).limit(100).all() ss = Syslog.query.order_by(db.desc('timestamp')).limit(100).all()
@ -475,23 +572,22 @@ def system_log():
'start_time': app.config['START_TIME'], 'start_time': app.config['START_TIME'],
'salt': look(app.config['SALT']), 'salt': look(app.config['SALT']),
'tmp_token': tmp_token(), 'tmp_token': tmp_token(),
'data': [map_syslog(s, u) for s in ss] 'custom_title': rds.hget(RDS_KEY_TITLE, hash_name(username)),
'data': [map_syslog(s, username) for s in ss]
} }
@app.route('/_api/v1/report', methods=['POST']) @app.route('/_api/v1/report', methods=['POST'])
@limiter.limit("10 / hour; 1 / 3 second") @limiter.limit("10 / hour; 1 / 3 second")
def report(): def report():
u = get_current_user() username = get_current_username()
pid = get_num(request.form.get('pid')) pid = get_num(request.form.get('pid'))
reason = request.form.get('reason', '') reason = request.form.get('reason', '')
db.session.add(Syslog( db.session.add(Syslog(
log_type='REPORT', log_type='REPORT',
log_detail=f"pid={pid}\n{reason}", log_detail=f"pid={pid}\n{reason}",
name_hash=hash_name(u.name) name_hash=hash_name(username)
)) ))
post = Post.query.get(pid) post = Post.query.get(pid)
@ -505,9 +601,8 @@ def report():
@app.route('/_api/v1/update_score', methods=['POST']) @app.route('/_api/v1/update_score', methods=['POST'])
def edit_hot_score(): def edit_hot_score():
u = get_current_user() username = get_current_username()
if not is_admin(u.name): if not is_admin(username):
print(u.name)
abort(403) abort(403)
pid = request.form.get('pid', type=int) pid = request.form.get('pid', type=int)
@ -520,5 +615,90 @@ def edit_hot_score():
return {'code': 0} return {'code': 0}
@app.route('/_api/v1/vote', methods=['POST'])
@limiter.limit("100 / hour; 1 / 2 second")
def add_vote():
username = get_current_username()
username = name_with_tmp_limit(username)
pid = request.form.get('pid', type=int)
vote = request.form.get('vote')
if not rds.exists(RDS_KEY_POLL_OPTS % pid):
abort(404)
opts = rds.lrange(RDS_KEY_POLL_OPTS % pid, 0, -1)
for idx, opt in enumerate(opts):
if rds.sismember(RDS_KEY_POLL_VOTES % (pid, idx), hash_name(username)):
raise APIError('已经投过票了')
if vote not in opts:
raise APIError('无效的选项')
rds.sadd(RDS_KEY_POLL_VOTES % (pid, opts.index(vote)), hash_name(username))
return {
'code': 0,
'data': gen_poll_dict(pid, username)
}
@app.route('/_api/v1/block', methods=['POST'])
@limiter.limit("15 / hour; 1 / 2 second")
def block_user_by_target():
username = get_current_username()
target_type = request.form.get('type')
target_id = request.form.get('id', type=int)
if username.startswith('tmp_'):
raise APIError('临时用户无法拉黑')
if target_type == 'post':
target = Post.query.get_or_404(target_id)
elif target_type == 'comment':
target = Comment.query.get_or_404(target_id)
else:
raise APIError('无效的type')
if hash_name(username) == target.name_hash:
raise APIError('不可拉黑自己')
if is_admin(username):
rds.sadd(RDS_KEY_DANGEROUS_USERS, target.name_hash)
curr_cnt = rds.hget(RDS_KEY_BLOCKED_COUNT, target.name_hash)
else:
if rds.sismember(RDS_KEY_BLOCK_SET % username, target.name_hash):
raise APIError('已经拉黑了')
rds.sadd(RDS_KEY_BLOCK_SET % username, target.name_hash)
curr_cnt = rds.hincrby(RDS_KEY_BLOCKED_COUNT, target.name_hash, 1)
if curr_cnt >= DANGEROUS_USER_THRESHOLD:
rds.sadd(RDS_KEY_DANGEROUS_USERS, target.name_hash)
return {
'code': 0,
'data': {
'curr': curr_cnt,
'threshold': DANGEROUS_USER_THRESHOLD
}
}
@app.route('/_api/v1/title', methods=['POST'])
@limiter.limit("10 / hour; 1 / 2 second")
def set_title():
username = get_current_username()
title = request.form.get('title')
if not title:
rds.hdel(RDS_KEY_TITLE, hash_name(username))
else:
if len(title) > 10:
raise APIError('自定义头衔太长')
if title in rds.hvals(RDS_KEY_TITLE): # 如果未来量大还是另外用个set维护
raise APIError('已经被使用了')
rds.hset(RDS_KEY_TITLE, hash_name(username), title)
return {'code': 0}
if __name__ == '__main__': if __name__ == '__main__':
app.run(debug=True) app.run(debug=True)

8
hot_score_attenuation.py

@ -1,9 +1,13 @@
import time
from hole import app from hole import app
from models import Post, db from models import Post, db
with app.app_context(): with app.app_context():
for p in Post.query.filter( for p in Post.query.filter(
Post.hot_score > 0 Post.hot_score > 10
).all(): ).all():
p.hot_score = int(p.hot_score * 0.9) if time.time() - p.timestamp > 60 * 60 * 24 * 3:
p.hot_score = 10
else:
p.hot_score = int(p.hot_score * 0.9)
db.session.commit() db.session.commit()

22
migration_search_table.py

@ -0,0 +1,22 @@
from hole import app
from models import SearchDB, Post, db
search_db = SearchDB()
search_db.execute("DROP TABLE IF EXISTS search_content;")
search_db.execute("CREATE VIRTUAL TABLE search_content "
"USING fts5(content, target_type UNINDEXED, target_id UNINDEXED, tokenize = 'simple');")
with app.app_context():
for post in Post.query.filter_by(deleted=False):
if post.search_text:
search_db.insert(post.search_text, 'post', post.id)
post.allow_search = True
for comment in post.comments:
if not comment.deleted:
search_db.insert(comment.content, 'comment', comment.id)
else:
post.allow_search = False
search_db.commit()
del search_db
db.session.commit()

1
migrations/README

@ -0,0 +1 @@
Single-database configuration for Flask.

50
migrations/alembic.ini

@ -0,0 +1,50 @@
# A generic, single database configuration.
[alembic]
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic,flask_migrate
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[logger_flask_migrate]
level = INFO
handlers =
qualname = flask_migrate
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

91
migrations/env.py

@ -0,0 +1,91 @@
from __future__ import with_statement
import logging
from logging.config import fileConfig
from flask import current_app
from alembic import context
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
logger = logging.getLogger('alembic.env')
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
config.set_main_option(
'sqlalchemy.url',
str(current_app.extensions['migrate'].db.get_engine().url).replace(
'%', '%%'))
target_metadata = current_app.extensions['migrate'].db.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url, target_metadata=target_metadata, literal_binds=True
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
# this callback is used to prevent an auto-migration from being generated
# when there are no changes to the schema
# reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html
def process_revision_directives(context, revision, directives):
if getattr(config.cmd_opts, 'autogenerate', False):
script = directives[0]
if script.upgrade_ops.is_empty():
directives[:] = []
logger.info('No changes in schema detected.')
connectable = current_app.extensions['migrate'].db.get_engine()
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
process_revision_directives=process_revision_directives,
**current_app.extensions['migrate'].configure_args
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

24
migrations/script.py.mako

@ -0,0 +1,24 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}

30
migrations/versions/4f4a8c914911_add_hot_score.py

@ -0,0 +1,30 @@
"""add hot score
Revision ID: 4f4a8c914911
Revises:
Create Date: 2021-12-18 03:37:13.716502
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4f4a8c914911'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('post', sa.Column('hot_score', sa.Integer(), server_default='0', nullable=False))
op.create_index(op.f('ix_post_comment_timestamp'), 'post', ['comment_timestamp'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_post_comment_timestamp'), table_name='post')
op.drop_column('post', 'hot_score')
# ### end Alembic commands ###

28
migrations/versions/865bf933ea82_add_n_comments.py

@ -0,0 +1,28 @@
"""add n_comments
Revision ID: 865bf933ea82
Revises: 9ac8682d438c
Create Date: 2021-12-24 20:21:53.928842
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '865bf933ea82'
down_revision = '9ac8682d438c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('post', sa.Column('n_comments', sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('post', 'n_comments')
# ### end Alembic commands ###

30
migrations/versions/91e5c7d37d43_add_author_title.py

@ -0,0 +1,30 @@
"""add author_title
Revision ID: 91e5c7d37d43
Revises: 4f4a8c914911
Create Date: 2021-12-23 17:31:49.909672
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '91e5c7d37d43'
down_revision = '4f4a8c914911'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('comment', sa.Column('author_title', sa.String(length=10), nullable=True))
op.add_column('post', sa.Column('author_title', sa.String(length=10), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('post', 'author_title')
op.drop_column('comment', 'author_title')
# ### end Alembic commands ###

28
migrations/versions/9ac8682d438c_add_bool_can_search.py

@ -0,0 +1,28 @@
"""add bool can_search
Revision ID: 9ac8682d438c
Revises: 91e5c7d37d43
Create Date: 2021-12-24 18:11:27.626988
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '9ac8682d438c'
down_revision = '91e5c7d37d43'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('post', sa.Column('allow_search', sa.Boolean(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('post', 'allow_search')
# ### end Alembic commands ###

47
models.py

@ -1,8 +1,40 @@
from flask_sqlalchemy import SQLAlchemy
import time import time
from flask_sqlalchemy import SQLAlchemy
import sqlite3
from config import SEARCH_DB, EXT_SIMPLE_URL
# 搜索用的fts表放到单独的database里,为了不影响flask-migrate和避免死锁
db = SQLAlchemy() db = SQLAlchemy()
SEARCH_INSERT_SQL = "INSERT INTO search_content VALUES(?, ?, ?);"
SEARCH_QUERY_SQL = "SELECT simple_highlight(search_content, 0, ' **', '** '), target_type, target_id FROM search_content WHERE content MATCH simple_query(?) ORDER BY rank LIMIT ? OFFSET ?;"
class SearchDB:
def __init__(self):
self.db = sqlite3.connect(SEARCH_DB)
self.db.enable_load_extension(True)
self.db.load_extension(EXT_SIMPLE_URL)
self.cursor = self.db.cursor()
def __del__(self):
if hasattr(self, 'db') and self.db:
self.db.close()
del self.db
def execute(self, sql, *params):
return self.cursor.execute(sql, params)
def commit(self):
self.db.commit()
def insert(self, *args):
return self.execute(SEARCH_INSERT_SQL, *args)
def query(self, *args):
return self.execute(SEARCH_QUERY_SQL, *args)
class User(db.Model): class User(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
@ -14,19 +46,25 @@ class User(db.Model):
class Post(db.Model): class Post(db.Model):
__table_args__ = {'sqlite_autoincrement': True}
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
name_hash = db.Column(db.String(64)) name_hash = db.Column(db.String(64))
author_title = db.Column(db.String(10))
content = db.Column(db.String(4096)) content = db.Column(db.String(4096))
search_text = db.Column(db.String(4096), default='', index=True) search_text = db.Column(db.String(4096), default='', index=True)
allow_search = db.Column(db.Boolean, default=False)
post_type = db.Column(db.String(8)) post_type = db.Column(db.String(8))
cw = db.Column(db.String(32)) cw = db.Column(db.String(32))
file_url = db.Column(db.String(256)) file_url = db.Column(db.String(256))
likenum = db.Column(db.Integer, default=0) likenum = db.Column(db.Integer, default=0)
n_comments = db.Column(db.Integer, default=0)
timestamp = db.Column(db.Integer) timestamp = db.Column(db.Integer)
deleted = db.Column(db.Boolean, default=False) deleted = db.Column(db.Boolean, default=False)
is_reported = db.Column(db.Boolean, default=False) is_reported = db.Column(db.Boolean, default=False)
comment_timestamp = db.Column(db.Integer, default=0, index=True) comment_timestamp = db.Column(db.Integer, default=0, index=True)
hot_score = db.Column(db.Integer, default=0, nullable=False, server_default="0") hot_score = db.Column(db.Integer, default=0,
nullable=False, server_default="0")
comments = db.relationship('Comment', backref='post', lazy=True) comments = db.relationship('Comment', backref='post', lazy=True)
@ -41,6 +79,7 @@ class Post(db.Model):
class Comment(db.Model): class Comment(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
name_hash = db.Column(db.String(64)) name_hash = db.Column(db.String(64))
author_title = db.Column(db.String(10))
content = db.Column(db.String(4096)) content = db.Column(db.String(4096))
timestamp = db.Column(db.Integer) timestamp = db.Column(db.Integer)
deleted = db.Column(db.Boolean, default=False) deleted = db.Column(db.Boolean, default=False)
@ -48,6 +87,10 @@ class Comment(db.Model):
post_id = db.Column(db.Integer, db.ForeignKey('post.id'), post_id = db.Column(db.Integer, db.ForeignKey('post.id'),
nullable=False) nullable=False)
@property
def post(self):
return Post.query.get(self.post_id)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(Comment, self).__init__(**kwargs) super(Comment, self).__init__(**kwargs)
self.timestamp = int(time.time()) self.timestamp = int(time.time())

15
requirements.txt

@ -1,7 +1,8 @@
Flask>=1.1.2 Flask
Flask-Limit>=1.0.2 Flask-Limit
Flask-Limiter>=1.3.1 Flask-Limiter
Flask-Login>=0.5.0 Flask-Login
Flask-Migrate>=2.5.3 Flask-Migrate
Flask-SQLAlchemy>=2.4.4 Flask-SQLAlchemy
Mastodon.py>=1.5.1 Mastodon.py
redis

103
utils.py

@ -1,7 +1,22 @@
import hashlib import hashlib
import time import time
import redis
from datetime import date
from flask import request, abort, current_app from flask import request, abort, current_app
from models import User, Attention, Syslog from models import User, Attention, Syslog
from config import RDS_CONFIG, ADMINS, ENABLE_TMP
RDS_KEY_POLL_OPTS = 'hole_thu:poll_opts:%s'
RDS_KEY_POLL_VOTES = 'hole_thu:poll_votes:%s:%s'
RDS_KEY_BLOCK_SET = 'hole_thu:block_list:%s' # key的参数是name而非namehash,为了方便清理和持续拉黑。拉黑名单不那么敏感,应该可以接受后台实名。value是namehash。
RDS_KEY_BLOCKED_COUNT = 'hole_thu:blocked_count' # namehash -> 被拉黑次数
RDS_KEY_DANGEROUS_USERS = 'hole_thu:dangerous_users'
RDS_KEY_TITLE = 'hole_thu:title' # 用户自己设置的专属头衔, namehash -> 头衔
rds = redis.Redis(**RDS_CONFIG)
def get_config(key): def get_config(key):
@ -9,7 +24,7 @@ def get_config(key):
def is_admin(name): def is_admin(name):
return name in get_config('ADMINS') return name in ADMINS
def tmp_token(): def tmp_token():
@ -18,50 +33,99 @@ def tmp_token():
)[5:21] )[5:21]
def get_current_user(): def get_current_username() -> str:
token = request.headers.get('User-Token') or request.args.get('user_token') token = request.headers.get('User-Token') or request.args.get('user_token')
if not token: if not token:
abort(401) abort(401)
if len(token.split('_')) == 2 and get_config('ENABLE_TMP'): if len(token.split('_')) == 2 and ENABLE_TMP:
tt, suf = token.split('_') tt, suf = token.split('_')
if tt != tmp_token(): if tt != tmp_token():
abort(401) abort(401)
return User(name='tmp_' + suf) return 'tmp_' + suf
u = User.query.filter_by(token=token).first() u = User.query.filter_by(token=token).first()
if not u or Syslog.query.filter_by( if not u or Syslog.query.filter_by(
log_type='BANNED', name_hash=hash_name(u.name)).first(): log_type='BANNED', name_hash=hash_name(u.name)).first():
abort(401) abort(401)
return u return u.name
def hash_name(name): def hash_name(name):
return hashlib.sha256( return hashlib.sha256(
(get_config('SALT') + name).encode('utf-8')).hexdigest() (get_config('SALT') + name).encode('utf-8')
).hexdigest()
def map_post(p, name, mc=50): def map_post(p, name, mc=50):
blocked = is_blocked(p.name_hash, name)
# TODO: 如果未来量大还是sql里not in一下
r = { r = {
'blocked': blocked,
'pid': p.id, 'pid': p.id,
'likenum': p.likenum, 'likenum': p.likenum,
'cw': p.cw, 'cw': p.cw,
'text': p.content, 'text': '' if blocked else p.content,
'timestamp': p.timestamp, 'timestamp': p.timestamp,
'type': p.post_type, 'type': p.post_type,
'url': p.file_url, 'url': p.file_url,
'reply': len(p.comments), 'reply': p.n_comments,
'comments': map_comment(p, name) if len(p.comments) < mc else None, 'comments': map_comments(p, name) if p.n_comments < mc else None,
'attention': check_attention(name, p.id), 'attention': check_attention(name, p.id),
'can_del': check_can_del(name, p.name_hash), 'can_del': check_can_del(name, p.name_hash),
'allow_search': bool(p.search_text) 'allow_search': p.allow_search,
'poll': None if blocked else gen_poll_dict(p.id, name),
'author_title': p.author_title
} }
if is_admin(name): if is_admin(name):
r['hot_score'] = p.hot_score r['hot_score'] = p.hot_score
if rds.sismember(RDS_KEY_DANGEROUS_USERS, p.name_hash):
r['dangerous_user'] = p.name_hash[:4]
r['blocked_count'] = rds.hget(RDS_KEY_BLOCKED_COUNT, p.name_hash)
r['is_reported'] = p.is_reported
return r return r
def map_comment(p, name): def gen_poll_dict(pid, name):
if not rds.exists(RDS_KEY_POLL_OPTS % pid):
return None
name = name_with_tmp_limit(name)
vote = None
answers = []
for idx, opt in enumerate(rds.lrange(RDS_KEY_POLL_OPTS % pid, 0, -1)):
answers.append({
'option': opt,
'votes': rds.scard(RDS_KEY_POLL_VOTES % (pid, idx))
})
if rds.sismember(RDS_KEY_POLL_VOTES % (pid, idx), hash_name(name)):
vote = opt
return {
'answers': answers,
'vote': vote
}
def name_with_tmp_limit(name: str) -> str:
return 'tmp:%s' % date.today() if name.startswith(
'tmp_') else name
def is_blocked(target_name_hash, name):
if rds.sismember(RDS_KEY_BLOCK_SET % name, target_name_hash):
return True
if rds.sismember(
RDS_KEY_DANGEROUS_USERS, target_name_hash
) and not (
is_admin(name) or rds.sismember(
RDS_KEY_DANGEROUS_USERS, hash_name(name))
):
return True
return False
def map_comments(p, name):
names = {p.name_hash: 0} names = {p.name_hash: 0}
@ -71,20 +135,27 @@ def map_comment(p, name):
return names[nh] return names[nh]
return [{ return [{
'blocked': (blocked := is_blocked(c.name_hash, name)),
'cid': c.id, 'cid': c.id,
'name_id': gen_name_id(c.name_hash), 'name_id': gen_name_id(c.name_hash),
'author_title': c.author_title,
'pid': p.id, 'pid': p.id,
'text': c.content, 'text': '' if blocked else c.content,
'timestamp': c.timestamp, 'timestamp': c.timestamp,
'can_del': check_can_del(name, c.name_hash) 'can_del': check_can_del(name, c.name_hash),
**({
'dangerous_user': c.name_hash[:4] if rds.sismember(
RDS_KEY_DANGEROUS_USERS, c.name_hash) else None,
'blocked_count': rds.hget(RDS_KEY_BLOCKED_COUNT, c.name_hash)
} if is_admin(name) else {})
} for c in p.comments if not (c.deleted and gen_name_id(c.name_hash) >= 0) } for c in p.comments if not (c.deleted and gen_name_id(c.name_hash) >= 0)
] ]
def map_syslog(s, u=None): def map_syslog(s, username):
return { return {
'type': s.log_type, 'type': s.log_type,
'detail': s.log_detail if check_can_del(u.name, s.name_hash) else '', 'detail': s.log_detail if check_can_del(username, s.name_hash) else '',
'user': look(s.name_hash), 'user': look(s.name_hash),
'timestamp': s.timestamp 'timestamp': s.timestamp
} }
@ -99,7 +170,7 @@ def check_attention(name, pid):
def check_can_del(name, author_hash): def check_can_del(name, author_hash):
return int(hash_name(name) == author_hash or is_admin(name)) return hash_name(name) == author_hash or is_admin(name)
def look(s): def look(s):

Loading…
Cancel
Save