Compare commits

..

7 Commits

15 changed files with 404 additions and 51 deletions

25
.gitignore vendored Normal file
View File

@ -0,0 +1,25 @@
# OS & tools
.DS_Store
*.swp
*.swo
*.tmp
# Python artifacts
__pycache__/
*.pyc
# Runtime data (main agent)
logs/
data/
project/
users/
webapp.pid
# Runtime data (sub agent)
sub_agent/tasks/
sub_agent/data/
sub_agent/logs/
sub_agent/project/
# Misc
*.pid

View File

@ -8,6 +8,7 @@ from . import conversation as _conversation
from . import security as _security from . import security as _security
from . import ui as _ui from . import ui as _ui
from . import memory as _memory from . import memory as _memory
from . import ocr as _ocr
from . import todo as _todo from . import todo as _todo
from . import auth as _auth from . import auth as _auth
from . import sub_agent as _sub_agent from . import sub_agent as _sub_agent
@ -20,12 +21,13 @@ from .conversation import *
from .security import * from .security import *
from .ui import * from .ui import *
from .memory import * from .memory import *
from .ocr import *
from .todo import * from .todo import *
from .auth import * from .auth import *
from .sub_agent import * from .sub_agent import *
__all__ = [] __all__ = []
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent): for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent):
__all__ += getattr(module, "__all__", []) __all__ += getattr(module, "__all__", [])
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent

13
config/ocr.py Normal file
View File

@ -0,0 +1,13 @@
"""OCR 配置DeepSeek-OCR 接口信息。"""
OCR_API_BASE_URL = "https://api.siliconflow.cn"
OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes"
OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR"
OCR_MAX_TOKENS = 4096
__all__ = [
"OCR_API_BASE_URL",
"OCR_API_KEY",
"OCR_MODEL_ID",
"OCR_MAX_TOKENS",
]

View File

