feat: add ocr tool for main and sub agent
This commit is contained in:
parent
92c3e25f7a
commit
adb1f1249a
@ -8,6 +8,7 @@ from . import conversation as _conversation
|
||||
from . import security as _security
|
||||
from . import ui as _ui
|
||||
from . import memory as _memory
|
||||
from . import ocr as _ocr
|
||||
from . import todo as _todo
|
||||
from . import auth as _auth
|
||||
from . import sub_agent as _sub_agent
|
||||
@ -20,12 +21,13 @@ from .conversation import *
|
||||
from .security import *
|
||||
from .ui import *
|
||||
from .memory import *
|
||||
from .ocr import *
|
||||
from .todo import *
|
||||
from .auth import *
|
||||
from .sub_agent import *
|
||||
|
||||
__all__ = []
|
||||
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent):
|
||||
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent):
|
||||
__all__ += getattr(module, "__all__", [])
|
||||
|
||||
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent
|
||||
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent
|
||||
|
||||
13
config/ocr.py
Normal file
13
config/ocr.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""OCR 配置:DeepSeek-OCR 接口信息。"""
|
||||
|
||||
OCR_API_BASE_URL = "https://api.siliconflow.cn"
|
||||
OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes"
|
||||
OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR"
|
||||
OCR_MAX_TOKENS = 4096
|
||||
|
||||
__all__ = [
|
||||
"OCR_API_BASE_URL",
|
||||
"OCR_API_KEY",
|
||||
"OCR_MODEL_ID",
|
||||
"OCR_MAX_TOKENS",
|
||||
]
|
||||
@ -40,6 +40,7 @@ from modules.terminal_manager import TerminalManager
|
||||
from modules.todo_manager import TodoManager
|
||||
from modules.sub_agent_manager import SubAgentManager
|
||||
from modules.webpage_extractor import extract_webpage_content, tavily_extract
|
||||
from modules.ocr_client import OCRClient
|
||||
from core.tool_config import TOOL_CATEGORIES
|
||||
from utils.api_client import DeepSeekClient
|
||||
from utils.context_manager import ContextManager
|
||||
@ -67,6 +68,7 @@ class MainTerminal:
|
||||
self.file_manager = FileManager(project_path)
|
||||
self.search_engine = SearchEngine()
|
||||
self.terminal_ops = TerminalOperator(project_path)
|
||||
self.ocr_client = OCRClient(project_path, self.file_manager)
|
||||
|
||||
# 新增:终端管理器
|
||||
self.terminal_manager = TerminalManager(
|
||||
@ -607,6 +609,8 @@ class MainTerminal:
|
||||
print(f"{OUTPUT_FORMATS['file']} 创建文件")
|
||||
elif tool_name == "read_file":
|
||||
print(f"{OUTPUT_FORMATS['file']} 读取文件")
|
||||
elif tool_name == "ocr_image":
|
||||
print(f"{OUTPUT_FORMATS['file']} 图片OCR")
|
||||
elif tool_name == "modify_file":
|
||||
print(f"{OUTPUT_FORMATS['file']} 修改文件")
|
||||
elif tool_name == "delete_file":
|
||||
@ -987,6 +991,21 @@ class MainTerminal:
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ocr_image",
|
||||
"description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "项目内的图片路径"},
|
||||
"prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”"}
|
||||
},
|
||||
"required": ["path", "prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@ -1478,6 +1497,12 @@ class MainTerminal:
|
||||
try:
|
||||
if tool_name == "read_file":
|
||||
result = self._handle_read_tool(arguments)
|
||||
elif tool_name == "ocr_image":
|
||||
path = arguments.get("path")
|
||||
prompt = arguments.get("prompt")
|
||||
if not path:
|
||||
return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False)
|
||||
result = self.ocr_client.ocr_image(path=path, prompt=prompt or "")
|
||||
|
||||
# 终端会话管理工具
|
||||
elif tool_name == "terminal_session":
|
||||
|
||||
@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = {
|
||||
),
|
||||
"read_focus": ToolCategory(
|
||||
label="阅读聚焦",
|
||||
tools=["read_file", "focus_file", "unfocus_file"],
|
||||
tools=["read_file", "focus_file", "unfocus_file", "ocr_image"],
|
||||
),
|
||||
"terminal_realtime": ToolCategory(
|
||||
label="实时终端",
|
||||
|
||||
100
modules/ocr_client.py
Normal file
100
modules/ocr_client.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""DeepSeek-OCR 客户端(主智能体专用)。"""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS
|
||||
from modules.file_manager import FileManager
|
||||
|
||||
|
||||
class OCRClient:
|
||||
"""封装 DeepSeek-OCR 调用逻辑。"""
|
||||
|
||||
def __init__(self, project_path: str, file_manager: FileManager):
|
||||
self.project_path = Path(project_path).resolve()
|
||||
self.file_manager = file_manager
|
||||
|
||||
# 补全 base_url,兼容是否包含 /v1
|
||||
base_url = (OCR_API_BASE_URL or "").rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=OCR_API_KEY,
|
||||
base_url=base_url,
|
||||
)
|
||||
self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR"
|
||||
self.max_tokens = OCR_MAX_TOKENS or 4096
|
||||
|
||||
# 默认大小上限(10MB),超出则警告并拒绝
|
||||
self.max_image_size = 10 * 1024 * 1024
|
||||
|
||||
def _validate_image_path(self, path: str):
|
||||
"""复用 FileManager 的路径校验,确保在项目内。"""
|
||||
valid, error, full_path = self.file_manager._validate_path(path)
|
||||
if not valid:
|
||||
return False, error, None
|
||||
if not full_path.exists():
|
||||
return False, "文件不存在", None
|
||||
if not full_path.is_file():
|
||||
return False, "不是文件", None
|
||||
return True, "", full_path
|
||||
|
||||
def ocr_image(self, path: str, prompt: str) -> Dict:
|
||||
"""执行 OCR,返回最简结果格式。"""
|
||||
warnings: List[str] = []
|
||||
|
||||
valid, error, full_path = self._validate_image_path(path)
|
||||
if not valid:
|
||||
return {"success": False, "error": error, "warnings": warnings}
|
||||
|
||||
if not prompt or not str(prompt).strip():
|
||||
return {"success": False, "error": "prompt 不能为空", "warnings": warnings}
|
||||
|
||||
try:
|
||||
data = full_path.read_bytes()
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings}
|
||||
|
||||
size = len(data)
|
||||
if size <= 0:
|
||||
return {"success": False, "error": "文件为空,无法识别", "warnings": warnings}
|
||||
|
||||
if size > self.max_image_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"图片过大({size}字节),上限为{self.max_image_size}字节",
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(str(full_path))
|
||||
if not mime_type or not mime_type.startswith("image/"):
|
||||
warnings.append("无法确定图片类型,已按 JPEG 处理")
|
||||
mime_type = "image/jpeg"
|
||||
|
||||
base64_image = base64.b64encode(data).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=0,
|
||||
)
|
||||
content = response.choices[0].message.content if response.choices else ""
|
||||
return {"success": True, "content": content or "", "warnings": warnings}
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings}
|
||||
@ -8,6 +8,7 @@ from . import conversation as _conversation
|
||||
from . import security as _security
|
||||
from . import ui as _ui
|
||||
from . import memory as _memory
|
||||
from . import ocr as _ocr
|
||||
from . import todo as _todo
|
||||
from . import auth as _auth
|
||||
from . import sub_agent as _sub_agent
|
||||
@ -20,12 +21,13 @@ from .conversation import *
|
||||
from .security import *
|
||||
from .ui import *
|
||||
from .memory import *
|
||||
from .ocr import *
|
||||
from .todo import *
|
||||
from .auth import *
|
||||
from .sub_agent import *
|
||||
|
||||
__all__ = []
|
||||
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent):
|
||||
for module in (_api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent):
|
||||
__all__ += getattr(module, "__all__", [])
|
||||
|
||||
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _todo, _auth, _sub_agent
|
||||
del _api, _paths, _limits, _terminal, _conversation, _security, _ui, _memory, _ocr, _todo, _auth, _sub_agent
|
||||
|
||||
13
sub_agent/config/ocr.py
Normal file
13
sub_agent/config/ocr.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""OCR 配置:DeepSeek-OCR 接口信息(子智能体)。"""
|
||||
|
||||
OCR_API_BASE_URL = "https://api.siliconflow.cn"
|
||||
OCR_API_KEY = "sk-suqqgewtlwajjkylvnotdhkzmsrshmrqptkakdqjmlrilaes"
|
||||
OCR_MODEL_ID = "deepseek-ai/DeepSeek-OCR"
|
||||
OCR_MAX_TOKENS = 4096
|
||||
|
||||
__all__ = [
|
||||
"OCR_API_BASE_URL",
|
||||
"OCR_API_KEY",
|
||||
"OCR_MODEL_ID",
|
||||
"OCR_MAX_TOKENS",
|
||||
]
|
||||
@ -40,6 +40,7 @@ from modules.terminal_manager import TerminalManager
|
||||
from modules.todo_manager import TodoManager
|
||||
from modules.sub_agent_manager import SubAgentManager
|
||||
from modules.webpage_extractor import extract_webpage_content, tavily_extract
|
||||
from modules.ocr_client import OCRClient
|
||||
from core.tool_config import TOOL_CATEGORIES
|
||||
from utils.api_client import DeepSeekClient
|
||||
from utils.context_manager import ContextManager
|
||||
@ -67,6 +68,7 @@ class MainTerminal:
|
||||
self.file_manager = FileManager(project_path)
|
||||
self.search_engine = SearchEngine()
|
||||
self.terminal_ops = TerminalOperator(project_path)
|
||||
self.ocr_client = OCRClient(project_path, self.file_manager)
|
||||
|
||||
# 新增:终端管理器
|
||||
self.terminal_manager = TerminalManager(
|
||||
@ -607,6 +609,8 @@ class MainTerminal:
|
||||
print(f"{OUTPUT_FORMATS['file']} 创建文件")
|
||||
elif tool_name == "read_file":
|
||||
print(f"{OUTPUT_FORMATS['file']} 读取文件")
|
||||
elif tool_name == "ocr_image":
|
||||
print(f"{OUTPUT_FORMATS['file']} 图片OCR")
|
||||
elif tool_name == "modify_file":
|
||||
print(f"{OUTPUT_FORMATS['file']} 修改文件")
|
||||
elif tool_name == "delete_file":
|
||||
@ -987,6 +991,21 @@ class MainTerminal:
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ocr_image",
|
||||
"description": "使用 DeepSeek-OCR 读取图片中的文字或根据提示生成描述,仅支持本地图片路径。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "项目内的图片路径"},
|
||||
"prompt": {"type": "string", "description": "传递给 OCR 模型的提示词,如“请识别图片中的文字”"}
|
||||
},
|
||||
"required": ["path", "prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@ -1464,6 +1483,12 @@ class MainTerminal:
|
||||
try:
|
||||
if tool_name == "read_file":
|
||||
result = self._handle_read_tool(arguments)
|
||||
elif tool_name == "ocr_image":
|
||||
path = arguments.get("path")
|
||||
prompt = arguments.get("prompt")
|
||||
if not path:
|
||||
return json.dumps({"success": False, "error": "缺少 path 参数", "warnings": []}, ensure_ascii=False)
|
||||
result = self.ocr_client.ocr_image(path=path, prompt=prompt or "")
|
||||
|
||||
# 终端会话管理工具
|
||||
elif tool_name == "terminal_session":
|
||||
|
||||
@ -32,7 +32,7 @@ TOOL_CATEGORIES: Dict[str, ToolCategory] = {
|
||||
),
|
||||
"read_focus": ToolCategory(
|
||||
label="阅读聚焦",
|
||||
tools=["read_file", "focus_file", "unfocus_file"],
|
||||
tools=["read_file", "focus_file", "unfocus_file", "ocr_image"],
|
||||
),
|
||||
"terminal_realtime": ToolCategory(
|
||||
label="实时终端",
|
||||
|
||||
95
sub_agent/modules/ocr_client.py
Normal file
95
sub_agent/modules/ocr_client.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""DeepSeek-OCR 客户端(子智能体专用)。"""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from config import OCR_API_BASE_URL, OCR_API_KEY, OCR_MODEL_ID, OCR_MAX_TOKENS
|
||||
from modules.file_manager import FileManager
|
||||
|
||||
|
||||
class OCRClient:
|
||||
"""封装 DeepSeek-OCR 调用逻辑。"""
|
||||
|
||||
def __init__(self, project_path: str, file_manager: FileManager):
|
||||
self.project_path = Path(project_path).resolve()
|
||||
self.file_manager = file_manager
|
||||
|
||||
base_url = (OCR_API_BASE_URL or "").rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=OCR_API_KEY,
|
||||
base_url=base_url,
|
||||
)
|
||||
self.model = OCR_MODEL_ID or "deepseek-ai/DeepSeek-OCR"
|
||||
self.max_tokens = OCR_MAX_TOKENS or 4096
|
||||
self.max_image_size = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
def _validate_image_path(self, path: str):
|
||||
valid, error, full_path = self.file_manager._validate_path(path)
|
||||
if not valid:
|
||||
return False, error, None
|
||||
if not full_path.exists():
|
||||
return False, "文件不存在", None
|
||||
if not full_path.is_file():
|
||||
return False, "不是文件", None
|
||||
return True, "", full_path
|
||||
|
||||
def ocr_image(self, path: str, prompt: str) -> Dict:
|
||||
warnings: List[str] = []
|
||||
|
||||
valid, error, full_path = self._validate_image_path(path)
|
||||
if not valid:
|
||||
return {"success": False, "error": error, "warnings": warnings}
|
||||
|
||||
if not prompt or not str(prompt).strip():
|
||||
return {"success": False, "error": "prompt 不能为空", "warnings": warnings}
|
||||
|
||||
try:
|
||||
data = full_path.read_bytes()
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": f"读取文件失败: {exc}", "warnings": warnings}
|
||||
|
||||
size = len(data)
|
||||
if size <= 0:
|
||||
return {"success": False, "error": "文件为空,无法识别", "warnings": warnings}
|
||||
|
||||
if size > self.max_image_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"图片过大({size}字节),上限为{self.max_image_size}字节",
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(str(full_path))
|
||||
if not mime_type or not mime_type.startswith("image/"):
|
||||
warnings.append("无法确定图片类型,已按 JPEG 处理")
|
||||
mime_type = "image/jpeg"
|
||||
|
||||
base64_image = base64.b64encode(data).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=0,
|
||||
)
|
||||
content = response.choices[0].message.content if response.choices else ""
|
||||
return {"success": True, "content": content or "", "warnings": warnings}
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": f"OCR 调用失败: {exc}", "warnings": warnings}
|
||||
Loading…
Reference in New Issue
Block a user