Browse Source

feat:完善jwt机制,实现token续约

zhangwl 1 month ago
parent
commit
9615deef13
1 changed files with 86 additions and 17 deletions
  1. 86 17
      app/routers/users.py

+ 86 - 17
app/routers/users.py

@@ -1,11 +1,13 @@
 from datetime import datetime, timedelta, timezone
 from datetime import datetime, timedelta, timezone
 from typing import Optional, Annotated
 from typing import Optional, Annotated
 import jwt
 import jwt
+import hashlib
 from fastapi import APIRouter, Depends, HTTPException, status
 from fastapi import APIRouter, Depends, HTTPException, status
 from fastapi.security import OAuth2PasswordBearer
 from fastapi.security import OAuth2PasswordBearer
 from pydantic import BaseModel
 from pydantic import BaseModel
 from pwdlib import PasswordHash
 from pwdlib import PasswordHash
 from ..config.config import Config
 from ..config.config import Config
+from ..db.redis_client import redis_client
 
 
 config = Config()
 config = Config()
 # =====================================================
 # =====================================================
@@ -49,6 +51,7 @@ class Token(BaseModel):
     """
     """
     message: str
     message: str
     access_token: str  # JWT访问令牌
     access_token: str  # JWT访问令牌
+    refresh_token: str  # JWT刷新令牌
     token_type: str  # 令牌类型,通常是"bearer"
     token_type: str  # 令牌类型,通常是"bearer"
     username: str
     username: str
 
 
@@ -116,7 +119,8 @@ fake_users_db = {
         # 这是"admin"的bcrypt哈希值
         # 这是"admin"的bcrypt哈希值
         "hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$AnPKKWadE68rd3n9CSM7vQ$p0LA3GqVSYxGRopB0B1yzzlO7W1LRCWagShF+Sbre9I",
         "hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$AnPKKWadE68rd3n9CSM7vQ$p0LA3GqVSYxGRopB0B1yzzlO7W1LRCWagShF+Sbre9I",
         "disabled": False,
         "disabled": False,
-    },
+    }
+    ,
     "admin": {
     "admin": {
         "userId":"2",
         "userId":"2",
         "username": "admin",
         "username": "admin",
@@ -201,30 +205,26 @@ def authenticate_user(fake_db: dict, username: str, password: str) -> Optional[U
     return user
     return user
 
 
 
 
-def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
+def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, token_type: str = "access") -> str:
     """
     """
-    创建JWT访问令牌
+    创建JWT令牌(access 或 refresh)
 
 
     Args:
     Args:
         data (dict): 要编码到令牌中的数据(通常包含用户标识)
         data (dict): 要编码到令牌中的数据(通常包含用户标识)
         expires_delta (Optional[timedelta]): 令牌过期时间,如果不提供则使用默认值
         expires_delta (Optional[timedelta]): 令牌过期时间,如果不提供则使用默认值
+        token_type (str): 令牌类型,"access" 或 "refresh"
 
 
     Returns:
     Returns:
         str: 编码后的JWT令牌
         str: 编码后的JWT令牌
     """
     """
-    # 复制数据以避免修改原始数据
     to_encode = data.copy()
     to_encode = data.copy()
 
 
-    # 计算过期时间,没有设置过期时间,就默认设置15分钟
     if expires_delta:
     if expires_delta:
         expire = datetime.now(timezone.utc) + expires_delta
         expire = datetime.now(timezone.utc) + expires_delta
     else:
     else:
         expire = datetime.now(timezone.utc) + timedelta(minutes=15)
         expire = datetime.now(timezone.utc) + timedelta(minutes=15)
 
 
-    # 添加过期时间到令牌数据中
-    to_encode.update({"exp": expire})
-
-    # 使用密钥和算法对数据进行编码
+    to_encode.update({"exp": expire, "type": token_type})
     encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM)
     encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM)
     return encoded_jwt
     return encoded_jwt
 
 
@@ -334,19 +334,27 @@ async def login_for_access_token(
             headers={"WWW-Authenticate": "Bearer"},
             headers={"WWW-Authenticate": "Bearer"},
         )
         )
 
 
