|
|
@@ -1,11 +1,13 @@
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
from typing import Optional, Annotated
|
|
|
import jwt
|
|
|
+import hashlib
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
|
from pydantic import BaseModel
|
|
|
from pwdlib import PasswordHash
|
|
|
from ..config.config import Config
|
|
|
+from ..db.redis_client import redis_client
|
|
|
|
|
|
config = Config()
|
|
|
# =====================================================
|
|
|
@@ -49,6 +51,7 @@ class Token(BaseModel):
|
|
|
"""
|
|
|
message: str
|
|
|
access_token: str # JWT访问令牌
|
|
|
+ refresh_token: str # JWT刷新令牌
|
|
|
token_type: str # 令牌类型,通常是"bearer"
|
|
|
username: str
|
|
|
|
|
|
@@ -116,7 +119,8 @@ fake_users_db = {
|
|
|
# 这是"admin"的bcrypt哈希值
|
|
|
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$AnPKKWadE68rd3n9CSM7vQ$p0LA3GqVSYxGRopB0B1yzzlO7W1LRCWagShF+Sbre9I",
|
|
|
"disabled": False,
|
|
|
- },
|
|
|
+ }
|
|
|
+ ,
|
|
|
"admin": {
|
|
|
"userId":"2",
|
|
|
"username": "admin",
|
|
|
@@ -201,30 +205,26 @@ def authenticate_user(fake_db: dict, username: str, password: str) -> Optional[U
|
|
|
return user
|
|
|
|
|
|
|
|
|
-def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
|
+def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, token_type: str = "access") -> str:
|
|
|
"""
|
|
|
- 创建JWT访问令牌
|
|
|
+ 创建JWT令牌(access 或 refresh)
|
|
|
|
|
|
Args:
|
|
|
data (dict): 要编码到令牌中的数据(通常包含用户标识)
|
|
|
expires_delta (Optional[timedelta]): 令牌过期时间,如果不提供则使用默认值
|
|
|
+ token_type (str): 令牌类型,"access" 或 "refresh"
|
|
|
|
|
|
Returns:
|
|
|
str: 编码后的JWT令牌
|
|
|
"""
|
|
|
- # 复制数据以避免修改原始数据
|
|
|
to_encode = data.copy()
|
|
|
|
|
|
- # 计算过期时间,没有设置过期时间,就默认设置15分钟
|
|
|
if expires_delta:
|
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
|
else:
|
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
|
|
|
|
|
|
- # 添加过期时间到令牌数据中
|
|
|
- to_encode.update({"exp": expire})
|
|
|
-
|
|
|
- # 使用密钥和算法对数据进行编码
|
|
|
+ to_encode.update({"exp": expire, "type": token_type})
|
|
|
encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM)
|
|
|
return encoded_jwt
|
|
|
|
|
|
@@ -334,19 +334,27 @@ async def login_for_access_token(
|
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
|
)
|
|
|
|
|
|
- # 设置令牌过期时间
|
|
|
- access_token_expires = timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
-
|
|
|
- # 创建访问令牌,将用户名作为subject存储在令牌中
|
|
|
+ # 创建访问令牌和刷新令牌
|
|
|
access_token = create_access_token(
|
|
|
data={"sub": user.username, "userId": user.userId},
|
|
|
- expires_delta=access_token_expires
|
|
|
+ expires_delta=timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES),
|
|
|
+ token_type="access"
|
|
|
+ )
|
|
|
+ refresh_token = create_access_token(
|
|
|
+ data={"sub": user.username, "userId": user.userId},
|
|
|
+ expires_delta=timedelta(days=7),
|
|
|
+ token_type="refresh"
|
|
|
)
|
|
|
|
|
|
+ # 存储刷新令牌哈希到Redis
|
|
|
+ refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
|
|
+ redis_client.setex(f"refresh:{user.userId}", 7*24*60*60, refresh_hash)
|
|
|
+
|
|
|
# 返回令牌和令牌类型
|
|
|
return {
|
|
|
"message": "登录成功",
|
|
|
"access_token": access_token,
|
|
|
+ "refresh_token": refresh_token,
|
|
|
"token_type": "bearer",
|
|
|
"username": user.username
|
|
|
}
|
|
|
@@ -401,8 +409,7 @@ async def logout(
|
|
|
"""
|
|
|
用户退出登录端点
|
|
|
|
|
|
- 由于JWT是无状态的,服务端不需要做特殊处理
|
|
|
- 主要是返回成功消息,让前端清除本地存储的token
|
|
|
+ 删除Redis中的刷新令牌,让前端清除本地存储的token
|
|
|
|
|
|
Args:
|
|
|
current_user (User): 通过依赖注入获取的当前用户信息
|
|
|
@@ -410,7 +417,7 @@ async def logout(
|
|
|
Returns:
|
|
|
dict: 退出成功的消息
|
|
|
"""
|
|
|
- print(current_user.username)
|
|
|
+ redis_client.delete(f"refresh:{current_user.userId}")
|
|
|
return {
|
|
|
"message": "退出登录成功",
|
|
|
"username": current_user.username,
|
|
|
@@ -418,6 +425,68 @@ async def logout(
|
|
|
}
|
|
|
|
|
|
|
|
|
+@router.post("/refresh", response_model=Token, summary="刷新令牌", description="使用刷新令牌获取新的访问令牌")
|
|
|
+async def refresh_access_token(refresh_token: str) -> Token:
|
|
|
+ """
|
|
|
+ 刷新令牌端点
|
|
|
+
|
|
|
+ 使用刷新令牌换取新的访问令牌
|
|
|
+
|
|
|
+ Args:
|
|
|
+ refresh_token (str): 刷新令牌
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Token: 包含新访问令牌的对象
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ HTTPException: 如果刷新令牌无效或已过期
|
|
|
+ """
|
|
|
+ credentials_exception = HTTPException(
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
+ detail="Invalid refresh token",
|
|
|
+ headers={"WWW-Authenticate": "Bearer"},
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ payload = jwt.decode(refresh_token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
|
|
|
+ if payload.get("type") != "refresh":
|
|
|
+ raise credentials_exception
|
|
|
+
|
|
|
+ username: str = payload.get("sub")
|
|
|
+ user_id: str = payload.get("userId")
|
|
|
+ if not username or not user_id:
|
|
|
+ raise credentials_exception
|
|
|
+
|
|
|
+ except jwt.PyJWTError:
|
|
|
+ raise credentials_exception
|
|
|
+
|
|
|
+ # 验证Redis中的刷新令牌哈希
|
|
|
+ refresh_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
|
|
+ stored_hash = redis_client.get(f"refresh:{user_id}")
|
|
|
+ if not stored_hash or stored_hash.decode() != refresh_hash:
|
|
|
+ raise credentials_exception
|
|
|
+
|
|
|
+ # 验证用户是否存在
|
|
|
+ user = get_user(fake_users_db, username)
|
|
|
+ if not user:
|
|
|
+ raise credentials_exception
|
|
|
+
|
|
|
+ # 生成新的访问令牌
|
|
|
+ new_access_token = create_access_token(
|
|
|
+ data={"sub": username, "userId": user_id},
|
|
|
+ expires_delta=timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES),
|
|
|
+ token_type="access"
|
|
|
+ )
|
|
|
+
|
|
|
+ return {
|
|
|
+ "message": "令牌刷新成功",
|
|
|
+ "access_token": new_access_token,
|
|
|
+ "refresh_token": refresh_token,
|
|
|
+ "token_type": "bearer",
|
|
|
+ "username": username
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
@router.get("/me", response_model=User, summary="获取用户信息", description="获取当前登录用户的个人信息")
|
|
|
async def read_users_me(
|
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|