nianjie/dialog/backend/rag.py

340 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import os
import sys
import types
from functools import lru_cache
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
from minirag import MiniRAG, QueryParam, minirag as minirag_mod
from minirag.base import BaseKVStorage, BaseVectorStorage, BaseGraphStorage
from minirag.utils import wrap_embedding_func_with_attrs, compute_mdhash_id
from sentence_transformers import SentenceTransformer
from .config import PROJECT_ROOT, QA_PATH, QA_REPORT_PATH
# 环境设置:关闭实体抽取,避免额外依赖
os.environ.setdefault("MINIRAG_DISABLE_ENTITY_EXTRACT", "1")
# --------- 为 pip 版 minirag 注入轻量存储实现,避免缺失 kg 模块 ---------
class _JsonKVStorage(BaseKVStorage):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage: Dict[str, dict] = {}
async def all_keys(self):
return list(self.storage.keys())
async def get_by_id(self, id: str):
return self.storage.get(id)
async def get_by_ids(self, ids, fields=None):
out = []
for i in ids:
v = self.storage.get(i)
if v and fields:
v = {k: v[k] for k in fields if k in v}
out.append(v)
return out
async def filter_keys(self, data):
return {k for k in data if k not in self.storage}
async def upsert(self, data):
self.storage.update(data)
async def drop(self):
self.storage.clear()
async def index_done_callback(self):
return
async def query_done_callback(self):
return
class _NanoVectorDBStorage(BaseVectorStorage):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage: Dict[str, Dict[str, Any]] = {}
self.embeddings: Dict[str, np.ndarray] = {}
async def query(self, query: str, top_k: int):
if not self.embeddings:
return []
q_embs = await self.embedding_func([query])
q_emb = np.array(q_embs[0], dtype=np.float32)
sims = []
for k, emb in self.embeddings.items():
denom = (np.linalg.norm(emb) * np.linalg.norm(q_emb)) or 1e-6
sims.append((k, float(np.dot(emb, q_emb) / denom)))
sims.sort(key=lambda x: x[1], reverse=True)
res = []
for k, score in sims[:top_k]:
item = dict(self.storage.get(k, {}))
item.update({"id": k, "score": score})
res.append(item)
return res
async def upsert(self, data):
texts = [v.get("content", "") for v in data.values()]
embs = await self.embedding_func(texts)
if isinstance(embs, np.ndarray):
embs = list(embs)
for (k, v), emb in zip(data.items(), embs):
self.storage[k] = v
self.embeddings[k] = np.array(emb, dtype=np.float32)
async def index_done_callback(self):
return
async def query_done_callback(self):
return
class _NetworkXStorage(BaseGraphStorage):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.nodes: Dict[str, dict] = {}
self.edges: Dict[tuple, dict] = {}
async def get_types(self):
return [], []
async def has_node(self, node_id: str):
return node_id in self.nodes
async def has_edge(self, source_node_id: str, target_node_id: str):
return (source_node_id, target_node_id) in self.edges
async def node_degree(self, node_id: str):
return 0
async def edge_degree(self, src_id: str, tgt_id: str):
return 0
async def get_node(self, node_id: str):
return self.nodes.get(node_id)
async def get_edge(self, source_node_id: str, target_node_id: str):
return self.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
return []
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self.nodes[node_id] = node_data
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
self.edges[(source_node_id, target_node_id)] = edge_data
async def delete_node(self, node_id: str):
self.nodes.pop(node_id, None)
for k in list(self.edges.keys()):
if k[0] == node_id or k[1] == node_id:
self.edges.pop(k, None)
async def embed_nodes(self, algorithm: str):
return np.zeros((0,)), []
async def index_done_callback(self):
return
async def query_done_callback(self):
return
class _JsonDocStatusStorage(_JsonKVStorage):
pass
_simple_module = types.ModuleType("minirag.simple_storage")
_simple_module.JsonKVStorage = _JsonKVStorage
_simple_module.NanoVectorDBStorage = _NanoVectorDBStorage
_simple_module.NetworkXStorage = _NetworkXStorage
_simple_module.JsonDocStatusStorage = _JsonDocStatusStorage
sys.modules["minirag.simple_storage"] = _simple_module
minirag_mod.STORAGES.update({
"NetworkXStorage": "minirag.simple_storage",
"JsonKVStorage": "minirag.simple_storage",
"NanoVectorDBStorage": "minirag.simple_storage",
"JsonDocStatusStorage": "minirag.simple_storage",
})
# -------------------------------------------------------------------------
# 模型与工作目录路径
MODEL_DIR = (PROJECT_ROOT.parent / "minirag" / "minirag" / "models" / "bge-small-zh-v1.5").resolve()
WORKDIR = (PROJECT_ROOT / "minirag_cache").resolve()
WORKDIR.mkdir(parents=True, exist_ok=True)
# 预加载 QA
def _load_qas(path: Path) -> List[Dict[str, Any]]:
return __import__("json").loads(path.read_text(encoding="utf-8"))
def _build_embedder():
model = SentenceTransformer(str(MODEL_DIR), device="cpu")
emb_dim = model.get_sentence_embedding_dimension()
@wrap_embedding_func_with_attrs(embedding_dim=emb_dim, max_token_size=512)
async def embed(texts):
if isinstance(texts, str):
texts = [texts]
embs = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True)
return embs.astype(np.float32)
return embed
@lru_cache(maxsize=1)
@lru_cache(maxsize=1)
def _embedder_cached():
return _build_embedder()
def _build_rag_for(path: Path, workdir: Path):
qas = _load_qas(path)
embed = _embedder_cached()
rag = MiniRAG(
working_dir=str(workdir),
embedding_func=embed,
chunk_token_size=1200, # 问+答一条,不再细切
chunk_overlap_token_size=0,
llm_model_func=lambda *a, **k: "",
log_level="WARNING",
)
chunks = []
id_to_qa = {}
for qa in qas:
chunk_text = f"Q{qa.get('id')}{qa.get('question','')}\nA{qa.get('answer','')}"
cid = compute_mdhash_id(chunk_text, prefix="chunk-")
chunks.append(chunk_text)
id_to_qa[cid] = qa
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(rag.ainsert(chunks))
loop.close()
return rag, id_to_qa
@lru_cache(maxsize=1)
def _rag_bundle_common():
workdir = WORKDIR / "common"
workdir.mkdir(parents=True, exist_ok=True)
return _build_rag_for(QA_PATH, workdir)
@lru_cache(maxsize=1)
def _rag_bundle_report():
workdir = WORKDIR / "report"
workdir.mkdir(parents=True, exist_ok=True)
return _build_rag_for(QA_REPORT_PATH, workdir)
REPORT_KEYWORDS = [
# 报告/证书类
"检测", "检测报告", "检验", "检验报告", "质检", "质检报告", "第三方",
"报告", "报告编号", "编号", "证书", "证明", "盖章", "", "资质", "认证",
# SDS / MSDS / COA
"MSDS", "SDS", "安全数据单", "安全技术说明书", "COA", "COC", "CMA", "CNAS",
# 指标/结果表达
"成分", "成分表", "配方", "含量", "浓度", "数值", "指标", "限值", "限量",
"合格", "达标", "超标", "未检出", "检出限", "ppm", "mg/kg",
# 重点关注物质/风险词
"化学", "", "有毒", "安全性", "刺激", "过敏", "致敏",
"", "", "", "", "甲醇", "甲醛", "重金属",
]
def _hit_report(query: str):
q = query.lower()
return [kw for kw in REPORT_KEYWORDS if kw.lower() in q]
def _query_rag(rag, id_to_qa, query: str, limit: int, source: str):
async def _search():
results = await rag.chunks_vdb.query(query, top_k=limit)
out = []
for r in results:
qa = id_to_qa.get(r.get("id"))
if not qa:
continue
out.append({
"id": qa.get("id"),
"question": qa.get("question", ""),
"answer": qa.get("answer", ""),
"score": r.get("score", 0.0),
"source": source,
})
return out
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(_search())
finally:
loop.close()
def search_rag(query: str, limit: int = 5) -> List[Dict[str, str]]:
"""
使用 minirag 检索,返回 question/answer 列表(无 id供简单问答
命中检测关键词时,同时查检测库。
"""
hits = _hit_report(query)
rag_c, map_c = _rag_bundle_common()
common = _query_rag(rag_c, map_c, query, limit, "common")
report = []
if hits:
rag_r, map_r = _rag_bundle_report()
report = _query_rag(rag_r, map_r, query, limit, "report")
merged = []
seen = set()
for item in report + common:
if item["id"] in seen:
continue
seen.add(item["id"])
merged.append({"question": item["question"], "answer": item["answer"], "source": item["source"]})
if len(merged) >= limit:
break
print(f"[RAG][faq_chat] used={'both' if hits else 'common'} kw={hits} q='{query}'")
return merged
def search_rag_full(query: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
返回带 id / question / answer 的列表,供 FAQ 阶段展示。
命中检测关键词时:检测库+常用库合并,检测结果优先。
"""
hits = _hit_report(query)
rag_c, map_c = _rag_bundle_common()
common = _query_rag(rag_c, map_c, query, limit, "common")
report = []
if hits:
rag_r, map_r = _rag_bundle_report()
report = _query_rag(rag_r, map_r, query, limit, "report")
merged = []
seen = set()
for item in report + common:
if item["id"] in seen:
continue
seen.add(item["id"])
merged.append({
"id": item["id"],
"question": item["question"],
"answer": item["answer"],
"source": item["source"],
"score": item.get("score", 0.0),
})
if len(merged) >= limit:
break
debug = {"used_index": "both" if hits else "common", "hit_keywords": hits}
print(f"[RAG][faq_search] used={debug['used_index']} kw={hits} q='{query}'")
return merged, debug