feat: add ocr tool for main and sub agent

This commit is contained in:
JOJO 2025-11-18 17:06:48 +08:00
parent 92c3e25f7a
commit adb1f1249a
10 changed files with 281 additions and 6 deletions

View File

@ -8,6 +8,7 @@ from . import conversation as _conversation
from . import security as _security from . import security as _security
from . import ui as _ui from . import ui as _ui
from . import memory as _memory from . import memory as _memory
from . import ocr as _ocr
from . import todo as _todo from . import todo as _todo
from . import auth as _auth from . import auth as _auth
from . import sub_agent as _sub_agent from . import sub_agent as _sub_agent
@ -20,12 +21,13 @@ from .conversation import *
from .security import * from .security import *
from .ui import * from .ui import *
from .memory import * from .memory import *
from .ocr import *
from .todo import * from .todo import *
from .auth import * from .auth import *
from .sub_agent import * from .sub_agent import *
__all__ = [] __all__ = []
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent): for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent):
__all__ += getattr(module, "__all__", []) __all__ += getattr(module, "__all__", [])
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent

13
config/ocr.py Normal file
View File

@ -0,0 +1,13 @@
"""OCR 配置DeepSeek-OCR 接口信息。"""
OCR_API_BASE_URL = "https://api.siliconflow.cn"
OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes"
OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR"
OCR_MAX_TOKENS = 4096
__all__ = [
"OCR_API_BASE_URL",
"OCR_API_KEY",
"OCR_MODEL_ID",
"OCR_MAX_TOKENS",
]

View File

@ -40,6 +40,7 @@ from modules.terminal_manager import TerminalManager
from modules.todo_manager import TodoManager from modules.todo_manager import TodoManager
from modules.sub_agent_manager import SubAgentManager from modules.sub_agent_manager import SubAgentManager
from modules.webpage_extractor import extract_webpage_content, tavily_extract from modules.webpage_extractor import extract_webpage_content, tavily_extract
from modules.ocr_client import OCRClient
from core.tool_config import TOOL_CATEGORIES from core.tool_config import TOOL_CATEGORIES
from utils.api_client import DeepSeekClient from utils.api_client import DeepSeekClient
from utils.context_manager import ContextManager from utils.context_manager import ContextManager
@ -67,6 +68,7 @@ class MainTerminal:
self.file_manager = FileManager(project_path) self.file_manager = FileManager(project_path)
self.search_engine = SearchEngine() self.search_engine = SearchEngine()
self.terminal_ops = TerminalOperator(project_path) self.terminal_ops = TerminalOperator(project_path)
self.ocr_client = OCRClient(project_path, self.file_manager)
# 新增:终端管理器 # 新增:终端管理器
self.terminal_manager = TerminalManager( self.terminal_manager = TerminalManager(
@ -607,6 +609,8 @@ class MainTerminal:
print(f"{OUTPUT_FORMATS['file']} 创建文件") print(f"{OUTPUT_FORMATS['file']} 创建文件")
elif tool_name == "read_file": elif tool_name == "read_file":
print(f"{OUTPUT_FORMATS['file']} 读取文件") print(f"{OUTPUT_FORMATS['file']} 读取文件")
elif tool_name == "ocr_image":
print(f"{OUTPUT_FORMATS['file']} 图片OCR")
elif tool_name == "modify_file": elif tool_name == "modify_file":
print(f"{OUTPUT_FORMATS['file']} 修改文件") print(f"{OUTPUT_FORMATS['file']} 修改文件")
elif tool_name == "delete_file": elif tool_name == "delete_file":
@ -987,6 +991,21 @@ class MainTerminal:
} }
} }
}, },
{
"type": "function",
"function": {
"name": "ocr_image",
"description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "项目内的图片路径"},
"prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”"}
},
"required": ["path", "prompt"]
}
}
},
{ {
"type": "function", "type": "function",
"function": { "function": {
@ -1478,6 +1497,12 @@ class MainTerminal:
try: try:
if tool_name == "read_file": if tool_name == "read_file":
result = self._handle_read_tool(arguments) result = self._handle_read_tool(arguments)
elif tool_name == "ocr_image":
path = arguments.get("path")
prompt = arguments.get("prompt")
if not path:
return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False)
result = self.ocr_client.ocr_image(path=path, prompt=prompt or "")
# 终端会话管理工具 # 终端会话管理工具
elif tool_name == "terminal_session": elif tool_name == "terminal_session":

View File

@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = {
), ),
"read_focus": ToolCategory( "read_focus": ToolCategory(
label="阅读聚焦", label="阅读聚焦",
tools=["read_file", "focus_file", "unfocus_file"], tools=["read_file", "focus_file", "unfocus_file", "ocr_image"],
), ),
"terminal_realtime": ToolCategory( "terminal_realtime": ToolCategory(
label="实时终端", label="实时终端",

100
modules/ocr_client.py Normal file
View File

@ -0,0 +1,100 @@
"""DeepSeek-OCR 客户端(主智能体专用)。"""
import base64
import mimetypes
from pathlib import Path
from typing import Dict, List
from openai import OpenAI
from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS
from modules.file_manager import FileManager
class OCRClient:
"""封装 DeepSeek-OCR 调用逻辑。"""
def __init__(self, project_path: str, file_manager: FileManager):
self.project_path = Path(project_path).resolve()
self.file_manager = file_manager
# 补全 base_url兼容是否包含 /v1
base_url = (OCR_API_BASE_URL or "").rstrip("/")
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
self.client = OpenAI(
api_key=OCR_API_KEY,
base_url=base_url,
)
self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR"
self.max_tokens = OCR_MAX_TOKENS or 4096
# 默认大小上限10MB超出则警告并拒绝
self.max_image_size = 10 * 1024 * 1024
def _validate_image_path(self, path: str):
"""复用 FileManager 的路径校验,确保在项目内。"""
valid, error, full_path = self.file_manager._validate_path(path)
if not valid:
return False, error, None
if not full_path.exists():
return False, "文件不存在", None
if not full_path.is_file():
return False, "不是文件", None
return True, "", full_path
def ocr_image(self, path: str, prompt: str) -> Dict:
"""执行 OCR返回最简结果格式。"""
warnings: List[str] = []
valid, error, full_path = self._validate_image_path(path)
if not valid:
return {"success": False, "error": error, "warnings": warnings}
if not prompt or not str(prompt).strip():
return {"success": False, "error": "prompt 不能为空", "warnings": warnings}
try:
data = full_path.read_bytes()
except Exception as exc:
return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings}
size = len(data)
if size <= 0:
return {"success": False, "error": "文件为空,无法识别", "warnings": warnings}
if size > self.max_image_size:
return {
"success": False,
"error": f"图片过大({size}字节),上限为{self.max_image_size}字节",
"warnings": warnings,
}
mime_type, _ = mimetypes.guess_type(str(full_path))
if not mime_type or not mime_type.startswith("image/"):
warnings.append("无法确定图片类型,已按 JPEG 处理")
mime_type = "image/jpeg"
base64_image = base64.b64encode(data).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}"
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": prompt},
],
}
],
max_tokens=self.max_tokens,
temperature=0,
)
content = response.choices[0].message.content if response.choices else ""
return {"success": True, "content": content or "", "warnings": warnings}
except Exception as exc:
return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings}

