from fastapi import APIRouter, HTTPException, Depends from fastapi.responses import StreamingResponse from typing import List, Annotated from datetime import datetime import json import asyncio from ..core.ark_client import config, client from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools from ..routers.users import get_current_active_user, User from ..db.mongo import save_chat_log from ..dependencies.auth import resolve_username # 内存存储用户聊天历史 {username: [ChatMessage, ...]} chatHistory = {} router = APIRouter() async def generate_stream_response(request: ChatRequest, username: str): latest_user_msg = None try: latest_user_msg = get_latest_user_message(request.messages) if not latest_user_msg: raise ValueError("请求中没有找到user角色的消息") chatHistory[username].append(ChatMessage( role=latest_user_msg.role, content=latest_user_msg.content, timestamp=datetime.now() )) api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}] tools = get_doubao_tools() previous_response_id = get_previous_response_id(username, chatHistory) stream = client.responses.create( model=config.MODEL_NAME, input=api_messages, tools=tools, stream=True, store=True, previous_response_id=previous_response_id, # thinking={"type": "auto"}, ) accumulated_content = "" response_id = None for chunk in stream: chunk_dict = chunk.__dict__ if hasattr(chunk, '__dict__') else {} event_type = chunk_dict.get('type', '') if event_type in ('response.output_text.delta', 'response.doubao_app_call_output_text.delta'): delta_text = chunk_dict.get('delta', '') if delta_text: accumulated_content += delta_text response_data = StreamResponse( content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now() ) yield f"data: {response_data.model_dump_json()}\n\n" await asyncio.sleep(0.01) elif event_type == 'response.completed': if 'response' in chunk_dict and hasattr(chunk_dict['response'], 'id'): response_id = chunk_dict['response'].id save_chat_log( username=username, question=latest_user_msg.content, stream_mode=True, raw_response=repr(chunk_dict.get('response')), status="success", ) if accumulated_content: chatHistory[username].append(ChatMessage( role="assistant", content=accumulated_content, timestamp=datetime.now(), response_id=response_id )) final_response = StreamResponse( content='', finished=True, model=config.MODEL_NAME, timestamp=datetime.now() ) yield f"data: {final_response.model_dump_json()}\n\n" print("流式内容已全部输出") except Exception as e: error_response = { "error": str(e), "finished": True, "timestamp": datetime.now().isoformat() } save_chat_log( username=username, question=latest_user_msg.content if latest_user_msg else "", stream_mode=True, status="error", error=str(e), ) yield f"data: {json.dumps(error_response)}\n\n" @router.post("/chat", response_model=ChatResponse) async def chat( request: ChatRequest, username: Annotated[str, Depends(resolve_username)], ): try: if username not in chatHistory: chatHistory[username] = [] if request.stream: # ===== 流式输出处理 ===== # 返回流式响应 # StreamingResponse 用于处理SSE协议 return StreamingResponse( generate_stream_response(request, username), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream", } ) # 以下是非流式输出处理 latest_user_msg = get_latest_user_message(request.messages) if not latest_user_msg: raise ValueError("请求中没有找到user角色的消息") chatHistory[username].append(ChatMessage( role=latest_user_msg.role, content=latest_user_msg.content, timestamp=datetime.now() )) api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}] tools = get_doubao_tools() previous_response_id = get_previous_response_id(username, chatHistory) response = client.responses.create( model=config.MODEL_NAME, input=api_messages, tools=tools, stream=False, store=True, previous_response_id=previous_response_id, # The parameter `thinking` specified in the request are not valid: `thinking` can not be set when enable doubao_app built-in tool. # thinking={"type": "auto"}, ) save_chat_log( username=username, question=latest_user_msg.content, stream_mode=False, raw_response=repr(response), status="success", ) if not (response.output and len(response.output) > 0): raise HTTPException(status_code=500, detail="AI模型返回了空响应") message_content = "" for item in response.output: if hasattr(item, 'type') and item.type == 'doubao_app_call': if hasattr(item, 'blocks') and item.blocks: for block in item.blocks: if hasattr(block, 'type') and block.type == 'output_text' and hasattr(block, 'text'): message_content += block.text elif hasattr(item, 'type') and item.type == 'message': if hasattr(item, 'content'): if isinstance(item.content, list): for content_item in item.content: if hasattr(content_item, 'text'): message_content += content_item.text else: message_content += str(item.content) if not message_content: raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容") assistant_message = ChatMessage( role="assistant", content=message_content, timestamp=datetime.now(), response_id=response.id ) chatHistory[username].append(assistant_message) return ChatResponse( message=assistant_message, model=response.model, usage=response.usage.model_dump() if response.usage else None, response_id=response.id ) except HTTPException: raise except Exception as e: error_message = f"处理聊天请求时发生错误: {str(e)}" save_chat_log( username=username, question=request.messages[-1].content if request.messages else "", stream_mode=request.stream, status="error", error=error_message, ) raise HTTPException(status_code=500, detail=error_message) @router.get("/models") async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]): try: models = client.models.list() return { "models": [model.id for model in models.data], "default_model": config.MODEL_NAME, "user": current_user.username } except Exception: return { "models": [config.MODEL_NAME], "default_model": config.MODEL_NAME, "note": "使用默认模型配置", "user": current_user.username } @router.get("/history") async def get_user_history( current_user: Annotated[User, Depends(get_current_active_user)] ) -> List[ChatMessage]: username = current_user.username if username not in chatHistory: return [] return [ ChatMessage(role=msg.role, content=msg.content, timestamp=msg.timestamp) for msg in chatHistory[username] ] @router.delete("/history") async def clear_user_history(current_user: Annotated[User, Depends(get_current_active_user)]): username = current_user.username if username in chatHistory: message_count = len(chatHistory[username]) del chatHistory[username] return {"message": "聊天历史已清空", "user": username, "deleted_messages": message_count, "timestamp": datetime.now()} return {"message": "用户没有聊天历史", "user": username, "deleted_messages": 0, "timestamp": datetime.now()} @router.get("/health") async def health_check(): return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME} router.tags = ["聊天服务"] router.responses = { 401: {"description": "未授权 - 需要有效的JWT令牌"}, 429: {"description": "请求过多 - 配额已用完"}, 500: {"description": "服务器内部错误"} }