nianjie/backend/retrieval.py

133 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import re
from collections import Counter
from functools import lru_cache
from typing import Any, Dict, List, Tuple
from .qa import load_qa_data, top_questions
_WORD_RE = re.compile(r"[\u4e00-\u9fffA-Za-z0-9]+")
def _normalize(text: str) -> str:
if not text:
return ""
return "".join(_WORD_RE.findall(text)).lower()
def _tokens(text: str) -> List[str]:
"""
轻量中文检索分词:
- 先做字符级规范化(保留中日韩统一表意文字 + 字母数字)
- 生成 2-gram + 1-gram兼容短查询
"""
s = _normalize(text)
if not s:
return []
chars = list(s)
if len(chars) == 1:
return chars
bigrams = [chars[i] + chars[i + 1] for i in range(len(chars) - 1)]
return bigrams + chars
def _item_tokens(item: Dict[str, Any]) -> List[str]:
"""
支持“手动分词”:
- 若数据里提供 tokenslist[str] 或以空白分隔的 str直接使用
- 否则退化为本地规则分词(不依赖模型)
"""
provided = item.get("tokens")
if isinstance(provided, list) and all(isinstance(x, str) for x in provided):
return [x.strip().lower() for x in provided if x and x.strip()]
if isinstance(provided, str) and provided.strip():
return [x.strip().lower() for x in provided.split() if x.strip()]
text = f"{item.get('question', '')}\n{item.get('answer', '')}"
return _tokens(text)
@lru_cache(maxsize=1)
def _build_index() -> Tuple[List[Dict[str, Any]], List[Counter], List[int], Dict[str, float]]:
docs = load_qa_data() or []
tfs: List[Counter] = []
doc_lens: List[int] = []
df: Counter = Counter()
for item in docs:
tok = _item_tokens(item)
tf = Counter(tok)
tfs.append(tf)
doc_lens.append(sum(tf.values()))
df.update(set(tf.keys()))
n = max(len(docs), 1)
idf: Dict[str, float] = {}
for term, freq in df.items():
# BM25 idf
idf[term] = math.log(1 + (n - freq + 0.5) / (freq + 0.5))
return docs, tfs, doc_lens, idf
def _bm25_scores(query: str, k1: float = 1.5, b: float = 0.75) -> List[Tuple[int, float]]:
q_tokens = _tokens(query)
if not q_tokens:
return []
docs, tfs, doc_lens, idf = _build_index()
avgdl = (sum(doc_lens) / max(len(doc_lens), 1)) if doc_lens else 0.0
q_tf = Counter(q_tokens)
scored: List[Tuple[int, float]] = []
for i, tf in enumerate(tfs):
score = 0.0
dl = doc_lens[i] or 0
denom_norm = 1.0 - b + b * (dl / avgdl) if avgdl > 0 else 1.0
for term, qf in q_tf.items():
f = tf.get(term, 0)
if not f:
continue
term_idf = idf.get(term, 0.0)
denom = f + k1 * denom_norm
score += term_idf * (f * (k1 + 1) / denom) * qf
if score > 0:
scored.append((i, score))
scored.sort(key=lambda x: x[1], reverse=True)
return scored
def default_items(limit: int = 10) -> List[Dict[str, Any]]:
items = top_questions() or []
if items:
return items[:limit]
docs, *_ = _build_index()
# 无“热门问题”配置时,按 id 稳定返回前 N 条(避免随机)
ordered = sorted(docs, key=lambda x: x.get("id", 0))
return ordered[:limit]
def search_items(query: str, limit: int = 10) -> List[Dict[str, Any]]:
q = (query or "").strip()
if not q:
return default_items(limit)
docs, *_ = _build_index()
ranked = _bm25_scores(q)
if not ranked:
return []
out: List[Dict[str, Any]] = []
for i, _score in ranked[:limit]:
item = docs[i]
out.append(item)
return out
def search_pairs(query: str, limit: int = 5) -> List[Dict[str, str]]:
items = search_items(query, limit=limit)
return [{"question": x.get("question", ""), "answer": x.get("answer", "")} for x in items]