252 lines
8.4 KiB
Python
252 lines
8.4 KiB
Python
"""Utilities for managing per-user personalization settings."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, Optional, Union
|
|
|
|
try:
|
|
from config.limits import THINKING_FAST_INTERVAL
|
|
except ImportError:
|
|
THINKING_FAST_INTERVAL = 10
|
|
|
|
from core.tool_config import TOOL_CATEGORIES
|
|
|
|
ALLOWED_RUN_MODES = {"fast", "thinking", "deep"}
|
|
|
|
PERSONALIZATION_FILENAME = "personalization.json"
|
|
MAX_SHORT_FIELD_LENGTH = 20
|
|
MAX_CONSIDERATION_LENGTH = 50
|
|
MAX_CONSIDERATION_ITEMS = 10
|
|
TONE_PRESETS = ["健谈", "幽默", "直言不讳", "鼓励性", "诗意", "企业商务", "打破常规", "同理心"]
|
|
THINKING_INTERVAL_MIN = 1
|
|
THINKING_INTERVAL_MAX = 50
|
|
|
|
DEFAULT_PERSONALIZATION_CONFIG: Dict[str, Any] = {
|
|
"enabled": False,
|
|
"self_identify": "",
|
|
"user_name": "",
|
|
"profession": "",
|
|
"tone": "",
|
|
"considerations": [],
|
|
"thinking_interval": None,
|
|
"disabled_tool_categories": [],
|
|
"default_run_mode": None,
|
|
"auto_generate_title": True,
|
|
"tool_intent_enabled": True,
|
|
}
|
|
|
|
__all__ = [
|
|
"PERSONALIZATION_FILENAME",
|
|
"DEFAULT_PERSONALIZATION_CONFIG",
|
|
"TONE_PRESETS",
|
|
"MAX_CONSIDERATION_ITEMS",
|
|
"load_personalization_config",
|
|
"save_personalization_config",
|
|
"ensure_personalization_config",
|
|
"build_personalization_prompt",
|
|
"sanitize_personalization_payload",
|
|
]
|
|
|
|
|
|
PathLike = Union[str, Path]
|
|
|
|
|
|
def _ensure_parent(path: Path) -> None:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def _to_path(base: PathLike) -> Path:
|
|
base_path = Path(base).expanduser()
|
|
if base_path.is_dir():
|
|
return base_path / PERSONALIZATION_FILENAME
|
|
return base_path
|
|
|
|
|
|
def ensure_personalization_config(base_dir: PathLike) -> Dict[str, Any]:
|
|
"""Ensure the personalization file exists and return its content."""
|
|
path = _to_path(base_dir)
|
|
_ensure_parent(path)
|
|
if not path.exists():
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(DEFAULT_PERSONALIZATION_CONFIG, f, ensure_ascii=False, indent=2)
|
|
return deepcopy(DEFAULT_PERSONALIZATION_CONFIG)
|
|
return load_personalization_config(base_dir)
|
|
|
|
|
|
def load_personalization_config(base_dir: PathLike) -> Dict[str, Any]:
|
|
"""Load personalization config; fall back to defaults on errors."""
|
|
path = _to_path(base_dir)
|
|
_ensure_parent(path)
|
|
if not path.exists():
|
|
return ensure_personalization_config(base_dir)
|
|
try:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
return sanitize_personalization_payload(data)
|
|
except (json.JSONDecodeError, OSError):
|
|
# 重置为默认配置,避免错误阻塞
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(DEFAULT_PERSONALIZATION_CONFIG, f, ensure_ascii=False, indent=2)
|
|
return deepcopy(DEFAULT_PERSONALIZATION_CONFIG)
|
|
|
|
|
|
def sanitize_personalization_payload(
|
|
payload: Optional[Dict[str, Any]],
|
|
fallback: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""Normalize payload structure and clamp field lengths."""
|
|
base = deepcopy(DEFAULT_PERSONALIZATION_CONFIG)
|
|
if fallback:
|
|
base.update(fallback)
|
|
data = payload or {}
|
|
allowed_tool_categories = set(TOOL_CATEGORIES.keys())
|
|
|
|
def _resolve_short_field(key: str) -> str:
|
|
if key in data:
|
|
return _sanitize_short_field(data.get(key))
|
|
return _sanitize_short_field(base.get(key))
|
|
|
|
base["enabled"] = bool(data.get("enabled", base["enabled"]))
|
|
base["auto_generate_title"] = bool(data.get("auto_generate_title", base["auto_generate_title"]))
|
|
base["self_identify"] = _resolve_short_field("self_identify")
|
|
base["user_name"] = _resolve_short_field("user_name")
|
|
base["profession"] = _resolve_short_field("profession")
|
|
base["tone"] = _resolve_short_field("tone")
|
|
if "considerations" in data:
|
|
base["considerations"] = _sanitize_considerations(data.get("considerations"))
|
|
else:
|
|
base["considerations"] = _sanitize_considerations(base.get("considerations", []))
|
|
|
|
if "thinking_interval" in data:
|
|
base["thinking_interval"] = _sanitize_thinking_interval(data.get("thinking_interval"))
|
|
else:
|
|
base["thinking_interval"] = _sanitize_thinking_interval(base.get("thinking_interval"))
|
|
|
|
# 工具意图提示开关
|
|
if "tool_intent_enabled" in data:
|
|
base["tool_intent_enabled"] = bool(data.get("tool_intent_enabled"))
|
|
else:
|
|
base["tool_intent_enabled"] = bool(base.get("tool_intent_enabled"))
|
|
|
|
if "disabled_tool_categories" in data:
|
|
base["disabled_tool_categories"] = _sanitize_tool_categories(data.get("disabled_tool_categories"), allowed_tool_categories)
|
|
else:
|
|
base["disabled_tool_categories"] = _sanitize_tool_categories(base.get("disabled_tool_categories"), allowed_tool_categories)
|
|
|
|
if "default_run_mode" in data:
|
|
base["default_run_mode"] = _sanitize_run_mode(data.get("default_run_mode"))
|
|
else:
|
|
base["default_run_mode"] = _sanitize_run_mode(base.get("default_run_mode"))
|
|
return base
|
|
|
|
|
|
def save_personalization_config(base_dir: PathLike, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Persist sanitized personalization config and return it."""
|
|
existing = load_personalization_config(base_dir)
|
|
config = sanitize_personalization_payload(payload, fallback=existing)
|
|
path = _to_path(base_dir)
|
|
_ensure_parent(path)
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(config, f, ensure_ascii=False, indent=2)
|
|
return config
|
|
|
|
|
|
def build_personalization_prompt(
|
|
config: Optional[Dict[str, Any]],
|
|
include_header: bool = True
|
|
) -> Optional[str]:
|
|
"""Generate the personalization prompt text based on config."""
|
|
if not config or not config.get("enabled"):
|
|
return None
|
|
|
|
lines = []
|
|
if include_header:
|
|
lines.append("用户的个性化数据,请回答时务必参照这些信息")
|
|
|
|
if config.get("self_identify"):
|
|
lines.append(f"用户希望你自称:{config['self_identify']}")
|
|
if config.get("user_name"):
|
|
lines.append(f"用户希望你称呼为:{config['user_name']}")
|
|
if config.get("profession"):
|
|
lines.append(f"用户的职业是:{config['profession']}")
|
|
if config.get("tone"):
|
|
lines.append(f"用户希望你使用 {config['tone']} 的语气与TA交流")
|
|
|
|
considerations: Iterable[str] = config.get("considerations") or []
|
|
considerations = [item for item in considerations if item]
|
|
if considerations:
|
|
lines.append("用户希望你在回答问题时必须考虑的信息是:")
|
|
for idx, item in enumerate(considerations, 1):
|
|
lines.append(f"{idx}. {item}")
|
|
|
|
if len(lines) == (1 if include_header else 0):
|
|
# 没有任何有效内容时不注入
|
|
return None
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _sanitize_short_field(value: Optional[str]) -> str:
|
|
if not value:
|
|
return ""
|
|
text = str(value).strip()
|
|
if not text:
|
|
return ""
|
|
return text[:MAX_SHORT_FIELD_LENGTH]
|
|
|
|
|
|
def _sanitize_considerations(value: Any) -> list:
|
|
if not isinstance(value, list):
|
|
return []
|
|
cleaned = []
|
|
for item in value:
|
|
if not isinstance(item, str):
|
|
continue
|
|
text = item.strip()
|
|
if not text:
|
|
continue
|
|
cleaned.append(text[:MAX_CONSIDERATION_LENGTH])
|
|
if len(cleaned) >= MAX_CONSIDERATION_ITEMS:
|
|
break
|
|
return cleaned
|
|
|
|
|
|
def _sanitize_thinking_interval(value: Any) -> Optional[int]:
|
|
if value is None or value == "":
|
|
return None
|
|
try:
|
|
interval = int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
interval = max(THINKING_INTERVAL_MIN, min(THINKING_INTERVAL_MAX, interval))
|
|
if interval == THINKING_FAST_INTERVAL:
|
|
return None
|
|
return interval
|
|
|
|
|
|
def _sanitize_tool_categories(value: Any, allowed: set) -> list:
|
|
if not isinstance(value, list):
|
|
return []
|
|
result = []
|
|
for item in value:
|
|
if not isinstance(item, str):
|
|
continue
|
|
candidate = item.strip()
|
|
if not candidate or candidate not in allowed:
|
|
continue
|
|
if candidate not in result:
|
|
result.append(candidate)
|
|
return result
|
|
|
|
|
|
def _sanitize_run_mode(value: Any) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
candidate = value.strip().lower()
|
|
if candidate in ALLOWED_RUN_MODES:
|
|
return candidate
|
|
return None
|