import json import re import time import uuid from datetime import datetime from flask import Blueprint, request, jsonify, Response, stream_with_context from openai import OpenAI from ..config import TOKEN_INTERVAL, MODEL_NAME, OPENAI_BASE_URL, API_KEY from ..prompt import SYSTEM_PROMPT from ..rag import search_rag from ..conversation_store import make_conversation, load as load_conversation, save as save_conversation bp = Blueprint("chat", __name__, url_prefix="/api") client = OpenAI(api_key=API_KEY, base_url=OPENAI_BASE_URL) TOOLS = [ { "type": "function", "function": { "name": "search_rag", "description": "在本地 QA 数据集中按 question 字段检索,返回最多 5 条匹配。", "parameters": { "type": "object", "required": ["query"], "properties": { "query": { "type": "string", "description": "用户问题或关键词" } } } } } ] def sanitize_legacy_markers(text: str) -> str: if not text: return text text = text.replace("[[SEARCH_START]]搜索中...", "") text = text.replace("[[SEARCH_DONE]]搜索完成", "") text = text.replace("[[SEARCH_START]]", "") text = text.replace("[[SEARCH_DONE]]", "") text = re.sub(r"(?m)^[ \t]*搜索中\.\.\.[ \t]*\n?", "", text) text = re.sub(r"(?m)^[ \t]*搜索完成[ \t]*\n?", "", text) return text def messages_for_model(messages: list[dict]) -> list[dict]: sanitized = [] for m in messages: mm = dict(m) if mm.get("role") == "assistant" and isinstance(mm.get("content"), str): mm["content"] = sanitize_legacy_markers(mm["content"]) sanitized.append(mm) return sanitized def emit_events_for_text(text: str): for ch in text: yield json.dumps({"type": "assistant_delta", "delta": ch}, ensure_ascii=False) + "\n" time.sleep(TOKEN_INTERVAL) @bp.post("/chat") def chat(): data = request.get_json(force=True) user_text = (data.get('message') or '').strip() if not user_text: return jsonify({"error": "message is required"}), 400 cid = data.get('conversation_id') convo = load_conversation(cid) if cid else None if not convo: convo = make_conversation(user_text[:20]) cid = convo['id'] # 仅记录 AI 对话的消息 convo['messages'].append({"role": "user", "content": user_text}) convo['updated_at'] = datetime.utcnow().isoformat() save_conversation(convo) all_messages = [SYSTEM_PROMPT] + messages_for_model(convo['messages']) def generate(): new_messages = [] try: model_messages = list(all_messages) max_tool_rounds = 2 tool_round = 0 while True: completion = client.chat.completions.create( model=MODEL_NAME, messages=model_messages, temperature=0.6, stream=True, tools=TOOLS, tool_choice="auto", ) print("[chat] streaming start, cid=", cid, "tool_round=", tool_round, "len(messages)=", len(model_messages)) tool_calls_acc = {} segment_text = "" tool_start_emitted = False for chunk in completion: delta = chunk.choices[0].delta text = delta.content or "" if text: segment_text += text yield from emit_events_for_text(text) print("[chat] delta text len", len(text)) if delta.tool_calls: if not tool_start_emitted: tool_start_emitted = True yield json.dumps({"type": "tool_call_start"}, ensure_ascii=False) + "\n" print("[chat] tool_call_start") for tc in delta.tool_calls: idx = tc.index or 0 entry = tool_calls_acc.setdefault(idx, {"id": tc.id, "name": None, "arguments": ""}) if tc.id: entry["id"] = tc.id if tc.function: if tc.function.name: entry["name"] = tc.function.name if tc.function.arguments: entry["arguments"] += tc.function.arguments print("[chat] tool_call accumulating args len", len(entry["arguments"])) if not tool_calls_acc: if segment_text: new_messages.append({"role": "assistant", "content": segment_text}) break if tool_round >= max_tool_rounds: if segment_text: new_messages.append({"role": "assistant", "content": segment_text}) break tool_round += 1 call = list(tool_calls_acc.values())[0] tool_name = call.get("name") tool_args = call.get("arguments") or "" tool_call_id = call.get("id") or str(uuid.uuid4()) yield json.dumps( { "type": "tool_call", "tool_call_id": tool_call_id, "name": tool_name, "arguments": tool_args, }, ensure_ascii=False, ) + "\n" print("[chat] tool_call emit", tool_name, tool_args) query = "" try: parsed = json.loads(tool_args or "{}") query = parsed.get("query", "") except Exception: query = tool_args matches = search_rag(query, limit=5) tool_response_content = json.dumps({"query": query, "matches": matches}, ensure_ascii=False) time.sleep(0.5) # 缩短等待,便于前端即时显示 yield json.dumps( { "type": "tool_result", "tool_call_id": tool_call_id, "content": tool_response_content, }, ensure_ascii=False, ) + "\n" print("[chat] tool_result emit, matches", len(matches)) assistant_tool_msg = { "role": "assistant", "content": segment_text, "tool_calls": [{ "id": tool_call_id, "type": "function", "function": {"name": tool_name, "arguments": tool_args} }] } tool_result_msg = { "role": "tool", "tool_call_id": tool_call_id, "content": tool_response_content, } model_messages = model_messages + [assistant_tool_msg, tool_result_msg] new_messages.append(assistant_tool_msg) new_messages.append(tool_result_msg) except Exception as e: err = f"[出错]{e}" yield from emit_events_for_text(err) new_messages.append({"role": "assistant", "content": err}) finally: if new_messages: convo['messages'].extend(new_messages) convo['updated_at'] = datetime.utcnow().isoformat() save_conversation(convo) headers = {'X-Conversation-Id': cid} return Response(stream_with_context(generate()), mimetype='application/x-ndjson; charset=utf-8', headers=headers)