deepresearch/所有文件/r1_agent.py
2025-07-02 15:35:36 +08:00

254 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
DeepSeek R1模型智能体
负责推理、判断、规划、撰写等思考密集型任务
"""
import json
import logging
from typing import Dict, List, Any, Optional
from openai import OpenAI
from config import Config
from app.agents.prompts import get_prompt
from app.utils.json_parser import parse_json_safely
logger = logging.getLogger(__name__)
class R1Agent:
"""R1模型智能体"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.DEEPSEEK_API_KEY
base_url = Config.DEEPSEEK_BASE_URL
# 火山引擎 ARK 平台使用不同的模型名称
if 'volces.com' in base_url:
self.model = "deepseek-r1-250120" # 火山引擎的 R1 模型名称
else:
self.model = Config.R1_MODEL
self.client = OpenAI(
api_key=self.api_key,
base_url=base_url
)
def _call_api(self, prompt: str, temperature: float = 0.7,
max_tokens: int = 4096, json_mode: bool = False) -> str:
"""调用R1 API"""
try:
messages = [{"role": "user", "content": prompt}]
# 对于JSON输出使用补全技巧
if json_mode and "```json" in prompt:
# 提取到```json之前的部分作为prompt
prefix = prompt.split("```json")[0] + "```json\n"
messages = [
{"role": "user", "content": prefix},
{"role": "assistant", "content": "```json\n"}
]
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
content = response.choices[0].message.content
# 如果是JSON模式提取JSON内容
if json_mode:
if "```json" in content:
json_start = content.find("```json") + 7
json_end = content.find("```", json_start)
if json_end > json_start:
content = content[json_start:json_end].strip()
elif content.startswith("```json\n"):
# 补全模式的响应
content = content[8:]
if content.endswith("```"):
content = content[:-3]
return content.strip()
except Exception as e:
logger.error(f"R1 API调用失败: {e}")
raise
def analyze_question_type(self, question: str) -> str:
"""分析问题类型"""
prompt = get_prompt("question_type_analysis", question=question)
result = self._call_api(prompt, temperature=0.3)
# 验证返回值
valid_types = ["factual", "comparative", "exploratory", "decision"]
result = result.lower().strip()
if result not in valid_types:
logger.warning(f"无效的问题类型: {result}默认使用exploratory")
return "exploratory"
return result
def refine_questions(self, question: str, question_type: str) -> List[str]:
"""细化问题"""
prompt = get_prompt("refine_questions",
question=question,
question_type=question_type)
result = self._call_api(prompt)
# 解析结果为列表
questions = [q.strip() for q in result.split('\n') if q.strip()]
# 过滤掉可能的序号
questions = [q.lstrip('0123456789.-) ') for q in questions]
return questions[:5] # 最多返回5个
def create_research_approach(self, question: str, question_type: str,
refined_questions: List[str]) -> str:
"""制定研究思路"""
refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions)
prompt = get_prompt("research_approach",
question=question,
question_type=question_type,
refined_questions=refined_questions_text)
return self._call_api(prompt)
def create_outline(self, question: str, question_type: str,
refined_questions: List[str], research_approach: str) -> Dict[str, Any]:
"""创建研究大纲"""
refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions)
prompt = get_prompt("create_outline",
question=question,
question_type=question_type,
refined_questions=refined_questions_text,
research_approach=research_approach)
# 尝试获取JSON格式的大纲
for attempt in range(3):
try:
result = self._call_api(prompt, temperature=0.5, json_mode=True)
outline = parse_json_safely(result)
# 验证必要字段
if all(key in outline for key in ["main_topic", "research_questions", "sub_topics"]):
return outline
else:
logger.warning(f"大纲缺少必要字段,第{attempt+1}次尝试")
except Exception as e:
logger.error(f"解析大纲失败,第{attempt+1}次尝试: {e}")
# 返回默认大纲
return {
"main_topic": question,
"research_questions": refined_questions[:3],
"sub_topics": [
{
"topic": "主要方面分析",
"explain": "针对问题的核心方面进行深入分析",
"priority": "high",
"related_questions": refined_questions[:2]
}
]
}
def validate_outline(self, outline: Dict[str, Any]) -> str:
"""验证大纲完整性"""
prompt = get_prompt("outline_validation", outline=json.dumps(outline, ensure_ascii=False))
return self._call_api(prompt)
def modify_outline(self, original_outline: Dict[str, Any],
user_feedback: str, validation_issues: str) -> Dict[str, Any]:
"""修改大纲"""
prompt = get_prompt("modify_outline",
original_outline=json.dumps(original_outline, ensure_ascii=False),
user_feedback=user_feedback,
validation_issues=validation_issues)
result = self._call_api(prompt, json_mode=True)
return parse_json_safely(result)
def evaluate_search_result(self, subtopic: str, title: str,
url: str, snippet: str) -> str:
"""评估搜索结果重要性"""
prompt = get_prompt("evaluate_search_results",
subtopic=subtopic,
title=title,
url=url,
snippet=snippet)
result = self._call_api(prompt, temperature=0.3).lower().strip()
# 验证返回值
if result not in ["high", "medium", "low"]:
return "medium"
return result
def reflect_on_information(self, subtopic: str, search_summary: str) -> List[Dict[str, str]]:
"""信息反思,返回需要深入搜索的要点"""
# 这里可以基于search_summary生成更详细的分析
prompt = get_prompt("information_reflection",
subtopic=subtopic,
search_summary=search_summary,
detailed_analysis="[基于搜索结果的详细分析]")
result = self._call_api(prompt)
# 解析结果,提取需要深入的要点
# 简单实现,实际可能需要更复杂的解析
key_points = []
lines = result.split('\n')
for line in lines:
if line.strip() and '还需要搜索' in line:
parts = line.split('还需要搜索')
if len(parts) == 2:
key_points.append({
"key_info": parts[0].strip(),
"detail_needed": parts[1].strip('() ')
})
return key_points
def integrate_information(self, subtopic: str, all_search_results: str) -> Dict[str, Any]:
"""整合信息为结构化格式"""
prompt = get_prompt("integrate_information",
subtopic=subtopic,
all_search_results=all_search_results)
result = self._call_api(prompt, json_mode=True)
return parse_json_safely(result)
def write_subtopic_report(self, subtopic: str, integrated_info: Dict[str, Any]) -> str:
"""撰写子主题报告"""
prompt = get_prompt("write_subtopic_report",
subtopic=subtopic,
integrated_info=json.dumps(integrated_info, ensure_ascii=False))
return self._call_api(prompt, temperature=0.7, max_tokens=8192)
def detect_hallucination(self, written_content: str, claimed_url: str,
original_content: str) -> Dict[str, Any]:
"""检测幻觉内容"""
prompt = get_prompt("hallucination_detection",
written_content=written_content,
claimed_url=claimed_url,
original_content=original_content)
result = self._call_api(prompt, temperature=0.3, json_mode=True)
return parse_json_safely(result)
def generate_final_report(self, main_topic: str, research_questions: List[str],
subtopic_reports: Dict[str, str]) -> str:
"""生成最终报告"""
# 格式化子主题报告
reports_text = "\n\n---\n\n".join([
f"### {topic}\n{report}"
for topic, report in subtopic_reports.items()
])
prompt = get_prompt("generate_final_report",
main_topic=main_topic,
research_questions='\n'.join(f"- {q}" for q in research_questions),
subtopic_reports=reports_text)
return self._call_api(prompt, temperature=0.7, max_tokens=16384)