auth.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from fastapi import APIRouter, Request, Depends, HTTPException, WebSocket, WebSocketDisconnect, Query
  2. from typing import Optional, List
  3. from services.global_manager import global_manager
  4. import auth_utils
  5. import db
  6. import schemas
  7. import session_utils
  8. import uuid
  9. from datetime import datetime, timedelta
  10. import locales
  11. router = APIRouter(prefix="/auth", tags=["auth"])
  12. @router.post("/register", response_model=schemas.UserResponse)
  13. async def register(request: Request, user: schemas.UserCreate, lang: str = "en"):
  14. existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (user.email,))
  15. if existing_user:
  16. raise HTTPException(status_code=400, detail=locales.translate_error("email_already_registered", lang))
  17. ip_address = request.client.host if request.client else None
  18. hashed_password = auth_utils.get_password_hash(user.password)
  19. query = """
  20. INSERT INTO users (email, password_hash, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address)
  21. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
  22. """
  23. params = (user.email, hashed_password, user.first_name, user.last_name, user.phone, user.shipping_address, user.preferred_language, 'user', ip_address)
  24. user_id = db.execute_commit(query, params)
  25. new_user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  26. return new_user[0]
  27. @router.post("/login", response_model=schemas.Token)
  28. async def login(user_data: schemas.UserLogin, lang: str = "en"):
  29. user = db.execute_query("SELECT * FROM users WHERE email = %s", (user_data.email,))
  30. if not user or not auth_utils.verify_password(user_data.password, user[0]['password_hash']):
  31. raise HTTPException(status_code=401, detail=locales.translate_error("incorrect_credentials", lang))
  32. access_token = auth_utils.create_access_token(
  33. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  34. )
  35. return {"access_token": access_token, "token_type": "bearer"}
  36. @router.post("/social-login", response_model=schemas.Token)
  37. async def social_login(request: Request, data: schemas.SocialLogin):
  38. user = db.execute_query("SELECT id, email, role FROM users WHERE email = %s", (data.email,))
  39. if user:
  40. access_token = auth_utils.create_access_token(
  41. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  42. )
  43. return {"access_token": access_token, "token_type": "bearer"}
  44. else:
  45. ip_address = request.client.host if request.client else None
  46. hashed_password = auth_utils.get_password_hash(str(uuid.uuid4()))
  47. query = "INSERT INTO users (email, password_hash, first_name, last_name, preferred_language, role, ip_address) VALUES (%s, %s, %s, %s, %s, %s, %s)"
  48. params = (data.email, hashed_password, data.first_name, data.last_name, data.preferred_language, 'user', ip_address)
  49. user_id = db.execute_commit(query, params)
  50. access_token = auth_utils.create_access_token(data={"sub": data.email, "id": user_id, "role": 'user'})
  51. return {"access_token": access_token, "token_type": "bearer"}
  52. @router.post("/logout")
  53. async def logout(token: str = Depends(auth_utils.oauth2_scheme)):
  54. payload = auth_utils.decode_token(token)
  55. if payload:
  56. sid = payload.get("sid")
  57. if sid: session_utils.delete_session(sid)
  58. return {"message": "Successfully logged out"}
  59. @router.post("/forgot-password")
  60. async def forgot_password(request: schemas.ForgotPassword):
  61. user = db.execute_query("SELECT id FROM users WHERE email = %s", (request.email,))
  62. if not user: raise HTTPException(status_code=404, detail="Email not found")
  63. token = str(uuid.uuid4())
  64. expires_at = datetime.utcnow() + timedelta(minutes=15)
  65. db.execute_commit("INSERT INTO password_reset_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)", (user[0]['id'], token, expires_at))
  66. return {"message": "Reset instructions sent to your email", "demo_token": token}
  67. @router.post("/reset-password")
  68. async def reset_password(request: schemas.ResetPassword):
  69. reset_data = db.execute_query("SELECT user_id, expires_at FROM password_reset_tokens WHERE token = %s", (request.token,))
  70. if not reset_data: raise HTTPException(status_code=400, detail="Invalid token")
  71. if reset_data[0]['expires_at'] < datetime.utcnow(): raise HTTPException(status_code=400, detail="Token expired")
  72. hashed_password = auth_utils.get_password_hash(request.new_password)
  73. db.execute_commit("UPDATE users SET password_hash = %s WHERE id = %s", (hashed_password, reset_data[0]['user_id']))
  74. db.execute_commit("DELETE FROM password_reset_tokens WHERE token = %s", (request.token,))
  75. return {"message": "Password reset successfully"}
  76. @router.get("/me", response_model=schemas.UserResponse)
  77. async def get_me(token: str = Depends(auth_utils.oauth2_scheme)):
  78. payload = auth_utils.decode_token(token)
  79. if not payload: raise HTTPException(status_code=401, detail="Invalid token")
  80. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, ip_address, created_at FROM users WHERE id = %s", (payload.get("id"),))
  81. if not user: raise HTTPException(status_code=404, detail="User not found")
  82. return user[0]
  83. @router.put("/me", response_model=schemas.UserResponse)
  84. async def update_me(data: schemas.UserUpdate, token: str = Depends(auth_utils.oauth2_scheme)):
  85. payload = auth_utils.decode_token(token)
  86. if not payload: raise HTTPException(status_code=401, detail="Invalid token")
  87. user_id = payload.get("id")
  88. update_fields = []
  89. params = []
  90. for field, value in data.dict(exclude_unset=True).items():
  91. update_fields.append(f"{field} = %s")
  92. params.append(value)
  93. if update_fields:
  94. query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
  95. params.append(user_id)
  96. db.execute_commit(query, tuple(params))
  97. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  98. return user[0]
  99. @router.get("/admin/users")
  100. async def admin_get_users(page: int = 1, size: int = 50, search: Optional[str] = None, token: str = Depends(auth_utils.oauth2_scheme)):
  101. payload = auth_utils.decode_token(token)
  102. if not payload or payload.get("role") != 'admin':
  103. raise HTTPException(status_code=403, detail="Admin role required")
  104. offset = (page - 1) * size
  105. base_query = "SELECT id, email, first_name, last_name, phone, role, can_chat, ip_address, created_at FROM users"
  106. count_query = "SELECT COUNT(*) as total FROM users"
  107. params = []
  108. if search and search.strip():
  109. where_clause = " WHERE email LIKE %s OR first_name LIKE %s OR last_name LIKE %s OR phone LIKE %s"
  110. base_query += where_clause
  111. count_query += where_clause
  112. pattern = f"%{search.strip()}%"
  113. params = [pattern] * 4
  114. base_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
  115. users = db.execute_query(base_query, tuple(params + [size, offset]))
  116. total = db.execute_query(count_query, tuple(params))[0]['total']
  117. return {"users": users, "total": total, "page": page, "size": size}
  118. @router.post("/admin/users", response_model=schemas.UserResponse)
  119. async def admin_create_user(data: schemas.UserCreate, token: str = Depends(auth_utils.oauth2_scheme)):
  120. payload = auth_utils.decode_token(token)
  121. if not payload or payload.get("role") != 'admin':
  122. raise HTTPException(status_code=403, detail="Admin role required")
  123. existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (data.email,))
  124. if existing_user:
  125. raise HTTPException(status_code=400, detail="Email already registered")
  126. hashed_password = auth_utils.get_password_hash(data.password)
  127. user_id = db.execute_commit(
  128. "INSERT INTO users (email, password_hash, first_name, last_name, phone, role, can_chat) VALUES (%s, %s, %s, %s, %s, %s, %s)",
  129. (data.email, hashed_password, data.first_name, data.last_name, data.phone, 'user', True)
  130. )
  131. user = db.execute_query("SELECT id, email, first_name, last_name, phone, role, can_chat, created_at FROM users WHERE id = %s", (user_id,))
  132. return user[0]
  133. @router.patch("/users/{target_id}/admin", response_model=schemas.UserResponse)
  134. async def admin_update_user(target_id: int, data: schemas.UserUpdate, token: str = Depends(auth_utils.oauth2_scheme)):
  135. payload = auth_utils.decode_token(token)
  136. if not payload or payload.get("role") != 'admin':
  137. raise HTTPException(status_code=403, detail="Admin role required")
  138. update_fields = []
  139. params = []
  140. for field, value in data.dict(exclude_unset=True).items():
  141. update_fields.append(f"{field} = %s")
  142. params.append(value)
  143. if update_fields:
  144. query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
  145. params.append(target_id)
  146. db.execute_commit(query, tuple(params))
  147. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, ip_address, created_at FROM users WHERE id = %s", (target_id,))
  148. if not user: raise HTTPException(status_code=404, detail="User not found")
  149. return user[0]
  150. @router.websocket("/ws/global")
  151. async def ws_global(websocket: WebSocket, token: str = Query(...)):
  152. payload = auth_utils.decode_token(token)
  153. if not payload:
  154. await websocket.close(code=4001)
  155. return
  156. user_id = payload.get("id")
  157. role = payload.get("role")
  158. if not user_id:
  159. await websocket.close(code=4001)
  160. return
  161. await global_manager.connect(websocket, user_id, role)
  162. session_utils.track_user_ping(user_id)
  163. # Send initial unread count
  164. await global_manager.notify_user(user_id) if role != 'admin' else await global_manager.notify_admins()
  165. try:
  166. while True:
  167. data = await websocket.receive_text()
  168. if data == "ping":
  169. session_utils.track_user_ping(user_id)
  170. except WebSocketDisconnect:
  171. global_manager.disconnect(websocket, user_id)