agent-Specialization/modules/api_user_manager.py

414 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""API 专用用户与工作区管理JSON + Bearer Token 哈希)。
支持 API 用户的创建/删除、Token 持久化(哈希 + 加密回显)与基础用量计数。
"""
from __future__ import annotations
import json
import hashlib
import threading
import secrets
import base64
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple, Any
from config import (
API_USER_SPACE_DIR,
API_USERS_DB_FILE,
API_TOKENS_FILE,
API_USAGE_FILE,
API_TOKEN_SECRET,
)
from modules.personalization_manager import ensure_personalization_config
try:
from cryptography.fernet import Fernet, InvalidToken # type: ignore
except Exception: # pragma: no cover - 环境缺失时给予友好提示
Fernet = None
InvalidToken = Exception # type: ignore
@dataclass
class ApiUserRecord:
username: str
token_sha256: str
created_at: str
note: str = ""
@dataclass
class ApiUserWorkspace:
"""API 用户的单个工作区描述。"""
username: str
workspace_id: str
root: Path
project_path: Path
data_dir: Path # 会话/备份等落盘到这里(每个工作区独立)
logs_dir: Path
uploads_dir: Path # project/user_upload
quarantine_dir: Path # 上传隔离区(按用户/工作区划分)
shared_dir: Path # 用户级共享目录prompts/personalization
prompts_dir: Path # 实际使用的 prompts 目录(指向 shared_dir/prompt
personalization_dir: Path # 实际使用的 personalization 目录(指向 shared_dir/personalization
class ApiUserManager:
"""最小化的 API 用户管理:只校验 token 哈希并准备隔离工作区。"""
def __init__(
self,
users_file: str = API_USERS_DB_FILE,
tokens_file: str = API_TOKENS_FILE,
workspace_root: str = API_USER_SPACE_DIR,
usage_file: str = API_USAGE_FILE,
):
self.users_file = Path(users_file)
self.tokens_file = Path(tokens_file)
self.usage_file = Path(usage_file)
self.workspace_root = Path(workspace_root).expanduser().resolve()
self.workspace_root.mkdir(parents=True, exist_ok=True)
self._users: Dict[str, ApiUserRecord] = {}
self._tokens: Dict[str, Dict[str, Any]] = {}
self._usage: Dict[str, Dict[str, Any]] = {}
self._lock = threading.Lock()
self._load_users()
self._load_tokens()
self._load_usage()
# ----------------------- public APIs -----------------------
def list_users(self) -> Dict[str, ApiUserRecord]:
with self._lock:
return dict(self._users)
def get_user_by_token(self, bearer_token: str) -> Optional[ApiUserRecord]:
if not bearer_token:
return None
token_sha = self._sha256(bearer_token)
with self._lock:
for user in self._users.values():
if user.token_sha256 == token_sha:
return user
return None
def ensure_workspace(self, username: str, workspace_id: str = "default") -> ApiUserWorkspace:
"""为 API 用户创建/获取指定工作区。
目录布局(每个用户):
<root>/<username>/
shared/ # 用户级共享prompts/personalization
prompts/
personalization/
workspaces/<ws>/ # 单个工作区
project/
user_upload/
data/
conversations/
backups/
logs/
"""
username = username.strip().lower()
ws_id = (workspace_id or "default").strip()
if not ws_id:
ws_id = "default"
user_root = (self.workspace_root / username).resolve()
shared_dir = user_root / "shared"
prompts_dir = shared_dir / "prompts"
personalization_dir = shared_dir / "personalization"
work_root = user_root / "workspaces" / ws_id
project_path = work_root / "project"
data_dir = work_root / "data"
logs_dir = work_root / "logs"
uploads_dir = project_path / "user_upload"
skills_dir = project_path / "skills"
for path in (project_path, data_dir, logs_dir, uploads_dir, skills_dir, shared_dir, prompts_dir, personalization_dir):
path.mkdir(parents=True, exist_ok=True)
# 数据子目录(工作区级)
(data_dir / "conversations").mkdir(parents=True, exist_ok=True)
(data_dir / "backups").mkdir(parents=True, exist_ok=True)
# 用户级 personalization 主文件(共享)
ensure_personalization_config(personalization_dir)
# 为 prompts/personalization 创建便捷访问保持向后兼容data_dir 下可作为符号链接)
for name, target in (("prompts", prompts_dir), ("personalization", personalization_dir)):
link = data_dir / name
if not link.exists():
try:
link.symlink_to(target, target_is_directory=True)
except Exception:
# 某些环境禁用 symlink则忽略使用共享目录路径显式传递
pass
# 上传隔离区(按用户/工作区划分)
from config import UPLOAD_QUARANTINE_SUBDIR
quarantine_root = Path(UPLOAD_QUARANTINE_SUBDIR).expanduser()
if not quarantine_root.is_absolute():
quarantine_root = (self.workspace_root.parent / UPLOAD_QUARANTINE_SUBDIR).resolve()
quarantine_dir = (quarantine_root / username / ws_id).resolve()
quarantine_dir.mkdir(parents=True, exist_ok=True)
return ApiUserWorkspace(
username=username,
workspace_id=ws_id,
root=work_root,
project_path=project_path,
data_dir=data_dir,
logs_dir=logs_dir,
uploads_dir=uploads_dir,
quarantine_dir=quarantine_dir,
shared_dir=shared_dir,
prompts_dir=prompts_dir,
personalization_dir=personalization_dir,
)
def create_user(self, username: str, note: str = "") -> Tuple[ApiUserRecord, str]:
"""创建新的 API 用户并返回明文 Token。"""
username = self._normalize_username(username)
with self._lock:
if username in self._users:
raise ValueError("该 API 用户已存在")
record, token = self._issue_token_locked(username, note=note)
self._save_users()
return record, token
def issue_token(self, username: str, note: str = "") -> Tuple[ApiUserRecord, str]:
"""为已有用户重新生成 token 并返回明文。"""
username = self._normalize_username(username)
with self._lock:
if username not in self._users:
raise ValueError("用户不存在")
record, token = self._issue_token_locked(username, note=note)
self._save_users()
return record, token
def delete_user(self, username: str) -> bool:
username = self._normalize_username(username)
with self._lock:
removed = self._users.pop(username, None) is not None
self._tokens.pop(username, None)
self._usage.pop(username, None)
self._save_users()
self._save_tokens()
self._save_usage()
# 尝试删除对应工作区目录(忽略失败)
try:
import shutil
user_root = (self.workspace_root / username).resolve()
if user_root.exists():
shutil.rmtree(user_root, ignore_errors=True)
except Exception:
pass
return removed
def get_plain_token(self, username: str) -> str:
username = self._normalize_username(username)
with self._lock:
token_entry = self._tokens.get(username) or {}
if not token_entry:
raise ValueError("未找到 token 记录")
# 优先使用加密存储;若无密钥或解密失败,回退到明文字段(本地后台安全场景可接受)
enc = token_entry.get("token_enc")
if enc:
fernet = self._fernet()
if fernet:
try:
return fernet.decrypt(enc.encode("utf-8")).decode("utf-8")
except InvalidToken as exc: # type: ignore
# fall through to plaintext
pass
else:
# 没有密钥,继续尝试明文
pass
plain = token_entry.get("token_plain")
if plain:
return plain
raise RuntimeError("缺少 API_TOKEN_SECRET 且无可用明文 token")
def bump_usage(self, username: str, endpoint: Optional[str] = None):
"""记录 API 请求次数与最近时间,用于后台监控。"""
username = self._normalize_username(username)
now_iso = datetime.utcnow().isoformat() + "Z"
with self._lock:
entry = self._usage.setdefault(username, {"total": 0, "endpoints": {}, "last_request_at": None})
entry["total"] = int(entry.get("total", 0)) + 1
if endpoint:
endpoint_key = endpoint.split("?")[0]
endpoints = entry.setdefault("endpoints", {})
endpoints[endpoint_key] = int(endpoints.get(endpoint_key, 0)) + 1
entry["last_request_at"] = now_iso
self._save_usage()
def get_usage(self, username: str) -> Dict[str, Any]:
username = self._normalize_username(username)
with self._lock:
return dict(self._usage.get(username) or {})
def list_usage(self) -> Dict[str, Dict[str, Any]]:
with self._lock:
return {u: dict(meta) for u, meta in self._usage.items()}
def list_workspaces(self, username: str) -> Dict[str, Dict]:
"""列出用户的所有工作区信息。"""
username = username.strip().lower()
user_root = (self.workspace_root / username / "workspaces").resolve()
if not user_root.exists():
return {}
result = {}
for p in sorted(user_root.iterdir()):
if not p.is_dir():
continue
ws_id = p.name
data_dir = p / "data"
project_path = p / "project"
result[ws_id] = {
"workspace_id": ws_id,
# 不暴露宿主机绝对路径,只返回相对工作区的信息
"project_path": "project",
"data_dir": "data",
"has_conversations": (data_dir / "conversations").exists(),
}
return result
def delete_workspace(self, username: str, workspace_id: str) -> bool:
"""删除指定工作区(仅工作区目录,不删除共享 prompts/personalization"""
username = username.strip().lower()
ws_id = (workspace_id or "").strip()
if not ws_id:
return False
work_root = (self.workspace_root / username / "workspaces" / ws_id).resolve()
if not work_root.exists():
return False
import shutil
shutil.rmtree(work_root, ignore_errors=True)
return True
# ----------------------- internal helpers -----------------------
def _normalize_username(self, username: str) -> str:
candidate = (username or "").strip().lower()
if not candidate:
raise ValueError("用户名不能为空")
return candidate
def _sha256(self, token: str) -> str:
return hashlib.sha256((token or "").encode("utf-8")).hexdigest()
def _load_users(self):
"""加载用户列表,读取 token_sha256不支持明文存储。"""
if not self.users_file.exists():
self._save_users()
return
try:
raw = json.loads(self.users_file.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise RuntimeError(f"无法解析 API 用户文件: {self.users_file} ({exc})")
users = raw.get("users", {}) if isinstance(raw, dict) else {}
for username, payload in users.items():
if not isinstance(payload, dict):
continue
token_sha = (payload.get("token_sha256") or "").strip()
if not token_sha:
continue
record = ApiUserRecord(
username=username.strip().lower(),
token_sha256=token_sha,
created_at=payload.get("created_at") or "",
note=payload.get("note") or "",
)
self._users[record.username] = record
def _save_users(self):
payload = {
"users": {
username: {
"token_sha256": record.token_sha256,
"created_at": record.created_at or datetime.utcnow().isoformat(),
"note": record.note,
}
for username, record in self._users.items()
}
}
self.users_file.parent.mkdir(parents=True, exist_ok=True)
self.users_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
# -------- Token 存储(加密回显) --------
def _fernet(self):
secret = (API_TOKEN_SECRET or "").strip()
if not secret or not Fernet:
return None
key = hashlib.sha256(secret.encode("utf-8")).digest()
fkey = base64.urlsafe_b64encode(key)
return Fernet(fkey)
def _load_tokens(self):
if not self.tokens_file.exists():
self._save_tokens()
return
try:
raw = json.loads(self.tokens_file.read_text(encoding="utf-8"))
tokens = raw.get("tokens", {}) if isinstance(raw, dict) else {}
self._tokens = tokens
except Exception:
self._tokens = {}
def _save_tokens(self):
payload = {"tokens": self._tokens}
self.tokens_file.parent.mkdir(parents=True, exist_ok=True)
self.tokens_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def _store_token(self, username: str, token: str, note: str = ""):
fernet = self._fernet()
enc = fernet.encrypt(token.encode("utf-8")).decode("utf-8") if fernet else ""
self._tokens[username] = {
"token_enc": enc,
"token_plain": token, # 便于在缺少密钥时回退,受本地文件权限保护
"note": note,
"created_at": datetime.utcnow().isoformat(),
}
self._save_tokens()
def _issue_token_locked(self, username: str, note: str = "") -> Tuple[ApiUserRecord, str]:
token = secrets.token_urlsafe(32)
token_sha = self._sha256(token)
record = self._users.get(username)
if not record:
record = ApiUserRecord(
username=username,
token_sha256=token_sha,
created_at=datetime.utcnow().isoformat(),
note=note or "",
)
record.token_sha256 = token_sha
record.note = note or record.note
record.created_at = record.created_at or datetime.utcnow().isoformat()
self._users[username] = record
self._store_token(username, token, note=note)
return record, token
# -------- API 请求用量 --------
def _load_usage(self):
if not self.usage_file.exists():
self._save_usage()
return
try:
raw = json.loads(self.usage_file.read_text(encoding="utf-8"))
usage = raw.get("usage", {}) if isinstance(raw, dict) else {}
self._usage = usage
except Exception:
self._usage = {}
def _save_usage(self):
payload = {"usage": self._usage}
self.usage_file.parent.mkdir(parents=True, exist_ok=True)
self.usage_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
__all__ = ["ApiUserManager", "ApiUserRecord", "ApiUserWorkspace"]