# 文件位置: app/services/research_manager.py # 文件名: research_manager.py """ 研究流程管理器 协调整个研究过程 """ import os import json import logging from datetime import datetime from typing import Dict, List, Any, Optional from app.models.research import ResearchSession, ResearchStatus, ResearchOutline, Subtopic from app.services.ai_service import AIService from app.services.search_service import SearchService from app.services.report_generator import ReportGenerator from config import Config logger = logging.getLogger(__name__) class ResearchManager: """研究流程管理器""" def __init__(self): self.ai_service = AIService() self.search_service = SearchService() self.report_generator = ReportGenerator() self.sessions: Dict[str, ResearchSession] = {} def create_session(self, question: str) -> ResearchSession: """创建新的研究会话""" session = ResearchSession(question=question) self.sessions[session.id] = session # 保存到文件 self._save_session(session) logger.info(f"创建研究会话: {session.id}") return session def start_research(self, session_id: str) -> Dict[str, Any]: """启动研究流程""" session = self.get_session(session_id) if not session: raise ValueError(f"Session not found: {session_id}") try: # 更新状态 session.update_status(ResearchStatus.ANALYZING) self._save_session(session) # 启动异步任务链 # 延迟导入,完全避免循环依赖 from app.tasks.research_tasks import analyze_question_chain analyze_question_chain.delay(session_id) return { "status": "started", "session_id": session_id, "message": "研究已开始" } except Exception as e: logger.error(f"启动研究失败: {e}") session.update_status(ResearchStatus.ERROR) session.error_message = str(e) self._save_session(session) raise def get_session(self, session_id: str) -> Optional[ResearchSession]: """获取研究会话""" # 先从内存查找 if session_id in self.sessions: return self.sessions[session_id] # 从文件加载 filepath = self._get_session_filepath(session_id) if os.path.exists(filepath): session = ResearchSession.load_from_file(filepath) self.sessions[session_id] = session return session return None def update_session(self, session_id: str, updates: Dict[str, Any]): """更新会话信息""" session = self.get_session(session_id) if not session: raise ValueError(f"Session not found: {session_id}") # 更新字段 for key, value in updates.items(): if hasattr(session, key): setattr(session, key, value) session.updated_at = datetime.now() self._save_session(session) def get_session_status(self, session_id: str) -> Dict[str, Any]: """获取研究进度""" session = self.get_session(session_id) if not session: return {"error": "Session not found"} # 计算子主题进度 subtopic_progress = [] if session.outline: for subtopic in session.outline.sub_topics: subtopic_progress.append({ "id": subtopic.id, "topic": subtopic.topic, "status": subtopic.status, "progress": subtopic.get_total_searches() / subtopic.max_searches * 100 }) return { "session_id": session_id, "status": session.status, "current_phase": session.current_phase, "progress_percentage": session.get_progress_percentage(), "subtopic_progress": subtopic_progress, "created_at": session.created_at.isoformat(), "updated_at": session.updated_at.isoformat(), "error_message": session.error_message } def cancel_research(self, session_id: str) -> Dict[str, Any]: """取消研究""" session = self.get_session(session_id) if not session: return {"error": "Session not found"} # 更新状态 session.update_status(ResearchStatus.CANCELLED) self._save_session(session) return { "status": "cancelled", "session_id": session_id, "message": "研究已取消" } def get_research_report(self, session_id: str) -> Optional[str]: """获取研究报告""" session = self.get_session(session_id) if not session: return None if session.status != ResearchStatus.COMPLETED: return None # 如果有最终报告,返回 if session.final_report: return session.final_report # 否则尝试从文件加载 report_path = os.path.join(Config.REPORTS_DIR, f"{session_id}.md") if os.path.exists(report_path): with open(report_path, 'r', encoding='utf-8') as f: return f.read() return None def list_sessions(self, limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]: """列出所有研究会话""" # 从文件系统读取所有会话 sessions = [] session_files = sorted( [f for f in os.listdir(Config.SESSIONS_DIR) if f.endswith('.json')], reverse=True # 最新的在前 ) for filename in session_files[offset:offset+limit]: filepath = os.path.join(Config.SESSIONS_DIR, filename) try: session = ResearchSession.load_from_file(filepath) sessions.append({ "id": session.id, "question": session.question, "status": session.status, "created_at": session.created_at.isoformat(), "progress": session.get_progress_percentage() }) except Exception as e: logger.error(f"加载会话失败 {filename}: {e}") return sessions def _save_session(self, session: ResearchSession): """保存会话到文件""" filepath = self._get_session_filepath(session.id) # 使用模型的 to_dict 方法处理 datetime 序列化 data = session.dict() # 转换 datetime 对象 for key in ['created_at', 'updated_at', 'completed_at']: if data.get(key): data[key] = data[key].isoformat() if hasattr(data[key], 'isoformat') else data[key] # 处理嵌套的 datetime if data.get('outline'): if data['outline'].get('created_at'): data['outline']['created_at'] = data['outline']['created_at'].isoformat() if data['outline'].get('updated_at'): data['outline']['updated_at'] = data['outline']['updated_at'].isoformat() with open(filepath, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2, default=str) def _get_session_filepath(self, session_id: str) -> str: """获取会话文件路径""" return os.path.join(Config.SESSIONS_DIR, f"{session_id}.json") # 以下是供任务调用的方法 def process_question_analysis(self, session_id: str): """处理问题分析阶段""" session = self.get_session(session_id) if not session: raise ValueError(f"Session not found: {session_id}") # 分析问题类型 session.question_type = self.ai_service.analyze_question_type(session.question) # 细化问题 session.refined_questions = self.ai_service.refine_questions( session.question, session.question_type ) # 制定研究思路 session.research_approach = self.ai_service.create_research_approach( session.question, session.question_type, session.refined_questions ) # 更新进度 session.current_phase = "制定大纲" session.completed_steps += 1 self._save_session(session) def process_outline_creation(self, session_id: str): """处理大纲创建阶段""" session = self.get_session(session_id) if not session: raise ValueError(f"Session not found: {session_id}") # 创建大纲 outline_dict = self.ai_service.create_outline( session.question, session.question_type, session.refined_questions, session.research_approach ) # 转换为模型对象 subtopics = [] for st in outline_dict.get('sub_topics', []): subtopic = Subtopic( topic=st['topic'], explain=st['explain'], priority=st['priority'], related_questions=st.get('related_questions', []) ) # 设置最大搜索次数 if subtopic.priority == "high": subtopic.max_searches = Config.MAX_SEARCHES_HIGH_PRIORITY elif subtopic.priority == "medium": subtopic.max_searches = Config.MAX_SEARCHES_MEDIUM_PRIORITY else: subtopic.max_searches = Config.MAX_SEARCHES_LOW_PRIORITY subtopics.append(subtopic) session.outline = ResearchOutline( main_topic=outline_dict['main_topic'], research_questions=outline_dict['research_questions'], sub_topics=subtopics ) # 更新进度 session.current_phase = "研究子主题" session.update_status(ResearchStatus.RESEARCHING) session.total_steps = 3 + len(subtopics) + 1 # 准备+大纲+子主题+最终报告 session.completed_steps = 2 self._save_session(session) def process_subtopic_research(self, session_id: str, subtopic_id: str): """处理子主题研究""" session = self.get_session(session_id) if not session or not session.outline: raise ValueError(f"Session or outline not found: {session_id}") # 找到对应的子主题 subtopic = None for st in session.outline.sub_topics: if st.id == subtopic_id: subtopic = st break if not subtopic: raise ValueError(f"Subtopic not found: {subtopic_id}") # 执行研究流程 # 这部分逻辑会在research_tasks.py中实现 # 这里只更新状态 subtopic.status = ResearchStatus.COMPLETED session.completed_steps += 1 self._save_session(session) def finalize_research(self, session_id: str): """完成研究""" session = self.get_session(session_id) if not session: raise ValueError(f"Session not found: {session_id}") # 生成最终报告 # 这部分逻辑会在report_generator.py中实现 # 更新状态 session.update_status(ResearchStatus.COMPLETED) session.current_phase = "研究完成" session.completed_steps = session.total_steps self._save_session(session)