240 lines
8.6 KiB
Python
240 lines
8.6 KiB
Python
"""API 专用用户与工作区管理(JSON + Bearer Token 哈希)。
|
||
|
||
仅支持手动维护:在 `API_USERS_DB_FILE` 中添加用户与 SHA256(token)。
|
||
结构示例:
|
||
{
|
||
"users": {
|
||
"api_jojo": {
|
||
"token_sha256": "abc123...",
|
||
"created_at": "2026-01-23",
|
||
"note": "for mobile app"
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
import json
|
||
import hashlib
|
||
import threading
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Dict, Optional, Tuple
|
||
|
||
from config import (
|
||
API_USER_SPACE_DIR,
|
||
API_USERS_DB_FILE,
|
||
API_TOKENS_FILE,
|
||
)
|
||
from modules.personalization_manager import ensure_personalization_config
|
||
|
||
|
||
@dataclass
|
||
class ApiUserRecord:
|
||
username: str
|
||
token_sha256: str
|
||
created_at: str
|
||
note: str = ""
|
||
|
||
|
||
@dataclass
|
||
class ApiUserWorkspace:
|
||
"""API 用户的单个工作区描述。"""
|
||
username: str
|
||
workspace_id: str
|
||
root: Path
|
||
project_path: Path
|
||
data_dir: Path # 会话/备份等落盘到这里(每个工作区独立)
|
||
logs_dir: Path
|
||
uploads_dir: Path # project/user_upload
|
||
quarantine_dir: Path # 上传隔离区(按用户/工作区划分)
|
||
shared_dir: Path # 用户级共享目录(prompts/personalization)
|
||
prompts_dir: Path # 实际使用的 prompts 目录(指向 shared_dir/prompt)
|
||
personalization_dir: Path # 实际使用的 personalization 目录(指向 shared_dir/personalization)
|
||
|
||
|
||
class ApiUserManager:
|
||
"""最小化的 API 用户管理:只校验 token 哈希并准备隔离工作区。"""
|
||
|
||
def __init__(
|
||
self,
|
||
users_file: str = API_USERS_DB_FILE,
|
||
tokens_file: str = API_TOKENS_FILE,
|
||
workspace_root: str = API_USER_SPACE_DIR,
|
||
):
|
||
self.users_file = Path(users_file)
|
||
self.tokens_file = Path(tokens_file)
|
||
self.workspace_root = Path(workspace_root).expanduser().resolve()
|
||
self.workspace_root.mkdir(parents=True, exist_ok=True)
|
||
|
||
self._users: Dict[str, ApiUserRecord] = {}
|
||
self._lock = threading.Lock()
|
||
|
||
self._load_users()
|
||
|
||
# ----------------------- public APIs -----------------------
|
||
def get_user_by_token(self, bearer_token: str) -> Optional[ApiUserRecord]:
|
||
if not bearer_token:
|
||
return None
|
||
token_sha = self._sha256(bearer_token)
|
||
with self._lock:
|
||
for user in self._users.values():
|
||
if user.token_sha256 == token_sha:
|
||
return user
|
||
return None
|
||
|
||
def ensure_workspace(self, username: str, workspace_id: str = "default") -> ApiUserWorkspace:
|
||
"""为 API 用户创建/获取指定工作区。
|
||
|
||
目录布局(每个用户):
|
||
<root>/<username>/
|
||
shared/ # 用户级共享(prompts/personalization)
|
||
prompts/
|
||
personalization/
|
||
workspaces/<ws>/ # 单个工作区
|
||
project/
|
||
user_upload/
|
||
data/
|
||
conversations/
|
||
backups/
|
||
logs/
|
||
"""
|
||
username = username.strip().lower()
|
||
ws_id = (workspace_id or "default").strip()
|
||
if not ws_id:
|
||
ws_id = "default"
|
||
|
||
user_root = (self.workspace_root / username).resolve()
|
||
shared_dir = user_root / "shared"
|
||
prompts_dir = shared_dir / "prompts"
|
||
personalization_dir = shared_dir / "personalization"
|
||
work_root = user_root / "workspaces" / ws_id
|
||
|
||
project_path = work_root / "project"
|
||
data_dir = work_root / "data"
|
||
logs_dir = work_root / "logs"
|
||
uploads_dir = project_path / "user_upload"
|
||
|
||
for path in (project_path, data_dir, logs_dir, uploads_dir, shared_dir, prompts_dir, personalization_dir):
|
||
path.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 数据子目录(工作区级)
|
||
(data_dir / "conversations").mkdir(parents=True, exist_ok=True)
|
||
(data_dir / "backups").mkdir(parents=True, exist_ok=True)
|
||
|
||
# 用户级 personalization 主文件(共享)
|
||
ensure_personalization_config(personalization_dir)
|
||
|
||
# 为 prompts/personalization 创建便捷访问(保持向后兼容:data_dir 下可作为符号链接)
|
||
for name, target in (("prompts", prompts_dir), ("personalization", personalization_dir)):
|
||
link = data_dir / name
|
||
if not link.exists():
|
||
try:
|
||
link.symlink_to(target, target_is_directory=True)
|
||
except Exception:
|
||
# 某些环境禁用 symlink,则忽略,使用共享目录路径显式传递
|
||
pass
|
||
|
||
# 上传隔离区(按用户/工作区划分)
|
||
from config import UPLOAD_QUARANTINE_SUBDIR
|
||
quarantine_root = Path(UPLOAD_QUARANTINE_SUBDIR).expanduser()
|
||
if not quarantine_root.is_absolute():
|
||
quarantine_root = (self.workspace_root.parent / UPLOAD_QUARANTINE_SUBDIR).resolve()
|
||
quarantine_dir = (quarantine_root / username / ws_id).resolve()
|
||
quarantine_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
return ApiUserWorkspace(
|
||
username=username,
|
||
workspace_id=ws_id,
|
||
root=work_root,
|
||
project_path=project_path,
|
||
data_dir=data_dir,
|
||
logs_dir=logs_dir,
|
||
uploads_dir=uploads_dir,
|
||
quarantine_dir=quarantine_dir,
|
||
shared_dir=shared_dir,
|
||
prompts_dir=prompts_dir,
|
||
personalization_dir=personalization_dir,
|
||
)
|
||
|
||
def list_workspaces(self, username: str) -> Dict[str, Dict]:
|
||
"""列出用户的所有工作区信息。"""
|
||
username = username.strip().lower()
|
||
user_root = (self.workspace_root / username / "workspaces").resolve()
|
||
if not user_root.exists():
|
||
return {}
|
||
result = {}
|
||
for p in sorted(user_root.iterdir()):
|
||
if not p.is_dir():
|
||
continue
|
||
ws_id = p.name
|
||
data_dir = p / "data"
|
||
project_path = p / "project"
|
||
result[ws_id] = {
|
||
"workspace_id": ws_id,
|
||
"project_path": str(project_path),
|
||
"data_dir": str(data_dir),
|
||
"has_conversations": (data_dir / "conversations").exists(),
|
||
}
|
||
return result
|
||
|
||
def delete_workspace(self, username: str, workspace_id: str) -> bool:
|
||
"""删除指定工作区(仅工作区目录,不删除共享 prompts/personalization)。"""
|
||
username = username.strip().lower()
|
||
ws_id = (workspace_id or "").strip()
|
||
if not ws_id:
|
||
return False
|
||
work_root = (self.workspace_root / username / "workspaces" / ws_id).resolve()
|
||
if not work_root.exists():
|
||
return False
|
||
import shutil
|
||
shutil.rmtree(work_root, ignore_errors=True)
|
||
return True
|
||
|
||
# ----------------------- internal helpers -----------------------
|
||
def _sha256(self, token: str) -> str:
|
||
return hashlib.sha256((token or "").encode("utf-8")).hexdigest()
|
||
|
||
def _load_users(self):
|
||
"""加载用户列表,读取 token_sha256;不支持明文存储。"""
|
||
if not self.users_file.exists():
|
||
self._save_users()
|
||
return
|
||
try:
|
||
raw = json.loads(self.users_file.read_text(encoding="utf-8"))
|
||
except json.JSONDecodeError as exc:
|
||
raise RuntimeError(f"无法解析 API 用户文件: {self.users_file} ({exc})")
|
||
|
||
users = raw.get("users", {}) if isinstance(raw, dict) else {}
|
||
for username, payload in users.items():
|
||
if not isinstance(payload, dict):
|
||
continue
|
||
token_sha = (payload.get("token_sha256") or "").strip()
|
||
if not token_sha:
|
||
continue
|
||
record = ApiUserRecord(
|
||
username=username.strip().lower(),
|
||
token_sha256=token_sha,
|
||
created_at=payload.get("created_at") or "",
|
||
note=payload.get("note") or "",
|
||
)
|
||
self._users[record.username] = record
|
||
|
||
def _save_users(self):
|
||
payload = {
|
||
"users": {
|
||
username: {
|
||
"token_sha256": record.token_sha256,
|
||
"created_at": record.created_at or datetime.utcnow().isoformat(),
|
||
"note": record.note,
|
||
}
|
||
for username, record in self._users.items()
|
||
}
|
||
}
|
||
self.users_file.parent.mkdir(parents=True, exist_ok=True)
|
||
self.users_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
|
||
|
||
__all__ = ["ApiUserManager", "ApiUserRecord", "ApiUserWorkspace"]
|