View File

@ -8,6 +8,7 @@ from . import conversation as _conversation
from . import security as _security from . import security as _security
from . import ui as _ui from . import ui as _ui
from . import memory as _memory from . import memory as _memory
from . import ocr as _ocr
from . import todo as _todo from . import todo as _todo
from . import auth as _auth from . import auth as _auth
from . import sub_agent as _sub_agent from . import sub_agent as _sub_agent
@ -20,12 +21,13 @@ from .conversation import *
from .security import * from .security import *
from .ui import * from .ui import *
from .memory import * from .memory import *
from .ocr import *
from .todo import * from .todo import *
from .auth import * from .auth import *
from .sub_agent import * from .sub_agent import *
__all__ = [] __all__ = []
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent): for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent):
__all__ += getattr(module, "__all__", []) __all__ += getattr(module, "__all__", [])
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent

13
sub_agent/config/ocr.py Normal file
View File

@ -0,0 +1,13 @@
"""OCR 配置DeepSeek-OCR 接口信息(子智能体)。"""
OCR_API_BASE_URL = "https://api.siliconflow.cn"
OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes"
OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR"
OCR_MAX_TOKENS = 4096
__all__ = [
"OCR_API_BASE_URL",
"OCR_API_KEY",
"OCR_MODEL_ID",
"OCR_MAX_TOKENS",
]

View File

