1138 lines
50 KiB
Python
1138 lines
50 KiB
Python
# ========== api_client.py ==========
|
||
# utils/api_client.py - DeepSeek API 客户端(支持Web模式)- 简化版
|
||
|
||
import httpx
|
||
import json
|
||
import asyncio
|
||
import base64
|
||
import mimetypes
|
||
import os
|
||
from typing import List, Dict, Optional, AsyncGenerator, Any
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Tuple
|
||
try:
|
||
from config import (
|
||
API_BASE_URL,
|
||
API_KEY,
|
||
MODEL_ID,
|
||
OUTPUT_FORMATS,
|
||
DEFAULT_RESPONSE_MAX_TOKENS,
|
||
THINKING_API_BASE_URL,
|
||
THINKING_API_KEY,
|
||
THINKING_MODEL_ID
|
||
)
|
||
except ImportError:
|
||
import sys
|
||
from pathlib import Path
|
||
project_root = Path(__file__).resolve().parents[1]
|
||
if str(project_root) not in sys.path:
|
||
sys.path.insert(0, str(project_root))
|
||
from config import (
|
||
API_BASE_URL,
|
||
API_KEY,
|
||
MODEL_ID,
|
||
OUTPUT_FORMATS,
|
||
DEFAULT_RESPONSE_MAX_TOKENS,
|
||
THINKING_API_BASE_URL,
|
||
THINKING_API_KEY,
|
||
THINKING_MODEL_ID
|
||
)
|
||
|
||
class DeepSeekClient:
|
||
def __init__(self, thinking_mode: bool = True, web_mode: bool = False):
|
||
self.fast_api_config = {
|
||
"base_url": API_BASE_URL,
|
||
"api_key": API_KEY,
|
||
"model_id": MODEL_ID
|
||
}
|
||
self.thinking_api_config = {
|
||
"base_url": THINKING_API_BASE_URL or API_BASE_URL,
|
||
"api_key": THINKING_API_KEY or API_KEY,
|
||
"model_id": THINKING_MODEL_ID or MODEL_ID
|
||
}
|
||
self.fast_max_tokens = None
|
||
self.thinking_max_tokens = None
|
||
self.fast_extra_params: Dict = {}
|
||
self.thinking_extra_params: Dict = {}
|
||
# 上下文预算(由上层在每次请求前设置)
|
||
self.current_context_tokens: int = 0
|
||
self.max_context_tokens: Optional[int] = None
|
||
self.default_context_window: Optional[int] = None
|
||
self.thinking_mode = thinking_mode # True=智能思考模式, False=快速模式
|
||
self.deep_thinking_mode = False # 深度思考模式:整轮都使用思考模型
|
||
self.deep_thinking_session = False # 当前任务是否处于深度思考会话
|
||
self.web_mode = web_mode # Web模式标志,用于禁用print输出
|
||
# 兼容旧代码路径
|
||
self.api_base_url = self.fast_api_config["base_url"]
|
||
self.api_key = self.fast_api_config["api_key"]
|
||
self.model_id = self.fast_api_config["model_id"]
|
||
self.model_key = None # 由宿主终端注入,便于做模型兼容处理
|
||
self.project_path: Optional[str] = None
|
||
# 每个任务的独立状态
|
||
self.current_task_first_call = True # 当前任务是否是第一次调用
|
||
self.current_task_thinking = "" # 当前任务的思考内容
|
||
self.force_thinking_next_call = False # 单次强制思考
|
||
self.skip_thinking_next_call = False # 单次强制快速
|
||
self.last_call_used_thinking = False # 最近一次调用是否使用思考模型
|
||
# 最近一次API错误详情
|
||
self.last_error_info: Optional[Dict[str, Any]] = None
|
||
# 请求体落盘目录
|
||
self.request_dump_dir = Path(__file__).resolve().parents[1] / "logs" / "api_requests"
|
||
self.request_dump_dir.mkdir(parents=True, exist_ok=True)
|
||
self.debug_log_path = Path(__file__).resolve().parents[1] / "logs" / "api_debug.log"
|
||
|
||
def _maybe_mark_aliyun_quota(self, error_text: str) -> None:
|
||
if not error_text or not self.model_key:
|
||
return
|
||
try:
|
||
from utils.aliyun_fallback import compute_disabled_until, set_disabled_until
|
||
except Exception:
|
||
return
|
||
disabled_until, reason = compute_disabled_until(error_text)
|
||
if disabled_until and reason:
|
||
set_disabled_until(self.model_key, disabled_until, reason)
|
||
# 立即切换到官方 API(仅在有配置时)
|
||
base_env_key = None
|
||
key_env_key = None
|
||
if self.model_key == "kimi-k2.5":
|
||
base_env_key = "API_BASE_KIMI_OFFICIAL"
|
||
key_env_key = "API_KEY_KIMI_OFFICIAL"
|
||
elif self.model_key == "qwen3-vl-plus":
|
||
base_env_key = "API_BASE_QWEN_OFFICIAL"
|
||
key_env_key = "API_KEY_QWEN_OFFICIAL"
|
||
elif self.model_key == "minimax-m2.5":
|
||
base_env_key = "API_BASE_MINIMAX_OFFICIAL"
|
||
key_env_key = "API_KEY_MINIMAX_OFFICIAL"
|
||
if base_env_key and key_env_key:
|
||
official_base = self._resolve_env_value(base_env_key)
|
||
official_key = self._resolve_env_value(key_env_key)
|
||
if official_base and official_key:
|
||
self.fast_api_config["base_url"] = official_base
|
||
self.fast_api_config["api_key"] = official_key
|
||
self.thinking_api_config["base_url"] = official_base
|
||
self.thinking_api_config["api_key"] = official_key
|
||
self.api_base_url = official_base
|
||
self.api_key = official_key
|
||
|
||
def _debug_log(self, payload: Dict[str, Any]) -> None:
|
||
try:
|
||
entry = {
|
||
"ts": datetime.now().isoformat(),
|
||
**payload
|
||
}
|
||
self.debug_log_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with self.debug_log_path.open("a", encoding="utf-8") as f:
|
||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||
except Exception:
|
||
pass
|
||
|
||
def _resolve_env_value(self, name: str) -> Optional[str]:
|
||
value = os.environ.get(name)
|
||
if value is None:
|
||
env_path = Path(__file__).resolve().parents[1] / ".env"
|
||
if env_path.exists():
|
||
try:
|
||
for raw_line in env_path.read_text(encoding="utf-8").splitlines():
|
||
line = raw_line.strip()
|
||
if not line or line.startswith("#") or "=" not in line:
|
||
continue
|
||
key, val = line.split("=", 1)
|
||
if key.strip() == name:
|
||
value = val.strip().strip('"').strip("'")
|
||
break
|
||
except Exception:
|
||
value = None
|
||
if value is None:
|
||
return None
|
||
value = value.strip()
|
||
return value or None
|
||
|
||
def _print(self, message: str, end: str = "\n", flush: bool = False):
|
||
"""安全的打印函数,在Web模式下不输出"""
|
||
if not self.web_mode:
|
||
print(message, end=end, flush=flush)
|
||
|
||
def _format_read_file_result(self, data: Dict) -> str:
|
||
"""根据读取模式格式化 read_file 工具结果。"""
|
||
if not isinstance(data, dict):
|
||
return json.dumps(data, ensure_ascii=False)
|
||
if not data.get("success"):
|
||
return json.dumps(data, ensure_ascii=False)
|
||
|
||
read_type = data.get("type", "read")
|
||
truncated_note = "(内容已截断)" if data.get("truncated") else ""
|
||
path = data.get("path", "未知路径")
|
||
max_chars = data.get("max_chars")
|
||
max_note = f"(max_chars={max_chars})" if max_chars else ""
|
||
|
||
if read_type == "read":
|
||
line_start = data.get("line_start")
|
||
line_end = data.get("line_end")
|
||
char_count = data.get("char_count", len(data.get("content", "") or ""))
|
||
header = f"读取 {path} 行 {line_start}~{line_end},返回 {char_count} 字符 {max_note}{truncated_note}".strip()
|
||
content = data.get("content", "")
|
||
return f"{header}\n```\n{content}\n```"
|
||
|
||
if read_type == "search":
|
||
query = data.get("query", "")
|
||
actual = data.get("actual_matches", 0)
|
||
returned = data.get("returned_matches", 0)
|
||
case_hint = "区分大小写" if data.get("case_sensitive") else "不区分大小写"
|
||
header = (
|
||
f"在 {path} 中搜索 \"{query}\",返回 {returned}/{actual} 条结果({case_hint})"
|
||
f" {max_note}{truncated_note}"
|
||
).strip()
|
||
match_texts = []
|
||
for idx, match in enumerate(data.get("matches", []), 1):
|
||
match_note = "(片段截断)" if match.get("truncated") else ""
|
||
hits = match.get("hits") or []
|
||
hit_text = ", ".join(str(h) for h in hits) if hits else "无"
|
||
label = match.get("id") or f"match_{idx}"
|
||
snippet = match.get("snippet", "")
|
||
match_texts.append(
|
||
f"[{label}] 行 {match.get('line_start')}~{match.get('line_end')} 命中行: {hit_text}{match_note}\n```\n{snippet}\n```"
|
||
)
|
||
if not match_texts:
|
||
match_texts.append("未找到匹配内容。")
|
||
return "\n".join([header] + match_texts)
|
||
|
||
def _build_content_with_images(self, text: str, images: List[str], videos: Optional[List[Any]] = None) -> Any:
|
||
"""将文本与图片/视频路径拼成多模态 content(用于 tool 消息)。"""
|
||
videos = videos or []
|
||
if not images and not videos:
|
||
return text
|
||
qwen_video_fps = 2
|
||
parts: List[Dict[str, Any]] = []
|
||
if text:
|
||
parts.append({"type": "text", "text": text})
|
||
base_path = Path(self.project_path or ".")
|
||
for path in images:
|
||
try:
|
||
abs_path = (base_path / path).resolve()
|
||
if not abs_path.exists() or not abs_path.is_file():
|
||
continue
|
||
mime, _ = mimetypes.guess_type(abs_path.name)
|
||
if not mime:
|
||
mime = "image/png"
|
||
data = abs_path.read_bytes()
|
||
b64 = base64.b64encode(data).decode("utf-8")
|
||
parts.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
||
except Exception:
|
||
continue
|
||
for item in videos:
|
||
try:
|
||
if isinstance(item, dict):
|
||
path = item.get("path") or ""
|
||
else:
|
||
path = item
|
||
if not path:
|
||
continue
|
||
abs_path = (base_path / path).resolve()
|
||
if not abs_path.exists() or not abs_path.is_file():
|
||
continue
|
||
mime, _ = mimetypes.guess_type(abs_path.name)
|
||
if not mime:
|
||
mime = "video/mp4"
|
||
data = abs_path.read_bytes()
|
||
b64 = base64.b64encode(data).decode("utf-8")
|
||
payload: Dict[str, Any] = {
|
||
"type": "video_url",
|
||
"video_url": {"url": f"data:{mime};base64,{b64}"}
|
||
}
|
||
if self.model_key == "qwen3-vl-plus":
|
||
payload["fps"] = qwen_video_fps
|
||
parts.append(payload)
|
||
except Exception:
|
||
continue
|
||
return parts if parts else text
|
||
|
||
if read_type == "extract":
|
||
segments = data.get("segments", [])
|
||
header = (
|
||
f"从 {path} 抽取 {len(segments)} 个片段 {max_note}{truncated_note}"
|
||
).strip()
|
||
seg_texts = []
|
||
for idx, segment in enumerate(segments, 1):
|
||
seg_note = "(片段截断)" if segment.get("truncated") else ""
|
||
label = segment.get("label") or f"segment_{idx}"
|
||
snippet = segment.get("content", "")
|
||
seg_texts.append(
|
||
f"[{label}] 行 {segment.get('line_start')}~{segment.get('line_end')}{seg_note}\n```\n{snippet}\n```"
|
||
)
|
||
if not seg_texts:
|
||
seg_texts.append("未提供可抽取的片段。")
|
||
return "\n".join([header] + seg_texts)
|
||
|
||
return json.dumps(data, ensure_ascii=False)
|
||
|
||
def _extract_reasoning_delta(self, delta: Dict[str, Any]) -> str:
|
||
"""统一提取思考内容,兼容 reasoning_content / reasoning_details。"""
|
||
if not isinstance(delta, dict):
|
||
return ""
|
||
if "reasoning_content" in delta:
|
||
return delta.get("reasoning_content") or ""
|
||
details = delta.get("reasoning_details")
|
||
if isinstance(details, list):
|
||
parts: List[str] = []
|
||
for item in details:
|
||
if isinstance(item, dict):
|
||
text = item.get("text")
|
||
if text:
|
||
parts.append(text)
|
||
if parts:
|
||
return "".join(parts)
|
||
return ""
|
||
|
||
def _merge_system_messages(self, messages: List[Dict]) -> List[Dict]:
|
||
"""
|
||
仅合并最开头连续的 system 消息(系统提示),后续插入的 system 消息保持原样。
|
||
"""
|
||
if not messages:
|
||
return messages
|
||
|
||
merged_contents: List[str] = []
|
||
idx = 0
|
||
while idx < len(messages) and messages[idx].get("role") == "system":
|
||
content = messages[idx].get("content", "")
|
||
if isinstance(content, str):
|
||
merged_contents.append(content)
|
||
else:
|
||
merged_contents.append(json.dumps(content, ensure_ascii=False))
|
||
idx += 1
|
||
|
||
if not merged_contents:
|
||
return messages
|
||
|
||
merged = {
|
||
"role": "system",
|
||
"content": "\n\n".join(c for c in merged_contents if c)
|
||
}
|
||
return [merged] + messages[idx:]
|
||
|
||
def set_deep_thinking_mode(self, enabled: bool):
|
||
"""配置深度思考模式(持续使用思考模型)。"""
|
||
self.deep_thinking_mode = bool(enabled)
|
||
if not enabled:
|
||
self.deep_thinking_session = False
|
||
|
||
def start_new_task(self, force_deep: bool = False):
|
||
"""开始新任务(重置任务级别的状态)"""
|
||
self.current_task_first_call = True
|
||
self.current_task_thinking = ""
|
||
self.force_thinking_next_call = False
|
||
self.skip_thinking_next_call = False
|
||
self.last_call_used_thinking = False
|
||
self.deep_thinking_session = bool(force_deep) or bool(self.deep_thinking_mode)
|
||
|
||
def _build_headers(self, api_key: str) -> Dict[str, str]:
|
||
return {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
def _select_api_config(self, use_thinking: bool) -> Dict[str, str]:
|
||
"""
|
||
根据当前模式选择API配置,确保缺失字段回退到默认模型。
|
||
"""
|
||
config = self.thinking_api_config if use_thinking else self.fast_api_config
|
||
fallback = self.fast_api_config
|
||
return {
|
||
"base_url": config.get("base_url") or fallback["base_url"],
|
||
"api_key": config.get("api_key") or fallback["api_key"],
|
||
"model_id": config.get("model_id") or fallback["model_id"]
|
||
}
|
||
|
||
def _dump_request_payload(self, payload: Dict, api_config: Dict, headers: Dict) -> Path:
|
||
"""
|
||
将本次请求的payload、headers、配置落盘,便于排查400等错误。
|
||
返回写入的文件路径。
|
||
"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||
filename = f"req_{timestamp}.json"
|
||
path = self.request_dump_dir / filename
|
||
try:
|
||
headers_sanitized = {}
|
||
for k, v in headers.items():
|
||
headers_sanitized[k] = "***" if k.lower() == "authorization" else v
|
||
data = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"api_config": {k: api_config.get(k) for k in ["base_url", "model_id"]},
|
||
"headers": headers_sanitized,
|
||
"payload": payload
|
||
}
|
||
path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
except Exception as exc:
|
||
self._print(f"{OUTPUT_FORMATS['warning']} 请求体落盘失败: {exc}")
|
||
return path
|
||
|
||
def _mark_request_error(self, dump_path: Path, status_code: int = None, error_text: str = None):
|
||
"""
|
||
在已有请求文件中追加错误标记,便于快速定位。
|
||
"""
|
||
if not dump_path or not dump_path.exists():
|
||
return
|
||
try:
|
||
data = json.loads(dump_path.read_text(encoding="utf-8"))
|
||
data["error"] = {
|
||
"status_code": status_code,
|
||
"message": error_text,
|
||
"marked_at": datetime.now().isoformat()
|
||
}
|
||
dump_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
except Exception as exc:
|
||
self._print(f"{OUTPUT_FORMATS['warning']} 标记请求错误失败: {exc}")
|
||
|
||
|
||
def apply_profile(self, profile: Dict):
|
||
"""
|
||
动态应用模型配置
|
||
profile 示例:
|
||
{
|
||
"fast": {"base_url": "...", "api_key": "...", "model_id": "...", "max_tokens": 8192},
|
||
"thinking": {...} 或 None,
|
||
"supports_thinking": True/False,
|
||
"fast_only": True/False
|
||
}
|
||
"""
|
||
if not profile or "fast" not in profile:
|
||
raise ValueError("无效的模型配置")
|
||
fast = profile["fast"] or {}
|
||
thinking = profile.get("thinking") or fast
|
||
self.fast_api_config = {
|
||
"base_url": fast.get("base_url") or self.fast_api_config.get("base_url"),
|
||
"api_key": fast.get("api_key") or self.fast_api_config.get("api_key"),
|
||
"model_id": fast.get("model_id") or self.fast_api_config.get("model_id")
|
||
}
|
||
self.thinking_api_config = {
|
||
"base_url": thinking.get("base_url") or self.thinking_api_config.get("base_url"),
|
||
"api_key": thinking.get("api_key") or self.thinking_api_config.get("api_key"),
|
||
"model_id": thinking.get("model_id") or self.thinking_api_config.get("model_id")
|
||
}
|
||
self.fast_max_tokens = fast.get("max_tokens")
|
||
self.thinking_max_tokens = thinking.get("max_tokens")
|
||
self.fast_extra_params = fast.get("extra_params") or {}
|
||
self.thinking_extra_params = thinking.get("extra_params") or {}
|
||
self.default_context_window = profile.get("context_window") or fast.get("context_window")
|
||
# 同步旧字段
|
||
self.api_base_url = self.fast_api_config["base_url"]
|
||
self.api_key = self.fast_api_config["api_key"]
|
||
self.model_id = self.fast_api_config["model_id"]
|
||
|
||
def update_context_budget(self, current_tokens: int, max_tokens: Optional[int]):
|
||
"""
|
||
由上层在每次调用前告知当前对话占用的token数和模型最大上下文。
|
||
"""
|
||
try:
|
||
self.current_context_tokens = max(0, int(current_tokens))
|
||
except (TypeError, ValueError):
|
||
self.current_context_tokens = 0
|
||
try:
|
||
self.max_context_tokens = int(max_tokens) if max_tokens is not None else None
|
||
except (TypeError, ValueError):
|
||
self.max_context_tokens = None
|
||
|
||
def get_current_thinking_mode(self) -> bool:
|
||
"""获取当前应该使用的思考模式"""
|
||
if self.deep_thinking_session:
|
||
return True
|
||
if not self.thinking_mode:
|
||
return False
|
||
if self.force_thinking_next_call:
|
||
return True
|
||
if self.skip_thinking_next_call:
|
||
return False
|
||
return self.current_task_first_call
|
||
|
||
def _validate_json_string(self, json_str: str) -> tuple:
|
||
"""
|
||
验证JSON字符串的完整性
|
||
|
||
Returns:
|
||
(is_valid: bool, error_message: str, parsed_data: dict or None)
|
||
"""
|
||
if not json_str or not json_str.strip():
|
||
return True, "", {}
|
||
|
||
# 检查基本的JSON结构标记
|
||
stripped = json_str.strip()
|
||
if not stripped.startswith('{') or not stripped.endswith('}'):
|
||
return False, "JSON字符串格式不完整(缺少开始或结束大括号)", None
|
||
|
||
# 检查引号配对
|
||
in_string = False
|
||
escape_next = False
|
||
quote_count = 0
|
||
|
||
for char in stripped:
|
||
if escape_next:
|
||
escape_next = False
|
||
continue
|
||
|
||
if char == '\\':
|
||
escape_next = True
|
||
continue
|
||
|
||
if char == '"':
|
||
quote_count += 1
|
||
in_string = not in_string
|
||
|
||
if in_string:
|
||
return False, "JSON字符串中存在未闭合的引号", None
|
||
|
||
# 尝试解析JSON
|
||
try:
|
||
parsed_data = json.loads(stripped)
|
||
return True, "", parsed_data
|
||
except json.JSONDecodeError as e:
|
||
return False, f"JSON解析错误: {str(e)}", None
|
||
|
||
def _safe_tool_arguments_parse(self, arguments_str: str, tool_name: str) -> tuple:
|
||
"""
|
||
安全地解析工具参数,保持失败即时返回
|
||
|
||
Returns:
|
||
(success: bool, arguments: dict, error_message: str)
|
||
"""
|
||
if not arguments_str or not arguments_str.strip():
|
||
return True, {}, ""
|
||
|
||
# 长度检查
|
||
max_length = 999999999 # 50KB限制
|
||
if len(arguments_str) > max_length:
|
||
return False, {}, f"参数过长({len(arguments_str)}字符),超过{max_length}字符限制"
|
||
|
||
# 尝试直接解析JSON
|
||
try:
|
||
parsed_data = json.loads(arguments_str)
|
||
return True, parsed_data, ""
|
||
except json.JSONDecodeError as e:
|
||
preview_length = 200
|
||
stripped = arguments_str.strip()
|
||
preview = stripped[:preview_length] + "..." if len(stripped) > preview_length else stripped
|
||
return False, {}, f"JSON解析失败: {str(e)}\n参数预览: {preview}"
|
||
async def chat(
|
||
self,
|
||
messages: List[Dict],
|
||
tools: Optional[List[Dict]] = None,
|
||
stream: bool = True
|
||
) -> AsyncGenerator[Dict, None]:
|
||
"""
|
||
异步调用DeepSeek API
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
tools: 工具定义列表
|
||
stream: 是否流式输出
|
||
|
||
Yields:
|
||
响应内容块
|
||
"""
|
||
# 检查API密钥
|
||
if not self.api_key or self.api_key == "your-deepseek-api-key":
|
||
self._print(f"{OUTPUT_FORMATS['error']} API密钥未配置,请在config.py中设置API_KEY")
|
||
return
|
||
|
||
# 决定是否使用思考模式
|
||
current_thinking_mode = self.get_current_thinking_mode()
|
||
api_config = self._select_api_config(current_thinking_mode)
|
||
headers = self._build_headers(api_config["api_key"])
|
||
|
||
# 如果当前为快速模式但已有思考内容,提示沿用
|
||
if self.thinking_mode and not current_thinking_mode and self.current_task_thinking:
|
||
self._print(f"{OUTPUT_FORMATS['info']} [任务内快速模式] 使用本次任务的思考继续处理...")
|
||
|
||
# 记录本次调用的模式
|
||
self.last_call_used_thinking = current_thinking_mode
|
||
if current_thinking_mode and self.force_thinking_next_call:
|
||
self.force_thinking_next_call = False
|
||
if not current_thinking_mode and self.skip_thinking_next_call:
|
||
self.skip_thinking_next_call = False
|
||
|
||
try:
|
||
override_max = self.thinking_max_tokens if current_thinking_mode else self.fast_max_tokens
|
||
if override_max is not None:
|
||
max_tokens = int(override_max)
|
||
else:
|
||
max_tokens = int(DEFAULT_RESPONSE_MAX_TOKENS)
|
||
if max_tokens <= 0:
|
||
raise ValueError("max_tokens must be positive")
|
||
except (TypeError, ValueError):
|
||
max_tokens = 4096
|
||
|
||
# 动态收缩 max_tokens,避免超过模型上下文窗口
|
||
budget_max_context = self.max_context_tokens or self.default_context_window
|
||
if budget_max_context and budget_max_context > 0:
|
||
used_tokens = max(0, int(self.current_context_tokens or 0))
|
||
available = budget_max_context - used_tokens
|
||
if available <= 0:
|
||
# 兜底:让上游错误处理,这里至少给1防止API报参数错误
|
||
max_tokens = 1
|
||
else:
|
||
max_tokens = min(max_tokens, available)
|
||
|
||
lower_base_url = (api_config.get("base_url") or "").lower()
|
||
is_minimax = self.model_key == "minimax-m2.5" or "minimax" in lower_base_url
|
||
|
||
final_messages = self._merge_system_messages(messages)
|
||
|
||
payload = {
|
||
"model": api_config["model_id"],
|
||
"messages": final_messages,
|
||
"stream": stream,
|
||
}
|
||
if is_minimax:
|
||
payload["max_completion_tokens"] = max_tokens
|
||
else:
|
||
payload["max_tokens"] = max_tokens
|
||
# 部分平台(如 Qwen、DeepSeek)需要显式请求 usage 才会在流式尾包返回
|
||
if stream:
|
||
should_include_usage = False
|
||
if self.model_key in {"qwen3-max", "qwen3-vl-plus", "deepseek", "minimax-m2.5"}:
|
||
should_include_usage = True
|
||
# 兜底:根据 base_url 识别 openai 兼容的提供商
|
||
if api_config["base_url"]:
|
||
lower_url = api_config["base_url"].lower()
|
||
if any(keyword in lower_url for keyword in ["dashscope", "aliyuncs", "deepseek.com"]):
|
||
should_include_usage = True
|
||
if should_include_usage:
|
||
if is_minimax:
|
||
# MiniMax 流式需要 stream_options.include_usage 才会返回有效 usage
|
||
payload["include_usage"] = True
|
||
payload.setdefault("stream_options", {})["include_usage"] = True
|
||
else:
|
||
payload.setdefault("stream_options", {})["include_usage"] = True
|
||
# 注入模型额外参数(如 Qwen enable_thinking)
|
||
extra_params = self.thinking_extra_params if current_thinking_mode else self.fast_extra_params
|
||
if extra_params:
|
||
payload.update(extra_params)
|
||
if tools:
|
||
payload["tools"] = tools
|
||
if not is_minimax:
|
||
payload["tool_choice"] = "auto"
|
||
|
||
# 将本次请求落盘,便于出错时快速定位
|
||
dump_path = self._dump_request_payload(payload, api_config, headers)
|
||
|
||
try:
|
||
async with httpx.AsyncClient(http2=True, timeout=300) as client:
|
||
if stream:
|
||
async with client.stream(
|
||
"POST",
|
||
f"{api_config['base_url']}/chat/completions",
|
||
json=payload,
|
||
headers=headers
|
||
) as response:
|
||
# 检查响应状态
|
||
if response.status_code != 200:
|
||
error_bytes = await response.aread()
|
||
error_text = error_bytes.decode('utf-8', errors='ignore') if hasattr(error_bytes, 'decode') else str(error_bytes)
|
||
self.last_error_info = {
|
||
"status_code": response.status_code,
|
||
"error_text": error_text,
|
||
"error_type": None,
|
||
"error_message": None,
|
||
"request_dump": str(dump_path),
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key
|
||
}
|
||
try:
|
||
parsed = json.loads(error_text)
|
||
err = parsed.get("error") if isinstance(parsed, dict) else {}
|
||
if isinstance(err, dict):
|
||
self.last_error_info["error_type"] = err.get("type")
|
||
self.last_error_info["error_message"] = err.get("message")
|
||
except Exception:
|
||
pass
|
||
self._maybe_mark_aliyun_quota(error_text)
|
||
self._debug_log({
|
||
"event": "http_error_stream",
|
||
"status_code": response.status_code,
|
||
"error_text": error_text,
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key,
|
||
"request_dump": str(dump_path)
|
||
})
|
||
self._print(
|
||
f"{OUTPUT_FORMATS['error']} API请求失败 ({response.status_code}): {error_text} "
|
||
f"(base_url={api_config.get('base_url')}, model_id={api_config.get('model_id')})"
|
||
)
|
||
self._mark_request_error(dump_path, response.status_code, error_text)
|
||
yield {"error": self.last_error_info}
|
||
return
|
||
|
||
async for line in response.aiter_lines():
|
||
if line.startswith("data:"):
|
||
json_str = line[5:].strip()
|
||
if json_str == "[DONE]":
|
||
break
|
||
|
||
try:
|
||
data = json.loads(json_str)
|
||
yield data
|
||
except json.JSONDecodeError:
|
||
continue
|
||
else:
|
||
response = await client.post(
|
||
f"{api_config['base_url']}/chat/completions",
|
||
json=payload,
|
||
headers=headers
|
||
)
|
||
if response.status_code != 200:
|
||
error_text = response.text
|
||
self.last_error_info = {
|
||
"status_code": response.status_code,
|
||
"error_text": error_text,
|
||
"error_type": None,
|
||
"error_message": None,
|
||
"request_dump": str(dump_path),
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key
|
||
}
|
||
try:
|
||
parsed = response.json()
|
||
err = parsed.get("error") if isinstance(parsed, dict) else {}
|
||
if isinstance(err, dict):
|
||
self.last_error_info["error_type"] = err.get("type")
|
||
self.last_error_info["error_message"] = err.get("message")
|
||
except Exception:
|
||
pass
|
||
self._maybe_mark_aliyun_quota(error_text)
|
||
self._debug_log({
|
||
"event": "http_error",
|
||
"status_code": response.status_code,
|
||
"error_text": error_text,
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key,
|
||
"request_dump": str(dump_path)
|
||
})
|
||
self._print(
|
||
f"{OUTPUT_FORMATS['error']} API请求失败 ({response.status_code}): {error_text} "
|
||
f"(base_url={api_config.get('base_url')}, model_id={api_config.get('model_id')})"
|
||
)
|
||
self._mark_request_error(dump_path, response.status_code, error_text)
|
||
yield {"error": self.last_error_info}
|
||
return
|
||
# 成功则清空错误状态
|
||
self.last_error_info = None
|
||
yield response.json()
|
||
|
||
except httpx.ConnectError as e:
|
||
connect_detail = str(e).strip() or repr(e)
|
||
self._print(
|
||
f"{OUTPUT_FORMATS['error']} 无法连接到API服务器,请检查网络连接"
|
||
f"({connect_detail})"
|
||
)
|
||
self.last_error_info = {
|
||
"status_code": None,
|
||
"error_text": "connect_error",
|
||
"error_type": "connection_error",
|
||
"error_message": f"无法连接到API服务器: {connect_detail}",
|
||
"error_detail": connect_detail,
|
||
"request_dump": str(dump_path),
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key
|
||
}
|
||
self._maybe_mark_aliyun_quota(self.last_error_info.get("error_text"))
|
||
self._debug_log({
|
||
"event": "connect_error",
|
||
"status_code": None,
|
||
"error_text": "connect_error",
|
||
"error_detail": connect_detail,
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key,
|
||
"request_dump": str(dump_path)
|
||
})
|
||
self._mark_request_error(dump_path, error_text=f"connect_error: {connect_detail}")
|
||
yield {"error": self.last_error_info}
|
||
except httpx.TimeoutException:
|
||
self._print(f"{OUTPUT_FORMATS['error']} API请求超时")
|
||
self.last_error_info = {
|
||
"status_code": None,
|
||
"error_text": "timeout",
|
||
"error_type": "timeout",
|
||
"error_message": "API请求超时",
|
||
"request_dump": str(dump_path),
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key
|
||
}
|
||
self._maybe_mark_aliyun_quota(self.last_error_info.get("error_text"))
|
||
self._debug_log({
|
||
"event": "timeout",
|
||
"status_code": None,
|
||
"error_text": "timeout",
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key,
|
||
"request_dump": str(dump_path)
|
||
})
|
||
self._mark_request_error(dump_path, error_text="timeout")
|
||
yield {"error": self.last_error_info}
|
||
except Exception as e:
|
||
error_text = str(e).strip() or repr(e)
|
||
self._print(f"{OUTPUT_FORMATS['error']} API调用异常: {error_text}")
|
||
self.last_error_info = {
|
||
"status_code": None,
|
||
"error_text": error_text,
|
||
"error_type": "exception",
|
||
"error_message": error_text,
|
||
"request_dump": str(dump_path),
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key
|
||
}
|
||
self._maybe_mark_aliyun_quota(self.last_error_info.get("error_text"))
|
||
self._debug_log({
|
||
"event": "exception",
|
||
"status_code": None,
|
||
"error_text": error_text,
|
||
"base_url": api_config.get("base_url"),
|
||
"model_id": api_config.get("model_id"),
|
||
"model_key": self.model_key,
|
||
"request_dump": str(dump_path)
|
||
})
|
||
self._mark_request_error(dump_path, error_text=error_text)
|
||
yield {"error": self.last_error_info}
|
||
|
||
async def chat_with_tools(
|
||
self,
|
||
messages: List[Dict],
|
||
tools: List[Dict],
|
||
tool_handler: callable
|
||
) -> str:
|
||
"""
|
||
带工具调用的对话(支持多轮)
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
tools: 工具定义
|
||
tool_handler: 工具处理函数
|
||
|
||
Returns:
|
||
最终回答
|
||
"""
|
||
final_response = ""
|
||
max_iterations = 200 # 最大迭代次数
|
||
iteration = 0
|
||
all_tool_results = [] # 记录所有工具调用结果
|
||
|
||
while iteration < max_iterations:
|
||
iteration += 1
|
||
|
||
# 调用API(始终提供工具定义)
|
||
full_response = ""
|
||
tool_calls = []
|
||
current_thinking = ""
|
||
|
||
# 状态标志
|
||
in_thinking = False
|
||
thinking_printed = False
|
||
|
||
async for chunk in self.chat(messages, tools, stream=True):
|
||
if chunk.get("error"):
|
||
# 直接返回错误,让上层处理
|
||
err = chunk["error"]
|
||
self.last_error_info = err
|
||
err_msg = err.get("error_message") or err.get("error_text") or "API调用失败"
|
||
status = err.get("status_code")
|
||
self._print(f"{OUTPUT_FORMATS['error']} 模型API错误{f'({status})' if status is not None else ''}: {err_msg}")
|
||
return ""
|
||
|
||
if "choices" not in chunk:
|
||
continue
|
||
|
||
delta = chunk["choices"][0].get("delta", {})
|
||
|
||
# 处理思考内容
|
||
reasoning_content = self._extract_reasoning_delta(delta)
|
||
if reasoning_content:
|
||
if not in_thinking:
|
||
self._print("💭 [正在思考]\n", end="", flush=True)
|
||
in_thinking = True
|
||
thinking_printed = True
|
||
current_thinking += reasoning_content
|
||
self._print(reasoning_content, end="", flush=True)
|
||
|
||
# 处理正常内容 - 独立的if,不是elif
|
||
if "content" in delta:
|
||
content = delta["content"]
|
||
if content: # 只处理非空内容
|
||
# 如果之前在输出思考,先结束思考输出
|
||
if in_thinking:
|
||
self._print("\n\n💭 [思考结束]\n\n", end="", flush=True)
|
||
in_thinking = False
|
||
full_response += content
|
||
self._print(content, end="", flush=True)
|
||
|
||
# 收集工具调用 - 改进的拼接逻辑
|
||
# 收集工具调用 - 修复JSON分片问题
|
||
if "tool_calls" in delta:
|
||
for tool_call in delta["tool_calls"]:
|
||
tool_index = tool_call.get("index", 0)
|
||
|
||
# 查找或创建对应索引的工具调用
|
||
existing_call = None
|
||
for existing in tool_calls:
|
||
if existing.get("index") == tool_index:
|
||
existing_call = existing
|
||
break
|
||
|
||
if not existing_call and tool_call.get("id"):
|
||
# 创建新的工具调用
|
||
new_call = {
|
||
"id": tool_call.get("id"),
|
||
"index": tool_index,
|
||
"type": tool_call.get("type", "function"),
|
||
"function": {
|
||
"name": tool_call.get("function", {}).get("name", ""),
|
||
"arguments": ""
|
||
}
|
||
}
|
||
tool_calls.append(new_call)
|
||
existing_call = new_call
|
||
|
||
# 安全地拼接arguments - 简单字符串拼接,不尝试JSON验证
|
||
if existing_call and "function" in tool_call and "arguments" in tool_call["function"]:
|
||
new_args = tool_call["function"]["arguments"]
|
||
if new_args: # 只拼接非空内容
|
||
existing_call["function"]["arguments"] += new_args
|
||
|
||
self._print("") # 最终换行
|
||
|
||
# 如果思考还没结束(只调用工具没有文本),手动结束
|
||
if in_thinking:
|
||
self._print("\n💭 [思考结束]\n")
|
||
|
||
# 记录思考内容并更新调用状态
|
||
if self.last_call_used_thinking and current_thinking:
|
||
self.current_task_thinking = current_thinking
|
||
if self.current_task_first_call:
|
||
self.current_task_first_call = False # 标记当前任务的第一次调用已完成
|
||
|
||
# 如果没有工具调用,说明完成了
|
||
if not tool_calls:
|
||
if full_response: # 有正常回复,任务完成
|
||
final_response = full_response
|
||
break
|
||
elif iteration == 1: # 第一次就没有工具调用也没有内容,可能有问题
|
||
self._print(f"{OUTPUT_FORMATS['warning']} 模型未返回内容")
|
||
break
|
||
|
||
# 构建助手消息 - 始终包含所有收集到的内容
|
||
assistant_content_parts = []
|
||
|
||
# 添加正式回复内容(如果有)
|
||
if full_response:
|
||
assistant_content_parts.append(full_response)
|
||
|
||
# 添加工具调用说明
|
||
if tool_calls:
|
||
tool_names = [tc['function']['name'] for tc in tool_calls]
|
||
assistant_content_parts.append(f"执行工具: {', '.join(tool_names)}")
|
||
|
||
# 合并所有内容
|
||
assistant_content = "\n".join(assistant_content_parts) if assistant_content_parts else "执行工具调用"
|
||
|
||
assistant_message = {
|
||
"role": "assistant",
|
||
"content": assistant_content,
|
||
"tool_calls": tool_calls
|
||
}
|
||
if current_thinking:
|
||
assistant_message["reasoning_content"] = current_thinking
|
||
messages.append(assistant_message)
|
||
|
||
# 执行所有工具调用 - 使用鲁棒的参数解析
|
||
for tool_call in tool_calls:
|
||
function_name = tool_call["function"]["name"]
|
||
arguments_str = tool_call["function"]["arguments"]
|
||
|
||
# 使用改进的参数解析方法,增强JSON修复能力
|
||
success, arguments, error_msg = self._safe_tool_arguments_parse(arguments_str, function_name)
|
||
|
||
if not success:
|
||
self._print(f"{OUTPUT_FORMATS['error']} 工具参数解析失败: {error_msg}")
|
||
self._print(f" 工具名称: {function_name}")
|
||
self._print(f" 参数长度: {len(arguments_str)} 字符")
|
||
|
||
# 返回详细的错误信息给模型
|
||
error_response = {
|
||
"success": False,
|
||
"error": error_msg,
|
||
"tool_name": function_name,
|
||
"arguments_length": len(arguments_str),
|
||
"suggestion": "请检查参数格式或减少参数长度后重试"
|
||
}
|
||
|
||
# 如果参数过长,提供分块建议
|
||
if len(arguments_str) > 10000:
|
||
error_response["suggestion"] = "参数过长,建议分块处理或使用更简洁的内容"
|
||
|
||
messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tool_call["id"],
|
||
"name": function_name,
|
||
"content": json.dumps(error_response, ensure_ascii=False)
|
||
})
|
||
|
||
# 记录失败的调用,防止死循环检测失效
|
||
all_tool_results.append({
|
||
"tool": function_name,
|
||
"args": {"parse_error": error_msg, "length": len(arguments_str)},
|
||
"result": f"参数解析失败: {error_msg}"
|
||
})
|
||
continue
|
||
|
||
self._print(f"\n{OUTPUT_FORMATS['action']} 调用工具: {function_name}")
|
||
|
||
tool_result = await tool_handler(function_name, arguments)
|
||
|
||
# 解析工具结果,提取关键信息
|
||
result_data = None
|
||
try:
|
||
result_data = json.loads(tool_result)
|
||
if function_name == "read_file":
|
||
tool_result_msg = self._format_read_file_result(result_data)
|
||
else:
|
||
tool_result_msg = tool_result
|
||
except Exception:
|
||
tool_result_msg = tool_result
|
||
|
||
tool_message_content = tool_result_msg
|
||
if (
|
||
isinstance(result_data, dict)
|
||
and result_data.get("success") is not False
|
||
):
|
||
if function_name == "view_image":
|
||
img_path = result_data.get("path")
|
||
if img_path:
|
||
text_part = tool_result_msg if isinstance(tool_result_msg, str) else ""
|
||
tool_message_content = self._build_content_with_images(text_part, [img_path])
|
||
elif function_name == "view_video":
|
||
video_path = result_data.get("path")
|
||
if video_path:
|
||
text_part = tool_result_msg if isinstance(tool_result_msg, str) else ""
|
||
tool_message_content = self._build_content_with_images(text_part, [], [video_path])
|
||
|
||
messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tool_call["id"],
|
||
"name": function_name,
|
||
"content": tool_message_content
|
||
})
|
||
|
||
# 记录工具结果
|
||
all_tool_results.append({
|
||
"tool": function_name,
|
||
"args": arguments,
|
||
"result": tool_result_msg
|
||
})
|
||
|
||
# 如果连续多次调用同样的工具,可能陷入循环
|
||
if len(all_tool_results) >= 8:
|
||
recent_tools = [r["tool"] for r in all_tool_results[-8:]]
|
||
if len(set(recent_tools)) == 1: # 最近8次都是同一个工具
|
||
self._print(f"\n{OUTPUT_FORMATS['warning']} 检测到重复操作,停止执行")
|
||
break
|
||
|
||
if iteration >= max_iterations:
|
||
self._print(f"\n{OUTPUT_FORMATS['warning']} 达到最大迭代次数限制")
|
||
|
||
return final_response
|
||
|
||
async def simple_chat(self, messages: List[Dict]) -> tuple:
|
||
"""
|
||
简单对话(无工具调用)
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
|
||
Returns:
|
||
(模型回答, 思考内容)
|
||
"""
|
||
full_response = ""
|
||
thinking_content = ""
|
||
in_thinking = False
|
||
|
||
# 如果思考模式且已有本任务的思考内容,补充到上下文,确保多次调用时思考不割裂
|
||
if (
|
||
self.thinking_mode
|
||
and not self.current_task_first_call
|
||
and self.current_task_thinking
|
||
):
|
||
thinking_context = (
|
||
"\n=== 📋 本次任务的思考 ===\n"
|
||
f"{self.current_task_thinking}\n"
|
||
"=== 思考结束 ===\n"
|
||
"提示:以上是本轮任务先前的思考,请在此基础上继续。"
|
||
)
|
||
messages.append({
|
||
"role": "system",
|
||
"content": thinking_context
|
||
})
|
||
thinking_context_injected = True
|
||
|
||
try:
|
||
async for chunk in self.chat(messages, tools=None, stream=True):
|
||
if chunk.get("error"):
|
||
err = chunk["error"]
|
||
self.last_error_info = err
|
||
err_msg = err.get("error_message") or err.get("error_text") or "API调用失败"
|
||
status = err.get("status_code")
|
||
self._print(f"{OUTPUT_FORMATS['error']} 模型API错误{f'({status})' if status is not None else ''}: {err_msg}")
|
||
return "", ""
|
||
|
||
if "choices" not in chunk:
|
||
continue
|
||
|
||
delta = chunk["choices"][0].get("delta", {})
|
||
|
||
# 处理思考内容
|
||
reasoning_content = self._extract_reasoning_delta(delta)
|
||
if reasoning_content:
|
||
if not in_thinking:
|
||
self._print("💭 [正在思考]\n", end="", flush=True)
|
||
in_thinking = True
|
||
thinking_content += reasoning_content
|
||
self._print(reasoning_content, end="", flush=True)
|
||
|
||
# 处理正常内容 - 独立的if而不是elif
|
||
if "content" in delta:
|
||
content = delta["content"]
|
||
if content: # 只处理非空内容
|
||
if in_thinking:
|
||
self._print("\n\n💭 [思考结束]\n\n", end="", flush=True)
|
||
in_thinking = False
|
||
full_response += content
|
||
self._print(content, end="", flush=True)
|
||
|
||
self._print("") # 最终换行
|
||
|
||
# 如果思考还没结束(极少情况),手动结束
|
||
if in_thinking:
|
||
self._print("\n💭 [思考结束]\n")
|
||
|
||
if self.last_call_used_thinking and thinking_content:
|
||
self.current_task_thinking = thinking_content
|
||
if self.current_task_first_call:
|
||
self.current_task_first_call = False
|
||
|
||
# 如果没有收到任何响应
|
||
if not full_response and not thinking_content:
|
||
self._print(f"{OUTPUT_FORMATS['error']} API未返回任何内容,请检查API密钥和模型ID")
|
||
return "", ""
|
||
|
||
except Exception as e:
|
||
self._print(f"{OUTPUT_FORMATS['error']} API调用失败: {e}")
|
||
return "", ""
|
||
|
||
return full_response, thinking_content
|