From adb1f1249abc092ab2cc2963b48d9591dac26558 Mon Sep 17 00:00:00 2001 From: JOJO <1498581755@qq.com> Date: Tue, 18 Nov 2025 17:06:48 +0800 Subject: [PATCH] feat: add ocr tool for main and sub agent --- config/__init__.py | 6 +- config/ocr.py | 13 +++++ core/main_terminal.py | 25 ++++++++ core/tool_config.py | 2 +- modules/ocr_client.py | 100 ++++++++++++++++++++++++++++++++ sub_agent/config/__init__.py | 6 +- sub_agent/config/ocr.py | 13 +++++ sub_agent/core/main_terminal.py | 25 ++++++++ sub_agent/core/tool_config.py | 2 +- sub_agent/modules/ocr_client.py | 95 ++++++++++++++++++++++++++++++ 10 files changed, 281 insertions(+), 6 deletions(-) create mode 100644 config/ocr.py create mode 100644 modules/ocr_client.py create mode 100644 sub_agent/config/ocr.py create mode 100644 sub_agent/modules/ocr_client.py diff --git a/config/__init__.py b/config/__init__.py index 3b886f5..a706c5f 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -8,6 +8,7 @@ from . import conversation as _conversation from . import security as _security from . import ui as _ui from . import memory as _memory +from . import ocr as _ocr from . import todo as _todo from . import auth as _auth from . import sub_agent as _sub_agent @@ -20,12 +21,13 @@ from .conversation import * from .security import * from .ui import * from .memory import * +from .ocr import * from .todo import * from .auth import * from .sub_agent import * __all__ = [] -for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent): +for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent): __all__ += getattr(module, "__all__", []) -del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent +del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent diff --git a/config/ocr.py b/config/ocr.py new file mode 100644 index 0000000..79bd199 --- /dev/null +++ b/config/ocr.py @@ -0,0 +1,13 @@ +"""OCR 配置:DeepSeek-OCR 接口信息。""" + +OCR_API_BASE_URL = "https://api.siliconflow.cn" +OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes" +OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR" +OCR_MAX_TOKENS = 4096 + +__all__ = [ + "OCR_API_BASE_URL", + "OCR_API_KEY", + "OCR_MODEL_ID", + "OCR_MAX_TOKENS", +] diff --git a/core/main_terminal.py b/core/main_terminal.py index 3cd64e3..9304acb 100644 --- a/core/main_terminal.py +++ b/core/main_terminal.py @@ -40,6 +40,7 @@ from modules.terminal_manager import TerminalManager from modules.todo_manager import TodoManager from modules.sub_agent_manager import SubAgentManager from modules.webpage_extractor import extract_webpage_content, tavily_extract +from modules.ocr_client import OCRClient from core.tool_config import TOOL_CATEGORIES from utils.api_client import DeepSeekClient from utils.context_manager import ContextManager @@ -67,6 +68,7 @@ class MainTerminal: self.file_manager = FileManager(project_path) self.search_engine = SearchEngine() self.terminal_ops = TerminalOperator(project_path) + self.ocr_client = OCRClient(project_path, self.file_manager) # 新增:终端管理器 self.terminal_manager = TerminalManager( @@ -607,6 +609,8 @@ class MainTerminal: print(f"{OUTPUT_FORMATS['file']} 创建文件") elif tool_name == "read_file": print(f"{OUTPUT_FORMATS['file']} 读取文件") + elif tool_name == "ocr_image": + print(f"{OUTPUT_FORMATS['file']} 图片OCR") elif tool_name == "modify_file": print(f"{OUTPUT_FORMATS['file']} 修改文件") elif tool_name == "delete_file": @@ -987,6 +991,21 @@ class MainTerminal: } } }, + { + "type": "function", + "function": { + "name": "ocr_image", + "description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "项目内的图片路径"}, + "prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”"} + }, + "required": ["path", "prompt"] + } + } + }, { "type": "function", "function": { @@ -1478,6 +1497,12 @@ class MainTerminal: try: if tool_name == "read_file": result = self._handle_read_tool(arguments) + elif tool_name == "ocr_image": + path = arguments.get("path") + prompt = arguments.get("prompt") + if not path: + return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False) + result = self.ocr_client.ocr_image(path=path, prompt=prompt or "") # 终端会话管理工具 elif tool_name == "terminal_session": diff --git a/core/tool_config.py b/core/tool_config.py index 1d2c114..d4e87e7 100644 --- a/core/tool_config.py +++ b/core/tool_config.py @@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = { ), "read_focus": ToolCategory( label="阅读聚焦", - tools=["read_file", "focus_file", "unfocus_file"], + tools=["read_file", "focus_file", "unfocus_file", "ocr_image"], ), "terminal_realtime": ToolCategory( label="实时终端", diff --git a/modules/ocr_client.py b/modules/ocr_client.py new file mode 100644 index 0000000..03a5d81 --- /dev/null +++ b/modules/ocr_client.py @@ -0,0 +1,100 @@ +"""DeepSeek-OCR 客户端(主智能体专用)。""" + +import base64 +import mimetypes +from pathlib import Path +from typing import Dict, List + +from openai import OpenAI + +from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS +from modules.file_manager import FileManager + + +class OCRClient: + """封装 DeepSeek-OCR 调用逻辑。""" + + def __init__(self, project_path: str, file_manager: FileManager): + self.project_path = Path(project_path).resolve() + self.file_manager = file_manager + + # 补全 base_url,兼容是否包含 /v1 + base_url = (OCR_API_BASE_URL or "").rstrip("/") + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" + + self.client = OpenAI( + api_key=OCR_API_KEY, + base_url=base_url, + ) + self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR" + self.max_tokens = OCR_MAX_TOKENS or 4096 + + # 默认大小上限(10MB),超出则警告并拒绝 + self.max_image_size = 10 * 1024 * 1024 + + def _validate_image_path(self, path: str): + """复用 FileManager 的路径校验,确保在项目内。""" + valid, error, full_path = self.file_manager._validate_path(path) + if not valid: + return False, error, None + if not full_path.exists(): + return False, "文件不存在", None + if not full_path.is_file(): + return False, "不是文件", None + return True, "", full_path + + def ocr_image(self, path: str, prompt: str) -> Dict: + """执行 OCR,返回最简结果格式。""" + warnings: List[str] = [] + + valid, error, full_path = self._validate_image_path(path) + if not valid: + return {"success": False, "error": error, "warnings": warnings} + + if not prompt or not str(prompt).strip(): + return {"success": False, "error": "prompt 不能为空", "warnings": warnings} + + try: + data = full_path.read_bytes() + except Exception as exc: + return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings} + + size = len(data) + if size <= 0: + return {"success": False, "error": "文件为空,无法识别", "warnings": warnings} + + if size > self.max_image_size: + return { + "success": False, + "error": f"图片过大({size}字节),上限为{self.max_image_size}字节", + "warnings": warnings, + } + + mime_type, _ = mimetypes.guess_type(str(full_path)) + if not mime_type or not mime_type.startswith("image/"): + warnings.append("无法确定图片类型,已按 JPEG 处理") + mime_type = "image/jpeg" + + base64_image = base64.b64encode(data).decode("utf-8") + data_url = f"data:{mime_type};base64,{base64_image}" + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_url}}, + {"type": "text", "text": prompt}, + ], + } + ], + max_tokens=self.max_tokens, + temperature=0, + ) + content = response.choices[0].message.content if response.choices else "" + return {"success": True, "content": content or "", "warnings": warnings} + except Exception as exc: + return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings} diff --git a/sub_agent/config/__init__.py b/sub_agent/config/__init__.py index 3b886f5..a706c5f 100644 --- a/sub_agent/config/__init__.py +++ b/sub_agent/config/__init__.py @@ -8,6 +8,7 @@ from . import conversation as _conversation from . import security as _security from . import ui as _ui from . import memory as _memory +from . import ocr as _ocr from . import todo as _todo from . import auth as _auth from . import sub_agent as _sub_agent @@ -20,12 +21,13 @@ from .conversation import * from .security import * from .ui import * from .memory import * +from .ocr import * from .todo import * from .auth import * from .sub_agent import * __all__ = [] -for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent): +for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent): __all__ += getattr(module, "__all__", []) -del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent +del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent diff --git a/sub_agent/config/ocr.py b/sub_agent/config/ocr.py new file mode 100644 index 0000000..cc51f9c --- /dev/null +++ b/sub_agent/config/ocr.py @@ -0,0 +1,13 @@ +"""OCR 配置:DeepSeek-OCR 接口信息(子智能体)。""" + +OCR_API_BASE_URL = "https://api.siliconflow.cn" +OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes" +OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR" +OCR_MAX_TOKENS = 4096 + +__all__ = [ + "OCR_API_BASE_URL", + "OCR_API_KEY", + "OCR_MODEL_ID", + "OCR_MAX_TOKENS", +] diff --git a/sub_agent/core/main_terminal.py b/sub_agent/core/main_terminal.py index 8a29368..0203fd4 100644 --- a/sub_agent/core/main_terminal.py +++ b/sub_agent/core/main_terminal.py @@ -40,6 +40,7 @@ from modules.terminal_manager import TerminalManager from modules.todo_manager import TodoManager from modules.sub_agent_manager import SubAgentManager from modules.webpage_extractor import extract_webpage_content, tavily_extract +from modules.ocr_client import OCRClient from core.tool_config import TOOL_CATEGORIES from utils.api_client import DeepSeekClient from utils.context_manager import ContextManager @@ -67,6 +68,7 @@ class MainTerminal: self.file_manager = FileManager(project_path) self.search_engine = SearchEngine() self.terminal_ops = TerminalOperator(project_path) + self.ocr_client = OCRClient(project_path, self.file_manager) # 新增:终端管理器 self.terminal_manager = TerminalManager( @@ -607,6 +609,8 @@ class MainTerminal: print(f"{OUTPUT_FORMATS['file']} 创建文件") elif tool_name == "read_file": print(f"{OUTPUT_FORMATS['file']} 读取文件") + elif tool_name == "ocr_image": + print(f"{OUTPUT_FORMATS['file']} 图片OCR") elif tool_name == "modify_file": print(f"{OUTPUT_FORMATS['file']} 修改文件") elif tool_name == "delete_file": @@ -987,6 +991,21 @@ class MainTerminal: } } }, + { + "type": "function", + "function": { + "name": "ocr_image", + "description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "项目内的图片路径"}, + "prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”"} + }, + "required": ["path", "prompt"] + } + } + }, { "type": "function", "function": { @@ -1464,6 +1483,12 @@ class MainTerminal: try: if tool_name == "read_file": result = self._handle_read_tool(arguments) + elif tool_name == "ocr_image": + path = arguments.get("path") + prompt = arguments.get("prompt") + if not path: + return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False) + result = self.ocr_client.ocr_image(path=path, prompt=prompt or "") # 终端会话管理工具 elif tool_name == "terminal_session": diff --git a/sub_agent/core/tool_config.py b/sub_agent/core/tool_config.py index af4c746..4de22ee 100644 --- a/sub_agent/core/tool_config.py +++ b/sub_agent/core/tool_config.py @@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = { ), "read_focus": ToolCategory( label="阅读聚焦", - tools=["read_file", "focus_file", "unfocus_file"], + tools=["read_file", "focus_file", "unfocus_file", "ocr_image"], ), "terminal_realtime": ToolCategory( label="实时终端", diff --git a/sub_agent/modules/ocr_client.py b/sub_agent/modules/ocr_client.py new file mode 100644 index 0000000..957e5b4 --- /dev/null +++ b/sub_agent/modules/ocr_client.py @@ -0,0 +1,95 @@ +"""DeepSeek-OCR 客户端(子智能体专用)。""" + +import base64 +import mimetypes +from pathlib import Path +from typing import Dict, List + +from openai import OpenAI + +from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS +from modules.file_manager import FileManager + + +class OCRClient: + """封装 DeepSeek-OCR 调用逻辑。""" + + def __init__(self, project_path: str, file_manager: FileManager): + self.project_path = Path(project_path).resolve() + self.file_manager = file_manager + + base_url = (OCR_API_BASE_URL or "").rstrip("/") + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" + + self.client = OpenAI( + api_key=OCR_API_KEY, + base_url=base_url, + ) + self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR" + self.max_tokens = OCR_MAX_TOKENS or 4096 + self.max_image_size = 10 * 1024 * 1024 # 10MB + + def _validate_image_path(self, path: str): + valid, error, full_path = self.file_manager._validate_path(path) + if not valid: + return False, error, None + if not full_path.exists(): + return False, "文件不存在", None + if not full_path.is_file(): + return False, "不是文件", None + return True, "", full_path + + def ocr_image(self, path: str, prompt: str) -> Dict: + warnings: List[str] = [] + + valid, error, full_path = self._validate_image_path(path) + if not valid: + return {"success": False, "error": error, "warnings": warnings} + + if not prompt or not str(prompt).strip(): + return {"success": False, "error": "prompt 不能为空", "warnings": warnings} + + try: + data = full_path.read_bytes() + except Exception as exc: + return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings} + + size = len(data) + if size <= 0: + return {"success": False, "error": "文件为空,无法识别", "warnings": warnings} + + if size > self.max_image_size: + return { + "success": False, + "error": f"图片过大({size}字节),上限为{self.max_image_size}字节", + "warnings": warnings, + } + + mime_type, _ = mimetypes.guess_type(str(full_path)) + if not mime_type or not mime_type.startswith("image/"): + warnings.append("无法确定图片类型,已按 JPEG 处理") + mime_type = "image/jpeg" + + base64_image = base64.b64encode(data).decode("utf-8") + data_url = f"data:{mime_type};base64,{base64_image}" + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_url}}, + {"type": "text", "text": prompt}, + ], + } + ], + max_tokens=self.max_tokens, + temperature=0, + ) + content = response.choices[0].message.content if response.choices else "" + return {"success": True, "content": content or "", "warnings": warnings} + except Exception as exc: + return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings}