import asyncio import os import sys import types from functools import lru_cache from pathlib import Path from typing import List, Dict, Any import numpy as np from minirag import MiniRAG, QueryParam, minirag as minirag_mod from minirag.base import BaseKVStorage, BaseVectorStorage, BaseGraphStorage from minirag.utils import wrap_embedding_func_with_attrs, compute_mdhash_id from sentence_transformers import SentenceTransformer from .config import PROJECT_ROOT, QA_PATH, QA_REPORT_PATH # 环境设置:关闭实体抽取,避免额外依赖 os.environ.setdefault("MINIRAG_DISABLE_ENTITY_EXTRACT", "1") # --------- 为 pip 版 minirag 注入轻量存储实现,避免缺失 kg 模块 --------- class _JsonKVStorage(BaseKVStorage): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.storage: Dict[str, dict] = {} async def all_keys(self): return list(self.storage.keys()) async def get_by_id(self, id: str): return self.storage.get(id) async def get_by_ids(self, ids, fields=None): out = [] for i in ids: v = self.storage.get(i) if v and fields: v = {k: v[k] for k in fields if k in v} out.append(v) return out async def filter_keys(self, data): return {k for k in data if k not in self.storage} async def upsert(self, data): self.storage.update(data) async def drop(self): self.storage.clear() async def index_done_callback(self): return async def query_done_callback(self): return class _NanoVectorDBStorage(BaseVectorStorage): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.storage: Dict[str, Dict[str, Any]] = {} self.embeddings: Dict[str, np.ndarray] = {} async def query(self, query: str, top_k: int): if not self.embeddings: return [] q_embs = await self.embedding_func([query]) q_emb = np.array(q_embs[0], dtype=np.float32) sims = [] for k, emb in self.embeddings.items(): denom = (np.linalg.norm(emb) * np.linalg.norm(q_emb)) or 1e-6 sims.append((k, float(np.dot(emb, q_emb) / denom))) sims.sort(key=lambda x: x[1], reverse=True) res = [] for k, score in sims[:top_k]: item = dict(self.storage.get(k, {})) item.update({"id": k, "score": score}) res.append(item) return res async def upsert(self, data): texts = [v.get("content", "") for v in data.values()] embs = await self.embedding_func(texts) if isinstance(embs, np.ndarray): embs = list(embs) for (k, v), emb in zip(data.items(), embs): self.storage[k] = v self.embeddings[k] = np.array(emb, dtype=np.float32) async def index_done_callback(self): return async def query_done_callback(self): return class _NetworkXStorage(BaseGraphStorage): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.nodes: Dict[str, dict] = {} self.edges: Dict[tuple, dict] = {} async def get_types(self): return [], [] async def has_node(self, node_id: str): return node_id in self.nodes async def has_edge(self, source_node_id: str, target_node_id: str): return (source_node_id, target_node_id) in self.edges async def node_degree(self, node_id: str): return 0 async def edge_degree(self, src_id: str, tgt_id: str): return 0 async def get_node(self, node_id: str): return self.nodes.get(node_id) async def get_edge(self, source_node_id: str, target_node_id: str): return self.edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str): return [] async def upsert_node(self, node_id: str, node_data: dict[str, str]): self.nodes[node_id] = node_data async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]): self.edges[(source_node_id, target_node_id)] = edge_data async def delete_node(self, node_id: str): self.nodes.pop(node_id, None) for k in list(self.edges.keys()): if k[0] == node_id or k[1] == node_id: self.edges.pop(k, None) async def embed_nodes(self, algorithm: str): return np.zeros((0,)), [] async def index_done_callback(self): return async def query_done_callback(self): return class _JsonDocStatusStorage(_JsonKVStorage): pass _simple_module = types.ModuleType("minirag.simple_storage") _simple_module.JsonKVStorage = _JsonKVStorage _simple_module.NanoVectorDBStorage = _NanoVectorDBStorage _simple_module.NetworkXStorage = _NetworkXStorage _simple_module.JsonDocStatusStorage = _JsonDocStatusStorage sys.modules["minirag.simple_storage"] = _simple_module minirag_mod.STORAGES.update({ "NetworkXStorage": "minirag.simple_storage", "JsonKVStorage": "minirag.simple_storage", "NanoVectorDBStorage": "minirag.simple_storage", "JsonDocStatusStorage": "minirag.simple_storage", }) # ------------------------------------------------------------------------- # 模型与工作目录路径 MODEL_DIR = (PROJECT_ROOT.parent / "minirag" / "minirag" / "models" / "bge-small-zh-v1.5").resolve() WORKDIR = (PROJECT_ROOT / "minirag_cache").resolve() WORKDIR.mkdir(parents=True, exist_ok=True) # 预加载 QA def _load_qas(path: Path) -> List[Dict[str, Any]]: return __import__("json").loads(path.read_text(encoding="utf-8")) def _build_embedder(): model = SentenceTransformer(str(MODEL_DIR), device="cpu") emb_dim = model.get_sentence_embedding_dimension() @wrap_embedding_func_with_attrs(embedding_dim=emb_dim, max_token_size=512) async def embed(texts): if isinstance(texts, str): texts = [texts] embs = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True) return embs.astype(np.float32) return embed @lru_cache(maxsize=1) @lru_cache(maxsize=1) def _embedder_cached(): return _build_embedder() def _build_rag_for(path: Path, workdir: Path): qas = _load_qas(path) embed = _embedder_cached() rag = MiniRAG( working_dir=str(workdir), embedding_func=embed, chunk_token_size=1200, # 问+答一条,不再细切 chunk_overlap_token_size=0, llm_model_func=lambda *a, **k: "", log_level="WARNING", ) chunks = [] id_to_qa = {} for qa in qas: chunk_text = f"Q{qa.get('id')}:{qa.get('question','')}\nA:{qa.get('answer','')}" cid = compute_mdhash_id(chunk_text, prefix="chunk-") chunks.append(chunk_text) id_to_qa[cid] = qa loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(rag.ainsert(chunks)) loop.close() return rag, id_to_qa @lru_cache(maxsize=1) def _rag_bundle_common(): workdir = WORKDIR / "common" workdir.mkdir(parents=True, exist_ok=True) return _build_rag_for(QA_PATH, workdir) @lru_cache(maxsize=1) def _rag_bundle_report(): workdir = WORKDIR / "report" workdir.mkdir(parents=True, exist_ok=True) return _build_rag_for(QA_REPORT_PATH, workdir) REPORT_KEYWORDS = [ # 报告/证书类 "检测", "检测报告", "检验", "检验报告", "质检", "质检报告", "第三方", "报告", "报告编号", "编号", "证书", "证明", "盖章", "章", "资质", "认证", # SDS / MSDS / COA "MSDS", "SDS", "安全数据单", "安全技术说明书", "COA", "COC", "CMA", "CNAS", # 指标/结果表达 "成分", "成分表", "配方", "含量", "浓度", "数值", "指标", "限值", "限量", "合格", "达标", "超标", "未检出", "检出限", "ppm", "mg/kg", # 重点关注物质/风险词 "化学", "毒", "有毒", "安全性", "刺激", "过敏", "致敏", "汞", "砷", "铅", "镉", "甲醇", "甲醛", "重金属", ] def _hit_report(query: str): q = query.lower() return [kw for kw in REPORT_KEYWORDS if kw.lower() in q] def _query_rag(rag, id_to_qa, query: str, limit: int, source: str): async def _search(): results = await rag.chunks_vdb.query(query, top_k=limit) out = [] for r in results: qa = id_to_qa.get(r.get("id")) if not qa: continue out.append({ "id": qa.get("id"), "question": qa.get("question", ""), "answer": qa.get("answer", ""), "score": r.get("score", 0.0), "source": source, }) return out loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(_search()) finally: loop.close() def search_rag(query: str, limit: int = 5) -> List[Dict[str, str]]: """ 使用 minirag 检索,返回 question/answer 列表(无 id,供简单问答)。 命中检测关键词时,同时查检测库。 """ hits = _hit_report(query) rag_c, map_c = _rag_bundle_common() common = _query_rag(rag_c, map_c, query, limit, "common") report = [] if hits: rag_r, map_r = _rag_bundle_report() report = _query_rag(rag_r, map_r, query, limit, "report") merged = [] seen = set() for item in report + common: if item["id"] in seen: continue seen.add(item["id"]) merged.append({"question": item["question"], "answer": item["answer"], "source": item["source"]}) if len(merged) >= limit: break print(f"[RAG][faq_chat] used={'both' if hits else 'common'} kw={hits} q='{query}'") return merged def search_rag_full(query: str, limit: int = 10) -> List[Dict[str, Any]]: """ 返回带 id / question / answer 的列表,供 FAQ 阶段展示。 命中检测关键词时:检测库+常用库合并,检测结果优先。 """ hits = _hit_report(query) rag_c, map_c = _rag_bundle_common() common = _query_rag(rag_c, map_c, query, limit, "common") report = [] if hits: rag_r, map_r = _rag_bundle_report() report = _query_rag(rag_r, map_r, query, limit, "report") merged = [] seen = set() for item in report + common: if item["id"] in seen: continue seen.add(item["id"]) merged.append({ "id": item["id"], "question": item["question"], "answer": item["answer"], "source": item["source"], "score": item.get("score", 0.0), }) if len(merged) >= limit: break debug = {"used_index": "both" if hits else "common", "hit_keywords": hits} print(f"[RAG][faq_search] used={debug['used_index']} kw={hits} q='{query}'") return merged, debug