471 lines
17 KiB
Python
471 lines
17 KiB
Python
# 文件位置: app/tasks/research_tasks.py
|
|
# 文件名: research_tasks.py
|
|
|
|
"""
|
|
研究相关的异步任务
|
|
使用线程池替代Celery
|
|
"""
|
|
import logging
|
|
from typing import Dict, List, Any
|
|
from app.services.task_manager import async_task
|
|
from app.models.research import ResearchStatus, Subtopic
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@async_task
|
|
def analyze_question_chain(session_id: str):
|
|
"""问题分析任务链"""
|
|
try:
|
|
# 启用调试
|
|
from app.utils.debug_logger import ai_debug_logger
|
|
ai_debug_logger.set_session(session_id)
|
|
logger.info(f"启用调试模式: {session_id}")
|
|
|
|
# 在函数内部导入,避免循环导入
|
|
from app.services.research_manager import ResearchManager
|
|
|
|
research_manager = ResearchManager()
|
|
session = research_manager.get_session(session_id)
|
|
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
# 发送状态更新
|
|
_emit_status(session_id, ResearchStatus.ANALYZING, "分析问题")
|
|
|
|
# 执行问题分析
|
|
research_manager.process_question_analysis(session_id)
|
|
|
|
# 发送进度更新
|
|
_emit_progress(session_id, 20, "问题分析完成")
|
|
|
|
# 启动大纲创建任务
|
|
create_outline_task.delay(session_id)
|
|
|
|
except Exception as e:
|
|
logger.error(f"问题分析失败: {e}")
|
|
_handle_task_error(session_id, str(e))
|
|
raise
|
|
|
|
@async_task
|
|
def create_outline_task(session_id: str):
|
|
"""创建大纲任务"""
|
|
try:
|
|
# 确保调试会话设置正确
|
|
from app.utils.debug_logger import ai_debug_logger
|
|
ai_debug_logger.set_session(session_id)
|
|
|
|
# 在函数内部导入,避免循环导入
|
|
from app.services.research_manager import ResearchManager
|
|
|
|
research_manager = ResearchManager()
|
|
|
|
# 发送状态更新
|
|
_emit_status(session_id, ResearchStatus.OUTLINING, "制定大纲")
|
|
|
|
# 创建大纲
|
|
research_manager.process_outline_creation(session_id)
|
|
|
|
# 发送进度更新
|
|
_emit_progress(session_id, 30, "大纲制定完成")
|
|
|
|
# 获取更新后的session
|
|
session = research_manager.get_session(session_id)
|
|
|
|
# 启动子主题研究任务组
|
|
if session.outline and session.outline.sub_topics:
|
|
# 并发执行子主题研究
|
|
subtopic_task_ids = []
|
|
for st in session.outline.sub_topics:
|
|
task_id = research_subtopic.delay(session_id, st.id)
|
|
subtopic_task_ids.append(task_id)
|
|
|
|
# 启动一个监控任务,等待所有子主题完成后生成最终报告
|
|
monitor_subtopics_completion.delay(session_id, subtopic_task_ids)
|
|
|
|
except Exception as e:
|
|
logger.error(f"创建大纲失败: {e}")
|
|
_handle_task_error(session_id, str(e))
|
|
raise
|
|
|
|
@async_task
|
|
def research_subtopic(session_id: str, subtopic_id: str):
|
|
"""研究单个子主题"""
|
|
try:
|
|
# 确保调试会话设置正确
|
|
from app.utils.debug_logger import ai_debug_logger
|
|
ai_debug_logger.set_session(session_id)
|
|
logger.info(f"开始研究子主题: {subtopic_id}")
|
|
|
|
# 在函数内部导入,避免循环导入
|
|
from app.services.research_manager import ResearchManager
|
|
from app.services.ai_service import AIService
|
|
from app.services.search_service import SearchService
|
|
|
|
research_manager = ResearchManager()
|
|
ai_service = AIService()
|
|
search_service = SearchService()
|
|
|
|
# 获取session和子主题
|
|
session = research_manager.get_session(session_id)
|
|
if not session or not session.outline:
|
|
raise ValueError("Session or outline not found")
|
|
|
|
subtopic = None
|
|
for st in session.outline.sub_topics:
|
|
if st.id == subtopic_id:
|
|
subtopic = st
|
|
break
|
|
|
|
if not subtopic:
|
|
raise ValueError(f"Subtopic not found: {subtopic_id}")
|
|
|
|
# 更新子主题状态
|
|
subtopic.status = ResearchStatus.RESEARCHING
|
|
research_manager.update_session(session_id, {'outline': session.outline})
|
|
_emit_subtopic_progress(session_id, subtopic_id, 0, "researching")
|
|
|
|
# 1. 生成搜索查询
|
|
queries = ai_service.generate_search_queries(
|
|
subtopic.topic,
|
|
subtopic.explain,
|
|
subtopic.related_questions,
|
|
subtopic.priority
|
|
)
|
|
|
|
# 2. 执行搜索
|
|
logger.info(f"开始搜索子主题 {subtopic.topic}: {len(queries)} 个查询")
|
|
search_results = []
|
|
|
|
for i, query in enumerate(queries):
|
|
try:
|
|
response = search_service.search(query)
|
|
results = response.to_search_results()
|
|
|
|
# 评估结果重要性
|
|
evaluated_results = ai_service.evaluate_search_results(
|
|
subtopic.topic, results
|
|
)
|
|
|
|
search_results.extend(evaluated_results)
|
|
|
|
# 更新进度
|
|
progress = (i + 1) / len(queries) * 50 # 搜索占50%进度
|
|
_emit_subtopic_progress(session_id, subtopic_id, progress, "searching")
|
|
|
|
except Exception as e:
|
|
logger.error(f"搜索失败 '{query}': {e}")
|
|
|
|
# 去重
|
|
unique_results = list({r.url: r for r in search_results}.values())
|
|
subtopic.searches = [
|
|
{
|
|
"url": r.url,
|
|
"title": r.title,
|
|
"snippet": r.snippet,
|
|
"importance": r.importance.value if r.importance else "medium"
|
|
}
|
|
for r in unique_results
|
|
]
|
|
|
|
# 3. 信息反思
|
|
key_points = ai_service.reflect_on_information(subtopic.topic, unique_results)
|
|
|
|
if key_points:
|
|
# 4. 生成细化查询
|
|
refined_queries_map = ai_service.generate_refined_queries(key_points)
|
|
|
|
# 5. 执行细化搜索
|
|
for key_info, queries in refined_queries_map.items():
|
|
refined_batch = search_service.refined_search(
|
|
subtopic_id, key_info, queries
|
|
)
|
|
|
|
# 评估细化搜索结果
|
|
evaluated_refined = ai_service.evaluate_search_results(
|
|
subtopic.topic, refined_batch.results
|
|
)
|
|
|
|
subtopic.refined_searches.extend([
|
|
{
|
|
"key_info": key_info,
|
|
"url": r.url,
|
|
"title": r.title,
|
|
"snippet": r.snippet,
|
|
"importance": r.importance.value if r.importance else "medium"
|
|
}
|
|
for r in evaluated_refined
|
|
])
|
|
|
|
_emit_subtopic_progress(session_id, subtopic_id, 70, "integrating")
|
|
|
|
# 6. 整合信息
|
|
all_results = unique_results + [r for batch in subtopic.refined_searches for r in batch.get('results', [])]
|
|
integrated_info = ai_service.integrate_information(subtopic.topic, all_results)
|
|
subtopic.integrated_info = integrated_info
|
|
|
|
# 7. 撰写报告
|
|
_emit_subtopic_progress(session_id, subtopic_id, 80, "writing")
|
|
report_content = ai_service.write_subtopic_report(subtopic.topic, integrated_info)
|
|
|
|
# 8. 幻觉检测和修正
|
|
_emit_subtopic_progress(session_id, subtopic_id, 90, "reviewing")
|
|
|
|
# 提取原始内容用于幻觉检测
|
|
url_content_map = {}
|
|
for result in all_results:
|
|
url_content_map[result.url] = result.snippet
|
|
|
|
fixed_report, hallucinations = ai_service.detect_and_fix_hallucinations(
|
|
report_content, url_content_map
|
|
)
|
|
|
|
subtopic.report = fixed_report
|
|
subtopic.hallucination_checks = hallucinations
|
|
subtopic.status = ResearchStatus.COMPLETED
|
|
|
|
# 保存更新
|
|
research_manager.update_session(session_id, {'outline': session.outline})
|
|
research_manager.process_subtopic_research(session_id, subtopic_id)
|
|
|
|
_emit_subtopic_progress(session_id, subtopic_id, 100, "completed")
|
|
|
|
logger.info(f"子主题研究完成: {subtopic_id}")
|
|
|
|
return {
|
|
"subtopic_id": subtopic_id,
|
|
"status": "completed",
|
|
"search_count": len(queries),
|
|
"results_count": len(unique_results),
|
|
"hallucinations_fixed": len(hallucinations)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"子主题研究失败 {subtopic_id}: {e}")
|
|
|
|
# 更新状态为错误
|
|
try:
|
|
from app.services.research_manager import ResearchManager
|
|
research_manager = ResearchManager()
|
|
session = research_manager.get_session(session_id)
|
|
if session and session.outline:
|
|
for st in session.outline.sub_topics:
|
|
if st.id == subtopic_id:
|
|
st.status = ResearchStatus.ERROR
|
|
break
|
|
research_manager.update_session(session_id, {'outline': session.outline})
|
|
_emit_subtopic_progress(session_id, subtopic_id, -1, "error")
|
|
except:
|
|
pass
|
|
|
|
raise
|
|
|
|
@async_task
|
|
def monitor_subtopics_completion(session_id: str, task_ids: List[str]):
|
|
"""监控子主题完成情况并生成最终报告"""
|
|
import time
|
|
from app.services.task_manager import task_manager
|
|
|
|
try:
|
|
# 确保调试会话设置正确
|
|
from app.utils.debug_logger import ai_debug_logger
|
|
ai_debug_logger.set_session(session_id)
|
|
|
|
# 等待所有子主题任务完成
|
|
max_wait_time = 1800 # 30分钟超时
|
|
start_time = time.time()
|
|
|
|
while True:
|
|
all_completed = True
|
|
failed_count = 0
|
|
|
|
for task_id in task_ids:
|
|
status = task_manager.get_task_status(task_id)
|
|
if status:
|
|
if status['status'] == 'running' or status['status'] == 'pending':
|
|
all_completed = False
|
|
elif status['status'] == 'failed':
|
|
failed_count += 1
|
|
|
|
if all_completed:
|
|
break
|
|
|
|
if time.time() - start_time > max_wait_time:
|
|
logger.error(f"等待子主题完成超时: {session_id}")
|
|
break
|
|
|
|
time.sleep(5) # 每5秒检查一次
|
|
|
|
# 所有子主题完成后,生成最终报告
|
|
if failed_count < len(task_ids): # 至少有一个成功
|
|
generate_final_report_task.delay(session_id)
|
|
else:
|
|
_handle_task_error(session_id, "所有子主题研究失败")
|
|
|
|
except Exception as e:
|
|
logger.error(f"监控子主题完成失败: {e}")
|
|
_handle_task_error(session_id, str(e))
|
|
|
|
@async_task
|
|
def generate_final_report_task(session_id: str):
|
|
"""生成最终报告"""
|
|
try:
|
|
# 确保调试会话设置正确
|
|
from app.utils.debug_logger import ai_debug_logger
|
|
ai_debug_logger.set_session(session_id)
|
|
logger.info(f"开始生成最终报告: {session_id}")
|
|
|
|
# 在函数内部导入,避免循环导入
|
|
from app.services.research_manager import ResearchManager
|
|
from app.services.ai_service import AIService
|
|
from app.services.report_generator import ReportGenerator
|
|
|
|
research_manager = ResearchManager()
|
|
ai_service = AIService()
|
|
report_generator = ReportGenerator()
|
|
|
|
# 发送状态更新
|
|
_emit_status(session_id, ResearchStatus.WRITING, "生成最终报告")
|
|
_emit_progress(session_id, 90, "整合所有子主题报告")
|
|
|
|
# 获取session
|
|
session = research_manager.get_session(session_id)
|
|
if not session or not session.outline:
|
|
raise ValueError("Session or outline not found")
|
|
|
|
# 收集所有子主题报告
|
|
subtopic_reports_dict = {}
|
|
subtopic_report_objects = []
|
|
|
|
for subtopic in session.outline.sub_topics:
|
|
if subtopic.report:
|
|
subtopic_reports_dict[subtopic.topic] = subtopic.report
|
|
|
|
# 创建报告对象
|
|
report_obj = report_generator.generate_subtopic_report(
|
|
subtopic,
|
|
subtopic.integrated_info or {},
|
|
subtopic.report
|
|
)
|
|
subtopic_report_objects.append(report_obj)
|
|
|
|
# 生成最终报告内容
|
|
final_content = ai_service.generate_final_report(
|
|
session.outline.main_topic,
|
|
session.outline.research_questions,
|
|
subtopic_reports_dict
|
|
)
|
|
|
|
# 创建最终报告对象
|
|
final_report = report_generator.generate_final_report(
|
|
session,
|
|
subtopic_report_objects,
|
|
final_content
|
|
)
|
|
|
|
# 保存报告
|
|
report_path = report_generator.save_report(final_report)
|
|
|
|
# 更新session
|
|
session.final_report = final_report.to_markdown()
|
|
session.update_status(ResearchStatus.COMPLETED)
|
|
research_manager.update_session(session_id, {
|
|
'final_report': session.final_report,
|
|
'status': session.status
|
|
})
|
|
research_manager.finalize_research(session_id)
|
|
|
|
# 发送完成通知
|
|
_emit_progress(session_id, 100, "研究完成")
|
|
_emit_status(session_id, ResearchStatus.COMPLETED, "研究完成")
|
|
_emit_report_ready(session_id, "final")
|
|
|
|
logger.info(f"研究完成: {session_id}")
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"status": "completed",
|
|
"report_path": report_path
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"生成最终报告失败: {e}")
|
|
_handle_task_error(session_id, str(e))
|
|
raise
|
|
|
|
# ========== 辅助函数 ==========
|
|
|
|
def _get_socketio():
|
|
"""获取socketio实例"""
|
|
# 延迟导入,避免循环依赖
|
|
from app import socketio
|
|
return socketio
|
|
|
|
def _emit_progress(session_id: str, percentage: float, message: str):
|
|
"""发送进度更新"""
|
|
try:
|
|
# 延迟导入避免循环依赖
|
|
from app.routes.websocket import emit_progress
|
|
socketio = _get_socketio()
|
|
emit_progress(socketio, session_id, {
|
|
'percentage': percentage,
|
|
'message': message
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"发送进度更新失败: {e}")
|
|
|
|
def _emit_status(session_id: str, status: ResearchStatus, phase: str):
|
|
"""发送状态更新"""
|
|
try:
|
|
# 延迟导入避免循环依赖
|
|
from app.routes.websocket import emit_status_change
|
|
socketio = _get_socketio()
|
|
emit_status_change(socketio, session_id, status.value, phase)
|
|
except Exception as e:
|
|
logger.error(f"发送状态更新失败: {e}")
|
|
|
|
def _emit_subtopic_progress(session_id: str, subtopic_id: str,
|
|
progress: float, status: str):
|
|
"""发送子主题进度"""
|
|
try:
|
|
# 延迟导入避免循环依赖
|
|
from app.routes.websocket import emit_subtopic_progress
|
|
socketio = _get_socketio()
|
|
emit_subtopic_progress(socketio, session_id, subtopic_id, progress, status)
|
|
except Exception as e:
|
|
logger.error(f"发送子主题进度失败: {e}")
|
|
|
|
def _emit_report_ready(session_id: str, report_type: str):
|
|
"""发送报告就绪通知"""
|
|
try:
|
|
# 延迟导入避免循环依赖
|
|
from app.routes.websocket import emit_report_ready
|
|
socketio = _get_socketio()
|
|
emit_report_ready(socketio, session_id, report_type)
|
|
except Exception as e:
|
|
logger.error(f"发送报告就绪通知失败: {e}")
|
|
|
|
def _handle_task_error(session_id: str, error_message: str):
|
|
"""处理任务错误"""
|
|
try:
|
|
# 在函数内部导入,避免循环导入
|
|
from app.services.research_manager import ResearchManager
|
|
from app.routes.websocket import emit_error
|
|
|
|
# 更新session状态
|
|
research_manager = ResearchManager()
|
|
session = research_manager.get_session(session_id)
|
|
if session:
|
|
session.update_status(ResearchStatus.ERROR)
|
|
session.error_message = error_message
|
|
research_manager.update_session(session_id, {
|
|
'status': session.status,
|
|
'error_message': error_message
|
|
})
|
|
|
|
# 发送错误通知
|
|
socketio = _get_socketio()
|
|
emit_error(socketio, session_id, error_message)
|
|
|
|
except Exception as e:
|
|
logger.error(f"处理任务错误失败: {e}") |