586 lines
24 KiB
Python
586 lines
24 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
from typing import Literal, Sequence, Optional, List, Union
|
|
|
|
import httpx
|
|
|
|
from .utils import get_max_items_from_list
|
|
from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError
|
|
|
|
|
|
class AsyncTavilyClient:
|
|
"""
|
|
Async Tavily API client class.
|
|
"""
|
|
|
|
def __init__(self, api_key: Optional[str] = None,
|
|
company_info_tags: Sequence[str] = ("news", "general", "finance"),
|
|
proxies: Optional[dict[str, str]] = None,
|
|
api_base_url: Optional[str] = None):
|
|
if api_key is None:
|
|
api_key = os.getenv("TAVILY_API_KEY")
|
|
|
|
if not api_key:
|
|
raise MissingAPIKeyError()
|
|
|
|
proxies = proxies or {}
|
|
|
|
mapped_proxies = {
|
|
"http://": proxies.get("http", os.getenv("TAVILY_HTTP_PROXY")),
|
|
"https://": proxies.get("https", os.getenv("TAVILY_HTTPS_PROXY")),
|
|
}
|
|
|
|
mapped_proxies = {key: value for key, value in mapped_proxies.items() if value}
|
|
|
|
proxy_mounts = (
|
|
{scheme: httpx.AsyncHTTPTransport(proxy=proxy) for scheme, proxy in mapped_proxies.items()}
|
|
if mapped_proxies
|
|
else None
|
|
)
|
|
|
|
self._api_base_url = api_base_url or "https://api.tavily.com"
|
|
self._client_creator = lambda: httpx.AsyncClient(
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}",
|
|
"X-Client-Source": "tavily-python"
|
|
},
|
|
base_url=self._api_base_url,
|
|
mounts=proxy_mounts
|
|
)
|
|
self._company_info_tags = company_info_tags
|
|
|
|
async def _search(
|
|
self,
|
|
query: str,
|
|
search_depth: Literal["basic", "advanced"] = None,
|
|
topic: Literal["general", "news", "finance"] = None,
|
|
time_range: Literal["day", "week", "month", "year"] = None,
|
|
start_date: str = None,
|
|
end_date: str = None,
|
|
days: int = None,
|
|
max_results: int = None,
|
|
include_domains: Sequence[str] = None,
|
|
exclude_domains: Sequence[str] = None,
|
|
include_answer: Union[bool, Literal["basic", "advanced"]] = None,
|
|
include_raw_content: Union[bool, Literal["markdown", "text"]] = None,
|
|
include_images: bool = None,
|
|
timeout: int = 60,
|
|
country: str = None,
|
|
auto_parameters: bool = None,
|
|
include_favicon: bool = None,
|
|
**kwargs,
|
|
) -> dict:
|
|
"""
|
|
Internal search method to send the request to the API.
|
|
"""
|
|
data = {
|
|
"query": query,
|
|
"search_depth": search_depth,
|
|
"topic": topic,
|
|
"time_range": time_range,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
"days": days,
|
|
"include_answer": include_answer,
|
|
"include_raw_content": include_raw_content,
|
|
"max_results": max_results,
|
|
"include_domains": include_domains,
|
|
"exclude_domains": exclude_domains,
|
|
"include_images": include_images,
|
|
"country": country,
|
|
"auto_parameters": auto_parameters,
|
|
"include_favicon": include_favicon,
|
|
}
|
|
|
|
data = {k: v for k, v in data.items() if v is not None}
|
|
|
|
if kwargs:
|
|
data.update(kwargs)
|
|
|
|
timeout = min(timeout, 120)
|
|
|
|
async with self._client_creator() as client:
|
|
try:
|
|
response = await client.post("/search", content=json.dumps(data), timeout=timeout)
|
|
except httpx.TimeoutException:
|
|
raise TimeoutError(timeout)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
detail = ""
|
|
try:
|
|
detail = response.json().get("detail", {}).get("error", None)
|
|
except Exception:
|
|
pass
|
|
|
|
if response.status_code == 429:
|
|
raise UsageLimitExceededError(detail)
|
|
elif response.status_code in [403,432,433]:
|
|
raise ForbiddenError(detail)
|
|
elif response.status_code == 401:
|
|
raise InvalidAPIKeyError(detail)
|
|
elif response.status_code == 400:
|
|
raise BadRequestError(detail)
|
|
else:
|
|
raise response.raise_for_status()
|
|
|
|
async def search(self,
|
|
query: str,
|
|
search_depth: Literal["basic", "advanced"] = None,
|
|
topic: Literal["general", "news", "finance"] = None,
|
|
time_range: Literal["day", "week", "month", "year"] = None,
|
|
start_date: str = None,
|
|
end_date: str = None,
|
|
days: int = None,
|
|
max_results: int = None,
|
|
include_domains: Sequence[str] = None,
|
|
exclude_domains: Sequence[str] = None,
|
|
include_answer: Union[bool, Literal["basic", "advanced"]] = None,
|
|
include_raw_content: Union[bool, Literal["markdown", "text"]] = None,
|
|
include_images: bool = None,
|
|
timeout: int = 60,
|
|
country: str = None,
|
|
auto_parameters: bool = None,
|
|
include_favicon: bool = None,
|
|
**kwargs, # Accept custom arguments
|
|
) -> dict:
|
|
"""
|
|
Combined search method. Set search_depth to either "basic" or "advanced".
|
|
"""
|
|
timeout = min(timeout, 120)
|
|
response_dict = await self._search(query,
|
|
search_depth=search_depth,
|
|
topic=topic,
|
|
time_range=time_range,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
days=days,
|
|
max_results=max_results,
|
|
include_domains=include_domains,
|
|
exclude_domains=exclude_domains,
|
|
include_answer=include_answer,
|
|
include_raw_content=include_raw_content,
|
|
include_images=include_images,
|
|
timeout=timeout,
|
|
country=country,
|
|
auto_parameters=auto_parameters,
|
|
include_favicon=include_favicon,
|
|
**kwargs,
|
|
)
|
|
|
|
tavily_results = response_dict.get("results", [])
|
|
|
|
response_dict["results"] = tavily_results
|
|
|
|
return response_dict
|
|
|
|
async def _extract(
|
|
self,
|
|
urls: Union[List[str], str],
|
|
include_images: bool = None,
|
|
extract_depth: Literal["basic", "advanced"] = None,
|
|
format: Literal["markdown", "text"] = None,
|
|
timeout: int = 30,
|
|
include_favicon: bool = None,
|
|
**kwargs
|
|
) -> dict:
|
|
"""
|
|
Internal extract method to send the request to the API.
|
|
include_favicon: If True, include the favicon in the extraction results.
|
|
"""
|
|
data = {
|
|
"urls": urls,
|
|
"include_images": include_images,
|
|
"extract_depth": extract_depth,
|
|
"format": format,
|
|
"timeout": timeout,
|
|
"include_favicon": include_favicon,
|
|
}
|
|
|
|
data = {k: v for k, v in data.items() if v is not None}
|
|
|
|
if kwargs:
|
|
data.update(kwargs)
|
|
|
|
timeout = min(timeout, 120)
|
|
|
|
async with self._client_creator() as client:
|
|
try:
|
|
response = await client.post("/extract", content=json.dumps(data), timeout=timeout)
|
|
except httpx.TimeoutException:
|
|
raise TimeoutError(timeout)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
detail = ""
|
|
try:
|
|
detail = response.json().get("detail", {}).get("error", None)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
if response.status_code == 429:
|
|
raise UsageLimitExceededError(detail)
|
|
elif response.status_code in [403,432,433]:
|
|
raise ForbiddenError(detail)
|
|
elif response.status_code == 401:
|
|
raise InvalidAPIKeyError(detail)
|
|
elif response.status_code == 400:
|
|
raise BadRequestError(detail)
|
|
else:
|
|
raise response.raise_for_status()
|
|
|
|
async def extract(self,
|
|
urls: Union[List[str], str], # Accept a list of URLs or a single URL
|
|
include_images: bool = None,
|
|
extract_depth: Literal["basic", "advanced"] = None,
|
|
format: Literal["markdown", "text"] = None,
|
|
timeout: int = 30,
|
|
include_favicon: bool = None,
|
|
**kwargs, # Accept custom arguments
|
|
) -> dict:
|
|
"""
|
|
Combined extract method.
|
|
include_favicon: If True, include the favicon in the extraction results.
|
|
"""
|
|
timeout = min(timeout, 120)
|
|
response_dict = await self._extract(urls,
|
|
include_images,
|
|
extract_depth,
|
|
format,
|
|
timeout,
|
|
include_favicon=include_favicon,
|
|
**kwargs,
|
|
)
|
|
|
|
tavily_results = response_dict.get("results", [])
|
|
failed_results = response_dict.get("failed_results", [])
|
|
|
|
response_dict["results"] = tavily_results
|
|
response_dict["failed_results"] = failed_results
|
|
|
|
return response_dict
|
|
|
|
async def _crawl(self,
|
|
url: str,
|
|
max_depth: int = None,
|
|
max_breadth: int = None,
|
|
limit: int = None,
|
|
instructions: str = None,
|
|
select_paths: Sequence[str] = None,
|
|
select_domains: Sequence[str] = None,
|
|
exclude_paths: Sequence[str] = None,
|
|
exclude_domains: Sequence[str] = None,
|
|
allow_external: bool = None,
|
|
include_images: bool = None,
|
|
extract_depth: Literal["basic", "advanced"] = None,
|
|
format: Literal["markdown", "text"] = None,
|
|
timeout: int = 60,
|
|
include_favicon: bool = None,
|
|
**kwargs
|
|
) -> dict:
|
|
"""
|
|
Internal crawl method to send the request to the API.
|
|
"""
|
|
data = {
|
|
"url": url,
|
|
"max_depth": max_depth,
|
|
"max_breadth": max_breadth,
|
|
"limit": limit,
|
|
"instructions": instructions,
|
|
"select_paths": select_paths,
|
|
"select_domains": select_domains,
|
|
"exclude_paths": exclude_paths,
|
|
"exclude_domains": exclude_domains,
|
|
"allow_external": allow_external,
|
|
"include_images": include_images,
|
|
"extract_depth": extract_depth,
|
|
"format": format,
|
|
"include_favicon": include_favicon,
|
|
}
|
|
|
|
if kwargs:
|
|
data.update(kwargs)
|
|
|
|
data = {k: v for k, v in data.items() if v is not None}
|
|
|
|
timeout = min(timeout, 120)
|
|
|
|
async with self._client_creator() as client:
|
|
try:
|
|
response = await client.post("/crawl", content=json.dumps(data), timeout=timeout)
|
|
except httpx.TimeoutException:
|
|
raise TimeoutError(timeout)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
detail = ""
|
|
try:
|
|
detail = response.json().get("detail", {}).get("error", None)
|
|
except Exception:
|
|
pass
|
|
|
|
if response.status_code == 429:
|
|
raise UsageLimitExceededError(detail)
|
|
elif response.status_code in [403,432,433]:
|
|
raise ForbiddenError(detail)
|
|
elif response.status_code == 401:
|
|
raise InvalidAPIKeyError(detail)
|
|
elif response.status_code == 400:
|
|
raise BadRequestError(detail)
|
|
else:
|
|
raise response.raise_for_status()
|
|
|
|
async def crawl(self,
|
|
url: str,
|
|
max_depth: int = None,
|
|
max_breadth: int = None,
|
|
limit: int = None,
|
|
instructions: str = None,
|
|
select_paths: Sequence[str] = None,
|
|
select_domains: Sequence[str] = None,
|
|
exclude_paths: Sequence[str] = None,
|
|
exclude_domains: Sequence[str] = None,
|
|
allow_external: bool = None,
|
|
extract_depth: Literal["basic", "advanced"] = None,
|
|
include_images: bool = None,
|
|
format: Literal["markdown", "text"] = None,
|
|
timeout: int = 60,
|
|
include_favicon: bool = None,
|
|
**kwargs
|
|
) -> dict:
|
|
"""
|
|
Combined crawl method.
|
|
|
|
"""
|
|
timeout = min(timeout, 120)
|
|
response_dict = await self._crawl(url,
|
|
max_depth=max_depth,
|
|
max_breadth=max_breadth,
|
|
limit=limit,
|
|
instructions=instructions,
|
|
select_paths=select_paths,
|
|
select_domains=select_domains,
|
|
exclude_paths=exclude_paths,
|
|
exclude_domains=exclude_domains,
|
|
allow_external=allow_external,
|
|
extract_depth=extract_depth,
|
|
include_images=include_images,
|
|
format=format,
|
|
timeout=timeout,
|
|
include_favicon=include_favicon,
|
|
**kwargs)
|
|
|
|
return response_dict
|
|
|
|
async def _map(self,
|
|
url: str,
|
|
max_depth: int = None,
|
|
max_breadth: int = None,
|
|
limit: int = None,
|
|
instructions: str = None,
|
|
select_paths: Sequence[str] = None,
|
|
select_domains: Sequence[str] = None,
|
|
exclude_paths: Sequence[str] = None,
|
|
exclude_domains: Sequence[str] = None,
|
|
allow_external: bool = None,
|
|
include_images: bool = None,
|
|
timeout: int = 60,
|
|
**kwargs
|
|
) -> dict:
|
|
"""
|
|
Internal map method to send the request to the API.
|
|
"""
|
|
data = {
|
|
"url": url,
|
|
"max_depth": max_depth,
|
|
"max_breadth": max_breadth,
|
|
"limit": limit,
|
|
"instructions": instructions,
|
|
"select_paths": select_paths,
|
|
"select_domains": select_domains,
|
|
"exclude_paths": exclude_paths,
|
|
"exclude_domains": exclude_domains,
|
|
"allow_external": allow_external,
|
|
"include_images": include_images,
|
|
}
|
|
|
|
if kwargs:
|
|
data.update(kwargs)
|
|
|
|
data = {k: v for k, v in data.items() if v is not None}
|
|
|
|
timeout = min(timeout, 120)
|
|
|
|
async with self._client_creator() as client:
|
|
try:
|
|
response = await client.post("/map", content=json.dumps(data), timeout=timeout)
|
|
except httpx.TimeoutException:
|
|
raise TimeoutError(timeout)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
detail = ""
|
|
try:
|
|
detail = response.json().get("detail", {}).get("error", None)
|
|
except Exception:
|
|
pass
|
|
|
|
if response.status_code == 429:
|
|
raise UsageLimitExceededError(detail)
|
|
elif response.status_code in [403,432,433]:
|
|
raise ForbiddenError(detail)
|
|
elif response.status_code == 401:
|
|
raise InvalidAPIKeyError(detail)
|
|
elif response.status_code == 400:
|
|
raise BadRequestError(detail)
|
|
else:
|
|
raise response.raise_for_status()
|
|
|
|
async def map(self,
|
|
url: str,
|
|
max_depth: int = None,
|
|
max_breadth: int = None,
|
|
limit: int = None,
|
|
instructions: str = None,
|
|
select_paths: Sequence[str] = None,
|
|
select_domains: Sequence[str] = None,
|
|
exclude_paths: Sequence[str] = None,
|
|
exclude_domains: Sequence[str] = None,
|
|
allow_external: bool = None,
|
|
include_images: bool = None,
|
|
timeout: int = 60,
|
|
**kwargs
|
|
) -> dict:
|
|
"""
|
|
Combined map method.
|
|
|
|
"""
|
|
timeout = min(timeout, 120)
|
|
response_dict = await self._map(url,
|
|
max_depth=max_depth,
|
|
max_breadth=max_breadth,
|
|
limit=limit,
|
|
instructions=instructions,
|
|
select_paths=select_paths,
|
|
select_domains=select_domains,
|
|
exclude_paths=exclude_paths,
|
|
exclude_domains=exclude_domains,
|
|
allow_external=allow_external,
|
|
include_images=include_images,
|
|
timeout=timeout,
|
|
**kwargs)
|
|
|
|
return response_dict
|
|
|
|
async def get_search_context(self,
|
|
query: str,
|
|
search_depth: Literal["basic", "advanced"] = "basic",
|
|
topic: Literal["general", "news", "finance"] = "general",
|
|
days: int = 7,
|
|
max_results: int = 5,
|
|
include_domains: Sequence[str] = None,
|
|
exclude_domains: Sequence[str] = None,
|
|
max_tokens: int = 4000,
|
|
timeout: int = 60,
|
|
country: str = None,
|
|
include_favicon: bool = None,
|
|
**kwargs, # Accept custom arguments
|
|
) -> str:
|
|
"""
|
|
Get the search context for a query. Useful for getting only related content from retrieved websites
|
|
without having to deal with context extraction and limitation yourself.
|
|
|
|
max_tokens: The maximum number of tokens to return (based on openai token compute). Defaults to 4000.
|
|
|
|
Returns a string of JSON containing the search context up to context limit.
|
|
"""
|
|
timeout = min(timeout, 120)
|
|
response_dict = await self._search(query,
|
|
search_depth=search_depth,
|
|
topic=topic,
|
|
days=days,
|
|
max_results=max_results,
|
|
include_domains=include_domains,
|
|
exclude_domains=exclude_domains,
|
|
include_answer=False,
|
|
include_raw_content=False,
|
|
include_images=False,
|
|
timeout = timeout,
|
|
country=country,
|
|
include_favicon=include_favicon,
|
|
**kwargs,
|
|
)
|
|
sources = response_dict.get("results", [])
|
|
context = [{"url": source["url"], "content": source["content"]} for source in sources]
|
|
return json.dumps(get_max_items_from_list(context, max_tokens))
|
|
|
|
async def qna_search(self,
|
|
query: str,
|
|
search_depth: Literal["basic", "advanced"] = "advanced",
|
|
topic: Literal["general", "news", "finance"] = "general",
|
|
days: int = 7,
|
|
max_results: int = 5,
|
|
include_domains: Sequence[str] = None,
|
|
exclude_domains: Sequence[str] = None,
|
|
timeout: int = 60,
|
|
country: str = None,
|
|
include_favicon: bool = None,
|
|
**kwargs, # Accept custom arguments
|
|
) -> str:
|
|
"""
|
|
Q&A search method. Search depth is advanced by default to get the best answer.
|
|
"""
|
|
timeout = min(timeout, 120)
|
|
response_dict = await self._search(query,
|
|
search_depth=search_depth,
|
|
topic=topic,
|
|
days=days,
|
|
max_results=max_results,
|
|
include_domains=include_domains,
|
|
exclude_domains=exclude_domains,
|
|
include_raw_content=False,
|
|
include_images=False,
|
|
include_answer=True,
|
|
timeout = timeout,
|
|
country=country,
|
|
include_favicon=include_favicon,
|
|
**kwargs,
|
|
)
|
|
return response_dict.get("answer", "")
|
|
|
|
async def get_company_info(self,
|
|
query: str,
|
|
search_depth: Literal["basic", "advanced"] = "advanced",
|
|
max_results: int = 5,
|
|
timeout: int = 60,
|
|
country: str = None,
|
|
) -> Sequence[dict]:
|
|
""" Company information search method. Search depth is advanced by default to get the best answer. """
|
|
timeout = min(timeout, 120)
|
|
|
|
async def _perform_search(topic: str):
|
|
return await self._search(query,
|
|
search_depth=search_depth,
|
|
topic=topic,
|
|
max_results=max_results,
|
|
include_answer=False,
|
|
timeout = timeout,
|
|
country=country)
|
|
|
|
all_results = []
|
|
for data in await asyncio.gather(*[_perform_search(topic) for topic in self._company_info_tags]):
|
|
if "results" in data:
|
|
all_results.extend(data["results"])
|
|
|
|
# Sort all the results by score in descending order and take the top 'max_results' items
|
|
sorted_results = sorted(all_results, key=lambda x: x["score"], reverse=True)[:max_results]
|
|
|
|
return sorted_results
|