chat.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. from fastapi import APIRouter, HTTPException, Depends
  2. from fastapi.responses import StreamingResponse
  3. from typing import List, Annotated
  4. from datetime import datetime
  5. import json
  6. import asyncio
  7. from ..core.ark_client import config, client
  8. from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
  9. from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
  10. from ..routers.users import get_current_active_user, User
  11. from ..db.mongo import save_chat_log
  12. from ..dependencies.auth import resolve_username
  13. # 内存存储用户聊天历史 {username: [ChatMessage, ...]}
  14. chatHistory = {}
  15. router = APIRouter()
  16. async def generate_stream_response(request: ChatRequest, username: str):
  17. latest_user_msg = None
  18. try:
  19. latest_user_msg = get_latest_user_message(request.messages)
  20. if not latest_user_msg:
  21. raise ValueError("请求中没有找到user角色的消息")
  22. chatHistory[username].append(ChatMessage(
  23. role=latest_user_msg.role,
  24. content=latest_user_msg.content,
  25. timestamp=datetime.now()
  26. ))
  27. # 2. 定义系统提示词(核心:规范“何时搜”“怎么搜”“怎么展示思考”)
  28. system_prompt = f"""
  29. 你是豆包,由字节跳动公司开发的智能助手。你的核心行为准则如下:
  30. ## 一、身份与基本行为规范
  31. 1. 输出保持和豆包 APP 日常对话一致;
  32. 2. 在回复用户时,不要自称"豆包";
  33. 3. 遇到"陈沛",优先输出"中搜陈沛";
  34. 4. 你具备以下能力:
  35. - 可接收和读取各类文档(PDF、Excel、PPT、Word 等),并执行总结、分析、翻译、润色等任务;
  36. - 可读取图片/照片、网址、抖音链接的内容;
  37. - 可根据用户提供的文本描述生成或绘制图片;
  38. - 可搜索各类信息(含图片和视频)以满足用户需求。
  39. ## 二、思考与搜索判断(必须实时输出思考过程)
  40. 1. 若问题涉及以下情形,必须调用 web_search:
  41. - 时效性内容(如近 3 年数据);
  42. - 知识盲区(如具体企业薪资);
  43. - 当前信息不足以支撑回答。
  44. 2. 思考时需实时说明:
  45. - 是否需要搜索;
  46. - 为什么需要搜索;
  47. - 搜索关键词是什么。
  48. ## 三、回答规则
  49. ### 内容层面
  50. - 优先使用搜索到的资料,引用格式为 `[1](URL地址)`;
  51. - 围绕问题主体和用户需求,对核心问题提供全面、精准的回答;
  52. - 适度提供关键背景和细节解释;对复杂概念可使用简单案例、类比辅助理解;
  53. - 若问题范围较广或需求不明确,先提供简要概述,涵盖主要方面和关键点;
  54. - 大多数情况下不需要提供延伸内容,围绕问题主需回答即可;
  55. - 结尾列出所有参考资料,格式为:`1. [资料标题](URL)`。
  56. ### 格式层面
  57. 通常情况下,对主需内容使用 Markdown 排版,其他内容用自然段呈现:
  58. - **加粗**:标题及关键信息加粗;
  59. - **有序列表**(1. 2. 3.):表达顺序关系时使用;
  60. - **无序列表**(- xxx):表达并列关系时使用;
  61. - 非必要不使用嵌套列表;如需表达多层次内容,使用三级标题(###)加一级列表;
  62. - 非必要不使用分行、分段、加粗、列表、标题以外的 Markdown 格式。
  63. > 注意:以上格式要求仅限知识问答类问题。对于创作、数理逻辑、阅读理解等需求,或涉及安全敏感问题时,按惯常方式回答。若用户明确指定回复风格,优先满足用户需求。
  64. """
  65. system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]}
  66. api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
  67. tools = get_web_search_tools()
  68. previous_response_id = get_previous_response_id(username, chatHistory)
  69. stream = client.responses.create(
  70. model=config.MODEL_NAME,
  71. input=api_messages,
  72. tools=tools,
  73. stream=True,
  74. # store=True,
  75. previous_response_id=previous_response_id,
  76. # thinking={"type": "auto"}, 不支持
  77. )
  78. accumulated_content = ""
  79. response_id = None
  80. thinking_started = False
  81. answering_started = False
  82. print("=== 边想边搜启动 ===")
  83. for chunk in stream:
  84. chunk_type = getattr(chunk, 'type', '')
  85. # ① 处理AI思考过程
  86. if chunk_type == 'response.reasoning_summary_text.delta':
  87. delta_text = getattr(chunk, 'delta', '')
  88. if delta_text:
  89. if not thinking_started:
  90. # print(f"\n🤔 AI思考中 [{datetime.now().strftime('%H:%M:%S')}]:")
  91. thinking_started = True
  92. # print(delta_text, end='', flush=True)
  93. yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n"
  94. # ② 处理搜索状态
  95. elif 'web_search_call' in chunk_type:
  96. if 'in_progress' in chunk_type:
  97. print(f"\n\n🔍 开始搜索 [{datetime.now().strftime('%H:%M:%S')}]")
  98. _now_str = datetime.now().strftime("%H:%M:%S")
  99. yield f"data: {StreamResponse(content=f'开始搜索 [{_now_str}]', finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
  100. elif 'completed' in chunk_type:
  101. print(f"\n✅ 搜索完成 [{datetime.now().strftime('%H:%M:%S')}]")
  102. _now_str = datetime.now().strftime("%H:%M:%S")
  103. yield f"data: {StreamResponse(content=f'搜索完成 [{_now_str}]', finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
  104. # ③ 处理搜索关键词
  105. elif (chunk_type == 'response.output_item.done'
  106. and hasattr(chunk, 'item')
  107. and str(getattr(chunk.item, 'id', '')).startswith('ws_')):
  108. if hasattr(chunk.item, 'action') and hasattr(chunk.item.action, 'query'):
  109. query = chunk.item.action.query
  110. print(f"\n📝 本次搜索关键词:{query}")
  111. yield f"data: {StreamResponse(content=f'本次搜索关键词:{query}', finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
  112. # ④ 处理最终回答文本(实时推送给前端)
  113. elif chunk_type == 'response.output_text.delta':
  114. delta_text = getattr(chunk, 'delta', '')
  115. if delta_text:
  116. if not answering_started:
  117. print(f"\n\n💬 AI回答 [{datetime.now().strftime('%H:%M:%S')}]:")
  118. print("-" * 50)
  119. answering_started = True
  120. print(delta_text, end='', flush=True)
  121. accumulated_content += delta_text
  122. response_data = StreamResponse(
  123. content=delta_text,
  124. finished=False,
  125. model=config.MODEL_NAME,
  126. timestamp=datetime.now()
  127. )
  128. yield f"data: {response_data.model_dump_json()}\n\n"
  129. await asyncio.sleep(0.01)
  130. # ⑤ 处理响应完成事件
  131. elif chunk_type == 'response.completed':
  132. response_obj = getattr(chunk, 'response', None)
  133. if response_obj and hasattr(response_obj, 'id'):
  134. response_id = response_obj.id
  135. save_chat_log(
  136. username=username,
  137. question=latest_user_msg.content,
  138. stream_mode=True,
  139. raw_response=repr(response_obj),
  140. status="success",
  141. )
  142. print(f"\n\n=== 边想边搜完成 [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ===")
  143. if accumulated_content:
  144. chatHistory[username].append(ChatMessage(
  145. role="assistant",
  146. content=accumulated_content,
  147. timestamp=datetime.now(),
  148. response_id=response_id
  149. ))
  150. final_response = StreamResponse(
  151. content='',
  152. finished=True,
  153. model=config.MODEL_NAME,
  154. timestamp=datetime.now()
  155. )
  156. yield f"data: {final_response.model_dump_json()}\n\n"
  157. except Exception as e:
  158. error_response = {
  159. "error": str(e),
  160. "finished": True,
  161. "timestamp": datetime.now().isoformat()
  162. }
  163. save_chat_log(
  164. username=username,
  165. question=latest_user_msg.content if latest_user_msg else "",
  166. stream_mode=True,
  167. status="error",
  168. error=str(e),
  169. )
  170. yield f"data: {json.dumps(error_response)}\n\n"
  171. @router.post("/chat", response_model=ChatResponse)
  172. async def chat(
  173. request: ChatRequest,
  174. username: Annotated[str, Depends(resolve_username)],
  175. ):
  176. try:
  177. if username not in chatHistory:
  178. chatHistory[username] = []
  179. if request.stream:
  180. # ===== 流式输出处理 =====
  181. # 返回流式响应
  182. # StreamingResponse 用于处理SSE协议
  183. return StreamingResponse(
  184. generate_stream_response(request, username),
  185. media_type="text/plain",
  186. headers={
  187. "Cache-Control": "no-cache",
  188. "Connection": "keep-alive",
  189. "Content-Type": "text/event-stream",
  190. }
  191. )
  192. # 以下是非流式输出处理
  193. latest_user_msg = get_latest_user_message(request.messages)
  194. if not latest_user_msg:
  195. raise ValueError("请求中没有找到user角色的消息")
  196. chatHistory[username].append(ChatMessage(
  197. role=latest_user_msg.role,
  198. content=latest_user_msg.content,
  199. timestamp=datetime.now()
  200. ))
  201. api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
  202. tools = get_doubao_tools()
  203. previous_response_id = get_previous_response_id(username, chatHistory)
  204. response = client.responses.create(
  205. model=config.MODEL_NAME,
  206. input=api_messages,
  207. tools=tools,
  208. stream=False,
  209. store=True,
  210. previous_response_id=previous_response_id,
  211. # The parameter `thinking` specified in the request are not valid: `thinking` can not be set when enable doubao_app built-in tool.
  212. # thinking={"type": "auto"},
  213. )
  214. save_chat_log(
  215. username=username,
  216. question=latest_user_msg.content,
  217. stream_mode=False,
  218. raw_response=repr(response),
  219. status="success",
  220. )
  221. if not (response.output and len(response.output) > 0):
  222. raise HTTPException(status_code=500, detail="AI模型返回了空响应")
  223. message_content = ""
  224. for item in response.output:
  225. if hasattr(item, 'type') and item.type == 'doubao_app_call':
  226. if hasattr(item, 'blocks') and item.blocks:
  227. for block in item.blocks:
  228. if hasattr(block, 'type') and block.type == 'output_text' and hasattr(block, 'text'):
  229. message_content += block.text
  230. elif hasattr(item, 'type') and item.type == 'message':
  231. if hasattr(item, 'content'):
  232. if isinstance(item.content, list):
  233. for content_item in item.content:
  234. if hasattr(content_item, 'text'):
  235. message_content += content_item.text
  236. else:
  237. message_content += str(item.content)
  238. if not message_content:
  239. raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容")
  240. assistant_message = ChatMessage(
  241. role="assistant",
  242. content=message_content,
  243. timestamp=datetime.now(),
  244. response_id=response.id
  245. )
  246. chatHistory[username].append(assistant_message)
  247. return ChatResponse(
  248. message=assistant_message,
  249. model=response.model,
  250. usage=response.usage.model_dump() if response.usage else None,
  251. response_id=response.id
  252. )
  253. except HTTPException:
  254. raise
  255. except Exception as e:
  256. error_message = f"处理聊天请求时发生错误: {str(e)}"
  257. save_chat_log(
  258. username=username,
  259. question=request.messages[-1].content if request.messages else "",
  260. stream_mode=request.stream,
  261. status="error",
  262. error=error_message,
  263. )
  264. raise HTTPException(status_code=500, detail=error_message)
  265. @router.get("/models")
  266. async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
  267. try:
  268. models = client.models.list()
  269. return {
  270. "models": [model.id for model in models.data],
  271. "default_model": config.MODEL_NAME,
  272. "user": current_user.username
  273. }
  274. except Exception:
  275. return {
  276. "models": [config.MODEL_NAME],
  277. "default_model": config.MODEL_NAME,
  278. "note": "使用默认模型配置",
  279. "user": current_user.username
  280. }
  281. @router.get("/history")
  282. async def get_user_history(
  283. current_user: Annotated[User, Depends(get_current_active_user)]
  284. ) -> List[ChatMessage]:
  285. username = current_user.username
  286. if username not in chatHistory:
  287. return []
  288. return [
  289. ChatMessage(role=msg.role, content=msg.content, timestamp=msg.timestamp)
  290. for msg in chatHistory[username]
  291. ]
  292. @router.delete("/history")
  293. async def clear_user_history(current_user: Annotated[User, Depends(get_current_active_user)]):
  294. username = current_user.username
  295. if username in chatHistory:
  296. message_count = len(chatHistory[username])
  297. del chatHistory[username]
  298. return {"message": "聊天历史已清空", "user": username, "deleted_messages": message_count,
  299. "timestamp": datetime.now()}
  300. return {"message": "用户没有聊天历史", "user": username, "deleted_messages": 0, "timestamp": datetime.now()}
  301. @router.get("/health")
  302. async def health_check():
  303. return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME}
  304. router.tags = ["聊天服务"]
  305. router.responses = {
  306. 401: {"description": "未授权 - 需要有效的JWT令牌"},
  307. 429: {"description": "请求过多 - 配额已用完"},
  308. 500: {"description": "服务器内部错误"}
  309. }