"""User and workspace management utilities for multi-user support.""" import json import re from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, Optional, Tuple from werkzeug.security import check_password_hash, generate_password_hash from config import ( ADMIN_PASSWORD_HASH, ADMIN_USERNAME, INVITE_CODES_FILE, USER_SPACE_DIR, USERS_DB_FILE, ) @dataclass class UserRecord: username: str email: str password_hash: str created_at: str invite_code: Optional[str] = None role: str = "user" @dataclass class UserWorkspace: username: str root: Path project_path: Path data_dir: Path logs_dir: Path uploads_dir: Path class UserManager: """Handle user registration, authentication and workspace provisioning.""" USERNAME_REGEX = re.compile(r"^[a-z0-9_\-]{3,32}$") def __init__( self, users_file: str = USERS_DB_FILE, invite_codes_file: str = INVITE_CODES_FILE, workspace_root: str = USER_SPACE_DIR, ): self.users_file = Path(users_file) self.invite_codes_file = Path(invite_codes_file) self.workspace_root = Path(workspace_root).expanduser().resolve() self.workspace_root.mkdir(parents=True, exist_ok=True) self._users: Dict[str, UserRecord] = {} self._invites: Dict[str, Dict] = {} self._email_map: Dict[str, str] = {} self._load_users() self._load_invite_codes() self._ensure_admin_user() # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def register_user( self, username: str, email: str, password: str, invite_code: str ) -> Tuple[UserRecord, UserWorkspace]: username = self._normalize_username(username) email = self._normalize_email(email) password = password.strip() invite_code = (invite_code or "").strip() if not password or len(password) < 8: raise ValueError("密码长度至少 8 位。") if username in self._users: raise ValueError("该用户名已被注册。") if email in self._email_map: raise ValueError("该邮箱已被注册。") invite_entry = self._validate_invite_code(invite_code) password_hash = generate_password_hash(password) created_at = datetime.utcnow().isoformat() record = UserRecord( username=username, email=email, password_hash=password_hash, created_at=created_at, invite_code=invite_entry["code"], ) self._users[username] = record self._index_user(record) self._save_users() self._consume_invite(invite_entry) workspace = self.ensure_user_workspace(username) return record, workspace def authenticate(self, email: str, password: str) -> Optional[UserRecord]: email = (email or "").strip().lower() username = self._email_map.get(email) if not username: return None record = self._users.get(username) if not record or not record.password_hash: return None if not check_password_hash(record.password_hash, password or ""): return None return record def get_user(self, username: str) -> Optional[UserRecord]: return self._users.get((username or "").strip().lower()) def ensure_user_workspace(self, username: str) -> UserWorkspace: username = self._normalize_username(username) root = (self.workspace_root / username).resolve() project_path = root / "project" data_dir = root / "data" logs_dir = root / "logs" uploads_dir = project_path / "user_upload" for path in [project_path, data_dir, logs_dir, uploads_dir]: path.mkdir(parents=True, exist_ok=True) # 初始化数据子目录 (data_dir / "conversations").mkdir(parents=True, exist_ok=True) (data_dir / "backups").mkdir(parents=True, exist_ok=True) return UserWorkspace( username=username, root=root, project_path=project_path, data_dir=data_dir, logs_dir=logs_dir, uploads_dir=uploads_dir, ) def list_invite_codes(self): return list(self._invites.values()) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _normalize_username(self, username: str) -> str: candidate = (username or "").strip().lower() if not candidate or not self.USERNAME_REGEX.match(candidate): raise ValueError("用户名需为 3-32 位小写字母、数字、下划线或连字符。") return candidate def _normalize_email(self, email: str) -> str: email = (email or "").strip().lower() if "@" not in email or len(email) < 6: raise ValueError("邮箱格式不正确。") return email def _index_user(self, record: UserRecord): email = (record.email or '').strip().lower() if email: self._email_map[email] = record.username def _load_users(self): if not self.users_file.exists(): self._save_users() return try: with open(self.users_file, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, dict): raw_users = data.get("users", {}) elif isinstance(data, list): raw_users = {item.get("username"): item for item in data if isinstance(item, dict) and item.get("username")} else: raw_users = {} for username, payload in raw_users.items(): record = UserRecord( username=username, email=payload.get("email", ""), password_hash=payload.get("password_hash", ""), created_at=payload.get("created_at", ""), invite_code=payload.get("invite_code"), role=payload.get("role", "user"), ) self._users[username] = record self._index_user(record) except json.JSONDecodeError: raise RuntimeError(f"无法解析用户数据文件: {self.users_file}") def _save_users(self): payload = { "users": { username: { "email": record.email, "password_hash": record.password_hash, "created_at": record.created_at, "invite_code": record.invite_code, "role": record.role, } for username, record in self._users.items() } } self.users_file.parent.mkdir(parents=True, exist_ok=True) with open(self.users_file, "w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) def _load_invite_codes(self): if not self.invite_codes_file.exists(): self._save_invite_codes({}) return try: with open(self.invite_codes_file, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, dict): codes = data.get("codes", []) elif isinstance(data, list): codes = data else: codes = [] self._invites = {item["code"]: item for item in codes if isinstance(item, dict) and "code" in item} except json.JSONDecodeError: raise RuntimeError(f"无法解析邀请码文件: {self.invite_codes_file}") def _save_invite_codes(self, overrides: Optional[Dict[str, Dict]] = None): codes = overrides or self._invites payload = {"codes": list(codes.values())} self.invite_codes_file.parent.mkdir(parents=True, exist_ok=True) with open(self.invite_codes_file, "w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) def _validate_invite_code(self, code: str) -> Dict: if not code: raise ValueError("邀请码不能为空。") entry = self._invites.get(code) if not entry: raise ValueError("邀请码不存在或已失效。") remaining = entry.get("remaining") if remaining is not None and remaining <= 0: raise ValueError("邀请码已被使用。") return entry def _consume_invite(self, entry: Dict): if entry.get("remaining") is None: return entry["remaining"] = max(0, entry["remaining"] - 1) self._save_invite_codes() def _ensure_admin_user(self): admin_name = (ADMIN_USERNAME or "").strip().lower() if not admin_name or not ADMIN_PASSWORD_HASH: return if admin_name in self._users: return record = UserRecord( username=admin_name, email=f"{admin_name}@local", password_hash=ADMIN_PASSWORD_HASH, created_at=datetime.utcnow().isoformat(), invite_code=None, role="admin", ) self._users[admin_name] = record self._index_user(record) self._save_users() self.ensure_user_workspace(admin_name)