@ -4,7 +4,7 @@ import asyncio
import json import json
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional, Set
from datetime import datetime from datetime import datetime
try: try:
@ -40,6 +40,7 @@ from modules.terminal_manager import TerminalManager
from modules.todo_manager import TodoManager from modules.todo_manager import TodoManager
from modules.sub_agent_manager import SubAgentManager from modules.sub_agent_manager import SubAgentManager
from modules.webpage_extractor import extract_webpage_content, tavily_extract from modules.webpage_extractor import extract_webpage_content, tavily_extract
from modules.ocr_client import OCRClient
from core.tool_config import TOOL_CATEGORIES from core.tool_config import TOOL_CATEGORIES
from utils.api_client import DeepSeekClient from utils.api_client import DeepSeekClient
from utils.context_manager import ContextManager from utils.context_manager import ContextManager
@ -67,6 +68,7 @@ class MainTerminal:
self.file_manager = FileManager(project_path) self.file_manager = FileManager(project_path)
self.search_engine = SearchEngine() self.search_engine = SearchEngine()
self.terminal_ops = TerminalOperator(project_path) self.terminal_ops = TerminalOperator(project_path)
self.ocr_client = OCRClient(project_path, self.file_manager)
# 新增:终端管理器 # 新增:终端管理器
self.terminal_manager = TerminalManager( self.terminal_manager = TerminalManager(
@ -533,6 +535,7 @@ class MainTerminal:
collected_tool_calls.append(tool_call_info) collected_tool_calls.append(tool_call_info)
# 处理工具结果用于保存 # 处理工具结果用于保存
result_data = {}
try: try:
result_data = json.loads(result) result_data = json.loads(result)
if tool_name == "read_file" and result_data.get("success"): if tool_name == "read_file" and result_data.get("success"):
@ -547,7 +550,9 @@ class MainTerminal:
collected_tool_results.append({ collected_tool_results.append({
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"name": tool_name, "name": tool_name,
"content": tool_result_content "content": tool_result_content,
"system_message": result_data.get("system_message") if isinstance(result_data, dict) else None,
"task_id": result_data.get("task_id") if isinstance(result_data, dict) else None
}) })
return result return result
@ -597,6 +602,9 @@ class MainTerminal:
tool_call_id=tool_result["tool_call_id"], tool_call_id=tool_result["tool_call_id"],
name=tool_result["name"] name=tool_result["name"]
) )
system_message = tool_result.get("system_message")
if system_message:
self._record_sub_agent_message(system_message, tool_result.get("task_id"), inline=False)
# 4. 在终端显示执行信息(不保存到历史) # 4. 在终端显示执行信息(不保存到历史)
if collected_tool_calls: if collected_tool_calls:
@ -607,6 +615,8 @@ class MainTerminal:
print(f"{OUTPUT_FORMATS['file']} 创建文件") print(f"{OUTPUT_FORMATS['file']} 创建文件")
elif tool_name == "read_file": elif tool_name == "read_file":
print(f"{OUTPUT_FORMATS['file']} 读取文件") print(f"{OUTPUT_FORMATS['file']} 读取文件")
elif tool_name == "ocr_image":
print(f"{OUTPUT_FORMATS['file']} 图片OCR")
elif tool_name == "modify_file": elif tool_name == "modify_file":
print(f"{OUTPUT_FORMATS['file']} 修改文件") print(f"{OUTPUT_FORMATS['file']} 修改文件")
elif tool_name == "delete_file": elif tool_name == "delete_file":
@ -987,6 +997,21 @@ class MainTerminal:
} }
} }
}, },
{
"type": "function",
"function": {
"name": "ocr_image",
"description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "项目内的图片路径"},
"prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”,必须使用中文提示词。"}
},
"required": ["path", "prompt"]
}
}
},
{ {
"type": "function", "type": "function",
"function": { "function": {
@ -1397,7 +1422,7 @@ class MainTerminal:
"target_dir": {"type": "string", "description": "项目下用于接收交付的相对目录"}, "target_dir": {"type": "string", "description": "项目下用于接收交付的相对目录"},
"reference_files": { "reference_files": {
"type": "array", "type": "array",
"description": "提供给子智能体的参考文件列表(相对路径)", "description": "提供给子智能体的参考文件列表(相对路径)禁止在summary和task中直接告知子智能体引用图片的路径必须使用本参数提供",
"items": {"type": "string"}, "items": {"type": "string"},
"maxItems": 10 "maxItems": 10
}, },
@ -1478,6 +1503,12 @@ class MainTerminal:
try: try:
if tool_name == "read_file": if tool_name == "read_file":
result = self._handle_read_tool(arguments) result = self._handle_read_tool(arguments)
elif tool_name == "ocr_image":
path = arguments.get("path")
prompt = arguments.get("prompt")
if not path:
return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False)
result = self.ocr_client.ocr_image(path=path, prompt=prompt or "")
# 终端会话管理工具 # 终端会话管理工具
elif tool_name == "terminal_session": elif tool_name == "terminal_session":
@ -2003,15 +2034,12 @@ class MainTerminal:
agent_id=arguments.get("agent_id"), agent_id=arguments.get("agent_id"),
timeout_seconds=wait_timeout timeout_seconds=wait_timeout
) )
self._record_sub_agent_message(result.get("system_message"), result.get("task_id"), inline=False)
elif tool_name == "close_sub_agent": elif tool_name == "close_sub_agent":
result = self.sub_agent_manager.terminate_sub_agent( result = self.sub_agent_manager.terminate_sub_agent(
task_id=arguments.get("task_id"), task_id=arguments.get("task_id"),
agent_id=arguments.get("agent_id") agent_id=arguments.get("agent_id")
) )
message = result.get("message") or result.get("error")
self._record_sub_agent_message(message, result.get("task_id"), inline=False)
else: else:
result = {"success": False, "error": f"未知工具: {tool_name}"} result = {"success": False, "error": f"未知工具: {tool_name}"}
@ -2039,6 +2067,30 @@ class MainTerminal:
# 构建上下文 # 构建上下文
return self.context_manager.build_main_context(memory) return self.context_manager.build_main_context(memory)
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
"""判断指定助手消息的工具调用是否拥有后续的工具响应。"""
if not tool_calls:
return False
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
if not expected_ids:
return False
matched: Set[str] = set()
idx = start_idx + 1
total = len(conversation)
while idx < total and len(matched) < len(expected_ids):
next_conv = conversation[idx]
role = next_conv.get("role")
if role == "tool":
call_id = next_conv.get("tool_call_id")
if call_id in expected_ids:
matched.add(call_id)
else:
break
elif role in ("assistant", "user"):
break
idx += 1
return len(matched) == len(expected_ids)
def build_messages(self, context: Dict, user_input: str) -> List[Dict]: def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
"""构建消息列表(添加终端内容注入)""" """构建消息列表(添加终端内容注入)"""
# 加载系统提示 # 加载系统提示
@ -2072,7 +2124,8 @@ class MainTerminal:
messages.append({"role": "system", "content": thinking_prompt}) messages.append({"role": "system", "content": thinking_prompt})
# 添加对话历史保留完整结构包括tool_calls和tool消息 # 添加对话历史保留完整结构包括tool_calls和tool消息
for conv in context["conversation"]: conversation = context["conversation"]
for idx, conv in enumerate(conversation):
metadata = conv.get("metadata") or {} metadata = conv.get("metadata") or {}
if conv["role"] == "assistant": if conv["role"] == "assistant":
# Assistant消息可能包含工具调用 # Assistant消息可能包含工具调用
@ -2081,8 +2134,9 @@ class MainTerminal:
"content": conv["content"] "content": conv["content"]
} }
# 如果有工具调用信息,添加到消息中 # 如果有工具调用信息,添加到消息中
if "tool_calls" in conv and conv["tool_calls"]: tool_calls = conv.get("tool_calls") or []
message["tool_calls"] = conv["tool_calls"] if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
message["tool_calls"] = tool_calls
messages.append(message) messages.append(message)
elif conv["role"] == "tool": elif conv["role"] == "tool":

