agent-Specialization/modules/personalization_manager.py

225 lines
7.4 KiB
Python

"""Utilities for managing per-user personalization settings."""
from __future__ import annotations
import json
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Union
try:
from config.limits import THINKING_FAST_INTERVAL
except ImportError:
THINKING_FAST_INTERVAL = 10
from core.tool_config import TOOL_CATEGORIES
PERSONALIZATION_FILENAME = "personalization.json"
MAX_SHORT_FIELD_LENGTH = 20
MAX_CONSIDERATION_LENGTH = 50
MAX_CONSIDERATION_ITEMS = 10
TONE_PRESETS = ["健谈", "幽默", "直言不讳", "鼓励性", "诗意", "企业商务", "打破常规", "同理心"]
THINKING_INTERVAL_MIN = 1
THINKING_INTERVAL_MAX = 50
DEFAULT_PERSONALIZATION_CONFIG: Dict[str, Any] = {
"enabled": False,
"self_identify": "",
"user_name": "",
"profession": "",
"tone": "",
"considerations": [],
"thinking_interval": None,
"disabled_tool_categories": [],
}
__all__ = [
"PERSONALIZATION_FILENAME",
"DEFAULT_PERSONALIZATION_CONFIG",
"TONE_PRESETS",
"MAX_CONSIDERATION_ITEMS",
"load_personalization_config",
"save_personalization_config",
"ensure_personalization_config",
"build_personalization_prompt",
"sanitize_personalization_payload",
]
PathLike = Union[str, Path]
def _ensure_parent(path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
def _to_path(base: PathLike) -> Path:
base_path = Path(base).expanduser()
if base_path.is_dir():
return base_path / PERSONALIZATION_FILENAME
return base_path
def ensure_personalization_config(base_dir: PathLike) -> Dict[str, Any]:
"""Ensure the personalization file exists and return its content."""
path = _to_path(base_dir)
_ensure_parent(path)
if not path.exists():
with open(path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_PERSONALIZATION_CONFIG, f, ensure_ascii=False, indent=2)
return deepcopy(DEFAULT_PERSONALIZATION_CONFIG)
return load_personalization_config(base_dir)
def load_personalization_config(base_dir: PathLike) -> Dict[str, Any]:
"""Load personalization config; fall back to defaults on errors."""
path = _to_path(base_dir)
_ensure_parent(path)
if not path.exists():
return ensure_personalization_config(base_dir)
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return sanitize_personalization_payload(data)
except (json.JSONDecodeError, OSError):
# 重置为默认配置,避免错误阻塞
with open(path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_PERSONALIZATION_CONFIG, f, ensure_ascii=False, indent=2)
return deepcopy(DEFAULT_PERSONALIZATION_CONFIG)
def sanitize_personalization_payload(
payload: Optional[Dict[str, Any]],
fallback: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Normalize payload structure and clamp field lengths."""
base = deepcopy(DEFAULT_PERSONALIZATION_CONFIG)
if fallback:
base.update(fallback)
data = payload or {}
allowed_tool_categories = set(TOOL_CATEGORIES.keys())
def _resolve_short_field(key: str) -> str:
if key in data:
return _sanitize_short_field(data.get(key))
return _sanitize_short_field(base.get(key))
base["enabled"] = bool(data.get("enabled", base["enabled"]))
base["self_identify"] = _resolve_short_field("self_identify")
base["user_name"] = _resolve_short_field("user_name")
base["profession"] = _resolve_short_field("profession")
base["tone"] = _resolve_short_field("tone")
if "considerations" in data:
base["considerations"] = _sanitize_considerations(data.get("considerations"))
else:
base["considerations"] = _sanitize_considerations(base.get("considerations", []))
if "thinking_interval" in data:
base["thinking_interval"] = _sanitize_thinking_interval(data.get("thinking_interval"))
else:
base["thinking_interval"] = _sanitize_thinking_interval(base.get("thinking_interval"))
if "disabled_tool_categories" in data:
base["disabled_tool_categories"] = _sanitize_tool_categories(data.get("disabled_tool_categories"), allowed_tool_categories)
else:
base["disabled_tool_categories"] = _sanitize_tool_categories(base.get("disabled_tool_categories"), allowed_tool_categories)
return base
def save_personalization_config(base_dir: PathLike, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Persist sanitized personalization config and return it."""
existing = load_personalization_config(base_dir)
config = sanitize_personalization_payload(payload, fallback=existing)
path = _to_path(base_dir)
_ensure_parent(path)
with open(path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
return config
def build_personalization_prompt(
config: Optional[Dict[str, Any]],
include_header: bool = True
) -> Optional[str]:
"""Generate the personalization prompt text based on config."""
if not config or not config.get("enabled"):
return None
lines = []
if include_header:
lines.append("用户的个性化数据,请回答时务必参照这些信息")
if config.get("self_identify"):
lines.append(f"用户希望你自称:{config['self_identify']}")
if config.get("user_name"):
lines.append(f"用户希望你称呼为:{config['user_name']}")
if config.get("profession"):
lines.append(f"用户的职业是:{config['profession']}")
if config.get("tone"):
lines.append(f"用户希望你使用 {config['tone']} 的语气与TA交流")
considerations: Iterable[str] = config.get("considerations") or []
considerations = [item for item in considerations if item]
if considerations:
lines.append("用户希望你在回答问题时必须考虑的信息是:")
for idx, item in enumerate(considerations, 1):
lines.append(f"{idx}. {item}")
if len(lines) == (1 if include_header else 0):
# 没有任何有效内容时不注入
return None
return "\n".join(lines)
def _sanitize_short_field(value: Optional[str]) -> str:
if not value:
return ""
text = str(value).strip()
if not text:
return ""
return text[:MAX_SHORT_FIELD_LENGTH]
def _sanitize_considerations(value: Any) -> list:
if not isinstance(value, list):
return []
cleaned = []
for item in value:
if not isinstance(item, str):
continue
text = item.strip()
if not text:
continue
cleaned.append(text[:MAX_CONSIDERATION_LENGTH])
if len(cleaned) >= MAX_CONSIDERATION_ITEMS:
break
return cleaned
def _sanitize_thinking_interval(value: Any) -> Optional[int]:
if value is None or value == "":
return None
try:
interval = int(value)
except (TypeError, ValueError):
return None
interval = max(THINKING_INTERVAL_MIN, min(THINKING_INTERVAL_MAX, interval))
if interval == THINKING_FAST_INTERVAL:
return None
return interval
def _sanitize_tool_categories(value: Any, allowed: set) -> list:
if not isinstance(value, list):
return []
result = []
for item in value:
if not isinstance(item, str):
continue
candidate = item.strip()
if not candidate or candidate not in allowed:
continue
if candidate not in result:
result.append(candidate)
return result