auth.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from fastapi import APIRouter, Request, Depends, HTTPException, WebSocket, WebSocketDisconnect, Query
  2. from typing import Optional, List
  3. from services.global_manager import global_manager
  4. import auth_utils
  5. import db
  6. import schemas
  7. import session_utils
  8. import uuid
  9. from datetime import datetime, timedelta
  10. import locales
  11. router = APIRouter(prefix="/auth", tags=["auth"])
  12. @router.post("/register", response_model=schemas.UserResponse)
  13. async def register(request: Request, user: schemas.UserCreate, lang: str = "en"):
  14. existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (user.email,))
  15. if existing_user:
  16. raise HTTPException(status_code=400, detail=locales.translate_error("email_already_registered", lang))
  17. ip_address = request.client.host if request.client else None
  18. hashed_password = auth_utils.get_password_hash(user.password)
  19. query = """
  20. INSERT INTO users (email, password_hash, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address)
  21. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
  22. """
  23. params = (user.email, hashed_password, user.first_name, user.last_name, user.phone, user.shipping_address, user.preferred_language, 'user', ip_address)
  24. user_id = db.execute_commit(query, params)
  25. new_user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  26. return new_user[0]
  27. @router.post("/login", response_model=schemas.Token)
  28. async def login(user_data: schemas.UserLogin, lang: str = "en"):
  29. user = db.execute_query("SELECT * FROM users WHERE email = %s", (user_data.email,))
  30. if not user or not auth_utils.verify_password(user_data.password, user[0]['password_hash']):
  31. raise HTTPException(status_code=401, detail=locales.translate_error("incorrect_credentials", lang))
  32. if not user[0].get('is_active', True):
  33. raise HTTPException(status_code=403, detail="Your account has been suspended.")
  34. access_token = auth_utils.create_access_token(
  35. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  36. )
  37. return {"access_token": access_token, "token_type": "bearer"}
  38. @router.post("/social-login", response_model=schemas.Token)
  39. async def social_login(request: Request, data: schemas.SocialLogin):
  40. user = db.execute_query("SELECT id, email, role, is_active FROM users WHERE email = %s", (data.email,))
  41. if user:
  42. if not user[0].get('is_active', True):
  43. raise HTTPException(status_code=403, detail="Your account has been suspended.")
  44. access_token = auth_utils.create_access_token(
  45. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  46. )
  47. return {"access_token": access_token, "token_type": "bearer"}
  48. else:
  49. ip_address = request.client.host if request.client else None
  50. hashed_password = auth_utils.get_password_hash(str(uuid.uuid4()))
  51. query = "INSERT INTO users (email, password_hash, first_name, last_name, preferred_language, role, ip_address) VALUES (%s, %s, %s, %s, %s, %s, %s)"
  52. params = (data.email, hashed_password, data.first_name, data.last_name, data.preferred_language, 'user', ip_address)
  53. user_id = db.execute_commit(query, params)
  54. access_token = auth_utils.create_access_token(data={"sub": data.email, "id": user_id, "role": 'user'})
  55. return {"access_token": access_token, "token_type": "bearer"}
  56. @router.post("/logout")
  57. async def logout(token: str = Depends(auth_utils.oauth2_scheme)):
  58. payload = auth_utils.decode_token(token)
  59. if payload:
  60. sid = payload.get("sid")
  61. if sid: session_utils.delete_session(sid)
  62. return {"message": "Successfully logged out"}
  63. @router.post("/forgot-password")
  64. async def forgot_password(request: schemas.ForgotPassword):
  65. user = db.execute_query("SELECT id FROM users WHERE email = %s", (request.email,))
  66. if not user: raise HTTPException(status_code=404, detail="Email not found")
  67. token = str(uuid.uuid4())
  68. expires_at = datetime.utcnow() + timedelta(minutes=15)
  69. db.execute_commit("INSERT INTO password_reset_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)", (user[0]['id'], token, expires_at))
  70. return {"message": "Reset instructions sent to your email", "demo_token": token}
  71. @router.post("/reset-password")
  72. async def reset_password(request: schemas.ResetPassword):
  73. reset_data = db.execute_query("SELECT user_id, expires_at FROM password_reset_tokens WHERE token = %s", (request.token,))
  74. if not reset_data: raise HTTPException(status_code=400, detail="Invalid token")
  75. if reset_data[0]['expires_at'] < datetime.utcnow(): raise HTTPException(status_code=400, detail="Token expired")
  76. hashed_password = auth_utils.get_password_hash(request.new_password)
  77. db.execute_commit("UPDATE users SET password_hash = %s WHERE id = %s", (hashed_password, reset_data[0]['user_id']))
  78. db.execute_commit("DELETE FROM password_reset_tokens WHERE token = %s", (request.token,))
  79. return {"message": "Password reset successfully"}
  80. @router.get("/me", response_model=schemas.UserResponse)
  81. async def get_me(token: str = Depends(auth_utils.oauth2_scheme)):
  82. payload = auth_utils.decode_token(token)
  83. if not payload: raise HTTPException(status_code=401, detail="Invalid token")
  84. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, ip_address, created_at FROM users WHERE id = %s", (payload.get("id"),))
  85. if not user: raise HTTPException(status_code=404, detail="User not found")
  86. return user[0]
  87. @router.put("/me", response_model=schemas.UserResponse)
  88. async def update_me(data: schemas.UserUpdate, token: str = Depends(auth_utils.oauth2_scheme)):
  89. payload = auth_utils.decode_token(token)
  90. if not payload: raise HTTPException(status_code=401, detail="Invalid token")
  91. user_id = payload.get("id")
  92. update_fields = []
  93. params = []
  94. for field, value in data.dict(exclude_unset=True).items():
  95. update_fields.append(f"{field} = %s")
  96. params.append(value)
  97. if update_fields:
  98. query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
  99. params.append(user_id)
  100. db.execute_commit(query, tuple(params))
  101. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  102. return user[0]
  103. @router.get("/admin/users")
  104. async def admin_get_users(page: int = 1, size: int = 50, search: Optional[str] = None, token: str = Depends(auth_utils.oauth2_scheme)):
  105. payload = auth_utils.decode_token(token)
  106. if not payload or payload.get("role") != 'admin':
  107. raise HTTPException(status_code=403, detail="Admin role required")
  108. offset = (page - 1) * size
  109. base_query = "SELECT id, email, first_name, last_name, phone, role, can_chat, is_active, ip_address, created_at FROM users"
  110. count_query = "SELECT COUNT(*) as total FROM users"
  111. params = []
  112. if search and search.strip():
  113. where_clause = " WHERE email LIKE %s OR first_name LIKE %s OR last_name LIKE %s OR phone LIKE %s"
  114. base_query += where_clause
  115. count_query += where_clause
  116. pattern = f"%{search.strip()}%"
  117. params = [pattern] * 4
  118. base_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
  119. users = db.execute_query(base_query, tuple(params + [size, offset]))
  120. total = db.execute_query(count_query, tuple(params))[0]['total']
  121. return {"users": users, "total": total, "page": page, "size": size}
  122. @router.post("/admin/users", response_model=schemas.UserResponse)
  123. async def admin_create_user(data: schemas.UserCreate, token: str = Depends(auth_utils.oauth2_scheme)):
  124. payload = auth_utils.decode_token(token)
  125. if not payload or payload.get("role") != 'admin':
  126. raise HTTPException(status_code=403, detail="Admin role required")
  127. existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (data.email,))
  128. if existing_user:
  129. raise HTTPException(status_code=400, detail="Email already registered")
  130. hashed_password = auth_utils.get_password_hash(data.password)
  131. user_id = db.execute_commit(
  132. "INSERT INTO users (email, password_hash, first_name, last_name, phone, role, can_chat) VALUES (%s, %s, %s, %s, %s, %s, %s)",
  133. (data.email, hashed_password, data.first_name, data.last_name, data.phone, 'user', True)
  134. )
  135. user = db.execute_query("SELECT id, email, first_name, last_name, phone, role, can_chat, is_active, created_at FROM users WHERE id = %s", (user_id,))
  136. return user[0]
  137. @router.patch("/users/{target_id}/admin", response_model=schemas.UserResponse)
  138. async def admin_update_user(target_id: int, data: schemas.UserUpdate, token: str = Depends(auth_utils.oauth2_scheme)):
  139. payload = auth_utils.decode_token(token)
  140. if not payload or payload.get("role") != 'admin':
  141. raise HTTPException(status_code=403, detail="Admin role required")
  142. update_fields = []
  143. params = []
  144. for field, value in data.dict(exclude_unset=True).items():
  145. update_fields.append(f"{field} = %s")
  146. params.append(value)
  147. if update_fields:
  148. query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
  149. params.append(target_id)
  150. db.execute_commit(query, tuple(params))
  151. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, ip_address, created_at FROM users WHERE id = %s", (target_id,))
  152. if not user: raise HTTPException(status_code=404, detail="User not found")
  153. return user[0]
  154. @router.websocket("/ws/global")
  155. async def ws_global(websocket: WebSocket, token: str = Query(...)):
  156. payload = auth_utils.decode_token(token)
  157. if not payload:
  158. await websocket.close(code=4001)
  159. return
  160. user_id = payload.get("id")
  161. role = payload.get("role")
  162. if not user_id:
  163. await websocket.close(code=4001)
  164. return
  165. await global_manager.connect(websocket, user_id, role)
  166. session_utils.track_user_ping(user_id)
  167. # Send initial unread count
  168. await global_manager.notify_user(user_id) if role != 'admin' else await global_manager.notify_admins()
  169. try:
  170. while True:
  171. data = await websocket.receive_text()
  172. if data == "ping":
  173. session_utils.track_user_ping(user_id)
  174. except WebSocketDisconnect:
  175. global_manager.disconnect(websocket, user_id)