View File

@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = {
), ),
"read_focus": ToolCategory( "read_focus": ToolCategory(
label="阅读聚焦", label="阅读聚焦",
tools=["read_file", "focus_file", "unfocus_file"], tools=["read_file", "focus_file", "unfocus_file", "ocr_image"],
), ),
"terminal_realtime": ToolCategory( "terminal_realtime": ToolCategory(
label="实时终端", label="实时终端",

104
modules/ocr_client.py Normal file
View File

@ -0,0 +1,104 @@
"""DeepSeek-OCR 客户端(主智能体专用)。"""
import base64
import mimetypes
from pathlib import Path
from typing import Dict, List
import httpx
from openai import OpenAI
from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS
from modules.file_manager import FileManager
class OCRClient:
"""封装 DeepSeek-OCR 调用逻辑。"""
def __init__(self, project_path: str, file_manager: FileManager):
self.project_path = Path(project_path).resolve()
self.file_manager = file_manager
# 补全 base_url兼容是否包含 /v1
base_url = (OCR_API_BASE_URL or "").rstrip("/")
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
# httpx 0.28 起不再支持 proxies 参数,显式传入 http_client 以避免默认封装报错
self.http_client = httpx.Client()
self.client = OpenAI(
api_key=OCR_API_KEY,
base_url=base_url,
http_client=self.http_client,
)
self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR"
self.max_tokens = OCR_MAX_TOKENS or 4096
# 默认大小上限10MB超出则警告并拒绝
self.max_image_size = 10 * 1024 * 1024
def _validate_image_path(self, path: str):
"""复用 FileManager 的路径校验,确保在项目内。"""
valid, error, full_path = self.file_manager._validate_path(path)
if not valid:
return False, error, None
if not full_path.exists():
return False, "文件不存在", None
if not full_path.is_file():
return False, "不是文件", None
return True, "", full_path
def ocr_image(self, path: str, prompt: str) -> Dict:
"""执行 OCR返回最简结果格式。"""
warnings: List[str] = []
valid, error, full_path = self._validate_image_path(path)
if not valid:
return {"success": False, "error": error, "warnings": warnings}
if not prompt or not str(prompt).strip():
return {"success": False, "error": "prompt 不能为空", "warnings": warnings}
try:
data = full_path.read_bytes()
except Exception as exc:
return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings}
size = len(data)
if size <= 0:
return {"success": False, "error": "文件为空,无法识别", "warnings": warnings}
if size > self.max_image_size:
return {
"success": False,
"error": f"图片过大({size}字节),上限为{self.max_image_size}字节",
"warnings": warnings,
}
mime_type, _ = mimetypes.guess_type(str(full_path))
if not mime_type or not mime_type.startswith("image/"):
warnings.append("无法确定图片类型,已按 JPEG 处理")
mime_type = "image/jpeg"
base64_image = base64.b64encode(data).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}"
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": prompt},
],
}
],
max_tokens=self.max_tokens,
temperature=0,
)
content = response.choices[0].message.content if response.choices else ""
return {"success": True, "content": content or "", "warnings": warnings}
except Exception as exc:
return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings}

