|
@@ -4,12 +4,13 @@ from typing import List, Annotated, Optional
|
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
|
import json
|
|
import json
|
|
|
import asyncio
|
|
import asyncio
|
|
|
|
|
+import threading
|
|
|
|
|
|
|
|
from ..core.ark_client import config, client
|
|
from ..core.ark_client import config, client
|
|
|
from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
|
|
from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
|
|
|
from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
|
|
from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
|
|
|
from ..routers.users import get_current_active_user, User
|
|
from ..routers.users import get_current_active_user, User
|
|
|
-from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history
|
|
|
|
|
|
|
+from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions
|
|
|
from ..dependencies.auth import resolve_user_id
|
|
from ..dependencies.auth import resolve_user_id
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
router = APIRouter()
|
|
@@ -93,19 +94,37 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
|
|
|
accumulated_thinking = ""
|
|
accumulated_thinking = ""
|
|
|
accumulated_searching = ""
|
|
accumulated_searching = ""
|
|
|
response_id = None
|
|
response_id = None
|
|
|
- thinking_started = False
|
|
|
|
|
- answering_started = False
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # 将同步阻塞的 stream 迭代放入子线程,通过 Queue 传递给异步生成器
|
|
|
|
|
+ # 避免阻塞事件循环,保证每个 chunk 到达时立即 yield 推送给前端
|
|
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
|
|
+ queue: asyncio.Queue = asyncio.Queue()
|
|
|
|
|
+
|
|
|
|
|
+ def _iterate_stream():
|
|
|
|
|
+ try:
|
|
|
|
|
+ for chunk in stream:
|
|
|
|
|
+ loop.call_soon_threadsafe(queue.put_nowait, chunk)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ loop.call_soon_threadsafe(queue.put_nowait, e)
|
|
|
|
|
+ finally:
|
|
|
|
|
+ loop.call_soon_threadsafe(queue.put_nowait, None) # 结束哨兵
|
|
|
|
|
+
|
|
|
|
|
+ threading.Thread(target=_iterate_stream, daemon=True).start()
|
|
|
|
|
|
|
|
print("=== 边想边搜启动 ===")
|
|
print("=== 边想边搜启动 ===")
|
|
|
- for chunk in stream:
|
|
|
|
|
|
|
+ while True:
|
|
|
|
|
+ chunk = await queue.get()
|
|
|
|
|
+ if chunk is None:
|
|
|
|
|
+ break
|
|
|
|
|
+ if isinstance(chunk, Exception):
|
|
|
|
|
+ raise chunk
|
|
|
|
|
+
|
|
|
chunk_type = getattr(chunk, 'type', '')
|
|
chunk_type = getattr(chunk, 'type', '')
|
|
|
|
|
|
|
|
# ① 处理AI思考过程
|
|
# ① 处理AI思考过程
|
|
|
if chunk_type == 'response.reasoning_summary_text.delta':
|
|
if chunk_type == 'response.reasoning_summary_text.delta':
|
|
|
delta_text = getattr(chunk, 'delta', '')
|
|
delta_text = getattr(chunk, 'delta', '')
|
|
|
if delta_text:
|
|
if delta_text:
|
|
|
- if not thinking_started:
|
|
|
|
|
- thinking_started = True
|
|
|
|
|
accumulated_thinking += delta_text
|
|
accumulated_thinking += delta_text
|
|
|
yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n"
|
|
yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n"
|
|
|
|
|
|
|
@@ -136,12 +155,8 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
|
|
|
elif chunk_type == 'response.output_text.delta':
|
|
elif chunk_type == 'response.output_text.delta':
|
|
|
delta_text = getattr(chunk, 'delta', '')
|
|
delta_text = getattr(chunk, 'delta', '')
|
|
|
if delta_text:
|
|
if delta_text:
|
|
|
- if not answering_started:
|
|
|
|
|
- # print(f"\n\n💬 AI回答 [{datetime.now().strftime('%H:%M:%S')}]:")
|
|
|
|
|
- answering_started = True
|
|
|
|
|
accumulated_content += delta_text
|
|
accumulated_content += delta_text
|
|
|
yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
|
|
yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
|
|
|
- await asyncio.sleep(0.01)
|
|
|
|
|
|
|
|
|
|
# ⑤ 处理响应完成事件
|
|
# ⑤ 处理响应完成事件
|
|
|
elif chunk_type == 'response.completed':
|
|
elif chunk_type == 'response.completed':
|
|
@@ -347,6 +362,13 @@ async def clear_user_history(
|
|
|
return {"message": "用户没有聊天历史", "user": current_user.userId, "deleted_messages": 0, "timestamp": datetime.now()}
|
|
return {"message": "用户没有聊天历史", "user": current_user.userId, "deleted_messages": 0, "timestamp": datetime.now()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+@router.get("/sessions")
|
|
|
|
|
+async def get_user_sessions(
|
|
|
|
|
+ current_user: Annotated[User, Depends(get_current_active_user)],
|
|
|
|
|
+):
|
|
|
|
|
+ return get_sessions(current_user.userId)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
@router.get("/health")
|
|
@router.get("/health")
|
|
|
async def health_check():
|
|
async def health_check():
|
|
|
return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME}
|
|
return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME}
|