255 lines
8.7 KiB
Python
255 lines
8.7 KiB
Python
"""管理员策略配置管理。
|
||
|
||
职责:
|
||
- 持久化管理员在工具分类、模型禁用、前端功能禁用等方面的策略。
|
||
- 支持作用域:global / role / user / invite_code,按优先级合并得到最终生效策略。
|
||
- 仅在此文件读写配置,避免侵入现有模块。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Dict, Any, Tuple
|
||
|
||
from config.paths import ADMIN_POLICY_FILE
|
||
from modules.custom_tool_registry import CustomToolRegistry, build_default_tool_category
|
||
|
||
# 可用的模型 key(与前端、model_profiles 保持一致)
|
||
ALLOWED_MODELS = {"kimi", "deepseek", "qwen3-max", "qwen3-vl-plus"}
|
||
|
||
# UI 禁用项键名,前后端统一
|
||
UI_BLOCK_KEYS = [
|
||
"collapse_workspace",
|
||
"block_file_manager",
|
||
"block_personal_space",
|
||
"block_upload",
|
||
"block_conversation_review",
|
||
"block_tool_toggle",
|
||
"block_realtime_terminal",
|
||
"block_focus_panel",
|
||
"block_token_panel",
|
||
"block_compress_conversation",
|
||
"block_virtual_monitor",
|
||
]
|
||
|
||
|
||
def _ensure_file() -> Path:
|
||
path = Path(ADMIN_POLICY_FILE).expanduser().resolve()
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
if not path.exists():
|
||
_save_policy(_default_policy(), path)
|
||
return path
|
||
|
||
|
||
def _read_policy(path: Path) -> Dict[str, Any]:
|
||
try:
|
||
return json.loads(path.read_text(encoding="utf-8"))
|
||
except Exception:
|
||
return _default_policy()
|
||
|
||
|
||
def _save_policy(payload: Dict[str, Any], path: Path | None = None) -> None:
|
||
target = path or _ensure_file()
|
||
target.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
|
||
|
||
def _default_policy() -> Dict[str, Any]:
|
||
return {
|
||
"updated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"global": _blank_config(),
|
||
"roles": {},
|
||
"users": {},
|
||
"invites": {},
|
||
}
|
||
|
||
|
||
def _blank_config() -> Dict[str, Any]:
|
||
return {
|
||
"category_overrides": {}, # id -> {label, tools, default_enabled?}
|
||
"remove_categories": [], # ids
|
||
"forced_category_states": {}, # id -> true/false
|
||
"disabled_models": [], # list of model keys
|
||
"ui_blocks": {}, # key -> bool
|
||
}
|
||
|
||
|
||
def _merge_config(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""浅合并策略,数组采取并集/覆盖逻辑。"""
|
||
merged = _blank_config()
|
||
base = base or {}
|
||
override = override or {}
|
||
|
||
merged["category_overrides"] = {
|
||
**(base.get("category_overrides") or {}),
|
||
**(override.get("category_overrides") or {}),
|
||
}
|
||
merged["remove_categories"] = list({
|
||
*[c for c in base.get("remove_categories") or [] if isinstance(c, str)],
|
||
*[c for c in override.get("remove_categories") or [] if isinstance(c, str)],
|
||
})
|
||
merged["forced_category_states"] = {
|
||
**(base.get("forced_category_states") or {}),
|
||
**(override.get("forced_category_states") or {}),
|
||
}
|
||
merged["disabled_models"] = list({
|
||
*[m for m in base.get("disabled_models") or [] if m in ALLOWED_MODELS],
|
||
*[m for m in override.get("disabled_models") or [] if m in ALLOWED_MODELS],
|
||
})
|
||
ui_base = base.get("ui_blocks") or {}
|
||
ui_override = override.get("ui_blocks") or {}
|
||
merged["ui_blocks"] = {**ui_base, **ui_override}
|
||
return merged
|
||
|
||
|
||
def load_policy() -> Dict[str, Any]:
|
||
path = _ensure_file()
|
||
payload = _read_policy(path)
|
||
# 补全缺失字段
|
||
if not isinstance(payload, dict):
|
||
payload = _default_policy()
|
||
payload.setdefault("updated_at", time.strftime("%Y-%m-%d %H:%M:%S"))
|
||
payload.setdefault("global", _blank_config())
|
||
payload.setdefault("roles", {})
|
||
payload.setdefault("users", {})
|
||
payload.setdefault("invites", {})
|
||
return payload
|
||
|
||
|
||
def save_scope_policy(target_type: str, target_value: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""更新指定作用域的策略并保存,返回最新策略。"""
|
||
if target_type not in {"global", "role", "user", "invite"}:
|
||
raise ValueError("invalid target_type")
|
||
|
||
policy = load_policy()
|
||
normalized = _blank_config()
|
||
normalized = _merge_config(normalized, config or {})
|
||
|
||
if target_type == "global":
|
||
policy["global"] = normalized
|
||
elif target_type == "role":
|
||
policy.setdefault("roles", {})[target_value] = normalized
|
||
elif target_type == "user":
|
||
policy.setdefault("users", {})[target_value] = normalized
|
||
else:
|
||
policy.setdefault("invites", {})[target_value] = normalized
|
||
|
||
policy["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
|
||
_save_policy(policy)
|
||
return policy
|
||
|
||
|
||
def _collect_categories_with_overrides(overrides: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
||
"""从 override 字典生成 {id: {label, tools, default_enabled}}"""
|
||
from core.tool_config import TOOL_CATEGORIES # 延迟导入避免循环
|
||
registry = CustomToolRegistry()
|
||
|
||
base: Dict[str, Dict[str, Any]] = {
|
||
key: {
|
||
"label": cat.label,
|
||
"tools": list(cat.tools),
|
||
"default_enabled": bool(cat.default_enabled),
|
||
"silent_when_disabled": getattr(cat, "silent_when_disabled", False),
|
||
}
|
||
for key, cat in TOOL_CATEGORIES.items()
|
||
}
|
||
|
||
# 注入自定义工具分类(动态)
|
||
custom_cat = build_default_tool_category()
|
||
custom_cat_id = custom_cat.get("id", "custom")
|
||
custom_tools = [item.get("id") for item in registry.list_tools() if item.get("id")]
|
||
base[custom_cat_id] = {
|
||
"label": custom_cat.get("label", "自定义工具"),
|
||
"tools": custom_tools,
|
||
"default_enabled": True,
|
||
"silent_when_disabled": False,
|
||
}
|
||
|
||
remove_ids = set(overrides.get("remove_categories") or [])
|
||
for rid in remove_ids:
|
||
base.pop(rid, None)
|
||
|
||
for cid, payload in (overrides.get("category_overrides") or {}).items():
|
||
if not isinstance(cid, str):
|
||
continue
|
||
if not isinstance(payload, dict):
|
||
continue
|
||
label = payload.get("label") or cid
|
||
tools = payload.get("tools") or []
|
||
if not isinstance(tools, list):
|
||
continue
|
||
default_enabled = bool(payload.get("default_enabled", True))
|
||
base[cid] = {
|
||
"label": str(label),
|
||
"tools": [t for t in tools if isinstance(t, str)],
|
||
"default_enabled": default_enabled,
|
||
"silent_when_disabled": bool(payload.get("silent_when_disabled", False)),
|
||
}
|
||
return base
|
||
|
||
|
||
def get_effective_policy(username: str | None, role: str | None, invite_code: str | None) -> Dict[str, Any]:
|
||
"""按优先级合并策略,返回生效配置。"""
|
||
policy = load_policy()
|
||
scopes: Tuple[Tuple[str, Dict[str, Any]], ...] = (
|
||
("global", policy.get("global") or _blank_config()),
|
||
("role", (policy.get("roles") or {}).get(role or "", {})),
|
||
("invite", (policy.get("invites") or {}).get(invite_code or "", {})),
|
||
("user", (policy.get("users") or {}).get(username or "", {})),
|
||
)
|
||
|
||
merged = _blank_config()
|
||
for _, cfg in scopes:
|
||
merged = _merge_config(merged, cfg or {})
|
||
|
||
# 计算最终分类
|
||
categories = _collect_categories_with_overrides(merged)
|
||
forced_states = {
|
||
key: bool(val) if isinstance(val, bool) else None
|
||
for key, val in (merged.get("forced_category_states") or {}).items()
|
||
if key
|
||
}
|
||
disabled_models = [m for m in merged.get("disabled_models") or [] if m in ALLOWED_MODELS]
|
||
ui_blocks = {k: bool(v) for k, v in (merged.get("ui_blocks") or {}).items() if k in UI_BLOCK_KEYS}
|
||
|
||
return {
|
||
"categories": categories,
|
||
"forced_category_states": forced_states,
|
||
"disabled_models": disabled_models,
|
||
"ui_blocks": ui_blocks,
|
||
"updated_at": policy.get("updated_at"),
|
||
}
|
||
|
||
|
||
def describe_defaults() -> Dict[str, Any]:
|
||
"""返回默认(未覆盖)工具分类,用于前端渲染。"""
|
||
from core.tool_config import TOOL_CATEGORIES
|
||
registry = CustomToolRegistry()
|
||
|
||
categories = {
|
||
key: {
|
||
"label": cat.label,
|
||
"tools": list(cat.tools),
|
||
"default_enabled": bool(cat.default_enabled),
|
||
"silent_when_disabled": getattr(cat, "silent_when_disabled", False),
|
||
}
|
||
for key, cat in TOOL_CATEGORIES.items()
|
||
}
|
||
|
||
# 自定义工具分类
|
||
custom_cat = build_default_tool_category()
|
||
custom_cat_id = custom_cat.get("id", "custom")
|
||
categories[custom_cat_id] = {
|
||
"label": custom_cat.get("label", "自定义工具"),
|
||
"tools": [item.get("id") for item in registry.list_tools() if item.get("id")],
|
||
"default_enabled": True,
|
||
"silent_when_disabled": False,
|
||
}
|
||
|
||
return {
|
||
"categories": categories,
|
||
"models": sorted(list(ALLOWED_MODELS)),
|
||
"ui_block_keys": UI_BLOCK_KEYS,
|
||
}
|