-    # 设置令牌过期时间
-    access_token_expires = timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)
-
-    # 创建访问令牌,将用户名作为subject存储在令牌中
+    # 创建访问令牌和刷新令牌
     access_token = create_access_token(
     access_token = create_access_token(
         data={"sub": user.username, "userId": user.userId},
         data={"sub": user.username, "userId": user.userId},
-        expires_delta=access_token_expires
+        expires_delta=timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES),
+        token_type="access"
+    )
+    refresh_token = create_access_token(
+        data={"sub": user.username, "userId": user.userId},
+        expires_delta=timedelta(days=7),
+        token_type="refresh"
     )
     )
 
 
+    # 存储刷新令牌哈希到Redis
+    refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
+    redis_client.setex(f"refresh:{user.userId}", 7*24*60*60, refresh_hash)
+
     # 返回令牌和令牌类型
     # 返回令牌和令牌类型
     return {
     return {
         "message": "登录成功",
         "message": "登录成功",
         "access_token": access_token,
         "access_token": access_token,
+        "refresh_token": refresh_token,
         "token_type": "bearer",
         "token_type": "bearer",
         "username": user.username
         "username": user.username
     }
     }
@@ -401,8 +409,7 @@ async def logout(
     """
     """
     用户退出登录端点
     用户退出登录端点
 
 
-    由于JWT是无状态的,服务端不需要做特殊处理
-    主要是返回成功消息,让前端清除本地存储的token
+    删除Redis中的刷新令牌,让前端清除本地存储的token
 
 
     Args:
     Args:
         current_user (User): 通过依赖注入获取的当前用户信息
         current_user (User): 通过依赖注入获取的当前用户信息
@@ -410,7 +417,7 @@ async def logout(
     Returns:
     Returns:
         dict: 退出成功的消息
         dict: 退出成功的消息
     """
     """
-    print(current_user.username)
+    redis_client.delete(f"refresh:{current_user.userId}")
     return {
     return {
         "message": "退出登录成功",
         "message": "退出登录成功",
         "username": current_user.username,
         "username": current_user.username,
@@ -418,6 +425,68 @@ async def logout(
     }
     }
 
 
 
 
+@router.post("/refresh", response_model=Token, summary="刷新令牌", description="使用刷新令牌获取新的访问令牌")
+async def refresh_access_token(refresh_token: str) -> Token:
+    """
+    刷新令牌端点
+
+    使用刷新令牌换取新的访问令牌
+
+    Args:
+        refresh_token (str): 刷新令牌
+
+    Returns:
+        Token: 包含新访问令牌的对象
+
+    Raises:
+        HTTPException: 如果刷新令牌无效或已过期
+    """
+    credentials_exception = HTTPException(
+        status_code=status.HTTP_401_UNAUTHORIZED,
+        detail="Invalid refresh token",
+        headers={"WWW-Authenticate": "Bearer"},
+    )
+
+    try:
+        payload = jwt.decode(refresh_token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
+        if payload.get("type") != "refresh":
+            raise credentials_exception
+
+        username: str = payload.get("sub")
+        user_id: str = payload.get("userId")
+        if not username or not user_id:
+            raise credentials_exception
+
+    except jwt.PyJWTError:
+        raise credentials_exception
+
+    # 验证Redis中的刷新令牌哈希
+    refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
+    stored_hash = redis_client.get(f"refresh:{user_id}")
+    if not stored_hash or stored_hash.decode() != refresh_hash:
+        raise credentials_exception
+
+    # 验证用户是否存在
+    user = get_user(fake_users_db, username)
+    if not user:
+        raise credentials_exception
+
+    # 生成新的访问令牌
+    new_access_token = create_access_token(
+        data={"sub": username, "userId": user_id},
+        expires_delta=timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES),
+        token_type="access"
+    )
+
+    return {
+        "message": "令牌刷新成功",
+        "access_token": new_access_token,
+        "refresh_token": refresh_token,
+        "token_type": "bearer",
+        "username": username
+    }
+
+
 @router.get("/me", response_model=User, summary="获取用户信息", description="获取当前登录用户的个人信息")
 @router.get("/me", response_model=User, summary="获取用户信息", description="获取当前登录用户的个人信息")
 async def read_users_me(
 async def read_users_me(
         current_user: Annotated[User, Depends(get_current_active_user)]
         current_user: Annotated[User, Depends(get_current_active_user)]