auth.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from fastapi import APIRouter, Request, Depends, HTTPException, WebSocket, WebSocketDisconnect, Query
  2. from services.global_manager import global_manager
  3. import auth_utils
  4. import db
  5. import schemas
  6. import session_utils
  7. import uuid
  8. from datetime import datetime, timedelta
  9. import locales
  10. router = APIRouter(prefix="/auth", tags=["auth"])
  11. @router.post("/register", response_model=schemas.UserResponse)
  12. async def register(request: Request, user: schemas.UserCreate, lang: str = "en"):
  13. existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (user.email,))
  14. if existing_user:
  15. raise HTTPException(status_code=400, detail=locales.translate_error("email_already_registered", lang))
  16. ip_address = request.client.host if request.client else None
  17. hashed_password = auth_utils.get_password_hash(user.password)
  18. query = """
  19. INSERT INTO users (email, password_hash, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address)
  20. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
  21. """
  22. params = (user.email, hashed_password, user.first_name, user.last_name, user.phone, user.shipping_address, user.preferred_language, 'user', ip_address)
  23. user_id = db.execute_commit(query, params)
  24. new_user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  25. return new_user[0]
  26. @router.post("/login", response_model=schemas.Token)
  27. async def login(user_data: schemas.UserLogin, lang: str = "en"):
  28. user = db.execute_query("SELECT * FROM users WHERE email = %s", (user_data.email,))
  29. if not user or not auth_utils.verify_password(user_data.password, user[0]['password_hash']):
  30. raise HTTPException(status_code=401, detail=locales.translate_error("incorrect_credentials", lang))
  31. access_token = auth_utils.create_access_token(
  32. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  33. )
  34. return {"access_token": access_token, "token_type": "bearer"}
  35. @router.post("/social-login", response_model=schemas.Token)
  36. async def social_login(request: Request, data: schemas.SocialLogin):
  37. user = db.execute_query("SELECT id, email, role FROM users WHERE email = %s", (data.email,))
  38. if user:
  39. access_token = auth_utils.create_access_token(
  40. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  41. )
  42. return {"access_token": access_token, "token_type": "bearer"}
  43. else:
  44. ip_address = request.client.host if request.client else None
  45. hashed_password = auth_utils.get_password_hash(str(uuid.uuid4()))
  46. query = "INSERT INTO users (email, password_hash, first_name, last_name, preferred_language, role, ip_address) VALUES (%s, %s, %s, %s, %s, %s, %s)"
  47. params = (data.email, hashed_password, data.first_name, data.last_name, data.preferred_language, 'user', ip_address)
  48. user_id = db.execute_commit(query, params)
  49. access_token = auth_utils.create_access_token(data={"sub": data.email, "id": user_id, "role": 'user'})
  50. return {"access_token": access_token, "token_type": "bearer"}
  51. @router.post("/logout")
  52. async def logout(token: str = Depends(auth_utils.oauth2_scheme)):
  53. payload = auth_utils.decode_token(token)
  54. if payload:
  55. sid = payload.get("sid")
  56. if sid: session_utils.delete_session(sid)
  57. return {"message": "Successfully logged out"}
  58. @router.post("/forgot-password")
  59. async def forgot_password(request: schemas.ForgotPassword):
  60. user = db.execute_query("SELECT id FROM users WHERE email = %s", (request.email,))
  61. if not user: raise HTTPException(status_code=404, detail="Email not found")
  62. token = str(uuid.uuid4())
  63. expires_at = datetime.utcnow() + timedelta(minutes=15)
  64. db.execute_commit("INSERT INTO password_reset_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)", (user[0]['id'], token, expires_at))
  65. return {"message": "Reset instructions sent to your email", "demo_token": token}
  66. @router.post("/reset-password")
  67. async def reset_password(request: schemas.ResetPassword):
  68. reset_data = db.execute_query("SELECT user_id, expires_at FROM password_reset_tokens WHERE token = %s", (request.token,))
  69. if not reset_data: raise HTTPException(status_code=400, detail="Invalid token")
  70. if reset_data[0]['expires_at'] < datetime.utcnow(): raise HTTPException(status_code=400, detail="Token expired")
  71. hashed_password = auth_utils.get_password_hash(request.new_password)
  72. db.execute_commit("UPDATE users SET password_hash = %s WHERE id = %s", (hashed_password, reset_data[0]['user_id']))
  73. db.execute_commit("DELETE FROM password_reset_tokens WHERE token = %s", (request.token,))
  74. return {"message": "Password reset successfully"}
  75. @router.get("/me", response_model=schemas.UserResponse)
  76. async def get_me(token: str = Depends(auth_utils.oauth2_scheme)):
  77. payload = auth_utils.decode_token(token)
  78. if not payload: raise HTTPException(status_code=401, detail="Invalid token")
  79. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, created_at FROM users WHERE id = %s", (payload.get("id"),))
  80. if not user: raise HTTPException(status_code=404, detail="User not found")
  81. return user[0]
  82. @router.put("/me", response_model=schemas.UserResponse)
  83. async def update_me(data: schemas.UserUpdate, token: str = Depends(auth_utils.oauth2_scheme)):
  84. payload = auth_utils.decode_token(token)
  85. if not payload: raise HTTPException(status_code=401, detail="Invalid token")
  86. user_id = payload.get("id")
  87. update_fields = []
  88. params = []
  89. for field, value in data.dict(exclude_unset=True).items():
  90. update_fields.append(f"{field} = %s")
  91. params.append(value)
  92. if update_fields:
  93. query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
  94. params.append(user_id)
  95. db.execute_commit(query, tuple(params))
  96. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  97. return user[0]
  98. @router.websocket("/ws/global")
  99. async def ws_global(websocket: WebSocket, token: str = Query(...)):
  100. payload = auth_utils.decode_token(token)
  101. if not payload:
  102. await websocket.close(code=4001)
  103. return
  104. user_id = payload.get("id")
  105. role = payload.get("role")
  106. if not user_id:
  107. await websocket.close(code=4001)
  108. return
  109. await global_manager.connect(websocket, user_id, role)
  110. session_utils.track_user_ping(user_id)
  111. # Send initial unread count
  112. await global_manager.notify_user(user_id) if role != 'admin' else await global_manager.notify_admins()
  113. try:
  114. while True:
  115. data = await websocket.receive_text()
  116. if data == "ping":
  117. session_utils.track_user_ping(user_id)
  118. except WebSocketDisconnect:
  119. global_manager.disconnect(websocket, user_id)