309 lines
13 KiB
Python
309 lines
13 KiB
Python
import asyncio
|
||
import json
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Set
|
||
|
||
try:
|
||
from config import (
|
||
OUTPUT_FORMATS, DATA_DIR, PROMPTS_DIR, NEED_CONFIRMATION,
|
||
MAX_TERMINALS, TERMINAL_BUFFER_SIZE, TERMINAL_DISPLAY_SIZE,
|
||
MAX_READ_FILE_CHARS, READ_TOOL_DEFAULT_MAX_CHARS,
|
||
READ_TOOL_DEFAULT_CONTEXT_BEFORE, READ_TOOL_DEFAULT_CONTEXT_AFTER,
|
||
READ_TOOL_MAX_CONTEXT_BEFORE, READ_TOOL_MAX_CONTEXT_AFTER,
|
||
READ_TOOL_DEFAULT_MAX_MATCHES, READ_TOOL_MAX_MATCHES,
|
||
READ_TOOL_MAX_FILE_SIZE,
|
||
TERMINAL_SANDBOX_MOUNT_PATH,
|
||
TERMINAL_SANDBOX_MODE,
|
||
TERMINAL_SANDBOX_CPUS,
|
||
TERMINAL_SANDBOX_MEMORY,
|
||
PROJECT_MAX_STORAGE_MB,
|
||
CUSTOM_TOOLS_ENABLED,
|
||
)
|
||
except ImportError:
|
||
import sys
|
||
project_root = Path(__file__).resolve().parents[2]
|
||
if str(project_root) not in sys.path:
|
||
sys.path.insert(0, str(project_root))
|
||
from config import (
|
||
OUTPUT_FORMATS, DATA_DIR, PROMPTS_DIR, NEED_CONFIRMATION,
|
||
MAX_TERMINALS, TERMINAL_BUFFER_SIZE, TERMINAL_DISPLAY_SIZE,
|
||
MAX_READ_FILE_CHARS, READ_TOOL_DEFAULT_MAX_CHARS,
|
||
READ_TOOL_DEFAULT_CONTEXT_BEFORE, READ_TOOL_DEFAULT_CONTEXT_AFTER,
|
||
READ_TOOL_MAX_CONTEXT_BEFORE, READ_TOOL_MAX_CONTEXT_AFTER,
|
||
READ_TOOL_DEFAULT_MAX_MATCHES, READ_TOOL_MAX_MATCHES,
|
||
READ_TOOL_MAX_FILE_SIZE,
|
||
TERMINAL_SANDBOX_MOUNT_PATH,
|
||
TERMINAL_SANDBOX_MODE,
|
||
TERMINAL_SANDBOX_CPUS,
|
||
TERMINAL_SANDBOX_MEMORY,
|
||
PROJECT_MAX_STORAGE_MB,
|
||
CUSTOM_TOOLS_ENABLED,
|
||
)
|
||
|
||
from modules.file_manager import FileManager
|
||
from modules.search_engine import SearchEngine
|
||
from modules.terminal_ops import TerminalOperator
|
||
from modules.memory_manager import MemoryManager
|
||
from modules.terminal_manager import TerminalManager
|
||
from modules.todo_manager import TodoManager
|
||
from modules.sub_agent_manager import SubAgentManager
|
||
from modules.webpage_extractor import extract_webpage_content, tavily_extract
|
||
from modules.ocr_client import OCRClient
|
||
from modules.easter_egg_manager import EasterEggManager
|
||
from modules.personalization_manager import (
|
||
load_personalization_config,
|
||
build_personalization_prompt,
|
||
)
|
||
from modules.skills_manager import (
|
||
get_skills_catalog,
|
||
build_skills_list,
|
||
merge_enabled_skills,
|
||
build_skills_prompt,
|
||
)
|
||
from modules.custom_tool_registry import CustomToolRegistry, build_default_tool_category
|
||
from modules.custom_tool_executor import CustomToolExecutor
|
||
|
||
try:
|
||
from config.limits import THINKING_FAST_INTERVAL
|
||
except ImportError:
|
||
THINKING_FAST_INTERVAL = 10
|
||
|
||
from modules.container_monitor import collect_stats, inspect_state
|
||
from core.tool_config import TOOL_CATEGORIES
|
||
from utils.api_client import DeepSeekClient
|
||
from utils.context_manager import ContextManager
|
||
from utils.tool_result_formatter import format_tool_result_for_context
|
||
from utils.logger import setup_logger
|
||
from config.model_profiles import (
|
||
get_model_profile,
|
||
get_model_prompt_replacements,
|
||
get_model_context_window,
|
||
)
|
||
|
||
logger = setup_logger(__name__)
|
||
DISABLE_LENGTH_CHECK = True
|
||
|
||
class MainTerminalToolsReadMixin:
|
||
@staticmethod
|
||
def _clamp_int(value, default, min_value=None, max_value=None):
|
||
"""将输入转换为整数并限制范围。"""
|
||
if value is None:
|
||
return default
|
||
try:
|
||
num = int(value)
|
||
except (TypeError, ValueError):
|
||
return default
|
||
if min_value is not None:
|
||
num = max(min_value, num)
|
||
if max_value is not None:
|
||
num = min(max_value, num)
|
||
return num
|
||
|
||
@staticmethod
|
||
def _parse_optional_line(value, field_name: str):
|
||
"""解析可选的行号参数。"""
|
||
if value is None:
|
||
return None, None
|
||
try:
|
||
number = int(value)
|
||
except (TypeError, ValueError):
|
||
return None, f"{field_name} 必须是整数"
|
||
if number < 1:
|
||
return None, f"{field_name} 必须大于等于1"
|
||
return number, None
|
||
|
||
@staticmethod
|
||
def _truncate_text_block(text: str, max_chars: int):
|
||
"""对单段文本应用字符限制。"""
|
||
if max_chars and len(text) > max_chars:
|
||
return text[:max_chars], True, max_chars
|
||
return text, False, len(text)
|
||
|
||
@staticmethod
|
||
def _limit_text_chunks(chunks: List[Dict], text_key: str, max_chars: int):
|
||
"""对多个文本片段应用全局字符限制。"""
|
||
if max_chars is None or max_chars <= 0:
|
||
return chunks, False, sum(len(chunk.get(text_key, "") or "") for chunk in chunks)
|
||
|
||
remaining = max_chars
|
||
limited_chunks: List[Dict] = []
|
||
truncated = False
|
||
consumed = 0
|
||
|
||
for chunk in chunks:
|
||
snippet = chunk.get(text_key, "") or ""
|
||
snippet_len = len(snippet)
|
||
chunk_copy = dict(chunk)
|
||
|
||
if remaining <= 0:
|
||
truncated = True
|
||
break
|
||
|
||
if snippet_len > remaining:
|
||
chunk_copy[text_key] = snippet[:remaining]
|
||
chunk_copy["truncated"] = True
|
||
consumed += remaining
|
||
limited_chunks.append(chunk_copy)
|
||
truncated = True
|
||
remaining = 0
|
||
break
|
||
|
||
limited_chunks.append(chunk_copy)
|
||
consumed += snippet_len
|
||
remaining -= snippet_len
|
||
|
||
return limited_chunks, truncated, consumed
|
||
|
||
def _handle_read_tool(self, arguments: Dict) -> Dict:
|
||
"""集中处理 read_file 工具的三种模式。"""
|
||
file_path = arguments.get("path")
|
||
if not file_path:
|
||
return {"success": False, "error": "缺少文件路径参数"}
|
||
|
||
read_type = (arguments.get("type") or "read").lower()
|
||
if read_type not in {"read", "search", "extract"}:
|
||
return {"success": False, "error": f"未知的读取类型: {read_type}"}
|
||
|
||
max_chars = self._clamp_int(
|
||
arguments.get("max_chars"),
|
||
READ_TOOL_DEFAULT_MAX_CHARS,
|
||
1,
|
||
MAX_READ_FILE_CHARS
|
||
)
|
||
|
||
base_result = {
|
||
"success": True,
|
||
"type": read_type,
|
||
"path": None,
|
||
"encoding": "utf-8",
|
||
"max_chars": max_chars,
|
||
"truncated": False
|
||
}
|
||
|
||
if read_type == "read":
|
||
start_line, error = self._parse_optional_line(arguments.get("start_line"), "start_line")
|
||
if error:
|
||
return {"success": False, "error": error}
|
||
end_line_val = arguments.get("end_line")
|
||
end_line = None
|
||
if end_line_val is not None:
|
||
end_line, error = self._parse_optional_line(end_line_val, "end_line")
|
||
if error:
|
||
return {"success": False, "error": error}
|
||
if start_line and end_line < start_line:
|
||
return {"success": False, "error": "end_line 必须大于等于 start_line"}
|
||
|
||
read_result = self.file_manager.read_text_segment(
|
||
file_path,
|
||
start_line=start_line,
|
||
end_line=end_line,
|
||
size_limit=READ_TOOL_MAX_FILE_SIZE
|
||
)
|
||
if not read_result.get("success"):
|
||
return read_result
|
||
|
||
content, truncated, char_count = self._truncate_text_block(read_result["content"], max_chars)
|
||
base_result.update({
|
||
"path": read_result["path"],
|
||
"content": content,
|
||
"line_start": read_result["line_start"],
|
||
"line_end": read_result["line_end"],
|
||
"total_lines": read_result["total_lines"],
|
||
"file_size": read_result["size"],
|
||
"char_count": char_count,
|
||
"message": f"已读取 {read_result['path']} 的内容(行 {read_result['line_start']}~{read_result['line_end']})"
|
||
})
|
||
base_result["truncated"] = truncated
|
||
self.context_manager.load_file(read_result["path"])
|
||
return base_result
|
||
|
||
if read_type == "search":
|
||
query = arguments.get("query")
|
||
if not query:
|
||
return {"success": False, "error": "搜索模式需要提供 query 参数"}
|
||
|
||
max_matches = self._clamp_int(
|
||
arguments.get("max_matches"),
|
||
READ_TOOL_DEFAULT_MAX_MATCHES,
|
||
1,
|
||
READ_TOOL_MAX_MATCHES
|
||
)
|
||
context_before = self._clamp_int(
|
||
arguments.get("context_before"),
|
||
READ_TOOL_DEFAULT_CONTEXT_BEFORE,
|
||
0,
|
||
READ_TOOL_MAX_CONTEXT_BEFORE
|
||
)
|
||
context_after = self._clamp_int(
|
||
arguments.get("context_after"),
|
||
READ_TOOL_DEFAULT_CONTEXT_AFTER,
|
||
0,
|
||
READ_TOOL_MAX_CONTEXT_AFTER
|
||
)
|
||
case_sensitive = bool(arguments.get("case_sensitive"))
|
||
|
||
search_result = self.file_manager.search_text(
|
||
file_path,
|
||
query=query,
|
||
max_matches=max_matches,
|
||
context_before=context_before,
|
||
context_after=context_after,
|
||
case_sensitive=case_sensitive,
|
||
size_limit=READ_TOOL_MAX_FILE_SIZE
|
||
)
|
||
if not search_result.get("success"):
|
||
return search_result
|
||
|
||
matches = search_result["matches"]
|
||
limited_matches, truncated, char_count = self._limit_text_chunks(matches, "snippet", max_chars)
|
||
|
||
base_result.update({
|
||
"path": search_result["path"],
|
||
"file_size": search_result["size"],
|
||
"query": query,
|
||
"max_matches": max_matches,
|
||
"actual_matches": len(matches),
|
||
"returned_matches": len(limited_matches),
|
||
"context_before": context_before,
|
||
"context_after": context_after,
|
||
"case_sensitive": case_sensitive,
|
||
"matches": limited_matches,
|
||
"char_count": char_count,
|
||
"message": f"在 {search_result['path']} 中搜索 \"{query}\",返回 {len(limited_matches)} 条结果"
|
||
})
|
||
base_result["truncated"] = truncated
|
||
return base_result
|
||
|
||
# extract
|
||
segments = arguments.get("segments")
|
||
if not isinstance(segments, list) or not segments:
|
||
return {"success": False, "error": "extract 模式需要提供 segments 数组"}
|
||
|
||
extract_result = self.file_manager.extract_segments(
|
||
file_path,
|
||
segments=segments,
|
||
size_limit=READ_TOOL_MAX_FILE_SIZE
|
||
)
|
||
if not extract_result.get("success"):
|
||
return extract_result
|
||
|
||
limited_segments, truncated, char_count = self._limit_text_chunks(
|
||
extract_result["segments"],
|
||
"content",
|
||
max_chars
|
||
)
|
||
|
||
base_result.update({
|
||
"path": extract_result["path"],
|
||
"segments": limited_segments,
|
||
"file_size": extract_result["size"],
|
||
"total_lines": extract_result["total_lines"],
|
||
"segment_count": len(limited_segments),
|
||
"char_count": char_count,
|
||
"message": f"已从 {extract_result['path']} 抽取 {len(limited_segments)} 个片段"
|
||
})
|
||
base_result["truncated"] = truncated
|
||
self.context_manager.load_file(extract_result["path"])
|
||
return base_result
|