133 lines
3.8 KiB
Python
133 lines
3.8 KiB
Python
import math
|
||
import re
|
||
from collections import Counter
|
||
from functools import lru_cache
|
||
from typing import Any, Dict, List, Tuple
|
||
|
||
from .qa import load_qa_data, top_questions
|
||
|
||
|
||
_WORD_RE = re.compile(r"[\u4e00-\u9fffA-Za-z0-9]+")
|
||
|
||
|
||
def _normalize(text: str) -> str:
|
||
if not text:
|
||
return ""
|
||
return "".join(_WORD_RE.findall(text)).lower()
|
||
|
||
|
||
def _tokens(text: str) -> List[str]:
|
||
"""
|
||
轻量中文检索分词:
|
||
- 先做字符级规范化(保留中日韩统一表意文字 + 字母数字)
|
||
- 生成 2-gram + 1-gram(兼容短查询)
|
||
"""
|
||
s = _normalize(text)
|
||
if not s:
|
||
return []
|
||
chars = list(s)
|
||
if len(chars) == 1:
|
||
return chars
|
||
bigrams = [chars[i] + chars[i + 1] for i in range(len(chars) - 1)]
|
||
return bigrams + chars
|
||
|
||
|
||
def _item_tokens(item: Dict[str, Any]) -> List[str]:
|
||
"""
|
||
支持“手动分词”:
|
||
- 若数据里提供 tokens(list[str] 或以空白分隔的 str),直接使用
|
||
- 否则退化为本地规则分词(不依赖模型)
|
||
"""
|
||
provided = item.get("tokens")
|
||
if isinstance(provided, list) and all(isinstance(x, str) for x in provided):
|
||
return [x.strip().lower() for x in provided if x and x.strip()]
|
||
if isinstance(provided, str) and provided.strip():
|
||
return [x.strip().lower() for x in provided.split() if x.strip()]
|
||
|
||
text = f"{item.get('question', '')}\n{item.get('answer', '')}"
|
||
return _tokens(text)
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
def _build_index() -> Tuple[List[Dict[str, Any]], List[Counter], List[int], Dict[str, float]]:
|
||
docs = load_qa_data() or []
|
||
tfs: List[Counter] = []
|
||
doc_lens: List[int] = []
|
||
df: Counter = Counter()
|
||
|
||
for item in docs:
|
||
tok = _item_tokens(item)
|
||
tf = Counter(tok)
|
||
tfs.append(tf)
|
||
doc_lens.append(sum(tf.values()))
|
||
df.update(set(tf.keys()))
|
||
|
||
n = max(len(docs), 1)
|
||
idf: Dict[str, float] = {}
|
||
for term, freq in df.items():
|
||
# BM25 idf
|
||
idf[term] = math.log(1 + (n - freq + 0.5) / (freq + 0.5))
|
||
|
||
return docs, tfs, doc_lens, idf
|
||
|
||
|
||
def _bm25_scores(query: str, k1: float = 1.5, b: float = 0.75) -> List[Tuple[int, float]]:
|
||
q_tokens = _tokens(query)
|
||
if not q_tokens:
|
||
return []
|
||
|
||
docs, tfs, doc_lens, idf = _build_index()
|
||
avgdl = (sum(doc_lens) / max(len(doc_lens), 1)) if doc_lens else 0.0
|
||
|
||
q_tf = Counter(q_tokens)
|
||
scored: List[Tuple[int, float]] = []
|
||
for i, tf in enumerate(tfs):
|
||
score = 0.0
|
||
dl = doc_lens[i] or 0
|
||
denom_norm = 1.0 - b + b * (dl / avgdl) if avgdl > 0 else 1.0
|
||
for term, qf in q_tf.items():
|
||
f = tf.get(term, 0)
|
||
if not f:
|
||
continue
|
||
term_idf = idf.get(term, 0.0)
|
||
denom = f + k1 * denom_norm
|
||
score += term_idf * (f * (k1 + 1) / denom) * qf
|
||
if score > 0:
|
||
scored.append((i, score))
|
||
|
||
scored.sort(key=lambda x: x[1], reverse=True)
|
||
return scored
|
||
|
||
|
||
def default_items(limit: int = 10) -> List[Dict[str, Any]]:
|
||
items = top_questions() or []
|
||
if items:
|
||
return items[:limit]
|
||
|
||
docs, *_ = _build_index()
|
||
# 无“热门问题”配置时,按 id 稳定返回前 N 条(避免随机)
|
||
ordered = sorted(docs, key=lambda x: x.get("id", 0))
|
||
return ordered[:limit]
|
||
|
||
|
||
def search_items(query: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||
q = (query or "").strip()
|
||
if not q:
|
||
return default_items(limit)
|
||
|
||
docs, *_ = _build_index()
|
||
ranked = _bm25_scores(q)
|
||
if not ranked:
|
||
return []
|
||
|
||
out: List[Dict[str, Any]] = []
|
||
for i, _score in ranked[:limit]:
|
||
item = docs[i]
|
||
out.append(item)
|
||
return out
|
||
|
||
|
||
def search_pairs(query: str, limit: int = 5) -> List[Dict[str, str]]:
|
||
items = search_items(query, limit=limit)
|
||
return [{"question": x.get("question", ""), "answer": x.get("answer", "")} for x in items]
|