agent-Specialization/server/chat_flow_pending_writes.py

547 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import re
from typing import Dict, List, Optional
async def finalize_pending_append(*, pending_append, append_probe_buffer: str, response_text: str, stream_completed: bool, finish_reason: str = None, web_terminal, sender, debug_log):
"""在流式输出结束后处理追加写入"""
result = {
"handled": False,
"success": False,
"summary": None,
"summary_message": None,
"tool_content": None,
"tool_call_id": None,
"path": None,
"forced": False,
"error": None,
"assistant_content": response_text,
"lines": 0,
"bytes": 0,
"finish_reason": finish_reason,
"appended_content": "",
"assistant_metadata": None
}
if not pending_append:
return result, pending_append, append_probe_buffer
state = pending_append
path = state.get("path")
tool_call_id = state.get("tool_call_id")
buffer = state.get("buffer", "")
start_marker = state.get("start_marker")
end_marker = state.get("end_marker")
start_idx = state.get("content_start")
end_idx = state.get("end_index")
display_id = state.get("display_id")
result.update({
"handled": True,
"path": path,
"tool_call_id": tool_call_id,
"display_id": display_id
})
if path is None or tool_call_id is None:
error_msg = "append_to_file 状态不完整缺少路径或ID。"
debug_log(error_msg)
result["error"] = error_msg
result["summary_message"] = error_msg
result["tool_content"] = json.dumps({
"success": False,
"error": error_msg
}, ensure_ascii=False)
if display_id:
sender('update_action', {
'id': display_id,
'status': 'failed',
'preparing_id': tool_call_id,
'message': error_msg
})
pending_append = None
return result, pending_append, append_probe_buffer
if start_idx is None:
error_msg = f"未检测到格式正确的开始标识 {start_marker}"
debug_log(error_msg)
result["error"] = error_msg
result["summary_message"] = error_msg
result["tool_content"] = json.dumps({
"success": False,
"path": path,
"error": error_msg
}, ensure_ascii=False)
if display_id:
sender('update_action', {
'id': display_id,
'status': 'failed',
'preparing_id': tool_call_id,
'message': error_msg
})
pending_append = None
return result, pending_append, append_probe_buffer
forced = False
if end_idx is None:
forced = True
# 查找下一个<<<,否则使用整个缓冲结尾
remaining = buffer[start_idx:]
next_marker = remaining.find("<<<", len(end_marker))
if next_marker != -1:
end_idx = start_idx + next_marker
else:
end_idx = len(buffer)
content = buffer[start_idx:end_idx]
if content.startswith('\n'):
content = content[1:]
if not content:
error_msg = "未检测到需要追加的内容,请严格按照<<<APPEND:path>>>...<<<END_APPEND>>>格式输出。"
debug_log(error_msg)
result["error"] = error_msg
result["forced"] = forced
result["tool_content"] = json.dumps({
"success": False,
"path": path,
"error": error_msg
}, ensure_ascii=False)
if display_id:
sender('update_action', {
'id': display_id,
'status': 'failed',
'preparing_id': tool_call_id,
'message': error_msg
})
pending_append = None
return result, pending_append, append_probe_buffer
assistant_message_lines = []
if start_marker:
assistant_message_lines.append(start_marker)
assistant_message_lines.append(content)
if not forced and end_marker:
assistant_message_lines.append(end_marker)
assistant_message_text = "\n".join(assistant_message_lines)
result["assistant_content"] = assistant_message_text
assistant_metadata = {
"append_payload": {
"path": path,
"tool_call_id": tool_call_id,
"forced": forced,
"has_end_marker": not forced
}
}
result["assistant_metadata"] = assistant_metadata
write_result = web_terminal.file_manager.append_file(path, content)
if write_result.get("success"):
bytes_written = len(content.encode('utf-8'))
line_count = content.count('\n')
if content and not content.endswith('\n'):
line_count += 1
summary = f"已向 {path} 追加 {line_count} 行({bytes_written} 字节)"
if forced:
summary += "。未检测到 <<<END_APPEND>>> 标记,系统已在流结束处完成写入。如内容未完成,请重新调用 append_to_file 并按标准格式补充;如已完成,可继续后续步骤。"
result.update({
"success": True,
"summary": summary,
"summary_message": summary,
"forced": forced,
"lines": line_count,
"bytes": bytes_written,
"appended_content": content,
"tool_content": json.dumps({
"success": True,
"path": path,
"lines": line_count,
"bytes": bytes_written,
"forced": forced,
"message": summary,
"finish_reason": finish_reason
}, ensure_ascii=False)
})
assistant_meta_payload = result["assistant_metadata"]["append_payload"]
assistant_meta_payload["lines"] = line_count
assistant_meta_payload["bytes"] = bytes_written
assistant_meta_payload["success"] = True
summary_payload = {
"success": True,
"path": path,
"lines": line_count,
"bytes": bytes_written,
"forced": forced,
"message": summary
}
if display_id:
sender('update_action', {
'id': display_id,
'status': 'completed',
'result': summary_payload,
'preparing_id': tool_call_id,
'message': summary
})
debug_log(f"追加写入完成: {summary}")
else:
error_msg = write_result.get("error", "追加写入失败")
result.update({
"error": error_msg,
"summary_message": error_msg,
"forced": forced,
"appended_content": content,
"tool_content": json.dumps({
"success": False,
"path": path,
"error": error_msg,
"finish_reason": finish_reason
}, ensure_ascii=False)
})
debug_log(f"追加写入失败: {error_msg}")
if result["assistant_metadata"]:
assistant_meta_payload = result["assistant_metadata"]["append_payload"]
assistant_meta_payload["lines"] = content.count('\n') + (0 if content.endswith('\n') or not content else 1)
assistant_meta_payload["bytes"] = len(content.encode('utf-8'))
assistant_meta_payload["success"] = False
failure_payload = {
"success": False,
"path": path,
"error": error_msg,
"forced": forced
}
if display_id:
sender('update_action', {
'id': display_id,
'status': 'completed',
'result': failure_payload,
'preparing_id': tool_call_id,
'message': error_msg
})
pending_append = None
append_probe_buffer = ""
if hasattr(web_terminal, "pending_append_request"):
web_terminal.pending_append_request = None
return result, pending_append, append_probe_buffer
async def finalize_pending_modify(*, pending_modify, modify_probe_buffer: str, response_text: str, stream_completed: bool, finish_reason: str = None, web_terminal, sender, debug_log):
"""在流式输出结束后处理修改写入"""
result = {
"handled": False,
"success": False,
"path": None,
"tool_call_id": None,
"display_id": None,
"total_blocks": 0,
"completed_blocks": [],
"failed_blocks": [],
"forced": False,
"details": [],
"error": None,
"assistant_content": response_text,
"assistant_metadata": None,
"tool_content": None,
"summary_message": None,
"finish_reason": finish_reason
}
if not pending_modify:
return result, pending_modify, modify_probe_buffer
state = pending_modify
path = state.get("path")
tool_call_id = state.get("tool_call_id")
display_id = state.get("display_id")
start_marker = state.get("start_marker")
end_marker = state.get("end_marker")
buffer = state.get("buffer", "")
raw_buffer = state.get("raw_buffer", "")
end_index = state.get("end_index")
result.update({
"handled": True,
"path": path,
"tool_call_id": tool_call_id,
"display_id": display_id
})
if not state.get("start_seen"):
error_msg = "未检测到格式正确的 <<<MODIFY:path>>> 标记。"
debug_log(error_msg)
result["error"] = error_msg
result["summary_message"] = error_msg
result["tool_content"] = json.dumps({
"success": False,
"path": path,
"error": error_msg,
"finish_reason": finish_reason
}, ensure_ascii=False)
if display_id:
sender('update_action', {
'id': display_id,
'status': 'failed',
'preparing_id': tool_call_id,
'message': error_msg
})
if hasattr(web_terminal, "pending_modify_request"):
web_terminal.pending_modify_request = None
pending_modify = None
modify_probe_buffer = ""
return result, pending_modify, modify_probe_buffer
forced = end_index is None
apply_text = buffer if forced else buffer[:end_index]
raw_content = raw_buffer if forced else raw_buffer[:len(start_marker) + end_index + len(end_marker)]
if raw_content:
result["assistant_content"] = raw_content
blocks_info = []
block_reports = {}
detected_indices = set()
block_pattern = re.compile(r"\[replace:(\d+)\](.*?)\[/replace\]", re.DOTALL)
structure_warnings: List[str] = []
structure_detail_entries: List[Dict] = []
def record_structure_warning(message: str, hint: Optional[str] = None):
"""记录结构性缺陷,便于给出更具体的反馈。"""
if message in structure_warnings:
return
structure_warnings.append(message)
structure_detail_entries.append({
"index": 0,
"status": "failed",
"reason": message,
"removed_lines": 0,
"added_lines": 0,
"hint": hint or "请严格按照模板输出:[replace:n] + <<OLD>>/<<NEW>> + [/replace],并使用 <<<END_MODIFY>>> 收尾。"
})
def extract_segment(body: str, tag: str):
marker = f"<<{tag}>>"
end_tag = "<<END>>"
start_pos = body.find(marker)
if start_pos == -1:
return None, f"缺少 {marker}"
start_pos += len(marker)
if body[start_pos:start_pos+2] == "\r\n":
start_pos += 2
elif body[start_pos:start_pos+1] == "\n":
start_pos += 1
end_pos = body.find(end_tag, start_pos)
if end_pos == -1:
return None, f"缺少 {end_tag}"
segment = body[start_pos:end_pos]
return segment, None
for match in block_pattern.finditer(apply_text):
try:
index = int(match.group(1))
except ValueError:
continue
body = match.group(2)
if index in detected_indices:
continue
detected_indices.add(index)
block_reports[index] = {
"index": index,
"status": "pending",
"reason": None,
"removed_lines": 0,
"added_lines": 0,
"hint": None
}
old_content, old_error = extract_segment(body, "OLD")
new_content, new_error = extract_segment(body, "NEW")
if old_error or new_error:
reason = old_error or new_error
block_reports[index]["status"] = "failed"
block_reports[index]["reason"] = reason
blocks_info.append({
"index": index,
"old": old_content,
"new": new_content,
"error": old_error or new_error
})
if not blocks_info:
has_replace_start = bool(re.search(r"\[replace:\s*\d+\]", apply_text))
has_replace_end = "[/replace]" in apply_text
has_old_tag = "<<OLD>>" in apply_text
has_new_tag = "<<NEW>>" in apply_text
if has_replace_start and not has_replace_end:
record_structure_warning("检测到 [replace:n] 标记但缺少对应的 [/replace] 结束标记。")
if has_replace_end and not has_replace_start:
record_structure_warning("检测到 [/replace] 结束标记但缺少对应的 [replace:n] 起始标记。")
old_tags = len(re.findall(r"<<OLD>>", apply_text))
completed_old_tags = len(re.findall(r"<<OLD>>[\s\S]*?<<END>>", apply_text))
if old_tags and completed_old_tags < old_tags:
record_structure_warning("检测到 <<OLD>> 段落但未看到对应的 <<END>> 结束标记。")
new_tags = len(re.findall(r"<<NEW>>", apply_text))
completed_new_tags = len(re.findall(r"<<NEW>>[\s\S]*?<<END>>", apply_text))
if new_tags and completed_new_tags < new_tags:
record_structure_warning("检测到 <<NEW>> 段落但未看到对应的 <<END>> 结束标记。")
if (has_replace_start or has_replace_end or has_old_tag or has_new_tag) and not structure_warnings:
record_structure_warning("检测到部分补丁标记,但整体结构不完整,请严格按照模板填写所有标记。")
total_blocks = len(blocks_info)
result["total_blocks"] = total_blocks
if forced:
debug_log("未检测到 <<<END_MODIFY>>>,将在流结束处执行已识别的修改块。")
result["forced"] = True
blocks_to_apply = [
{"index": block["index"], "old": block["old"], "new": block["new"]}
for block in blocks_info
if block["error"] is None and block["old"] is not None and block["new"] is not None
]
# 记录格式残缺的块
for block in blocks_info:
if block["error"]:
idx = block["index"]
block_reports[idx]["status"] = "failed"
block_reports[idx]["reason"] = block["error"]
block_reports[idx]["hint"] = "请检查补丁块的 OLD/NEW 标记是否完整,必要时复用 terminal_snapshot 或终端命令重新调整。"
apply_result = {}
if blocks_to_apply:
apply_result = web_terminal.file_manager.apply_modify_blocks(path, blocks_to_apply)
else:
apply_result = {"success": False, "completed": [], "failed": [], "results": [], "write_performed": False, "error": None}
block_result_map = {item["index"]: item for item in apply_result.get("results", [])}
for block in blocks_info:
idx = block["index"]
report = block_reports.get(idx)
if report is None:
continue
if report["status"] == "failed":
continue
block_apply = block_result_map.get(idx)
if not block_apply:
report["status"] = "failed"
report["reason"] = "未执行,可能未找到匹配原文"
report["hint"] = report.get("hint") or "请确认 OLD 文本与文件内容完全一致;若多次失败,可改用终端命令/Python 进行精准替换。"
continue
status = block_apply.get("status")
report["removed_lines"] = block_apply.get("removed_lines", 0)
report["added_lines"] = block_apply.get("added_lines", 0)
if block_apply.get("hint"):
report["hint"] = block_apply.get("hint")
if status == "success":
report["status"] = "completed"
elif status == "not_found":
report["status"] = "failed"
report["reason"] = block_apply.get("reason") or "未找到匹配的原文"
if not report.get("hint"):
report["hint"] = "请使用 terminal_snapshot/grep -n 校验原文,或在说明后改用 run_command/python 精确替换。"
else:
report["status"] = "failed"
report["reason"] = block_apply.get("reason") or "替换失败"
if not report.get("hint"):
report["hint"] = block_apply.get("hint") or "若多次尝试仍失败,可考虑利用终端命令或 Python 小脚本完成此次修改。"
completed_blocks = sorted([idx for idx, rep in block_reports.items() if rep["status"] == "completed"])
failed_blocks = sorted([idx for idx, rep in block_reports.items() if rep["status"] != "completed"])
result["completed_blocks"] = completed_blocks
result["failed_blocks"] = failed_blocks
details = sorted(block_reports.values(), key=lambda x: x["index"])
if structure_detail_entries:
details = structure_detail_entries + details
result["details"] = details
summary_parts = []
if total_blocks == 0:
summary_parts.append("未检测到有效的修改块,未执行任何修改。")
summary_parts.extend(structure_warnings)
else:
if not completed_blocks and failed_blocks:
summary_parts.append(f"共检测到 {total_blocks} 个修改块,全部未执行。")
elif completed_blocks and not failed_blocks:
summary_parts.append(f"{total_blocks} 个修改块全部完成。")
else:
summary_parts.append(
f"共检测到 {total_blocks} 个修改块,其中成功 {len(completed_blocks)} 个,失败 {len(failed_blocks)} 个。"
)
if forced:
summary_parts.append("未检测到 <<<END_MODIFY>>> 标记,系统已在流结束处执行补丁。")
if apply_result.get("error"):
summary_parts.append(apply_result["error"])
matching_note = "提示:补丁匹配基于完整文本,包含注释和空白符,请确保 <<<OLD>>> 段落与文件内容逐字一致。如果修改成功,请忽略,如果失败,请明确原文后再次尝试。"
summary_parts.append(matching_note)
summary_message = " ".join(summary_parts).strip()
result["summary_message"] = summary_message
result["success"] = bool(completed_blocks) and not failed_blocks and apply_result.get("error") is None
tool_payload = {
"success": result["success"],
"path": path,
"total_blocks": total_blocks,
"completed": completed_blocks,
"failed": [
{
"index": rep["index"],
"reason": rep.get("reason"),
"hint": rep.get("hint")
}
for rep in result["details"] if rep["status"] != "completed"
],
"forced": forced,
"message": summary_message,
"finish_reason": finish_reason,
"details": result["details"]
}
if apply_result.get("error"):
tool_payload["error"] = apply_result["error"]
result["tool_content"] = json.dumps(tool_payload, ensure_ascii=False)
result["assistant_metadata"] = {
"modify_payload": {
"path": path,
"total_blocks": total_blocks,
"completed": completed_blocks,
"failed": failed_blocks,
"forced": forced,
"details": result["details"]
}
}
if display_id:
sender('update_action', {
'id': display_id,
'status': 'completed' if result["success"] else 'failed',
'result': tool_payload,
'preparing_id': tool_call_id,
'message': summary_message
})
pending_modify = None
modify_probe_buffer = ""
if hasattr(web_terminal, "pending_modify_request"):
web_terminal.pending_modify_request = None
return result, pending_modify, modify_probe_buffer