deepresearch/app/services/search_service.py
2025-07-02 15:35:36 +08:00

204 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
搜索服务
封装Tavily API调用
"""
import logging
from typing import List, Dict, Any, Optional
from tavily import TavilyClient
from app.models.search_result import SearchResult, TavilySearchResponse, SearchBatch
from config import Config
import time
logger = logging.getLogger(__name__)
class SearchService:
"""搜索服务"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.TAVILY_API_KEY
self.client = TavilyClient(api_key=self.api_key)
self._search_cache = {} # 简单的搜索缓存
def search(self, query: str, max_results: int = None,
search_depth: str = None, include_answer: bool = None,
include_raw_content: bool = None) -> TavilySearchResponse:
"""执行搜索"""
# 检查缓存
cache_key = f"{query}:{max_results}:{search_depth}"
if cache_key in self._search_cache:
logger.info(f"从缓存返回搜索结果: {query}")
return self._search_cache[cache_key]
# 设置默认值
if max_results is None:
max_results = Config.TAVILY_MAX_RESULTS
if search_depth is None:
search_depth = Config.TAVILY_SEARCH_DEPTH
if include_answer is None:
include_answer = Config.TAVILY_INCLUDE_ANSWER
if include_raw_content is None:
include_raw_content = Config.TAVILY_INCLUDE_RAW_CONTENT
try:
logger.info(f"执行Tavily搜索: {query}")
start_time = time.time()
# 调用Tavily API
response = self.client.search(
query=query,
max_results=max_results,
search_depth=search_depth,
include_answer=include_answer,
include_raw_content=include_raw_content
)
response_time = time.time() - start_time
# 转换为我们的响应模型
tavily_response = TavilySearchResponse(
query=query,
answer=response.get('answer'),
images=response.get('images', []),
results=response.get('results', []),
response_time=response_time
)
# 缓存结果
self._search_cache[cache_key] = tavily_response
logger.info(f"搜索完成,耗时 {response_time:.2f}秒,返回 {len(tavily_response.results)} 条结果")
return tavily_response
except Exception as e:
logger.error(f"Tavily搜索失败: {e}")
# 返回空结果
return TavilySearchResponse(
query=query,
answer=None,
images=[],
results=[],
response_time=0.0
)
def batch_search(self, queries: List[str], max_results_per_query: int = 10) -> List[TavilySearchResponse]:
"""批量搜索"""
results = []
for query in queries:
# 添加延迟以避免速率限制
if results: # 不是第一个查询
time.sleep(0.5) # 500ms延迟
try:
response = self.search(query, max_results=max_results_per_query)
results.append(response)
except Exception as e:
logger.error(f"批量搜索中的查询失败 '{query}': {e}")
# 添加空结果
results.append(TavilySearchResponse(
query=query,
results=[],
response_time=0.0
))
return results
def search_subtopic(self, subtopic_id: str, subtopic_name: str,
queries: List[str]) -> SearchBatch:
"""为子主题执行搜索"""
all_results = []
for query in queries:
response = self.search(query)
search_results = response.to_search_results()
all_results.extend(search_results)
# 创建搜索批次
batch = SearchBatch(
subtopic_id=subtopic_id,
query=f"子主题搜索: {subtopic_name}",
results=[]
)
# 去重并添加结果
batch.add_results(all_results)
return batch
def refined_search(self, subtopic_id: str, key_info: str,
queries: List[str], parent_search_id: str = None) -> SearchBatch:
"""执行细化搜索"""
all_results = []
for query in queries:
response = self.search(query, search_depth="advanced")
search_results = response.to_search_results()
all_results.extend(search_results)
# 创建细化搜索批次
batch = SearchBatch(
subtopic_id=subtopic_id,
query=f"细化搜索: {key_info}",
results=[],
is_refined_search=True,
parent_search_id=parent_search_id,
detail_type=key_info
)
batch.add_results(all_results)
return batch
def extract_content(self, urls: List[str]) -> Dict[str, str]:
"""提取URL的完整内容"""
content_map = {}
try:
# Tavily的extract功能如果可用
# 注意这需要Tavily API支持extract功能
response = self.client.extract(urls=urls[:20]) # 最多20个URL
for result in response.get('results', []):
url = result.get('url')
content = result.get('raw_content', '')
if url and content:
content_map[url] = content
except Exception as e:
logger.error(f"提取内容失败: {e}")
# 如果extract不可用使用搜索结果中的内容
for url in urls:
# 从缓存的搜索结果中查找
for cached_response in self._search_cache.values():
for result in cached_response.results:
if result.get('url') == url:
content_map[url] = result.get('content', '')
break
return content_map
def get_search_statistics(self) -> Dict[str, Any]:
"""获取搜索统计信息"""
total_searches = len(self._search_cache)
total_results = sum(len(r.results) for r in self._search_cache.values())
return {
"total_searches": total_searches,
"total_results": total_results,
"cache_size": len(self._search_cache),
"cached_queries": list(self._search_cache.keys())
}
def clear_cache(self):
"""清空搜索缓存"""
self._search_cache.clear()
logger.info("搜索缓存已清空")
def test_connection(self) -> bool:
"""测试Tavily API连接"""
try:
response = self.search("test query", max_results=1)
return len(response.results) >= 0
except Exception as e:
logger.error(f"Tavily API连接测试失败: {e}")
return False