""" DeepSeek R1模型智能体 - 带调试功能 负责推理、判断、规划、撰写等思考密集型任务 """ import json import logging from typing import Dict, List, Any, Optional from openai import OpenAI from config import Config from app.agents.prompts import get_prompt from app.utils.json_parser import parse_json_safely from app.utils.debug_logger import ai_debug_logger logger = logging.getLogger(__name__) class R1Agent: """R1模型智能体""" def __init__(self, api_key: str = None): self.api_key = api_key or Config.DEEPSEEK_API_KEY base_url = Config.DEEPSEEK_BASE_URL # 火山引擎 ARK 平台使用不同的模型名称 if 'volces.com' in base_url: self.model = "deepseek-r1-250120" # 火山引擎的 R1 模型名称 else: self.model = Config.R1_MODEL self.client = OpenAI( api_key=self.api_key, base_url=base_url ) def _call_api(self, prompt: str, temperature: float = 0.7, max_tokens: int = 4096, json_mode: bool = False) -> str: """调用R1 API""" try: messages = [{"role": "user", "content": prompt}] # 对于JSON输出,使用补全技巧 if json_mode and "```json" in prompt: # 提取到```json之前的部分作为prompt prefix = prompt.split("```json")[0] + "```json\n" messages = [ {"role": "user", "content": prefix}, {"role": "assistant", "content": "```json\n"} ] logger.info(f"调用R1 API: temperature={temperature}, max_tokens={max_tokens}") response = self.client.chat.completions.create( model=self.model, messages=messages, temperature=temperature, max_tokens=max_tokens ) content = response.choices[0].message.content # 记录原始输出到调试日志 ai_debug_logger.log_api_call( model=self.model, agent_type="R1", method=self._get_caller_method(), prompt=prompt, response=content, temperature=temperature, max_tokens=max_tokens, metadata={ "json_mode": json_mode, "prompt_tokens": response.usage.prompt_tokens if hasattr(response, 'usage') else None, "completion_tokens": response.usage.completion_tokens if hasattr(response, 'usage') else None } ) # 不再单独提取思考过程,保持在原始输出中 # 如果是JSON模式,提取JSON内容 original_content = content if json_mode: if "```json" in content: json_start = content.find("```json") + 7 json_end = content.find("```", json_start) if json_end > json_start: content = content[json_start:json_end].strip() elif content.startswith("```json\n"): # 补全模式的响应 content = content[8:] if content.endswith("```"): content = content[:-3] # 如果提取后的内容与原始内容不同,记录 if content != original_content: logger.debug(f"JSON提取: 原始长度={len(original_content)}, 提取后长度={len(content)}") return content.strip() except Exception as e: logger.error(f"R1 API调用失败: {e}") ai_debug_logger.log_api_call( model=self.model, agent_type="R1", method=self._get_caller_method(), prompt=prompt, response=f"ERROR: {str(e)}", temperature=temperature, max_tokens=max_tokens, metadata={"error": str(e)} ) raise def _get_caller_method(self) -> str: """获取调用方法名""" import inspect frame = inspect.currentframe() if frame and frame.f_back and frame.f_back.f_back: return frame.f_back.f_back.f_code.co_name return "unknown" def analyze_question_type(self, question: str) -> str: """分析问题类型""" prompt = get_prompt("question_type_analysis", question=question) result = self._call_api(prompt, temperature=0.3) # 验证返回值 valid_types = ["factual", "comparative", "exploratory", "decision"] result = result.lower().strip() if result not in valid_types: logger.warning(f"无效的问题类型: {result},默认使用exploratory") return "exploratory" return result def refine_questions(self, question: str, question_type: str) -> List[str]: """细化问题""" prompt = get_prompt("refine_questions", question=question, question_type=question_type) result = self._call_api(prompt) # 解析结果为列表 questions = [q.strip() for q in result.split('\n') if q.strip()] # 过滤掉可能的序号 questions = [q.lstrip('0123456789.-) ') for q in questions] return questions[:5] # 最多返回5个 def create_research_approach(self, question: str, question_type: str, refined_questions: List[str]) -> str: """制定研究思路""" refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions) prompt = get_prompt("research_approach", question=question, question_type=question_type, refined_questions=refined_questions_text) return self._call_api(prompt) def create_outline(self, question: str, question_type: str, refined_questions: List[str], research_approach: str) -> Dict[str, Any]: """创建研究大纲""" refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions) prompt = get_prompt("create_outline", question=question, question_type=question_type, refined_questions=refined_questions_text, research_approach=research_approach) # 尝试获取JSON格式的大纲 for attempt in range(3): try: result = self._call_api(prompt, temperature=0.5, json_mode=True) outline = parse_json_safely(result) # 验证必要字段 if all(key in outline for key in ["main_topic", "research_questions", "sub_topics"]): return outline else: logger.warning(f"大纲缺少必要字段,第{attempt+1}次尝试") ai_debug_logger.log_json_parse_error( result, f"Missing required fields: {outline.keys()}", None ) except Exception as e: logger.error(f"解析大纲失败,第{attempt+1}次尝试: {e}") ai_debug_logger.log_json_parse_error( result if 'result' in locals() else '', str(e), None ) # 返回默认大纲 default_outline = { "main_topic": question, "research_questions": refined_questions[:3], "sub_topics": [ { "topic": "主要方面分析", "explain": "针对问题的核心方面进行深入分析", "priority": "high", "related_questions": refined_questions[:2] } ] } ai_debug_logger.log_json_parse_error( result if 'result' in locals() else '', "Failed to parse after 3 attempts, using default outline", json.dumps(default_outline, ensure_ascii=False) ) return default_outline def validate_outline(self, outline: Dict[str, Any]) -> str: """验证大纲完整性""" prompt = get_prompt("outline_validation", outline=json.dumps(outline, ensure_ascii=False)) return self._call_api(prompt) def modify_outline(self, original_outline: Dict[str, Any], user_feedback: str, validation_issues: str) -> Dict[str, Any]: """修改大纲""" prompt = get_prompt("modify_outline", original_outline=json.dumps(original_outline, ensure_ascii=False), user_feedback=user_feedback, validation_issues=validation_issues) result = self._call_api(prompt, json_mode=True) return parse_json_safely(result) def evaluate_search_result(self, subtopic: str, title: str, url: str, snippet: str) -> str: """评估搜索结果重要性""" prompt = get_prompt("evaluate_search_results", subtopic=subtopic, title=title, url=url, snippet=snippet) result = self._call_api(prompt, temperature=0.3).lower().strip() # 验证返回值 if result not in ["high", "medium", "low"]: return "medium" return result def reflect_on_information(self, subtopic: str, search_summary: str) -> List[Dict[str, str]]: """信息反思,返回需要深入搜索的要点""" # 这里可以基于search_summary生成更详细的分析 prompt = get_prompt("information_reflection", subtopic=subtopic, search_summary=search_summary, detailed_analysis="[基于搜索结果的详细分析]") result = self._call_api(prompt) # 解析结果,提取需要深入的要点 # 简单实现,实际可能需要更复杂的解析 key_points = [] lines = result.split('\n') for line in lines: if line.strip() and '还需要搜索' in line: parts = line.split('还需要搜索') if len(parts) == 2: key_points.append({ "key_info": parts[0].strip(), "detail_needed": parts[1].strip('()() ') }) return key_points def integrate_information(self, subtopic: str, all_search_results: str) -> Dict[str, Any]: """整合信息为结构化格式""" prompt = get_prompt("integrate_information", subtopic=subtopic, all_search_results=all_search_results) result = self._call_api(prompt, json_mode=True) parsed = parse_json_safely(result) # 确保返回的结构包含必要字段 if 'key_points' not in parsed: parsed['key_points'] = [] if 'themes' not in parsed: parsed['themes'] = [] return parsed def write_subtopic_report(self, subtopic: str, integrated_info: Dict[str, Any]) -> str: """撰写子主题报告""" prompt = get_prompt("write_subtopic_report", subtopic=subtopic, integrated_info=json.dumps(integrated_info, ensure_ascii=False)) return self._call_api(prompt, temperature=0.7, max_tokens=8192) def detect_hallucination(self, written_content: str, claimed_url: str, original_content: str) -> Dict[str, Any]: """检测幻觉内容""" prompt = get_prompt("hallucination_detection", written_content=written_content, claimed_url=claimed_url, original_content=original_content) result = self._call_api(prompt, temperature=0.3, json_mode=True) parsed = parse_json_safely(result) # 确保返回结构正确 if 'is_hallucination' not in parsed: parsed['is_hallucination'] = False if 'hallucination_type' not in parsed: parsed['hallucination_type'] = None if 'explanation' not in parsed: parsed['explanation'] = '' return parsed def generate_final_report(self, main_topic: str, research_questions: List[str], subtopic_reports: Dict[str, str]) -> str: """生成最终报告""" # 格式化子主题报告 reports_text = "\n\n---\n\n".join([ f"### {topic}\n{report}" for topic, report in subtopic_reports.items() ]) prompt = get_prompt("generate_final_report", main_topic=main_topic, research_questions='\n'.join(f"- {q}" for q in research_questions), subtopic_reports=reports_text) return self._call_api(prompt, temperature=0.7, max_tokens=16384)