|
@@ -108,11 +108,31 @@ def convert_messages_for_api(messages: List[ChatMessage]) -> List[Dict[str, str]
|
|
|
return [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
return [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def get_latest_user_message(messages: List[ChatMessage]) -> Optional[ChatMessage]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取消息列表中最后一条user角色的消息
|
|
|
|
|
+
|
|
|
|
|
+ 在多轮对话中,消息列表可能包含user和assistant的消息,
|
|
|
|
|
+ 流式场景下客户端会预先添加空的assistant消息作为占位符,
|
|
|
|
|
+ 此函数确保获取到最后一条用户发送的消息
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ messages (List[ChatMessage]): 消息列表
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ Optional[ChatMessage]: 最后一条user角色的消息,如果不存在则返回None
|
|
|
|
|
+ """
|
|
|
|
|
+ for message in reversed(messages):
|
|
|
|
|
+ if message.role == "user":
|
|
|
|
|
+ return message
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
async def generate_stream_response(request: ChatRequest, username: str):
|
|
async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
"""
|
|
"""
|
|
|
生成流式响应的异步生成器
|
|
生成流式响应的异步生成器
|
|
|
|
|
|
|
|
- 这个函数处理流式AI响应,将OpenAI的流式输出转换为SSE格式
|
|
|
|
|
|
|
+ 这个函数处理流式AI响应,将Ark API的流式输出转换为SSE格式
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
request (ChatRequest): 聊天请求对象
|
|
request (ChatRequest): 聊天请求对象
|
|
@@ -126,47 +146,115 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
客户端需要使用EventSource或类似技术接收流式数据
|
|
客户端需要使用EventSource或类似技术接收流式数据
|
|
|
"""
|
|
"""
|
|
|
try:
|
|
try:
|
|
|
- # 转换消息格式为OpenAI API需要的格式
|
|
|
|
|
- api_messages = convert_messages_for_api(request.messages)
|
|
|
|
|
|
|
+ # 获取最后一条user角色的消息
|
|
|
|
|
+ latest_user_msg = get_latest_user_message(request.messages)
|
|
|
|
|
+ if not latest_user_msg:
|
|
|
|
|
+ raise ValueError("请求中没有找到user角色的消息")
|
|
|
|
|
+
|
|
|
|
|
+ # 将用户消息添加到历史记录
|
|
|
|
|
+ user_message = ChatMessage(
|
|
|
|
|
+ role=latest_user_msg.role,
|
|
|
|
|
+ content=latest_user_msg.content,
|
|
|
|
|
+ timestamp=datetime.now()
|
|
|
|
|
+ )
|
|
|
|
|
+ chatHistory[username].append(user_message)
|
|
|
|
|
+
|
|
|
|
|
+ # 转换消息格式为API需要的格式
|
|
|
|
|
+ api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
|
|
|
+
|
|
|
|
|
+ tools = [{
|
|
|
|
|
+ "type": "doubao_app",
|
|
|
|
|
+ "feature": {
|
|
|
|
|
+ "ai_search": {
|
|
|
|
|
+ "type": "enabled",
|
|
|
|
|
+ "role_description": "你是浙江云悦有限公司助手,专业解答云悦问题"
|
|
|
|
|
+ }
|
|
|
|
|
+ },
|
|
|
|
|
+ "user_location": {
|
|
|
|
|
+ "type": "approximate",
|
|
|
|
|
+ "country": "中国",
|
|
|
|
|
+ "region": "浙江",
|
|
|
|
|
+ "city": "杭州"
|
|
|
|
|
+ }
|
|
|
|
|
+ }]
|
|
|
|
|
+
|
|
|
|
|
+ # 获取上一轮对话的response_id,用于多轮对话的上下文关联
|
|
|
|
|
+ previous_response_id = None
|
|
|
|
|
+ if username in chatHistory and len(chatHistory[username]) > 0:
|
|
|
|
|
+ # 从后往前查找最后一条assistant消息的response_id
|
|
|
|
|
+ for message in reversed(chatHistory[username]):
|
|
|
|
|
+ if message.role == "assistant" and message.response_id:
|
|
|
|
|
+ previous_response_id = message.response_id
|
|
|
|
|
+ break
|
|
|
|
|
|
|
|
- # 调用OpenAI流式API
|
|
|
|
|
# stream=True 启用流式输出,API会返回一个迭代器
|
|
# stream=True 启用流式输出,API会返回一个迭代器
|
|
|
- stream = client.chat.completions.create(
|
|
|
|
|
- model=request.model or config.MODEL_NAME, # 使用指定模型或默认模型
|
|
|
|
|
- messages=api_messages, # 对话历史
|
|
|
|
|
- max_tokens=request.max_tokens, # 最大token数
|
|
|
|
|
- temperature=request.temperature, # 创造性温度
|
|
|
|
|
- stream=True # 启用流式输出
|
|
|
|
|
|
|
+ stream = client.responses.create(
|
|
|
|
|
+ model=config.MODEL_NAME,
|
|
|
|
|
+ input=api_messages,
|
|
|
|
|
+ tools=tools,
|
|
|
|
|
+ stream=True,
|
|
|
|
|
+ store=True, # 存储当前对话上下文。此字段不存储tools,每次调用仍需给tools赋值。
|
|
|
|
|
+ previous_response_id=previous_response_id,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # 用于累积完整的回答内容
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # 用于累积完整的回答内容和response_id
|
|
|
accumulated_content = ""
|
|
accumulated_content = ""
|
|
|
|
|
+ response_id = None
|
|
|
|
|
|
|
|
# 遍历流式响应的每个数据块
|
|
# 遍历流式响应的每个数据块
|
|
|
for chunk in stream:
|
|
for chunk in stream:
|
|
|
- # 检查数据块是否包含有效内容
|
|
|
|
|
- if chunk.choices and chunk.choices[0].delta.content:
|
|
|
|
|
- # 提取本次数据块的内容
|
|
|
|
|
- chunk_content = chunk.choices[0].delta.content
|
|
|
|
|
-
|
|
|
|
|
- # 累积到完整内容中
|
|
|
|
|
- accumulated_content += chunk_content
|
|
|
|
|
-
|
|
|
|
|
- # 构建流式响应数据对象
|
|
|
|
|
- response_data = StreamResponse(
|
|
|
|
|
- content=chunk.choices[0].delta.content, # 本次片段内容
|
|
|
|
|
- finished=False, # 标记为未完成
|
|
|
|
|
- model=request.model or config.MODEL_NAME, # 使用的模型
|
|
|
|
|
- timestamp=datetime.now() # 当前时间戳
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # 处理不同类型的流式事件
|
|
|
|
|
+ chunk_dict = chunk.__dict__ if hasattr(chunk, '__dict__') else {}
|
|
|
|
|
+ event_type = chunk_dict.get('type', '')
|
|
|
|
|
+
|
|
|
|
|
+ # 处理文本内容增量事件(普通文本)
|
|
|
|
|
+ if event_type == 'response.output_text.delta':
|
|
|
|
|
+ delta_text = chunk_dict.get('delta', '')
|
|
|
|
|
+ if delta_text:
|
|
|
|
|
+ # 累积到完整内容中
|
|
|
|
|
+ accumulated_content += delta_text
|
|
|
|
|
+
|
|
|
|
|
+ # 构建流式响应数据对象
|
|
|
|
|
+ response_data = StreamResponse(
|
|
|
|
|
+ content=delta_text, # 本次片段内容
|
|
|
|
|
+ finished=False, # 标记为未完成
|
|
|
|
|
+ model= config.MODEL_NAME, # 使用的模型
|
|
|
|
|
+ timestamp=datetime.now() # 当前时间戳
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- # 格式化为SSE格式并发送
|
|
|
|
|
- # SSE格式: "data: {json_data}\n\n"
|
|
|
|
|
- yield f"data: {response_data.model_dump_json()}\n\n"
|
|
|
|
|
|
|
+ # 格式化为SSE格式并发送
|
|
|
|
|
+ # SSE格式: "data: {json_data}\n\n"
|
|
|
|
|
+ yield f"data: {response_data.model_dump_json()}\n\n"
|
|
|
|
|
+
|
|
|
|
|
+ # 异步让出控制权,避免阻塞事件循环
|
|
|
|
|
+ await asyncio.sleep(0.01)
|
|
|
|
|
+
|
|
|
|
|
+ # 处理DoubaoApp调用的文本输出增量事件
|
|
|
|
|
+ elif event_type == 'response.doubao_app_call_output_text.delta':
|
|
|
|
|
+ delta_text = chunk_dict.get('delta', '')
|
|
|
|
|
+ if delta_text:
|
|
|
|
|
+ # 累积到完整内容中
|
|
|
|
|
+ accumulated_content += delta_text
|
|
|
|
|
+
|
|
|
|
|
+ # 构建流式响应数据对象
|
|
|
|
|
+ response_data = StreamResponse(
|
|
|
|
|
+ content=delta_text, # 本次片段内容
|
|
|
|
|
+ finished=False, # 标记为未完成
|
|
|
|
|
+ model= config.MODEL_NAME, # 使用的模型
|
|
|
|
|
+ timestamp=datetime.now() # 当前时间戳
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 格式化为SSE格式并发送
|
|
|
|
|
+ yield f"data: {response_data.model_dump_json()}\n\n"
|
|
|
|
|
|
|
|
- # 异步让出控制权,避免阻塞事件循环
|
|
|
|
|
- # 这对于处理大量并发请求很重要
|
|
|
|
|
- await asyncio.sleep(0.01)
|
|
|
|
|
|
|
+ # 异步让出控制权,避免阻塞事件循环
|
|
|
|
|
+ await asyncio.sleep(0.01)
|
|
|
|
|
+
|
|
|
|
|
+ # 处理响应完成事件,获取response_id
|
|
|
|
|
+ elif event_type == 'response.completed':
|
|
|
|
|
+ if 'response' in chunk_dict and hasattr(chunk_dict['response'], 'id'):
|
|
|
|
|
+ response_id = chunk_dict['response'].id
|
|
|
|
|
|
|
|
# 流式响应结束后的处理
|
|
# 流式响应结束后的处理
|
|
|
if accumulated_content:
|
|
if accumulated_content:
|
|
@@ -174,22 +262,27 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
final_response = StreamResponse(
|
|
final_response = StreamResponse(
|
|
|
content='', # 结束信号不包含内容
|
|
content='', # 结束信号不包含内容
|
|
|
finished=True, # 标记为已完成
|
|
finished=True, # 标记为已完成
|
|
|
- model=request.model or config.MODEL_NAME,
|
|
|
|
|
|
|
+ model= config.MODEL_NAME,
|
|
|
timestamp=datetime.now()
|
|
timestamp=datetime.now()
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # 将完整的AI回复保存到用户的聊天历史中
|
|
|
|
|
|
|
+ # 将完整的AI回复保存到用户的聊天历史中,包含response_id
|
|
|
chatHistory[username].append(
|
|
chatHistory[username].append(
|
|
|
ChatMessage(
|
|
ChatMessage(
|
|
|
role="assistant",
|
|
role="assistant",
|
|
|
content=accumulated_content,
|
|
content=accumulated_content,
|
|
|
- timestamp=datetime.now()
|
|
|
|
|
|
|
+ timestamp=datetime.now(),
|
|
|
|
|
+ response_id=response_id # 保存response_id用于后续多轮对话
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+
|
|
|
# 发送结束信号
|
|
# 发送结束信号
|
|
|
yield f"data: {final_response.model_dump_json()}\n\n"
|
|
yield f"data: {final_response.model_dump_json()}\n\n"
|
|
|
|
|
|
|
|
|
|
+ # 在控制台输出提示
|
|
|
|
|
+ print("流式内容已全部输出")
|
|
|
|
|
+
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
# 流式响应过程中的错误处理
|
|
# 流式响应过程中的错误处理
|
|
|
# 构建错误响应并发送给客户端
|
|
# 构建错误响应并发送给客户端
|
|
@@ -249,14 +342,6 @@ async def chat(
|
|
|
if request.stream:
|
|
if request.stream:
|
|
|
# ===== 流式输出处理 =====
|
|
# ===== 流式输出处理 =====
|
|
|
|
|
|
|
|
- # 将用户的最新消息添加到历史记录
|
|
|
|
|
- user_message = ChatMessage(
|
|
|
|
|
- role=request.messages[-1].role,
|
|
|
|
|
- content=request.messages[-1].content,
|
|
|
|
|
- timestamp=datetime.now()
|
|
|
|
|
- )
|
|
|
|
|
- # chatHistory[username].append(user_message)
|
|
|
|
|
-
|
|
|
|
|
# 返回流式响应
|
|
# 返回流式响应
|
|
|
# StreamingResponse 用于处理SSE协议
|
|
# StreamingResponse 用于处理SSE协议
|
|
|
return StreamingResponse(
|
|
return StreamingResponse(
|
|
@@ -272,21 +357,36 @@ async def chat(
|
|
|
else:
|
|
else:
|
|
|
# ===== 非流式输出处理 =====
|
|
# ===== 非流式输出处理 =====
|
|
|
|
|
|
|
|
|
|
+ # 获取最后一条user角色的消息
|
|
|
|
|
+ latest_user_msg = get_latest_user_message(request.messages)
|
|
|
|
|
+ if not latest_user_msg:
|
|
|
|
|
+ raise ValueError("请求中没有找到user角色的消息")
|
|
|
|
|
+
|
|
|
# 将用户消息添加到历史记录
|
|
# 将用户消息添加到历史记录
|
|
|
user_message = ChatMessage(
|
|
user_message = ChatMessage(
|
|
|
- role=request.messages[-1].role,
|
|
|
|
|
- content=request.messages[-1].content,
|
|
|
|
|
|
|
+ role=latest_user_msg.role,
|
|
|
|
|
+ content=latest_user_msg.content,
|
|
|
timestamp=datetime.now()
|
|
timestamp=datetime.now()
|
|
|
)
|
|
)
|
|
|
chatHistory[username].append(user_message)
|
|
chatHistory[username].append(user_message)
|
|
|
|
|
|
|
|
- # 转换消息格式为OpenAI API需要的格式
|
|
|
|
|
- # api_messages = convert_messages_for_api(request.messages)
|
|
|
|
|
- api_messages = [{"role": request.messages[-1].role, "content": request.messages[-1].content}]
|
|
|
|
|
|
|
+ # 转换消息格式为API需要的格式
|
|
|
|
|
+ api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
|
|
|
|
|
|
tools = [{
|
|
tools = [{
|
|
|
- "type": "web_search",
|
|
|
|
|
- "max_keyword": 2, # 可选参数,用于限制一轮搜索的最大关键词数量
|
|
|
|
|
|
|
+ "type": "doubao_app",
|
|
|
|
|
+ "feature": {
|
|
|
|
|
+ "ai_search": {
|
|
|
|
|
+ "type": "enabled",
|
|
|
|
|
+ "role_description": "你是浙江云悦有限公司助手,专业解答云悦问题"
|
|
|
|
|
+ }
|
|
|
|
|
+ },
|
|
|
|
|
+ "user_location": {
|
|
|
|
|
+ "type": "approximate",
|
|
|
|
|
+ "country": "中国",
|
|
|
|
|
+ "region": "浙江",
|
|
|
|
|
+ "city": "杭州"
|
|
|
|
|
+ }
|
|
|
}]
|
|
}]
|
|
|
|
|
|
|
|
# 获取上一轮对话的response_id,用于多轮对话的上下文关联
|
|
# 获取上一轮对话的response_id,用于多轮对话的上下文关联
|
|
@@ -303,33 +403,40 @@ async def chat(
|
|
|
input=api_messages,
|
|
input=api_messages,
|
|
|
tools=tools,
|
|
tools=tools,
|
|
|
stream=False,
|
|
stream=False,
|
|
|
|
|
+ store=True, # 存储当前对话上下文。此字段不存储tools,每次调用仍需给tools赋值。
|
|
|
previous_response_id=previous_response_id,
|
|
previous_response_id=previous_response_id,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 检查API响应是否有效
|
|
# 检查API响应是否有效
|
|
|
if response.output and len(response.output) > 0:
|
|
if response.output and len(response.output) > 0:
|
|
|
- # 从output中找到最后一条消息(ResponseOutputMessage类型)
|
|
|
|
|
- last_message = None
|
|
|
|
|
- for item in reversed(response.output):
|
|
|
|
|
- if hasattr(item, 'type') and item.type == 'message':
|
|
|
|
|
- last_message = item
|
|
|
|
|
- break
|
|
|
|
|
-
|
|
|
|
|
- if last_message and hasattr(last_message, 'content'):
|
|
|
|
|
- # 提取消息内容
|
|
|
|
|
- message_content = ""
|
|
|
|
|
- if isinstance(last_message.content, list):
|
|
|
|
|
- # content是列表,提取所有文本内容
|
|
|
|
|
- for content_item in last_message.content:
|
|
|
|
|
- if hasattr(content_item, 'text'):
|
|
|
|
|
- message_content += content_item.text
|
|
|
|
|
- else:
|
|
|
|
|
- message_content = str(last_message.content)
|
|
|
|
|
-
|
|
|
|
|
|
|
+ # 从output中提取文本内容
|
|
|
|
|
+ message_content = ""
|
|
|
|
|
+
|
|
|
|
|
+ for item in response.output:
|
|
|
|
|
+ # 处理 ItemDoubaoAppCall 类型(包含搜索结果和文本输出)
|
|
|
|
|
+ if hasattr(item, 'type') and item.type == 'doubao_app_call':
|
|
|
|
|
+ if hasattr(item, 'blocks') and item.blocks:
|
|
|
|
|
+ # 从blocks中找到output_text类型的块
|
|
|
|
|
+ for block in item.blocks:
|
|
|
|
|
+ if hasattr(block, 'type') and block.type == 'output_text':
|
|
|
|
|
+ if hasattr(block, 'text'):
|
|
|
|
|
+ message_content += block.text
|
|
|
|
|
+
|
|
|
|
|
+ # 处理其他类型的消息项
|
|
|
|
|
+ elif hasattr(item, 'type') and item.type == 'message':
|
|
|
|
|
+ if hasattr(item, 'content'):
|
|
|
|
|
+ if isinstance(item.content, list):
|
|
|
|
|
+ for content_item in item.content:
|
|
|
|
|
+ if hasattr(content_item, 'text'):
|
|
|
|
|
+ message_content += content_item.text
|
|
|
|
|
+ else:
|
|
|
|
|
+ message_content += str(item.content)
|
|
|
|
|
+
|
|
|
|
|
+ if message_content:
|
|
|
# 构建AI助手的回复消息,包含response_id用于多轮对话
|
|
# 构建AI助手的回复消息,包含response_id用于多轮对话
|
|
|
assistant_message = ChatMessage(
|
|
assistant_message = ChatMessage(
|
|
|
role="assistant",
|
|
role="assistant",
|
|
|
- content=message_content or "",
|
|
|
|
|
|
|
+ content=message_content,
|
|
|
timestamp=datetime.now(),
|
|
timestamp=datetime.now(),
|
|
|
response_id=response.id # 保存response_id用于后续多轮对话
|
|
response_id=response.id # 保存response_id用于后续多轮对话
|
|
|
)
|
|
)
|
|
@@ -346,6 +453,12 @@ async def chat(
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
return chat_response
|
|
return chat_response
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 没有提取到文本内容的错误处理
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=500,
|
|
|
|
|
+ detail="无法从AI响应中提取文本内容"
|
|
|
|
|
+ )
|
|
|
else:
|
|
else:
|
|
|
# API返回空响应的错误处理
|
|
# API返回空响应的错误处理
|
|
|
raise HTTPException(
|
|
raise HTTPException(
|