283 lines
10 KiB
Python
283 lines
10 KiB
Python
"""User and workspace management utilities for multi-user support."""
|
|
|
|
import json
|
|
import re
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
from werkzeug.security import check_password_hash, generate_password_hash
|
|
|
|
from config import (
|
|
ADMIN_PASSWORD_HASH,
|
|
ADMIN_USERNAME,
|
|
INVITE_CODES_FILE,
|
|
USER_SPACE_DIR,
|
|
USERS_DB_FILE,
|
|
UPLOAD_QUARANTINE_SUBDIR,
|
|
)
|
|
from modules.personalization_manager import ensure_personalization_config
|
|
|
|
|
|
@dataclass
|
|
class UserRecord:
|
|
username: str
|
|
email: str
|
|
password_hash: str
|
|
created_at: str
|
|
invite_code: Optional[str] = None
|
|
role: str = "user"
|
|
|
|
|
|
@dataclass
|
|
class UserWorkspace:
|
|
username: str
|
|
root: Path
|
|
project_path: Path
|
|
data_dir: Path
|
|
logs_dir: Path
|
|
uploads_dir: Path
|
|
quarantine_dir: Path
|
|
|
|
|
|
class UserManager:
|
|
"""Handle user registration, authentication and workspace provisioning."""
|
|
|
|
USERNAME_REGEX = re.compile(r"^[a-z0-9_\-]{3,32}$")
|
|
|
|
def __init__(
|
|
self,
|
|
users_file: str = USERS_DB_FILE,
|
|
invite_codes_file: str = INVITE_CODES_FILE,
|
|
workspace_root: str = USER_SPACE_DIR,
|
|
):
|
|
self.users_file = Path(users_file)
|
|
self.invite_codes_file = Path(invite_codes_file)
|
|
self.workspace_root = Path(workspace_root).expanduser().resolve()
|
|
self.workspace_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
self._users: Dict[str, UserRecord] = {}
|
|
self._invites: Dict[str, Dict] = {}
|
|
self._email_map: Dict[str, str] = {}
|
|
|
|
self._load_users()
|
|
self._load_invite_codes()
|
|
self._ensure_admin_user()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
def register_user(
|
|
self, username: str, email: str, password: str, invite_code: str
|
|
) -> Tuple[UserRecord, UserWorkspace]:
|
|
username = self._normalize_username(username)
|
|
email = self._normalize_email(email)
|
|
password = password.strip()
|
|
invite_code = (invite_code or "").strip()
|
|
|
|
if not password or len(password) < 8:
|
|
raise ValueError("密码长度至少 8 位。")
|
|
if username in self._users:
|
|
raise ValueError("该用户名已被注册。")
|
|
if email in self._email_map:
|
|
raise ValueError("该邮箱已被注册。")
|
|
|
|
invite_entry = self._validate_invite_code(invite_code)
|
|
password_hash = generate_password_hash(password)
|
|
created_at = datetime.utcnow().isoformat()
|
|
|
|
record = UserRecord(
|
|
username=username,
|
|
email=email,
|
|
password_hash=password_hash,
|
|
created_at=created_at,
|
|
invite_code=invite_entry["code"],
|
|
)
|
|
self._users[username] = record
|
|
self._index_user(record)
|
|
self._save_users()
|
|
self._consume_invite(invite_entry)
|
|
|
|
workspace = self.ensure_user_workspace(username)
|
|
return record, workspace
|
|
|
|
def authenticate(self, email: str, password: str) -> Optional[UserRecord]:
|
|
email = (email or "").strip().lower()
|
|
username = self._email_map.get(email)
|
|
if not username:
|
|
return None
|
|
record = self._users.get(username)
|
|
if not record or not record.password_hash:
|
|
return None
|
|
if not check_password_hash(record.password_hash, password or ""):
|
|
return None
|
|
return record
|
|
|
|
def get_user(self, username: str) -> Optional[UserRecord]:
|
|
return self._users.get((username or "").strip().lower())
|
|
|
|
def ensure_user_workspace(self, username: str) -> UserWorkspace:
|
|
username = self._normalize_username(username)
|
|
root = (self.workspace_root / username).resolve()
|
|
project_path = root / "project"
|
|
data_dir = root / "data"
|
|
logs_dir = root / "logs"
|
|
uploads_dir = project_path / "user_upload"
|
|
|
|
for path in [project_path, data_dir, logs_dir, uploads_dir]:
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# 初始化数据子目录
|
|
(data_dir / "conversations").mkdir(parents=True, exist_ok=True)
|
|
(data_dir / "backups").mkdir(parents=True, exist_ok=True)
|
|
ensure_personalization_config(data_dir)
|
|
|
|
quarantine_root = Path(UPLOAD_QUARANTINE_SUBDIR).expanduser()
|
|
if not quarantine_root.is_absolute():
|
|
quarantine_root = (self.workspace_root.parent / UPLOAD_QUARANTINE_SUBDIR).resolve()
|
|
quarantine_dir = (quarantine_root / username).resolve()
|
|
quarantine_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
return UserWorkspace(
|
|
username=username,
|
|
root=root,
|
|
project_path=project_path,
|
|
data_dir=data_dir,
|
|
logs_dir=logs_dir,
|
|
uploads_dir=uploads_dir,
|
|
quarantine_dir=quarantine_dir,
|
|
)
|
|
|
|
def list_invite_codes(self):
|
|
return list(self._invites.values())
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
def _normalize_username(self, username: str) -> str:
|
|
candidate = (username or "").strip().lower()
|
|
if not candidate or not self.USERNAME_REGEX.match(candidate):
|
|
raise ValueError("用户名需为 3-32 位小写字母、数字、下划线或连字符。")
|
|
return candidate
|
|
|
|
def _normalize_email(self, email: str) -> str:
|
|
email = (email or "").strip().lower()
|
|
if "@" not in email or len(email) < 6:
|
|
raise ValueError("邮箱格式不正确。")
|
|
return email
|
|
|
|
def _index_user(self, record: UserRecord):
|
|
email = (record.email or '').strip().lower()
|
|
if email:
|
|
self._email_map[email] = record.username
|
|
|
|
def _load_users(self):
|
|
if not self.users_file.exists():
|
|
self._save_users()
|
|
return
|
|
|
|
try:
|
|
with open(self.users_file, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
if isinstance(data, dict):
|
|
raw_users = data.get("users", {})
|
|
elif isinstance(data, list):
|
|
raw_users = {item.get("username"): item for item in data if isinstance(item, dict) and item.get("username")}
|
|
else:
|
|
raw_users = {}
|
|
for username, payload in raw_users.items():
|
|
record = UserRecord(
|
|
username=username,
|
|
email=payload.get("email", ""),
|
|
password_hash=payload.get("password_hash", ""),
|
|
created_at=payload.get("created_at", ""),
|
|
invite_code=payload.get("invite_code"),
|
|
role=payload.get("role", "user"),
|
|
)
|
|
self._users[username] = record
|
|
self._index_user(record)
|
|
except json.JSONDecodeError:
|
|
raise RuntimeError(f"无法解析用户数据文件: {self.users_file}")
|
|
|
|
def _save_users(self):
|
|
payload = {
|
|
"users": {
|
|
username: {
|
|
"email": record.email,
|
|
"password_hash": record.password_hash,
|
|
"created_at": record.created_at,
|
|
"invite_code": record.invite_code,
|
|
"role": record.role,
|
|
}
|
|
for username, record in self._users.items()
|
|
}
|
|
}
|
|
self.users_file.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(self.users_file, "w", encoding="utf-8") as f:
|
|
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
|
|
def _load_invite_codes(self):
|
|
if not self.invite_codes_file.exists():
|
|
self._save_invite_codes({})
|
|
return
|
|
|
|
try:
|
|
with open(self.invite_codes_file, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
if isinstance(data, dict):
|
|
codes = data.get("codes", [])
|
|
elif isinstance(data, list):
|
|
codes = data
|
|
else:
|
|
codes = []
|
|
self._invites = {item["code"]: item for item in codes if isinstance(item, dict) and "code" in item}
|
|
except json.JSONDecodeError:
|
|
raise RuntimeError(f"无法解析邀请码文件: {self.invite_codes_file}")
|
|
|
|
def _save_invite_codes(self, overrides: Optional[Dict[str, Dict]] = None):
|
|
codes = overrides or self._invites
|
|
payload = {"codes": list(codes.values())}
|
|
self.invite_codes_file.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(self.invite_codes_file, "w", encoding="utf-8") as f:
|
|
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
|
|
def _validate_invite_code(self, code: str) -> Dict:
|
|
if not code:
|
|
raise ValueError("邀请码不能为空。")
|
|
entry = self._invites.get(code)
|
|
if not entry:
|
|
raise ValueError("邀请码不存在或已失效。")
|
|
|
|
remaining = entry.get("remaining")
|
|
if remaining is not None and remaining <= 0:
|
|
raise ValueError("邀请码已被使用。")
|
|
return entry
|
|
|
|
def _consume_invite(self, entry: Dict):
|
|
if entry.get("remaining") is None:
|
|
return
|
|
entry["remaining"] = max(0, entry["remaining"] - 1)
|
|
self._save_invite_codes()
|
|
|
|
def _ensure_admin_user(self):
|
|
admin_name = (ADMIN_USERNAME or "").strip().lower()
|
|
if not admin_name or not ADMIN_PASSWORD_HASH:
|
|
return
|
|
|
|
if admin_name in self._users:
|
|
return
|
|
|
|
record = UserRecord(
|
|
username=admin_name,
|
|
email=f"{admin_name}@local",
|
|
password_hash=ADMIN_PASSWORD_HASH,
|
|
created_at=datetime.utcnow().isoformat(),
|
|
invite_code=None,
|
|
role="admin",
|
|
)
|
|
self._users[admin_name] = record
|
|
self._index_user(record)
|
|
self._save_users()
|
|
self.ensure_user_workspace(admin_name)
|