From 5793455cf3f55c2b710a126bc1b4e3540d91e5c1 Mon Sep 17 00:00:00 2001 From: hole-thu Date: Fri, 24 Dec 2021 22:31:58 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BC=BA=E5=8C=96=E6=90=9C=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + config.sample.py | 2 + fix_n_comments.py | 8 ++ hole.py | 135 ++++++++++++------ migration_search_table.py | 22 +++ .../versions/865bf933ea82_add_n_comments.py | 28 ++++ .../9ac8682d438c_add_bool_can_search.py | 28 ++++ models.py | 40 +++++- utils.py | 8 +- 9 files changed, 225 insertions(+), 47 deletions(-) create mode 100644 fix_n_comments.py create mode 100644 migration_search_table.py create mode 100644 migrations/versions/865bf933ea82_add_n_comments.py create mode 100644 migrations/versions/9ac8682d438c_add_bool_can_search.py diff --git a/.gitignore b/.gitignore index 9062e9b..e7a9666 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ /backup/ /.venv/ +/libsimple __pycache__/ *.pyc diff --git a/config.sample.py b/config.sample.py index 9ee7c06..2c5a3b1 100644 --- a/config.sample.py +++ b/config.sample.py @@ -15,3 +15,5 @@ RDS_CONFIG = { 'port': 6379, 'decode_responses': True } +SEARCH_DB = 'hole_search.db' +EXT_SIMPLE_URL = 'libsimple/libsimple' diff --git a/fix_n_comments.py b/fix_n_comments.py new file mode 100644 index 0000000..ea83952 --- /dev/null +++ b/fix_n_comments.py @@ -0,0 +1,8 @@ +from hole import app +from models import Post, db + +with app.app_context(): + for post in Post.query: + post.n_comments = len([c for c in post.comments if not c.deleted]) + + db.session.commit() diff --git a/hole.py b/hole.py index 7663209..0e32684 100644 --- a/hole.py +++ b/hole.py @@ -9,8 +9,8 @@ from flask_migrate import Migrate from sqlalchemy.sql.expression import func from mastodon import Mastodon -from models import db, User, Post, Comment, Attention, TagRecord, Syslog -from utils import get_current_username, map_post, map_comment, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin, check_can_del, rds, RDS_KEY_POLL_OPTS, RDS_KEY_POLL_VOTES, gen_poll_dict, name_with_tmp_limit, RDS_KEY_BLOCK_SET, RDS_KEY_BLOCKED_COUNT, RDS_KEY_DANGEROUS_USERS, RDS_KEY_TITLE +from models import db, User, Post, Comment, Attention, TagRecord, Syslog, SearchDB +from utils import get_current_username, map_post, map_comments, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin, check_can_del, rds, RDS_KEY_POLL_OPTS, RDS_KEY_POLL_VOTES, gen_poll_dict, name_with_tmp_limit, RDS_KEY_BLOCK_SET, RDS_KEY_BLOCKED_COUNT, RDS_KEY_DANGEROUS_USERS, RDS_KEY_TITLE app = Flask(__name__) app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///hole.db' @@ -187,43 +187,74 @@ def search(): username = get_current_username() page = request.args.get('page', type=int, default=1) - pagesize = min(request.args.get('pagesize', type=int, default=200), 200) - keywords = request.args.get('keywords') + search_mode = request.args.get('search_mode', type=int) + pagesize = min(request.args.get('pagesize', type=int, default=PER_PAGE), 2 * PER_PAGE) + keywords = request.args.get('keywords', '').strip() + if not keywords: - abort(422) + raise APIError("搜索词不可为空") + if search_mode is None: + raise APIError("请点击“强制检查更新”,更新网页到最新版") - tag_pids = TagRecord.query.with_entities( - TagRecord.pid - ).filter_by( - tag=keywords - ).all() + data = [] - tag_pids = [ - tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in + if search_mode == 0: # tag 搜索 + tag_pids = TagRecord.query.with_entities( + TagRecord.pid + ).filter_by( + tag=keywords + ).limit(pagesize).offset((page - 1) * pagesize).all() - posts = Post.query.filter( - Post.search_text.like("%{}%".format(keywords)) - ).filter( - Post.id.notin_(tag_pids) - ).filter_by( - deleted=False, is_reported=False - ).order_by( - Post.id.desc() - ).limit(pagesize).offset((page - 1) * pagesize).all() + tag_pids = [ + tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in - if page == 1: - posts = Post.query.filter( - Post.id.in_(tag_pids) - ).filter_by( + posts = Post.query.filter(Post.id.in_(tag_pids)).filter_by( + deleted=False, is_reported=False + ).order_by(Post.id.desc()).all() + + data = [ + map_post(post, username) + for post in posts + ] + elif search_mode == 1: # 全文搜索 + search_db = SearchDB() + for highlighted_content, obj_type, obj_id in search_db.query( + keywords, pagesize, (page - 1) * pagesize + ): + if obj_type == 'post': + obj = Post.query.get(obj_id) + else: + obj = Comment.query.get(obj_id) + if not obj or obj.deleted: + continue + if obj_type == 'post': + post = obj + else: + post = obj.post + if not post or post.deleted or post.is_reported: + continue + obj.content = highlighted_content + if obj_type == 'post': + post_dict = map_post(post, username) + else: + post_dict = map_post(post, username, 1000) + post_dict['comments'] = [ + c for c in post_dict['comments'] if c['cid'] == obj_id + ] + + post_dict['key'] = "search_%s_%s" % (obj_type, obj_id) + data.append(post_dict) + del search_db + elif search_mode == 2: # 头衔 + posts = Post.query.filter_by(author_title=keywords).filter_by( deleted=False, is_reported=False ).order_by( Post.id.desc() - ).all() + posts - - data = [ - map_post(post, username) - for post in posts - ] + ).limit(pagesize).offset((page - 1) * pagesize).all() + data = [ + map_post(post, username) + for post in posts + ] return { 'code': 0, @@ -248,9 +279,6 @@ def do_post(): if not content or len(content) > 4096 or len(cw) > 32: raise APIError('无内容或超长') - search_text = content.replace( - '\n', '') if allow_search else '' - if poll_options and poll_options[0]: if len(poll_options) != len(set(poll_options)): raise APIError('有重复的投票选项') @@ -264,7 +292,7 @@ def do_post(): name_hash=name_hash, author_title=rds.hget(RDS_KEY_TITLE, name_hash) if use_title else None, content=content, - search_text=search_text, + allow_search=bool(allow_search), post_type=post_type, cw=cw or None, likenum=1, @@ -283,6 +311,12 @@ def do_post(): db.session.add(Attention(name_hash=hash_name(username), pid=p.id)) db.session.commit() + if allow_search: + search_db = SearchDB() + search_db.insert(content, 'post', p.id) + search_db.commit() + del search_db + rds.delete(RDS_KEY_POLL_OPTS % p.id) # 由于历史原因,现在的数据库里发布后删再发布可能导致id重复 if poll_options and poll_options[0]: rds.rpush(RDS_KEY_POLL_OPTS % p.id, *poll_options) @@ -326,7 +360,7 @@ def get_comment(): if post.deleted and not check_can_del(username, post.name_hash): abort(451) - data = map_comment(post, username) + data = map_comments(post, username) return { 'code': 0, @@ -365,6 +399,7 @@ def do_comment(): ) post.comments.append(c) post.comment_timestamp = c.timestamp + post.n_comments += 1 if post.hot_score != -1: post.hot_score += 1 @@ -386,6 +421,12 @@ def do_comment(): db.session.commit() + if post.allow_search: + search_db = SearchDB() + search_db.insert(content, 'comment', c.id) + search_db.commit() + del search_db + return { 'code': 0, 'data': pid @@ -397,7 +438,7 @@ def do_comment(): def attention(): username = get_current_username() if username[:4] == 'tmp_': - abort(403) + raise APIError('临时用户无法手动关注') s = request.form.get('switch', type=int) if s not in [0, 1]: @@ -474,23 +515,30 @@ def delete(): if note and len(note) > 100: abort(422) - obj = None + # 兼容 if obj_type == 'pid': - obj = Post.query.get(obj_id) + obj_type = 'post' elif obj_type == 'cid': + obj_type = 'comment' + + obj = None + if obj_type == 'post': + obj = Post.query.get(obj_id) + elif obj_type == 'comment': obj = Comment.query.get(obj_id) if not obj: abort(404) if obj.name_hash == hash_name(username): - if obj_type == 'pid': - if len(obj.comments): - abort(403) + if obj_type == 'post': + if obj.n_comments: + abort("已经有评论了") Attention.query.filter_by(pid=obj.id).delete() TagRecord.query.filter_by(pid=obj.id).delete() db.session.delete(obj) else: obj.deleted = True + elif username in app.config.get('ADMINS'): obj.deleted = True db.session.add(Syslog( @@ -507,6 +555,9 @@ def delete(): else: abort(403) + if obj_type == 'comment': + obj.post.n_comments -= 1 + db.session.commit() return {'code': 0} diff --git a/migration_search_table.py b/migration_search_table.py new file mode 100644 index 0000000..17cdffa --- /dev/null +++ b/migration_search_table.py @@ -0,0 +1,22 @@ +from hole import app +from models import SearchDB, Post, db + +search_db = SearchDB() +search_db.execute("DROP TABLE IF EXISTS search_content;") +search_db.execute("CREATE VIRTUAL TABLE search_content " + "USING fts5(content, target_type UNINDEXED, target_id UNINDEXED, tokenize = 'simple');") + +with app.app_context(): + for post in Post.query.filter_by(deleted=False): + if post.search_text: + search_db.insert(post.search_text, 'post', post.id) + post.allow_search = True + for comment in post.comments: + if not comment.deleted: + search_db.insert(comment.content, 'comment', comment.id) + else: + post.allow_search = False + + search_db.commit() + del search_db + db.session.commit() diff --git a/migrations/versions/865bf933ea82_add_n_comments.py b/migrations/versions/865bf933ea82_add_n_comments.py new file mode 100644 index 0000000..d84f259 --- /dev/null +++ b/migrations/versions/865bf933ea82_add_n_comments.py @@ -0,0 +1,28 @@ +"""add n_comments + +Revision ID: 865bf933ea82 +Revises: 9ac8682d438c +Create Date: 2021-12-24 20:21:53.928842 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '865bf933ea82' +down_revision = '9ac8682d438c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('post', sa.Column('n_comments', sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('post', 'n_comments') + # ### end Alembic commands ### diff --git a/migrations/versions/9ac8682d438c_add_bool_can_search.py b/migrations/versions/9ac8682d438c_add_bool_can_search.py new file mode 100644 index 0000000..0994f80 --- /dev/null +++ b/migrations/versions/9ac8682d438c_add_bool_can_search.py @@ -0,0 +1,28 @@ +"""add bool can_search + +Revision ID: 9ac8682d438c +Revises: 91e5c7d37d43 +Create Date: 2021-12-24 18:11:27.626988 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9ac8682d438c' +down_revision = '91e5c7d37d43' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('post', sa.Column('allow_search', sa.Boolean(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('post', 'allow_search') + # ### end Alembic commands ### diff --git a/models.py b/models.py index 7b97fb4..65b5773 100644 --- a/models.py +++ b/models.py @@ -1,8 +1,40 @@ -from flask_sqlalchemy import SQLAlchemy import time +from flask_sqlalchemy import SQLAlchemy +import sqlite3 + +from config import SEARCH_DB, EXT_SIMPLE_URL +# 搜索用的fts表放到单独的database里,为了不影响flask-migrate和避免死锁 db = SQLAlchemy() +SEARCH_INSERT_SQL = "INSERT INTO search_content VALUES(?, ?, ?);" +SEARCH_QUERY_SQL = "SELECT simple_highlight(search_content, 0, ' **', '** '), target_type, target_id FROM search_content WHERE content MATCH simple_query(?) ORDER BY rank LIMIT ? OFFSET ?;" + + +class SearchDB: + def __init__(self): + self.db = sqlite3.connect(SEARCH_DB) + self.db.enable_load_extension(True) + self.db.load_extension(EXT_SIMPLE_URL) + self.cursor = self.db.cursor() + + def __del__(self): + if hasattr(self, 'db') and self.db: + self.db.close() + del self.db + + def execute(self, sql, *params): + return self.cursor.execute(sql, params) + + def commit(self): + self.db.commit() + + def insert(self, *args): + return self.execute(SEARCH_INSERT_SQL, *args) + + def query(self, *args): + return self.execute(SEARCH_QUERY_SQL, *args) + class User(db.Model): id = db.Column(db.Integer, primary_key=True) @@ -21,10 +53,12 @@ class Post(db.Model): author_title = db.Column(db.String(10)) content = db.Column(db.String(4096)) search_text = db.Column(db.String(4096), default='', index=True) + allow_search = db.Column(db.Boolean, default=False) post_type = db.Column(db.String(8)) cw = db.Column(db.String(32)) file_url = db.Column(db.String(256)) likenum = db.Column(db.Integer, default=0) + n_comments = db.Column(db.Integer, default=0) timestamp = db.Column(db.Integer) deleted = db.Column(db.Boolean, default=False) is_reported = db.Column(db.Boolean, default=False) @@ -53,6 +87,10 @@ class Comment(db.Model): post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) + @property + def post(self): + return Post.query.get(self.post_id) + def __init__(self, **kwargs): super(Comment, self).__init__(**kwargs) self.timestamp = int(time.time()) diff --git a/utils.py b/utils.py index 3f34662..839681e 100644 --- a/utils.py +++ b/utils.py @@ -69,11 +69,11 @@ def map_post(p, name, mc=50): 'timestamp': p.timestamp, 'type': p.post_type, 'url': p.file_url, - 'reply': len(p.comments), - 'comments': map_comment(p, name) if len(p.comments) < mc else None, + 'reply': p.n_comments, + 'comments': map_comments(p, name) if p.n_comments < mc else None, 'attention': check_attention(name, p.id), 'can_del': check_can_del(name, p.name_hash), - 'allow_search': bool(p.search_text), + 'allow_search': p.allow_search, 'poll': None if blocked else gen_poll_dict(p.id, name), 'author_title': p.author_title } @@ -125,7 +125,7 @@ def is_blocked(target_name_hash, name): return False -def map_comment(p, name): +def map_comments(p, name): names = {p.name_hash: 0}