feat: persist api usage tokens

This commit is contained in:
JOJO 2025-11-29 17:45:32 +08:00
parent e42a924429
commit dd32db7677
6 changed files with 167 additions and 275 deletions

View File

@ -494,6 +494,7 @@ const appOptions = {
...mapActions(useResourceStore, { ...mapActions(useResourceStore, {
resourceUpdateCurrentContextTokens: 'updateCurrentContextTokens', resourceUpdateCurrentContextTokens: 'updateCurrentContextTokens',
resourceFetchConversationTokenStatistics: 'fetchConversationTokenStatistics', resourceFetchConversationTokenStatistics: 'fetchConversationTokenStatistics',
resourceSetCurrentContextTokens: 'setCurrentContextTokens',
resourceToggleTokenPanel: 'toggleTokenPanel', resourceToggleTokenPanel: 'toggleTokenPanel',
resourceApplyStatusSnapshot: 'applyStatusSnapshot', resourceApplyStatusSnapshot: 'applyStatusSnapshot',
resourceUpdateContainerStatus: 'updateContainerStatus', resourceUpdateContainerStatus: 'updateContainerStatus',

View File

@ -462,8 +462,13 @@ export async function initializeLegacySocket(ctx: any) {
console.log(`累计Token统计更新: 输入=${data.cumulative_input_tokens}, 输出=${data.cumulative_output_tokens}, 总计=${data.cumulative_total_tokens}`); console.log(`累计Token统计更新: 输入=${data.cumulative_input_tokens}, 输出=${data.cumulative_output_tokens}, 总计=${data.cumulative_total_tokens}`);
// 同时更新当前上下文Token关键修复 const hasContextTokens = typeof data.current_context_tokens === 'number';
ctx.updateCurrentContextTokens(); if (hasContextTokens && typeof ctx.resourceSetCurrentContextTokens === 'function') {
ctx.resourceSetCurrentContextTokens(data.current_context_tokens);
} else {
// 同时更新当前上下文Token关键修复
ctx.updateCurrentContextTokens();
}
ctx.$forceUpdate(); ctx.$forceUpdate();
} }

View File

@ -87,6 +87,9 @@ export const useResourceStore = defineStore('resource', {
cumulative_total_tokens: 0 cumulative_total_tokens: 0
}; };
}, },
setCurrentContextTokens(value: number) {
this.currentContextTokens = value || 0;
},
toggleTokenPanel() { toggleTokenPanel() {
this.tokenPanelCollapsed = !this.tokenPanelCollapsed; this.tokenPanelCollapsed = !this.tokenPanelCollapsed;
}, },
@ -120,6 +123,9 @@ export const useResourceStore = defineStore('resource', {
this.currentConversationTokens.cumulative_input_tokens = data.data.total_input_tokens || 0; this.currentConversationTokens.cumulative_input_tokens = data.data.total_input_tokens || 0;
this.currentConversationTokens.cumulative_output_tokens = data.data.total_output_tokens || 0; this.currentConversationTokens.cumulative_output_tokens = data.data.total_output_tokens || 0;
this.currentConversationTokens.cumulative_total_tokens = data.data.total_tokens || 0; this.currentConversationTokens.cumulative_total_tokens = data.data.total_tokens || 0;
if (typeof data.data.current_context_tokens === 'number') {
this.currentContextTokens = data.data.current_context_tokens;
}
} }
} catch (error) { } catch (error) {
console.warn('获取Token统计异常:', error); console.warn('获取Token统计异常:', error);

View File

@ -2,7 +2,6 @@
import os import os
import json import json
import tiktoken
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from pathlib import Path from pathlib import Path
@ -55,13 +54,6 @@ class ContextManager:
self.auto_save_enabled = True self.auto_save_enabled = True
self.main_terminal = None # 由宿主终端在初始化后回填,用于工具定义访问 self.main_terminal = None # 由宿主终端在初始化后回填,用于工具定义访问
# 新增Token计算相关
try:
self.encoding = tiktoken.get_encoding("cl100k_base")
except Exception as e:
print(f"⚠️ tiktoken初始化失败: {e}")
self.encoding = None
# 用于接收Web终端的回调函数 # 用于接收Web终端的回调函数
self._web_terminal_callback = None self._web_terminal_callback = None
self._focused_files = {} self._focused_files = {}
@ -187,87 +179,32 @@ class ContextManager:
# 新增Token统计相关方法 # 新增Token统计相关方法
# =========================================== # ===========================================
def calculate_input_tokens(self, messages: List[Dict], tools: List[Dict] = None) -> int: def apply_usage_statistics(self, usage: Dict[str, Any]) -> bool:
if not self.encoding:
return 0
try:
total_tokens = 0
print(f"[Debug] 开始计算输入tokenmessages数量: {len(messages)}")
# 详细分析每条消息
for i, message in enumerate(messages):
content = message.get("content", "")
role = message.get("role", "unknown")
if content:
msg_tokens = len(self.encoding.encode(content))
total_tokens += msg_tokens
print(f"[Debug] 消息 {i+1} ({role}): {msg_tokens} tokens - {content[:50]}...")
print(f"[Debug] 消息总token: {total_tokens}")
# 工具定义
if tools:
tools_str = json.dumps(tools, ensure_ascii=False)
tools_tokens = len(self.encoding.encode(tools_str))
total_tokens += tools_tokens
print(f"[Debug] 工具定义token: {tools_tokens}")
print(f"[Debug] 最终输入token: {total_tokens}")
return total_tokens
except Exception as e:
print(f"计算输入token失败: {e}")
return 0
def calculate_output_tokens(self, ai_content: str) -> int:
""" """
计算AI输出的token数量 根据模型返回的 usage 字段更新token统计
Args:
ai_content: AI输出的完整内容包括thinking文本工具调用
Returns:
int: 输出token数量
"""
if not self.encoding or not ai_content:
return 0
try:
return len(self.encoding.encode(ai_content))
except Exception as e:
print(f"计算输出token失败: {e}")
return 0
def update_token_statistics(self, input_tokens: int, output_tokens: int) -> bool:
"""
更新当前对话的token统计
Args:
input_tokens: 输入token数量
output_tokens: 输出token数量
Returns:
bool: 更新是否成功
""" """
if not self.current_conversation_id: if not self.current_conversation_id:
print("⚠️ 没有当前对话ID跳过token统计更新") print("⚠️ 没有当前对话ID跳过usage统计更新")
return False return False
try: try:
prompt_tokens = int(usage.get("prompt_tokens") or 0)
completion_tokens = int(usage.get("completion_tokens") or 0)
total_tokens = int(usage.get("total_tokens") or (prompt_tokens + completion_tokens))
success = self.conversation_manager.update_token_statistics( success = self.conversation_manager.update_token_statistics(
self.current_conversation_id, self.current_conversation_id,
input_tokens, prompt_tokens,
output_tokens completion_tokens,
total_tokens
) )
if success: if success:
# 广播token更新事件
self.safe_broadcast_token_update() self.safe_broadcast_token_update()
return success return success
except Exception as e: except Exception as e:
print(f"更新token统计失败: {e}") print(f"更新usage统计失败: {e}")
return False return False
def get_conversation_token_statistics(self, conversation_id: str = None) -> Optional[Dict]: def get_conversation_token_statistics(self, conversation_id: str = None) -> Optional[Dict]:
@ -286,6 +223,15 @@ class ContextManager:
return self.conversation_manager.get_token_statistics(target_id) return self.conversation_manager.get_token_statistics(target_id)
def get_current_context_tokens(self, conversation_id: str = None) -> int:
"""
获取最近一次请求的上下文token数量
"""
stats = self.get_conversation_token_statistics(conversation_id)
if not stats:
return 0
return stats.get("current_context_tokens", 0)
# =========================================== # ===========================================
# 新增:对话持久化相关方法 # 新增:对话持久化相关方法
# =========================================== # ===========================================
@ -647,6 +593,7 @@ class ContextManager:
'cumulative_input_tokens': cumulative_stats.get("total_input_tokens", 0) if cumulative_stats else 0, 'cumulative_input_tokens': cumulative_stats.get("total_input_tokens", 0) if cumulative_stats else 0,
'cumulative_output_tokens': cumulative_stats.get("total_output_tokens", 0) if cumulative_stats else 0, 'cumulative_output_tokens': cumulative_stats.get("total_output_tokens", 0) if cumulative_stats else 0,
'cumulative_total_tokens': cumulative_stats.get("total_tokens", 0) if cumulative_stats else 0, 'cumulative_total_tokens': cumulative_stats.get("total_tokens", 0) if cumulative_stats else 0,
'current_context_tokens': cumulative_stats.get("current_context_tokens", 0) if cumulative_stats else 0,
'updated_at': datetime.now().isoformat() 'updated_at': datetime.now().isoformat()
} }
@ -714,32 +661,8 @@ class ContextManager:
# 自动保存 # 自动保存
self.auto_save_conversation() self.auto_save_conversation()
# 特殊处理如果是用户消息需要计算并更新输入token print(f"[Debug] 添加{role}消息后广播token更新")
if role == "user": self.safe_broadcast_token_update()
self._handle_user_message_token_update()
else:
# 其他消息只需要广播现有统计
print(f"[Debug] 添加{role}消息后广播token更新")
self.safe_broadcast_token_update()
def _handle_user_message_token_update(self):
"""处理用户消息的token更新计算输入token并更新统计"""
try:
print(f"[Debug] 用户发送消息开始计算输入token")
# 需要访问web_terminal来构建完整的messages
# 这里有个问题add_conversation是在用户消息添加后调用的
# 但我们需要构建包含这条消息的完整context来计算输入token
# 临时解决方案延迟计算让web_server负责在构建messages后计算输入token
# 这里只广播现有统计
print(f"[Debug] 用户消息添加完成广播现有token统计")
self.safe_broadcast_token_update()
except Exception as e:
print(f"[Debug] 处理用户消息token更新失败: {e}")
# 失败时仍然广播现有统计
self.safe_broadcast_token_update()
def add_tool_result(self, tool_call_id: str, function_name: str, result: str): def add_tool_result(self, tool_call_id: str, function_name: str, result: str):
"""添加工具调用结果(保留方法以兼容)""" """添加工具调用结果(保留方法以兼容)"""

