agent-Specialization/utils/api_client.py

736 lines
32 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ========== api_client.py ==========
# utils/api_client.py - DeepSeek API 客户端支持Web模式- 简化版
import httpx
import json
import asyncio
from typing import List, Dict, Optional, AsyncGenerator
try:
from config import (
API_BASE_URL,
API_KEY,
MODEL_ID,
OUTPUT_FORMATS,
DEFAULT_RESPONSE_MAX_TOKENS,
THINKING_API_BASE_URL,
THINKING_API_KEY,
THINKING_MODEL_ID
)
except ImportError:
import sys
from pathlib import Path
project_root = Path(__file__).resolve().parents[1]
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from config import (
API_BASE_URL,
API_KEY,
MODEL_ID,
OUTPUT_FORMATS,
DEFAULT_RESPONSE_MAX_TOKENS,
THINKING_API_BASE_URL,
THINKING_API_KEY,
THINKING_MODEL_ID
)
class DeepSeekClient:
def __init__(self, thinking_mode: bool = True, web_mode: bool = False):
self.fast_api_config = {
"base_url": API_BASE_URL,
"api_key": API_KEY,
"model_id": MODEL_ID
}
self.thinking_api_config = {
"base_url": THINKING_API_BASE_URL or API_BASE_URL,
"api_key": THINKING_API_KEY or API_KEY,
"model_id": THINKING_MODEL_ID or MODEL_ID
}
self.fast_max_tokens = None
self.thinking_max_tokens = None
self.fast_extra_params: Dict = {}
self.thinking_extra_params: Dict = {}
self.thinking_mode = thinking_mode # True=智能思考模式, False=快速模式
self.deep_thinking_mode = False # 深度思考模式:整轮都使用思考模型
self.deep_thinking_session = False # 当前任务是否处于深度思考会话
self.web_mode = web_mode # Web模式标志用于禁用print输出
# 兼容旧代码路径
self.api_base_url = self.fast_api_config["base_url"]
self.api_key = self.fast_api_config["api_key"]
self.model_id = self.fast_api_config["model_id"]
self.model_key = None # 由宿主终端注入,便于做模型兼容处理
# 每个任务的独立状态
self.current_task_first_call = True # 当前任务是否是第一次调用
self.current_task_thinking = "" # 当前任务的思考内容
self.force_thinking_next_call = False # 单次强制思考
self.skip_thinking_next_call = False # 单次强制快速
self.last_call_used_thinking = False # 最近一次调用是否使用思考模型
def _print(self, message: str, end: str = "\n", flush: bool = False):
"""安全的打印函数在Web模式下不输出"""
if not self.web_mode:
print(message, end=end, flush=flush)
def _format_read_file_result(self, data: Dict) -> str:
"""根据读取模式格式化 read_file 工具结果。"""
if not isinstance(data, dict):
return json.dumps(data, ensure_ascii=False)
if not data.get("success"):
return json.dumps(data, ensure_ascii=False)
read_type = data.get("type", "read")
truncated_note = "(内容已截断)" if data.get("truncated") else ""
path = data.get("path", "未知路径")
max_chars = data.get("max_chars")
max_note = f"(max_chars={max_chars})" if max_chars else ""
if read_type == "read":
line_start = data.get("line_start")
line_end = data.get("line_end")
char_count = data.get("char_count", len(data.get("content", "") or ""))
header = f"读取 {path}{line_start}~{line_end},返回 {char_count} 字符 {max_note}{truncated_note}".strip()
content = data.get("content", "")
return f"{header}\n```\n{content}\n```"
if read_type == "search":
query = data.get("query", "")
actual = data.get("actual_matches", 0)
returned = data.get("returned_matches", 0)
case_hint = "区分大小写" if data.get("case_sensitive") else "不区分大小写"
header = (
f"{path} 中搜索 \"{query}\",返回 {returned}/{actual} 条结果({case_hint}"
f" {max_note}{truncated_note}"
).strip()
match_texts = []
for idx, match in enumerate(data.get("matches", []), 1):
match_note = "(片段截断)" if match.get("truncated") else ""
hits = match.get("hits") or []
hit_text = ", ".join(str(h) for h in hits) if hits else ""
label = match.get("id") or f"match_{idx}"
snippet = match.get("snippet", "")
match_texts.append(
f"[{label}] 行 {match.get('line_start')}~{match.get('line_end')} 命中行: {hit_text}{match_note}\n```\n{snippet}\n```"
)
if not match_texts:
match_texts.append("未找到匹配内容。")
return "\n".join([header] + match_texts)
if read_type == "extract":
segments = data.get("segments", [])
header = (
f"{path} 抽取 {len(segments)} 个片段 {max_note}{truncated_note}"
).strip()
seg_texts = []
for idx, segment in enumerate(segments, 1):
seg_note = "(片段截断)" if segment.get("truncated") else ""
label = segment.get("label") or f"segment_{idx}"
snippet = segment.get("content", "")
seg_texts.append(
f"[{label}] 行 {segment.get('line_start')}~{segment.get('line_end')}{seg_note}\n```\n{snippet}\n```"
)
if not seg_texts:
seg_texts.append("未提供可抽取的片段。")
return "\n".join([header] + seg_texts)
return json.dumps(data, ensure_ascii=False)
def set_deep_thinking_mode(self, enabled: bool):
"""配置深度思考模式(持续使用思考模型)。"""
self.deep_thinking_mode = bool(enabled)
if not enabled:
self.deep_thinking_session = False
def start_new_task(self, force_deep: bool = False):
"""开始新任务(重置任务级别的状态)"""
self.current_task_first_call = True
self.current_task_thinking = ""
self.force_thinking_next_call = False
self.skip_thinking_next_call = False
self.last_call_used_thinking = False
self.deep_thinking_session = bool(force_deep) or bool(self.deep_thinking_mode)
def _build_headers(self, api_key: str) -> Dict[str, str]:
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
def _select_api_config(self, use_thinking: bool) -> Dict[str, str]:
"""
根据当前模式选择API配置确保缺失字段回退到默认模型。
"""
config = self.thinking_api_config if use_thinking else self.fast_api_config
fallback = self.fast_api_config
return {
"base_url": config.get("base_url") or fallback["base_url"],
"api_key": config.get("api_key") or fallback["api_key"],
"model_id": config.get("model_id") or fallback["model_id"]
}
def apply_profile(self, profile: Dict):
"""
动态应用模型配置
profile 示例:
{
"fast": {"base_url": "...", "api_key": "...", "model_id": "...", "max_tokens": 8192},
"thinking": {...} 或 None,
"supports_thinking": True/False,
"fast_only": True/False
}
"""
if not profile or "fast" not in profile:
raise ValueError("无效的模型配置")
fast = profile["fast"] or {}
thinking = profile.get("thinking") or fast
self.fast_api_config = {
"base_url": fast.get("base_url") or self.fast_api_config.get("base_url"),
"api_key": fast.get("api_key") or self.fast_api_config.get("api_key"),
"model_id": fast.get("model_id") or self.fast_api_config.get("model_id")
}
self.thinking_api_config = {
"base_url": thinking.get("base_url") or self.thinking_api_config.get("base_url"),
"api_key": thinking.get("api_key") or self.thinking_api_config.get("api_key"),
"model_id": thinking.get("model_id") or self.thinking_api_config.get("model_id")
}
self.fast_max_tokens = fast.get("max_tokens")
self.thinking_max_tokens = thinking.get("max_tokens")
self.fast_extra_params = fast.get("extra_params") or {}
self.thinking_extra_params = thinking.get("extra_params") or {}
# 同步旧字段
self.api_base_url = self.fast_api_config["base_url"]
self.api_key = self.fast_api_config["api_key"]
self.model_id = self.fast_api_config["model_id"]
def get_current_thinking_mode(self) -> bool:
"""获取当前应该使用的思考模式"""
if self.deep_thinking_session:
return True
if not self.thinking_mode:
return False
if self.force_thinking_next_call:
return True
if self.skip_thinking_next_call:
return False
return self.current_task_first_call
def _validate_json_string(self, json_str: str) -> tuple:
"""
验证JSON字符串的完整性
Returns:
(is_valid: bool, error_message: str, parsed_data: dict or None)
"""
if not json_str or not json_str.strip():
return True, "", {}
# 检查基本的JSON结构标记
stripped = json_str.strip()
if not stripped.startswith('{') or not stripped.endswith('}'):
return False, "JSON字符串格式不完整缺少开始或结束大括号", None
# 检查引号配对
in_string = False
escape_next = False
quote_count = 0
for char in stripped:
if escape_next:
escape_next = False
continue
if char == '\\':
escape_next = True
continue
if char == '"':
quote_count += 1
in_string = not in_string
if in_string:
return False, "JSON字符串中存在未闭合的引号", None
# 尝试解析JSON
try:
parsed_data = json.loads(stripped)
return True, "", parsed_data
except json.JSONDecodeError as e:
return False, f"JSON解析错误: {str(e)}", None
def _safe_tool_arguments_parse(self, arguments_str: str, tool_name: str) -> tuple:
"""
安全地解析工具参数,保持失败即时返回
Returns:
(success: bool, arguments: dict, error_message: str)
"""
if not arguments_str or not arguments_str.strip():
return True, {}, ""
# 长度检查
max_length = 999999999 # 50KB限制
if len(arguments_str) > max_length:
return False, {}, f"参数过长({len(arguments_str)}字符),超过{max_length}字符限制"
# 尝试直接解析JSON
try:
parsed_data = json.loads(arguments_str)
return True, parsed_data, ""
except json.JSONDecodeError as e:
preview_length = 200
stripped = arguments_str.strip()
preview = stripped[:preview_length] + "..." if len(stripped) > preview_length else stripped
return False, {}, f"JSON解析失败: {str(e)}\n参数预览: {preview}"
async def chat(
self,
messages: List[Dict],
tools: Optional[List[Dict]] = None,
stream: bool = True
) -> AsyncGenerator[Dict, None]:
"""
异步调用DeepSeek API
Args:
messages: 消息列表
tools: 工具定义列表
stream: 是否流式输出
Yields:
响应内容块
"""
# 检查API密钥
if not self.api_key or self.api_key == "your-deepseek-api-key":
self._print(f"{OUTPUT_FORMATS['error']} API密钥未配置请在config.py中设置API_KEY")
return
# 决定是否使用思考模式
current_thinking_mode = self.get_current_thinking_mode()
api_config = self._select_api_config(current_thinking_mode)
headers = self._build_headers(api_config["api_key"])
# 如果当前为快速模式但已有思考内容,提示沿用
if self.thinking_mode and not current_thinking_mode and self.current_task_thinking:
self._print(f"{OUTPUT_FORMATS['info']} [任务内快速模式] 使用本次任务的思考继续处理...")
# 记录本次调用的模式
self.last_call_used_thinking = current_thinking_mode
if current_thinking_mode and self.force_thinking_next_call:
self.force_thinking_next_call = False
if not current_thinking_mode and self.skip_thinking_next_call:
self.skip_thinking_next_call = False
try:
override_max = self.thinking_max_tokens if current_thinking_mode else self.fast_max_tokens
if override_max is not None:
max_tokens = int(override_max)
else:
max_tokens = int(DEFAULT_RESPONSE_MAX_TOKENS)
if max_tokens <= 0:
raise ValueError("max_tokens must be positive")
except (TypeError, ValueError):
max_tokens = 4096
payload = {
"model": api_config["model_id"],
"messages": messages,
"stream": stream,
"max_tokens": max_tokens
}
# 部分平台(如 Qwen、DeepSeek需要显式请求 usage 才会在流式尾包返回
if stream:
should_include_usage = False
if self.model_key in {"qwen3-max", "qwen3-vl-plus", "deepseek"}:
should_include_usage = True
# 兜底:根据 base_url 识别 openai 兼容的提供商
if api_config["base_url"]:
lower_url = api_config["base_url"].lower()
if any(keyword in lower_url for keyword in ["dashscope", "aliyuncs", "deepseek.com"]):
should_include_usage = True
if should_include_usage:
payload.setdefault("stream_options", {})["include_usage"] = True
# 注入模型额外参数(如 Qwen enable_thinking
extra_params = self.thinking_extra_params if current_thinking_mode else self.fast_extra_params
if extra_params:
payload.update(extra_params)
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
try:
async with httpx.AsyncClient(http2=True, timeout=300) as client:
if stream:
async with client.stream(
"POST",
f"{api_config['base_url']}/chat/completions",
json=payload,
headers=headers
) as response:
# 检查响应状态
if response.status_code != 200:
error_text = await response.aread()
self._print(f"{OUTPUT_FORMATS['error']} API请求失败 ({response.status_code}): {error_text}")
return
async for line in response.aiter_lines():
if line.startswith("data:"):
json_str = line[5:].strip()
if json_str == "[DONE]":
break
try:
data = json.loads(json_str)
yield data
except json.JSONDecodeError:
continue
else:
response = await client.post(
f"{api_config['base_url']}/chat/completions",
json=payload,
headers=headers
)
if response.status_code != 200:
error_text = response.text
self._print(f"{OUTPUT_FORMATS['error']} API请求失败 ({response.status_code}): {error_text}")
return
yield response.json()
except httpx.ConnectError:
self._print(f"{OUTPUT_FORMATS['error']} 无法连接到API服务器请检查网络连接")
except httpx.TimeoutException:
self._print(f"{OUTPUT_FORMATS['error']} API请求超时")
except Exception as e:
self._print(f"{OUTPUT_FORMATS['error']} API调用异常: {e}")
async def chat_with_tools(
self,
messages: List[Dict],
tools: List[Dict],
tool_handler: callable
) -> str:
"""
带工具调用的对话(支持多轮)
Args:
messages: 消息列表
tools: 工具定义
tool_handler: 工具处理函数
Returns:
最终回答
"""
final_response = ""
max_iterations = 200 # 最大迭代次数
iteration = 0
all_tool_results = [] # 记录所有工具调用结果
while iteration < max_iterations:
iteration += 1
# 调用API始终提供工具定义
full_response = ""
tool_calls = []
current_thinking = ""
# 针对 append_to_file / modify_file 的占位结构,防止未定义变量导致异常
append_result = {"handled": False}
modify_result = {"handled": False}
# 状态标志
in_thinking = False
thinking_printed = False
async for chunk in self.chat(messages, tools, stream=True):
if "choices" not in chunk:
continue
delta = chunk["choices"][0].get("delta", {})
# 处理思考内容
if "reasoning_content" in delta:
reasoning_content = delta["reasoning_content"]
if reasoning_content: # 只处理非空内容
if not in_thinking:
self._print("💭 [正在思考]\n", end="", flush=True)
in_thinking = True
thinking_printed = True
current_thinking += reasoning_content
self._print(reasoning_content, end="", flush=True)
# 处理正常内容 - 独立的if不是elif
if "content" in delta:
content = delta["content"]
if content: # 只处理非空内容
# 如果之前在输出思考,先结束思考输出
if in_thinking:
self._print("\n\n💭 [思考结束]\n\n", end="", flush=True)
in_thinking = False
full_response += content
self._print(content, end="", flush=True)
# 收集工具调用 - 改进的拼接逻辑
# 收集工具调用 - 修复JSON分片问题
if "tool_calls" in delta:
for tool_call in delta["tool_calls"]:
tool_index = tool_call.get("index", 0)
# 查找或创建对应索引的工具调用
existing_call = None
for existing in tool_calls:
if existing.get("index") == tool_index:
existing_call = existing
break
if not existing_call and tool_call.get("id"):
# 创建新的工具调用
new_call = {
"id": tool_call.get("id"),
"index": tool_index,
"type": tool_call.get("type", "function"),
"function": {
"name": tool_call.get("function", {}).get("name", ""),
"arguments": ""
}
}
tool_calls.append(new_call)
existing_call = new_call
# 安全地拼接arguments - 简单字符串拼接不尝试JSON验证
if existing_call and "function" in tool_call and "arguments" in tool_call["function"]:
new_args = tool_call["function"]["arguments"]
if new_args: # 只拼接非空内容
existing_call["function"]["arguments"] += new_args
self._print("") # 最终换行
# 如果思考还没结束(只调用工具没有文本),手动结束
if in_thinking:
self._print("\n💭 [思考结束]\n")
# 记录思考内容并更新调用状态
if self.last_call_used_thinking and current_thinking:
self.current_task_thinking = current_thinking
if self.current_task_first_call:
self.current_task_first_call = False # 标记当前任务的第一次调用已完成
# 如果没有工具调用,说明完成了
if not tool_calls:
if full_response: # 有正常回复,任务完成
final_response = full_response
break
elif iteration == 1: # 第一次就没有工具调用也没有内容,可能有问题
self._print(f"{OUTPUT_FORMATS['warning']} 模型未返回内容")
break
# 构建助手消息 - 始终包含所有收集到的内容
assistant_content_parts = []
# 添加正式回复内容(如果有)
if full_response:
assistant_content_parts.append(full_response)
elif append_result["handled"] and append_result["assistant_content"]:
assistant_content_parts.append(append_result["assistant_content"])
elif modify_result["handled"] and modify_result.get("assistant_content"):
assistant_content_parts.append(modify_result["assistant_content"])
# 添加工具调用说明
if tool_calls:
tool_names = [tc['function']['name'] for tc in tool_calls]
assistant_content_parts.append(f"执行工具: {', '.join(tool_names)}")
# 合并所有内容
assistant_content = "\n".join(assistant_content_parts) if assistant_content_parts else "执行工具调用"
assistant_message = {
"role": "assistant",
"content": assistant_content,
"tool_calls": tool_calls
}
if current_thinking:
assistant_message["reasoning_content"] = current_thinking
messages.append(assistant_message)
# 执行所有工具调用 - 使用鲁棒的参数解析
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
arguments_str = tool_call["function"]["arguments"]
# 使用改进的参数解析方法增强JSON修复能力
success, arguments, error_msg = self._safe_tool_arguments_parse(arguments_str, function_name)
if not success:
self._print(f"{OUTPUT_FORMATS['error']} 工具参数解析失败: {error_msg}")
self._print(f" 工具名称: {function_name}")
self._print(f" 参数长度: {len(arguments_str)} 字符")
# 返回详细的错误信息给模型
error_response = {
"success": False,
"error": error_msg,
"tool_name": function_name,
"arguments_length": len(arguments_str),
"suggestion": "请检查参数格式或减少参数长度后重试"
}
# 如果参数过长,提供分块建议
if len(arguments_str) > 10000:
error_response["suggestion"] = "参数过长,建议分块处理或使用更简洁的内容"
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": json.dumps(error_response, ensure_ascii=False)
})
# 记录失败的调用,防止死循环检测失效
all_tool_results.append({
"tool": function_name,
"args": {"parse_error": error_msg, "length": len(arguments_str)},
"result": f"参数解析失败: {error_msg}"
})
continue
self._print(f"\n{OUTPUT_FORMATS['action']} 调用工具: {function_name}")
# 额外的参数长度检查(针对特定工具)
if function_name == "modify_file" and "content" in arguments:
content_length = len(arguments.get("content", ""))
if content_length > 9999999999: # 降低到50KB限制
error_msg = f"内容过长({content_length}字符)超过50KB限制"
self._print(f"{OUTPUT_FORMATS['warning']} {error_msg}")
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": json.dumps({
"success": False,
"error": error_msg,
"suggestion": "请将内容分成多个小块分别修改或使用replace操作只修改必要部分"
}, ensure_ascii=False)
})
all_tool_results.append({
"tool": function_name,
"args": arguments,
"result": error_msg
})
continue
tool_result = await tool_handler(function_name, arguments)
# 解析工具结果,提取关键信息
try:
result_data = json.loads(tool_result)
if function_name == "read_file":
tool_result_msg = self._format_read_file_result(result_data)
else:
tool_result_msg = tool_result
except:
tool_result_msg = tool_result
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": tool_result_msg
})
# 记录工具结果
all_tool_results.append({
"tool": function_name,
"args": arguments,
"result": tool_result_msg
})
# 如果连续多次调用同样的工具,可能陷入循环
if len(all_tool_results) >= 8:
recent_tools = [r["tool"] for r in all_tool_results[-8:]]
if len(set(recent_tools)) == 1: # 最近8次都是同一个工具
self._print(f"\n{OUTPUT_FORMATS['warning']} 检测到重复操作,停止执行")
break
if iteration >= max_iterations:
self._print(f"\n{OUTPUT_FORMATS['warning']} 达到最大迭代次数限制")
return final_response
async def simple_chat(self, messages: List[Dict]) -> tuple:
"""
简单对话(无工具调用)
Args:
messages: 消息列表
Returns:
(模型回答, 思考内容)
"""
full_response = ""
thinking_content = ""
in_thinking = False
# 如果思考模式且已有本任务的思考内容,补充到上下文,确保多次调用时思考不割裂
if (
self.thinking_mode
and not self.current_task_first_call
and self.current_task_thinking
):
thinking_context = (
"\n=== 📋 本次任务的思考 ===\n"
f"{self.current_task_thinking}\n"
"=== 思考结束 ===\n"
"提示:以上是本轮任务先前的思考,请在此基础上继续。"
)
messages.append({
"role": "system",
"content": thinking_context
})
thinking_context_injected = True
try:
async for chunk in self.chat(messages, tools=None, stream=True):
if "choices" not in chunk:
continue
delta = chunk["choices"][0].get("delta", {})
# 处理思考内容
if "reasoning_content" in delta:
reasoning_content = delta["reasoning_content"]
if reasoning_content: # 只处理非空内容
if not in_thinking:
self._print("💭 [正在思考]\n", end="", flush=True)
in_thinking = True
thinking_content += reasoning_content
self._print(reasoning_content, end="", flush=True)
# 处理正常内容 - 独立的if而不是elif
if "content" in delta:
content = delta["content"]
if content: # 只处理非空内容
if in_thinking:
self._print("\n\n💭 [思考结束]\n\n", end="", flush=True)
in_thinking = False
full_response += content
self._print(content, end="", flush=True)
self._print("") # 最终换行
# 如果思考还没结束(极少情况),手动结束
if in_thinking:
self._print("\n💭 [思考结束]\n")
if self.last_call_used_thinking and thinking_content:
self.current_task_thinking = thinking_content
if self.current_task_first_call:
self.current_task_first_call = False
# 如果没有收到任何响应
if not full_response and not thinking_content:
self._print(f"{OUTPUT_FORMATS['error']} API未返回任何内容请检查API密钥和模型ID")
return "", ""
except Exception as e:
self._print(f"{OUTPUT_FORMATS['error']} API调用失败: {e}")
return "", ""
return full_response, thinking_content