340 lines
11 KiB
Python
340 lines
11 KiB
Python
import asyncio
|
||
import os
|
||
import sys
|
||
import types
|
||
from functools import lru_cache
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any
|
||
|
||
import numpy as np
|
||
from minirag import MiniRAG, QueryParam, minirag as minirag_mod
|
||
from minirag.base import BaseKVStorage, BaseVectorStorage, BaseGraphStorage
|
||
from minirag.utils import wrap_embedding_func_with_attrs, compute_mdhash_id
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
from .config import PROJECT_ROOT, QA_PATH, QA_REPORT_PATH
|
||
|
||
# 环境设置:关闭实体抽取,避免额外依赖
|
||
os.environ.setdefault("MINIRAG_DISABLE_ENTITY_EXTRACT", "1")
|
||
|
||
# --------- 为 pip 版 minirag 注入轻量存储实现,避免缺失 kg 模块 ---------
|
||
class _JsonKVStorage(BaseKVStorage):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.storage: Dict[str, dict] = {}
|
||
|
||
async def all_keys(self):
|
||
return list(self.storage.keys())
|
||
|
||
async def get_by_id(self, id: str):
|
||
return self.storage.get(id)
|
||
|
||
async def get_by_ids(self, ids, fields=None):
|
||
out = []
|
||
for i in ids:
|
||
v = self.storage.get(i)
|
||
if v and fields:
|
||
v = {k: v[k] for k in fields if k in v}
|
||
out.append(v)
|
||
return out
|
||
|
||
async def filter_keys(self, data):
|
||
return {k for k in data if k not in self.storage}
|
||
|
||
async def upsert(self, data):
|
||
self.storage.update(data)
|
||
|
||
async def drop(self):
|
||
self.storage.clear()
|
||
|
||
async def index_done_callback(self):
|
||
return
|
||
|
||
async def query_done_callback(self):
|
||
return
|
||
|
||
|
||
class _NanoVectorDBStorage(BaseVectorStorage):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.storage: Dict[str, Dict[str, Any]] = {}
|
||
self.embeddings: Dict[str, np.ndarray] = {}
|
||
|
||
async def query(self, query: str, top_k: int):
|
||
if not self.embeddings:
|
||
return []
|
||
q_embs = await self.embedding_func([query])
|
||
q_emb = np.array(q_embs[0], dtype=np.float32)
|
||
sims = []
|
||
for k, emb in self.embeddings.items():
|
||
denom = (np.linalg.norm(emb) * np.linalg.norm(q_emb)) or 1e-6
|
||
sims.append((k, float(np.dot(emb, q_emb) / denom)))
|
||
sims.sort(key=lambda x: x[1], reverse=True)
|
||
res = []
|
||
for k, score in sims[:top_k]:
|
||
item = dict(self.storage.get(k, {}))
|
||
item.update({"id": k, "score": score})
|
||
res.append(item)
|
||
return res
|
||
|
||
async def upsert(self, data):
|
||
texts = [v.get("content", "") for v in data.values()]
|
||
embs = await self.embedding_func(texts)
|
||
if isinstance(embs, np.ndarray):
|
||
embs = list(embs)
|
||
for (k, v), emb in zip(data.items(), embs):
|
||
self.storage[k] = v
|
||
self.embeddings[k] = np.array(emb, dtype=np.float32)
|
||
|
||
async def index_done_callback(self):
|
||
return
|
||
|
||
async def query_done_callback(self):
|
||
return
|
||
|
||
|
||
class _NetworkXStorage(BaseGraphStorage):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.nodes: Dict[str, dict] = {}
|
||
self.edges: Dict[tuple, dict] = {}
|
||
|
||
async def get_types(self):
|
||
return [], []
|
||
|
||
async def has_node(self, node_id: str):
|
||
return node_id in self.nodes
|
||
|
||
async def has_edge(self, source_node_id: str, target_node_id: str):
|
||
return (source_node_id, target_node_id) in self.edges
|
||
|
||
async def node_degree(self, node_id: str):
|
||
return 0
|
||
|
||
async def edge_degree(self, src_id: str, tgt_id: str):
|
||
return 0
|
||
|
||
async def get_node(self, node_id: str):
|
||
return self.nodes.get(node_id)
|
||
|
||
async def get_edge(self, source_node_id: str, target_node_id: str):
|
||
return self.edges.get((source_node_id, target_node_id))
|
||
|
||
async def get_node_edges(self, source_node_id: str):
|
||
return []
|
||
|
||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||
self.nodes[node_id] = node_data
|
||
|
||
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
|
||
self.edges[(source_node_id, target_node_id)] = edge_data
|
||
|
||
async def delete_node(self, node_id: str):
|
||
self.nodes.pop(node_id, None)
|
||
for k in list(self.edges.keys()):
|
||
if k[0] == node_id or k[1] == node_id:
|
||
self.edges.pop(k, None)
|
||
|
||
async def embed_nodes(self, algorithm: str):
|
||
return np.zeros((0,)), []
|
||
|
||
async def index_done_callback(self):
|
||
return
|
||
|
||
async def query_done_callback(self):
|
||
return
|
||
|
||
|
||
class _JsonDocStatusStorage(_JsonKVStorage):
|
||
pass
|
||
|
||
|
||
_simple_module = types.ModuleType("minirag.simple_storage")
|
||
_simple_module.JsonKVStorage = _JsonKVStorage
|
||
_simple_module.NanoVectorDBStorage = _NanoVectorDBStorage
|
||
_simple_module.NetworkXStorage = _NetworkXStorage
|
||
_simple_module.JsonDocStatusStorage = _JsonDocStatusStorage
|
||
sys.modules["minirag.simple_storage"] = _simple_module
|
||
|
||
minirag_mod.STORAGES.update({
|
||
"NetworkXStorage": "minirag.simple_storage",
|
||
"JsonKVStorage": "minirag.simple_storage",
|
||
"NanoVectorDBStorage": "minirag.simple_storage",
|
||
"JsonDocStatusStorage": "minirag.simple_storage",
|
||
})
|
||
# -------------------------------------------------------------------------
|
||
|
||
# 模型与工作目录路径
|
||
MODEL_DIR = (PROJECT_ROOT.parent / "minirag" / "minirag" / "models" / "bge-small-zh-v1.5").resolve()
|
||
WORKDIR = (PROJECT_ROOT / "minirag_cache").resolve()
|
||
WORKDIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 预加载 QA
|
||
def _load_qas(path: Path) -> List[Dict[str, Any]]:
|
||
return __import__("json").loads(path.read_text(encoding="utf-8"))
|
||
|
||
|
||
def _build_embedder():
|
||
model = SentenceTransformer(str(MODEL_DIR), device="cpu")
|
||
emb_dim = model.get_sentence_embedding_dimension()
|
||
|
||
@wrap_embedding_func_with_attrs(embedding_dim=emb_dim, max_token_size=512)
|
||
async def embed(texts):
|
||
if isinstance(texts, str):
|
||
texts = [texts]
|
||
embs = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True)
|
||
return embs.astype(np.float32)
|
||
|
||
return embed
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
@lru_cache(maxsize=1)
|
||
def _embedder_cached():
|
||
return _build_embedder()
|
||
|
||
|
||
def _build_rag_for(path: Path, workdir: Path):
|
||
qas = _load_qas(path)
|
||
embed = _embedder_cached()
|
||
rag = MiniRAG(
|
||
working_dir=str(workdir),
|
||
embedding_func=embed,
|
||
chunk_token_size=1200, # 问+答一条,不再细切
|
||
chunk_overlap_token_size=0,
|
||
llm_model_func=lambda *a, **k: "",
|
||
log_level="WARNING",
|
||
)
|
||
chunks = []
|
||
id_to_qa = {}
|
||
for qa in qas:
|
||
chunk_text = f"Q{qa.get('id')}:{qa.get('question','')}\nA:{qa.get('answer','')}"
|
||
cid = compute_mdhash_id(chunk_text, prefix="chunk-")
|
||
chunks.append(chunk_text)
|
||
id_to_qa[cid] = qa
|
||
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
loop.run_until_complete(rag.ainsert(chunks))
|
||
loop.close()
|
||
return rag, id_to_qa
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
def _rag_bundle_common():
|
||
workdir = WORKDIR / "common"
|
||
workdir.mkdir(parents=True, exist_ok=True)
|
||
return _build_rag_for(QA_PATH, workdir)
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
def _rag_bundle_report():
|
||
workdir = WORKDIR / "report"
|
||
workdir.mkdir(parents=True, exist_ok=True)
|
||
return _build_rag_for(QA_REPORT_PATH, workdir)
|
||
|
||
|
||
REPORT_KEYWORDS = [
|
||
# 报告/证书类
|
||
"检测", "检测报告", "检验", "检验报告", "质检", "质检报告", "第三方",
|
||
"报告", "报告编号", "编号", "证书", "证明", "盖章", "章", "资质", "认证",
|
||
# SDS / MSDS / COA
|
||
"MSDS", "SDS", "安全数据单", "安全技术说明书", "COA", "COC", "CMA", "CNAS",
|
||
# 指标/结果表达
|
||
"成分", "成分表", "配方", "含量", "浓度", "数值", "指标", "限值", "限量",
|
||
"合格", "达标", "超标", "未检出", "检出限", "ppm", "mg/kg",
|
||
# 重点关注物质/风险词
|
||
"化学", "毒", "有毒", "安全性", "刺激", "过敏", "致敏",
|
||
"汞", "砷", "铅", "镉", "甲醇", "甲醛", "重金属",
|
||
]
|
||
|
||
|
||
def _hit_report(query: str):
|
||
q = query.lower()
|
||
return [kw for kw in REPORT_KEYWORDS if kw.lower() in q]
|
||
|
||
|
||
def _query_rag(rag, id_to_qa, query: str, limit: int, source: str):
|
||
async def _search():
|
||
results = await rag.chunks_vdb.query(query, top_k=limit)
|
||
out = []
|
||
for r in results:
|
||
qa = id_to_qa.get(r.get("id"))
|
||
if not qa:
|
||
continue
|
||
out.append({
|
||
"id": qa.get("id"),
|
||
"question": qa.get("question", ""),
|
||
"answer": qa.get("answer", ""),
|
||
"score": r.get("score", 0.0),
|
||
"source": source,
|
||
})
|
||
return out
|
||
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
return loop.run_until_complete(_search())
|
||
finally:
|
||
loop.close()
|
||
|
||
|
||
def search_rag(query: str, limit: int = 5) -> List[Dict[str, str]]:
|
||
"""
|
||
使用 minirag 检索,返回 question/answer 列表(无 id,供简单问答)。
|
||
命中检测关键词时,同时查检测库。
|
||
"""
|
||
hits = _hit_report(query)
|
||
rag_c, map_c = _rag_bundle_common()
|
||
common = _query_rag(rag_c, map_c, query, limit, "common")
|
||
report = []
|
||
if hits:
|
||
rag_r, map_r = _rag_bundle_report()
|
||
report = _query_rag(rag_r, map_r, query, limit, "report")
|
||
|
||
merged = []
|
||
seen = set()
|
||
for item in report + common:
|
||
if item["id"] in seen:
|
||
continue
|
||
seen.add(item["id"])
|
||
merged.append({"question": item["question"], "answer": item["answer"], "source": item["source"]})
|
||
if len(merged) >= limit:
|
||
break
|
||
print(f"[RAG][faq_chat] used={'both' if hits else 'common'} kw={hits} q='{query}'")
|
||
return merged
|
||
|
||
|
||
def search_rag_full(query: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||
"""
|
||
返回带 id / question / answer 的列表,供 FAQ 阶段展示。
|
||
命中检测关键词时:检测库+常用库合并,检测结果优先。
|
||
"""
|
||
hits = _hit_report(query)
|
||
rag_c, map_c = _rag_bundle_common()
|
||
common = _query_rag(rag_c, map_c, query, limit, "common")
|
||
report = []
|
||
if hits:
|
||
rag_r, map_r = _rag_bundle_report()
|
||
report = _query_rag(rag_r, map_r, query, limit, "report")
|
||
|
||
merged = []
|
||
seen = set()
|
||
for item in report + common:
|
||
if item["id"] in seen:
|
||
continue
|
||
seen.add(item["id"])
|
||
merged.append({
|
||
"id": item["id"],
|
||
"question": item["question"],
|
||
"answer": item["answer"],
|
||
"source": item["source"],
|
||
"score": item.get("score", 0.0),
|
||
})
|
||
if len(merged) >= limit:
|
||
break
|
||
|
||
debug = {"used_index": "both" if hits else "common", "hit_keywords": hits}
|
||
print(f"[RAG][faq_search] used={debug['used_index']} kw={hits} q='{query}'")
|
||
return merged, debug
|