feat: persist api usage tokens
This commit is contained in:
parent
e42a924429
commit
dd32db7677
@ -494,6 +494,7 @@ const appOptions = {
|
||||
...mapActions(useResourceStore, {
|
||||
resourceUpdateCurrentContextTokens: 'updateCurrentContextTokens',
|
||||
resourceFetchConversationTokenStatistics: 'fetchConversationTokenStatistics',
|
||||
resourceSetCurrentContextTokens: 'setCurrentContextTokens',
|
||||
resourceToggleTokenPanel: 'toggleTokenPanel',
|
||||
resourceApplyStatusSnapshot: 'applyStatusSnapshot',
|
||||
resourceUpdateContainerStatus: 'updateContainerStatus',
|
||||
|
||||
@ -462,8 +462,13 @@ export async function initializeLegacySocket(ctx: any) {
|
||||
|
||||
console.log(`累计Token统计更新: 输入=${data.cumulative_input_tokens}, 输出=${data.cumulative_output_tokens}, 总计=${data.cumulative_total_tokens}`);
|
||||
|
||||
const hasContextTokens = typeof data.current_context_tokens === 'number';
|
||||
if (hasContextTokens && typeof ctx.resourceSetCurrentContextTokens === 'function') {
|
||||
ctx.resourceSetCurrentContextTokens(data.current_context_tokens);
|
||||
} else {
|
||||
// 同时更新当前上下文Token(关键修复)
|
||||
ctx.updateCurrentContextTokens();
|
||||
}
|
||||
|
||||
ctx.$forceUpdate();
|
||||
}
|
||||
|
||||
@ -87,6 +87,9 @@ export const useResourceStore = defineStore('resource', {
|
||||
cumulative_total_tokens: 0
|
||||
};
|
||||
},
|
||||
setCurrentContextTokens(value: number) {
|
||||
this.currentContextTokens = value || 0;
|
||||
},
|
||||
toggleTokenPanel() {
|
||||
this.tokenPanelCollapsed = !this.tokenPanelCollapsed;
|
||||
},
|
||||
@ -120,6 +123,9 @@ export const useResourceStore = defineStore('resource', {
|
||||
this.currentConversationTokens.cumulative_input_tokens = data.data.total_input_tokens || 0;
|
||||
this.currentConversationTokens.cumulative_output_tokens = data.data.total_output_tokens || 0;
|
||||
this.currentConversationTokens.cumulative_total_tokens = data.data.total_tokens || 0;
|
||||
if (typeof data.data.current_context_tokens === 'number') {
|
||||
this.currentContextTokens = data.data.current_context_tokens;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('获取Token统计异常:', error);
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
import os
|
||||
import json
|
||||
import tiktoken
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pathlib import Path
|
||||
@ -55,13 +54,6 @@ class ContextManager:
|
||||
self.auto_save_enabled = True
|
||||
self.main_terminal = None # 由宿主终端在初始化后回填,用于工具定义访问
|
||||
|
||||
# 新增:Token计算相关
|
||||
try:
|
||||
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||
except Exception as e:
|
||||
print(f"⚠️ tiktoken初始化失败: {e}")
|
||||
self.encoding = None
|
||||
|
||||
# 用于接收Web终端的回调函数
|
||||
self._web_terminal_callback = None
|
||||
self._focused_files = {}
|
||||
@ -187,87 +179,32 @@ class ContextManager:
|
||||
# 新增:Token统计相关方法
|
||||
# ===========================================
|
||||
|
||||
def calculate_input_tokens(self, messages: List[Dict], tools: List[Dict] = None) -> int:
|
||||
if not self.encoding:
|
||||
return 0
|
||||
|
||||
try:
|
||||
total_tokens = 0
|
||||
|
||||
print(f"[Debug] 开始计算输入token,messages数量: {len(messages)}")
|
||||
|
||||
# 详细分析每条消息
|
||||
for i, message in enumerate(messages):
|
||||
content = message.get("content", "")
|
||||
role = message.get("role", "unknown")
|
||||
if content:
|
||||
msg_tokens = len(self.encoding.encode(content))
|
||||
total_tokens += msg_tokens
|
||||
print(f"[Debug] 消息 {i+1} ({role}): {msg_tokens} tokens - {content[:50]}...")
|
||||
|
||||
print(f"[Debug] 消息总token: {total_tokens}")
|
||||
|
||||
# 工具定义
|
||||
if tools:
|
||||
tools_str = json.dumps(tools, ensure_ascii=False)
|
||||
tools_tokens = len(self.encoding.encode(tools_str))
|
||||
total_tokens += tools_tokens
|
||||
print(f"[Debug] 工具定义token: {tools_tokens}")
|
||||
|
||||
print(f"[Debug] 最终输入token: {total_tokens}")
|
||||
return total_tokens
|
||||
except Exception as e:
|
||||
print(f"计算输入token失败: {e}")
|
||||
return 0
|
||||
|
||||
def calculate_output_tokens(self, ai_content: str) -> int:
|
||||
def apply_usage_statistics(self, usage: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
计算AI输出的token数量
|
||||
|
||||
Args:
|
||||
ai_content: AI输出的完整内容(包括thinking、文本、工具调用)
|
||||
|
||||
Returns:
|
||||
int: 输出token数量
|
||||
"""
|
||||
if not self.encoding or not ai_content:
|
||||
return 0
|
||||
|
||||
try:
|
||||
return len(self.encoding.encode(ai_content))
|
||||
except Exception as e:
|
||||
print(f"计算输出token失败: {e}")
|
||||
return 0
|
||||
|
||||
def update_token_statistics(self, input_tokens: int, output_tokens: int) -> bool:
|
||||
"""
|
||||
更新当前对话的token统计
|
||||
|
||||
Args:
|
||||
input_tokens: 输入token数量
|
||||
output_tokens: 输出token数量
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
根据模型返回的 usage 字段更新token统计
|
||||
"""
|
||||
if not self.current_conversation_id:
|
||||
print("⚠️ 没有当前对话ID,跳过token统计更新")
|
||||
print("⚠️ 没有当前对话ID,跳过usage统计更新")
|
||||
return False
|
||||
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens") or 0)
|
||||
completion_tokens = int(usage.get("completion_tokens") or 0)
|
||||
total_tokens = int(usage.get("total_tokens") or (prompt_tokens + completion_tokens))
|
||||
|
||||
success = self.conversation_manager.update_token_statistics(
|
||||
self.current_conversation_id,
|
||||
input_tokens,
|
||||
output_tokens
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens
|
||||
)
|
||||
|
||||
if success:
|
||||
# 广播token更新事件
|
||||
self.safe_broadcast_token_update()
|
||||
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"更新token统计失败: {e}")
|
||||
print(f"更新usage统计失败: {e}")
|
||||
return False
|
||||
|
||||
def get_conversation_token_statistics(self, conversation_id: str = None) -> Optional[Dict]:
|
||||
@ -286,6 +223,15 @@ class ContextManager:
|
||||
|
||||
return self.conversation_manager.get_token_statistics(target_id)
|
||||
|
||||
def get_current_context_tokens(self, conversation_id: str = None) -> int:
|
||||
"""
|
||||
获取最近一次请求的上下文token数量
|
||||
"""
|
||||
stats = self.get_conversation_token_statistics(conversation_id)
|
||||
if not stats:
|
||||
return 0
|
||||
return stats.get("current_context_tokens", 0)
|
||||
|
||||
# ===========================================
|
||||
# 新增:对话持久化相关方法
|
||||
# ===========================================
|
||||
@ -647,6 +593,7 @@ class ContextManager:
|
||||
'cumulative_input_tokens': cumulative_stats.get("total_input_tokens", 0) if cumulative_stats else 0,
|
||||
'cumulative_output_tokens': cumulative_stats.get("total_output_tokens", 0) if cumulative_stats else 0,
|
||||
'cumulative_total_tokens': cumulative_stats.get("total_tokens", 0) if cumulative_stats else 0,
|
||||
'current_context_tokens': cumulative_stats.get("current_context_tokens", 0) if cumulative_stats else 0,
|
||||
'updated_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@ -714,33 +661,9 @@ class ContextManager:
|
||||
# 自动保存
|
||||
self.auto_save_conversation()
|
||||
|
||||
# 特殊处理:如果是用户消息,需要计算并更新输入token
|
||||
if role == "user":
|
||||
self._handle_user_message_token_update()
|
||||
else:
|
||||
# 其他消息只需要广播现有统计
|
||||
print(f"[Debug] 添加{role}消息后广播token更新")
|
||||
self.safe_broadcast_token_update()
|
||||
|
||||
def _handle_user_message_token_update(self):
|
||||
"""处理用户消息的token更新(计算输入token并更新统计)"""
|
||||
try:
|
||||
print(f"[Debug] 用户发送消息,开始计算输入token")
|
||||
|
||||
# 需要访问web_terminal来构建完整的messages
|
||||
# 这里有个问题:add_conversation是在用户消息添加后调用的
|
||||
# 但我们需要构建包含这条消息的完整context来计算输入token
|
||||
|
||||
# 临时解决方案:延迟计算,让web_server负责在构建messages后计算输入token
|
||||
# 这里只广播现有统计
|
||||
print(f"[Debug] 用户消息添加完成,广播现有token统计")
|
||||
self.safe_broadcast_token_update()
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Debug] 处理用户消息token更新失败: {e}")
|
||||
# 失败时仍然广播现有统计
|
||||
self.safe_broadcast_token_update()
|
||||
|
||||
def add_tool_result(self, tool_call_id: str, function_name: str, result: str):
|
||||
"""添加工具调用结果(保留方法以兼容)"""
|
||||
self.add_conversation(
|
||||
|
||||
@ -16,7 +16,6 @@ except ImportError:
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
from config import DATA_DIR
|
||||
import tiktoken
|
||||
|
||||
@dataclass
|
||||
class ConversationMetadata:
|
||||
@ -44,13 +43,6 @@ class ConversationManager:
|
||||
self._ensure_directories()
|
||||
self._load_index()
|
||||
|
||||
# 初始化tiktoken编码器
|
||||
try:
|
||||
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||
except Exception as e:
|
||||
print(f"⚠️ tiktoken初始化失败: {e}")
|
||||
self.encoding = None
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""确保必要的目录存在"""
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -219,10 +211,13 @@ class ConversationManager:
|
||||
|
||||
def _initialize_token_statistics(self) -> Dict:
|
||||
"""初始化Token统计结构"""
|
||||
now = datetime.now().isoformat()
|
||||
return {
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"updated_at": datetime.now().isoformat()
|
||||
"total_tokens": 0,
|
||||
"current_context_tokens": 0,
|
||||
"updated_at": now
|
||||
}
|
||||
|
||||
def _validate_token_statistics(self, data: Dict) -> Dict:
|
||||
@ -230,21 +225,20 @@ class ConversationManager:
|
||||
token_stats = data.get("token_statistics", {})
|
||||
|
||||
# 确保必要字段存在
|
||||
if "total_input_tokens" not in token_stats:
|
||||
token_stats["total_input_tokens"] = 0
|
||||
if "total_output_tokens" not in token_stats:
|
||||
token_stats["total_output_tokens"] = 0
|
||||
if "updated_at" not in token_stats:
|
||||
token_stats["updated_at"] = datetime.now().isoformat()
|
||||
defaults = self._initialize_token_statistics()
|
||||
for key, default_value in defaults.items():
|
||||
if key not in token_stats:
|
||||
token_stats[key] = default_value
|
||||
|
||||
# 确保数值类型正确
|
||||
try:
|
||||
token_stats["total_input_tokens"] = int(token_stats["total_input_tokens"])
|
||||
token_stats["total_output_tokens"] = int(token_stats["total_output_tokens"])
|
||||
token_stats["total_input_tokens"] = int(token_stats.get("total_input_tokens", 0))
|
||||
token_stats["total_output_tokens"] = int(token_stats.get("total_output_tokens", 0))
|
||||
token_stats["total_tokens"] = int(token_stats.get("total_tokens", 0))
|
||||
token_stats["current_context_tokens"] = int(token_stats.get("current_context_tokens", 0))
|
||||
except (ValueError, TypeError):
|
||||
print("⚠️ Token统计数据损坏,重置为0")
|
||||
token_stats["total_input_tokens"] = 0
|
||||
token_stats["total_output_tokens"] = 0
|
||||
token_stats = defaults
|
||||
|
||||
data["token_statistics"] = token_stats
|
||||
return data
|
||||
@ -466,7 +460,13 @@ class ConversationManager:
|
||||
print(f"⌘ 加载对话失败 {conversation_id}: {e}")
|
||||
return None
|
||||
|
||||
def update_token_statistics(self, conversation_id: str, input_tokens: int, output_tokens: int) -> bool:
|
||||
def update_token_statistics(
|
||||
self,
|
||||
conversation_id: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
total_tokens: int
|
||||
) -> bool:
|
||||
"""
|
||||
更新对话的Token统计
|
||||
|
||||
@ -474,6 +474,7 @@ class ConversationManager:
|
||||
conversation_id: 对话ID
|
||||
input_tokens: 输入Token数量
|
||||
output_tokens: 输出Token数量
|
||||
total_tokens: 本次请求的总Token数量(prompt+completion)
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
@ -492,6 +493,8 @@ class ConversationManager:
|
||||
token_stats = conversation_data["token_statistics"]
|
||||
token_stats["total_input_tokens"] = token_stats.get("total_input_tokens", 0) + input_tokens
|
||||
token_stats["total_output_tokens"] = token_stats.get("total_output_tokens", 0) + output_tokens
|
||||
token_stats["total_tokens"] = token_stats.get("total_tokens", 0) + total_tokens
|
||||
token_stats["current_context_tokens"] = total_tokens
|
||||
token_stats["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 保存更新
|
||||
@ -520,13 +523,14 @@ class ConversationManager:
|
||||
if not conversation_data:
|
||||
return None
|
||||
|
||||
token_stats = conversation_data.get("token_statistics", {})
|
||||
validated = self._validate_token_statistics(conversation_data)
|
||||
token_stats = validated.get("token_statistics", {})
|
||||
|
||||
# 确保基本字段存在
|
||||
result = {
|
||||
"total_input_tokens": token_stats.get("total_input_tokens", 0),
|
||||
"total_output_tokens": token_stats.get("total_output_tokens", 0),
|
||||
"total_tokens": token_stats.get("total_input_tokens", 0) + token_stats.get("total_output_tokens", 0),
|
||||
"total_tokens": token_stats.get("total_tokens", 0),
|
||||
"current_context_tokens": token_stats.get("current_context_tokens", 0),
|
||||
"updated_at": token_stats.get("updated_at"),
|
||||
"conversation_id": conversation_id
|
||||
}
|
||||
@ -536,6 +540,13 @@ class ConversationManager:
|
||||
print(f"⌘ 获取Token统计失败 {conversation_id}: {e}")
|
||||
return None
|
||||
|
||||
def get_current_context_tokens(self, conversation_id: str) -> int:
|
||||
"""获取最近一次请求的上下文token"""
|
||||
stats = self.get_token_statistics(conversation_id)
|
||||
if not stats:
|
||||
return 0
|
||||
return stats.get("current_context_tokens", 0)
|
||||
|
||||
def get_conversation_list(self, limit: int = 50, offset: int = 0) -> Dict:
|
||||
"""
|
||||
获取对话列表
|
||||
@ -807,48 +818,3 @@ class ConversationManager:
|
||||
"""设置当前对话ID"""
|
||||
self.current_conversation_id = conversation_id
|
||||
|
||||
def calculate_conversation_tokens(self, conversation_id: str, context_manager=None, focused_files=None, terminal_content="") -> dict:
|
||||
"""计算对话的真实API token消耗"""
|
||||
try:
|
||||
if not context_manager:
|
||||
return {"total_tokens": 0}
|
||||
|
||||
conversation_data = self.load_conversation(conversation_id)
|
||||
if not conversation_data:
|
||||
return {"total_tokens": 0}
|
||||
|
||||
# 使用宿主终端的构建流程以贴合真实API请求
|
||||
if getattr(context_manager, "main_terminal", None):
|
||||
main_terminal = context_manager.main_terminal
|
||||
context = main_terminal.build_context()
|
||||
messages = main_terminal.build_messages(context, "")
|
||||
tools = main_terminal.define_tools()
|
||||
else:
|
||||
context = context_manager.build_main_context(memory_content="")
|
||||
messages = context_manager.build_messages(context, "")
|
||||
tools = self._get_tools_definition(context_manager) or []
|
||||
|
||||
total_tokens = context_manager.calculate_input_tokens(messages, tools)
|
||||
|
||||
return {"total_tokens": total_tokens}
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算token失败: {e}")
|
||||
return {"total_tokens": 0}
|
||||
def _get_tools_definition(self, context_manager):
|
||||
"""获取工具定义"""
|
||||
try:
|
||||
# 需要找到工具定义的来源,通常在 main_terminal 中
|
||||
# 你需要找到 main_terminal 的引用或者 define_tools 方法
|
||||
|
||||
# 方法1: 如果 context_manager 有 main_terminal 引用
|
||||
if hasattr(context_manager, 'main_terminal') and context_manager.main_terminal:
|
||||
return context_manager.main_terminal.define_tools()
|
||||
|
||||
# 方法2: 如果有其他方式获取工具定义
|
||||
# 你需要去找一下在哪里调用了 calculate_input_tokens,看看 tools 参数是怎么传的
|
||||
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"获取工具定义失败: {e}")
|
||||
return []
|
||||
|
||||
197
web_server.py
197
web_server.py
@ -358,6 +358,9 @@ def format_tool_result_notice(tool_name: str, tool_call_id: Optional[str], conte
|
||||
|
||||
# 创建调试日志文件
|
||||
DEBUG_LOG_FILE = Path(LOGS_DIR).expanduser().resolve() / "debug_stream.log"
|
||||
CHUNK_BACKEND_LOG_FILE = Path(LOGS_DIR).expanduser().resolve() / "chunk_backend.log"
|
||||
CHUNK_FRONTEND_LOG_FILE = Path(LOGS_DIR).expanduser().resolve() / "chunk_frontend.log"
|
||||
STREAMING_DEBUG_LOG_FILE = Path(LOGS_DIR).expanduser().resolve() / "streaming_debug.log"
|
||||
UPLOAD_FOLDER_NAME = "user_upload"
|
||||
|
||||
|
||||
@ -586,14 +589,41 @@ def reset_system_state(terminal: Optional[WebTerminal]):
|
||||
debug_log(f"错误详情: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def debug_log(message):
|
||||
"""写入调试日志"""
|
||||
DEBUG_LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with DEBUG_LOG_FILE.open('a', encoding='utf-8') as f:
|
||||
def _write_log(file_path: Path, message: str) -> None:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open('a', encoding='utf-8') as f:
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
|
||||
f.write(f"[{timestamp}] {message}\n")
|
||||
|
||||
|
||||
def debug_log(message):
|
||||
"""写入调试日志"""
|
||||
_write_log(DEBUG_LOG_FILE, message)
|
||||
|
||||
|
||||
def log_backend_chunk(conversation_id: str, iteration: int, chunk_index: int, elapsed: float, char_len: int, content_preview: str):
|
||||
preview = content_preview.replace('\n', '\\n')
|
||||
_write_log(
|
||||
CHUNK_BACKEND_LOG_FILE,
|
||||
f"conv={conversation_id or 'unknown'} iter={iteration} chunk={chunk_index} elapsed={elapsed:.3f}s len={char_len} preview={preview}"
|
||||
)
|
||||
|
||||
|
||||
def log_frontend_chunk(conversation_id: str, chunk_index: int, elapsed: float, char_len: int, client_ts: float):
|
||||
_write_log(
|
||||
CHUNK_FRONTEND_LOG_FILE,
|
||||
f"conv={conversation_id or 'unknown'} chunk={chunk_index} elapsed={elapsed:.3f}s len={char_len} client_ts={client_ts}"
|
||||
)
|
||||
|
||||
|
||||
def log_streaming_debug_entry(data: Dict[str, Any]):
|
||||
try:
|
||||
serialized = json.dumps(data, ensure_ascii=False)
|
||||
except Exception:
|
||||
serialized = str(data)
|
||||
_write_log(STREAMING_DEBUG_LOG_FILE, serialized)
|
||||
|
||||
|
||||
def get_thinking_state(terminal: WebTerminal) -> Dict[str, Any]:
|
||||
"""获取(或初始化)思考调度状态。"""
|
||||
state = getattr(terminal, "_thinking_state", None)
|
||||
@ -1884,6 +1914,27 @@ def handle_message(data):
|
||||
# 传递客户端ID
|
||||
socketio.start_background_task(process_message_task, terminal, message, send_to_client, client_sid)
|
||||
|
||||
|
||||
@socketio.on('client_chunk_log')
|
||||
def handle_client_chunk_log(data):
|
||||
"""前端chunk日志上报"""
|
||||
conversation_id = data.get('conversation_id')
|
||||
chunk_index = int(data.get('index') or data.get('chunk_index') or 0)
|
||||
elapsed = float(data.get('elapsed') or 0.0)
|
||||
length = int(data.get('length') or len(data.get('content') or ""))
|
||||
client_ts = float(data.get('ts') or 0.0)
|
||||
log_frontend_chunk(conversation_id, chunk_index, elapsed, length, client_ts)
|
||||
|
||||
|
||||
@socketio.on('client_stream_debug_log')
|
||||
def handle_client_stream_debug_log(data):
|
||||
"""前端流式调试日志"""
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
entry = dict(data)
|
||||
entry.setdefault('server_ts', time.time())
|
||||
log_streaming_debug_entry(entry)
|
||||
|
||||
# 在 web_server.py 中添加以下对话管理API接口
|
||||
# 添加在现有路由之后,@socketio 事件处理之前
|
||||
|
||||
@ -2426,6 +2477,7 @@ def detect_malformed_tool_call(text):
|
||||
async def handle_task_with_sender(terminal: WebTerminal, message, sender, client_sid):
|
||||
"""处理任务并发送消息 - 集成token统计版本"""
|
||||
web_terminal = terminal
|
||||
conversation_id = getattr(web_terminal.context_manager, "current_conversation_id", None)
|
||||
|
||||
# 如果是思考模式,重置状态
|
||||
if web_terminal.thinking_mode:
|
||||
@ -3076,22 +3128,13 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
mark_force_thinking(web_terminal, reason="tool_limit")
|
||||
break
|
||||
|
||||
# === 修改:每次API调用前都计算输入token ===
|
||||
try:
|
||||
input_tokens = web_terminal.context_manager.calculate_input_tokens(messages, tools)
|
||||
debug_log(f"第{iteration + 1}次API调用输入token: {input_tokens}")
|
||||
|
||||
# 更新输入token统计
|
||||
web_terminal.context_manager.update_token_statistics(input_tokens, 0)
|
||||
except Exception as e:
|
||||
debug_log(f"输入token统计失败: {e}")
|
||||
|
||||
apply_thinking_schedule(web_terminal)
|
||||
|
||||
full_response = ""
|
||||
tool_calls = []
|
||||
current_thinking = ""
|
||||
detected_tools = {}
|
||||
last_usage_payload = None
|
||||
|
||||
# 状态标志
|
||||
in_thinking = False
|
||||
@ -3099,65 +3142,9 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
thinking_ended = False
|
||||
text_started = False
|
||||
text_has_content = False
|
||||
TEXT_BUFFER_MAX_CHARS = 1
|
||||
TEXT_BUFFER_MAX_INTERVAL = 0.0
|
||||
TEXT_BUFFER_FLUSH_CHARS = 1
|
||||
text_chunk_buffer: deque[str] = deque()
|
||||
text_chunk_buffer_size = 0
|
||||
last_text_flush_time = time.time()
|
||||
TEXT_BUFFER_CHAR_DELAY = 0.02
|
||||
|
||||
def _drain_text_buffer(force: bool = False) -> bool:
|
||||
nonlocal text_chunk_buffer, text_chunk_buffer_size, last_text_flush_time
|
||||
if not text_chunk_buffer:
|
||||
return False
|
||||
|
||||
drain_all = force or TEXT_BUFFER_MAX_INTERVAL == 0.0
|
||||
sent = False
|
||||
while text_chunk_buffer:
|
||||
now = time.time()
|
||||
should_flush = (
|
||||
force
|
||||
or text_chunk_buffer_size >= TEXT_BUFFER_MAX_CHARS
|
||||
or TEXT_BUFFER_MAX_INTERVAL == 0.0
|
||||
or (TEXT_BUFFER_MAX_INTERVAL > 0 and (now - last_text_flush_time) >= TEXT_BUFFER_MAX_INTERVAL)
|
||||
)
|
||||
if not should_flush:
|
||||
break
|
||||
|
||||
batch_size = text_chunk_buffer_size if drain_all else max(1, min(text_chunk_buffer_size, TEXT_BUFFER_FLUSH_CHARS or 1))
|
||||
pieces: List[str] = []
|
||||
remaining = batch_size
|
||||
|
||||
while text_chunk_buffer and remaining > 0:
|
||||
chunk = text_chunk_buffer.popleft()
|
||||
chunk_len = len(chunk)
|
||||
if chunk_len <= remaining:
|
||||
pieces.append(chunk)
|
||||
remaining -= chunk_len
|
||||
else:
|
||||
pieces.append(chunk[:remaining])
|
||||
text_chunk_buffer.appendleft(chunk[remaining:])
|
||||
chunk_len = remaining
|
||||
remaining = 0
|
||||
text_chunk_buffer_size -= chunk_len
|
||||
|
||||
if not pieces:
|
||||
break
|
||||
|
||||
sender('text_chunk', {'content': "".join(pieces)})
|
||||
last_text_flush_time = now
|
||||
sent = True
|
||||
|
||||
if not drain_all:
|
||||
break
|
||||
return sent
|
||||
|
||||
async def flush_text_buffer(force: bool = False):
|
||||
sent = _drain_text_buffer(force)
|
||||
if sent and not force and TEXT_BUFFER_CHAR_DELAY > 0:
|
||||
await asyncio.sleep(TEXT_BUFFER_CHAR_DELAY)
|
||||
text_streaming = False
|
||||
text_chunk_index = 0
|
||||
last_text_chunk_time: Optional[float] = None
|
||||
|
||||
# 计数器
|
||||
chunk_count = 0
|
||||
@ -3223,6 +3210,10 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
if finish_reason:
|
||||
last_finish_reason = finish_reason
|
||||
|
||||
usage_info = choice.get("usage")
|
||||
if usage_info:
|
||||
last_usage_payload = usage_info
|
||||
|
||||
# 处理思考内容
|
||||
if "reasoning_content" in delta:
|
||||
reasoning_content = delta["reasoning_content"]
|
||||
@ -3454,10 +3445,23 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
full_response += content
|
||||
accumulated_response += content
|
||||
text_has_content = True
|
||||
for ch in content:
|
||||
text_chunk_buffer.append(ch)
|
||||
text_chunk_buffer_size += 1
|
||||
await flush_text_buffer()
|
||||
emit_time = time.time()
|
||||
elapsed = 0.0 if last_text_chunk_time is None else emit_time - last_text_chunk_time
|
||||
last_text_chunk_time = emit_time
|
||||
text_chunk_index += 1
|
||||
log_backend_chunk(
|
||||
conversation_id,
|
||||
iteration + 1,
|
||||
text_chunk_index,
|
||||
elapsed,
|
||||
len(content),
|
||||
content[:32]
|
||||
)
|
||||
sender('text_chunk', {
|
||||
'content': content,
|
||||
'index': text_chunk_index,
|
||||
'elapsed': elapsed
|
||||
})
|
||||
|
||||
# 收集工具调用 - 实时发送准备状态
|
||||
if "tool_calls" in delta:
|
||||
@ -3508,24 +3512,20 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
return
|
||||
|
||||
# === API响应完成后只计算输出token ===
|
||||
if last_usage_payload:
|
||||
try:
|
||||
ai_output_content = full_response or append_result.get("assistant_content") or modify_result.get("assistant_content") or ""
|
||||
if tool_calls:
|
||||
ai_output_content += json.dumps(tool_calls, ensure_ascii=False)
|
||||
|
||||
if ai_output_content.strip():
|
||||
output_tokens = web_terminal.context_manager.calculate_output_tokens(ai_output_content)
|
||||
debug_log(f"第{iteration + 1}次API调用输出token: {output_tokens}")
|
||||
|
||||
# 只更新输出token统计
|
||||
web_terminal.context_manager.update_token_statistics(0, output_tokens)
|
||||
else:
|
||||
debug_log("没有AI输出内容,跳过输出token统计")
|
||||
web_terminal.context_manager.apply_usage_statistics(last_usage_payload)
|
||||
debug_log(
|
||||
f"Usage统计: prompt={last_usage_payload.get('prompt_tokens', 0)}, "
|
||||
f"completion={last_usage_payload.get('completion_tokens', 0)}, "
|
||||
f"total={last_usage_payload.get('total_tokens', 0)}"
|
||||
)
|
||||
except Exception as e:
|
||||
debug_log(f"输出token统计失败: {e}")
|
||||
debug_log(f"Usage统计更新失败: {e}")
|
||||
else:
|
||||
debug_log("未获取到usage字段,跳过token统计更新")
|
||||
|
||||
# 流结束后的处理
|
||||
await flush_text_buffer(force=True)
|
||||
debug_log(f"\n流结束统计:")
|
||||
debug_log(f" 总chunks: {chunk_count}")
|
||||
debug_log(f" 思考chunks: {reasoning_chunks}")
|
||||
@ -3548,7 +3548,6 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
|
||||
# 确保text_end事件被发送
|
||||
if text_started and text_has_content and not append_result["handled"] and not modify_result["handled"]:
|
||||
await flush_text_buffer(force=True)
|
||||
debug_log(f"发送text_end事件,完整内容长度: {len(full_response)}")
|
||||
sender('text_end', {'full_content': full_response})
|
||||
await asyncio.sleep(0.1)
|
||||
@ -4251,20 +4250,12 @@ def get_conversation_token_statistics(conversation_id, terminal: WebTerminal, wo
|
||||
def get_conversation_tokens(conversation_id, terminal: WebTerminal, workspace: UserWorkspace, username: str):
|
||||
"""获取对话的当前完整上下文token数(包含所有动态内容)"""
|
||||
try:
|
||||
# 获取当前聚焦文件状态
|
||||
focused_files = terminal.get_focused_files_info()
|
||||
|
||||
# 计算完整token
|
||||
tokens = terminal.context_manager.conversation_manager.calculate_conversation_tokens(
|
||||
conversation_id=conversation_id,
|
||||
context_manager=terminal.context_manager,
|
||||
focused_files=focused_files,
|
||||
terminal_content=""
|
||||
)
|
||||
|
||||
current_tokens = terminal.context_manager.get_current_context_tokens(conversation_id)
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"data": tokens
|
||||
"data": {
|
||||
"total_tokens": current_tokens
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user