336 lines
13 KiB
Python
336 lines
13 KiB
Python
from __future__ import annotations
|
||
import asyncio, time, json, re
|
||
from typing import Dict, Any
|
||
from flask import request
|
||
from flask_socketio import emit, join_room, leave_room, disconnect
|
||
|
||
from .extensions import socketio
|
||
from .auth_helpers import get_current_username, resolve_admin_policy
|
||
from .context import (
|
||
get_terminal_for_sid,
|
||
ensure_conversation_loaded,
|
||
reset_system_state,
|
||
get_user_resources,
|
||
)
|
||
from .utils_common import debug_log, log_frontend_chunk, log_streaming_debug_entry
|
||
from .state import connection_users, stop_flags, terminal_rooms, pending_socket_tokens, user_manager, get_stop_flag, set_stop_flag, clear_stop_flag
|
||
from .usage import record_user_activity
|
||
from .chat_flow import start_chat_task
|
||
from .security import consume_socket_token, prune_socket_tokens
|
||
from config import OUTPUT_FORMATS, AGENT_VERSION
|
||
|
||
@socketio.on('connect')
|
||
def handle_connect(auth):
|
||
"""客户端连接"""
|
||
print(f"[WebSocket] 客户端连接: {request.sid}")
|
||
username = get_current_username()
|
||
token_value = (auth or {}).get('socket_token') if isinstance(auth, dict) else None
|
||
if not username or not consume_socket_token(token_value, username):
|
||
emit('error', {'message': '未登录或连接凭证无效'})
|
||
disconnect()
|
||
return
|
||
|
||
emit('connected', {'status': 'Connected to server'})
|
||
connection_users[request.sid] = username
|
||
|
||
# 清理可能存在的停止标志和状态
|
||
stop_flags.pop(request.sid, None)
|
||
# 将旧的 username 级别任务映射到新的 sid,便于重新停止
|
||
user_entry = get_stop_flag(None, username)
|
||
if user_entry:
|
||
set_stop_flag(request.sid, username, user_entry)
|
||
|
||
join_room(f"user_{username}")
|
||
join_room(f"user_{username}_terminal")
|
||
if request.sid not in terminal_rooms:
|
||
terminal_rooms[request.sid] = set()
|
||
terminal_rooms[request.sid].update({f"user_{username}", f"user_{username}_terminal"})
|
||
|
||
terminal, workspace = get_user_resources(username)
|
||
if terminal:
|
||
reset_system_state(terminal)
|
||
emit('system_ready', {
|
||
'project_path': str(workspace.project_path),
|
||
'thinking_mode': bool(getattr(terminal, "thinking_mode", False)),
|
||
'version': AGENT_VERSION
|
||
}, room=request.sid)
|
||
|
||
if terminal.terminal_manager:
|
||
terminals = terminal.terminal_manager.get_terminal_list()
|
||
emit('terminal_list_update', {
|
||
'terminals': terminals,
|
||
'active': terminal.terminal_manager.active_terminal
|
||
}, room=request.sid)
|
||
|
||
if terminal.terminal_manager.active_terminal:
|
||
for name, term in terminal.terminal_manager.terminals.items():
|
||
emit('terminal_started', {
|
||
'session': name,
|
||
'working_dir': str(term.working_dir),
|
||
'shell': term.shell_command,
|
||
'time': term.start_time.isoformat() if term.start_time else None
|
||
}, room=request.sid)
|
||
|
||
@socketio.on('disconnect')
|
||
def handle_disconnect():
|
||
"""客户端断开"""
|
||
print(f"[WebSocket] 客户端断开: {request.sid}")
|
||
username = connection_users.pop(request.sid, None)
|
||
# 若同一用户仍有其他活跃连接,不因断开而停止任务
|
||
has_other_connection = False
|
||
if username:
|
||
for sid, user in connection_users.items():
|
||
if user == username:
|
||
has_other_connection = True
|
||
break
|
||
|
||
task_info = get_stop_flag(request.sid, username)
|
||
if isinstance(task_info, dict) and not has_other_connection:
|
||
task_info['stop'] = True
|
||
pending_task = task_info.get('task')
|
||
if pending_task and not pending_task.done():
|
||
debug_log(f"disconnect: cancel task for {request.sid}")
|
||
pending_task.cancel()
|
||
terminal = task_info.get('terminal')
|
||
if terminal:
|
||
reset_system_state(terminal)
|
||
|
||
# 清理停止标志
|
||
clear_stop_flag(request.sid, None)
|
||
|
||
# 从所有房间移除
|
||
for room in list(terminal_rooms.get(request.sid, [])):
|
||
leave_room(room)
|
||
if request.sid in terminal_rooms:
|
||
del terminal_rooms[request.sid]
|
||
|
||
if username:
|
||
leave_room(f"user_{username}")
|
||
leave_room(f"user_{username}_terminal")
|
||
|
||
@socketio.on('stop_task')
|
||
def handle_stop_task():
|
||
"""处理停止任务请求"""
|
||
print(f"[停止] 收到停止请求: {request.sid}")
|
||
username = connection_users.get(request.sid)
|
||
task_info = get_stop_flag(request.sid, username)
|
||
if not isinstance(task_info, dict):
|
||
task_info = {'stop': False, 'task': None, 'terminal': None}
|
||
# 标记停止并尝试取消任务
|
||
task_info['stop'] = True
|
||
pending_task = task_info.get('task')
|
||
if pending_task and not pending_task.done():
|
||
debug_log(f"正在取消任务: {request.sid}")
|
||
pending_task.cancel()
|
||
if task_info.get('terminal'):
|
||
reset_system_state(task_info['terminal'])
|
||
set_stop_flag(request.sid, username, task_info)
|
||
|
||
emit('stop_requested', {
|
||
'message': '停止请求已接收,正在取消任务...'
|
||
})
|
||
|
||
@socketio.on('terminal_subscribe')
|
||
def handle_terminal_subscribe(data):
|
||
"""订阅终端事件"""
|
||
session_name = data.get('session')
|
||
subscribe_all = data.get('all', False)
|
||
|
||
username, terminal, _ = get_terminal_for_sid(request.sid)
|
||
if not username or not terminal or not terminal.terminal_manager:
|
||
emit('error', {'message': 'Terminal system not initialized'})
|
||
return
|
||
policy = resolve_admin_policy(user_manager.get_user(username))
|
||
if policy.get("ui_blocks", {}).get("block_realtime_terminal"):
|
||
emit('error', {'message': '实时终端已被管理员禁用'})
|
||
return
|
||
|
||
if request.sid not in terminal_rooms:
|
||
terminal_rooms[request.sid] = set()
|
||
|
||
if subscribe_all:
|
||
# 订阅所有终端事件
|
||
room_name = f"user_{username}_terminal"
|
||
join_room(room_name)
|
||
terminal_rooms[request.sid].add(room_name)
|
||
print(f"[Terminal] {request.sid} 订阅所有终端事件")
|
||
|
||
# 发送当前终端状态
|
||
emit('terminal_subscribed', {
|
||
'type': 'all',
|
||
'terminals': terminal.terminal_manager.get_terminal_list()
|
||
})
|
||
elif session_name:
|
||
# 订阅特定终端会话
|
||
room_name = f'user_{username}_terminal_{session_name}'
|
||
join_room(room_name)
|
||
terminal_rooms[request.sid].add(room_name)
|
||
print(f"[Terminal] {request.sid} 订阅终端: {session_name}")
|
||
|
||
# 发送该终端的当前输出
|
||
output_result = terminal.terminal_manager.get_terminal_output(session_name, 100)
|
||
if output_result['success']:
|
||
emit('terminal_history', {
|
||
'session': session_name,
|
||
'output': output_result['output']
|
||
})
|
||
|
||
@socketio.on('terminal_unsubscribe')
|
||
def handle_terminal_unsubscribe(data):
|
||
"""取消订阅终端事件"""
|
||
session_name = data.get('session')
|
||
username = connection_users.get(request.sid)
|
||
|
||
if session_name:
|
||
room_name = f'user_{username}_terminal_{session_name}' if username else f'terminal_{session_name}'
|
||
leave_room(room_name)
|
||
if request.sid in terminal_rooms:
|
||
terminal_rooms[request.sid].discard(room_name)
|
||
print(f"[Terminal] {request.sid} 取消订阅终端: {session_name}")
|
||
|
||
@socketio.on('get_terminal_output')
|
||
def handle_get_terminal_output(data):
|
||
"""获取终端输出历史"""
|
||
session_name = data.get('session')
|
||
lines = data.get('lines', 50)
|
||
|
||
username, terminal, _ = get_terminal_for_sid(request.sid)
|
||
if not terminal or not terminal.terminal_manager:
|
||
emit('error', {'message': 'Terminal system not initialized'})
|
||
return
|
||
policy = resolve_admin_policy(user_manager.get_user(username))
|
||
if policy.get("ui_blocks", {}).get("block_realtime_terminal"):
|
||
emit('error', {'message': '实时终端已被管理员禁用'})
|
||
return
|
||
|
||
result = terminal.terminal_manager.get_terminal_output(session_name, lines)
|
||
|
||
if result['success']:
|
||
emit('terminal_output_history', {
|
||
'session': session_name,
|
||
'output': result['output'],
|
||
'is_interactive': result.get('is_interactive', False),
|
||
'last_command': result.get('last_command', '')
|
||
})
|
||
else:
|
||
emit('error', {'message': result['error']})
|
||
|
||
@socketio.on('send_message')
|
||
def handle_message(data):
|
||
"""处理用户消息"""
|
||
username, terminal, workspace = get_terminal_for_sid(request.sid)
|
||
if not terminal:
|
||
emit('error', {'message': 'System not initialized'})
|
||
return
|
||
|
||
message = (data.get('message') or '').strip()
|
||
images = data.get('images') or []
|
||
videos = data.get('videos') or []
|
||
if not message and not images and not videos:
|
||
emit('error', {'message': '消息不能为空'})
|
||
return
|
||
if images and getattr(terminal, "model_key", None) not in {"qwen3-vl-plus", "kimi-k2.5"}:
|
||
emit('error', {'message': '当前模型不支持图片,请切换到 Qwen-VL 或 Kimi-k2.5'})
|
||
return
|
||
if videos and getattr(terminal, "model_key", None) != "kimi-k2.5":
|
||
emit('error', {'message': '当前模型不支持视频,请切换到 Kimi-k2.5'})
|
||
return
|
||
if images and videos:
|
||
emit('error', {'message': '图片和视频请分开发送'})
|
||
return
|
||
|
||
print(f"[WebSocket] 收到消息: {message}")
|
||
debug_log(f"\n{'='*80}\n新任务开始: {message}\n{'='*80}")
|
||
record_user_activity(username)
|
||
|
||
requested_conversation_id = data.get('conversation_id')
|
||
try:
|
||
conversation_id, created_new = ensure_conversation_loaded(terminal, requested_conversation_id)
|
||
except RuntimeError as exc:
|
||
emit('error', {'message': str(exc)})
|
||
return
|
||
try:
|
||
conv_data = terminal.context_manager.conversation_manager.load_conversation(conversation_id) or {}
|
||
except Exception:
|
||
conv_data = {}
|
||
title = conv_data.get('title', '新对话')
|
||
|
||
socketio.emit('conversation_resolved', {
|
||
'conversation_id': conversation_id,
|
||
'title': title,
|
||
'created': created_new
|
||
}, room=request.sid)
|
||
|
||
if created_new:
|
||
socketio.emit('conversation_list_update', {
|
||
'action': 'created',
|
||
'conversation_id': conversation_id
|
||
}, room=f"user_{username}")
|
||
socketio.emit('conversation_changed', {
|
||
'conversation_id': conversation_id,
|
||
'title': title
|
||
}, room=request.sid)
|
||
|
||
client_sid = request.sid
|
||
|
||
def send_to_client(event_type, data):
|
||
"""发送消息到客户端"""
|
||
socketio.emit(event_type, data, room=client_sid)
|
||
|
||
# 模型活动事件:用于刷新“在线”心跳(回复/工具调用都算活动)
|
||
activity_events = {
|
||
'ai_message_start', 'thinking_start', 'thinking_chunk', 'thinking_end',
|
||
'text_start', 'text_chunk', 'text_end',
|
||
'tool_preparing', 'tool_start', 'update_action',
|
||
'append_payload', 'modify_payload', 'system_message',
|
||
'task_complete'
|
||
}
|
||
last_model_activity = 0.0
|
||
|
||
def send_with_activity(event_type, data):
|
||
"""模型产生输出或调用工具时刷新活跃时间,防止长回复被误判下线。"""
|
||
nonlocal last_model_activity
|
||
if event_type in activity_events:
|
||
now = time.time()
|
||
# 轻量节流:1 秒内多次事件只记一次
|
||
if now - last_model_activity >= 1.0:
|
||
record_user_activity(username)
|
||
last_model_activity = now
|
||
send_to_client(event_type, data)
|
||
|
||
# 传递客户端ID
|
||
images = data.get('images') or []
|
||
videos = data.get('videos') or []
|
||
start_chat_task(terminal, message, images, send_with_activity, client_sid, workspace, username, videos)
|
||
|
||
|
||
@socketio.on('client_chunk_log')
|
||
def handle_client_chunk_log(data):
|
||
"""前端chunk日志上报"""
|
||
conversation_id = data.get('conversation_id')
|
||
chunk_index = int(data.get('index') or data.get('chunk_index') or 0)
|
||
elapsed = float(data.get('elapsed') or 0.0)
|
||
length = int(data.get('length') or len(data.get('content') or ""))
|
||
client_ts = float(data.get('ts') or 0.0)
|
||
log_frontend_chunk(conversation_id, chunk_index, elapsed, length, client_ts)
|
||
|
||
|
||
@socketio.on('client_stream_debug_log')
|
||
def handle_client_stream_debug_log(data):
|
||
"""前端流式调试日志"""
|
||
if not isinstance(data, dict):
|
||
return
|
||
entry = dict(data)
|
||
entry.setdefault('server_ts', time.time())
|
||
log_streaming_debug_entry(entry)
|
||
|
||
# 在 web_server.py 中添加以下对话管理API接口
|
||
# 添加在现有路由之后,@socketio 事件处理之前
|
||
|
||
# ==========================================
|
||
# 对话管理API接口
|
||
# ==========================================
|
||
|
||
|
||
# conversation routes moved to server/conversation.py
|