Browse Source

feat:流式输出

zhangwl 22 hours ago
parent
commit
fb75cbd8a2
1 changed files with 182 additions and 69 deletions
  1. 182 69
      app/routers/chat.py

+ 182 - 69
app/routers/chat.py

@@ -108,11 +108,31 @@ def convert_messages_for_api(messages: List[ChatMessage]) -> List[Dict[str, str]
     return [{"role": msg.role, "content": msg.content} for msg in messages]
 
 
+def get_latest_user_message(messages: List[ChatMessage]) -> Optional[ChatMessage]:
+    """
+    获取消息列表中最后一条user角色的消息
+
+    在多轮对话中,消息列表可能包含user和assistant的消息,
+    流式场景下客户端会预先添加空的assistant消息作为占位符,
+    此函数确保获取到最后一条用户发送的消息
+
+    Args:
+        messages (List[ChatMessage]): 消息列表
+
+    Returns:
+        Optional[ChatMessage]: 最后一条user角色的消息,如果不存在则返回None
+    """
+    for message in reversed(messages):
+        if message.role == "user":
+            return message
+    return None
+
+
 async def generate_stream_response(request: ChatRequest, username: str):
     """
     生成流式响应的异步生成器
 
-    这个函数处理流式AI响应,将OpenAI的流式输出转换为SSE格式
+    这个函数处理流式AI响应,将Ark API的流式输出转换为SSE格式
 
     Args:
         request (ChatRequest): 聊天请求对象
@@ -126,47 +146,115 @@ async def generate_stream_response(request: ChatRequest, username: str):
         客户端需要使用EventSource或类似技术接收流式数据
     """
     try:
-        # 转换消息格式为OpenAI API需要的格式
-        api_messages = convert_messages_for_api(request.messages)
+        # 获取最后一条user角色的消息
+        latest_user_msg = get_latest_user_message(request.messages)
+        if not latest_user_msg:
+            raise ValueError("请求中没有找到user角色的消息")
+
+        # 将用户消息添加到历史记录
+        user_message = ChatMessage(
+            role=latest_user_msg.role,
+            content=latest_user_msg.content,
+            timestamp=datetime.now()
+        )
+        chatHistory[username].append(user_message)
+
+        # 转换消息格式为API需要的格式
+        api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
+
+        tools = [{
+            "type": "doubao_app",
+            "feature": {
+                "ai_search": {
+                    "type": "enabled",
+                    "role_description": "你是浙江云悦有限公司助手,专业解答云悦问题"
+                }
+            },
+            "user_location": {
+                "type": "approximate",
+                "country": "中国",
+                "region": "浙江",
+                "city": "杭州"
+            }
+        }]
+
+        # 获取上一轮对话的response_id,用于多轮对话的上下文关联
+        previous_response_id = None
+        if username in chatHistory and len(chatHistory[username]) > 0:
+            # 从后往前查找最后一条assistant消息的response_id
+            for message in reversed(chatHistory[username]):
+                if message.role == "assistant" and message.response_id:
+                    previous_response_id = message.response_id
+                    break
 
-        # 调用OpenAI流式API
         # stream=True 启用流式输出,API会返回一个迭代器
-        stream = client.chat.completions.create(
-            model=request.model or config.MODEL_NAME,  # 使用指定模型或默认模型
-            messages=api_messages,  # 对话历史
-            max_tokens=request.max_tokens,  # 最大token数
-            temperature=request.temperature,  # 创造性温度
-            stream=True  # 启用流式输出
+        stream = client.responses.create(
+            model=config.MODEL_NAME,
+            input=api_messages,
+            tools=tools,
+            stream=True,
+            store=True,  # 存储当前对话上下文。此字段不存储tools,每次调用仍需给tools赋值。
+            previous_response_id=previous_response_id,
         )
 
-        # 用于累积完整的回答内容
+
+        # 用于累积完整的回答内容和response_id
         accumulated_content = ""
+        response_id = None
 
         # 遍历流式响应的每个数据块
         for chunk in stream:
-            # 检查数据块是否包含有效内容
-            if chunk.choices and chunk.choices[0].delta.content:
-                # 提取本次数据块的内容
-                chunk_content = chunk.choices[0].delta.content
-
-                # 累积到完整内容中
-                accumulated_content += chunk_content
-
-                # 构建流式响应数据对象
-                response_data = StreamResponse(
-                    content=chunk.choices[0].delta.content,  # 本次片段内容
-                    finished=False,  # 标记为未完成
-                    model=request.model or config.MODEL_NAME,  # 使用的模型
-                    timestamp=datetime.now()  # 当前时间戳
-                )
+            # 处理不同类型的流式事件
+            chunk_dict = chunk.__dict__ if hasattr(chunk, '__dict__') else {}
+            event_type = chunk_dict.get('type', '')
+
+            # 处理文本内容增量事件(普通文本)
+            if event_type == 'response.output_text.delta':
+                delta_text = chunk_dict.get('delta', '')
+                if delta_text:
+                    # 累积到完整内容中
+                    accumulated_content += delta_text
+
+                    # 构建流式响应数据对象
+                    response_data = StreamResponse(
+                        content=delta_text,  # 本次片段内容
+                        finished=False,  # 标记为未完成
+                        model= config.MODEL_NAME,  # 使用的模型
+                        timestamp=datetime.now()  # 当前时间戳
+                    )
 
-                # 格式化为SSE格式并发送
-                # SSE格式: "data: {json_data}\n\n"
-                yield f"data: {response_data.model_dump_json()}\n\n"
+                    # 格式化为SSE格式并发送
+                    # SSE格式: "data: {json_data}\n\n"
+                    yield f"data: {response_data.model_dump_json()}\n\n"
+
+                    # 异步让出控制权,避免阻塞事件循环
+                    await asyncio.sleep(0.01)
+
+            # 处理DoubaoApp调用的文本输出增量事件
+            elif event_type == 'response.doubao_app_call_output_text.delta':
+                delta_text = chunk_dict.get('delta', '')
+                if delta_text:
+                    # 累积到完整内容中
+                    accumulated_content += delta_text
+
+                    # 构建流式响应数据对象
+                    response_data = StreamResponse(
+                        content=delta_text,  # 本次片段内容
+                        finished=False,  # 标记为未完成
+                        model= config.MODEL_NAME,  # 使用的模型
+                        timestamp=datetime.now()  # 当前时间戳
+                    )
+
+                    # 格式化为SSE格式并发送
+                    yield f"data: {response_data.model_dump_json()}\n\n"
 
-                # 异步让出控制权,避免阻塞事件循环
-                # 这对于处理大量并发请求很重要
-                await asyncio.sleep(0.01)
+                    # 异步让出控制权,避免阻塞事件循环
+                    await asyncio.sleep(0.01)
+
+            # 处理响应完成事件,获取response_id
+            elif event_type == 'response.completed':
+                if 'response' in chunk_dict and hasattr(chunk_dict['response'], 'id'):
+                    response_id = chunk_dict['response'].id
 
         # 流式响应结束后的处理
         if accumulated_content:
@@ -174,22 +262,27 @@ async def generate_stream_response(request: ChatRequest, username: str):
             final_response = StreamResponse(
                 content='',  # 结束信号不包含内容
                 finished=True,  # 标记为已完成
-                model=request.model or config.MODEL_NAME,
+                model= config.MODEL_NAME,
                 timestamp=datetime.now()
             )
 
-            # 将完整的AI回复保存到用户的聊天历史中
+            # 将完整的AI回复保存到用户的聊天历史中,包含response_id
             chatHistory[username].append(
                 ChatMessage(
                     role="assistant",
                     content=accumulated_content,
-                    timestamp=datetime.now()
+                    timestamp=datetime.now(),
+                    response_id=response_id  # 保存response_id用于后续多轮对话
                 )
             )
 
+
             # 发送结束信号
             yield f"data: {final_response.model_dump_json()}\n\n"
 
+            # 在控制台输出提示
+            print("流式内容已全部输出")
+
     except Exception as e:
         # 流式响应过程中的错误处理
         # 构建错误响应并发送给客户端
@@ -249,14 +342,6 @@ async def chat(
         if request.stream:
             # ===== 流式输出处理 =====
 
-            # 将用户的最新消息添加到历史记录
-            user_message = ChatMessage(
-                role=request.messages[-1].role,
-                content=request.messages[-1].content,
-                timestamp=datetime.now()
-            )
-            # chatHistory[username].append(user_message)
-
             # 返回流式响应
             # StreamingResponse 用于处理SSE协议
             return StreamingResponse(
@@ -272,21 +357,36 @@ async def chat(
         else:
             # ===== 非流式输出处理 =====
 
+            # 获取最后一条user角色的消息
+            latest_user_msg = get_latest_user_message(request.messages)
+            if not latest_user_msg:
+                raise ValueError("请求中没有找到user角色的消息")
+
             # 将用户消息添加到历史记录
             user_message = ChatMessage(
-                role=request.messages[-1].role,
-                content=request.messages[-1].content,
+                role=latest_user_msg.role,
+                content=latest_user_msg.content,
                 timestamp=datetime.now()
             )
             chatHistory[username].append(user_message)
 
-            # 转换消息格式为OpenAI API需要的格式
-            # api_messages = convert_messages_for_api(request.messages)
-            api_messages = [{"role": request.messages[-1].role, "content": request.messages[-1].content}]
+            # 转换消息格式为API需要的格式
+            api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
 
             tools = [{
-                "type": "web_search",
-                "max_keyword": 2,  # 可选参数,用于限制一轮搜索的最大关键词数量
+                "type": "doubao_app",
+                "feature": {
+                    "ai_search": {
+                        "type": "enabled",
+                        "role_description": "你是浙江云悦有限公司助手,专业解答云悦问题"
+                    }
+                },
+                "user_location": {
+                    "type": "approximate",
+                    "country": "中国",
+                    "region": "浙江",
+                    "city": "杭州"
+                }
             }]
 
             # 获取上一轮对话的response_id,用于多轮对话的上下文关联
@@ -303,33 +403,40 @@ async def chat(
                 input=api_messages,
                 tools=tools,
                 stream=False,
+                store=True,  # 存储当前对话上下文。此字段不存储tools,每次调用仍需给tools赋值。
                 previous_response_id=previous_response_id,
             )
 
             # 检查API响应是否有效
             if response.output and len(response.output) > 0:
-                # 从output中找到最后一条消息(ResponseOutputMessage类型)
-                last_message = None
-                for item in reversed(response.output):
-                    if hasattr(item, 'type') and item.type == 'message':
-                        last_message = item
-                        break
-
-                if last_message and hasattr(last_message, 'content'):
-                    # 提取消息内容
-                    message_content = ""
-                    if isinstance(last_message.content, list):
-                        # content是列表,提取所有文本内容
-                        for content_item in last_message.content:
-                            if hasattr(content_item, 'text'):
-                                message_content += content_item.text
-                    else:
-                        message_content = str(last_message.content)
-
+                # 从output中提取文本内容
+                message_content = ""
+
+                for item in response.output:
+                    # 处理 ItemDoubaoAppCall 类型(包含搜索结果和文本输出)
+                    if hasattr(item, 'type') and item.type == 'doubao_app_call':
+                        if hasattr(item, 'blocks') and item.blocks:
+                            # 从blocks中找到output_text类型的块
+                            for block in item.blocks:
+                                if hasattr(block, 'type') and block.type == 'output_text':
+                                    if hasattr(block, 'text'):
+                                        message_content += block.text
+
+                    # 处理其他类型的消息项
+                    elif hasattr(item, 'type') and item.type == 'message':
+                        if hasattr(item, 'content'):
+                            if isinstance(item.content, list):
+                                for content_item in item.content:
+                                    if hasattr(content_item, 'text'):
+                                        message_content += content_item.text
+                            else:
+                                message_content += str(item.content)
+
+                if message_content:
                     # 构建AI助手的回复消息,包含response_id用于多轮对话
                     assistant_message = ChatMessage(
                         role="assistant",
-                        content=message_content or "",
+                        content=message_content,
                         timestamp=datetime.now(),
                         response_id=response.id  # 保存response_id用于后续多轮对话
                     )
@@ -346,6 +453,12 @@ async def chat(
                     )
 
                     return chat_response
+                else:
+                    # 没有提取到文本内容的错误处理
+                    raise HTTPException(
+                        status_code=500,
+                        detail="无法从AI响应中提取文本内容"
+                    )
             else:
                 # API返回空响应的错误处理
                 raise HTTPException(