import requests import json import warnings import os from typing import Literal, Sequence, Optional, List, Union from concurrent.futures import ThreadPoolExecutor, as_completed from .utils import get_max_items_from_list from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError class TavilyClient: """ Tavily API client class. """ def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, str]] = None, api_base_url: Optional[str] = None): if api_key is None: api_key = os.getenv("TAVILY_API_KEY") if not api_key: raise MissingAPIKeyError() resolved_proxies = { "http": proxies.get("http") if proxies else os.getenv("TAVILY_HTTP_PROXY"), "https": proxies.get("https") if proxies else os.getenv("TAVILY_HTTPS_PROXY"), } resolved_proxies = {k: v for k, v in resolved_proxies.items() if v} or None self.base_url = api_base_url or "https://api.tavily.com" self.api_key = api_key self.proxies = resolved_proxies self.headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", "X-Client-Source": "tavily-python" } def _search(self, query: str, search_depth: Literal["basic", "advanced"] = None, topic: Literal["general", "news", "finance"] = None, time_range: Literal["day", "week", "month", "year"] = None, start_date: str = None, end_date: str = None, days: int = None, max_results: int = None, include_domains: Sequence[str] = None, exclude_domains: Sequence[str] = None, include_answer: Union[bool, Literal["basic", "advanced"]] = None, include_raw_content: Union[bool, Literal["markdown", "text"]] = None, include_images: bool = None, timeout: int = 60, country: str = None, auto_parameters: bool = None, include_favicon: bool = None, **kwargs ) -> dict: """ Internal search method to send the request to the API. """ data = { "query": query, "search_depth": search_depth, "topic": topic, "time_range": time_range, "start_date": start_date, "end_date": end_date, "days": days, "include_answer": include_answer, "include_raw_content": include_raw_content, "max_results": max_results, "include_domains": include_domains, "exclude_domains": exclude_domains, "include_images": include_images, "country": country, "auto_parameters": auto_parameters, "include_favicon": include_favicon, } data = {k: v for k, v in data.items() if v is not None} if kwargs: data.update(kwargs) timeout = min(timeout, 120) try: response = requests.post(self.base_url + "/search", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) except requests.exceptions.Timeout: raise TimeoutError(timeout) if response.status_code == 200: return response.json() else: detail = "" try: detail = response.json().get("detail", {}).get("error", None) except Exception: pass if response.status_code == 429: raise UsageLimitExceededError(detail) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) else: raise response.raise_for_status() def search(self, query: str, search_depth: Literal["basic", "advanced"] = None, topic: Literal["general", "news", "finance" ] = None, time_range: Literal["day", "week", "month", "year"] = None, start_date: str = None, end_date: str = None, days: int = None, max_results: int = None, include_domains: Sequence[str] = None, exclude_domains: Sequence[str] = None, include_answer: Union[bool, Literal["basic", "advanced"]] = None, include_raw_content: Union[bool, Literal["markdown", "text"]] = None, include_images: bool = None, timeout: int = 60, country: str = None, auto_parameters: bool = None, include_favicon: bool = None, **kwargs, # Accept custom arguments ) -> dict: """ Combined search method. """ timeout = min(timeout, 120) response_dict = self._search(query, search_depth=search_depth, topic=topic, time_range=time_range, start_date=start_date, end_date=end_date, days=days, max_results=max_results, include_domains=include_domains, exclude_domains=exclude_domains, include_answer=include_answer, include_raw_content=include_raw_content, include_images=include_images, timeout=timeout, country=country, auto_parameters=auto_parameters, include_favicon=include_favicon, **kwargs, ) tavily_results = response_dict.get("results", []) response_dict["results"] = tavily_results return response_dict def _extract(self, urls: Union[List[str], str], include_images: bool = None, extract_depth: Literal["basic", "advanced"] = None, format: Literal["markdown", "text"] = None, timeout: int = 30, include_favicon: bool = None, **kwargs ) -> dict: """ Internal extract method to send the request to the API. """ data = { "urls": urls, "include_images": include_images, "extract_depth": extract_depth, "format": format, "timeout": timeout, "include_favicon": include_favicon, } data = {k: v for k, v in data.items() if v is not None} if kwargs: data.update(kwargs) timeout = min(timeout, 120) try: response = requests.post(self.base_url + "/extract", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) except requests.exceptions.Timeout: raise TimeoutError(timeout) if response.status_code == 200: return response.json() else: detail = "" try: detail = response.json().get("detail", {}).get("error", None) except Exception: pass if response.status_code == 429: raise UsageLimitExceededError(detail) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) else: raise response.raise_for_status() def extract(self, urls: Union[List[str], str], # Accept a list of URLs or a single URL include_images: bool = None, extract_depth: Literal["basic", "advanced"] = None, format: Literal["markdown", "text"] = None, timeout: int = 30, include_favicon: bool = None, **kwargs, # Accept custom arguments ) -> dict: """ Combined extract method. """ timeout = min(timeout, 120) response_dict = self._extract(urls, include_images, extract_depth, format, timeout, include_favicon=include_favicon, **kwargs) tavily_results = response_dict.get("results", []) failed_results = response_dict.get("failed_results", []) response_dict["results"] = tavily_results response_dict["failed_results"] = failed_results return response_dict def _crawl(self, url: str, max_depth: int = None, max_breadth: int = None, limit: int = None, instructions: str = None, select_paths: Sequence[str] = None, select_domains: Sequence[str] = None, exclude_paths: Sequence[str] = None, exclude_domains: Sequence[str] = None, allow_external: bool = None, include_images: bool = None, extract_depth: Literal["basic", "advanced"] = None, format: Literal["markdown", "text"] = None, timeout: int = 60, include_favicon: bool = None, **kwargs ) -> dict: """ Internal crawl method to send the request to the API. include_favicon: If True, include the favicon in the crawl results. """ data = { "url": url, "max_depth": max_depth, "max_breadth": max_breadth, "limit": limit, "instructions": instructions, "select_paths": select_paths, "select_domains": select_domains, "exclude_paths": exclude_paths, "exclude_domains": exclude_domains, "allow_external": allow_external, "include_images": include_images, "extract_depth": extract_depth, "format": format, "include_favicon": include_favicon, } if kwargs: data.update(kwargs) data = {k: v for k, v in data.items() if v is not None} timeout = min(timeout, 120) try: response = requests.post( self.base_url + "/crawl", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) except requests.exceptions.Timeout: raise TimeoutError(timeout) if response.status_code == 200: return response.json() else: detail = "" try: detail = response.json().get("detail", {}).get("error", None) except Exception: pass if response.status_code == 429: raise UsageLimitExceededError(detail) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) else: raise response.raise_for_status() def crawl(self, url: str, max_depth: int = None, max_breadth: int = None, limit: int = None, instructions: str = None, select_paths: Sequence[str] = None, select_domains: Sequence[str] = None, exclude_paths: Sequence[str] = None, exclude_domains: Sequence[str] = None, allow_external: bool = None, include_images: bool = None, extract_depth: Literal["basic", "advanced"] = None, format: Literal["markdown", "text"] = None, timeout: int = 60, include_favicon: bool = None, **kwargs ) -> dict: """ Combined crawl method. include_favicon: If True, include the favicon in the crawl results. """ timeout = min(timeout, 120) response_dict = self._crawl(url, max_depth=max_depth, max_breadth=max_breadth, limit=limit, instructions=instructions, select_paths=select_paths, select_domains=select_domains, exclude_paths=exclude_paths, exclude_domains=exclude_domains, allow_external=allow_external, include_images=include_images, extract_depth=extract_depth, format=format, timeout=timeout, include_favicon=include_favicon, **kwargs) return response_dict def _map(self, url: str, max_depth: int = None, max_breadth: int = None, limit: int = None, instructions: str = None, select_paths: Sequence[str] = None, select_domains: Sequence[str] = None, exclude_paths: Sequence[str] = None, exclude_domains: Sequence[str] = None, allow_external: bool = None, include_images: bool = None, timeout: int = 60, **kwargs ) -> dict: """ Internal map method to send the request to the API. """ data = { "url": url, "max_depth": max_depth, "max_breadth": max_breadth, "limit": limit, "instructions": instructions, "select_paths": select_paths, "select_domains": select_domains, "exclude_paths": exclude_paths, "exclude_domains": exclude_domains, "allow_external": allow_external, "include_images": include_images, } if kwargs: data.update(kwargs) data = {k: v for k, v in data.items() if v is not None} timeout = min(timeout, 120) try: response = requests.post( self.base_url + "/map", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) except requests.exceptions.Timeout: raise TimeoutError(timeout) if response.status_code == 200: return response.json() else: detail = "" try: detail = response.json().get("detail", {}).get("error", None) except Exception: pass if response.status_code == 429: raise UsageLimitExceededError(detail) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) else: raise response.raise_for_status() def map(self, url: str, max_depth: int = None, max_breadth: int = None, limit: int = None, instructions: str = None, select_paths: Sequence[str] = None, select_domains: Sequence[str] = None, exclude_paths: Sequence[str] = None, exclude_domains: Sequence[str] = None, allow_external: bool = None, include_images: bool = None, timeout: int = 60, **kwargs ) -> dict: """ Combined map method. """ timeout = min(timeout, 120) response_dict = self._map(url, max_depth=max_depth, max_breadth=max_breadth, limit=limit, instructions=instructions, select_paths=select_paths, select_domains=select_domains, exclude_paths=exclude_paths, exclude_domains=exclude_domains, allow_external=allow_external, include_images=include_images, timeout=timeout, **kwargs) return response_dict def get_search_context(self, query: str, search_depth: Literal["basic", "advanced"] = "basic", topic: Literal["general", "news", "finance"] = "general", days: int = 7, max_results: int = 5, include_domains: Sequence[str] = None, exclude_domains: Sequence[str] = None, max_tokens: int = 4000, timeout: int = 60, country: str = None, include_favicon: bool = None, **kwargs, # Accept custom arguments ) -> str: """ Get the search context for a query. Useful for getting only related content from retrieved websites without having to deal with context extraction and limitation yourself. max_tokens: The maximum number of tokens to return (based on openai token compute). Defaults to 4000. Returns a string of JSON containing the search context up to context limit. """ timeout = min(timeout, 120) response_dict = self._search(query, search_depth=search_depth, topic=topic, days=days, max_results=max_results, include_domains=include_domains, exclude_domains=exclude_domains, include_answer=False, include_raw_content=False, include_images=False, timeout=timeout, country=country, include_favicon=include_favicon, **kwargs, ) sources = response_dict.get("results", []) context = [{"url": source["url"], "content": source["content"]} for source in sources] return json.dumps(get_max_items_from_list(context, max_tokens)) def qna_search(self, query: str, search_depth: Literal["basic", "advanced"] = "advanced", topic: Literal["general", "news", "finance"] = "general", days: int = 7, max_results: int = 5, include_domains: Sequence[str] = None, exclude_domains: Sequence[str] = None, timeout: int = 60, country: str = None, include_favicon: bool = None, **kwargs, # Accept custom arguments ) -> str: """ Q&A search method. Search depth is advanced by default to get the best answer. """ timeout = min(timeout, 120) response_dict = self._search(query, search_depth=search_depth, topic=topic, days=days, max_results=max_results, include_domains=include_domains, exclude_domains=exclude_domains, include_raw_content=False, include_images=False, include_answer=True, timeout=timeout, country=country, include_favicon=include_favicon, **kwargs, ) return response_dict.get("answer", "") def get_company_info(self, query: str, search_depth: Literal["basic", "advanced"] = "advanced", max_results: int = 5, timeout: int = 60, country: str = None, ) -> Sequence[dict]: """ Company information search method. Search depth is advanced by default to get the best answer. """ timeout = min(timeout, 120) def _perform_search(topic): return self._search(query, search_depth=search_depth, topic=topic, max_results=max_results, include_answer=False, timeout=timeout, country=country) with ThreadPoolExecutor() as executor: # Initiate the search for each topic in parallel future_to_topic = {executor.submit(_perform_search, topic): topic for topic in ["news", "general", "finance"]} all_results = [] # Process the results as they become available for future in as_completed(future_to_topic): data = future.result() if 'results' in data: all_results.extend(data['results']) # Sort all the results by score in descending order and take the top 'max_results' items sorted_results = sorted(all_results, key=lambda x: x['score'], reverse=True)[ :max_results] return sorted_results class Client(TavilyClient): """ Tavily API client class. WARNING! This class is deprecated. Please use TavilyClient instead. """ def __init__(self, kwargs): warnings.warn("Client is deprecated, please use TavilyClient instead", DeprecationWarning, stacklevel=2) super().__init__(kwargs)