View File

@ -566,11 +566,9 @@ class SubAgentManager:
"""返回子智能体任务概览,用于前端展示。""" """返回子智能体任务概览,用于前端展示。"""
overview: List[Dict[str, Any]] = [] overview: List[Dict[str, Any]] = []
for task_id, task in self.tasks.items(): for task_id, task in self.tasks.items():
status = task.get("status")
if status not in {"pending", "running"}:
continue
if conversation_id and task.get("conversation_id") != conversation_id: if conversation_id and task.get("conversation_id") != conversation_id:
continue continue
snapshot = { snapshot = {
"task_id": task_id, "task_id": task_id,
"agent_id": task.get("agent_id"), "agent_id": task.get("agent_id"),
@ -585,7 +583,8 @@ class SubAgentManager:
"conversation_id": task.get("conversation_id"), "conversation_id": task.get("conversation_id"),
"sub_conversation_id": task.get("sub_conversation_id"), "sub_conversation_id": task.get("sub_conversation_id"),
} }
# 对于运行中的任务,尝试获取最新状态
# 运行中的任务尝试同步远端最新状态
if snapshot["status"] not in TERMINAL_STATUSES: if snapshot["status"] not in TERMINAL_STATUSES:
remote = self._call_service("GET", f"/tasks/{task_id}", timeout=5) remote = self._call_service("GET", f"/tasks/{task_id}", timeout=5)
if remote.get("success"): if remote.get("success"):
@ -593,5 +592,13 @@ class SubAgentManager:
snapshot["remote_message"] = remote.get("message") snapshot["remote_message"] = remote.get("message")
snapshot["last_tool"] = remote.get("last_tool") snapshot["last_tool"] = remote.get("last_tool")
task["last_tool"] = snapshot["last_tool"] task["last_tool"] = snapshot["last_tool"]
else:
# 已结束的任务带上最终结果/系统消息,方便前端展示
final_result = task.get("final_result") or {}
snapshot["final_message"] = final_result.get("system_message") or final_result.get("message")
snapshot["success"] = final_result.get("success")
overview.append(snapshot) overview.append(snapshot)
overview.sort(key=lambda item: item.get("created_at") or 0, reverse=True)
return overview return overview

View File

@ -8,6 +8,7 @@ from . import conversation as _conversation
from . import security as _security from . import security as _security
from . import ui as _ui from . import ui as _ui
from . import memory as _memory from . import memory as _memory
from . import ocr as _ocr
from . import todo as _todo from . import todo as _todo
from . import auth as _auth from . import auth as _auth
from . import sub_agent as _sub_agent from . import sub_agent as _sub_agent
@ -20,12 +21,13 @@ from .conversation import *
from .security import * from .security import *
from .ui import * from .ui import *
from .memory import * from .memory import *
from .ocr import *
from .todo import * from .todo import *
from .auth import * from .auth import *
from .sub_agent import * from .sub_agent import *
__all__ = [] __all__ = []
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent): for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent):
__all__ += getattr(module, "__all__", []) __all__ += getattr(module, "__all__", [])
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent

13
sub_agent/config/ocr.py Normal file
View File

@ -0,0 +1,13 @@
"""OCR 配置DeepSeek-OCR 接口信息(子智能体)。"""
OCR_API_BASE_URL = "https://api.siliconflow.cn"
OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes"
OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR"
OCR_MAX_TOKENS = 4096
__all__ = [
"OCR_API_BASE_URL",
"OCR_API_KEY",
"OCR_MODEL_ID",
"OCR_MAX_TOKENS",
]

View File

