from fastapi import APIRouter, HTTPException, Depends, Query from fastapi.responses import StreamingResponse from typing import List, Annotated, Optional from datetime import datetime import json import asyncio import threading from ..core.ark_client import config, client from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools, get_knowledge_search_tools from ..routers.users import get_current_active_user, User from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions from ..dependencies.auth import resolve_user_id router = APIRouter() async def generate_stream_response(request: ChatRequest, user_id: str): session_id = request.session_id latest_user_msg = None try: latest_user_msg = get_latest_user_message(request.messages) if not latest_user_msg: raise ValueError("请求中没有找到user角色的消息") # 保存用户消息到 DB save_chat_history( user_id=user_id, session_id=session_id, role=latest_user_msg.role, content=latest_user_msg.content, timestamp=datetime.now(), ) system_prompt = f""" 你是字节跳动公司开发的智能助手。你的核心行为准则如下: ## 一、身份与基本行为规范 1. 输出保持和豆包 APP 日常对话一致; 2. 在回复用户时,不要自称“豆包”; 3. 你具备以下能力: - 可接收和读取各类文档(PDF、Excel、PPT、Word 等),并执行总结、分析、翻译、润色等任务; - 可读取图片/照片、网址、抖音链接的内容; - 可根据用户提供的文本描述生成或绘制图片; - 可搜索各类信息(含图片和视频)以满足用户需求。 ## 二、工具使用总原则 1. 优先使用「知识库」检索信息,只有当知识库的信息不足以支撑回答时,才能使用联网搜索;如果知识库信息足够,则不联网。 2. 对于以下问题,优先参考「知识库」中的信息进行回复: - 云悦产品相关问题(如:XX宝); - 企业信息相关问题(如:云悦); - 创始人或负责人相关问题(如:陈沛)。 3. 当用户提问涉及企业、企业产品、企业负责人、人物信息等内容时,应先尝试通过知识库检索;若知识库无法提供足够信息,再判断为当前信息不足并启用联网搜索。 4. 若知识库无结果或结果不足,不需要向用户说明“知识库未命中”或“正在联网搜索”,直接继续完成检索与回答。 5. 不得为了形式完整而强行联网;若知识库已足够回答,则直接基于已有信息作答。 ## 三、联网搜索触发规则 仅在以下情况下,才允许调用联网搜索: 1. 知识库信息不足以支撑回答; 2. 问题具有明显时效性,例如近3年的数据、最新动态、近期人事变动、当前价格、最新产品信息等; 3. 问题属于你的知识盲区,且知识库也未覆盖,例如特定企业薪资、实时工商状态、近期新闻事件等; 4. 用户问题需要依赖最新公开信息,而当前已有信息无法确保准确性。 若不满足以上条件,则不联网。 ## 四、搜索与信息验证规则 当必须联网搜索时,应遵循以下原则: 1. 搜索范围 - 默认获取 top10 搜索结果作为候选信息; - 优先关注与用户问题强相关的信息。 2. 来源可信度判断 - 优先采用高可信来源的信息,例如: - 官方网站、官方公告、官方公众号; - 权威媒体; - 行业机构、公开财报、监管披露、学术或专业数据库。 - 对来源不明、营销导向强、内容农场、明显搬运或缺乏佐证的信息,应降低权重或直接舍弃。 3. 信息真实性验证 - 对关键事实进行交叉验证,尤其是: - 企业名称、产品名称; - 职位、负责人身份; - 时间、金额、价格、融资、营收等关键数据; - 产品能力、发布时间、合作关系等。 - 重点检查: - 时间是否一致; - 表述是否存在逻辑冲突; - 是否有多个独立来源支持; - 是否存在明显异常或夸张描述。 - 如果信息可能不实,则直接排除,不用于回答。 4. 信息整合 - 优先采用高质量、可交叉验证的信息形成答案; - 若多个可信来源一致,可提高回答确定性; - 若信息存在冲突,应仅保留相对稳妥、可验证的部分,避免武断下结论; - 若搜索结果整体质量较低、无法形成可靠结论,则视为“未搜索到可靠信息”。 5. 搜索失败处理 - 若联网搜索后仍无可靠信息,不编造、不猜测; - 应直接告诉用户目前无法找到可靠信息。 ## 五、回答规则 ### 1. 内容层面 - 优先回答用户的核心问题,内容应准确、直接、完整; - 在不偏离主问题的前提下,可适度补充必要背景,帮助用户理解; - 对复杂概念可使用简洁例子或类比辅助说明; - 若问题范围较广或需求不明确,先给出简要概述,再覆盖关键点; - 大多数情况下不需要提供过多延伸内容,围绕用户主需回答即可; - 若信息不足或搜索结果不可靠,应明确说明无法确认,不得编造。 ### 2. 来源呈现规则 - 可以内部参考知识库和搜索结果进行作答; - 但对用户输出时,**不得暴露参考资料的存在**; - 不得出现类似: - “根据参考资料” - “根据知识库” - “根据检索结果” - “我查到” - “搜索显示” 等表述; - 不需要展示引用链接、角标引用、参考文献列表。 ### 3. 时效性表达 - 对企业、产品、负责人、人事变动、价格、营收、融资等容易变化的信息,应自然标注时间范围; - 推荐表达方式: - “截至2025年3月,……” - “从目前公开信息来看,……” - “根据2024年下半年的公开信息,……” - 时效性表达应自然融入回答,不要生硬罗列。 ### 4. 格式层面 通常情况下,对知识问答类问题使用清晰、结构化表达,确保用户轻松理解和使用: - 优先使用自然分段; - 需要表达顺序关系时,使用有序列表(1. 2. 3.); - 需要表达并列关系时,使用无序列表; - 可适度使用加粗突出标题和关键信息; - 非必要不使用复杂嵌套列表; - 对创作、数理逻辑、阅读理解等任务,按惯常方式回答; - 若用户明确指定回复风格,优先满足用户需求。 ## 六、特殊场景处理 1. 如果知识库已有云悦、XX宝、陈沛相关信息,优先使用知识库内容,不主动联网补充。 2. 如果知识库对上述主题信息不足,再进行联网搜索,并仅吸收可信、可验证的信息。 3. 对敏感、隐私、争议信息保持谨慎,尤其是个人资产、未经证实的履历、传闻、八卦、负面指控等;若缺乏可靠依据,应拒绝采纳或明确表示无法确认。 4. 若用户提问本身不清晰,可先简短追问澄清;但若已有足够上下文,也可先给出当前可确定的答案。 ## 七、禁止事项 1. 不得在知识库信息足够时擅自联网; 2. 不得把低可信、未验证、可能不实的信息写入答案; 3. 不得编造事实、时间、数据、人物关系或产品能力; 4. 不得向用户暴露知识库、检索、搜索策略、来源筛选过程或内部判断过程; 5. 不得输出“思考过程”“搜索关键词”“为什么需要搜索”等内部推理内容; 6. 不得使用“根据参考资料/根据知识库/根据搜索结果”等表述。 ## 八、最终目标 在保证回答自然、清晰、易懂的前提下: - 优先使用知识库; - 仅在必要时联网; - 对联网结果进行真实性与可信度验证; - 用结构化语言给出准确、稳妥、不过度暴露内部过程的回答。 """ system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]} api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}] tools = get_web_search_tools() + get_knowledge_search_tools() previous_response_id = get_previous_response_id(user_id, session_id) stream = client.responses.create( model=config.MODEL_NAME, input=api_messages, tools=tools, stream=True, previous_response_id=previous_response_id, ) accumulated_content = "" accumulated_thinking = "" accumulated_searching = "" response_id = None # 将同步阻塞的 stream 迭代放入子线程,通过 Queue 传递给异步生成器 # 避免阻塞事件循环,保证每个 chunk 到达时立即 yield 推送给前端 loop = asyncio.get_event_loop() queue: asyncio.Queue = asyncio.Queue() def _iterate_stream(): try: for chunk in stream: loop.call_soon_threadsafe(queue.put_nowait, chunk) except Exception as e: loop.call_soon_threadsafe(queue.put_nowait, e) finally: loop.call_soon_threadsafe(queue.put_nowait, None) # 结束哨兵 threading.Thread(target=_iterate_stream, daemon=True).start() print("=== 边想边搜启动 ===") while True: chunk = await queue.get() if chunk is None: break if isinstance(chunk, Exception): raise chunk chunk_type = getattr(chunk, 'type', '') # ① 处理AI思考过程 if chunk_type == 'response.reasoning_summary_text.delta': delta_text = getattr(chunk, 'delta', '') if delta_text: accumulated_thinking += delta_text yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n" # ② 处理搜索状态 elif 'web_search_call' in chunk_type: if 'in_progress' in chunk_type: _now_str = datetime.now().strftime("%H:%M:%S") msg = f'开始搜索 [{_now_str}]' accumulated_searching += msg + "\n" yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n" elif 'completed' in chunk_type: _now_str = datetime.now().strftime("%H:%M:%S") msg = f'搜索完成 [{_now_str}]' accumulated_searching += msg + "\n" yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n" # ③ 处理搜索关键词 elif (chunk_type == 'response.output_item.done' and hasattr(chunk, 'item') and str(getattr(chunk.item, 'id', '')).startswith('ws_')): if hasattr(chunk.item, 'action') and hasattr(chunk.item.action, 'query'): query = chunk.item.action.query msg = f'搜索关键词: {query}' accumulated_searching += msg + "\n" yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n" # ④ 处理最终回答文本(实时推送给前端) elif chunk_type == 'response.output_text.delta': delta_text = getattr(chunk, 'delta', '') if delta_text: accumulated_content += delta_text yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n" # ⑤ 处理响应完成事件 elif chunk_type == 'response.completed': response_obj = getattr(chunk, 'response', None) if response_obj and hasattr(response_obj, 'id'): response_id = response_obj.id save_chat_log( user_id=user_id, question=latest_user_msg.content, stream_mode=True, raw_response=repr(response_obj), status="success", ) print(f"\n\n=== 边想边搜完成 [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ===") if accumulated_content: # 保存助手消息到 DB(含 thinking / searching) save_chat_history( user_id=user_id, session_id=session_id, role="assistant", content=accumulated_content, timestamp=datetime.now(), response_id=response_id, thinking=accumulated_thinking or None, searching=accumulated_searching or None, ) yield f"data: {StreamResponse(content='', finished=True, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n" except Exception as e: error_response = { "error": str(e), "finished": True, "timestamp": datetime.now().isoformat() } save_chat_log( user_id=user_id, question=latest_user_msg.content if latest_user_msg else "", stream_mode=True, status="error", error=str(e), ) yield f"data: {json.dumps(error_response)}\n\n" # 以下是非流格式的输出, 采用了豆包助手工具,目前项目中没有使用到 @router.post("/chat", response_model=ChatResponse) async def chat( request: ChatRequest, user_id: Annotated[str, Depends(resolve_user_id)], ): try: if request.stream: return StreamingResponse( generate_stream_response(request, user_id), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream", } ) # 以下是非流式输出处理 session_id = request.session_id latest_user_msg = get_latest_user_message(request.messages) if not latest_user_msg: raise ValueError("请求中没有找到user角色的消息") save_chat_history( user_id=user_id, session_id=session_id, role=latest_user_msg.role, content=latest_user_msg.content, timestamp=datetime.now(), ) api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}] tools = get_doubao_tools() previous_response_id = get_previous_response_id(user_id, session_id) response = client.responses.create( model=config.MODEL_NAME, input=api_messages, tools=tools, stream=False, store=True, previous_response_id=previous_response_id, ) save_chat_log( user_id=user_id, question=latest_user_msg.content, stream_mode=False, raw_response=repr(response), status="success", ) if not (response.output and len(response.output) > 0): raise HTTPException(status_code=500, detail="AI模型返回了空响应") message_content = "" for item in response.output: if hasattr(item, 'type') and item.type == 'doubao_app_call': if hasattr(item, 'blocks') and item.blocks: for block in item.blocks: if hasattr(block, 'type') and block.type == 'output_text' and hasattr(block, 'text'): message_content += block.text elif hasattr(item, 'type') and item.type == 'message': if hasattr(item, 'content'): if isinstance(item.content, list): for content_item in item.content: if hasattr(content_item, 'text'): message_content += content_item.text else: message_content += str(item.content) if not message_content: raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容") now = datetime.now() save_chat_history( user_id=user_id, session_id=session_id, role="assistant", content=message_content, timestamp=now, response_id=response.id, ) assistant_message = ChatMessage( role="assistant", content=message_content, timestamp=now, response_id=response.id, ) return ChatResponse( message=assistant_message, model=response.model, usage=response.usage.model_dump() if response.usage else None, response_id=response.id, ) except HTTPException: raise except Exception as e: error_message = f"处理聊天请求时发生错误: {str(e)}" save_chat_log( user_id=user_id, question=request.messages[-1].content if request.messages else "", stream_mode=request.stream, status="error", error=error_message, ) raise HTTPException(status_code=500, detail=error_message) @router.get("/models") async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]): try: models = client.models.list() return { "models": [model.id for model in models.data], "default_model": config.MODEL_NAME, "user": current_user.username } except Exception: return { "models": [config.MODEL_NAME], "default_model": config.MODEL_NAME, "note": "使用默认模型配置", "user": current_user.username } @router.get("/history") async def get_user_history( current_user: Annotated[User, Depends(get_current_active_user)], sessionId: str = Query(..., description="会话ID"), ) -> List[ChatMessage]: docs = get_chat_history(current_user.userId, sessionId) return [ ChatMessage( role=doc["role"], content=doc["content"], timestamp=doc.get("timestamp"), response_id=doc.get("response_id"), thinking=doc.get("thinking"), searching=doc.get("searching"), ) for doc in docs ] @router.delete("/history") async def clear_user_history( current_user: Annotated[User, Depends(get_current_active_user)], sessionId: str = Query(..., description="会话ID"), ): deleted_count = delete_chat_history(current_user.userId, sessionId) if deleted_count > 0: return {"message": "聊天历史已清空", "user": current_user.userId, "deleted_messages": deleted_count, "timestamp": datetime.now()} return {"message": "用户没有聊天历史", "user": current_user.userId, "deleted_messages": 0, "timestamp": datetime.now()} @router.get("/sessions") async def get_user_sessions( current_user: Annotated[User, Depends(get_current_active_user)], ): return get_sessions(current_user.userId) @router.get("/health") async def health_check(): return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME} router.tags = ["聊天服务"] router.responses = { 401: {"description": "未授权 - 需要有效的JWT令牌"}, 429: {"description": "请求过多 - 配额已用完"}, 500: {"description": "服务器内部错误"} }