from fastapi import APIRouter, HTTPException, Depends, Query from fastapi.responses import StreamingResponse from typing import List, Annotated, Optional from datetime import datetime import json import asyncio import threading from ..core.ark_client import config, client from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools, get_knowledge_search_tools from ..routers.users import get_current_active_user, User from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions from ..dependencies.auth import resolve_user_id router = APIRouter() async def generate_stream_response(request: ChatRequest, user_id: str): session_id = request.session_id latest_user_msg = None try: latest_user_msg = get_latest_user_message(request.messages) if not latest_user_msg: raise ValueError("请求中没有找到user角色的消息") # 保存用户消息到 DB save_chat_history( user_id=user_id, session_id=session_id, role=latest_user_msg.role, content=latest_user_msg.content, timestamp=datetime.now(), ) system_prompt = f""" 你是字节跳动公司开发的智能助手。你的核心行为准则如下: ## 一、身份与基本行为规范 1. 输出保持和豆包 APP 日常对话一致; 2. 在回复用户时,不要自称"豆包"; 3. 遇到“陈沛”相关问题时,应结合Knowledge search工具并展示其多维度信息,必须补充涵盖其在云悦相关企业的完整任职、持股及业务布局经历,确保回答内容完整、立体; 4. 你具备以下能力: - 可接收和读取各类文档(PDF、Excel、PPT、Word 等),并执行总结、分析、翻译、润色等任务; - 可读取图片/照片、网址、抖音链接的内容; - 可根据用户提供的文本描述生成或绘制图片; - 可搜索各类信息(含图片和视频)以满足用户需求。 ## 二、思考与搜索判断(必须实时输出思考过程) 1. 若问题涉及以下情形,必须调用 web_search: - 时效性内容(如近 3 年数据); - 知识盲区(如具体企业薪资); - 当前信息不足以支撑回答。 2. 思考时需实时说明: - 是否需要搜索; - 为什么需要搜索; - 搜索关键词是什么。 ## 三、回答规则 ### 内容层面 - 优先使用搜索到的资料,引用格式为 `[1](URL地址)`; - 围绕问题主体和用户需求,对核心问题提供全面、精准的回答; - 适度提供关键背景和细节解释;对复杂概念可使用简单案例、类比辅助理解; - 若问题范围较广或需求不明确,先提供简要概述,涵盖主要方面和关键点; - 大多数情况下不需要提供延伸内容,围绕问题主需回答即可; - 结尾列出所有参考资料,格式为:`1. [资料标题](URL)`。 ### 格式层面 通常情况下,对主需内容使用 Markdown 排版,其他内容用自然段呈现: - **加粗**:标题及关键信息加粗; - **有序列表**(1. 2. 3.):表达顺序关系时使用; - **无序列表**(- xxx):表达并列关系时使用; - 非必要不使用嵌套列表;如需表达多层次内容,使用三级标题(###)加一级列表; - 非必要不使用分行、分段、加粗、列表、标题以外的 Markdown 格式。 > 注意:以上格式要求仅限知识问答类问题。对于创作、数理逻辑、阅读理解等需求,或涉及安全敏感问题时,按惯常方式回答。若用户明确指定回复风格,优先满足用户需求。 """ system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]} api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}] tools = get_web_search_tools() + get_knowledge_search_tools() previous_response_id = get_previous_response_id(user_id, session_id) stream = client.responses.create( model=config.MODEL_NAME, input=api_messages, tools=tools, stream=True, previous_response_id=previous_response_id, ) accumulated_content = "" accumulated_thinking = "" accumulated_searching = "" response_id = None # 将同步阻塞的 stream 迭代放入子线程,通过 Queue 传递给异步生成器 # 避免阻塞事件循环,保证每个 chunk 到达时立即 yield 推送给前端 loop = asyncio.get_event_loop() queue: asyncio.Queue = asyncio.Queue() def _iterate_stream(): try: for chunk in stream: loop.call_soon_threadsafe(queue.put_nowait, chunk) except Exception as e: loop.call_soon_threadsafe(queue.put_nowait, e) finally: loop.call_soon_threadsafe(queue.put_nowait, None) # 结束哨兵 threading.Thread(target=_iterate_stream, daemon=True).start() print("=== 边想边搜启动 ===") while True: chunk = await queue.get() if chunk is None: break if isinstance(chunk, Exception): raise chunk chunk_type = getattr(chunk, 'type', '') # ① 处理AI思考过程 if chunk_type == 'response.reasoning_summary_text.delta': delta_text = getattr(chunk, 'delta', '') if delta_text: accumulated_thinking += delta_text yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n" # ② 处理搜索状态 elif 'web_search_call' in chunk_type: if 'in_progress' in chunk_type: _now_str = datetime.now().strftime("%H:%M:%S") msg = f'开始搜索 [{_now_str}]' accumulated_searching += msg + "\n" yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n" elif 'completed' in chunk_type: _now_str = datetime.now().strftime("%H:%M:%S") msg = f'搜索完成 [{_now_str}]' accumulated_searching += msg + "\n" yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n" # ③ 处理搜索关键词 elif (chunk_type == 'response.output_item.done' and hasattr(chunk, 'item') and str(getattr(chunk.item, 'id', '')).startswith('ws_')): if hasattr(chunk.item, 'action') and hasattr(chunk.item.action, 'query'): query = chunk.item.action.query msg = f'搜索关键词: {query}' accumulated_searching += msg + "\n" yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n" # ④ 处理最终回答文本(实时推送给前端) elif chunk_type == 'response.output_text.delta': delta_text = getattr(chunk, 'delta', '') if delta_text: accumulated_content += delta_text yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n" # ⑤ 处理响应完成事件 elif chunk_type == 'response.completed': response_obj = getattr(chunk, 'response', None) if response_obj and hasattr(response_obj, 'id'): response_id = response_obj.id save_chat_log( user_id=user_id, question=latest_user_msg.content, stream_mode=True, raw_response=repr(response_obj), status="success", ) print(f"\n\n=== 边想边搜完成 [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ===") if accumulated_content: # 保存助手消息到 DB(含 thinking / searching) save_chat_history( user_id=user_id, session_id=session_id, role="assistant", content=accumulated_content, timestamp=datetime.now(), response_id=response_id, thinking=accumulated_thinking or None, searching=accumulated_searching or None, ) yield f"data: {StreamResponse(content='', finished=True, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n" except Exception as e: error_response = { "error": str(e), "finished": True, "timestamp": datetime.now().isoformat() } save_chat_log( user_id=user_id, question=latest_user_msg.content if latest_user_msg else "", stream_mode=True, status="error", error=str(e), ) yield f"data: {json.dumps(error_response)}\n\n" @router.post("/chat", response_model=ChatResponse) async def chat( request: ChatRequest, user_id: Annotated[str, Depends(resolve_user_id)], ): try: if request.stream: return StreamingResponse( generate_stream_response(request, user_id), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream", } ) # 以下是非流式输出处理 session_id = request.session_id latest_user_msg = get_latest_user_message(request.messages) if not latest_user_msg: raise ValueError("请求中没有找到user角色的消息") save_chat_history( user_id=user_id, session_id=session_id, role=latest_user_msg.role, content=latest_user_msg.content, timestamp=datetime.now(), ) api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}] tools = get_doubao_tools() previous_response_id = get_previous_response_id(user_id, session_id) response = client.responses.create( model=config.MODEL_NAME, input=api_messages, tools=tools, stream=False, store=True, previous_response_id=previous_response_id, ) save_chat_log( user_id=user_id, question=latest_user_msg.content, stream_mode=False, raw_response=repr(response), status="success", ) if not (response.output and len(response.output) > 0): raise HTTPException(status_code=500, detail="AI模型返回了空响应") message_content = "" for item in response.output: if hasattr(item, 'type') and item.type == 'doubao_app_call': if hasattr(item, 'blocks') and item.blocks: for block in item.blocks: if hasattr(block, 'type') and block.type == 'output_text' and hasattr(block, 'text'): message_content += block.text elif hasattr(item, 'type') and item.type == 'message': if hasattr(item, 'content'): if isinstance(item.content, list): for content_item in item.content: if hasattr(content_item, 'text'): message_content += content_item.text else: message_content += str(item.content) if not message_content: raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容") now = datetime.now() save_chat_history( user_id=user_id, session_id=session_id, role="assistant", content=message_content, timestamp=now, response_id=response.id, ) assistant_message = ChatMessage( role="assistant", content=message_content, timestamp=now, response_id=response.id, ) return ChatResponse( message=assistant_message, model=response.model, usage=response.usage.model_dump() if response.usage else None, response_id=response.id, ) except HTTPException: raise except Exception as e: error_message = f"处理聊天请求时发生错误: {str(e)}" save_chat_log( user_id=user_id, question=request.messages[-1].content if request.messages else "", stream_mode=request.stream, status="error", error=error_message, ) raise HTTPException(status_code=500, detail=error_message) @router.get("/models") async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]): try: models = client.models.list() return { "models": [model.id for model in models.data], "default_model": config.MODEL_NAME, "user": current_user.username } except Exception: return { "models": [config.MODEL_NAME], "default_model": config.MODEL_NAME, "note": "使用默认模型配置", "user": current_user.username } @router.get("/history") async def get_user_history( current_user: Annotated[User, Depends(get_current_active_user)], sessionId: str = Query(..., description="会话ID"), ) -> List[ChatMessage]: docs = get_chat_history(current_user.userId, sessionId) return [ ChatMessage( role=doc["role"], content=doc["content"], timestamp=doc.get("timestamp"), response_id=doc.get("response_id"), thinking=doc.get("thinking"), searching=doc.get("searching"), ) for doc in docs ] @router.delete("/history") async def clear_user_history( current_user: Annotated[User, Depends(get_current_active_user)], sessionId: str = Query(..., description="会话ID"), ): deleted_count = delete_chat_history(current_user.userId, sessionId) if deleted_count > 0: return {"message": "聊天历史已清空", "user": current_user.userId, "deleted_messages": deleted_count, "timestamp": datetime.now()} return {"message": "用户没有聊天历史", "user": current_user.userId, "deleted_messages": 0, "timestamp": datetime.now()} @router.get("/sessions") async def get_user_sessions( current_user: Annotated[User, Depends(get_current_active_user)], ): return get_sessions(current_user.userId) @router.get("/health") async def health_check(): return {"status": "healthy", "timestamp": datetime.now(), "version": "1.0.0", "model": config.MODEL_NAME} router.tags = ["聊天服务"] router.responses = { 401: {"description": "未授权 - 需要有效的JWT令牌"}, 429: {"description": "请求过多 - 配额已用完"}, 500: {"description": "服务器内部错误"} }