| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- from fastapi import APIRouter, HTTPException, Depends
- from fastapi.responses import StreamingResponse
- from typing import List, Annotated
- from datetime import datetime
- import json
- import asyncio
- from ..core.ark_client import config, client
- from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
- from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools
- from ..routers.users import get_current_active_user, User
- from ..db.mongo import save_chat_log
- from ..dependencies.auth import resolve_username
- # 内存存储用户聊天历史 {username: [ChatMessage, ...]}
- chatHistory = {}
- router = APIRouter()
- async def generate_stream_response(request: ChatRequest, username: str):
- latest_user_msg = None
- try:
- latest_user_msg = get_latest_user_message(request.messages)
- if not latest_user_msg:
- raise ValueError("请求中没有找到user角色的消息")
- chatHistory[username].append(ChatMessage(
- role=latest_user_msg.role,
- content=latest_user_msg.content,
- timestamp=datetime.now()
- ))
- api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
- tools = get_doubao_tools()
- previous_response_id = get_previous_response_id(username, chatHistory)
- stream = client.responses.create(
- model=config.MODEL_NAME,
- input=api_messages,
- tools=tools,
- stream=True,
- store=True,
- previous_response_id=previous_response_id,
- # thinking={"type": "auto"},
- )
- accumulated_content = ""
- response_id = None
- for chunk in stream:
- chunk_dict = chunk.__dict__ if hasattr(chunk, '__dict__') else {}
- event_type = chunk_dict.get('type', '')
- if event_type in ('response.output_text.delta', 'response.doubao_app_call_output_text.delta'):
- delta_text = chunk_dict.get('delta', '')
- if delta_text:
- accumulated_content += delta_text
- response_data = StreamResponse(
- content=delta_text,
- finished=False,
- model=config.MODEL_NAME,
- timestamp=datetime.now()
- )
- yield f"data: {response_data.model_dump_json()}\n\n"
- await asyncio.sleep(0.01)
- elif event_type == 'response.completed':
- if 'response' in chunk_dict and hasattr(chunk_dict['response'], 'id'):
- response_id = chunk_dict['response'].id
- save_chat_log(
- username=username,
- question=latest_user_msg.content,
- stream_mode=True,
- raw_response=repr(chunk_dict.get('response')),
- status="success",
- )
- if accumulated_content:
- chatHistory[username].append(ChatMessage(
- role="assistant",
- content=accumulated_content,
- timestamp=datetime.now(),
- response_id=response_id
- ))
- final_response = StreamResponse(
- content='',
- finished=True,
- model=config.MODEL_NAME,
- timestamp=datetime.now()
- )
- yield f"data: {final_response.model_dump_json()}\n\n"
- print("流式内容已全部输出")
- except Exception as e:
- error_response = {
- "error": str(e),
- "finished": True,
- "timestamp": datetime.now().isoformat()
- }
- save_chat_log(
- username=username,
- question=latest_user_msg.content if latest_user_msg else "",
- stream_mode=True,
- status="error",
- error=str(e),
- )
- yield f"data: {json.dumps(error_response)}\n\n"
- @router.post("/chat", response_model=ChatResponse)
- async def chat(
- request: ChatRequest,
- username: Annotated[str, Depends(resolve_username)],
- ):
- try:
- if username not in chatHistory:
- chatHistory[username] = []
- if request.stream:
- # ===== 流式输出处理 =====
- # 返回流式响应
- # StreamingResponse 用于处理SSE协议
- return StreamingResponse(
- generate_stream_response(request, username),
- media_type="text/plain",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Content-Type": "text/event-stream",
- }
- )
- # 以下是非流式输出处理
- latest_user_msg = get_latest_user_message(request.messages)
- if not latest_user_msg:
- raise ValueError("请求中没有找到user角色的消息")
- chatHistory[username].append(ChatMessage(
- role=latest_user_msg.role,
- content=latest_user_msg.content,
- timestamp=datetime.now()
- ))
- api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
- tools = get_doubao_tools()
- previous_response_id = get_previous_response_id(username, chatHistory)
- response = client.responses.create(
- model=config.MODEL_NAME,
- input=api_messages,
- tools=tools,
- stream=False,
- store=True,
- previous_response_id=previous_response_id,
- # The parameter `thinking` specified in the request are not valid: `thinking` can not be set when enable doubao_app built-in tool.
- # thinking={"type": "auto"},
- )
- save_chat_log(
- username=username,
- question=latest_user_msg.content,
- stream_mode=False,
- raw_response=repr(response),
- status="success",
- )
- if not (response.output and len(response.output) > 0):
- raise HTTPException(status_code=500, detail="AI模型返回了空响应")
- message_content = ""
- for item in response.output:
- if hasattr(item, 'type') and item.type == 'doubao_app_call':
- if hasattr(item, 'blocks') and item.blocks:
- for block in item.blocks:
- if hasattr(block, 'type') and block.type == 'output_text' and hasattr(block, 'text'):
- message_content += block.text
- elif hasattr(item, 'type') and item.type == 'message':
- if hasattr(item, 'content'):
- if isinstance(item.content, list):
- for content_item in item.content:
- if hasattr(content_item, 'text'):
- message_content += content_item.text
- else:
- message_content += str(item.content)
- if not message_content:
- raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容")
- assistant_message = ChatMessage(
- role="assistant",
- content=message_content,
- timestamp=datetime.now(),
- response_id=response.id
- )
- chatHistory[username].append(assistant_message)
- return ChatResponse(
- message=assistant_message,
- model=response.model,
- usage=response.usage.model_dump() if response.usage else None,
- response_id=response.id
- )
- except HTTPException:
- raise
- except Exception as e:
- error_message = f"处理聊天请求时发生错误: {str(e)}"
- save_chat_log(
- username=username,
- question=request.messages[-1].content if request.messages else "",
- stream_mode=request.stream,
- status="error",
- error=error_message,
- )
- raise HTTPException(status_code=500, detail=error_message)
- @router.get("/models")
- async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
- try:
- models = client.models.list()
- return {
- "models": [model.id for model in models.data],
- "default_model": config.MODEL_NAME,
- "user": current_user.username
- }
- except Exception:
- return {
- "models": [config.MODEL_NAME],
- "default_model": config.MODEL_NAME,
- "note": "使用默认模型配置",
- "user": current_user.username
- }
- @router.get("/history")
- async def get_user_history(
- current_user: Annotated[User, Depends(get_current_active_user)]
- ) -> List[ChatMessage]:
- username = current_user.username
- if username not in chatHistory:
- return []
- return [
- ChatMessage(role=msg.role, content=msg.content, timestamp=msg.timestamp)
- for msg in chatHistory[username]
- ]
- @router.delete("/history")
- async def clear_user_history(current_user: Annotated[User, Depends(get_current_active_user)]):
- username = current_user.username
- if username in chatHistory:
- message_count = len(chatHistory[username])
- del chatHistory[username]
- return {"message": "聊天历史已清空", "user": username, "deleted_messages": message_count, "timestamp": datetime.now()}
- return {"message": "用户没有聊天历史", "user": username, "deleted_messages": 0, "timestamp": datetime.now()}
- @router.get("/health")
- async def health_check():
- return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME}
- router.tags = ["聊天服务"]
- router.responses = {
- 401: {"description": "未授权 - 需要有效的JWT令牌"},
- 429: {"description": "请求过多 - 配额已用完"},
- 500: {"description": "服务器内部错误"}
- }
|