254 lines
10 KiB
Python
254 lines
10 KiB
Python
"""
|
||
DeepSeek R1模型智能体
|
||
负责推理、判断、规划、撰写等思考密集型任务
|
||
"""
|
||
import json
|
||
import logging
|
||
from typing import Dict, List, Any, Optional
|
||
from openai import OpenAI
|
||
from config import Config
|
||
from app.agents.prompts import get_prompt
|
||
from app.utils.json_parser import parse_json_safely
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class R1Agent:
|
||
"""R1模型智能体"""
|
||
|
||
def __init__(self, api_key: str = None):
|
||
self.api_key = api_key or Config.DEEPSEEK_API_KEY
|
||
base_url = Config.DEEPSEEK_BASE_URL
|
||
|
||
# 火山引擎 ARK 平台使用不同的模型名称
|
||
if 'volces.com' in base_url:
|
||
self.model = "deepseek-r1-250120" # 火山引擎的 R1 模型名称
|
||
else:
|
||
self.model = Config.R1_MODEL
|
||
|
||
self.client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=base_url
|
||
)
|
||
|
||
def _call_api(self, prompt: str, temperature: float = 0.7,
|
||
max_tokens: int = 4096, json_mode: bool = False) -> str:
|
||
"""调用R1 API"""
|
||
try:
|
||
messages = [{"role": "user", "content": prompt}]
|
||
|
||
# 对于JSON输出,使用补全技巧
|
||
if json_mode and "```json" in prompt:
|
||
# 提取到```json之前的部分作为prompt
|
||
prefix = prompt.split("```json")[0] + "```json\n"
|
||
messages = [
|
||
{"role": "user", "content": prefix},
|
||
{"role": "assistant", "content": "```json\n"}
|
||
]
|
||
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
|
||
# 如果是JSON模式,提取JSON内容
|
||
if json_mode:
|
||
if "```json" in content:
|
||
json_start = content.find("```json") + 7
|
||
json_end = content.find("```", json_start)
|
||
if json_end > json_start:
|
||
content = content[json_start:json_end].strip()
|
||
elif content.startswith("```json\n"):
|
||
# 补全模式的响应
|
||
content = content[8:]
|
||
if content.endswith("```"):
|
||
content = content[:-3]
|
||
|
||
return content.strip()
|
||
|
||
except Exception as e:
|
||
logger.error(f"R1 API调用失败: {e}")
|
||
raise
|
||
|
||
def analyze_question_type(self, question: str) -> str:
|
||
"""分析问题类型"""
|
||
prompt = get_prompt("question_type_analysis", question=question)
|
||
result = self._call_api(prompt, temperature=0.3)
|
||
|
||
# 验证返回值
|
||
valid_types = ["factual", "comparative", "exploratory", "decision"]
|
||
result = result.lower().strip()
|
||
|
||
if result not in valid_types:
|
||
logger.warning(f"无效的问题类型: {result},默认使用exploratory")
|
||
return "exploratory"
|
||
|
||
return result
|
||
|
||
def refine_questions(self, question: str, question_type: str) -> List[str]:
|
||
"""细化问题"""
|
||
prompt = get_prompt("refine_questions",
|
||
question=question,
|
||
question_type=question_type)
|
||
result = self._call_api(prompt)
|
||
|
||
# 解析结果为列表
|
||
questions = [q.strip() for q in result.split('\n') if q.strip()]
|
||
# 过滤掉可能的序号
|
||
questions = [q.lstrip('0123456789.-) ') for q in questions]
|
||
|
||
return questions[:5] # 最多返回5个
|
||
|
||
def create_research_approach(self, question: str, question_type: str,
|
||
refined_questions: List[str]) -> str:
|
||
"""制定研究思路"""
|
||
refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions)
|
||
prompt = get_prompt("research_approach",
|
||
question=question,
|
||
question_type=question_type,
|
||
refined_questions=refined_questions_text)
|
||
return self._call_api(prompt)
|
||
|
||
def create_outline(self, question: str, question_type: str,
|
||
refined_questions: List[str], research_approach: str) -> Dict[str, Any]:
|
||
"""创建研究大纲"""
|
||
refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions)
|
||
prompt = get_prompt("create_outline",
|
||
question=question,
|
||
question_type=question_type,
|
||
refined_questions=refined_questions_text,
|
||
research_approach=research_approach)
|
||
|
||
# 尝试获取JSON格式的大纲
|
||
for attempt in range(3):
|
||
try:
|
||
result = self._call_api(prompt, temperature=0.5, json_mode=True)
|
||
outline = parse_json_safely(result)
|
||
|
||
# 验证必要字段
|
||
if all(key in outline for key in ["main_topic", "research_questions", "sub_topics"]):
|
||
return outline
|
||
else:
|
||
logger.warning(f"大纲缺少必要字段,第{attempt+1}次尝试")
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析大纲失败,第{attempt+1}次尝试: {e}")
|
||
|
||
# 返回默认大纲
|
||
return {
|
||
"main_topic": question,
|
||
"research_questions": refined_questions[:3],
|
||
"sub_topics": [
|
||
{
|
||
"topic": "主要方面分析",
|
||
"explain": "针对问题的核心方面进行深入分析",
|
||
"priority": "high",
|
||
"related_questions": refined_questions[:2]
|
||
}
|
||
]
|
||
}
|
||
|
||
def validate_outline(self, outline: Dict[str, Any]) -> str:
|
||
"""验证大纲完整性"""
|
||
prompt = get_prompt("outline_validation", outline=json.dumps(outline, ensure_ascii=False))
|
||
return self._call_api(prompt)
|
||
|
||
def modify_outline(self, original_outline: Dict[str, Any],
|
||
user_feedback: str, validation_issues: str) -> Dict[str, Any]:
|
||
"""修改大纲"""
|
||
prompt = get_prompt("modify_outline",
|
||
original_outline=json.dumps(original_outline, ensure_ascii=False),
|
||
user_feedback=user_feedback,
|
||
validation_issues=validation_issues)
|
||
|
||
result = self._call_api(prompt, json_mode=True)
|
||
return parse_json_safely(result)
|
||
|
||
def evaluate_search_result(self, subtopic: str, title: str,
|
||
url: str, snippet: str) -> str:
|
||
"""评估搜索结果重要性"""
|
||
prompt = get_prompt("evaluate_search_results",
|
||
subtopic=subtopic,
|
||
title=title,
|
||
url=url,
|
||
snippet=snippet)
|
||
|
||
result = self._call_api(prompt, temperature=0.3).lower().strip()
|
||
|
||
# 验证返回值
|
||
if result not in ["high", "medium", "low"]:
|
||
return "medium"
|
||
|
||
return result
|
||
|
||
def reflect_on_information(self, subtopic: str, search_summary: str) -> List[Dict[str, str]]:
|
||
"""信息反思,返回需要深入搜索的要点"""
|
||
# 这里可以基于search_summary生成更详细的分析
|
||
prompt = get_prompt("information_reflection",
|
||
subtopic=subtopic,
|
||
search_summary=search_summary,
|
||
detailed_analysis="[基于搜索结果的详细分析]")
|
||
|
||
result = self._call_api(prompt)
|
||
|
||
# 解析结果,提取需要深入的要点
|
||
# 简单实现,实际可能需要更复杂的解析
|
||
key_points = []
|
||
lines = result.split('\n')
|
||
for line in lines:
|
||
if line.strip() and '还需要搜索' in line:
|
||
parts = line.split('还需要搜索')
|
||
if len(parts) == 2:
|
||
key_points.append({
|
||
"key_info": parts[0].strip(),
|
||
"detail_needed": parts[1].strip('()() ')
|
||
})
|
||
|
||
return key_points
|
||
|
||
def integrate_information(self, subtopic: str, all_search_results: str) -> Dict[str, Any]:
|
||
"""整合信息为结构化格式"""
|
||
prompt = get_prompt("integrate_information",
|
||
subtopic=subtopic,
|
||
all_search_results=all_search_results)
|
||
|
||
result = self._call_api(prompt, json_mode=True)
|
||
return parse_json_safely(result)
|
||
|
||
def write_subtopic_report(self, subtopic: str, integrated_info: Dict[str, Any]) -> str:
|
||
"""撰写子主题报告"""
|
||
prompt = get_prompt("write_subtopic_report",
|
||
subtopic=subtopic,
|
||
integrated_info=json.dumps(integrated_info, ensure_ascii=False))
|
||
|
||
return self._call_api(prompt, temperature=0.7, max_tokens=8192)
|
||
|
||
def detect_hallucination(self, written_content: str, claimed_url: str,
|
||
original_content: str) -> Dict[str, Any]:
|
||
"""检测幻觉内容"""
|
||
prompt = get_prompt("hallucination_detection",
|
||
written_content=written_content,
|
||
claimed_url=claimed_url,
|
||
original_content=original_content)
|
||
|
||
result = self._call_api(prompt, temperature=0.3, json_mode=True)
|
||
return parse_json_safely(result)
|
||
|
||
def generate_final_report(self, main_topic: str, research_questions: List[str],
|
||
subtopic_reports: Dict[str, str]) -> str:
|
||
"""生成最终报告"""
|
||
# 格式化子主题报告
|
||
reports_text = "\n\n---\n\n".join([
|
||
f"### {topic}\n{report}"
|
||
for topic, report in subtopic_reports.items()
|
||
])
|
||
|
||
prompt = get_prompt("generate_final_report",
|
||
main_topic=main_topic,
|
||
research_questions='\n'.join(f"- {q}" for q in research_questions),
|
||
subtopic_reports=reports_text)
|
||
|
||
return self._call_api(prompt, temperature=0.7, max_tokens=16384) |