agent-Specialization/sub_agent/modules/search_engine.py

493 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# modules/search_engine.py - 网络搜索模块
import httpx
import json
from typing import Dict, Optional, Any
from datetime import datetime
import re
try:
from config import TAVILY_API_KEY, SEARCH_MAX_RESULTS, OUTPUT_FORMATS
except ImportError:
import sys
from pathlib import Path
project_root = Path(__file__).resolve().parents[1]
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from config import TAVILY_API_KEY, SEARCH_MAX_RESULTS, OUTPUT_FORMATS
class SearchEngine:
def __init__(self):
self.api_key = TAVILY_API_KEY
self.api_url = "https://api.tavily.com/search"
self._valid_topics = {"general", "news", "finance"}
self._valid_time_ranges = {
"day": "day",
"d": "day",
"week": "week",
"w": "week",
"month": "month",
"m": "month",
"year": "year",
"y": "year"
}
self._date_pattern = re.compile(r"^\d{4}-\d{2}-\d{2}$")
async def search(
self,
query: str,
max_results: Optional[int] = None,
topic: Optional[str] = None,
time_range: Optional[str] = None,
days: Optional[int] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
country: Optional[str] = None
) -> Dict:
"""
执行网络搜索
Args:
query: 搜索关键词
max_results: 最大结果数
topic: 搜索类型general/news/finance
time_range: 相对时间范围day/week/month/year 或 d/w/m/y
days: 过去N天仅topic=news可用
start_date: 起始日期格式YYYY-MM-DD
end_date: 结束日期格式YYYY-MM-DD
country: 国家过滤仅topic=general可用
Returns:
搜索结果字典
"""
if not self.api_key or self.api_key == "your-tavily-api-key":
return {
"success": False,
"error": "Tavily API密钥未配置",
"results": []
}
validation = self._build_payload(
query=query,
max_results=max_results,
topic=topic,
time_range=time_range,
days=days,
start_date=start_date,
end_date=end_date,
country=country
)
if not validation["success"]:
return validation
payload = validation["payload"]
applied_filters = validation["filters"]
max_results = payload.get("max_results", SEARCH_MAX_RESULTS)
print(f"{OUTPUT_FORMATS['search']} 搜索: {query}")
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
self.api_url,
json={
**payload
},
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
)
if response.status_code != 200:
return {
"success": False,
"error": f"API请求失败: {response.status_code}",
"results": []
}
data = response.json()
# 格式化结果
formatted_results = self._format_results(data, applied_filters)
print(f"{OUTPUT_FORMATS['success']} 搜索完成,找到 {len(formatted_results['results'])} 条结果")
return formatted_results
except httpx.TimeoutException:
return {
"success": False,
"error": "搜索超时",
"results": []
}
except Exception as e:
return {
"success": False,
"error": f"搜索失败: {str(e)}",
"results": []
}
def _format_results(self, raw_data: Dict, filters: Dict[str, Any]) -> Dict:
"""格式化搜索结果"""
formatted = {
"success": True,
"query": raw_data.get("query", ""),
"answer": raw_data.get("answer", ""),
"results": [],
"timestamp": datetime.now().isoformat(),
"filters": filters,
"total_results": len(raw_data.get("results", []))
}
# 处理每个搜索结果
for idx, result in enumerate(raw_data.get("results", []), 1):
formatted_result = {
"index": idx,
"title": result.get("title", "无标题"),
"url": result.get("url", ""),
"content": result.get("content", ""),
"score": result.get("score", 0),
"published_date": result.get("published_date", "")
}
formatted["results"].append(formatted_result)
return formatted
async def search_with_summary(
self,
query: str,
max_results: Optional[int] = None,
topic: Optional[str] = None,
time_range: Optional[str] = None,
days: Optional[int] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
country: Optional[str] = None
) -> Dict[str, Any]:
"""
搜索并返回格式化的摘要
Args:
query: 搜索关键词
max_results: 最大结果数
Returns:
格式化的搜索摘要字符串
"""
results = await self.search(
query=query,
max_results=max_results,
topic=topic,
time_range=time_range,
days=days,
start_date=start_date,
end_date=end_date,
country=country
)
if not results["success"]:
return {
"success": False,
"error": results.get("error", "未知错误"),
"summary": ""
}
# 构建摘要
summary_lines = [
f"🔍 搜索查询: {query}",
f"📅 搜索时间: {results['timestamp']}"
]
filter_notes = self._summarize_filters(results.get("filters", {}))
if filter_notes:
summary_lines.append(filter_notes)
summary_lines.append("")
# 添加AI答案如果有
if results.get("answer"):
summary_lines.extend([
"📝 AI摘要:",
results["answer"],
"",
"---",
""
])
# 添加搜索结果
if results["results"]:
summary_lines.append("📊 搜索结果:")
for result in results["results"]:
summary_lines.extend([
f"\n{result['index']}. {result['title']}",
f" 🔗 {result['url']}",
f" 📄 {result['content'][:200]}..." if len(result['content']) > 200 else f" 📄 {result['content']}",
])
if result.get("published_date"):
summary_lines.append(f" 📅 发布时间: {result['published_date']}")
else:
summary_lines.append("未找到相关结果")
return {
"success": True,
"summary": "\n".join(summary_lines),
"filters": results.get("filters", {}),
"query": results.get("query", query),
"results": results.get("results", []),
"total_results": results.get("total_results", len(results.get("results", [])))
}
async def quick_answer(self, query: str) -> str:
"""
快速获取答案只返回AI摘要
Args:
query: 查询问题
Returns:
AI答案或错误信息
"""
results = await self.search(query, max_results=5)
if not results["success"]:
return f"搜索失败: {results['error']}"
if results.get("answer"):
return results["answer"]
# 如果没有AI答案返回第一个结果的摘要
if results["results"]:
first_result = results["results"][0]
return f"{first_result['title']}\n{first_result['content'][:300]}..."
return "未找到相关信息"
def save_results(self, results: Dict, filename: str = None) -> str:
"""
保存搜索结果到文件
Args:
results: 搜索结果
filename: 文件名(可选)
Returns:
保存的文件路径
"""
if filename is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"search_{timestamp}.json"
file_path = f"./data/searches/{filename}"
# 确保目录存在
import os
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# 保存结果
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"{OUTPUT_FORMATS['file']} 搜索结果已保存到: {file_path}")
return file_path
def load_results(self, filename: str) -> Optional[Dict]:
"""
加载之前的搜索结果
Args:
filename: 文件名
Returns:
搜索结果字典或None
"""
file_path = f"./data/searches/{filename}"
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
print(f"{OUTPUT_FORMATS['error']} 文件不存在: {file_path}")
return None
except Exception as e:
print(f"{OUTPUT_FORMATS['error']} 加载失败: {e}")
return None
def _build_payload(
self,
query: str,
max_results: Optional[int],
topic: Optional[str],
time_range: Optional[str],
days: Optional[int],
start_date: Optional[str],
end_date: Optional[str],
country: Optional[str]
) -> Dict[str, Any]:
"""验证并构建 Tavily 请求参数"""
payload: Dict[str, Any] = {
"query": query,
"search_depth": "advanced",
"include_answer": True,
"include_images": False,
"include_raw_content": False
}
filters: Dict[str, Any] = {}
if max_results:
payload["max_results"] = max_results
else:
payload["max_results"] = SEARCH_MAX_RESULTS
normalized_topic = (topic or "general").strip().lower()
if not normalized_topic:
normalized_topic = "general"
if normalized_topic not in self._valid_topics:
return {
"success": False,
"error": f"无效的topic: {topic}. 可选值: {', '.join(self._valid_topics)}",
"results": []
}
payload["topic"] = normalized_topic
filters["topic"] = normalized_topic
# 时间参数互斥检查
has_time_range = bool(time_range)
has_days = days is not None
has_date_range = bool(start_date or end_date)
selected_filters = sum([has_time_range, has_days, has_date_range])
if selected_filters > 1:
return {
"success": False,
"error": "时间参数只能三选一time_range、days、start_date+end_date 不能同时使用",
"results": []
}
# 验证 days
if has_days:
try:
days_value = int(days) # type: ignore[arg-type]
except (TypeError, ValueError):
return {
"success": False,
"error": f"days 必须是正整数,当前值: {days}",
"results": []
}
if days_value <= 0:
return {
"success": False,
"error": f"days 必须大于0当前值: {days_value}",
"results": []
}
if normalized_topic != "news":
return {
"success": False,
"error": "days 参数仅在 topic=\"news\" 时可用,请调整 topic 或改用其他时间参数",
"results": []
}
payload["days"] = days_value
filters["days"] = days_value
# 验证 time_range
if has_time_range:
normalized_range = time_range.strip().lower() # type: ignore[union-attr]
normalized_range = self._valid_time_ranges.get(normalized_range, "")
if not normalized_range:
return {
"success": False,
"error": f"无效的time_range: {time_range}. 可选值: day/week/month/year 或缩写 d/w/m/y",
"results": []
}
payload["time_range"] = normalized_range
filters["time_range"] = normalized_range
# 验证日期范围
if has_date_range:
if not start_date or not end_date:
return {
"success": False,
"error": "start_date 与 end_date 必须同时提供且格式为 YYYY-MM-DD",
"results": []
}
if not self._date_pattern.match(start_date):
return {
"success": False,
"error": f"start_date 格式无效: {start_date},请使用 YYYY-MM-DD",
"results": []
}
if not self._date_pattern.match(end_date):
return {
"success": False,
"error": f"end_date 格式无效: {end_date},请使用 YYYY-MM-DD",
"results": []
}
try:
start_dt = datetime.fromisoformat(start_date)
end_dt = datetime.fromisoformat(end_date)
except ValueError:
return {
"success": False,
"error": "start_date 或 end_date 含无效日期,请检查是否为有效的公历日期",
"results": []
}
if start_dt > end_dt:
return {
"success": False,
"error": f"start_date ({start_date}) 不能晚于 end_date ({end_date})",
"results": []
}
payload["start_date"] = start_date
payload["end_date"] = end_date
filters["start_date"] = start_date
filters["end_date"] = end_date
# 国家过滤
if country:
normalized_country = country.strip().lower()
if normalized_country:
if normalized_topic != "general":
return {
"success": False,
"error": "country 参数仅在 topic=\"general\" 时可用,请调整 topic 或移除 country",
"results": []
}
payload["country"] = normalized_country
filters["country"] = normalized_country
return {
"success": True,
"payload": payload,
"filters": filters,
"results": []
}
def _summarize_filters(self, filters: Dict[str, Any]) -> str:
"""构建过滤条件摘要"""
if not filters:
return ""
parts = []
topic = filters.get("topic")
if topic:
parts.append(f"Topic: {topic}")
if "time_range" in filters:
parts.append(f"Time Range: {filters['time_range']}")
elif "days" in filters:
parts.append(f"最近 {filters['days']}")
elif "start_date" in filters and "end_date" in filters:
parts.append(f"{filters['start_date']}{filters['end_date']}")
if "country" in filters:
parts.append(f"Country: {filters['country']}")
if not parts:
return ""
return "🎯 过滤条件: " + " | ".join(parts)