Browse Source

强化搜索

master
hole-thu 4 years ago
parent
commit
5793455cf3
  1. 1
      .gitignore
  2. 2
      config.sample.py
  3. 8
      fix_n_comments.py
  4. 135
      hole.py
  5. 22
      migration_search_table.py
  6. 28
      migrations/versions/865bf933ea82_add_n_comments.py
  7. 28
      migrations/versions/9ac8682d438c_add_bool_can_search.py
  8. 40
      models.py
  9. 8
      utils.py

1
.gitignore vendored

@ -1,5 +1,6 @@
/backup/ /backup/
/.venv/ /.venv/
/libsimple
__pycache__/ __pycache__/
*.pyc *.pyc

2
config.sample.py

@ -15,3 +15,5 @@ RDS_CONFIG = {
'port': 6379, 'port': 6379,
'decode_responses': True 'decode_responses': True
} }
SEARCH_DB = 'hole_search.db'
EXT_SIMPLE_URL = 'libsimple/libsimple'

8
fix_n_comments.py

@ -0,0 +1,8 @@
from hole import app
from models import Post, db
with app.app_context():
for post in Post.query:
post.n_comments = len([c for c in post.comments if not c.deleted])
db.session.commit()

135
hole.py

@ -9,8 +9,8 @@ from flask_migrate import Migrate
from sqlalchemy.sql.expression import func from sqlalchemy.sql.expression import func
from mastodon import Mastodon from mastodon import Mastodon
from models import db, User, Post, Comment, Attention, TagRecord, Syslog from models import db, User, Post, Comment, Attention, TagRecord, Syslog, SearchDB
from utils import get_current_username, map_post, map_comment, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin, check_can_del, rds, RDS_KEY_POLL_OPTS, RDS_KEY_POLL_VOTES, gen_poll_dict, name_with_tmp_limit, RDS_KEY_BLOCK_SET, RDS_KEY_BLOCKED_COUNT, RDS_KEY_DANGEROUS_USERS, RDS_KEY_TITLE from utils import get_current_username, map_post, map_comments, map_syslog, check_attention, hash_name, look, get_num, tmp_token, is_admin, check_can_del, rds, RDS_KEY_POLL_OPTS, RDS_KEY_POLL_VOTES, gen_poll_dict, name_with_tmp_limit, RDS_KEY_BLOCK_SET, RDS_KEY_BLOCKED_COUNT, RDS_KEY_DANGEROUS_USERS, RDS_KEY_TITLE
app = Flask(__name__) app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///hole.db' app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///hole.db'
@ -187,43 +187,74 @@ def search():
username = get_current_username() username = get_current_username()
page = request.args.get('page', type=int, default=1) page = request.args.get('page', type=int, default=1)
pagesize = min(request.args.get('pagesize', type=int, default=200), 200) search_mode = request.args.get('search_mode', type=int)
keywords = request.args.get('keywords') pagesize = min(request.args.get('pagesize', type=int, default=PER_PAGE), 2 * PER_PAGE)
keywords = request.args.get('keywords', '').strip()
if not keywords: if not keywords:
abort(422) raise APIError("搜索词不可为空")
if search_mode is None:
raise APIError("请点击“强制检查更新”,更新网页到最新版")
tag_pids = TagRecord.query.with_entities( data = []
TagRecord.pid
).filter_by(
tag=keywords
).all()
tag_pids = [ if search_mode == 0: # tag 搜索
tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in tag_pids = TagRecord.query.with_entities(
TagRecord.pid
).filter_by(
tag=keywords
).limit(pagesize).offset((page - 1) * pagesize).all()
posts = Post.query.filter( tag_pids = [
Post.search_text.like("%{}%".format(keywords)) tag_pid for tag_pid, in tag_pids] or [0] # sql not allowed empty in
).filter(
Post.id.notin_(tag_pids)
).filter_by(
deleted=False, is_reported=False
).order_by(
Post.id.desc()
).limit(pagesize).offset((page - 1) * pagesize).all()
if page == 1: posts = Post.query.filter(Post.id.in_(tag_pids)).filter_by(
posts = Post.query.filter( deleted=False, is_reported=False
Post.id.in_(tag_pids) ).order_by(Post.id.desc()).all()
).filter_by(
data = [
map_post(post, username)
for post in posts
]
elif search_mode == 1: # 全文搜索
search_db = SearchDB()
for highlighted_content, obj_type, obj_id in search_db.query(
keywords, pagesize, (page - 1) * pagesize
):
if obj_type == 'post':
obj = Post.query.get(obj_id)
else:
obj = Comment.query.get(obj_id)
if not obj or obj.deleted:
continue
if obj_type == 'post':
post = obj
else:
post = obj.post
if not post or post.deleted or post.is_reported:
continue
obj.content = highlighted_content
if obj_type == 'post':
post_dict = map_post(post, username)
else:
post_dict = map_post(post, username, 1000)
post_dict['comments'] = [
c for c in post_dict['comments'] if c['cid'] == obj_id
]
post_dict['key'] = "search_%s_%s" % (obj_type, obj_id)
data.append(post_dict)
del search_db
elif search_mode == 2: # 头衔
posts = Post.query.filter_by(author_title=keywords).filter_by(
deleted=False, is_reported=False deleted=False, is_reported=False
).order_by( ).order_by(
Post.id.desc() Post.id.desc()
).all() + posts ).limit(pagesize).offset((page - 1) * pagesize).all()
data = [
data = [ map_post(post, username)
map_post(post, username) for post in posts
for post in posts ]
]
return { return {
'code': 0, 'code': 0,
@ -248,9 +279,6 @@ def do_post():
if not content or len(content) > 4096 or len(cw) > 32: if not content or len(content) > 4096 or len(cw) > 32:
raise APIError('无内容或超长') raise APIError('无内容或超长')
search_text = content.replace(
'\n', '') if allow_search else ''
if poll_options and poll_options[0]: if poll_options and poll_options[0]:
if len(poll_options) != len(set(poll_options)): if len(poll_options) != len(set(poll_options)):
raise APIError('有重复的投票选项') raise APIError('有重复的投票选项')
@ -264,7 +292,7 @@ def do_post():
name_hash=name_hash, name_hash=name_hash,
author_title=rds.hget(RDS_KEY_TITLE, name_hash) if use_title else None, author_title=rds.hget(RDS_KEY_TITLE, name_hash) if use_title else None,
content=content, content=content,
search_text=search_text, allow_search=bool(allow_search),
post_type=post_type, post_type=post_type,
cw=cw or None, cw=cw or None,
likenum=1, likenum=1,
@ -283,6 +311,12 @@ def do_post():
db.session.add(Attention(name_hash=hash_name(username), pid=p.id)) db.session.add(Attention(name_hash=hash_name(username), pid=p.id))
db.session.commit() db.session.commit()
if allow_search:
search_db = SearchDB()
search_db.insert(content, 'post', p.id)
search_db.commit()
del search_db
rds.delete(RDS_KEY_POLL_OPTS % p.id) # 由于历史原因,现在的数据库里发布后删再发布可能导致id重复 rds.delete(RDS_KEY_POLL_OPTS % p.id) # 由于历史原因,现在的数据库里发布后删再发布可能导致id重复
if poll_options and poll_options[0]: if poll_options and poll_options[0]:
rds.rpush(RDS_KEY_POLL_OPTS % p.id, *poll_options) rds.rpush(RDS_KEY_POLL_OPTS % p.id, *poll_options)
@ -326,7 +360,7 @@ def get_comment():
if post.deleted and not check_can_del(username, post.name_hash): if post.deleted and not check_can_del(username, post.name_hash):
abort(451) abort(451)
data = map_comment(post, username) data = map_comments(post, username)
return { return {
'code': 0, 'code': 0,
@ -365,6 +399,7 @@ def do_comment():
) )
post.comments.append(c) post.comments.append(c)
post.comment_timestamp = c.timestamp post.comment_timestamp = c.timestamp
post.n_comments += 1
if post.hot_score != -1: if post.hot_score != -1:
post.hot_score += 1 post.hot_score += 1
@ -386,6 +421,12 @@ def do_comment():
db.session.commit() db.session.commit()
if post.allow_search:
search_db = SearchDB()
search_db.insert(content, 'comment', c.id)
search_db.commit()
del search_db
return { return {
'code': 0, 'code': 0,
'data': pid 'data': pid
@ -397,7 +438,7 @@ def do_comment():
def attention(): def attention():
username = get_current_username() username = get_current_username()
if username[:4] == 'tmp_': if username[:4] == 'tmp_':
abort(403) raise APIError('临时用户无法手动关注')
s = request.form.get('switch', type=int) s = request.form.get('switch', type=int)
if s not in [0, 1]: if s not in [0, 1]:
@ -474,23 +515,30 @@ def delete():
if note and len(note) > 100: if note and len(note) > 100:
abort(422) abort(422)
obj = None # 兼容
if obj_type == 'pid': if obj_type == 'pid':
obj = Post.query.get(obj_id) obj_type = 'post'
elif obj_type == 'cid': elif obj_type == 'cid':
obj_type = 'comment'
obj = None
if obj_type == 'post':
obj = Post.query.get(obj_id)
elif obj_type == 'comment':
obj = Comment.query.get(obj_id) obj = Comment.query.get(obj_id)
if not obj: if not obj:
abort(404) abort(404)
if obj.name_hash == hash_name(username): if obj.name_hash == hash_name(username):
if obj_type == 'pid': if obj_type == 'post':
if len(obj.comments): if obj.n_comments:
abort(403) abort("已经有评论了")
Attention.query.filter_by(pid=obj.id).delete() Attention.query.filter_by(pid=obj.id).delete()
TagRecord.query.filter_by(pid=obj.id).delete() TagRecord.query.filter_by(pid=obj.id).delete()
db.session.delete(obj) db.session.delete(obj)
else: else:
obj.deleted = True obj.deleted = True
elif username in app.config.get('ADMINS'): elif username in app.config.get('ADMINS'):
obj.deleted = True obj.deleted = True
db.session.add(Syslog( db.session.add(Syslog(
@ -507,6 +555,9 @@ def delete():
else: else:
abort(403) abort(403)
if obj_type == 'comment':
obj.post.n_comments -= 1
db.session.commit() db.session.commit()
return {'code': 0} return {'code': 0}

22
migration_search_table.py

@ -0,0 +1,22 @@
from hole import app
from models import SearchDB, Post, db
search_db = SearchDB()
search_db.execute("DROP TABLE IF EXISTS search_content;")
search_db.execute("CREATE VIRTUAL TABLE search_content "
"USING fts5(content, target_type UNINDEXED, target_id UNINDEXED, tokenize = 'simple');")
with app.app_context():
for post in Post.query.filter_by(deleted=False):
if post.search_text:
search_db.insert(post.search_text, 'post', post.id)
post.allow_search = True
for comment in post.comments:
if not comment.deleted:
search_db.insert(comment.content, 'comment', comment.id)
else:
post.allow_search = False
search_db.commit()
del search_db
db.session.commit()

28
migrations/versions/865bf933ea82_add_n_comments.py

@ -0,0 +1,28 @@
"""add n_comments
Revision ID: 865bf933ea82
Revises: 9ac8682d438c
Create Date: 2021-12-24 20:21:53.928842
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '865bf933ea82'
down_revision = '9ac8682d438c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('post', sa.Column('n_comments', sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('post', 'n_comments')
# ### end Alembic commands ###

28
migrations/versions/9ac8682d438c_add_bool_can_search.py

@ -0,0 +1,28 @@
"""add bool can_search
Revision ID: 9ac8682d438c
Revises: 91e5c7d37d43
Create Date: 2021-12-24 18:11:27.626988
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '9ac8682d438c'
down_revision = '91e5c7d37d43'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('post', sa.Column('allow_search', sa.Boolean(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('post', 'allow_search')
# ### end Alembic commands ###

40
models.py

@ -1,8 +1,40 @@
from flask_sqlalchemy import SQLAlchemy
import time import time
from flask_sqlalchemy import SQLAlchemy
import sqlite3
from config import SEARCH_DB, EXT_SIMPLE_URL
# 搜索用的fts表放到单独的database里,为了不影响flask-migrate和避免死锁
db = SQLAlchemy() db = SQLAlchemy()
SEARCH_INSERT_SQL = "INSERT INTO search_content VALUES(?, ?, ?);"
SEARCH_QUERY_SQL = "SELECT simple_highlight(search_content, 0, ' **', '** '), target_type, target_id FROM search_content WHERE content MATCH simple_query(?) ORDER BY rank LIMIT ? OFFSET ?;"
class SearchDB:
def __init__(self):
self.db = sqlite3.connect(SEARCH_DB)
self.db.enable_load_extension(True)
self.db.load_extension(EXT_SIMPLE_URL)
self.cursor = self.db.cursor()
def __del__(self):
if hasattr(self, 'db') and self.db:
self.db.close()
del self.db
def execute(self, sql, *params):
return self.cursor.execute(sql, params)
def commit(self):
self.db.commit()
def insert(self, *args):
return self.execute(SEARCH_INSERT_SQL, *args)
def query(self, *args):
return self.execute(SEARCH_QUERY_SQL, *args)
class User(db.Model): class User(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
@ -21,10 +53,12 @@ class Post(db.Model):
author_title = db.Column(db.String(10)) author_title = db.Column(db.String(10))
content = db.Column(db.String(4096)) content = db.Column(db.String(4096))
search_text = db.Column(db.String(4096), default='', index=True) search_text = db.Column(db.String(4096), default='', index=True)
allow_search = db.Column(db.Boolean, default=False)
post_type = db.Column(db.String(8)) post_type = db.Column(db.String(8))
cw = db.Column(db.String(32)) cw = db.Column(db.String(32))
file_url = db.Column(db.String(256)) file_url = db.Column(db.String(256))
likenum = db.Column(db.Integer, default=0) likenum = db.Column(db.Integer, default=0)
n_comments = db.Column(db.Integer, default=0)
timestamp = db.Column(db.Integer) timestamp = db.Column(db.Integer)
deleted = db.Column(db.Boolean, default=False) deleted = db.Column(db.Boolean, default=False)
is_reported = db.Column(db.Boolean, default=False) is_reported = db.Column(db.Boolean, default=False)
@ -53,6 +87,10 @@ class Comment(db.Model):
post_id = db.Column(db.Integer, db.ForeignKey('post.id'), post_id = db.Column(db.Integer, db.ForeignKey('post.id'),
nullable=False) nullable=False)
@property
def post(self):
return Post.query.get(self.post_id)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(Comment, self).__init__(**kwargs) super(Comment, self).__init__(**kwargs)
self.timestamp = int(time.time()) self.timestamp = int(time.time())

8
utils.py

@ -69,11 +69,11 @@ def map_post(p, name, mc=50):
'timestamp': p.timestamp, 'timestamp': p.timestamp,
'type': p.post_type, 'type': p.post_type,
'url': p.file_url, 'url': p.file_url,
'reply': len(p.comments), 'reply': p.n_comments,
'comments': map_comment(p, name) if len(p.comments) < mc else None, 'comments': map_comments(p, name) if p.n_comments < mc else None,
'attention': check_attention(name, p.id), 'attention': check_attention(name, p.id),
'can_del': check_can_del(name, p.name_hash), 'can_del': check_can_del(name, p.name_hash),
'allow_search': bool(p.search_text), 'allow_search': p.allow_search,
'poll': None if blocked else gen_poll_dict(p.id, name), 'poll': None if blocked else gen_poll_dict(p.id, name),
'author_title': p.author_title 'author_title': p.author_title
} }
@ -125,7 +125,7 @@ def is_blocked(target_name_hash, name):
return False return False
def map_comment(p, name): def map_comments(p, name):
names = {p.name_hash: 0} names = {p.name_hash: 0}

Loading…
Cancel
Save