# modules/search_engine.py - 网络搜索模块 import httpx import json from typing import Dict, Optional, Any from datetime import datetime import re try: from config import TAVILY_API_KEY, SEARCH_MAX_RESULTS, OUTPUT_FORMATS except ImportError: import sys from pathlib import Path project_root = Path(__file__).resolve().parents[1] if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) from config import TAVILY_API_KEY, SEARCH_MAX_RESULTS, OUTPUT_FORMATS class SearchEngine: def __init__(self): self.api_key = TAVILY_API_KEY self.api_url = "https://api.tavily.com/search" self._valid_topics = {"general", "news", "finance"} self._valid_time_ranges = { "day": "day", "d": "day", "week": "week", "w": "week", "month": "month", "m": "month", "year": "year", "y": "year" } self._date_pattern = re.compile(r"^\d{4}-\d{2}-\d{2}$") async def search( self, query: str, max_results: Optional[int] = None, topic: Optional[str] = None, time_range: Optional[str] = None, days: Optional[int] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, country: Optional[str] = None ) -> Dict: """ 执行网络搜索 Args: query: 搜索关键词 max_results: 最大结果数 topic: 搜索类型(general/news/finance) time_range: 相对时间范围(day/week/month/year 或 d/w/m/y) days: 过去N天,仅topic=news可用 start_date: 起始日期,格式YYYY-MM-DD end_date: 结束日期,格式YYYY-MM-DD country: 国家过滤,仅topic=general可用 Returns: 搜索结果字典 """ if not self.api_key or self.api_key == "your-tavily-api-key": return { "success": False, "error": "Tavily API密钥未配置", "results": [] } validation = self._build_payload( query=query, max_results=max_results, topic=topic, time_range=time_range, days=days, start_date=start_date, end_date=end_date, country=country ) if not validation["success"]: return validation payload = validation["payload"] applied_filters = validation["filters"] max_results = payload.get("max_results", SEARCH_MAX_RESULTS) print(f"{OUTPUT_FORMATS['search']} 搜索: {query}") try: async with httpx.AsyncClient(timeout=30) as client: response = await client.post( self.api_url, json={ **payload }, headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } ) if response.status_code != 200: return { "success": False, "error": f"API请求失败: {response.status_code}", "results": [] } data = response.json() # 格式化结果 formatted_results = self._format_results(data, applied_filters) print(f"{OUTPUT_FORMATS['success']} 搜索完成,找到 {len(formatted_results['results'])} 条结果") return formatted_results except httpx.TimeoutException: return { "success": False, "error": "搜索超时", "results": [] } except Exception as e: return { "success": False, "error": f"搜索失败: {str(e)}", "results": [] } def _format_results(self, raw_data: Dict, filters: Dict[str, Any]) -> Dict: """格式化搜索结果""" formatted = { "success": True, "query": raw_data.get("query", ""), "answer": raw_data.get("answer", ""), "results": [], "timestamp": datetime.now().isoformat(), "filters": filters, "total_results": len(raw_data.get("results", [])) } # 处理每个搜索结果 for idx, result in enumerate(raw_data.get("results", []), 1): formatted_result = { "index": idx, "title": result.get("title", "无标题"), "url": result.get("url", ""), "content": result.get("content", ""), "score": result.get("score", 0), "published_date": result.get("published_date", "") } formatted["results"].append(formatted_result) return formatted async def search_with_summary( self, query: str, max_results: Optional[int] = None, topic: Optional[str] = None, time_range: Optional[str] = None, days: Optional[int] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, country: Optional[str] = None ) -> Dict[str, Any]: """ 搜索并返回格式化的摘要 Args: query: 搜索关键词 max_results: 最大结果数 Returns: 格式化的搜索摘要字符串 """ results = await self.search( query=query, max_results=max_results, topic=topic, time_range=time_range, days=days, start_date=start_date, end_date=end_date, country=country ) if not results["success"]: return { "success": False, "error": results.get("error", "未知错误"), "summary": "" } # 构建摘要 summary_lines = [ f"🔍 搜索查询: {query}", f"📅 搜索时间: {results['timestamp']}" ] filter_notes = self._summarize_filters(results.get("filters", {})) if filter_notes: summary_lines.append(filter_notes) summary_lines.append("") # 添加AI答案(如果有) if results.get("answer"): summary_lines.extend([ "📝 AI摘要:", results["answer"], "", "---", "" ]) # 添加搜索结果 if results["results"]: summary_lines.append("📊 搜索结果:") for result in results["results"]: summary_lines.extend([ f"\n{result['index']}. {result['title']}", f" 🔗 {result['url']}", f" 📄 {result['content'][:200]}..." if len(result['content']) > 200 else f" 📄 {result['content']}", ]) if result.get("published_date"): summary_lines.append(f" 📅 发布时间: {result['published_date']}") else: summary_lines.append("未找到相关结果") return { "success": True, "summary": "\n".join(summary_lines), "filters": results.get("filters", {}), "query": results.get("query", query), "results": results.get("results", []), "total_results": results.get("total_results", len(results.get("results", []))) } async def quick_answer(self, query: str) -> str: """ 快速获取答案(只返回AI摘要) Args: query: 查询问题 Returns: AI答案或错误信息 """ results = await self.search(query, max_results=5) if not results["success"]: return f"搜索失败: {results['error']}" if results.get("answer"): return results["answer"] # 如果没有AI答案,返回第一个结果的摘要 if results["results"]: first_result = results["results"][0] return f"{first_result['title']}\n{first_result['content'][:300]}..." return "未找到相关信息" def save_results(self, results: Dict, filename: str = None) -> str: """ 保存搜索结果到文件 Args: results: 搜索结果 filename: 文件名(可选) Returns: 保存的文件路径 """ if filename is None: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"search_{timestamp}.json" file_path = f"./data/searches/{filename}" # 确保目录存在 import os os.makedirs(os.path.dirname(file_path), exist_ok=True) # 保存结果 with open(file_path, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) print(f"{OUTPUT_FORMATS['file']} 搜索结果已保存到: {file_path}") return file_path def load_results(self, filename: str) -> Optional[Dict]: """ 加载之前的搜索结果 Args: filename: 文件名 Returns: 搜索结果字典或None """ file_path = f"./data/searches/{filename}" try: with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) except FileNotFoundError: print(f"{OUTPUT_FORMATS['error']} 文件不存在: {file_path}") return None except Exception as e: print(f"{OUTPUT_FORMATS['error']} 加载失败: {e}") return None def _build_payload( self, query: str, max_results: Optional[int], topic: Optional[str], time_range: Optional[str], days: Optional[int], start_date: Optional[str], end_date: Optional[str], country: Optional[str] ) -> Dict[str, Any]: """验证并构建 Tavily 请求参数""" payload: Dict[str, Any] = { "query": query, "search_depth": "advanced", "include_answer": True, "include_images": False, "include_raw_content": False } filters: Dict[str, Any] = {} if max_results: payload["max_results"] = max_results else: payload["max_results"] = SEARCH_MAX_RESULTS normalized_topic = (topic or "general").strip().lower() if not normalized_topic: normalized_topic = "general" if normalized_topic not in self._valid_topics: return { "success": False, "error": f"无效的topic: {topic}. 可选值: {', '.join(self._valid_topics)}", "results": [] } payload["topic"] = normalized_topic filters["topic"] = normalized_topic # 时间参数互斥检查 has_time_range = bool(time_range) has_days = days is not None has_date_range = bool(start_date or end_date) selected_filters = sum([has_time_range, has_days, has_date_range]) if selected_filters > 1: return { "success": False, "error": "时间参数只能三选一:time_range、days、start_date+end_date 不能同时使用", "results": [] } # 验证 days if has_days: try: days_value = int(days) # type: ignore[arg-type] except (TypeError, ValueError): return { "success": False, "error": f"days 必须是正整数,当前值: {days}", "results": [] } if days_value <= 0: return { "success": False, "error": f"days 必须大于0,当前值: {days_value}", "results": [] } if normalized_topic != "news": return { "success": False, "error": "days 参数仅在 topic=\"news\" 时可用,请调整 topic 或改用其他时间参数", "results": [] } payload["days"] = days_value filters["days"] = days_value # 验证 time_range if has_time_range: normalized_range = time_range.strip().lower() # type: ignore[union-attr] normalized_range = self._valid_time_ranges.get(normalized_range, "") if not normalized_range: return { "success": False, "error": f"无效的time_range: {time_range}. 可选值: day/week/month/year 或缩写 d/w/m/y", "results": [] } payload["time_range"] = normalized_range filters["time_range"] = normalized_range # 验证日期范围 if has_date_range: if not start_date or not end_date: return { "success": False, "error": "start_date 与 end_date 必须同时提供且格式为 YYYY-MM-DD", "results": [] } if not self._date_pattern.match(start_date): return { "success": False, "error": f"start_date 格式无效: {start_date},请使用 YYYY-MM-DD", "results": [] } if not self._date_pattern.match(end_date): return { "success": False, "error": f"end_date 格式无效: {end_date},请使用 YYYY-MM-DD", "results": [] } try: start_dt = datetime.fromisoformat(start_date) end_dt = datetime.fromisoformat(end_date) except ValueError: return { "success": False, "error": "start_date 或 end_date 含无效日期,请检查是否为有效的公历日期", "results": [] } if start_dt > end_dt: return { "success": False, "error": f"start_date ({start_date}) 不能晚于 end_date ({end_date})", "results": [] } payload["start_date"] = start_date payload["end_date"] = end_date filters["start_date"] = start_date filters["end_date"] = end_date # 国家过滤 if country: normalized_country = country.strip().lower() if normalized_country: if normalized_topic != "general": return { "success": False, "error": "country 参数仅在 topic=\"general\" 时可用,请调整 topic 或移除 country", "results": [] } payload["country"] = normalized_country filters["country"] = normalized_country return { "success": True, "payload": payload, "filters": filters, "results": [] } def _summarize_filters(self, filters: Dict[str, Any]) -> str: """构建过滤条件摘要""" if not filters: return "" parts = [] topic = filters.get("topic") if topic: parts.append(f"Topic: {topic}") if "time_range" in filters: parts.append(f"Time Range: {filters['time_range']}") elif "days" in filters: parts.append(f"最近 {filters['days']} 天") elif "start_date" in filters and "end_date" in filters: parts.append(f"{filters['start_date']} 至 {filters['end_date']}") if "country" in filters: parts.append(f"Country: {filters['country']}") if not parts: return "" return "🎯 过滤条件: " + " | ".join(parts)