users.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. from datetime import datetime, timedelta, timezone
  2. from typing import Optional, Annotated
  3. import jwt
  4. from fastapi import APIRouter, Depends, HTTPException, status
  5. from fastapi.security import OAuth2PasswordBearer
  6. from passlib.context import CryptContext
  7. from pydantic import BaseModel
  8. # =====================================================
  9. # JWT和安全配置
  10. # =====================================================
  11. # JWT密钥配置 - 生产环境中应该使用环境变量或密钥管理服务
  12. SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" # 警告:生产环境请使用强密钥并通过环境变量管理
  13. ALGORITHM = "HS256" # JWT签名算法
  14. ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Token过期时间(分钟)
  15. # 密码加密上下文配置
  16. # schemes: 支持的密码哈希方案,bcrypt是目前推荐的安全哈希算法
  17. # deprecated: 标记为已弃用的方案(用于向后兼容)
  18. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  19. # OAuth2密码Bearer令牌方案
  20. # tokenUrl: 获取token的端点URL,必须与实际的token端点路径匹配
  21. # 这告诉FastAPI和前端客户端在哪里获取访问令牌
  22. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/token")
  23. # 创建路由器实例
  24. # 这个路由器将包含所有用户相关的路由
  25. router = APIRouter()
  26. # =====================================================
  27. # Pydantic数据模型定义
  28. # =====================================================
  29. class LoginRequest(BaseModel):
  30. """
  31. 用户登录输入的模型
  32. """
  33. username: str
  34. password: str
  35. class Token(BaseModel):
  36. """
  37. 访问令牌响应模型
  38. 用于登录成功后返回JWT令牌
  39. """
  40. message: str
  41. access_token: str # JWT访问令牌
  42. token_type: str # 令牌类型,通常是"bearer"
  43. username: str
  44. class TokenData(BaseModel):
  45. """
  46. 令牌数据模型
  47. 用于解析JWT令牌中的用户信息
  48. """
  49. username: Optional[str] = None # 用户名(可选)
  50. class User(BaseModel):
  51. """
  52. 用户基础信息模型
  53. 定义用户的公开信息(不包含密码等敏感信息)
  54. """
  55. username: str # 用户名(必需)
  56. email: Optional[str] = None # 邮箱(可选)
  57. full_name: Optional[str] = None # 全名(可选)
  58. disabled: Optional[bool] = None # 是否禁用(可选)
  59. class UserInDB(User):
  60. """
  61. 数据库中的用户模型
  62. 继承User模型,添加了密码哈希字段
  63. """
  64. hashed_password: str # 哈希后的密码
  65. class UserCreate(BaseModel):
  66. """
  67. 用户创建请求模型
  68. 用于用户注册时接收前端传来的数据
  69. """
  70. username: str # 用户名(必需)
  71. password: str # 明文密码(必需)
  72. email: Optional[str] = None # 邮箱(可选)
  73. full_name: Optional[str] = None # 全名(可选)
  74. class UserUpdate(BaseModel):
  75. """
  76. 用户更新请求模型
  77. 用于更新用户信息时接收前端传来的数据
  78. """
  79. email: Optional[str] = None # 新邮箱(可选)
  80. full_name: Optional[str] = None # 新全名(可选)
  81. # =====================================================
  82. # 模拟数据库
  83. # =====================================================
  84. # 模拟用户数据库 - 生产环境中应该使用真实的数据库(如PostgreSQL、MySQL等)
  85. # 这里使用字典来模拟数据库存储,包含一个默认管理员账户
  86. fake_users_db = {
  87. "root": {
  88. "username": "root",
  89. "full_name": "Administrator",
  90. "email": "admin@example.com",
  91. # 这是"admin123"的bcrypt哈希值
  92. "hashed_password": "$2b$12$p2v617r0nPHKa4LVd6j7puYqR0lD8xivcwvtCp9UBziF5c2dRhFe.",
  93. "disabled": False,
  94. }
  95. }
  96. # =====================================================
  97. # 工具函数
  98. # =====================================================
  99. def verify_password(plain_password: str, hashed_password: str) -> bool:
  100. """
  101. 验证密码是否正确
  102. Args:
  103. plain_password (str): 用户输入的明文密码
  104. hashed_password (str): 数据库中存储的密码哈希
  105. Returns:
  106. bool: 密码是否匹配
  107. """
  108. return pwd_context.verify(plain_password, hashed_password)
  109. def get_password_hash(password: str) -> str:
  110. """
  111. 生成密码的哈希值
  112. Args:
  113. password (str): 明文密码
  114. Returns:
  115. str: 密码的bcrypt哈希值
  116. """
  117. return pwd_context.hash(password)
  118. def get_user(db: dict, username: str) -> Optional[UserInDB]:
  119. """
  120. 从数据库中获取用户信息
  121. Args:
  122. db (dict): 用户数据库
  123. username (str): 用户名
  124. Returns:
  125. Optional[UserInDB]: 用户信息对象,如果用户不存在则返回None
  126. """
  127. if username in db:
  128. user_dict = db[username]
  129. return UserInDB(**user_dict)
  130. return None
  131. def authenticate_user(fake_db: dict, username: str, password: str) -> Optional[UserInDB]:
  132. """
  133. 验证用户身份
  134. Args:
  135. fake_db (dict): 用户数据库
  136. username (str): 用户名
  137. password (str): 明文密码
  138. Returns:
  139. Optional[UserInDB]: 验证成功返回用户对象,失败返回False
  140. """
  141. # 首先获取用户信息
  142. user = get_user(fake_db, username)
  143. if not user:
  144. return False
  145. # 验证密码是否正确
  146. if not verify_password(password, user.hashed_password):
  147. return False
  148. return user
  149. def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
  150. """
  151. 创建JWT访问令牌
  152. Args:
  153. data (dict): 要编码到令牌中的数据(通常包含用户标识)
  154. expires_delta (Optional[timedelta]): 令牌过期时间,如果不提供则使用默认值
  155. Returns:
  156. str: 编码后的JWT令牌
  157. """
  158. # 复制数据以避免修改原始数据
  159. to_encode = data.copy()
  160. # 计算过期时间,没有设置过期时间,就默认设置15分钟
  161. if expires_delta:
  162. expire = datetime.now(timezone.utc) + expires_delta
  163. else:
  164. expire = datetime.now(timezone.utc) + timedelta(minutes=15)
  165. # 添加过期时间到令牌数据中
  166. to_encode.update({"exp": expire})
  167. # 使用密钥和算法对数据进行编码
  168. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  169. return encoded_jwt
  170. # =====================================================
  171. # 依赖函数(用于路由中的依赖注入)
  172. # =====================================================
  173. async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> UserInDB:
  174. """
  175. 从JWT令牌中获取当前用户信息
  176. 这是一个依赖函数,会被其他需要用户身份验证的路由使用
  177. Args:
  178. token (str): 从请求头中提取的Bearer令牌
  179. Returns:
  180. UserInDB: 当前用户信息
  181. Raises:
  182. HTTPException: 如果令牌无效或用户不存在
  183. """
  184. # 定义认证异常,当令牌验证失败时抛出
  185. credentials_exception = HTTPException(
  186. status_code=status.HTTP_401_UNAUTHORIZED,
  187. detail="Could not validate credentials", # 无法验证凭据
  188. headers={"WWW-Authenticate": "Bearer"}, # 告诉客户端使用Bearer认证
  189. )
  190. try:
  191. # 解码JWT令牌
  192. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  193. # 从令牌中提取用户名(sub是JWT标准字段,表示subject/主题)
  194. username: str = payload.get("sub")
  195. if username is None:
  196. raise credentials_exception
  197. # 创建令牌数据对象
  198. token_data = TokenData(username=username)
  199. except jwt.PyJWTError:
  200. # JWT解码失败(令牌无效、过期等)
  201. raise credentials_exception
  202. # 从数据库中获取用户信息
  203. user = get_user(fake_users_db, username=token_data.username)
  204. if user is None:
  205. raise credentials_exception
  206. return user
  207. async def get_current_active_user(
  208. current_user: Annotated[User, Depends(get_current_user)]
  209. ) -> User:
  210. """
  211. 获取当前活跃用户
  212. 这是另一个依赖函数,确保用户不仅通过了身份验证,而且账户是活跃的
  213. Args:
  214. current_user (User): 从get_current_user依赖中获取的当前用户
  215. Returns:
  216. User: 活跃的用户信息
  217. Raises:
  218. HTTPException: 如果用户账户被禁用
  219. """
  220. if current_user.disabled:
  221. raise HTTPException(status_code=400, detail="Inactive user")
  222. return current_user
  223. # =====================================================
  224. # API路由端点
  225. # =====================================================
  226. @router.post("/token", response_model=Token, summary="用户登录", description="使用用户名和密码获取JWT访问令牌")
  227. async def login_for_access_token(
  228. login_data: LoginRequest
  229. ) -> Token:
  230. """
  231. 用户登录端点
  232. 接受用户名和密码,返回JWT访问令牌
  233. 使用OAuth2PasswordRequestForm来接收表单数据(username、password字段)
  234. Args:
  235. login_data : 包含用户名和密码的json数据
  236. Returns:
  237. Token: 包含访问令牌和令牌类型的对象
  238. Raises:
  239. HTTPException: 如果用户名或密码不正确
  240. """
  241. # 验证用户身份
  242. user = authenticate_user(fake_users_db, login_data.username, login_data.password)
  243. if not user:
  244. # 认证失败,返回401未授权状态码
  245. raise HTTPException(
  246. status_code=status.HTTP_401_UNAUTHORIZED,
  247. detail="用户名或密码错误", # 用户名或密码错误
  248. headers={"WWW-Authenticate": "Bearer"},
  249. )
  250. # 设置令牌过期时间
  251. access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  252. # 创建访问令牌,将用户名作为subject存储在令牌中
  253. access_token = create_access_token(
  254. data={"sub": user.username},
  255. expires_delta=access_token_expires
  256. )
  257. # 返回令牌和令牌类型
  258. return {
  259. "message": "登录成功",
  260. "access_token": access_token,
  261. "token_type": "bearer",
  262. "username": user.username
  263. }
  264. @router.post("/register", response_model=User, summary="用户注册", description="创建新用户账户")
  265. async def register_user(user: UserCreate) -> User:
  266. """
  267. 用户注册端点
  268. 创建新的用户账户,密码会被自动哈希加密存储
  269. Args:
  270. user (UserCreate): 包含用户注册信息的对象
  271. Returns:
  272. User: 创建成功的用户信息(不包含密码)
  273. Raises:
  274. HTTPException: 如果用户名已存在
  275. """
  276. # 检查用户名是否已经存在
  277. if user.username in fake_users_db:
  278. raise HTTPException(
  279. status_code=status.HTTP_400_BAD_REQUEST,
  280. detail="Username already registered" # 用户名已被注册
  281. )
  282. # 对密码进行哈希加密
  283. hashed_password = get_password_hash(user.password)
  284. # 创建用户数据字典
  285. user_dict = {
  286. "username": user.username,
  287. "full_name": user.full_name,
  288. "email": user.email,
  289. "hashed_password": hashed_password,
  290. "disabled": False # 新用户默认为启用状态
  291. }
  292. # 将用户数据保存到"数据库"
  293. fake_users_db[user.username] = user_dict
  294. # 返回用户信息(不包含密码哈希)
  295. return User(**user_dict)
  296. @router.post("/logout", summary="用户退出", description="用户退出登录")
  297. async def logout(
  298. current_user: Annotated[User, Depends(get_current_active_user)]
  299. ) -> dict:
  300. """
  301. 用户退出登录端点
  302. 由于JWT是无状态的,服务端不需要做特殊处理
  303. 主要是返回成功消息,让前端清除本地存储的token
  304. Args:
  305. current_user (User): 通过依赖注入获取的当前用户信息
  306. Returns:
  307. dict: 退出成功的消息
  308. """
  309. print(current_user.username)
  310. return {
  311. "message": "退出登录成功",
  312. "username": current_user.username,
  313. "logout_time": datetime.now(timezone.utc).isoformat()
  314. }
  315. @router.get("/me", response_model=User, summary="获取用户信息", description="获取当前登录用户的个人信息")
  316. async def read_users_me(
  317. current_user: Annotated[User, Depends(get_current_active_user)]
  318. ) -> User:
  319. """
  320. 获取当前用户信息端点
  321. 需要有效的JWT令牌才能访问
  322. Args:
  323. current_user (User): 通过依赖注入获取的当前用户信息
  324. Returns:
  325. User: 当前用户的信息
  326. """
  327. return current_user
  328. @router.put("/me", response_model=User, summary="更新用户信息", description="更新当前登录用户的个人信息")
  329. async def update_user_me(
  330. user_update: UserUpdate,
  331. current_user: Annotated[User, Depends(get_current_active_user)]
  332. ) -> User:
  333. """
  334. 更新当前用户信息端点
  335. 允许用户更新自己的邮箱和全名信息
  336. Args:
  337. user_update (UserUpdate): 包含要更新的用户信息
  338. current_user (User): 通过依赖注入获取的当前用户信息
  339. Returns:
  340. User: 更新后的用户信息
  341. """
  342. # 只更新非None的字段
  343. if user_update.email is not None:
  344. fake_users_db[current_user.username]["email"] = user_update.email
  345. if user_update.full_name is not None:
  346. fake_users_db[current_user.username]["full_name"] = user_update.full_name
  347. # 获取并返回更新后的用户信息
  348. updated_user = get_user(fake_users_db, current_user.username)
  349. return User(**updated_user.model_dump())
  350. @router.get("/protected", summary="受保护的路由示例", description="演示需要身份验证才能访问的路由")
  351. async def protected_route(
  352. current_user: Annotated[User, Depends(get_current_active_user)]
  353. ) -> dict:
  354. """
  355. 受保护的路由示例
  356. 这个端点演示了如何创建需要身份验证的路由
  357. 只有提供有效JWT令牌的用户才能访问
  358. Args:
  359. current_user (User): 通过依赖注入获取的当前用户信息
  360. Returns:
  361. dict: 包含欢迎消息的字典
  362. """
  363. return {
  364. "message": f"Hello {current_user.username}, this is a protected route!",
  365. "user_info": {
  366. "username": current_user.username,
  367. "email": current_user.email,
  368. "full_name": current_user.full_name
  369. },
  370. "access_time": datetime.now(timezone.utc).isoformat()
  371. }
  372. @router.get("/all", summary="获取所有用户", description="获取系统中所有用户的列表(需要管理员权限)")
  373. async def get_all_users(
  374. current_user: Annotated[User, Depends(get_current_active_user)]
  375. ) -> dict:
  376. """
  377. 获取所有用户列表端点
  378. 返回系统中所有用户的信息(不包含密码)
  379. 注意:在实际应用中,这个功能应该有权限控制
  380. Args:
  381. current_user (User): 通过依赖注入获取的当前用户信息
  382. Returns:
  383. dict: 包含用户列表和总数的字典
  384. """
  385. users = []
  386. for username, user_data in fake_users_db.items():
  387. # 创建User对象(不包含密码哈希)
  388. user_info = {k: v for k, v in user_data.items() if k != 'hashed_password'}
  389. users.append(User(**user_info))
  390. return {
  391. "users": users,
  392. "total": len(users),
  393. "requested_by": current_user.username
  394. }
  395. @router.delete("/me", summary="删除用户账户", description="删除当前登录用户的账户")
  396. async def delete_user_account(
  397. current_user: Annotated[User, Depends(get_current_active_user)]
  398. ) -> dict:
  399. """
  400. 删除当前用户账户端点
  401. 允许用户删除自己的账户
  402. 注意:在实际应用中,可能需要额外的确认步骤
  403. Args:
  404. current_user (User): 通过依赖注入获取的当前用户信息
  405. Returns:
  406. dict: 删除成功的确认消息
  407. Raises:
  408. HTTPException: 如果用户不存在(理论上不会发生)
  409. """
  410. if current_user.username in fake_users_db:
  411. # 从数据库中删除用户
  412. del fake_users_db[current_user.username]
  413. return {
  414. "message": "User account deleted successfully",
  415. "deleted_user": current_user.username,
  416. "deleted_at": datetime.now(timezone.utc).isoformat()
  417. }
  418. else:
  419. # 这种情况理论上不会发生,因为用户已经通过了身份验证
  420. raise HTTPException(
  421. status_code=status.HTTP_404_NOT_FOUND,
  422. detail="User not found"
  423. )