|
@@ -1,14 +1,19 @@
|
|
|
from fastapi import APIRouter, HTTPException, Depends
|
|
from fastapi import APIRouter, HTTPException, Depends
|
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
+from fastapi.security import OAuth2PasswordBearer
|
|
|
from volcenginesdkarkruntime import Ark
|
|
from volcenginesdkarkruntime import Ark
|
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
|
from typing import List, Optional, Dict, Any, Annotated
|
|
from typing import List, Optional, Dict, Any, Annotated
|
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
|
import json
|
|
import json
|
|
|
import asyncio
|
|
import asyncio
|
|
|
|
|
+import jwt
|
|
|
from ..config.config import Config
|
|
from ..config.config import Config
|
|
|
-from ..routers.users import get_current_active_user, User
|
|
|
|
|
|
|
+from ..routers.users import get_current_active_user, User, SECRET_KEY, ALGORITHM
|
|
|
from ..db.mongo import save_chat_log
|
|
from ..db.mongo import save_chat_log
|
|
|
|
|
+from ..db.redis_client import get_app_user
|
|
|
|
|
+
|
|
|
|
|
+oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="/users/token", auto_error=False)
|
|
|
|
|
|
|
|
# =====================================================
|
|
# =====================================================
|
|
|
# 全局变量和配置
|
|
# 全局变量和配置
|
|
@@ -60,6 +65,8 @@ class ChatRequest(BaseModel):
|
|
|
temperature: Optional[float] = config.TEMPERATURE # 创造性温度值(0-2),控制回答的随机性
|
|
temperature: Optional[float] = config.TEMPERATURE # 创造性温度值(0-2),控制回答的随机性
|
|
|
max_tokens: Optional[int] = config.MAX_TOKENS # 最大生成token数,限制回答长度
|
|
max_tokens: Optional[int] = config.MAX_TOKENS # 最大生成token数,限制回答长度
|
|
|
stream: Optional[bool] = False # 是否启用流式输出,True时会实时返回生成的内容
|
|
stream: Optional[bool] = False # 是否启用流式输出,True时会实时返回生成的内容
|
|
|
|
|
+ source: Optional[str] = None # 来源标识,source=app 时走第三方token认证
|
|
|
|
|
+ token: Optional[str] = None # App端传入的第三方token,source=app时必填
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatResponse(BaseModel):
|
|
class ChatResponse(BaseModel):
|
|
@@ -321,36 +328,28 @@ async def generate_stream_response(request: ChatRequest, username: str):
|
|
|
@router.post("/chat", response_model=ChatResponse)
|
|
@router.post("/chat", response_model=ChatResponse)
|
|
|
async def chat(
|
|
async def chat(
|
|
|
request: ChatRequest,
|
|
request: ChatRequest,
|
|
|
- current_user: Annotated[User, Depends(get_current_active_user)]
|
|
|
|
|
|
|
+ jwt_token: Annotated[Optional[str], Depends(oauth2_scheme_optional)] = None
|
|
|
):
|
|
):
|
|
|
- """
|
|
|
|
|
- 聊天对话接口 - 需要登录认证
|
|
|
|
|
-
|
|
|
|
|
- 这是核心的聊天接口,支持流式和非流式两种模式:
|
|
|
|
|
- - 非流式: 等待AI完整回答后一次性返回
|
|
|
|
|
- - 流式: 实时返回AI生成的内容片段
|
|
|
|
|
-
|
|
|
|
|
- 安全特性:
|
|
|
|
|
- - 需要有效的JWT令牌
|
|
|
|
|
- - 自动配额检查和限制
|
|
|
|
|
- - 用户数据隔离
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- request (ChatRequest): 聊天请求数据
|
|
|
|
|
- current_user (User): 通过JWT认证获取的当前用户信息
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- ChatResponse: 非流式模式的完整响应
|
|
|
|
|
- StreamingResponse: 流式模式的SSE响应
|
|
|
|
|
-
|
|
|
|
|
- Raises:
|
|
|
|
|
- HTTPException:
|
|
|
|
|
- - 429: 配额已用完
|
|
|
|
|
- - 500: AI模型调用失败或其他服务器错误
|
|
|
|
|
- """
|
|
|
|
|
try:
|
|
try:
|
|
|
- # 从认证信息中获取用户名,确保数据安全
|
|
|
|
|
- username = current_user.username
|
|
|
|
|
|
|
+ # ===== 认证分支 =====
|
|
|
|
|
+ if request.source == "app" and request.token:
|
|
|
|
|
+ # 第三方 App token 认证
|
|
|
|
|
+ app_user = get_app_user(request.token)
|
|
|
|
|
+ if not app_user:
|
|
|
|
|
+ raise HTTPException(status_code=401, detail="无效的 App token")
|
|
|
|
|
+ username = f"app_{app_user['userId']}"
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 原有本地 JWT 认证
|
|
|
|
|
+ if not jwt_token:
|
|
|
|
|
+ raise HTTPException(status_code=401, detail="未提供认证令牌")
|
|
|
|
|
+ try:
|
|
|
|
|
+ payload = jwt.decode(jwt_token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
|
|
+ sub = payload.get("sub")
|
|
|
|
|
+ if not sub:
|
|
|
|
|
+ raise HTTPException(status_code=401, detail="无效的令牌")
|
|
|
|
|
+ except jwt.PyJWTError:
|
|
|
|
|
+ raise HTTPException(status_code=401, detail="无效的令牌")
|
|
|
|
|
+ username = sub
|
|
|
|
|
|
|
|
# 初始化用户的聊天历史记录(如果不存在)
|
|
# 初始化用户的聊天历史记录(如果不存在)
|
|
|
if username not in chatHistory:
|
|
if username not in chatHistory:
|