deepresearch/app/tasks/research_tasks.py
2025-07-02 15:35:36 +08:00

471 lines
17 KiB
Python

# 文件位置: app/tasks/research_tasks.py
# 文件名: research_tasks.py
"""
研究相关的异步任务
使用线程池替代Celery
"""
import logging
from typing import Dict, List, Any
from app.services.task_manager import async_task
from app.models.research import ResearchStatus, Subtopic
logger = logging.getLogger(__name__)
@async_task
def analyze_question_chain(session_id: str):
"""问题分析任务链"""
try:
# 启用调试
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
logger.info(f"启用调试模式: {session_id}")
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
research_manager = ResearchManager()
session = research_manager.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 发送状态更新
_emit_status(session_id, ResearchStatus.ANALYZING, "分析问题")
# 执行问题分析
research_manager.process_question_analysis(session_id)
# 发送进度更新
_emit_progress(session_id, 20, "问题分析完成")
# 启动大纲创建任务
create_outline_task.delay(session_id)
except Exception as e:
logger.error(f"问题分析失败: {e}")
_handle_task_error(session_id, str(e))
raise
@async_task
def create_outline_task(session_id: str):
"""创建大纲任务"""
try:
# 确保调试会话设置正确
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
research_manager = ResearchManager()
# 发送状态更新
_emit_status(session_id, ResearchStatus.OUTLINING, "制定大纲")
# 创建大纲
research_manager.process_outline_creation(session_id)
# 发送进度更新
_emit_progress(session_id, 30, "大纲制定完成")
# 获取更新后的session
session = research_manager.get_session(session_id)
# 启动子主题研究任务组
if session.outline and session.outline.sub_topics:
# 并发执行子主题研究
subtopic_task_ids = []
for st in session.outline.sub_topics:
task_id = research_subtopic.delay(session_id, st.id)
subtopic_task_ids.append(task_id)
# 启动一个监控任务,等待所有子主题完成后生成最终报告
monitor_subtopics_completion.delay(session_id, subtopic_task_ids)
except Exception as e:
logger.error(f"创建大纲失败: {e}")
_handle_task_error(session_id, str(e))
raise
@async_task
def research_subtopic(session_id: str, subtopic_id: str):
"""研究单个子主题"""
try:
# 确保调试会话设置正确
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
logger.info(f"开始研究子主题: {subtopic_id}")
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
from app.services.ai_service import AIService
from app.services.search_service import SearchService
research_manager = ResearchManager()
ai_service = AIService()
search_service = SearchService()
# 获取session和子主题
session = research_manager.get_session(session_id)
if not session or not session.outline:
raise ValueError("Session or outline not found")
subtopic = None
for st in session.outline.sub_topics:
if st.id == subtopic_id:
subtopic = st
break
if not subtopic:
raise ValueError(f"Subtopic not found: {subtopic_id}")
# 更新子主题状态
subtopic.status = ResearchStatus.RESEARCHING
research_manager.update_session(session_id, {'outline': session.outline})
_emit_subtopic_progress(session_id, subtopic_id, 0, "researching")
# 1. 生成搜索查询
queries = ai_service.generate_search_queries(
subtopic.topic,
subtopic.explain,
subtopic.related_questions,
subtopic.priority
)
# 2. 执行搜索
logger.info(f"开始搜索子主题 {subtopic.topic}: {len(queries)} 个查询")
search_results = []
for i, query in enumerate(queries):
try:
response = search_service.search(query)
results = response.to_search_results()
# 评估结果重要性
evaluated_results = ai_service.evaluate_search_results(
subtopic.topic, results
)
search_results.extend(evaluated_results)
# 更新进度
progress = (i + 1) / len(queries) * 50 # 搜索占50%进度
_emit_subtopic_progress(session_id, subtopic_id, progress, "searching")
except Exception as e:
logger.error(f"搜索失败 '{query}': {e}")
# 去重
unique_results = list({r.url: r for r in search_results}.values())
subtopic.searches = [
{
"url": r.url,
"title": r.title,
"snippet": r.snippet,
"importance": r.importance.value if r.importance else "medium"
}
for r in unique_results
]
# 3. 信息反思
key_points = ai_service.reflect_on_information(subtopic.topic, unique_results)
if key_points:
# 4. 生成细化查询
refined_queries_map = ai_service.generate_refined_queries(key_points)
# 5. 执行细化搜索
for key_info, queries in refined_queries_map.items():
refined_batch = search_service.refined_search(
subtopic_id, key_info, queries
)
# 评估细化搜索结果
evaluated_refined = ai_service.evaluate_search_results(
subtopic.topic, refined_batch.results
)
subtopic.refined_searches.extend([
{
"key_info": key_info,
"url": r.url,
"title": r.title,
"snippet": r.snippet,
"importance": r.importance.value if r.importance else "medium"
}
for r in evaluated_refined
])
_emit_subtopic_progress(session_id, subtopic_id, 70, "integrating")
# 6. 整合信息
all_results = unique_results + [r for batch in subtopic.refined_searches for r in batch.get('results', [])]
integrated_info = ai_service.integrate_information(subtopic.topic, all_results)
subtopic.integrated_info = integrated_info
# 7. 撰写报告
_emit_subtopic_progress(session_id, subtopic_id, 80, "writing")
report_content = ai_service.write_subtopic_report(subtopic.topic, integrated_info)
# 8. 幻觉检测和修正
_emit_subtopic_progress(session_id, subtopic_id, 90, "reviewing")
# 提取原始内容用于幻觉检测
url_content_map = {}
for result in all_results:
url_content_map[result.url] = result.snippet
fixed_report, hallucinations = ai_service.detect_and_fix_hallucinations(
report_content, url_content_map
)
subtopic.report = fixed_report
subtopic.hallucination_checks = hallucinations
subtopic.status = ResearchStatus.COMPLETED
# 保存更新
research_manager.update_session(session_id, {'outline': session.outline})
research_manager.process_subtopic_research(session_id, subtopic_id)
_emit_subtopic_progress(session_id, subtopic_id, 100, "completed")
logger.info(f"子主题研究完成: {subtopic_id}")
return {
"subtopic_id": subtopic_id,
"status": "completed",
"search_count": len(queries),
"results_count": len(unique_results),
"hallucinations_fixed": len(hallucinations)
}
except Exception as e:
logger.error(f"子主题研究失败 {subtopic_id}: {e}")
# 更新状态为错误
try:
from app.services.research_manager import ResearchManager
research_manager = ResearchManager()
session = research_manager.get_session(session_id)
if session and session.outline:
for st in session.outline.sub_topics:
if st.id == subtopic_id:
st.status = ResearchStatus.ERROR
break
research_manager.update_session(session_id, {'outline': session.outline})
_emit_subtopic_progress(session_id, subtopic_id, -1, "error")
except:
pass
raise
@async_task
def monitor_subtopics_completion(session_id: str, task_ids: List[str]):
"""监控子主题完成情况并生成最终报告"""
import time
from app.services.task_manager import task_manager
try:
# 确保调试会话设置正确
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
# 等待所有子主题任务完成
max_wait_time = 1800 # 30分钟超时
start_time = time.time()
while True:
all_completed = True
failed_count = 0
for task_id in task_ids:
status = task_manager.get_task_status(task_id)
if status:
if status['status'] == 'running' or status['status'] == 'pending':
all_completed = False
elif status['status'] == 'failed':
failed_count += 1
if all_completed:
break
if time.time() - start_time > max_wait_time:
logger.error(f"等待子主题完成超时: {session_id}")
break
time.sleep(5) # 每5秒检查一次
# 所有子主题完成后,生成最终报告
if failed_count < len(task_ids): # 至少有一个成功
generate_final_report_task.delay(session_id)
else:
_handle_task_error(session_id, "所有子主题研究失败")
except Exception as e:
logger.error(f"监控子主题完成失败: {e}")
_handle_task_error(session_id, str(e))
@async_task
def generate_final_report_task(session_id: str):
"""生成最终报告"""
try:
# 确保调试会话设置正确
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
logger.info(f"开始生成最终报告: {session_id}")
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
from app.services.ai_service import AIService
from app.services.report_generator import ReportGenerator
research_manager = ResearchManager()
ai_service = AIService()
report_generator = ReportGenerator()
# 发送状态更新
_emit_status(session_id, ResearchStatus.WRITING, "生成最终报告")
_emit_progress(session_id, 90, "整合所有子主题报告")
# 获取session
session = research_manager.get_session(session_id)
if not session or not session.outline:
raise ValueError("Session or outline not found")
# 收集所有子主题报告
subtopic_reports_dict = {}
subtopic_report_objects = []
for subtopic in session.outline.sub_topics:
if subtopic.report:
subtopic_reports_dict[subtopic.topic] = subtopic.report
# 创建报告对象
report_obj = report_generator.generate_subtopic_report(
subtopic,
subtopic.integrated_info or {},
subtopic.report
)
subtopic_report_objects.append(report_obj)
# 生成最终报告内容
final_content = ai_service.generate_final_report(
session.outline.main_topic,
session.outline.research_questions,
subtopic_reports_dict
)
# 创建最终报告对象
final_report = report_generator.generate_final_report(
session,
subtopic_report_objects,
final_content
)
# 保存报告
report_path = report_generator.save_report(final_report)
# 更新session
session.final_report = final_report.to_markdown()
session.update_status(ResearchStatus.COMPLETED)
research_manager.update_session(session_id, {
'final_report': session.final_report,
'status': session.status
})
research_manager.finalize_research(session_id)
# 发送完成通知
_emit_progress(session_id, 100, "研究完成")
_emit_status(session_id, ResearchStatus.COMPLETED, "研究完成")
_emit_report_ready(session_id, "final")
logger.info(f"研究完成: {session_id}")
return {
"session_id": session_id,
"status": "completed",
"report_path": report_path
}
except Exception as e:
logger.error(f"生成最终报告失败: {e}")
_handle_task_error(session_id, str(e))
raise
# ========== 辅助函数 ==========
def _get_socketio():
"""获取socketio实例"""
# 延迟导入,避免循环依赖
from app import socketio
return socketio
def _emit_progress(session_id: str, percentage: float, message: str):
"""发送进度更新"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_progress
socketio = _get_socketio()
emit_progress(socketio, session_id, {
'percentage': percentage,
'message': message
})
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
def _emit_status(session_id: str, status: ResearchStatus, phase: str):
"""发送状态更新"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_status_change
socketio = _get_socketio()
emit_status_change(socketio, session_id, status.value, phase)
except Exception as e:
logger.error(f"发送状态更新失败: {e}")
def _emit_subtopic_progress(session_id: str, subtopic_id: str,
progress: float, status: str):
"""发送子主题进度"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_subtopic_progress
socketio = _get_socketio()
emit_subtopic_progress(socketio, session_id, subtopic_id, progress, status)
except Exception as e:
logger.error(f"发送子主题进度失败: {e}")
def _emit_report_ready(session_id: str, report_type: str):
"""发送报告就绪通知"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_report_ready
socketio = _get_socketio()
emit_report_ready(socketio, session_id, report_type)
except Exception as e:
logger.error(f"发送报告就绪通知失败: {e}")
def _handle_task_error(session_id: str, error_message: str):
"""处理任务错误"""
try:
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
from app.routes.websocket import emit_error
# 更新session状态
research_manager = ResearchManager()
session = research_manager.get_session(session_id)
if session:
session.update_status(ResearchStatus.ERROR)
session.error_message = error_message
research_manager.update_session(session_id, {
'status': session.status,
'error_message': error_message
})
# 发送错误通知
socketio = _get_socketio()
emit_error(socketio, session_id, error_message)
except Exception as e:
logger.error(f"处理任务错误失败: {e}")