341 lines
14 KiB
Python
341 lines
14 KiB
Python
"""用户终端与工作区相关的共享辅助函数。"""
|
||
from __future__ import annotations
|
||
from functools import wraps
|
||
from typing import Optional, Tuple, Dict, Any
|
||
from flask import session, jsonify, has_request_context
|
||
|
||
from core.web_terminal import WebTerminal
|
||
from modules.gui_file_manager import GuiFileManager
|
||
from modules.upload_security import UploadQuarantineManager, UploadSecurityError
|
||
from modules.personalization_manager import load_personalization_config
|
||
import json
|
||
from pathlib import Path
|
||
from modules.usage_tracker import UsageTracker
|
||
|
||
from . import state
|
||
from .utils_common import debug_log
|
||
from .auth_helpers import get_current_username, get_current_user_record, get_current_user_role # will create helper module
|
||
|
||
|
||
def make_terminal_callback(username: str):
|
||
"""生成面向指定用户的广播函数"""
|
||
from .extensions import socketio
|
||
def _callback(event_type, data):
|
||
try:
|
||
socketio.emit(event_type, data, room=f"user_{username}")
|
||
except Exception as exc:
|
||
debug_log(f"广播事件失败 ({username}): {event_type} - {exc}")
|
||
return _callback
|
||
|
||
|
||
def attach_user_broadcast(terminal: WebTerminal, username: str):
|
||
"""确保终端的广播函数指向当前用户的房间"""
|
||
callback = make_terminal_callback(username)
|
||
terminal.message_callback = callback
|
||
if terminal.terminal_manager:
|
||
terminal.terminal_manager.broadcast = callback
|
||
|
||
|
||
def get_user_resources(username: Optional[str] = None) -> Tuple[Optional[WebTerminal], Optional['modules.user_manager.UserWorkspace']]:
|
||
from modules.user_manager import UserWorkspace
|
||
username = (username or get_current_username())
|
||
if not username:
|
||
return None, None
|
||
is_api_user = bool(session.get("is_api_user")) if has_request_context() else False
|
||
# API 用户与网页用户使用不同的 manager
|
||
if is_api_user:
|
||
record = None
|
||
workspace = state.api_user_manager.ensure_workspace(username)
|
||
else:
|
||
record = get_current_user_record()
|
||
workspace = state.user_manager.ensure_user_workspace(username)
|
||
container_handle = state.container_manager.ensure_container(username, str(workspace.project_path))
|
||
usage_tracker = None if is_api_user else get_or_create_usage_tracker(username, workspace)
|
||
terminal = state.user_terminals.get(username)
|
||
if not terminal:
|
||
run_mode = session.get('run_mode') if has_request_context() else None
|
||
thinking_mode_flag = session.get('thinking_mode') if has_request_context() else None
|
||
if run_mode not in {"fast", "thinking", "deep"}:
|
||
preferred_run_mode = None
|
||
try:
|
||
personal_config = load_personalization_config(workspace.data_dir)
|
||
candidate_mode = (personal_config or {}).get('default_run_mode')
|
||
if isinstance(candidate_mode, str) and candidate_mode.lower() in {"fast", "thinking", "deep"}:
|
||
preferred_run_mode = candidate_mode.lower()
|
||
except Exception as exc:
|
||
debug_log(f"[UserInit] 加载个性化偏好失败: {exc}")
|
||
|
||
if preferred_run_mode:
|
||
run_mode = preferred_run_mode
|
||
thinking_mode_flag = preferred_run_mode != "fast"
|
||
elif thinking_mode_flag:
|
||
run_mode = "deep"
|
||
else:
|
||
run_mode = "fast"
|
||
thinking_mode = run_mode != "fast"
|
||
terminal = WebTerminal(
|
||
project_path=str(workspace.project_path),
|
||
thinking_mode=thinking_mode,
|
||
run_mode=run_mode,
|
||
message_callback=make_terminal_callback(username),
|
||
data_dir=str(workspace.data_dir),
|
||
container_session=container_handle,
|
||
usage_tracker=usage_tracker
|
||
)
|
||
if terminal.terminal_manager:
|
||
terminal.terminal_manager.broadcast = terminal.message_callback
|
||
state.user_terminals[username] = terminal
|
||
terminal.username = username
|
||
terminal.user_role = "api" if is_api_user else get_current_user_role(record)
|
||
terminal.quota_update_callback = (lambda metric=None: emit_user_quota_update(username)) if not is_api_user else None
|
||
if has_request_context():
|
||
session['run_mode'] = terminal.run_mode
|
||
session['thinking_mode'] = terminal.thinking_mode
|
||
else:
|
||
terminal.update_container_session(container_handle)
|
||
attach_user_broadcast(terminal, username)
|
||
terminal.username = username
|
||
terminal.user_role = "api" if is_api_user else get_current_user_role(record)
|
||
terminal.quota_update_callback = (lambda metric=None: emit_user_quota_update(username)) if not is_api_user else None
|
||
|
||
# 应用管理员策略
|
||
if not is_api_user:
|
||
try:
|
||
from core.tool_config import ToolCategory
|
||
from modules import admin_policy_manager
|
||
policy = admin_policy_manager.get_effective_policy(
|
||
record.username if record else None,
|
||
get_current_user_role(record),
|
||
getattr(record, "invite_code", None),
|
||
)
|
||
categories_map = {
|
||
cid: ToolCategory(
|
||
label=cat.get("label") or cid,
|
||
tools=list(cat.get("tools") or []),
|
||
default_enabled=bool(cat.get("default_enabled", True)),
|
||
silent_when_disabled=bool(cat.get("silent_when_disabled", False)),
|
||
)
|
||
for cid, cat in policy.get("categories", {}).items()
|
||
}
|
||
forced_states = policy.get("forced_category_states") or {}
|
||
disabled_models = policy.get("disabled_models") or []
|
||
terminal.set_admin_policy(categories_map, forced_states, disabled_models)
|
||
terminal.admin_policy_ui_blocks = policy.get("ui_blocks") or {}
|
||
terminal.admin_policy_version = policy.get("updated_at")
|
||
if terminal.model_key in disabled_models:
|
||
for candidate in ["kimi", "deepseek", "qwen3-vl-plus", "qwen3-max"]:
|
||
if candidate not in disabled_models:
|
||
try:
|
||
terminal.set_model(candidate)
|
||
session["model_key"] = terminal.model_key
|
||
break
|
||
except Exception:
|
||
continue
|
||
except Exception as exc:
|
||
debug_log(f"[admin_policy] 应用失败: {exc}")
|
||
return terminal, workspace
|
||
|
||
|
||
def get_or_create_usage_tracker(username: Optional[str], workspace: Optional['modules.user_manager.UserWorkspace'] = None) -> Optional[UsageTracker]:
|
||
if not username:
|
||
return None
|
||
tracker = state.usage_trackers.get(username)
|
||
if tracker:
|
||
return tracker
|
||
from modules.user_manager import UserWorkspace
|
||
if workspace is None:
|
||
workspace = state.user_manager.ensure_user_workspace(username)
|
||
record = state.user_manager.get_user(username)
|
||
role = getattr(record, "role", "user") if record else "user"
|
||
tracker = UsageTracker(str(workspace.data_dir), role=role or "user")
|
||
state.usage_trackers[username] = tracker
|
||
return tracker
|
||
|
||
|
||
def emit_user_quota_update(username: Optional[str]):
|
||
from .extensions import socketio
|
||
if not username:
|
||
return
|
||
tracker = get_or_create_usage_tracker(username)
|
||
if not tracker:
|
||
return
|
||
try:
|
||
snapshot = tracker.get_quota_snapshot()
|
||
socketio.emit('quota_update', {'quotas': snapshot}, room=f"user_{username}")
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def apply_conversation_overrides(terminal: WebTerminal, workspace, conversation_id: Optional[str]):
|
||
"""根据对话元数据应用自定义 prompt / personalization(仅 API 用途)。"""
|
||
if not conversation_id:
|
||
return
|
||
conv_path = Path(workspace.data_dir) / "conversations" / f"{conversation_id}.json"
|
||
if not conv_path.exists():
|
||
return
|
||
try:
|
||
data = json.loads(conv_path.read_text(encoding="utf-8"))
|
||
meta = data.get("metadata") or {}
|
||
prompt_name = meta.get("custom_prompt_name")
|
||
personalization_name = meta.get("personalization_name")
|
||
# prompt override
|
||
if prompt_name:
|
||
prompt_path = Path(workspace.data_dir) / "prompts" / f"{prompt_name}.txt"
|
||
if prompt_path.exists():
|
||
terminal.context_manager.custom_system_prompt = prompt_path.read_text(encoding="utf-8")
|
||
else:
|
||
terminal.context_manager.custom_system_prompt = None
|
||
else:
|
||
terminal.context_manager.custom_system_prompt = None
|
||
# personalization override
|
||
if personalization_name:
|
||
pers_path = Path(workspace.data_dir) / "personalization" / f"{personalization_name}.json"
|
||
if pers_path.exists():
|
||
try:
|
||
terminal.context_manager.custom_personalization_config = json.loads(pers_path.read_text(encoding="utf-8"))
|
||
except Exception:
|
||
terminal.context_manager.custom_personalization_config = None
|
||
else:
|
||
terminal.context_manager.custom_personalization_config = None
|
||
else:
|
||
terminal.context_manager.custom_personalization_config = None
|
||
except Exception as exc:
|
||
debug_log(f"[apply_overrides] 读取对话元数据失败: {exc}")
|
||
|
||
|
||
def with_terminal(func):
|
||
"""注入用户专属终端和工作区"""
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
username = get_current_username()
|
||
try:
|
||
terminal, workspace = get_user_resources(username)
|
||
except RuntimeError as exc:
|
||
return jsonify({"error": str(exc), "code": "resource_busy"}), 503
|
||
if not terminal or not workspace:
|
||
return jsonify({"error": "System not initialized"}), 503
|
||
kwargs.update({
|
||
'terminal': terminal,
|
||
'workspace': workspace,
|
||
'username': username
|
||
})
|
||
return func(*args, **kwargs)
|
||
return wrapper
|
||
|
||
|
||
def get_terminal_for_sid(sid: str):
|
||
username = state.connection_users.get(sid)
|
||
if not username:
|
||
return None, None, None
|
||
try:
|
||
terminal, workspace = get_user_resources(username)
|
||
except RuntimeError:
|
||
return username, None, None
|
||
return username, terminal, workspace
|
||
|
||
|
||
def get_gui_manager(workspace):
|
||
return GuiFileManager(str(workspace.project_path))
|
||
|
||
|
||
def get_upload_guard(workspace):
|
||
return UploadQuarantineManager(workspace)
|
||
|
||
|
||
def build_upload_error_response(exc: UploadSecurityError):
|
||
status = 400
|
||
if exc.code in {"scanner_missing", "scanner_unavailable"}:
|
||
status = 500
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(exc),
|
||
"code": exc.code,
|
||
}), status
|
||
|
||
|
||
def ensure_conversation_loaded(terminal: WebTerminal, conversation_id: Optional[str]):
|
||
created_new = False
|
||
if not conversation_id:
|
||
result = terminal.create_new_conversation()
|
||
if not result.get("success"):
|
||
raise RuntimeError(result.get("message", "创建对话失败"))
|
||
conversation_id = result["conversation_id"]
|
||
if has_request_context():
|
||
session['run_mode'] = terminal.run_mode
|
||
session['thinking_mode'] = terminal.thinking_mode
|
||
created_new = True
|
||
else:
|
||
conversation_id = conversation_id if conversation_id.startswith('conv_') else f"conv_{conversation_id}"
|
||
current_id = terminal.context_manager.current_conversation_id
|
||
if current_id != conversation_id:
|
||
load_result = terminal.load_conversation(conversation_id)
|
||
if not load_result.get("success"):
|
||
raise RuntimeError(load_result.get("message", "对话加载失败"))
|
||
try:
|
||
conv_data = terminal.context_manager.conversation_manager.load_conversation(conversation_id) or {}
|
||
meta = conv_data.get("metadata", {}) or {}
|
||
run_mode_meta = meta.get("run_mode")
|
||
if run_mode_meta:
|
||
terminal.set_run_mode(run_mode_meta)
|
||
elif meta.get("thinking_mode"):
|
||
terminal.set_run_mode("thinking")
|
||
else:
|
||
terminal.set_run_mode("fast")
|
||
if terminal.thinking_mode:
|
||
terminal.api_client.start_new_task(force_deep=terminal.deep_thinking_mode)
|
||
else:
|
||
terminal.api_client.start_new_task()
|
||
if has_request_context():
|
||
session['run_mode'] = terminal.run_mode
|
||
session['thinking_mode'] = terminal.thinking_mode
|
||
except Exception:
|
||
pass
|
||
# 应用对话级自定义 prompt / personalization(仅 API)
|
||
try:
|
||
apply_conversation_overrides(terminal, workspace, conversation_id)
|
||
except Exception as exc:
|
||
debug_log(f"[apply_overrides] 失败: {exc}")
|
||
return conversation_id, created_new
|
||
|
||
|
||
def reset_system_state(terminal: Optional[WebTerminal]):
|
||
"""完整重置系统状态"""
|
||
if not terminal:
|
||
return
|
||
try:
|
||
if hasattr(terminal, 'api_client') and terminal.api_client:
|
||
debug_log("重置API客户端状态")
|
||
terminal.api_client.start_new_task(force_deep=getattr(terminal, "deep_thinking_mode", False))
|
||
if hasattr(terminal, 'current_session_id'):
|
||
terminal.current_session_id += 1
|
||
debug_log(f"重置会话ID为: {terminal.current_session_id}")
|
||
web_attrs = ['streamingMessage', 'currentMessageIndex', 'preparingTools', 'activeTools']
|
||
for attr in web_attrs:
|
||
if hasattr(terminal, attr):
|
||
if attr in ['streamingMessage']:
|
||
setattr(terminal, attr, False)
|
||
elif attr in ['currentMessageIndex']:
|
||
setattr(terminal, attr, -1)
|
||
elif attr in ['preparingTools', 'activeTools'] and hasattr(getattr(terminal, attr), 'clear'):
|
||
getattr(terminal, attr).clear()
|
||
debug_log("系统状态重置完成")
|
||
except Exception as e:
|
||
debug_log(f"状态重置过程中出现错误: {e}")
|
||
import traceback
|
||
debug_log(f"错误详情: {traceback.format_exc()}")
|
||
|
||
|
||
__all__ = [
|
||
"get_user_resources",
|
||
"with_terminal",
|
||
"get_terminal_for_sid",
|
||
"get_gui_manager",
|
||
"get_upload_guard",
|
||
"build_upload_error_response",
|
||
"ensure_conversation_loaded",
|
||
"reset_system_state",
|
||
"get_or_create_usage_tracker",
|
||
"emit_user_quota_update",
|
||
"attach_user_broadcast",
|
||
]
|