547 lines
20 KiB
Python
547 lines
20 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
from typing import Dict, List, Optional
|
||
|
||
|
||
async def finalize_pending_append(*, pending_append, append_probe_buffer: str, response_text: str, stream_completed: bool, finish_reason: str = None, web_terminal, sender, debug_log):
|
||
"""在流式输出结束后处理追加写入"""
|
||
|
||
result = {
|
||
"handled": False,
|
||
"success": False,
|
||
"summary": None,
|
||
"summary_message": None,
|
||
"tool_content": None,
|
||
"tool_call_id": None,
|
||
"path": None,
|
||
"forced": False,
|
||
"error": None,
|
||
"assistant_content": response_text,
|
||
"lines": 0,
|
||
"bytes": 0,
|
||
"finish_reason": finish_reason,
|
||
"appended_content": "",
|
||
"assistant_metadata": None
|
||
}
|
||
|
||
if not pending_append:
|
||
return result, pending_append, append_probe_buffer
|
||
|
||
state = pending_append
|
||
path = state.get("path")
|
||
tool_call_id = state.get("tool_call_id")
|
||
buffer = state.get("buffer", "")
|
||
start_marker = state.get("start_marker")
|
||
end_marker = state.get("end_marker")
|
||
start_idx = state.get("content_start")
|
||
end_idx = state.get("end_index")
|
||
|
||
display_id = state.get("display_id")
|
||
|
||
result.update({
|
||
"handled": True,
|
||
"path": path,
|
||
"tool_call_id": tool_call_id,
|
||
"display_id": display_id
|
||
})
|
||
|
||
if path is None or tool_call_id is None:
|
||
error_msg = "append_to_file 状态不完整,缺少路径或ID。"
|
||
debug_log(error_msg)
|
||
result["error"] = error_msg
|
||
result["summary_message"] = error_msg
|
||
result["tool_content"] = json.dumps({
|
||
"success": False,
|
||
"error": error_msg
|
||
}, ensure_ascii=False)
|
||
if display_id:
|
||
sender('update_action', {
|
||
'id': display_id,
|
||
'status': 'failed',
|
||
'preparing_id': tool_call_id,
|
||
'message': error_msg
|
||
})
|
||
pending_append = None
|
||
return result, pending_append, append_probe_buffer
|
||
|
||
if start_idx is None:
|
||
error_msg = f"未检测到格式正确的开始标识 {start_marker}。"
|
||
debug_log(error_msg)
|
||
result["error"] = error_msg
|
||
result["summary_message"] = error_msg
|
||
result["tool_content"] = json.dumps({
|
||
"success": False,
|
||
"path": path,
|
||
"error": error_msg
|
||
}, ensure_ascii=False)
|
||
if display_id:
|
||
sender('update_action', {
|
||
'id': display_id,
|
||
'status': 'failed',
|
||
'preparing_id': tool_call_id,
|
||
'message': error_msg
|
||
})
|
||
pending_append = None
|
||
return result, pending_append, append_probe_buffer
|
||
|
||
forced = False
|
||
if end_idx is None:
|
||
forced = True
|
||
# 查找下一个<<<,否则使用整个缓冲结尾
|
||
remaining = buffer[start_idx:]
|
||
next_marker = remaining.find("<<<", len(end_marker))
|
||
if next_marker != -1:
|
||
end_idx = start_idx + next_marker
|
||
else:
|
||
end_idx = len(buffer)
|
||
|
||
content = buffer[start_idx:end_idx]
|
||
if content.startswith('\n'):
|
||
content = content[1:]
|
||
|
||
if not content:
|
||
error_msg = "未检测到需要追加的内容,请严格按照<<<APPEND:path>>>...<<<END_APPEND>>>格式输出。"
|
||
debug_log(error_msg)
|
||
result["error"] = error_msg
|
||
result["forced"] = forced
|
||
result["tool_content"] = json.dumps({
|
||
"success": False,
|
||
"path": path,
|
||
"error": error_msg
|
||
}, ensure_ascii=False)
|
||
if display_id:
|
||
sender('update_action', {
|
||
'id': display_id,
|
||
'status': 'failed',
|
||
'preparing_id': tool_call_id,
|
||
'message': error_msg
|
||
})
|
||
pending_append = None
|
||
return result, pending_append, append_probe_buffer
|
||
|
||
assistant_message_lines = []
|
||
if start_marker:
|
||
assistant_message_lines.append(start_marker)
|
||
assistant_message_lines.append(content)
|
||
if not forced and end_marker:
|
||
assistant_message_lines.append(end_marker)
|
||
assistant_message_text = "\n".join(assistant_message_lines)
|
||
result["assistant_content"] = assistant_message_text
|
||
assistant_metadata = {
|
||
"append_payload": {
|
||
"path": path,
|
||
"tool_call_id": tool_call_id,
|
||
"forced": forced,
|
||
"has_end_marker": not forced
|
||
}
|
||
}
|
||
result["assistant_metadata"] = assistant_metadata
|
||
|
||
write_result = web_terminal.file_manager.append_file(path, content)
|
||
if write_result.get("success"):
|
||
bytes_written = len(content.encode('utf-8'))
|
||
line_count = content.count('\n')
|
||
if content and not content.endswith('\n'):
|
||
line_count += 1
|
||
|
||
summary = f"已向 {path} 追加 {line_count} 行({bytes_written} 字节)"
|
||
if forced:
|
||
summary += "。未检测到 <<<END_APPEND>>> 标记,系统已在流结束处完成写入。如内容未完成,请重新调用 append_to_file 并按标准格式补充;如已完成,可继续后续步骤。"
|
||
|
||
result.update({
|
||
"success": True,
|
||
"summary": summary,
|
||
"summary_message": summary,
|
||
"forced": forced,
|
||
"lines": line_count,
|
||
"bytes": bytes_written,
|
||
"appended_content": content,
|
||
"tool_content": json.dumps({
|
||
"success": True,
|
||
"path": path,
|
||
"lines": line_count,
|
||
"bytes": bytes_written,
|
||
"forced": forced,
|
||
"message": summary,
|
||
"finish_reason": finish_reason
|
||
}, ensure_ascii=False)
|
||
})
|
||
|
||
assistant_meta_payload = result["assistant_metadata"]["append_payload"]
|
||
assistant_meta_payload["lines"] = line_count
|
||
assistant_meta_payload["bytes"] = bytes_written
|
||
assistant_meta_payload["success"] = True
|
||
|
||
summary_payload = {
|
||
"success": True,
|
||
"path": path,
|
||
"lines": line_count,
|
||
"bytes": bytes_written,
|
||
"forced": forced,
|
||
"message": summary
|
||
}
|
||
|
||
if display_id:
|
||
sender('update_action', {
|
||
'id': display_id,
|
||
'status': 'completed',
|
||
'result': summary_payload,
|
||
'preparing_id': tool_call_id,
|
||
'message': summary
|
||
})
|
||
|
||
debug_log(f"追加写入完成: {summary}")
|
||
else:
|
||
error_msg = write_result.get("error", "追加写入失败")
|
||
result.update({
|
||
"error": error_msg,
|
||
"summary_message": error_msg,
|
||
"forced": forced,
|
||
"appended_content": content,
|
||
"tool_content": json.dumps({
|
||
"success": False,
|
||
"path": path,
|
||
"error": error_msg,
|
||
"finish_reason": finish_reason
|
||
}, ensure_ascii=False)
|
||
})
|
||
debug_log(f"追加写入失败: {error_msg}")
|
||
|
||
if result["assistant_metadata"]:
|
||
assistant_meta_payload = result["assistant_metadata"]["append_payload"]
|
||
assistant_meta_payload["lines"] = content.count('\n') + (0 if content.endswith('\n') or not content else 1)
|
||
assistant_meta_payload["bytes"] = len(content.encode('utf-8'))
|
||
assistant_meta_payload["success"] = False
|
||
|
||
failure_payload = {
|
||
"success": False,
|
||
"path": path,
|
||
"error": error_msg,
|
||
"forced": forced
|
||
}
|
||
|
||
if display_id:
|
||
sender('update_action', {
|
||
'id': display_id,
|
||
'status': 'completed',
|
||
'result': failure_payload,
|
||
'preparing_id': tool_call_id,
|
||
'message': error_msg
|
||
})
|
||
|
||
pending_append = None
|
||
append_probe_buffer = ""
|
||
if hasattr(web_terminal, "pending_append_request"):
|
||
web_terminal.pending_append_request = None
|
||
return result, pending_append, append_probe_buffer
|
||
|
||
|
||
|
||
async def finalize_pending_modify(*, pending_modify, modify_probe_buffer: str, response_text: str, stream_completed: bool, finish_reason: str = None, web_terminal, sender, debug_log):
|
||
"""在流式输出结束后处理修改写入"""
|
||
|
||
result = {
|
||
"handled": False,
|
||
"success": False,
|
||
"path": None,
|
||
"tool_call_id": None,
|
||
"display_id": None,
|
||
"total_blocks": 0,
|
||
"completed_blocks": [],
|
||
"failed_blocks": [],
|
||
"forced": False,
|
||
"details": [],
|
||
"error": None,
|
||
"assistant_content": response_text,
|
||
"assistant_metadata": None,
|
||
"tool_content": None,
|
||
"summary_message": None,
|
||
"finish_reason": finish_reason
|
||
}
|
||
|
||
if not pending_modify:
|
||
return result, pending_modify, modify_probe_buffer
|
||
|
||
state = pending_modify
|
||
path = state.get("path")
|
||
tool_call_id = state.get("tool_call_id")
|
||
display_id = state.get("display_id")
|
||
start_marker = state.get("start_marker")
|
||
end_marker = state.get("end_marker")
|
||
buffer = state.get("buffer", "")
|
||
raw_buffer = state.get("raw_buffer", "")
|
||
end_index = state.get("end_index")
|
||
|
||
result.update({
|
||
"handled": True,
|
||
"path": path,
|
||
"tool_call_id": tool_call_id,
|
||
"display_id": display_id
|
||
})
|
||
|
||
if not state.get("start_seen"):
|
||
error_msg = "未检测到格式正确的 <<<MODIFY:path>>> 标记。"
|
||
debug_log(error_msg)
|
||
result["error"] = error_msg
|
||
result["summary_message"] = error_msg
|
||
result["tool_content"] = json.dumps({
|
||
"success": False,
|
||
"path": path,
|
||
"error": error_msg,
|
||
"finish_reason": finish_reason
|
||
}, ensure_ascii=False)
|
||
if display_id:
|
||
sender('update_action', {
|
||
'id': display_id,
|
||
'status': 'failed',
|
||
'preparing_id': tool_call_id,
|
||
'message': error_msg
|
||
})
|
||
if hasattr(web_terminal, "pending_modify_request"):
|
||
web_terminal.pending_modify_request = None
|
||
pending_modify = None
|
||
modify_probe_buffer = ""
|
||
return result, pending_modify, modify_probe_buffer
|
||
|
||
forced = end_index is None
|
||
apply_text = buffer if forced else buffer[:end_index]
|
||
raw_content = raw_buffer if forced else raw_buffer[:len(start_marker) + end_index + len(end_marker)]
|
||
if raw_content:
|
||
result["assistant_content"] = raw_content
|
||
|
||
blocks_info = []
|
||
block_reports = {}
|
||
detected_indices = set()
|
||
block_pattern = re.compile(r"\[replace:(\d+)\](.*?)\[/replace\]", re.DOTALL)
|
||
structure_warnings: List[str] = []
|
||
structure_detail_entries: List[Dict] = []
|
||
|
||
def record_structure_warning(message: str, hint: Optional[str] = None):
|
||
"""记录结构性缺陷,便于给出更具体的反馈。"""
|
||
if message in structure_warnings:
|
||
return
|
||
structure_warnings.append(message)
|
||
structure_detail_entries.append({
|
||
"index": 0,
|
||
"status": "failed",
|
||
"reason": message,
|
||
"removed_lines": 0,
|
||
"added_lines": 0,
|
||
"hint": hint or "请严格按照模板输出:[replace:n] + <<OLD>>/<<NEW>> + [/replace],并使用 <<<END_MODIFY>>> 收尾。"
|
||
})
|
||
|
||
def extract_segment(body: str, tag: str):
|
||
marker = f"<<{tag}>>"
|
||
end_tag = "<<END>>"
|
||
start_pos = body.find(marker)
|
||
if start_pos == -1:
|
||
return None, f"缺少 {marker}"
|
||
start_pos += len(marker)
|
||
if body[start_pos:start_pos+2] == "\r\n":
|
||
start_pos += 2
|
||
elif body[start_pos:start_pos+1] == "\n":
|
||
start_pos += 1
|
||
end_pos = body.find(end_tag, start_pos)
|
||
if end_pos == -1:
|
||
return None, f"缺少 {end_tag}"
|
||
segment = body[start_pos:end_pos]
|
||
return segment, None
|
||
|
||
for match in block_pattern.finditer(apply_text):
|
||
try:
|
||
index = int(match.group(1))
|
||
except ValueError:
|
||
continue
|
||
body = match.group(2)
|
||
if index in detected_indices:
|
||
continue
|
||
detected_indices.add(index)
|
||
block_reports[index] = {
|
||
"index": index,
|
||
"status": "pending",
|
||
"reason": None,
|
||
"removed_lines": 0,
|
||
"added_lines": 0,
|
||
"hint": None
|
||
}
|
||
old_content, old_error = extract_segment(body, "OLD")
|
||
new_content, new_error = extract_segment(body, "NEW")
|
||
if old_error or new_error:
|
||
reason = old_error or new_error
|
||
block_reports[index]["status"] = "failed"
|
||
block_reports[index]["reason"] = reason
|
||
blocks_info.append({
|
||
"index": index,
|
||
"old": old_content,
|
||
"new": new_content,
|
||
"error": old_error or new_error
|
||
})
|
||
|
||
if not blocks_info:
|
||
has_replace_start = bool(re.search(r"\[replace:\s*\d+\]", apply_text))
|
||
has_replace_end = "[/replace]" in apply_text
|
||
has_old_tag = "<<OLD>>" in apply_text
|
||
has_new_tag = "<<NEW>>" in apply_text
|
||
|
||
if has_replace_start and not has_replace_end:
|
||
record_structure_warning("检测到 [replace:n] 标记但缺少对应的 [/replace] 结束标记。")
|
||
if has_replace_end and not has_replace_start:
|
||
record_structure_warning("检测到 [/replace] 结束标记但缺少对应的 [replace:n] 起始标记。")
|
||
|
||
old_tags = len(re.findall(r"<<OLD>>", apply_text))
|
||
completed_old_tags = len(re.findall(r"<<OLD>>[\s\S]*?<<END>>", apply_text))
|
||
if old_tags and completed_old_tags < old_tags:
|
||
record_structure_warning("检测到 <<OLD>> 段落但未看到对应的 <<END>> 结束标记。")
|
||
|
||
new_tags = len(re.findall(r"<<NEW>>", apply_text))
|
||
completed_new_tags = len(re.findall(r"<<NEW>>[\s\S]*?<<END>>", apply_text))
|
||
if new_tags and completed_new_tags < new_tags:
|
||
record_structure_warning("检测到 <<NEW>> 段落但未看到对应的 <<END>> 结束标记。")
|
||
|
||
if (has_replace_start or has_replace_end or has_old_tag or has_new_tag) and not structure_warnings:
|
||
record_structure_warning("检测到部分补丁标记,但整体结构不完整,请严格按照模板填写所有标记。")
|
||
|
||
total_blocks = len(blocks_info)
|
||
result["total_blocks"] = total_blocks
|
||
if forced:
|
||
debug_log("未检测到 <<<END_MODIFY>>>,将在流结束处执行已识别的修改块。")
|
||
result["forced"] = True
|
||
|
||
blocks_to_apply = [
|
||
{"index": block["index"], "old": block["old"], "new": block["new"]}
|
||
for block in blocks_info
|
||
if block["error"] is None and block["old"] is not None and block["new"] is not None
|
||
]
|
||
|
||
# 记录格式残缺的块
|
||
for block in blocks_info:
|
||
if block["error"]:
|
||
idx = block["index"]
|
||
block_reports[idx]["status"] = "failed"
|
||
block_reports[idx]["reason"] = block["error"]
|
||
block_reports[idx]["hint"] = "请检查补丁块的 OLD/NEW 标记是否完整,必要时复用 terminal_snapshot 或终端命令重新调整。"
|
||
|
||
apply_result = {}
|
||
if blocks_to_apply:
|
||
apply_result = web_terminal.file_manager.apply_modify_blocks(path, blocks_to_apply)
|
||
else:
|
||
apply_result = {"success": False, "completed": [], "failed": [], "results": [], "write_performed": False, "error": None}
|
||
|
||
block_result_map = {item["index"]: item for item in apply_result.get("results", [])}
|
||
|
||
for block in blocks_info:
|
||
idx = block["index"]
|
||
report = block_reports.get(idx)
|
||
if report is None:
|
||
continue
|
||
if report["status"] == "failed":
|
||
continue
|
||
block_apply = block_result_map.get(idx)
|
||
if not block_apply:
|
||
report["status"] = "failed"
|
||
report["reason"] = "未执行,可能未找到匹配原文"
|
||
report["hint"] = report.get("hint") or "请确认 OLD 文本与文件内容完全一致;若多次失败,可改用终端命令/Python 进行精准替换。"
|
||
continue
|
||
status = block_apply.get("status")
|
||
report["removed_lines"] = block_apply.get("removed_lines", 0)
|
||
report["added_lines"] = block_apply.get("added_lines", 0)
|
||
if block_apply.get("hint"):
|
||
report["hint"] = block_apply.get("hint")
|
||
if status == "success":
|
||
report["status"] = "completed"
|
||
elif status == "not_found":
|
||
report["status"] = "failed"
|
||
report["reason"] = block_apply.get("reason") or "未找到匹配的原文"
|
||
if not report.get("hint"):
|
||
report["hint"] = "请使用 terminal_snapshot/grep -n 校验原文,或在说明后改用 run_command/python 精确替换。"
|
||
else:
|
||
report["status"] = "failed"
|
||
report["reason"] = block_apply.get("reason") or "替换失败"
|
||
if not report.get("hint"):
|
||
report["hint"] = block_apply.get("hint") or "若多次尝试仍失败,可考虑利用终端命令或 Python 小脚本完成此次修改。"
|
||
|
||
completed_blocks = sorted([idx for idx, rep in block_reports.items() if rep["status"] == "completed"])
|
||
failed_blocks = sorted([idx for idx, rep in block_reports.items() if rep["status"] != "completed"])
|
||
|
||
result["completed_blocks"] = completed_blocks
|
||
result["failed_blocks"] = failed_blocks
|
||
details = sorted(block_reports.values(), key=lambda x: x["index"])
|
||
if structure_detail_entries:
|
||
details = structure_detail_entries + details
|
||
result["details"] = details
|
||
|
||
summary_parts = []
|
||
if total_blocks == 0:
|
||
summary_parts.append("未检测到有效的修改块,未执行任何修改。")
|
||
summary_parts.extend(structure_warnings)
|
||
else:
|
||
if not completed_blocks and failed_blocks:
|
||
summary_parts.append(f"共检测到 {total_blocks} 个修改块,全部未执行。")
|
||
elif completed_blocks and not failed_blocks:
|
||
summary_parts.append(f"共 {total_blocks} 个修改块全部完成。")
|
||
else:
|
||
summary_parts.append(
|
||
f"共检测到 {total_blocks} 个修改块,其中成功 {len(completed_blocks)} 个,失败 {len(failed_blocks)} 个。"
|
||
)
|
||
if forced:
|
||
summary_parts.append("未检测到 <<<END_MODIFY>>> 标记,系统已在流结束处执行补丁。")
|
||
if apply_result.get("error"):
|
||
summary_parts.append(apply_result["error"])
|
||
|
||
matching_note = "提示:补丁匹配基于完整文本,包含注释和空白符,请确保 <<<OLD>>> 段落与文件内容逐字一致。如果修改成功,请忽略,如果失败,请明确原文后再次尝试。"
|
||
summary_parts.append(matching_note)
|
||
summary_message = " ".join(summary_parts).strip()
|
||
result["summary_message"] = summary_message
|
||
result["success"] = bool(completed_blocks) and not failed_blocks and apply_result.get("error") is None
|
||
|
||
tool_payload = {
|
||
"success": result["success"],
|
||
"path": path,
|
||
"total_blocks": total_blocks,
|
||
"completed": completed_blocks,
|
||
"failed": [
|
||
{
|
||
"index": rep["index"],
|
||
"reason": rep.get("reason"),
|
||
"hint": rep.get("hint")
|
||
}
|
||
for rep in result["details"] if rep["status"] != "completed"
|
||
],
|
||
"forced": forced,
|
||
"message": summary_message,
|
||
"finish_reason": finish_reason,
|
||
"details": result["details"]
|
||
}
|
||
if apply_result.get("error"):
|
||
tool_payload["error"] = apply_result["error"]
|
||
|
||
result["tool_content"] = json.dumps(tool_payload, ensure_ascii=False)
|
||
result["assistant_metadata"] = {
|
||
"modify_payload": {
|
||
"path": path,
|
||
"total_blocks": total_blocks,
|
||
"completed": completed_blocks,
|
||
"failed": failed_blocks,
|
||
"forced": forced,
|
||
"details": result["details"]
|
||
}
|
||
}
|
||
|
||
if display_id:
|
||
sender('update_action', {
|
||
'id': display_id,
|
||
'status': 'completed' if result["success"] else 'failed',
|
||
'result': tool_payload,
|
||
'preparing_id': tool_call_id,
|
||
'message': summary_message
|
||
})
|
||
|
||
pending_modify = None
|
||
modify_probe_buffer = ""
|
||
if hasattr(web_terminal, "pending_modify_request"):
|
||
web_terminal.pending_modify_request = None
|
||
return result, pending_modify, modify_probe_buffer
|
||
|