auth.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from fastapi import APIRouter, Request, Depends, HTTPException
  2. import auth_utils
  3. import db
  4. import schemas
  5. import session_utils
  6. import uuid
  7. from datetime import datetime, timedelta
  8. import locales
  9. router = APIRouter(prefix="/auth", tags=["auth"])
  10. @router.post("/register", response_model=schemas.UserResponse)
  11. async def register(request: Request, user: schemas.UserCreate, lang: str = "en"):
  12. existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (user.email,))
  13. if existing_user:
  14. raise HTTPException(status_code=400, detail=locales.translate_error("email_already_registered", lang))
  15. ip_address = request.client.host if request.client else None
  16. hashed_password = auth_utils.get_password_hash(user.password)
  17. query = """
  18. INSERT INTO users (email, password_hash, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address)
  19. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
  20. """
  21. params = (user.email, hashed_password, user.first_name, user.last_name, user.phone, user.shipping_address, user.preferred_language, 'user', ip_address)
  22. user_id = db.execute_commit(query, params)
  23. new_user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  24. return new_user[0]
  25. @router.post("/login", response_model=schemas.Token)
  26. async def login(user_data: schemas.UserLogin, lang: str = "en"):
  27. user = db.execute_query("SELECT * FROM users WHERE email = %s", (user_data.email,))
  28. if not user or not auth_utils.verify_password(user_data.password, user[0]['password_hash']):
  29. raise HTTPException(status_code=401, detail=locales.translate_error("incorrect_credentials", lang))
  30. access_token = auth_utils.create_access_token(
  31. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  32. )
  33. return {"access_token": access_token, "token_type": "bearer"}
  34. @router.post("/social-login", response_model=schemas.Token)
  35. async def social_login(request: Request, data: schemas.SocialLogin):
  36. user = db.execute_query("SELECT id, email, role FROM users WHERE email = %s", (data.email,))
  37. if user:
  38. access_token = auth_utils.create_access_token(
  39. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  40. )
  41. return {"access_token": access_token, "token_type": "bearer"}
  42. else:
  43. ip_address = request.client.host if request.client else None
  44. hashed_password = auth_utils.get_password_hash(str(uuid.uuid4()))
  45. query = "INSERT INTO users (email, password_hash, first_name, last_name, preferred_language, role, ip_address) VALUES (%s, %s, %s, %s, %s, %s, %s)"
  46. params = (data.email, hashed_password, data.first_name, data.last_name, data.preferred_language, 'user', ip_address)
  47. user_id = db.execute_commit(query, params)
  48. access_token = auth_utils.create_access_token(data={"sub": data.email, "id": user_id, "role": 'user'})
  49. return {"access_token": access_token, "token_type": "bearer"}
  50. @router.post("/logout")
  51. async def logout(token: str = Depends(auth_utils.oauth2_scheme)):
  52. payload = auth_utils.decode_token(token)
  53. if payload:
  54. sid = payload.get("sid")
  55. if sid: session_utils.delete_session(sid)
  56. return {"message": "Successfully logged out"}
  57. @router.post("/forgot-password")
  58. async def forgot_password(request: schemas.ForgotPassword):
  59. user = db.execute_query("SELECT id FROM users WHERE email = %s", (request.email,))
  60. if not user: raise HTTPException(status_code=404, detail="Email not found")
  61. token = str(uuid.uuid4())
  62. expires_at = datetime.utcnow() + timedelta(minutes=15)
  63. db.execute_commit("INSERT INTO password_reset_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)", (user[0]['id'], token, expires_at))
  64. return {"message": "Reset instructions sent to your email", "demo_token": token}
  65. @router.post("/reset-password")
  66. async def reset_password(request: schemas.ResetPassword):
  67. reset_data = db.execute_query("SELECT user_id, expires_at FROM password_reset_tokens WHERE token = %s", (request.token,))
  68. if not reset_data: raise HTTPException(status_code=400, detail="Invalid token")
  69. if reset_data[0]['expires_at'] < datetime.utcnow(): raise HTTPException(status_code=400, detail="Token expired")
  70. hashed_password = auth_utils.get_password_hash(request.new_password)
  71. db.execute_commit("UPDATE users SET password_hash = %s WHERE id = %s", (hashed_password, reset_data[0]['user_id']))
  72. db.execute_commit("DELETE FROM password_reset_tokens WHERE token = %s", (request.token,))
  73. return {"message": "Password reset successfully"}
  74. @router.get("/me", response_model=schemas.UserResponse)
  75. async def get_me(token: str = Depends(auth_utils.oauth2_scheme)):
  76. payload = auth_utils.decode_token(token)
  77. if not payload: raise HTTPException(status_code=401, detail="Invalid token")
  78. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, created_at FROM users WHERE id = %s", (payload.get("id"),))
  79. if not user: raise HTTPException(status_code=404, detail="User not found")
  80. return user[0]
  81. @router.put("/me", response_model=schemas.UserResponse)
  82. async def update_me(data: schemas.UserUpdate, token: str = Depends(auth_utils.oauth2_scheme)):
  83. payload = auth_utils.decode_token(token)
  84. if not payload: raise HTTPException(status_code=401, detail="Invalid token")
  85. user_id = payload.get("id")
  86. update_fields = []
  87. params = []
  88. for field, value in data.dict(exclude_unset=True).items():
  89. update_fields.append(f"{field} = %s")
  90. params.append(value)
  91. if update_fields:
  92. query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
  93. params.append(user_id)
  94. db.execute_commit(query, tuple(params))
  95. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  96. return user[0]
  97. @router.post("/ping")
  98. async def track_ping(token: str = Depends(auth_utils.oauth2_scheme)):
  99. payload = auth_utils.decode_token(token)
  100. unread_count = 0
  101. if payload and payload.get("id"):
  102. user_id = payload.get("id")
  103. role = payload.get("role")
  104. session_utils.track_user_ping(user_id)
  105. if role == 'admin':
  106. row = db.execute_query("SELECT count(*) as cnt FROM order_messages WHERE is_from_admin = FALSE AND is_read = FALSE")
  107. else:
  108. row = db.execute_query(
  109. "SELECT count(*) as cnt FROM order_messages om JOIN orders o ON om.order_id = o.id WHERE o.user_id = %s AND om.is_from_admin = TRUE AND om.is_read = FALSE",
  110. (user_id,)
  111. )
  112. if row:
  113. unread_count = row[0]['cnt']
  114. return {"status": "ok", "unread_count": unread_count}