Browse Source

feat:新增一种APP的认证方式,用redis缓存用户信息

zhangwl 1 month ago
parent
commit
f85bc384e4
4 changed files with 107 additions and 30 deletions
  1. 6 0
      README.md
  2. 2 1
      app/config/config.py
  3. 71 0
      app/db/redis_client.py
  4. 28 29
      app/routers/chat.py

+ 6 - 0
README.md

@@ -8,3 +8,9 @@ pip install fastapi
 pip install openai
 # 火山引擎的openAI
 pip install 'volcengine-python-sdk[ark]'
+# mongodb
+pip install "pymongo==3.13.0"
+# redis
+pip install "redis==3.5.3"
+# HTTP 请求库,用于向服务器发送网络请求
+pip install requests

+ 2 - 1
app/config/config.py

@@ -7,7 +7,8 @@ class Config:
     # 使用阿里云DashScope API
     API_KEY = os.getenv("DASHSCOPE_API_KEY")
     BASE_URL = os.getenv("DASHSCOPE_BASE_URL")
-    MODEL_NAME = "doubao-seed-2-0-mini-260215"
+    # MODEL_NAME = "doubao-seed-2-0-mini-260215"
+    MODEL_NAME = "doubao-seed-1-8-251228"
     MAX_TOKENS = 2000
     TEMPERATURE = 0.7
 

+ 71 - 0
app/db/redis_client.py

@@ -0,0 +1,71 @@
+import redis
+import json
+import requests
+from typing import Optional
+
+REDIS_HOST = "r-2zeitjlg0gdypzb4v6.redis.rds.aliyuncs.com"
+REDIS_PORT = 6379
+REDIS_TTL = 1800  # 30分钟
+
+THIRD_PARTY_API = "http://eco.zhongsou.com/eco/user/user.redis.info.groovy"
+
+try:
+    redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
+except Exception:
+    redis_client = None
+
+
+def get_app_user(token: str) -> Optional[dict]:
+    """
+    通过第三方 token 获取用户信息
+    先查 Redis 缓存,未命中则调第三方接口并写入缓存
+
+    Args:
+        token: App 端传入的第三方 token
+
+    Returns:
+        用户信息字典,token 无效或接口异常时返回 None
+    """
+    cache_key = f"app_user:{token}"
+
+    # 1. 查 Redis 缓存
+    try:
+        if redis_client:
+            cached = redis_client.get(cache_key)
+            if cached:
+                return json.loads(cached)
+    except Exception as e:
+        print(f"Redis 读取失败,降级调第三方接口: {e}")
+
+    # 2. 调第三方接口
+    try:
+        resp = requests.get(THIRD_PARTY_API, params={"token": token}, timeout=5)
+        data = resp.json()
+
+        if data.get("head", {}).get("status") != 200:
+            return None
+
+        body = data.get("body")
+        if not body or not body.get("userId"):
+            return None
+
+        user_info = {
+            "userId": body["userId"],
+            "userName": body["userName"],
+            "nickName": body["nickName"],
+            "mobile": body["mobile"],
+            "imageUrl": body.get("imageUrl", ""),
+        }
+
+        # 3. 写入 Redis 缓存
+        try:
+            if redis_client:
+                redis_client.setex(cache_key, REDIS_TTL, json.dumps(user_info))
+        except Exception as e:
+            print(f"Redis 写入失败: {e}")
+
+        return user_info
+
+    except Exception as e:
+        print(f"第三方接口调用失败: {e}")
+        return None

+ 28 - 29
app/routers/chat.py

@@ -1,14 +1,19 @@
 from fastapi import APIRouter, HTTPException, Depends
 from fastapi.responses import StreamingResponse
+from fastapi.security import OAuth2PasswordBearer
 from volcenginesdkarkruntime import Ark
 from pydantic import BaseModel
 from typing import List, Optional, Dict, Any, Annotated
 from datetime import datetime
 import json
 import asyncio
+import jwt
 from ..config.config import Config
-from ..routers.users import get_current_active_user, User
+from ..routers.users import get_current_active_user, User, SECRET_KEY, ALGORITHM
 from ..db.mongo import save_chat_log
+from ..db.redis_client import get_app_user
+
+oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="/users/token", auto_error=False)
 
 # =====================================================
 # 全局变量和配置
@@ -60,6 +65,8 @@ class ChatRequest(BaseModel):
     temperature: Optional[float] = config.TEMPERATURE  # 创造性温度值(0-2),控制回答的随机性
     max_tokens: Optional[int] = config.MAX_TOKENS  # 最大生成token数,限制回答长度
     stream: Optional[bool] = False  # 是否启用流式输出,True时会实时返回生成的内容
+    source: Optional[str] = None  # 来源标识,source=app 时走第三方token认证
+    token: Optional[str] = None  # App端传入的第三方token,source=app时必填
 
 
 class ChatResponse(BaseModel):
@@ -321,36 +328,28 @@ async def generate_stream_response(request: ChatRequest, username: str):
 @router.post("/chat", response_model=ChatResponse)
 async def chat(
         request: ChatRequest,
-        current_user: Annotated[User, Depends(get_current_active_user)]
+        jwt_token: Annotated[Optional[str], Depends(oauth2_scheme_optional)] = None
 ):
-    """
-    聊天对话接口 - 需要登录认证
-
-    这是核心的聊天接口,支持流式和非流式两种模式:
-    - 非流式: 等待AI完整回答后一次性返回
-    - 流式: 实时返回AI生成的内容片段
-
-    安全特性:
-    - 需要有效的JWT令牌
-    - 自动配额检查和限制
-    - 用户数据隔离
-
-    Args:
-        request (ChatRequest): 聊天请求数据
-        current_user (User): 通过JWT认证获取的当前用户信息
-
-    Returns:
-        ChatResponse: 非流式模式的完整响应
-        StreamingResponse: 流式模式的SSE响应
-
-    Raises:
-        HTTPException:
-            - 429: 配额已用完
-            - 500: AI模型调用失败或其他服务器错误
-    """
     try:
-        # 从认证信息中获取用户名,确保数据安全
-        username = current_user.username
+        # ===== 认证分支 =====
+        if request.source == "app" and request.token:
+            # 第三方 App token 认证
+            app_user = get_app_user(request.token)
+            if not app_user:
+                raise HTTPException(status_code=401, detail="无效的 App token")
+            username = f"app_{app_user['userId']}"
+        else:
+            # 原有本地 JWT 认证
+            if not jwt_token:
+                raise HTTPException(status_code=401, detail="未提供认证令牌")
+            try:
+                payload = jwt.decode(jwt_token, SECRET_KEY, algorithms=[ALGORITHM])
+                sub = payload.get("sub")
+                if not sub:
+                    raise HTTPException(status_code=401, detail="无效的令牌")
+            except jwt.PyJWTError:
+                raise HTTPException(status_code=401, detail="无效的令牌")
+            username = sub
 
         # 初始化用户的聊天历史记录(如果不存在)
         if username not in chatHistory: