Selaa lähdekoodia

security: invalidate all sessions on password reset

unknown 2 päivää sitten
vanhempi
commit
d68b9aa3cf
2 muutettua tiedostoa jossa 24 lisäystä ja 5 poistoa
  1. 2 1
      backend/routers/auth.py
  2. 22 4
      backend/session_utils.py

+ 2 - 1
backend/routers/auth.py

@@ -195,8 +195,9 @@ async def reset_password(request: schemas.ResetPassword):
     hashed_password = auth_utils.get_password_hash(request.new_password)
     db.execute_commit("UPDATE users SET password_hash = %s WHERE id = %s", (hashed_password, user_id))
     
-    # Successful reset - Cleanup ALL reset tokens for this user
+    # Successful reset - Cleanup ALL reset tokens AND sessions for this user
     token_service.cleanup_reset_tokens(user_id)
+    session_utils.delete_all_user_sessions(user_id)
     
     return {"message": "Password updated successfully"}
 

+ 22 - 4
backend/session_utils.py

@@ -11,10 +11,17 @@ REDIS_DB = int(os.getenv("REDIS_DB", 0))
 r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
 
 def create_session(user_id: int, expires_days: int = 365) -> str:
-    """Create a unique session ID in Redis and map it to user_id"""
+    """Create a unique session ID and track it for the user"""
     session_id = str(uuid.uuid4())
-    # Save the session with an expiration time
-    r.setex(f"session:{session_id}", timedelta(days=expires_days), str(user_id))
+    expiration = timedelta(days=expires_days)
+    
+    # Save the session mapping
+    r.setex(f"session:{session_id}", expiration, str(user_id))
+    
+    # Add to user's active sessions set
+    r.sadd(f"user_sessions:{user_id}", session_id)
+    r.expire(f"user_sessions:{user_id}", expiration)
+    
     return session_id
 
 def validate_session(session_id: str) -> bool:
@@ -22,9 +29,20 @@ def validate_session(session_id: str) -> bool:
     return r.exists(f"session:{session_id}") == 1
 
 def delete_session(session_id: str):
-    """Delete a session from Redis (Logout)"""
+    """Delete a specific session and remove from user tracking"""
+    user_id = r.get(f"session:{session_id}")
+    if user_id:
+        r.srem(f"user_sessions:{user_id}", session_id)
     r.delete(f"session:{session_id}")
 
+def delete_all_user_sessions(user_id: int):
+    """Invalidate ALL active sessions for a user (e.g. after password reset)"""
+    sessions = r.smembers(f"user_sessions:{user_id}")
+    if sessions:
+        for session_id in sessions:
+            r.delete(f"session:{session_id}")
+        r.delete(f"user_sessions:{user_id}")
+
 def get_user_id_from_session(session_id: str):
     """Retrieve the user_id associated with a session"""
     return r.get(f"session:{session_id}")