View File

@ -16,7 +16,6 @@ except ImportError:
if str(project_root) not in sys.path: if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
from config import DATA_DIR from config import DATA_DIR
import tiktoken
@dataclass @dataclass
class ConversationMetadata: class ConversationMetadata:
@ -44,13 +43,6 @@ class ConversationManager:
self._ensure_directories() self._ensure_directories()
self._load_index() self._load_index()
# 初始化tiktoken编码器
try:
self.encoding = tiktoken.get_encoding("cl100k_base")
except Exception as e:
print(f"⚠️ tiktoken初始化失败: {e}")
self.encoding = None
def _ensure_directories(self): def _ensure_directories(self):
"""确保必要的目录存在""" """确保必要的目录存在"""
self.base_dir.mkdir(parents=True, exist_ok=True) self.base_dir.mkdir(parents=True, exist_ok=True)
@ -219,10 +211,13 @@ class ConversationManager:
def _initialize_token_statistics(self) -> Dict: def _initialize_token_statistics(self) -> Dict:
"""初始化Token统计结构""" """初始化Token统计结构"""
now = datetime.now().isoformat()
return { return {
"total_input_tokens": 0, "total_input_tokens": 0,
"total_output_tokens": 0, "total_output_tokens": 0,
"updated_at": datetime.now().isoformat() "total_tokens": 0,
"current_context_tokens": 0,
"updated_at": now
} }
def _validate_token_statistics(self, data: Dict) -> Dict: def _validate_token_statistics(self, data: Dict) -> Dict:
@ -230,21 +225,20 @@ class ConversationManager:
token_stats = data.get("token_statistics", {}) token_stats = data.get("token_statistics", {})
# 确保必要字段存在 # 确保必要字段存在
if "total_input_tokens" not in token_stats: defaults = self._initialize_token_statistics()
token_stats["total_input_tokens"] = 0 for key, default_value in defaults.items():
if "total_output_tokens" not in token_stats: if key not in token_stats:
token_stats["total_output_tokens"] = 0 token_stats[key] = default_value
if "updated_at" not in token_stats:
token_stats["updated_at"] = datetime.now().isoformat()
# 确保数值类型正确 # 确保数值类型正确
try: try:
token_stats["total_input_tokens"] = int(token_stats["total_input_tokens"]) token_stats["total_input_tokens"] = int(token_stats.get("total_input_tokens", 0))
token_stats["total_output_tokens"] = int(token_stats["total_output_tokens"]) token_stats["total_output_tokens"] = int(token_stats.get("total_output_tokens", 0))
token_stats["total_tokens"] = int(token_stats.get("total_tokens", 0))
token_stats["current_context_tokens"] = int(token_stats.get("current_context_tokens", 0))
except (ValueError, TypeError): except (ValueError, TypeError):
print("⚠️ Token统计数据损坏重置为0") print("⚠️ Token统计数据损坏重置为0")
token_stats["total_input_tokens"] = 0 token_stats = defaults
token_stats["total_output_tokens"] = 0
data["token_statistics"] = token_stats data["token_statistics"] = token_stats
return data return data
@ -466,7 +460,13 @@ class ConversationManager:
print(f"⌘ 加载对话失败 {conversation_id}: {e}") print(f"⌘ 加载对话失败 {conversation_id}: {e}")
return None return None
def update_token_statistics(self, conversation_id: str, input_tokens: int, output_tokens: int) -> bool: def update_token_statistics(
self,
conversation_id: str,
input_tokens: int,
output_tokens: int,
total_tokens: int
) -> bool:
""" """
更新对话的Token统计 更新对话的Token统计
@ -474,6 +474,7 @@ class ConversationManager:
conversation_id: 对话ID conversation_id: 对话ID
input_tokens: 输入Token数量 input_tokens: 输入Token数量
output_tokens: 输出Token数量 output_tokens: 输出Token数量
total_tokens: 本次请求的总Token数量prompt+completion
Returns: Returns:
bool: 更新是否成功 bool: 更新是否成功
@ -492,6 +493,8 @@ class ConversationManager:
token_stats = conversation_data["token_statistics"] token_stats = conversation_data["token_statistics"]
token_stats["total_input_tokens"] = token_stats.get("total_input_tokens", 0) + input_tokens token_stats["total_input_tokens"] = token_stats.get("total_input_tokens", 0) + input_tokens
token_stats["total_output_tokens"] = token_stats.get("total_output_tokens", 0) + output_tokens token_stats["total_output_tokens"] = token_stats.get("total_output_tokens", 0) + output_tokens
token_stats["total_tokens"] = token_stats.get("total_tokens", 0) + total_tokens
token_stats["current_context_tokens"] = total_tokens
token_stats["updated_at"] = datetime.now().isoformat() token_stats["updated_at"] = datetime.now().isoformat()
# 保存更新 # 保存更新
@ -520,13 +523,14 @@ class ConversationManager:
if not conversation_data: if not conversation_data:
return None return None
token_stats = conversation_data.get("token_statistics", {}) validated = self._validate_token_statistics(conversation_data)
token_stats = validated.get("token_statistics", {})
# 确保基本字段存在
result = { result = {
"total_input_tokens": token_stats.get("total_input_tokens", 0), "total_input_tokens": token_stats.get("total_input_tokens", 0),
"total_output_tokens": token_stats.get("total_output_tokens", 0), "total_output_tokens": token_stats.get("total_output_tokens", 0),
"total_tokens": token_stats.get("total_input_tokens", 0) + token_stats.get("total_output_tokens", 0), "total_tokens": token_stats.get("total_tokens", 0),
"current_context_tokens": token_stats.get("current_context_tokens", 0),
"updated_at": token_stats.get("updated_at"), "updated_at": token_stats.get("updated_at"),
"conversation_id": conversation_id "conversation_id": conversation_id
} }
@ -536,6 +540,13 @@ class ConversationManager:
print(f"⌘ 获取Token统计失败 {conversation_id}: {e}") print(f"⌘ 获取Token统计失败 {conversation_id}: {e}")
return None return None
def get_current_context_tokens(self, conversation_id: str) -> int:
"""获取最近一次请求的上下文token"""
stats = self.get_token_statistics(conversation_id)
if not stats:
return 0
return stats.get("current_context_tokens", 0)
def get_conversation_list(self, limit: int = 50, offset: int = 0) -> Dict: def get_conversation_list(self, limit: int = 50, offset: int = 0) -> Dict:
""" """
获取对话列表 获取对话列表
@ -807,48 +818,3 @@ class ConversationManager:
"""设置当前对话ID""" """设置当前对话ID"""
self.current_conversation_id = conversation_id self.current_conversation_id = conversation_id
def calculate_conversation_tokens(self, conversation_id: str, context_manager=None, focused_files=None, terminal_content="") -> dict:
"""计算对话的真实API token消耗"""
try:
if not context_manager:
return {"total_tokens": 0}
conversation_data = self.load_conversation(conversation_id)
if not conversation_data:
return {"total_tokens": 0}
# 使用宿主终端的构建流程以贴合真实API请求
if getattr(context_manager, "main_terminal", None):
main_terminal = context_manager.main_terminal
context = main_terminal.build_context()
messages = main_terminal.build_messages(context, "")
tools = main_terminal.define_tools()
else:
context = context_manager.build_main_context(memory_content="")
messages = context_manager.build_messages(context, "")
tools = self._get_tools_definition(context_manager) or []
total_tokens = context_manager.calculate_input_tokens(messages, tools)
return {"total_tokens": total_tokens}
except Exception as e:
print(f"计算token失败: {e}")
return {"total_tokens": 0}
def _get_tools_definition(self, context_manager):
"""获取工具定义"""
try:
# 需要找到工具定义的来源,通常在 main_terminal 中
# 你需要找到 main_terminal 的引用或者 define_tools 方法
# 方法1: 如果 context_manager 有 main_terminal 引用
if hasattr(context_manager, 'main_terminal') and context_manager.main_terminal:
return context_manager.main_terminal.define_tools()
# 方法2: 如果有其他方式获取工具定义
# 你需要去找一下在哪里调用了 calculate_input_tokens看看 tools 参数是怎么传的
return []
except Exception as e:
print(f"获取工具定义失败: {e}")
return []

