agent-Specialization/modules/user_manager.py

274 lines
9.6 KiB
Python

"""User and workspace management utilities for multi-user support."""
import json
import re
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple
from werkzeug.security import check_password_hash, generate_password_hash
from config import (
ADMIN_PASSWORD_HASH,
ADMIN_USERNAME,
INVITE_CODES_FILE,
USER_SPACE_DIR,
USERS_DB_FILE,
)
from modules.personalization_manager import ensure_personalization_config
@dataclass
class UserRecord:
username: str
email: str
password_hash: str
created_at: str
invite_code: Optional[str] = None
role: str = "user"
@dataclass
class UserWorkspace:
username: str
root: Path
project_path: Path
data_dir: Path
logs_dir: Path
uploads_dir: Path
class UserManager:
"""Handle user registration, authentication and workspace provisioning."""
USERNAME_REGEX = re.compile(r"^[a-z0-9_\-]{3,32}$")
def __init__(
self,
users_file: str = USERS_DB_FILE,
invite_codes_file: str = INVITE_CODES_FILE,
workspace_root: str = USER_SPACE_DIR,
):
self.users_file = Path(users_file)
self.invite_codes_file = Path(invite_codes_file)
self.workspace_root = Path(workspace_root).expanduser().resolve()
self.workspace_root.mkdir(parents=True, exist_ok=True)
self._users: Dict[str, UserRecord] = {}
self._invites: Dict[str, Dict] = {}
self._email_map: Dict[str, str] = {}
self._load_users()
self._load_invite_codes()
self._ensure_admin_user()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def register_user(
self, username: str, email: str, password: str, invite_code: str
) -> Tuple[UserRecord, UserWorkspace]:
username = self._normalize_username(username)
email = self._normalize_email(email)
password = password.strip()
invite_code = (invite_code or "").strip()
if not password or len(password) < 8:
raise ValueError("密码长度至少 8 位。")
if username in self._users:
raise ValueError("该用户名已被注册。")
if email in self._email_map:
raise ValueError("该邮箱已被注册。")
invite_entry = self._validate_invite_code(invite_code)
password_hash = generate_password_hash(password)
created_at = datetime.utcnow().isoformat()
record = UserRecord(
username=username,
email=email,
password_hash=password_hash,
created_at=created_at,
invite_code=invite_entry["code"],
)
self._users[username] = record
self._index_user(record)
self._save_users()
self._consume_invite(invite_entry)
workspace = self.ensure_user_workspace(username)
return record, workspace
def authenticate(self, email: str, password: str) -> Optional[UserRecord]:
email = (email or "").strip().lower()
username = self._email_map.get(email)
if not username:
return None
record = self._users.get(username)
if not record or not record.password_hash:
return None
if not check_password_hash(record.password_hash, password or ""):
return None
return record
def get_user(self, username: str) -> Optional[UserRecord]:
return self._users.get((username or "").strip().lower())
def ensure_user_workspace(self, username: str) -> UserWorkspace:
username = self._normalize_username(username)
root = (self.workspace_root / username).resolve()
project_path = root / "project"
data_dir = root / "data"
logs_dir = root / "logs"
uploads_dir = project_path / "user_upload"
for path in [project_path, data_dir, logs_dir, uploads_dir]:
path.mkdir(parents=True, exist_ok=True)
# 初始化数据子目录
(data_dir / "conversations").mkdir(parents=True, exist_ok=True)
(data_dir / "backups").mkdir(parents=True, exist_ok=True)
ensure_personalization_config(data_dir)
return UserWorkspace(
username=username,
root=root,
project_path=project_path,
data_dir=data_dir,
logs_dir=logs_dir,
uploads_dir=uploads_dir,
)
def list_invite_codes(self):
return list(self._invites.values())
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _normalize_username(self, username: str) -> str:
candidate = (username or "").strip().lower()
if not candidate or not self.USERNAME_REGEX.match(candidate):
raise ValueError("用户名需为 3-32 位小写字母、数字、下划线或连字符。")
return candidate
def _normalize_email(self, email: str) -> str:
email = (email or "").strip().lower()
if "@" not in email or len(email) < 6:
raise ValueError("邮箱格式不正确。")
return email
def _index_user(self, record: UserRecord):
email = (record.email or '').strip().lower()
if email:
self._email_map[email] = record.username
def _load_users(self):
if not self.users_file.exists():
self._save_users()
return
try:
with open(self.users_file, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
raw_users = data.get("users", {})
elif isinstance(data, list):
raw_users = {item.get("username"): item for item in data if isinstance(item, dict) and item.get("username")}
else:
raw_users = {}
for username, payload in raw_users.items():
record = UserRecord(
username=username,
email=payload.get("email", ""),
password_hash=payload.get("password_hash", ""),
created_at=payload.get("created_at", ""),
invite_code=payload.get("invite_code"),
role=payload.get("role", "user"),
)
self._users[username] = record
self._index_user(record)
except json.JSONDecodeError:
raise RuntimeError(f"无法解析用户数据文件: {self.users_file}")
def _save_users(self):
payload = {
"users": {
username: {
"email": record.email,
"password_hash": record.password_hash,
"created_at": record.created_at,
"invite_code": record.invite_code,
"role": record.role,
}
for username, record in self._users.items()
}
}
self.users_file.parent.mkdir(parents=True, exist_ok=True)
with open(self.users_file, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def _load_invite_codes(self):
if not self.invite_codes_file.exists():
self._save_invite_codes({})
return
try:
with open(self.invite_codes_file, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
codes = data.get("codes", [])
elif isinstance(data, list):
codes = data
else:
codes = []
self._invites = {item["code"]: item for item in codes if isinstance(item, dict) and "code" in item}
except json.JSONDecodeError:
raise RuntimeError(f"无法解析邀请码文件: {self.invite_codes_file}")
def _save_invite_codes(self, overrides: Optional[Dict[str, Dict]] = None):
codes = overrides or self._invites
payload = {"codes": list(codes.values())}
self.invite_codes_file.parent.mkdir(parents=True, exist_ok=True)
with open(self.invite_codes_file, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def _validate_invite_code(self, code: str) -> Dict:
if not code:
raise ValueError("邀请码不能为空。")
entry = self._invites.get(code)
if not entry:
raise ValueError("邀请码不存在或已失效。")
remaining = entry.get("remaining")
if remaining is not None and remaining <= 0:
raise ValueError("邀请码已被使用。")
return entry
def _consume_invite(self, entry: Dict):
if entry.get("remaining") is None:
return
entry["remaining"] = max(0, entry["remaining"] - 1)
self._save_invite_codes()
def _ensure_admin_user(self):
admin_name = (ADMIN_USERNAME or "").strip().lower()
if not admin_name or not ADMIN_PASSWORD_HASH:
return
if admin_name in self._users:
return
record = UserRecord(
username=admin_name,
email=f"{admin_name}@local",
password_hash=ADMIN_PASSWORD_HASH,
created_at=datetime.utcnow().isoformat(),
invite_code=None,
role="admin",
)
self._users[admin_name] = record
self._index_user(record)
self._save_users()
self.ensure_user_workspace(admin_name)