deepresearch/app/agents/v3_agent.py
2025-07-02 15:35:36 +08:00

252 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
DeepSeek V3模型智能体 - 带调试功能
负责API调用、内容重写等执行型任务
"""
import json
import logging
from typing import Dict, List, Any, Optional
from openai import OpenAI
from config import Config
from app.agents.prompts import get_prompt
from app.utils.debug_logger import ai_debug_logger
logger = logging.getLogger(__name__)
class V3Agent:
"""V3模型智能体"""
def __init__(self, api_key: str = None):
self.api_key = api_key or Config.DEEPSEEK_API_KEY
base_url = Config.DEEPSEEK_BASE_URL
# 火山引擎 ARK 平台使用不同的模型名称
if 'volces.com' in base_url:
self.model = "deepseek-v3-241226" # 火山引擎的 V3 模型名称
else:
self.model = Config.V3_MODEL
self.client = OpenAI(
api_key=self.api_key,
base_url=base_url
)
def _call_api(self, prompt: str, temperature: float = 0.3,
max_tokens: int = 4096, functions: List[Dict] = None) -> Any:
"""调用V3 API"""
try:
messages = [{"role": "user", "content": prompt}]
kwargs = {
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
# 如果提供了functions添加function calling参数
if functions:
kwargs["functions"] = functions
kwargs["function_call"] = "auto"
logger.info(f"调用V3 API: temperature={temperature}, max_tokens={max_tokens}, functions={bool(functions)}")
response = self.client.chat.completions.create(**kwargs)
# 准备响应内容
if functions and response.choices[0].message.function_call:
result = {
"function_call": {
"name": response.choices[0].message.function_call.name,
"arguments": json.loads(response.choices[0].message.function_call.arguments)
}
}
response_text = json.dumps(result, ensure_ascii=False)
else:
result = response.choices[0].message.content.strip()
response_text = result
# 记录调试日志
ai_debug_logger.log_api_call(
model=self.model,
agent_type="V3",
method=self._get_caller_method(),
prompt=prompt,
response=response_text,
temperature=temperature,
max_tokens=max_tokens,
metadata={
"has_functions": bool(functions),
"function_count": len(functions) if functions else 0,
"prompt_tokens": response.usage.prompt_tokens if hasattr(response, 'usage') else None,
"completion_tokens": response.usage.completion_tokens if hasattr(response, 'usage') else None
}
)
return result
except Exception as e:
logger.error(f"V3 API调用失败: {e}")
ai_debug_logger.log_api_call(
model=self.model,
agent_type="V3",
method=self._get_caller_method(),
prompt=prompt,
response=f"ERROR: {str(e)}",
temperature=temperature,
max_tokens=max_tokens,
metadata={"error": str(e)}
)
raise
def _get_caller_method(self) -> str:
"""获取调用方法名"""
import inspect
frame = inspect.currentframe()
if frame and frame.f_back and frame.f_back.f_back:
return frame.f_back.f_back.f_code.co_name
return "unknown"
def generate_search_queries(self, subtopic: str, explanation: str,
related_questions: List[str], count: int) -> List[str]:
"""生成搜索查询"""
prompt = get_prompt("generate_search_queries",
subtopic=subtopic,
explanation=explanation,
related_questions=', '.join(related_questions),
count=count)
result = self._call_api(prompt, temperature=0.7)
# 解析结果为列表
queries = [q.strip() for q in result.split('\n') if q.strip()]
# 去除可能的序号
queries = [q.lstrip('0123456789.-) ') for q in queries]
# 记录解析后的查询
logger.debug(f"生成了{len(queries)}个搜索查询")
return queries[:count]
def generate_refined_queries(self, key_info: str, detail_needed: str) -> List[str]:
"""生成细化搜索查询"""
prompt = get_prompt("generate_refined_queries",
key_info=key_info,
detail_needed=detail_needed)
result = self._call_api(prompt, temperature=0.7)
queries = [q.strip() for q in result.split('\n') if q.strip()]
queries = [q.lstrip('0123456789.-) ') for q in queries]
logger.debug(f"'{key_info}'生成了{len(queries)}个细化查询")
return queries[:3]
def rewrite_hallucination(self, hallucinated_content: str,
original_sources: str) -> str:
"""重写幻觉内容"""
prompt = get_prompt("rewrite_hallucination",
hallucinated_content=hallucinated_content,
original_sources=original_sources)
rewritten = self._call_api(prompt, temperature=0.3)
# 记录幻觉修正
logger.info(f"修正幻觉内容: 原始长度={len(hallucinated_content)}, 修正后长度={len(rewritten)}")
return rewritten
def call_tavily_search(self, query: str, max_results: int = 10) -> Dict[str, Any]:
"""
调用Tavily搜索API通过function calling
注意这是一个示例实现实际的Tavily调用会在search_service.py中
"""
# 定义Tavily搜索function
tavily_function = {
"name": "tavily_search",
"description": "Search the web using Tavily API",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results",
"default": 10
},
"search_depth": {
"type": "string",
"enum": ["basic", "advanced"],
"default": "advanced"
}
},
"required": ["query"]
}
}
prompt = f"Please search for information about: {query}"
result = self._call_api(prompt, functions=[tavily_function])
# 如果返回的是function call提取参数
if isinstance(result, dict) and "function_call" in result:
return result["function_call"]["arguments"]
# 否则返回默认参数
return {
"query": query,
"max_results": max_results,
"search_depth": "advanced"
}
def format_search_results(self, results: List[Dict[str, Any]]) -> str:
"""格式化搜索结果为结构化文本"""
formatted = []
for i, result in enumerate(results, 1):
formatted.append(f"{i}. 标题: {result.get('title', 'N/A')}")
formatted.append(f" URL: {result.get('url', 'N/A')}")
formatted.append(f" 摘要: {result.get('snippet', 'N/A')}")
if result.get('score'):
formatted.append(f" 相关度: {result.get('score', 0):.2f}")
formatted.append("")
return '\n'.join(formatted)
def extract_key_points(self, text: str, max_points: int = 5) -> List[str]:
"""从文本中提取关键点"""
prompt = f"""
请从以下文本中提取最多{max_points}个关键点:
{text}
每个关键点独占一行,简洁明了。
"""
result = self._call_api(prompt, temperature=0.5)
points = [p.strip() for p in result.split('\n') if p.strip()]
points = [p.lstrip('0123456789.-) ') for p in points]
logger.debug(f"从文本中提取了{len(points)}个关键点")
return points[:max_points]
def summarize_content(self, content: str, max_length: int = 200) -> str:
"""总结内容"""
prompt = f"""
请将以下内容总结为不超过{max_length}字的摘要:
{content}
要求:保留关键信息,语言流畅。
"""
summary = self._call_api(prompt, temperature=0.5)
logger.debug(f"内容总结: 原始长度={len(content)}, 摘要长度={len(summary)}")
return summary