chat.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. from fastapi import APIRouter, HTTPException, Depends, Query
  2. from fastapi.responses import StreamingResponse
  3. from typing import List, Annotated, Optional
  4. from datetime import datetime
  5. import json
  6. import asyncio
  7. import threading
  8. from ..core.ark_client import config, client
  9. from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
  10. from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools, get_knowledge_search_tools
  11. from ..routers.users import get_current_active_user, User
  12. from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions
  13. from ..dependencies.auth import resolve_user_id
  14. router = APIRouter()
  15. async def generate_stream_response(request: ChatRequest, user_id: str):
  16. session_id = request.session_id
  17. latest_user_msg = None
  18. try:
  19. latest_user_msg = get_latest_user_message(request.messages)
  20. if not latest_user_msg:
  21. raise ValueError("请求中没有找到user角色的消息")
  22. # 保存用户消息到 DB
  23. save_chat_history(
  24. user_id=user_id,
  25. session_id=session_id,
  26. role=latest_user_msg.role,
  27. content=latest_user_msg.content,
  28. timestamp=datetime.now(),
  29. )
  30. system_prompt = f"""
  31. 你是字节跳动公司开发的智能助手。你的核心行为准则如下:
  32. ## 一、身份与基本行为规范
  33. 1. 输出保持和豆包 APP 日常对话一致;
  34. 2. 在回复用户时,不要自称"豆包";
  35. 3. 遇到“陈沛”相关问题时,应结合Knowledge search工具并展示其多维度信息,必须补充涵盖其在云悦相关企业的完整任职、持股及业务布局经历,确保回答内容完整、立体;
  36. 4. 你具备以下能力:
  37. - 可接收和读取各类文档(PDF、Excel、PPT、Word 等),并执行总结、分析、翻译、润色等任务;
  38. - 可读取图片/照片、网址、抖音链接的内容;
  39. - 可根据用户提供的文本描述生成或绘制图片;
  40. - 可搜索各类信息(含图片和视频)以满足用户需求。
  41. ## 二、思考与搜索判断(必须实时输出思考过程)
  42. 1. 若问题涉及以下情形,必须调用 web_search:
  43. - 时效性内容(如近 3 年数据);
  44. - 知识盲区(如具体企业薪资);
  45. - 当前信息不足以支撑回答。
  46. 2. 思考时需实时说明:
  47. - 是否需要搜索;
  48. - 为什么需要搜索;
  49. - 搜索关键词是什么。
  50. ## 三、回答规则
  51. ### 内容层面
  52. - 优先使用搜索到的资料,引用格式为 `[1](URL地址)`;
  53. - 围绕问题主体和用户需求,对核心问题提供全面、精准的回答;
  54. - 适度提供关键背景和细节解释;对复杂概念可使用简单案例、类比辅助理解;
  55. - 若问题范围较广或需求不明确,先提供简要概述,涵盖主要方面和关键点;
  56. - 大多数情况下不需要提供延伸内容,围绕问题主需回答即可;
  57. - 结尾列出所有参考资料,格式为:`1. [资料标题](URL)`。
  58. ### 格式层面
  59. 通常情况下,对主需内容使用 Markdown 排版,其他内容用自然段呈现:
  60. - **加粗**:标题及关键信息加粗;
  61. - **有序列表**(1. 2. 3.):表达顺序关系时使用;
  62. - **无序列表**(- xxx):表达并列关系时使用;
  63. - 非必要不使用嵌套列表;如需表达多层次内容,使用三级标题(###)加一级列表;
  64. - 非必要不使用分行、分段、加粗、列表、标题以外的 Markdown 格式。
  65. > 注意:以上格式要求仅限知识问答类问题。对于创作、数理逻辑、阅读理解等需求,或涉及安全敏感问题时,按惯常方式回答。若用户明确指定回复风格,优先满足用户需求。
  66. """
  67. system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]}
  68. api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
  69. tools = get_web_search_tools() + get_knowledge_search_tools()
  70. previous_response_id = get_previous_response_id(user_id, session_id)
  71. stream = client.responses.create(
  72. model=config.MODEL_NAME,
  73. input=api_messages,
  74. tools=tools,
  75. stream=True,
  76. previous_response_id=previous_response_id,
  77. )
  78. accumulated_content = ""
  79. accumulated_thinking = ""
  80. accumulated_searching = ""
  81. response_id = None
  82. # 将同步阻塞的 stream 迭代放入子线程,通过 Queue 传递给异步生成器
  83. # 避免阻塞事件循环,保证每个 chunk 到达时立即 yield 推送给前端
  84. loop = asyncio.get_event_loop()
  85. queue: asyncio.Queue = asyncio.Queue()
  86. def _iterate_stream():
  87. try:
  88. for chunk in stream:
  89. loop.call_soon_threadsafe(queue.put_nowait, chunk)
  90. except Exception as e:
  91. loop.call_soon_threadsafe(queue.put_nowait, e)
  92. finally:
  93. loop.call_soon_threadsafe(queue.put_nowait, None) # 结束哨兵
  94. threading.Thread(target=_iterate_stream, daemon=True).start()
  95. print("=== 边想边搜启动 ===")
  96. while True:
  97. chunk = await queue.get()
  98. if chunk is None:
  99. break
  100. if isinstance(chunk, Exception):
  101. raise chunk
  102. chunk_type = getattr(chunk, 'type', '')
  103. # ① 处理AI思考过程
  104. if chunk_type == 'response.reasoning_summary_text.delta':
  105. delta_text = getattr(chunk, 'delta', '')
  106. if delta_text:
  107. accumulated_thinking += delta_text
  108. yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n"
  109. # ② 处理搜索状态
  110. elif 'web_search_call' in chunk_type:
  111. if 'in_progress' in chunk_type:
  112. _now_str = datetime.now().strftime("%H:%M:%S")
  113. msg = f'开始搜索 [{_now_str}]'
  114. accumulated_searching += msg + "\n"
  115. yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
  116. elif 'completed' in chunk_type:
  117. _now_str = datetime.now().strftime("%H:%M:%S")
  118. msg = f'搜索完成 [{_now_str}]'
  119. accumulated_searching += msg + "\n"
  120. yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
  121. # ③ 处理搜索关键词
  122. elif (chunk_type == 'response.output_item.done'
  123. and hasattr(chunk, 'item')
  124. and str(getattr(chunk.item, 'id', '')).startswith('ws_')):
  125. if hasattr(chunk.item, 'action') and hasattr(chunk.item.action, 'query'):
  126. query = chunk.item.action.query
  127. msg = f'搜索关键词: {query}'
  128. accumulated_searching += msg + "\n"
  129. yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
  130. # ④ 处理最终回答文本(实时推送给前端)
  131. elif chunk_type == 'response.output_text.delta':
  132. delta_text = getattr(chunk, 'delta', '')
  133. if delta_text:
  134. accumulated_content += delta_text
  135. yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
  136. # ⑤ 处理响应完成事件
  137. elif chunk_type == 'response.completed':
  138. response_obj = getattr(chunk, 'response', None)
  139. if response_obj and hasattr(response_obj, 'id'):
  140. response_id = response_obj.id
  141. save_chat_log(
  142. user_id=user_id,
  143. question=latest_user_msg.content,
  144. stream_mode=True,
  145. raw_response=repr(response_obj),
  146. status="success",
  147. )
  148. print(f"\n\n=== 边想边搜完成 [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ===")
  149. if accumulated_content:
  150. # 保存助手消息到 DB(含 thinking / searching)
  151. save_chat_history(
  152. user_id=user_id,
  153. session_id=session_id,
  154. role="assistant",
  155. content=accumulated_content,
  156. timestamp=datetime.now(),
  157. response_id=response_id,
  158. thinking=accumulated_thinking or None,
  159. searching=accumulated_searching or None,
  160. )
  161. yield f"data: {StreamResponse(content='', finished=True, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
  162. except Exception as e:
  163. error_response = {
  164. "error": str(e),
  165. "finished": True,
  166. "timestamp": datetime.now().isoformat()
  167. }
  168. save_chat_log(
  169. user_id=user_id,
  170. question=latest_user_msg.content if latest_user_msg else "",
  171. stream_mode=True,
  172. status="error",
  173. error=str(e),
  174. )
  175. yield f"data: {json.dumps(error_response)}\n\n"
  176. @router.post("/chat", response_model=ChatResponse)
  177. async def chat(
  178. request: ChatRequest,
  179. user_id: Annotated[str, Depends(resolve_user_id)],
  180. ):
  181. try:
  182. if request.stream:
  183. return StreamingResponse(
  184. generate_stream_response(request, user_id),
  185. media_type="text/plain",
  186. headers={
  187. "Cache-Control": "no-cache",
  188. "Connection": "keep-alive",
  189. "Content-Type": "text/event-stream",
  190. }
  191. )
  192. # 以下是非流式输出处理
  193. session_id = request.session_id
  194. latest_user_msg = get_latest_user_message(request.messages)
  195. if not latest_user_msg:
  196. raise ValueError("请求中没有找到user角色的消息")
  197. save_chat_history(
  198. user_id=user_id,
  199. session_id=session_id,
  200. role=latest_user_msg.role,
  201. content=latest_user_msg.content,
  202. timestamp=datetime.now(),
  203. )
  204. api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
  205. tools = get_doubao_tools()
  206. previous_response_id = get_previous_response_id(user_id, session_id)
  207. response = client.responses.create(
  208. model=config.MODEL_NAME,
  209. input=api_messages,
  210. tools=tools,
  211. stream=False,
  212. store=True,
  213. previous_response_id=previous_response_id,
  214. )
  215. save_chat_log(
  216. user_id=user_id,
  217. question=latest_user_msg.content,
  218. stream_mode=False,
  219. raw_response=repr(response),
  220. status="success",
  221. )
  222. if not (response.output and len(response.output) > 0):
  223. raise HTTPException(status_code=500, detail="AI模型返回了空响应")
  224. message_content = ""
  225. for item in response.output:
  226. if hasattr(item, 'type') and item.type == 'doubao_app_call':
  227. if hasattr(item, 'blocks') and item.blocks:
  228. for block in item.blocks:
  229. if hasattr(block, 'type') and block.type == 'output_text' and hasattr(block, 'text'):
  230. message_content += block.text
  231. elif hasattr(item, 'type') and item.type == 'message':
  232. if hasattr(item, 'content'):
  233. if isinstance(item.content, list):
  234. for content_item in item.content:
  235. if hasattr(content_item, 'text'):
  236. message_content += content_item.text
  237. else:
  238. message_content += str(item.content)
  239. if not message_content:
  240. raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容")
  241. now = datetime.now()
  242. save_chat_history(
  243. user_id=user_id,
  244. session_id=session_id,
  245. role="assistant",
  246. content=message_content,
  247. timestamp=now,
  248. response_id=response.id,
  249. )
  250. assistant_message = ChatMessage(
  251. role="assistant",
  252. content=message_content,
  253. timestamp=now,
  254. response_id=response.id,
  255. )
  256. return ChatResponse(
  257. message=assistant_message,
  258. model=response.model,
  259. usage=response.usage.model_dump() if response.usage else None,
  260. response_id=response.id,
  261. )
  262. except HTTPException:
  263. raise
  264. except Exception as e:
  265. error_message = f"处理聊天请求时发生错误: {str(e)}"
  266. save_chat_log(
  267. user_id=user_id,
  268. question=request.messages[-1].content if request.messages else "",
  269. stream_mode=request.stream,
  270. status="error",
  271. error=error_message,
  272. )
  273. raise HTTPException(status_code=500, detail=error_message)
  274. @router.get("/models")
  275. async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
  276. try:
  277. models = client.models.list()
  278. return {
  279. "models": [model.id for model in models.data],
  280. "default_model": config.MODEL_NAME,
  281. "user": current_user.username
  282. }
  283. except Exception:
  284. return {
  285. "models": [config.MODEL_NAME],
  286. "default_model": config.MODEL_NAME,
  287. "note": "使用默认模型配置",
  288. "user": current_user.username
  289. }
  290. @router.get("/history")
  291. async def get_user_history(
  292. current_user: Annotated[User, Depends(get_current_active_user)],
  293. sessionId: str = Query(..., description="会话ID"),
  294. ) -> List[ChatMessage]:
  295. docs = get_chat_history(current_user.userId, sessionId)
  296. return [
  297. ChatMessage(
  298. role=doc["role"],
  299. content=doc["content"],
  300. timestamp=doc.get("timestamp"),
  301. response_id=doc.get("response_id"),
  302. thinking=doc.get("thinking"),
  303. searching=doc.get("searching"),
  304. )
  305. for doc in docs
  306. ]
  307. @router.delete("/history")
  308. async def clear_user_history(
  309. current_user: Annotated[User, Depends(get_current_active_user)],
  310. sessionId: str = Query(..., description="会话ID"),
  311. ):
  312. deleted_count = delete_chat_history(current_user.userId, sessionId)
  313. if deleted_count > 0:
  314. return {"message": "聊天历史已清空", "user": current_user.userId, "deleted_messages": deleted_count, "timestamp": datetime.now()}
  315. return {"message": "用户没有聊天历史", "user": current_user.userId, "deleted_messages": 0, "timestamp": datetime.now()}
  316. @router.get("/sessions")
  317. async def get_user_sessions(
  318. current_user: Annotated[User, Depends(get_current_active_user)],
  319. ):
  320. return get_sessions(current_user.userId)
  321. @router.get("/health")
  322. async def health_check():
  323. return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME}
  324. router.tags = ["聊天服务"]
  325. router.responses = {
  326. 401: {"description": "未授权 - 需要有效的JWT令牌"},
  327. 429: {"description": "请求过多 - 配额已用完"},
  328. 500: {"description": "服务器内部错误"}
  329. }