auth_utils.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from passlib.context import CryptContext
  2. from jose import JWTError, jwt
  3. from datetime import datetime, timedelta
  4. from typing import Optional
  5. from fastapi.security import OAuth2PasswordBearer
  6. import session_utils
  7. # Configuration
  8. SECRET_KEY = "your-secret-key-replace-with-env-variable"
  9. ALGORITHM = "HS256"
  10. ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 365 # 1 year
  11. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
  12. oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="auth/login", auto_error=False)
  13. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  14. def verify_password(plain_password, hashed_password):
  15. return pwd_context.verify(plain_password, hashed_password)
  16. def get_password_hash(password):
  17. return pwd_context.hash(password)
  18. def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
  19. to_encode = data.copy()
  20. if expires_delta:
  21. expire = datetime.utcnow() + expires_delta
  22. else:
  23. expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  24. # Create a persistent session in Redis for tracking
  25. sid = session_utils.create_session(data.get("id", 0))
  26. to_encode.update({"exp": expire, "sid": sid})
  27. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  28. return encoded_jwt
  29. def decode_token(token: str):
  30. try:
  31. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  32. sid = payload.get("sid")
  33. # Ensure session exists in Redis for stateful revocation
  34. if sid and not session_utils.validate_session(sid):
  35. return None
  36. return payload
  37. except JWTError:
  38. return None