Browse Source

强化搜索

master
hole-thu 4 years ago
parent
commit
5793455cf3
  1. 1
      .gitignore
  2. 2
      config.sample.py
  3. 8
      fix_n_comments.py
  4. 135
      hole.py
  5. 22
      migration_search_table.py
  6. 28
      migrations/versions/865bf933ea82_add_n_comments.py
  7. 28
      migrations/versions/9ac8682d438c_add_bool_can_search.py
  8. 40
      models.py
  9. 8
      utils.py

1
.gitignore vendored

@ -1,5 +1,6 @@
/backup/
/.venv/
/libsimple
__pycache__/
*.pyc

2
config.sample.py

@ -15,3 +15,5 @@ RDS_CONFIG = {
'port': 6379,
'decode_responses': True
}
SEARCH_DB = 'hole_search.db'
EXT_SIMPLE_URL = 'libsimple/libsimple'

8
fix_n_comments.py

@ -0,0 +1,8 @@
from hole import app
from models import Post, db
with app.app_context():
for post in Post.query:
post.n_comments = len([c for c in post.comments if not c.deleted])
db.session.commit()

135
hole.py

@ -9,8 +9,8 @@ from flask_migrate import Migrate
from sqlalchemy.sql.expression import func
from mastodon import Mastodon
from models import db, User, Post, Comment, Attention, TagRecord, Syslog
from utils import get_current_username, map_post, map_comment, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin, check_can_del, rds, RDS_KEY_POLL_OPTS, RDS_KEY_POLL_VOTES, gen_poll_dict, name_with_tmp_limit, RDS_KEY_BLOCK_SET, RDS_KEY_BLOCKED_COUNT, RDS_KEY_DANGEROUS_USERS, RDS_KEY_TITLE
from models import db, User, Post, Comment, Attention, TagRecord, Syslog, SearchDB
from utils import get_current_username, map_post, map_comments, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin, check_can_del, rds, RDS_KEY_POLL_OPTS, RDS_KEY_POLL_VOTES, gen_poll_dict, name_with_tmp_limit, RDS_KEY_BLOCK_SET, RDS_KEY_BLOCKED_COUNT, RDS_KEY_DANGEROUS_USERS, RDS_KEY_TITLE
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///hole.db'
@ -187,43 +187,74 @@ def search():
username = get_current_username()
page = request.args.get('page', type=int, default=1)
pagesize = min(request.args.get('pagesize', type=int, default=200), 200)
keywords = request.args.get('keywords')
search_mode = request.args.get('search_mode', type=int)
pagesize = min(request.args.get('pagesize', type=int, default=PER_PAGE), 2 * PER_PAGE)
keywords = request.args.get('keywords', '').strip()
if not keywords:
abort(422)
raise APIError("搜索词不可为空")
if search_mode is None:
raise APIError("请点击“强制检查更新”,更新网页到最新版")
tag_pids = TagRecord.query.with_entities(
TagRecord.pid
).filter_by(
tag=keywords
).all()
data = []
tag_pids = [
tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in
if search_mode == 0: # tag 搜索
tag_pids = TagRecord.query.with_entities(
TagRecord.pid
).filter_by(
tag=keywords
).limit(pagesize).offset((page - 1) * pagesize).all()
posts = Post.query.filter(
Post.search_text.like("%{}%".format(keywords))
).filter(
Post.id.notin_(tag_pids)
).filter_by(
deleted=False, is_reported=False
).order_by(
Post.id.desc()
).limit(pagesize).offset((page - 1) * pagesize).all()
tag_pids = [
tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in
if page == 1:
posts = Post.query.filter(
Post.id.in_(tag_pids)
).filter_by(
posts = Post.query.filter(Post.id.in_(tag_pids)).filter_by(
deleted=False, is_reported=False
).order_by(Post.id.desc()).all()
data = [
map_post(post, username)
for post in posts
]
elif search_mode == 1: # 全文搜索
search_db = SearchDB()
for highlighted_content, obj_type, obj_id in search_db.query(
keywords, pagesize, (page - 1) * pagesize
):
if obj_type == 'post':
obj = Post.query.get(obj_id)
else:
obj = Comment.query.get(obj_id)
if not obj or obj.deleted:
continue
if obj_type == 'post':
post = obj
else:
post = obj.post
if not post or post.deleted or post.is_reported:
continue
obj.content = highlighted_content
if obj_type == 'post':
post_dict = map_post(post, username)
else:
post_dict = map_post(post, username, 1000)
post_dict['comments'] = [
c for c in post_dict['comments'] if c['cid'] == obj_id
]
post_dict['key'] = "search_%s_%s" % (obj_type, obj_id)
data.append(post_dict)
del search_db
elif search_mode == 2: # 头衔
posts = Post.query.filter_by(author_title=keywords).filter_by(
deleted=False, is_reported=False
).order_by(
Post.id.desc()
).all() + posts
data = [
map_post(post, username)
for post in posts
]
).limit(pagesize).offset((page - 1) * pagesize).all()
data = [
map_post(post, username)
for post in posts
]
return {
'code': 0,
@ -248,9 +279,6 @@ def do_post():
if not content or len(content) > 4096 or len(cw) > 32:
raise APIError('无内容或超长')
search_text = content.replace(
'\n', '') if allow_search else ''
if poll_options and poll_options[0]:
if len(poll_options) != len(set(poll_options)):
raise APIError('有重复的投票选项')
@ -264,7 +292,7 @@ def do_post():
name_hash=name_hash,
author_title=rds.hget(RDS_KEY_TITLE, name_hash) if use_title else None,
content=content,
search_text=search_text,
allow_search=bool(allow_search),
post_type=post_type,
cw=cw or None,
likenum=1,
@ -283,6 +311,12 @@ def do_post():
db.session.add(Attention(name_hash=hash_name(username), pid=p.id))
db.session.commit()
if allow_search:
search_db = SearchDB()
search_db.insert(content, 'post', p.id)
search_db.commit()
del search_db
rds.delete(RDS_KEY_POLL_OPTS % p.id) # 由于历史原因,现在的数据库里发布后删再发布可能导致id重复
if poll_options and poll_options[0]:
rds.rpush(RDS_KEY_POLL_OPTS % p.id, *poll_options)
@ -326,7 +360,7 @@ def get_comment():
if post.deleted and not check_can_del(username, post.name_hash):
abort(451)
data = map_comment(post, username)
data = map_comments(post, username)
return {
'code': 0,
@ -365,6 +399,7 @@ def do_comment():
)
post.comments.append(c)
post.comment_timestamp = c.timestamp
post.n_comments += 1
if post.hot_score != -1:
post.hot_score += 1
@ -386,6 +421,12 @@ def do_comment():
db.session.commit()
if post.allow_search:
search_db = SearchDB()
search_db.insert(content, 'comment', c.id)
search_db.commit()
del search_db
return {
'code': 0,
'data': pid
@ -397,7 +438,7 @@ def do_comment():
def attention():
username = get_current_username()
if username[:4] == 'tmp_':
abort(403)
raise APIError('临时用户无法手动关注')
s = request.form.get('switch', type=int)
if s not in [0, 1]:
@ -474,23 +515,30 @@ def delete():
if note and len(note) > 100:
abort(422)
obj = None
# 兼容
if obj_type == 'pid':
obj = Post.query.get(obj_id)
obj_type = 'post'
elif obj_type == 'cid':
obj_type = 'comment'
obj = None
if obj_type == 'post':
obj = Post.query.get(obj_id)
elif obj_type == 'comment':
obj = Comment.query.get(obj_id)
if not obj:
abort(404)
if obj.name_hash == hash_name(username):
if obj_type == 'pid':
if len(obj.comments):
abort(403)
if obj_type == 'post':
if obj.n_comments:
abort("已经有评论了")
Attention.query.filter_by(pid=obj.id).delete()
TagRecord.query.filter_by(pid=obj.id).delete()
db.session.delete(obj)
else:
obj.deleted = True
elif username in app.config.get('ADMINS'):
obj.deleted = True
db.session.add(Syslog(
@ -507,6 +555,9 @@ def delete():
else:
abort(403)
if obj_type == 'comment':
obj.post.n_comments -= 1
db.session.commit()
return {'code': 0}

22
migration_search_table.py

@ -0,0 +1,22 @@
from hole import app
from models import SearchDB, Post, db
search_db = SearchDB()
search_db.execute("DROP TABLE IF EXISTS search_content;")
search_db.execute("CREATE VIRTUAL TABLE search_content "
"USING fts5(content, target_type UNINDEXED, target_id UNINDEXED, tokenize = 'simple');")
with app.app_context():
for post in Post.query.filter_by(deleted=False):
if post.search_text:
search_db.insert(post.search_text, 'post', post.id)
post.allow_search = True
for comment in post.comments:
if not comment.deleted:
search_db.insert(comment.content, 'comment', comment.id)
else:
post.allow_search = False
search_db.commit()
del search_db
db.session.commit()

28
migrations/versions/865bf933ea82_add_n_comments.py

@ -0,0 +1,28 @@
"""add n_comments
Revision ID: 865bf933ea82
Revises: 9ac8682d438c
Create Date: 2021-12-24 20:21:53.928842
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '865bf933ea82'
down_revision = '9ac8682d438c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('post', sa.Column('n_comments', sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('post', 'n_comments')
# ### end Alembic commands ###

28
migrations/versions/9ac8682d438c_add_bool_can_search.py

@ -0,0 +1,28 @@
"""add bool can_search
Revision ID: 9ac8682d438c
Revises: 91e5c7d37d43
Create Date: 2021-12-24 18:11:27.626988
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '9ac8682d438c'
down_revision = '91e5c7d37d43'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('post', sa.Column('allow_search', sa.Boolean(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('post', 'allow_search')
# ### end Alembic commands ###

40
models.py

@ -1,8 +1,40 @@
from flask_sqlalchemy import SQLAlchemy
import time
from flask_sqlalchemy import SQLAlchemy
import sqlite3
from config import SEARCH_DB, EXT_SIMPLE_URL
# 搜索用的fts表放到单独的database里,为了不影响flask-migrate和避免死锁
db = SQLAlchemy()
SEARCH_INSERT_SQL = "INSERT INTO search_content VALUES(?, ?, ?);"
SEARCH_QUERY_SQL = "SELECT simple_highlight(search_content, 0, ' **', '** '), target_type, target_id FROM search_content WHERE content MATCH simple_query(?) ORDER BY rank LIMIT ? OFFSET ?;"
class SearchDB:
def __init__(self):
self.db = sqlite3.connect(SEARCH_DB)
self.db.enable_load_extension(True)
self.db.load_extension(EXT_SIMPLE_URL)
self.cursor = self.db.cursor()
def __del__(self):
if hasattr(self, 'db') and self.db:
self.db.close()
del self.db
def execute(self, sql, *params):
return self.cursor.execute(sql, params)
def commit(self):
self.db.commit()
def insert(self, *args):
return self.execute(SEARCH_INSERT_SQL, *args)
def query(self, *args):
return self.execute(SEARCH_QUERY_SQL, *args)
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
@ -21,10 +53,12 @@ class Post(db.Model):
author_title = db.Column(db.String(10))
content = db.Column(db.String(4096))
search_text = db.Column(db.String(4096), default='', index=True)
allow_search = db.Column(db.Boolean, default=False)
post_type = db.Column(db.String(8))
cw = db.Column(db.String(32))
file_url = db.Column(db.String(256))
likenum = db.Column(db.Integer, default=0)
n_comments = db.Column(db.Integer, default=0)
timestamp = db.Column(db.Integer)
deleted = db.Column(db.Boolean, default=False)
is_reported = db.Column(db.Boolean, default=False)
@ -53,6 +87,10 @@ class Comment(db.Model):
post_id = db.Column(db.Integer, db.ForeignKey('post.id'),
nullable=False)
@property
def post(self):
return Post.query.get(self.post_id)
def __init__(self, **kwargs):
super(Comment, self).__init__(**kwargs)
self.timestamp = int(time.time())

8
utils.py

@ -69,11 +69,11 @@ def map_post(p, name, mc=50):
'timestamp': p.timestamp,
'type': p.post_type,
'url': p.file_url,
'reply': len(p.comments),
'comments': map_comment(p, name) if len(p.comments) < mc else None,
'reply': p.n_comments,
'comments': map_comments(p, name) if p.n_comments < mc else None,
'attention': check_attention(name, p.id),
'can_del': check_can_del(name, p.name_hash),
'allow_search': bool(p.search_text),
'allow_search': p.allow_search,
'poll': None if blocked else gen_poll_dict(p.id, name),
'author_title': p.author_title
}
@ -125,7 +125,7 @@ def is_blocked(target_name_hash, name):
return False
def map_comment(p, name):
def map_comments(p, name):
names = {p.name_hash: 0}

Loading…
Cancel
Save