agent-Specialization/modules/api_user_manager.py

240 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""API 专用用户与工作区管理JSON + Bearer Token 哈希)。
仅支持手动维护:在 `API_USERS_DB_FILE` 中添加用户与 SHA256(token)。
结构示例:
{
"users": {
"api_jojo": {
"token_sha256": "abc123...",
"created_at": "2026-01-23",
"note": "for mobile app"
}
}
}
"""
from __future__ import annotations
import json
import hashlib
import threading
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple
from config import (
API_USER_SPACE_DIR,
API_USERS_DB_FILE,
API_TOKENS_FILE,
)
from modules.personalization_manager import ensure_personalization_config
@dataclass
class ApiUserRecord:
username: str
token_sha256: str
created_at: str
note: str = ""
@dataclass
class ApiUserWorkspace:
"""API 用户的单个工作区描述。"""
username: str
workspace_id: str
root: Path
project_path: Path
data_dir: Path # 会话/备份等落盘到这里(每个工作区独立)
logs_dir: Path
uploads_dir: Path # project/user_upload
quarantine_dir: Path # 上传隔离区(按用户/工作区划分)
shared_dir: Path # 用户级共享目录prompts/personalization
prompts_dir: Path # 实际使用的 prompts 目录(指向 shared_dir/prompt
personalization_dir: Path # 实际使用的 personalization 目录(指向 shared_dir/personalization
class ApiUserManager:
"""最小化的 API 用户管理:只校验 token 哈希并准备隔离工作区。"""
def __init__(
self,
users_file: str = API_USERS_DB_FILE,
tokens_file: str = API_TOKENS_FILE,
workspace_root: str = API_USER_SPACE_DIR,
):
self.users_file = Path(users_file)
self.tokens_file = Path(tokens_file)
self.workspace_root = Path(workspace_root).expanduser().resolve()
self.workspace_root.mkdir(parents=True, exist_ok=True)
self._users: Dict[str, ApiUserRecord] = {}
self._lock = threading.Lock()
self._load_users()
# ----------------------- public APIs -----------------------
def get_user_by_token(self, bearer_token: str) -> Optional[ApiUserRecord]:
if not bearer_token:
return None
token_sha = self._sha256(bearer_token)
with self._lock:
for user in self._users.values():
if user.token_sha256 == token_sha:
return user
return None
def ensure_workspace(self, username: str, workspace_id: str = "default") -> ApiUserWorkspace:
"""为 API 用户创建/获取指定工作区。
目录布局(每个用户):
<root>/<username>/
shared/ # 用户级共享prompts/personalization
prompts/
personalization/
workspaces/<ws>/ # 单个工作区
project/
user_upload/
data/
conversations/
backups/
logs/
"""
username = username.strip().lower()
ws_id = (workspace_id or "default").strip()
if not ws_id:
ws_id = "default"
user_root = (self.workspace_root / username).resolve()
shared_dir = user_root / "shared"
prompts_dir = shared_dir / "prompts"
personalization_dir = shared_dir / "personalization"
work_root = user_root / "workspaces" / ws_id
project_path = work_root / "project"
data_dir = work_root / "data"
logs_dir = work_root / "logs"
uploads_dir = project_path / "user_upload"
for path in (project_path, data_dir, logs_dir, uploads_dir, shared_dir, prompts_dir, personalization_dir):
path.mkdir(parents=True, exist_ok=True)
# 数据子目录(工作区级)
(data_dir / "conversations").mkdir(parents=True, exist_ok=True)
(data_dir / "backups").mkdir(parents=True, exist_ok=True)
# 用户级 personalization 主文件(共享)
ensure_personalization_config(personalization_dir)
# 为 prompts/personalization 创建便捷访问保持向后兼容data_dir 下可作为符号链接)
for name, target in (("prompts", prompts_dir), ("personalization", personalization_dir)):
link = data_dir / name
if not link.exists():
try:
link.symlink_to(target, target_is_directory=True)
except Exception:
# 某些环境禁用 symlink则忽略使用共享目录路径显式传递
pass
# 上传隔离区(按用户/工作区划分)
from config import UPLOAD_QUARANTINE_SUBDIR
quarantine_root = Path(UPLOAD_QUARANTINE_SUBDIR).expanduser()
if not quarantine_root.is_absolute():
quarantine_root = (self.workspace_root.parent / UPLOAD_QUARANTINE_SUBDIR).resolve()
quarantine_dir = (quarantine_root / username / ws_id).resolve()
quarantine_dir.mkdir(parents=True, exist_ok=True)
return ApiUserWorkspace(
username=username,
workspace_id=ws_id,
root=work_root,
project_path=project_path,
data_dir=data_dir,
logs_dir=logs_dir,
uploads_dir=uploads_dir,
quarantine_dir=quarantine_dir,
shared_dir=shared_dir,
prompts_dir=prompts_dir,
personalization_dir=personalization_dir,
)
def list_workspaces(self, username: str) -> Dict[str, Dict]:
"""列出用户的所有工作区信息。"""
username = username.strip().lower()
user_root = (self.workspace_root / username / "workspaces").resolve()
if not user_root.exists():
return {}
result = {}
for p in sorted(user_root.iterdir()):
if not p.is_dir():
continue
ws_id = p.name
data_dir = p / "data"
project_path = p / "project"
result[ws_id] = {
"workspace_id": ws_id,
"project_path": str(project_path),
"data_dir": str(data_dir),
"has_conversations": (data_dir / "conversations").exists(),
}
return result
def delete_workspace(self, username: str, workspace_id: str) -> bool:
"""删除指定工作区(仅工作区目录,不删除共享 prompts/personalization"""
username = username.strip().lower()
ws_id = (workspace_id or "").strip()
if not ws_id:
return False
work_root = (self.workspace_root / username / "workspaces" / ws_id).resolve()
if not work_root.exists():
return False
import shutil
shutil.rmtree(work_root, ignore_errors=True)
return True
# ----------------------- internal helpers -----------------------
def _sha256(self, token: str) -> str:
return hashlib.sha256((token or "").encode("utf-8")).hexdigest()
def _load_users(self):
"""加载用户列表,读取 token_sha256不支持明文存储。"""
if not self.users_file.exists():
self._save_users()
return
try:
raw = json.loads(self.users_file.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise RuntimeError(f"无法解析 API 用户文件: {self.users_file} ({exc})")
users = raw.get("users", {}) if isinstance(raw, dict) else {}
for username, payload in users.items():
if not isinstance(payload, dict):
continue
token_sha = (payload.get("token_sha256") or "").strip()
if not token_sha:
continue
record = ApiUserRecord(
username=username.strip().lower(),
token_sha256=token_sha,
created_at=payload.get("created_at") or "",
note=payload.get("note") or "",
)
self._users[record.username] = record
def _save_users(self):
payload = {
"users": {
username: {
"token_sha256": record.token_sha256,
"created_at": record.created_at or datetime.utcnow().isoformat(),
"note": record.note,
}
for username, record in self._users.items()
}
}
self.users_file.parent.mkdir(parents=True, exist_ok=True)
self.users_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
__all__ = ["ApiUserManager", "ApiUserRecord", "ApiUserWorkspace"]