agent-Specialization/server/context.py

447 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""用户终端与工作区相关的共享辅助函数。"""
from __future__ import annotations
from functools import wraps
from typing import Optional, Tuple, Dict, Any
from flask import session, jsonify, has_request_context
from core.web_terminal import WebTerminal
from modules.gui_file_manager import GuiFileManager
from modules.upload_security import UploadQuarantineManager, UploadSecurityError
from modules.personalization_manager import load_personalization_config
import json
from pathlib import Path
from modules.usage_tracker import UsageTracker
from config import (
HOST_PROJECT_PATH,
DATA_DIR,
LOGS_DIR,
TERMINAL_SANDBOX_MODE,
UPLOAD_QUARANTINE_SUBDIR,
)
from . import state
from .utils_common import debug_log
from .auth_helpers import get_current_username, get_current_user_record, get_current_user_role # will create helper module
def make_terminal_callback(username: str):
"""生成面向指定用户的广播函数"""
from .extensions import socketio
def _callback(event_type, data):
try:
socketio.emit(event_type, data, room=f"user_{username}")
except Exception as exc:
debug_log(f"广播事件失败 ({username}): {event_type} - {exc}")
return _callback
def attach_user_broadcast(terminal: WebTerminal, username: str):
"""确保终端的广播函数指向当前用户的房间"""
callback = make_terminal_callback(username)
terminal.message_callback = callback
if terminal.terminal_manager:
terminal.terminal_manager.broadcast = callback
def _make_terminal_key(username: str, workspace_id: Optional[str] = None) -> str:
return f"{username}::{workspace_id}" if workspace_id else username
def get_user_resources(username: Optional[str] = None, workspace_id: Optional[str] = None) -> Tuple[Optional[WebTerminal], Optional['modules.user_manager.UserWorkspace']]:
from modules.user_manager import UserWorkspace
username = (username or get_current_username())
if not username:
return None, None
# 宿主机免登录模式:使用 HOST_PROJECT_PATH 直接进入,不创建 /users/<user>/project
host_mode_session = bool(session.get("host_mode")) if has_request_context() else False
sandbox_is_host = (TERMINAL_SANDBOX_MODE or "host").lower() == "host"
if host_mode_session and sandbox_is_host:
project_path = Path(HOST_PROJECT_PATH).expanduser().resolve()
project_path.mkdir(parents=True, exist_ok=True)
data_dir = Path(DATA_DIR).expanduser().resolve()
data_dir.mkdir(parents=True, exist_ok=True)
logs_dir = Path(LOGS_DIR).expanduser().resolve()
logs_dir.mkdir(parents=True, exist_ok=True)
uploads_dir = project_path / "user_upload"
uploads_dir.mkdir(parents=True, exist_ok=True)
quarantine_root = Path(UPLOAD_QUARANTINE_SUBDIR).expanduser()
if not quarantine_root.is_absolute():
quarantine_root = (project_path.parent / UPLOAD_QUARANTINE_SUBDIR).resolve()
quarantine_root.mkdir(parents=True, exist_ok=True)
workspace = UserWorkspace(
username="host",
root=project_path.parent,
project_path=project_path,
data_dir=data_dir,
logs_dir=logs_dir,
uploads_dir=uploads_dir,
quarantine_dir=quarantine_root,
)
if not hasattr(workspace, "workspace_id"):
workspace.workspace_id = "host"
term_key = "host"
container_handle = state.container_manager.ensure_container("host", str(project_path), container_key=term_key, preferred_mode="host")
usage_tracker = None # 宿主机模式不计配额
terminal = state.user_terminals.get(term_key)
if not terminal:
run_mode = session.get('run_mode') if has_request_context() else None
thinking_mode_flag = session.get('thinking_mode') if has_request_context() else None
if run_mode not in {"fast", "thinking", "deep"}:
run_mode = "fast"
thinking_mode_flag = False
thinking_mode = bool(thinking_mode_flag) if thinking_mode_flag is not None else (run_mode != "fast")
terminal = WebTerminal(
project_path=str(project_path),
thinking_mode=thinking_mode,
run_mode=run_mode,
message_callback=make_terminal_callback("host"),
data_dir=str(data_dir),
container_session=container_handle,
usage_tracker=usage_tracker
)
if terminal.terminal_manager:
terminal.terminal_manager.broadcast = terminal.message_callback
state.user_terminals[term_key] = terminal
terminal.username = "host"
terminal.user_role = "admin"
terminal.quota_update_callback = None
if has_request_context():
session['run_mode'] = terminal.run_mode
session['thinking_mode'] = terminal.thinking_mode
session['workspace_id'] = getattr(workspace, "workspace_id", None)
else:
terminal.update_container_session(container_handle)
attach_user_broadcast(terminal, "host")
terminal.username = "host"
terminal.user_role = "admin"
if has_request_context():
session['workspace_id'] = getattr(workspace, "workspace_id", None)
return terminal, workspace
is_api_user = bool(session.get("is_api_user")) if has_request_context() else False
# API 用户与网页用户使用不同的 manager
if is_api_user:
record = None
if workspace_id is None:
raise RuntimeError("API 调用缺少 workspace_id")
workspace = state.api_user_manager.ensure_workspace(username, workspace_id)
else:
record = get_current_user_record()
workspace = state.user_manager.ensure_user_workspace(username)
# 为兼容后续逻辑,补充 workspace_id 属性
if not hasattr(workspace, "workspace_id"):
try:
workspace.workspace_id = "default"
except Exception:
pass
term_key = _make_terminal_key(username, getattr(workspace, "workspace_id", None) if is_api_user else None)
container_handle = state.container_manager.ensure_container(username, str(workspace.project_path), container_key=term_key, preferred_mode="docker")
usage_tracker = None if is_api_user else get_or_create_usage_tracker(username, workspace)
terminal = state.user_terminals.get(term_key)
if not terminal:
run_mode = session.get('run_mode') if has_request_context() else None
thinking_mode_flag = session.get('thinking_mode') if has_request_context() else None
if run_mode not in {"fast", "thinking", "deep"}:
preferred_run_mode = None
try:
personal_config = load_personalization_config(workspace.data_dir)
candidate_mode = (personal_config or {}).get('default_run_mode')
if isinstance(candidate_mode, str) and candidate_mode.lower() in {"fast", "thinking", "deep"}:
preferred_run_mode = candidate_mode.lower()
except Exception as exc:
debug_log(f"[UserInit] 加载个性化偏好失败: {exc}")
if preferred_run_mode:
run_mode = preferred_run_mode
thinking_mode_flag = preferred_run_mode != "fast"
elif thinking_mode_flag:
run_mode = "deep"
else:
run_mode = "fast"
thinking_mode = run_mode != "fast"
terminal = WebTerminal(
project_path=str(workspace.project_path),
thinking_mode=thinking_mode,
run_mode=run_mode,
message_callback=make_terminal_callback(username),
data_dir=str(workspace.data_dir),
container_session=container_handle,
usage_tracker=usage_tracker
)
if terminal.terminal_manager:
terminal.terminal_manager.broadcast = terminal.message_callback
state.user_terminals[term_key] = terminal
terminal.username = username
terminal.user_role = "api" if is_api_user else get_current_user_role(record)
terminal.quota_update_callback = (lambda metric=None: emit_user_quota_update(username)) if not is_api_user else None
if has_request_context():
session['run_mode'] = terminal.run_mode
session['thinking_mode'] = terminal.thinking_mode
if is_api_user:
session['workspace_id'] = getattr(workspace, "workspace_id", None)
else:
terminal.update_container_session(container_handle)
attach_user_broadcast(terminal, username)
terminal.username = username
terminal.user_role = "api" if is_api_user else get_current_user_role(record)
terminal.quota_update_callback = (lambda metric=None: emit_user_quota_update(username)) if not is_api_user else None
if has_request_context() and is_api_user:
session['workspace_id'] = getattr(workspace, "workspace_id", None)
# 应用管理员策略
if not is_api_user:
try:
from core.tool_config import ToolCategory
from modules import admin_policy_manager
policy = admin_policy_manager.get_effective_policy(
record.username if record else None,
get_current_user_role(record),
getattr(record, "invite_code", None),
)
categories_map = {
cid: ToolCategory(
label=cat.get("label") or cid,
tools=list(cat.get("tools") or []),
default_enabled=bool(cat.get("default_enabled", True)),
silent_when_disabled=bool(cat.get("silent_when_disabled", False)),
)
for cid, cat in policy.get("categories", {}).items()
}
forced_states = policy.get("forced_category_states") or {}
disabled_models = policy.get("disabled_models") or []
terminal.set_admin_policy(categories_map, forced_states, disabled_models)
terminal.admin_policy_ui_blocks = policy.get("ui_blocks") or {}
terminal.admin_policy_version = policy.get("updated_at")
if terminal.model_key in disabled_models:
for candidate in ["kimi-k2.5", "kimi", "deepseek", "qwen3-vl-plus", "qwen3-max"]:
if candidate not in disabled_models:
try:
terminal.set_model(candidate)
session["model_key"] = terminal.model_key
break
except Exception:
continue
except Exception as exc:
debug_log(f"[admin_policy] 应用失败: {exc}")
return terminal, workspace
def get_or_create_usage_tracker(username: Optional[str], workspace: Optional['modules.user_manager.UserWorkspace'] = None) -> Optional[UsageTracker]:
if not username:
return None
tracker = state.usage_trackers.get(username)
if tracker:
return tracker
from modules.user_manager import UserWorkspace
if workspace is None:
workspace = state.user_manager.ensure_user_workspace(username)
record = state.user_manager.get_user(username)
role = getattr(record, "role", "user") if record else "user"
tracker = UsageTracker(str(workspace.data_dir), role=role or "user")
state.usage_trackers[username] = tracker
return tracker
def emit_user_quota_update(username: Optional[str]):
from .extensions import socketio
if not username:
return
tracker = get_or_create_usage_tracker(username)
if not tracker:
return
try:
snapshot = tracker.get_quota_snapshot()
socketio.emit('quota_update', {'quotas': snapshot}, room=f"user_{username}")
except Exception:
pass
def apply_conversation_overrides(terminal: WebTerminal, workspace, conversation_id: Optional[str]):
"""根据对话元数据应用自定义 prompt / personalization仅 API 用途)。"""
if not conversation_id:
return
conv_path = Path(workspace.data_dir) / "conversations" / f"{conversation_id}.json"
if not conv_path.exists():
return
try:
data = json.loads(conv_path.read_text(encoding="utf-8"))
meta = data.get("metadata") or {}
prompt_name = meta.get("custom_prompt_name")
personalization_name = meta.get("personalization_name")
# prompt override
if prompt_name:
prompt_path = Path(workspace.data_dir) / "prompts" / f"{prompt_name}.txt"
if prompt_path.exists():
terminal.context_manager.custom_system_prompt = prompt_path.read_text(encoding="utf-8")
else:
terminal.context_manager.custom_system_prompt = None
else:
terminal.context_manager.custom_system_prompt = None
# personalization override
if personalization_name:
pers_path = Path(workspace.data_dir) / "personalization" / f"{personalization_name}.json"
if pers_path.exists():
try:
terminal.context_manager.custom_personalization_config = json.loads(pers_path.read_text(encoding="utf-8"))
except Exception:
terminal.context_manager.custom_personalization_config = None
else:
terminal.context_manager.custom_personalization_config = None
else:
terminal.context_manager.custom_personalization_config = None
# 应用个性化偏好(含禁用工具分类)到当前终端
try:
terminal.apply_personalization_preferences(terminal.context_manager.custom_personalization_config)
except Exception as exc:
debug_log(f"[apply_overrides] 应用个性化失败: {exc}")
except Exception as exc:
debug_log(f"[apply_overrides] 读取对话元数据失败: {exc}")
def with_terminal(func):
"""注入用户专属终端和工作区"""
@wraps(func)
def wrapper(*args, **kwargs):
username = get_current_username()
try:
terminal, workspace = get_user_resources(username)
except RuntimeError as exc:
return jsonify({"error": str(exc), "code": "resource_busy"}), 503
if not terminal or not workspace:
return jsonify({"error": "System not initialized"}), 503
kwargs.update({
'terminal': terminal,
'workspace': workspace,
'username': username
})
return func(*args, **kwargs)
return wrapper
def get_terminal_for_sid(sid: str):
username = state.connection_users.get(sid)
if not username:
return None, None, None
try:
terminal, workspace = get_user_resources(username)
except RuntimeError:
return username, None, None
return username, terminal, workspace
def get_gui_manager(workspace):
return GuiFileManager(str(workspace.project_path))
def get_upload_guard(workspace):
return UploadQuarantineManager(workspace)
def build_upload_error_response(exc: UploadSecurityError):
status = 400
if exc.code in {"scanner_missing", "scanner_unavailable"}:
status = 500
return jsonify({
"success": False,
"error": str(exc),
"code": exc.code,
}), status
def ensure_conversation_loaded(
terminal: WebTerminal,
conversation_id: Optional[str],
workspace=None,
):
created_new = False
if not conversation_id:
result = terminal.create_new_conversation()
if not result.get("success"):
raise RuntimeError(result.get("message", "创建对话失败"))
conversation_id = result["conversation_id"]
if has_request_context():
session['run_mode'] = terminal.run_mode
session['thinking_mode'] = terminal.thinking_mode
created_new = True
else:
conversation_id = conversation_id if conversation_id.startswith('conv_') else f"conv_{conversation_id}"
current_id = terminal.context_manager.current_conversation_id
if current_id != conversation_id:
load_result = terminal.load_conversation(conversation_id)
if not load_result.get("success"):
raise RuntimeError(load_result.get("message", "对话加载失败"))
try:
conv_data = terminal.context_manager.conversation_manager.load_conversation(conversation_id) or {}
meta = conv_data.get("metadata", {}) or {}
run_mode_meta = meta.get("run_mode")
if run_mode_meta:
terminal.set_run_mode(run_mode_meta)
elif meta.get("thinking_mode"):
terminal.set_run_mode("thinking")
else:
terminal.set_run_mode("fast")
if terminal.thinking_mode:
terminal.api_client.start_new_task(force_deep=terminal.deep_thinking_mode)
else:
terminal.api_client.start_new_task()
if has_request_context():
session['run_mode'] = terminal.run_mode
session['thinking_mode'] = terminal.thinking_mode
except Exception:
pass
# 应用对话级自定义 prompt / personalization仅 API
# 注意ensure_conversation_loaded 在 WebSocket/后台任务等多处复用,有些调用点拿不到 workspace
# 因此这里允许 workspace 为空(仅跳过 override不影响正常对话加载
if workspace is not None:
try:
apply_conversation_overrides(terminal, workspace, conversation_id)
except Exception as exc:
debug_log(f"[apply_overrides] 失败: {exc}")
return conversation_id, created_new
def reset_system_state(terminal: Optional[WebTerminal]):
"""完整重置系统状态"""
if not terminal:
return
try:
if hasattr(terminal, 'api_client') and terminal.api_client:
debug_log("重置API客户端状态")
terminal.api_client.start_new_task(force_deep=getattr(terminal, "deep_thinking_mode", False))
if hasattr(terminal, 'current_session_id'):
terminal.current_session_id += 1
debug_log(f"重置会话ID为: {terminal.current_session_id}")
web_attrs = ['streamingMessage', 'currentMessageIndex', 'preparingTools', 'activeTools']
for attr in web_attrs:
if hasattr(terminal, attr):
if attr in ['streamingMessage']:
setattr(terminal, attr, False)
elif attr in ['currentMessageIndex']:
setattr(terminal, attr, -1)
elif attr in ['preparingTools', 'activeTools'] and hasattr(getattr(terminal, attr), 'clear'):
getattr(terminal, attr).clear()
debug_log("系统状态重置完成")
except Exception as e:
debug_log(f"状态重置过程中出现错误: {e}")
import traceback
debug_log(f"错误详情: {traceback.format_exc()}")
__all__ = [
"get_user_resources",
"with_terminal",
"get_terminal_for_sid",
"get_gui_manager",
"get_upload_guard",
"build_upload_error_response",
"ensure_conversation_loaded",
"reset_system_state",
"get_or_create_usage_tracker",
"emit_user_quota_update",
"attach_user_broadcast",
]