import math import re from collections import Counter from functools import lru_cache from typing import Any, Dict, List, Tuple from .qa import load_qa_data, top_questions _WORD_RE = re.compile(r"[\u4e00-\u9fffA-Za-z0-9]+") def _normalize(text: str) -> str: if not text: return "" return "".join(_WORD_RE.findall(text)).lower() def _tokens(text: str) -> List[str]: """ 轻量中文检索分词: - 先做字符级规范化(保留中日韩统一表意文字 + 字母数字) - 生成 2-gram + 1-gram(兼容短查询) """ s = _normalize(text) if not s: return [] chars = list(s) if len(chars) == 1: return chars bigrams = [chars[i] + chars[i + 1] for i in range(len(chars) - 1)] return bigrams + chars def _item_tokens(item: Dict[str, Any]) -> List[str]: """ 支持“手动分词”: - 若数据里提供 tokens(list[str] 或以空白分隔的 str),直接使用 - 否则退化为本地规则分词(不依赖模型) """ provided = item.get("tokens") if isinstance(provided, list) and all(isinstance(x, str) for x in provided): return [x.strip().lower() for x in provided if x and x.strip()] if isinstance(provided, str) and provided.strip(): return [x.strip().lower() for x in provided.split() if x.strip()] text = f"{item.get('question', '')}\n{item.get('answer', '')}" return _tokens(text) @lru_cache(maxsize=1) def _build_index() -> Tuple[List[Dict[str, Any]], List[Counter], List[int], Dict[str, float]]: docs = load_qa_data() or [] tfs: List[Counter] = [] doc_lens: List[int] = [] df: Counter = Counter() for item in docs: tok = _item_tokens(item) tf = Counter(tok) tfs.append(tf) doc_lens.append(sum(tf.values())) df.update(set(tf.keys())) n = max(len(docs), 1) idf: Dict[str, float] = {} for term, freq in df.items(): # BM25 idf idf[term] = math.log(1 + (n - freq + 0.5) / (freq + 0.5)) return docs, tfs, doc_lens, idf def _bm25_scores(query: str, k1: float = 1.5, b: float = 0.75) -> List[Tuple[int, float]]: q_tokens = _tokens(query) if not q_tokens: return [] docs, tfs, doc_lens, idf = _build_index() avgdl = (sum(doc_lens) / max(len(doc_lens), 1)) if doc_lens else 0.0 q_tf = Counter(q_tokens) scored: List[Tuple[int, float]] = [] for i, tf in enumerate(tfs): score = 0.0 dl = doc_lens[i] or 0 denom_norm = 1.0 - b + b * (dl / avgdl) if avgdl > 0 else 1.0 for term, qf in q_tf.items(): f = tf.get(term, 0) if not f: continue term_idf = idf.get(term, 0.0) denom = f + k1 * denom_norm score += term_idf * (f * (k1 + 1) / denom) * qf if score > 0: scored.append((i, score)) scored.sort(key=lambda x: x[1], reverse=True) return scored def default_items(limit: int = 10) -> List[Dict[str, Any]]: items = top_questions() or [] if items: return items[:limit] docs, *_ = _build_index() # 无“热门问题”配置时,按 id 稳定返回前 N 条(避免随机) ordered = sorted(docs, key=lambda x: x.get("id", 0)) return ordered[:limit] def search_items(query: str, limit: int = 10) -> List[Dict[str, Any]]: q = (query or "").strip() if not q: return default_items(limit) docs, *_ = _build_index() ranked = _bm25_scores(q) if not ranked: return [] out: List[Dict[str, Any]] = [] for i, _score in ranked[:limit]: item = docs[i] out.append(item) return out def search_pairs(query: str, limit: int = 5) -> List[Dict[str, str]]: items = search_items(query, limit=limit) return [{"question": x.get("question", ""), "answer": x.get("answer", "")} for x in items]