deepresearch/app/services/research_manager.py
2025-07-02 15:35:36 +08:00

323 lines
12 KiB
Python

# 文件位置: app/services/research_manager.py
# 文件名: research_manager.py
"""
研究流程管理器
协调整个研究过程
"""
import os
import json
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional
from app.models.research import ResearchSession, ResearchStatus, ResearchOutline, Subtopic
from app.services.ai_service import AIService
from app.services.search_service import SearchService
from app.services.report_generator import ReportGenerator
from config import Config
logger = logging.getLogger(__name__)
class ResearchManager:
"""研究流程管理器"""
def __init__(self):
self.ai_service = AIService()
self.search_service = SearchService()
self.report_generator = ReportGenerator()
self.sessions: Dict[str, ResearchSession] = {}
def create_session(self, question: str) -> ResearchSession:
"""创建新的研究会话"""
session = ResearchSession(question=question)
self.sessions[session.id] = session
# 保存到文件
self._save_session(session)
logger.info(f"创建研究会话: {session.id}")
return session
def start_research(self, session_id: str) -> Dict[str, Any]:
"""启动研究流程"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
try:
# 更新状态
session.update_status(ResearchStatus.ANALYZING)
self._save_session(session)
# 启动异步任务链
# 延迟导入,完全避免循环依赖
from app.tasks.research_tasks import analyze_question_chain
analyze_question_chain.delay(session_id)
return {
"status": "started",
"session_id": session_id,
"message": "研究已开始"
}
except Exception as e:
logger.error(f"启动研究失败: {e}")
session.update_status(ResearchStatus.ERROR)
session.error_message = str(e)
self._save_session(session)
raise
def get_session(self, session_id: str) -> Optional[ResearchSession]:
"""获取研究会话"""
# 先从内存查找
if session_id in self.sessions:
return self.sessions[session_id]
# 从文件加载
filepath = self._get_session_filepath(session_id)
if os.path.exists(filepath):
session = ResearchSession.load_from_file(filepath)
self.sessions[session_id] = session
return session
return None
def update_session(self, session_id: str, updates: Dict[str, Any]):
"""更新会话信息"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 更新字段
for key, value in updates.items():
if hasattr(session, key):
setattr(session, key, value)
session.updated_at = datetime.now()
self._save_session(session)
def get_session_status(self, session_id: str) -> Dict[str, Any]:
"""获取研究进度"""
session = self.get_session(session_id)
if not session:
return {"error": "Session not found"}
# 计算子主题进度
subtopic_progress = []
if session.outline:
for subtopic in session.outline.sub_topics:
subtopic_progress.append({
"id": subtopic.id,
"topic": subtopic.topic,
"status": subtopic.status,
"progress": subtopic.get_total_searches() / subtopic.max_searches * 100
})
return {
"session_id": session_id,
"status": session.status,
"current_phase": session.current_phase,
"progress_percentage": session.get_progress_percentage(),
"subtopic_progress": subtopic_progress,
"created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(),
"error_message": session.error_message
}
def cancel_research(self, session_id: str) -> Dict[str, Any]:
"""取消研究"""
session = self.get_session(session_id)
if not session:
return {"error": "Session not found"}
# 更新状态
session.update_status(ResearchStatus.CANCELLED)
self._save_session(session)
return {
"status": "cancelled",
"session_id": session_id,
"message": "研究已取消"
}
def get_research_report(self, session_id: str) -> Optional[str]:
"""获取研究报告"""
session = self.get_session(session_id)
if not session:
return None
if session.status != ResearchStatus.COMPLETED:
return None
# 如果有最终报告,返回
if session.final_report:
return session.final_report
# 否则尝试从文件加载
report_path = os.path.join(Config.REPORTS_DIR, f"{session_id}.md")
if os.path.exists(report_path):
with open(report_path, 'r', encoding='utf-8') as f:
return f.read()
return None
def list_sessions(self, limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]:
"""列出所有研究会话"""
# 从文件系统读取所有会话
sessions = []
session_files = sorted(
[f for f in os.listdir(Config.SESSIONS_DIR) if f.endswith('.json')],
reverse=True # 最新的在前
)
for filename in session_files[offset:offset+limit]:
filepath = os.path.join(Config.SESSIONS_DIR, filename)
try:
session = ResearchSession.load_from_file(filepath)
sessions.append({
"id": session.id,
"question": session.question,
"status": session.status,
"created_at": session.created_at.isoformat(),
"progress": session.get_progress_percentage()
})
except Exception as e:
logger.error(f"加载会话失败 {filename}: {e}")
return sessions
def _save_session(self, session: ResearchSession):
"""保存会话到文件"""
filepath = self._get_session_filepath(session.id)
# 使用模型的 to_dict 方法处理 datetime 序列化
data = session.dict()
# 转换 datetime 对象
for key in ['created_at', 'updated_at', 'completed_at']:
if data.get(key):
data[key] = data[key].isoformat() if hasattr(data[key], 'isoformat') else data[key]
# 处理嵌套的 datetime
if data.get('outline'):
if data['outline'].get('created_at'):
data['outline']['created_at'] = data['outline']['created_at'].isoformat()
if data['outline'].get('updated_at'):
data['outline']['updated_at'] = data['outline']['updated_at'].isoformat()
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
def _get_session_filepath(self, session_id: str) -> str:
"""获取会话文件路径"""
return os.path.join(Config.SESSIONS_DIR, f"{session_id}.json")
# 以下是供任务调用的方法
def process_question_analysis(self, session_id: str):
"""处理问题分析阶段"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 分析问题类型
session.question_type = self.ai_service.analyze_question_type(session.question)
# 细化问题
session.refined_questions = self.ai_service.refine_questions(
session.question,
session.question_type
)
# 制定研究思路
session.research_approach = self.ai_service.create_research_approach(
session.question,
session.question_type,
session.refined_questions
)
# 更新进度
session.current_phase = "制定大纲"
session.completed_steps += 1
self._save_session(session)
def process_outline_creation(self, session_id: str):
"""处理大纲创建阶段"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 创建大纲
outline_dict = self.ai_service.create_outline(
session.question,
session.question_type,
session.refined_questions,
session.research_approach
)
# 转换为模型对象
subtopics = []
for st in outline_dict.get('sub_topics', []):
subtopic = Subtopic(
topic=st['topic'],
explain=st['explain'],
priority=st['priority'],
related_questions=st.get('related_questions', [])
)
# 设置最大搜索次数
if subtopic.priority == "high":
subtopic.max_searches = Config.MAX_SEARCHES_HIGH_PRIORITY
elif subtopic.priority == "medium":
subtopic.max_searches = Config.MAX_SEARCHES_MEDIUM_PRIORITY
else:
subtopic.max_searches = Config.MAX_SEARCHES_LOW_PRIORITY
subtopics.append(subtopic)
session.outline = ResearchOutline(
main_topic=outline_dict['main_topic'],
research_questions=outline_dict['research_questions'],
sub_topics=subtopics
)
# 更新进度
session.current_phase = "研究子主题"
session.update_status(ResearchStatus.RESEARCHING)
session.total_steps = 3 + len(subtopics) + 1 # 准备+大纲+子主题+最终报告
session.completed_steps = 2
self._save_session(session)
def process_subtopic_research(self, session_id: str, subtopic_id: str):
"""处理子主题研究"""
session = self.get_session(session_id)
if not session or not session.outline:
raise ValueError(f"Session or outline not found: {session_id}")
# 找到对应的子主题
subtopic = None
for st in session.outline.sub_topics:
if st.id == subtopic_id:
subtopic = st
break
if not subtopic:
raise ValueError(f"Subtopic not found: {subtopic_id}")
# 执行研究流程
# 这部分逻辑会在research_tasks.py中实现
# 这里只更新状态
subtopic.status = ResearchStatus.COMPLETED
session.completed_steps += 1
self._save_session(session)
def finalize_research(self, session_id: str):
"""完成研究"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 生成最终报告
# 这部分逻辑会在report_generator.py中实现
# 更新状态
session.update_status(ResearchStatus.COMPLETED)
session.current_phase = "研究完成"
session.completed_steps = session.total_steps
self._save_session(session)