131 lines
3.5 KiB
Python
131 lines
3.5 KiB
Python
"""
|
||
输入验证工具
|
||
"""
|
||
import re
|
||
from typing import Optional
|
||
|
||
def validate_question(question: str) -> Optional[str]:
|
||
"""验证用户问题"""
|
||
if not question:
|
||
return "问题不能为空"
|
||
|
||
if len(question) < 5:
|
||
return "问题太短,请提供更详细的描述"
|
||
|
||
if len(question) > 1000:
|
||
return "问题太长,请精简到1000字以内"
|
||
|
||
# 检查是否包含有效内容(不只是标点符号)
|
||
if not re.search(r'[a-zA-Z\u4e00-\u9fa5]+', question):
|
||
return "请输入有效的问题内容"
|
||
|
||
return None
|
||
|
||
def validate_outline_feedback(feedback: str) -> Optional[str]:
|
||
"""验证大纲反馈"""
|
||
if not feedback:
|
||
return "反馈内容不能为空"
|
||
|
||
if len(feedback) < 10:
|
||
return "请提供更详细的修改建议"
|
||
|
||
if len(feedback) > 500:
|
||
return "反馈内容请控制在500字以内"
|
||
|
||
return None
|
||
|
||
def validate_session_id(session_id: str) -> Optional[str]:
|
||
"""验证会话ID"""
|
||
if not session_id:
|
||
return "会话ID不能为空"
|
||
|
||
# UUID格式验证
|
||
uuid_pattern = re.compile(
|
||
r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$',
|
||
re.IGNORECASE
|
||
)
|
||
|
||
if not uuid_pattern.match(session_id):
|
||
return "无效的会话ID格式"
|
||
|
||
return None
|
||
|
||
def validate_subtopic_id(subtopic_id: str) -> Optional[str]:
|
||
"""验证子主题ID"""
|
||
if not subtopic_id:
|
||
return "子主题ID不能为空"
|
||
|
||
# 格式: ST开头 + 8位十六进制
|
||
if not re.match(r'^ST[0-9a-f]{8}$', subtopic_id, re.IGNORECASE):
|
||
return "无效的子主题ID格式"
|
||
|
||
return None
|
||
|
||
def validate_search_query(query: str) -> Optional[str]:
|
||
"""验证搜索查询"""
|
||
if not query:
|
||
return "搜索查询不能为空"
|
||
|
||
if len(query) < 2:
|
||
return "搜索查询太短"
|
||
|
||
if len(query) > 200:
|
||
return "搜索查询太长,请控制在200字符以内"
|
||
|
||
# 检查是否包含特殊字符攻击
|
||
dangerous_patterns = [
|
||
r'<script',
|
||
r'javascript:',
|
||
r'onerror=',
|
||
r'onclick=',
|
||
r'DROP TABLE',
|
||
r'DELETE FROM',
|
||
r'INSERT INTO'
|
||
]
|
||
|
||
for pattern in dangerous_patterns:
|
||
if re.search(pattern, query, re.IGNORECASE):
|
||
return "搜索查询包含不允许的内容"
|
||
|
||
return None
|
||
|
||
def validate_priority(priority: str) -> Optional[str]:
|
||
"""验证优先级"""
|
||
valid_priorities = ['high', 'medium', 'low']
|
||
|
||
if priority not in valid_priorities:
|
||
return f"优先级必须是以下之一: {', '.join(valid_priorities)}"
|
||
|
||
return None
|
||
|
||
def validate_report_format(format: str) -> Optional[str]:
|
||
"""验证报告格式"""
|
||
valid_formats = ['json', 'markdown', 'html', 'pdf']
|
||
|
||
if format not in valid_formats:
|
||
return f"报告格式必须是以下之一: {', '.join(valid_formats)}"
|
||
|
||
return None
|
||
|
||
def sanitize_filename(filename: str) -> str:
|
||
"""清理文件名"""
|
||
# 移除不安全的字符
|
||
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||
|
||
# 限制长度
|
||
if len(filename) > 200:
|
||
filename = filename[:200]
|
||
|
||
# 确保不以点开头(隐藏文件)
|
||
if filename.startswith('.'):
|
||
filename = '_' + filename[1:]
|
||
|
||
return filename
|
||
|
||
def validate_json_structure(data: dict, required_fields: list) -> Optional[str]:
|
||
"""验证JSON结构"""
|
||
for field in required_fields:
|
||
if field not in data:
|
||
return f"缺少必要字段: {field}"
|
||
|
||
return None |