|
@@ -1,6 +1,6 @@
|
|
|
-from fastapi import APIRouter, HTTPException, Depends
|
|
|
|
|
|
|
+from fastapi import APIRouter, HTTPException, Depends, Query
|
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.responses import StreamingResponse
|
|
|
-from typing import List, Annotated
|
|
|
|
|
|
|
+from typing import List, Annotated, Optional
|
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
|
import json
|
|
import json
|
|
|
import asyncio
|
|
import asyncio
|
|
@@ -10,12 +10,13 @@ from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamRespons
|
|
|
from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
|
|
from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
|
|
|
from ..routers.users import get_current_active_user, User
|
|
from ..routers.users import get_current_active_user, User
|
|
|
from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history
|
|
from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history
|
|
|
-from ..dependencies.auth import resolve_username
|
|
|
|
|
|
|
+from ..dependencies.auth import resolve_user_id
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
-async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
|
|
|
|
+async def generate_stream_response(request: ChatRequest, user_id: str):
|
|
|
|
|
+ session_id = request.session_id
|
|
|
latest_user_msg = None
|
|
latest_user_msg = None
|
|
|
try:
|
|
try:
|
|
|
latest_user_msg = get_latest_user_message(request.messages)
|
|
latest_user_msg = get_latest_user_message(request.messages)
|
|
@@ -24,7 +25,8 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
|
|
|
|
|
# 保存用户消息到 DB
|
|
# 保存用户消息到 DB
|
|
|
save_chat_history(
|
|
save_chat_history(
|
|
|
- username=username,
|
|
|
|
|
|
|
+ user_id=user_id,
|
|
|
|
|
+ session_id=session_id,
|
|
|
role=latest_user_msg.role,
|
|
role=latest_user_msg.role,
|
|
|
content=latest_user_msg.content,
|
|
content=latest_user_msg.content,
|
|
|
timestamp=datetime.now(),
|
|
timestamp=datetime.now(),
|
|
@@ -77,7 +79,7 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
|
|
|
|
|
api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
|
tools = get_web_search_tools()
|
|
tools = get_web_search_tools()
|
|
|
- previous_response_id = get_previous_response_id(username)
|
|
|
|
|
|
|
+ previous_response_id = get_previous_response_id(user_id, session_id)
|
|
|
|
|
|
|
|
stream = client.responses.create(
|
|
stream = client.responses.create(
|
|
|
model=config.MODEL_NAME,
|
|
model=config.MODEL_NAME,
|
|
@@ -135,7 +137,7 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
delta_text = getattr(chunk, 'delta', '')
|
|
delta_text = getattr(chunk, 'delta', '')
|
|
|
if delta_text:
|
|
if delta_text:
|
|
|
if not answering_started:
|
|
if not answering_started:
|
|
|
- print(f"\n\n💬 AI回答 [{datetime.now().strftime('%H:%M:%S')}]:")
|
|
|
|
|
|
|
+ # print(f"\n\n💬 AI回答 [{datetime.now().strftime('%H:%M:%S')}]:")
|
|
|
answering_started = True
|
|
answering_started = True
|
|
|
accumulated_content += delta_text
|
|
accumulated_content += delta_text
|
|
|
yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
|
|
yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
|
|
@@ -147,7 +149,7 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
if response_obj and hasattr(response_obj, 'id'):
|
|
if response_obj and hasattr(response_obj, 'id'):
|
|
|
response_id = response_obj.id
|
|
response_id = response_obj.id
|
|
|
save_chat_log(
|
|
save_chat_log(
|
|
|
- username=username,
|
|
|
|
|
|
|
+ user_id=user_id,
|
|
|
question=latest_user_msg.content,
|
|
question=latest_user_msg.content,
|
|
|
stream_mode=True,
|
|
stream_mode=True,
|
|
|
raw_response=repr(response_obj),
|
|
raw_response=repr(response_obj),
|
|
@@ -159,7 +161,8 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
if accumulated_content:
|
|
if accumulated_content:
|
|
|
# 保存助手消息到 DB(含 thinking / searching)
|
|
# 保存助手消息到 DB(含 thinking / searching)
|
|
|
save_chat_history(
|
|
save_chat_history(
|
|
|
- username=username,
|
|
|
|
|
|
|
+ user_id=user_id,
|
|
|
|
|
+ session_id=session_id,
|
|
|
role="assistant",
|
|
role="assistant",
|
|
|
content=accumulated_content,
|
|
content=accumulated_content,
|
|
|
timestamp=datetime.now(),
|
|
timestamp=datetime.now(),
|
|
@@ -176,7 +179,7 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
"timestamp": datetime.now().isoformat()
|
|
"timestamp": datetime.now().isoformat()
|
|
|
}
|
|
}
|
|
|
save_chat_log(
|
|
save_chat_log(
|
|
|
- username=username,
|
|
|
|
|
|
|
+ user_id=user_id,
|
|
|
question=latest_user_msg.content if latest_user_msg else "",
|
|
question=latest_user_msg.content if latest_user_msg else "",
|
|
|
stream_mode=True,
|
|
stream_mode=True,
|
|
|
status="error",
|
|
status="error",
|
|
@@ -188,12 +191,12 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
@router.post("/chat", response_model=ChatResponse)
|
|
@router.post("/chat", response_model=ChatResponse)
|
|
|
async def chat(
|
|
async def chat(
|
|
|
request: ChatRequest,
|
|
request: ChatRequest,
|
|
|
- username: Annotated[str, Depends(resolve_username)],
|
|
|
|
|
|
|
+ user_id: Annotated[str, Depends(resolve_user_id)],
|
|
|
):
|
|
):
|
|
|
try:
|
|
try:
|
|
|
if request.stream:
|
|
if request.stream:
|
|
|
return StreamingResponse(
|
|
return StreamingResponse(
|
|
|
- generate_stream_response(request, username),
|
|
|
|
|
|
|
+ generate_stream_response(request, user_id),
|
|
|
media_type="text/plain",
|
|
media_type="text/plain",
|
|
|
headers={
|
|
headers={
|
|
|
"Cache-Control": "no-cache",
|
|
"Cache-Control": "no-cache",
|
|
@@ -203,12 +206,14 @@ async def chat(
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 以下是非流式输出处理
|
|
# 以下是非流式输出处理
|
|
|
|
|
+ session_id = request.session_id
|
|
|
latest_user_msg = get_latest_user_message(request.messages)
|
|
latest_user_msg = get_latest_user_message(request.messages)
|
|
|
if not latest_user_msg:
|
|
if not latest_user_msg:
|
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
|
|
|
|
|
|
save_chat_history(
|
|
save_chat_history(
|
|
|
- username=username,
|
|
|
|
|
|
|
+ user_id=user_id,
|
|
|
|
|
+ session_id=session_id,
|
|
|
role=latest_user_msg.role,
|
|
role=latest_user_msg.role,
|
|
|
content=latest_user_msg.content,
|
|
content=latest_user_msg.content,
|
|
|
timestamp=datetime.now(),
|
|
timestamp=datetime.now(),
|
|
@@ -216,7 +221,7 @@ async def chat(
|
|
|
|
|
|
|
|
api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
|
tools = get_doubao_tools()
|
|
tools = get_doubao_tools()
|
|
|
- previous_response_id = get_previous_response_id(username)
|
|
|
|
|
|
|
+ previous_response_id = get_previous_response_id(user_id, session_id)
|
|
|
|
|
|
|
|
response = client.responses.create(
|
|
response = client.responses.create(
|
|
|
model=config.MODEL_NAME,
|
|
model=config.MODEL_NAME,
|
|
@@ -228,7 +233,7 @@ async def chat(
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
save_chat_log(
|
|
save_chat_log(
|
|
|
- username=username,
|
|
|
|
|
|
|
+ user_id=user_id,
|
|
|
question=latest_user_msg.content,
|
|
question=latest_user_msg.content,
|
|
|
stream_mode=False,
|
|
stream_mode=False,
|
|
|
raw_response=repr(response),
|
|
raw_response=repr(response),
|
|
@@ -259,7 +264,8 @@ async def chat(
|
|
|
|
|
|
|
|
now = datetime.now()
|
|
now = datetime.now()
|
|
|
save_chat_history(
|
|
save_chat_history(
|
|
|
- username=username,
|
|
|
|
|
|
|
+ user_id=user_id,
|
|
|
|
|
+ session_id=session_id,
|
|
|
role="assistant",
|
|
role="assistant",
|
|
|
content=message_content,
|
|
content=message_content,
|
|
|
timestamp=now,
|
|
timestamp=now,
|
|
@@ -284,7 +290,7 @@ async def chat(
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
error_message = f"处理聊天请求时发生错误: {str(e)}"
|
|
error_message = f"处理聊天请求时发生错误: {str(e)}"
|
|
|
save_chat_log(
|
|
save_chat_log(
|
|
|
- username=username,
|
|
|
|
|
|
|
+ user_id=user_id,
|
|
|
question=request.messages[-1].content if request.messages else "",
|
|
question=request.messages[-1].content if request.messages else "",
|
|
|
stream_mode=request.stream,
|
|
stream_mode=request.stream,
|
|
|
status="error",
|
|
status="error",
|
|
@@ -313,9 +319,10 @@ async def get_models(current_user: Annotated[User, Depends(get_current_active_us
|
|
|
|
|
|
|
|
@router.get("/history")
|
|
@router.get("/history")
|
|
|
async def get_user_history(
|
|
async def get_user_history(
|
|
|
- current_user: Annotated[User, Depends(get_current_active_user)]
|
|
|
|
|
|
|
+ current_user: Annotated[User, Depends(get_current_active_user)],
|
|
|
|
|
+ sessionId: str = Query(..., description="会话ID"),
|
|
|
) -> List[ChatMessage]:
|
|
) -> List[ChatMessage]:
|
|
|
- docs = get_chat_history(current_user.username)
|
|
|
|
|
|
|
+ docs = get_chat_history(current_user.userId, sessionId)
|
|
|
return [
|
|
return [
|
|
|
ChatMessage(
|
|
ChatMessage(
|
|
|
role=doc["role"],
|
|
role=doc["role"],
|
|
@@ -330,11 +337,14 @@ async def get_user_history(
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/history")
|
|
@router.delete("/history")
|
|
|
-async def clear_user_history(current_user: Annotated[User, Depends(get_current_active_user)]):
|
|
|
|
|
- deleted_count = delete_chat_history(current_user.username)
|
|
|
|
|
|
|
+async def clear_user_history(
|
|
|
|
|
+ current_user: Annotated[User, Depends(get_current_active_user)],
|
|
|
|
|
+ sessionId: str = Query(..., description="会话ID"),
|
|
|
|
|
+):
|
|
|
|
|
+ deleted_count = delete_chat_history(current_user.userId, sessionId)
|
|
|
if deleted_count > 0:
|
|
if deleted_count > 0:
|
|
|
- return {"message": "聊天历史已清空", "user": current_user.username, "deleted_messages": deleted_count, "timestamp": datetime.now()}
|
|
|
|
|
- return {"message": "用户没有聊天历史", "user": current_user.username, "deleted_messages": 0, "timestamp": datetime.now()}
|
|
|
|
|
|
|
+ return {"message": "聊天历史已清空", "user": current_user.userId, "deleted_messages": deleted_count, "timestamp": datetime.now()}
|
|
|
|
|
+ return {"message": "用户没有聊天历史", "user": current_user.userId, "deleted_messages": 0, "timestamp": datetime.now()}
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/health")
|
|
@router.get("/health")
|