""" 搜索服务 封装Tavily API调用 """ import logging from typing import List, Dict, Any, Optional from tavily import TavilyClient from app.models.search_result import SearchResult, TavilySearchResponse, SearchBatch from config import Config import time logger = logging.getLogger(__name__) class SearchService: """搜索服务""" def __init__(self, api_key: str = None): self.api_key = api_key or Config.TAVILY_API_KEY self.client = TavilyClient(api_key=self.api_key) self._search_cache = {} # 简单的搜索缓存 def search(self, query: str, max_results: int = None, search_depth: str = None, include_answer: bool = None, include_raw_content: bool = None) -> TavilySearchResponse: """执行搜索""" # 检查缓存 cache_key = f"{query}:{max_results}:{search_depth}" if cache_key in self._search_cache: logger.info(f"从缓存返回搜索结果: {query}") return self._search_cache[cache_key] # 设置默认值 if max_results is None: max_results = Config.TAVILY_MAX_RESULTS if search_depth is None: search_depth = Config.TAVILY_SEARCH_DEPTH if include_answer is None: include_answer = Config.TAVILY_INCLUDE_ANSWER if include_raw_content is None: include_raw_content = Config.TAVILY_INCLUDE_RAW_CONTENT try: logger.info(f"执行Tavily搜索: {query}") start_time = time.time() # 调用Tavily API response = self.client.search( query=query, max_results=max_results, search_depth=search_depth, include_answer=include_answer, include_raw_content=include_raw_content ) response_time = time.time() - start_time # 转换为我们的响应模型 tavily_response = TavilySearchResponse( query=query, answer=response.get('answer'), images=response.get('images', []), results=response.get('results', []), response_time=response_time ) # 缓存结果 self._search_cache[cache_key] = tavily_response logger.info(f"搜索完成,耗时 {response_time:.2f}秒,返回 {len(tavily_response.results)} 条结果") return tavily_response except Exception as e: logger.error(f"Tavily搜索失败: {e}") # 返回空结果 return TavilySearchResponse( query=query, answer=None, images=[], results=[], response_time=0.0 ) def batch_search(self, queries: List[str], max_results_per_query: int = 10) -> List[TavilySearchResponse]: """批量搜索""" results = [] for query in queries: # 添加延迟以避免速率限制 if results: # 不是第一个查询 time.sleep(0.5) # 500ms延迟 try: response = self.search(query, max_results=max_results_per_query) results.append(response) except Exception as e: logger.error(f"批量搜索中的查询失败 '{query}': {e}") # 添加空结果 results.append(TavilySearchResponse( query=query, results=[], response_time=0.0 )) return results def search_subtopic(self, subtopic_id: str, subtopic_name: str, queries: List[str]) -> SearchBatch: """为子主题执行搜索""" all_results = [] for query in queries: response = self.search(query) search_results = response.to_search_results() all_results.extend(search_results) # 创建搜索批次 batch = SearchBatch( subtopic_id=subtopic_id, query=f"子主题搜索: {subtopic_name}", results=[] ) # 去重并添加结果 batch.add_results(all_results) return batch def refined_search(self, subtopic_id: str, key_info: str, queries: List[str], parent_search_id: str = None) -> SearchBatch: """执行细化搜索""" all_results = [] for query in queries: response = self.search(query, search_depth="advanced") search_results = response.to_search_results() all_results.extend(search_results) # 创建细化搜索批次 batch = SearchBatch( subtopic_id=subtopic_id, query=f"细化搜索: {key_info}", results=[], is_refined_search=True, parent_search_id=parent_search_id, detail_type=key_info ) batch.add_results(all_results) return batch def extract_content(self, urls: List[str]) -> Dict[str, str]: """提取URL的完整内容""" content_map = {} try: # Tavily的extract功能(如果可用) # 注意:这需要Tavily API支持extract功能 response = self.client.extract(urls=urls[:20]) # 最多20个URL for result in response.get('results', []): url = result.get('url') content = result.get('raw_content', '') if url and content: content_map[url] = content except Exception as e: logger.error(f"提取内容失败: {e}") # 如果extract不可用,使用搜索结果中的内容 for url in urls: # 从缓存的搜索结果中查找 for cached_response in self._search_cache.values(): for result in cached_response.results: if result.get('url') == url: content_map[url] = result.get('content', '') break return content_map def get_search_statistics(self) -> Dict[str, Any]: """获取搜索统计信息""" total_searches = len(self._search_cache) total_results = sum(len(r.results) for r in self._search_cache.values()) return { "total_searches": total_searches, "total_results": total_results, "cache_size": len(self._search_cache), "cached_queries": list(self._search_cache.keys()) } def clear_cache(self): """清空搜索缓存""" self._search_cache.clear() logger.info("搜索缓存已清空") def test_connection(self) -> bool: """测试Tavily API连接""" try: response = self.search("test query", max_results=1) return len(response.results) >= 0 except Exception as e: logger.error(f"Tavily API连接测试失败: {e}") return False