fix: reset stop flag after cancellations

This commit is contained in:
JOJO 2025-11-20 17:44:34 +08:00
parent aee18837e4
commit fd00fda96e

View File

@ -1249,21 +1249,18 @@ def handle_stop_task():
"""处理停止任务请求"""
print(f"[停止] 收到停止请求: {request.sid}")
# 检查是否有正在运行的任务
if request.sid in stop_flags and isinstance(stop_flags[request.sid], dict):
# 获取任务引用并取消
task_info = stop_flags[request.sid]
if 'task' in task_info and not task_info['task'].done():
debug_log(f"正在取消任务: {request.sid}")
task_info['task'].cancel()
task_info = stop_flags.get(request.sid)
if not isinstance(task_info, dict):
task_info = {'stop': False, 'task': None, 'terminal': None}
stop_flags[request.sid] = task_info
# 设置停止标志
task_info['stop'] = True
if task_info.get('terminal'):
reset_system_state(task_info['terminal'])
else:
# 如果没有任务引用,使用旧的布尔标志
stop_flags[request.sid] = True
if task_info.get('task') and not task_info['task'].done():
debug_log(f"正在取消任务: {request.sid}")
task_info['task'].cancel()
task_info['stop'] = True
if task_info.get('terminal'):
reset_system_state(task_info['terminal'])
emit('stop_requested', {
'message': '停止请求已接收,正在取消任务...'
@ -1864,12 +1861,13 @@ def process_message_task(terminal: WebTerminal, message: str, sender, client_sid
# 创建可取消的任务
task = loop.create_task(handle_task_with_sender(terminal, message, sender, client_sid))
# 存储任务引用,以便取消
if client_sid not in stop_flags:
stop_flags[client_sid] = {'stop': False, 'task': task, 'terminal': terminal}
else:
stop_flags[client_sid]['task'] = task
stop_flags[client_sid]['terminal'] = terminal
entry = stop_flags.get(client_sid)
if not isinstance(entry, dict):
entry = {'stop': False, 'task': None, 'terminal': None}
stop_flags[client_sid] = entry
entry['stop'] = False
entry['task'] = task
entry['terminal'] = terminal
try:
loop.run_until_complete(task)
@ -1907,8 +1905,7 @@ def process_message_task(terminal: WebTerminal, message: str, sender, client_sid
finally:
# 清理任务引用
if client_sid in stop_flags and isinstance(stop_flags[client_sid], dict):
stop_flags.pop(client_sid, None)
stop_flags.pop(client_sid, None)
def detect_malformed_tool_call(text):
"""检测文本中是否包含格式错误的工具调用"""