265 lines
8.2 KiB
Python
265 lines
8.2 KiB
Python
import asyncio
|
||
import os
|
||
import sys
|
||
import types
|
||
from functools import lru_cache
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any
|
||
|
||
import numpy as np
|
||
from minirag import MiniRAG, QueryParam, minirag as minirag_mod
|
||
from minirag.base import BaseKVStorage, BaseVectorStorage, BaseGraphStorage
|
||
from minirag.utils import wrap_embedding_func_with_attrs, compute_mdhash_id
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
from .config import PROJECT_ROOT, QA_PATH
|
||
|
||
# 环境设置:关闭实体抽取,避免额外依赖
|
||
os.environ.setdefault("MINIRAG_DISABLE_ENTITY_EXTRACT", "1")
|
||
|
||
# --------- 为 pip 版 minirag 注入轻量存储实现,避免缺失 kg 模块 ---------
|
||
class _JsonKVStorage(BaseKVStorage):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.storage: Dict[str, dict] = {}
|
||
|
||
async def all_keys(self):
|
||
return list(self.storage.keys())
|
||
|
||
async def get_by_id(self, id: str):
|
||
return self.storage.get(id)
|
||
|
||
async def get_by_ids(self, ids, fields=None):
|
||
out = []
|
||
for i in ids:
|
||
v = self.storage.get(i)
|
||
if v and fields:
|
||
v = {k: v[k] for k in fields if k in v}
|
||
out.append(v)
|
||
return out
|
||
|
||
async def filter_keys(self, data):
|
||
return {k for k in data if k not in self.storage}
|
||
|
||
async def upsert(self, data):
|
||
self.storage.update(data)
|
||
|
||
async def drop(self):
|
||
self.storage.clear()
|
||
|
||
async def index_done_callback(self):
|
||
return
|
||
|
||
async def query_done_callback(self):
|
||
return
|
||
|
||
|
||
class _NanoVectorDBStorage(BaseVectorStorage):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.storage: Dict[str, Dict[str, Any]] = {}
|
||
self.embeddings: Dict[str, np.ndarray] = {}
|
||
|
||
async def query(self, query: str, top_k: int):
|
||
if not self.embeddings:
|
||
return []
|
||
q_embs = await self.embedding_func([query])
|
||
q_emb = np.array(q_embs[0], dtype=np.float32)
|
||
sims = []
|
||
for k, emb in self.embeddings.items():
|
||
denom = (np.linalg.norm(emb) * np.linalg.norm(q_emb)) or 1e-6
|
||
sims.append((k, float(np.dot(emb, q_emb) / denom)))
|
||
sims.sort(key=lambda x: x[1], reverse=True)
|
||
res = []
|
||
for k, score in sims[:top_k]:
|
||
item = dict(self.storage.get(k, {}))
|
||
item.update({"id": k, "score": score})
|
||
res.append(item)
|
||
return res
|
||
|
||
async def upsert(self, data):
|
||
texts = [v.get("content", "") for v in data.values()]
|
||
embs = await self.embedding_func(texts)
|
||
if isinstance(embs, np.ndarray):
|
||
embs = list(embs)
|
||
for (k, v), emb in zip(data.items(), embs):
|
||
self.storage[k] = v
|
||
self.embeddings[k] = np.array(emb, dtype=np.float32)
|
||
|
||
async def index_done_callback(self):
|
||
return
|
||
|
||
async def query_done_callback(self):
|
||
return
|
||
|
||
|
||
class _NetworkXStorage(BaseGraphStorage):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.nodes: Dict[str, dict] = {}
|
||
self.edges: Dict[tuple, dict] = {}
|
||
|
||
async def get_types(self):
|
||
return [], []
|
||
|
||
async def has_node(self, node_id: str):
|
||
return node_id in self.nodes
|
||
|
||
async def has_edge(self, source_node_id: str, target_node_id: str):
|
||
return (source_node_id, target_node_id) in self.edges
|
||
|
||
async def node_degree(self, node_id: str):
|
||
return 0
|
||
|
||
async def edge_degree(self, src_id: str, tgt_id: str):
|
||
return 0
|
||
|
||
async def get_node(self, node_id: str):
|
||
return self.nodes.get(node_id)
|
||
|
||
async def get_edge(self, source_node_id: str, target_node_id: str):
|
||
return self.edges.get((source_node_id, target_node_id))
|
||
|
||
async def get_node_edges(self, source_node_id: str):
|
||
return []
|
||
|
||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||
self.nodes[node_id] = node_data
|
||
|
||
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):
|
||
self.edges[(source_node_id, target_node_id)] = edge_data
|
||
|
||
async def delete_node(self, node_id: str):
|
||
self.nodes.pop(node_id, None)
|
||
for k in list(self.edges.keys()):
|
||
if k[0] == node_id or k[1] == node_id:
|
||
self.edges.pop(k, None)
|
||
|
||
async def embed_nodes(self, algorithm: str):
|
||
return np.zeros((0,)), []
|
||
|
||
async def index_done_callback(self):
|
||
return
|
||
|
||
async def query_done_callback(self):
|
||
return
|
||
|
||
|
||
class _JsonDocStatusStorage(_JsonKVStorage):
|
||
pass
|
||
|
||
|
||
_simple_module = types.ModuleType("minirag.simple_storage")
|
||
_simple_module.JsonKVStorage = _JsonKVStorage
|
||
_simple_module.NanoVectorDBStorage = _NanoVectorDBStorage
|
||
_simple_module.NetworkXStorage = _NetworkXStorage
|
||
_simple_module.JsonDocStatusStorage = _JsonDocStatusStorage
|
||
sys.modules["minirag.simple_storage"] = _simple_module
|
||
|
||
minirag_mod.STORAGES.update({
|
||
"NetworkXStorage": "minirag.simple_storage",
|
||
"JsonKVStorage": "minirag.simple_storage",
|
||
"NanoVectorDBStorage": "minirag.simple_storage",
|
||
"JsonDocStatusStorage": "minirag.simple_storage",
|
||
})
|
||
# -------------------------------------------------------------------------
|
||
|
||
# 模型与工作目录路径
|
||
MODEL_DIR = (PROJECT_ROOT.parent / "minirag" / "minirag" / "models" / "bge-small-zh-v1.5").resolve()
|
||
WORKDIR = (PROJECT_ROOT / "minirag_cache").resolve()
|
||
WORKDIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 预加载 QA
|
||
def _load_qas() -> List[Dict[str, Any]]:
|
||
return __import__("json").loads(QA_PATH.read_text(encoding="utf-8"))
|
||
|
||
|
||
def _build_embedder():
|
||
model = SentenceTransformer(str(MODEL_DIR), device="cpu")
|
||
emb_dim = model.get_sentence_embedding_dimension()
|
||
|
||
@wrap_embedding_func_with_attrs(embedding_dim=emb_dim, max_token_size=512)
|
||
async def embed(texts):
|
||
if isinstance(texts, str):
|
||
texts = [texts]
|
||
embs = model.encode(texts, normalize_embeddings=True, convert_to_numpy=True)
|
||
return embs
|
||
|
||
return embed
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
def _rag_bundle():
|
||
qas = _load_qas()
|
||
embed = _build_embedder()
|
||
rag = MiniRAG(
|
||
working_dir=str(WORKDIR),
|
||
embedding_func=embed,
|
||
chunk_token_size=1200, # 不再二次切片,足够容纳问+答
|
||
chunk_overlap_token_size=0,
|
||
llm_model_func=lambda *a, **k: "", # 不在检索阶段调用 LLM
|
||
log_level="WARNING",
|
||
)
|
||
# 构造 chunk 与原始 qa 的映射
|
||
chunks = []
|
||
id_to_qa = {}
|
||
for qa in qas:
|
||
chunk_text = f"Q{qa.get('id')}:{qa.get('question','')}\nA:{qa.get('answer','')}"
|
||
cid = compute_mdhash_id(chunk_text, prefix="chunk-")
|
||
chunks.append(chunk_text)
|
||
id_to_qa[cid] = qa
|
||
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
loop.run_until_complete(rag.ainsert(chunks))
|
||
loop.close()
|
||
return rag, id_to_qa
|
||
|
||
|
||
def search_rag(query: str, limit: int = 5) -> List[Dict[str, str]]:
|
||
"""
|
||
使用 minirag 检索,返回 question/answer 列表。
|
||
"""
|
||
rag, id_to_qa = _rag_bundle()
|
||
|
||
async def _search():
|
||
results = await rag.chunks_vdb.query(query, top_k=limit)
|
||
out = []
|
||
for r in results:
|
||
qa = id_to_qa.get(r.get("id"))
|
||
if not qa:
|
||
continue
|
||
out.append({"question": qa.get("question", ""), "answer": qa.get("answer", "")})
|
||
return out
|
||
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
return loop.run_until_complete(_search())
|
||
finally:
|
||
loop.close()
|
||
|
||
|
||
def search_rag_full(query: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||
"""
|
||
返回带 id / question / answer 的列表,供 FAQ 阶段展示。
|
||
"""
|
||
rag, id_to_qa = _rag_bundle()
|
||
|
||
async def _search():
|
||
results = await rag.chunks_vdb.query(query, top_k=limit)
|
||
out = []
|
||
for r in results:
|
||
qa = id_to_qa.get(r.get("id"))
|
||
if not qa:
|
||
continue
|
||
out.append({"id": qa.get("id"), "question": qa.get("question", ""), "answer": qa.get("answer", "")})
|
||
return out
|
||
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
return loop.run_until_complete(_search())
|
||
finally:
|
||
loop.close()
|