194 lines
6.9 KiB
Python
194 lines
6.9 KiB
Python
"""
|
||
DeepSeek V3模型智能体
|
||
负责API调用、内容重写等执行型任务
|
||
"""
|
||
import json
|
||
import logging
|
||
from typing import Dict, List, Any, Optional
|
||
from openai import OpenAI
|
||
from config import Config
|
||
from app.agents.prompts import get_prompt
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class V3Agent:
|
||
"""V3模型智能体"""
|
||
|
||
def __init__(self, api_key: str = None):
|
||
self.api_key = api_key or Config.DEEPSEEK_API_KEY
|
||
base_url = Config.DEEPSEEK_BASE_URL
|
||
|
||
# 火山引擎 ARK 平台使用不同的模型名称
|
||
if 'volces.com' in base_url:
|
||
self.model = "deepseek-v3-241226" # 火山引擎的 V3 模型名称
|
||
else:
|
||
self.model = Config.V3_MODEL
|
||
|
||
self.client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=base_url
|
||
)
|
||
|
||
def _call_api(self, prompt: str, temperature: float = 0.3,
|
||
max_tokens: int = 4096, functions: List[Dict] = None) -> Any:
|
||
"""调用V3 API"""
|
||
try:
|
||
messages = [{"role": "user", "content": prompt}]
|
||
|
||
kwargs = {
|
||
"model": self.model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
}
|
||
|
||
# 如果提供了functions,添加function calling参数
|
||
if functions:
|
||
kwargs["functions"] = functions
|
||
kwargs["function_call"] = "auto"
|
||
|
||
response = self.client.chat.completions.create(**kwargs)
|
||
|
||
# 检查是否有function call
|
||
if functions and response.choices[0].message.function_call:
|
||
return {
|
||
"function_call": {
|
||
"name": response.choices[0].message.function_call.name,
|
||
"arguments": json.loads(response.choices[0].message.function_call.arguments)
|
||
}
|
||
}
|
||
|
||
return response.choices[0].message.content.strip()
|
||
|
||
except Exception as e:
|
||
logger.error(f"V3 API调用失败: {e}")
|
||
raise
|
||
|
||
def generate_search_queries(self, subtopic: str, explanation: str,
|
||
related_questions: List[str], count: int) -> List[str]:
|
||
"""生成搜索查询"""
|
||
prompt = get_prompt("generate_search_queries",
|
||
subtopic=subtopic,
|
||
explanation=explanation,
|
||
related_questions=', '.join(related_questions),
|
||
count=count)
|
||
|
||
result = self._call_api(prompt, temperature=0.7)
|
||
|
||
# 解析结果为列表
|
||
queries = [q.strip() for q in result.split('\n') if q.strip()]
|
||
# 去除可能的序号
|
||
queries = [q.lstrip('0123456789.-) ') for q in queries]
|
||
|
||
return queries[:count]
|
||
|
||
def generate_refined_queries(self, key_info: str, detail_needed: str) -> List[str]:
|
||
"""生成细化搜索查询"""
|
||
prompt = get_prompt("generate_refined_queries",
|
||
key_info=key_info,
|
||
detail_needed=detail_needed)
|
||
|
||
result = self._call_api(prompt, temperature=0.7)
|
||
|
||
queries = [q.strip() for q in result.split('\n') if q.strip()]
|
||
queries = [q.lstrip('0123456789.-) ') for q in queries]
|
||
|
||
return queries[:3]
|
||
|
||
def rewrite_hallucination(self, hallucinated_content: str,
|
||
original_sources: str) -> str:
|
||
"""重写幻觉内容"""
|
||
prompt = get_prompt("rewrite_hallucination",
|
||
hallucinated_content=hallucinated_content,
|
||
original_sources=original_sources)
|
||
|
||
return self._call_api(prompt, temperature=0.3)
|
||
|
||
def call_tavily_search(self, query: str, max_results: int = 10) -> Dict[str, Any]:
|
||
"""
|
||
调用Tavily搜索API(通过function calling)
|
||
注意:这是一个示例实现,实际的Tavily调用会在search_service.py中
|
||
"""
|
||
# 定义Tavily搜索function
|
||
tavily_function = {
|
||
"name": "tavily_search",
|
||
"description": "Search the web using Tavily API",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "The search query"
|
||
},
|
||
"max_results": {
|
||
"type": "integer",
|
||
"description": "Maximum number of results",
|
||
"default": 10
|
||
},
|
||
"search_depth": {
|
||
"type": "string",
|
||
"enum": ["basic", "advanced"],
|
||
"default": "advanced"
|
||
}
|
||
},
|
||
"required": ["query"]
|
||
}
|
||
}
|
||
|
||
prompt = f"Please search for information about: {query}"
|
||
|
||
result = self._call_api(prompt, functions=[tavily_function])
|
||
|
||
# 如果返回的是function call,提取参数
|
||
if isinstance(result, dict) and "function_call" in result:
|
||
return result["function_call"]["arguments"]
|
||
|
||
# 否则返回默认参数
|
||
return {
|
||
"query": query,
|
||
"max_results": max_results,
|
||
"search_depth": "advanced"
|
||
}
|
||
|
||
def format_search_results(self, results: List[Dict[str, Any]]) -> str:
|
||
"""格式化搜索结果为结构化文本"""
|
||
formatted = []
|
||
|
||
for i, result in enumerate(results, 1):
|
||
formatted.append(f"{i}. 标题: {result.get('title', 'N/A')}")
|
||
formatted.append(f" URL: {result.get('url', 'N/A')}")
|
||
formatted.append(f" 摘要: {result.get('snippet', 'N/A')}")
|
||
if result.get('score'):
|
||
formatted.append(f" 相关度: {result.get('score', 0):.2f}")
|
||
formatted.append("")
|
||
|
||
return '\n'.join(formatted)
|
||
|
||
def extract_key_points(self, text: str, max_points: int = 5) -> List[str]:
|
||
"""从文本中提取关键点"""
|
||
prompt = f"""
|
||
请从以下文本中提取最多{max_points}个关键点:
|
||
|
||
{text}
|
||
|
||
每个关键点独占一行,简洁明了。
|
||
"""
|
||
|
||
result = self._call_api(prompt, temperature=0.5)
|
||
|
||
points = [p.strip() for p in result.split('\n') if p.strip()]
|
||
points = [p.lstrip('0123456789.-) ') for p in points]
|
||
|
||
return points[:max_points]
|
||
|
||
def summarize_content(self, content: str, max_length: int = 200) -> str:
|
||
"""总结内容"""
|
||
prompt = f"""
|
||
请将以下内容总结为不超过{max_length}字的摘要:
|
||
|
||
{content}
|
||
|
||
要求:保留关键信息,语言流畅。
|
||
"""
|
||
|
||
return self._call_api(prompt, temperature=0.5) |