@ -40,6 +40,7 @@ from modules.terminal_manager import TerminalManager
from modules.todo_manager import TodoManager from modules.todo_manager import TodoManager
from modules.sub_agent_manager import SubAgentManager from modules.sub_agent_manager import SubAgentManager
from modules.webpage_extractor import extract_webpage_content, tavily_extract from modules.webpage_extractor import extract_webpage_content, tavily_extract
from modules.ocr_client import OCRClient
from core.tool_config import TOOL_CATEGORIES from core.tool_config import TOOL_CATEGORIES
from utils.api_client import DeepSeekClient from utils.api_client import DeepSeekClient
from utils.context_manager import ContextManager from utils.context_manager import ContextManager
@ -67,6 +68,7 @@ class MainTerminal:
self.file_manager = FileManager(project_path) self.file_manager = FileManager(project_path)
self.search_engine = SearchEngine() self.search_engine = SearchEngine()
self.terminal_ops = TerminalOperator(project_path) self.terminal_ops = TerminalOperator(project_path)
self.ocr_client = OCRClient(project_path, self.file_manager)
# 新增:终端管理器 # 新增:终端管理器
self.terminal_manager = TerminalManager( self.terminal_manager = TerminalManager(
@ -607,6 +609,8 @@ class MainTerminal:
print(f"{OUTPUT_FORMATS['file']} 创建文件") print(f"{OUTPUT_FORMATS['file']} 创建文件")
elif tool_name == "read_file": elif tool_name == "read_file":
print(f"{OUTPUT_FORMATS['file']} 读取文件") print(f"{OUTPUT_FORMATS['file']} 读取文件")
elif tool_name == "ocr_image":
print(f"{OUTPUT_FORMATS['file']} 图片OCR")
elif tool_name == "modify_file": elif tool_name == "modify_file":
print(f"{OUTPUT_FORMATS['file']} 修改文件") print(f"{OUTPUT_FORMATS['file']} 修改文件")
elif tool_name == "delete_file": elif tool_name == "delete_file":
@ -987,6 +991,21 @@ class MainTerminal:
} }
} }
}, },
{
"type": "function",
"function": {
"name": "ocr_image",
"description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "项目内的图片路径"},
"prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”"}
},
"required": ["path", "prompt"]
}
}
},
{ {
"type": "function", "type": "function",
"function": { "function": {
@ -1464,6 +1483,12 @@ class MainTerminal:
try: try:
if tool_name == "read_file": if tool_name == "read_file":
result = self._handle_read_tool(arguments) result = self._handle_read_tool(arguments)
elif tool_name == "ocr_image":
path = arguments.get("path")
prompt = arguments.get("prompt")
if not path:
return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False)
result = self.ocr_client.ocr_image(path=path, prompt=prompt or "")
# 终端会话管理工具 # 终端会话管理工具
elif tool_name == "terminal_session": elif tool_name == "terminal_session":

View File

@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = {
), ),
"read_focus": ToolCategory( "read_focus": ToolCategory(
label="阅读聚焦", label="阅读聚焦",
tools=["read_file", "focus_file", "unfocus_file"], tools=["read_file", "focus_file", "unfocus_file", "ocr_image"],
), ),
"terminal_realtime": ToolCategory( "terminal_realtime": ToolCategory(
label="实时终端", label="实时终端",

View File

@ -0,0 +1,95 @@
"""DeepSeek-OCR 客户端(子智能体专用)。"""
import base64
import mimetypes
from pathlib import Path
from typing import Dict, List
from openai import OpenAI
from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS
from modules.file_manager import FileManager
class OCRClient:
"""封装 DeepSeek-OCR 调用逻辑。"""
def __init__(self, project_path: str, file_manager: FileManager):
self.project_path = Path(project_path).resolve()
self.file_manager = file_manager
base_url = (OCR_API_BASE_URL or "").rstrip("/")
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
self.client = OpenAI(
api_key=OCR_API_KEY,
base_url=base_url,
)
self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR"
self.max_tokens = OCR_MAX_TOKENS or 4096
self.max_image_size = 10 * 1024 * 1024 # 10MB
def _validate_image_path(self, path: str):
valid, error, full_path = self.file_manager._validate_path(path)
if not valid:
return False, error, None
if not full_path.exists():
return False, "文件不存在", None
if not full_path.is_file():
return False, "不是文件", None
return True, "", full_path
def ocr_image(self, path: str, prompt: str) -> Dict:
warnings: List[str] = []
valid, error, full_path = self._validate_image_path(path)
if not valid:
return {"success": False, "error": error, "warnings": warnings}
if not prompt or not str(prompt).strip():
return {"success": False, "error": "prompt 不能为空", "warnings": warnings}
try:
data = full_path.read_bytes()
except Exception as exc:
return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings}
size = len(data)
if size <= 0:
return {"success": False, "error": "文件为空,无法识别", "warnings": warnings}
if size > self.max_image_size:
return {
"success": False,
"error": f"图片过大({size}字节),上限为{self.max_image_size}字节",
"warnings": warnings,
}
mime_type, _ = mimetypes.guess_type(str(full_path))
if not mime_type or not mime_type.startswith("image/"):
warnings.append("无法确定图片类型,已按 JPEG 处理")
mime_type = "image/jpeg"
base64_image = base64.b64encode(data).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}"
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": prompt},
],
}
],
max_tokens=self.max_tokens,
temperature=0,
)
content = response.choices[0].message.content if response.choices else ""
return {"success": True, "content": content or "", "warnings": warnings}
except Exception as exc:
return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings}