agent-Specialization/modules/admin_policy_manager.py
2026-01-05 21:48:55 +08:00

255 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""管理员策略配置管理。
职责:
- 持久化管理员在工具分类、模型禁用、前端功能禁用等方面的策略。
- 支持作用域global / role / user / invite_code按优先级合并得到最终生效策略。
- 仅在此文件读写配置,避免侵入现有模块。
"""
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import Dict, Any, Tuple
from config.paths import ADMIN_POLICY_FILE
from modules.custom_tool_registry import CustomToolRegistry, build_default_tool_category
# 可用的模型 key与前端、model_profiles 保持一致)
ALLOWED_MODELS = {"kimi", "deepseek", "qwen3-max", "qwen3-vl-plus"}
# UI 禁用项键名,前后端统一
UI_BLOCK_KEYS = [
"collapse_workspace",
"block_file_manager",
"block_personal_space",
"block_upload",
"block_conversation_review",
"block_tool_toggle",
"block_realtime_terminal",
"block_focus_panel",
"block_token_panel",
"block_compress_conversation",
"block_virtual_monitor",
]
def _ensure_file() -> Path:
path = Path(ADMIN_POLICY_FILE).expanduser().resolve()
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
_save_policy(_default_policy(), path)
return path
def _read_policy(path: Path) -> Dict[str, Any]:
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return _default_policy()
def _save_policy(payload: Dict[str, Any], path: Path | None = None) -> None:
target = path or _ensure_file()
target.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def _default_policy() -> Dict[str, Any]:
return {
"updated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"global": _blank_config(),
"roles": {},
"users": {},
"invites": {},
}
def _blank_config() -> Dict[str, Any]:
return {
"category_overrides": {}, # id -> {label, tools, default_enabled?}
"remove_categories": [], # ids
"forced_category_states": {}, # id -> true/false
"disabled_models": [], # list of model keys
"ui_blocks": {}, # key -> bool
}
def _merge_config(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
"""浅合并策略,数组采取并集/覆盖逻辑。"""
merged = _blank_config()
base = base or {}
override = override or {}
merged["category_overrides"] = {
**(base.get("category_overrides") or {}),
**(override.get("category_overrides") or {}),
}
merged["remove_categories"] = list({
*[c for c in base.get("remove_categories") or [] if isinstance(c, str)],
*[c for c in override.get("remove_categories") or [] if isinstance(c, str)],
})
merged["forced_category_states"] = {
**(base.get("forced_category_states") or {}),
**(override.get("forced_category_states") or {}),
}
merged["disabled_models"] = list({
*[m for m in base.get("disabled_models") or [] if m in ALLOWED_MODELS],
*[m for m in override.get("disabled_models") or [] if m in ALLOWED_MODELS],
})
ui_base = base.get("ui_blocks") or {}
ui_override = override.get("ui_blocks") or {}
merged["ui_blocks"] = {**ui_base, **ui_override}
return merged
def load_policy() -> Dict[str, Any]:
path = _ensure_file()
payload = _read_policy(path)
# 补全缺失字段
if not isinstance(payload, dict):
payload = _default_policy()
payload.setdefault("updated_at", time.strftime("%Y-%m-%d %H:%M:%S"))
payload.setdefault("global", _blank_config())
payload.setdefault("roles", {})
payload.setdefault("users", {})
payload.setdefault("invites", {})
return payload
def save_scope_policy(target_type: str, target_value: str, config: Dict[str, Any]) -> Dict[str, Any]:
"""更新指定作用域的策略并保存,返回最新策略。"""
if target_type not in {"global", "role", "user", "invite"}:
raise ValueError("invalid target_type")
policy = load_policy()
normalized = _blank_config()
normalized = _merge_config(normalized, config or {})
if target_type == "global":
policy["global"] = normalized
elif target_type == "role":
policy.setdefault("roles", {})[target_value] = normalized
elif target_type == "user":
policy.setdefault("users", {})[target_value] = normalized
else:
policy.setdefault("invites", {})[target_value] = normalized
policy["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
_save_policy(policy)
return policy
def _collect_categories_with_overrides(overrides: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
"""从 override 字典生成 {id: {label, tools, default_enabled}}"""
from core.tool_config import TOOL_CATEGORIES # 延迟导入避免循环
registry = CustomToolRegistry()
base: Dict[str, Dict[str, Any]] = {
key: {
"label": cat.label,
"tools": list(cat.tools),
"default_enabled": bool(cat.default_enabled),
"silent_when_disabled": getattr(cat, "silent_when_disabled", False),
}
for key, cat in TOOL_CATEGORIES.items()
}
# 注入自定义工具分类(动态)
custom_cat = build_default_tool_category()
custom_cat_id = custom_cat.get("id", "custom")
custom_tools = [item.get("id") for item in registry.list_tools() if item.get("id")]
base[custom_cat_id] = {
"label": custom_cat.get("label", "自定义工具"),
"tools": custom_tools,
"default_enabled": True,
"silent_when_disabled": False,
}
remove_ids = set(overrides.get("remove_categories") or [])
for rid in remove_ids:
base.pop(rid, None)
for cid, payload in (overrides.get("category_overrides") or {}).items():
if not isinstance(cid, str):
continue
if not isinstance(payload, dict):
continue
label = payload.get("label") or cid
tools = payload.get("tools") or []
if not isinstance(tools, list):
continue
default_enabled = bool(payload.get("default_enabled", True))
base[cid] = {
"label": str(label),
"tools": [t for t in tools if isinstance(t, str)],
"default_enabled": default_enabled,
"silent_when_disabled": bool(payload.get("silent_when_disabled", False)),
}
return base
def get_effective_policy(username: str | None, role: str | None, invite_code: str | None) -> Dict[str, Any]:
"""按优先级合并策略,返回生效配置。"""
policy = load_policy()
scopes: Tuple[Tuple[str, Dict[str, Any]], ...] = (
("global", policy.get("global") or _blank_config()),
("role", (policy.get("roles") or {}).get(role or "", {})),
("invite", (policy.get("invites") or {}).get(invite_code or "", {})),
("user", (policy.get("users") or {}).get(username or "", {})),
)
merged = _blank_config()
for _, cfg in scopes:
merged = _merge_config(merged, cfg or {})
# 计算最终分类
categories = _collect_categories_with_overrides(merged)
forced_states = {
key: bool(val) if isinstance(val, bool) else None
for key, val in (merged.get("forced_category_states") or {}).items()
if key
}
disabled_models = [m for m in merged.get("disabled_models") or [] if m in ALLOWED_MODELS]
ui_blocks = {k: bool(v) for k, v in (merged.get("ui_blocks") or {}).items() if k in UI_BLOCK_KEYS}
return {
"categories": categories,
"forced_category_states": forced_states,
"disabled_models": disabled_models,
"ui_blocks": ui_blocks,
"updated_at": policy.get("updated_at"),
}
def describe_defaults() -> Dict[str, Any]:
"""返回默认(未覆盖)工具分类,用于前端渲染。"""
from core.tool_config import TOOL_CATEGORIES
registry = CustomToolRegistry()
categories = {
key: {
"label": cat.label,
"tools": list(cat.tools),
"default_enabled": bool(cat.default_enabled),
"silent_when_disabled": getattr(cat, "silent_when_disabled", False),
}
for key, cat in TOOL_CATEGORIES.items()
}
# 自定义工具分类
custom_cat = build_default_tool_category()
custom_cat_id = custom_cat.get("id", "custom")
categories[custom_cat_id] = {
"label": custom_cat.get("label", "自定义工具"),
"tools": [item.get("id") for item in registry.list_tools() if item.get("id")],
"default_enabled": True,
"silent_when_disabled": False,
}
return {
"categories": categories,
"models": sorted(list(ALLOWED_MODELS)),
"ui_block_keys": UI_BLOCK_KEYS,
}