deepresearch/所有文件/v3_agent.py
2025-07-02 15:35:36 +08:00

194 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
DeepSeek V3模型智能体
负责API调用、内容重写等执行型任务
"""
import json
import logging
from typing import Dict, List, Any, Optional
from openai import OpenAI
from config import Config
from app.agents.prompts import get_prompt
logger = logging.getLogger(__name__)
class V3Agent:
"""V3模型智能体"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.DEEPSEEK_API_KEY
base_url = Config.DEEPSEEK_BASE_URL
# 火山引擎 ARK 平台使用不同的模型名称
if 'volces.com' in base_url:
self.model = "deepseek-v3-241226" # 火山引擎的 V3 模型名称
else:
self.model = Config.V3_MODEL
self.client = OpenAI(
api_key=self.api_key,
base_url=base_url
)
def _call_api(self, prompt: str, temperature: float = 0.3,
max_tokens: int = 4096, functions: List[Dict] = None) -> Any:
"""调用V3 API"""
try:
messages = [{"role": "user", "content": prompt}]
kwargs = {
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
# 如果提供了functions添加function calling参数
if functions:
kwargs["functions"] = functions
kwargs["function_call"] = "auto"
response = self.client.chat.completions.create(**kwargs)
# 检查是否有function call
if functions and response.choices[0].message.function_call:
return {
"function_call": {
"name": response.choices[0].message.function_call.name,
"arguments": json.loads(response.choices[0].message.function_call.arguments)
}
}
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"V3 API调用失败: {e}")
raise
def generate_search_queries(self, subtopic: str, explanation: str,
related_questions: List[str], count: int) -> List[str]:
"""生成搜索查询"""
prompt = get_prompt("generate_search_queries",
subtopic=subtopic,
explanation=explanation,
related_questions=', '.join(related_questions),
count=count)
result = self._call_api(prompt, temperature=0.7)
# 解析结果为列表
queries = [q.strip() for q in result.split('\n') if q.strip()]
# 去除可能的序号
queries = [q.lstrip('0123456789.-) ') for q in queries]
return queries[:count]
def generate_refined_queries(self, key_info: str, detail_needed: str) -> List[str]:
"""生成细化搜索查询"""
prompt = get_prompt("generate_refined_queries",
key_info=key_info,
detail_needed=detail_needed)
result = self._call_api(prompt, temperature=0.7)
queries = [q.strip() for q in result.split('\n') if q.strip()]
queries = [q.lstrip('0123456789.-) ') for q in queries]
return queries[:3]
def rewrite_hallucination(self, hallucinated_content: str,
original_sources: str) -> str:
"""重写幻觉内容"""
prompt = get_prompt("rewrite_hallucination",
hallucinated_content=hallucinated_content,
original_sources=original_sources)
return self._call_api(prompt, temperature=0.3)
def call_tavily_search(self, query: str, max_results: int = 10) -> Dict[str, Any]:
"""
调用Tavily搜索API通过function calling
注意这是一个示例实现实际的Tavily调用会在search_service.py中
"""
# 定义Tavily搜索function
tavily_function = {
"name": "tavily_search",
"description": "Search the web using Tavily API",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results",
"default": 10
},
"search_depth": {
"type": "string",
"enum": ["basic", "advanced"],
"default": "advanced"
}
},
"required": ["query"]
}
}
prompt = f"Please search for information about: {query}"
result = self._call_api(prompt, functions=[tavily_function])
# 如果返回的是function call提取参数
if isinstance(result, dict) and "function_call" in result:
return result["function_call"]["arguments"]
# 否则返回默认参数
return {
"query": query,
"max_results": max_results,
"search_depth": "advanced"
}
def format_search_results(self, results: List[Dict[str, Any]]) -> str:
"""格式化搜索结果为结构化文本"""
formatted = []
for i, result in enumerate(results, 1):
formatted.append(f"{i}. 标题: {result.get('title', 'N/A')}")
formatted.append(f" URL: {result.get('url', 'N/A')}")
formatted.append(f" 摘要: {result.get('snippet', 'N/A')}")
if result.get('score'):
formatted.append(f" 相关度: {result.get('score', 0):.2f}")
formatted.append("")
return '\n'.join(formatted)
def extract_key_points(self, text: str, max_points: int = 5) -> List[str]:
"""从文本中提取关键点"""
prompt = f"""
请从以下文本中提取最多{max_points}个关键点:
{text}
每个关键点独占一行,简洁明了。
"""
result = self._call_api(prompt, temperature=0.5)
points = [p.strip() for p in result.split('\n') if p.strip()]
points = [p.lstrip('0123456789.-) ') for p in points]
return points[:max_points]
def summarize_content(self, content: str, max_length: int = 200) -> str:
"""总结内容"""
prompt = f"""
请将以下内容总结为不超过{max_length}字的摘要:
{content}
要求:保留关键信息,语言流畅。
"""
return self._call_api(prompt, temperature=0.5)