deepresearch/app/models/research.py
2025-07-02 15:35:36 +08:00

126 lines
4.1 KiB
Python

"""
研究会话数据模型
"""
import json
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field
class QuestionType(str, Enum):
"""问题类型枚举"""
FACTUAL = "factual" # 事实查询型
COMPARATIVE = "comparative" # 分析对比型
EXPLORATORY = "exploratory" # 探索发现型
DECISION = "decision" # 决策支持型
class ResearchStatus(str, Enum):
"""研究状态枚举"""
PENDING = "pending"
ANALYZING = "analyzing"
OUTLINING = "outlining"
RESEARCHING = "researching"
WRITING = "writing"
REVIEWING = "reviewing"
COMPLETED = "completed"
ERROR = "error"
CANCELLED = "cancelled"
class SubtopicPriority(str, Enum):
"""子主题优先级"""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class Subtopic(BaseModel):
"""子主题模型"""
id: str = Field(default_factory=lambda: f"ST{uuid.uuid4().hex[:8]}")
topic: str
explain: str
priority: SubtopicPriority
related_questions: List[str] = []
status: ResearchStatus = ResearchStatus.PENDING
search_count: int = 0
max_searches: int = 15
searches: List[Dict[str, Any]] = []
refined_searches: List[Dict[str, Any]] = []
integrated_info: Optional[Dict[str, Any]] = None
report: Optional[str] = None
hallucination_checks: List[Dict[str, Any]] = []
def get_total_searches(self) -> int:
"""获取总搜索次数"""
return len(self.searches) + len(self.refined_searches)
class ResearchOutline(BaseModel):
"""研究大纲模型"""
main_topic: str
research_questions: List[str]
sub_topics: List[Subtopic]
version: int = 1
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class ResearchSession(BaseModel):
"""研究会话模型"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
question: str
question_type: Optional[QuestionType] = None
refined_questions: List[str] = []
research_approach: Optional[str] = None
status: ResearchStatus = ResearchStatus.PENDING
outline: Optional[ResearchOutline] = None
final_report: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
completed_at: Optional[datetime] = None
error_message: Optional[str] = None
# 进度追踪
total_steps: int = 0
completed_steps: int = 0
current_phase: str = "初始化"
def update_status(self, status: ResearchStatus):
"""更新状态"""
self.status = status
self.updated_at = datetime.now()
if status == ResearchStatus.COMPLETED:
self.completed_at = datetime.now()
def get_progress_percentage(self) -> float:
"""获取进度百分比"""
if self.total_steps == 0:
return 0.0
return (self.completed_steps / self.total_steps) * 100
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
data = self.dict()
# 转换datetime对象为字符串
for key in ['created_at', 'updated_at', 'completed_at']:
if data.get(key):
data[key] = data[key].isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ResearchSession':
"""从字典创建实例"""
# 转换字符串为datetime对象
for key in ['created_at', 'updated_at', 'completed_at']:
if data.get(key) and isinstance(data[key], str):
data[key] = datetime.fromisoformat(data[key])
return cls(**data)
def save_to_file(self, filepath: str):
"""保存到文件"""
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
@classmethod
def load_from_file(cls, filepath: str) -> 'ResearchSession':
"""从文件加载"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return cls.from_dict(data)