"""用户终端与工作区相关的共享辅助函数。""" from __future__ import annotations from functools import wraps from typing import Optional, Tuple, Dict, Any from flask import session, jsonify, has_request_context from core.web_terminal import WebTerminal from modules.gui_file_manager import GuiFileManager from modules.upload_security import UploadQuarantineManager, UploadSecurityError from modules.personalization_manager import load_personalization_config import json from pathlib import Path from modules.usage_tracker import UsageTracker from . import state from .utils_common import debug_log from .auth_helpers import get_current_username, get_current_user_record, get_current_user_role # will create helper module def make_terminal_callback(username: str): """生成面向指定用户的广播函数""" from .extensions import socketio def _callback(event_type, data): try: socketio.emit(event_type, data, room=f"user_{username}") except Exception as exc: debug_log(f"广播事件失败 ({username}): {event_type} - {exc}") return _callback def attach_user_broadcast(terminal: WebTerminal, username: str): """确保终端的广播函数指向当前用户的房间""" callback = make_terminal_callback(username) terminal.message_callback = callback if terminal.terminal_manager: terminal.terminal_manager.broadcast = callback def get_user_resources(username: Optional[str] = None) -> Tuple[Optional[WebTerminal], Optional['modules.user_manager.UserWorkspace']]: from modules.user_manager import UserWorkspace username = (username or get_current_username()) if not username: return None, None is_api_user = bool(session.get("is_api_user")) if has_request_context() else False # API 用户与网页用户使用不同的 manager if is_api_user: record = None workspace = state.api_user_manager.ensure_workspace(username) else: record = get_current_user_record() workspace = state.user_manager.ensure_user_workspace(username) container_handle = state.container_manager.ensure_container(username, str(workspace.project_path)) usage_tracker = None if is_api_user else get_or_create_usage_tracker(username, workspace) terminal = state.user_terminals.get(username) if not terminal: run_mode = session.get('run_mode') if has_request_context() else None thinking_mode_flag = session.get('thinking_mode') if has_request_context() else None if run_mode not in {"fast", "thinking", "deep"}: preferred_run_mode = None try: personal_config = load_personalization_config(workspace.data_dir) candidate_mode = (personal_config or {}).get('default_run_mode') if isinstance(candidate_mode, str) and candidate_mode.lower() in {"fast", "thinking", "deep"}: preferred_run_mode = candidate_mode.lower() except Exception as exc: debug_log(f"[UserInit] 加载个性化偏好失败: {exc}") if preferred_run_mode: run_mode = preferred_run_mode thinking_mode_flag = preferred_run_mode != "fast" elif thinking_mode_flag: run_mode = "deep" else: run_mode = "fast" thinking_mode = run_mode != "fast" terminal = WebTerminal( project_path=str(workspace.project_path), thinking_mode=thinking_mode, run_mode=run_mode, message_callback=make_terminal_callback(username), data_dir=str(workspace.data_dir), container_session=container_handle, usage_tracker=usage_tracker ) if terminal.terminal_manager: terminal.terminal_manager.broadcast = terminal.message_callback state.user_terminals[username] = terminal terminal.username = username terminal.user_role = "api" if is_api_user else get_current_user_role(record) terminal.quota_update_callback = (lambda metric=None: emit_user_quota_update(username)) if not is_api_user else None if has_request_context(): session['run_mode'] = terminal.run_mode session['thinking_mode'] = terminal.thinking_mode else: terminal.update_container_session(container_handle) attach_user_broadcast(terminal, username) terminal.username = username terminal.user_role = "api" if is_api_user else get_current_user_role(record) terminal.quota_update_callback = (lambda metric=None: emit_user_quota_update(username)) if not is_api_user else None # 应用管理员策略 if not is_api_user: try: from core.tool_config import ToolCategory from modules import admin_policy_manager policy = admin_policy_manager.get_effective_policy( record.username if record else None, get_current_user_role(record), getattr(record, "invite_code", None), ) categories_map = { cid: ToolCategory( label=cat.get("label") or cid, tools=list(cat.get("tools") or []), default_enabled=bool(cat.get("default_enabled", True)), silent_when_disabled=bool(cat.get("silent_when_disabled", False)), ) for cid, cat in policy.get("categories", {}).items() } forced_states = policy.get("forced_category_states") or {} disabled_models = policy.get("disabled_models") or [] terminal.set_admin_policy(categories_map, forced_states, disabled_models) terminal.admin_policy_ui_blocks = policy.get("ui_blocks") or {} terminal.admin_policy_version = policy.get("updated_at") if terminal.model_key in disabled_models: for candidate in ["kimi", "deepseek", "qwen3-vl-plus", "qwen3-max"]: if candidate not in disabled_models: try: terminal.set_model(candidate) session["model_key"] = terminal.model_key break except Exception: continue except Exception as exc: debug_log(f"[admin_policy] 应用失败: {exc}") return terminal, workspace def get_or_create_usage_tracker(username: Optional[str], workspace: Optional['modules.user_manager.UserWorkspace'] = None) -> Optional[UsageTracker]: if not username: return None tracker = state.usage_trackers.get(username) if tracker: return tracker from modules.user_manager import UserWorkspace if workspace is None: workspace = state.user_manager.ensure_user_workspace(username) record = state.user_manager.get_user(username) role = getattr(record, "role", "user") if record else "user" tracker = UsageTracker(str(workspace.data_dir), role=role or "user") state.usage_trackers[username] = tracker return tracker def emit_user_quota_update(username: Optional[str]): from .extensions import socketio if not username: return tracker = get_or_create_usage_tracker(username) if not tracker: return try: snapshot = tracker.get_quota_snapshot() socketio.emit('quota_update', {'quotas': snapshot}, room=f"user_{username}") except Exception: pass def apply_conversation_overrides(terminal: WebTerminal, workspace, conversation_id: Optional[str]): """根据对话元数据应用自定义 prompt / personalization(仅 API 用途)。""" if not conversation_id: return conv_path = Path(workspace.data_dir) / "conversations" / f"{conversation_id}.json" if not conv_path.exists(): return try: data = json.loads(conv_path.read_text(encoding="utf-8")) meta = data.get("metadata") or {} prompt_name = meta.get("custom_prompt_name") personalization_name = meta.get("personalization_name") # prompt override if prompt_name: prompt_path = Path(workspace.data_dir) / "prompts" / f"{prompt_name}.txt" if prompt_path.exists(): terminal.context_manager.custom_system_prompt = prompt_path.read_text(encoding="utf-8") else: terminal.context_manager.custom_system_prompt = None else: terminal.context_manager.custom_system_prompt = None # personalization override if personalization_name: pers_path = Path(workspace.data_dir) / "personalization" / f"{personalization_name}.json" if pers_path.exists(): try: terminal.context_manager.custom_personalization_config = json.loads(pers_path.read_text(encoding="utf-8")) except Exception: terminal.context_manager.custom_personalization_config = None else: terminal.context_manager.custom_personalization_config = None else: terminal.context_manager.custom_personalization_config = None except Exception as exc: debug_log(f"[apply_overrides] 读取对话元数据失败: {exc}") def with_terminal(func): """注入用户专属终端和工作区""" @wraps(func) def wrapper(*args, **kwargs): username = get_current_username() try: terminal, workspace = get_user_resources(username) except RuntimeError as exc: return jsonify({"error": str(exc), "code": "resource_busy"}), 503 if not terminal or not workspace: return jsonify({"error": "System not initialized"}), 503 kwargs.update({ 'terminal': terminal, 'workspace': workspace, 'username': username }) return func(*args, **kwargs) return wrapper def get_terminal_for_sid(sid: str): username = state.connection_users.get(sid) if not username: return None, None, None try: terminal, workspace = get_user_resources(username) except RuntimeError: return username, None, None return username, terminal, workspace def get_gui_manager(workspace): return GuiFileManager(str(workspace.project_path)) def get_upload_guard(workspace): return UploadQuarantineManager(workspace) def build_upload_error_response(exc: UploadSecurityError): status = 400 if exc.code in {"scanner_missing", "scanner_unavailable"}: status = 500 return jsonify({ "success": False, "error": str(exc), "code": exc.code, }), status def ensure_conversation_loaded(terminal: WebTerminal, conversation_id: Optional[str]): created_new = False if not conversation_id: result = terminal.create_new_conversation() if not result.get("success"): raise RuntimeError(result.get("message", "创建对话失败")) conversation_id = result["conversation_id"] if has_request_context(): session['run_mode'] = terminal.run_mode session['thinking_mode'] = terminal.thinking_mode created_new = True else: conversation_id = conversation_id if conversation_id.startswith('conv_') else f"conv_{conversation_id}" current_id = terminal.context_manager.current_conversation_id if current_id != conversation_id: load_result = terminal.load_conversation(conversation_id) if not load_result.get("success"): raise RuntimeError(load_result.get("message", "对话加载失败")) try: conv_data = terminal.context_manager.conversation_manager.load_conversation(conversation_id) or {} meta = conv_data.get("metadata", {}) or {} run_mode_meta = meta.get("run_mode") if run_mode_meta: terminal.set_run_mode(run_mode_meta) elif meta.get("thinking_mode"): terminal.set_run_mode("thinking") else: terminal.set_run_mode("fast") if terminal.thinking_mode: terminal.api_client.start_new_task(force_deep=terminal.deep_thinking_mode) else: terminal.api_client.start_new_task() if has_request_context(): session['run_mode'] = terminal.run_mode session['thinking_mode'] = terminal.thinking_mode except Exception: pass # 应用对话级自定义 prompt / personalization(仅 API) try: apply_conversation_overrides(terminal, workspace, conversation_id) except Exception as exc: debug_log(f"[apply_overrides] 失败: {exc}") return conversation_id, created_new def reset_system_state(terminal: Optional[WebTerminal]): """完整重置系统状态""" if not terminal: return try: if hasattr(terminal, 'api_client') and terminal.api_client: debug_log("重置API客户端状态") terminal.api_client.start_new_task(force_deep=getattr(terminal, "deep_thinking_mode", False)) if hasattr(terminal, 'current_session_id'): terminal.current_session_id += 1 debug_log(f"重置会话ID为: {terminal.current_session_id}") web_attrs = ['streamingMessage', 'currentMessageIndex', 'preparingTools', 'activeTools'] for attr in web_attrs: if hasattr(terminal, attr): if attr in ['streamingMessage']: setattr(terminal, attr, False) elif attr in ['currentMessageIndex']: setattr(terminal, attr, -1) elif attr in ['preparingTools', 'activeTools'] and hasattr(getattr(terminal, attr), 'clear'): getattr(terminal, attr).clear() debug_log("系统状态重置完成") except Exception as e: debug_log(f"状态重置过程中出现错误: {e}") import traceback debug_log(f"错误详情: {traceback.format_exc()}") __all__ = [ "get_user_resources", "with_terminal", "get_terminal_for_sid", "get_gui_manager", "get_upload_guard", "build_upload_error_response", "ensure_conversation_loaded", "reset_system_state", "get_or_create_usage_tracker", "emit_user_quota_update", "attach_user_broadcast", ]