307 lines
12 KiB
Python
307 lines
12 KiB
Python
"""
|
||
AI服务层
|
||
封装对R1和V3智能体的调用
|
||
"""
|
||
import logging
|
||
from typing import Dict, List, Any, Optional, Tuple
|
||
from app.agents.r1_agent import R1Agent
|
||
from app.agents.v3_agent import V3Agent
|
||
from app.models.search_result import SearchResult, SearchImportance
|
||
from config import Config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class AIService:
|
||
"""AI服务统一接口"""
|
||
|
||
def __init__(self):
|
||
self.r1_agent = R1Agent()
|
||
self.v3_agent = V3Agent()
|
||
|
||
# ========== 问题分析阶段 (R1) ==========
|
||
|
||
def analyze_question_type(self, question: str) -> str:
|
||
"""分析问题类型"""
|
||
try:
|
||
return self.r1_agent.analyze_question_type(question)
|
||
except Exception as e:
|
||
logger.error(f"分析问题类型失败: {e}")
|
||
return "exploratory" # 默认值
|
||
|
||
def refine_questions(self, question: str, question_type: str) -> List[str]:
|
||
"""细化问题"""
|
||
try:
|
||
return self.r1_agent.refine_questions(question, question_type)
|
||
except Exception as e:
|
||
logger.error(f"细化问题失败: {e}")
|
||
return [question] # 返回原问题
|
||
|
||
def create_research_approach(self, question: str, question_type: str,
|
||
refined_questions: List[str]) -> str:
|
||
"""制定研究思路"""
|
||
try:
|
||
return self.r1_agent.create_research_approach(
|
||
question, question_type, refined_questions
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"制定研究思路失败: {e}")
|
||
return "采用系统化的方法深入研究这个问题。"
|
||
|
||
# ========== 大纲制定阶段 (R1) ==========
|
||
|
||
def create_outline(self, question: str, question_type: str,
|
||
refined_questions: List[str], research_approach: str) -> Dict[str, Any]:
|
||
"""创建研究大纲"""
|
||
try:
|
||
return self.r1_agent.create_outline(
|
||
question, question_type, refined_questions, research_approach
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"创建大纲失败: {e}")
|
||
# 返回基本大纲
|
||
return {
|
||
"main_topic": question,
|
||
"research_questions": refined_questions[:3],
|
||
"sub_topics": [
|
||
{
|
||
"topic": "核心分析",
|
||
"explain": "对问题进行深入分析",
|
||
"priority": "high",
|
||
"related_questions": refined_questions[:2]
|
||
}
|
||
]
|
||
}
|
||
|
||
def validate_outline(self, outline: Dict[str, Any]) -> str:
|
||
"""验证大纲"""
|
||
try:
|
||
return self.r1_agent.validate_outline(outline)
|
||
except Exception as e:
|
||
logger.error(f"验证大纲失败: {e}")
|
||
return "大纲结构合理。"
|
||
|
||
def modify_outline(self, original_outline: Dict[str, Any],
|
||
user_feedback: str, validation_issues: str = "") -> Dict[str, Any]:
|
||
"""修改大纲"""
|
||
try:
|
||
return self.r1_agent.modify_outline(
|
||
original_outline, user_feedback, validation_issues
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"修改大纲失败: {e}")
|
||
return original_outline
|
||
|
||
# ========== 搜索阶段 (V3 + R1) ==========
|
||
|
||
def generate_search_queries(self, subtopic: str, explanation: str,
|
||
related_questions: List[str], priority: str) -> List[str]:
|
||
"""生成搜索查询(V3)"""
|
||
# 根据优先级确定搜索数量
|
||
count_map = {
|
||
"high": Config.MAX_SEARCHES_HIGH_PRIORITY,
|
||
"medium": Config.MAX_SEARCHES_MEDIUM_PRIORITY,
|
||
"low": Config.MAX_SEARCHES_LOW_PRIORITY
|
||
}
|
||
count = count_map.get(priority, 10)
|
||
|
||
try:
|
||
return self.v3_agent.generate_search_queries(
|
||
subtopic, explanation, related_questions, count
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"生成搜索查询失败: {e}")
|
||
# 返回基本查询
|
||
return [subtopic, f"{subtopic} {explanation}"][:count]
|
||
|
||
def evaluate_search_results(self, subtopic: str,
|
||
search_results: List[SearchResult]) -> List[SearchResult]:
|
||
"""评估搜索结果重要性(R1)"""
|
||
evaluated_results = []
|
||
|
||
for result in search_results:
|
||
try:
|
||
importance = self.r1_agent.evaluate_search_result(
|
||
subtopic,
|
||
result.title,
|
||
result.url,
|
||
result.snippet
|
||
)
|
||
result.importance = SearchImportance(importance)
|
||
evaluated_results.append(result)
|
||
except Exception as e:
|
||
logger.error(f"评估搜索结果失败: {e}")
|
||
result.importance = SearchImportance.MEDIUM
|
||
evaluated_results.append(result)
|
||
|
||
return evaluated_results
|
||
|
||
# ========== 信息反思阶段 (R1) ==========
|
||
|
||
def reflect_on_information(self, subtopic: str,
|
||
search_results: List[SearchResult]) -> List[Dict[str, str]]:
|
||
"""信息反思,返回需要深入的要点"""
|
||
# 生成搜索摘要
|
||
summary = self._generate_search_summary(search_results)
|
||
|
||
try:
|
||
return self.r1_agent.reflect_on_information(subtopic, summary)
|
||
except Exception as e:
|
||
logger.error(f"信息反思失败: {e}")
|
||
return []
|
||
|
||
def generate_refined_queries(self, key_points: List[Dict[str, str]]) -> Dict[str, List[str]]:
|
||
"""为关键点生成细化查询(V3)"""
|
||
refined_queries = {}
|
||
|
||
for point in key_points:
|
||
try:
|
||
queries = self.v3_agent.generate_refined_queries(
|
||
point["key_info"],
|
||
point["detail_needed"]
|
||
)
|
||
refined_queries[point["key_info"]] = queries
|
||
except Exception as e:
|
||
logger.error(f"生成细化查询失败: {e}")
|
||
refined_queries[point["key_info"]] = [point["key_info"]]
|
||
|
||
return refined_queries
|
||
|
||
# ========== 信息整合阶段 (R1) ==========
|
||
|
||
def integrate_information(self, subtopic: str,
|
||
all_search_results: List[SearchResult]) -> Dict[str, Any]:
|
||
"""整合信息"""
|
||
# 格式化搜索结果
|
||
formatted_results = self._format_search_results_for_integration(all_search_results)
|
||
|
||
try:
|
||
return self.r1_agent.integrate_information(subtopic, formatted_results)
|
||
except Exception as e:
|
||
logger.error(f"整合信息失败: {e}")
|
||
# 返回基本结构
|
||
return {
|
||
"key_points": [],
|
||
"themes": []
|
||
}
|
||
|
||
# ========== 报告撰写阶段 (R1) ==========
|
||
|
||
def write_subtopic_report(self, subtopic: str, integrated_info: Dict[str, Any]) -> str:
|
||
"""撰写子主题报告"""
|
||
try:
|
||
return self.r1_agent.write_subtopic_report(subtopic, integrated_info)
|
||
except Exception as e:
|
||
logger.error(f"撰写子主题报告失败: {e}")
|
||
return f"## {subtopic}\n\n撰写报告时发生错误。"
|
||
|
||
# ========== 幻觉检测阶段 (R1 + V3) ==========
|
||
|
||
def detect_and_fix_hallucinations(self, report: str,
|
||
original_sources: Dict[str, str]) -> Tuple[str, List[Dict]]:
|
||
"""检测并修复幻觉内容"""
|
||
hallucinations = []
|
||
fixed_report = report
|
||
|
||
# 提取报告中的所有URL引用
|
||
url_references = self._extract_url_references(report)
|
||
|
||
for url, content in url_references.items():
|
||
if url in original_sources:
|
||
try:
|
||
# 检测幻觉(R1)
|
||
result = self.r1_agent.detect_hallucination(
|
||
content, url, original_sources[url]
|
||
)
|
||
|
||
if result.get("is_hallucination", False):
|
||
hallucinations.append({
|
||
"url": url,
|
||
"content": content,
|
||
"type": result.get("hallucination_type", "未知"),
|
||
"explanation": result.get("explanation", "")
|
||
})
|
||
|
||
# 重写内容(V3)
|
||
try:
|
||
new_content = self.v3_agent.rewrite_hallucination(
|
||
content, original_sources[url]
|
||
)
|
||
fixed_report = fixed_report.replace(content, new_content)
|
||
except Exception as e:
|
||
logger.error(f"重写幻觉内容失败: {e}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"检测幻觉失败: {e}")
|
||
|
||
return fixed_report, hallucinations
|
||
|
||
# ========== 最终报告阶段 (R1) ==========
|
||
|
||
def generate_final_report(self, main_topic: str, research_questions: List[str],
|
||
subtopic_reports: Dict[str, str]) -> str:
|
||
"""生成最终报告"""
|
||
try:
|
||
return self.r1_agent.generate_final_report(
|
||
main_topic, research_questions, subtopic_reports
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"生成最终报告失败: {e}")
|
||
# 返回基本报告
|
||
reports_text = "\n\n---\n\n".join(subtopic_reports.values())
|
||
return f"# {main_topic}\n\n## 研究报告\n\n{reports_text}"
|
||
|
||
# ========== 辅助方法 ==========
|
||
|
||
def _generate_search_summary(self, search_results: List[SearchResult]) -> str:
|
||
"""生成搜索结果摘要"""
|
||
high_count = sum(1 for r in search_results if r.importance == SearchImportance.HIGH)
|
||
medium_count = sum(1 for r in search_results if r.importance == SearchImportance.MEDIUM)
|
||
low_count = sum(1 for r in search_results if r.importance == SearchImportance.LOW)
|
||
|
||
summary_lines = [
|
||
f"共找到 {len(search_results)} 条搜索结果",
|
||
f"高重要性: {high_count} 条",
|
||
f"中重要性: {medium_count} 条",
|
||
f"低重要性: {low_count} 条",
|
||
"",
|
||
"主要发现:"
|
||
]
|
||
|
||
# 添加高重要性结果的摘要
|
||
for result in search_results[:10]: # 最多10条
|
||
if result.importance == SearchImportance.HIGH:
|
||
summary_lines.append(f"- {result.title}: {result.snippet[:100]}...")
|
||
|
||
return '\n'.join(summary_lines)
|
||
|
||
def _format_search_results_for_integration(self, search_results: List[SearchResult]) -> str:
|
||
"""格式化搜索结果用于整合"""
|
||
formatted_lines = []
|
||
|
||
for i, result in enumerate(search_results, 1):
|
||
formatted_lines.extend([
|
||
f"{i}. 来源: {result.url}",
|
||
f" 标题: {result.title}",
|
||
f" 内容: {result.snippet}",
|
||
f" 重要性: {result.importance.value if result.importance else '未评估'}",
|
||
""
|
||
])
|
||
|
||
return '\n'.join(formatted_lines)
|
||
|
||
def _extract_url_references(self, report: str) -> Dict[str, str]:
|
||
"""从报告中提取URL引用及其对应内容"""
|
||
# 简单实现,实际可能需要更复杂的解析
|
||
import re
|
||
|
||
url_references = {}
|
||
# 匹配模式: 内容(来源:URL)
|
||
pattern = r'([^(]+)(来源:([^)]+))'
|
||
|
||
matches = re.finditer(pattern, report)
|
||
for match in matches:
|
||
content = match.group(1).strip()
|
||
url = match.group(2).strip()
|
||
url_references[url] = content
|
||
|
||
return url_references |