agent-Specialization/server/security.py

262 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""安全相关工具限流、CSRF、Socket Token、工具结果压缩等。"""
from __future__ import annotations
import hmac
import secrets
import time
from typing import Dict, Any, Optional, Tuple
from flask import request, session, jsonify
from functools import wraps
from . import state
# 便捷别名
def get_client_ip() -> str:
"""获取客户端IP支持 X-Forwarded-For."""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.remote_addr or "unknown"
def resolve_identifier(scope: str = "ip", identifier: Optional[str] = None, kwargs: Optional[Dict[str, Any]] = None) -> str:
if identifier:
return identifier
if scope == "user":
if kwargs:
username = kwargs.get('username')
if username:
return username
from .auth import get_current_username # 局部导入避免循环
username = get_current_username()
if username:
return username
return get_client_ip()
def check_rate_limit(action: str, limit: int, window_seconds: int, identifier: Optional[str]) -> Tuple[bool, int]:
"""简单滑动窗口限频。"""
bucket_key = f"{action}:{identifier or 'anonymous'}"
bucket = state.RATE_LIMIT_BUCKETS[bucket_key]
now = time.time()
while bucket and now - bucket[0] > window_seconds:
bucket.popleft()
if len(bucket) >= limit:
retry_after = window_seconds - int(now - bucket[0])
return True, max(retry_after, 1)
bucket.append(now)
return False, 0
def rate_limited(action: str, limit: int, window_seconds: int, scope: str = "ip", error_message: Optional[str] = None):
"""装饰器:为路由增加速率限制。"""
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
identifier = resolve_identifier(scope, kwargs=kwargs)
limited, retry_after = check_rate_limit(action, limit, window_seconds, identifier)
if limited:
message = error_message or "请求过于频繁,请稍后再试。"
return jsonify({
"success": False,
"error": message,
"retry_after": retry_after
}), 429
return func(*args, **kwargs)
return wrapped
return decorator
def register_failure(action: str, limit: int, lock_seconds: int, scope: str = "ip", identifier: Optional[str] = None, kwargs: Optional[Dict[str, Any]] = None) -> int:
"""记录失败次数,超过阈值后触发锁定。"""
ident = resolve_identifier(scope, identifier, kwargs)
key = f"{action}:{ident}"
now = time.time()
entry = state.FAILURE_TRACKERS.setdefault(key, {"count": 0, "blocked_until": 0})
blocked_until = entry.get("blocked_until", 0)
if blocked_until and blocked_until > now:
return int(blocked_until - now)
entry["count"] = entry.get("count", 0) + 1
if entry["count"] >= limit:
entry["count"] = 0
entry["blocked_until"] = now + lock_seconds
return lock_seconds
return 0
def is_action_blocked(action: str, scope: str = "ip", identifier: Optional[str] = None, kwargs: Optional[Dict[str, Any]] = None) -> Tuple[bool, int]:
ident = resolve_identifier(scope, identifier, kwargs)
key = f"{action}:{ident}"
entry = state.FAILURE_TRACKERS.get(key)
if not entry:
return False, 0
now = time.time()
blocked_until = entry.get("blocked_until", 0)
if blocked_until and blocked_until > now:
return True, int(blocked_until - now)
return False, 0
def clear_failures(action: str, scope: str = "ip", identifier: Optional[str] = None, kwargs: Optional[Dict[str, Any]] = None):
ident = resolve_identifier(scope, identifier, kwargs)
key = f"{action}:{ident}"
state.FAILURE_TRACKERS.pop(key, None)
def get_csrf_token(force_new: bool = False) -> str:
token = session.get(state.CSRF_SESSION_KEY)
if force_new or not token:
token = secrets.token_urlsafe(32)
session[state.CSRF_SESSION_KEY] = token
return token
def requires_csrf_protection(path: str) -> bool:
# Bearer Token 请求走无状态认证,跳过 CSRF
auth_header = (request.headers.get("Authorization") or "").lower()
if auth_header.startswith("bearer "):
return False
# API v1 统一跳过 CSRF若未携带 Authorization将由鉴权层返回 401
if path.startswith("/api/v1/"):
return False
if path in state.CSRF_EXEMPT_PATHS:
return False
if path in state.CSRF_PROTECTED_PATHS:
return True
return any(path.startswith(prefix) for prefix in state.CSRF_PROTECTED_PREFIXES)
def validate_csrf_request() -> bool:
expected = session.get(state.CSRF_SESSION_KEY)
provided = request.headers.get(state.CSRF_HEADER_NAME) or request.form.get("csrf_token")
if not expected or not provided:
return False
try:
return hmac.compare_digest(str(provided), str(expected))
except Exception:
return False
def prune_socket_tokens(now: Optional[float] = None):
current = now or time.time()
for token, meta in list(state.pending_socket_tokens.items()):
if meta.get("expires_at", 0) <= current:
state.pending_socket_tokens.pop(token, None)
def consume_socket_token(token_value: Optional[str], username: Optional[str]) -> bool:
if not token_value or not username:
return False
prune_socket_tokens()
token_meta = state.pending_socket_tokens.pop(token_value, None)
if not token_meta:
return False
if token_meta.get("username") != username:
return False
if token_meta.get("expires_at", 0) <= time.time():
return False
fingerprint = token_meta.get("fingerprint") or ""
request_fp = (request.headers.get("User-Agent") or "")[:128]
if fingerprint and request_fp and not hmac.compare_digest(fingerprint, request_fp):
return False
return True
def format_tool_result_notice(tool_name: str, tool_call_id: Optional[str], content: str) -> str:
"""将工具执行结果转为系统消息文本,方便在对话中回传。"""
header = f"[工具结果] {tool_name}"
if tool_call_id:
header += f" (tool_call_id={tool_call_id})"
body = (content or "").strip()
if not body:
body = "(无附加输出)"
return f"{header}\n{body}"
def compact_web_search_result(result_data: Dict[str, Any]) -> Dict[str, Any]:
"""提取 web_search 结果中前端展示所需的关键字段,避免持久化时丢失列表。"""
if not isinstance(result_data, dict):
return {"success": False, "error": "invalid search result"}
compact: Dict[str, Any] = {
"success": bool(result_data.get("success")),
"summary": result_data.get("summary"),
"query": result_data.get("query"),
"filters": result_data.get("filters") or {},
"total_results": result_data.get("total_results", 0)
}
items: list[Dict[str, Any]] = []
for item in result_data.get("results") or []:
if not isinstance(item, dict):
continue
items.append({
"index": item.get("index"),
"title": item.get("title") or item.get("name"),
"url": item.get("url")
})
compact["results"] = items
if not compact.get("success") and result_data.get("error"):
compact["error"] = result_data.get("error")
return compact
def attach_security_hooks(app):
"""注册 CSRF 校验与通用安全响应头。"""
@app.before_request
def _enforce_csrf_token():
method = (request.method or "GET").upper()
if method in state.CSRF_SAFE_METHODS:
return
if not requires_csrf_protection(request.path):
return
if validate_csrf_request():
return
return jsonify({"success": False, "error": "CSRF validation failed"}), 403
@app.after_request
def _apply_security_headers(response):
response.headers.setdefault("X-Frame-Options", "SAMEORIGIN")
response.headers.setdefault("X-Content-Type-Options", "nosniff")
response.headers.setdefault("Referrer-Policy", "strict-origin-when-cross-origin")
if response.mimetype == "application/json":
response.headers.setdefault("Cache-Control", "no-store")
if app.config.get("SESSION_COOKIE_SECURE"):
response.headers.setdefault("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
return response
__all__ = [
"get_client_ip",
"resolve_identifier",
"check_rate_limit",
"rate_limited",
"register_failure",
"is_action_blocked",
"clear_failures",
"get_csrf_token",
"requires_csrf_protection",
"validate_csrf_request",
"prune_socket_tokens",
"consume_socket_token",
"format_tool_result_notice",
"compact_web_search_result",
"attach_security_hooks",
]
__all__ = [
"get_client_ip",
"resolve_identifier",
"check_rate_limit",
"rate_limited",
"register_failure",
"is_action_blocked",
"clear_failures",
"get_csrf_token",
"requires_csrf_protection",
"validate_csrf_request",
"prune_socket_tokens",
"consume_socket_token",
"format_tool_result_notice",
"compact_web_search_result",
]