nianjie/dialog/backend/rag.py
2026-01-11 18:52:11 +08:00

265 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import os
import sys
import types
from functools import lru_cache
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
from minirag import MiniRAG, QueryParam, minirag as minirag_mod
from minirag.base import BaseKVStorage, BaseVectorStorage, BaseGraphStorage
from minirag.utils import wrap_embedding_func_with_attrs, compute_mdhash_id
from sentence_transformers import SentenceTransformer
from .config import PROJECT_ROOT, QA_PATH
# 环境设置:关闭实体抽取,避免额外依赖
os.environ.setdefault("MINIRAG_DISABLE_ENTITY_EXTRACT", "1")
# --------- 为 pip 版 minirag 注入轻量存储实现,避免缺失 kg 模块 ---------
class _JsonKVStorage(BaseKVStorage):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage: Dict[str, dict] = {}
async def all_keys(self):
return list(self.storage.keys())
async def get_by_id(self, id: str):
return self.storage.get(id)
async def get_by_ids(self, ids, fields=None):
out = []
for i in ids:
v = self.storage.get(i)
if v and fields:
v = {k: v[k] for k in fields if k in v}
out.append(v)
return out
async def filter_keys(self, data):
return {k for k in data if k not in self.storage}
async def upsert(self, data):
self.storage.update(data)
async def drop(self):
self.storage.clear()
async def index_done_callback(self):
return
async def query_done_callback(self):
return
class _NanoVectorDBStorage(BaseVectorStorage):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage: Dict[str, Dict[str, Any]] = {}
self.embeddings: Dict[str, np.ndarray] = {}
async def query(self, query: str, top_k: int):
if not self.embeddings:
return []
q_embs = await self.embedding_func([query])
q_emb = np.array(q_embs[0], dtype=np.float32)
sims = []
for k, emb in self.embeddings.items():
denom = (np.linalg.norm(emb) * np.linalg.norm(q_emb)) or 1e-6
sims.append((k, float(np.dot(emb, q_emb) / denom)))
sims.sort(key=lambda x: x[1], reverse=True)
res = []
for k, score in sims[:top_k]:
item = dict(self.storage.get(k, {}))
item.update({"id": k, "score": score})
res.append(item)
return res
async def upsert(self, data):
texts = [v.get("content", "") for v in data.values()]
embs = await self.embedding_func(texts)
if isinstance(embs, np.ndarray):
embs = list(embs)
for (k, v), emb in zip(data.items(), embs):
self.storage[k] = v
self.embeddings[k] = np.array(emb, dtype=np.float32)
async def index_done_callback(self):
return
async def query_done_callback(self):
return
class _NetworkXStorage(BaseGraphStorage):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.nodes: Dict[str, dict] = {}
self.edges: Dict[tuple, dict] = {}
async def get_types(self):
return [], []
async def has_node(self, node_id: str):
return node_id in self.nodes
async def has_edge(self, source_node_id: str, target_node_id: str):
return (source_node_id, target_node_id) in self.edges
async def node_degree(self, node_id: str):
return 0
async def edge_degree(self, src_id: str, tgt_id: str):
return 0
async def get_node(self, node_id: str):
return self.nodes.get(node_id)
async def get_edge(self, source_node_id: str, target_node_id: str):
return self.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
return []
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self.nodes[node_id] = node_data
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
self.edges[(source_node_id, target_node_id)] = edge_data
async def delete_node(self, node_id: str):
self.nodes.pop(node_id, None)
for k in list(self.edges.keys()):
if k[0] == node_id or k[1] == node_id:
self.edges.pop(k, None)
async def embed_nodes(self, algorithm: str):
return np.zeros((0,)), []
async def index_done_callback(self):
return
async def query_done_callback(self):
return
class _JsonDocStatusStorage(_JsonKVStorage):
pass
_simple_module = types.ModuleType("minirag.simple_storage")
_simple_module.JsonKVStorage = _JsonKVStorage
_simple_module.NanoVectorDBStorage = _NanoVectorDBStorage
_simple_module.NetworkXStorage = _NetworkXStorage
_simple_module.JsonDocStatusStorage = _JsonDocStatusStorage
sys.modules["minirag.simple_storage"] = _simple_module
minirag_mod.STORAGES.update({
"NetworkXStorage": "minirag.simple_storage",
"JsonKVStorage": "minirag.simple_storage",
"NanoVectorDBStorage": "minirag.simple_storage",
"JsonDocStatusStorage": "minirag.simple_storage",
})
# -------------------------------------------------------------------------
# 模型与工作目录路径
MODEL_DIR = (PROJECT_ROOT.parent / "minirag" / "minirag" / "models" / "bge-small-zh-v1.5").resolve()
WORKDIR = (PROJECT_ROOT / "minirag_cache").resolve()
WORKDIR.mkdir(parents=True, exist_ok=True)
# 预加载 QA
def _load_qas() -> List[Dict[str, Any]]:
return __import__("json").loads(QA_PATH.read_text(encoding="utf-8"))
def _build_embedder():
model = SentenceTransformer(str(MODEL_DIR), device="cpu")
emb_dim = model.get_sentence_embedding_dimension()
@wrap_embedding_func_with_attrs(embedding_dim=emb_dim, max_token_size=512)
async def embed(texts):
if isinstance(texts, str):
texts = [texts]
embs = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True)
return embs
return embed
@lru_cache(maxsize=1)
def _rag_bundle():
qas = _load_qas()
embed = _build_embedder()
rag = MiniRAG(
working_dir=str(WORKDIR),
embedding_func=embed,
chunk_token_size=1200, # 不再二次切片,足够容纳问+答
chunk_overlap_token_size=0,
llm_model_func=lambda *a, **k: "", # 不在检索阶段调用 LLM
log_level="WARNING",
)
# 构造 chunk 与原始 qa 的映射
chunks = []
id_to_qa = {}
for qa in qas:
chunk_text = f"Q{qa.get('id')}{qa.get('question','')}\nA{qa.get('answer','')}"
cid = compute_mdhash_id(chunk_text, prefix="chunk-")
chunks.append(chunk_text)
id_to_qa[cid] = qa
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(rag.ainsert(chunks))
loop.close()
return rag, id_to_qa
def search_rag(query: str, limit: int = 5) -> List[Dict[str, str]]:
"""
使用 minirag 检索,返回 question/answer 列表。
"""
rag, id_to_qa = _rag_bundle()
async def _search():
results = await rag.chunks_vdb.query(query, top_k=limit)
out = []
for r in results:
qa = id_to_qa.get(r.get("id"))
if not qa:
continue
out.append({"question": qa.get("question", ""), "answer": qa.get("answer", "")})
return out
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(_search())
finally:
loop.close()
def search_rag_full(query: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
返回带 id / question / answer 的列表,供 FAQ 阶段展示。
"""
rag, id_to_qa = _rag_bundle()
async def _search():
results = await rag.chunks_vdb.query(query, top_k=limit)
out = []
for r in results:
qa = id_to_qa.get(r.get("id"))
if not qa:
continue
out.append({"id": qa.get("id"), "question": qa.get("question", ""), "answer": qa.get("answer", "")})
return out
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(_search())
finally:
loop.close()