chat.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from fastapi import APIRouter, HTTPException, Depends
  2. from fastapi.responses import StreamingResponse
  3. from typing import List, Annotated
  4. from datetime import datetime
  5. import json
  6. import asyncio
  7. from ..core.ark_client import config, client
  8. from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
  9. from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools
  10. from ..routers.users import get_current_active_user, User
  11. from ..db.mongo import save_chat_log
  12. from ..dependencies.auth import resolve_username
  13. # 内存存储用户聊天历史 {username: [ChatMessage, ...]}
  14. chatHistory = {}
  15. router = APIRouter()
  16. async def generate_stream_response(request: ChatRequest, username: str):
  17. latest_user_msg = None
  18. try:
  19. latest_user_msg = get_latest_user_message(request.messages)
  20. if not latest_user_msg:
  21. raise ValueError("请求中没有找到user角色的消息")
  22. chatHistory[username].append(ChatMessage(
  23. role=latest_user_msg.role,
  24. content=latest_user_msg.content,
  25. timestamp=datetime.now()
  26. ))
  27. api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
  28. tools = get_doubao_tools()
  29. previous_response_id = get_previous_response_id(username, chatHistory)
  30. stream = client.responses.create(
  31. model=config.MODEL_NAME,
  32. input=api_messages,
  33. tools=tools,
  34. stream=True,
  35. store=True,
  36. previous_response_id=previous_response_id,
  37. # thinking={"type": "auto"},
  38. )
  39. accumulated_content = ""
  40. response_id = None
  41. for chunk in stream:
  42. chunk_dict = chunk.__dict__ if hasattr(chunk, '__dict__') else {}
  43. event_type = chunk_dict.get('type', '')
  44. if event_type in ('response.output_text.delta', 'response.doubao_app_call_output_text.delta'):
  45. delta_text = chunk_dict.get('delta', '')
  46. if delta_text:
  47. accumulated_content += delta_text
  48. response_data = StreamResponse(
  49. content=delta_text,
  50. finished=False,
  51. model=config.MODEL_NAME,
  52. timestamp=datetime.now()
  53. )
  54. yield f"data: {response_data.model_dump_json()}\n\n"
  55. await asyncio.sleep(0.01)
  56. elif event_type == 'response.completed':
  57. if 'response' in chunk_dict and hasattr(chunk_dict['response'], 'id'):
  58. response_id = chunk_dict['response'].id
  59. save_chat_log(
  60. username=username,
  61. question=latest_user_msg.content,
  62. stream_mode=True,
  63. raw_response=repr(chunk_dict.get('response')),
  64. status="success",
  65. )
  66. if accumulated_content:
  67. chatHistory[username].append(ChatMessage(
  68. role="assistant",
  69. content=accumulated_content,
  70. timestamp=datetime.now(),
  71. response_id=response_id
  72. ))
  73. final_response = StreamResponse(
  74. content='',
  75. finished=True,
  76. model=config.MODEL_NAME,
  77. timestamp=datetime.now()
  78. )
  79. yield f"data: {final_response.model_dump_json()}\n\n"
  80. print("流式内容已全部输出")
  81. except Exception as e:
  82. error_response = {
  83. "error": str(e),
  84. "finished": True,
  85. "timestamp": datetime.now().isoformat()
  86. }
  87. save_chat_log(
  88. username=username,
  89. question=latest_user_msg.content if latest_user_msg else "",
  90. stream_mode=True,
  91. status="error",
  92. error=str(e),
  93. )
  94. yield f"data: {json.dumps(error_response)}\n\n"
  95. @router.post("/chat", response_model=ChatResponse)
  96. async def chat(
  97. request: ChatRequest,
  98. username: Annotated[str, Depends(resolve_username)],
  99. ):
  100. try:
  101. if username not in chatHistory:
  102. chatHistory[username] = []
  103. if request.stream:
  104. # ===== 流式输出处理 =====
  105. # 返回流式响应
  106. # StreamingResponse 用于处理SSE协议
  107. return StreamingResponse(
  108. generate_stream_response(request, username),
  109. media_type="text/plain",
  110. headers={
  111. "Cache-Control": "no-cache",
  112. "Connection": "keep-alive",
  113. "Content-Type": "text/event-stream",
  114. }
  115. )
  116. # 以下是非流式输出处理
  117. latest_user_msg = get_latest_user_message(request.messages)
  118. if not latest_user_msg:
  119. raise ValueError("请求中没有找到user角色的消息")
  120. chatHistory[username].append(ChatMessage(
  121. role=latest_user_msg.role,
  122. content=latest_user_msg.content,
  123. timestamp=datetime.now()
  124. ))
  125. api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
  126. tools = get_doubao_tools()
  127. previous_response_id = get_previous_response_id(username, chatHistory)
  128. response = client.responses.create(
  129. model=config.MODEL_NAME,
  130. input=api_messages,
  131. tools=tools,
  132. stream=False,
  133. store=True,
  134. previous_response_id=previous_response_id,
  135. # The parameter `thinking` specified in the request are not valid: `thinking` can not be set when enable doubao_app built-in tool.
  136. # thinking={"type": "auto"},
  137. )
  138. save_chat_log(
  139. username=username,
  140. question=latest_user_msg.content,
  141. stream_mode=False,
  142. raw_response=repr(response),
  143. status="success",
  144. )
  145. if not (response.output and len(response.output) > 0):
  146. raise HTTPException(status_code=500, detail="AI模型返回了空响应")
  147. message_content = ""
  148. for item in response.output:
  149. if hasattr(item, 'type') and item.type == 'doubao_app_call':
  150. if hasattr(item, 'blocks') and item.blocks:
  151. for block in item.blocks:
  152. if hasattr(block, 'type') and block.type == 'output_text' and hasattr(block, 'text'):
  153. message_content += block.text
  154. elif hasattr(item, 'type') and item.type == 'message':
  155. if hasattr(item, 'content'):
  156. if isinstance(item.content, list):
  157. for content_item in item.content:
  158. if hasattr(content_item, 'text'):
  159. message_content += content_item.text
  160. else:
  161. message_content += str(item.content)
  162. if not message_content:
  163. raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容")
  164. assistant_message = ChatMessage(
  165. role="assistant",
  166. content=message_content,
  167. timestamp=datetime.now(),
  168. response_id=response.id
  169. )
  170. chatHistory[username].append(assistant_message)
  171. return ChatResponse(
  172. message=assistant_message,
  173. model=response.model,
  174. usage=response.usage.model_dump() if response.usage else None,
  175. response_id=response.id
  176. )
  177. except HTTPException:
  178. raise
  179. except Exception as e:
  180. error_message = f"处理聊天请求时发生错误: {str(e)}"
  181. save_chat_log(
  182. username=username,
  183. question=request.messages[-1].content if request.messages else "",
  184. stream_mode=request.stream,
  185. status="error",
  186. error=error_message,
  187. )
  188. raise HTTPException(status_code=500, detail=error_message)
  189. @router.get("/models")
  190. async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
  191. try:
  192. models = client.models.list()
  193. return {
  194. "models": [model.id for model in models.data],
  195. "default_model": config.MODEL_NAME,
  196. "user": current_user.username
  197. }
  198. except Exception:
  199. return {
  200. "models": [config.MODEL_NAME],
  201. "default_model": config.MODEL_NAME,
  202. "note": "使用默认模型配置",
  203. "user": current_user.username
  204. }
  205. @router.get("/history")
  206. async def get_user_history(
  207. current_user: Annotated[User, Depends(get_current_active_user)]
  208. ) -> List[ChatMessage]:
  209. username = current_user.username
  210. if username not in chatHistory:
  211. return []
  212. return [
  213. ChatMessage(role=msg.role, content=msg.content, timestamp=msg.timestamp)
  214. for msg in chatHistory[username]
  215. ]
  216. @router.delete("/history")
  217. async def clear_user_history(current_user: Annotated[User, Depends(get_current_active_user)]):
  218. username = current_user.username
  219. if username in chatHistory:
  220. message_count = len(chatHistory[username])
  221. del chatHistory[username]
  222. return {"message": "聊天历史已清空", "user": username, "deleted_messages": message_count, "timestamp": datetime.now()}
  223. return {"message": "用户没有聊天历史", "user": username, "deleted_messages": 0, "timestamp": datetime.now()}
  224. @router.get("/health")
  225. async def health_check():
  226. return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME}
  227. router.tags = ["聊天服务"]
  228. router.responses = {
  229. 401: {"description": "未授权 - 需要有效的JWT令牌"},
  230. 429: {"description": "请求过多 - 配额已用完"},
  231. 500: {"description": "服务器内部错误"}
  232. }