nianjie/backend/routes/chat.py

216 lines
7.9 KiB
Python

import json
import re
import time
import uuid
from datetime import datetime
from flask import Blueprint, request, jsonify, Response, stream_with_context
from openai import OpenAI
from ..config import TOKEN_INTERVAL, MODEL_NAME, OPENAI_API_KEY, OPENAI_BASE_URL
from ..prompt import SYSTEM_PROMPT
from ..rag import search_rag
from ..conversation_store import make_conversation, load as load_conversation, save as save_conversation
bp = Blueprint("chat", __name__, url_prefix="/api")
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
TOOLS = [
{
"type": "function",
"function": {
"name": "search_rag",
"description": "在本地 QA 数据集中按 question 字段检索,返回最多 5 条匹配。",
"parameters": {
"type": "object",
"required": ["query"],
"properties": {
"query": {
"type": "string",
"description": "用户问题或关键词"
}
}
}
}
}
]
def sanitize_legacy_markers(text: str) -> str:
if not text:
return text
text = text.replace("[[SEARCH_START]]搜索中...", "")
text = text.replace("[[SEARCH_DONE]]搜索完成", "")
text = text.replace("[[SEARCH_START]]", "")
text = text.replace("[[SEARCH_DONE]]", "")
text = re.sub(r"(?m)^[ \t]*搜索中\.\.\.[ \t]*\n?", "", text)
text = re.sub(r"(?m)^[ \t]*搜索完成[ \t]*\n?", "", text)
return text
def messages_for_model(messages: list[dict]) -> list[dict]:
sanitized = []
for m in messages:
mm = dict(m)
if mm.get("role") == "assistant" and isinstance(mm.get("content"), str):
mm["content"] = sanitize_legacy_markers(mm["content"])
sanitized.append(mm)
return sanitized
def emit_events_for_text(text: str):
for ch in text:
yield json.dumps({"type": "assistant_delta", "delta": ch}, ensure_ascii=False) + "\n"
time.sleep(TOKEN_INTERVAL)
@bp.post("/chat")
def chat():
data = request.get_json(force=True)
user_text = (data.get('message') or '').strip()
if not user_text:
return jsonify({"error": "message is required"}), 400
cid = data.get('conversation_id')
convo = load_conversation(cid) if cid else None
if not convo:
convo = make_conversation(user_text[:20])
cid = convo['id']
# 仅记录 AI 对话的消息
convo['messages'].append({"role": "user", "content": user_text})
convo['updated_at'] = datetime.utcnow().isoformat()
save_conversation(convo)
all_messages = [SYSTEM_PROMPT] + messages_for_model(convo['messages'])
def generate():
new_messages = []
try:
model_messages = list(all_messages)
max_tool_rounds = 2
tool_round = 0
while True:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=model_messages,
temperature=0.6,
stream=True,
tools=TOOLS,
tool_choice="auto",
)
print("[chat] streaming start, cid=", cid, "tool_round=", tool_round, "len(messages)=", len(model_messages))
tool_calls_acc = {}
segment_text = ""
tool_start_emitted = False
for chunk in completion:
delta = chunk.choices[0].delta
text = delta.content or ""
if text:
segment_text += text
yield from emit_events_for_text(text)
print("[chat] delta text len", len(text))
if delta.tool_calls:
if not tool_start_emitted:
tool_start_emitted = True
yield json.dumps({"type": "tool_call_start"}, ensure_ascii=False) + "\n"
print("[chat] tool_call_start")
for tc in delta.tool_calls:
idx = tc.index or 0
entry = tool_calls_acc.setdefault(idx, {"id": tc.id, "name": None, "arguments": ""})
if tc.id:
entry["id"] = tc.id
if tc.function:
if tc.function.name:
entry["name"] = tc.function.name
if tc.function.arguments:
entry["arguments"] += tc.function.arguments
print("[chat] tool_call accumulating args len", len(entry["arguments"]))
if not tool_calls_acc:
if segment_text:
new_messages.append({"role": "assistant", "content": segment_text})
break
if tool_round >= max_tool_rounds:
if segment_text:
new_messages.append({"role": "assistant", "content": segment_text})
break
tool_round += 1
call = list(tool_calls_acc.values())[0]
tool_name = call.get("name")
tool_args = call.get("arguments") or ""
tool_call_id = call.get("id") or str(uuid.uuid4())
yield json.dumps(
{
"type": "tool_call",
"tool_call_id": tool_call_id,
"name": tool_name,
"arguments": tool_args,
},
ensure_ascii=False,
) + "\n"
print("[chat] tool_call emit", tool_name, tool_args)
query = ""
try:
parsed = json.loads(tool_args or "{}")
query = parsed.get("query", "")
except Exception:
query = tool_args
matches = search_rag(query, limit=5)
tool_response_content = json.dumps({"query": query, "matches": matches}, ensure_ascii=False)
time.sleep(0.5) # 缩短等待,便于前端即时显示
yield json.dumps(
{
"type": "tool_result",
"tool_call_id": tool_call_id,
"content": tool_response_content,
},
ensure_ascii=False,
) + "\n"
print("[chat] tool_result emit, matches", len(matches))
assistant_tool_msg = {
"role": "assistant",
"content": segment_text,
"tool_calls": [{
"id": tool_call_id,
"type": "function",
"function": {"name": tool_name, "arguments": tool_args}
}]
}
tool_result_msg = {
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_response_content,
}
model_messages = model_messages + [assistant_tool_msg, tool_result_msg]
new_messages.append(assistant_tool_msg)
new_messages.append(tool_result_msg)
except Exception as e:
err = f"[出错]{e}"
yield from emit_events_for_text(err)
new_messages.append({"role": "assistant", "content": err})
finally:
if new_messages:
convo['messages'].extend(new_messages)
convo['updated_at'] = datetime.utcnow().isoformat()
save_conversation(convo)
headers = {'X-Conversation-Id': cid}
return Response(stream_with_context(generate()), mimetype='application/x-ndjson; charset=utf-8', headers=headers)