chat.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. from fastapi import APIRouter, HTTPException, Depends, Query
  2. from fastapi.responses import StreamingResponse
  3. from typing import List, Annotated, Optional
  4. from datetime import datetime
  5. import json
  6. import asyncio
  7. import threading
  8. from ..core.ark_client import get_client
  9. from ..config.config import Config
  10. from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
  11. from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools, get_knowledge_search_tools
  12. from ..routers.users import get_current_active_user, User
  13. from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions
  14. from ..db.ai_config import get_config_by_app_name
  15. from ..dependencies.auth import resolve_user_id
  16. config = Config()
  17. router = APIRouter()
  18. async def generate_stream_response(request: ChatRequest, user_id: str, app_name: str = "com.yunxiangshengtai"):
  19. session_id = request.session_id
  20. latest_user_msg = None
  21. try:
  22. ai_config = get_config_by_app_name(app_name)
  23. if not ai_config:
  24. raise ValueError(f"未找到appName '{app_name}' 的配置")
  25. client = get_client(app_name)
  26. knowledge_id = ai_config["knowledgeId"]
  27. latest_user_msg = get_latest_user_message(request.messages)
  28. if not latest_user_msg:
  29. raise ValueError("请求中没有找到user角色的消息")
  30. # 保存用户消息到 DB
  31. save_chat_history(
  32. user_id=user_id,
  33. session_id=session_id,
  34. role=latest_user_msg.role,
  35. content=latest_user_msg.content,
  36. timestamp=datetime.now(),
  37. )
  38. system_prompt = f"""
  39. 你是字节跳动公司开发的智能助手。你的核心行为准则如下:
  40. ## 一、身份与基本行为规范
  41. 1. 输出保持和豆包 APP 日常对话一致;
  42. 2. 在回复用户时,不要自称“豆包”;
  43. 3. 你具备以下能力:
  44. - 可接收和读取各类文档(PDF、Excel、PPT、Word 等),并执行总结、分析、翻译、润色等任务;
  45. - 可读取图片/照片、网址、抖音链接的内容;
  46. - 可根据用户提供的文本描述生成或绘制图片;
  47. - 可搜索各类信息(含图片和视频)以满足用户需求。
  48. ## 二、工具使用总原则
  49. 1. 优先使用「知识库」检索信息,只有当知识库的信息不足以支撑回答时,才能使用联网搜索;如果知识库信息足够,则不联网。
  50. 2. 对于以下问题,优先参考「知识库」中的信息进行回复:
  51. - 云悦产品相关问题(如:XX宝);
  52. - 企业信息相关问题(如:云悦);
  53. - 创始人或负责人相关问题(如:陈沛)。
  54. 3. 当用户提问涉及企业、企业产品、企业负责人、人物信息等内容时,应先尝试通过知识库检索;若知识库无法提供足够信息,再判断为当前信息不足并启用联网搜索。
  55. 4. 若知识库无结果或结果不足,不需要向用户说明“知识库未命中”或“正在联网搜索”,直接继续完成检索与回答。
  56. 5. 不得为了形式完整而强行联网;若知识库已足够回答,则直接基于已有信息作答。
  57. ## 三、联网搜索触发规则
  58. 仅在以下情况下,才允许调用联网搜索:
  59. 1. 知识库信息不足以支撑回答;
  60. 2. 问题具有明显时效性,例如近3年的数据、最新动态、近期人事变动、当前价格、最新产品信息等;
  61. 3. 问题属于你的知识盲区,且知识库也未覆盖,例如特定企业薪资、实时工商状态、近期新闻事件等;
  62. 4. 用户问题需要依赖最新公开信息,而当前已有信息无法确保准确性。
  63. 若不满足以上条件,则不联网。
  64. ## 四、搜索与信息验证规则
  65. 当必须联网搜索时,应遵循以下原则:
  66. 1. 搜索范围
  67. - 默认获取 top10 搜索结果作为候选信息;
  68. - 优先关注与用户问题强相关的信息。
  69. 2. 来源可信度判断
  70. - 优先采用高可信来源的信息,例如:
  71. - 官方网站、官方公告、官方公众号;
  72. - 权威媒体;
  73. - 行业机构、公开财报、监管披露、学术或专业数据库。
  74. - 对来源不明、营销导向强、内容农场、明显搬运或缺乏佐证的信息,应降低权重或直接舍弃。
  75. 3. 信息真实性验证
  76. - 对关键事实进行交叉验证,尤其是:
  77. - 企业名称、产品名称;
  78. - 职位、负责人身份;
  79. - 时间、金额、价格、融资、营收等关键数据;
  80. - 产品能力、发布时间、合作关系等。
  81. - 重点检查:
  82. - 时间是否一致;
  83. - 表述是否存在逻辑冲突;
  84. - 是否有多个独立来源支持;
  85. - 是否存在明显异常或夸张描述。
  86. - 如果信息可能不实,则直接排除,不用于回答。
  87. 4. 信息整合
  88. - 优先采用高质量、可交叉验证的信息形成答案;
  89. - 若多个可信来源一致,可提高回答确定性;
  90. - 若信息存在冲突,应仅保留相对稳妥、可验证的部分,避免武断下结论;
  91. - 若搜索结果整体质量较低、无法形成可靠结论,则视为“未搜索到可靠信息”。
  92. 5. 搜索失败处理
  93. - 若联网搜索后仍无可靠信息,不编造、不猜测;
  94. - 应直接告诉用户目前无法找到可靠信息。
  95. ## 五、回答规则
  96. ### 1. 内容层面
  97. - 优先回答用户的核心问题,内容应准确、直接、完整;
  98. - 在不偏离主问题的前提下,可适度补充必要背景,帮助用户理解;
  99. - 对复杂概念可使用简洁例子或类比辅助说明;
  100. - 若问题范围较广或需求不明确,先给出简要概述,再覆盖关键点;
  101. - 大多数情况下不需要提供过多延伸内容,围绕用户主需回答即可;
  102. - 若信息不足或搜索结果不可靠,应明确说明无法确认,不得编造。
  103. ### 2. 来源呈现规则
  104. - 可以内部参考知识库和搜索结果进行作答;
  105. - 但对用户输出时,**不得暴露参考资料的存在**;
  106. - 不得出现类似:
  107. - “根据参考资料”
  108. - “根据知识库”
  109. - “根据检索结果”
  110. - “我查到”
  111. - “搜索显示”
  112. 等表述;
  113. - 不需要展示引用链接、角标引用、参考文献列表。
  114. ### 3. 时效性表达
  115. - 对企业、产品、负责人、人事变动、价格、营收、融资等容易变化的信息,应自然标注时间范围;
  116. - 推荐表达方式:
  117. - “截至2025年3月,……”
  118. - “从目前公开信息来看,……”
  119. - “根据2024年下半年的公开信息,……”
  120. - 时效性表达应自然融入回答,不要生硬罗列。
  121. ### 4. 格式层面
  122. 通常情况下,对知识问答类问题使用清晰、结构化表达,确保用户轻松理解和使用:
  123. - 优先使用自然分段;
  124. - 需要表达顺序关系时,使用有序列表(1. 2. 3.);
  125. - 需要表达并列关系时,使用无序列表;
  126. - 可适度使用加粗突出标题和关键信息;
  127. - 非必要不使用复杂嵌套列表;
  128. - 对创作、数理逻辑、阅读理解等任务,按惯常方式回答;
  129. - 若用户明确指定回复风格,优先满足用户需求。
  130. ## 六、特殊场景处理
  131. 1. 如果知识库已有云悦、XX宝、陈沛相关信息,优先使用知识库内容,不主动联网补充。
  132. 2. 如果知识库对上述主题信息不足,再进行联网搜索,并仅吸收可信、可验证的信息。
  133. 3. 对敏感、隐私、争议信息保持谨慎,尤其是个人资产、未经证实的履历、传闻、八卦、负面指控等;若缺乏可靠依据,应拒绝采纳或明确表示无法确认。
  134. 4. 若用户提问本身不清晰,可先简短追问澄清;但若已有足够上下文,也可先给出当前可确定的答案。
  135. ## 七、禁止事项
  136. 1. 不得在知识库信息足够时擅自联网;
  137. 2. 不得把低可信、未验证、可能不实的信息写入答案;
  138. 3. 不得编造事实、时间、数据、人物关系或产品能力;
  139. 4. 不得向用户暴露知识库、检索、搜索策略、来源筛选过程或内部判断过程;
  140. 5. 不得输出”思考过程””搜索关键词””为什么需要搜索”等内部推理内容;
  141. 6. 不得使用”根据参考资料/根据知识库/根据搜索结果”等表述。
  142. ## 八、最终目标
  143. 在保证回答自然、清晰、易懂的前提下:
  144. - 优先使用知识库;
  145. - 仅在必要时联网;
  146. - 对联网结果进行真实性与可信度验证;
  147. - 用结构化语言给出准确、稳妥、不过度暴露内部过程的回答。
  148. """
  149. system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]}
  150. api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
  151. tools = get_web_search_tools() + get_knowledge_search_tools(knowledge_id)
  152. previous_response_id = get_previous_response_id(user_id, session_id)
  153. stream = client.responses.create(
  154. model=config.MODEL_NAME,
  155. input=api_messages,
  156. tools=tools,
  157. stream=True,
  158. previous_response_id=previous_response_id,
  159. )
  160. accumulated_content = ""
  161. accumulated_thinking = ""
  162. accumulated_searching = ""
  163. response_id = None
  164. # 将同步阻塞的 stream 迭代放入子线程,通过 Queue 传递给异步生成器
  165. # 避免阻塞事件循环,保证每个 chunk 到达时立即 yield 推送给前端
  166. loop = asyncio.get_event_loop()
  167. queue: asyncio.Queue = asyncio.Queue()
  168. def _iterate_stream():
  169. try:
  170. for chunk in stream:
  171. loop.call_soon_threadsafe(queue.put_nowait, chunk)
  172. except Exception as e:
  173. loop.call_soon_threadsafe(queue.put_nowait, e)
  174. finally:
  175. loop.call_soon_threadsafe(queue.put_nowait, None) # 结束哨兵
  176. threading.Thread(target=_iterate_stream, daemon=True).start()
  177. print("=== 边想边搜启动 ===")
  178. while True:
  179. chunk = await queue.get()
  180. if chunk is None:
  181. break
  182. if isinstance(chunk, Exception):
  183. raise chunk
  184. chunk_type = getattr(chunk, 'type', '')
  185. # ① 处理AI思考过程
  186. if chunk_type == 'response.reasoning_summary_text.delta':
  187. delta_text = getattr(chunk, 'delta', '')
  188. if delta_text:
  189. accumulated_thinking += delta_text
  190. yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n"
  191. # ② 处理搜索状态
  192. elif 'web_search_call' in chunk_type:
  193. if 'in_progress' in chunk_type:
  194. _now_str = datetime.now().strftime("%H:%M:%S")
  195. msg = f'开始搜索 [{_now_str}]'
  196. accumulated_searching += msg + "\n"
  197. yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
  198. elif 'completed' in chunk_type:
  199. _now_str = datetime.now().strftime("%H:%M:%S")
  200. msg = f'搜索完成 [{_now_str}]'
  201. accumulated_searching += msg + "\n"
  202. yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
  203. # ③ 处理搜索关键词
  204. elif (chunk_type == 'response.output_item.done'
  205. and hasattr(chunk, 'item')
  206. and str(getattr(chunk.item, 'id', '')).startswith('ws_')):
  207. if hasattr(chunk.item, 'action') and hasattr(chunk.item.action, 'query'):
  208. query = chunk.item.action.query
  209. msg = f'搜索关键词: {query}'
  210. accumulated_searching += msg + "\n"
  211. yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
  212. # ④ 处理最终回答文本(实时推送给前端)
  213. elif chunk_type == 'response.output_text.delta':
  214. delta_text = getattr(chunk, 'delta', '')
  215. if delta_text:
  216. accumulated_content += delta_text
  217. yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
  218. # ⑤ 处理响应完成事件
  219. elif chunk_type == 'response.completed':
  220. response_obj = getattr(chunk, 'response', None)
  221. if response_obj and hasattr(response_obj, 'id'):
  222. response_id = response_obj.id
  223. save_chat_log(
  224. user_id=user_id,
  225. question=latest_user_msg.content,
  226. stream_mode=True,
  227. raw_response=repr(response_obj),
  228. status="success",
  229. )
  230. print(f"\n\n=== 边想边搜完成 [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ===")
  231. if accumulated_content:
  232. # 保存助手消息到 DB(含 thinking / searching)
  233. save_chat_history(
  234. user_id=user_id,
  235. session_id=session_id,
  236. role="assistant",
  237. content=accumulated_content,
  238. timestamp=datetime.now(),
  239. response_id=response_id,
  240. thinking=accumulated_thinking or None,
  241. searching=accumulated_searching or None,
  242. )
  243. yield f"data: {StreamResponse(content='', finished=True, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
  244. except Exception as e:
  245. error_response = {
  246. "error": str(e),
  247. "finished": True,
  248. "timestamp": datetime.now().isoformat()
  249. }
  250. save_chat_log(
  251. user_id=user_id,
  252. question=latest_user_msg.content if latest_user_msg else "",
  253. stream_mode=True,
  254. status="error",
  255. error=str(e),
  256. )
  257. yield f"data: {json.dumps(error_response)}\n\n"
  258. # 以下是非流格式的输出, 采用了豆包助手工具,目前项目中没有使用到
  259. @router.post("/chat", response_model=ChatResponse)
  260. async def chat(
  261. request: ChatRequest,
  262. user_id: Annotated[str, Depends(resolve_user_id)],
  263. app_name: str = Query(default="com.yunxiangshengtai", description="应用包名"),
  264. ):
  265. try:
  266. if request.stream:
  267. return StreamingResponse(
  268. generate_stream_response(request, user_id, app_name),
  269. media_type="text/plain",
  270. headers={
  271. "Cache-Control": "no-cache",
  272. "Connection": "keep-alive",
  273. "Content-Type": "text/event-stream",
  274. }
  275. )
  276. # 以下是非流式输出处理
  277. session_id = request.session_id
  278. latest_user_msg = get_latest_user_message(request.messages)
  279. if not latest_user_msg:
  280. raise ValueError("请求中没有找到user角色的消息")
  281. ai_config = get_config_by_app_name(app_name)
  282. if not ai_config:
  283. raise ValueError(f"未找到appName '{app_name}' 的配置")
  284. client = get_client(app_name)
  285. save_chat_history(
  286. user_id=user_id,
  287. session_id=session_id,
  288. role=latest_user_msg.role,
  289. content=latest_user_msg.content,
  290. timestamp=datetime.now(),
  291. )
  292. api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
  293. tools = get_doubao_tools()
  294. previous_response_id = get_previous_response_id(user_id, session_id)
  295. response = client.responses.create(
  296. model=config.MODEL_NAME,
  297. input=api_messages,
  298. tools=tools,
  299. stream=False,
  300. store=True,
  301. previous_response_id=previous_response_id,
  302. )
  303. save_chat_log(
  304. user_id=user_id,
  305. question=latest_user_msg.content,
  306. stream_mode=False,
  307. raw_response=repr(response),
  308. status="success",
  309. )
  310. if not (response.output and len(response.output) > 0):
  311. raise HTTPException(status_code=500, detail="AI模型返回了空响应")
  312. message_content = ""
  313. for item in response.output:
  314. if hasattr(item, 'type') and item.type == 'doubao_app_call':
  315. if hasattr(item, 'blocks') and item.blocks:
  316. for block in item.blocks:
  317. if hasattr(block, 'type') and block.type == 'output_text' and hasattr(block, 'text'):
  318. message_content += block.text
  319. elif hasattr(item, 'type') and item.type == 'message':
  320. if hasattr(item, 'content'):
  321. if isinstance(item.content, list):
  322. for content_item in item.content:
  323. if hasattr(content_item, 'text'):
  324. message_content += content_item.text
  325. else:
  326. message_content += str(item.content)
  327. if not message_content:
  328. raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容")
  329. now = datetime.now()
  330. save_chat_history(
  331. user_id=user_id,
  332. session_id=session_id,
  333. role="assistant",
  334. content=message_content,
  335. timestamp=now,
  336. response_id=response.id,
  337. )
  338. assistant_message = ChatMessage(
  339. role="assistant",
  340. content=message_content,
  341. timestamp=now,
  342. response_id=response.id,
  343. )
  344. return ChatResponse(
  345. message=assistant_message,
  346. model=response.model,
  347. usage=response.usage.model_dump() if response.usage else None,
  348. response_id=response.id,
  349. )
  350. except HTTPException:
  351. raise
  352. except Exception as e:
  353. error_message = f"处理聊天请求时发生错误: {str(e)}"
  354. save_chat_log(
  355. user_id=user_id,
  356. question=request.messages[-1].content if request.messages else "",
  357. stream_mode=request.stream,
  358. status="error",
  359. error=error_message,
  360. )
  361. raise HTTPException(status_code=500, detail=error_message)
  362. #
  363. # @router.get("/models")
  364. # async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
  365. # try:
  366. # models = client.models.list()
  367. # return {
  368. # "models": [model.id for model in models.data],
  369. # "default_model": config.MODEL_NAME,
  370. # "user": current_user.username
  371. # }
  372. # except Exception:
  373. # return {
  374. # "models": [config.MODEL_NAME],
  375. # "default_model": config.MODEL_NAME,
  376. # "note": "使用默认模型配置",
  377. # "user": current_user.username
  378. # }
  379. @router.get("/history")
  380. async def get_user_history(
  381. current_user: Annotated[User, Depends(get_current_active_user)],
  382. sessionId: str = Query(..., description="会话ID"),
  383. ) -> List[ChatMessage]:
  384. docs = get_chat_history(current_user.userId, sessionId)
  385. return [
  386. ChatMessage(
  387. role=doc["role"],
  388. content=doc["content"],
  389. timestamp=doc.get("timestamp"),
  390. response_id=doc.get("response_id"),
  391. thinking=doc.get("thinking"),
  392. searching=doc.get("searching"),
  393. )
  394. for doc in docs
  395. ]
  396. @router.delete("/history")
  397. async def clear_user_history(
  398. current_user: Annotated[User, Depends(get_current_active_user)],
  399. sessionId: str = Query(..., description="会话ID"),
  400. ):
  401. deleted_count = delete_chat_history(current_user.userId, sessionId)
  402. if deleted_count > 0:
  403. return {"message": "聊天历史已清空", "user": current_user.userId, "deleted_messages": deleted_count, "timestamp": datetime.now()}
  404. return {"message": "用户没有聊天历史", "user": current_user.userId, "deleted_messages": 0, "timestamp": datetime.now()}
  405. @router.get("/sessions")
  406. async def get_user_sessions(
  407. current_user: Annotated[User, Depends(get_current_active_user)],
  408. ):
  409. return get_sessions(current_user.userId)
  410. @router.get("/health")
  411. async def health_check():
  412. return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME}
  413. router.tags = ["聊天服务"]
  414. router.responses = {
  415. 401: {"description": "未授权 - 需要有效的JWT令牌"},
  416. 429: {"description": "请求过多 - 配额已用完"},
  417. 500: {"description": "服务器内部错误"}
  418. }