| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670 |
- from fastapi import APIRouter, HTTPException, Depends
- from fastapi.responses import StreamingResponse
- from volcenginesdkarkruntime import Ark
- from pydantic import BaseModel
- from typing import List, Optional, Dict, Any, Annotated
- from datetime import datetime
- import json
- import asyncio
- from ..config.config import Config
- from ..routers.users import get_current_active_user, User
- from ..db.mongo import save_chat_log
- # =====================================================
- # 全局变量和配置
- # =====================================================
- # 内存存储用户聊天历史
- # 结构: {username: [ChatMessage, ChatMessage, ...]}
- # 注意: 生产环境中应该使用数据库存储,如Redis或PostgreSQL
- chatHistory = {}
- # 创建配置实例
- # 从配置文件中读取API密钥、模型名称等设置
- config = Config()
- # 创建OpenAI客户端实例
- # 使用配置中的API密钥和基础URL
- client = Ark(api_key=config.API_KEY, base_url=config.BASE_URL)
- # 创建FastAPI路由器实例
- # 所有的聊天相关路由都会注册到这个路由器上
- router = APIRouter()
- # =====================================================
- # Pydantic数据模型定义
- # =====================================================
- class ChatMessage(BaseModel):
- """
- 聊天消息数据模型
- 用于表示对话中的单条消息,包含角色、内容和时间戳
- """
- role: str # 消息角色: "user"(用户), "assistant"(AI助手), "system"(系统)
- content: str # 消息的文本内容
- timestamp: Optional[datetime] = None # 消息创建时间戳(可选)
- response_id: Optional[str] = None # API响应ID,用于多轮对话的上下文关联
- class ChatRequest(BaseModel):
- """
- 客户端聊天请求数据模型
- 定义了客户端发送聊天请求时需要包含的所有参数
- 注意: 移除了userId字段,改为从JWT认证中获取用户身份
- """
- messages: List[ChatMessage] # 对话历史消息列表,包含用户和助手的所有消息
- model: Optional[str] = config.MODEL_NAME # 要使用的AI模型名称,默认使用配置中的模型
- temperature: Optional[float] = config.TEMPERATURE # 创造性温度值(0-2),控制回答的随机性
- max_tokens: Optional[int] = config.MAX_TOKENS # 最大生成token数,限制回答长度
- stream: Optional[bool] = False # 是否启用流式输出,True时会实时返回生成的内容
- class ChatResponse(BaseModel):
- """
- 服务端聊天响应数据模型(非流式)
- 用于非流式请求的响应,包含完整的AI回答和使用统计
- """
- message: ChatMessage # AI助手的回复消息
- model: str # 实际使用的模型名称
- usage: Optional[Dict[str, Any]] = None # token使用情况统计
- response_id: Optional[str] = None # API响应ID,用于后续多轮对话
- class StreamResponse(BaseModel):
- """
- 流式响应数据模型
- 用于流式输出时的每个数据块,支持Server-Sent Events (SSE)
- """
- content: str # 本次返回的内容片段
- finished: bool # 是否为最后一个片段,True表示流式响应结束
- model: str # 使用的模型名称
- timestamp: datetime # 当前片段的时间戳
- # =====================================================
- # 消息处理工具函数
- # =====================================================
- def convert_messages_for_api(messages: List[ChatMessage]) -> List[Dict[str, str]]:
- """
- 将自定义ChatMessage转换为OpenAI API需要的格式
- OpenAI API需要的消息格式是字典列表,每个字典包含role和content字段
- Args:
- messages (List[ChatMessage]): 自定义的消息对象列表
- Returns:
- List[Dict[str, str]]: OpenAI API格式的消息列表
- Example:
- 输入: [ChatMessage(role="user", content="你好")]
- 输出: [{"role": "user", "content": "你好"}]
- """
- return [{"role": msg.role, "content": msg.content} for msg in messages]
- def get_latest_user_message(messages: List[ChatMessage]) -> Optional[ChatMessage]:
- """
- 获取消息列表中最后一条user角色的消息
- 在多轮对话中,消息列表可能包含user和assistant的消息,
- 流式场景下客户端会预先添加空的assistant消息作为占位符,
- 此函数确保获取到最后一条用户发送的消息
- Args:
- messages (List[ChatMessage]): 消息列表
- Returns:
- Optional[ChatMessage]: 最后一条user角色的消息,如果不存在则返回None
- """
- for message in reversed(messages):
- if message.role == "user":
- return message
- return None
- async def generate_stream_response(request: ChatRequest, username: str):
- """
- 生成流式响应的异步生成器
- 这个函数处理流式AI响应,将Ark API的流式输出转换为SSE格式
- Args:
- request (ChatRequest): 聊天请求对象
- username (str): 当前用户名
- Yields:
- str: 格式化的SSE数据,每行以"data: "开头
- Note:
- 使用Server-Sent Events (SSE) 协议进行实时数据传输
- 客户端需要使用EventSource或类似技术接收流式数据
- """
- try:
- # 获取最后一条user角色的消息
- latest_user_msg = get_latest_user_message(request.messages)
- if not latest_user_msg:
- raise ValueError("请求中没有找到user角色的消息")
- # 将用户消息添加到历史记录
- user_message = ChatMessage(
- role=latest_user_msg.role,
- content=latest_user_msg.content,
- timestamp=datetime.now()
- )
- chatHistory[username].append(user_message)
- # 转换消息格式为API需要的格式
- api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
- tools = [{
- "type": "doubao_app",
- "feature": {
- "ai_search": {
- "type": "enabled",
- "role_description": "你是浙江云悦有限公司助手,专业解答云悦问题"
- }
- },
- "user_location": {
- "type": "approximate",
- "country": "中国",
- "region": "浙江",
- "city": "杭州"
- }
- }]
- # 获取上一轮对话的response_id,用于多轮对话的上下文关联
- previous_response_id = None
- if username in chatHistory and len(chatHistory[username]) > 0:
- # 从后往前查找最后一条assistant消息的response_id
- for message in reversed(chatHistory[username]):
- if message.role == "assistant" and message.response_id:
- previous_response_id = message.response_id
- break
- # stream=True 启用流式输出,API会返回一个迭代器
- stream = client.responses.create(
- model=config.MODEL_NAME,
- input=api_messages,
- tools=tools,
- stream=True,
- store=True, # 存储当前对话上下文。此字段不存储tools,每次调用仍需给tools赋值。
- previous_response_id=previous_response_id,
- )
- # 用于累积完整的回答内容和response_id
- accumulated_content = ""
- response_id = None
- # 遍历流式响应的每个数据块
- for chunk in stream:
- # 处理不同类型的流式事件
- chunk_dict = chunk.__dict__ if hasattr(chunk, '__dict__') else {}
- event_type = chunk_dict.get('type', '')
- # 处理文本内容增量事件(普通文本)
- if event_type == 'response.output_text.delta':
- delta_text = chunk_dict.get('delta', '')
- if delta_text:
- # 累积到完整内容中
- accumulated_content += delta_text
- # 构建流式响应数据对象
- response_data = StreamResponse(
- content=delta_text, # 本次片段内容
- finished=False, # 标记为未完成
- model= config.MODEL_NAME, # 使用的模型
- timestamp=datetime.now() # 当前时间戳
- )
- # 格式化为SSE格式并发送
- # SSE格式: "data: {json_data}\n\n"
- yield f"data: {response_data.model_dump_json()}\n\n"
- # 异步让出控制权,避免阻塞事件循环
- await asyncio.sleep(0.01)
- # 处理DoubaoApp调用的文本输出增量事件
- elif event_type == 'response.doubao_app_call_output_text.delta':
- delta_text = chunk_dict.get('delta', '')
- if delta_text:
- # 累积到完整内容中
- accumulated_content += delta_text
- # 构建流式响应数据对象
- response_data = StreamResponse(
- content=delta_text, # 本次片段内容
- finished=False, # 标记为未完成
- model= config.MODEL_NAME, # 使用的模型
- timestamp=datetime.now() # 当前时间戳
- )
- # 格式化为SSE格式并发送
- yield f"data: {response_data.model_dump_json()}\n\n"
- # 异步让出控制权,避免阻塞事件循环
- await asyncio.sleep(0.01)
- # 处理响应完成事件,获取response_id并记录原始响应日志
- elif event_type == 'response.completed':
- if 'response' in chunk_dict and hasattr(chunk_dict['response'], 'id'):
- response_id = chunk_dict['response'].id
- # 记录原始响应日志
- save_chat_log(
- username=username,
- question=latest_user_msg.content,
- stream_mode=True,
- raw_response=repr(chunk_dict.get('response')),
- status="success",
- )
- # 流式响应结束后的处理
- if accumulated_content:
- # 构建结束信号响应
- final_response = StreamResponse(
- content='', # 结束信号不包含内容
- finished=True, # 标记为已完成
- model= config.MODEL_NAME,
- timestamp=datetime.now()
- )
- # 将完整的AI回复保存到用户的聊天历史中,包含response_id
- chatHistory[username].append(
- ChatMessage(
- role="assistant",
- content=accumulated_content,
- timestamp=datetime.now(),
- response_id=response_id # 保存response_id用于后续多轮对话
- )
- )
- # 发送结束信号
- yield f"data: {final_response.model_dump_json()}\n\n"
- # 在控制台输出提示
- print("流式内容已全部输出")
- except Exception as e:
- # 流式响应过程中的错误处理
- # 构建错误响应并发送给客户端
- error_response = {
- "error": str(e), # 错误信息
- "finished": True, # 标记为结束
- "timestamp": datetime.now().isoformat() # 错误发生时间
- }
- # 记录错误日志到 MongoDB
- save_chat_log(
- username=username,
- question=latest_user_msg.content if latest_user_msg else "",
- stream_mode=True,
- status="error",
- error=str(e),
- )
- # 发送错误信息
- yield f"data: {json.dumps(error_response)}\n\n"
- # =====================================================
- # API路由端点定义
- # =====================================================
- @router.post("/chat", response_model=ChatResponse)
- async def chat(
- request: ChatRequest,
- current_user: Annotated[User, Depends(get_current_active_user)]
- ):
- """
- 聊天对话接口 - 需要登录认证
- 这是核心的聊天接口,支持流式和非流式两种模式:
- - 非流式: 等待AI完整回答后一次性返回
- - 流式: 实时返回AI生成的内容片段
- 安全特性:
- - 需要有效的JWT令牌
- - 自动配额检查和限制
- - 用户数据隔离
- Args:
- request (ChatRequest): 聊天请求数据
- current_user (User): 通过JWT认证获取的当前用户信息
- Returns:
- ChatResponse: 非流式模式的完整响应
- StreamingResponse: 流式模式的SSE响应
- Raises:
- HTTPException:
- - 429: 配额已用完
- - 500: AI模型调用失败或其他服务器错误
- """
- try:
- # 从认证信息中获取用户名,确保数据安全
- username = current_user.username
- # 初始化用户的聊天历史记录(如果不存在)
- if username not in chatHistory:
- chatHistory[username] = []
- # 根据请求类型处理:流式 vs 非流式
- if request.stream:
- # ===== 流式输出处理 =====
- # 返回流式响应
- # StreamingResponse 用于处理SSE协议
- return StreamingResponse(
- generate_stream_response(request, username), # 异步生成器
- media_type="text/plain", # 媒体类型
- headers={
- "Cache-Control": "no-cache", # 禁用缓存
- "Connection": "keep-alive", # 保持连接
- "Content-Type": "text/event-stream", # SSE内容类型
- }
- )
- else:
- # ===== 非流式输出处理 =====
- # 获取最后一条user角色的消息
- latest_user_msg = get_latest_user_message(request.messages)
- if not latest_user_msg:
- raise ValueError("请求中没有找到user角色的消息")
- # 将用户消息添加到历史记录
- user_message = ChatMessage(
- role=latest_user_msg.role,
- content=latest_user_msg.content,
- timestamp=datetime.now()
- )
- chatHistory[username].append(user_message)
- # 转换消息格式为API需要的格式
- api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
- tools = [{
- "type": "doubao_app",
- "feature": {
- "ai_search": {
- "type": "enabled",
- "role_description": "你是浙江云悦有限公司助手,专业解答云悦问题"
- }
- },
- "user_location": {
- "type": "approximate",
- "country": "中国",
- "region": "浙江",
- "city": "杭州"
- }
- }]
- # 获取上一轮对话的response_id,用于多轮对话的上下文关联
- previous_response_id = None
- if username in chatHistory and len(chatHistory[username]) > 0:
- # 从后往前查找最后一条assistant消息的response_id
- for message in reversed(chatHistory[username]):
- if message.role == "assistant" and message.response_id:
- previous_response_id = message.response_id
- break
- response = client.responses.create(
- model=config.MODEL_NAME,
- input=api_messages,
- tools=tools,
- stream=False,
- store=True, # 存储当前对话上下文。此字段不存储tools,每次调用仍需给tools赋值。
- previous_response_id=previous_response_id,
- )
- # 记录原始响应日志到 MongoDB(解析前)
- save_chat_log(
- username=username,
- question=latest_user_msg.content,
- stream_mode=False,
- raw_response=repr(response),
- status="success",
- )
- # 检查API响应是否有效
- if response.output and len(response.output) > 0:
- # 从output中提取文本内容
- message_content = ""
- for item in response.output:
- # 处理 ItemDoubaoAppCall 类型(包含搜索结果和文本输出)
- if hasattr(item, 'type') and item.type == 'doubao_app_call':
- if hasattr(item, 'blocks') and item.blocks:
- # 从blocks中找到output_text类型的块
- for block in item.blocks:
- if hasattr(block, 'type') and block.type == 'output_text':
- if hasattr(block, 'text'):
- message_content += block.text
- # 处理其他类型的消息项
- elif hasattr(item, 'type') and item.type == 'message':
- if hasattr(item, 'content'):
- if isinstance(item.content, list):
- for content_item in item.content:
- if hasattr(content_item, 'text'):
- message_content += content_item.text
- else:
- message_content += str(item.content)
- if message_content:
- # 构建AI助手的回复消息,包含response_id用于多轮对话
- assistant_message = ChatMessage(
- role="assistant",
- content=message_content,
- timestamp=datetime.now(),
- response_id=response.id # 保存response_id用于后续多轮对话
- )
- # 将AI回复添加到用户的聊天历史
- chatHistory[username].append(assistant_message)
- # 保存聊天日志到 MongoDB
- save_chat_log(
- username=username,
- question=latest_user_msg.content,
- answer=message_content,
- stream_mode=False,
- response_id=response.id,
- )
- # 构建完整的响应对象
- chat_response = ChatResponse(
- message=assistant_message, # AI回复消息
- model=response.model, # 实际使用的模型
- usage=response.usage.model_dump() if response.usage else None, # token使用统计
- response_id=response.id # 返回response_id供客户端保存
- )
- return chat_response
- else:
- # 没有提取到文本内容的错误处理
- raise HTTPException(
- status_code=500,
- detail="无法从AI响应中提取文本内容"
- )
- else:
- # API返回空响应的错误处理
- raise HTTPException(
- status_code=500,
- detail="AI模型返回了空响应"
- )
- except HTTPException:
- # 重新抛出HTTP异常(如配额限制)
- raise
- except Exception as e:
- # 捕获所有其他异常并转换为HTTP异常
- error_message = f"处理聊天请求时发生错误: {str(e)}"
- save_chat_log(
- username=current_user.username,
- question=request.messages[-1].content if request.messages else "",
- stream_mode=request.stream,
- status="error",
- error=error_message,
- )
- raise HTTPException(status_code=500, detail=error_message)
- @router.get("/models")
- async def get_models(
- current_user: Annotated[User, Depends(get_current_active_user)]
- ):
- """
- 获取可用的AI模型列表 - 需要登录
- 返回当前可用的AI模型列表,用户可以在聊天时选择使用
- Args:
- current_user (User): 通过JWT认证获取的当前用户信息
- Returns:
- dict: 包含模型列表和默认模型的字典
- Note:
- 如果无法获取模型列表,会返回默认配置
- """
- try:
- # 尝试从OpenAI API获取可用模型列表
- models = client.models.list()
- return {
- "models": [model.id for model in models.data], # 模型ID列表
- "default_model": config.MODEL_NAME, # 默认模型
- "user": current_user.username # 请求用户
- }
- except Exception as e:
- # 如果获取模型列表失败,返回默认配置
- return {
- "models": [config.MODEL_NAME], # 只返回默认模型
- "default_model": config.MODEL_NAME,
- "note": "使用默认模型配置", # 提示信息
- "user": current_user.username
- }
- @router.get("/history")
- async def get_user_history(
- current_user: Annotated[User, Depends(get_current_active_user)]
- ) -> List[ChatMessage]:
- """
- 获取当前用户的聊天历史 - 安全版本
- 只返回当前认证用户的聊天历史,确保数据隐私
- Args:
- current_user (User): 通过JWT认证获取的当前用户信息
- Returns:
- List[ChatMessage]: 用户的历史消息列表
- Security:
- 用户只能访问自己的聊天历史,无法访问他人数据
- """
- username = current_user.username
- # 如果用户没有聊天历史,返回空列表
- if username not in chatHistory:
- return []
- # 返回用户的完整聊天历史
- # 创建新的ChatMessage对象确保数据一致性
- return [
- ChatMessage(
- role=msg.role,
- content=msg.content,
- timestamp=msg.timestamp
- )
- for msg in chatHistory[username]
- ]
- @router.delete("/history")
- async def clear_user_history(
- current_user: Annotated[User, Depends(get_current_active_user)]
- ):
- """
- 清空当前用户的聊天历史
- 删除用户的所有聊天记录,此操作不可逆
- Args:
- current_user (User): 通过JWT认证获取的当前用户信息
- Returns:
- dict: 操作确认信息
- Warning:
- 此操作会永久删除用户的聊天历史,无法恢复
- """
- username = current_user.username
- # 如果用户有聊天历史,则删除
- if username in chatHistory:
- # 获取删除前的消息数量用于统计
- message_count = len(chatHistory[username])
- # 删除用户的聊天历史
- del chatHistory[username]
- return {
- "message": "聊天历史已清空",
- "user": username,
- "deleted_messages": message_count, # 删除的消息数量
- "timestamp": datetime.now()
- }
- else:
- # 用户没有聊天历史
- return {
- "message": "用户没有聊天历史",
- "user": username,
- "deleted_messages": 0,
- "timestamp": datetime.now()
- }
- @router.get("/health")
- async def health_check():
- """
- 健康检查接口 - 无需认证
- 用于监控服务状态,通常被负载均衡器或监控系统调用
- Returns:
- dict: 服务状态信息
- Note:
- 此接口不需要认证,可被任何人访问
- """
- return {
- "status": "healthy", # 服务状态
- "timestamp": datetime.now(), # 当前时间
- "version": "1.0.0", # 服务版本
- "model": config.MODEL_NAME, # 默认AI模型
- }
- # =====================================================
- # 路由器配置和元数据
- # =====================================================
- # 为路由器添加标签和元数据,用于API文档生成
- router.tags = ["聊天服务"]
- router.responses = {
- 401: {"description": "未授权 - 需要有效的JWT令牌"},
- 429: {"description": "请求过多 - 配额已用完"},
- 500: {"description": "服务器内部错误"}
- }
|