@ -4,7 +4,7 @@ import asyncio
import json import json
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional, Set
from datetime import datetime from datetime import datetime
try: try:
@ -40,6 +40,7 @@ from modules.terminal_manager import TerminalManager
from modules.todo_manager import TodoManager from modules.todo_manager import TodoManager
from modules.sub_agent_manager import SubAgentManager from modules.sub_agent_manager import SubAgentManager
from modules.webpage_extractor import extract_webpage_content, tavily_extract from modules.webpage_extractor import extract_webpage_content, tavily_extract
from modules.ocr_client import OCRClient
from core.tool_config import TOOL_CATEGORIES from core.tool_config import TOOL_CATEGORIES
from utils.api_client import DeepSeekClient from utils.api_client import DeepSeekClient
from utils.context_manager import ContextManager from utils.context_manager import ContextManager
@ -67,6 +68,7 @@ class MainTerminal:
self.file_manager = FileManager(project_path) self.file_manager = FileManager(project_path)
self.search_engine = SearchEngine() self.search_engine = SearchEngine()
self.terminal_ops = TerminalOperator(project_path) self.terminal_ops = TerminalOperator(project_path)
self.ocr_client = OCRClient(project_path, self.file_manager)
# 新增:终端管理器 # 新增:终端管理器
self.terminal_manager = TerminalManager( self.terminal_manager = TerminalManager(
@ -607,6 +609,8 @@ class MainTerminal:
print(f"{OUTPUT_FORMATS['file']} 创建文件") print(f"{OUTPUT_FORMATS['file']} 创建文件")
elif tool_name == "read_file": elif tool_name == "read_file":
print(f"{OUTPUT_FORMATS['file']} 读取文件") print(f"{OUTPUT_FORMATS['file']} 读取文件")
elif tool_name == "ocr_image":
print(f"{OUTPUT_FORMATS['file']} 图片OCR")
elif tool_name == "modify_file": elif tool_name == "modify_file":
print(f"{OUTPUT_FORMATS['file']} 修改文件") print(f"{OUTPUT_FORMATS['file']} 修改文件")
elif tool_name == "delete_file": elif tool_name == "delete_file":
@ -987,6 +991,21 @@ class MainTerminal:
} }
} }
}, },
{
"type": "function",
"function": {
"name": "ocr_image",
"description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "项目内的图片路径"},
"prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”"}
},
"required": ["path", "prompt"]
}
}
},
{ {
"type": "function", "type": "function",
"function": { "function": {
@ -1464,6 +1483,12 @@ class MainTerminal:
try: try:
if tool_name == "read_file": if tool_name == "read_file":
result = self._handle_read_tool(arguments) result = self._handle_read_tool(arguments)
elif tool_name == "ocr_image":
path = arguments.get("path")
prompt = arguments.get("prompt")
if not path:
return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False)
result = self.ocr_client.ocr_image(path=path, prompt=prompt or "")
# 终端会话管理工具 # 终端会话管理工具
elif tool_name == "terminal_session": elif tool_name == "terminal_session":
@ -1980,7 +2005,6 @@ class MainTerminal:
agent_id=arguments.get("agent_id"), agent_id=arguments.get("agent_id"),
timeout_seconds=arguments.get("timeout_seconds") timeout_seconds=arguments.get("timeout_seconds")
) )
self._record_sub_agent_message(result.get("system_message"), result.get("task_id"), inline=False)
else: else:
result = {"success": False, "error": f"未知工具: {tool_name}"} result = {"success": False, "error": f"未知工具: {tool_name}"}
@ -2008,6 +2032,29 @@ class MainTerminal:
# 构建上下文 # 构建上下文
return self.context_manager.build_main_context(memory) return self.context_manager.build_main_context(memory)
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
if not tool_calls:
return False
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
if not expected_ids:
return False
matched: Set[str] = set()
idx = start_idx + 1
total = len(conversation)
while idx < total and len(matched) < len(expected_ids):
next_conv = conversation[idx]
role = next_conv.get("role")
if role == "tool":
call_id = next_conv.get("tool_call_id")
if call_id in expected_ids:
matched.add(call_id)
else:
break
elif role in ("assistant", "user"):
break
idx += 1
return len(matched) == len(expected_ids)
def build_messages(self, context: Dict, user_input: str) -> List[Dict]: def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
"""构建消息列表(添加终端内容注入)""" """构建消息列表(添加终端内容注入)"""
# 加载系统提示 # 加载系统提示
@ -2031,7 +2078,8 @@ class MainTerminal:
messages.append({"role": "system", "content": todo_prompt}) messages.append({"role": "system", "content": todo_prompt})
# 添加对话历史保留完整结构包括tool_calls和tool消息 # 添加对话历史保留完整结构包括tool_calls和tool消息
for conv in context["conversation"]: conversation = context["conversation"]
for idx, conv in enumerate(conversation):
metadata = conv.get("metadata") or {} metadata = conv.get("metadata") or {}
if conv["role"] == "assistant": if conv["role"] == "assistant":
# Assistant消息可能包含工具调用 # Assistant消息可能包含工具调用
@ -2040,8 +2088,9 @@ class MainTerminal:
"content": conv["content"] "content": conv["content"]
} }
# 如果有工具调用信息,添加到消息中 # 如果有工具调用信息,添加到消息中
if "tool_calls" in conv and conv["tool_calls"]: tool_calls = conv.get("tool_calls") or []
message["tool_calls"] = conv["tool_calls"] if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
message["tool_calls"] = tool_calls
messages.append(message) messages.append(message)
elif conv["role"] == "tool": elif conv["role"] == "tool":

