323 lines
12 KiB
Python
323 lines
12 KiB
Python
# 文件位置: app/services/research_manager.py
|
|
# 文件名: research_manager.py
|
|
|
|
"""
|
|
研究流程管理器
|
|
协调整个研究过程
|
|
"""
|
|
import os
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Dict, List, Any, Optional
|
|
from app.models.research import ResearchSession, ResearchStatus, ResearchOutline, Subtopic
|
|
from app.services.ai_service import AIService
|
|
from app.services.search_service import SearchService
|
|
from app.services.report_generator import ReportGenerator
|
|
from config import Config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ResearchManager:
|
|
"""研究流程管理器"""
|
|
|
|
def __init__(self):
|
|
self.ai_service = AIService()
|
|
self.search_service = SearchService()
|
|
self.report_generator = ReportGenerator()
|
|
self.sessions: Dict[str, ResearchSession] = {}
|
|
|
|
def create_session(self, question: str) -> ResearchSession:
|
|
"""创建新的研究会话"""
|
|
session = ResearchSession(question=question)
|
|
self.sessions[session.id] = session
|
|
|
|
# 保存到文件
|
|
self._save_session(session)
|
|
|
|
logger.info(f"创建研究会话: {session.id}")
|
|
return session
|
|
|
|
def start_research(self, session_id: str) -> Dict[str, Any]:
|
|
"""启动研究流程"""
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
try:
|
|
# 更新状态
|
|
session.update_status(ResearchStatus.ANALYZING)
|
|
self._save_session(session)
|
|
|
|
# 启动异步任务链
|
|
# 延迟导入,完全避免循环依赖
|
|
from app.tasks.research_tasks import analyze_question_chain
|
|
analyze_question_chain.delay(session_id)
|
|
|
|
return {
|
|
"status": "started",
|
|
"session_id": session_id,
|
|
"message": "研究已开始"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"启动研究失败: {e}")
|
|
session.update_status(ResearchStatus.ERROR)
|
|
session.error_message = str(e)
|
|
self._save_session(session)
|
|
raise
|
|
|
|
def get_session(self, session_id: str) -> Optional[ResearchSession]:
|
|
"""获取研究会话"""
|
|
# 先从内存查找
|
|
if session_id in self.sessions:
|
|
return self.sessions[session_id]
|
|
|
|
# 从文件加载
|
|
filepath = self._get_session_filepath(session_id)
|
|
if os.path.exists(filepath):
|
|
session = ResearchSession.load_from_file(filepath)
|
|
self.sessions[session_id] = session
|
|
return session
|
|
|
|
return None
|
|
|
|
def update_session(self, session_id: str, updates: Dict[str, Any]):
|
|
"""更新会话信息"""
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
# 更新字段
|
|
for key, value in updates.items():
|
|
if hasattr(session, key):
|
|
setattr(session, key, value)
|
|
|
|
session.updated_at = datetime.now()
|
|
self._save_session(session)
|
|
|
|
def get_session_status(self, session_id: str) -> Dict[str, Any]:
|
|
"""获取研究进度"""
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
return {"error": "Session not found"}
|
|
|
|
# 计算子主题进度
|
|
subtopic_progress = []
|
|
if session.outline:
|
|
for subtopic in session.outline.sub_topics:
|
|
subtopic_progress.append({
|
|
"id": subtopic.id,
|
|
"topic": subtopic.topic,
|
|
"status": subtopic.status,
|
|
"progress": subtopic.get_total_searches() / subtopic.max_searches * 100
|
|
})
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"status": session.status,
|
|
"current_phase": session.current_phase,
|
|
"progress_percentage": session.get_progress_percentage(),
|
|
"subtopic_progress": subtopic_progress,
|
|
"created_at": session.created_at.isoformat(),
|
|
"updated_at": session.updated_at.isoformat(),
|
|
"error_message": session.error_message
|
|
}
|
|
|
|
def cancel_research(self, session_id: str) -> Dict[str, Any]:
|
|
"""取消研究"""
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
return {"error": "Session not found"}
|
|
|
|
# 更新状态
|
|
session.update_status(ResearchStatus.CANCELLED)
|
|
self._save_session(session)
|
|
|
|
return {
|
|
"status": "cancelled",
|
|
"session_id": session_id,
|
|
"message": "研究已取消"
|
|
}
|
|
|
|
def get_research_report(self, session_id: str) -> Optional[str]:
|
|
"""获取研究报告"""
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
return None
|
|
|
|
if session.status != ResearchStatus.COMPLETED:
|
|
return None
|
|
|
|
# 如果有最终报告,返回
|
|
if session.final_report:
|
|
return session.final_report
|
|
|
|
# 否则尝试从文件加载
|
|
report_path = os.path.join(Config.REPORTS_DIR, f"{session_id}.md")
|
|
if os.path.exists(report_path):
|
|
with open(report_path, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
|
|
return None
|
|
|
|
def list_sessions(self, limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]:
|
|
"""列出所有研究会话"""
|
|
# 从文件系统读取所有会话
|
|
sessions = []
|
|
session_files = sorted(
|
|
[f for f in os.listdir(Config.SESSIONS_DIR) if f.endswith('.json')],
|
|
reverse=True # 最新的在前
|
|
)
|
|
|
|
for filename in session_files[offset:offset+limit]:
|
|
filepath = os.path.join(Config.SESSIONS_DIR, filename)
|
|
try:
|
|
session = ResearchSession.load_from_file(filepath)
|
|
sessions.append({
|
|
"id": session.id,
|
|
"question": session.question,
|
|
"status": session.status,
|
|
"created_at": session.created_at.isoformat(),
|
|
"progress": session.get_progress_percentage()
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"加载会话失败 {filename}: {e}")
|
|
|
|
return sessions
|
|
|
|
def _save_session(self, session: ResearchSession):
|
|
"""保存会话到文件"""
|
|
filepath = self._get_session_filepath(session.id)
|
|
# 使用模型的 to_dict 方法处理 datetime 序列化
|
|
data = session.dict()
|
|
# 转换 datetime 对象
|
|
for key in ['created_at', 'updated_at', 'completed_at']:
|
|
if data.get(key):
|
|
data[key] = data[key].isoformat() if hasattr(data[key], 'isoformat') else data[key]
|
|
# 处理嵌套的 datetime
|
|
if data.get('outline'):
|
|
if data['outline'].get('created_at'):
|
|
data['outline']['created_at'] = data['outline']['created_at'].isoformat()
|
|
if data['outline'].get('updated_at'):
|
|
data['outline']['updated_at'] = data['outline']['updated_at'].isoformat()
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
|
|
|
|
def _get_session_filepath(self, session_id: str) -> str:
|
|
"""获取会话文件路径"""
|
|
return os.path.join(Config.SESSIONS_DIR, f"{session_id}.json")
|
|
|
|
# 以下是供任务调用的方法
|
|
|
|
def process_question_analysis(self, session_id: str):
|
|
"""处理问题分析阶段"""
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
# 分析问题类型
|
|
session.question_type = self.ai_service.analyze_question_type(session.question)
|
|
|
|
# 细化问题
|
|
session.refined_questions = self.ai_service.refine_questions(
|
|
session.question,
|
|
session.question_type
|
|
)
|
|
|
|
# 制定研究思路
|
|
session.research_approach = self.ai_service.create_research_approach(
|
|
session.question,
|
|
session.question_type,
|
|
session.refined_questions
|
|
)
|
|
|
|
# 更新进度
|
|
session.current_phase = "制定大纲"
|
|
session.completed_steps += 1
|
|
self._save_session(session)
|
|
|
|
def process_outline_creation(self, session_id: str):
|
|
"""处理大纲创建阶段"""
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
# 创建大纲
|
|
outline_dict = self.ai_service.create_outline(
|
|
session.question,
|
|
session.question_type,
|
|
session.refined_questions,
|
|
session.research_approach
|
|
)
|
|
|
|
# 转换为模型对象
|
|
subtopics = []
|
|
for st in outline_dict.get('sub_topics', []):
|
|
subtopic = Subtopic(
|
|
topic=st['topic'],
|
|
explain=st['explain'],
|
|
priority=st['priority'],
|
|
related_questions=st.get('related_questions', [])
|
|
)
|
|
# 设置最大搜索次数
|
|
if subtopic.priority == "high":
|
|
subtopic.max_searches = Config.MAX_SEARCHES_HIGH_PRIORITY
|
|
elif subtopic.priority == "medium":
|
|
subtopic.max_searches = Config.MAX_SEARCHES_MEDIUM_PRIORITY
|
|
else:
|
|
subtopic.max_searches = Config.MAX_SEARCHES_LOW_PRIORITY
|
|
|
|
subtopics.append(subtopic)
|
|
|
|
session.outline = ResearchOutline(
|
|
main_topic=outline_dict['main_topic'],
|
|
research_questions=outline_dict['research_questions'],
|
|
sub_topics=subtopics
|
|
)
|
|
|
|
# 更新进度
|
|
session.current_phase = "研究子主题"
|
|
session.update_status(ResearchStatus.RESEARCHING)
|
|
session.total_steps = 3 + len(subtopics) + 1 # 准备+大纲+子主题+最终报告
|
|
session.completed_steps = 2
|
|
self._save_session(session)
|
|
|
|
def process_subtopic_research(self, session_id: str, subtopic_id: str):
|
|
"""处理子主题研究"""
|
|
session = self.get_session(session_id)
|
|
if not session or not session.outline:
|
|
raise ValueError(f"Session or outline not found: {session_id}")
|
|
|
|
# 找到对应的子主题
|
|
subtopic = None
|
|
for st in session.outline.sub_topics:
|
|
if st.id == subtopic_id:
|
|
subtopic = st
|
|
break
|
|
|
|
if not subtopic:
|
|
raise ValueError(f"Subtopic not found: {subtopic_id}")
|
|
|
|
# 执行研究流程
|
|
# 这部分逻辑会在research_tasks.py中实现
|
|
# 这里只更新状态
|
|
subtopic.status = ResearchStatus.COMPLETED
|
|
session.completed_steps += 1
|
|
self._save_session(session)
|
|
|
|
def finalize_research(self, session_id: str):
|
|
"""完成研究"""
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
# 生成最终报告
|
|
# 这部分逻辑会在report_generator.py中实现
|
|
|
|
# 更新状态
|
|
session.update_status(ResearchStatus.COMPLETED)
|
|
session.current_phase = "研究完成"
|
|
session.completed_steps = session.total_steps
|
|
self._save_session(session) |