119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
"""
|
|
搜索结果数据模型
|
|
"""
|
|
from datetime import datetime
|
|
from typing import List, Optional, Dict, Any
|
|
from enum import Enum
|
|
from pydantic import BaseModel, Field
|
|
|
|
class SearchImportance(str, Enum):
|
|
"""搜索结果重要性"""
|
|
HIGH = "high"
|
|
MEDIUM = "medium"
|
|
LOW = "low"
|
|
|
|
class SearchResult(BaseModel):
|
|
"""单个搜索结果"""
|
|
title: str
|
|
url: str
|
|
snippet: str
|
|
score: float = 0.0
|
|
published_date: Optional[str] = None
|
|
raw_content: Optional[str] = None
|
|
importance: Optional[SearchImportance] = None
|
|
key_findings: List[str] = []
|
|
|
|
def __hash__(self):
|
|
"""使URL可以用于集合去重"""
|
|
return hash(self.url)
|
|
|
|
def __eq__(self, other):
|
|
"""基于URL判断相等性"""
|
|
if isinstance(other, SearchResult):
|
|
return self.url == other.url
|
|
return False
|
|
|
|
class TavilySearchResponse(BaseModel):
|
|
"""Tavily API响应模型"""
|
|
query: str
|
|
answer: Optional[str] = None
|
|
images: List[str] = []
|
|
results: List[Dict[str, Any]] = []
|
|
response_time: float = 0.0
|
|
|
|
def to_search_results(self) -> List[SearchResult]:
|
|
"""转换为SearchResult列表"""
|
|
search_results = []
|
|
for result in self.results:
|
|
search_results.append(SearchResult(
|
|
title=result.get('title', ''),
|
|
url=result.get('url', ''),
|
|
snippet=result.get('content', ''),
|
|
score=result.get('score', 0.0),
|
|
published_date=result.get('published_date'),
|
|
raw_content=result.get('raw_content')
|
|
))
|
|
return search_results
|
|
|
|
class SearchBatch(BaseModel):
|
|
"""搜索批次"""
|
|
search_id: str = Field(default_factory=lambda: f"S{uuid.uuid4().hex[:8]}")
|
|
subtopic_id: str
|
|
query: str
|
|
timestamp: datetime = Field(default_factory=datetime.now)
|
|
results: List[SearchResult] = []
|
|
is_refined_search: bool = False
|
|
parent_search_id: Optional[str] = None
|
|
detail_type: Optional[str] = None
|
|
total_results: int = 0
|
|
|
|
def add_results(self, results: List[SearchResult]):
|
|
"""添加搜索结果并去重"""
|
|
existing_urls = {r.url for r in self.results}
|
|
for result in results:
|
|
if result.url not in existing_urls:
|
|
self.results.append(result)
|
|
existing_urls.add(result.url)
|
|
self.total_results = len(self.results)
|
|
|
|
class SearchSummary(BaseModel):
|
|
"""搜索摘要统计"""
|
|
subtopic_id: str
|
|
total_searches: int = 0
|
|
total_results: int = 0
|
|
high_importance_count: int = 0
|
|
medium_importance_count: int = 0
|
|
low_importance_count: int = 0
|
|
unique_domains: List[str] = []
|
|
|
|
@classmethod
|
|
def from_search_batches(cls, subtopic_id: str, batches: List[SearchBatch]) -> 'SearchSummary':
|
|
"""从搜索批次生成摘要"""
|
|
summary = cls(subtopic_id=subtopic_id)
|
|
summary.total_searches = len(batches)
|
|
|
|
all_results = []
|
|
domains = set()
|
|
|
|
for batch in batches:
|
|
all_results.extend(batch.results)
|
|
for result in batch.results:
|
|
# 提取域名
|
|
from urllib.parse import urlparse
|
|
domain = urlparse(result.url).netloc
|
|
if domain:
|
|
domains.add(domain)
|
|
|
|
summary.total_results = len(all_results)
|
|
summary.unique_domains = list(domains)
|
|
|
|
# 统计重要性
|
|
for result in all_results:
|
|
if result.importance == SearchImportance.HIGH:
|
|
summary.high_importance_count += 1
|
|
elif result.importance == SearchImportance.MEDIUM:
|
|
summary.medium_importance_count += 1
|
|
elif result.importance == SearchImportance.LOW:
|
|
summary.low_importance_count += 1
|
|
|
|
return summary |