View File

@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = {
), ),
"read_focus": ToolCategory( "read_focus": ToolCategory(
label="阅读聚焦", label="阅读聚焦",
tools=["read_file", "focus_file", "unfocus_file"], tools=["read_file", "focus_file", "unfocus_file", "ocr_image"],
), ),
"terminal_realtime": ToolCategory( "terminal_realtime": ToolCategory(
label="实时终端", label="实时终端",

View File

@ -0,0 +1,98 @@
"""DeepSeek-OCR 客户端(子智能体专用)。"""
import base64
import mimetypes
from pathlib import Path
from typing import Dict, List
import httpx
from openai import OpenAI
from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS
from modules.file_manager import FileManager
class OCRClient:
"""封装 DeepSeek-OCR 调用逻辑。"""
def __init__(self, project_path: str, file_manager: FileManager):
self.project_path = Path(project_path).resolve()
self.file_manager = file_manager
base_url = (OCR_API_BASE_URL or "").rstrip("/")
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
self.http_client = httpx.Client()
self.client = OpenAI(
api_key=OCR_API_KEY,
base_url=base_url,
http_client=self.http_client,
)
self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR"
self.max_tokens = OCR_MAX_TOKENS or 4096
self.max_image_size = 10 * 1024 * 1024 # 10MB
def _validate_image_path(self, path: str):
valid, error, full_path = self.file_manager._validate_path(path)
if not valid:
return False, error, None
if not full_path.exists():
return False, "文件不存在", None
if not full_path.is_file():
return False, "不是文件", None
return True, "", full_path
def ocr_image(self, path: str, prompt: str) -> Dict:
warnings: List[str] = []
valid, error, full_path = self._validate_image_path(path)
if not valid:
return {"success": False, "error": error, "warnings": warnings}
if not prompt or not str(prompt).strip():
return {"success": False, "error": "prompt 不能为空", "warnings": warnings}
try:
data = full_path.read_bytes()
except Exception as exc:
return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings}
size = len(data)
if size <= 0:
return {"success": False, "error": "文件为空,无法识别", "warnings": warnings}
if size > self.max_image_size:
return {
"success": False,
"error": f"图片过大({size}字节),上限为{self.max_image_size}字节",
"warnings": warnings,
}
mime_type, _ = mimetypes.guess_type(str(full_path))
if not mime_type or not mime_type.startswith("image/"):
warnings.append("无法确定图片类型,已按 JPEG 处理")
mime_type = "image/jpeg"
base64_image = base64.b64encode(data).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}"
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": prompt},
],
}
],
max_tokens=self.max_tokens,
temperature=0,
)
content = response.choices[0].message.content if response.choices else ""
return {"success": True, "content": content or "", "warnings": warnings}
except Exception as exc:
return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings}

View File

