Compare commits
No commits in common. "5bd61ae80d54a9968ecbc4d509a34b0e7ae8c8e3" and "92c3e25f7a883587d8400fcc311a833e9ba51e83" have entirely different histories.
5bd61ae80d
...
92c3e25f7a
25
.gitignore
vendored
25
.gitignore
vendored
@ -1,25 +0,0 @@
|
||||
# OS & tools
|
||||
.DS_Store
|
||||
*.swp
|
||||
*.swo
|
||||
*.tmp
|
||||
|
||||
# Python artifacts
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Runtime data (main agent)
|
||||
logs/
|
||||
data/
|
||||
project/
|
||||
users/
|
||||
webapp.pid
|
||||
|
||||
# Runtime data (sub agent)
|
||||
sub_agent/tasks/
|
||||
sub_agent/data/
|
||||
sub_agent/logs/
|
||||
sub_agent/project/
|
||||
|
||||
# Misc
|
||||
*.pid
|
||||
@ -8,7 +8,6 @@ from . import conversation as _conversation
|
||||
from . import security as _security
|
||||
from . import ui as _ui
|
||||
from . import memory as _memory
|
||||
from . import ocr as _ocr
|
||||
from . import todo as _todo
|
||||
from . import auth as _auth
|
||||
from . import sub_agent as _sub_agent
|
||||
@ -21,13 +20,12 @@ from .conversation import *
|
||||
from .security import *
|
||||
from .ui import *
|
||||
from .memory import *
|
||||
from .ocr import *
|
||||
from .todo import *
|
||||
from .auth import *
|
||||
from .sub_agent import *
|
||||
|
||||
__all__ = []
|
||||
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent):
|
||||
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent):
|
||||
__all__ += getattr(module, "__all__", [])
|
||||
|
||||
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent
|
||||
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
"""OCR 配置:DeepSeek-OCR 接口信息。"""
|
||||
|
||||
OCR_API_BASE_URL = "https://api.siliconflow.cn"
|
||||
OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes"
|
||||
OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR"
|
||||
OCR_MAX_TOKENS = 4096
|
||||
|
||||
__all__ = [
|
||||
"OCR_API_BASE_URL",
|
||||
"OCR_API_KEY",
|
||||
"OCR_MODEL_ID",
|
||||
"OCR_MAX_TOKENS",
|
||||
]
|
||||
@ -4,7 +4,7 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
@ -40,7 +40,6 @@ from modules.terminal_manager import TerminalManager
|
||||
from modules.todo_manager import TodoManager
|
||||
from modules.sub_agent_manager import SubAgentManager
|
||||
from modules.webpage_extractor import extract_webpage_content, tavily_extract
|
||||
from modules.ocr_client import OCRClient
|
||||
from core.tool_config import TOOL_CATEGORIES
|
||||
from utils.api_client import DeepSeekClient
|
||||
from utils.context_manager import ContextManager
|
||||
@ -68,7 +67,6 @@ class MainTerminal:
|
||||
self.file_manager = FileManager(project_path)
|
||||
self.search_engine = SearchEngine()
|
||||
self.terminal_ops = TerminalOperator(project_path)
|
||||
self.ocr_client = OCRClient(project_path, self.file_manager)
|
||||
|
||||
# 新增:终端管理器
|
||||
self.terminal_manager = TerminalManager(
|
||||
@ -535,7 +533,6 @@ class MainTerminal:
|
||||
collected_tool_calls.append(tool_call_info)
|
||||
|
||||
# 处理工具结果用于保存
|
||||
result_data = {}
|
||||
try:
|
||||
result_data = json.loads(result)
|
||||
if tool_name == "read_file" and result_data.get("success"):
|
||||
@ -550,9 +547,7 @@ class MainTerminal:
|
||||
collected_tool_results.append({
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": tool_result_content,
|
||||
"system_message": result_data.get("system_message") if isinstance(result_data, dict) else None,
|
||||
"task_id": result_data.get("task_id") if isinstance(result_data, dict) else None
|
||||
"content": tool_result_content
|
||||
})
|
||||
|
||||
return result
|
||||
@ -602,9 +597,6 @@ class MainTerminal:
|
||||
tool_call_id=tool_result["tool_call_id"],
|
||||
name=tool_result["name"]
|
||||
)
|
||||
system_message = tool_result.get("system_message")
|
||||
if system_message:
|
||||
self._record_sub_agent_message(system_message, tool_result.get("task_id"), inline=False)
|
||||
|
||||
# 4. 在终端显示执行信息(不保存到历史)
|
||||
if collected_tool_calls:
|
||||
@ -615,8 +607,6 @@ class MainTerminal:
|
||||
print(f"{OUTPUT_FORMATS['file']} 创建文件")
|
||||
elif tool_name == "read_file":
|
||||
print(f"{OUTPUT_FORMATS['file']} 读取文件")
|
||||
elif tool_name == "ocr_image":
|
||||
print(f"{OUTPUT_FORMATS['file']} 图片OCR")
|
||||
elif tool_name == "modify_file":
|
||||
print(f"{OUTPUT_FORMATS['file']} 修改文件")
|
||||
elif tool_name == "delete_file":
|
||||
@ -997,21 +987,6 @@ class MainTerminal:
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ocr_image",
|
||||
"description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "项目内的图片路径"},
|
||||
"prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”,必须使用中文提示词。"}
|
||||
},
|
||||
"required": ["path", "prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@ -1422,7 +1397,7 @@ class MainTerminal:
|
||||
"target_dir": {"type": "string", "description": "项目下用于接收交付的相对目录"},
|
||||
"reference_files": {
|
||||
"type": "array",
|
||||
"description": "提供给子智能体的参考文件列表(相对路径),禁止在summary和task中直接告知子智能体引用图片的路径,必须使用本参数提供",
|
||||
"description": "提供给子智能体的参考文件列表(相对路径)",
|
||||
"items": {"type": "string"},
|
||||
"maxItems": 10
|
||||
},
|
||||
@ -1503,12 +1478,6 @@ class MainTerminal:
|
||||
try:
|
||||
if tool_name == "read_file":
|
||||
result = self._handle_read_tool(arguments)
|
||||
elif tool_name == "ocr_image":
|
||||
path = arguments.get("path")
|
||||
prompt = arguments.get("prompt")
|
||||
if not path:
|
||||
return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False)
|
||||
result = self.ocr_client.ocr_image(path=path, prompt=prompt or "")
|
||||
|
||||
# 终端会话管理工具
|
||||
elif tool_name == "terminal_session":
|
||||
@ -2034,12 +2003,15 @@ class MainTerminal:
|
||||
agent_id=arguments.get("agent_id"),
|
||||
timeout_seconds=wait_timeout
|
||||
)
|
||||
self._record_sub_agent_message(result.get("system_message"), result.get("task_id"), inline=False)
|
||||
|
||||
elif tool_name == "close_sub_agent":
|
||||
result = self.sub_agent_manager.terminate_sub_agent(
|
||||
task_id=arguments.get("task_id"),
|
||||
agent_id=arguments.get("agent_id")
|
||||
)
|
||||
message = result.get("message") or result.get("error")
|
||||
self._record_sub_agent_message(message, result.get("task_id"), inline=False)
|
||||
|
||||
else:
|
||||
result = {"success": False, "error": f"未知工具: {tool_name}"}
|
||||
@ -2067,30 +2039,6 @@ class MainTerminal:
|
||||
# 构建上下文
|
||||
return self.context_manager.build_main_context(memory)
|
||||
|
||||
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
|
||||
"""判断指定助手消息的工具调用是否拥有后续的工具响应。"""
|
||||
if not tool_calls:
|
||||
return False
|
||||
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
||||
if not expected_ids:
|
||||
return False
|
||||
matched: Set[str] = set()
|
||||
idx = start_idx + 1
|
||||
total = len(conversation)
|
||||
while idx < total and len(matched) < len(expected_ids):
|
||||
next_conv = conversation[idx]
|
||||
role = next_conv.get("role")
|
||||
if role == "tool":
|
||||
call_id = next_conv.get("tool_call_id")
|
||||
if call_id in expected_ids:
|
||||
matched.add(call_id)
|
||||
else:
|
||||
break
|
||||
elif role in ("assistant", "user"):
|
||||
break
|
||||
idx += 1
|
||||
return len(matched) == len(expected_ids)
|
||||
|
||||
def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
|
||||
"""构建消息列表(添加终端内容注入)"""
|
||||
# 加载系统提示
|
||||
@ -2124,8 +2072,7 @@ class MainTerminal:
|
||||
messages.append({"role": "system", "content": thinking_prompt})
|
||||
|
||||
# 添加对话历史(保留完整结构,包括tool_calls和tool消息)
|
||||
conversation = context["conversation"]
|
||||
for idx, conv in enumerate(conversation):
|
||||
for conv in context["conversation"]:
|
||||
metadata = conv.get("metadata") or {}
|
||||
if conv["role"] == "assistant":
|
||||
# Assistant消息可能包含工具调用
|
||||
@ -2134,9 +2081,8 @@ class MainTerminal:
|
||||
"content": conv["content"]
|
||||
}
|
||||
# 如果有工具调用信息,添加到消息中
|
||||
tool_calls = conv.get("tool_calls") or []
|
||||
if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
|
||||
message["tool_calls"] = tool_calls
|
||||
if "tool_calls" in conv and conv["tool_calls"]:
|
||||
message["tool_calls"] = conv["tool_calls"]
|
||||
messages.append(message)
|
||||
|
||||
elif conv["role"] == "tool":
|
||||
|
||||
@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = {
|
||||
),
|
||||
"read_focus": ToolCategory(
|
||||
label="阅读聚焦",
|
||||
tools=["read_file", "focus_file", "unfocus_file", "ocr_image"],
|
||||
tools=["read_file", "focus_file", "unfocus_file"],
|
||||
),
|
||||
"terminal_realtime": ToolCategory(
|
||||
label="实时终端",
|
||||
|
||||
@ -1,104 +0,0 @@
|
||||
"""DeepSeek-OCR 客户端(主智能体专用)。"""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
|
||||
from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS
|
||||
from modules.file_manager import FileManager
|
||||
|
||||
|
||||
class OCRClient:
|
||||
"""封装 DeepSeek-OCR 调用逻辑。"""
|
||||
|
||||
def __init__(self, project_path: str, file_manager: FileManager):
|
||||
self.project_path = Path(project_path).resolve()
|
||||
self.file_manager = file_manager
|
||||
|
||||
# 补全 base_url,兼容是否包含 /v1
|
||||
base_url = (OCR_API_BASE_URL or "").rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
|
||||
# httpx 0.28 起不再支持 proxies 参数,显式传入 http_client 以避免默认封装报错
|
||||
self.http_client = httpx.Client()
|
||||
self.client = OpenAI(
|
||||
api_key=OCR_API_KEY,
|
||||
base_url=base_url,
|
||||
http_client=self.http_client,
|
||||
)
|
||||
self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR"
|
||||
self.max_tokens = OCR_MAX_TOKENS or 4096
|
||||
|
||||
# 默认大小上限(10MB),超出则警告并拒绝
|
||||
self.max_image_size = 10 * 1024 * 1024
|
||||
|
||||
def _validate_image_path(self, path: str):
|
||||
"""复用 FileManager 的路径校验,确保在项目内。"""
|
||||
valid, error, full_path = self.file_manager._validate_path(path)
|
||||
if not valid:
|
||||
return False, error, None
|
||||
if not full_path.exists():
|
||||
return False, "文件不存在", None
|
||||
if not full_path.is_file():
|
||||
return False, "不是文件", None
|
||||
return True, "", full_path
|
||||
|
||||
def ocr_image(self, path: str, prompt: str) -> Dict:
|
||||
"""执行 OCR,返回最简结果格式。"""
|
||||
warnings: List[str] = []
|
||||
|
||||
valid, error, full_path = self._validate_image_path(path)
|
||||
if not valid:
|
||||
return {"success": False, "error": error, "warnings": warnings}
|
||||
|
||||
if not prompt or not str(prompt).strip():
|
||||
return {"success": False, "error": "prompt 不能为空", "warnings": warnings}
|
||||
|
||||
try:
|
||||
data = full_path.read_bytes()
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings}
|
||||
|
||||
size = len(data)
|
||||
if size <= 0:
|
||||
return {"success": False, "error": "文件为空,无法识别", "warnings": warnings}
|
||||
|
||||
if size > self.max_image_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"图片过大({size}字节),上限为{self.max_image_size}字节",
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(str(full_path))
|
||||
if not mime_type or not mime_type.startswith("image/"):
|
||||
warnings.append("无法确定图片类型,已按 JPEG 处理")
|
||||
mime_type = "image/jpeg"
|
||||
|
||||
base64_image = base64.b64encode(data).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=0,
|
||||
)
|
||||
content = response.choices[0].message.content if response.choices else ""
|
||||
return {"success": True, "content": content or "", "warnings": warnings}
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings}
|
||||
@ -566,9 +566,11 @@ class SubAgentManager:
|
||||
"""返回子智能体任务概览,用于前端展示。"""
|
||||
overview: List[Dict[str, Any]] = []
|
||||
for task_id, task in self.tasks.items():
|
||||
status = task.get("status")
|
||||
if status not in {"pending", "running"}:
|
||||
continue
|
||||
if conversation_id and task.get("conversation_id") != conversation_id:
|
||||
continue
|
||||
|
||||
snapshot = {
|
||||
"task_id": task_id,
|
||||
"agent_id": task.get("agent_id"),
|
||||
@ -583,8 +585,7 @@ class SubAgentManager:
|
||||
"conversation_id": task.get("conversation_id"),
|
||||
"sub_conversation_id": task.get("sub_conversation_id"),
|
||||
}
|
||||
|
||||
# 运行中的任务尝试同步远端最新状态
|
||||
# 对于运行中的任务,尝试获取最新状态
|
||||
if snapshot["status"] not in TERMINAL_STATUSES:
|
||||
remote = self._call_service("GET", f"/tasks/{task_id}", timeout=5)
|
||||
if remote.get("success"):
|
||||
@ -592,13 +593,5 @@ class SubAgentManager:
|
||||
snapshot["remote_message"] = remote.get("message")
|
||||
snapshot["last_tool"] = remote.get("last_tool")
|
||||
task["last_tool"] = snapshot["last_tool"]
|
||||
else:
|
||||
# 已结束的任务带上最终结果/系统消息,方便前端展示
|
||||
final_result = task.get("final_result") or {}
|
||||
snapshot["final_message"] = final_result.get("system_message") or final_result.get("message")
|
||||
snapshot["success"] = final_result.get("success")
|
||||
|
||||
overview.append(snapshot)
|
||||
|
||||
overview.sort(key=lambda item: item.get("created_at") or 0, reverse=True)
|
||||
return overview
|
||||
|
||||
@ -8,7 +8,6 @@ from . import conversation as _conversation
|
||||
from . import security as _security
|
||||
from . import ui as _ui
|
||||
from . import memory as _memory
|
||||
from . import ocr as _ocr
|
||||
from . import todo as _todo
|
||||
from . import auth as _auth
|
||||
from . import sub_agent as _sub_agent
|
||||
@ -21,13 +20,12 @@ from .conversation import *
|
||||
from .security import *
|
||||
from .ui import *
|
||||
from .memory import *
|
||||
from .ocr import *
|
||||
from .todo import *
|
||||
from .auth import *
|
||||
from .sub_agent import *
|
||||
|
||||
__all__ = []
|
||||
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent):
|
||||
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent):
|
||||
__all__ += getattr(module, "__all__", [])
|
||||
|
||||
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent
|
||||
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
"""OCR 配置:DeepSeek-OCR 接口信息(子智能体)。"""
|
||||
|
||||
OCR_API_BASE_URL = "https://api.siliconflow.cn"
|
||||
OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes"
|
||||
OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR"
|
||||
OCR_MAX_TOKENS = 4096
|
||||
|
||||
__all__ = [
|
||||
"OCR_API_BASE_URL",
|
||||
"OCR_API_KEY",
|
||||
"OCR_MODEL_ID",
|
||||
"OCR_MAX_TOKENS",
|
||||
]
|
||||
@ -4,7 +4,7 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
@ -40,7 +40,6 @@ from modules.terminal_manager import TerminalManager
|
||||
from modules.todo_manager import TodoManager
|
||||
from modules.sub_agent_manager import SubAgentManager
|
||||
from modules.webpage_extractor import extract_webpage_content, tavily_extract
|
||||
from modules.ocr_client import OCRClient
|
||||
from core.tool_config import TOOL_CATEGORIES
|
||||
from utils.api_client import DeepSeekClient
|
||||
from utils.context_manager import ContextManager
|
||||
@ -68,7 +67,6 @@ class MainTerminal:
|
||||
self.file_manager = FileManager(project_path)
|
||||
self.search_engine = SearchEngine()
|
||||
self.terminal_ops = TerminalOperator(project_path)
|
||||
self.ocr_client = OCRClient(project_path, self.file_manager)
|
||||
|
||||
# 新增:终端管理器
|
||||
self.terminal_manager = TerminalManager(
|
||||
@ -609,8 +607,6 @@ class MainTerminal:
|
||||
print(f"{OUTPUT_FORMATS['file']} 创建文件")
|
||||
elif tool_name == "read_file":
|
||||
print(f"{OUTPUT_FORMATS['file']} 读取文件")
|
||||
elif tool_name == "ocr_image":
|
||||
print(f"{OUTPUT_FORMATS['file']} 图片OCR")
|
||||
elif tool_name == "modify_file":
|
||||
print(f"{OUTPUT_FORMATS['file']} 修改文件")
|
||||
elif tool_name == "delete_file":
|
||||
@ -991,21 +987,6 @@ class MainTerminal:
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ocr_image",
|
||||
"description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "项目内的图片路径"},
|
||||
"prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”"}
|
||||
},
|
||||
"required": ["path", "prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@ -1483,12 +1464,6 @@ class MainTerminal:
|
||||
try:
|
||||
if tool_name == "read_file":
|
||||
result = self._handle_read_tool(arguments)
|
||||
elif tool_name == "ocr_image":
|
||||
path = arguments.get("path")
|
||||
prompt = arguments.get("prompt")
|
||||
if not path:
|
||||
return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False)
|
||||
result = self.ocr_client.ocr_image(path=path, prompt=prompt or "")
|
||||
|
||||
# 终端会话管理工具
|
||||
elif tool_name == "terminal_session":
|
||||
@ -2005,6 +1980,7 @@ class MainTerminal:
|
||||
agent_id=arguments.get("agent_id"),
|
||||
timeout_seconds=arguments.get("timeout_seconds")
|
||||
)
|
||||
self._record_sub_agent_message(result.get("system_message"), result.get("task_id"), inline=False)
|
||||
|
||||
else:
|
||||
result = {"success": False, "error": f"未知工具: {tool_name}"}
|
||||
@ -2032,29 +2008,6 @@ class MainTerminal:
|
||||
# 构建上下文
|
||||
return self.context_manager.build_main_context(memory)
|
||||
|
||||
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
|
||||
if not tool_calls:
|
||||
return False
|
||||
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
||||
if not expected_ids:
|
||||
return False
|
||||
matched: Set[str] = set()
|
||||
idx = start_idx + 1
|
||||
total = len(conversation)
|
||||
while idx < total and len(matched) < len(expected_ids):
|
||||
next_conv = conversation[idx]
|
||||
role = next_conv.get("role")
|
||||
if role == "tool":
|
||||
call_id = next_conv.get("tool_call_id")
|
||||
if call_id in expected_ids:
|
||||
matched.add(call_id)
|
||||
else:
|
||||
break
|
||||
elif role in ("assistant", "user"):
|
||||
break
|
||||
idx += 1
|
||||
return len(matched) == len(expected_ids)
|
||||
|
||||
def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
|
||||
"""构建消息列表(添加终端内容注入)"""
|
||||
# 加载系统提示
|
||||
@ -2078,8 +2031,7 @@ class MainTerminal:
|
||||
messages.append({"role": "system", "content": todo_prompt})
|
||||
|
||||
# 添加对话历史(保留完整结构,包括tool_calls和tool消息)
|
||||
conversation = context["conversation"]
|
||||
for idx, conv in enumerate(conversation):
|
||||
for conv in context["conversation"]:
|
||||
metadata = conv.get("metadata") or {}
|
||||
if conv["role"] == "assistant":
|
||||
# Assistant消息可能包含工具调用
|
||||
@ -2088,9 +2040,8 @@ class MainTerminal:
|
||||
"content": conv["content"]
|
||||
}
|
||||
# 如果有工具调用信息,添加到消息中
|
||||
tool_calls = conv.get("tool_calls") or []
|
||||
if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
|
||||
message["tool_calls"] = tool_calls
|
||||
if "tool_calls" in conv and conv["tool_calls"]:
|
||||
message["tool_calls"] = conv["tool_calls"]
|
||||
messages.append(message)
|
||||
|
||||
elif conv["role"] == "tool":
|
||||
|
||||
@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = {
|
||||
),
|
||||
"read_focus": ToolCategory(
|
||||
label="阅读聚焦",
|
||||
tools=["read_file", "focus_file", "unfocus_file", "ocr_image"],
|
||||
tools=["read_file", "focus_file", "unfocus_file"],
|
||||
),
|
||||
"terminal_realtime": ToolCategory(
|
||||
label="实时终端",
|
||||
|
||||
@ -1,98 +0,0 @@
|
||||
"""DeepSeek-OCR 客户端(子智能体专用)。"""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
|
||||
from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS
|
||||
from modules.file_manager import FileManager
|
||||
|
||||
|
||||
class OCRClient:
|
||||
"""封装 DeepSeek-OCR 调用逻辑。"""
|
||||
|
||||
def __init__(self, project_path: str, file_manager: FileManager):
|
||||
self.project_path = Path(project_path).resolve()
|
||||
self.file_manager = file_manager
|
||||
|
||||
base_url = (OCR_API_BASE_URL or "").rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
|
||||
self.http_client = httpx.Client()
|
||||
self.client = OpenAI(
|
||||
api_key=OCR_API_KEY,
|
||||
base_url=base_url,
|
||||
http_client=self.http_client,
|
||||
)
|
||||
self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR"
|
||||
self.max_tokens = OCR_MAX_TOKENS or 4096
|
||||
self.max_image_size = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
def _validate_image_path(self, path: str):
|
||||
valid, error, full_path = self.file_manager._validate_path(path)
|
||||
if not valid:
|
||||
return False, error, None
|
||||
if not full_path.exists():
|
||||
return False, "文件不存在", None
|
||||
if not full_path.is_file():
|
||||
return False, "不是文件", None
|
||||
return True, "", full_path
|
||||
|
||||
def ocr_image(self, path: str, prompt: str) -> Dict:
|
||||
warnings: List[str] = []
|
||||
|
||||
valid, error, full_path = self._validate_image_path(path)
|
||||
if not valid:
|
||||
return {"success": False, "error": error, "warnings": warnings}
|
||||
|
||||
if not prompt or not str(prompt).strip():
|
||||
return {"success": False, "error": "prompt 不能为空", "warnings": warnings}
|
||||
|
||||
try:
|
||||
data = full_path.read_bytes()
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings}
|
||||
|
||||
size = len(data)
|
||||
if size <= 0:
|
||||
return {"success": False, "error": "文件为空,无法识别", "warnings": warnings}
|
||||
|
||||
if size > self.max_image_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"图片过大({size}字节),上限为{self.max_image_size}字节",
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(str(full_path))
|
||||
if not mime_type or not mime_type.startswith("image/"):
|
||||
warnings.append("无法确定图片类型,已按 JPEG 处理")
|
||||
mime_type = "image/jpeg"
|
||||
|
||||
base64_image = base64.b64encode(data).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=0,
|
||||
)
|
||||
content = response.choices[0].message.content if response.choices else ""
|
||||
return {"success": True, "content": content or "", "warnings": warnings}
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings}
|
||||
@ -85,7 +85,6 @@ sub_agent_rooms: Dict[str, set] = defaultdict(set)
|
||||
sub_agent_connections: Dict[str, str] = {}
|
||||
SUB_AGENT_TERMINAL_STATUSES = {"completed", "failed", "timeout"}
|
||||
STOPPING_GRACE_SECONDS = 30
|
||||
TERMINAL_ARCHIVE_GRACE_SECONDS = 20
|
||||
|
||||
def format_read_file_result(result_data: Dict) -> str:
|
||||
"""格式化 read_file 工具的输出,便于在Web端展示。"""
|
||||
@ -254,9 +253,7 @@ def cleanup_inactive_sub_agent_tasks(force: bool = False):
|
||||
for task_id, task in list(sub_agent_tasks.items()):
|
||||
status = (task.get("status") or "").lower()
|
||||
if status in SUB_AGENT_TERMINAL_STATUSES:
|
||||
updated_at = task.get("updated_at") or task.get("created_at") or now
|
||||
if force or (now - updated_at) > TERMINAL_ARCHIVE_GRACE_SECONDS:
|
||||
_purge_sub_agent_task(task_id)
|
||||
_purge_sub_agent_task(task_id)
|
||||
continue
|
||||
if status == "stopping":
|
||||
updated_at = task.get("updated_at") or task.get("created_at") or now
|
||||
@ -3901,6 +3898,16 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
|
||||
last_tool_name = tool_name
|
||||
|
||||
# ===== 增量保存:保存工具调用信息 =====
|
||||
if tool_calls:
|
||||
# 保存assistant消息(只包含工具调用信息,内容为空)
|
||||
web_terminal.context_manager.add_conversation(
|
||||
"assistant",
|
||||
"", # 空内容,只记录工具调用
|
||||
tool_calls
|
||||
)
|
||||
debug_log(f"💾 增量保存:工具调用信息 ({len(tool_calls)} 个工具)")
|
||||
|
||||
# 更新统计
|
||||
total_tool_calls += len(tool_calls)
|
||||
|
||||
@ -4005,7 +4012,6 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
await asyncio.sleep(1.5 - execution_time)
|
||||
|
||||
# 更新工具状态
|
||||
result_data = {}
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
except:
|
||||
@ -4118,9 +4124,6 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
name=function_name
|
||||
)
|
||||
debug_log(f"💾 增量保存:工具结果 {function_name}")
|
||||
system_message = result_data.get("system_message") if isinstance(result_data, dict) else None
|
||||
if system_message:
|
||||
web_terminal._record_sub_agent_message(system_message, result_data.get("task_id"), inline=False)
|
||||
|
||||
# 添加到消息历史(用于API继续对话)
|
||||
messages.append({
|
||||
|
||||
@ -266,6 +266,9 @@ class DeepSeekClient:
|
||||
"stream": stream,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
if current_thinking_mode:
|
||||
payload["thinking"] = {"type": "enabled"}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
@ -3137,7 +3137,8 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
}
|
||||
|
||||
messages.append(assistant_message)
|
||||
if assistant_content or current_thinking or tool_calls:
|
||||
|
||||
if assistant_content or current_thinking:
|
||||
web_terminal.context_manager.add_conversation(
|
||||
"assistant",
|
||||
assistant_content,
|
||||
@ -3208,6 +3209,17 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
consecutive_same_tool[tool_name] = 1
|
||||
|
||||
last_tool_name = tool_name
|
||||
|
||||
# ===== 增量保存:保存工具调用信息 =====
|
||||
if tool_calls:
|
||||
# 保存assistant消息(只包含工具调用信息,内容为空)
|
||||
web_terminal.context_manager.add_conversation(
|
||||
"assistant",
|
||||
"", # 空内容,只记录工具调用
|
||||
tool_calls
|
||||
)
|
||||
debug_log(f"💾 增量保存:工具调用信息 ({len(tool_calls)} 个工具)")
|
||||
|
||||
# 更新统计
|
||||
total_tool_calls += len(tool_calls)
|
||||
|
||||
@ -3310,7 +3322,6 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
await asyncio.sleep(1.5 - execution_time)
|
||||
|
||||
# 更新工具状态
|
||||
result_data = {}
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
except:
|
||||
@ -3423,9 +3434,6 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
|
||||
name=function_name
|
||||
)
|
||||
debug_log(f"💾 增量保存:工具结果 {function_name}")
|
||||
system_message = result_data.get("system_message") if isinstance(result_data, dict) else None
|
||||
if system_message:
|
||||
web_terminal._record_sub_agent_message(system_message, result_data.get("task_id"), inline=False)
|
||||
|
||||
# 添加到消息历史(用于API继续对话)
|
||||
messages.append({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user