View File

@ -358,6 +358,9 @@ def format_tool_result_notice(tool_name: str, tool_call_id: Optional[str], conte
# 创建调试日志文件 # 创建调试日志文件
DEBUG_LOG_FILE = Path(LOGS_DIR).expanduser().resolve() / "debug_stream.log" DEBUG_LOG_FILE = Path(LOGS_DIR).expanduser().resolve() / "debug_stream.log"
CHUNK_BACKEND_LOG_FILE = Path(LOGS_DIR).expanduser().resolve() / "chunk_backend.log"
CHUNK_FRONTEND_LOG_FILE = Path(LOGS_DIR).expanduser().resolve() / "chunk_frontend.log"
STREAMING_DEBUG_LOG_FILE = Path(LOGS_DIR).expanduser().resolve() / "streaming_debug.log"
UPLOAD_FOLDER_NAME = "user_upload" UPLOAD_FOLDER_NAME = "user_upload"
@ -586,14 +589,41 @@ def reset_system_state(terminal: Optional[WebTerminal]):
debug_log(f"错误详情: {traceback.format_exc()}") debug_log(f"错误详情: {traceback.format_exc()}")
def debug_log(message): def _write_log(file_path: Path, message: str) -> None:
"""写入调试日志""" file_path.parent.mkdir(parents=True, exist_ok=True)
DEBUG_LOG_FILE.parent.mkdir(parents=True, exist_ok=True) with file_path.open('a', encoding='utf-8') as f:
with DEBUG_LOG_FILE.open('a', encoding='utf-8') as f:
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
f.write(f"[{timestamp}] {message}\n") f.write(f"[{timestamp}] {message}\n")
def debug_log(message):
"""写入调试日志"""
_write_log(DEBUG_LOG_FILE, message)
def log_backend_chunk(conversation_id: str, iteration: int, chunk_index: int, elapsed: float, char_len: int, content_preview: str):
preview = content_preview.replace('\n', '\\n')
_write_log(
CHUNK_BACKEND_LOG_FILE,
f"conv={conversation_id or 'unknown'} iter={iteration} chunk={chunk_index} elapsed={elapsed:.3f}s len={char_len} preview={preview}"
)
def log_frontend_chunk(conversation_id: str, chunk_index: int, elapsed: float, char_len: int, client_ts: float):
_write_log(
CHUNK_FRONTEND_LOG_FILE,
f"conv={conversation_id or 'unknown'} chunk={chunk_index} elapsed={elapsed:.3f}s len={char_len} client_ts={client_ts}"
)
def log_streaming_debug_entry(data: Dict[str, Any]):
try:
serialized = json.dumps(data, ensure_ascii=False)
except Exception:
serialized = str(data)
_write_log(STREAMING_DEBUG_LOG_FILE, serialized)
def get_thinking_state(terminal: WebTerminal) -> Dict[str, Any]: def get_thinking_state(terminal: WebTerminal) -> Dict[str, Any]:
"""获取(或初始化)思考调度状态。""" """获取(或初始化)思考调度状态。"""
state = getattr(terminal, "_thinking_state", None) state = getattr(terminal, "_thinking_state", None)
@ -1884,6 +1914,27 @@ def handle_message(data):
# 传递客户端ID # 传递客户端ID
socketio.start_background_task(process_message_task, terminal, message, send_to_client, client_sid) socketio.start_background_task(process_message_task, terminal, message, send_to_client, client_sid)
@socketio.on('client_chunk_log')
def handle_client_chunk_log(data):
"""前端chunk日志上报"""
conversation_id = data.get('conversation_id')
chunk_index = int(data.get('index') or data.get('chunk_index') or 0)
elapsed = float(data.get('elapsed') or 0.0)
length = int(data.get('length') or len(data.get('content') or ""))
client_ts = float(data.get('ts') or 0.0)
log_frontend_chunk(conversation_id, chunk_index, elapsed, length, client_ts)
@socketio.on('client_stream_debug_log')
def handle_client_stream_debug_log(data):
"""前端流式调试日志"""
if not isinstance(data, dict):
return
entry = dict(data)
entry.setdefault('server_ts', time.time())
log_streaming_debug_entry(entry)
# 在 web_server.py 中添加以下对话管理API接口 # 在 web_server.py 中添加以下对话管理API接口
# 添加在现有路由之后,@socketio 事件处理之前 # 添加在现有路由之后,@socketio 事件处理之前
@ -2426,6 +2477,7 @@ def detect_malformed_tool_call(text):
async def handle_task_with_sender(terminal: WebTerminal, message, sender, client_sid): async def handle_task_with_sender(terminal: WebTerminal, message, sender, client_sid):
"""处理任务并发送消息 - 集成token统计版本""" """处理任务并发送消息 - 集成token统计版本"""
web_terminal = terminal web_terminal = terminal
conversation_id = getattr(web_terminal.context_manager, "current_conversation_id", None)
# 如果是思考模式,重置状态 # 如果是思考模式,重置状态
if web_terminal.thinking_mode: if web_terminal.thinking_mode:
@ -3076,22 +3128,13 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
mark_force_thinking(web_terminal, reason="tool_limit") mark_force_thinking(web_terminal, reason="tool_limit")
break break
# === 修改每次API调用前都计算输入token ===
try:
input_tokens = web_terminal.context_manager.calculate_input_tokens(messages, tools)
debug_log(f"{iteration + 1}次API调用输入token: {input_tokens}")
# 更新输入token统计
web_terminal.context_manager.update_token_statistics(input_tokens, 0)
except Exception as e:
debug_log(f"输入token统计失败: {e}")
apply_thinking_schedule(web_terminal) apply_thinking_schedule(web_terminal)
full_response = "" full_response = ""
tool_calls = [] tool_calls = []
current_thinking = "" current_thinking = ""
detected_tools = {} detected_tools = {}
last_usage_payload = None
# 状态标志 # 状态标志
in_thinking = False in_thinking = False
@ -3099,65 +3142,9 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
thinking_ended = False thinking_ended = False
text_started = False text_started = False
text_has_content = False text_has_content = False
TEXT_BUFFER_MAX_CHARS = 1
TEXT_BUFFER_MAX_INTERVAL = 0.0
TEXT_BUFFER_FLUSH_CHARS = 1
text_chunk_buffer: deque[str] = deque()
text_chunk_buffer_size = 0
last_text_flush_time = time.time()
TEXT_BUFFER_CHAR_DELAY = 0.02
def _drain_text_buffer(force: bool = False) -> bool:
nonlocal text_chunk_buffer, text_chunk_buffer_size, last_text_flush_time
if not text_chunk_buffer:
return False
drain_all = force or TEXT_BUFFER_MAX_INTERVAL == 0.0
sent = False
while text_chunk_buffer:
now = time.time()
should_flush = (
force
or text_chunk_buffer_size >= TEXT_BUFFER_MAX_CHARS
or TEXT_BUFFER_MAX_INTERVAL == 0.0
or (TEXT_BUFFER_MAX_INTERVAL > 0 and (now - last_text_flush_time) >= TEXT_BUFFER_MAX_INTERVAL)
)
if not should_flush:
break
batch_size = text_chunk_buffer_size if drain_all else max(1, min(text_chunk_buffer_size, TEXT_BUFFER_FLUSH_CHARS or 1))
pieces: List[str] = []
remaining = batch_size
while text_chunk_buffer and remaining > 0:
chunk = text_chunk_buffer.popleft()
chunk_len = len(chunk)
if chunk_len <= remaining:
pieces.append(chunk)
remaining -= chunk_len
else:
pieces.append(chunk[:remaining])
text_chunk_buffer.appendleft(chunk[remaining:])
chunk_len = remaining
remaining = 0
text_chunk_buffer_size -= chunk_len
if not pieces:
break
sender('text_chunk', {'content': "".join(pieces)})
last_text_flush_time = now
sent = True
if not drain_all:
break
return sent
async def flush_text_buffer(force: bool = False):
sent = _drain_text_buffer(force)
if sent and not force and TEXT_BUFFER_CHAR_DELAY > 0:
await asyncio.sleep(TEXT_BUFFER_CHAR_DELAY)
text_streaming = False text_streaming = False
text_chunk_index = 0
last_text_chunk_time: Optional[float] = None
# 计数器 # 计数器
chunk_count = 0 chunk_count = 0
@ -3223,6 +3210,10 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
if finish_reason: if finish_reason:
last_finish_reason = finish_reason last_finish_reason = finish_reason
usage_info = choice.get("usage")
if usage_info:
last_usage_payload = usage_info
# 处理思考内容 # 处理思考内容
if "reasoning_content" in delta: if "reasoning_content" in delta:
reasoning_content = delta["reasoning_content"] reasoning_content = delta["reasoning_content"]
@ -3454,10 +3445,23 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
full_response += content full_response += content
accumulated_response += content accumulated_response += content
text_has_content = True text_has_content = True
for ch in content: emit_time = time.time()
text_chunk_buffer.append(ch) elapsed = 0.0 if last_text_chunk_time is None else emit_time - last_text_chunk_time
text_chunk_buffer_size += 1 last_text_chunk_time = emit_time
await flush_text_buffer() text_chunk_index += 1
log_backend_chunk(
conversation_id,
iteration + 1,
text_chunk_index,
elapsed,
len(content),
content[:32]
)
sender('text_chunk', {
'content': content,
'index': text_chunk_index,
'elapsed': elapsed
})
# 收集工具调用 - 实时发送准备状态 # 收集工具调用 - 实时发送准备状态
if "tool_calls" in delta: if "tool_calls" in delta:
@ -3508,24 +3512,20 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
return return
# === API响应完成后只计算输出token === # === API响应完成后只计算输出token ===
try: if last_usage_payload:
ai_output_content = full_response or append_result.get("assistant_content") or modify_result.get("assistant_content") or "" try:
if tool_calls: web_terminal.context_manager.apply_usage_statistics(last_usage_payload)
ai_output_content += json.dumps(tool_calls, ensure_ascii=False) debug_log(
f"Usage统计: prompt={last_usage_payload.get('prompt_tokens', 0)}, "
if ai_output_content.strip(): f"completion={last_usage_payload.get('completion_tokens', 0)}, "
output_tokens = web_terminal.context_manager.calculate_output_tokens(ai_output_content) f"total={last_usage_payload.get('total_tokens', 0)}"
debug_log(f"{iteration + 1}次API调用输出token: {output_tokens}") )
except Exception as e:
# 只更新输出token统计 debug_log(f"Usage统计更新失败: {e}")
web_terminal.context_manager.update_token_statistics(0, output_tokens) else:
else: debug_log("未获取到usage字段跳过token统计更新")
debug_log("没有AI输出内容跳过输出token统计")
except Exception as e:
debug_log(f"输出token统计失败: {e}")
# 流结束后的处理 # 流结束后的处理
await flush_text_buffer(force=True)
debug_log(f"\n流结束统计:") debug_log(f"\n流结束统计:")
debug_log(f" 总chunks: {chunk_count}") debug_log(f" 总chunks: {chunk_count}")
debug_log(f" 思考chunks: {reasoning_chunks}") debug_log(f" 思考chunks: {reasoning_chunks}")
@ -3548,7 +3548,6 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
# 确保text_end事件被发送 # 确保text_end事件被发送
if text_started and text_has_content and not append_result["handled"] and not modify_result["handled"]: if text_started and text_has_content and not append_result["handled"] and not modify_result["handled"]:
await flush_text_buffer(force=True)
debug_log(f"发送text_end事件完整内容长度: {len(full_response)}") debug_log(f"发送text_end事件完整内容长度: {len(full_response)}")
sender('text_end', {'full_content': full_response}) sender('text_end', {'full_content': full_response})
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@ -4251,20 +4250,12 @@ def get_conversation_token_statistics(conversation_id, terminal: WebTerminal, wo
def get_conversation_tokens(conversation_id, terminal: WebTerminal, workspace: UserWorkspace, username: str): def get_conversation_tokens(conversation_id, terminal: WebTerminal, workspace: UserWorkspace, username: str):
"""获取对话的当前完整上下文token数包含所有动态内容""" """获取对话的当前完整上下文token数包含所有动态内容"""
try: try:
# 获取当前聚焦文件状态 current_tokens = terminal.context_manager.get_current_context_tokens(conversation_id)
focused_files = terminal.get_focused_files_info()
# 计算完整token
tokens = terminal.context_manager.conversation_manager.calculate_conversation_tokens(
conversation_id=conversation_id,
context_manager=terminal.context_manager,
focused_files=focused_files,
terminal_content=""
)
return jsonify({ return jsonify({
"success": True, "success": True,
"data": tokens "data": {
"total_tokens": current_tokens
}
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({