@ -85,6 +85,7 @@ sub_agent_rooms: Dict[str, set] = defaultdict(set)
sub_agent_connections: Dict[str, str] = {} sub_agent_connections: Dict[str, str] = {}
SUB_AGENT_TERMINAL_STATUSES = {"completed", "failed", "timeout"} SUB_AGENT_TERMINAL_STATUSES = {"completed", "failed", "timeout"}
STOPPING_GRACE_SECONDS = 30 STOPPING_GRACE_SECONDS = 30
TERMINAL_ARCHIVE_GRACE_SECONDS = 20
def format_read_file_result(result_data: Dict) -> str: def format_read_file_result(result_data: Dict) -> str:
"""格式化 read_file 工具的输出便于在Web端展示。""" """格式化 read_file 工具的输出便于在Web端展示。"""
@ -253,7 +254,9 @@ def cleanup_inactive_sub_agent_tasks(force: bool = False):
for task_id, task in list(sub_agent_tasks.items()): for task_id, task in list(sub_agent_tasks.items()):
status = (task.get("status") or "").lower() status = (task.get("status") or "").lower()
if status in SUB_AGENT_TERMINAL_STATUSES: if status in SUB_AGENT_TERMINAL_STATUSES:
_purge_sub_agent_task(task_id) updated_at = task.get("updated_at") or task.get("created_at") or now
if force or (now - updated_at) > TERMINAL_ARCHIVE_GRACE_SECONDS:
_purge_sub_agent_task(task_id)
continue continue
if status == "stopping": if status == "stopping":
updated_at = task.get("updated_at") or task.get("created_at") or now updated_at = task.get("updated_at") or task.get("created_at") or now
@ -3898,16 +3901,6 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
last_tool_name = tool_name last_tool_name = tool_name
# ===== 增量保存:保存工具调用信息 =====
if tool_calls:
# 保存assistant消息只包含工具调用信息内容为空
web_terminal.context_manager.add_conversation(
"assistant",
"", # 空内容,只记录工具调用
tool_calls
)
debug_log(f"💾 增量保存:工具调用信息 ({len(tool_calls)} 个工具)")
# 更新统计 # 更新统计
total_tool_calls += len(tool_calls) total_tool_calls += len(tool_calls)
@ -4012,6 +4005,7 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
await asyncio.sleep(1.5 - execution_time) await asyncio.sleep(1.5 - execution_time)
# 更新工具状态 # 更新工具状态
result_data = {}
try: try:
result_data = json.loads(tool_result) result_data = json.loads(tool_result)
except: except:
@ -4124,6 +4118,9 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
name=function_name name=function_name
) )
debug_log(f"💾 增量保存:工具结果 {function_name}") debug_log(f"💾 增量保存:工具结果 {function_name}")
system_message = result_data.get("system_message") if isinstance(result_data, dict) else None
if system_message:
web_terminal._record_sub_agent_message(system_message, result_data.get("task_id"), inline=False)
# 添加到消息历史用于API继续对话 # 添加到消息历史用于API继续对话
messages.append({ messages.append({

View File

@ -266,9 +266,6 @@ class DeepSeekClient:
"stream": stream, "stream": stream,
"max_tokens": max_tokens "max_tokens": max_tokens
} }
if current_thinking_mode:
payload["thinking"] = {"type": "enabled"}
if tools: if tools:
payload["tools"] = tools payload["tools"] = tools
payload["tool_choice"] = "auto" payload["tool_choice"] = "auto"

View File

@ -3137,8 +3137,7 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
} }
messages.append(assistant_message) messages.append(assistant_message)
if assistant_content or current_thinking or tool_calls:
if assistant_content or current_thinking:
web_terminal.context_manager.add_conversation( web_terminal.context_manager.add_conversation(
"assistant", "assistant",
assistant_content, assistant_content,
@ -3209,17 +3208,6 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
consecutive_same_tool[tool_name] = 1 consecutive_same_tool[tool_name] = 1
last_tool_name = tool_name last_tool_name = tool_name
# ===== 增量保存:保存工具调用信息 =====
if tool_calls:
# 保存assistant消息只包含工具调用信息内容为空
web_terminal.context_manager.add_conversation(
"assistant",
"", # 空内容,只记录工具调用
tool_calls
)
debug_log(f"💾 增量保存:工具调用信息 ({len(tool_calls)} 个工具)")
# 更新统计 # 更新统计
total_tool_calls += len(tool_calls) total_tool_calls += len(tool_calls)
@ -3322,6 +3310,7 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
await asyncio.sleep(1.5 - execution_time) await asyncio.sleep(1.5 - execution_time)
# 更新工具状态 # 更新工具状态
result_data = {}
try: try:
result_data = json.loads(tool_result) result_data = json.loads(tool_result)
except: except:
@ -3434,6 +3423,9 @@ async def handle_task_with_sender(terminal: WebTerminal, message, sender, client
name=function_name name=function_name
) )
debug_log(f"💾 增量保存:工具结果 {function_name}") debug_log(f"💾 增量保存:工具结果 {function_name}")
system_message = result_data.get("system_message") if isinstance(result_data, dict) else None
if system_message:
web_terminal._record_sub_agent_message(system_message, result_data.get("task_id"), inline=False)
# 添加到消息历史用于API继续对话 # 添加到消息历史用于API继续对话
messages.append({ messages.append({