From 3402375710f4c562dd576e5510b6a9858c3a89da Mon Sep 17 00:00:00 2001 From: JOJO <1498581755@qq.com> Date: Sat, 7 Mar 2026 20:37:58 +0800 Subject: [PATCH] refactor: split pending writes and task support from chat task main --- server/chat_flow_pending_writes.py | 546 ++++++++++++++++++++++++ server/chat_flow_task_main.py | 648 +---------------------------- server/chat_flow_task_support.py | 94 +++++ 3 files changed, 652 insertions(+), 636 deletions(-) create mode 100644 server/chat_flow_pending_writes.py create mode 100644 server/chat_flow_task_support.py diff --git a/server/chat_flow_pending_writes.py b/server/chat_flow_pending_writes.py new file mode 100644 index 0000000..c131f83 --- /dev/null +++ b/server/chat_flow_pending_writes.py @@ -0,0 +1,546 @@ +from __future__ import annotations + +import json +import re +from typing import Dict, List, Optional + + +async def finalize_pending_append(*, pending_append, append_probe_buffer: str, response_text: str, stream_completed: bool, finish_reason: str = None, web_terminal, sender, debug_log): + """在流式输出结束后处理追加写入""" + + result = { + "handled": False, + "success": False, + "summary": None, + "summary_message": None, + "tool_content": None, + "tool_call_id": None, + "path": None, + "forced": False, + "error": None, + "assistant_content": response_text, + "lines": 0, + "bytes": 0, + "finish_reason": finish_reason, + "appended_content": "", + "assistant_metadata": None + } + + if not pending_append: + return result, pending_append, append_probe_buffer + + state = pending_append + path = state.get("path") + tool_call_id = state.get("tool_call_id") + buffer = state.get("buffer", "") + start_marker = state.get("start_marker") + end_marker = state.get("end_marker") + start_idx = state.get("content_start") + end_idx = state.get("end_index") + + display_id = state.get("display_id") + + result.update({ + "handled": True, + "path": path, + "tool_call_id": tool_call_id, + "display_id": display_id + }) + + if path is None or tool_call_id is None: + error_msg = "append_to_file 状态不完整,缺少路径或ID。" + debug_log(error_msg) + result["error"] = error_msg + result["summary_message"] = error_msg + result["tool_content"] = json.dumps({ + "success": False, + "error": error_msg + }, ensure_ascii=False) + if display_id: + sender('update_action', { + 'id': display_id, + 'status': 'failed', + 'preparing_id': tool_call_id, + 'message': error_msg + }) + pending_append = None + return result, pending_append, append_probe_buffer + + if start_idx is None: + error_msg = f"未检测到格式正确的开始标识 {start_marker}。" + debug_log(error_msg) + result["error"] = error_msg + result["summary_message"] = error_msg + result["tool_content"] = json.dumps({ + "success": False, + "path": path, + "error": error_msg + }, ensure_ascii=False) + if display_id: + sender('update_action', { + 'id': display_id, + 'status': 'failed', + 'preparing_id': tool_call_id, + 'message': error_msg + }) + pending_append = None + return result, pending_append, append_probe_buffer + + forced = False + if end_idx is None: + forced = True + # 查找下一个<<<,否则使用整个缓冲结尾 + remaining = buffer[start_idx:] + next_marker = remaining.find("<<<", len(end_marker)) + if next_marker != -1: + end_idx = start_idx + next_marker + else: + end_idx = len(buffer) + + content = buffer[start_idx:end_idx] + if content.startswith('\n'): + content = content[1:] + + if not content: + error_msg = "未检测到需要追加的内容,请严格按照<<>>...<<>>格式输出。" + debug_log(error_msg) + result["error"] = error_msg + result["forced"] = forced + result["tool_content"] = json.dumps({ + "success": False, + "path": path, + "error": error_msg + }, ensure_ascii=False) + if display_id: + sender('update_action', { + 'id': display_id, + 'status': 'failed', + 'preparing_id': tool_call_id, + 'message': error_msg + }) + pending_append = None + return result, pending_append, append_probe_buffer + + assistant_message_lines = [] + if start_marker: + assistant_message_lines.append(start_marker) + assistant_message_lines.append(content) + if not forced and end_marker: + assistant_message_lines.append(end_marker) + assistant_message_text = "\n".join(assistant_message_lines) + result["assistant_content"] = assistant_message_text + assistant_metadata = { + "append_payload": { + "path": path, + "tool_call_id": tool_call_id, + "forced": forced, + "has_end_marker": not forced + } + } + result["assistant_metadata"] = assistant_metadata + + write_result = web_terminal.file_manager.append_file(path, content) + if write_result.get("success"): + bytes_written = len(content.encode('utf-8')) + line_count = content.count('\n') + if content and not content.endswith('\n'): + line_count += 1 + + summary = f"已向 {path} 追加 {line_count} 行({bytes_written} 字节)" + if forced: + summary += "。未检测到 <<>> 标记,系统已在流结束处完成写入。如内容未完成,请重新调用 append_to_file 并按标准格式补充;如已完成,可继续后续步骤。" + + result.update({ + "success": True, + "summary": summary, + "summary_message": summary, + "forced": forced, + "lines": line_count, + "bytes": bytes_written, + "appended_content": content, + "tool_content": json.dumps({ + "success": True, + "path": path, + "lines": line_count, + "bytes": bytes_written, + "forced": forced, + "message": summary, + "finish_reason": finish_reason + }, ensure_ascii=False) + }) + + assistant_meta_payload = result["assistant_metadata"]["append_payload"] + assistant_meta_payload["lines"] = line_count + assistant_meta_payload["bytes"] = bytes_written + assistant_meta_payload["success"] = True + + summary_payload = { + "success": True, + "path": path, + "lines": line_count, + "bytes": bytes_written, + "forced": forced, + "message": summary + } + + if display_id: + sender('update_action', { + 'id': display_id, + 'status': 'completed', + 'result': summary_payload, + 'preparing_id': tool_call_id, + 'message': summary + }) + + debug_log(f"追加写入完成: {summary}") + else: + error_msg = write_result.get("error", "追加写入失败") + result.update({ + "error": error_msg, + "summary_message": error_msg, + "forced": forced, + "appended_content": content, + "tool_content": json.dumps({ + "success": False, + "path": path, + "error": error_msg, + "finish_reason": finish_reason + }, ensure_ascii=False) + }) + debug_log(f"追加写入失败: {error_msg}") + + if result["assistant_metadata"]: + assistant_meta_payload = result["assistant_metadata"]["append_payload"] + assistant_meta_payload["lines"] = content.count('\n') + (0 if content.endswith('\n') or not content else 1) + assistant_meta_payload["bytes"] = len(content.encode('utf-8')) + assistant_meta_payload["success"] = False + + failure_payload = { + "success": False, + "path": path, + "error": error_msg, + "forced": forced + } + + if display_id: + sender('update_action', { + 'id': display_id, + 'status': 'completed', + 'result': failure_payload, + 'preparing_id': tool_call_id, + 'message': error_msg + }) + + pending_append = None + append_probe_buffer = "" + if hasattr(web_terminal, "pending_append_request"): + web_terminal.pending_append_request = None + return result, pending_append, append_probe_buffer + + + +async def finalize_pending_modify(*, pending_modify, modify_probe_buffer: str, response_text: str, stream_completed: bool, finish_reason: str = None, web_terminal, sender, debug_log): + """在流式输出结束后处理修改写入""" + + result = { + "handled": False, + "success": False, + "path": None, + "tool_call_id": None, + "display_id": None, + "total_blocks": 0, + "completed_blocks": [], + "failed_blocks": [], + "forced": False, + "details": [], + "error": None, + "assistant_content": response_text, + "assistant_metadata": None, + "tool_content": None, + "summary_message": None, + "finish_reason": finish_reason + } + + if not pending_modify: + return result, pending_modify, modify_probe_buffer + + state = pending_modify + path = state.get("path") + tool_call_id = state.get("tool_call_id") + display_id = state.get("display_id") + start_marker = state.get("start_marker") + end_marker = state.get("end_marker") + buffer = state.get("buffer", "") + raw_buffer = state.get("raw_buffer", "") + end_index = state.get("end_index") + + result.update({ + "handled": True, + "path": path, + "tool_call_id": tool_call_id, + "display_id": display_id + }) + + if not state.get("start_seen"): + error_msg = "未检测到格式正确的 <<>> 标记。" + debug_log(error_msg) + result["error"] = error_msg + result["summary_message"] = error_msg + result["tool_content"] = json.dumps({ + "success": False, + "path": path, + "error": error_msg, + "finish_reason": finish_reason + }, ensure_ascii=False) + if display_id: + sender('update_action', { + 'id': display_id, + 'status': 'failed', + 'preparing_id': tool_call_id, + 'message': error_msg + }) + if hasattr(web_terminal, "pending_modify_request"): + web_terminal.pending_modify_request = None + pending_modify = None + modify_probe_buffer = "" + return result, pending_modify, modify_probe_buffer + + forced = end_index is None + apply_text = buffer if forced else buffer[:end_index] + raw_content = raw_buffer if forced else raw_buffer[:len(start_marker) + end_index + len(end_marker)] + if raw_content: + result["assistant_content"] = raw_content + + blocks_info = [] + block_reports = {} + detected_indices = set() + block_pattern = re.compile(r"\[replace:(\d+)\](.*?)\[/replace\]", re.DOTALL) + structure_warnings: List[str] = [] + structure_detail_entries: List[Dict] = [] + + def record_structure_warning(message: str, hint: Optional[str] = None): + """记录结构性缺陷,便于给出更具体的反馈。""" + if message in structure_warnings: + return + structure_warnings.append(message) + structure_detail_entries.append({ + "index": 0, + "status": "failed", + "reason": message, + "removed_lines": 0, + "added_lines": 0, + "hint": hint or "请严格按照模板输出:[replace:n] + <>/<> + [/replace],并使用 <<>> 收尾。" + }) + + def extract_segment(body: str, tag: str): + marker = f"<<{tag}>>" + end_tag = "<>" + start_pos = body.find(marker) + if start_pos == -1: + return None, f"缺少 {marker}" + start_pos += len(marker) + if body[start_pos:start_pos+2] == "\r\n": + start_pos += 2 + elif body[start_pos:start_pos+1] == "\n": + start_pos += 1 + end_pos = body.find(end_tag, start_pos) + if end_pos == -1: + return None, f"缺少 {end_tag}" + segment = body[start_pos:end_pos] + return segment, None + + for match in block_pattern.finditer(apply_text): + try: + index = int(match.group(1)) + except ValueError: + continue + body = match.group(2) + if index in detected_indices: + continue + detected_indices.add(index) + block_reports[index] = { + "index": index, + "status": "pending", + "reason": None, + "removed_lines": 0, + "added_lines": 0, + "hint": None + } + old_content, old_error = extract_segment(body, "OLD") + new_content, new_error = extract_segment(body, "NEW") + if old_error or new_error: + reason = old_error or new_error + block_reports[index]["status"] = "failed" + block_reports[index]["reason"] = reason + blocks_info.append({ + "index": index, + "old": old_content, + "new": new_content, + "error": old_error or new_error + }) + + if not blocks_info: + has_replace_start = bool(re.search(r"\[replace:\s*\d+\]", apply_text)) + has_replace_end = "[/replace]" in apply_text + has_old_tag = "<>" in apply_text + has_new_tag = "<>" in apply_text + + if has_replace_start and not has_replace_end: + record_structure_warning("检测到 [replace:n] 标记但缺少对应的 [/replace] 结束标记。") + if has_replace_end and not has_replace_start: + record_structure_warning("检测到 [/replace] 结束标记但缺少对应的 [replace:n] 起始标记。") + + old_tags = len(re.findall(r"<>", apply_text)) + completed_old_tags = len(re.findall(r"<>[\s\S]*?<>", apply_text)) + if old_tags and completed_old_tags < old_tags: + record_structure_warning("检测到 <> 段落但未看到对应的 <> 结束标记。") + + new_tags = len(re.findall(r"<>", apply_text)) + completed_new_tags = len(re.findall(r"<>[\s\S]*?<>", apply_text)) + if new_tags and completed_new_tags < new_tags: + record_structure_warning("检测到 <> 段落但未看到对应的 <> 结束标记。") + + if (has_replace_start or has_replace_end or has_old_tag or has_new_tag) and not structure_warnings: + record_structure_warning("检测到部分补丁标记,但整体结构不完整,请严格按照模板填写所有标记。") + + total_blocks = len(blocks_info) + result["total_blocks"] = total_blocks + if forced: + debug_log("未检测到 <<>>,将在流结束处执行已识别的修改块。") + result["forced"] = True + + blocks_to_apply = [ + {"index": block["index"], "old": block["old"], "new": block["new"]} + for block in blocks_info + if block["error"] is None and block["old"] is not None and block["new"] is not None + ] + + # 记录格式残缺的块 + for block in blocks_info: + if block["error"]: + idx = block["index"] + block_reports[idx]["status"] = "failed" + block_reports[idx]["reason"] = block["error"] + block_reports[idx]["hint"] = "请检查补丁块的 OLD/NEW 标记是否完整,必要时复用 terminal_snapshot 或终端命令重新调整。" + + apply_result = {} + if blocks_to_apply: + apply_result = web_terminal.file_manager.apply_modify_blocks(path, blocks_to_apply) + else: + apply_result = {"success": False, "completed": [], "failed": [], "results": [], "write_performed": False, "error": None} + + block_result_map = {item["index"]: item for item in apply_result.get("results", [])} + + for block in blocks_info: + idx = block["index"] + report = block_reports.get(idx) + if report is None: + continue + if report["status"] == "failed": + continue + block_apply = block_result_map.get(idx) + if not block_apply: + report["status"] = "failed" + report["reason"] = "未执行,可能未找到匹配原文" + report["hint"] = report.get("hint") or "请确认 OLD 文本与文件内容完全一致;若多次失败,可改用终端命令/Python 进行精准替换。" + continue + status = block_apply.get("status") + report["removed_lines"] = block_apply.get("removed_lines", 0) + report["added_lines"] = block_apply.get("added_lines", 0) + if block_apply.get("hint"): + report["hint"] = block_apply.get("hint") + if status == "success": + report["status"] = "completed" + elif status == "not_found": + report["status"] = "failed" + report["reason"] = block_apply.get("reason") or "未找到匹配的原文" + if not report.get("hint"): + report["hint"] = "请使用 terminal_snapshot/grep -n 校验原文,或在说明后改用 run_command/python 精确替换。" + else: + report["status"] = "failed" + report["reason"] = block_apply.get("reason") or "替换失败" + if not report.get("hint"): + report["hint"] = block_apply.get("hint") or "若多次尝试仍失败,可考虑利用终端命令或 Python 小脚本完成此次修改。" + + completed_blocks = sorted([idx for idx, rep in block_reports.items() if rep["status"] == "completed"]) + failed_blocks = sorted([idx for idx, rep in block_reports.items() if rep["status"] != "completed"]) + + result["completed_blocks"] = completed_blocks + result["failed_blocks"] = failed_blocks + details = sorted(block_reports.values(), key=lambda x: x["index"]) + if structure_detail_entries: + details = structure_detail_entries + details + result["details"] = details + + summary_parts = [] + if total_blocks == 0: + summary_parts.append("未检测到有效的修改块,未执行任何修改。") + summary_parts.extend(structure_warnings) + else: + if not completed_blocks and failed_blocks: + summary_parts.append(f"共检测到 {total_blocks} 个修改块,全部未执行。") + elif completed_blocks and not failed_blocks: + summary_parts.append(f"共 {total_blocks} 个修改块全部完成。") + else: + summary_parts.append( + f"共检测到 {total_blocks} 个修改块,其中成功 {len(completed_blocks)} 个,失败 {len(failed_blocks)} 个。" + ) + if forced: + summary_parts.append("未检测到 <<>> 标记,系统已在流结束处执行补丁。") + if apply_result.get("error"): + summary_parts.append(apply_result["error"]) + + matching_note = "提示:补丁匹配基于完整文本,包含注释和空白符,请确保 <<>> 段落与文件内容逐字一致。如果修改成功,请忽略,如果失败,请明确原文后再次尝试。" + summary_parts.append(matching_note) + summary_message = " ".join(summary_parts).strip() + result["summary_message"] = summary_message + result["success"] = bool(completed_blocks) and not failed_blocks and apply_result.get("error") is None + + tool_payload = { + "success": result["success"], + "path": path, + "total_blocks": total_blocks, + "completed": completed_blocks, + "failed": [ + { + "index": rep["index"], + "reason": rep.get("reason"), + "hint": rep.get("hint") + } + for rep in result["details"] if rep["status"] != "completed" + ], + "forced": forced, + "message": summary_message, + "finish_reason": finish_reason, + "details": result["details"] + } + if apply_result.get("error"): + tool_payload["error"] = apply_result["error"] + + result["tool_content"] = json.dumps(tool_payload, ensure_ascii=False) + result["assistant_metadata"] = { + "modify_payload": { + "path": path, + "total_blocks": total_blocks, + "completed": completed_blocks, + "failed": failed_blocks, + "forced": forced, + "details": result["details"] + } + } + + if display_id: + sender('update_action', { + 'id': display_id, + 'status': 'completed' if result["success"] else 'failed', + 'result': tool_payload, + 'preparing_id': tool_call_id, + 'message': summary_message + }) + + pending_modify = None + modify_probe_buffer = "" + if hasattr(web_terminal, "pending_modify_request"): + web_terminal.pending_modify_request = None + return result, pending_modify, modify_probe_buffer + diff --git a/server/chat_flow_task_main.py b/server/chat_flow_task_main.py index 52964f5..9671445 100644 --- a/server/chat_flow_task_main.py +++ b/server/chat_flow_task_main.py @@ -123,6 +123,9 @@ from .chat_flow_runtime import ( detect_malformed_tool_call, ) +from .chat_flow_pending_writes import finalize_pending_append, finalize_pending_modify +from .chat_flow_task_support import process_sub_agent_updates, wait_retry_delay, cancel_pending_tools + async def handle_task_with_sender(terminal: WebTerminal, workspace: UserWorkspace, message, images, sender, client_sid, username: str, videos=None): """处理任务并发送消息 - 集成token统计版本""" web_terminal = terminal @@ -251,607 +254,6 @@ async def handle_task_with_sender(terminal: WebTerminal, workspace: UserWorkspac pending_modify = None # {"path": str, "tool_call_id": str, "buffer": str, ...} modify_probe_buffer = "" - async def finalize_pending_append(response_text: str, stream_completed: bool, finish_reason: str = None) -> Dict: - """在流式输出结束后处理追加写入""" - nonlocal pending_append, append_probe_buffer - - result = { - "handled": False, - "success": False, - "summary": None, - "summary_message": None, - "tool_content": None, - "tool_call_id": None, - "path": None, - "forced": False, - "error": None, - "assistant_content": response_text, - "lines": 0, - "bytes": 0, - "finish_reason": finish_reason, - "appended_content": "", - "assistant_metadata": None - } - - if not pending_append: - return result - - state = pending_append - path = state.get("path") - tool_call_id = state.get("tool_call_id") - buffer = state.get("buffer", "") - start_marker = state.get("start_marker") - end_marker = state.get("end_marker") - start_idx = state.get("content_start") - end_idx = state.get("end_index") - - display_id = state.get("display_id") - - result.update({ - "handled": True, - "path": path, - "tool_call_id": tool_call_id, - "display_id": display_id - }) - - if path is None or tool_call_id is None: - error_msg = "append_to_file 状态不完整,缺少路径或ID。" - debug_log(error_msg) - result["error"] = error_msg - result["summary_message"] = error_msg - result["tool_content"] = json.dumps({ - "success": False, - "error": error_msg - }, ensure_ascii=False) - if display_id: - sender('update_action', { - 'id': display_id, - 'status': 'failed', - 'preparing_id': tool_call_id, - 'message': error_msg - }) - pending_append = None - return result - - if start_idx is None: - error_msg = f"未检测到格式正确的开始标识 {start_marker}。" - debug_log(error_msg) - result["error"] = error_msg - result["summary_message"] = error_msg - result["tool_content"] = json.dumps({ - "success": False, - "path": path, - "error": error_msg - }, ensure_ascii=False) - if display_id: - sender('update_action', { - 'id': display_id, - 'status': 'failed', - 'preparing_id': tool_call_id, - 'message': error_msg - }) - pending_append = None - return result - - forced = False - if end_idx is None: - forced = True - # 查找下一个<<<,否则使用整个缓冲结尾 - remaining = buffer[start_idx:] - next_marker = remaining.find("<<<", len(end_marker)) - if next_marker != -1: - end_idx = start_idx + next_marker - else: - end_idx = len(buffer) - - content = buffer[start_idx:end_idx] - if content.startswith('\n'): - content = content[1:] - - if not content: - error_msg = "未检测到需要追加的内容,请严格按照<<>>...<<>>格式输出。" - debug_log(error_msg) - result["error"] = error_msg - result["forced"] = forced - result["tool_content"] = json.dumps({ - "success": False, - "path": path, - "error": error_msg - }, ensure_ascii=False) - if display_id: - sender('update_action', { - 'id': display_id, - 'status': 'failed', - 'preparing_id': tool_call_id, - 'message': error_msg - }) - pending_append = None - return result - - assistant_message_lines = [] - if start_marker: - assistant_message_lines.append(start_marker) - assistant_message_lines.append(content) - if not forced and end_marker: - assistant_message_lines.append(end_marker) - assistant_message_text = "\n".join(assistant_message_lines) - result["assistant_content"] = assistant_message_text - assistant_metadata = { - "append_payload": { - "path": path, - "tool_call_id": tool_call_id, - "forced": forced, - "has_end_marker": not forced - } - } - result["assistant_metadata"] = assistant_metadata - - write_result = web_terminal.file_manager.append_file(path, content) - if write_result.get("success"): - bytes_written = len(content.encode('utf-8')) - line_count = content.count('\n') - if content and not content.endswith('\n'): - line_count += 1 - - summary = f"已向 {path} 追加 {line_count} 行({bytes_written} 字节)" - if forced: - summary += "。未检测到 <<>> 标记,系统已在流结束处完成写入。如内容未完成,请重新调用 append_to_file 并按标准格式补充;如已完成,可继续后续步骤。" - - result.update({ - "success": True, - "summary": summary, - "summary_message": summary, - "forced": forced, - "lines": line_count, - "bytes": bytes_written, - "appended_content": content, - "tool_content": json.dumps({ - "success": True, - "path": path, - "lines": line_count, - "bytes": bytes_written, - "forced": forced, - "message": summary, - "finish_reason": finish_reason - }, ensure_ascii=False) - }) - - assistant_meta_payload = result["assistant_metadata"]["append_payload"] - assistant_meta_payload["lines"] = line_count - assistant_meta_payload["bytes"] = bytes_written - assistant_meta_payload["success"] = True - - summary_payload = { - "success": True, - "path": path, - "lines": line_count, - "bytes": bytes_written, - "forced": forced, - "message": summary - } - - if display_id: - sender('update_action', { - 'id': display_id, - 'status': 'completed', - 'result': summary_payload, - 'preparing_id': tool_call_id, - 'message': summary - }) - - debug_log(f"追加写入完成: {summary}") - else: - error_msg = write_result.get("error", "追加写入失败") - result.update({ - "error": error_msg, - "summary_message": error_msg, - "forced": forced, - "appended_content": content, - "tool_content": json.dumps({ - "success": False, - "path": path, - "error": error_msg, - "finish_reason": finish_reason - }, ensure_ascii=False) - }) - debug_log(f"追加写入失败: {error_msg}") - - if result["assistant_metadata"]: - assistant_meta_payload = result["assistant_metadata"]["append_payload"] - assistant_meta_payload["lines"] = content.count('\n') + (0 if content.endswith('\n') or not content else 1) - assistant_meta_payload["bytes"] = len(content.encode('utf-8')) - assistant_meta_payload["success"] = False - - failure_payload = { - "success": False, - "path": path, - "error": error_msg, - "forced": forced - } - - if display_id: - sender('update_action', { - 'id': display_id, - 'status': 'completed', - 'result': failure_payload, - 'preparing_id': tool_call_id, - 'message': error_msg - }) - - pending_append = None - append_probe_buffer = "" - if hasattr(web_terminal, "pending_append_request"): - web_terminal.pending_append_request = None - return result - - async def finalize_pending_modify(response_text: str, stream_completed: bool, finish_reason: str = None) -> Dict: - """在流式输出结束后处理修改写入""" - nonlocal pending_modify, modify_probe_buffer - - result = { - "handled": False, - "success": False, - "path": None, - "tool_call_id": None, - "display_id": None, - "total_blocks": 0, - "completed_blocks": [], - "failed_blocks": [], - "forced": False, - "details": [], - "error": None, - "assistant_content": response_text, - "assistant_metadata": None, - "tool_content": None, - "summary_message": None, - "finish_reason": finish_reason - } - - if not pending_modify: - return result - - state = pending_modify - path = state.get("path") - tool_call_id = state.get("tool_call_id") - display_id = state.get("display_id") - start_marker = state.get("start_marker") - end_marker = state.get("end_marker") - buffer = state.get("buffer", "") - raw_buffer = state.get("raw_buffer", "") - end_index = state.get("end_index") - - result.update({ - "handled": True, - "path": path, - "tool_call_id": tool_call_id, - "display_id": display_id - }) - - if not state.get("start_seen"): - error_msg = "未检测到格式正确的 <<>> 标记。" - debug_log(error_msg) - result["error"] = error_msg - result["summary_message"] = error_msg - result["tool_content"] = json.dumps({ - "success": False, - "path": path, - "error": error_msg, - "finish_reason": finish_reason - }, ensure_ascii=False) - if display_id: - sender('update_action', { - 'id': display_id, - 'status': 'failed', - 'preparing_id': tool_call_id, - 'message': error_msg - }) - if hasattr(web_terminal, "pending_modify_request"): - web_terminal.pending_modify_request = None - pending_modify = None - modify_probe_buffer = "" - return result - - forced = end_index is None - apply_text = buffer if forced else buffer[:end_index] - raw_content = raw_buffer if forced else raw_buffer[:len(start_marker) + end_index + len(end_marker)] - if raw_content: - result["assistant_content"] = raw_content - - blocks_info = [] - block_reports = {} - detected_indices = set() - block_pattern = re.compile(r"\[replace:(\d+)\](.*?)\[/replace\]", re.DOTALL) - structure_warnings: List[str] = [] - structure_detail_entries: List[Dict] = [] - - def record_structure_warning(message: str, hint: Optional[str] = None): - """记录结构性缺陷,便于给出更具体的反馈。""" - if message in structure_warnings: - return - structure_warnings.append(message) - structure_detail_entries.append({ - "index": 0, - "status": "failed", - "reason": message, - "removed_lines": 0, - "added_lines": 0, - "hint": hint or "请严格按照模板输出:[replace:n] + <>/<> + [/replace],并使用 <<>> 收尾。" - }) - - def extract_segment(body: str, tag: str): - marker = f"<<{tag}>>" - end_tag = "<>" - start_pos = body.find(marker) - if start_pos == -1: - return None, f"缺少 {marker}" - start_pos += len(marker) - if body[start_pos:start_pos+2] == "\r\n": - start_pos += 2 - elif body[start_pos:start_pos+1] == "\n": - start_pos += 1 - end_pos = body.find(end_tag, start_pos) - if end_pos == -1: - return None, f"缺少 {end_tag}" - segment = body[start_pos:end_pos] - return segment, None - - for match in block_pattern.finditer(apply_text): - try: - index = int(match.group(1)) - except ValueError: - continue - body = match.group(2) - if index in detected_indices: - continue - detected_indices.add(index) - block_reports[index] = { - "index": index, - "status": "pending", - "reason": None, - "removed_lines": 0, - "added_lines": 0, - "hint": None - } - old_content, old_error = extract_segment(body, "OLD") - new_content, new_error = extract_segment(body, "NEW") - if old_error or new_error: - reason = old_error or new_error - block_reports[index]["status"] = "failed" - block_reports[index]["reason"] = reason - blocks_info.append({ - "index": index, - "old": old_content, - "new": new_content, - "error": old_error or new_error - }) - - if not blocks_info: - has_replace_start = bool(re.search(r"\[replace:\s*\d+\]", apply_text)) - has_replace_end = "[/replace]" in apply_text - has_old_tag = "<>" in apply_text - has_new_tag = "<>" in apply_text - - if has_replace_start and not has_replace_end: - record_structure_warning("检测到 [replace:n] 标记但缺少对应的 [/replace] 结束标记。") - if has_replace_end and not has_replace_start: - record_structure_warning("检测到 [/replace] 结束标记但缺少对应的 [replace:n] 起始标记。") - - old_tags = len(re.findall(r"<>", apply_text)) - completed_old_tags = len(re.findall(r"<>[\s\S]*?<>", apply_text)) - if old_tags and completed_old_tags < old_tags: - record_structure_warning("检测到 <> 段落但未看到对应的 <> 结束标记。") - - new_tags = len(re.findall(r"<>", apply_text)) - completed_new_tags = len(re.findall(r"<>[\s\S]*?<>", apply_text)) - if new_tags and completed_new_tags < new_tags: - record_structure_warning("检测到 <> 段落但未看到对应的 <> 结束标记。") - - if (has_replace_start or has_replace_end or has_old_tag or has_new_tag) and not structure_warnings: - record_structure_warning("检测到部分补丁标记,但整体结构不完整,请严格按照模板填写所有标记。") - - total_blocks = len(blocks_info) - result["total_blocks"] = total_blocks - if forced: - debug_log("未检测到 <<>>,将在流结束处执行已识别的修改块。") - result["forced"] = True - - blocks_to_apply = [ - {"index": block["index"], "old": block["old"], "new": block["new"]} - for block in blocks_info - if block["error"] is None and block["old"] is not None and block["new"] is not None - ] - - # 记录格式残缺的块 - for block in blocks_info: - if block["error"]: - idx = block["index"] - block_reports[idx]["status"] = "failed" - block_reports[idx]["reason"] = block["error"] - block_reports[idx]["hint"] = "请检查补丁块的 OLD/NEW 标记是否完整,必要时复用 terminal_snapshot 或终端命令重新调整。" - - apply_result = {} - if blocks_to_apply: - apply_result = web_terminal.file_manager.apply_modify_blocks(path, blocks_to_apply) - else: - apply_result = {"success": False, "completed": [], "failed": [], "results": [], "write_performed": False, "error": None} - - block_result_map = {item["index"]: item for item in apply_result.get("results", [])} - - for block in blocks_info: - idx = block["index"] - report = block_reports.get(idx) - if report is None: - continue - if report["status"] == "failed": - continue - block_apply = block_result_map.get(idx) - if not block_apply: - report["status"] = "failed" - report["reason"] = "未执行,可能未找到匹配原文" - report["hint"] = report.get("hint") or "请确认 OLD 文本与文件内容完全一致;若多次失败,可改用终端命令/Python 进行精准替换。" - continue - status = block_apply.get("status") - report["removed_lines"] = block_apply.get("removed_lines", 0) - report["added_lines"] = block_apply.get("added_lines", 0) - if block_apply.get("hint"): - report["hint"] = block_apply.get("hint") - if status == "success": - report["status"] = "completed" - elif status == "not_found": - report["status"] = "failed" - report["reason"] = block_apply.get("reason") or "未找到匹配的原文" - if not report.get("hint"): - report["hint"] = "请使用 terminal_snapshot/grep -n 校验原文,或在说明后改用 run_command/python 精确替换。" - else: - report["status"] = "failed" - report["reason"] = block_apply.get("reason") or "替换失败" - if not report.get("hint"): - report["hint"] = block_apply.get("hint") or "若多次尝试仍失败,可考虑利用终端命令或 Python 小脚本完成此次修改。" - - completed_blocks = sorted([idx for idx, rep in block_reports.items() if rep["status"] == "completed"]) - failed_blocks = sorted([idx for idx, rep in block_reports.items() if rep["status"] != "completed"]) - - result["completed_blocks"] = completed_blocks - result["failed_blocks"] = failed_blocks - details = sorted(block_reports.values(), key=lambda x: x["index"]) - if structure_detail_entries: - details = structure_detail_entries + details - result["details"] = details - - summary_parts = [] - if total_blocks == 0: - summary_parts.append("未检测到有效的修改块,未执行任何修改。") - summary_parts.extend(structure_warnings) - else: - if not completed_blocks and failed_blocks: - summary_parts.append(f"共检测到 {total_blocks} 个修改块,全部未执行。") - elif completed_blocks and not failed_blocks: - summary_parts.append(f"共 {total_blocks} 个修改块全部完成。") - else: - summary_parts.append( - f"共检测到 {total_blocks} 个修改块,其中成功 {len(completed_blocks)} 个,失败 {len(failed_blocks)} 个。" - ) - if forced: - summary_parts.append("未检测到 <<>> 标记,系统已在流结束处执行补丁。") - if apply_result.get("error"): - summary_parts.append(apply_result["error"]) - - matching_note = "提示:补丁匹配基于完整文本,包含注释和空白符,请确保 <<>> 段落与文件内容逐字一致。如果修改成功,请忽略,如果失败,请明确原文后再次尝试。" - summary_parts.append(matching_note) - summary_message = " ".join(summary_parts).strip() - result["summary_message"] = summary_message - result["success"] = bool(completed_blocks) and not failed_blocks and apply_result.get("error") is None - - tool_payload = { - "success": result["success"], - "path": path, - "total_blocks": total_blocks, - "completed": completed_blocks, - "failed": [ - { - "index": rep["index"], - "reason": rep.get("reason"), - "hint": rep.get("hint") - } - for rep in result["details"] if rep["status"] != "completed" - ], - "forced": forced, - "message": summary_message, - "finish_reason": finish_reason, - "details": result["details"] - } - if apply_result.get("error"): - tool_payload["error"] = apply_result["error"] - - result["tool_content"] = json.dumps(tool_payload, ensure_ascii=False) - result["assistant_metadata"] = { - "modify_payload": { - "path": path, - "total_blocks": total_blocks, - "completed": completed_blocks, - "failed": failed_blocks, - "forced": forced, - "details": result["details"] - } - } - - if display_id: - sender('update_action', { - 'id': display_id, - 'status': 'completed' if result["success"] else 'failed', - 'result': tool_payload, - 'preparing_id': tool_call_id, - 'message': summary_message - }) - - pending_modify = None - modify_probe_buffer = "" - if hasattr(web_terminal, "pending_modify_request"): - web_terminal.pending_modify_request = None - return result - - async def process_sub_agent_updates( - messages: List[Dict], - inline: bool = False, - after_tool_call_id: Optional[str] = None - ): - """轮询子智能体任务并通知前端,并把结果插入当前对话上下文。""" - manager = getattr(web_terminal, "sub_agent_manager", None) - if not manager: - return - try: - updates = manager.poll_updates() - debug_log(f"[SubAgent] poll inline={inline} updates={len(updates)}") - except Exception as exc: - debug_log(f"子智能体状态检查失败: {exc}") - return - for update in updates: - message = update.get("system_message") - if not message: - continue - task_id = update.get("task_id") - debug_log(f"[SubAgent] update task={task_id} inline={inline} msg={message}") - web_terminal._record_sub_agent_message(message, task_id, inline=inline) - debug_log(f"[SubAgent] recorded task={task_id}, 计算插入位置") - - insert_index = len(messages) - if after_tool_call_id: - for idx, msg in enumerate(messages): - if msg.get("role") == "tool" and msg.get("tool_call_id") == after_tool_call_id: - insert_index = idx + 1 - break - - messages.insert(insert_index, { - "role": "system", - "content": message, - "metadata": {"sub_agent_notice": True, "inline": inline, "task_id": task_id} - }) - debug_log(f"[SubAgent] 插入系统消息位置: {insert_index}") - sender('system_message', { - 'content': message, - 'inline': inline - }) - maybe_mark_failure_from_message(web_terminal, message) - - async def _wait_retry_delay(delay_seconds: int) -> bool: - """等待重试间隔,同时检查是否收到停止请求。""" - if delay_seconds <= 0: - return False - deadline = time.time() + delay_seconds - while time.time() < deadline: - client_stop_info = get_stop_flag(client_sid, username) - if client_stop_info: - stop_requested = client_stop_info.get('stop', False) if isinstance(client_stop_info, dict) else client_stop_info - if stop_requested: - sender('task_stopped', { - 'message': '命令执行被用户取消', - 'reason': 'user_stop' - }) - clear_stop_flag(client_sid, username) - return True - await asyncio.sleep(0.2) - return False - iteration = 0 while max_iterations is None or iteration < max_iterations: current_iteration = iteration + 1 @@ -898,32 +300,6 @@ async def handle_task_with_sender(terminal: WebTerminal, workspace: UserWorkspac modify_result = {"handled": False} last_finish_reason = None - def _cancel_pending_tools(tool_calls_list): - """为尚未返回结果的工具生成取消结果,防止缺失 tool_call_id 造成后续 400。""" - if not tool_calls_list: - return - for tc in tool_calls_list: - tc_id = tc.get("id") - func_name = tc.get("function", {}).get("name") - sender('update_action', { - 'preparing_id': tc_id, - 'status': 'cancelled', - 'result': { - "success": False, - "status": "cancelled", - "message": "命令执行被用户取消", - "tool": func_name - } - }) - if tc_id: - messages.append({ - "role": "tool", - "tool_call_id": tc_id, - "name": func_name, - "content": "命令执行被用户取消", - "metadata": {"status": "cancelled"} - }) - thinking_expected = web_terminal.api_client.get_current_thinking_mode() debug_log(f"思考模式: {thinking_expected}") quota_allowed = True @@ -989,10 +365,10 @@ async def handle_task_with_sender(terminal: WebTerminal, workspace: UserWorkspac if stop_requested: debug_log(f"检测到停止请求,中断流处理") if pending_append: - append_result = await finalize_pending_append(full_response, False, finish_reason="user_stop") + append_result, pending_append, append_probe_buffer = await finalize_pending_append(pending_append=pending_append, append_probe_buffer=append_probe_buffer, response_text=full_response, stream_completed=False, finish_reason="user_stop", web_terminal=web_terminal, sender=sender, debug_log=debug_log) if pending_modify: - modify_result = await finalize_pending_modify(full_response, False, finish_reason="user_stop") - _cancel_pending_tools(tool_calls) + modify_result, pending_modify, modify_probe_buffer = await finalize_pending_modify(pending_modify=pending_modify, modify_probe_buffer=modify_probe_buffer, response_text=full_response, stream_completed=False, finish_reason="user_stop", web_terminal=web_terminal, sender=sender, debug_log=debug_log) + cancel_pending_tools(tool_calls_list=tool_calls, sender=sender, messages=messages) sender('task_stopped', { 'message': '命令执行被用户取消', 'reason': 'user_stop' @@ -1358,7 +734,7 @@ async def handle_task_with_sender(terminal: WebTerminal, workspace: UserWorkspac 'message': '命令执行被用户取消', 'reason': 'user_stop' }) - _cancel_pending_tools(tool_calls) + cancel_pending_tools(tool_calls_list=tool_calls, sender=sender, messages=messages) clear_stop_flag(client_sid, username) return @@ -1449,11 +825,11 @@ async def handle_task_with_sender(terminal: WebTerminal, workspace: UserWorkspac web_terminal.apply_model_profile(profile) except Exception as exc: debug_log(f"重试前更新模型配置失败: {exc}") - cancelled = await _wait_retry_delay(retry_delay_seconds) + cancelled = await wait_retry_delay(delay_seconds=retry_delay_seconds, client_sid=client_sid, username=username, sender=sender, get_stop_flag=get_stop_flag, clear_stop_flag=clear_stop_flag) if cancelled: return continue - _cancel_pending_tools(tool_calls) + cancel_pending_tools(tool_calls_list=tool_calls, sender=sender, messages=messages) return break @@ -1468,9 +844,9 @@ async def handle_task_with_sender(terminal: WebTerminal, workspace: UserWorkspac debug_log(f" 收集到的工具: {len(tool_calls)} 个") if not append_result["handled"] and pending_append: - append_result = await finalize_pending_append(full_response, True, finish_reason=last_finish_reason) + append_result, pending_append, append_probe_buffer = await finalize_pending_append(pending_append=pending_append, append_probe_buffer=append_probe_buffer, response_text=full_response, stream_completed=True, finish_reason=last_finish_reason, web_terminal=web_terminal, sender=sender, debug_log=debug_log) if not modify_result["handled"] and pending_modify: - modify_result = await finalize_pending_modify(full_response, True, finish_reason=last_finish_reason) + modify_result, pending_modify, modify_probe_buffer = await finalize_pending_modify(pending_modify=pending_modify, modify_probe_buffer=modify_probe_buffer, response_text=full_response, stream_completed=True, finish_reason=last_finish_reason, web_terminal=web_terminal, sender=sender, debug_log=debug_log) # 结束未完成的流 if in_thinking and not thinking_ended: @@ -2159,7 +1535,7 @@ async def handle_task_with_sender(terminal: WebTerminal, workspace: UserWorkspac }) if function_name not in {'write_file', 'edit_file'}: - await process_sub_agent_updates(messages, inline=True, after_tool_call_id=tool_call_id) + await process_sub_agent_updates(messages=messages, inline=True, after_tool_call_id=tool_call_id, web_terminal=web_terminal, sender=sender, debug_log=debug_log, maybe_mark_failure_from_message=maybe_mark_failure_from_message) await asyncio.sleep(0.2) diff --git a/server/chat_flow_task_support.py b/server/chat_flow_task_support.py new file mode 100644 index 0000000..d30f261 --- /dev/null +++ b/server/chat_flow_task_support.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Dict, List, Optional + + +async def process_sub_agent_updates(*, messages: List[Dict], inline: bool = False, after_tool_call_id: Optional[str] = None, web_terminal, sender, debug_log, maybe_mark_failure_from_message): + """轮询子智能体任务并通知前端,并把结果插入当前对话上下文。""" + manager = getattr(web_terminal, "sub_agent_manager", None) + if not manager: + return + try: + updates = manager.poll_updates() + debug_log(f"[SubAgent] poll inline={inline} updates={len(updates)}") + except Exception as exc: + debug_log(f"子智能体状态检查失败: {exc}") + return + for update in updates: + message = update.get("system_message") + if not message: + continue + task_id = update.get("task_id") + debug_log(f"[SubAgent] update task={task_id} inline={inline} msg={message}") + web_terminal._record_sub_agent_message(message, task_id, inline=inline) + debug_log(f"[SubAgent] recorded task={task_id}, 计算插入位置") + + insert_index = len(messages) + if after_tool_call_id: + for idx, msg in enumerate(messages): + if msg.get("role") == "tool" and msg.get("tool_call_id") == after_tool_call_id: + insert_index = idx + 1 + break + + messages.insert(insert_index, { + "role": "system", + "content": message, + "metadata": {"sub_agent_notice": True, "inline": inline, "task_id": task_id} + }) + debug_log(f"[SubAgent] 插入系统消息位置: {insert_index}") + sender('system_message', { + 'content': message, + 'inline': inline + }) + maybe_mark_failure_from_message(web_terminal, message) + + + +async def wait_retry_delay(*, delay_seconds: int, client_sid: str, username: str, sender, get_stop_flag, clear_stop_flag) -> bool: + """等待重试间隔,同时检查是否收到停止请求。""" + if delay_seconds <= 0: + return False + deadline = time.time() + delay_seconds + while time.time() < deadline: + client_stop_info = get_stop_flag(client_sid, username) + if client_stop_info: + stop_requested = client_stop_info.get('stop', False) if isinstance(client_stop_info, dict) else client_stop_info + if stop_requested: + sender('task_stopped', { + 'message': '命令执行被用户取消', + 'reason': 'user_stop' + }) + clear_stop_flag(client_sid, username) + return True + await asyncio.sleep(0.2) + return False + + + +def cancel_pending_tools(*, tool_calls_list, sender, messages): + """为尚未返回结果的工具生成取消结果,防止缺失 tool_call_id 造成后续 400。""" + if not tool_calls_list: + return + for tc in tool_calls_list: + tc_id = tc.get("id") + func_name = tc.get("function", {}).get("name") + sender('update_action', { + 'preparing_id': tc_id, + 'status': 'cancelled', + 'result': { + "success": False, + "status": "cancelled", + "message": "命令执行被用户取消", + "tool": func_name + } + }) + if tc_id: + messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "name": func_name, + "content": "命令执行被用户取消", + "metadata": {"status": "cancelled"} + })