216 lines
7.9 KiB
Python
216 lines
7.9 KiB
Python
import json
|
|
import re
|
|
import time
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from flask import Blueprint, request, jsonify, Response, stream_with_context
|
|
from openai import OpenAI
|
|
|
|
from ..config import TOKEN_INTERVAL, MODEL_NAME, OPENAI_BASE_URL, API_KEY
|
|
from ..prompt import SYSTEM_PROMPT
|
|
from ..rag import search_rag
|
|
from ..conversation_store import make_conversation, load as load_conversation, save as save_conversation
|
|
|
|
bp = Blueprint("chat", __name__, url_prefix="/api")
|
|
|
|
client = OpenAI(api_key=API_KEY, base_url=OPENAI_BASE_URL)
|
|
|
|
TOOLS = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search_rag",
|
|
"description": "在本地 QA 数据集中按 question 字段检索,返回最多 5 条匹配。",
|
|
"parameters": {
|
|
"type": "object",
|
|
"required": ["query"],
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "用户问题或关键词"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
|
|
def sanitize_legacy_markers(text: str) -> str:
|
|
if not text:
|
|
return text
|
|
text = text.replace("[[SEARCH_START]]搜索中...", "")
|
|
text = text.replace("[[SEARCH_DONE]]搜索完成", "")
|
|
text = text.replace("[[SEARCH_START]]", "")
|
|
text = text.replace("[[SEARCH_DONE]]", "")
|
|
text = re.sub(r"(?m)^[ \t]*搜索中\.\.\.[ \t]*\n?", "", text)
|
|
text = re.sub(r"(?m)^[ \t]*搜索完成[ \t]*\n?", "", text)
|
|
return text
|
|
|
|
|
|
def messages_for_model(messages: list[dict]) -> list[dict]:
|
|
sanitized = []
|
|
for m in messages:
|
|
mm = dict(m)
|
|
if mm.get("role") == "assistant" and isinstance(mm.get("content"), str):
|
|
mm["content"] = sanitize_legacy_markers(mm["content"])
|
|
sanitized.append(mm)
|
|
return sanitized
|
|
|
|
|
|
def emit_events_for_text(text: str):
|
|
for ch in text:
|
|
yield json.dumps({"type": "assistant_delta", "delta": ch}, ensure_ascii=False) + "\n"
|
|
time.sleep(TOKEN_INTERVAL)
|
|
|
|
|
|
@bp.post("/chat")
|
|
def chat():
|
|
data = request.get_json(force=True)
|
|
user_text = (data.get('message') or '').strip()
|
|
if not user_text:
|
|
return jsonify({"error": "message is required"}), 400
|
|
|
|
cid = data.get('conversation_id')
|
|
convo = load_conversation(cid) if cid else None
|
|
if not convo:
|
|
convo = make_conversation(user_text[:20])
|
|
cid = convo['id']
|
|
|
|
# 仅记录 AI 对话的消息
|
|
convo['messages'].append({"role": "user", "content": user_text})
|
|
convo['updated_at'] = datetime.utcnow().isoformat()
|
|
save_conversation(convo)
|
|
|
|
all_messages = [SYSTEM_PROMPT] + messages_for_model(convo['messages'])
|
|
|
|
def generate():
|
|
new_messages = []
|
|
try:
|
|
model_messages = list(all_messages)
|
|
max_tool_rounds = 2
|
|
tool_round = 0
|
|
|
|
while True:
|
|
completion = client.chat.completions.create(
|
|
model=MODEL_NAME,
|
|
messages=model_messages,
|
|
temperature=0.6,
|
|
stream=True,
|
|
tools=TOOLS,
|
|
tool_choice="auto",
|
|
)
|
|
print("[chat] streaming start, cid=", cid, "tool_round=", tool_round, "len(messages)=", len(model_messages))
|
|
|
|
tool_calls_acc = {}
|
|
segment_text = ""
|
|
tool_start_emitted = False
|
|
|
|
for chunk in completion:
|
|
delta = chunk.choices[0].delta
|
|
text = delta.content or ""
|
|
if text:
|
|
segment_text += text
|
|
yield from emit_events_for_text(text)
|
|
print("[chat] delta text len", len(text))
|
|
|
|
if delta.tool_calls:
|
|
if not tool_start_emitted:
|
|
tool_start_emitted = True
|
|
yield json.dumps({"type": "tool_call_start"}, ensure_ascii=False) + "\n"
|
|
print("[chat] tool_call_start")
|
|
for tc in delta.tool_calls:
|
|
idx = tc.index or 0
|
|
entry = tool_calls_acc.setdefault(idx, {"id": tc.id, "name": None, "arguments": ""})
|
|
if tc.id:
|
|
entry["id"] = tc.id
|
|
if tc.function:
|
|
if tc.function.name:
|
|
entry["name"] = tc.function.name
|
|
if tc.function.arguments:
|
|
entry["arguments"] += tc.function.arguments
|
|
print("[chat] tool_call accumulating args len", len(entry["arguments"]))
|
|
|
|
if not tool_calls_acc:
|
|
if segment_text:
|
|
new_messages.append({"role": "assistant", "content": segment_text})
|
|
break
|
|
|
|
if tool_round >= max_tool_rounds:
|
|
if segment_text:
|
|
new_messages.append({"role": "assistant", "content": segment_text})
|
|
break
|
|
|
|
tool_round += 1
|
|
|
|
call = list(tool_calls_acc.values())[0]
|
|
tool_name = call.get("name")
|
|
tool_args = call.get("arguments") or ""
|
|
tool_call_id = call.get("id") or str(uuid.uuid4())
|
|
|
|
yield json.dumps(
|
|
{
|
|
"type": "tool_call",
|
|
"tool_call_id": tool_call_id,
|
|
"name": tool_name,
|
|
"arguments": tool_args,
|
|
},
|
|
ensure_ascii=False,
|
|
) + "\n"
|
|
print("[chat] tool_call emit", tool_name, tool_args)
|
|
|
|
query = ""
|
|
try:
|
|
parsed = json.loads(tool_args or "{}")
|
|
query = parsed.get("query", "")
|
|
except Exception:
|
|
query = tool_args
|
|
|
|
matches = search_rag(query, limit=5)
|
|
tool_response_content = json.dumps({"query": query, "matches": matches}, ensure_ascii=False)
|
|
|
|
time.sleep(0.5) # 缩短等待,便于前端即时显示
|
|
|
|
yield json.dumps(
|
|
{
|
|
"type": "tool_result",
|
|
"tool_call_id": tool_call_id,
|
|
"content": tool_response_content,
|
|
},
|
|
ensure_ascii=False,
|
|
) + "\n"
|
|
print("[chat] tool_result emit, matches", len(matches))
|
|
|
|
assistant_tool_msg = {
|
|
"role": "assistant",
|
|
"content": segment_text,
|
|
"tool_calls": [{
|
|
"id": tool_call_id,
|
|
"type": "function",
|
|
"function": {"name": tool_name, "arguments": tool_args}
|
|
}]
|
|
}
|
|
tool_result_msg = {
|
|
"role": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"content": tool_response_content,
|
|
}
|
|
|
|
model_messages = model_messages + [assistant_tool_msg, tool_result_msg]
|
|
new_messages.append(assistant_tool_msg)
|
|
new_messages.append(tool_result_msg)
|
|
|
|
except Exception as e:
|
|
err = f"[出错]{e}"
|
|
yield from emit_events_for_text(err)
|
|
new_messages.append({"role": "assistant", "content": err})
|
|
finally:
|
|
if new_messages:
|
|
convo['messages'].extend(new_messages)
|
|
convo['updated_at'] = datetime.utcnow().isoformat()
|
|
save_conversation(convo)
|
|
|
|
headers = {'X-Conversation-Id': cid}
|
|
return Response(stream_with_context(generate()), mimetype='application/x-ndjson; charset=utf-8', headers=headers)
|