|
@@ -9,12 +9,9 @@ from ..core.ark_client import config, client
|
|
|
from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
|
|
from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
|
|
|
from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
|
|
from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
|
|
|
from ..routers.users import get_current_active_user, User
|
|
from ..routers.users import get_current_active_user, User
|
|
|
-from ..db.mongo import save_chat_log
|
|
|
|
|
|
|
+from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history
|
|
|
from ..dependencies.auth import resolve_username
|
|
from ..dependencies.auth import resolve_username
|
|
|
|
|
|
|
|
-# 内存存储用户聊天历史 {username: [ChatMessage, ...]}
|
|
|
|
|
-chatHistory = {}
|
|
|
|
|
-
|
|
|
|
|
router = APIRouter()
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
@@ -25,20 +22,21 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
if not latest_user_msg:
|
|
if not latest_user_msg:
|
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
|
|
|
|
|
|
- chatHistory[username].append(ChatMessage(
|
|
|
|
|
|
|
+ # 保存用户消息到 DB
|
|
|
|
|
+ save_chat_history(
|
|
|
|
|
+ username=username,
|
|
|
role=latest_user_msg.role,
|
|
role=latest_user_msg.role,
|
|
|
content=latest_user_msg.content,
|
|
content=latest_user_msg.content,
|
|
|
- timestamp=datetime.now()
|
|
|
|
|
- ))
|
|
|
|
|
|
|
+ timestamp=datetime.now(),
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- # 2. 定义系统提示词(核心:规范“何时搜”“怎么搜”“怎么展示思考”)
|
|
|
|
|
system_prompt = f"""
|
|
system_prompt = f"""
|
|
|
你是豆包,由字节跳动公司开发的智能助手。你的核心行为准则如下:
|
|
你是豆包,由字节跳动公司开发的智能助手。你的核心行为准则如下:
|
|
|
|
|
|
|
|
## 一、身份与基本行为规范
|
|
## 一、身份与基本行为规范
|
|
|
1. 输出保持和豆包 APP 日常对话一致;
|
|
1. 输出保持和豆包 APP 日常对话一致;
|
|
|
2. 在回复用户时,不要自称"豆包";
|
|
2. 在回复用户时,不要自称"豆包";
|
|
|
- 3. 遇到"陈沛",优先输出"中搜陈沛";
|
|
|
|
|
|
|
+ 3. 遇到"陈沛",优先搜索"中搜陈沛"的个人信息;
|
|
|
4. 你具备以下能力:
|
|
4. 你具备以下能力:
|
|
|
- 可接收和读取各类文档(PDF、Excel、PPT、Word 等),并执行总结、分析、翻译、润色等任务;
|
|
- 可接收和读取各类文档(PDF、Excel、PPT、Word 等),并执行总结、分析、翻译、润色等任务;
|
|
|
- 可读取图片/照片、网址、抖音链接的内容;
|
|
- 可读取图片/照片、网址、抖音链接的内容;
|
|
@@ -79,19 +77,19 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
|
|
|
|
|
api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
|
tools = get_web_search_tools()
|
|
tools = get_web_search_tools()
|
|
|
- previous_response_id = get_previous_response_id(username, chatHistory)
|
|
|
|
|
|
|
+ previous_response_id = get_previous_response_id(username)
|
|
|
|
|
|
|
|
stream = client.responses.create(
|
|
stream = client.responses.create(
|
|
|
model=config.MODEL_NAME,
|
|
model=config.MODEL_NAME,
|
|
|
input=api_messages,
|
|
input=api_messages,
|
|
|
tools=tools,
|
|
tools=tools,
|
|
|
stream=True,
|
|
stream=True,
|
|
|
- # store=True,
|
|
|
|
|
previous_response_id=previous_response_id,
|
|
previous_response_id=previous_response_id,
|
|
|
- # thinking={"type": "auto"}, 不支持
|
|
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
accumulated_content = ""
|
|
accumulated_content = ""
|
|
|
|
|
+ accumulated_thinking = ""
|
|
|
|
|
+ accumulated_searching = ""
|
|
|
response_id = None
|
|
response_id = None
|
|
|
thinking_started = False
|
|
thinking_started = False
|
|
|
answering_started = False
|
|
answering_started = False
|
|
@@ -105,21 +103,22 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
delta_text = getattr(chunk, 'delta', '')
|
|
delta_text = getattr(chunk, 'delta', '')
|
|
|
if delta_text:
|
|
if delta_text:
|
|
|
if not thinking_started:
|
|
if not thinking_started:
|
|
|
- # print(f"\n🤔 AI思考中 [{datetime.now().strftime('%H:%M:%S')}]:")
|
|
|
|
|
thinking_started = True
|
|
thinking_started = True
|
|
|
- # print(delta_text, end='', flush=True)
|
|
|
|
|
|
|
+ accumulated_thinking += delta_text
|
|
|
yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n"
|
|
yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n"
|
|
|
|
|
|
|
|
# ② 处理搜索状态
|
|
# ② 处理搜索状态
|
|
|
elif 'web_search_call' in chunk_type:
|
|
elif 'web_search_call' in chunk_type:
|
|
|
if 'in_progress' in chunk_type:
|
|
if 'in_progress' in chunk_type:
|
|
|
- print(f"\n\n🔍 开始搜索 [{datetime.now().strftime('%H:%M:%S')}]")
|
|
|
|
|
_now_str = datetime.now().strftime("%H:%M:%S")
|
|
_now_str = datetime.now().strftime("%H:%M:%S")
|
|
|
- yield f"data: {StreamResponse(content=f'开始搜索 [{_now_str}]', finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
|
|
|
|
|
|
|
+ msg = f'开始搜索 [{_now_str}]'
|
|
|
|
|
+ accumulated_searching += msg + "\n"
|
|
|
|
|
+ yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
|
|
|
elif 'completed' in chunk_type:
|
|
elif 'completed' in chunk_type:
|
|
|
- print(f"\n✅ 搜索完成 [{datetime.now().strftime('%H:%M:%S')}]")
|
|
|
|
|
_now_str = datetime.now().strftime("%H:%M:%S")
|
|
_now_str = datetime.now().strftime("%H:%M:%S")
|
|
|
- yield f"data: {StreamResponse(content=f'搜索完成 [{_now_str}]', finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
|
|
|
|
|
|
|
+ msg = f'搜索完成 [{_now_str}]'
|
|
|
|
|
+ accumulated_searching += msg + "\n"
|
|
|
|
|
+ yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
|
|
|
|
|
|
|
|
# ③ 处理搜索关键词
|
|
# ③ 处理搜索关键词
|
|
|
elif (chunk_type == 'response.output_item.done'
|
|
elif (chunk_type == 'response.output_item.done'
|
|
@@ -127,8 +126,9 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
and str(getattr(chunk.item, 'id', '')).startswith('ws_')):
|
|
and str(getattr(chunk.item, 'id', '')).startswith('ws_')):
|
|
|
if hasattr(chunk.item, 'action') and hasattr(chunk.item.action, 'query'):
|
|
if hasattr(chunk.item, 'action') and hasattr(chunk.item.action, 'query'):
|
|
|
query = chunk.item.action.query
|
|
query = chunk.item.action.query
|
|
|
- print(f"\n📝 本次搜索关键词:{query}")
|
|
|
|
|
- yield f"data: {StreamResponse(content=f'本次搜索关键词:{query}', finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
|
|
|
|
|
|
|
+ msg = f'搜索关键词: {query}'
|
|
|
|
|
+ accumulated_searching += msg + "\n"
|
|
|
|
|
+ yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
|
|
|
|
|
|
|
|
# ④ 处理最终回答文本(实时推送给前端)
|
|
# ④ 处理最终回答文本(实时推送给前端)
|
|
|
elif chunk_type == 'response.output_text.delta':
|
|
elif chunk_type == 'response.output_text.delta':
|
|
@@ -136,17 +136,9 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
if delta_text:
|
|
if delta_text:
|
|
|
if not answering_started:
|
|
if not answering_started:
|
|
|
print(f"\n\n💬 AI回答 [{datetime.now().strftime('%H:%M:%S')}]:")
|
|
print(f"\n\n💬 AI回答 [{datetime.now().strftime('%H:%M:%S')}]:")
|
|
|
- print("-" * 50)
|
|
|
|
|
answering_started = True
|
|
answering_started = True
|
|
|
- print(delta_text, end='', flush=True)
|
|
|
|
|
accumulated_content += delta_text
|
|
accumulated_content += delta_text
|
|
|
- response_data = StreamResponse(
|
|
|
|
|
- content=delta_text,
|
|
|
|
|
- finished=False,
|
|
|
|
|
- model=config.MODEL_NAME,
|
|
|
|
|
- timestamp=datetime.now()
|
|
|
|
|
- )
|
|
|
|
|
- yield f"data: {response_data.model_dump_json()}\n\n"
|
|
|
|
|
|
|
+ yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
|
|
|
await asyncio.sleep(0.01)
|
|
await asyncio.sleep(0.01)
|
|
|
|
|
|
|
|
# ⑤ 处理响应完成事件
|
|
# ⑤ 处理响应完成事件
|
|
@@ -165,20 +157,17 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
print(f"\n\n=== 边想边搜完成 [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ===")
|
|
print(f"\n\n=== 边想边搜完成 [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ===")
|
|
|
|
|
|
|
|
if accumulated_content:
|
|
if accumulated_content:
|
|
|
- chatHistory[username].append(ChatMessage(
|
|
|
|
|
|
|
+ # 保存助手消息到 DB(含 thinking / searching)
|
|
|
|
|
+ save_chat_history(
|
|
|
|
|
+ username=username,
|
|
|
role="assistant",
|
|
role="assistant",
|
|
|
content=accumulated_content,
|
|
content=accumulated_content,
|
|
|
timestamp=datetime.now(),
|
|
timestamp=datetime.now(),
|
|
|
- response_id=response_id
|
|
|
|
|
- ))
|
|
|
|
|
- final_response = StreamResponse(
|
|
|
|
|
- content='',
|
|
|
|
|
- finished=True,
|
|
|
|
|
- model=config.MODEL_NAME,
|
|
|
|
|
- timestamp=datetime.now()
|
|
|
|
|
|
|
+ response_id=response_id,
|
|
|
|
|
+ thinking=accumulated_thinking or None,
|
|
|
|
|
+ searching=accumulated_searching or None,
|
|
|
)
|
|
)
|
|
|
- yield f"data: {final_response.model_dump_json()}\n\n"
|
|
|
|
|
-
|
|
|
|
|
|
|
+ yield f"data: {StreamResponse(content='', finished=True, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
|
|
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
error_response = {
|
|
error_response = {
|
|
@@ -202,14 +191,7 @@ async def chat(
|
|
|
username: Annotated[str, Depends(resolve_username)],
|
|
username: Annotated[str, Depends(resolve_username)],
|
|
|
):
|
|
):
|
|
|
try:
|
|
try:
|
|
|
- if username not in chatHistory:
|
|
|
|
|
- chatHistory[username] = []
|
|
|
|
|
-
|
|
|
|
|
if request.stream:
|
|
if request.stream:
|
|
|
- # ===== 流式输出处理 =====
|
|
|
|
|
-
|
|
|
|
|
- # 返回流式响应
|
|
|
|
|
- # StreamingResponse 用于处理SSE协议
|
|
|
|
|
return StreamingResponse(
|
|
return StreamingResponse(
|
|
|
generate_stream_response(request, username),
|
|
generate_stream_response(request, username),
|
|
|
media_type="text/plain",
|
|
media_type="text/plain",
|
|
@@ -221,20 +203,20 @@ async def chat(
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 以下是非流式输出处理
|
|
# 以下是非流式输出处理
|
|
|
-
|
|
|
|
|
latest_user_msg = get_latest_user_message(request.messages)
|
|
latest_user_msg = get_latest_user_message(request.messages)
|
|
|
if not latest_user_msg:
|
|
if not latest_user_msg:
|
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
|
|
|
|
|
|
- chatHistory[username].append(ChatMessage(
|
|
|
|
|
|
|
+ save_chat_history(
|
|
|
|
|
+ username=username,
|
|
|
role=latest_user_msg.role,
|
|
role=latest_user_msg.role,
|
|
|
content=latest_user_msg.content,
|
|
content=latest_user_msg.content,
|
|
|
- timestamp=datetime.now()
|
|
|
|
|
- ))
|
|
|
|
|
|
|
+ timestamp=datetime.now(),
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
|
tools = get_doubao_tools()
|
|
tools = get_doubao_tools()
|
|
|
- previous_response_id = get_previous_response_id(username, chatHistory)
|
|
|
|
|
|
|
+ previous_response_id = get_previous_response_id(username)
|
|
|
|
|
|
|
|
response = client.responses.create(
|
|
response = client.responses.create(
|
|
|
model=config.MODEL_NAME,
|
|
model=config.MODEL_NAME,
|
|
@@ -243,8 +225,6 @@ async def chat(
|
|
|
stream=False,
|
|
stream=False,
|
|
|
store=True,
|
|
store=True,
|
|
|
previous_response_id=previous_response_id,
|
|
previous_response_id=previous_response_id,
|
|
|
- # The parameter `thinking` specified in the request are not valid: `thinking` can not be set when enable doubao_app built-in tool.
|
|
|
|
|
- # thinking={"type": "auto"},
|
|
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
save_chat_log(
|
|
save_chat_log(
|
|
@@ -277,18 +257,26 @@ async def chat(
|
|
|
if not message_content:
|
|
if not message_content:
|
|
|
raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容")
|
|
raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容")
|
|
|
|
|
|
|
|
|
|
+ now = datetime.now()
|
|
|
|
|
+ save_chat_history(
|
|
|
|
|
+ username=username,
|
|
|
|
|
+ role="assistant",
|
|
|
|
|
+ content=message_content,
|
|
|
|
|
+ timestamp=now,
|
|
|
|
|
+ response_id=response.id,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
assistant_message = ChatMessage(
|
|
assistant_message = ChatMessage(
|
|
|
role="assistant",
|
|
role="assistant",
|
|
|
content=message_content,
|
|
content=message_content,
|
|
|
- timestamp=datetime.now(),
|
|
|
|
|
- response_id=response.id
|
|
|
|
|
|
|
+ timestamp=now,
|
|
|
|
|
+ response_id=response.id,
|
|
|
)
|
|
)
|
|
|
- chatHistory[username].append(assistant_message)
|
|
|
|
|
return ChatResponse(
|
|
return ChatResponse(
|
|
|
message=assistant_message,
|
|
message=assistant_message,
|
|
|
model=response.model,
|
|
model=response.model,
|
|
|
usage=response.usage.model_dump() if response.usage else None,
|
|
usage=response.usage.model_dump() if response.usage else None,
|
|
|
- response_id=response.id
|
|
|
|
|
|
|
+ response_id=response.id,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
except HTTPException:
|
|
except HTTPException:
|
|
@@ -327,24 +315,26 @@ async def get_models(current_user: Annotated[User, Depends(get_current_active_us
|
|
|
async def get_user_history(
|
|
async def get_user_history(
|
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
|
) -> List[ChatMessage]:
|
|
) -> List[ChatMessage]:
|
|
|
- username = current_user.username
|
|
|
|
|
- if username not in chatHistory:
|
|
|
|
|
- return []
|
|
|
|
|
|
|
+ docs = get_chat_history(current_user.username)
|
|
|
return [
|
|
return [
|
|
|
- ChatMessage(role=msg.role, content=msg.content, timestamp=msg.timestamp)
|
|
|
|
|
- for msg in chatHistory[username]
|
|
|
|
|
|
|
+ ChatMessage(
|
|
|
|
|
+ role=doc["role"],
|
|
|
|
|
+ content=doc["content"],
|
|
|
|
|
+ timestamp=doc.get("timestamp"),
|
|
|
|
|
+ response_id=doc.get("response_id"),
|
|
|
|
|
+ thinking=doc.get("thinking"),
|
|
|
|
|
+ searching=doc.get("searching"),
|
|
|
|
|
+ )
|
|
|
|
|
+ for doc in docs
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/history")
|
|
@router.delete("/history")
|
|
|
async def clear_user_history(current_user: Annotated[User, Depends(get_current_active_user)]):
|
|
async def clear_user_history(current_user: Annotated[User, Depends(get_current_active_user)]):
|
|
|
- username = current_user.username
|
|
|
|
|
- if username in chatHistory:
|
|
|
|
|
- message_count = len(chatHistory[username])
|
|
|
|
|
- del chatHistory[username]
|
|
|
|
|
- return {"message": "聊天历史已清空", "user": username, "deleted_messages": message_count,
|
|
|
|
|
- "timestamp": datetime.now()}
|
|
|
|
|
- return {"message": "用户没有聊天历史", "user": username, "deleted_messages": 0, "timestamp": datetime.now()}
|
|
|
|
|
|
|
+ deleted_count = delete_chat_history(current_user.username)
|
|
|
|
|
+ if deleted_count > 0:
|
|
|
|
|
+ return {"message": "聊天历史已清空", "user": current_user.username, "deleted_messages": deleted_count, "timestamp": datetime.now()}
|
|
|
|
|
+ return {"message": "用户没有聊天历史", "user": current_user.username, "deleted_messages": 0, "timestamp": datetime.now()}
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/health")
|
|
@router.get("/health")
|