204 lines
7.2 KiB
Python
204 lines
7.2 KiB
Python
"""
|
||
搜索服务
|
||
封装Tavily API调用
|
||
"""
|
||
import logging
|
||
from typing import List, Dict, Any, Optional
|
||
from tavily import TavilyClient
|
||
from app.models.search_result import SearchResult, TavilySearchResponse, SearchBatch
|
||
from config import Config
|
||
import time
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class SearchService:
|
||
"""搜索服务"""
|
||
|
||
def __init__(self, api_key: str = None):
|
||
self.api_key = api_key or Config.TAVILY_API_KEY
|
||
self.client = TavilyClient(api_key=self.api_key)
|
||
self._search_cache = {} # 简单的搜索缓存
|
||
|
||
def search(self, query: str, max_results: int = None,
|
||
search_depth: str = None, include_answer: bool = None,
|
||
include_raw_content: bool = None) -> TavilySearchResponse:
|
||
"""执行搜索"""
|
||
# 检查缓存
|
||
cache_key = f"{query}:{max_results}:{search_depth}"
|
||
if cache_key in self._search_cache:
|
||
logger.info(f"从缓存返回搜索结果: {query}")
|
||
return self._search_cache[cache_key]
|
||
|
||
# 设置默认值
|
||
if max_results is None:
|
||
max_results = Config.TAVILY_MAX_RESULTS
|
||
if search_depth is None:
|
||
search_depth = Config.TAVILY_SEARCH_DEPTH
|
||
if include_answer is None:
|
||
include_answer = Config.TAVILY_INCLUDE_ANSWER
|
||
if include_raw_content is None:
|
||
include_raw_content = Config.TAVILY_INCLUDE_RAW_CONTENT
|
||
|
||
try:
|
||
logger.info(f"执行Tavily搜索: {query}")
|
||
start_time = time.time()
|
||
|
||
# 调用Tavily API
|
||
response = self.client.search(
|
||
query=query,
|
||
max_results=max_results,
|
||
search_depth=search_depth,
|
||
include_answer=include_answer,
|
||
include_raw_content=include_raw_content
|
||
)
|
||
|
||
response_time = time.time() - start_time
|
||
|
||
# 转换为我们的响应模型
|
||
tavily_response = TavilySearchResponse(
|
||
query=query,
|
||
answer=response.get('answer'),
|
||
images=response.get('images', []),
|
||
results=response.get('results', []),
|
||
response_time=response_time
|
||
)
|
||
|
||
# 缓存结果
|
||
self._search_cache[cache_key] = tavily_response
|
||
|
||
logger.info(f"搜索完成,耗时 {response_time:.2f}秒,返回 {len(tavily_response.results)} 条结果")
|
||
return tavily_response
|
||
|
||
except Exception as e:
|
||
logger.error(f"Tavily搜索失败: {e}")
|
||
# 返回空结果
|
||
return TavilySearchResponse(
|
||
query=query,
|
||
answer=None,
|
||
images=[],
|
||
results=[],
|
||
response_time=0.0
|
||
)
|
||
|
||
def batch_search(self, queries: List[str], max_results_per_query: int = 10) -> List[TavilySearchResponse]:
|
||
"""批量搜索"""
|
||
results = []
|
||
|
||
for query in queries:
|
||
# 添加延迟以避免速率限制
|
||
if results: # 不是第一个查询
|
||
time.sleep(0.5) # 500ms延迟
|
||
|
||
try:
|
||
response = self.search(query, max_results=max_results_per_query)
|
||
results.append(response)
|
||
except Exception as e:
|
||
logger.error(f"批量搜索中的查询失败 '{query}': {e}")
|
||
# 添加空结果
|
||
results.append(TavilySearchResponse(
|
||
query=query,
|
||
results=[],
|
||
response_time=0.0
|
||
))
|
||
|
||
return results
|
||
|
||
def search_subtopic(self, subtopic_id: str, subtopic_name: str,
|
||
queries: List[str]) -> SearchBatch:
|
||
"""为子主题执行搜索"""
|
||
all_results = []
|
||
|
||
for query in queries:
|
||
response = self.search(query)
|
||
search_results = response.to_search_results()
|
||
all_results.extend(search_results)
|
||
|
||
# 创建搜索批次
|
||
batch = SearchBatch(
|
||
subtopic_id=subtopic_id,
|
||
query=f"子主题搜索: {subtopic_name}",
|
||
results=[]
|
||
)
|
||
|
||
# 去重并添加结果
|
||
batch.add_results(all_results)
|
||
|
||
return batch
|
||
|
||
def refined_search(self, subtopic_id: str, key_info: str,
|
||
queries: List[str], parent_search_id: str = None) -> SearchBatch:
|
||
"""执行细化搜索"""
|
||
all_results = []
|
||
|
||
for query in queries:
|
||
response = self.search(query, search_depth="advanced")
|
||
search_results = response.to_search_results()
|
||
all_results.extend(search_results)
|
||
|
||
# 创建细化搜索批次
|
||
batch = SearchBatch(
|
||
subtopic_id=subtopic_id,
|
||
query=f"细化搜索: {key_info}",
|
||
results=[],
|
||
is_refined_search=True,
|
||
parent_search_id=parent_search_id,
|
||
detail_type=key_info
|
||
)
|
||
|
||
batch.add_results(all_results)
|
||
|
||
return batch
|
||
|
||
def extract_content(self, urls: List[str]) -> Dict[str, str]:
|
||
"""提取URL的完整内容"""
|
||
content_map = {}
|
||
|
||
try:
|
||
# Tavily的extract功能(如果可用)
|
||
# 注意:这需要Tavily API支持extract功能
|
||
response = self.client.extract(urls=urls[:20]) # 最多20个URL
|
||
|
||
for result in response.get('results', []):
|
||
url = result.get('url')
|
||
content = result.get('raw_content', '')
|
||
if url and content:
|
||
content_map[url] = content
|
||
|
||
except Exception as e:
|
||
logger.error(f"提取内容失败: {e}")
|
||
# 如果extract不可用,使用搜索结果中的内容
|
||
for url in urls:
|
||
# 从缓存的搜索结果中查找
|
||
for cached_response in self._search_cache.values():
|
||
for result in cached_response.results:
|
||
if result.get('url') == url:
|
||
content_map[url] = result.get('content', '')
|
||
break
|
||
|
||
return content_map
|
||
|
||
def get_search_statistics(self) -> Dict[str, Any]:
|
||
"""获取搜索统计信息"""
|
||
total_searches = len(self._search_cache)
|
||
total_results = sum(len(r.results) for r in self._search_cache.values())
|
||
|
||
return {
|
||
"total_searches": total_searches,
|
||
"total_results": total_results,
|
||
"cache_size": len(self._search_cache),
|
||
"cached_queries": list(self._search_cache.keys())
|
||
}
|
||
|
||
def clear_cache(self):
|
||
"""清空搜索缓存"""
|
||
self._search_cache.clear()
|
||
logger.info("搜索缓存已清空")
|
||
|
||
def test_connection(self) -> bool:
|
||
"""测试Tavily API连接"""
|
||
try:
|
||
response = self.search("test query", max_results=1)
|
||
return len(response.results) >= 0
|
||
except Exception as e:
|
||
logger.error(f"Tavily API连接测试失败: {e}")
|
||
return False |