初始提交

This commit is contained in:
JOJO 2025-07-02 15:35:36 +08:00
commit 47fa0cace0
92 changed files with 13816 additions and 0 deletions

52
.gitignore vendored Normal file
View File

@ -0,0 +1,52 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
env/
venv/
ENV/
.venv
# Flask
instance/
.webassets-cache
# Environment variables
.env
.env.local
.env.*.local
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# Logs
logs/
*.log
# Data
data/sessions/*
data/reports/*
data/cache/*
!data/sessions/.gitkeep
!data/reports/.gitkeep
!data/cache/.gitkeep
# Testing
.pytest_cache/
.coverage
htmlcov/
.tox/
# OS
.DS_Store
Thumbs.db
# Celery
celerybeat-schedule
celerybeat.pid

0
README.md Normal file
View File

59
app.py Normal file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
"""
DeepResearch 应用入口
"""
import os
import sys
import signal
from app import create_app, socketio
from config import config
# 获取配置名称
config_name = os.environ.get('FLASK_CONFIG', 'development')
app = create_app(config_name)
def shutdown_handler(signum, frame):
"""优雅关闭处理器"""
print("\n正在关闭应用...")
# 关闭任务管理器
try:
from app.services.task_manager import task_manager
task_manager.shutdown()
print("任务管理器已关闭")
except Exception as e:
print(f"关闭任务管理器时出错: {e}")
sys.exit(0)
if __name__ == '__main__':
# 注册信号处理器
signal.signal(signal.SIGINT, shutdown_handler)
signal.signal(signal.SIGTERM, shutdown_handler)
# 检查必要的环境变量
required_env_vars = ['DEEPSEEK_API_KEY', 'TAVILY_API_KEY']
missing_vars = [var for var in required_env_vars if not os.environ.get(var)]
if missing_vars:
print(f"错误: 缺少必要的环境变量: {', '.join(missing_vars)}")
print("请在.env文件中设置这些变量")
sys.exit(1)
# 启动应用
port = int(os.environ.get('PORT', 8088))
debug = app.config.get('DEBUG', False)
print(f"启动 DeepResearch 服务器...")
print(f"配置: {config_name}")
print(f"调试模式: {debug}")
print(f"访问地址: http://localhost:{port}")
print(f"\n提示: 不再需要 Redis 和 Celery Worker")
print(f"按 Ctrl+C 优雅关闭应用\n")
# 使用socketio运行以支持WebSocket
socketio.run(app,
host='0.0.0.0',
port=port,
debug=debug,
use_reloader=debug)

203
app/__init__.py Normal file
View File

@ -0,0 +1,203 @@
# 文件位置: app/__init__.py
# 文件名: __init__.py
"""
Flask应用工厂
"""
import os
import logging
from flask import Flask, render_template, jsonify
from flask_cors import CORS
from flask_socketio import SocketIO
from config import config
# 创建SocketIO实例
socketio = SocketIO()
def create_app(config_name=None):
"""创建Flask应用"""
if config_name is None:
config_name = os.environ.get('FLASK_CONFIG', 'development')
# 创建Flask应用
app = Flask(__name__)
# 加载配置
app.config.from_object(config[config_name])
config[config_name].init_app(app)
# 初始化扩展
CORS(app, resources={r"/api/*": {"origins": "*"}})
socketio.init_app(app,
cors_allowed_origins="*",
async_mode='eventlet',
logger=True,
engineio_logger=True if app.debug else False)
# 配置日志
setup_logging(app)
# 注册蓝图
register_blueprints(app)
# 注册WebSocket事件处理器
with app.app_context():
from app.routes.websocket import register_handlers
register_handlers(socketio)
# 注册错误处理器
register_error_handlers(app)
# 创建必要的目录
create_directories(app)
# 初始化任务管理器
init_task_manager(app)
app.logger.info(f'DeepResearch应用已创建配置: {config_name}')
return app
def register_blueprints(app):
"""注册所有蓝图"""
from app.routes.main import main_bp
from app.routes.api import api_bp
from app.routes.frontend import frontend_bp
# 先注册前端路由,确保 '/' 路由正确
app.register_blueprint(frontend_bp)
app.register_blueprint(api_bp, url_prefix='/api')
app.register_blueprint(main_bp)
app.logger.info('蓝图注册完成')
def register_error_handlers(app):
"""注册错误处理器"""
@app.errorhandler(404)
def not_found_error(error):
if request.path.startswith('/api/'):
return jsonify({'error': 'Not found'}), 404
return render_template('404.html'), 404
@app.errorhandler(500)
def internal_error(error):
app.logger.error(f'服务器错误: {error}')
if request.path.startswith('/api/'):
return jsonify({'error': 'Internal server error'}), 500
return render_template('500.html'), 500
@app.errorhandler(Exception)
def unhandled_exception(error):
app.logger.error(f'未处理的异常: {error}', exc_info=True)
if request.path.startswith('/api/'):
return jsonify({'error': str(error)}), 500
return render_template('500.html'), 500
def setup_logging(app):
"""设置日志系统"""
from app.utils.logger import setup_logging as logger_setup
logger_setup(app)
# 设置依赖库的日志级别
logging.getLogger('werkzeug').setLevel(logging.WARNING)
logging.getLogger('socketio').setLevel(logging.INFO if app.debug else logging.WARNING)
logging.getLogger('engineio').setLevel(logging.INFO if app.debug else logging.WARNING)
def create_directories(app):
"""创建必要的目录"""
directories = [
app.config.get('DATA_DIR'),
app.config.get('SESSIONS_DIR'),
app.config.get('REPORTS_DIR'),
app.config.get('CACHE_DIR'),
app.config.get('LOG_DIR'),
os.path.join(app.config.get('DATA_DIR'), 'debug') # 调试日志目录
]
for directory in directories:
if directory and not os.path.exists(directory):
os.makedirs(directory)
app.logger.info(f'创建目录: {directory}')
def init_task_manager(app):
"""初始化任务管理器"""
try:
from app.services.task_manager import task_manager
# 清理旧任务
if hasattr(task_manager, 'cleanup_old_tasks'):
cleaned = task_manager.cleanup_old_tasks(hours=24)
app.logger.info(f'清理了 {cleaned} 个旧任务')
app.logger.info('任务管理器初始化完成')
except Exception as e:
app.logger.error(f'任务管理器初始化失败: {e}')
# 导入必要的模块以避免导入错误
from flask import request
# 创建简单的错误页面模板(如果不存在)
def create_error_templates(app):
"""创建基本的错误页面模板"""
templates_dir = os.path.join(app.root_path, 'templates')
# 404页面
error_404_path = os.path.join(templates_dir, '404.html')
if not os.path.exists(error_404_path):
error_404_content = '''
<!DOCTYPE html>
<html>
<head>
<title>404 - 页面未找到</title>
<style>
body { font-family: Arial, sans-serif; text-align: center; padding: 50px; }
h1 { color: #333; }
a { color: #007bff; text-decoration: none; }
</style>
</head>
<body>
<h1>404 - 页面未找到</h1>
<p>抱歉您访问的页面不存在</p>
<p><a href="/">返回首页</a></p>
</body>
</html>
'''
try:
with open(error_404_path, 'w', encoding='utf-8') as f:
f.write(error_404_content)
except:
pass
# 500页面
error_500_path = os.path.join(templates_dir, '500.html')
if not os.path.exists(error_500_path):
error_500_content = '''
<!DOCTYPE html>
<html>
<head>
<title>500 - 服务器错误</title>
<style>
body { font-family: Arial, sans-serif; text-align: center; padding: 50px; }
h1 { color: #dc3545; }
a { color: #007bff; text-decoration: none; }
</style>
</head>
<body>
<h1>500 - 服务器错误</h1>
<p>抱歉服务器遇到了一个错误</p>
<p><a href="/">返回首页</a></p>
</body>
</html>
'''
try:
with open(error_500_path, 'w', encoding='utf-8') as f:
f.write(error_500_content)
except:
pass
# 提供一个便捷的函数来获取socketio实例
def get_socketio():
"""获取socketio实例"""
return socketio

0
app/agents/__init__.py Normal file
View File

340
app/agents/prompts.py Normal file
View File

@ -0,0 +1,340 @@
"""
所有AI模型的提示词模板
"""
PROMPTS = {
# 1. 判断问题类型
"question_type_analysis": """
请分析以下用户问题判断其属于哪种类型
用户问题{question}
请从以下类型中选择最合适的一个
1. factual - 事实查询型需要具体准确的信息
2. comparative - 分析对比型需要多角度分析和比较
3. exploratory - 探索发现型需要广泛探索未知领域
4. decision - 决策支持型需要综合分析支持决策
请直接返回类型代码factual不需要其他解释
""",
# 2. 细化问题
"refine_questions": """
基于用户的问题和问题类型请提出3-5个细化问题帮助更好地理解和研究这个主题
原始问题{question}
问题类型{question_type}
请思考
1. 还需要哪些具体信息
2. 应该关注问题的哪些方面
3. 有哪些潜在的相关维度需要探索
请以列表形式返回细化问题每个问题独占一行
""",
# 3. 初步研究思路
"research_approach": """
基于用户问题和细化问题请制定初步的研究思路
原始问题{question}
问题类型{question_type}
细化问题
{refined_questions}
请简要说明研究这个问题的整体思路和方法200字以内
""",
# 4. 制定研究大纲
"create_outline": """
请为以下研究主题制定详细的研究大纲
主题{question}
问题类型{question_type}
细化问题{refined_questions}
研究思路{research_approach}
请按以下JSON格式输出大纲
```json
{{
"main_topic": "用户输入的主题",
"research_questions": [
"核心问题1",
"核心问题2",
"核心问题3"
],
"sub_topics": [
{{
"topic": "子主题1",
"explain": "子主题1的简单解释",
"priority": "high",
"related_questions": ["核心问题1", "核心问题2"]
}},
{{
"topic": "子主题2",
"explain": "子主题2的简单解释",
"priority": "medium",
"related_questions": ["核心问题2"]
}}
]
}}
```
注意
- 子主题数量建议3-6
- priority可选值high/medium/low
- 确保子主题覆盖所有核心问题
""",
# 5. 大纲验证搜索
"outline_validation": """
请评估这个研究大纲是否完整和合理
研究大纲
{outline}
请思考并搜索验证
1. 核心问题是否全面
2. 子主题划分是否合理
3. 是否有遗漏的重要方面
如果需要改进请提供具体建议
""",
# 6. 修改大纲
"modify_outline": """
基于用户反馈和验证结果请修改研究大纲
原大纲
{original_outline}
用户反馈
{user_feedback}
验证发现的问题
{validation_issues}
请输出修改后的大纲格式与原大纲相同
重点关注用户提出的修改意见
""",
# 8. 评估搜索结果
"evaluate_search_results": """
请评估以下搜索结果对于研究子主题的重要性
子主题{subtopic}
搜索结果
标题{title}
URL{url}
摘要{snippet}
评估标准
1. 主题匹配度内容与子主题的相关程度
2. 问题覆盖度能否回答相关的核心问题
3. 信息新颖度是否提供了独特或深入的见解
请直接返回重要性级别high/medium/low
""",
# 9. 信息反思
"information_reflection": """
<think>
好的现在需要梳理已获取的信息
子主题{subtopic}
已获得信息总结
{search_summary}
让我再仔细思考总结一下是否有哪些信息非常重要需要更细节的内容
{detailed_analysis}
</think>
基于以上分析以下信息还需要进一步获取细节内容
""",
# 11. 信息结构化整合
"integrate_information": """
请将子主题的所有搜索结果整合为结构化信息
子主题{subtopic}
所有搜索结果
{all_search_results}
请按以下JSON格式输出整合后的信息
```json
{{
"key_points": [
{{
"point": "关键点描述",
"evidence": [
{{
"source_url": "https://example.com",
"confidence": "high"
}}
],
"contradictions": [],
"related_points": []
}}
],
"themes": [
{{
"theme": "主题归类",
"points": ["关键点1", "关键点2"]
}}
]
}}
```
""",
# 12. 子主题报告撰写
"write_subtopic_report": """
请基于整合的信息撰写子主题研究报告
子主题{subtopic}
整合信息
{integrated_info}
撰写要求
1. 使用以下格式
2. 每个观点必须标注来源URL
3. 保持客观准确
4. 突出关键发现和洞察
格式要求
## [子主题名称]
### 一、[主要发现1]
#### 1.1 [子标题]
[内容]来源[具体URL]
#### 1.2 [子标题]
[内容]来源[具体URL]
### 二、[主要发现2]
#### 2.1 [子标题]
[内容]来源[具体URL]
### 三、关键洞察
1. **[洞察1]**基于[来源URL]的数据显示...
2. **[洞察2]**根据[来源URL]的分析...
### 四、建议与展望
[基于研究的可执行建议]
""",
# 13. 幻觉内容检测
"hallucination_detection": """
请检查撰写内容是否存在幻觉与原始来源不符
撰写内容
{written_content}
声称的来源URL{claimed_url}
原始搜索结果中的对应内容
{original_content}
请判断
1. 撰写内容是否准确反映了原始来源
2. 是否存在夸大错误归因或无中生有
如果存在幻觉请指出具体问题
返回格式
{{
"is_hallucination": true/false,
"hallucination_type": "夸大/错误归因/无中生有/无",
"explanation": "具体说明"
}}
""",
# 14. 幻觉内容重写V3使用
"rewrite_hallucination": """
请基于原始搜索材料重新撰写这部分内容
原始内容存在幻觉
{hallucinated_content}
原始搜索材料
{original_sources}
请严格基于搜索材料重新撰写确保准确性
保持原有的格式和风格
""",
# 15. 最终报告生成
"generate_final_report": """
请基于所有子主题报告生成最终的综合研究报告
研究主题{main_topic}
研究问题{research_questions}
各子主题报告
{subtopic_reports}
要求
1. 综合各子主题的发现
2. 提炼整体洞察
3. 保持URL引用格式
4. 提供可执行的建议
报告结构
# [研究主题]
## 执行摘要
[整体研究发现概述]
## 主要发现
### 1. [综合发现1]
基于多个来源的分析...来源[URL1], [URL2]
### 2. [综合发现2]
研究表明...来源[URL3], [URL4]
## 综合洞察
[基于所有研究的深度洞察]
## 建议
[具体可执行的建议]
## 详细子主题报告
[插入所有子主题的详细报告]
"""
}
# 搜索相关的提示词V3使用
SEARCH_PROMPTS = {
"generate_search_queries": """
为以下子主题生成{count}个搜索查询
子主题{subtopic}
子主题说明{explanation}
相关问题{related_questions}
要求
1. 查询要具体有针对性
2. 覆盖不同角度
3. 使用不同的关键词组合
4. 每个查询独占一行
请直接返回搜索查询列表
""",
"generate_refined_queries": """
基于信息反思结果为以下重点生成细节搜索查询
重点信息{key_info}
需要的细节{detail_needed}
请生成3个针对性的搜索查询每个查询独占一行
"""
}
def get_prompt(prompt_name: str, **kwargs) -> str:
"""获取并格式化提示词"""
if prompt_name in PROMPTS:
return PROMPTS[prompt_name].format(**kwargs)
elif prompt_name in SEARCH_PROMPTS:
return SEARCH_PROMPTS[prompt_name].format(**kwargs)
else:
raise ValueError(f"Unknown prompt: {prompt_name}")

332
app/agents/r1_agent.py Normal file
View File

@ -0,0 +1,332 @@
"""
DeepSeek R1模型智能体 - 带调试功能
负责推理判断规划撰写等思考密集型任务
"""
import json
import logging
from typing import Dict, List, Any, Optional
from openai import OpenAI
from config import Config
from app.agents.prompts import get_prompt
from app.utils.json_parser import parse_json_safely
from app.utils.debug_logger import ai_debug_logger
logger = logging.getLogger(__name__)
class R1Agent:
"""R1模型智能体"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.DEEPSEEK_API_KEY
base_url = Config.DEEPSEEK_BASE_URL
# 火山引擎 ARK 平台使用不同的模型名称
if 'volces.com' in base_url:
self.model = "deepseek-r1-250120" # 火山引擎的 R1 模型名称
else:
self.model = Config.R1_MODEL
self.client = OpenAI(
api_key=self.api_key,
base_url=base_url
)
def _call_api(self, prompt: str, temperature: float = 0.7,
max_tokens: int = 4096, json_mode: bool = False) -> str:
"""调用R1 API"""
try:
messages = [{"role": "user", "content": prompt}]
# 对于JSON输出使用补全技巧
if json_mode and "```json" in prompt:
# 提取到```json之前的部分作为prompt
prefix = prompt.split("```json")[0] + "```json\n"
messages = [
{"role": "user", "content": prefix},
{"role": "assistant", "content": "```json\n"}
]
logger.info(f"调用R1 API: temperature={temperature}, max_tokens={max_tokens}")
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
content = response.choices[0].message.content
# 记录原始输出到调试日志
ai_debug_logger.log_api_call(
model=self.model,
agent_type="R1",
method=self._get_caller_method(),
prompt=prompt,
response=content,
temperature=temperature,
max_tokens=max_tokens,
metadata={
"json_mode": json_mode,
"prompt_tokens": response.usage.prompt_tokens if hasattr(response, 'usage') else None,
"completion_tokens": response.usage.completion_tokens if hasattr(response, 'usage') else None
}
)
# 不再单独提取思考过程,保持在原始输出中
# 如果是JSON模式提取JSON内容
original_content = content
if json_mode:
if "```json" in content:
json_start = content.find("```json") + 7
json_end = content.find("```", json_start)
if json_end > json_start:
content = content[json_start:json_end].strip()
elif content.startswith("```json\n"):
# 补全模式的响应
content = content[8:]
if content.endswith("```"):
content = content[:-3]
# 如果提取后的内容与原始内容不同,记录
if content != original_content:
logger.debug(f"JSON提取: 原始长度={len(original_content)}, 提取后长度={len(content)}")
return content.strip()
except Exception as e:
logger.error(f"R1 API调用失败: {e}")
ai_debug_logger.log_api_call(
model=self.model,
agent_type="R1",
method=self._get_caller_method(),
prompt=prompt,
response=f"ERROR: {str(e)}",
temperature=temperature,
max_tokens=max_tokens,
metadata={"error": str(e)}
)
raise
def _get_caller_method(self) -> str:
"""获取调用方法名"""
import inspect
frame = inspect.currentframe()
if frame and frame.f_back and frame.f_back.f_back:
return frame.f_back.f_back.f_code.co_name
return "unknown"
def analyze_question_type(self, question: str) -> str:
"""分析问题类型"""
prompt = get_prompt("question_type_analysis", question=question)
result = self._call_api(prompt, temperature=0.3)
# 验证返回值
valid_types = ["factual", "comparative", "exploratory", "decision"]
result = result.lower().strip()
if result not in valid_types:
logger.warning(f"无效的问题类型: {result}默认使用exploratory")
return "exploratory"
return result
def refine_questions(self, question: str, question_type: str) -> List[str]:
"""细化问题"""
prompt = get_prompt("refine_questions",
question=question,
question_type=question_type)
result = self._call_api(prompt)
# 解析结果为列表
questions = [q.strip() for q in result.split('\n') if q.strip()]
# 过滤掉可能的序号
questions = [q.lstrip('0123456789.-) ') for q in questions]
return questions[:5] # 最多返回5个
def create_research_approach(self, question: str, question_type: str,
refined_questions: List[str]) -> str:
"""制定研究思路"""
refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions)
prompt = get_prompt("research_approach",
question=question,
question_type=question_type,
refined_questions=refined_questions_text)
return self._call_api(prompt)
def create_outline(self, question: str, question_type: str,
refined_questions: List[str], research_approach: str) -> Dict[str, Any]:
"""创建研究大纲"""
refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions)
prompt = get_prompt("create_outline",
question=question,
question_type=question_type,
refined_questions=refined_questions_text,
research_approach=research_approach)
# 尝试获取JSON格式的大纲
for attempt in range(3):
try:
result = self._call_api(prompt, temperature=0.5, json_mode=True)
outline = parse_json_safely(result)
# 验证必要字段
if all(key in outline for key in ["main_topic", "research_questions", "sub_topics"]):
return outline
else:
logger.warning(f"大纲缺少必要字段,第{attempt+1}次尝试")
ai_debug_logger.log_json_parse_error(
result,
f"Missing required fields: {outline.keys()}",
None
)
except Exception as e:
logger.error(f"解析大纲失败,第{attempt+1}次尝试: {e}")
ai_debug_logger.log_json_parse_error(
result if 'result' in locals() else '',
str(e),
None
)
# 返回默认大纲
default_outline = {
"main_topic": question,
"research_questions": refined_questions[:3],
"sub_topics": [
{
"topic": "主要方面分析",
"explain": "针对问题的核心方面进行深入分析",
"priority": "high",
"related_questions": refined_questions[:2]
}
]
}
ai_debug_logger.log_json_parse_error(
result if 'result' in locals() else '',
"Failed to parse after 3 attempts, using default outline",
json.dumps(default_outline, ensure_ascii=False)
)
return default_outline
def validate_outline(self, outline: Dict[str, Any]) -> str:
"""验证大纲完整性"""
prompt = get_prompt("outline_validation", outline=json.dumps(outline, ensure_ascii=False))
return self._call_api(prompt)
def modify_outline(self, original_outline: Dict[str, Any],
user_feedback: str, validation_issues: str) -> Dict[str, Any]:
"""修改大纲"""
prompt = get_prompt("modify_outline",
original_outline=json.dumps(original_outline, ensure_ascii=False),
user_feedback=user_feedback,
validation_issues=validation_issues)
result = self._call_api(prompt, json_mode=True)
return parse_json_safely(result)
def evaluate_search_result(self, subtopic: str, title: str,
url: str, snippet: str) -> str:
"""评估搜索结果重要性"""
prompt = get_prompt("evaluate_search_results",
subtopic=subtopic,
title=title,
url=url,
snippet=snippet)
result = self._call_api(prompt, temperature=0.3).lower().strip()
# 验证返回值
if result not in ["high", "medium", "low"]:
return "medium"
return result
def reflect_on_information(self, subtopic: str, search_summary: str) -> List[Dict[str, str]]:
"""信息反思,返回需要深入搜索的要点"""
# 这里可以基于search_summary生成更详细的分析
prompt = get_prompt("information_reflection",
subtopic=subtopic,
search_summary=search_summary,
detailed_analysis="[基于搜索结果的详细分析]")
result = self._call_api(prompt)
# 解析结果,提取需要深入的要点
# 简单实现,实际可能需要更复杂的解析
key_points = []
lines = result.split('\n')
for line in lines:
if line.strip() and '还需要搜索' in line:
parts = line.split('还需要搜索')
if len(parts) == 2:
key_points.append({
"key_info": parts[0].strip(),
"detail_needed": parts[1].strip('() ')
})
return key_points
def integrate_information(self, subtopic: str, all_search_results: str) -> Dict[str, Any]:
"""整合信息为结构化格式"""
prompt = get_prompt("integrate_information",
subtopic=subtopic,
all_search_results=all_search_results)
result = self._call_api(prompt, json_mode=True)
parsed = parse_json_safely(result)
# 确保返回的结构包含必要字段
if 'key_points' not in parsed:
parsed['key_points'] = []
if 'themes' not in parsed:
parsed['themes'] = []
return parsed
def write_subtopic_report(self, subtopic: str, integrated_info: Dict[str, Any]) -> str:
"""撰写子主题报告"""
prompt = get_prompt("write_subtopic_report",
subtopic=subtopic,
integrated_info=json.dumps(integrated_info, ensure_ascii=False))
return self._call_api(prompt, temperature=0.7, max_tokens=8192)
def detect_hallucination(self, written_content: str, claimed_url: str,
original_content: str) -> Dict[str, Any]:
"""检测幻觉内容"""
prompt = get_prompt("hallucination_detection",
written_content=written_content,
claimed_url=claimed_url,
original_content=original_content)
result = self._call_api(prompt, temperature=0.3, json_mode=True)
parsed = parse_json_safely(result)
# 确保返回结构正确
if 'is_hallucination' not in parsed:
parsed['is_hallucination'] = False
if 'hallucination_type' not in parsed:
parsed['hallucination_type'] = None
if 'explanation' not in parsed:
parsed['explanation'] = ''
return parsed
def generate_final_report(self, main_topic: str, research_questions: List[str],
subtopic_reports: Dict[str, str]) -> str:
"""生成最终报告"""
# 格式化子主题报告
reports_text = "\n\n---\n\n".join([
f"### {topic}\n{report}"
for topic, report in subtopic_reports.items()
])
prompt = get_prompt("generate_final_report",
main_topic=main_topic,
research_questions='\n'.join(f"- {q}" for q in research_questions),
subtopic_reports=reports_text)
return self._call_api(prompt, temperature=0.7, max_tokens=16384)

252
app/agents/v3_agent.py Normal file
View File

@ -0,0 +1,252 @@
"""
DeepSeek V3模型智能体 - 带调试功能
负责API调用内容重写等执行型任务
"""
import json
import logging
from typing import Dict, List, Any, Optional
from openai import OpenAI
from config import Config
from app.agents.prompts import get_prompt
from app.utils.debug_logger import ai_debug_logger
logger = logging.getLogger(__name__)
class V3Agent:
"""V3模型智能体"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.DEEPSEEK_API_KEY
base_url = Config.DEEPSEEK_BASE_URL
# 火山引擎 ARK 平台使用不同的模型名称
if 'volces.com' in base_url:
self.model = "deepseek-v3-241226" # 火山引擎的 V3 模型名称
else:
self.model = Config.V3_MODEL
self.client = OpenAI(
api_key=self.api_key,
base_url=base_url
)
def _call_api(self, prompt: str, temperature: float = 0.3,
max_tokens: int = 4096, functions: List[Dict] = None) -> Any:
"""调用V3 API"""
try:
messages = [{"role": "user", "content": prompt}]
kwargs = {
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
# 如果提供了functions添加function calling参数
if functions:
kwargs["functions"] = functions
kwargs["function_call"] = "auto"
logger.info(f"调用V3 API: temperature={temperature}, max_tokens={max_tokens}, functions={bool(functions)}")
response = self.client.chat.completions.create(**kwargs)
# 准备响应内容
if functions and response.choices[0].message.function_call:
result = {
"function_call": {
"name": response.choices[0].message.function_call.name,
"arguments": json.loads(response.choices[0].message.function_call.arguments)
}
}
response_text = json.dumps(result, ensure_ascii=False)
else:
result = response.choices[0].message.content.strip()
response_text = result
# 记录调试日志
ai_debug_logger.log_api_call(
model=self.model,
agent_type="V3",
method=self._get_caller_method(),
prompt=prompt,
response=response_text,
temperature=temperature,
max_tokens=max_tokens,
metadata={
"has_functions": bool(functions),
"function_count": len(functions) if functions else 0,
"prompt_tokens": response.usage.prompt_tokens if hasattr(response, 'usage') else None,
"completion_tokens": response.usage.completion_tokens if hasattr(response, 'usage') else None
}
)
return result
except Exception as e:
logger.error(f"V3 API调用失败: {e}")
ai_debug_logger.log_api_call(
model=self.model,
agent_type="V3",
method=self._get_caller_method(),
prompt=prompt,
response=f"ERROR: {str(e)}",
temperature=temperature,
max_tokens=max_tokens,
metadata={"error": str(e)}
)
raise
def _get_caller_method(self) -> str:
"""获取调用方法名"""
import inspect
frame = inspect.currentframe()
if frame and frame.f_back and frame.f_back.f_back:
return frame.f_back.f_back.f_code.co_name
return "unknown"
def generate_search_queries(self, subtopic: str, explanation: str,
related_questions: List[str], count: int) -> List[str]:
"""生成搜索查询"""
prompt = get_prompt("generate_search_queries",
subtopic=subtopic,
explanation=explanation,
related_questions=', '.join(related_questions),
count=count)
result = self._call_api(prompt, temperature=0.7)
# 解析结果为列表
queries = [q.strip() for q in result.split('\n') if q.strip()]
# 去除可能的序号
queries = [q.lstrip('0123456789.-) ') for q in queries]
# 记录解析后的查询
logger.debug(f"生成了{len(queries)}个搜索查询")
return queries[:count]
def generate_refined_queries(self, key_info: str, detail_needed: str) -> List[str]:
"""生成细化搜索查询"""
prompt = get_prompt("generate_refined_queries",
key_info=key_info,
detail_needed=detail_needed)
result = self._call_api(prompt, temperature=0.7)
queries = [q.strip() for q in result.split('\n') if q.strip()]
queries = [q.lstrip('0123456789.-) ') for q in queries]
logger.debug(f"'{key_info}'生成了{len(queries)}个细化查询")
return queries[:3]
def rewrite_hallucination(self, hallucinated_content: str,
original_sources: str) -> str:
"""重写幻觉内容"""
prompt = get_prompt("rewrite_hallucination",
hallucinated_content=hallucinated_content,
original_sources=original_sources)
rewritten = self._call_api(prompt, temperature=0.3)
# 记录幻觉修正
logger.info(f"修正幻觉内容: 原始长度={len(hallucinated_content)}, 修正后长度={len(rewritten)}")
return rewritten
def call_tavily_search(self, query: str, max_results: int = 10) -> Dict[str, Any]:
"""
调用Tavily搜索API通过function calling
注意这是一个示例实现实际的Tavily调用会在search_service.py中
"""
# 定义Tavily搜索function
tavily_function = {
"name": "tavily_search",
"description": "Search the web using Tavily API",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results",
"default": 10
},
"search_depth": {
"type": "string",
"enum": ["basic", "advanced"],
"default": "advanced"
}
},
"required": ["query"]
}
}
prompt = f"Please search for information about: {query}"
result = self._call_api(prompt, functions=[tavily_function])
# 如果返回的是function call提取参数
if isinstance(result, dict) and "function_call" in result:
return result["function_call"]["arguments"]
# 否则返回默认参数
return {
"query": query,
"max_results": max_results,
"search_depth": "advanced"
}
def format_search_results(self, results: List[Dict[str, Any]]) -> str:
"""格式化搜索结果为结构化文本"""
formatted = []
for i, result in enumerate(results, 1):
formatted.append(f"{i}. 标题: {result.get('title', 'N/A')}")
formatted.append(f" URL: {result.get('url', 'N/A')}")
formatted.append(f" 摘要: {result.get('snippet', 'N/A')}")
if result.get('score'):
formatted.append(f" 相关度: {result.get('score', 0):.2f}")
formatted.append("")
return '\n'.join(formatted)
def extract_key_points(self, text: str, max_points: int = 5) -> List[str]:
"""从文本中提取关键点"""
prompt = f"""
请从以下文本中提取最多{max_points}个关键点
{text}
每个关键点独占一行简洁明了
"""
result = self._call_api(prompt, temperature=0.5)
points = [p.strip() for p in result.split('\n') if p.strip()]
points = [p.lstrip('0123456789.-) ') for p in points]
logger.debug(f"从文本中提取了{len(points)}个关键点")
return points[:max_points]
def summarize_content(self, content: str, max_length: int = 200) -> str:
"""总结内容"""
prompt = f"""
请将以下内容总结为不超过{max_length}字的摘要
{content}
要求保留关键信息语言流畅
"""
summary = self._call_api(prompt, temperature=0.5)
logger.debug(f"内容总结: 原始长度={len(content)}, 摘要长度={len(summary)}")
return summary

0
app/models/__init__.py Normal file
View File

171
app/models/report.py Normal file
View File

@ -0,0 +1,171 @@
"""
研究报告数据模型
"""
import uuid
from datetime import datetime
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class ReportSection(BaseModel):
"""报告章节"""
title: str
content: str
subsections: List['ReportSection'] = []
sources: List[str] = [] # URL列表
def to_markdown(self, level: int = 1) -> str:
"""转换为Markdown格式"""
header = "#" * level
markdown = f"{header} {self.title}\n\n{self.content}\n\n"
# 添加子章节
for subsection in self.subsections:
markdown += subsection.to_markdown(level + 1)
# 添加来源
if self.sources:
markdown += f"\n{'#' * (level + 1)} 参考来源\n"
for i, source in enumerate(self.sources, 1):
markdown += f"{i}. [{source}]({source})\n"
markdown += "\n"
return markdown
# 允许递归引用
ReportSection.model_rebuild()
class KeyInsight(BaseModel):
"""关键洞察"""
insight: str
supporting_evidence: List[str] = []
source_urls: List[str] = []
confidence: float = 0.0 # 0-1之间
class SubtopicReport(BaseModel):
"""子主题报告"""
subtopic_id: str
subtopic_name: str
sections: List[ReportSection] = []
key_insights: List[KeyInsight] = []
recommendations: List[str] = []
created_at: datetime = Field(default_factory=datetime.now)
word_count: int = 0
def to_markdown(self) -> str:
"""转换为Markdown格式"""
markdown = f"## {self.subtopic_name}\n\n"
# 添加各个章节
for section in self.sections:
markdown += section.to_markdown(level=3)
# 添加关键洞察
if self.key_insights:
markdown += "### 关键洞察\n\n"
for i, insight in enumerate(self.key_insights, 1):
markdown += f"{i}. **{insight.insight}**\n"
if insight.supporting_evidence:
for evidence in insight.supporting_evidence:
markdown += f" - {evidence}\n"
if insight.source_urls:
markdown += f" - 来源: "
markdown += ", ".join([f"[{i+1}]({url})" for i, url in enumerate(insight.source_urls)])
markdown += "\n"
markdown += "\n"
# 添加建议
if self.recommendations:
markdown += "### 建议与展望\n\n"
for recommendation in self.recommendations:
markdown += f"- {recommendation}\n"
markdown += "\n"
return markdown
class HallucinationCheck(BaseModel):
"""幻觉检查记录"""
content: str
source_url: str
original_text: Optional[str] = None
is_hallucination: bool = False
hallucination_type: Optional[str] = None # 夸大/错误归因/无中生有
corrected_content: Optional[str] = None
checked_at: datetime = Field(default_factory=datetime.now)
class FinalReport(BaseModel):
"""最终研究报告"""
session_id: str
title: str
executive_summary: str
main_findings: List[ReportSection] = []
subtopic_reports: List[SubtopicReport] = []
overall_insights: List[KeyInsight] = []
recommendations: List[str] = []
methodology: Optional[str] = None
limitations: List[str] = []
created_at: datetime = Field(default_factory=datetime.now)
total_sources: int = 0
total_searches: int = 0
def to_markdown(self) -> str:
"""转换为完整的Markdown报告"""
markdown = f"# {self.title}\n\n"
markdown += f"*生成时间: {self.created_at.strftime('%Y-%m-%d %H:%M:%S')}*\n\n"
# 执行摘要
markdown += "## 执行摘要\n\n"
markdown += f"{self.executive_summary}\n\n"
# 主要发现
if self.main_findings:
markdown += "## 主要发现\n\n"
for finding in self.main_findings:
markdown += finding.to_markdown(level=3)
# 整体洞察
if self.overall_insights:
markdown += "## 综合洞察\n\n"
for i, insight in enumerate(self.overall_insights, 1):
markdown += f"### {i}. {insight.insight}\n\n"
if insight.supporting_evidence:
for evidence in insight.supporting_evidence:
markdown += f"- {evidence}\n"
markdown += "\n"
# 建议
if self.recommendations:
markdown += "## 建议\n\n"
for recommendation in self.recommendations:
markdown += f"- {recommendation}\n"
markdown += "\n"
# 详细子主题报告
markdown += "## 详细分析\n\n"
for report in self.subtopic_reports:
markdown += report.to_markdown()
markdown += "---\n\n"
# 研究方法
if self.methodology:
markdown += "## 研究方法\n\n"
markdown += f"{self.methodology}\n\n"
# 局限性
if self.limitations:
markdown += "## 研究局限性\n\n"
for limitation in self.limitations:
markdown += f"- {limitation}\n"
markdown += "\n"
# 统计信息
markdown += "## 研究统计\n\n"
markdown += f"- 总搜索次数: {self.total_searches}\n"
markdown += f"- 引用来源数: {self.total_sources}\n"
markdown += f"- 分析子主题数: {len(self.subtopic_reports)}\n"
return markdown
def save_to_file(self, filepath: str):
"""保存为Markdown文件"""
with open(filepath, 'w', encoding='utf-8') as f:
f.write(self.to_markdown())

126
app/models/research.py Normal file
View File

@ -0,0 +1,126 @@
"""
研究会话数据模型
"""
import json
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field
class QuestionType(str, Enum):
"""问题类型枚举"""
FACTUAL = "factual" # 事实查询型
COMPARATIVE = "comparative" # 分析对比型
EXPLORATORY = "exploratory" # 探索发现型
DECISION = "decision" # 决策支持型
class ResearchStatus(str, Enum):
"""研究状态枚举"""
PENDING = "pending"
ANALYZING = "analyzing"
OUTLINING = "outlining"
RESEARCHING = "researching"
WRITING = "writing"
REVIEWING = "reviewing"
COMPLETED = "completed"
ERROR = "error"
CANCELLED = "cancelled"
class SubtopicPriority(str, Enum):
"""子主题优先级"""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class Subtopic(BaseModel):
"""子主题模型"""
id: str = Field(default_factory=lambda: f"ST{uuid.uuid4().hex[:8]}")
topic: str
explain: str
priority: SubtopicPriority
related_questions: List[str] = []
status: ResearchStatus = ResearchStatus.PENDING
search_count: int = 0
max_searches: int = 15
searches: List[Dict[str, Any]] = []
refined_searches: List[Dict[str, Any]] = []
integrated_info: Optional[Dict[str, Any]] = None
report: Optional[str] = None
hallucination_checks: List[Dict[str, Any]] = []
def get_total_searches(self) -> int:
"""获取总搜索次数"""
return len(self.searches) + len(self.refined_searches)
class ResearchOutline(BaseModel):
"""研究大纲模型"""
main_topic: str
research_questions: List[str]
sub_topics: List[Subtopic]
version: int = 1
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class ResearchSession(BaseModel):
"""研究会话模型"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
question: str
question_type: Optional[QuestionType] = None
refined_questions: List[str] = []
research_approach: Optional[str] = None
status: ResearchStatus = ResearchStatus.PENDING
outline: Optional[ResearchOutline] = None
final_report: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
completed_at: Optional[datetime] = None
error_message: Optional[str] = None
# 进度追踪
total_steps: int = 0
completed_steps: int = 0
current_phase: str = "初始化"
def update_status(self, status: ResearchStatus):
"""更新状态"""
self.status = status
self.updated_at = datetime.now()
if status == ResearchStatus.COMPLETED:
self.completed_at = datetime.now()
def get_progress_percentage(self) -> float:
"""获取进度百分比"""
if self.total_steps == 0:
return 0.0
return (self.completed_steps / self.total_steps) * 100
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
data = self.dict()
# 转换datetime对象为字符串
for key in ['created_at', 'updated_at', 'completed_at']:
if data.get(key):
data[key] = data[key].isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ResearchSession':
"""从字典创建实例"""
# 转换字符串为datetime对象
for key in ['created_at', 'updated_at', 'completed_at']:
if data.get(key) and isinstance(data[key], str):
data[key] = datetime.fromisoformat(data[key])
return cls(**data)
def save_to_file(self, filepath: str):
"""保存到文件"""
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
@classmethod
def load_from_file(cls, filepath: str) -> 'ResearchSession':
"""从文件加载"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return cls.from_dict(data)

119
app/models/search_result.py Normal file
View File

@ -0,0 +1,119 @@
"""
搜索结果数据模型
"""
from datetime import datetime
from typing import List, Optional, Dict, Any
from enum import Enum
from pydantic import BaseModel, Field
class SearchImportance(str, Enum):
"""搜索结果重要性"""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class SearchResult(BaseModel):
"""单个搜索结果"""
title: str
url: str
snippet: str
score: float = 0.0
published_date: Optional[str] = None
raw_content: Optional[str] = None
importance: Optional[SearchImportance] = None
key_findings: List[str] = []
def __hash__(self):
"""使URL可以用于集合去重"""
return hash(self.url)
def __eq__(self, other):
"""基于URL判断相等性"""
if isinstance(other, SearchResult):
return self.url == other.url
return False
class TavilySearchResponse(BaseModel):
"""Tavily API响应模型"""
query: str
answer: Optional[str] = None
images: List[str] = []
results: List[Dict[str, Any]] = []
response_time: float = 0.0
def to_search_results(self) -> List[SearchResult]:
"""转换为SearchResult列表"""
search_results = []
for result in self.results:
search_results.append(SearchResult(
title=result.get('title', ''),
url=result.get('url', ''),
snippet=result.get('content', ''),
score=result.get('score', 0.0),
published_date=result.get('published_date'),
raw_content=result.get('raw_content')
))
return search_results
class SearchBatch(BaseModel):
"""搜索批次"""
search_id: str = Field(default_factory=lambda: f"S{uuid.uuid4().hex[:8]}")
subtopic_id: str
query: str
timestamp: datetime = Field(default_factory=datetime.now)
results: List[SearchResult] = []
is_refined_search: bool = False
parent_search_id: Optional[str] = None
detail_type: Optional[str] = None
total_results: int = 0
def add_results(self, results: List[SearchResult]):
"""添加搜索结果并去重"""
existing_urls = {r.url for r in self.results}
for result in results:
if result.url not in existing_urls:
self.results.append(result)
existing_urls.add(result.url)
self.total_results = len(self.results)
class SearchSummary(BaseModel):
"""搜索摘要统计"""
subtopic_id: str
total_searches: int = 0
total_results: int = 0
high_importance_count: int = 0
medium_importance_count: int = 0
low_importance_count: int = 0
unique_domains: List[str] = []
@classmethod
def from_search_batches(cls, subtopic_id: str, batches: List[SearchBatch]) -> 'SearchSummary':
"""从搜索批次生成摘要"""
summary = cls(subtopic_id=subtopic_id)
summary.total_searches = len(batches)
all_results = []
domains = set()
for batch in batches:
all_results.extend(batch.results)
for result in batch.results:
# 提取域名
from urllib.parse import urlparse
domain = urlparse(result.url).netloc
if domain:
domains.add(domain)
summary.total_results = len(all_results)
summary.unique_domains = list(domains)
# 统计重要性
for result in all_results:
if result.importance == SearchImportance.HIGH:
summary.high_importance_count += 1
elif result.importance == SearchImportance.MEDIUM:
summary.medium_importance_count += 1
elif result.importance == SearchImportance.LOW:
summary.low_importance_count += 1
return summary

0
app/routes/__init__.py Normal file
View File

427
app/routes/api.py Normal file
View File

@ -0,0 +1,427 @@
# 文件位置: app/routes/api.py
# 文件名: api.py
"""
API路由
处理研究相关的API请求
"""
from flask import Blueprint, request, jsonify, current_app, send_file
from app.services.research_manager import ResearchManager
from app.services.task_manager import task_manager
from app.utils.validators import validate_question, validate_outline_feedback
import os
api_bp = Blueprint('api', __name__)
research_manager = ResearchManager()
@api_bp.route('/research', methods=['POST'])
def create_research():
"""创建新的研究任务"""
try:
data = request.get_json()
# 验证输入
question = data.get('question', '').strip()
error = validate_question(question)
if error:
return jsonify({"error": error}), 400
# 创建研究会话
session = research_manager.create_session(question)
# 自动开始研究(可选)
auto_start = data.get('auto_start', True)
if auto_start:
result = research_manager.start_research(session.id)
return jsonify({
"session_id": session.id,
"status": "started",
"message": "研究已开始",
"created_at": session.created_at.isoformat()
})
else:
return jsonify({
"session_id": session.id,
"status": "created",
"message": "研究会话已创建,等待开始",
"created_at": session.created_at.isoformat()
})
except Exception as e:
current_app.logger.error(f"创建研究失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/start', methods=['POST'])
def start_research(session_id):
"""手动开始研究"""
try:
result = research_manager.start_research(session_id)
return jsonify(result)
except ValueError as e:
return jsonify({"error": str(e)}), 404
except Exception as e:
current_app.logger.error(f"开始研究失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/status', methods=['GET'])
def get_research_status(session_id):
"""获取研究状态"""
try:
status = research_manager.get_session_status(session_id)
if "error" in status:
return jsonify(status), 404
# 添加任务信息
tasks = task_manager.get_session_tasks(session_id)
status['tasks'] = tasks
return jsonify(status)
except Exception as e:
current_app.logger.error(f"获取状态失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/outline', methods=['GET'])
def get_research_outline(session_id):
"""获取研究大纲"""
try:
session = research_manager.get_session(session_id)
if not session:
return jsonify({"error": "Session not found"}), 404
if not session.outline:
return jsonify({"error": "Outline not yet created"}), 400
return jsonify({
"main_topic": session.outline.main_topic,
"research_questions": session.outline.research_questions,
"sub_topics": [
{
"id": st.id,
"topic": st.topic,
"explain": st.explain,
"priority": st.priority,
"status": st.status
}
for st in session.outline.sub_topics
],
"version": session.outline.version
})
except Exception as e:
current_app.logger.error(f"获取大纲失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/outline', methods=['PUT'])
def update_research_outline(session_id):
"""更新研究大纲(用户反馈)"""
try:
data = request.get_json()
feedback = data.get('feedback', '').strip()
error = validate_outline_feedback(feedback)
if error:
return jsonify({"error": error}), 400
# TODO: 实现大纲更新逻辑
# 这需要调用AI服务来修改大纲
return jsonify({
"message": "大纲更新请求已接收",
"status": "processing"
})
except Exception as e:
current_app.logger.error(f"更新大纲失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/cancel', methods=['POST'])
def cancel_research(session_id):
"""取消研究"""
try:
# 取消任务
cancelled_count = task_manager.cancel_session_tasks(session_id)
# 更新会话状态
result = research_manager.cancel_research(session_id)
if "error" in result:
return jsonify(result), 404
result['cancelled_tasks'] = cancelled_count
return jsonify(result)
except Exception as e:
current_app.logger.error(f"取消研究失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/report', methods=['GET'])
def get_research_report(session_id):
"""获取研究报告"""
try:
format = request.args.get('format', 'json')
report_content = research_manager.get_research_report(session_id)
if not report_content:
return jsonify({"error": "Report not available"}), 404
if format == 'markdown':
# 返回Markdown文件
report_path = os.path.join(
current_app.config['REPORTS_DIR'],
f"{session_id}.md"
)
if os.path.exists(report_path):
return send_file(
report_path,
mimetype='text/markdown',
as_attachment=True,
download_name=f"research_report_{session_id}.md"
)
else:
# 临时创建文件
with open(report_path, 'w', encoding='utf-8') as f:
f.write(report_content)
return send_file(
report_path,
mimetype='text/markdown',
as_attachment=True,
download_name=f"research_report_{session_id}.md"
)
else:
# 返回JSON格式
return jsonify({
"session_id": session_id,
"report": report_content,
"format": "markdown"
})
except Exception as e:
current_app.logger.error(f"获取报告失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/subtopic/<subtopic_id>', methods=['GET'])
def get_subtopic_detail(session_id, subtopic_id):
"""获取子主题详情"""
try:
session = research_manager.get_session(session_id)
if not session:
return jsonify({"error": "Session not found"}), 404
if not session.outline:
return jsonify({"error": "Outline not created"}), 400
# 找到对应的子主题
subtopic = None
for st in session.outline.sub_topics:
if st.id == subtopic_id:
subtopic = st
break
if not subtopic:
return jsonify({"error": "Subtopic not found"}), 404
return jsonify({
"id": subtopic.id,
"topic": subtopic.topic,
"explain": subtopic.explain,
"priority": subtopic.priority,
"status": subtopic.status,
"search_count": subtopic.search_count,
"max_searches": subtopic.max_searches,
"progress": subtopic.get_total_searches() / subtopic.max_searches * 100,
"has_report": subtopic.report is not None
})
except Exception as e:
current_app.logger.error(f"获取子主题详情失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/sessions', methods=['GET'])
def list_research_sessions():
"""列出所有研究会话"""
try:
limit = request.args.get('limit', 20, type=int)
offset = request.args.get('offset', 0, type=int)
sessions = research_manager.list_sessions(limit=limit, offset=offset)
return jsonify({
"sessions": sessions,
"total": len(sessions),
"limit": limit,
"offset": offset
})
except Exception as e:
current_app.logger.error(f"列出会话失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/tasks/status', methods=['GET'])
def get_tasks_status():
"""获取任务管理器状态"""
try:
tasks = task_manager.tasks
status_counts = {
'pending': 0,
'running': 0,
'completed': 0,
'failed': 0,
'cancelled': 0
}
for task in tasks.values():
status_counts[task.status.value] += 1
return jsonify({
"total_tasks": len(tasks),
"status_counts": status_counts,
"sessions_count": len(task_manager.session_tasks)
})
except Exception as e:
current_app.logger.error(f"获取任务状态失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/test/connections', methods=['GET'])
def test_connections():
"""测试API连接仅开发环境"""
if not current_app.debug:
return jsonify({"error": "Not available in production"}), 403
from app.services.search_service import SearchService
from app.services.ai_service import AIService
results = {
"deepseek_api": False,
"tavily_api": False,
"task_manager": False
}
try:
# 测试DeepSeek API
ai_service = AIService()
test_result = ai_service.analyze_question_type("test question")
results["deepseek_api"] = bool(test_result)
except Exception as e:
current_app.logger.error(f"DeepSeek API测试失败: {e}")
try:
# 测试Tavily API
search_service = SearchService()
results["tavily_api"] = search_service.test_connection()
except Exception as e:
current_app.logger.error(f"Tavily API测试失败: {e}")
# 测试任务管理器
try:
# 提交一个测试任务
def test_task():
return "test"
task_id = task_manager.submit_task(test_task)
status = task_manager.get_task_status(task_id)
results["task_manager"] = status is not None
except Exception as e:
current_app.logger.error(f"任务管理器测试失败: {e}")
return jsonify({
"connections": results,
"all_connected": all(results.values())
})
# 添加到 app/routes/api.py 的末尾
@api_bp.route('/research/<session_id>/debug', methods=['GET'])
def get_debug_logs(session_id):
"""获取研究会话的调试日志"""
try:
from app.utils.debug_logger import ai_debug_logger
log_type = request.args.get('type', 'all') # all, api_calls, thinking, errors
limit = request.args.get('limit', 100, type=int)
# 验证会话是否存在
session = research_manager.get_session(session_id)
if not session:
return jsonify({"error": "Session not found"}), 404
# 获取日志
logs = ai_debug_logger.get_session_logs(session_id, log_type)
# 限制返回数量
if limit > 0:
logs = logs[-limit:] # 返回最新的N条
return jsonify({
"session_id": session_id,
"log_type": log_type,
"count": len(logs),
"logs": logs
})
except Exception as e:
current_app.logger.error(f"获取调试日志失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/debug/download', methods=['GET'])
def download_debug_logs(session_id):
"""下载调试日志文件"""
try:
from app.utils.debug_logger import ai_debug_logger
import zipfile
import io
# 验证会话
session = research_manager.get_session(session_id)
if not session:
return jsonify({"error": "Session not found"}), 404
# 创建ZIP文件
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# 添加所有调试文件
debug_dir = os.path.join(ai_debug_logger.debug_dir, session_id)
if os.path.exists(debug_dir):
for filename in os.listdir(debug_dir):
file_path = os.path.join(debug_dir, filename)
if os.path.isfile(file_path):
zip_file.write(file_path, filename)
zip_buffer.seek(0)
return send_file(
zip_buffer,
mimetype='application/zip',
as_attachment=True,
download_name=f"debug_logs_{session_id}.zip"
)
except Exception as e:
current_app.logger.error(f"下载调试日志失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/debug/enable', methods=['POST'])
def enable_debug_mode():
"""启用调试模式"""
try:
data = request.get_json()
session_id = data.get('session_id')
if not session_id:
return jsonify({"error": "session_id required"}), 400
# 设置调试会话
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
# 如果有socketio设置它
if hasattr(current_app, 'socketio'):
ai_debug_logger.set_socketio(current_app.socketio)
return jsonify({
"status": "enabled",
"session_id": session_id,
"message": "调试模式已启用"
})
except Exception as e:
current_app.logger.error(f"启用调试模式失败: {e}")
return jsonify({"error": str(e)}), 500

30
app/routes/frontend.py Normal file
View File

@ -0,0 +1,30 @@
# 文件位置: app/routes/frontend.py
# 文件名: frontend.py
"""
前端页面路由
"""
from flask import Blueprint, render_template, send_from_directory
import os
frontend_bp = Blueprint('frontend', __name__)
@frontend_bp.route('/')
def index():
"""主页"""
return render_template('index.html')
@frontend_bp.route('/research/<session_id>')
def research_detail(session_id):
"""研究详情页"""
return render_template('research.html', session_id=session_id)
@frontend_bp.route('/debug')
def debug_page():
"""调试页面"""
return render_template('debug.html')
@frontend_bp.route('/static/<path:filename>')
def static_files(filename):
"""静态文件"""
return send_from_directory('static', filename)

45
app/routes/main.py Normal file
View File

@ -0,0 +1,45 @@
"""
主路由
处理API信息和健康检查
"""
from flask import Blueprint, jsonify, current_app
main_bp = Blueprint('main', __name__)
@main_bp.route('/api')
def api_info():
"""API信息"""
return jsonify({
"message": "Welcome to DeepResearch API",
"version": "1.0.0",
"endpoints": {
"create_research": "POST /api/research",
"get_status": "GET /api/research/<session_id>/status",
"get_report": "GET /api/research/<session_id>/report",
"list_sessions": "GET /api/research/sessions"
}
})
@main_bp.route('/health')
def health_check():
"""健康检查"""
return jsonify({
"status": "healthy",
"service": "DeepResearch"
})
@main_bp.route('/api/config')
def get_config():
"""获取配置信息(仅开发环境)"""
if current_app.debug:
return jsonify({
"debug": current_app.debug,
"max_concurrent_subtopics": current_app.config.get('MAX_CONCURRENT_SUBTOPICS'),
"search_priorities": {
"high": current_app.config.get('MAX_SEARCHES_HIGH_PRIORITY'),
"medium": current_app.config.get('MAX_SEARCHES_MEDIUM_PRIORITY'),
"low": current_app.config.get('MAX_SEARCHES_LOW_PRIORITY')
}
})
else:
return jsonify({"error": "Not available in production"}), 403

141
app/routes/websocket.py Normal file
View File

@ -0,0 +1,141 @@
"""
WebSocket事件处理
实时推送研究进度
"""
from flask_socketio import emit, join_room, leave_room
from flask import request
import logging
logger = logging.getLogger(__name__)
def register_handlers(socketio):
"""注册WebSocket事件处理器"""
@socketio.on('connect')
def handle_connect():
"""客户端连接"""
client_id = request.sid
logger.info(f"客户端连接: {client_id}")
emit('connected', {'message': '连接成功', 'client_id': client_id})
@socketio.on('disconnect')
def handle_disconnect():
"""客户端断开"""
client_id = request.sid
logger.info(f"客户端断开: {client_id}")
@socketio.on('join_session')
def handle_join_session(data):
"""加入研究会话房间"""
session_id = data.get('session_id')
if session_id:
join_room(session_id)
logger.info(f"客户端 {request.sid} 加入房间 {session_id}")
emit('joined', {'session_id': session_id, 'message': '已加入研究会话'})
@socketio.on('leave_session')
def handle_leave_session(data):
"""离开研究会话房间"""
session_id = data.get('session_id')
if session_id:
leave_room(session_id)
logger.info(f"客户端 {request.sid} 离开房间 {session_id}")
emit('left', {'session_id': session_id, 'message': '已离开研究会话'})
# 以下是推送给客户端的事件(由任务调用)
@socketio.on('research_progress')
def broadcast_progress(data):
"""广播研究进度"""
session_id = data.get('session_id')
if session_id:
socketio.emit('progress_update', data, room=session_id)
@socketio.on('research_status_change')
def broadcast_status_change(data):
"""广播状态变化"""
session_id = data.get('session_id')
if session_id:
socketio.emit('status_changed', data, room=session_id)
@socketio.on('subtopic_update')
def broadcast_subtopic_update(data):
"""广播子主题更新"""
session_id = data.get('session_id')
if session_id:
socketio.emit('subtopic_updated', data, room=session_id)
@socketio.on('search_result')
def broadcast_search_result(data):
"""广播搜索结果"""
session_id = data.get('session_id')
if session_id:
socketio.emit('new_search_result', data, room=session_id)
@socketio.on('report_ready')
def broadcast_report_ready(data):
"""广播报告就绪"""
session_id = data.get('session_id')
if session_id:
socketio.emit('report_available', data, room=session_id)
@socketio.on('error_occurred')
def broadcast_error(data):
"""广播错误信息"""
session_id = data.get('session_id')
if session_id:
socketio.emit('research_error', data, room=session_id)
def emit_progress(socketio, session_id: str, progress_data: dict):
"""发送进度更新(供任务调用)"""
socketio.emit('progress_update', {
'session_id': session_id,
**progress_data
}, room=session_id)
def emit_status_change(socketio, session_id: str, status: str, phase: str = None):
"""发送状态变化(供任务调用)"""
data = {
'session_id': session_id,
'status': status
}
if phase:
data['phase'] = phase
socketio.emit('status_changed', data, room=session_id)
def emit_subtopic_progress(socketio, session_id: str, subtopic_id: str,
progress: float, status: str):
"""发送子主题进度(供任务调用)"""
socketio.emit('subtopic_updated', {
'session_id': session_id,
'subtopic_id': subtopic_id,
'progress': progress,
'status': status
}, room=session_id)
def emit_search_complete(socketio, session_id: str, subtopic_id: str,
search_count: int, results_count: int):
"""发送搜索完成通知(供任务调用)"""
socketio.emit('search_completed', {
'session_id': session_id,
'subtopic_id': subtopic_id,
'search_count': search_count,
'results_count': results_count
}, room=session_id)
def emit_report_ready(socketio, session_id: str, report_type: str):
"""发送报告就绪通知(供任务调用)"""
socketio.emit('report_available', {
'session_id': session_id,
'report_type': report_type,
'message': f'{report_type}报告已生成'
}, room=session_id)
def emit_error(socketio, session_id: str, error_message: str, error_type: str = 'general'):
"""发送错误通知(供任务调用)"""
socketio.emit('research_error', {
'session_id': session_id,
'error_type': error_type,
'error_message': error_message
}, room=session_id)

0
app/services/__init__.py Normal file
View File

316
app/services/ai_service.py Normal file
View File

@ -0,0 +1,316 @@
"""
AI服务层
封装对R1和V3智能体的调用
"""
import logging
from typing import Dict, List, Any, Optional, Tuple
from app.agents.r1_agent import R1Agent
from app.agents.v3_agent import V3Agent
from app.models.search_result import SearchResult, SearchImportance
from config import Config
logger = logging.getLogger(__name__)
class AIService:
"""AI服务统一接口"""
def __init__(self):
self.r1_agent = R1Agent()
self.v3_agent = V3Agent()
# 设置调试器
try:
from app import socketio
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_socketio(socketio)
logger.info("AI调试器已初始化")
except Exception as e:
logger.warning(f"无法初始化调试器: {e}")
# ========== 问题分析阶段 (R1) ==========
def analyze_question_type(self, question: str) -> str:
"""分析问题类型"""
try:
return self.r1_agent.analyze_question_type(question)
except Exception as e:
logger.error(f"分析问题类型失败: {e}")
return "exploratory" # 默认值
def refine_questions(self, question: str, question_type: str) -> List[str]:
"""细化问题"""
try:
return self.r1_agent.refine_questions(question, question_type)
except Exception as e:
logger.error(f"细化问题失败: {e}")
return [question] # 返回原问题
def create_research_approach(self, question: str, question_type: str,
refined_questions: List[str]) -> str:
"""制定研究思路"""
try:
return self.r1_agent.create_research_approach(
question, question_type, refined_questions
)
except Exception as e:
logger.error(f"制定研究思路失败: {e}")
return "采用系统化的方法深入研究这个问题。"
# ========== 大纲制定阶段 (R1) ==========
def create_outline(self, question: str, question_type: str,
refined_questions: List[str], research_approach: str) -> Dict[str, Any]:
"""创建研究大纲"""
try:
return self.r1_agent.create_outline(
question, question_type, refined_questions, research_approach
)
except Exception as e:
logger.error(f"创建大纲失败: {e}")
# 返回基本大纲
return {
"main_topic": question,
"research_questions": refined_questions[:3],
"sub_topics": [
{
"topic": "核心分析",
"explain": "对问题进行深入分析",
"priority": "high",
"related_questions": refined_questions[:2]
}
]
}
def validate_outline(self, outline: Dict[str, Any]) -> str:
"""验证大纲"""
try:
return self.r1_agent.validate_outline(outline)
except Exception as e:
logger.error(f"验证大纲失败: {e}")
return "大纲结构合理。"
def modify_outline(self, original_outline: Dict[str, Any],
user_feedback: str, validation_issues: str = "") -> Dict[str, Any]:
"""修改大纲"""
try:
return self.r1_agent.modify_outline(
original_outline, user_feedback, validation_issues
)
except Exception as e:
logger.error(f"修改大纲失败: {e}")
return original_outline
# ========== 搜索阶段 (V3 + R1) ==========
def generate_search_queries(self, subtopic: str, explanation: str,
related_questions: List[str], priority: str) -> List[str]:
"""生成搜索查询V3"""
# 根据优先级确定搜索数量
count_map = {
"high": Config.MAX_SEARCHES_HIGH_PRIORITY,
"medium": Config.MAX_SEARCHES_MEDIUM_PRIORITY,
"low": Config.MAX_SEARCHES_LOW_PRIORITY
}
count = count_map.get(priority, 10)
try:
return self.v3_agent.generate_search_queries(
subtopic, explanation, related_questions, count
)
except Exception as e:
logger.error(f"生成搜索查询失败: {e}")
# 返回基本查询
return [subtopic, f"{subtopic} {explanation}"][:count]
def evaluate_search_results(self, subtopic: str,
search_results: List[SearchResult]) -> List[SearchResult]:
"""评估搜索结果重要性R1"""
evaluated_results = []
for result in search_results:
try:
importance = self.r1_agent.evaluate_search_result(
subtopic,
result.title,
result.url,
result.snippet
)
result.importance = SearchImportance(importance)
evaluated_results.append(result)
except Exception as e:
logger.error(f"评估搜索结果失败: {e}")
result.importance = SearchImportance.MEDIUM
evaluated_results.append(result)
return evaluated_results
# ========== 信息反思阶段 (R1) ==========
def reflect_on_information(self, subtopic: str,
search_results: List[SearchResult]) -> List[Dict[str, str]]:
"""信息反思,返回需要深入的要点"""
# 生成搜索摘要
summary = self._generate_search_summary(search_results)
try:
return self.r1_agent.reflect_on_information(subtopic, summary)
except Exception as e:
logger.error(f"信息反思失败: {e}")
return []
def generate_refined_queries(self, key_points: List[Dict[str, str]]) -> Dict[str, List[str]]:
"""为关键点生成细化查询V3"""
refined_queries = {}
for point in key_points:
try:
queries = self.v3_agent.generate_refined_queries(
point["key_info"],
point["detail_needed"]
)
refined_queries[point["key_info"]] = queries
except Exception as e:
logger.error(f"生成细化查询失败: {e}")
refined_queries[point["key_info"]] = [point["key_info"]]
return refined_queries
# ========== 信息整合阶段 (R1) ==========
def integrate_information(self, subtopic: str,
all_search_results: List[SearchResult]) -> Dict[str, Any]:
"""整合信息"""
# 格式化搜索结果
formatted_results = self._format_search_results_for_integration(all_search_results)
try:
return self.r1_agent.integrate_information(subtopic, formatted_results)
except Exception as e:
logger.error(f"整合信息失败: {e}")
# 返回基本结构
return {
"key_points": [],
"themes": []
}
# ========== 报告撰写阶段 (R1) ==========
def write_subtopic_report(self, subtopic: str, integrated_info: Dict[str, Any]) -> str:
"""撰写子主题报告"""
try:
return self.r1_agent.write_subtopic_report(subtopic, integrated_info)
except Exception as e:
logger.error(f"撰写子主题报告失败: {e}")
return f"## {subtopic}\n\n撰写报告时发生错误。"
# ========== 幻觉检测阶段 (R1 + V3) ==========
def detect_and_fix_hallucinations(self, report: str,
original_sources: Dict[str, str]) -> Tuple[str, List[Dict]]:
"""检测并修复幻觉内容"""
hallucinations = []
fixed_report = report
# 提取报告中的所有URL引用
url_references = self._extract_url_references(report)
for url, content in url_references.items():
if url in original_sources:
try:
# 检测幻觉R1
result = self.r1_agent.detect_hallucination(
content, url, original_sources[url]
)
if result.get("is_hallucination", False):
hallucinations.append({
"url": url,
"content": content,
"type": result.get("hallucination_type", "未知"),
"explanation": result.get("explanation", "")
})
# 重写内容V3
try:
new_content = self.v3_agent.rewrite_hallucination(
content, original_sources[url]
)
fixed_report = fixed_report.replace(content, new_content)
except Exception as e:
logger.error(f"重写幻觉内容失败: {e}")
except Exception as e:
logger.error(f"检测幻觉失败: {e}")
return fixed_report, hallucinations
# ========== 最终报告阶段 (R1) ==========
def generate_final_report(self, main_topic: str, research_questions: List[str],
subtopic_reports: Dict[str, str]) -> str:
"""生成最终报告"""
try:
return self.r1_agent.generate_final_report(
main_topic, research_questions, subtopic_reports
)
except Exception as e:
logger.error(f"生成最终报告失败: {e}")
# 返回基本报告
reports_text = "\n\n---\n\n".join(subtopic_reports.values())
return f"# {main_topic}\n\n## 研究报告\n\n{reports_text}"
# ========== 辅助方法 ==========
def _generate_search_summary(self, search_results: List[SearchResult]) -> str:
"""生成搜索结果摘要"""
high_count = sum(1 for r in search_results if r.importance == SearchImportance.HIGH)
medium_count = sum(1 for r in search_results if r.importance == SearchImportance.MEDIUM)
low_count = sum(1 for r in search_results if r.importance == SearchImportance.LOW)
summary_lines = [
f"共找到 {len(search_results)} 条搜索结果",
f"高重要性: {high_count}",
f"中重要性: {medium_count}",
f"低重要性: {low_count}",
"",
"主要发现:"
]
# 添加高重要性结果的摘要
for result in search_results[:10]: # 最多10条
if result.importance == SearchImportance.HIGH:
summary_lines.append(f"- {result.title}: {result.snippet[:100]}...")
return '\n'.join(summary_lines)
def _format_search_results_for_integration(self, search_results: List[SearchResult]) -> str:
"""格式化搜索结果用于整合"""
formatted_lines = []
for i, result in enumerate(search_results, 1):
formatted_lines.extend([
f"{i}. 来源: {result.url}",
f" 标题: {result.title}",
f" 内容: {result.snippet}",
f" 重要性: {result.importance.value if result.importance else '未评估'}",
""
])
return '\n'.join(formatted_lines)
def _extract_url_references(self, report: str) -> Dict[str, str]:
"""从报告中提取URL引用及其对应内容"""
# 简单实现,实际可能需要更复杂的解析
import re
url_references = {}
# 匹配模式: 内容来源URL
pattern = r'([^]+)(来源:([^]+)'
matches = re.finditer(pattern, report)
for match in matches:
content = match.group(1).strip()
url = match.group(2).strip()
url_references[url] = content
return url_references

View File

@ -0,0 +1,347 @@
"""
报告生成服务
负责生成各类研究报告
"""
import os
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional
from app.models.report import (
SubtopicReport, FinalReport, ReportSection,
KeyInsight, HallucinationCheck
)
from app.models.research import ResearchSession, Subtopic
from app.models.search_result import SearchResult
from config import Config
logger = logging.getLogger(__name__)
class ReportGenerator:
"""报告生成器"""
def generate_subtopic_report(self, subtopic: Subtopic,
integrated_info: Dict[str, Any],
report_content: str) -> SubtopicReport:
"""生成子主题报告"""
try:
# 解析报告内容为结构化格式
sections = self._parse_report_sections(report_content)
key_insights = self._extract_key_insights(report_content)
recommendations = self._extract_recommendations(report_content)
# 统计字数
word_count = len(report_content.replace(" ", ""))
# 创建子主题报告
report = SubtopicReport(
subtopic_id=subtopic.id,
subtopic_name=subtopic.topic,
sections=sections,
key_insights=key_insights,
recommendations=recommendations,
word_count=word_count
)
return report
except Exception as e:
logger.error(f"生成子主题报告失败: {e}")
# 返回基本报告
return SubtopicReport(
subtopic_id=subtopic.id,
subtopic_name=subtopic.topic,
sections=[
ReportSection(
title="报告内容",
content=report_content
)
]
)
def generate_final_report(self, session: ResearchSession,
subtopic_reports: List[SubtopicReport],
final_content: str) -> FinalReport:
"""生成最终报告"""
try:
# 解析最终报告内容
executive_summary = self._extract_executive_summary(final_content)
main_findings = self._parse_main_findings(final_content)
overall_insights = self._extract_overall_insights(final_content)
recommendations = self._extract_final_recommendations(final_content)
# 统计信息
total_sources = self._count_total_sources(subtopic_reports)
total_searches = self._count_total_searches(session)
# 创建最终报告
report = FinalReport(
session_id=session.id,
title=session.question,
executive_summary=executive_summary,
main_findings=main_findings,
subtopic_reports=subtopic_reports,
overall_insights=overall_insights,
recommendations=recommendations,
methodology=self._generate_methodology(session),
limitations=self._identify_limitations(session),
total_sources=total_sources,
total_searches=total_searches
)
return report
except Exception as e:
logger.error(f"生成最终报告失败: {e}")
# 返回基本报告
return FinalReport(
session_id=session.id,
title=session.question,
executive_summary="研究报告生成过程中出现错误。",
subtopic_reports=subtopic_reports,
total_sources=total_sources if 'total_sources' in locals() else 0,
total_searches=total_searches if 'total_searches' in locals() else 0
)
def save_report(self, report: FinalReport, format: str = "markdown") -> str:
"""保存报告到文件"""
try:
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{report.session_id}_{timestamp}.md"
filepath = os.path.join(Config.REPORTS_DIR, filename)
# 保存文件
if format == "markdown":
report.save_to_file(filepath)
else:
# 未来可以支持其他格式PDF、HTML等
raise ValueError(f"不支持的格式: {format}")
logger.info(f"报告已保存: {filepath}")
return filepath
except Exception as e:
logger.error(f"保存报告失败: {e}")
raise
def create_hallucination_report(self, hallucinations: List[Dict[str, Any]]) -> str:
"""创建幻觉检测报告"""
if not hallucinations:
return "未检测到幻觉内容。"
report_lines = ["# 幻觉检测报告", ""]
report_lines.append(f"共检测到 {len(hallucinations)} 处可能的幻觉内容:")
report_lines.append("")
for i, h in enumerate(hallucinations, 1):
report_lines.extend([
f"## {i}. {h.get('type', '未知类型')}",
f"**URL**: {h.get('url', 'N/A')}",
f"**原始内容**: {h.get('content', 'N/A')}",
f"**说明**: {h.get('explanation', '无说明')}",
""
])
return '\n'.join(report_lines)
# ========== 解析辅助方法 ==========
def _parse_report_sections(self, content: str) -> List[ReportSection]:
"""解析报告章节"""
sections = []
# 简单的Markdown解析
lines = content.split('\n')
current_section = None
current_content = []
for line in lines:
if line.startswith('### '):
# 保存前一个章节
if current_section:
current_section.content = '\n'.join(current_content).strip()
sections.append(current_section)
# 开始新章节
current_section = ReportSection(title=line[4:].strip(), content="")
current_content = []
elif line.startswith('#### ') and current_section:
# 子章节
subsection_title = line[5:].strip()
# 收集子章节内容(简化处理)
current_content.append(line)
elif current_section:
current_content.append(line)
# 保存最后一个章节
if current_section:
current_section.content = '\n'.join(current_content).strip()
sections.append(current_section)
return sections
def _extract_key_insights(self, content: str) -> List[KeyInsight]:
"""提取关键洞察"""
insights = []
# 查找"关键洞察"部分
lines = content.split('\n')
in_insights_section = False
for i, line in enumerate(lines):
if '关键洞察' in line and line.startswith('#'):
in_insights_section = True
continue
if in_insights_section:
if line.startswith('#') and '关键洞察' not in line:
break
if line.strip().startswith(('1.', '2.', '3.', '4.', '5.')):
# 提取洞察内容
insight_text = line.split('.', 1)[1].strip()
# 移除Markdown格式
insight_text = insight_text.replace('**', '').replace('*', '')
# 查找来源URL
source_urls = self._extract_urls_from_text(insight_text)
insights.append(KeyInsight(
insight=insight_text.split('')[0] if '' in insight_text else insight_text,
source_urls=source_urls,
confidence=0.8 # 默认置信度
))
return insights
def _extract_recommendations(self, content: str) -> List[str]:
"""提取建议"""
recommendations = []
lines = content.split('\n')
in_recommendations_section = False
for line in lines:
if '建议' in line and line.startswith('#'):
in_recommendations_section = True
continue
if in_recommendations_section:
if line.startswith('#') and '建议' not in line:
break
if line.strip().startswith(('-', '*', '')):
recommendation = line.strip()[1:].strip()
if recommendation:
recommendations.append(recommendation)
return recommendations
def _extract_executive_summary(self, content: str) -> str:
"""提取执行摘要"""
lines = content.split('\n')
in_summary = False
summary_lines = []
for line in lines:
if '执行摘要' in line and line.startswith('#'):
in_summary = True
continue
if in_summary:
if line.startswith('#'):
break
summary_lines.append(line)
return '\n'.join(summary_lines).strip()
def _parse_main_findings(self, content: str) -> List[ReportSection]:
"""解析主要发现"""
# 类似于_parse_report_sections但只关注"主要发现"部分
# 简化实现
return []
def _extract_overall_insights(self, content: str) -> List[KeyInsight]:
"""提取整体洞察"""
# 类似于_extract_key_insights但关注"综合洞察"部分
return []
def _extract_final_recommendations(self, content: str) -> List[str]:
"""提取最终建议"""
# 类似于_extract_recommendations
return []
def _extract_urls_from_text(self, text: str) -> List[str]:
"""从文本中提取URL"""
import re
# 简单的URL提取
url_pattern = r'https?://[^\s)]+|www\.[^\s)]+'
urls = re.findall(url_pattern, text)
# 清理URL
cleaned_urls = []
for url in urls:
# 移除末尾的标点
url = url.rstrip('.,;:!?)')
if url:
cleaned_urls.append(url)
return cleaned_urls
def _count_total_sources(self, subtopic_reports: List[SubtopicReport]) -> int:
"""统计总来源数"""
all_urls = set()
for report in subtopic_reports:
for section in report.sections:
all_urls.update(section.sources)
for insight in report.key_insights:
all_urls.update(insight.source_urls)
return len(all_urls)
def _count_total_searches(self, session: ResearchSession) -> int:
"""统计总搜索次数"""
if not session.outline:
return 0
total = 0
for subtopic in session.outline.sub_topics:
total += subtopic.get_total_searches()
return total
def _generate_methodology(self, session: ResearchSession) -> str:
"""生成研究方法说明"""
methodology = f"""
本研究采用系统化的深度研究方法具体流程如下
1. **问题分析**: 识别问题类型为"{session.question_type.value if session.question_type else '未知'}"并细化为{len(session.refined_questions)}个具体问题
2. **研究规划**: 制定包含{len(session.outline.sub_topics) if session.outline else 0}个子主题的研究大纲每个子主题根据重要性分配不同的搜索资源
3. **信息收集**: 使用Tavily搜索引擎进行多轮搜索共执行{self._count_total_searches(session)}次搜索
4. **质量控制**: 通过AI评估搜索结果重要性并进行幻觉检测和内容验证
5. **综合分析**: 整合所有信息提炼关键洞察形成结构化报告
"""
return methodology.strip()
def _identify_limitations(self, session: ResearchSession) -> List[str]:
"""识别研究局限性"""
limitations = [
"搜索结果受限于公开可访问的网络信息",
"部分专业领域可能缺乏深度分析",
"时效性信息可能存在延迟"
]
# 根据实际情况添加更多局限性
if session.outline and any(st.status == "cancelled" for st in session.outline.sub_topics):
limitations.append("部分子主题研究未完成")
return limitations

View File

@ -0,0 +1,323 @@
# 文件位置: app/services/research_manager.py
# 文件名: research_manager.py
"""
研究流程管理器
协调整个研究过程
"""
import os
import json
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional
from app.models.research import ResearchSession, ResearchStatus, ResearchOutline, Subtopic
from app.services.ai_service import AIService
from app.services.search_service import SearchService
from app.services.report_generator import ReportGenerator
from config import Config
logger = logging.getLogger(__name__)
class ResearchManager:
"""研究流程管理器"""
def __init__(self):
self.ai_service = AIService()
self.search_service = SearchService()
self.report_generator = ReportGenerator()
self.sessions: Dict[str, ResearchSession] = {}
def create_session(self, question: str) -> ResearchSession:
"""创建新的研究会话"""
session = ResearchSession(question=question)
self.sessions[session.id] = session
# 保存到文件
self._save_session(session)
logger.info(f"创建研究会话: {session.id}")
return session
def start_research(self, session_id: str) -> Dict[str, Any]:
"""启动研究流程"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
try:
# 更新状态
session.update_status(ResearchStatus.ANALYZING)
self._save_session(session)
# 启动异步任务链
# 延迟导入,完全避免循环依赖
from app.tasks.research_tasks import analyze_question_chain
analyze_question_chain.delay(session_id)
return {
"status": "started",
"session_id": session_id,
"message": "研究已开始"
}
except Exception as e:
logger.error(f"启动研究失败: {e}")
session.update_status(ResearchStatus.ERROR)
session.error_message = str(e)
self._save_session(session)
raise
def get_session(self, session_id: str) -> Optional[ResearchSession]:
"""获取研究会话"""
# 先从内存查找
if session_id in self.sessions:
return self.sessions[session_id]
# 从文件加载
filepath = self._get_session_filepath(session_id)
if os.path.exists(filepath):
session = ResearchSession.load_from_file(filepath)
self.sessions[session_id] = session
return session
return None
def update_session(self, session_id: str, updates: Dict[str, Any]):
"""更新会话信息"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 更新字段
for key, value in updates.items():
if hasattr(session, key):
setattr(session, key, value)
session.updated_at = datetime.now()
self._save_session(session)
def get_session_status(self, session_id: str) -> Dict[str, Any]:
"""获取研究进度"""
session = self.get_session(session_id)
if not session:
return {"error": "Session not found"}
# 计算子主题进度
subtopic_progress = []
if session.outline:
for subtopic in session.outline.sub_topics:
subtopic_progress.append({
"id": subtopic.id,
"topic": subtopic.topic,
"status": subtopic.status,
"progress": subtopic.get_total_searches() / subtopic.max_searches * 100
})
return {
"session_id": session_id,
"status": session.status,
"current_phase": session.current_phase,
"progress_percentage": session.get_progress_percentage(),
"subtopic_progress": subtopic_progress,
"created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(),
"error_message": session.error_message
}
def cancel_research(self, session_id: str) -> Dict[str, Any]:
"""取消研究"""
session = self.get_session(session_id)
if not session:
return {"error": "Session not found"}
# 更新状态
session.update_status(ResearchStatus.CANCELLED)
self._save_session(session)
return {
"status": "cancelled",
"session_id": session_id,
"message": "研究已取消"
}
def get_research_report(self, session_id: str) -> Optional[str]:
"""获取研究报告"""
session = self.get_session(session_id)
if not session:
return None
if session.status != ResearchStatus.COMPLETED:
return None
# 如果有最终报告,返回
if session.final_report:
return session.final_report
# 否则尝试从文件加载
report_path = os.path.join(Config.REPORTS_DIR, f"{session_id}.md")
if os.path.exists(report_path):
with open(report_path, 'r', encoding='utf-8') as f:
return f.read()
return None
def list_sessions(self, limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]:
"""列出所有研究会话"""
# 从文件系统读取所有会话
sessions = []
session_files = sorted(
[f for f in os.listdir(Config.SESSIONS_DIR) if f.endswith('.json')],
reverse=True # 最新的在前
)
for filename in session_files[offset:offset+limit]:
filepath = os.path.join(Config.SESSIONS_DIR, filename)
try:
session = ResearchSession.load_from_file(filepath)
sessions.append({
"id": session.id,
"question": session.question,
"status": session.status,
"created_at": session.created_at.isoformat(),
"progress": session.get_progress_percentage()
})
except Exception as e:
logger.error(f"加载会话失败 {filename}: {e}")
return sessions
def _save_session(self, session: ResearchSession):
"""保存会话到文件"""
filepath = self._get_session_filepath(session.id)
# 使用模型的 to_dict 方法处理 datetime 序列化
data = session.dict()
# 转换 datetime 对象
for key in ['created_at', 'updated_at', 'completed_at']:
if data.get(key):
data[key] = data[key].isoformat() if hasattr(data[key], 'isoformat') else data[key]
# 处理嵌套的 datetime
if data.get('outline'):
if data['outline'].get('created_at'):
data['outline']['created_at'] = data['outline']['created_at'].isoformat()
if data['outline'].get('updated_at'):
data['outline']['updated_at'] = data['outline']['updated_at'].isoformat()
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
def _get_session_filepath(self, session_id: str) -> str:
"""获取会话文件路径"""
return os.path.join(Config.SESSIONS_DIR, f"{session_id}.json")
# 以下是供任务调用的方法
def process_question_analysis(self, session_id: str):
"""处理问题分析阶段"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 分析问题类型
session.question_type = self.ai_service.analyze_question_type(session.question)
# 细化问题
session.refined_questions = self.ai_service.refine_questions(
session.question,
session.question_type
)
# 制定研究思路
session.research_approach = self.ai_service.create_research_approach(
session.question,
session.question_type,
session.refined_questions
)
# 更新进度
session.current_phase = "制定大纲"
session.completed_steps += 1
self._save_session(session)
def process_outline_creation(self, session_id: str):
"""处理大纲创建阶段"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 创建大纲
outline_dict = self.ai_service.create_outline(
session.question,
session.question_type,
session.refined_questions,
session.research_approach
)
# 转换为模型对象
subtopics = []
for st in outline_dict.get('sub_topics', []):
subtopic = Subtopic(
topic=st['topic'],
explain=st['explain'],
priority=st['priority'],
related_questions=st.get('related_questions', [])
)
# 设置最大搜索次数
if subtopic.priority == "high":
subtopic.max_searches = Config.MAX_SEARCHES_HIGH_PRIORITY
elif subtopic.priority == "medium":
subtopic.max_searches = Config.MAX_SEARCHES_MEDIUM_PRIORITY
else:
subtopic.max_searches = Config.MAX_SEARCHES_LOW_PRIORITY
subtopics.append(subtopic)
session.outline = ResearchOutline(
main_topic=outline_dict['main_topic'],
research_questions=outline_dict['research_questions'],
sub_topics=subtopics
)
# 更新进度
session.current_phase = "研究子主题"
session.update_status(ResearchStatus.RESEARCHING)
session.total_steps = 3 + len(subtopics) + 1 # 准备+大纲+子主题+最终报告
session.completed_steps = 2
self._save_session(session)
def process_subtopic_research(self, session_id: str, subtopic_id: str):
"""处理子主题研究"""
session = self.get_session(session_id)
if not session or not session.outline:
raise ValueError(f"Session or outline not found: {session_id}")
# 找到对应的子主题
subtopic = None
for st in session.outline.sub_topics:
if st.id == subtopic_id:
subtopic = st
break
if not subtopic:
raise ValueError(f"Subtopic not found: {subtopic_id}")
# 执行研究流程
# 这部分逻辑会在research_tasks.py中实现
# 这里只更新状态
subtopic.status = ResearchStatus.COMPLETED
session.completed_steps += 1
self._save_session(session)
def finalize_research(self, session_id: str):
"""完成研究"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 生成最终报告
# 这部分逻辑会在report_generator.py中实现
# 更新状态
session.update_status(ResearchStatus.COMPLETED)
session.current_phase = "研究完成"
session.completed_steps = session.total_steps
self._save_session(session)

View File

@ -0,0 +1,204 @@
"""
搜索服务
封装Tavily API调用
"""
import logging
from typing import List, Dict, Any, Optional
from tavily import TavilyClient
from app.models.search_result import SearchResult, TavilySearchResponse, SearchBatch
from config import Config
import time
logger = logging.getLogger(__name__)
class SearchService:
"""搜索服务"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.TAVILY_API_KEY
self.client = TavilyClient(api_key=self.api_key)
self._search_cache = {} # 简单的搜索缓存
def search(self, query: str, max_results: int = None,
search_depth: str = None, include_answer: bool = None,
include_raw_content: bool = None) -> TavilySearchResponse:
"""执行搜索"""
# 检查缓存
cache_key = f"{query}:{max_results}:{search_depth}"
if cache_key in self._search_cache:
logger.info(f"从缓存返回搜索结果: {query}")
return self._search_cache[cache_key]
# 设置默认值
if max_results is None:
max_results = Config.TAVILY_MAX_RESULTS
if search_depth is None:
search_depth = Config.TAVILY_SEARCH_DEPTH
if include_answer is None:
include_answer = Config.TAVILY_INCLUDE_ANSWER
if include_raw_content is None:
include_raw_content = Config.TAVILY_INCLUDE_RAW_CONTENT
try:
logger.info(f"执行Tavily搜索: {query}")
start_time = time.time()
# 调用Tavily API
response = self.client.search(
query=query,
max_results=max_results,
search_depth=search_depth,
include_answer=include_answer,
include_raw_content=include_raw_content
)
response_time = time.time() - start_time
# 转换为我们的响应模型
tavily_response = TavilySearchResponse(
query=query,
answer=response.get('answer'),
images=response.get('images', []),
results=response.get('results', []),
response_time=response_time
)
# 缓存结果
self._search_cache[cache_key] = tavily_response
logger.info(f"搜索完成,耗时 {response_time:.2f}秒,返回 {len(tavily_response.results)} 条结果")
return tavily_response
except Exception as e:
logger.error(f"Tavily搜索失败: {e}")
# 返回空结果
return TavilySearchResponse(
query=query,
answer=None,
images=[],
results=[],
response_time=0.0
)
def batch_search(self, queries: List[str], max_results_per_query: int = 10) -> List[TavilySearchResponse]:
"""批量搜索"""
results = []
for query in queries:
# 添加延迟以避免速率限制
if results: # 不是第一个查询
time.sleep(0.5) # 500ms延迟
try:
response = self.search(query, max_results=max_results_per_query)
results.append(response)
except Exception as e:
logger.error(f"批量搜索中的查询失败 '{query}': {e}")
# 添加空结果
results.append(TavilySearchResponse(
query=query,
results=[],
response_time=0.0
))
return results
def search_subtopic(self, subtopic_id: str, subtopic_name: str,
queries: List[str]) -> SearchBatch:
"""为子主题执行搜索"""
all_results = []
for query in queries:
response = self.search(query)
search_results = response.to_search_results()
all_results.extend(search_results)
# 创建搜索批次
batch = SearchBatch(
subtopic_id=subtopic_id,
query=f"子主题搜索: {subtopic_name}",
results=[]
)
# 去重并添加结果
batch.add_results(all_results)
return batch
def refined_search(self, subtopic_id: str, key_info: str,
queries: List[str], parent_search_id: str = None) -> SearchBatch:
"""执行细化搜索"""
all_results = []
for query in queries:
response = self.search(query, search_depth="advanced")
search_results = response.to_search_results()
all_results.extend(search_results)
# 创建细化搜索批次
batch = SearchBatch(
subtopic_id=subtopic_id,
query=f"细化搜索: {key_info}",
results=[],
is_refined_search=True,
parent_search_id=parent_search_id,
detail_type=key_info
)
batch.add_results(all_results)
return batch
def extract_content(self, urls: List[str]) -> Dict[str, str]:
"""提取URL的完整内容"""
content_map = {}
try:
# Tavily的extract功能如果可用
# 注意这需要Tavily API支持extract功能
response = self.client.extract(urls=urls[:20]) # 最多20个URL
for result in response.get('results', []):
url = result.get('url')
content = result.get('raw_content', '')
if url and content:
content_map[url] = content
except Exception as e:
logger.error(f"提取内容失败: {e}")
# 如果extract不可用使用搜索结果中的内容
for url in urls:
# 从缓存的搜索结果中查找
for cached_response in self._search_cache.values():
for result in cached_response.results:
if result.get('url') == url:
content_map[url] = result.get('content', '')
break
return content_map
def get_search_statistics(self) -> Dict[str, Any]:
"""获取搜索统计信息"""
total_searches = len(self._search_cache)
total_results = sum(len(r.results) for r in self._search_cache.values())
return {
"total_searches": total_searches,
"total_results": total_results,
"cache_size": len(self._search_cache),
"cached_queries": list(self._search_cache.keys())
}
def clear_cache(self):
"""清空搜索缓存"""
self._search_cache.clear()
logger.info("搜索缓存已清空")
def test_connection(self) -> bool:
"""测试Tavily API连接"""
try:
response = self.search("test query", max_results=1)
return len(response.results) >= 0
except Exception as e:
logger.error(f"Tavily API连接测试失败: {e}")
return False

View File

@ -0,0 +1,202 @@
# 文件位置: app/services/task_manager.py
# 文件名: task_manager.py
"""
任务管理器
替代 Celery 的轻量级任务队列实现
"""
import uuid
import logging
import threading
from concurrent.futures import ThreadPoolExecutor, Future
from typing import Dict, Any, Callable, Optional, List
from datetime import datetime
from enum import Enum
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""任务状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TaskInfo:
"""任务信息"""
def __init__(self, task_id: str, func_name: str, args: tuple, kwargs: dict):
self.id = task_id
self.func_name = func_name
self.args = args
self.kwargs = kwargs
self.status = TaskStatus.PENDING
self.created_at = datetime.now()
self.started_at: Optional[datetime] = None
self.completed_at: Optional[datetime] = None
self.result: Any = None
self.error: Optional[str] = None
self.future: Optional[Future] = None
class TaskManager:
"""任务管理器单例"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, 'initialized'):
self.executor = ThreadPoolExecutor(max_workers=10)
self.tasks: Dict[str, TaskInfo] = {}
self.session_tasks: Dict[str, List[str]] = {} # session_id -> task_ids
self.initialized = True
logger.info("任务管理器初始化完成")
def submit_task(self, func: Callable, *args, **kwargs) -> str:
"""提交任务"""
task_id = str(uuid.uuid4())
task_info = TaskInfo(task_id, func.__name__, args, kwargs)
# 提取session_id如果存在
session_id = None
if args and isinstance(args[0], str) and '-' in args[0]:
# 假设第一个参数是session_idUUID格式
session_id = args[0]
elif 'session_id' in kwargs:
session_id = kwargs['session_id']
# 记录任务
self.tasks[task_id] = task_info
# 关联到session
if session_id:
if session_id not in self.session_tasks:
self.session_tasks[session_id] = []
self.session_tasks[session_id].append(task_id)
# 提交执行
future = self.executor.submit(self._execute_task, task_info, func, *args, **kwargs)
task_info.future = future
logger.info(f"任务提交成功: {task_id} - {func.__name__}")
return task_id
def _execute_task(self, task_info: TaskInfo, func: Callable, *args, **kwargs):
"""执行任务"""
try:
task_info.status = TaskStatus.RUNNING
task_info.started_at = datetime.now()
logger.info(f"任务开始执行: {task_info.id} - {task_info.func_name}")
# 执行任务
result = func(*args, **kwargs)
# 更新任务信息
task_info.status = TaskStatus.COMPLETED
task_info.completed_at = datetime.now()
task_info.result = result
logger.info(f"任务执行成功: {task_info.id}")
return result
except Exception as e:
task_info.status = TaskStatus.FAILED
task_info.completed_at = datetime.now()
task_info.error = str(e)
logger.error(f"任务执行失败: {task_info.id} - {e}")
raise
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""获取任务状态"""
if task_id not in self.tasks:
return None
task_info = self.tasks[task_id]
return {
"task_id": task_info.id,
"status": task_info.status.value,
"func_name": task_info.func_name,
"created_at": task_info.created_at.isoformat(),
"started_at": task_info.started_at.isoformat() if task_info.started_at else None,
"completed_at": task_info.completed_at.isoformat() if task_info.completed_at else None,
"error": task_info.error
}
def get_session_tasks(self, session_id: str) -> List[Dict[str, Any]]:
"""获取会话的所有任务"""
task_ids = self.session_tasks.get(session_id, [])
return [self.get_task_status(task_id) for task_id in task_ids if self.get_task_status(task_id)]
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
if task_id not in self.tasks:
return False
task_info = self.tasks[task_id]
if task_info.future and not task_info.future.done():
cancelled = task_info.future.cancel()
if cancelled:
task_info.status = TaskStatus.CANCELLED
task_info.completed_at = datetime.now()
logger.info(f"任务已取消: {task_id}")
return True
return False
def cancel_session_tasks(self, session_id: str) -> int:
"""取消会话的所有任务"""
task_ids = self.session_tasks.get(session_id, [])
cancelled_count = 0
for task_id in task_ids:
if self.cancel_task(task_id):
cancelled_count += 1
return cancelled_count
def cleanup_old_tasks(self, hours: int = 24):
"""清理旧任务"""
cutoff_time = datetime.now().timestamp() - (hours * 3600)
tasks_to_remove = []
for task_id, task_info in self.tasks.items():
if task_info.completed_at and task_info.completed_at.timestamp() < cutoff_time:
tasks_to_remove.append(task_id)
for task_id in tasks_to_remove:
del self.tasks[task_id]
# 从session_tasks中移除
for session_id, task_ids in self.session_tasks.items():
if task_id in task_ids:
task_ids.remove(task_id)
logger.info(f"清理了 {len(tasks_to_remove)} 个旧任务")
return len(tasks_to_remove)
def shutdown(self):
"""关闭任务管理器"""
self.executor.shutdown(wait=True)
logger.info("任务管理器已关闭")
# 全局任务管理器实例
task_manager = TaskManager()
# 装饰器:将普通函数转换为异步任务
def async_task(func):
"""异步任务装饰器"""
def wrapper(*args, **kwargs):
return task_manager.submit_task(func, *args, **kwargs)
wrapper.delay = wrapper # 兼容Celery的.delay()调用方式
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper

417
app/static/css/style.css Normal file
View File

@ -0,0 +1,417 @@
/* app/static/css/style.css */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background-color: #f5f7fa;
color: #2c3e50;
}
.app {
min-height: 100vh;
display: flex;
flex-direction: column;
}
/* Header */
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.header h1 {
font-size: 24px;
font-weight: 600;
}
.progress-bar {
margin-top: 10px;
background: rgba(255,255,255,0.3);
height: 8px;
border-radius: 4px;
overflow: hidden;
}
.progress-fill {
height: 100%;
background: white;
transition: width 0.3s ease;
}
.progress-message {
font-size: 14px;
opacity: 0.9;
margin-top: 5px;
}
/* Main Container */
.main-container {
flex: 1;
padding: 30px;
}
/* Start Screen */
.start-screen {
max-width: 800px;
margin: 0 auto;
}
.start-card {
background: white;
border-radius: 12px;
padding: 40px;
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
}
.start-card h2 {
font-size: 28px;
margin-bottom: 30px;
text-align: center;
}
.input-group {
display: flex;
gap: 10px;
margin-bottom: 40px;
}
.question-input {
flex: 1;
padding: 15px 20px;
font-size: 16px;
border: 2px solid #e0e0e0;
border-radius: 8px;
transition: border-color 0.3s ease;
}
.question-input:focus {
outline: none;
border-color: #667eea;
}
.start-button {
padding: 15px 30px;
background: #667eea;
color: white;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 500;
cursor: pointer;
transition: all 0.3s ease;
display: flex;
align-items: center;
gap: 8px;
}
.start-button:hover {
background: #5a67d8;
transform: translateY(-1px);
}
.start-button:disabled {
background: #cbd5e0;
cursor: not-allowed;
}
/* History Section */
.history-section {
margin-top: 40px;
padding-top: 40px;
border-top: 1px solid #e0e0e0;
}
.session-item {
background: #f8fafc;
border: 1px solid #e5e7eb;
border-radius: 8px;
padding: 16px;
margin-bottom: 12px;
cursor: pointer;
transition: all 0.3s ease;
}
.session-item:hover {
border-color: #667eea;
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
}
/* Tree Container */
.tree-container {
background: white;
border-radius: 12px;
padding: 30px;
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
overflow-x: auto;
min-width: 800px;
}
/* Tree Structure */
.tree-node {
position: relative;
padding-left: 30px;
margin: 10px 0;
}
.tree-node::before {
content: '';
position: absolute;
left: 0;
top: -10px;
width: 1px;
height: calc(100% + 20px);
background: #e0e0e0;
}
.tree-node:last-child::before {
height: 30px;
}
.tree-node::after {
content: '';
position: absolute;
left: 0;
top: 20px;
width: 20px;
height: 1px;
background: #e0e0e0;
}
/* Root Node */
.root-node {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 30px;
}
.root-node h2 {
font-size: 20px;
margin-bottom: 8px;
}
/* Node Card */
.node-card {
background: white;
border: 2px solid #e0e0e0;
border-radius: 8px;
padding: 15px;
cursor: pointer;
transition: all 0.3s ease;
display: inline-block;
min-width: 300px;
}
.node-card:hover {
border-color: #667eea;
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.15);
}
.node-card.processing {
border-color: #3b82f6;
animation: pulse 2s infinite;
}
.node-card.completed {
border-color: #10b981;
}
.node-card.error {
border-color: #ef4444;
}
@keyframes pulse {
0% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0.4); }
70% { box-shadow: 0 0 0 8px rgba(59, 130, 246, 0); }
100% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0); }
}
/* Node Content */
.node-header {
display: flex;
align-items: center;
justify-content: space-between;
}
.node-title {
font-weight: 600;
font-size: 16px;
flex: 1;
}
.node-status {
display: flex;
align-items: center;
gap: 8px;
}
.status-icon {
width: 24px;
height: 24px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-size: 14px;
}
.status-icon.completed {
background: #10b981;
color: white;
}
.status-icon.processing {
background: #3b82f6;
color: white;
}
.status-icon.pending {
background: #9ca3af;
color: white;
}
.expand-icon {
transition: transform 0.3s ease;
margin-right: 8px;
}
.expand-icon.expanded {
transform: rotate(90deg);
}
.node-content {
margin-top: 15px;
padding-top: 15px;
border-top: 1px solid #f0f0f0;
max-height: 0;
overflow: hidden;
opacity: 0;
transition: all 0.3s ease;
}
.node-content.expanded {
max-height: 1000px;
opacity: 1;
}
/* Phase Card */
.phase-card {
background: #f8fafc;
border: 1px solid #e5e7eb;
border-radius: 6px;
padding: 12px;
margin: 8px 0;
}
.phase-card h4 {
font-size: 14px;
margin-bottom: 8px;
color: #4b5563;
}
/* Action Buttons */
.action-buttons {
position: fixed;
bottom: 30px;
right: 30px;
display: flex;
gap: 12px;
}
.action-button {
background: white;
border: 1px solid #e5e7eb;
padding: 12px 20px;
border-radius: 8px;
cursor: pointer;
font-size: 14px;
transition: all 0.3s ease;
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
}
.action-button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(0,0,0,0.12);
}
.action-button.primary {
background: #667eea;
color: white;
border-color: #667eea;
}
.action-button.danger {
background: #ef4444;
color: white;
border-color: #ef4444;
}
/* Detail Panel */
.detail-panel {
position: fixed;
right: -400px;
top: 0;
width: 400px;
height: 100vh;
background: white;
box-shadow: -4px 0 20px rgba(0,0,0,0.1);
transition: right 0.3s ease;
z-index: 1000;
overflow-y: auto;
}
.detail-panel.open {
right: 0;
}
.panel-header {
padding: 20px;
border-bottom: 1px solid #e5e7eb;
display: flex;
justify-content: space-between;
align-items: center;
}
.panel-close {
background: none;
border: none;
font-size: 24px;
cursor: pointer;
color: #6b7280;
}
/* Loading */
.loading-overlay {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: rgba(0, 0, 0, 0.5);
display: flex;
align-items: center;
justify-content: center;
z-index: 9999;
}
.loading-spinner {
width: 50px;
height: 50px;
border: 5px solid #f3f3f3;
border-top: 5px solid #667eea;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}

51
app/static/js/api.js Normal file
View File

@ -0,0 +1,51 @@
// app/static/js/api.js
const API_BASE = '/api';
const api = {
// 创建研究
createResearch: async (question) => {
const response = await fetch(`${API_BASE}/research`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
question: question,
auto_start: true
})
});
return response.json();
},
// 获取会话列表
getSessions: async (limit = 20, offset = 0) => {
const response = await fetch(`${API_BASE}/research/sessions?limit=${limit}&offset=${offset}`);
return response.json();
},
// 获取会话状态
getSessionStatus: async (sessionId) => {
const response = await fetch(`${API_BASE}/research/${sessionId}/status`);
return response.json();
},
// 获取研究大纲
getOutline: async (sessionId) => {
const response = await fetch(`${API_BASE}/research/${sessionId}/outline`);
return response.json();
},
// 取消研究
cancelResearch: async (sessionId) => {
const response = await fetch(`${API_BASE}/research/${sessionId}/cancel`, {
method: 'POST'
});
return response.json();
},
// 下载报告
downloadReport: async (sessionId) => {
window.open(`${API_BASE}/research/${sessionId}/report?format=markdown`, '_blank');
}
};

87
app/static/js/index.js Normal file
View File

@ -0,0 +1,87 @@
// app/static/js/index.js
document.addEventListener('DOMContentLoaded', function() {
loadSessions();
});
async function startResearch() {
const input = document.getElementById('questionInput');
const question = input.value.trim();
if (!question) {
alert('请输入研究问题');
return;
}
const startBtn = document.getElementById('startBtn');
const loading = document.getElementById('loading');
startBtn.disabled = true;
loading.style.display = 'flex';
try {
const result = await api.createResearch(question);
if (result.session_id) {
// 跳转到研究页面
window.location.href = `/research/${result.session_id}`;
} else {
alert('创建研究失败: ' + (result.error || '未知错误'));
}
} catch (error) {
console.error('Error:', error);
alert('创建研究失败,请重试');
} finally {
startBtn.disabled = false;
loading.style.display = 'none';
}
}
async function loadSessions() {
try {
const data = await api.getSessions();
if (data.sessions && data.sessions.length > 0) {
const historySection = document.getElementById('historySection');
const sessionList = document.getElementById('sessionList');
historySection.style.display = 'block';
sessionList.innerHTML = '';
data.sessions.forEach(session => {
const item = document.createElement('div');
item.className = 'session-item';
item.onclick = () => {
window.location.href = `/research/${session.id}`;
};
item.innerHTML = `
<div class="session-question">${session.question}</div>
<div class="session-meta">
<span class="status-badge ${session.status}">${getStatusText(session.status)}</span>
<span class="session-date">${new Date(session.created_at).toLocaleDateString()}</span>
</div>
`;
sessionList.appendChild(item);
});
}
} catch (error) {
console.error('Failed to load sessions:', error);
}
}
function getStatusText(status) {
const statusMap = {
'pending': '等待中',
'analyzing': '分析中',
'outlining': '制定大纲',
'researching': '研究中',
'writing': '撰写中',
'reviewing': '审核中',
'completed': '已完成',
'error': '错误',
'cancelled': '已取消'
};
return statusMap[status] || status;
}

View File

@ -0,0 +1,149 @@
// app/static/js/research-tree.js
function renderTree(session, outline) {
const container = document.getElementById('treeContainer');
if (!container) {
console.error('Tree container not found!');
return;
}
container.innerHTML = '';
// 根节点 - 始终显示
const rootNode = document.createElement('div');
rootNode.className = 'root-node';
rootNode.innerHTML = `
<h2>${session.question || '研究问题'}</h2>
<p>状态${getStatusText(session.status)} |
开始时间${new Date(session.created_at).toLocaleString()}</p>
${session.error_message ? `<p style="color: #ff6b6b;">错误: ${session.error_message}</p>` : ''}
`;
container.appendChild(rootNode);
// 如果出错,显示错误信息
if (session.status === 'error') {
const errorNode = createTreeNode('研究出现错误', 'error');
container.appendChild(wrapInTreeNode(errorNode));
return;
}
// 研究准备节点
const prepNode = createTreeNode('研究准备', session.refined_questions ? 'completed' : session.status);
prepNode.onclick = () => showDetail('preparation', session);
if (session.refined_questions) {
const content = document.createElement('div');
content.className = 'node-content expanded';
content.innerHTML = `
<div class="phase-card">
<h4>🎯 问题细化</h4>
<ul>
${session.refined_questions.map(q => `<li>• ${q}</li>`).join('')}
</ul>
</div>
`;
prepNode.querySelector('.node-card').appendChild(content);
}
container.appendChild(wrapInTreeNode(prepNode));
// 大纲节点
if (outline) {
const outlineNode = createTreeNode('研究大纲', 'completed');
const outlineContent = document.createElement('div');
outlineContent.className = 'node-content expanded';
outlineContent.innerHTML = `
<div class="phase-card">
<h4>📋 主要研究问题</h4>
<ul>
${outline.research_questions.map(q => `<li>• ${q}</li>`).join('')}
</ul>
</div>
`;
outlineNode.querySelector('.node-card').appendChild(outlineContent);
const outlineWrapper = wrapInTreeNode(outlineNode);
container.appendChild(outlineWrapper);
// 子主题节点
outline.sub_topics.forEach((subtopic, idx) => {
const subtopicNode = createSubtopicNode(subtopic, idx + 1);
outlineWrapper.appendChild(wrapInTreeNode(subtopicNode, true));
});
} else {
// 显示大纲创建中或失败
const outlineStatus = session.status === 'outlining' ? 'processing' :
session.status === 'error' ? 'error' : 'pending';
const outlineNode = createTreeNode('研究大纲', outlineStatus);
container.appendChild(wrapInTreeNode(outlineNode));
}
// 最终报告节点
const reportNode = createTreeNode('研究报告生成', session.final_report ? 'completed' : 'pending');
container.appendChild(wrapInTreeNode(reportNode));
}
function createTreeNode(title, status) {
const node = document.createElement('div');
const statusInfo = getStatusInfo(status);
node.innerHTML = `
<div class="node-card ${statusInfo.className}">
<div class="node-header">
<div class="node-title-wrapper">
<span class="node-title">${title}</span>
</div>
<div class="node-status">
<span class="status-icon ${statusInfo.className}">${statusInfo.icon}</span>
</div>
</div>
</div>
`;
return node;
}
function createSubtopicNode(subtopic, index) {
const node = document.createElement('div');
const statusInfo = getStatusInfo(subtopic.status);
node.innerHTML = `
<div class="node-card ${statusInfo.className}" onclick="showDetail('subtopic', ${JSON.stringify(subtopic).replace(/"/g, '&quot;')})">
<div class="node-header">
<div class="node-title-wrapper">
<span class="node-title">子主题${index}${subtopic.topic}</span>
</div>
<div class="node-status">
<span class="priority-badge ${subtopic.priority}">
${subtopic.priority === 'high' ? '高' : subtopic.priority === 'medium' ? '中' : '低'}优先级
</span>
<span class="status-icon ${statusInfo.className}">${statusInfo.icon}</span>
</div>
</div>
</div>
`;
return node;
}
function wrapInTreeNode(node, isSubtopic = false) {
const wrapper = document.createElement('div');
wrapper.className = 'tree-node' + (isSubtopic ? ' subtopic-node' : '');
wrapper.appendChild(node);
return wrapper;
}
function getStatusInfo(status) {
const statusMap = {
'pending': { icon: '○', className: 'pending' },
'analyzing': { icon: '●', className: 'processing' },
'outlining': { icon: '●', className: 'processing' },
'researching': { icon: '●', className: 'processing' },
'writing': { icon: '●', className: 'processing' },
'reviewing': { icon: '●', className: 'processing' },
'completed': { icon: '✓', className: 'completed' },
'error': { icon: '✗', className: 'error' },
'cancelled': { icon: '⊘', className: 'cancelled' }
};
return statusMap[status] || statusMap['pending'];
}

156
app/static/js/research.js Normal file
View File

@ -0,0 +1,156 @@
// app/static/js/research.js
let socket = null;
let currentSession = null;
document.addEventListener('DOMContentLoaded', function() {
initWebSocket();
loadSessionData();
// 定期刷新状态
setInterval(loadSessionData, 3000);
});
function initWebSocket() {
socket = io();
socket.on('connect', function() {
console.log('WebSocket connected');
socket.emit('join_session', { session_id: SESSION_ID });
});
socket.on('progress_update', function(data) {
updateProgress(data.percentage, data.message);
});
socket.on('status_changed', function(data) {
loadSessionData();
});
socket.on('subtopic_updated', function(data) {
loadSessionData();
});
socket.on('report_available', function(data) {
document.getElementById('downloadBtn').style.display = 'block';
document.getElementById('cancelBtn').style.display = 'none';
});
}
async function loadSessionData() {
try {
const status = await api.getSessionStatus(SESSION_ID);
currentSession = status;
updateProgress(status.progress_percentage || 0, status.current_phase || '准备中');
// 始终尝试渲染基础树结构
renderTree(status, null);
// 如果有大纲,加载大纲
if (status.status !== 'pending' && status.status !== 'analyzing') {
try {
const outline = await api.getOutline(SESSION_ID);
renderTree(status, outline);
} catch (error) {
console.log('大纲尚未创建');
}
}
// 如果完成,显示下载按钮
if (status.status === 'completed') {
document.getElementById('downloadBtn').style.display = 'block';
document.getElementById('cancelBtn').style.display = 'none';
}
} catch (error) {
console.error('Failed to load session data:', error);
}
}
function updateProgress(percentage, message) {
document.getElementById('progressFill').style.width = percentage + '%';
document.getElementById('progressMessage').textContent = message;
}
async function cancelResearch() {
if (confirm('确定要取消当前研究吗?')) {
try {
await api.cancelResearch(SESSION_ID);
alert('研究已取消');
window.location.href = '/';
} catch (error) {
alert('取消失败,请重试');
}
}
}
async function downloadReport() {
api.downloadReport(SESSION_ID);
}
function showDetail(type, data) {
const panel = document.getElementById('detailPanel');
const content = document.getElementById('panelContent');
// 根据类型渲染不同的内容
let html = '';
switch(type) {
case 'preparation':
html = `
<h4>研究准备</h4>
${data.refined_questions ? `
<div class="detail-section">
<h5>细化的问题</h5>
<ul>
${data.refined_questions.map(q => `<li>${q}</li>`).join('')}
</ul>
</div>
` : ''}
${data.research_approach ? `
<div class="detail-section">
<h5>研究思路</h5>
<p>${data.research_approach}</p>
</div>
` : ''}
`;
break;
case 'subtopic':
html = `
<h4>${data.topic}</h4>
<p>${data.explain}</p>
<div class="meta-info">
<span>优先级${data.priority}</span>
<span>状态${getStatusText(data.status)}</span>
</div>
`;
break;
default:
html = '<p>暂无详细信息</p>';
}
content.innerHTML = html;
panel.classList.add('open');
}
function closePanel() {
document.getElementById('detailPanel').classList.remove('open');
}
function getStatusText(status) {
const statusMap = {
'pending': '等待中',
'analyzing': '分析中',
'outlining': '制定大纲',
'researching': '研究中',
'writing': '撰写中',
'reviewing': '审核中',
'completed': '已完成',
'error': '错误',
'cancelled': '已取消'
};
return statusMap[status] || status;
}

8
app/tasks/__init__.py Normal file
View File

@ -0,0 +1,8 @@
# 文件位置: app/tasks/__init__.py
# 文件名: __init__.py
"""
任务模块
"""
# 保持空白,避免循环导入

471
app/tasks/research_tasks.py Normal file
View File

@ -0,0 +1,471 @@
# 文件位置: app/tasks/research_tasks.py
# 文件名: research_tasks.py
"""
研究相关的异步任务
使用线程池替代Celery
"""
import logging
from typing import Dict, List, Any
from app.services.task_manager import async_task
from app.models.research import ResearchStatus, Subtopic
logger = logging.getLogger(__name__)
@async_task
def analyze_question_chain(session_id: str):
"""问题分析任务链"""
try:
# 启用调试
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
logger.info(f"启用调试模式: {session_id}")
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
research_manager = ResearchManager()
session = research_manager.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 发送状态更新
_emit_status(session_id, ResearchStatus.ANALYZING, "分析问题")
# 执行问题分析
research_manager.process_question_analysis(session_id)
# 发送进度更新
_emit_progress(session_id, 20, "问题分析完成")
# 启动大纲创建任务
create_outline_task.delay(session_id)
except Exception as e:
logger.error(f"问题分析失败: {e}")
_handle_task_error(session_id, str(e))
raise
@async_task
def create_outline_task(session_id: str):
"""创建大纲任务"""
try:
# 确保调试会话设置正确
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
research_manager = ResearchManager()
# 发送状态更新
_emit_status(session_id, ResearchStatus.OUTLINING, "制定大纲")
# 创建大纲
research_manager.process_outline_creation(session_id)
# 发送进度更新
_emit_progress(session_id, 30, "大纲制定完成")
# 获取更新后的session
session = research_manager.get_session(session_id)
# 启动子主题研究任务组
if session.outline and session.outline.sub_topics:
# 并发执行子主题研究
subtopic_task_ids = []
for st in session.outline.sub_topics:
task_id = research_subtopic.delay(session_id, st.id)
subtopic_task_ids.append(task_id)
# 启动一个监控任务,等待所有子主题完成后生成最终报告
monitor_subtopics_completion.delay(session_id, subtopic_task_ids)
except Exception as e:
logger.error(f"创建大纲失败: {e}")
_handle_task_error(session_id, str(e))
raise
@async_task
def research_subtopic(session_id: str, subtopic_id: str):
"""研究单个子主题"""
try:
# 确保调试会话设置正确
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
logger.info(f"开始研究子主题: {subtopic_id}")
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
from app.services.ai_service import AIService
from app.services.search_service import SearchService
research_manager = ResearchManager()
ai_service = AIService()
search_service = SearchService()
# 获取session和子主题
session = research_manager.get_session(session_id)
if not session or not session.outline:
raise ValueError("Session or outline not found")
subtopic = None
for st in session.outline.sub_topics:
if st.id == subtopic_id:
subtopic = st
break
if not subtopic:
raise ValueError(f"Subtopic not found: {subtopic_id}")
# 更新子主题状态
subtopic.status = ResearchStatus.RESEARCHING
research_manager.update_session(session_id, {'outline': session.outline})
_emit_subtopic_progress(session_id, subtopic_id, 0, "researching")
# 1. 生成搜索查询
queries = ai_service.generate_search_queries(
subtopic.topic,
subtopic.explain,
subtopic.related_questions,
subtopic.priority
)
# 2. 执行搜索
logger.info(f"开始搜索子主题 {subtopic.topic}: {len(queries)} 个查询")
search_results = []
for i, query in enumerate(queries):
try:
response = search_service.search(query)
results = response.to_search_results()
# 评估结果重要性
evaluated_results = ai_service.evaluate_search_results(
subtopic.topic, results
)
search_results.extend(evaluated_results)
# 更新进度
progress = (i + 1) / len(queries) * 50 # 搜索占50%进度
_emit_subtopic_progress(session_id, subtopic_id, progress, "searching")
except Exception as e:
logger.error(f"搜索失败 '{query}': {e}")
# 去重
unique_results = list({r.url: r for r in search_results}.values())
subtopic.searches = [
{
"url": r.url,
"title": r.title,
"snippet": r.snippet,
"importance": r.importance.value if r.importance else "medium"
}
for r in unique_results
]
# 3. 信息反思
key_points = ai_service.reflect_on_information(subtopic.topic, unique_results)
if key_points:
# 4. 生成细化查询
refined_queries_map = ai_service.generate_refined_queries(key_points)
# 5. 执行细化搜索
for key_info, queries in refined_queries_map.items():
refined_batch = search_service.refined_search(
subtopic_id, key_info, queries
)
# 评估细化搜索结果
evaluated_refined = ai_service.evaluate_search_results(
subtopic.topic, refined_batch.results
)
subtopic.refined_searches.extend([
{
"key_info": key_info,
"url": r.url,
"title": r.title,
"snippet": r.snippet,
"importance": r.importance.value if r.importance else "medium"
}
for r in evaluated_refined
])
_emit_subtopic_progress(session_id, subtopic_id, 70, "integrating")
# 6. 整合信息
all_results = unique_results + [r for batch in subtopic.refined_searches for r in batch.get('results', [])]
integrated_info = ai_service.integrate_information(subtopic.topic, all_results)
subtopic.integrated_info = integrated_info
# 7. 撰写报告
_emit_subtopic_progress(session_id, subtopic_id, 80, "writing")
report_content = ai_service.write_subtopic_report(subtopic.topic, integrated_info)
# 8. 幻觉检测和修正
_emit_subtopic_progress(session_id, subtopic_id, 90, "reviewing")
# 提取原始内容用于幻觉检测
url_content_map = {}
for result in all_results:
url_content_map[result.url] = result.snippet
fixed_report, hallucinations = ai_service.detect_and_fix_hallucinations(
report_content, url_content_map
)
subtopic.report = fixed_report
subtopic.hallucination_checks = hallucinations
subtopic.status = ResearchStatus.COMPLETED
# 保存更新
research_manager.update_session(session_id, {'outline': session.outline})
research_manager.process_subtopic_research(session_id, subtopic_id)
_emit_subtopic_progress(session_id, subtopic_id, 100, "completed")
logger.info(f"子主题研究完成: {subtopic_id}")
return {
"subtopic_id": subtopic_id,
"status": "completed",
"search_count": len(queries),
"results_count": len(unique_results),
"hallucinations_fixed": len(hallucinations)
}
except Exception as e:
logger.error(f"子主题研究失败 {subtopic_id}: {e}")
# 更新状态为错误
try:
from app.services.research_manager import ResearchManager
research_manager = ResearchManager()
session = research_manager.get_session(session_id)
if session and session.outline:
for st in session.outline.sub_topics:
if st.id == subtopic_id:
st.status = ResearchStatus.ERROR
break
research_manager.update_session(session_id, {'outline': session.outline})
_emit_subtopic_progress(session_id, subtopic_id, -1, "error")
except:
pass
raise
@async_task
def monitor_subtopics_completion(session_id: str, task_ids: List[str]):
"""监控子主题完成情况并生成最终报告"""
import time
from app.services.task_manager import task_manager
try:
# 确保调试会话设置正确
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
# 等待所有子主题任务完成
max_wait_time = 1800 # 30分钟超时
start_time = time.time()
while True:
all_completed = True
failed_count = 0
for task_id in task_ids:
status = task_manager.get_task_status(task_id)
if status:
if status['status'] == 'running' or status['status'] == 'pending':
all_completed = False
elif status['status'] == 'failed':
failed_count += 1
if all_completed:
break
if time.time() - start_time > max_wait_time:
logger.error(f"等待子主题完成超时: {session_id}")
break
time.sleep(5) # 每5秒检查一次
# 所有子主题完成后,生成最终报告
if failed_count < len(task_ids): # 至少有一个成功
generate_final_report_task.delay(session_id)
else:
_handle_task_error(session_id, "所有子主题研究失败")
except Exception as e:
logger.error(f"监控子主题完成失败: {e}")
_handle_task_error(session_id, str(e))
@async_task
def generate_final_report_task(session_id: str):
"""生成最终报告"""
try:
# 确保调试会话设置正确
from app.utils.debug_logger import ai_debug_logger
ai_debug_logger.set_session(session_id)
logger.info(f"开始生成最终报告: {session_id}")
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
from app.services.ai_service import AIService
from app.services.report_generator import ReportGenerator
research_manager = ResearchManager()
ai_service = AIService()
report_generator = ReportGenerator()
# 发送状态更新
_emit_status(session_id, ResearchStatus.WRITING, "生成最终报告")
_emit_progress(session_id, 90, "整合所有子主题报告")
# 获取session
session = research_manager.get_session(session_id)
if not session or not session.outline:
raise ValueError("Session or outline not found")
# 收集所有子主题报告
subtopic_reports_dict = {}
subtopic_report_objects = []
for subtopic in session.outline.sub_topics:
if subtopic.report:
subtopic_reports_dict[subtopic.topic] = subtopic.report
# 创建报告对象
report_obj = report_generator.generate_subtopic_report(
subtopic,
subtopic.integrated_info or {},
subtopic.report
)
subtopic_report_objects.append(report_obj)
# 生成最终报告内容
final_content = ai_service.generate_final_report(
session.outline.main_topic,
session.outline.research_questions,
subtopic_reports_dict
)
# 创建最终报告对象
final_report = report_generator.generate_final_report(
session,
subtopic_report_objects,
final_content
)
# 保存报告
report_path = report_generator.save_report(final_report)
# 更新session
session.final_report = final_report.to_markdown()
session.update_status(ResearchStatus.COMPLETED)
research_manager.update_session(session_id, {
'final_report': session.final_report,
'status': session.status
})
research_manager.finalize_research(session_id)
# 发送完成通知
_emit_progress(session_id, 100, "研究完成")
_emit_status(session_id, ResearchStatus.COMPLETED, "研究完成")
_emit_report_ready(session_id, "final")
logger.info(f"研究完成: {session_id}")
return {
"session_id": session_id,
"status": "completed",
"report_path": report_path
}
except Exception as e:
logger.error(f"生成最终报告失败: {e}")
_handle_task_error(session_id, str(e))
raise
# ========== 辅助函数 ==========
def _get_socketio():
"""获取socketio实例"""
# 延迟导入,避免循环依赖
from app import socketio
return socketio
def _emit_progress(session_id: str, percentage: float, message: str):
"""发送进度更新"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_progress
socketio = _get_socketio()
emit_progress(socketio, session_id, {
'percentage': percentage,
'message': message
})
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
def _emit_status(session_id: str, status: ResearchStatus, phase: str):
"""发送状态更新"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_status_change
socketio = _get_socketio()
emit_status_change(socketio, session_id, status.value, phase)
except Exception as e:
logger.error(f"发送状态更新失败: {e}")
def _emit_subtopic_progress(session_id: str, subtopic_id: str,
progress: float, status: str):
"""发送子主题进度"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_subtopic_progress
socketio = _get_socketio()
emit_subtopic_progress(socketio, session_id, subtopic_id, progress, status)
except Exception as e:
logger.error(f"发送子主题进度失败: {e}")
def _emit_report_ready(session_id: str, report_type: str):
"""发送报告就绪通知"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_report_ready
socketio = _get_socketio()
emit_report_ready(socketio, session_id, report_type)
except Exception as e:
logger.error(f"发送报告就绪通知失败: {e}")
def _handle_task_error(session_id: str, error_message: str):
"""处理任务错误"""
try:
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
from app.routes.websocket import emit_error
# 更新session状态
research_manager = ResearchManager()
session = research_manager.get_session(session_id)
if session:
session.update_status(ResearchStatus.ERROR)
session.error_message = error_message
research_manager.update_session(session_id, {
'status': session.status,
'error_message': error_message
})
# 发送错误通知
socketio = _get_socketio()
emit_error(socketio, session_id, error_message)
except Exception as e:
logger.error(f"处理任务错误失败: {e}")

26
app/templates/base.html Normal file
View File

@ -0,0 +1,26 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{% block title %}DeepResearch - 智能深度研究系统{% endblock %}</title>
<link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}">
{% block extra_css %}{% endblock %}
</head>
<body>
<div class="app">
<header class="header">
<h1>DeepResearch - 智能深度研究系统</h1>
{% block header_content %}{% endblock %}
</header>
<main class="main-container">
{% block content %}{% endblock %}
</main>
</div>
<script src="https://cdn.socket.io/4.5.4/socket.io.min.js"></script>
<script src="{{ url_for('static', filename='js/api.js') }}"></script>
{% block extra_js %}{% endblock %}
</body>
</html>

669
app/templates/debug.html Normal file
View File

@ -0,0 +1,669 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>DeepResearch 调试查看器</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #0a0a0a;
color: #e0e0e0;
line-height: 1.6;
padding: 20px;
}
.container {
max-width: 1600px;
margin: 0 auto;
}
h1 {
color: #00ff88;
margin-bottom: 20px;
font-size: 24px;
}
.controls {
background: #1a1a1a;
padding: 15px;
border-radius: 8px;
margin-bottom: 20px;
display: flex;
gap: 15px;
align-items: center;
flex-wrap: wrap;
}
.controls input, .controls select, .controls button {
padding: 8px 15px;
background: #2a2a2a;
border: 1px solid #3a3a3a;
color: #e0e0e0;
border-radius: 4px;
font-size: 14px;
}
.controls input {
flex: 1;
min-width: 300px;
}
.controls button {
cursor: pointer;
background: #00ff88;
color: #000;
font-weight: bold;
transition: all 0.3s;
}
.controls button:hover {
background: #00cc6a;
}
.tabs {
display: flex;
gap: 5px;
margin-bottom: 20px;
border-bottom: 1px solid #3a3a3a;
}
.tab {
padding: 10px 20px;
background: #1a1a1a;
border: none;
color: #888;
cursor: pointer;
border-radius: 4px 4px 0 0;
transition: all 0.3s;
}
.tab.active {
background: #2a2a2a;
color: #00ff88;
}
.log-container {
background: #1a1a1a;
border-radius: 8px;
padding: 20px;
max-height: 80vh;
overflow-y: auto;
}
.log-entry {
background: #2a2a2a;
padding: 15px;
margin-bottom: 15px;
border-radius: 6px;
border-left: 3px solid #00ff88;
}
.log-entry.error {
border-left-color: #ff4444;
}
.log-header {
display: flex;
justify-content: space-between;
margin-bottom: 10px;
font-size: 12px;
color: #888;
}
.log-model {
color: #00ff88;
font-weight: bold;
}
.log-method {
color: #4488ff;
}
.log-content {
margin-top: 10px;
}
.content-section {
margin: 10px 0;
}
.content-label {
font-weight: bold;
color: #4488ff;
margin-bottom: 5px;
display: flex;
justify-content: space-between;
align-items: center;
}
.toggle-btn {
font-size: 12px;
padding: 2px 8px;
background: #3a3a3a;
border: none;
color: #888;
cursor: pointer;
border-radius: 3px;
}
.toggle-btn:hover {
background: #4a4a4a;
}
.content-box {
background: #1a1a1a;
padding: 15px;
border-radius: 6px;
font-family: 'Consolas', 'Monaco', monospace;
font-size: 13px;
white-space: pre-wrap;
word-break: break-word;
overflow-x: auto;
max-height: none;
transition: max-height 0.3s ease;
}
.content-box.collapsed {
max-height: 200px;
overflow: hidden;
position: relative;
}
.content-box.collapsed::after {
content: "... (点击展开查看更多)";
position: absolute;
bottom: 0;
left: 0;
right: 0;
padding: 20px;
background: linear-gradient(transparent, #1a1a1a);
text-align: center;
color: #888;
}
.prompt {
border-left: 3px solid #4488ff;
}
.response {
border-left: 3px solid #00ff88;
}
/* 高亮<think>标签内容 */
.think-content {
background: #1a2a3a;
border: 1px solid #4488ff;
padding: 10px;
margin: 10px 0;
border-radius: 4px;
color: #88bbff;
}
.status {
position: fixed;
bottom: 20px;
right: 20px;
padding: 10px 20px;
background: #2a2a2a;
border-radius: 20px;
font-size: 12px;
}
.status.connected {
background: #00ff88;
color: #000;
}
.no-logs {
text-align: center;
color: #666;
padding: 40px;
}
.metadata {
margin-top: 10px;
font-size: 12px;
color: #666;
}
.download-btn {
position: fixed;
top: 20px;
right: 20px;
background: #4488ff;
color: white;
padding: 10px 20px;
border-radius: 4px;
text-decoration: none;
font-weight: bold;
}
.copy-content-btn {
float: right;
font-size: 12px;
padding: 2px 8px;
background: #3a3a3a;
border: none;
color: #888;
cursor: pointer;
border-radius: 3px;
}
.copy-content-btn:hover {
background: #00ff88;
color: #000;
}
/* 语法高亮 */
.json-key { color: #ff79c6; }
.json-string { color: #f1fa8c; }
.json-number { color: #bd93f9; }
.json-boolean { color: #50fa7b; }
.json-null { color: #ff5555; }
</style>
</head>
<body>
<div class="container">
<h1>🔍 DeepResearch AI 调试查看器</h1>
<div class="controls">
<input type="text" id="sessionId" placeholder="输入会话ID格式uuid">
<select id="logType">
<option value="all">所有日志</option>
<option value="api_calls">API调用</option>
<option value="errors">错误日志</option>
</select>
<button onclick="loadLogs()">加载日志</button>
<button onclick="connectWebSocket()">实时监听</button>
<button onclick="clearLogs()">清空显示</button>
<button onclick="toggleAllContent()">展开/折叠全部</button>
</div>
<div class="tabs">
<button class="tab active" onclick="switchTab('all')">全部</button>
<button class="tab" onclick="switchTab('r1')">R1模型</button>
<button class="tab" onclick="switchTab('v3')">V3模型</button>
<button class="tab" onclick="switchTab('errors')">错误</button>
</div>
<div class="log-container" id="logContainer">
<div class="no-logs">请输入会话ID并点击"加载日志"</div>
</div>
<div class="status" id="status">未连接</div>
<a href="#" class="download-btn" id="downloadBtn" style="display:none;" onclick="downloadLogs()">
📥 下载日志
</a>
</div>
<script src="https://cdn.socket.io/4.5.4/socket.io.min.js"></script>
<script>
let socket = null;
let currentSessionId = '';
let currentTab = 'all';
let allLogs = [];
let allExpanded = false;
function loadLogs() {
const sessionId = document.getElementById('sessionId').value;
const logType = document.getElementById('logType').value;
if (!sessionId) {
alert('请输入会话ID');
return;
}
currentSessionId = sessionId;
document.getElementById('downloadBtn').style.display = 'block';
fetch(`/api/research/${sessionId}/debug?type=${logType}&limit=0`)
.then(resp => resp.json())
.then(data => {
if (data.error) {
alert('加载失败: ' + data.error);
return;
}
allLogs = data.logs || [];
displayLogs();
})
.catch(err => {
alert('加载失败: ' + err.message);
});
}
function displayLogs() {
const container = document.getElementById('logContainer');
let filteredLogs = allLogs;
// 根据当前标签过滤
if (currentTab === 'r1') {
filteredLogs = allLogs.filter(log => log.agent_type === 'R1');
} else if (currentTab === 'v3') {
filteredLogs = allLogs.filter(log => log.agent_type === 'V3');
} else if (currentTab === 'errors') {
filteredLogs = allLogs.filter(log => log.type === 'json_parse_error' || log.response?.startsWith('ERROR:'));
}
if (filteredLogs.length === 0) {
container.innerHTML = '<div class="no-logs">没有找到相关日志</div>';
return;
}
container.innerHTML = filteredLogs.map((log, index) => {
if (log.type === 'json_parse_error') {
return createErrorEntry(log, index);
} else {
return createLogEntry(log, index);
}
}).join('');
}
function createLogEntry(log, index) {
const isError = log.response?.startsWith('ERROR:');
return `
<div class="log-entry ${isError ? 'error' : ''}">
<div class="log-header">
<div>
<span class="log-model">${log.model}</span> |
<span class="log-method">${log.method}</span> |
<span>${log.agent_type}</span>
</div>
<div>${new Date(log.timestamp).toLocaleString()}</div>
</div>
<div class="log-content">
<div class="content-section">
<div class="content-label">
<span>Prompt (${log.prompt_length} 字符):</span>
<button class="copy-content-btn" onclick="copyContent('${index}_prompt')">📋 复制</button>
</div>
<div class="content-box prompt" id="${index}_prompt_content" onclick="toggleExpand(this)">
${highlightThinkTags(escapeHtml(log.prompt))}
</div>
</div>
<div class="content-section">
<div class="content-label">
<span>Response (${log.response_length} 字符):</span>
<button class="copy-content-btn" onclick="copyContent('${index}_response')">📋 复制</button>
</div>
<div class="content-box response" id="${index}_response_content" onclick="toggleExpand(this)">
${highlightContent(log.response)}
</div>
</div>
${log.metadata ? `<div class="metadata">
温度: ${log.temperature || 'N/A'} |
最大tokens: ${log.max_tokens || 'N/A'} |
Prompt tokens: ${log.metadata.prompt_tokens || 'N/A'} |
Completion tokens: ${log.metadata.completion_tokens || 'N/A'}
</div>` : ''}
</div>
</div>
`;
}
function createErrorEntry(log, index) {
return `
<div class="log-entry error">
<div class="log-header">
<div>
<span>JSON解析错误</span>
</div>
<div>${new Date(log.timestamp).toLocaleString()}</div>
</div>
<div class="log-content">
<div class="content-section">
<div class="content-label">
<span>原始文本:</span>
<button class="copy-content-btn" onclick="copyContent('${index}_raw')">📋 复制</button>
</div>
<div class="content-box prompt" id="${index}_raw_content" onclick="toggleExpand(this)">
${escapeHtml(log.raw_text)}
</div>
</div>
<div class="content-section">
<div class="content-label">错误:</div>
<div class="content-box" style="border-left-color: #ff4444;">
${escapeHtml(log.error)}
</div>
</div>
${log.fixed_text ? `<div class="content-section">
<div class="content-label">
<span>修复后:</span>
<button class="copy-content-btn" onclick="copyContent('${index}_fixed')">📋 复制</button>
</div>
<div class="content-box response" id="${index}_fixed_content" onclick="toggleExpand(this)">
${escapeHtml(log.fixed_text)}
</div>
</div>` : ''}
</div>
</div>
`;
}
function connectWebSocket() {
const sessionId = document.getElementById('sessionId').value;
if (!sessionId) {
alert('请先输入会话ID');
return;
}
if (socket) {
socket.disconnect();
}
// 先启用调试模式
fetch('/api/debug/enable', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({session_id: sessionId})
}).then(() => {
// 连接WebSocket
socket = io();
socket.on('connect', () => {
document.getElementById('status').textContent = '已连接';
document.getElementById('status').classList.add('connected');
// 加入会话房间
socket.emit('join_session', {session_id: sessionId});
});
socket.on('disconnect', () => {
document.getElementById('status').textContent = '已断开';
document.getElementById('status').classList.remove('connected');
});
// 监听调试日志
socket.on('ai_debug_log', (data) => {
if (data.log_entry) {
allLogs.push(data.log_entry);
displayLogs();
// 滚动到底部
const container = document.getElementById('logContainer');
container.scrollTop = container.scrollHeight;
}
});
// 监听解析错误
socket.on('parse_error', (data) => {
allLogs.push(data);
displayLogs();
});
});
}
function switchTab(tab) {
currentTab = tab;
document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
event.target.classList.add('active');
displayLogs();
}
function clearLogs() {
allLogs = [];
displayLogs();
}
function downloadLogs() {
if (!currentSessionId) return;
window.location.href = `/api/research/${currentSessionId}/debug/download`;
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
function highlightThinkTags(text) {
// 高亮显示<think>标签内容
return text.replace(/&lt;think&gt;([\s\S]*?)&lt;\/think&gt;/g,
'<div class="think-content">&lt;think&gt;<br>$1<br>&lt;/think&gt;</div>');
}
function highlightContent(text) {
// 先转义HTML
let escaped = escapeHtml(text);
// 高亮<think>标签
escaped = highlightThinkTags(escaped);
// 尝试高亮JSON但不要破坏think标签的高亮
// 只有在没有think标签的情况下才尝试JSON高亮
if (!text.includes('<think>') && (text.trim().startsWith('{') || text.trim().startsWith('['))) {
try {
// 验证是否为有效JSON
JSON.parse(text);
// 如果是,进行语法高亮
escaped = highlightJson(escaped);
} catch (e) {
// 不是有效JSON保持原样
}
}
return escaped;
}
function highlightJson(text) {
// 简单的JSON语法高亮
text = text.replace(/"([^"]+)":/g, '<span class="json-key">"$1"</span>:');
text = text.replace(/: "([^"]+)"/g, ': <span class="json-string">"$1"</span>');
text = text.replace(/: (\d+)/g, ': <span class="json-number">$1</span>');
text = text.replace(/: (true|false)/g, ': <span class="json-boolean">$1</span>');
text = text.replace(/: null/g, ': <span class="json-null">null</span>');
return text;
}
function toggleExpand(element) {
if (element.classList.contains('collapsed')) {
element.classList.remove('collapsed');
} else {
// 只有内容超过200px高度时才折叠
if (element.scrollHeight > 200) {
element.classList.add('collapsed');
}
}
}
function toggleAllContent() {
const contentBoxes = document.querySelectorAll('.content-box');
allExpanded = !allExpanded;
contentBoxes.forEach(box => {
if (allExpanded) {
box.classList.remove('collapsed');
} else if (box.scrollHeight > 200) {
box.classList.add('collapsed');
}
});
}
function copyContent(elementId) {
const element = document.getElementById(elementId + '_content');
const text = element.textContent;
navigator.clipboard.writeText(text).then(() => {
// 显示复制成功提示
const tooltip = document.createElement('div');
tooltip.className = 'copy-tooltip';
tooltip.textContent = '已复制!';
tooltip.style.cssText = `
position: fixed;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
background: #00ff88;
color: #000;
padding: 10px 20px;
border-radius: 8px;
z-index: 1000;
`;
document.body.appendChild(tooltip);
setTimeout(() => {
tooltip.remove();
}, 2000);
});
}
// 为复制的内容添加原始值
window.copyContent = function(prefix) {
const log = allLogs[parseInt(prefix.split('_')[0])];
let text = '';
if (prefix.endsWith('_prompt')) {
text = log.prompt;
} else if (prefix.endsWith('_response')) {
text = log.response;
} else if (prefix.endsWith('_raw')) {
text = log.raw_text;
} else if (prefix.endsWith('_fixed')) {
text = log.fixed_text;
}
navigator.clipboard.writeText(text).then(() => {
const tooltip = document.createElement('div');
tooltip.className = 'copy-tooltip';
tooltip.textContent = '已复制原始内容!';
tooltip.style.cssText = `
position: fixed;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
background: #00ff88;
color: #000;
padding: 10px 20px;
border-radius: 8px;
z-index: 1000;
`;
document.body.appendChild(tooltip);
setTimeout(() => {
tooltip.remove();
}, 2000);
});
};
// 自动加载URL中的session ID
const urlParams = new URLSearchParams(window.location.search);
const sessionIdParam = urlParams.get('session_id');
if (sessionIdParam) {
document.getElementById('sessionId').value = sessionIdParam;
loadLogs();
connectWebSocket();
}
</script>
</body>
</html>

38
app/templates/index.html Normal file
View File

@ -0,0 +1,38 @@
{% extends "base.html" %}
{% block content %}
<div class="start-screen">
<div class="start-card">
<h2>开始新的研究</h2>
<div class="input-group">
<input
type="text"
id="questionInput"
placeholder="输入你想研究的问题..."
class="question-input"
onkeypress="if(event.key === 'Enter') startResearch()"
/>
<button onclick="startResearch()" class="start-button" id="startBtn">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="11" cy="11" r="8"></circle>
<path d="m21 21-4.35-4.35"></path>
</svg>
</button>
</div>
<div class="history-section" id="historySection" style="display: none;">
<h3>历史研究</h3>
<div class="session-list" id="sessionList"></div>
</div>
</div>
</div>
<!-- 加载动画 -->
<div id="loading" class="loading-overlay" style="display: none;">
<div class="loading-spinner"></div>
</div>
{% endblock %}
{% block extra_js %}
<script src="{{ url_for('static', filename='js/index.js') }}"></script>
{% endblock %}

View File

@ -0,0 +1,46 @@
{% extends "base.html" %}
{% block header_content %}
<div class="progress-bar" id="progressBar">
<div class="progress-fill" id="progressFill" style="width: 0%"></div>
</div>
<div class="progress-message" id="progressMessage"></div>
{% endblock %}
{% block content %}
<div class="research-view">
<div class="tree-container" id="treeContainer">
<!-- 动态生成的研究树 -->
</div>
<div class="action-buttons">
<button class="action-button" onclick="window.location.href='/'">返回</button>
<button class="action-button primary" id="downloadBtn" style="display: none;" onclick="downloadReport()">
📄 下载报告
</button>
<button class="action-button danger" id="cancelBtn" onclick="cancelResearch()">
✕ 取消研究
</button>
</div>
</div>
<!-- 详情面板 -->
<div class="detail-panel" id="detailPanel">
<div class="panel-header">
<h3>详细信息</h3>
<button class="panel-close" onclick="closePanel()">×</button>
</div>
<div class="panel-content" id="panelContent">
<!-- 动态内容 -->
</div>
</div>
<script>
const SESSION_ID = "{{ session_id }}";
</script>
{% endblock %}
{% block extra_js %}
<script src="{{ url_for('static', filename='js/research-tree.js') }}"></script>
<script src="{{ url_for('static', filename='js/research.js') }}"></script>
{% endblock %}

0
app/utils/__init__.py Normal file
View File

159
app/utils/debug_logger.py Normal file
View File

@ -0,0 +1,159 @@
# 文件位置: app/utils/debug_logger.py
# 文件名: debug_logger.py
"""
AI调试日志记录器
记录所有AI模型的输入输出用于调试和优化
"""
import os
import json
import logging
from datetime import datetime
from typing import Dict, Any, Optional, List
from config import Config
logger = logging.getLogger(__name__)
class AIDebugLogger:
"""AI调试日志记录器"""
def __init__(self):
self.debug_dir = os.path.join(Config.DATA_DIR, 'debug')
os.makedirs(self.debug_dir, exist_ok=True)
self.current_session_id = None
self.socketio = None
def set_socketio(self, socketio):
"""设置socketio实例用于实时推送"""
self.socketio = socketio
def set_session(self, session_id: str):
"""设置当前会话ID"""
self.current_session_id = session_id
# 创建会话专属的调试目录
session_debug_dir = os.path.join(self.debug_dir, session_id)
os.makedirs(session_debug_dir, exist_ok=True)
def log_api_call(self, model: str, agent_type: str, method: str,
prompt: str, response: str, temperature: float = None,
max_tokens: int = None, metadata: Dict[str, Any] = None):
"""记录API调用"""
timestamp = datetime.now()
# 构建日志条目 - 保存完整内容
log_entry = {
"timestamp": timestamp.isoformat(),
"session_id": self.current_session_id,
"model": model,
"agent_type": agent_type, # "R1" or "V3"
"method": method,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_length": len(prompt),
"response_length": len(response),
"prompt": prompt, # 完整保存
"response": response, # 完整保存
"metadata": metadata or {}
}
# 保存到文件
if self.current_session_id:
self._save_to_file(log_entry)
# 通过WebSocket推送如果可用
if self.socketio and self.current_session_id:
self._emit_debug_log(log_entry)
# 记录到标准日志(摘要)
logger.debug(f"AI Call - {agent_type}/{method}: "
f"prompt={len(prompt)}chars, response={len(response)}chars")
def log_json_parse_error(self, raw_text: str, error: str, fixed_text: Optional[str] = None):
"""记录JSON解析错误"""
log_entry = {
"timestamp": datetime.now().isoformat(),
"session_id": self.current_session_id,
"type": "json_parse_error",
"raw_text": raw_text,
"error": error,
"fixed_text": fixed_text,
"fixed": fixed_text is not None
}
if self.current_session_id:
self._save_error_log(log_entry)
if self.socketio and self.current_session_id:
self.socketio.emit('parse_error', {
'session_id': self.current_session_id,
**log_entry
}, room=self.current_session_id)
def _save_to_file(self, log_entry: Dict[str, Any]):
"""保存到文件"""
if not self.current_session_id:
return
# 保存到会话专属文件
session_log_file = os.path.join(
self.debug_dir,
self.current_session_id,
f"api_calls_{datetime.now().strftime('%Y%m%d')}.jsonl"
)
with open(session_log_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(log_entry, ensure_ascii=False) + '\n')
def _save_error_log(self, log_entry: Dict[str, Any]):
"""保存错误日志"""
if not self.current_session_id:
return
error_log_file = os.path.join(
self.debug_dir,
self.current_session_id,
"errors.jsonl"
)
with open(error_log_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(log_entry, ensure_ascii=False) + '\n')
def _emit_debug_log(self, log_entry: Dict[str, Any]):
"""通过WebSocket发送调试日志"""
# 发送完整内容,不截断
self.socketio.emit('ai_debug_log', {
'session_id': self.current_session_id,
'log_entry': log_entry # 完整发送
}, room=self.current_session_id)
def get_session_logs(self, session_id: str, log_type: str = 'all') -> List[Dict[str, Any]]:
"""获取会话的日志"""
logs = []
session_debug_dir = os.path.join(self.debug_dir, session_id)
if not os.path.exists(session_debug_dir):
return logs
# 根据类型读取不同的日志文件
if log_type in ['all', 'api_calls']:
api_files = [f for f in os.listdir(session_debug_dir) if f.startswith('api_calls_')]
for file_name in api_files:
file_path = os.path.join(session_debug_dir, file_name)
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
logs.append(json.loads(line))
if log_type in ['all', 'errors']:
error_file = os.path.join(session_debug_dir, 'errors.jsonl')
if os.path.exists(error_file):
with open(error_file, 'r', encoding='utf-8') as f:
for line in f:
logs.append(json.loads(line))
# 按时间戳排序
logs.sort(key=lambda x: x.get('timestamp', ''))
return logs
# 全局调试日志实例
ai_debug_logger = AIDebugLogger()

180
app/utils/formatters.py Normal file
View File

@ -0,0 +1,180 @@
"""
格式化工具
"""
import re
from datetime import datetime
from typing import List, Dict, Any, Optional
def format_datetime(dt: datetime, format: str = "full") -> str:
"""格式化日期时间"""
if format == "full":
return dt.strftime("%Y-%m-%d %H:%M:%S")
elif format == "date":
return dt.strftime("%Y-%m-%d")
elif format == "time":
return dt.strftime("%H:%M:%S")
elif format == "relative":
return get_relative_time(dt)
else:
return dt.isoformat()
def get_relative_time(dt: datetime) -> str:
"""获取相对时间"""
now = datetime.now()
delta = now - dt
if delta.total_seconds() < 60:
return "刚刚"
elif delta.total_seconds() < 3600:
minutes = int(delta.total_seconds() / 60)
return f"{minutes}分钟前"
elif delta.total_seconds() < 86400:
hours = int(delta.total_seconds() / 3600)
return f"{hours}小时前"
elif delta.days < 30:
return f"{delta.days}天前"
elif delta.days < 365:
months = int(delta.days / 30)
return f"{months}个月前"
else:
years = int(delta.days / 365)
return f"{years}年前"
def format_file_size(size_bytes: int) -> str:
"""格式化文件大小"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} PB"
def format_percentage(value: float, decimals: int = 1) -> str:
"""格式化百分比"""
return f"{value:.{decimals}f}%"
def format_search_results(results: List[Dict[str, Any]]) -> str:
"""格式化搜索结果为文本"""
formatted_lines = []
for i, result in enumerate(results, 1):
formatted_lines.append(f"{i}. {result.get('title', '无标题')}")
formatted_lines.append(f" URL: {result.get('url', 'N/A')}")
formatted_lines.append(f" {result.get('snippet', '无摘要')}")
formatted_lines.append("")
return '\n'.join(formatted_lines)
def format_outline_text(outline: Dict[str, Any]) -> str:
"""格式化大纲为文本"""
lines = []
lines.append(f"# {outline.get('main_topic', '研究主题')}")
lines.append("")
lines.append("## 研究问题")
for i, question in enumerate(outline.get('research_questions', []), 1):
lines.append(f"{i}. {question}")
lines.append("")
lines.append("## 子主题")
for i, subtopic in enumerate(outline.get('sub_topics', []), 1):
lines.append(f"{i}. **{subtopic.get('topic', '')}** ({subtopic.get('priority', '')})")
lines.append(f" {subtopic.get('explain', '')}")
return '\n'.join(lines)
def clean_markdown(text: str) -> str:
"""清理Markdown文本"""
# 移除多余的空行
text = re.sub(r'\n{3,}', '\n\n', text)
# 确保标题前后有空行
text = re.sub(r'([^\n])\n(#{1,6} )', r'\1\n\n\2', text)
text = re.sub(r'(#{1,6} [^\n]+)\n([^\n])', r'\1\n\n\2', text)
# 修复列表格式
text = re.sub(r'\n- ', r'\n- ', text)
text = re.sub(r'\n\* ', r'\n* ', text)
text = re.sub(r'\n\d+\. ', lambda m: '\n' + m.group(0)[1:], text)
return text.strip()
def truncate_text(text: str, max_length: int, ellipsis: str = "...") -> str:
"""截断文本"""
if len(text) <= max_length:
return text
# 在词边界截断
truncated = text[:max_length]
last_space = truncated.rfind(' ')
if last_space > max_length * 0.8: # 如果空格在80%位置之后
truncated = truncated[:last_space]
return truncated + ellipsis
def highlight_keywords(text: str, keywords: List[str]) -> str:
"""高亮关键词"""
for keyword in keywords:
# 使用正则表达式进行大小写不敏感的替换
pattern = re.compile(re.escape(keyword), re.IGNORECASE)
text = pattern.sub(f"**{keyword}**", text)
return text
def extract_urls(text: str) -> List[str]:
"""从文本中提取URL"""
url_pattern = re.compile(
r'https?://(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b'
r'(?:[-a-zA-Z0-9()@:%_\+.~#?&/=]*)'
)
urls = url_pattern.findall(text)
return list(set(urls)) # 去重
def format_json_output(data: Any, indent: int = 2) -> str:
"""格式化JSON输出"""
import json
return json.dumps(
data,
ensure_ascii=False,
indent=indent,
sort_keys=True,
default=str # 处理datetime等特殊对象
)
def create_summary(text: str, max_sentences: int = 3) -> str:
"""创建文本摘要"""
# 简单的句子分割
sentences = re.split(r'[。!?.!?]+', text)
sentences = [s.strip() for s in sentences if s.strip()]
# 返回前N个句子
summary_sentences = sentences[:max_sentences]
if len(sentences) > max_sentences:
return ''.join(summary_sentences) + '。...'
else:
return ''.join(summary_sentences) + ''
def format_status_message(status: str, phase: Optional[str] = None) -> str:
"""格式化状态消息"""
status_messages = {
"pending": "等待开始",
"analyzing": "分析问题中",
"outlining": "制定大纲中",
"researching": "研究进行中",
"writing": "撰写报告中",
"reviewing": "审核内容中",
"completed": "研究完成",
"error": "发生错误",
"cancelled": "已取消"
}
message = status_messages.get(status, status)
if phase:
message = f"{message} - {phase}"
return message

283
app/utils/json_parser.py Normal file
View File

@ -0,0 +1,283 @@
"""
JSON解析和修复工具
"""
import json
import re
import logging
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
def parse_json_safely(text: str) -> Dict[str, Any]:
"""安全解析JSON带错误修复"""
# 首先尝试直接解析
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# 尝试修复常见问题
fixed_text = fix_json_common_issues(text)
try:
return json.loads(fixed_text)
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}")
logger.debug(f"原始文本: {text[:500]}...")
# 尝试更激进的修复
fixed_text = fix_json_aggressive(fixed_text)
try:
return json.loads(fixed_text)
except json.JSONDecodeError:
# 最后的尝试提取JSON部分
json_part = extract_json_from_text(text)
if json_part:
try:
return json.loads(json_part)
except:
pass
# 返回空字典而不是抛出异常
logger.error("无法解析JSON返回空字典")
return {}
def fix_json_common_issues(text: str) -> str:
"""修复常见的JSON问题"""
# 移除可能的Markdown代码块标记
text = re.sub(r'^```json\s*', '', text, flags=re.MULTILINE)
text = re.sub(r'^```\s*$', '', text, flags=re.MULTILINE)
# 移除BOM
text = text.lstrip('\ufeff')
# 移除控制字符
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
# 修复尾随逗号
text = re.sub(r',\s*}', '}', text)
text = re.sub(r',\s*]', ']', text)
# 修复单引号JSON只接受双引号
# 但要小心不要替换值中的单引号
text = fix_single_quotes(text)
# 修复未加引号的键
text = fix_unquoted_keys(text)
# 修复Python的True/False/None
text = text.replace('True', 'true')
text = text.replace('False', 'false')
text = text.replace('None', 'null')
# 移除注释
text = remove_json_comments(text)
return text.strip()
def fix_json_aggressive(text: str) -> str:
"""更激进的JSON修复"""
# 尝试修复断行的字符串
text = re.sub(r'"\s*\n\s*"', '" "', text)
# 修复缺失的逗号
# 在 } 或 ] 后面跟着 " 或 { 或 [ 的地方添加逗号
text = re.sub(r'}\s*"', '},\n"', text)
text = re.sub(r']\s*"', '],\n"', text)
text = re.sub(r'}\s*{', '},\n{', text)
text = re.sub(r']\s*\[', '],\n[', text)
# 修复缺失的冒号
text = re.sub(r'"([^"]+)"\s*"', r'"\1": "', text)
# 确保所有字符串值都被引号包围
# 这个比较复杂,需要小心处理
return text
def fix_single_quotes(text: str) -> str:
"""修复单引号为双引号"""
# 使用更智能的方法替换单引号
# 只替换作为字符串边界的单引号
result = []
in_string = False
string_char = None
i = 0
while i < len(text):
char = text[i]
if not in_string:
if char == "'" and (i == 0 or text[i-1] in ' \n\t:,{['):
# 可能是字符串开始
result.append('"')
in_string = True
string_char = "'"
else:
result.append(char)
else:
if char == string_char and (i + 1 >= len(text) or text[i+1] in ' \n\t,}]:'):
# 字符串结束
result.append('"')
in_string = False
string_char = None
elif char == '\\' and i + 1 < len(text):
# 转义字符
result.append(char)
result.append(text[i + 1])
i += 1
else:
result.append(char)
i += 1
return ''.join(result)
def fix_unquoted_keys(text: str) -> str:
"""修复未加引号的键"""
# 匹配形如 key: value 的模式
pattern = r'([,\{\s])([a-zA-Z_][a-zA-Z0-9_]*)\s*:'
replacement = r'\1"\2":'
return re.sub(pattern, replacement, text)
def remove_json_comments(text: str) -> str:
"""移除JSON中的注释"""
# 移除单行注释 //
text = re.sub(r'//.*$', '', text, flags=re.MULTILINE)
# 移除多行注释 /* */
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
return text
def extract_json_from_text(text: str) -> Optional[str]:
"""从文本中提取JSON部分"""
# 查找第一个 { 或 [
start_idx = -1
start_char = None
for i, char in enumerate(text):
if char in '{[':
start_idx = i
start_char = char
break
if start_idx == -1:
return None
# 查找匹配的结束字符
end_char = '}' if start_char == '{' else ']'
bracket_count = 0
in_string = False
escape = False
for i in range(start_idx, len(text)):
char = text[i]
if escape:
escape = False
continue
if char == '\\':
escape = True
continue
if char == '"' and not escape:
in_string = not in_string
continue
if not in_string:
if char == start_char:
bracket_count += 1
elif char == end_char:
bracket_count -= 1
if bracket_count == 0:
return text[start_idx:i+1]
return None
def validate_json_schema(data: Dict[str, Any], schema: Dict[str, Any]) -> List[str]:
"""验证JSON是否符合schema"""
errors = []
# 简单的schema验证实现
required_fields = schema.get('required', [])
properties = schema.get('properties', {})
# 检查必需字段
for field in required_fields:
if field not in data:
errors.append(f"缺少必需字段: {field}")
# 检查字段类型
for field, value in data.items():
if field in properties:
expected_type = properties[field].get('type')
if expected_type:
actual_type = type(value).__name__
type_mapping = {
'string': 'str',
'number': 'float',
'integer': 'int',
'boolean': 'bool',
'array': 'list',
'object': 'dict'
}
expected_python_type = type_mapping.get(expected_type, expected_type)
if actual_type != expected_python_type:
# 特殊处理int可以作为float
if not (expected_python_type == 'float' and actual_type == 'int'):
errors.append(
f"字段 '{field}' 类型错误: "
f"期望 {expected_type}, 实际 {actual_type}"
)
return errors
def merge_json_objects(obj1: Dict[str, Any], obj2: Dict[str, Any],
deep: bool = True) -> Dict[str, Any]:
"""合并两个JSON对象"""
result = obj1.copy()
for key, value in obj2.items():
if key in result and deep and isinstance(result[key], dict) and isinstance(value, dict):
# 深度合并
result[key] = merge_json_objects(result[key], value, deep=True)
elif key in result and deep and isinstance(result[key], list) and isinstance(value, list):
# 合并列表(去重)
result[key] = list(set(result[key] + value))
else:
# 直接覆盖
result[key] = value
return result
def json_to_flat_dict(data: Dict[str, Any], parent_key: str = '',
separator: str = '.') -> Dict[str, Any]:
"""将嵌套的JSON转换为扁平的字典"""
items = []
for key, value in data.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
if isinstance(value, dict):
items.extend(
json_to_flat_dict(value, new_key, separator).items()
)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
items.extend(
json_to_flat_dict(item, f"{new_key}[{i}]", separator).items()
)
else:
items.append((f"{new_key}[{i}]", item))
else:
items.append((new_key, value))
return dict(items)

223
app/utils/logger.py Normal file
View File

@ -0,0 +1,223 @@
"""
日志配置工具
"""
import os
import logging
import logging.handlers
from datetime import datetime
from pythonjsonlogger import jsonlogger
def setup_logging(app):
"""设置应用日志"""
log_level = app.config.get('LOG_LEVEL', 'INFO')
log_dir = app.config.get('LOG_DIR', 'logs')
# 确保日志目录存在
os.makedirs(log_dir, exist_ok=True)
# 设置根日志器
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, log_level))
# 清除现有的处理器
root_logger.handlers = []
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(getattr(logging, log_level))
console_formatter = ColoredFormatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
# 文件处理器 - 一般日志
file_handler = logging.handlers.RotatingFileHandler(
os.path.join(log_dir, 'app.log'),
maxBytes=10485760, # 10MB
backupCount=10
)
file_handler.setLevel(logging.INFO)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
file_handler.setFormatter(file_formatter)
root_logger.addHandler(file_handler)
# 错误日志文件
error_handler = logging.handlers.RotatingFileHandler(
os.path.join(log_dir, 'error.log'),
maxBytes=10485760,
backupCount=10
)
error_handler.setLevel(logging.ERROR)
error_handler.setFormatter(file_formatter)
root_logger.addHandler(error_handler)
# JSON格式日志用于日志分析
json_handler = logging.handlers.RotatingFileHandler(
os.path.join(log_dir, 'app.json.log'),
maxBytes=10485760,
backupCount=10
)
json_formatter = CustomJsonFormatter()
json_handler.setFormatter(json_formatter)
json_handler.setLevel(logging.INFO)
root_logger.addHandler(json_handler)
# 研究任务专用日志
research_logger = logging.getLogger('research')
research_handler = logging.handlers.RotatingFileHandler(
os.path.join(log_dir, 'research.log'),
maxBytes=10485760,
backupCount=10
)
research_handler.setFormatter(file_formatter)
research_logger.addHandler(research_handler)
research_logger.setLevel(logging.DEBUG)
# 设置第三方库的日志级别
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('requests').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
app.logger.info(f"日志系统初始化完成,级别: {log_level}")
class ColoredFormatter(logging.Formatter):
"""带颜色的控制台日志格式化器"""
COLORS = {
'DEBUG': '\033[36m', # 青色
'INFO': '\033[32m', # 绿色
'WARNING': '\033[33m', # 黄色
'ERROR': '\033[31m', # 红色
'CRITICAL': '\033[35m', # 紫色
}
RESET = '\033[0m'
def format(self, record):
log_color = self.COLORS.get(record.levelname, self.RESET)
record.levelname = f"{log_color}{record.levelname}{self.RESET}"
return super().format(record)
class CustomJsonFormatter(jsonlogger.JsonFormatter):
"""自定义JSON日志格式化器"""
def add_fields(self, log_record, record, message_dict):
super().add_fields(log_record, record, message_dict)
# 添加额外字段
log_record['timestamp'] = datetime.utcnow().isoformat()
log_record['level'] = record.levelname
log_record['logger'] = record.name
# 添加异常信息
if record.exc_info:
log_record['exception'] = self.formatException(record.exc_info)
# 添加额外的上下文信息
if hasattr(record, 'session_id'):
log_record['session_id'] = record.session_id
if hasattr(record, 'subtopic_id'):
log_record['subtopic_id'] = record.subtopic_id
if hasattr(record, 'user_id'):
log_record['user_id'] = record.user_id
def get_logger(name: str) -> logging.Logger:
"""获取指定名称的日志器"""
return logging.getLogger(name)
def log_performance(func):
"""性能日志装饰器"""
import functools
import time
@functools.wraps(func)
def wrapper(*args, **kwargs):
logger = logging.getLogger(func.__module__)
start_time = time.time()
try:
result = func(*args, **kwargs)
elapsed_time = time.time() - start_time
logger.info(
f"{func.__name__} 执行成功,耗时: {elapsed_time:.3f}",
extra={'performance': {'function': func.__name__, 'duration': elapsed_time}}
)
return result
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(
f"{func.__name__} 执行失败,耗时: {elapsed_time:.3f}秒,错误: {str(e)}",
extra={'performance': {'function': func.__name__, 'duration': elapsed_time}},
exc_info=True
)
raise
return wrapper
def log_api_call(service_name: str):
"""API调用日志装饰器"""
def decorator(func):
import functools
@functools.wraps(func)
def wrapper(*args, **kwargs):
logger = logging.getLogger('api_calls')
# 记录请求
logger.info(
f"调用 {service_name} API: {func.__name__}",
extra={
'api_service': service_name,
'api_method': func.__name__,
'args': str(args)[:200], # 限制长度
'kwargs': str(kwargs)[:200]
}
)
try:
result = func(*args, **kwargs)
logger.info(
f"{service_name} API 调用成功: {func.__name__}",
extra={
'api_service': service_name,
'api_method': func.__name__,
'success': True
}
)
return result
except Exception as e:
logger.error(
f"{service_name} API 调用失败: {func.__name__} - {str(e)}",
extra={
'api_service': service_name,
'api_method': func.__name__,
'success': False,
'error': str(e)
},
exc_info=True
)
raise
return wrapper
return decorator
class SessionLoggerAdapter(logging.LoggerAdapter):
"""带会话ID的日志适配器"""
def process(self, msg, kwargs):
if 'extra' not in kwargs:
kwargs['extra'] = {}
if hasattr(self, 'session_id'):
kwargs['extra']['session_id'] = self.session_id
return msg, kwargs
def get_session_logger(session_id: str, logger_name: str = 'research') -> SessionLoggerAdapter:
"""获取带会话ID的日志器"""
logger = logging.getLogger(logger_name)
adapter = SessionLoggerAdapter(logger, {})
adapter.session_id = session_id
return adapter

131
app/utils/validators.py Normal file
View File

@ -0,0 +1,131 @@
"""
输入验证工具
"""
import re
from typing import Optional
def validate_question(question: str) -> Optional[str]:
"""验证用户问题"""
if not question:
return "问题不能为空"
if len(question) < 5:
return "问题太短,请提供更详细的描述"
if len(question) > 1000:
return "问题太长请精简到1000字以内"
# 检查是否包含有效内容(不只是标点符号)
if not re.search(r'[a-zA-Z\u4e00-\u9fa5]+', question):
return "请输入有效的问题内容"
return None
def validate_outline_feedback(feedback: str) -> Optional[str]:
"""验证大纲反馈"""
if not feedback:
return "反馈内容不能为空"
if len(feedback) < 10:
return "请提供更详细的修改建议"
if len(feedback) > 500:
return "反馈内容请控制在500字以内"
return None
def validate_session_id(session_id: str) -> Optional[str]:
"""验证会话ID"""
if not session_id:
return "会话ID不能为空"
# UUID格式验证
uuid_pattern = re.compile(
r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$',
re.IGNORECASE
)
if not uuid_pattern.match(session_id):
return "无效的会话ID格式"
return None
def validate_subtopic_id(subtopic_id: str) -> Optional[str]:
"""验证子主题ID"""
if not subtopic_id:
return "子主题ID不能为空"
# 格式: ST开头 + 8位十六进制
if not re.match(r'^ST[0-9a-f]{8}$', subtopic_id, re.IGNORECASE):
return "无效的子主题ID格式"
return None
def validate_search_query(query: str) -> Optional[str]:
"""验证搜索查询"""
if not query:
return "搜索查询不能为空"
if len(query) < 2:
return "搜索查询太短"
if len(query) > 200:
return "搜索查询太长请控制在200字符以内"
# 检查是否包含特殊字符攻击
dangerous_patterns = [
r'<script',
r'javascript:',
r'onerror=',
r'onclick=',
r'DROP TABLE',
r'DELETE FROM',
r'INSERT INTO'
]
for pattern in dangerous_patterns:
if re.search(pattern, query, re.IGNORECASE):
return "搜索查询包含不允许的内容"
return None
def validate_priority(priority: str) -> Optional[str]:
"""验证优先级"""
valid_priorities = ['high', 'medium', 'low']
if priority not in valid_priorities:
return f"优先级必须是以下之一: {', '.join(valid_priorities)}"
return None
def validate_report_format(format: str) -> Optional[str]:
"""验证报告格式"""
valid_formats = ['json', 'markdown', 'html', 'pdf']
if format not in valid_formats:
return f"报告格式必须是以下之一: {', '.join(valid_formats)}"
return None
def sanitize_filename(filename: str) -> str:
"""清理文件名"""
# 移除不安全的字符
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
# 限制长度
if len(filename) > 200:
filename = filename[:200]
# 确保不以点开头(隐藏文件)
if filename.startswith('.'):
filename = '_' + filename[1:]
return filename
def validate_json_structure(data: dict, required_fields: list) -> Optional[str]:
"""验证JSON结构"""
for field in required_fields:
if field not in data:
return f"缺少必要字段: {field}"
return None

115
config.py Normal file
View File

@ -0,0 +1,115 @@
# 文件位置: config.py
# 文件名: config.py
import os
from datetime import timedelta
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
class Config:
"""基础配置"""
# Flask配置
SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production')
DEBUG = False
TESTING = False
# API配置
DEEPSEEK_API_KEY = os.environ.get('DEEPSEEK_API_KEY')
DEEPSEEK_BASE_URL = os.environ.get('DEEPSEEK_BASE_URL', 'https://api.deepseek.com/v1')
TAVILY_API_KEY = os.environ.get('TAVILY_API_KEY')
# 模型配置
R1_MODEL = "deepseek-reasoner" # R1-0528
V3_MODEL = "deepseek-chat" # V3-0324
# 研究配置
MAX_CONCURRENT_SUBTOPICS = 10
MAX_SEARCHES_HIGH_PRIORITY = 15
MAX_SEARCHES_MEDIUM_PRIORITY = 10
MAX_SEARCHES_LOW_PRIORITY = 5
# 搜索配置
TAVILY_MAX_RESULTS = 10
TAVILY_SEARCH_DEPTH = "advanced"
TAVILY_INCLUDE_ANSWER = True
TAVILY_INCLUDE_RAW_CONTENT = False
# 任务管理配置替代Celery
TASK_POOL_SIZE = 10 # 线程池大小
TASK_TIMEOUT = {
'question_analysis': 60,
'outline_creation': 120,
'search': 30,
'report_generation': 180
}
# 重试配置
MAX_RETRIES = 3
RETRY_DELAY = 5 # 秒
# 存储配置
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
SESSIONS_DIR = os.path.join(DATA_DIR, 'sessions')
REPORTS_DIR = os.path.join(DATA_DIR, 'reports')
CACHE_DIR = os.path.join(DATA_DIR, 'cache')
# 日志配置
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO')
LOG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')
# MongoDB配置可选
MONGODB_URI = os.environ.get('MONGODB_URI', 'mongodb://localhost:27017/deepresearch')
# WebSocket配置
SOCKETIO_ASYNC_MODE = 'eventlet'
@staticmethod
def init_app(app):
"""初始化应用配置"""
# 确保必要的目录存在
for dir_path in [Config.DATA_DIR, Config.SESSIONS_DIR,
Config.REPORTS_DIR, Config.CACHE_DIR, Config.LOG_DIR]:
os.makedirs(dir_path, exist_ok=True)
class DevelopmentConfig(Config):
"""开发环境配置"""
DEBUG = True
class ProductionConfig(Config):
"""生产环境配置"""
DEBUG = False
@classmethod
def init_app(cls, app):
Config.init_app(app)
# 生产环境特定的初始化
import logging
from logging.handlers import RotatingFileHandler
if not app.debug and not app.testing:
file_handler = RotatingFileHandler(
os.path.join(cls.LOG_DIR, 'deepresearch.log'),
maxBytes=10485760,
backupCount=10
)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]'
))
file_handler.setLevel(logging.INFO)
app.logger.addHandler(file_handler)
app.logger.setLevel(logging.INFO)
app.logger.info('DeepResearch startup')
class TestingConfig(Config):
"""测试环境配置"""
TESTING = True
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig
}

0
data/cache/.gitkeep vendored Normal file
View File

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,4 @@
{"timestamp": "2025-07-02T10:22:06.050479", "session_id": "1d3103b4-05e4-4f7f-bac3-df18318667c4", "type": "json_parse_error", "raw_text": "{\n \"研究标题\": \"人工智能在医疗诊断领域的应用现状、挑战与未来趋势研究\",\n \"研究目标\": [\n \"系统梳理AI在医疗诊断中的核心应用场景及代表性案例\",\n \"揭示AI诊断系统在临床验证中的技术瓶颈与数据需求\",\n \"评估数据隐私与孤岛问题对技术落地的制约机制及应对方案\",\n \"构建医患接受度评估模型与伦理责任分析框架\",\n \"预测跨模态学习驱动下的个性化诊疗创新路径\"\n ],\n \"研究问题\": [\n \"Q1: 当前AI在医疗诊断中的主要应用领域及典型案例\",\n \"Q2: AI与传统诊断的临床准确性差异及技术瓶颈\",\n \"Q3: 数据隐私政策与孤岛问题对技术发展的影响\",\n \"Q4: 医患群体对AI诊断的信任差异与伦理挑战\",\n \"Q5: 跨模态融合在个性化诊疗中的发展潜力\"\n ],\n \"方法论\": [\n {\n \"方法名称\": \"文献计量分析\",\n \"方法描述\": \"运用VOSviewer对近五年核心文献进行共现分析构建技术应用知识图谱\",\n \"适用研究问题\": [\"Q1\"]\n },\n {\n \"方法名称\": \"对比实验验证\",\n \"方法描述\": \"选择CT影像诊断等典型场景设计ROC曲线与混淆矩阵对比分析框架\",\n \"适用研究问题\": [\"Q2\"]\n },\n {\n \"方法名称\": \"政策文本分析\",\n \"方法描述\": \"基于GDPR/HIPAA等法规构建隐私保护强度指数评估政策约束力\",\n \"适用研究问题\": [\"Q3\"]\n },\n {\n \"方法名称\": \"德尔菲专家调查\",\n \"方法描述\": \"通过三轮专家咨询建立医患信任度评估指标体系\",\n \"适用研究问题\": [\"Q4\"]\n },\n {\n \"方法名称\": \"技术成熟度预测\",\n \"方法描述\": \"应用Gartner曲线模型评估跨模态学习技术的商业化成熟度\",\n \"适用研究问题\": [\"Q5\"]\n }\n ],\n \"预期成果\": [\n \"发布医疗AI应用场景全景图与典型案例库\",\n \"建立AI诊断临床验证标准评估框架\",\n \"提出联邦学习优化方案与数据治理政策建议\",\n \"形成AI医疗伦理责任分配指南\",\n \"构建跨模态诊疗创新应用预测模型\"\n ]\n}", "error": "Missing required fields: dict_keys(['研究标题', '研究目标', '研究问题', '方法论', '预期成果'])", "fixed_text": null, "fixed": false}
{"timestamp": "2025-07-02T10:22:25.307986", "session_id": "1d3103b4-05e4-4f7f-bac3-df18318667c4", "type": "json_parse_error", "raw_text": "{\n \"研究主题\": \"分析当前人工智能在医疗诊断领域的应用现状和未来发展趋势\",\n \"研究类型\": \"探索性研究\",\n \"细化问题\": [\n \"当前人工智能在医疗诊断中的主要应用集中在哪些医学领域(如影像识别、病理分析、基因组学等)?有哪些具体案例或代表性研究成果?\",\n \"人工智能医疗诊断系统在临床应用中的准确性与传统诊断方法相比存在哪些技术瓶颈或验证挑战?需要哪些数据支撑和技术突破?\",\n \"医疗数据隐私保护政策和数据孤岛问题如何影响AI诊断模型的开发与部署现有解决方案如联邦学习的实践效果如何\",\n \"医生和患者对AI辅助诊断的接受度及信任度存在哪些差异伦理责任划分和算法透明性要求如何影响技术落地\",\n \"未来5-10年内AI在个性化诊疗和罕见病筛查领域的发展潜力如何跨模态医学数据融合会带来哪些新型应用场景\"\n ],\n \"研究框架\": {\n \"现状梳理\": {\n \"方法\": [\"文献计量法\", \"案例分析法\"],\n \"内容\": [\n \"通过文献计量法分析AI在医疗诊断领域的研究热点与趋势\",\n \"构建技术应用图谱涵盖影像识别如DeepMind眼科诊断、病理分析如IBM Watson肿瘤分析、基因组学等领域的代表性案例\"\n ]\n },\n \"瓶颈分析\": {\n \"方法\": [\"对比分析法\", \"政策文本分析\", \"案例评估\"],\n \"内容\": [\n \"对比AI与传统诊断方法在准确性、敏感性等指标上的临床验证数据\",\n \"从数据质量(标注一致性、多中心数据差异)、算法可解释性等维度剖析技术瓶颈\",\n \"评估医疗数据隐私政策如GDPR、HIPAA对模型开发的影响\",\n \"通过联邦学习案例如FATE框架分析数据孤岛问题的解决方案效能\"\n ]\n },\n \"趋势预测\": {\n \"方法\": [\"德尔菲法\", \"技术成熟度曲线分析\", \"跨模态学习进展评估\"],\n \"内容\": [\n \"通过德尔菲法整合专家意见,预测个性化诊疗与罕见病筛查的技术发展路径\",\n \"基于技术成熟度曲线Gartner曲线分析AI诊断技术的商业化落地阶段\",\n \"探讨跨模态医学数据融合(如影像-文本-生物信号)带来的新型应用场景\"\n ]\n },\n \"人文伦理研究\": {\n \"方法\": [\"问卷调查\", \"伦理框架分析\"],\n \"内容\": [\n \"设计医患双群体问卷量化AI辅助诊断的接受度差异及信任度影响因素\",\n \"结合伦理责任矩阵模型,探讨诊断错误时的责任划分机制\",\n \"从算法透明性需求出发提出可解释性AIXAI的临床部署建议\"\n ]\n }\n },\n \"预期成果\": [\n \"绘制AI医疗诊断技术应用热力图谱及技术成熟度评估矩阵\",\n \"建立涵盖数据、算法、临床验证的AI诊断瓶颈分析框架\",\n \"提出跨机构数据协作的联邦学习优化路径与隐私计算部署方案\",\n \"形成医患双视角的AI诊断接受度评估报告及伦理责任白皮书\",\n \"预测跨模态融合驱动的三类新型应用场景及罕见病筛查技术路线图\"\n ]\n}\n```", "error": "Missing required fields: dict_keys(['研究主题', '研究类型', '细化问题', '研究框架', '预期成果'])", "fixed_text": null, "fixed": false}
{"timestamp": "2025-07-02T10:22:43.652864", "session_id": "1d3103b4-05e4-4f7f-bac3-df18318667c4", "type": "json_parse_error", "raw_text": "{\n \"研究标题\": \"人工智能在医疗诊断领域的应用现状、技术瓶颈与未来趋势研究\",\n \"研究目标\": [\n \"系统梳理AI在医疗诊断各细分领域的技术应用现状\",\n \"识别AI医疗诊断系统在临床验证中的关键技术瓶颈\",\n \"评估数据隐私与伦理问题对技术落地的实际影响\",\n \"预测跨模态学习等新兴技术在诊疗场景中的发展潜力\",\n \"构建兼顾技术创新与伦理规范的AI医疗发展框架\"\n ],\n \"主要研究问题\": [\n \"当前AI医疗诊断的核心技术路径与典型应用场景分布\",\n \"深度学习模型在医学影像分析中的敏感性与特异性表现\",\n \"异构医疗数据标准化处理与多中心验证的实践难点\",\n \"联邦学习在保护患者隐私前提下的模型泛化能力验证\",\n \"临床决策支持系统的人机协同机制与责任认定边界\",\n \"多组学数据融合对精准诊疗方案制定的赋能效应\"\n ],\n \"研究方法\": [\n \"文献计量分析Web of Science核心论文聚类\",\n \"多中心临床数据对比实验ROC曲线与SHAP值分析\",\n \"政策文本主题建模LDA算法\",\n \"联邦学习仿真实验FATE平台跨机构建模\",\n \"医患双盲对照试验Likert量表与SEM模型\",\n \"技术成熟度评估Gartner曲线修正模型\"\n ],\n \"数据来源\": [\n \"PubMed/Embase收录的AI诊断相关临床研究论文\",\n \"NIH临床影像数据库CT/MRI标准化数据集\",\n \"中国卫生健康统计年鉴2018-2023\",\n \"欧盟GDPR与美国HIPAA政策白皮书\",\n \"腾讯觅影、推想医疗等企业技术白皮书\",\n \"三甲医院多模态诊疗数据集(脱敏处理)\"\n ],\n \"预期成果\": [\n \"绘制AI医疗诊断技术应用热力图谱含技术成熟度分级\",\n \"建立AI诊断模型临床验证的量化评估指标体系\",\n \"提出基于区块链的分布式医疗数据共享框架\",\n \"设计可解释性增强的临床决策支持系统架构\",\n \"发布医疗AI伦理治理的跨学科共识指南\",\n \"预测跨模态学习驱动的五大创新应用场景\"\n ]\n}", "error": "Missing required fields: dict_keys(['研究标题', '研究目标', '主要研究问题', '研究方法', '数据来源', '预期成果'])", "fixed_text": null, "fixed": false}
{"timestamp": "2025-07-02T10:22:43.654964", "session_id": "1d3103b4-05e4-4f7f-bac3-df18318667c4", "type": "json_parse_error", "raw_text": "{\n \"研究标题\": \"人工智能在医疗诊断领域的应用现状、技术瓶颈与未来趋势研究\",\n \"研究目标\": [\n \"系统梳理AI在医疗诊断各细分领域的技术应用现状\",\n \"识别AI医疗诊断系统在临床验证中的关键技术瓶颈\",\n \"评估数据隐私与伦理问题对技术落地的实际影响\",\n \"预测跨模态学习等新兴技术在诊疗场景中的发展潜力\",\n \"构建兼顾技术创新与伦理规范的AI医疗发展框架\"\n ],\n \"主要研究问题\": [\n \"当前AI医疗诊断的核心技术路径与典型应用场景分布\",\n \"深度学习模型在医学影像分析中的敏感性与特异性表现\",\n \"异构医疗数据标准化处理与多中心验证的实践难点\",\n \"联邦学习在保护患者隐私前提下的模型泛化能力验证\",\n \"临床决策支持系统的人机协同机制与责任认定边界\",\n \"多组学数据融合对精准诊疗方案制定的赋能效应\"\n ],\n \"研究方法\": [\n \"文献计量分析Web of Science核心论文聚类\",\n \"多中心临床数据对比实验ROC曲线与SHAP值分析\",\n \"政策文本主题建模LDA算法\",\n \"联邦学习仿真实验FATE平台跨机构建模\",\n \"医患双盲对照试验Likert量表与SEM模型\",\n \"技术成熟度评估Gartner曲线修正模型\"\n ],\n \"数据来源\": [\n \"PubMed/Embase收录的AI诊断相关临床研究论文\",\n \"NIH临床影像数据库CT/MRI标准化数据集\",\n \"中国卫生健康统计年鉴2018-2023\",\n \"欧盟GDPR与美国HIPAA政策白皮书\",\n \"腾讯觅影、推想医疗等企业技术白皮书\",\n \"三甲医院多模态诊疗数据集(脱敏处理)\"\n ],\n \"预期成果\": [\n \"绘制AI医疗诊断技术应用热力图谱含技术成熟度分级\",\n \"建立AI诊断模型临床验证的量化评估指标体系\",\n \"提出基于区块链的分布式医疗数据共享框架\",\n \"设计可解释性增强的临床决策支持系统架构\",\n \"发布医疗AI伦理治理的跨学科共识指南\",\n \"预测跨模态学习驱动的五大创新应用场景\"\n ]\n}", "error": "Failed to parse after 3 attempts, using default outline", "fixed_text": "{\"main_topic\": \"分析当前人工智能在医疗诊断领域的应用现状和未来发展趋势\", \"research_questions\": [\"当前人工智能在医疗诊断中的主要应用集中在哪些医学领域(如影像识别、病理分析、基因组学等)?有哪些具体案例或代表性研究成果?\", \"人工智能医疗诊断系统在临床应用中的准确性与传统诊断方法相比存在哪些技术瓶颈或验证挑战?需要哪些数据支撑和技术突破?\", \"医疗数据隐私保护政策和数据孤岛问题如何影响AI诊断模型的开发与部署现有解决方案如联邦学习的实践效果如何\"], \"sub_topics\": [{\"topic\": \"主要方面分析\", \"explain\": \"针对问题的核心方面进行深入分析\", \"priority\": \"high\", \"related_questions\": [\"当前人工智能在医疗诊断中的主要应用集中在哪些医学领域(如影像识别、病理分析、基因组学等)?有哪些具体案例或代表性研究成果?\", \"人工智能医疗诊断系统在临床应用中的准确性与传统诊断方法相比存在哪些技术瓶颈或验证挑战?需要哪些数据支撑和技术突破?\"]}]}", "fixed": true}

0
data/reports/.gitkeep Normal file
View File

0
data/sessions/.gitkeep Normal file
View File

17
requirements.txt Normal file
View File

@ -0,0 +1,17 @@
Flask==3.0.0
Flask-CORS==4.0.0
Flask-SocketIO==5.3.5
python-socketio==5.10.0
python-dotenv==1.0.0
requests==2.31.0
openai>=1.0.0,<2.0.0
tavily-python==0.5.0
celery==5.3.4
redis==5.0.1
pymongo==4.6.1
pydantic==2.5.3
python-json-logger==2.0.7
pytest==7.4.4
pytest-asyncio==0.23.3
gunicorn==21.2.0
eventlet==0.33.3

172
scripts/init_db.py Executable file
View File

@ -0,0 +1,172 @@
#!/usr/bin/env python3
"""
初始化数据库和目录结构
"""
import os
import sys
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import Config
def init_directories():
"""初始化必要的目录"""
directories = [
Config.DATA_DIR,
Config.SESSIONS_DIR,
Config.REPORTS_DIR,
Config.CACHE_DIR,
Config.LOG_DIR
]
for directory in directories:
if not os.path.exists(directory):
os.makedirs(directory)
print(f"创建目录: {directory}")
else:
print(f"目录已存在: {directory}")
# 创建.gitkeep文件
for directory in [Config.SESSIONS_DIR, Config.REPORTS_DIR, Config.CACHE_DIR]:
gitkeep_path = os.path.join(directory, '.gitkeep')
if not os.path.exists(gitkeep_path):
with open(gitkeep_path, 'w') as f:
f.write('')
print(f"创建.gitkeep: {gitkeep_path}")
def init_mongodb():
"""初始化MongoDB如果使用"""
try:
from pymongo import MongoClient
client = MongoClient(Config.MONGODB_URI)
db = client.get_database()
# 创建集合和索引
collections = {
'sessions': [
('created_at', -1),
('status', 1),
('question_type', 1)
],
'search_results': [
('session_id', 1),
('subtopic_id', 1),
('created_at', -1)
],
'reports': [
('session_id', 1),
('created_at', -1)
]
}
for collection_name, indexes in collections.items():
collection = db[collection_name]
for index in indexes:
if isinstance(index, tuple):
collection.create_index([index])
else:
collection.create_index(index)
print(f"初始化集合: {collection_name}")
print("MongoDB初始化完成")
except Exception as e:
print(f"MongoDB初始化失败可选: {e}")
def check_environment():
"""检查环境变量"""
required_vars = [
'DEEPSEEK_API_KEY',
'TAVILY_API_KEY'
]
missing_vars = []
for var in required_vars:
if not os.environ.get(var):
missing_vars.append(var)
if missing_vars:
print("\n警告: 缺少以下环境变量:")
for var in missing_vars:
print(f" - {var}")
print("\n请在.env文件中设置这些变量")
else:
print("\n环境变量检查通过")
def test_task_manager():
"""测试任务管理器"""
print("\n测试任务管理器...")
try:
from app.services.task_manager import task_manager
# 测试任务提交
def test_task():
return "Task manager is working!"
task_id = task_manager.submit_task(test_task)
print(f"✓ 任务管理器正常工作测试任务ID: {task_id}")
# 关闭任务管理器
task_manager.shutdown()
except Exception as e:
print(f"✗ 任务管理器测试失败: {e}")
def create_test_data():
"""创建测试数据(开发环境)"""
if os.environ.get('FLASK_ENV') == 'development':
print("\n开发环境:创建测试数据...")
# 创建一个示例会话文件
sample_session = {
"id": "test-session-001",
"question": "这是一个测试研究问题",
"status": "completed",
"created_at": "2024-01-01T00:00:00"
}
import json
test_file = os.path.join(Config.SESSIONS_DIR, 'test-session-001.json')
if not os.path.exists(test_file):
with open(test_file, 'w', encoding='utf-8') as f:
json.dump(sample_session, f, ensure_ascii=False, indent=2)
print(f"创建测试会话文件: {test_file}")
def main():
"""主函数"""
print("DeepResearch 初始化脚本")
print("=" * 50)
# 初始化目录
print("\n1. 初始化目录结构...")
init_directories()
# 初始化MongoDB可选
print("\n2. 初始化MongoDB...")
init_mongodb()
# 检查环境变量
print("\n3. 检查环境变量...")
check_environment()
# 测试任务管理器
print("\n4. 测试任务管理器...")
test_task_manager()
# 创建测试数据
print("\n5. 创建测试数据...")
create_test_data()
print("\n" + "=" * 50)
print("初始化完成!")
print("\n下一步:")
print("1. 确保在.env文件中设置了必要的API密钥")
print("2. 运行 'python app.py' 启动应用")
print("\n注意: 不再需要启动 Redis 和 Celery Worker")
if __name__ == '__main__':
main()

279
scripts/test_api_keys.py Executable file
View File

@ -0,0 +1,279 @@
#!/usr/bin/env python3
"""
测试API密钥是否有效
"""
import os
import sys
from dotenv import load_dotenv
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 加载环境变量
load_dotenv()
def test_deepseek_api():
"""测试DeepSeek API"""
print("\n测试 DeepSeek API...")
api_key = os.environ.get('DEEPSEEK_API_KEY')
if not api_key:
print("❌ 错误: 未设置 DEEPSEEK_API_KEY")
return False
try:
from openai import OpenAI
base_url = os.environ.get('DEEPSEEK_BASE_URL', 'https://api.deepseek.com/v1')
# 检测是否是火山引擎
if 'volces.com' in base_url:
print(" 检测到火山引擎 ARK 平台")
r1_model = "deepseek-r1-250120"
v3_model = "deepseek-v3-241226"
else:
r1_model = "deepseek-reasoner"
v3_model = "deepseek-chat"
client = OpenAI(
api_key=api_key,
base_url=base_url
)
# 测试V3模型先测试V3因为它更稳定
print(f" 测试 V3 模型 ({v3_model})...")
try:
response = client.chat.completions.create(
model=v3_model,
messages=[{"role": "user", "content": "Hello, this is a test. Reply with OK."}],
max_tokens=10
)
if response.choices[0].message.content:
print(" ✓ V3 模型测试成功")
else:
print(" ❌ V3 模型响应异常")
return False
except Exception as e:
print(f" ❌ V3 模型测试失败: {e}")
# 如果是火山引擎可能需要使用endpoint ID
if 'volces.com' in base_url:
print(" 提示: 火山引擎可能需要使用自定义的 endpoint ID (如 ep-xxxxx)")
return False
# 测试R1模型
print(f" 测试 R1 模型 ({r1_model})...")
try:
response = client.chat.completions.create(
model=r1_model,
messages=[{"role": "user", "content": "Hello, this is a test. Reply with OK."}],
max_tokens=10
)
if response.choices[0].message.content:
print(" ✓ R1 模型测试成功")
else:
print(" ❌ R1 模型响应异常")
# R1失败不影响整体因为V3可以工作
except Exception as e:
print(f" ⚠️ R1 模型测试失败: {e}")
print(" 注意: R1模型可能需要特殊配置或不可用")
print("✅ DeepSeek API 测试通过")
return True
except Exception as e:
print(f"❌ DeepSeek API 测试失败: {e}")
return False
def test_tavily_api():
"""测试Tavily API"""
print("\n测试 Tavily API...")
api_key = os.environ.get('TAVILY_API_KEY')
if not api_key:
print("❌ 错误: 未设置 TAVILY_API_KEY")
return False
try:
from tavily import TavilyClient
client = TavilyClient(api_key=api_key)
# 执行测试搜索
print(" 执行测试搜索...")
response = client.search("test query", max_results=1)
if response and 'results' in response:
print(f" ✓ 搜索返回 {len(response['results'])} 条结果")
print("✅ Tavily API 测试通过")
return True
else:
print(" ❌ 搜索响应异常")
return False
except Exception as e:
print(f"❌ Tavily API 测试失败: {e}")
return False
def test_task_manager():
"""测试任务管理器"""
print("\n测试任务管理器...")
try:
from app.services.task_manager import task_manager
# 测试提交任务
def test_func(x):
return x * 2
task_id = task_manager.submit_task(test_func, 5)
print(f" ✓ 任务提交成功: {task_id}")
# 等待任务完成
import time
time.sleep(1)
# 检查任务状态
status = task_manager.get_task_status(task_id)
if status and status['status'] == 'completed':
print(" ✓ 任务执行成功")
print("✅ 任务管理器测试通过")
return True
else:
print(f" ❌ 任务状态异常: {status}")
return False
except Exception as e:
print(f"❌ 任务管理器测试失败: {e}")
return False
def test_mongodb_connection():
"""测试MongoDB连接可选"""
print("\n测试 MongoDB 连接(可选)...")
mongodb_uri = os.environ.get('MONGODB_URI', 'mongodb://localhost:27017/deepresearch')
try:
from pymongo import MongoClient
client = MongoClient(mongodb_uri, serverSelectionTimeoutMS=5000)
# 测试连接
client.server_info()
print("✅ MongoDB 连接测试通过")
# 测试数据库操作
db = client.get_database()
test_collection = db['test_collection']
# 插入测试文档
result = test_collection.insert_one({'test': 'document'})
# 查询测试文档
doc = test_collection.find_one({'_id': result.inserted_id})
# 删除测试文档
test_collection.delete_one({'_id': result.inserted_id})
if doc and doc['test'] == 'document':
print(" ✓ MongoDB 读写测试通过")
return True
else:
print(" ❌ MongoDB 读写测试失败")
return False
except Exception as e:
print(f"⚠️ MongoDB 连接测试失败(可选): {e}")
print(" 提示: MongoDB是可选的不影响基本功能")
return False
def check_python_version():
"""检查Python版本"""
print("\n检查 Python 版本...")
version = sys.version_info
if version.major == 3 and version.minor >= 8:
print(f"✅ Python 版本: {version.major}.{version.minor}.{version.micro}")
return True
else:
print(f"❌ Python 版本过低: {version.major}.{version.minor}.{version.micro}")
print(" 需要 Python 3.8 或更高版本")
return False
def check_dependencies():
"""检查依赖包"""
print("\n检查依赖包...")
required_packages = [
'flask',
'flask_cors',
'flask_socketio',
'openai',
'tavily',
'pydantic',
'python-dotenv'
]
missing_packages = []
for package in required_packages:
try:
__import__(package)
except ImportError:
missing_packages.append(package)
if missing_packages:
print(f"❌ 缺少以下依赖包: {', '.join(missing_packages)}")
print(" 请运行: pip install -r requirements.txt")
return False
else:
print("✅ 所有必需的依赖包已安装")
return True
def main():
"""主函数"""
print("=" * 60)
print("DeepResearch API 密钥测试工具")
print("=" * 60)
all_passed = True
# 检查Python版本
if not check_python_version():
all_passed = False
# 检查依赖包
if not check_dependencies():
all_passed = False
# 测试DeepSeek API
if not test_deepseek_api():
all_passed = False
# 测试Tavily API
if not test_tavily_api():
all_passed = False
# 测试任务管理器
if not test_task_manager():
all_passed = False
# 测试MongoDB连接可选
test_mongodb_connection() # 不影响整体结果
print("\n" + "=" * 60)
if all_passed:
print("✅ 所有必需的测试都已通过!")
print("\n您可以运行以下命令启动应用:")
print("1. 启动应用: python app.py")
print("\n注意: 不再需要启动 Celery Worker 和 Redis")
else:
print("❌ 有些测试未通过,请检查上述错误信息")
print("\n常见问题:")
print("1. 确保在.env文件中正确设置了API密钥")
print("2. 检查网络连接是否正常")
if __name__ == '__main__':
main()

17
start.sh Executable file
View File

@ -0,0 +1,17 @@
#!/bin/bash
# 启动 DeepResearch 应用
# 进入项目目录
cd /Users/jojo/Desktop/deepresearch
# 激活虚拟环境
source venv/bin/activate
# 清理旧的日志
echo "清理旧日志..."
rm -f logs/*.log
# 启动应用
echo "启动 DeepResearch..."
python3 app.py

0
tests/__init__.py Normal file
View File

0
tests/conftest.py Normal file
View File

152
一 准备工作.ini Normal file
View File

@ -0,0 +1,152 @@
一 准备工作
1.用户输入问题
ai判断 r1
事实查询型:需要具体、准确的信息 2025
分析对比型:需要多角度分析和比较 对比M4Max的macbookpro和5090幻16air谁是最强全能本
探索发现型:需要广泛探索未知领域 如何让小参数大语言模型进行稳定的结构化输出
决策支持型:需要综合分析支持决策 500万元怎么做投资
2.ai输出 询问一些具体细节,比如还需要哪些问题相关信息,应该更关注问题的什么方面等数个问题 r1
3.ai根据上一问的结果和初始问题开始研究 r1
二 制定大纲
1.ai开始制定大纲json格式 r1
包括主研究问题 研究的数个分点
research_plan = {
"main_topic": "用户输入的主题",
"research_questions": [
"核心问题1",
"核心问题2",
...
],
"sub_topics1": [
{
"topic": "子主题1", #可以是核心问题,也可以是其拓展
"explain":"子主题1的简单解释"
"priority": "高/中/低"
}
],
"sub_topics2": [
{
"topic": "子主题2", #可以是核心问题,也可以是其拓展
"explain":
"priority": "高/中/低"
}
],
...
}
2.ai进行第一轮搜索看看自己的这些核心问题和子主题是否需要补充更多/需要更改/替换 r1
3.输出给用户大纲,是否满意,用户给出修改意见
4.ai再次搜索修改大纲再次询问 r1
5.直到满意为止(用户手动触发)
三 开展研究
1.开始搜索各个子主题(并发)
创建多个子智能体,每个智能体负责一个子主题
子智能体:
根据priority确定搜索数量15/10/5根据子主题确定搜索内容
搜索使用tavily的api服务 v3 其返回结果类似
```markdown
AI摘要
Ollama is a local framework for running large language models, and Open WebUI is a user-friendly front-end interface. It supports various models and can be deployed locally.
搜索结果:
1. 使用这些前端工具与Ollama一起工作 - 知乎专栏
来源: https://zhuanlan.zhihu.com/p/699544176
大型语言模型LLMs是可以做到这一点的强大工具但使用它们可能会很棘手。本文探讨了三种用户友好的界面——Msty、Open WebUI 和Page Assist——它们使通过...
2. 如何与Ollama 一起在本地运行大语言模型 - 知乎专栏
来源: https://zhuanlan.zhihu.com/p/690122036
它支持许多可以运行的模型,不断添加新模型,并且您也可以引入自定义模型。 它有一个优秀的第3 方开源前端名为“Ollama Web-UI”您可以使用。 它支持多模式模型。...
3. 带你认识本地大语言模型框架Ollama(可直接上手)
来源: https://wiki.eryajf.net/pages/97047e/
Copyright 二丫讲梵(opens new window) 版权所有
```
根据搜索结果,给每个搜索结果评定重要性 r1 高/中/低
评判依据 主题匹配度 问题覆盖度 信息新颖度
储存进后台json文件中
四 信息反思
根据每个子主题已获取的信息,进行反思 r1
告知ai这样思考
<think>
好的,现在需要梳理已获取的信息
已获得信息总结:
来源1 [关键发现1]
从这个信息中可知:
来源2 [关键发现2]
...
让我再仔细思考总结一下是否有哪些信息非常重要,需要更细节的内容
1.由于...来源1的信息是重点
2.由于...来源2的信息是重点
</think>
以下信息还需要进一步获取细节内容,以下是信息和具体在什么细节搜索
[重点信息1]还需要搜索(细节)
[重点信息2]还需要搜索...
五 二次搜索
根据[重点信息]和需要搜索的细节使用tavily的api服务进行搜索 v3
结果储存进json文件位置要在原本的信息源后面并说明 这是一个重要信息,以及(细节)是什么
四 子主题信息整合
根据获取的信息,信息的重要性 将每个子主题的信息点结构化
(这我暂时想不到怎么设计结构和内容)
五 子主题撰写
输入所有这个子主题的整合信息
r1
撰写要求格式如下
一、[主要发现1]
1.1 [子标题]
[内容](来源:[具体URL]
二、[主要发现2]
2.1 [子标题]
[内容](来源:[具体URL]
三、关键洞察
1. **[洞察1]**:基于[来源URL]的数据显示...
2. **[洞察2]**:根据[来源URL]的分析...
四、建议与展望
[基于研究的可执行建议]
六 幻觉判断
根据子主题撰写的内容进行幻觉消除
具体步骤为根据撰写内容中的所有每个url和对应的内容看看能否找到该子信息在搜索结果中的出处
如果找到,就把两条内容都摘出来 交给ai r1
让ai判断是否有任何幻觉内容 如果有,标记 没有,通过
七 幻觉内容重新撰写
给ai每一条被标记为幻觉的内容 以及原始搜索材料
重新撰写这一部分,直接替换原内容 v3
八 最终材料报告
根据每一个子主题的报告,写汇总报告 格式要求还是必须要url的那种 r1
在报告的最后直接插入每个子主题的报告

52
所有文件/.gitignore vendored Normal file
View File

@ -0,0 +1,52 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
env/
venv/
ENV/
.venv
# Flask
instance/
.webassets-cache
# Environment variables
.env
.env.local
.env.*.local
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# Logs
logs/
*.log
# Data
data/sessions/*
data/reports/*
data/cache/*
!data/sessions/.gitkeep
!data/reports/.gitkeep
!data/cache/.gitkeep
# Testing
.pytest_cache/
.coverage
htmlcov/
.tox/
# OS
.DS_Store
Thumbs.db
# Celery
celerybeat-schedule
celerybeat.pid

0
所有文件/README.md Normal file
View File

0
所有文件/__init__.py Normal file
View File

307
所有文件/ai_service.py Normal file
View File

@ -0,0 +1,307 @@
"""
AI服务层
封装对R1和V3智能体的调用
"""
import logging
from typing import Dict, List, Any, Optional, Tuple
from app.agents.r1_agent import R1Agent
from app.agents.v3_agent import V3Agent
from app.models.search_result import SearchResult, SearchImportance
from config import Config
logger = logging.getLogger(__name__)
class AIService:
"""AI服务统一接口"""
def __init__(self):
self.r1_agent = R1Agent()
self.v3_agent = V3Agent()
# ========== 问题分析阶段 (R1) ==========
def analyze_question_type(self, question: str) -> str:
"""分析问题类型"""
try:
return self.r1_agent.analyze_question_type(question)
except Exception as e:
logger.error(f"分析问题类型失败: {e}")
return "exploratory" # 默认值
def refine_questions(self, question: str, question_type: str) -> List[str]:
"""细化问题"""
try:
return self.r1_agent.refine_questions(question, question_type)
except Exception as e:
logger.error(f"细化问题失败: {e}")
return [question] # 返回原问题
def create_research_approach(self, question: str, question_type: str,
refined_questions: List[str]) -> str:
"""制定研究思路"""
try:
return self.r1_agent.create_research_approach(
question, question_type, refined_questions
)
except Exception as e:
logger.error(f"制定研究思路失败: {e}")
return "采用系统化的方法深入研究这个问题。"
# ========== 大纲制定阶段 (R1) ==========
def create_outline(self, question: str, question_type: str,
refined_questions: List[str], research_approach: str) -> Dict[str, Any]:
"""创建研究大纲"""
try:
return self.r1_agent.create_outline(
question, question_type, refined_questions, research_approach
)
except Exception as e:
logger.error(f"创建大纲失败: {e}")
# 返回基本大纲
return {
"main_topic": question,
"research_questions": refined_questions[:3],
"sub_topics": [
{
"topic": "核心分析",
"explain": "对问题进行深入分析",
"priority": "high",
"related_questions": refined_questions[:2]
}
]
}
def validate_outline(self, outline: Dict[str, Any]) -> str:
"""验证大纲"""
try:
return self.r1_agent.validate_outline(outline)
except Exception as e:
logger.error(f"验证大纲失败: {e}")
return "大纲结构合理。"
def modify_outline(self, original_outline: Dict[str, Any],
user_feedback: str, validation_issues: str = "") -> Dict[str, Any]:
"""修改大纲"""
try:
return self.r1_agent.modify_outline(
original_outline, user_feedback, validation_issues
)
except Exception as e:
logger.error(f"修改大纲失败: {e}")
return original_outline
# ========== 搜索阶段 (V3 + R1) ==========
def generate_search_queries(self, subtopic: str, explanation: str,
related_questions: List[str], priority: str) -> List[str]:
"""生成搜索查询V3"""
# 根据优先级确定搜索数量
count_map = {
"high": Config.MAX_SEARCHES_HIGH_PRIORITY,
"medium": Config.MAX_SEARCHES_MEDIUM_PRIORITY,
"low": Config.MAX_SEARCHES_LOW_PRIORITY
}
count = count_map.get(priority, 10)
try:
return self.v3_agent.generate_search_queries(
subtopic, explanation, related_questions, count
)
except Exception as e:
logger.error(f"生成搜索查询失败: {e}")
# 返回基本查询
return [subtopic, f"{subtopic} {explanation}"][:count]
def evaluate_search_results(self, subtopic: str,
search_results: List[SearchResult]) -> List[SearchResult]:
"""评估搜索结果重要性R1"""
evaluated_results = []
for result in search_results:
try:
importance = self.r1_agent.evaluate_search_result(
subtopic,
result.title,
result.url,
result.snippet
)
result.importance = SearchImportance(importance)
evaluated_results.append(result)
except Exception as e:
logger.error(f"评估搜索结果失败: {e}")
result.importance = SearchImportance.MEDIUM
evaluated_results.append(result)
return evaluated_results
# ========== 信息反思阶段 (R1) ==========
def reflect_on_information(self, subtopic: str,
search_results: List[SearchResult]) -> List[Dict[str, str]]:
"""信息反思,返回需要深入的要点"""
# 生成搜索摘要
summary = self._generate_search_summary(search_results)
try:
return self.r1_agent.reflect_on_information(subtopic, summary)
except Exception as e:
logger.error(f"信息反思失败: {e}")
return []
def generate_refined_queries(self, key_points: List[Dict[str, str]]) -> Dict[str, List[str]]:
"""为关键点生成细化查询V3"""
refined_queries = {}
for point in key_points:
try:
queries = self.v3_agent.generate_refined_queries(
point["key_info"],
point["detail_needed"]
)
refined_queries[point["key_info"]] = queries
except Exception as e:
logger.error(f"生成细化查询失败: {e}")
refined_queries[point["key_info"]] = [point["key_info"]]
return refined_queries
# ========== 信息整合阶段 (R1) ==========
def integrate_information(self, subtopic: str,
all_search_results: List[SearchResult]) -> Dict[str, Any]:
"""整合信息"""
# 格式化搜索结果
formatted_results = self._format_search_results_for_integration(all_search_results)
try:
return self.r1_agent.integrate_information(subtopic, formatted_results)
except Exception as e:
logger.error(f"整合信息失败: {e}")
# 返回基本结构
return {
"key_points": [],
"themes": []
}
# ========== 报告撰写阶段 (R1) ==========
def write_subtopic_report(self, subtopic: str, integrated_info: Dict[str, Any]) -> str:
"""撰写子主题报告"""
try:
return self.r1_agent.write_subtopic_report(subtopic, integrated_info)
except Exception as e:
logger.error(f"撰写子主题报告失败: {e}")
return f"## {subtopic}\n\n撰写报告时发生错误。"
# ========== 幻觉检测阶段 (R1 + V3) ==========
def detect_and_fix_hallucinations(self, report: str,
original_sources: Dict[str, str]) -> Tuple[str, List[Dict]]:
"""检测并修复幻觉内容"""
hallucinations = []
fixed_report = report
# 提取报告中的所有URL引用
url_references = self._extract_url_references(report)
for url, content in url_references.items():
if url in original_sources:
try:
# 检测幻觉R1
result = self.r1_agent.detect_hallucination(
content, url, original_sources[url]
)
if result.get("is_hallucination", False):
hallucinations.append({
"url": url,
"content": content,
"type": result.get("hallucination_type", "未知"),
"explanation": result.get("explanation", "")
})
# 重写内容V3
try:
new_content = self.v3_agent.rewrite_hallucination(
content, original_sources[url]
)
fixed_report = fixed_report.replace(content, new_content)
except Exception as e:
logger.error(f"重写幻觉内容失败: {e}")
except Exception as e:
logger.error(f"检测幻觉失败: {e}")
return fixed_report, hallucinations
# ========== 最终报告阶段 (R1) ==========
def generate_final_report(self, main_topic: str, research_questions: List[str],
subtopic_reports: Dict[str, str]) -> str:
"""生成最终报告"""
try:
return self.r1_agent.generate_final_report(
main_topic, research_questions, subtopic_reports
)
except Exception as e:
logger.error(f"生成最终报告失败: {e}")
# 返回基本报告
reports_text = "\n\n---\n\n".join(subtopic_reports.values())
return f"# {main_topic}\n\n## 研究报告\n\n{reports_text}"
# ========== 辅助方法 ==========
def _generate_search_summary(self, search_results: List[SearchResult]) -> str:
"""生成搜索结果摘要"""
high_count = sum(1 for r in search_results if r.importance == SearchImportance.HIGH)
medium_count = sum(1 for r in search_results if r.importance == SearchImportance.MEDIUM)
low_count = sum(1 for r in search_results if r.importance == SearchImportance.LOW)
summary_lines = [
f"共找到 {len(search_results)} 条搜索结果",
f"高重要性: {high_count}",
f"中重要性: {medium_count}",
f"低重要性: {low_count}",
"",
"主要发现:"
]
# 添加高重要性结果的摘要
for result in search_results[:10]: # 最多10条
if result.importance == SearchImportance.HIGH:
summary_lines.append(f"- {result.title}: {result.snippet[:100]}...")
return '\n'.join(summary_lines)
def _format_search_results_for_integration(self, search_results: List[SearchResult]) -> str:
"""格式化搜索结果用于整合"""
formatted_lines = []
for i, result in enumerate(search_results, 1):
formatted_lines.extend([
f"{i}. 来源: {result.url}",
f" 标题: {result.title}",
f" 内容: {result.snippet}",
f" 重要性: {result.importance.value if result.importance else '未评估'}",
""
])
return '\n'.join(formatted_lines)
def _extract_url_references(self, report: str) -> Dict[str, str]:
"""从报告中提取URL引用及其对应内容"""
# 简单实现,实际可能需要更复杂的解析
import re
url_references = {}
# 匹配模式: 内容来源URL
pattern = r'([^]+)(来源:([^]+)'
matches = re.finditer(pattern, report)
for match in matches:
content = match.group(1).strip()
url = match.group(2).strip()
url_references[url] = content
return url_references

51
所有文件/api.js Normal file
View File

@ -0,0 +1,51 @@
// app/static/js/api.js
const API_BASE = '/api';
const api = {
// 创建研究
createResearch: async (question) => {
const response = await fetch(`${API_BASE}/research`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
question: question,
auto_start: true
})
});
return response.json();
},
// 获取会话列表
getSessions: async (limit = 20, offset = 0) => {
const response = await fetch(`${API_BASE}/research/sessions?limit=${limit}&offset=${offset}`);
return response.json();
},
// 获取会话状态
getSessionStatus: async (sessionId) => {
const response = await fetch(`${API_BASE}/research/${sessionId}/status`);
return response.json();
},
// 获取研究大纲
getOutline: async (sessionId) => {
const response = await fetch(`${API_BASE}/research/${sessionId}/outline`);
return response.json();
},
// 取消研究
cancelResearch: async (sessionId) => {
const response = await fetch(`${API_BASE}/research/${sessionId}/cancel`, {
method: 'POST'
});
return response.json();
},
// 下载报告
downloadReport: async (sessionId) => {
window.open(`${API_BASE}/research/${sessionId}/report?format=markdown`, '_blank');
}
};

328
所有文件/api.py Normal file
View File

@ -0,0 +1,328 @@
# 文件位置: app/routes/api.py
# 文件名: api.py
"""
API路由
处理研究相关的API请求
"""
from flask import Blueprint, request, jsonify, current_app, send_file
from app.services.research_manager import ResearchManager
from app.services.task_manager import task_manager
from app.utils.validators import validate_question, validate_outline_feedback
import os
api_bp = Blueprint('api', __name__)
research_manager = ResearchManager()
@api_bp.route('/research', methods=['POST'])
def create_research():
"""创建新的研究任务"""
try:
data = request.get_json()
# 验证输入
question = data.get('question', '').strip()
error = validate_question(question)
if error:
return jsonify({"error": error}), 400
# 创建研究会话
session = research_manager.create_session(question)
# 自动开始研究(可选)
auto_start = data.get('auto_start', True)
if auto_start:
result = research_manager.start_research(session.id)
return jsonify({
"session_id": session.id,
"status": "started",
"message": "研究已开始",
"created_at": session.created_at.isoformat()
})
else:
return jsonify({
"session_id": session.id,
"status": "created",
"message": "研究会话已创建,等待开始",
"created_at": session.created_at.isoformat()
})
except Exception as e:
current_app.logger.error(f"创建研究失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/start', methods=['POST'])
def start_research(session_id):
"""手动开始研究"""
try:
result = research_manager.start_research(session_id)
return jsonify(result)
except ValueError as e:
return jsonify({"error": str(e)}), 404
except Exception as e:
current_app.logger.error(f"开始研究失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/status', methods=['GET'])
def get_research_status(session_id):
"""获取研究状态"""
try:
status = research_manager.get_session_status(session_id)
if "error" in status:
return jsonify(status), 404
# 添加任务信息
tasks = task_manager.get_session_tasks(session_id)
status['tasks'] = tasks
return jsonify(status)
except Exception as e:
current_app.logger.error(f"获取状态失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/outline', methods=['GET'])
def get_research_outline(session_id):
"""获取研究大纲"""
try:
session = research_manager.get_session(session_id)
if not session:
return jsonify({"error": "Session not found"}), 404
if not session.outline:
return jsonify({"error": "Outline not yet created"}), 400
return jsonify({
"main_topic": session.outline.main_topic,
"research_questions": session.outline.research_questions,
"sub_topics": [
{
"id": st.id,
"topic": st.topic,
"explain": st.explain,
"priority": st.priority,
"status": st.status
}
for st in session.outline.sub_topics
],
"version": session.outline.version
})
except Exception as e:
current_app.logger.error(f"获取大纲失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/outline', methods=['PUT'])
def update_research_outline(session_id):
"""更新研究大纲(用户反馈)"""
try:
data = request.get_json()
feedback = data.get('feedback', '').strip()
error = validate_outline_feedback(feedback)
if error:
return jsonify({"error": error}), 400
# TODO: 实现大纲更新逻辑
# 这需要调用AI服务来修改大纲
return jsonify({
"message": "大纲更新请求已接收",
"status": "processing"
})
except Exception as e:
current_app.logger.error(f"更新大纲失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/cancel', methods=['POST'])
def cancel_research(session_id):
"""取消研究"""
try:
# 取消任务
cancelled_count = task_manager.cancel_session_tasks(session_id)
# 更新会话状态
result = research_manager.cancel_research(session_id)
if "error" in result:
return jsonify(result), 404
result['cancelled_tasks'] = cancelled_count
return jsonify(result)
except Exception as e:
current_app.logger.error(f"取消研究失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/report', methods=['GET'])
def get_research_report(session_id):
"""获取研究报告"""
try:
format = request.args.get('format', 'json')
report_content = research_manager.get_research_report(session_id)
if not report_content:
return jsonify({"error": "Report not available"}), 404
if format == 'markdown':
# 返回Markdown文件
report_path = os.path.join(
current_app.config['REPORTS_DIR'],
f"{session_id}.md"
)
if os.path.exists(report_path):
return send_file(
report_path,
mimetype='text/markdown',
as_attachment=True,
download_name=f"research_report_{session_id}.md"
)
else:
# 临时创建文件
with open(report_path, 'w', encoding='utf-8') as f:
f.write(report_content)
return send_file(
report_path,
mimetype='text/markdown',
as_attachment=True,
download_name=f"research_report_{session_id}.md"
)
else:
# 返回JSON格式
return jsonify({
"session_id": session_id,
"report": report_content,
"format": "markdown"
})
except Exception as e:
current_app.logger.error(f"获取报告失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/<session_id>/subtopic/<subtopic_id>', methods=['GET'])
def get_subtopic_detail(session_id, subtopic_id):
"""获取子主题详情"""
try:
session = research_manager.get_session(session_id)
if not session:
return jsonify({"error": "Session not found"}), 404
if not session.outline:
return jsonify({"error": "Outline not created"}), 400
# 找到对应的子主题
subtopic = None
for st in session.outline.sub_topics:
if st.id == subtopic_id:
subtopic = st
break
if not subtopic:
return jsonify({"error": "Subtopic not found"}), 404
return jsonify({
"id": subtopic.id,
"topic": subtopic.topic,
"explain": subtopic.explain,
"priority": subtopic.priority,
"status": subtopic.status,
"search_count": subtopic.search_count,
"max_searches": subtopic.max_searches,
"progress": subtopic.get_total_searches() / subtopic.max_searches * 100,
"has_report": subtopic.report is not None
})
except Exception as e:
current_app.logger.error(f"获取子主题详情失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/research/sessions', methods=['GET'])
def list_research_sessions():
"""列出所有研究会话"""
try:
limit = request.args.get('limit', 20, type=int)
offset = request.args.get('offset', 0, type=int)
sessions = research_manager.list_sessions(limit=limit, offset=offset)
return jsonify({
"sessions": sessions,
"total": len(sessions),
"limit": limit,
"offset": offset
})
except Exception as e:
current_app.logger.error(f"列出会话失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/tasks/status', methods=['GET'])
def get_tasks_status():
"""获取任务管理器状态"""
try:
tasks = task_manager.tasks
status_counts = {
'pending': 0,
'running': 0,
'completed': 0,
'failed': 0,
'cancelled': 0
}
for task in tasks.values():
status_counts[task.status.value] += 1
return jsonify({
"total_tasks": len(tasks),
"status_counts": status_counts,
"sessions_count": len(task_manager.session_tasks)
})
except Exception as e:
current_app.logger.error(f"获取任务状态失败: {e}")
return jsonify({"error": str(e)}), 500
@api_bp.route('/test/connections', methods=['GET'])
def test_connections():
"""测试API连接仅开发环境"""
if not current_app.debug:
return jsonify({"error": "Not available in production"}), 403
from app.services.search_service import SearchService
from app.services.ai_service import AIService
results = {
"deepseek_api": False,
"tavily_api": False,
"task_manager": False
}
try:
# 测试DeepSeek API
ai_service = AIService()
test_result = ai_service.analyze_question_type("test question")
results["deepseek_api"] = bool(test_result)
except Exception as e:
current_app.logger.error(f"DeepSeek API测试失败: {e}")
try:
# 测试Tavily API
search_service = SearchService()
results["tavily_api"] = search_service.test_connection()
except Exception as e:
current_app.logger.error(f"Tavily API测试失败: {e}")
# 测试任务管理器
try:
# 提交一个测试任务
def test_task():
return "test"
task_id = task_manager.submit_task(test_task)
status = task_manager.get_task_status(task_id)
results["task_manager"] = status is not None
except Exception as e:
current_app.logger.error(f"任务管理器测试失败: {e}")
return jsonify({
"connections": results,
"all_connected": all(results.values())
})

59
所有文件/app.py Normal file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
"""
DeepResearch 应用入口
"""
import os
import sys
import signal
from app import create_app, socketio
from config import config
# 获取配置名称
config_name = os.environ.get('FLASK_CONFIG', 'development')
app = create_app(config_name)
def shutdown_handler(signum, frame):
"""优雅关闭处理器"""
print("\n正在关闭应用...")
# 关闭任务管理器
try:
from app.services.task_manager import task_manager
task_manager.shutdown()
print("任务管理器已关闭")
except Exception as e:
print(f"关闭任务管理器时出错: {e}")
sys.exit(0)
if __name__ == '__main__':
# 注册信号处理器
signal.signal(signal.SIGINT, shutdown_handler)
signal.signal(signal.SIGTERM, shutdown_handler)
# 检查必要的环境变量
required_env_vars = ['DEEPSEEK_API_KEY', 'TAVILY_API_KEY']
missing_vars = [var for var in required_env_vars if not os.environ.get(var)]
if missing_vars:
print(f"错误: 缺少必要的环境变量: {', '.join(missing_vars)}")
print("请在.env文件中设置这些变量")
sys.exit(1)
# 启动应用
port = int(os.environ.get('PORT', 8088))
debug = app.config.get('DEBUG', False)
print(f"启动 DeepResearch 服务器...")
print(f"配置: {config_name}")
print(f"调试模式: {debug}")
print(f"访问地址: http://localhost:{port}")
print(f"\n提示: 不再需要 Redis 和 Celery Worker")
print(f"按 Ctrl+C 优雅关闭应用\n")
# 使用socketio运行以支持WebSocket
socketio.run(app,
host='0.0.0.0',
port=port,
debug=debug,
use_reloader=debug)

26
所有文件/base.html Normal file
View File

@ -0,0 +1,26 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{% block title %}DeepResearch - 智能深度研究系统{% endblock %}</title>
<link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}">
{% block extra_css %}{% endblock %}
</head>
<body>
<div class="app">
<header class="header">
<h1>DeepResearch - 智能深度研究系统</h1>
{% block header_content %}{% endblock %}
</header>
<main class="main-container">
{% block content %}{% endblock %}
</main>
</div>
<script src="https://cdn.socket.io/4.5.4/socket.io.min.js"></script>
<script src="{{ url_for('static', filename='js/api.js') }}"></script>
{% block extra_js %}{% endblock %}
</body>
</html>

115
所有文件/config.py Normal file
View File

@ -0,0 +1,115 @@
# 文件位置: config.py
# 文件名: config.py
import os
from datetime import timedelta
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
class Config:
"""基础配置"""
# Flask配置
SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production')
DEBUG = False
TESTING = False
# API配置
DEEPSEEK_API_KEY = os.environ.get('DEEPSEEK_API_KEY')
DEEPSEEK_BASE_URL = os.environ.get('DEEPSEEK_BASE_URL', 'https://api.deepseek.com/v1')
TAVILY_API_KEY = os.environ.get('TAVILY_API_KEY')
# 模型配置
R1_MODEL = "deepseek-reasoner" # R1-0528
V3_MODEL = "deepseek-chat" # V3-0324
# 研究配置
MAX_CONCURRENT_SUBTOPICS = 10
MAX_SEARCHES_HIGH_PRIORITY = 15
MAX_SEARCHES_MEDIUM_PRIORITY = 10
MAX_SEARCHES_LOW_PRIORITY = 5
# 搜索配置
TAVILY_MAX_RESULTS = 10
TAVILY_SEARCH_DEPTH = "advanced"
TAVILY_INCLUDE_ANSWER = True
TAVILY_INCLUDE_RAW_CONTENT = False
# 任务管理配置替代Celery
TASK_POOL_SIZE = 10 # 线程池大小
TASK_TIMEOUT = {
'question_analysis': 60,
'outline_creation': 120,
'search': 30,
'report_generation': 180
}
# 重试配置
MAX_RETRIES = 3
RETRY_DELAY = 5 # 秒
# 存储配置
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
SESSIONS_DIR = os.path.join(DATA_DIR, 'sessions')
REPORTS_DIR = os.path.join(DATA_DIR, 'reports')
CACHE_DIR = os.path.join(DATA_DIR, 'cache')
# 日志配置
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO')
LOG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')
# MongoDB配置可选
MONGODB_URI = os.environ.get('MONGODB_URI', 'mongodb://localhost:27017/deepresearch')
# WebSocket配置
SOCKETIO_ASYNC_MODE = 'eventlet'
@staticmethod
def init_app(app):
"""初始化应用配置"""
# 确保必要的目录存在
for dir_path in [Config.DATA_DIR, Config.SESSIONS_DIR,
Config.REPORTS_DIR, Config.CACHE_DIR, Config.LOG_DIR]:
os.makedirs(dir_path, exist_ok=True)
class DevelopmentConfig(Config):
"""开发环境配置"""
DEBUG = True
class ProductionConfig(Config):
"""生产环境配置"""
DEBUG = False
@classmethod
def init_app(cls, app):
Config.init_app(app)
# 生产环境特定的初始化
import logging
from logging.handlers import RotatingFileHandler
if not app.debug and not app.testing:
file_handler = RotatingFileHandler(
os.path.join(cls.LOG_DIR, 'deepresearch.log'),
maxBytes=10485760,
backupCount=10
)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]'
))
file_handler.setLevel(logging.INFO)
app.logger.addHandler(file_handler)
app.logger.setLevel(logging.INFO)
app.logger.info('DeepResearch startup')
class TestingConfig(Config):
"""测试环境配置"""
TESTING = True
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig
}

0
所有文件/conftest.py Normal file
View File

View File

@ -0,0 +1,20 @@
# Flask配置
SECRET_KEY=your-secret-key-here
FLASK_ENV=development
FLASK_DEBUG=True
# DeepSeek API配置
DEEPSEEK_API_KEY=e157ad52-c417-441a-9f48-a4a9a46b106d
DEEPSEEK_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
# Tavily API配置
TAVILY_API_KEY=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
# Redis配置用于Celery
REDIS_URL=redis://localhost:6379/0
# MongoDB配置可选用于持久化存储
MONGODB_URI=mongodb://localhost:27017/deepresearch
# 日志级别
LOG_LEVEL=INFO

View File

@ -0,0 +1,44 @@
total 560
drwxr-xr-x 43 jojo staff 1376 7 1 21:17 .
drwxr-xr-x 18 jojo staff 576 7 1 20:53 ..
-rw-r--r--@ 1 jojo staff 10244 7 1 21:15 .DS_Store
-rw-r--r--@ 1 jojo staff 499 7 1 19:34 .env
-rw-r--r--@ 1 jojo staff 467 7 1 19:34 .gitignore
-rw-r--r-- 1 jojo staff 0 7 1 21:17 README.md
-rw-r--r-- 1 jojo staff 0 7 1 21:17 __init__.py
-rw-r--r--@ 1 jojo staff 12270 7 1 21:17 ai_service.py
-rw-r--r--@ 1 jojo staff 1469 7 1 21:17 api.js
-rw-r--r--@ 1 jojo staff 11105 7 1 21:17 api.py
-rw-r--r--@ 1 jojo staff 1772 7 1 21:17 app.py
-rw-r--r--@ 1 jojo staff 875 7 1 21:17 base.html
-rw-r--r--@ 1 jojo staff 3429 7 1 21:17 config.py
-rw-r--r-- 1 jojo staff 0 7 1 21:17 conftest.py
-rw-r--r--@ 1 jojo staff 499 7 1 21:17 env_example.txt
-rw-r--r-- 1 jojo staff 0 7 1 21:17 file_list.txt
-rw-r--r--@ 1 jojo staff 5685 7 1 21:17 formatters.py
-rw-r--r--@ 1 jojo staff 631 7 1 21:17 frontend.py
-rw-r--r--@ 1 jojo staff 1284 7 1 21:17 index.html
-rw-r--r--@ 1 jojo staff 2791 7 1 21:17 index.js
-rwxr-xr-x@ 1 jojo staff 5040 7 1 21:17 init_db.py
-rw-r--r--@ 1 jojo staff 8745 7 1 21:17 json_parser.py
-rw-r--r--@ 1 jojo staff 7522 7 1 21:17 logger.py
-rw-r--r--@ 1 jojo staff 1374 7 1 21:17 main.py
-rw-r--r--@ 1 jojo staff 8013 7 1 21:17 prompts.py
-rw-r--r--@ 1 jojo staff 10598 7 1 21:17 r1_agent.py
-rw-r--r--@ 1 jojo staff 5891 7 1 21:17 report.py
-rw-r--r--@ 1 jojo staff 13164 7 1 21:17 report_generator.py
-rw-r--r--@ 1 jojo staff 311 7 1 21:17 requirements.txt
-rw-r--r--@ 1 jojo staff 5767 7 1 21:17 research-tree.js
-rw-r--r--@ 1 jojo staff 1421 7 1 21:17 research.html
-rw-r--r--@ 1 jojo staff 4596 7 1 21:17 research.js
-rw-r--r--@ 1 jojo staff 4220 7 1 21:17 research.py
-rw-r--r--@ 1 jojo staff 11872 7 1 21:17 research_manager.py
-rw-r--r--@ 1 jojo staff 16334 7 1 21:17 research_tasks.py
-rw-r--r--@ 1 jojo staff 3894 7 1 21:17 search_result.py
-rw-r--r--@ 1 jojo staff 7351 7 1 21:17 search_service.py
-rw-r--r--@ 1 jojo staff 7117 7 1 21:17 style.css
-rw-r--r--@ 1 jojo staff 7075 7 1 21:17 task_manager.py
-rwxr-xr-x@ 1 jojo staff 8531 7 1 21:17 test_api_keys.py
-rw-r--r--@ 1 jojo staff 7070 7 1 21:17 v3_agent.py
-rw-r--r--@ 1 jojo staff 3604 7 1 21:17 validators.py
-rw-r--r--@ 1 jojo staff 4971 7 1 21:17 websocket.py

180
所有文件/formatters.py Normal file
View File

@ -0,0 +1,180 @@
"""
格式化工具
"""
import re
from datetime import datetime
from typing import List, Dict, Any, Optional
def format_datetime(dt: datetime, format: str = "full") -> str:
"""格式化日期时间"""
if format == "full":
return dt.strftime("%Y-%m-%d %H:%M:%S")
elif format == "date":
return dt.strftime("%Y-%m-%d")
elif format == "time":
return dt.strftime("%H:%M:%S")
elif format == "relative":
return get_relative_time(dt)
else:
return dt.isoformat()
def get_relative_time(dt: datetime) -> str:
"""获取相对时间"""
now = datetime.now()
delta = now - dt
if delta.total_seconds() < 60:
return "刚刚"
elif delta.total_seconds() < 3600:
minutes = int(delta.total_seconds() / 60)
return f"{minutes}分钟前"
elif delta.total_seconds() < 86400:
hours = int(delta.total_seconds() / 3600)
return f"{hours}小时前"
elif delta.days < 30:
return f"{delta.days}天前"
elif delta.days < 365:
months = int(delta.days / 30)
return f"{months}个月前"
else:
years = int(delta.days / 365)
return f"{years}年前"
def format_file_size(size_bytes: int) -> str:
"""格式化文件大小"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} PB"
def format_percentage(value: float, decimals: int = 1) -> str:
"""格式化百分比"""
return f"{value:.{decimals}f}%"
def format_search_results(results: List[Dict[str, Any]]) -> str:
"""格式化搜索结果为文本"""
formatted_lines = []
for i, result in enumerate(results, 1):
formatted_lines.append(f"{i}. {result.get('title', '无标题')}")
formatted_lines.append(f" URL: {result.get('url', 'N/A')}")
formatted_lines.append(f" {result.get('snippet', '无摘要')}")
formatted_lines.append("")
return '\n'.join(formatted_lines)
def format_outline_text(outline: Dict[str, Any]) -> str:
"""格式化大纲为文本"""
lines = []
lines.append(f"# {outline.get('main_topic', '研究主题')}")
lines.append("")
lines.append("## 研究问题")
for i, question in enumerate(outline.get('research_questions', []), 1):
lines.append(f"{i}. {question}")
lines.append("")
lines.append("## 子主题")
for i, subtopic in enumerate(outline.get('sub_topics', []), 1):
lines.append(f"{i}. **{subtopic.get('topic', '')}** ({subtopic.get('priority', '')})")
lines.append(f" {subtopic.get('explain', '')}")
return '\n'.join(lines)
def clean_markdown(text: str) -> str:
"""清理Markdown文本"""
# 移除多余的空行
text = re.sub(r'\n{3,}', '\n\n', text)
# 确保标题前后有空行
text = re.sub(r'([^\n])\n(#{1,6} )', r'\1\n\n\2', text)
text = re.sub(r'(#{1,6} [^\n]+)\n([^\n])', r'\1\n\n\2', text)
# 修复列表格式
text = re.sub(r'\n- ', r'\n- ', text)
text = re.sub(r'\n\* ', r'\n* ', text)
text = re.sub(r'\n\d+\. ', lambda m: '\n' + m.group(0)[1:], text)
return text.strip()
def truncate_text(text: str, max_length: int, ellipsis: str = "...") -> str:
"""截断文本"""
if len(text) <= max_length:
return text
# 在词边界截断
truncated = text[:max_length]
last_space = truncated.rfind(' ')
if last_space > max_length * 0.8: # 如果空格在80%位置之后
truncated = truncated[:last_space]
return truncated + ellipsis
def highlight_keywords(text: str, keywords: List[str]) -> str:
"""高亮关键词"""
for keyword in keywords:
# 使用正则表达式进行大小写不敏感的替换
pattern = re.compile(re.escape(keyword), re.IGNORECASE)
text = pattern.sub(f"**{keyword}**", text)
return text
def extract_urls(text: str) -> List[str]:
"""从文本中提取URL"""
url_pattern = re.compile(
r'https?://(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b'
r'(?:[-a-zA-Z0-9()@:%_\+.~#?&/=]*)'
)
urls = url_pattern.findall(text)
return list(set(urls)) # 去重
def format_json_output(data: Any, indent: int = 2) -> str:
"""格式化JSON输出"""
import json
return json.dumps(
data,
ensure_ascii=False,
indent=indent,
sort_keys=True,
default=str # 处理datetime等特殊对象
)
def create_summary(text: str, max_sentences: int = 3) -> str:
"""创建文本摘要"""
# 简单的句子分割
sentences = re.split(r'[。!?.!?]+', text)
sentences = [s.strip() for s in sentences if s.strip()]
# 返回前N个句子
summary_sentences = sentences[:max_sentences]
if len(sentences) > max_sentences:
return ''.join(summary_sentences) + '。...'
else:
return ''.join(summary_sentences) + ''
def format_status_message(status: str, phase: Optional[str] = None) -> str:
"""格式化状态消息"""
status_messages = {
"pending": "等待开始",
"analyzing": "分析问题中",
"outlining": "制定大纲中",
"researching": "研究进行中",
"writing": "撰写报告中",
"reviewing": "审核内容中",
"completed": "研究完成",
"error": "发生错误",
"cancelled": "已取消"
}
message = status_messages.get(status, status)
if phase:
message = f"{message} - {phase}"
return message

25
所有文件/frontend.py Normal file
View File

@ -0,0 +1,25 @@
# 文件位置: app/routes/frontend.py
# 文件名: frontend.py
"""
前端页面路由
"""
from flask import Blueprint, render_template, send_from_directory
import os
frontend_bp = Blueprint('frontend', __name__)
@frontend_bp.route('/')
def index():
"""主页"""
return render_template('index.html')
@frontend_bp.route('/research/<session_id>')
def research_detail(session_id):
"""研究详情页"""
return render_template('research.html', session_id=session_id)
@frontend_bp.route('/static/<path:filename>')
def static_files(filename):
"""静态文件"""
return send_from_directory('static', filename)

38
所有文件/index.html Normal file
View File

@ -0,0 +1,38 @@
{% extends "base.html" %}
{% block content %}
<div class="start-screen">
<div class="start-card">
<h2>开始新的研究</h2>
<div class="input-group">
<input
type="text"
id="questionInput"
placeholder="输入你想研究的问题..."
class="question-input"
onkeypress="if(event.key === 'Enter') startResearch()"
/>
<button onclick="startResearch()" class="start-button" id="startBtn">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="11" cy="11" r="8"></circle>
<path d="m21 21-4.35-4.35"></path>
</svg>
</button>
</div>
<div class="history-section" id="historySection" style="display: none;">
<h3>历史研究</h3>
<div class="session-list" id="sessionList"></div>
</div>
</div>
</div>
<!-- 加载动画 -->
<div id="loading" class="loading-overlay" style="display: none;">
<div class="loading-spinner"></div>
</div>
{% endblock %}
{% block extra_js %}
<script src="{{ url_for('static', filename='js/index.js') }}"></script>
{% endblock %}

87
所有文件/index.js Normal file
View File

@ -0,0 +1,87 @@
// app/static/js/index.js
document.addEventListener('DOMContentLoaded', function() {
loadSessions();
});
async function startResearch() {
const input = document.getElementById('questionInput');
const question = input.value.trim();
if (!question) {
alert('请输入研究问题');
return;
}
const startBtn = document.getElementById('startBtn');
const loading = document.getElementById('loading');
startBtn.disabled = true;
loading.style.display = 'flex';
try {
const result = await api.createResearch(question);
if (result.session_id) {
// 跳转到研究页面
window.location.href = `/research/${result.session_id}`;
} else {
alert('创建研究失败: ' + (result.error || '未知错误'));
}
} catch (error) {
console.error('Error:', error);
alert('创建研究失败,请重试');
} finally {
startBtn.disabled = false;
loading.style.display = 'none';
}
}
async function loadSessions() {
try {
const data = await api.getSessions();
if (data.sessions && data.sessions.length > 0) {
const historySection = document.getElementById('historySection');
const sessionList = document.getElementById('sessionList');
historySection.style.display = 'block';
sessionList.innerHTML = '';
data.sessions.forEach(session => {
const item = document.createElement('div');
item.className = 'session-item';
item.onclick = () => {
window.location.href = `/research/${session.id}`;
};
item.innerHTML = `
<div class="session-question">${session.question}</div>
<div class="session-meta">
<span class="status-badge ${session.status}">${getStatusText(session.status)}</span>
<span class="session-date">${new Date(session.created_at).toLocaleDateString()}</span>
</div>
`;
sessionList.appendChild(item);
});
}
} catch (error) {
console.error('Failed to load sessions:', error);
}
}
function getStatusText(status) {
const statusMap = {
'pending': '等待中',
'analyzing': '分析中',
'outlining': '制定大纲',
'researching': '研究中',
'writing': '撰写中',
'reviewing': '审核中',
'completed': '已完成',
'error': '错误',
'cancelled': '已取消'
};
return statusMap[status] || status;
}

172
所有文件/init_db.py Executable file
View File

@ -0,0 +1,172 @@
#!/usr/bin/env python3
"""
初始化数据库和目录结构
"""
import os
import sys
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import Config
def init_directories():
"""初始化必要的目录"""
directories = [
Config.DATA_DIR,
Config.SESSIONS_DIR,
Config.REPORTS_DIR,
Config.CACHE_DIR,
Config.LOG_DIR
]
for directory in directories:
if not os.path.exists(directory):
os.makedirs(directory)
print(f"创建目录: {directory}")
else:
print(f"目录已存在: {directory}")
# 创建.gitkeep文件
for directory in [Config.SESSIONS_DIR, Config.REPORTS_DIR, Config.CACHE_DIR]:
gitkeep_path = os.path.join(directory, '.gitkeep')
if not os.path.exists(gitkeep_path):
with open(gitkeep_path, 'w') as f:
f.write('')
print(f"创建.gitkeep: {gitkeep_path}")
def init_mongodb():
"""初始化MongoDB如果使用"""
try:
from pymongo import MongoClient
client = MongoClient(Config.MONGODB_URI)
db = client.get_database()
# 创建集合和索引
collections = {
'sessions': [
('created_at', -1),
('status', 1),
('question_type', 1)
],
'search_results': [
('session_id', 1),
('subtopic_id', 1),
('created_at', -1)
],
'reports': [
('session_id', 1),
('created_at', -1)
]
}
for collection_name, indexes in collections.items():
collection = db[collection_name]
for index in indexes:
if isinstance(index, tuple):
collection.create_index([index])
else:
collection.create_index(index)
print(f"初始化集合: {collection_name}")
print("MongoDB初始化完成")
except Exception as e:
print(f"MongoDB初始化失败可选: {e}")
def check_environment():
"""检查环境变量"""
required_vars = [
'DEEPSEEK_API_KEY',
'TAVILY_API_KEY'
]
missing_vars = []
for var in required_vars:
if not os.environ.get(var):
missing_vars.append(var)
if missing_vars:
print("\n警告: 缺少以下环境变量:")
for var in missing_vars:
print(f" - {var}")
print("\n请在.env文件中设置这些变量")
else:
print("\n环境变量检查通过")
def test_task_manager():
"""测试任务管理器"""
print("\n测试任务管理器...")
try:
from app.services.task_manager import task_manager
# 测试任务提交
def test_task():
return "Task manager is working!"
task_id = task_manager.submit_task(test_task)
print(f"✓ 任务管理器正常工作测试任务ID: {task_id}")
# 关闭任务管理器
task_manager.shutdown()
except Exception as e:
print(f"✗ 任务管理器测试失败: {e}")
def create_test_data():
"""创建测试数据(开发环境)"""
if os.environ.get('FLASK_ENV') == 'development':
print("\n开发环境:创建测试数据...")
# 创建一个示例会话文件
sample_session = {
"id": "test-session-001",
"question": "这是一个测试研究问题",
"status": "completed",
"created_at": "2024-01-01T00:00:00"
}
import json
test_file = os.path.join(Config.SESSIONS_DIR, 'test-session-001.json')
if not os.path.exists(test_file):
with open(test_file, 'w', encoding='utf-8') as f:
json.dump(sample_session, f, ensure_ascii=False, indent=2)
print(f"创建测试会话文件: {test_file}")
def main():
"""主函数"""
print("DeepResearch 初始化脚本")
print("=" * 50)
# 初始化目录
print("\n1. 初始化目录结构...")
init_directories()
# 初始化MongoDB可选
print("\n2. 初始化MongoDB...")
init_mongodb()
# 检查环境变量
print("\n3. 检查环境变量...")
check_environment()
# 测试任务管理器
print("\n4. 测试任务管理器...")
test_task_manager()
# 创建测试数据
print("\n5. 创建测试数据...")
create_test_data()
print("\n" + "=" * 50)
print("初始化完成!")
print("\n下一步:")
print("1. 确保在.env文件中设置了必要的API密钥")
print("2. 运行 'python app.py' 启动应用")
print("\n注意: 不再需要启动 Redis 和 Celery Worker")
if __name__ == '__main__':
main()

283
所有文件/json_parser.py Normal file
View File

@ -0,0 +1,283 @@
"""
JSON解析和修复工具
"""
import json
import re
import logging
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
def parse_json_safely(text: str) -> Dict[str, Any]:
"""安全解析JSON带错误修复"""
# 首先尝试直接解析
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# 尝试修复常见问题
fixed_text = fix_json_common_issues(text)
try:
return json.loads(fixed_text)
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}")
logger.debug(f"原始文本: {text[:500]}...")
# 尝试更激进的修复
fixed_text = fix_json_aggressive(fixed_text)
try:
return json.loads(fixed_text)
except json.JSONDecodeError:
# 最后的尝试提取JSON部分
json_part = extract_json_from_text(text)
if json_part:
try:
return json.loads(json_part)
except:
pass
# 返回空字典而不是抛出异常
logger.error("无法解析JSON返回空字典")
return {}
def fix_json_common_issues(text: str) -> str:
"""修复常见的JSON问题"""
# 移除可能的Markdown代码块标记
text = re.sub(r'^```json\s*', '', text, flags=re.MULTILINE)
text = re.sub(r'^```\s*$', '', text, flags=re.MULTILINE)
# 移除BOM
text = text.lstrip('\ufeff')
# 移除控制字符
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
# 修复尾随逗号
text = re.sub(r',\s*}', '}', text)
text = re.sub(r',\s*]', ']', text)
# 修复单引号JSON只接受双引号
# 但要小心不要替换值中的单引号
text = fix_single_quotes(text)
# 修复未加引号的键
text = fix_unquoted_keys(text)
# 修复Python的True/False/None
text = text.replace('True', 'true')
text = text.replace('False', 'false')
text = text.replace('None', 'null')
# 移除注释
text = remove_json_comments(text)
return text.strip()
def fix_json_aggressive(text: str) -> str:
"""更激进的JSON修复"""
# 尝试修复断行的字符串
text = re.sub(r'"\s*\n\s*"', '" "', text)
# 修复缺失的逗号
# 在 } 或 ] 后面跟着 " 或 { 或 [ 的地方添加逗号
text = re.sub(r'}\s*"', '},\n"', text)
text = re.sub(r']\s*"', '],\n"', text)
text = re.sub(r'}\s*{', '},\n{', text)
text = re.sub(r']\s*\[', '],\n[', text)
# 修复缺失的冒号
text = re.sub(r'"([^"]+)"\s*"', r'"\1": "', text)
# 确保所有字符串值都被引号包围
# 这个比较复杂,需要小心处理
return text
def fix_single_quotes(text: str) -> str:
"""修复单引号为双引号"""
# 使用更智能的方法替换单引号
# 只替换作为字符串边界的单引号
result = []
in_string = False
string_char = None
i = 0
while i < len(text):
char = text[i]
if not in_string:
if char == "'" and (i == 0 or text[i-1] in ' \n\t:,{['):
# 可能是字符串开始
result.append('"')
in_string = True
string_char = "'"
else:
result.append(char)
else:
if char == string_char and (i + 1 >= len(text) or text[i+1] in ' \n\t,}]:'):
# 字符串结束
result.append('"')
in_string = False
string_char = None
elif char == '\\' and i + 1 < len(text):
# 转义字符
result.append(char)
result.append(text[i + 1])
i += 1
else:
result.append(char)
i += 1
return ''.join(result)
def fix_unquoted_keys(text: str) -> str:
"""修复未加引号的键"""
# 匹配形如 key: value 的模式
pattern = r'([,\{\s])([a-zA-Z_][a-zA-Z0-9_]*)\s*:'
replacement = r'\1"\2":'
return re.sub(pattern, replacement, text)
def remove_json_comments(text: str) -> str:
"""移除JSON中的注释"""
# 移除单行注释 //
text = re.sub(r'//.*$', '', text, flags=re.MULTILINE)
# 移除多行注释 /* */
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
return text
def extract_json_from_text(text: str) -> Optional[str]:
"""从文本中提取JSON部分"""
# 查找第一个 { 或 [
start_idx = -1
start_char = None
for i, char in enumerate(text):
if char in '{[':
start_idx = i
start_char = char
break
if start_idx == -1:
return None
# 查找匹配的结束字符
end_char = '}' if start_char == '{' else ']'
bracket_count = 0
in_string = False
escape = False
for i in range(start_idx, len(text)):
char = text[i]
if escape:
escape = False
continue
if char == '\\':
escape = True
continue
if char == '"' and not escape:
in_string = not in_string
continue
if not in_string:
if char == start_char:
bracket_count += 1
elif char == end_char:
bracket_count -= 1
if bracket_count == 0:
return text[start_idx:i+1]
return None
def validate_json_schema(data: Dict[str, Any], schema: Dict[str, Any]) -> List[str]:
"""验证JSON是否符合schema"""
errors = []
# 简单的schema验证实现
required_fields = schema.get('required', [])
properties = schema.get('properties', {})
# 检查必需字段
for field in required_fields:
if field not in data:
errors.append(f"缺少必需字段: {field}")
# 检查字段类型
for field, value in data.items():
if field in properties:
expected_type = properties[field].get('type')
if expected_type:
actual_type = type(value).__name__
type_mapping = {
'string': 'str',
'number': 'float',
'integer': 'int',
'boolean': 'bool',
'array': 'list',
'object': 'dict'
}
expected_python_type = type_mapping.get(expected_type, expected_type)
if actual_type != expected_python_type:
# 特殊处理int可以作为float
if not (expected_python_type == 'float' and actual_type == 'int'):
errors.append(
f"字段 '{field}' 类型错误: "
f"期望 {expected_type}, 实际 {actual_type}"
)
return errors
def merge_json_objects(obj1: Dict[str, Any], obj2: Dict[str, Any],
deep: bool = True) -> Dict[str, Any]:
"""合并两个JSON对象"""
result = obj1.copy()
for key, value in obj2.items():
if key in result and deep and isinstance(result[key], dict) and isinstance(value, dict):
# 深度合并
result[key] = merge_json_objects(result[key], value, deep=True)
elif key in result and deep and isinstance(result[key], list) and isinstance(value, list):
# 合并列表(去重)
result[key] = list(set(result[key] + value))
else:
# 直接覆盖
result[key] = value
return result
def json_to_flat_dict(data: Dict[str, Any], parent_key: str = '',
separator: str = '.') -> Dict[str, Any]:
"""将嵌套的JSON转换为扁平的字典"""
items = []
for key, value in data.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
if isinstance(value, dict):
items.extend(
json_to_flat_dict(value, new_key, separator).items()
)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
items.extend(
json_to_flat_dict(item, f"{new_key}[{i}]", separator).items()
)
else:
items.append((f"{new_key}[{i}]", item))
else:
items.append((new_key, value))
return dict(items)

223
所有文件/logger.py Normal file
View File

@ -0,0 +1,223 @@
"""
日志配置工具
"""
import os
import logging
import logging.handlers
from datetime import datetime
from pythonjsonlogger import jsonlogger
def setup_logging(app):
"""设置应用日志"""
log_level = app.config.get('LOG_LEVEL', 'INFO')
log_dir = app.config.get('LOG_DIR', 'logs')
# 确保日志目录存在
os.makedirs(log_dir, exist_ok=True)
# 设置根日志器
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, log_level))
# 清除现有的处理器
root_logger.handlers = []
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(getattr(logging, log_level))
console_formatter = ColoredFormatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
# 文件处理器 - 一般日志
file_handler = logging.handlers.RotatingFileHandler(
os.path.join(log_dir, 'app.log'),
maxBytes=10485760, # 10MB
backupCount=10
)
file_handler.setLevel(logging.INFO)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
file_handler.setFormatter(file_formatter)
root_logger.addHandler(file_handler)
# 错误日志文件
error_handler = logging.handlers.RotatingFileHandler(
os.path.join(log_dir, 'error.log'),
maxBytes=10485760,
backupCount=10
)
error_handler.setLevel(logging.ERROR)
error_handler.setFormatter(file_formatter)
root_logger.addHandler(error_handler)
# JSON格式日志用于日志分析
json_handler = logging.handlers.RotatingFileHandler(
os.path.join(log_dir, 'app.json.log'),
maxBytes=10485760,
backupCount=10
)
json_formatter = CustomJsonFormatter()
json_handler.setFormatter(json_formatter)
json_handler.setLevel(logging.INFO)
root_logger.addHandler(json_handler)
# 研究任务专用日志
research_logger = logging.getLogger('research')
research_handler = logging.handlers.RotatingFileHandler(
os.path.join(log_dir, 'research.log'),
maxBytes=10485760,
backupCount=10
)
research_handler.setFormatter(file_formatter)
research_logger.addHandler(research_handler)
research_logger.setLevel(logging.DEBUG)
# 设置第三方库的日志级别
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('requests').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
app.logger.info(f"日志系统初始化完成,级别: {log_level}")
class ColoredFormatter(logging.Formatter):
"""带颜色的控制台日志格式化器"""
COLORS = {
'DEBUG': '\033[36m', # 青色
'INFO': '\033[32m', # 绿色
'WARNING': '\033[33m', # 黄色
'ERROR': '\033[31m', # 红色
'CRITICAL': '\033[35m', # 紫色
}
RESET = '\033[0m'
def format(self, record):
log_color = self.COLORS.get(record.levelname, self.RESET)
record.levelname = f"{log_color}{record.levelname}{self.RESET}"
return super().format(record)
class CustomJsonFormatter(jsonlogger.JsonFormatter):
"""自定义JSON日志格式化器"""
def add_fields(self, log_record, record, message_dict):
super().add_fields(log_record, record, message_dict)
# 添加额外字段
log_record['timestamp'] = datetime.utcnow().isoformat()
log_record['level'] = record.levelname
log_record['logger'] = record.name
# 添加异常信息
if record.exc_info:
log_record['exception'] = self.formatException(record.exc_info)
# 添加额外的上下文信息
if hasattr(record, 'session_id'):
log_record['session_id'] = record.session_id
if hasattr(record, 'subtopic_id'):
log_record['subtopic_id'] = record.subtopic_id
if hasattr(record, 'user_id'):
log_record['user_id'] = record.user_id
def get_logger(name: str) -> logging.Logger:
"""获取指定名称的日志器"""
return logging.getLogger(name)
def log_performance(func):
"""性能日志装饰器"""
import functools
import time
@functools.wraps(func)
def wrapper(*args, **kwargs):
logger = logging.getLogger(func.__module__)
start_time = time.time()
try:
result = func(*args, **kwargs)
elapsed_time = time.time() - start_time
logger.info(
f"{func.__name__} 执行成功,耗时: {elapsed_time:.3f}",
extra={'performance': {'function': func.__name__, 'duration': elapsed_time}}
)
return result
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(
f"{func.__name__} 执行失败,耗时: {elapsed_time:.3f}秒,错误: {str(e)}",
extra={'performance': {'function': func.__name__, 'duration': elapsed_time}},
exc_info=True
)
raise
return wrapper
def log_api_call(service_name: str):
"""API调用日志装饰器"""
def decorator(func):
import functools
@functools.wraps(func)
def wrapper(*args, **kwargs):
logger = logging.getLogger('api_calls')
# 记录请求
logger.info(
f"调用 {service_name} API: {func.__name__}",
extra={
'api_service': service_name,
'api_method': func.__name__,
'args': str(args)[:200], # 限制长度
'kwargs': str(kwargs)[:200]
}
)
try:
result = func(*args, **kwargs)
logger.info(
f"{service_name} API 调用成功: {func.__name__}",
extra={
'api_service': service_name,
'api_method': func.__name__,
'success': True
}
)
return result
except Exception as e:
logger.error(
f"{service_name} API 调用失败: {func.__name__} - {str(e)}",
extra={
'api_service': service_name,
'api_method': func.__name__,
'success': False,
'error': str(e)
},
exc_info=True
)
raise
return wrapper
return decorator
class SessionLoggerAdapter(logging.LoggerAdapter):
"""带会话ID的日志适配器"""
def process(self, msg, kwargs):
if 'extra' not in kwargs:
kwargs['extra'] = {}
if hasattr(self, 'session_id'):
kwargs['extra']['session_id'] = self.session_id
return msg, kwargs
def get_session_logger(session_id: str, logger_name: str = 'research') -> SessionLoggerAdapter:
"""获取带会话ID的日志器"""
logger = logging.getLogger(logger_name)
adapter = SessionLoggerAdapter(logger, {})
adapter.session_id = session_id
return adapter

45
所有文件/main.py Normal file
View File

@ -0,0 +1,45 @@
"""
主路由
处理页面请求
"""
from flask import Blueprint, jsonify, current_app
main_bp = Blueprint('main', __name__)
@main_bp.route('/')
def index():
"""首页"""
return jsonify({
"message": "Welcome to DeepResearch API",
"version": "1.0.0",
"endpoints": {
"create_research": "POST /api/research",
"get_status": "GET /api/research/<session_id>/status",
"get_report": "GET /api/research/<session_id>/report",
"list_sessions": "GET /api/research/sessions"
}
})
@main_bp.route('/health')
def health_check():
"""健康检查"""
return jsonify({
"status": "healthy",
"service": "DeepResearch"
})
@main_bp.route('/config')
def get_config():
"""获取配置信息(仅开发环境)"""
if current_app.debug:
return jsonify({
"debug": current_app.debug,
"max_concurrent_subtopics": current_app.config.get('MAX_CONCURRENT_SUBTOPICS'),
"search_priorities": {
"high": current_app.config.get('MAX_SEARCHES_HIGH_PRIORITY'),
"medium": current_app.config.get('MAX_SEARCHES_MEDIUM_PRIORITY'),
"low": current_app.config.get('MAX_SEARCHES_LOW_PRIORITY')
}
})
else:
return jsonify({"error": "Not available in production"}), 403

340
所有文件/prompts.py Normal file
View File

@ -0,0 +1,340 @@
"""
所有AI模型的提示词模板
"""
PROMPTS = {
# 1. 判断问题类型
"question_type_analysis": """
请分析以下用户问题判断其属于哪种类型
用户问题{question}
请从以下类型中选择最合适的一个
1. factual - 事实查询型需要具体准确的信息
2. comparative - 分析对比型需要多角度分析和比较
3. exploratory - 探索发现型需要广泛探索未知领域
4. decision - 决策支持型需要综合分析支持决策
请直接返回类型代码factual不需要其他解释
""",
# 2. 细化问题
"refine_questions": """
基于用户的问题和问题类型请提出3-5个细化问题帮助更好地理解和研究这个主题
原始问题{question}
问题类型{question_type}
请思考
1. 还需要哪些具体信息
2. 应该关注问题的哪些方面
3. 有哪些潜在的相关维度需要探索
请以列表形式返回细化问题每个问题独占一行
""",
# 3. 初步研究思路
"research_approach": """
基于用户问题和细化问题请制定初步的研究思路
原始问题{question}
问题类型{question_type}
细化问题
{refined_questions}
请简要说明研究这个问题的整体思路和方法200字以内
""",
# 4. 制定研究大纲
"create_outline": """
请为以下研究主题制定详细的研究大纲
主题{question}
问题类型{question_type}
细化问题{refined_questions}
研究思路{research_approach}
请按以下JSON格式输出大纲
```json
{{
"main_topic": "用户输入的主题",
"research_questions": [
"核心问题1",
"核心问题2",
"核心问题3"
],
"sub_topics": [
{{
"topic": "子主题1",
"explain": "子主题1的简单解释",
"priority": "high",
"related_questions": ["核心问题1", "核心问题2"]
}},
{{
"topic": "子主题2",
"explain": "子主题2的简单解释",
"priority": "medium",
"related_questions": ["核心问题2"]
}}
]
}}
```
注意
- 子主题数量建议3-6
- priority可选值high/medium/low
- 确保子主题覆盖所有核心问题
""",
# 5. 大纲验证搜索
"outline_validation": """
请评估这个研究大纲是否完整和合理
研究大纲
{outline}
请思考并搜索验证
1. 核心问题是否全面
2. 子主题划分是否合理
3. 是否有遗漏的重要方面
如果需要改进请提供具体建议
""",
# 6. 修改大纲
"modify_outline": """
基于用户反馈和验证结果请修改研究大纲
原大纲
{original_outline}
用户反馈
{user_feedback}
验证发现的问题
{validation_issues}
请输出修改后的大纲格式与原大纲相同
重点关注用户提出的修改意见
""",
# 8. 评估搜索结果
"evaluate_search_results": """
请评估以下搜索结果对于研究子主题的重要性
子主题{subtopic}
搜索结果
标题{title}
URL{url}
摘要{snippet}
评估标准
1. 主题匹配度内容与子主题的相关程度
2. 问题覆盖度能否回答相关的核心问题
3. 信息新颖度是否提供了独特或深入的见解
请直接返回重要性级别high/medium/low
""",
# 9. 信息反思
"information_reflection": """
<think>
好的现在需要梳理已获取的信息
子主题{subtopic}
已获得信息总结
{search_summary}
让我再仔细思考总结一下是否有哪些信息非常重要需要更细节的内容
{detailed_analysis}
</think>
基于以上分析以下信息还需要进一步获取细节内容
""",
# 11. 信息结构化整合
"integrate_information": """
请将子主题的所有搜索结果整合为结构化信息
子主题{subtopic}
所有搜索结果
{all_search_results}
请按以下JSON格式输出整合后的信息
```json
{{
"key_points": [
{{
"point": "关键点描述",
"evidence": [
{{
"source_url": "https://example.com",
"confidence": "high"
}}
],
"contradictions": [],
"related_points": []
}}
],
"themes": [
{{
"theme": "主题归类",
"points": ["关键点1", "关键点2"]
}}
]
}}
```
""",
# 12. 子主题报告撰写
"write_subtopic_report": """
请基于整合的信息撰写子主题研究报告
子主题{subtopic}
整合信息
{integrated_info}
撰写要求
1. 使用以下格式
2. 每个观点必须标注来源URL
3. 保持客观准确
4. 突出关键发现和洞察
格式要求
## [子主题名称]
### 一、[主要发现1]
#### 1.1 [子标题]
[内容]来源[具体URL]
#### 1.2 [子标题]
[内容]来源[具体URL]
### 二、[主要发现2]
#### 2.1 [子标题]
[内容]来源[具体URL]
### 三、关键洞察
1. **[洞察1]**基于[来源URL]的数据显示...
2. **[洞察2]**根据[来源URL]的分析...
### 四、建议与展望
[基于研究的可执行建议]
""",
# 13. 幻觉内容检测
"hallucination_detection": """
请检查撰写内容是否存在幻觉与原始来源不符
撰写内容
{written_content}
声称的来源URL{claimed_url}
原始搜索结果中的对应内容
{original_content}
请判断
1. 撰写内容是否准确反映了原始来源
2. 是否存在夸大错误归因或无中生有
如果存在幻觉请指出具体问题
返回格式
{{
"is_hallucination": true/false,
"hallucination_type": "夸大/错误归因/无中生有/无",
"explanation": "具体说明"
}}
""",
# 14. 幻觉内容重写V3使用
"rewrite_hallucination": """
请基于原始搜索材料重新撰写这部分内容
原始内容存在幻觉
{hallucinated_content}
原始搜索材料
{original_sources}
请严格基于搜索材料重新撰写确保准确性
保持原有的格式和风格
""",
# 15. 最终报告生成
"generate_final_report": """
请基于所有子主题报告生成最终的综合研究报告
研究主题{main_topic}
研究问题{research_questions}
各子主题报告
{subtopic_reports}
要求
1. 综合各子主题的发现
2. 提炼整体洞察
3. 保持URL引用格式
4. 提供可执行的建议
报告结构
# [研究主题]
## 执行摘要
[整体研究发现概述]
## 主要发现
### 1. [综合发现1]
基于多个来源的分析...来源[URL1], [URL2]
### 2. [综合发现2]
研究表明...来源[URL3], [URL4]
## 综合洞察
[基于所有研究的深度洞察]
## 建议
[具体可执行的建议]
## 详细子主题报告
[插入所有子主题的详细报告]
"""
}
# 搜索相关的提示词V3使用
SEARCH_PROMPTS = {
"generate_search_queries": """
为以下子主题生成{count}个搜索查询
子主题{subtopic}
子主题说明{explanation}
相关问题{related_questions}
要求
1. 查询要具体有针对性
2. 覆盖不同角度
3. 使用不同的关键词组合
4. 每个查询独占一行
请直接返回搜索查询列表
""",
"generate_refined_queries": """
基于信息反思结果为以下重点生成细节搜索查询
重点信息{key_info}
需要的细节{detail_needed}
请生成3个针对性的搜索查询每个查询独占一行
"""
}
def get_prompt(prompt_name: str, **kwargs) -> str:
"""获取并格式化提示词"""
if prompt_name in PROMPTS:
return PROMPTS[prompt_name].format(**kwargs)
elif prompt_name in SEARCH_PROMPTS:
return SEARCH_PROMPTS[prompt_name].format(**kwargs)
else:
raise ValueError(f"Unknown prompt: {prompt_name}")

254
所有文件/r1_agent.py Normal file
View File

@ -0,0 +1,254 @@
"""
DeepSeek R1模型智能体
负责推理判断规划撰写等思考密集型任务
"""
import json
import logging
from typing import Dict, List, Any, Optional
from openai import OpenAI
from config import Config
from app.agents.prompts import get_prompt
from app.utils.json_parser import parse_json_safely
logger = logging.getLogger(__name__)
class R1Agent:
"""R1模型智能体"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.DEEPSEEK_API_KEY
base_url = Config.DEEPSEEK_BASE_URL
# 火山引擎 ARK 平台使用不同的模型名称
if 'volces.com' in base_url:
self.model = "deepseek-r1-250120" # 火山引擎的 R1 模型名称
else:
self.model = Config.R1_MODEL
self.client = OpenAI(
api_key=self.api_key,
base_url=base_url
)
def _call_api(self, prompt: str, temperature: float = 0.7,
max_tokens: int = 4096, json_mode: bool = False) -> str:
"""调用R1 API"""
try:
messages = [{"role": "user", "content": prompt}]
# 对于JSON输出使用补全技巧
if json_mode and "```json" in prompt:
# 提取到```json之前的部分作为prompt
prefix = prompt.split("```json")[0] + "```json\n"
messages = [
{"role": "user", "content": prefix},
{"role": "assistant", "content": "```json\n"}
]
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
content = response.choices[0].message.content
# 如果是JSON模式提取JSON内容
if json_mode:
if "```json" in content:
json_start = content.find("```json") + 7
json_end = content.find("```", json_start)
if json_end > json_start:
content = content[json_start:json_end].strip()
elif content.startswith("```json\n"):
# 补全模式的响应
content = content[8:]
if content.endswith("```"):
content = content[:-3]
return content.strip()
except Exception as e:
logger.error(f"R1 API调用失败: {e}")
raise
def analyze_question_type(self, question: str) -> str:
"""分析问题类型"""
prompt = get_prompt("question_type_analysis", question=question)
result = self._call_api(prompt, temperature=0.3)
# 验证返回值
valid_types = ["factual", "comparative", "exploratory", "decision"]
result = result.lower().strip()
if result not in valid_types:
logger.warning(f"无效的问题类型: {result}默认使用exploratory")
return "exploratory"
return result
def refine_questions(self, question: str, question_type: str) -> List[str]:
"""细化问题"""
prompt = get_prompt("refine_questions",
question=question,
question_type=question_type)
result = self._call_api(prompt)
# 解析结果为列表
questions = [q.strip() for q in result.split('\n') if q.strip()]
# 过滤掉可能的序号
questions = [q.lstrip('0123456789.-) ') for q in questions]
return questions[:5] # 最多返回5个
def create_research_approach(self, question: str, question_type: str,
refined_questions: List[str]) -> str:
"""制定研究思路"""
refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions)
prompt = get_prompt("research_approach",
question=question,
question_type=question_type,
refined_questions=refined_questions_text)
return self._call_api(prompt)
def create_outline(self, question: str, question_type: str,
refined_questions: List[str], research_approach: str) -> Dict[str, Any]:
"""创建研究大纲"""
refined_questions_text = '\n'.join(f"- {q}" for q in refined_questions)
prompt = get_prompt("create_outline",
question=question,
question_type=question_type,
refined_questions=refined_questions_text,
research_approach=research_approach)
# 尝试获取JSON格式的大纲
for attempt in range(3):
try:
result = self._call_api(prompt, temperature=0.5, json_mode=True)
outline = parse_json_safely(result)
# 验证必要字段
if all(key in outline for key in ["main_topic", "research_questions", "sub_topics"]):
return outline
else:
logger.warning(f"大纲缺少必要字段,第{attempt+1}次尝试")
except Exception as e:
logger.error(f"解析大纲失败,第{attempt+1}次尝试: {e}")
# 返回默认大纲
return {
"main_topic": question,
"research_questions": refined_questions[:3],
"sub_topics": [
{
"topic": "主要方面分析",
"explain": "针对问题的核心方面进行深入分析",
"priority": "high",
"related_questions": refined_questions[:2]
}
]
}
def validate_outline(self, outline: Dict[str, Any]) -> str:
"""验证大纲完整性"""
prompt = get_prompt("outline_validation", outline=json.dumps(outline, ensure_ascii=False))
return self._call_api(prompt)
def modify_outline(self, original_outline: Dict[str, Any],
user_feedback: str, validation_issues: str) -> Dict[str, Any]:
"""修改大纲"""
prompt = get_prompt("modify_outline",
original_outline=json.dumps(original_outline, ensure_ascii=False),
user_feedback=user_feedback,
validation_issues=validation_issues)
result = self._call_api(prompt, json_mode=True)
return parse_json_safely(result)
def evaluate_search_result(self, subtopic: str, title: str,
url: str, snippet: str) -> str:
"""评估搜索结果重要性"""
prompt = get_prompt("evaluate_search_results",
subtopic=subtopic,
title=title,
url=url,
snippet=snippet)
result = self._call_api(prompt, temperature=0.3).lower().strip()
# 验证返回值
if result not in ["high", "medium", "low"]:
return "medium"
return result
def reflect_on_information(self, subtopic: str, search_summary: str) -> List[Dict[str, str]]:
"""信息反思,返回需要深入搜索的要点"""
# 这里可以基于search_summary生成更详细的分析
prompt = get_prompt("information_reflection",
subtopic=subtopic,
search_summary=search_summary,
detailed_analysis="[基于搜索结果的详细分析]")
result = self._call_api(prompt)
# 解析结果,提取需要深入的要点
# 简单实现,实际可能需要更复杂的解析
key_points = []
lines = result.split('\n')
for line in lines:
if line.strip() and '还需要搜索' in line:
parts = line.split('还需要搜索')
if len(parts) == 2:
key_points.append({
"key_info": parts[0].strip(),
"detail_needed": parts[1].strip('() ')
})
return key_points
def integrate_information(self, subtopic: str, all_search_results: str) -> Dict[str, Any]:
"""整合信息为结构化格式"""
prompt = get_prompt("integrate_information",
subtopic=subtopic,
all_search_results=all_search_results)
result = self._call_api(prompt, json_mode=True)
return parse_json_safely(result)
def write_subtopic_report(self, subtopic: str, integrated_info: Dict[str, Any]) -> str:
"""撰写子主题报告"""
prompt = get_prompt("write_subtopic_report",
subtopic=subtopic,
integrated_info=json.dumps(integrated_info, ensure_ascii=False))
return self._call_api(prompt, temperature=0.7, max_tokens=8192)
def detect_hallucination(self, written_content: str, claimed_url: str,
original_content: str) -> Dict[str, Any]:
"""检测幻觉内容"""
prompt = get_prompt("hallucination_detection",
written_content=written_content,
claimed_url=claimed_url,
original_content=original_content)
result = self._call_api(prompt, temperature=0.3, json_mode=True)
return parse_json_safely(result)
def generate_final_report(self, main_topic: str, research_questions: List[str],
subtopic_reports: Dict[str, str]) -> str:
"""生成最终报告"""
# 格式化子主题报告
reports_text = "\n\n---\n\n".join([
f"### {topic}\n{report}"
for topic, report in subtopic_reports.items()
])
prompt = get_prompt("generate_final_report",
main_topic=main_topic,
research_questions='\n'.join(f"- {q}" for q in research_questions),
subtopic_reports=reports_text)
return self._call_api(prompt, temperature=0.7, max_tokens=16384)

171
所有文件/report.py Normal file
View File

@ -0,0 +1,171 @@
"""
研究报告数据模型
"""
import uuid
from datetime import datetime
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class ReportSection(BaseModel):
"""报告章节"""
title: str
content: str
subsections: List['ReportSection'] = []
sources: List[str] = [] # URL列表
def to_markdown(self, level: int = 1) -> str:
"""转换为Markdown格式"""
header = "#" * level
markdown = f"{header} {self.title}\n\n{self.content}\n\n"
# 添加子章节
for subsection in self.subsections:
markdown += subsection.to_markdown(level + 1)
# 添加来源
if self.sources:
markdown += f"\n{'#' * (level + 1)} 参考来源\n"
for i, source in enumerate(self.sources, 1):
markdown += f"{i}. [{source}]({source})\n"
markdown += "\n"
return markdown
# 允许递归引用
ReportSection.model_rebuild()
class KeyInsight(BaseModel):
"""关键洞察"""
insight: str
supporting_evidence: List[str] = []
source_urls: List[str] = []
confidence: float = 0.0 # 0-1之间
class SubtopicReport(BaseModel):
"""子主题报告"""
subtopic_id: str
subtopic_name: str
sections: List[ReportSection] = []
key_insights: List[KeyInsight] = []
recommendations: List[str] = []
created_at: datetime = Field(default_factory=datetime.now)
word_count: int = 0
def to_markdown(self) -> str:
"""转换为Markdown格式"""
markdown = f"## {self.subtopic_name}\n\n"
# 添加各个章节
for section in self.sections:
markdown += section.to_markdown(level=3)
# 添加关键洞察
if self.key_insights:
markdown += "### 关键洞察\n\n"
for i, insight in enumerate(self.key_insights, 1):
markdown += f"{i}. **{insight.insight}**\n"
if insight.supporting_evidence:
for evidence in insight.supporting_evidence:
markdown += f" - {evidence}\n"
if insight.source_urls:
markdown += f" - 来源: "
markdown += ", ".join([f"[{i+1}]({url})" for i, url in enumerate(insight.source_urls)])
markdown += "\n"
markdown += "\n"
# 添加建议
if self.recommendations:
markdown += "### 建议与展望\n\n"
for recommendation in self.recommendations:
markdown += f"- {recommendation}\n"
markdown += "\n"
return markdown
class HallucinationCheck(BaseModel):
"""幻觉检查记录"""
content: str
source_url: str
original_text: Optional[str] = None
is_hallucination: bool = False
hallucination_type: Optional[str] = None # 夸大/错误归因/无中生有
corrected_content: Optional[str] = None
checked_at: datetime = Field(default_factory=datetime.now)
class FinalReport(BaseModel):
"""最终研究报告"""
session_id: str
title: str
executive_summary: str
main_findings: List[ReportSection] = []
subtopic_reports: List[SubtopicReport] = []
overall_insights: List[KeyInsight] = []
recommendations: List[str] = []
methodology: Optional[str] = None
limitations: List[str] = []
created_at: datetime = Field(default_factory=datetime.now)
total_sources: int = 0
total_searches: int = 0
def to_markdown(self) -> str:
"""转换为完整的Markdown报告"""
markdown = f"# {self.title}\n\n"
markdown += f"*生成时间: {self.created_at.strftime('%Y-%m-%d %H:%M:%S')}*\n\n"
# 执行摘要
markdown += "## 执行摘要\n\n"
markdown += f"{self.executive_summary}\n\n"
# 主要发现
if self.main_findings:
markdown += "## 主要发现\n\n"
for finding in self.main_findings:
markdown += finding.to_markdown(level=3)
# 整体洞察
if self.overall_insights:
markdown += "## 综合洞察\n\n"
for i, insight in enumerate(self.overall_insights, 1):
markdown += f"### {i}. {insight.insight}\n\n"
if insight.supporting_evidence:
for evidence in insight.supporting_evidence:
markdown += f"- {evidence}\n"
markdown += "\n"
# 建议
if self.recommendations:
markdown += "## 建议\n\n"
for recommendation in self.recommendations:
markdown += f"- {recommendation}\n"
markdown += "\n"
# 详细子主题报告
markdown += "## 详细分析\n\n"
for report in self.subtopic_reports:
markdown += report.to_markdown()
markdown += "---\n\n"
# 研究方法
if self.methodology:
markdown += "## 研究方法\n\n"
markdown += f"{self.methodology}\n\n"
# 局限性
if self.limitations:
markdown += "## 研究局限性\n\n"
for limitation in self.limitations:
markdown += f"- {limitation}\n"
markdown += "\n"
# 统计信息
markdown += "## 研究统计\n\n"
markdown += f"- 总搜索次数: {self.total_searches}\n"
markdown += f"- 引用来源数: {self.total_sources}\n"
markdown += f"- 分析子主题数: {len(self.subtopic_reports)}\n"
return markdown
def save_to_file(self, filepath: str):
"""保存为Markdown文件"""
with open(filepath, 'w', encoding='utf-8') as f:
f.write(self.to_markdown())

View File

@ -0,0 +1,347 @@
"""
报告生成服务
负责生成各类研究报告
"""
import os
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional
from app.models.report import (
SubtopicReport, FinalReport, ReportSection,
KeyInsight, HallucinationCheck
)
from app.models.research import ResearchSession, Subtopic
from app.models.search_result import SearchResult
from config import Config
logger = logging.getLogger(__name__)
class ReportGenerator:
"""报告生成器"""
def generate_subtopic_report(self, subtopic: Subtopic,
integrated_info: Dict[str, Any],
report_content: str) -> SubtopicReport:
"""生成子主题报告"""
try:
# 解析报告内容为结构化格式
sections = self._parse_report_sections(report_content)
key_insights = self._extract_key_insights(report_content)
recommendations = self._extract_recommendations(report_content)
# 统计字数
word_count = len(report_content.replace(" ", ""))
# 创建子主题报告
report = SubtopicReport(
subtopic_id=subtopic.id,
subtopic_name=subtopic.topic,
sections=sections,
key_insights=key_insights,
recommendations=recommendations,
word_count=word_count
)
return report
except Exception as e:
logger.error(f"生成子主题报告失败: {e}")
# 返回基本报告
return SubtopicReport(
subtopic_id=subtopic.id,
subtopic_name=subtopic.topic,
sections=[
ReportSection(
title="报告内容",
content=report_content
)
]
)
def generate_final_report(self, session: ResearchSession,
subtopic_reports: List[SubtopicReport],
final_content: str) -> FinalReport:
"""生成最终报告"""
try:
# 解析最终报告内容
executive_summary = self._extract_executive_summary(final_content)
main_findings = self._parse_main_findings(final_content)
overall_insights = self._extract_overall_insights(final_content)
recommendations = self._extract_final_recommendations(final_content)
# 统计信息
total_sources = self._count_total_sources(subtopic_reports)
total_searches = self._count_total_searches(session)
# 创建最终报告
report = FinalReport(
session_id=session.id,
title=session.question,
executive_summary=executive_summary,
main_findings=main_findings,
subtopic_reports=subtopic_reports,
overall_insights=overall_insights,
recommendations=recommendations,
methodology=self._generate_methodology(session),
limitations=self._identify_limitations(session),
total_sources=total_sources,
total_searches=total_searches
)
return report
except Exception as e:
logger.error(f"生成最终报告失败: {e}")
# 返回基本报告
return FinalReport(
session_id=session.id,
title=session.question,
executive_summary="研究报告生成过程中出现错误。",
subtopic_reports=subtopic_reports,
total_sources=total_sources if 'total_sources' in locals() else 0,
total_searches=total_searches if 'total_searches' in locals() else 0
)
def save_report(self, report: FinalReport, format: str = "markdown") -> str:
"""保存报告到文件"""
try:
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{report.session_id}_{timestamp}.md"
filepath = os.path.join(Config.REPORTS_DIR, filename)
# 保存文件
if format == "markdown":
report.save_to_file(filepath)
else:
# 未来可以支持其他格式PDF、HTML等
raise ValueError(f"不支持的格式: {format}")
logger.info(f"报告已保存: {filepath}")
return filepath
except Exception as e:
logger.error(f"保存报告失败: {e}")
raise
def create_hallucination_report(self, hallucinations: List[Dict[str, Any]]) -> str:
"""创建幻觉检测报告"""
if not hallucinations:
return "未检测到幻觉内容。"
report_lines = ["# 幻觉检测报告", ""]
report_lines.append(f"共检测到 {len(hallucinations)} 处可能的幻觉内容:")
report_lines.append("")
for i, h in enumerate(hallucinations, 1):
report_lines.extend([
f"## {i}. {h.get('type', '未知类型')}",
f"**URL**: {h.get('url', 'N/A')}",
f"**原始内容**: {h.get('content', 'N/A')}",
f"**说明**: {h.get('explanation', '无说明')}",
""
])
return '\n'.join(report_lines)
# ========== 解析辅助方法 ==========
def _parse_report_sections(self, content: str) -> List[ReportSection]:
"""解析报告章节"""
sections = []
# 简单的Markdown解析
lines = content.split('\n')
current_section = None
current_content = []
for line in lines:
if line.startswith('### '):
# 保存前一个章节
if current_section:
current_section.content = '\n'.join(current_content).strip()
sections.append(current_section)
# 开始新章节
current_section = ReportSection(title=line[4:].strip(), content="")
current_content = []
elif line.startswith('#### ') and current_section:
# 子章节
subsection_title = line[5:].strip()
# 收集子章节内容(简化处理)
current_content.append(line)
elif current_section:
current_content.append(line)
# 保存最后一个章节
if current_section:
current_section.content = '\n'.join(current_content).strip()
sections.append(current_section)
return sections
def _extract_key_insights(self, content: str) -> List[KeyInsight]:
"""提取关键洞察"""
insights = []
# 查找"关键洞察"部分
lines = content.split('\n')
in_insights_section = False
for i, line in enumerate(lines):
if '关键洞察' in line and line.startswith('#'):
in_insights_section = True
continue
if in_insights_section:
if line.startswith('#') and '关键洞察' not in line:
break
if line.strip().startswith(('1.', '2.', '3.', '4.', '5.')):
# 提取洞察内容
insight_text = line.split('.', 1)[1].strip()
# 移除Markdown格式
insight_text = insight_text.replace('**', '').replace('*', '')
# 查找来源URL
source_urls = self._extract_urls_from_text(insight_text)
insights.append(KeyInsight(
insight=insight_text.split('')[0] if '' in insight_text else insight_text,
source_urls=source_urls,
confidence=0.8 # 默认置信度
))
return insights
def _extract_recommendations(self, content: str) -> List[str]:
"""提取建议"""
recommendations = []
lines = content.split('\n')
in_recommendations_section = False
for line in lines:
if '建议' in line and line.startswith('#'):
in_recommendations_section = True
continue
if in_recommendations_section:
if line.startswith('#') and '建议' not in line:
break
if line.strip().startswith(('-', '*', '')):
recommendation = line.strip()[1:].strip()
if recommendation:
recommendations.append(recommendation)
return recommendations
def _extract_executive_summary(self, content: str) -> str:
"""提取执行摘要"""
lines = content.split('\n')
in_summary = False
summary_lines = []
for line in lines:
if '执行摘要' in line and line.startswith('#'):
in_summary = True
continue
if in_summary:
if line.startswith('#'):
break
summary_lines.append(line)
return '\n'.join(summary_lines).strip()
def _parse_main_findings(self, content: str) -> List[ReportSection]:
"""解析主要发现"""
# 类似于_parse_report_sections但只关注"主要发现"部分
# 简化实现
return []
def _extract_overall_insights(self, content: str) -> List[KeyInsight]:
"""提取整体洞察"""
# 类似于_extract_key_insights但关注"综合洞察"部分
return []
def _extract_final_recommendations(self, content: str) -> List[str]:
"""提取最终建议"""
# 类似于_extract_recommendations
return []
def _extract_urls_from_text(self, text: str) -> List[str]:
"""从文本中提取URL"""
import re
# 简单的URL提取
url_pattern = r'https?://[^\s)]+|www\.[^\s)]+'
urls = re.findall(url_pattern, text)
# 清理URL
cleaned_urls = []
for url in urls:
# 移除末尾的标点
url = url.rstrip('.,;:!?)')
if url:
cleaned_urls.append(url)
return cleaned_urls
def _count_total_sources(self, subtopic_reports: List[SubtopicReport]) -> int:
"""统计总来源数"""
all_urls = set()
for report in subtopic_reports:
for section in report.sections:
all_urls.update(section.sources)
for insight in report.key_insights:
all_urls.update(insight.source_urls)
return len(all_urls)
def _count_total_searches(self, session: ResearchSession) -> int:
"""统计总搜索次数"""
if not session.outline:
return 0
total = 0
for subtopic in session.outline.sub_topics:
total += subtopic.get_total_searches()
return total
def _generate_methodology(self, session: ResearchSession) -> str:
"""生成研究方法说明"""
methodology = f"""
本研究采用系统化的深度研究方法具体流程如下
1. **问题分析**: 识别问题类型为"{session.question_type.value if session.question_type else '未知'}"并细化为{len(session.refined_questions)}个具体问题
2. **研究规划**: 制定包含{len(session.outline.sub_topics) if session.outline else 0}个子主题的研究大纲每个子主题根据重要性分配不同的搜索资源
3. **信息收集**: 使用Tavily搜索引擎进行多轮搜索共执行{self._count_total_searches(session)}次搜索
4. **质量控制**: 通过AI评估搜索结果重要性并进行幻觉检测和内容验证
5. **综合分析**: 整合所有信息提炼关键洞察形成结构化报告
"""
return methodology.strip()
def _identify_limitations(self, session: ResearchSession) -> List[str]:
"""识别研究局限性"""
limitations = [
"搜索结果受限于公开可访问的网络信息",
"部分专业领域可能缺乏深度分析",
"时效性信息可能存在延迟"
]
# 根据实际情况添加更多局限性
if session.outline and any(st.status == "cancelled" for st in session.outline.sub_topics):
limitations.append("部分子主题研究未完成")
return limitations

View File

@ -0,0 +1,17 @@
Flask==3.0.0
Flask-CORS==4.0.0
Flask-SocketIO==5.3.5
python-socketio==5.10.0
python-dotenv==1.0.0
requests==2.31.0
openai>=1.0.0,<2.0.0
tavily-python==0.5.0
celery==5.3.4
redis==5.0.1
pymongo==4.6.1
pydantic==2.5.3
python-json-logger==2.0.7
pytest==7.4.4
pytest-asyncio==0.23.3
gunicorn==21.2.0
eventlet==0.33.3

View File

@ -0,0 +1,149 @@
// app/static/js/research-tree.js
function renderTree(session, outline) {
const container = document.getElementById('treeContainer');
if (!container) {
console.error('Tree container not found!');
return;
}
container.innerHTML = '';
// 根节点 - 始终显示
const rootNode = document.createElement('div');
rootNode.className = 'root-node';
rootNode.innerHTML = `
<h2>${session.question || '研究问题'}</h2>
<p>状态${getStatusText(session.status)} |
开始时间${new Date(session.created_at).toLocaleString()}</p>
${session.error_message ? `<p style="color: #ff6b6b;">错误: ${session.error_message}</p>` : ''}
`;
container.appendChild(rootNode);
// 如果出错,显示错误信息
if (session.status === 'error') {
const errorNode = createTreeNode('研究出现错误', 'error');
container.appendChild(wrapInTreeNode(errorNode));
return;
}
// 研究准备节点
const prepNode = createTreeNode('研究准备', session.refined_questions ? 'completed' : session.status);
prepNode.onclick = () => showDetail('preparation', session);
if (session.refined_questions) {
const content = document.createElement('div');
content.className = 'node-content expanded';
content.innerHTML = `
<div class="phase-card">
<h4>🎯 问题细化</h4>
<ul>
${session.refined_questions.map(q => `<li>• ${q}</li>`).join('')}
</ul>
</div>
`;
prepNode.querySelector('.node-card').appendChild(content);
}
container.appendChild(wrapInTreeNode(prepNode));
// 大纲节点
if (outline) {
const outlineNode = createTreeNode('研究大纲', 'completed');
const outlineContent = document.createElement('div');
outlineContent.className = 'node-content expanded';
outlineContent.innerHTML = `
<div class="phase-card">
<h4>📋 主要研究问题</h4>
<ul>
${outline.research_questions.map(q => `<li>• ${q}</li>`).join('')}
</ul>
</div>
`;
outlineNode.querySelector('.node-card').appendChild(outlineContent);
const outlineWrapper = wrapInTreeNode(outlineNode);
container.appendChild(outlineWrapper);
// 子主题节点
outline.sub_topics.forEach((subtopic, idx) => {
const subtopicNode = createSubtopicNode(subtopic, idx + 1);
outlineWrapper.appendChild(wrapInTreeNode(subtopicNode, true));
});
} else {
// 显示大纲创建中或失败
const outlineStatus = session.status === 'outlining' ? 'processing' :
session.status === 'error' ? 'error' : 'pending';
const outlineNode = createTreeNode('研究大纲', outlineStatus);
container.appendChild(wrapInTreeNode(outlineNode));
}
// 最终报告节点
const reportNode = createTreeNode('研究报告生成', session.final_report ? 'completed' : 'pending');
container.appendChild(wrapInTreeNode(reportNode));
}
function createTreeNode(title, status) {
const node = document.createElement('div');
const statusInfo = getStatusInfo(status);
node.innerHTML = `
<div class="node-card ${statusInfo.className}">
<div class="node-header">
<div class="node-title-wrapper">
<span class="node-title">${title}</span>
</div>
<div class="node-status">
<span class="status-icon ${statusInfo.className}">${statusInfo.icon}</span>
</div>
</div>
</div>
`;
return node;
}
function createSubtopicNode(subtopic, index) {
const node = document.createElement('div');
const statusInfo = getStatusInfo(subtopic.status);
node.innerHTML = `
<div class="node-card ${statusInfo.className}" onclick="showDetail('subtopic', ${JSON.stringify(subtopic).replace(/"/g, '&quot;')})">
<div class="node-header">
<div class="node-title-wrapper">
<span class="node-title">子主题${index}${subtopic.topic}</span>
</div>
<div class="node-status">
<span class="priority-badge ${subtopic.priority}">
${subtopic.priority === 'high' ? '高' : subtopic.priority === 'medium' ? '中' : '低'}优先级
</span>
<span class="status-icon ${statusInfo.className}">${statusInfo.icon}</span>
</div>
</div>
</div>
`;
return node;
}
function wrapInTreeNode(node, isSubtopic = false) {
const wrapper = document.createElement('div');
wrapper.className = 'tree-node' + (isSubtopic ? ' subtopic-node' : '');
wrapper.appendChild(node);
return wrapper;
}
function getStatusInfo(status) {
const statusMap = {
'pending': { icon: '○', className: 'pending' },
'analyzing': { icon: '●', className: 'processing' },
'outlining': { icon: '●', className: 'processing' },
'researching': { icon: '●', className: 'processing' },
'writing': { icon: '●', className: 'processing' },
'reviewing': { icon: '●', className: 'processing' },
'completed': { icon: '✓', className: 'completed' },
'error': { icon: '✗', className: 'error' },
'cancelled': { icon: '⊘', className: 'cancelled' }
};
return statusMap[status] || statusMap['pending'];
}

View File

@ -0,0 +1,46 @@
{% extends "base.html" %}
{% block header_content %}
<div class="progress-bar" id="progressBar">
<div class="progress-fill" id="progressFill" style="width: 0%"></div>
</div>
<div class="progress-message" id="progressMessage"></div>
{% endblock %}
{% block content %}
<div class="research-view">
<div class="tree-container" id="treeContainer">
<!-- 动态生成的研究树 -->
</div>
<div class="action-buttons">
<button class="action-button" onclick="window.location.href='/'">返回</button>
<button class="action-button primary" id="downloadBtn" style="display: none;" onclick="downloadReport()">
📄 下载报告
</button>
<button class="action-button danger" id="cancelBtn" onclick="cancelResearch()">
✕ 取消研究
</button>
</div>
</div>
<!-- 详情面板 -->
<div class="detail-panel" id="detailPanel">
<div class="panel-header">
<h3>详细信息</h3>
<button class="panel-close" onclick="closePanel()">×</button>
</div>
<div class="panel-content" id="panelContent">
<!-- 动态内容 -->
</div>
</div>
<script>
const SESSION_ID = "{{ session_id }}";
</script>
{% endblock %}
{% block extra_js %}
<script src="{{ url_for('static', filename='js/research-tree.js') }}"></script>
<script src="{{ url_for('static', filename='js/research.js') }}"></script>
{% endblock %}

156
所有文件/research.js Normal file
View File

@ -0,0 +1,156 @@
// app/static/js/research.js
let socket = null;
let currentSession = null;
document.addEventListener('DOMContentLoaded', function() {
initWebSocket();
loadSessionData();
// 定期刷新状态
setInterval(loadSessionData, 3000);
});
function initWebSocket() {
socket = io();
socket.on('connect', function() {
console.log('WebSocket connected');
socket.emit('join_session', { session_id: SESSION_ID });
});
socket.on('progress_update', function(data) {
updateProgress(data.percentage, data.message);
});
socket.on('status_changed', function(data) {
loadSessionData();
});
socket.on('subtopic_updated', function(data) {
loadSessionData();
});
socket.on('report_available', function(data) {
document.getElementById('downloadBtn').style.display = 'block';
document.getElementById('cancelBtn').style.display = 'none';
});
}
async function loadSessionData() {
try {
const status = await api.getSessionStatus(SESSION_ID);
currentSession = status;
updateProgress(status.progress_percentage || 0, status.current_phase || '准备中');
// 始终尝试渲染基础树结构
renderTree(status, null);
// 如果有大纲,加载大纲
if (status.status !== 'pending' && status.status !== 'analyzing') {
try {
const outline = await api.getOutline(SESSION_ID);
renderTree(status, outline);
} catch (error) {
console.log('大纲尚未创建');
}
}
// 如果完成,显示下载按钮
if (status.status === 'completed') {
document.getElementById('downloadBtn').style.display = 'block';
document.getElementById('cancelBtn').style.display = 'none';
}
} catch (error) {
console.error('Failed to load session data:', error);
}
}
function updateProgress(percentage, message) {
document.getElementById('progressFill').style.width = percentage + '%';
document.getElementById('progressMessage').textContent = message;
}
async function cancelResearch() {
if (confirm('确定要取消当前研究吗?')) {
try {
await api.cancelResearch(SESSION_ID);
alert('研究已取消');
window.location.href = '/';
} catch (error) {
alert('取消失败,请重试');
}
}
}
async function downloadReport() {
api.downloadReport(SESSION_ID);
}
function showDetail(type, data) {
const panel = document.getElementById('detailPanel');
const content = document.getElementById('panelContent');
// 根据类型渲染不同的内容
let html = '';
switch(type) {
case 'preparation':
html = `
<h4>研究准备</h4>
${data.refined_questions ? `
<div class="detail-section">
<h5>细化的问题</h5>
<ul>
${data.refined_questions.map(q => `<li>${q}</li>`).join('')}
</ul>
</div>
` : ''}
${data.research_approach ? `
<div class="detail-section">
<h5>研究思路</h5>
<p>${data.research_approach}</p>
</div>
` : ''}
`;
break;
case 'subtopic':
html = `
<h4>${data.topic}</h4>
<p>${data.explain}</p>
<div class="meta-info">
<span>优先级${data.priority}</span>
<span>状态${getStatusText(data.status)}</span>
</div>
`;
break;
default:
html = '<p>暂无详细信息</p>';
}
content.innerHTML = html;
panel.classList.add('open');
}
function closePanel() {
document.getElementById('detailPanel').classList.remove('open');
}
function getStatusText(status) {
const statusMap = {
'pending': '等待中',
'analyzing': '分析中',
'outlining': '制定大纲',
'researching': '研究中',
'writing': '撰写中',
'reviewing': '审核中',
'completed': '已完成',
'error': '错误',
'cancelled': '已取消'
};
return statusMap[status] || status;
}

126
所有文件/research.py Normal file
View File

@ -0,0 +1,126 @@
"""
研究会话数据模型
"""
import json
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field
class QuestionType(str, Enum):
"""问题类型枚举"""
FACTUAL = "factual" # 事实查询型
COMPARATIVE = "comparative" # 分析对比型
EXPLORATORY = "exploratory" # 探索发现型
DECISION = "decision" # 决策支持型
class ResearchStatus(str, Enum):
"""研究状态枚举"""
PENDING = "pending"
ANALYZING = "analyzing"
OUTLINING = "outlining"
RESEARCHING = "researching"
WRITING = "writing"
REVIEWING = "reviewing"
COMPLETED = "completed"
ERROR = "error"
CANCELLED = "cancelled"
class SubtopicPriority(str, Enum):
"""子主题优先级"""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class Subtopic(BaseModel):
"""子主题模型"""
id: str = Field(default_factory=lambda: f"ST{uuid.uuid4().hex[:8]}")
topic: str
explain: str
priority: SubtopicPriority
related_questions: List[str] = []
status: ResearchStatus = ResearchStatus.PENDING
search_count: int = 0
max_searches: int = 15
searches: List[Dict[str, Any]] = []
refined_searches: List[Dict[str, Any]] = []
integrated_info: Optional[Dict[str, Any]] = None
report: Optional[str] = None
hallucination_checks: List[Dict[str, Any]] = []
def get_total_searches(self) -> int:
"""获取总搜索次数"""
return len(self.searches) + len(self.refined_searches)
class ResearchOutline(BaseModel):
"""研究大纲模型"""
main_topic: str
research_questions: List[str]
sub_topics: List[Subtopic]
version: int = 1
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class ResearchSession(BaseModel):
"""研究会话模型"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
question: str
question_type: Optional[QuestionType] = None
refined_questions: List[str] = []
research_approach: Optional[str] = None
status: ResearchStatus = ResearchStatus.PENDING
outline: Optional[ResearchOutline] = None
final_report: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
completed_at: Optional[datetime] = None
error_message: Optional[str] = None
# 进度追踪
total_steps: int = 0
completed_steps: int = 0
current_phase: str = "初始化"
def update_status(self, status: ResearchStatus):
"""更新状态"""
self.status = status
self.updated_at = datetime.now()
if status == ResearchStatus.COMPLETED:
self.completed_at = datetime.now()
def get_progress_percentage(self) -> float:
"""获取进度百分比"""
if self.total_steps == 0:
return 0.0
return (self.completed_steps / self.total_steps) * 100
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
data = self.dict()
# 转换datetime对象为字符串
for key in ['created_at', 'updated_at', 'completed_at']:
if data.get(key):
data[key] = data[key].isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ResearchSession':
"""从字典创建实例"""
# 转换字符串为datetime对象
for key in ['created_at', 'updated_at', 'completed_at']:
if data.get(key) and isinstance(data[key], str):
data[key] = datetime.fromisoformat(data[key])
return cls(**data)
def save_to_file(self, filepath: str):
"""保存到文件"""
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
@classmethod
def load_from_file(cls, filepath: str) -> 'ResearchSession':
"""从文件加载"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return cls.from_dict(data)

View File

@ -0,0 +1,323 @@
# 文件位置: app/services/research_manager.py
# 文件名: research_manager.py
"""
研究流程管理器
协调整个研究过程
"""
import os
import json
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional
from app.models.research import ResearchSession, ResearchStatus, ResearchOutline, Subtopic
from app.services.ai_service import AIService
from app.services.search_service import SearchService
from app.services.report_generator import ReportGenerator
from config import Config
logger = logging.getLogger(__name__)
class ResearchManager:
"""研究流程管理器"""
def __init__(self):
self.ai_service = AIService()
self.search_service = SearchService()
self.report_generator = ReportGenerator()
self.sessions: Dict[str, ResearchSession] = {}
def create_session(self, question: str) -> ResearchSession:
"""创建新的研究会话"""
session = ResearchSession(question=question)
self.sessions[session.id] = session
# 保存到文件
self._save_session(session)
logger.info(f"创建研究会话: {session.id}")
return session
def start_research(self, session_id: str) -> Dict[str, Any]:
"""启动研究流程"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
try:
# 更新状态
session.update_status(ResearchStatus.ANALYZING)
self._save_session(session)
# 启动异步任务链
# 延迟导入,完全避免循环依赖
from app.tasks.research_tasks import analyze_question_chain
analyze_question_chain.delay(session_id)
return {
"status": "started",
"session_id": session_id,
"message": "研究已开始"
}
except Exception as e:
logger.error(f"启动研究失败: {e}")
session.update_status(ResearchStatus.ERROR)
session.error_message = str(e)
self._save_session(session)
raise
def get_session(self, session_id: str) -> Optional[ResearchSession]:
"""获取研究会话"""
# 先从内存查找
if session_id in self.sessions:
return self.sessions[session_id]
# 从文件加载
filepath = self._get_session_filepath(session_id)
if os.path.exists(filepath):
session = ResearchSession.load_from_file(filepath)
self.sessions[session_id] = session
return session
return None
def update_session(self, session_id: str, updates: Dict[str, Any]):
"""更新会话信息"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 更新字段
for key, value in updates.items():
if hasattr(session, key):
setattr(session, key, value)
session.updated_at = datetime.now()
self._save_session(session)
def get_session_status(self, session_id: str) -> Dict[str, Any]:
"""获取研究进度"""
session = self.get_session(session_id)
if not session:
return {"error": "Session not found"}
# 计算子主题进度
subtopic_progress = []
if session.outline:
for subtopic in session.outline.sub_topics:
subtopic_progress.append({
"id": subtopic.id,
"topic": subtopic.topic,
"status": subtopic.status,
"progress": subtopic.get_total_searches() / subtopic.max_searches * 100
})
return {
"session_id": session_id,
"status": session.status,
"current_phase": session.current_phase,
"progress_percentage": session.get_progress_percentage(),
"subtopic_progress": subtopic_progress,
"created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(),
"error_message": session.error_message
}
def cancel_research(self, session_id: str) -> Dict[str, Any]:
"""取消研究"""
session = self.get_session(session_id)
if not session:
return {"error": "Session not found"}
# 更新状态
session.update_status(ResearchStatus.CANCELLED)
self._save_session(session)
return {
"status": "cancelled",
"session_id": session_id,
"message": "研究已取消"
}
def get_research_report(self, session_id: str) -> Optional[str]:
"""获取研究报告"""
session = self.get_session(session_id)
if not session:
return None
if session.status != ResearchStatus.COMPLETED:
return None
# 如果有最终报告,返回
if session.final_report:
return session.final_report
# 否则尝试从文件加载
report_path = os.path.join(Config.REPORTS_DIR, f"{session_id}.md")
if os.path.exists(report_path):
with open(report_path, 'r', encoding='utf-8') as f:
return f.read()
return None
def list_sessions(self, limit: int = 20, offset: int = 0) -> List[Dict[str, Any]]:
"""列出所有研究会话"""
# 从文件系统读取所有会话
sessions = []
session_files = sorted(
[f for f in os.listdir(Config.SESSIONS_DIR) if f.endswith('.json')],
reverse=True # 最新的在前
)
for filename in session_files[offset:offset+limit]:
filepath = os.path.join(Config.SESSIONS_DIR, filename)
try:
session = ResearchSession.load_from_file(filepath)
sessions.append({
"id": session.id,
"question": session.question,
"status": session.status,
"created_at": session.created_at.isoformat(),
"progress": session.get_progress_percentage()
})
except Exception as e:
logger.error(f"加载会话失败 {filename}: {e}")
return sessions
def _save_session(self, session: ResearchSession):
"""保存会话到文件"""
filepath = self._get_session_filepath(session.id)
# 使用模型的 to_dict 方法处理 datetime 序列化
data = session.dict()
# 转换 datetime 对象
for key in ['created_at', 'updated_at', 'completed_at']:
if data.get(key):
data[key] = data[key].isoformat() if hasattr(data[key], 'isoformat') else data[key]
# 处理嵌套的 datetime
if data.get('outline'):
if data['outline'].get('created_at'):
data['outline']['created_at'] = data['outline']['created_at'].isoformat()
if data['outline'].get('updated_at'):
data['outline']['updated_at'] = data['outline']['updated_at'].isoformat()
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
def _get_session_filepath(self, session_id: str) -> str:
"""获取会话文件路径"""
return os.path.join(Config.SESSIONS_DIR, f"{session_id}.json")
# 以下是供任务调用的方法
def process_question_analysis(self, session_id: str):
"""处理问题分析阶段"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 分析问题类型
session.question_type = self.ai_service.analyze_question_type(session.question)
# 细化问题
session.refined_questions = self.ai_service.refine_questions(
session.question,
session.question_type
)
# 制定研究思路
session.research_approach = self.ai_service.create_research_approach(
session.question,
session.question_type,
session.refined_questions
)
# 更新进度
session.current_phase = "制定大纲"
session.completed_steps += 1
self._save_session(session)
def process_outline_creation(self, session_id: str):
"""处理大纲创建阶段"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 创建大纲
outline_dict = self.ai_service.create_outline(
session.question,
session.question_type,
session.refined_questions,
session.research_approach
)
# 转换为模型对象
subtopics = []
for st in outline_dict.get('sub_topics', []):
subtopic = Subtopic(
topic=st['topic'],
explain=st['explain'],
priority=st['priority'],
related_questions=st.get('related_questions', [])
)
# 设置最大搜索次数
if subtopic.priority == "high":
subtopic.max_searches = Config.MAX_SEARCHES_HIGH_PRIORITY
elif subtopic.priority == "medium":
subtopic.max_searches = Config.MAX_SEARCHES_MEDIUM_PRIORITY
else:
subtopic.max_searches = Config.MAX_SEARCHES_LOW_PRIORITY
subtopics.append(subtopic)
session.outline = ResearchOutline(
main_topic=outline_dict['main_topic'],
research_questions=outline_dict['research_questions'],
sub_topics=subtopics
)
# 更新进度
session.current_phase = "研究子主题"
session.update_status(ResearchStatus.RESEARCHING)
session.total_steps = 3 + len(subtopics) + 1 # 准备+大纲+子主题+最终报告
session.completed_steps = 2
self._save_session(session)
def process_subtopic_research(self, session_id: str, subtopic_id: str):
"""处理子主题研究"""
session = self.get_session(session_id)
if not session or not session.outline:
raise ValueError(f"Session or outline not found: {session_id}")
# 找到对应的子主题
subtopic = None
for st in session.outline.sub_topics:
if st.id == subtopic_id:
subtopic = st
break
if not subtopic:
raise ValueError(f"Subtopic not found: {subtopic_id}")
# 执行研究流程
# 这部分逻辑会在research_tasks.py中实现
# 这里只更新状态
subtopic.status = ResearchStatus.COMPLETED
session.completed_steps += 1
self._save_session(session)
def finalize_research(self, session_id: str):
"""完成研究"""
session = self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 生成最终报告
# 这部分逻辑会在report_generator.py中实现
# 更新状态
session.update_status(ResearchStatus.COMPLETED)
session.current_phase = "研究完成"
session.completed_steps = session.total_steps
self._save_session(session)

View File

@ -0,0 +1,446 @@
# 文件位置: app/tasks/research_tasks.py
# 文件名: research_tasks.py
"""
研究相关的异步任务
使用线程池替代Celery
"""
import logging
from typing import Dict, List, Any
from app.services.task_manager import async_task
from app.models.research import ResearchStatus, Subtopic
logger = logging.getLogger(__name__)
@async_task
def analyze_question_chain(session_id: str):
"""问题分析任务链"""
try:
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
research_manager = ResearchManager()
session = research_manager.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
# 发送状态更新
_emit_status(session_id, ResearchStatus.ANALYZING, "分析问题")
# 执行问题分析
research_manager.process_question_analysis(session_id)
# 发送进度更新
_emit_progress(session_id, 20, "问题分析完成")
# 启动大纲创建任务
create_outline_task.delay(session_id)
except Exception as e:
logger.error(f"问题分析失败: {e}")
_handle_task_error(session_id, str(e))
raise
@async_task
def create_outline_task(session_id: str):
"""创建大纲任务"""
try:
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
research_manager = ResearchManager()
# 发送状态更新
_emit_status(session_id, ResearchStatus.OUTLINING, "制定大纲")
# 创建大纲
research_manager.process_outline_creation(session_id)
# 发送进度更新
_emit_progress(session_id, 30, "大纲制定完成")
# 获取更新后的session
session = research_manager.get_session(session_id)
# 启动子主题研究任务组
if session.outline and session.outline.sub_topics:
# 并发执行子主题研究
subtopic_task_ids = []
for st in session.outline.sub_topics:
task_id = research_subtopic.delay(session_id, st.id)
subtopic_task_ids.append(task_id)
# 启动一个监控任务,等待所有子主题完成后生成最终报告
monitor_subtopics_completion.delay(session_id, subtopic_task_ids)
except Exception as e:
logger.error(f"创建大纲失败: {e}")
_handle_task_error(session_id, str(e))
raise
@async_task
def research_subtopic(session_id: str, subtopic_id: str):
"""研究单个子主题"""
try:
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
from app.services.ai_service import AIService
from app.services.search_service import SearchService
research_manager = ResearchManager()
ai_service = AIService()
search_service = SearchService()
# 获取session和子主题
session = research_manager.get_session(session_id)
if not session or not session.outline:
raise ValueError("Session or outline not found")
subtopic = None
for st in session.outline.sub_topics:
if st.id == subtopic_id:
subtopic = st
break
if not subtopic:
raise ValueError(f"Subtopic not found: {subtopic_id}")
# 更新子主题状态
subtopic.status = ResearchStatus.RESEARCHING
research_manager.update_session(session_id, {'outline': session.outline})
_emit_subtopic_progress(session_id, subtopic_id, 0, "researching")
# 1. 生成搜索查询
queries = ai_service.generate_search_queries(
subtopic.topic,
subtopic.explain,
subtopic.related_questions,
subtopic.priority
)
# 2. 执行搜索
logger.info(f"开始搜索子主题 {subtopic.topic}: {len(queries)} 个查询")
search_results = []
for i, query in enumerate(queries):
try:
response = search_service.search(query)
results = response.to_search_results()
# 评估结果重要性
evaluated_results = ai_service.evaluate_search_results(
subtopic.topic, results
)
search_results.extend(evaluated_results)
# 更新进度
progress = (i + 1) / len(queries) * 50 # 搜索占50%进度
_emit_subtopic_progress(session_id, subtopic_id, progress, "searching")
except Exception as e:
logger.error(f"搜索失败 '{query}': {e}")
# 去重
unique_results = list({r.url: r for r in search_results}.values())
subtopic.searches = [
{
"url": r.url,
"title": r.title,
"snippet": r.snippet,
"importance": r.importance.value if r.importance else "medium"
}
for r in unique_results
]
# 3. 信息反思
key_points = ai_service.reflect_on_information(subtopic.topic, unique_results)
if key_points:
# 4. 生成细化查询
refined_queries_map = ai_service.generate_refined_queries(key_points)
# 5. 执行细化搜索
for key_info, queries in refined_queries_map.items():
refined_batch = search_service.refined_search(
subtopic_id, key_info, queries
)
# 评估细化搜索结果
evaluated_refined = ai_service.evaluate_search_results(
subtopic.topic, refined_batch.results
)
subtopic.refined_searches.extend([
{
"key_info": key_info,
"url": r.url,
"title": r.title,
"snippet": r.snippet,
"importance": r.importance.value if r.importance else "medium"
}
for r in evaluated_refined
])
_emit_subtopic_progress(session_id, subtopic_id, 70, "integrating")
# 6. 整合信息
all_results = unique_results + [r for batch in subtopic.refined_searches for r in batch.get('results', [])]
integrated_info = ai_service.integrate_information(subtopic.topic, all_results)
subtopic.integrated_info = integrated_info
# 7. 撰写报告
_emit_subtopic_progress(session_id, subtopic_id, 80, "writing")
report_content = ai_service.write_subtopic_report(subtopic.topic, integrated_info)
# 8. 幻觉检测和修正
_emit_subtopic_progress(session_id, subtopic_id, 90, "reviewing")
# 提取原始内容用于幻觉检测
url_content_map = {}
for result in all_results:
url_content_map[result.url] = result.snippet
fixed_report, hallucinations = ai_service.detect_and_fix_hallucinations(
report_content, url_content_map
)
subtopic.report = fixed_report
subtopic.hallucination_checks = hallucinations
subtopic.status = ResearchStatus.COMPLETED
# 保存更新
research_manager.update_session(session_id, {'outline': session.outline})
research_manager.process_subtopic_research(session_id, subtopic_id)
_emit_subtopic_progress(session_id, subtopic_id, 100, "completed")
return {
"subtopic_id": subtopic_id,
"status": "completed",
"search_count": len(queries),
"results_count": len(unique_results),
"hallucinations_fixed": len(hallucinations)
}
except Exception as e:
logger.error(f"子主题研究失败 {subtopic_id}: {e}")
# 更新状态为错误
try:
from app.services.research_manager import ResearchManager
research_manager = ResearchManager()
session = research_manager.get_session(session_id)
if session and session.outline:
for st in session.outline.sub_topics:
if st.id == subtopic_id:
st.status = ResearchStatus.ERROR
break
research_manager.update_session(session_id, {'outline': session.outline})
_emit_subtopic_progress(session_id, subtopic_id, -1, "error")
except:
pass
raise
@async_task
def monitor_subtopics_completion(session_id: str, task_ids: List[str]):
"""监控子主题完成情况并生成最终报告"""
import time
from app.services.task_manager import task_manager
try:
# 等待所有子主题任务完成
max_wait_time = 1800 # 30分钟超时
start_time = time.time()
while True:
all_completed = True
failed_count = 0
for task_id in task_ids:
status = task_manager.get_task_status(task_id)
if status:
if status['status'] == 'running' or status['status'] == 'pending':
all_completed = False
elif status['status'] == 'failed':
failed_count += 1
if all_completed:
break
if time.time() - start_time > max_wait_time:
logger.error(f"等待子主题完成超时: {session_id}")
break
time.sleep(5) # 每5秒检查一次
# 所有子主题完成后,生成最终报告
if failed_count < len(task_ids): # 至少有一个成功
generate_final_report_task.delay(session_id)
else:
_handle_task_error(session_id, "所有子主题研究失败")
except Exception as e:
logger.error(f"监控子主题完成失败: {e}")
_handle_task_error(session_id, str(e))
@async_task
def generate_final_report_task(session_id: str):
"""生成最终报告"""
try:
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
from app.services.ai_service import AIService
from app.services.report_generator import ReportGenerator
research_manager = ResearchManager()
ai_service = AIService()
report_generator = ReportGenerator()
# 发送状态更新
_emit_status(session_id, ResearchStatus.WRITING, "生成最终报告")
_emit_progress(session_id, 90, "整合所有子主题报告")
# 获取session
session = research_manager.get_session(session_id)
if not session or not session.outline:
raise ValueError("Session or outline not found")
# 收集所有子主题报告
subtopic_reports_dict = {}
subtopic_report_objects = []
for subtopic in session.outline.sub_topics:
if subtopic.report:
subtopic_reports_dict[subtopic.topic] = subtopic.report
# 创建报告对象
report_obj = report_generator.generate_subtopic_report(
subtopic,
subtopic.integrated_info or {},
subtopic.report
)
subtopic_report_objects.append(report_obj)
# 生成最终报告内容
final_content = ai_service.generate_final_report(
session.outline.main_topic,
session.outline.research_questions,
subtopic_reports_dict
)
# 创建最终报告对象
final_report = report_generator.generate_final_report(
session,
subtopic_report_objects,
final_content
)
# 保存报告
report_path = report_generator.save_report(final_report)
# 更新session
session.final_report = final_report.to_markdown()
session.update_status(ResearchStatus.COMPLETED)
research_manager.update_session(session_id, {
'final_report': session.final_report,
'status': session.status
})
research_manager.finalize_research(session_id)
# 发送完成通知
_emit_progress(session_id, 100, "研究完成")
_emit_status(session_id, ResearchStatus.COMPLETED, "研究完成")
_emit_report_ready(session_id, "final")
logger.info(f"研究完成: {session_id}")
return {
"session_id": session_id,
"status": "completed",
"report_path": report_path
}
except Exception as e:
logger.error(f"生成最终报告失败: {e}")
_handle_task_error(session_id, str(e))
raise
# ========== 辅助函数 ==========
def _get_socketio():
"""获取socketio实例"""
# 延迟导入,避免循环依赖
from app import socketio
return socketio
def _emit_progress(session_id: str, percentage: float, message: str):
"""发送进度更新"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_progress
socketio = _get_socketio()
emit_progress(socketio, session_id, {
'percentage': percentage,
'message': message
})
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
def _emit_status(session_id: str, status: ResearchStatus, phase: str):
"""发送状态更新"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_status_change
socketio = _get_socketio()
emit_status_change(socketio, session_id, status.value, phase)
except Exception as e:
logger.error(f"发送状态更新失败: {e}")
def _emit_subtopic_progress(session_id: str, subtopic_id: str,
progress: float, status: str):
"""发送子主题进度"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_subtopic_progress
socketio = _get_socketio()
emit_subtopic_progress(socketio, session_id, subtopic_id, progress, status)
except Exception as e:
logger.error(f"发送子主题进度失败: {e}")
def _emit_report_ready(session_id: str, report_type: str):
"""发送报告就绪通知"""
try:
# 延迟导入避免循环依赖
from app.routes.websocket import emit_report_ready
socketio = _get_socketio()
emit_report_ready(socketio, session_id, report_type)
except Exception as e:
logger.error(f"发送报告就绪通知失败: {e}")
def _handle_task_error(session_id: str, error_message: str):
"""处理任务错误"""
try:
# 在函数内部导入,避免循环导入
from app.services.research_manager import ResearchManager
from app.routes.websocket import emit_error
# 更新session状态
research_manager = ResearchManager()
session = research_manager.get_session(session_id)
if session:
session.update_status(ResearchStatus.ERROR)
session.error_message = error_message
research_manager.update_session(session_id, {
'status': session.status,
'error_message': error_message
})
# 发送错误通知
socketio = _get_socketio()
emit_error(socketio, session_id, error_message)
except Exception as e:
logger.error(f"处理任务错误失败: {e}")

View File

@ -0,0 +1,119 @@
"""
搜索结果数据模型
"""
from datetime import datetime
from typing import List, Optional, Dict, Any
from enum import Enum
from pydantic import BaseModel, Field
class SearchImportance(str, Enum):
"""搜索结果重要性"""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class SearchResult(BaseModel):
"""单个搜索结果"""
title: str
url: str
snippet: str
score: float = 0.0
published_date: Optional[str] = None
raw_content: Optional[str] = None
importance: Optional[SearchImportance] = None
key_findings: List[str] = []
def __hash__(self):
"""使URL可以用于集合去重"""
return hash(self.url)
def __eq__(self, other):
"""基于URL判断相等性"""
if isinstance(other, SearchResult):
return self.url == other.url
return False
class TavilySearchResponse(BaseModel):
"""Tavily API响应模型"""
query: str
answer: Optional[str] = None
images: List[str] = []
results: List[Dict[str, Any]] = []
response_time: float = 0.0
def to_search_results(self) -> List[SearchResult]:
"""转换为SearchResult列表"""
search_results = []
for result in self.results:
search_results.append(SearchResult(
title=result.get('title', ''),
url=result.get('url', ''),
snippet=result.get('content', ''),
score=result.get('score', 0.0),
published_date=result.get('published_date'),
raw_content=result.get('raw_content')
))
return search_results
class SearchBatch(BaseModel):
"""搜索批次"""
search_id: str = Field(default_factory=lambda: f"S{uuid.uuid4().hex[:8]}")
subtopic_id: str
query: str
timestamp: datetime = Field(default_factory=datetime.now)
results: List[SearchResult] = []
is_refined_search: bool = False
parent_search_id: Optional[str] = None
detail_type: Optional[str] = None
total_results: int = 0
def add_results(self, results: List[SearchResult]):
"""添加搜索结果并去重"""
existing_urls = {r.url for r in self.results}
for result in results:
if result.url not in existing_urls:
self.results.append(result)
existing_urls.add(result.url)
self.total_results = len(self.results)
class SearchSummary(BaseModel):
"""搜索摘要统计"""
subtopic_id: str
total_searches: int = 0
total_results: int = 0
high_importance_count: int = 0
medium_importance_count: int = 0
low_importance_count: int = 0
unique_domains: List[str] = []
@classmethod
def from_search_batches(cls, subtopic_id: str, batches: List[SearchBatch]) -> 'SearchSummary':
"""从搜索批次生成摘要"""
summary = cls(subtopic_id=subtopic_id)
summary.total_searches = len(batches)
all_results = []
domains = set()
for batch in batches:
all_results.extend(batch.results)
for result in batch.results:
# 提取域名
from urllib.parse import urlparse
domain = urlparse(result.url).netloc
if domain:
domains.add(domain)
summary.total_results = len(all_results)
summary.unique_domains = list(domains)
# 统计重要性
for result in all_results:
if result.importance == SearchImportance.HIGH:
summary.high_importance_count += 1
elif result.importance == SearchImportance.MEDIUM:
summary.medium_importance_count += 1
elif result.importance == SearchImportance.LOW:
summary.low_importance_count += 1
return summary

View File

@ -0,0 +1,204 @@
"""
搜索服务
封装Tavily API调用
"""
import logging
from typing import List, Dict, Any, Optional
from tavily import TavilyClient
from app.models.search_result import SearchResult, TavilySearchResponse, SearchBatch
from config import Config
import time
logger = logging.getLogger(__name__)
class SearchService:
"""搜索服务"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.TAVILY_API_KEY
self.client = TavilyClient(api_key=self.api_key)
self._search_cache = {} # 简单的搜索缓存
def search(self, query: str, max_results: int = None,
search_depth: str = None, include_answer: bool = None,
include_raw_content: bool = None) -> TavilySearchResponse:
"""执行搜索"""
# 检查缓存
cache_key = f"{query}:{max_results}:{search_depth}"
if cache_key in self._search_cache:
logger.info(f"从缓存返回搜索结果: {query}")
return self._search_cache[cache_key]
# 设置默认值
if max_results is None:
max_results = Config.TAVILY_MAX_RESULTS
if search_depth is None:
search_depth = Config.TAVILY_SEARCH_DEPTH
if include_answer is None:
include_answer = Config.TAVILY_INCLUDE_ANSWER
if include_raw_content is None:
include_raw_content = Config.TAVILY_INCLUDE_RAW_CONTENT
try:
logger.info(f"执行Tavily搜索: {query}")
start_time = time.time()
# 调用Tavily API
response = self.client.search(
query=query,
max_results=max_results,
search_depth=search_depth,
include_answer=include_answer,
include_raw_content=include_raw_content
)
response_time = time.time() - start_time
# 转换为我们的响应模型
tavily_response = TavilySearchResponse(
query=query,
answer=response.get('answer'),
images=response.get('images', []),
results=response.get('results', []),
response_time=response_time
)
# 缓存结果
self._search_cache[cache_key] = tavily_response
logger.info(f"搜索完成,耗时 {response_time:.2f}秒,返回 {len(tavily_response.results)} 条结果")
return tavily_response
except Exception as e:
logger.error(f"Tavily搜索失败: {e}")
# 返回空结果
return TavilySearchResponse(
query=query,
answer=None,
images=[],
results=[],
response_time=0.0
)
def batch_search(self, queries: List[str], max_results_per_query: int = 10) -> List[TavilySearchResponse]:
"""批量搜索"""
results = []
for query in queries:
# 添加延迟以避免速率限制
if results: # 不是第一个查询
time.sleep(0.5) # 500ms延迟
try:
response = self.search(query, max_results=max_results_per_query)
results.append(response)
except Exception as e:
logger.error(f"批量搜索中的查询失败 '{query}': {e}")
# 添加空结果
results.append(TavilySearchResponse(
query=query,
results=[],
response_time=0.0
))
return results
def search_subtopic(self, subtopic_id: str, subtopic_name: str,
queries: List[str]) -> SearchBatch:
"""为子主题执行搜索"""
all_results = []
for query in queries:
response = self.search(query)
search_results = response.to_search_results()
all_results.extend(search_results)
# 创建搜索批次
batch = SearchBatch(
subtopic_id=subtopic_id,
query=f"子主题搜索: {subtopic_name}",
results=[]
)
# 去重并添加结果
batch.add_results(all_results)
return batch
def refined_search(self, subtopic_id: str, key_info: str,
queries: List[str], parent_search_id: str = None) -> SearchBatch:
"""执行细化搜索"""
all_results = []
for query in queries:
response = self.search(query, search_depth="advanced")
search_results = response.to_search_results()
all_results.extend(search_results)
# 创建细化搜索批次
batch = SearchBatch(
subtopic_id=subtopic_id,
query=f"细化搜索: {key_info}",
results=[],
is_refined_search=True,
parent_search_id=parent_search_id,
detail_type=key_info
)
batch.add_results(all_results)
return batch
def extract_content(self, urls: List[str]) -> Dict[str, str]:
"""提取URL的完整内容"""
content_map = {}
try:
# Tavily的extract功能如果可用
# 注意这需要Tavily API支持extract功能
response = self.client.extract(urls=urls[:20]) # 最多20个URL
for result in response.get('results', []):
url = result.get('url')
content = result.get('raw_content', '')
if url and content:
content_map[url] = content
except Exception as e:
logger.error(f"提取内容失败: {e}")
# 如果extract不可用使用搜索结果中的内容
for url in urls:
# 从缓存的搜索结果中查找
for cached_response in self._search_cache.values():
for result in cached_response.results:
if result.get('url') == url:
content_map[url] = result.get('content', '')
break
return content_map
def get_search_statistics(self) -> Dict[str, Any]:
"""获取搜索统计信息"""
total_searches = len(self._search_cache)
total_results = sum(len(r.results) for r in self._search_cache.values())
return {
"total_searches": total_searches,
"total_results": total_results,
"cache_size": len(self._search_cache),
"cached_queries": list(self._search_cache.keys())
}
def clear_cache(self):
"""清空搜索缓存"""
self._search_cache.clear()
logger.info("搜索缓存已清空")
def test_connection(self) -> bool:
"""测试Tavily API连接"""
try:
response = self.search("test query", max_results=1)
return len(response.results) >= 0
except Exception as e:
logger.error(f"Tavily API连接测试失败: {e}")
return False

417
所有文件/style.css Normal file
View File

@ -0,0 +1,417 @@
/* app/static/css/style.css */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background-color: #f5f7fa;
color: #2c3e50;
}
.app {
min-height: 100vh;
display: flex;
flex-direction: column;
}
/* Header */
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.header h1 {
font-size: 24px;
font-weight: 600;
}
.progress-bar {
margin-top: 10px;
background: rgba(255,255,255,0.3);
height: 8px;
border-radius: 4px;
overflow: hidden;
}
.progress-fill {
height: 100%;
background: white;
transition: width 0.3s ease;
}
.progress-message {
font-size: 14px;
opacity: 0.9;
margin-top: 5px;
}
/* Main Container */
.main-container {
flex: 1;
padding: 30px;
}
/* Start Screen */
.start-screen {
max-width: 800px;
margin: 0 auto;
}
.start-card {
background: white;
border-radius: 12px;
padding: 40px;
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
}
.start-card h2 {
font-size: 28px;
margin-bottom: 30px;
text-align: center;
}
.input-group {
display: flex;
gap: 10px;
margin-bottom: 40px;
}
.question-input {
flex: 1;
padding: 15px 20px;
font-size: 16px;
border: 2px solid #e0e0e0;
border-radius: 8px;
transition: border-color 0.3s ease;
}
.question-input:focus {
outline: none;
border-color: #667eea;
}
.start-button {
padding: 15px 30px;
background: #667eea;
color: white;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 500;
cursor: pointer;
transition: all 0.3s ease;
display: flex;
align-items: center;
gap: 8px;
}
.start-button:hover {
background: #5a67d8;
transform: translateY(-1px);
}
.start-button:disabled {
background: #cbd5e0;
cursor: not-allowed;
}
/* History Section */
.history-section {
margin-top: 40px;
padding-top: 40px;
border-top: 1px solid #e0e0e0;
}
.session-item {
background: #f8fafc;
border: 1px solid #e5e7eb;
border-radius: 8px;
padding: 16px;
margin-bottom: 12px;
cursor: pointer;
transition: all 0.3s ease;
}
.session-item:hover {
border-color: #667eea;
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
}
/* Tree Container */
.tree-container {
background: white;
border-radius: 12px;
padding: 30px;
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
overflow-x: auto;
min-width: 800px;
}
/* Tree Structure */
.tree-node {
position: relative;
padding-left: 30px;
margin: 10px 0;
}
.tree-node::before {
content: '';
position: absolute;
left: 0;
top: -10px;
width: 1px;
height: calc(100% + 20px);
background: #e0e0e0;
}
.tree-node:last-child::before {
height: 30px;
}
.tree-node::after {
content: '';
position: absolute;
left: 0;
top: 20px;
width: 20px;
height: 1px;
background: #e0e0e0;
}
/* Root Node */
.root-node {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 30px;
}
.root-node h2 {
font-size: 20px;
margin-bottom: 8px;
}
/* Node Card */
.node-card {
background: white;
border: 2px solid #e0e0e0;
border-radius: 8px;
padding: 15px;
cursor: pointer;
transition: all 0.3s ease;
display: inline-block;
min-width: 300px;
}
.node-card:hover {
border-color: #667eea;
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.15);
}
.node-card.processing {
border-color: #3b82f6;
animation: pulse 2s infinite;
}
.node-card.completed {
border-color: #10b981;
}
.node-card.error {
border-color: #ef4444;
}
@keyframes pulse {
0% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0.4); }
70% { box-shadow: 0 0 0 8px rgba(59, 130, 246, 0); }
100% { box-shadow: 0 0 0 0 rgba(59, 130, 246, 0); }
}
/* Node Content */
.node-header {
display: flex;
align-items: center;
justify-content: space-between;
}
.node-title {
font-weight: 600;
font-size: 16px;
flex: 1;
}
.node-status {
display: flex;
align-items: center;
gap: 8px;
}
.status-icon {
width: 24px;
height: 24px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-size: 14px;
}
.status-icon.completed {
background: #10b981;
color: white;
}
.status-icon.processing {
background: #3b82f6;
color: white;
}
.status-icon.pending {
background: #9ca3af;
color: white;
}
.expand-icon {
transition: transform 0.3s ease;
margin-right: 8px;
}
.expand-icon.expanded {
transform: rotate(90deg);
}
.node-content {
margin-top: 15px;
padding-top: 15px;
border-top: 1px solid #f0f0f0;
max-height: 0;
overflow: hidden;
opacity: 0;
transition: all 0.3s ease;
}
.node-content.expanded {
max-height: 1000px;
opacity: 1;
}
/* Phase Card */
.phase-card {
background: #f8fafc;
border: 1px solid #e5e7eb;
border-radius: 6px;
padding: 12px;
margin: 8px 0;
}
.phase-card h4 {
font-size: 14px;
margin-bottom: 8px;
color: #4b5563;
}
/* Action Buttons */
.action-buttons {
position: fixed;
bottom: 30px;
right: 30px;
display: flex;
gap: 12px;
}
.action-button {
background: white;
border: 1px solid #e5e7eb;
padding: 12px 20px;
border-radius: 8px;
cursor: pointer;
font-size: 14px;
transition: all 0.3s ease;
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
}
.action-button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(0,0,0,0.12);
}
.action-button.primary {
background: #667eea;
color: white;
border-color: #667eea;
}
.action-button.danger {
background: #ef4444;
color: white;
border-color: #ef4444;
}
/* Detail Panel */
.detail-panel {
position: fixed;
right: -400px;
top: 0;
width: 400px;
height: 100vh;
background: white;
box-shadow: -4px 0 20px rgba(0,0,0,0.1);
transition: right 0.3s ease;
z-index: 1000;
overflow-y: auto;
}
.detail-panel.open {
right: 0;
}
.panel-header {
padding: 20px;
border-bottom: 1px solid #e5e7eb;
display: flex;
justify-content: space-between;
align-items: center;
}
.panel-close {
background: none;
border: none;
font-size: 24px;
cursor: pointer;
color: #6b7280;
}
/* Loading */
.loading-overlay {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: rgba(0, 0, 0, 0.5);
display: flex;
align-items: center;
justify-content: center;
z-index: 9999;
}
.loading-spinner {
width: 50px;
height: 50px;
border: 5px solid #f3f3f3;
border-top: 5px solid #667eea;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}

View File

@ -0,0 +1,202 @@
# 文件位置: app/services/task_manager.py
# 文件名: task_manager.py
"""
任务管理器
替代 Celery 的轻量级任务队列实现
"""
import uuid
import logging
import threading
from concurrent.futures import ThreadPoolExecutor, Future
from typing import Dict, Any, Callable, Optional, List
from datetime import datetime
from enum import Enum
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""任务状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TaskInfo:
"""任务信息"""
def __init__(self, task_id: str, func_name: str, args: tuple, kwargs: dict):
self.id = task_id
self.func_name = func_name
self.args = args
self.kwargs = kwargs
self.status = TaskStatus.PENDING
self.created_at = datetime.now()
self.started_at: Optional[datetime] = None
self.completed_at: Optional[datetime] = None
self.result: Any = None
self.error: Optional[str] = None
self.future: Optional[Future] = None
class TaskManager:
"""任务管理器单例"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, 'initialized'):
self.executor = ThreadPoolExecutor(max_workers=10)
self.tasks: Dict[str, TaskInfo] = {}
self.session_tasks: Dict[str, List[str]] = {} # session_id -> task_ids
self.initialized = True
logger.info("任务管理器初始化完成")
def submit_task(self, func: Callable, *args, **kwargs) -> str:
"""提交任务"""
task_id = str(uuid.uuid4())
task_info = TaskInfo(task_id, func.__name__, args, kwargs)
# 提取session_id如果存在
session_id = None
if args and isinstance(args[0], str) and '-' in args[0]:
# 假设第一个参数是session_idUUID格式
session_id = args[0]
elif 'session_id' in kwargs:
session_id = kwargs['session_id']
# 记录任务
self.tasks[task_id] = task_info
# 关联到session
if session_id:
if session_id not in self.session_tasks:
self.session_tasks[session_id] = []
self.session_tasks[session_id].append(task_id)
# 提交执行
future = self.executor.submit(self._execute_task, task_info, func, *args, **kwargs)
task_info.future = future
logger.info(f"任务提交成功: {task_id} - {func.__name__}")
return task_id
def _execute_task(self, task_info: TaskInfo, func: Callable, *args, **kwargs):
"""执行任务"""
try:
task_info.status = TaskStatus.RUNNING
task_info.started_at = datetime.now()
logger.info(f"任务开始执行: {task_info.id} - {task_info.func_name}")
# 执行任务
result = func(*args, **kwargs)
# 更新任务信息
task_info.status = TaskStatus.COMPLETED
task_info.completed_at = datetime.now()
task_info.result = result
logger.info(f"任务执行成功: {task_info.id}")
return result
except Exception as e:
task_info.status = TaskStatus.FAILED
task_info.completed_at = datetime.now()
task_info.error = str(e)
logger.error(f"任务执行失败: {task_info.id} - {e}")
raise
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""获取任务状态"""
if task_id not in self.tasks:
return None
task_info = self.tasks[task_id]
return {
"task_id": task_info.id,
"status": task_info.status.value,
"func_name": task_info.func_name,
"created_at": task_info.created_at.isoformat(),
"started_at": task_info.started_at.isoformat() if task_info.started_at else None,
"completed_at": task_info.completed_at.isoformat() if task_info.completed_at else None,
"error": task_info.error
}
def get_session_tasks(self, session_id: str) -> List[Dict[str, Any]]:
"""获取会话的所有任务"""
task_ids = self.session_tasks.get(session_id, [])
return [self.get_task_status(task_id) for task_id in task_ids if self.get_task_status(task_id)]
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
if task_id not in self.tasks:
return False
task_info = self.tasks[task_id]
if task_info.future and not task_info.future.done():
cancelled = task_info.future.cancel()
if cancelled:
task_info.status = TaskStatus.CANCELLED
task_info.completed_at = datetime.now()
logger.info(f"任务已取消: {task_id}")
return True
return False
def cancel_session_tasks(self, session_id: str) -> int:
"""取消会话的所有任务"""
task_ids = self.session_tasks.get(session_id, [])
cancelled_count = 0
for task_id in task_ids:
if self.cancel_task(task_id):
cancelled_count += 1
return cancelled_count
def cleanup_old_tasks(self, hours: int = 24):
"""清理旧任务"""
cutoff_time = datetime.now().timestamp() - (hours * 3600)
tasks_to_remove = []
for task_id, task_info in self.tasks.items():
if task_info.completed_at and task_info.completed_at.timestamp() < cutoff_time:
tasks_to_remove.append(task_id)
for task_id in tasks_to_remove:
del self.tasks[task_id]
# 从session_tasks中移除
for session_id, task_ids in self.session_tasks.items():
if task_id in task_ids:
task_ids.remove(task_id)
logger.info(f"清理了 {len(tasks_to_remove)} 个旧任务")
return len(tasks_to_remove)
def shutdown(self):
"""关闭任务管理器"""
self.executor.shutdown(wait=True)
logger.info("任务管理器已关闭")
# 全局任务管理器实例
task_manager = TaskManager()
# 装饰器:将普通函数转换为异步任务
def async_task(func):
"""异步任务装饰器"""
def wrapper(*args, **kwargs):
return task_manager.submit_task(func, *args, **kwargs)
wrapper.delay = wrapper # 兼容Celery的.delay()调用方式
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper

279
所有文件/test_api_keys.py Executable file
View File

@ -0,0 +1,279 @@
#!/usr/bin/env python3
"""
测试API密钥是否有效
"""
import os
import sys
from dotenv import load_dotenv
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 加载环境变量
load_dotenv()
def test_deepseek_api():
"""测试DeepSeek API"""
print("\n测试 DeepSeek API...")
api_key = os.environ.get('DEEPSEEK_API_KEY')
if not api_key:
print("❌ 错误: 未设置 DEEPSEEK_API_KEY")
return False
try:
from openai import OpenAI
base_url = os.environ.get('DEEPSEEK_BASE_URL', 'https://api.deepseek.com/v1')
# 检测是否是火山引擎
if 'volces.com' in base_url:
print(" 检测到火山引擎 ARK 平台")
r1_model = "deepseek-r1-250120"
v3_model = "deepseek-v3-241226"
else:
r1_model = "deepseek-reasoner"
v3_model = "deepseek-chat"
client = OpenAI(
api_key=api_key,
base_url=base_url
)
# 测试V3模型先测试V3因为它更稳定
print(f" 测试 V3 模型 ({v3_model})...")
try:
response = client.chat.completions.create(
model=v3_model,
messages=[{"role": "user", "content": "Hello, this is a test. Reply with OK."}],
max_tokens=10
)
if response.choices[0].message.content:
print(" ✓ V3 模型测试成功")
else:
print(" ❌ V3 模型响应异常")
return False
except Exception as e:
print(f" ❌ V3 模型测试失败: {e}")
# 如果是火山引擎可能需要使用endpoint ID
if 'volces.com' in base_url:
print(" 提示: 火山引擎可能需要使用自定义的 endpoint ID (如 ep-xxxxx)")
return False
# 测试R1模型
print(f" 测试 R1 模型 ({r1_model})...")
try:
response = client.chat.completions.create(
model=r1_model,
messages=[{"role": "user", "content": "Hello, this is a test. Reply with OK."}],
max_tokens=10
)
if response.choices[0].message.content:
print(" ✓ R1 模型测试成功")
else:
print(" ❌ R1 模型响应异常")
# R1失败不影响整体因为V3可以工作
except Exception as e:
print(f" ⚠️ R1 模型测试失败: {e}")
print(" 注意: R1模型可能需要特殊配置或不可用")
print("✅ DeepSeek API 测试通过")
return True
except Exception as e:
print(f"❌ DeepSeek API 测试失败: {e}")
return False
def test_tavily_api():
"""测试Tavily API"""
print("\n测试 Tavily API...")
api_key = os.environ.get('TAVILY_API_KEY')
if not api_key:
print("❌ 错误: 未设置 TAVILY_API_KEY")
return False
try:
from tavily import TavilyClient
client = TavilyClient(api_key=api_key)
# 执行测试搜索
print(" 执行测试搜索...")
response = client.search("test query", max_results=1)
if response and 'results' in response:
print(f" ✓ 搜索返回 {len(response['results'])} 条结果")
print("✅ Tavily API 测试通过")
return True
else:
print(" ❌ 搜索响应异常")
return False
except Exception as e:
print(f"❌ Tavily API 测试失败: {e}")
return False
def test_task_manager():
"""测试任务管理器"""
print("\n测试任务管理器...")
try:
from app.services.task_manager import task_manager
# 测试提交任务
def test_func(x):
return x * 2
task_id = task_manager.submit_task(test_func, 5)
print(f" ✓ 任务提交成功: {task_id}")
# 等待任务完成
import time
time.sleep(1)
# 检查任务状态
status = task_manager.get_task_status(task_id)
if status and status['status'] == 'completed':
print(" ✓ 任务执行成功")
print("✅ 任务管理器测试通过")
return True
else:
print(f" ❌ 任务状态异常: {status}")
return False
except Exception as e:
print(f"❌ 任务管理器测试失败: {e}")
return False
def test_mongodb_connection():
"""测试MongoDB连接可选"""
print("\n测试 MongoDB 连接(可选)...")
mongodb_uri = os.environ.get('MONGODB_URI', 'mongodb://localhost:27017/deepresearch')
try:
from pymongo import MongoClient
client = MongoClient(mongodb_uri, serverSelectionTimeoutMS=5000)
# 测试连接
client.server_info()
print("✅ MongoDB 连接测试通过")
# 测试数据库操作
db = client.get_database()
test_collection = db['test_collection']
# 插入测试文档
result = test_collection.insert_one({'test': 'document'})
# 查询测试文档
doc = test_collection.find_one({'_id': result.inserted_id})
# 删除测试文档
test_collection.delete_one({'_id': result.inserted_id})
if doc and doc['test'] == 'document':
print(" ✓ MongoDB 读写测试通过")
return True
else:
print(" ❌ MongoDB 读写测试失败")
return False
except Exception as e:
print(f"⚠️ MongoDB 连接测试失败(可选): {e}")
print(" 提示: MongoDB是可选的不影响基本功能")
return False
def check_python_version():
"""检查Python版本"""
print("\n检查 Python 版本...")
version = sys.version_info
if version.major == 3 and version.minor >= 8:
print(f"✅ Python 版本: {version.major}.{version.minor}.{version.micro}")
return True
else:
print(f"❌ Python 版本过低: {version.major}.{version.minor}.{version.micro}")
print(" 需要 Python 3.8 或更高版本")
return False
def check_dependencies():
"""检查依赖包"""
print("\n检查依赖包...")
required_packages = [
'flask',
'flask_cors',
'flask_socketio',
'openai',
'tavily',
'pydantic',
'python-dotenv'
]
missing_packages = []
for package in required_packages:
try:
__import__(package)
except ImportError:
missing_packages.append(package)
if missing_packages:
print(f"❌ 缺少以下依赖包: {', '.join(missing_packages)}")
print(" 请运行: pip install -r requirements.txt")
return False
else:
print("✅ 所有必需的依赖包已安装")
return True
def main():
"""主函数"""
print("=" * 60)
print("DeepResearch API 密钥测试工具")
print("=" * 60)
all_passed = True
# 检查Python版本
if not check_python_version():
all_passed = False
# 检查依赖包
if not check_dependencies():
all_passed = False
# 测试DeepSeek API
if not test_deepseek_api():
all_passed = False
# 测试Tavily API
if not test_tavily_api():
all_passed = False
# 测试任务管理器
if not test_task_manager():
all_passed = False
# 测试MongoDB连接可选
test_mongodb_connection() # 不影响整体结果
print("\n" + "=" * 60)
if all_passed:
print("✅ 所有必需的测试都已通过!")
print("\n您可以运行以下命令启动应用:")
print("1. 启动应用: python app.py")
print("\n注意: 不再需要启动 Celery Worker 和 Redis")
else:
print("❌ 有些测试未通过,请检查上述错误信息")
print("\n常见问题:")
print("1. 确保在.env文件中正确设置了API密钥")
print("2. 检查网络连接是否正常")
if __name__ == '__main__':
main()

194
所有文件/v3_agent.py Normal file
View File

@ -0,0 +1,194 @@
"""
DeepSeek V3模型智能体
负责API调用内容重写等执行型任务
"""
import json
import logging
from typing import Dict, List, Any, Optional
from openai import OpenAI
from config import Config
from app.agents.prompts import get_prompt
logger = logging.getLogger(__name__)
class V3Agent:
"""V3模型智能体"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.DEEPSEEK_API_KEY
base_url = Config.DEEPSEEK_BASE_URL
# 火山引擎 ARK 平台使用不同的模型名称
if 'volces.com' in base_url:
self.model = "deepseek-v3-241226" # 火山引擎的 V3 模型名称
else:
self.model = Config.V3_MODEL
self.client = OpenAI(
api_key=self.api_key,
base_url=base_url
)
def _call_api(self, prompt: str, temperature: float = 0.3,
max_tokens: int = 4096, functions: List[Dict] = None) -> Any:
"""调用V3 API"""
try:
messages = [{"role": "user", "content": prompt}]
kwargs = {
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
# 如果提供了functions添加function calling参数
if functions:
kwargs["functions"] = functions
kwargs["function_call"] = "auto"
response = self.client.chat.completions.create(**kwargs)
# 检查是否有function call
if functions and response.choices[0].message.function_call:
return {
"function_call": {
"name": response.choices[0].message.function_call.name,
"arguments": json.loads(response.choices[0].message.function_call.arguments)
}
}
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"V3 API调用失败: {e}")
raise
def generate_search_queries(self, subtopic: str, explanation: str,
related_questions: List[str], count: int) -> List[str]:
"""生成搜索查询"""
prompt = get_prompt("generate_search_queries",
subtopic=subtopic,
explanation=explanation,
related_questions=', '.join(related_questions),
count=count)
result = self._call_api(prompt, temperature=0.7)
# 解析结果为列表
queries = [q.strip() for q in result.split('\n') if q.strip()]
# 去除可能的序号
queries = [q.lstrip('0123456789.-) ') for q in queries]
return queries[:count]
def generate_refined_queries(self, key_info: str, detail_needed: str) -> List[str]:
"""生成细化搜索查询"""
prompt = get_prompt("generate_refined_queries",
key_info=key_info,
detail_needed=detail_needed)
result = self._call_api(prompt, temperature=0.7)
queries = [q.strip() for q in result.split('\n') if q.strip()]
queries = [q.lstrip('0123456789.-) ') for q in queries]
return queries[:3]
def rewrite_hallucination(self, hallucinated_content: str,
original_sources: str) -> str:
"""重写幻觉内容"""
prompt = get_prompt("rewrite_hallucination",
hallucinated_content=hallucinated_content,
original_sources=original_sources)
return self._call_api(prompt, temperature=0.3)
def call_tavily_search(self, query: str, max_results: int = 10) -> Dict[str, Any]:
"""
调用Tavily搜索API通过function calling
注意这是一个示例实现实际的Tavily调用会在search_service.py中
"""
# 定义Tavily搜索function
tavily_function = {
"name": "tavily_search",
"description": "Search the web using Tavily API",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results",
"default": 10
},
"search_depth": {
"type": "string",
"enum": ["basic", "advanced"],
"default": "advanced"
}
},
"required": ["query"]
}
}
prompt = f"Please search for information about: {query}"
result = self._call_api(prompt, functions=[tavily_function])
# 如果返回的是function call提取参数
if isinstance(result, dict) and "function_call" in result:
return result["function_call"]["arguments"]
# 否则返回默认参数
return {
"query": query,
"max_results": max_results,
"search_depth": "advanced"
}
def format_search_results(self, results: List[Dict[str, Any]]) -> str:
"""格式化搜索结果为结构化文本"""
formatted = []
for i, result in enumerate(results, 1):
formatted.append(f"{i}. 标题: {result.get('title', 'N/A')}")
formatted.append(f" URL: {result.get('url', 'N/A')}")
formatted.append(f" 摘要: {result.get('snippet', 'N/A')}")
if result.get('score'):
formatted.append(f" 相关度: {result.get('score', 0):.2f}")
formatted.append("")
return '\n'.join(formatted)
def extract_key_points(self, text: str, max_points: int = 5) -> List[str]:
"""从文本中提取关键点"""
prompt = f"""
请从以下文本中提取最多{max_points}个关键点
{text}
每个关键点独占一行简洁明了
"""
result = self._call_api(prompt, temperature=0.5)
points = [p.strip() for p in result.split('\n') if p.strip()]
points = [p.lstrip('0123456789.-) ') for p in points]
return points[:max_points]
def summarize_content(self, content: str, max_length: int = 200) -> str:
"""总结内容"""
prompt = f"""
请将以下内容总结为不超过{max_length}字的摘要
{content}
要求保留关键信息语言流畅
"""
return self._call_api(prompt, temperature=0.5)

131
所有文件/validators.py Normal file
View File

@ -0,0 +1,131 @@
"""
输入验证工具
"""
import re
from typing import Optional
def validate_question(question: str) -> Optional[str]:
"""验证用户问题"""
if not question:
return "问题不能为空"
if len(question) < 5:
return "问题太短,请提供更详细的描述"
if len(question) > 1000:
return "问题太长请精简到1000字以内"
# 检查是否包含有效内容(不只是标点符号)
if not re.search(r'[a-zA-Z\u4e00-\u9fa5]+', question):
return "请输入有效的问题内容"
return None
def validate_outline_feedback(feedback: str) -> Optional[str]:
"""验证大纲反馈"""
if not feedback:
return "反馈内容不能为空"
if len(feedback) < 10:
return "请提供更详细的修改建议"
if len(feedback) > 500:
return "反馈内容请控制在500字以内"
return None
def validate_session_id(session_id: str) -> Optional[str]:
"""验证会话ID"""
if not session_id:
return "会话ID不能为空"
# UUID格式验证
uuid_pattern = re.compile(
r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$',
re.IGNORECASE
)
if not uuid_pattern.match(session_id):
return "无效的会话ID格式"
return None
def validate_subtopic_id(subtopic_id: str) -> Optional[str]:
"""验证子主题ID"""
if not subtopic_id:
return "子主题ID不能为空"
# 格式: ST开头 + 8位十六进制
if not re.match(r'^ST[0-9a-f]{8}$', subtopic_id, re.IGNORECASE):
return "无效的子主题ID格式"
return None
def validate_search_query(query: str) -> Optional[str]:
"""验证搜索查询"""
if not query:
return "搜索查询不能为空"
if len(query) < 2:
return "搜索查询太短"
if len(query) > 200:
return "搜索查询太长请控制在200字符以内"
# 检查是否包含特殊字符攻击
dangerous_patterns = [
r'<script',
r'javascript:',
r'onerror=',
r'onclick=',
r'DROP TABLE',
r'DELETE FROM',
r'INSERT INTO'
]
for pattern in dangerous_patterns:
if re.search(pattern, query, re.IGNORECASE):
return "搜索查询包含不允许的内容"
return None
def validate_priority(priority: str) -> Optional[str]:
"""验证优先级"""
valid_priorities = ['high', 'medium', 'low']
if priority not in valid_priorities:
return f"优先级必须是以下之一: {', '.join(valid_priorities)}"
return None
def validate_report_format(format: str) -> Optional[str]:
"""验证报告格式"""
valid_formats = ['json', 'markdown', 'html', 'pdf']
if format not in valid_formats:
return f"报告格式必须是以下之一: {', '.join(valid_formats)}"
return None
def sanitize_filename(filename: str) -> str:
"""清理文件名"""
# 移除不安全的字符
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
# 限制长度
if len(filename) > 200:
filename = filename[:200]
# 确保不以点开头(隐藏文件)
if filename.startswith('.'):
filename = '_' + filename[1:]
return filename
def validate_json_structure(data: dict, required_fields: list) -> Optional[str]:
"""验证JSON结构"""
for field in required_fields:
if field not in data:
return f"缺少必要字段: {field}"
return None

141
所有文件/websocket.py Normal file
View File

@ -0,0 +1,141 @@
"""
WebSocket事件处理
实时推送研究进度
"""
from flask_socketio import emit, join_room, leave_room
from flask import request
import logging
logger = logging.getLogger(__name__)
def register_handlers(socketio):
"""注册WebSocket事件处理器"""
@socketio.on('connect')
def handle_connect():
"""客户端连接"""
client_id = request.sid
logger.info(f"客户端连接: {client_id}")
emit('connected', {'message': '连接成功', 'client_id': client_id})
@socketio.on('disconnect')
def handle_disconnect():
"""客户端断开"""
client_id = request.sid
logger.info(f"客户端断开: {client_id}")
@socketio.on('join_session')
def handle_join_session(data):
"""加入研究会话房间"""
session_id = data.get('session_id')
if session_id:
join_room(session_id)
logger.info(f"客户端 {request.sid} 加入房间 {session_id}")
emit('joined', {'session_id': session_id, 'message': '已加入研究会话'})
@socketio.on('leave_session')
def handle_leave_session(data):
"""离开研究会话房间"""
session_id = data.get('session_id')
if session_id:
leave_room(session_id)
logger.info(f"客户端 {request.sid} 离开房间 {session_id}")
emit('left', {'session_id': session_id, 'message': '已离开研究会话'})
# 以下是推送给客户端的事件(由任务调用)
@socketio.on('research_progress')
def broadcast_progress(data):
"""广播研究进度"""
session_id = data.get('session_id')
if session_id:
socketio.emit('progress_update', data, room=session_id)
@socketio.on('research_status_change')
def broadcast_status_change(data):
"""广播状态变化"""
session_id = data.get('session_id')
if session_id:
socketio.emit('status_changed', data, room=session_id)
@socketio.on('subtopic_update')
def broadcast_subtopic_update(data):
"""广播子主题更新"""
session_id = data.get('session_id')
if session_id:
socketio.emit('subtopic_updated', data, room=session_id)
@socketio.on('search_result')
def broadcast_search_result(data):
"""广播搜索结果"""
session_id = data.get('session_id')
if session_id:
socketio.emit('new_search_result', data, room=session_id)
@socketio.on('report_ready')
def broadcast_report_ready(data):
"""广播报告就绪"""
session_id = data.get('session_id')
if session_id:
socketio.emit('report_available', data, room=session_id)
@socketio.on('error_occurred')
def broadcast_error(data):
"""广播错误信息"""
session_id = data.get('session_id')
if session_id:
socketio.emit('research_error', data, room=session_id)
def emit_progress(socketio, session_id: str, progress_data: dict):
"""发送进度更新(供任务调用)"""
socketio.emit('progress_update', {
'session_id': session_id,
**progress_data
}, room=session_id)
def emit_status_change(socketio, session_id: str, status: str, phase: str = None):
"""发送状态变化(供任务调用)"""
data = {
'session_id': session_id,
'status': status
}
if phase:
data['phase'] = phase
socketio.emit('status_changed', data, room=session_id)
def emit_subtopic_progress(socketio, session_id: str, subtopic_id: str,
progress: float, status: str):
"""发送子主题进度(供任务调用)"""
socketio.emit('subtopic_updated', {
'session_id': session_id,
'subtopic_id': subtopic_id,
'progress': progress,
'status': status
}, room=session_id)
def emit_search_complete(socketio, session_id: str, subtopic_id: str,
search_count: int, results_count: int):
"""发送搜索完成通知(供任务调用)"""
socketio.emit('search_completed', {
'session_id': session_id,
'subtopic_id': subtopic_id,
'search_count': search_count,
'results_count': results_count
}, room=session_id)
def emit_report_ready(socketio, session_id: str, report_type: str):
"""发送报告就绪通知(供任务调用)"""
socketio.emit('report_available', {
'session_id': session_id,
'report_type': report_type,
'message': f'{report_type}报告已生成'
}, room=session_id)
def emit_error(socketio, session_id: str, error_message: str, error_type: str = 'general'):
"""发送错误通知(供任务调用)"""
socketio.emit('research_error', {
'session_id': session_id,
'error_type': error_type,
'error_message': error_message
}, room=session_id)