Browse Source

feat:完善jwt机制,实现token续约

zhangwl 1 month ago
parent
commit
9615deef13
1 changed files with 86 additions and 17 deletions
  1. 86 17
      app/routers/users.py

+ 86 - 17
app/routers/users.py

@@ -1,11 +1,13 @@
 from datetime import datetime, timedelta, timezone
 from typing import Optional, Annotated
 import jwt
+import hashlib
 from fastapi import APIRouter, Depends, HTTPException, status
 from fastapi.security import OAuth2PasswordBearer
 from pydantic import BaseModel
 from pwdlib import PasswordHash
 from ..config.config import Config
+from ..db.redis_client import redis_client
 
 config = Config()
 # =====================================================
@@ -49,6 +51,7 @@ class Token(BaseModel):
     """
     message: str
     access_token: str  # JWT访问令牌
+    refresh_token: str  # JWT刷新令牌
     token_type: str  # 令牌类型,通常是"bearer"
     username: str
 
@@ -116,7 +119,8 @@ fake_users_db = {
         # 这是"admin"的bcrypt哈希值
         "hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$AnPKKWadE68rd3n9CSM7vQ$p0LA3GqVSYxGRopB0B1yzzlO7W1LRCWagShF+Sbre9I",
         "disabled": False,
-    },
+    }
+    ,
     "admin": {
         "userId":"2",
         "username": "admin",
@@ -201,30 +205,26 @@ def authenticate_user(fake_db: dict, username: str, password: str) -> Optional[U
     return user
 
 
-def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
+def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, token_type: str = "access") -> str:
     """
-    创建JWT访问令牌
+    创建JWT令牌(access 或 refresh)
 
     Args:
         data (dict): 要编码到令牌中的数据(通常包含用户标识)
         expires_delta (Optional[timedelta]): 令牌过期时间,如果不提供则使用默认值
+        token_type (str): 令牌类型,"access" 或 "refresh"
 
     Returns:
         str: 编码后的JWT令牌
     """
-    # 复制数据以避免修改原始数据
     to_encode = data.copy()
 
-    # 计算过期时间,没有设置过期时间,就默认设置15分钟
     if expires_delta:
         expire = datetime.now(timezone.utc) + expires_delta
     else:
         expire = datetime.now(timezone.utc) + timedelta(minutes=15)
 
-    # 添加过期时间到令牌数据中
-    to_encode.update({"exp": expire})
-
-    # 使用密钥和算法对数据进行编码
+    to_encode.update({"exp": expire, "type": token_type})
     encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM)
     return encoded_jwt
 
@@ -334,19 +334,27 @@ async def login_for_access_token(
             headers={"WWW-Authenticate": "Bearer"},
         )
 
-    # 设置令牌过期时间
-    access_token_expires = timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)
-
-    # 创建访问令牌,将用户名作为subject存储在令牌中
+    # 创建访问令牌和刷新令牌
     access_token = create_access_token(
         data={"sub": user.username, "userId": user.userId},
-        expires_delta=access_token_expires
+        expires_delta=timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES),
+        token_type="access"
+    )
+    refresh_token = create_access_token(
+        data={"sub": user.username, "userId": user.userId},
+        expires_delta=timedelta(days=7),
+        token_type="refresh"
     )
 
+    # 存储刷新令牌哈希到Redis
+    refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
+    redis_client.setex(f"refresh:{user.userId}", 7*24*60*60, refresh_hash)
+
     # 返回令牌和令牌类型
     return {
         "message": "登录成功",
         "access_token": access_token,
+        "refresh_token": refresh_token,
         "token_type": "bearer",
         "username": user.username
     }
@@ -401,8 +409,7 @@ async def logout(
     """
     用户退出登录端点
 
-    由于JWT是无状态的,服务端不需要做特殊处理
-    主要是返回成功消息,让前端清除本地存储的token
+    删除Redis中的刷新令牌,让前端清除本地存储的token
 
     Args:
         current_user (User): 通过依赖注入获取的当前用户信息
@@ -410,7 +417,7 @@ async def logout(
     Returns:
         dict: 退出成功的消息
     """
-    print(current_user.username)
+    redis_client.delete(f"refresh:{current_user.userId}")
     return {
         "message": "退出登录成功",
         "username": current_user.username,
@@ -418,6 +425,68 @@ async def logout(
     }
 
 
+@router.post("/refresh", response_model=Token, summary="刷新令牌", description="使用刷新令牌获取新的访问令牌")
+async def refresh_access_token(refresh_token: str) -> Token:
+    """
+    刷新令牌端点
+
+    使用刷新令牌换取新的访问令牌
+
+    Args:
+        refresh_token (str): 刷新令牌
+
+    Returns:
+        Token: 包含新访问令牌的对象
+
+    Raises:
+        HTTPException: 如果刷新令牌无效或已过期
+    """
+    credentials_exception = HTTPException(
+        status_code=status.HTTP_401_UNAUTHORIZED,
+        detail="Invalid refresh token",
+        headers={"WWW-Authenticate": "Bearer"},
+    )
+
+    try:
+        payload = jwt.decode(refresh_token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
+        if payload.get("type") != "refresh":
+            raise credentials_exception
+
+        username: str = payload.get("sub")
+        user_id: str = payload.get("userId")
+        if not username or not user_id:
+            raise credentials_exception
+
+    except jwt.PyJWTError:
+        raise credentials_exception
+
+    # 验证Redis中的刷新令牌哈希
+    refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
+    stored_hash = redis_client.get(f"refresh:{user_id}")
+    if not stored_hash or stored_hash.decode() != refresh_hash:
+        raise credentials_exception
+
+    # 验证用户是否存在
+    user = get_user(fake_users_db, username)
+    if not user:
+        raise credentials_exception
+
+    # 生成新的访问令牌
+    new_access_token = create_access_token(
+        data={"sub": username, "userId": user_id},
+        expires_delta=timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES),
+        token_type="access"
+    )
+
+    return {
+        "message": "令牌刷新成功",
+        "access_token": new_access_token,
+        "refresh_token": refresh_token,
+        "token_type": "bearer",
+        "username": username
+    }
+
+
 @router.get("/me", response_model=User, summary="获取用户信息", description="获取当前登录用户的个人信息")
 async def read_users_me(
         current_user: Annotated[User, Depends(get_current_active_user)]