|
|
@@ -11,10 +11,17 @@ REDIS_DB = int(os.getenv("REDIS_DB", 0))
|
|
|
r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
|
|
|
|
|
def create_session(user_id: int, expires_days: int = 365) -> str:
|
|
|
- """Create a unique session ID in Redis and map it to user_id"""
|
|
|
+ """Create a unique session ID and track it for the user"""
|
|
|
session_id = str(uuid.uuid4())
|
|
|
- # Save the session with an expiration time
|
|
|
- r.setex(f"session:{session_id}", timedelta(days=expires_days), str(user_id))
|
|
|
+ expiration = timedelta(days=expires_days)
|
|
|
+
|
|
|
+ # Save the session mapping
|
|
|
+ r.setex(f"session:{session_id}", expiration, str(user_id))
|
|
|
+
|
|
|
+ # Add to user's active sessions set
|
|
|
+ r.sadd(f"user_sessions:{user_id}", session_id)
|
|
|
+ r.expire(f"user_sessions:{user_id}", expiration)
|
|
|
+
|
|
|
return session_id
|
|
|
|
|
|
def validate_session(session_id: str) -> bool:
|
|
|
@@ -22,9 +29,20 @@ def validate_session(session_id: str) -> bool:
|
|
|
return r.exists(f"session:{session_id}") == 1
|
|
|
|
|
|
def delete_session(session_id: str):
|
|
|
- """Delete a session from Redis (Logout)"""
|
|
|
+ """Delete a specific session and remove from user tracking"""
|
|
|
+ user_id = r.get(f"session:{session_id}")
|
|
|
+ if user_id:
|
|
|
+ r.srem(f"user_sessions:{user_id}", session_id)
|
|
|
r.delete(f"session:{session_id}")
|
|
|
|
|
|
+def delete_all_user_sessions(user_id: int):
|
|
|
+ """Invalidate ALL active sessions for a user (e.g. after password reset)"""
|
|
|
+ sessions = r.smembers(f"user_sessions:{user_id}")
|
|
|
+ if sessions:
|
|
|
+ for session_id in sessions:
|
|
|
+ r.delete(f"session:{session_id}")
|
|
|
+ r.delete(f"user_sessions:{user_id}")
|
|
|
+
|
|
|
def get_user_id_from_session(session_id: str):
|
|
|
"""Retrieve the user_id associated with a session"""
|
|
|
return r.get(f"session:{session_id}")
|