import asyncio import os import sys import types from functools import lru_cache from pathlib import Path from typing import List, Dict, Any import numpy as np from minirag import MiniRAG, QueryParam, minirag as minirag_mod from minirag.base import BaseKVStorage, BaseVectorStorage, BaseGraphStorage from minirag.utils import wrap_embedding_func_with_attrs, compute_mdhash_id from sentence_transformers import SentenceTransformer from .config import PROJECT_ROOT, QA_PATH # 环境设置:关闭实体抽取,避免额外依赖 os.environ.setdefault("MINIRAG_DISABLE_ENTITY_EXTRACT", "1") # --------- 为 pip 版 minirag 注入轻量存储实现,避免缺失 kg 模块 --------- class _JsonKVStorage(BaseKVStorage): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.storage: Dict[str, dict] = {} async def all_keys(self): return list(self.storage.keys()) async def get_by_id(self, id: str): return self.storage.get(id) async def get_by_ids(self, ids, fields=None): out = [] for i in ids: v = self.storage.get(i) if v and fields: v = {k: v[k] for k in fields if k in v} out.append(v) return out async def filter_keys(self, data): return {k for k in data if k not in self.storage} async def upsert(self, data): self.storage.update(data) async def drop(self): self.storage.clear() async def index_done_callback(self): return async def query_done_callback(self): return class _NanoVectorDBStorage(BaseVectorStorage): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.storage: Dict[str, Dict[str, Any]] = {} self.embeddings: Dict[str, np.ndarray] = {} async def query(self, query: str, top_k: int): if not self.embeddings: return [] q_embs = await self.embedding_func([query]) q_emb = np.array(q_embs[0], dtype=np.float32) sims = [] for k, emb in self.embeddings.items(): denom = (np.linalg.norm(emb) * np.linalg.norm(q_emb)) or 1e-6 sims.append((k, float(np.dot(emb, q_emb) / denom))) sims.sort(key=lambda x: x[1], reverse=True) res = [] for k, score in sims[:top_k]: item = dict(self.storage.get(k, {})) item.update({"id": k, "score": score}) res.append(item) return res async def upsert(self, data): texts = [v.get("content", "") for v in data.values()] embs = await self.embedding_func(texts) if isinstance(embs, np.ndarray): embs = list(embs) for (k, v), emb in zip(data.items(), embs): self.storage[k] = v self.embeddings[k] = np.array(emb, dtype=np.float32) async def index_done_callback(self): return async def query_done_callback(self): return class _NetworkXStorage(BaseGraphStorage): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.nodes: Dict[str, dict] = {} self.edges: Dict[tuple, dict] = {} async def get_types(self): return [], [] async def has_node(self, node_id: str): return node_id in self.nodes async def has_edge(self, source_node_id: str, target_node_id: str): return (source_node_id, target_node_id) in self.edges async def node_degree(self, node_id: str): return 0 async def edge_degree(self, src_id: str, tgt_id: str): return 0 async def get_node(self, node_id: str): return self.nodes.get(node_id) async def get_edge(self, source_node_id: str, target_node_id: str): return self.edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str): return [] async def upsert_node(self, node_id: str, node_data: dict[str, str]): self.nodes[node_id] = node_data async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]): self.edges[(source_node_id, target_node_id)] = edge_data async def delete_node(self, node_id: str): self.nodes.pop(node_id, None) for k in list(self.edges.keys()): if k[0] == node_id or k[1] == node_id: self.edges.pop(k, None) async def embed_nodes(self, algorithm: str): return np.zeros((0,)), [] async def index_done_callback(self): return async def query_done_callback(self): return class _JsonDocStatusStorage(_JsonKVStorage): pass _simple_module = types.ModuleType("minirag.simple_storage") _simple_module.JsonKVStorage = _JsonKVStorage _simple_module.NanoVectorDBStorage = _NanoVectorDBStorage _simple_module.NetworkXStorage = _NetworkXStorage _simple_module.JsonDocStatusStorage = _JsonDocStatusStorage sys.modules["minirag.simple_storage"] = _simple_module minirag_mod.STORAGES.update({ "NetworkXStorage": "minirag.simple_storage", "JsonKVStorage": "minirag.simple_storage", "NanoVectorDBStorage": "minirag.simple_storage", "JsonDocStatusStorage": "minirag.simple_storage", }) # ------------------------------------------------------------------------- # 模型与工作目录路径 MODEL_DIR = (PROJECT_ROOT.parent / "minirag" / "minirag" / "models" / "bge-small-zh-v1.5").resolve() WORKDIR = (PROJECT_ROOT / "minirag_cache").resolve() WORKDIR.mkdir(parents=True, exist_ok=True) # 预加载 QA def _load_qas() -> List[Dict[str, Any]]: return __import__("json").loads(QA_PATH.read_text(encoding="utf-8")) def _build_embedder(): model = SentenceTransformer(str(MODEL_DIR), device="cpu") emb_dim = model.get_sentence_embedding_dimension() @wrap_embedding_func_with_attrs(embedding_dim=emb_dim, max_token_size=512) async def embed(texts): if isinstance(texts, str): texts = [texts] embs = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True) return embs return embed @lru_cache(maxsize=1) def _rag_bundle(): qas = _load_qas() embed = _build_embedder() rag = MiniRAG( working_dir=str(WORKDIR), embedding_func=embed, chunk_token_size=1200, # 不再二次切片,足够容纳问+答 chunk_overlap_token_size=0, llm_model_func=lambda *a, **k: "", # 不在检索阶段调用 LLM log_level="WARNING", ) # 构造 chunk 与原始 qa 的映射 chunks = [] id_to_qa = {} for qa in qas: chunk_text = f"Q{qa.get('id')}:{qa.get('question','')}\nA:{qa.get('answer','')}" cid = compute_mdhash_id(chunk_text, prefix="chunk-") chunks.append(chunk_text) id_to_qa[cid] = qa loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(rag.ainsert(chunks)) loop.close() return rag, id_to_qa def search_rag(query: str, limit: int = 5) -> List[Dict[str, str]]: """ 使用 minirag 检索,返回 question/answer 列表。 """ rag, id_to_qa = _rag_bundle() async def _search(): results = await rag.chunks_vdb.query(query, top_k=limit) out = [] for r in results: qa = id_to_qa.get(r.get("id")) if not qa: continue out.append({"question": qa.get("question", ""), "answer": qa.get("answer", "")}) return out loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(_search()) finally: loop.close() def search_rag_full(query: str, limit: int = 10) -> List[Dict[str, Any]]: """ 返回带 id / question / answer 的列表,供 FAQ 阶段展示。 """ rag, id_to_qa = _rag_bundle() async def _search(): results = await rag.chunks_vdb.query(query, top_k=limit) out = [] for r in results: qa = id_to_qa.get(r.get("id")) if not qa: continue out.append({"id": qa.get("id"), "question": qa.get("question", ""), "answer": qa.get("answer", "")}) return out loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(_search()) finally: loop.close()