auth.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. from fastapi import APIRouter, Request, Depends, HTTPException, WebSocket, WebSocketDisconnect, Query
  2. from typing import Optional, List
  3. from services.global_manager import global_manager
  4. from services.rate_limit_service import rate_limit_service
  5. import auth_utils
  6. import db
  7. import schemas
  8. import session_utils
  9. import uuid
  10. from datetime import datetime, timedelta
  11. import locales
  12. from dependencies import get_current_user, require_admin
  13. import config
  14. try:
  15. from google.oauth2 import id_token
  16. from google.auth.transport import requests as google_requests
  17. except ImportError:
  18. id_token = None
  19. google_requests = None
  20. router = APIRouter(prefix="/auth", tags=["auth"])
  21. @router.post("/register", response_model=schemas.UserResponse)
  22. async def register(request: Request, user: schemas.UserCreate, lang: str = "en"):
  23. existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (user.email,))
  24. if existing_user:
  25. raise HTTPException(status_code=400, detail=locales.translate_error("email_already_registered", lang))
  26. ip_address = request.client.host if request.client else None
  27. hashed_password = auth_utils.get_password_hash(user.password)
  28. query = """
  29. INSERT INTO users (email, password_hash, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, is_company, company_name, company_pib, company_address)
  30. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  31. """
  32. params = (user.email, hashed_password, user.first_name, user.last_name, user.phone, user.shipping_address, user.preferred_language, 'user', ip_address, user.is_company, user.company_name, user.company_pib, user.company_address)
  33. user_id = db.execute_commit(query, params)
  34. new_user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  35. return new_user[0]
  36. @router.post("/login", response_model=schemas.Token)
  37. async def login(request: Request, user_data: schemas.UserLogin, lang: str = "en"):
  38. ip = request.client.host if request.client else "unknown"
  39. email = user_data.email.lower()
  40. # 1. Check Global Rate Limit
  41. if rate_limit_service.is_rate_limited(email, ip):
  42. raise HTTPException(
  43. status_code=429,
  44. detail=locales.translate_error("too_many_attempts", lang)
  45. )
  46. # 2. Check if Captcha is Required
  47. if rate_limit_service.is_captcha_required(email, ip):
  48. if not user_data.captcha_token:
  49. raise HTTPException(
  50. status_code=403,
  51. detail=locales.translate_error("captcha_required", lang)
  52. )
  53. # 3. Verify Captcha
  54. if not await rate_limit_service.verify_captcha(user_data.captcha_token):
  55. raise HTTPException(
  56. status_code=403,
  57. detail=locales.translate_error("invalid_token", lang)
  58. )
  59. # 4. Attempt Authentication
  60. user = db.execute_query("SELECT * FROM users WHERE email = %s", (email,))
  61. if not user or not auth_utils.verify_password(user_data.password, user[0]['password_hash']):
  62. # Log failure
  63. rate_limit_service.record_failed_attempt(email, ip)
  64. raise HTTPException(status_code=401, detail=locales.translate_error("incorrect_credentials", lang))
  65. if not user[0].get('is_active', True):
  66. raise HTTPException(status_code=403, detail="Your account has been suspended.")
  67. # 5. Success - Reset Rate Limits
  68. rate_limit_service.reset_attempts(email, ip)
  69. access_token = auth_utils.create_access_token(
  70. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  71. )
  72. return {"access_token": access_token, "token_type": "bearer"}
  73. @router.post("/social-login", response_model=schemas.Token)
  74. async def social_login(request: Request, data: schemas.SocialLogin):
  75. email = data.email.lower() if data.email else None
  76. first_name = data.first_name
  77. last_name = data.last_name
  78. # 1. Verify token if provider is Google
  79. if data.provider == 'google':
  80. print(f"DEBUG: Social Login attempt. id_token library available: {id_token is not None}")
  81. print(f"DEBUG: config.GOOGLE_CLIENT_ID exists: {bool(config.GOOGLE_CLIENT_ID)}")
  82. print(f"DEBUG: config.GOOGLE_CLIENT_ID value: {config.GOOGLE_CLIENT_ID}")
  83. if not id_token or not config.GOOGLE_CLIENT_ID:
  84. msg = f"Config error: id_token_lib={id_token is not None}, client_id_set={bool(config.GOOGLE_CLIENT_ID)}"
  85. raise HTTPException(status_code=500, detail=f"Google Auth not configured on server ({msg})")
  86. try:
  87. # Verify the ID token
  88. idinfo = id_token.verify_oauth2_token(data.token, google_requests.Request(), config.GOOGLE_CLIENT_ID)
  89. # ID token is valid. Get user's Google info
  90. if idinfo['iss'] not in ['accounts.google.com', 'https://accounts.google.com']:
  91. raise ValueError('Wrong issuer.')
  92. email = idinfo['email'].lower()
  93. first_name = idinfo.get('given_name', first_name)
  94. last_name = idinfo.get('family_name', last_name)
  95. except Exception as e:
  96. print(f"Google Token Verification Error: {e}")
  97. raise HTTPException(status_code=401, detail="Invalid Google token")
  98. if not email:
  99. raise HTTPException(status_code=400, detail="Email is required")
  100. # 2. Proceed with login/registration
  101. user = db.execute_query("SELECT id, email, role, is_active FROM users WHERE email = %s", (email,))
  102. if user:
  103. if not user[0].get('is_active', True):
  104. raise HTTPException(status_code=403, detail="Your account has been suspended.")
  105. access_token = auth_utils.create_access_token(
  106. data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
  107. )
  108. return {"access_token": access_token, "token_type": "bearer"}
  109. else:
  110. ip_address = request.client.host if request.client else None
  111. hashed_password = auth_utils.get_password_hash(str(uuid.uuid4()))
  112. query = """
  113. INSERT INTO users (email, password_hash, first_name, last_name, preferred_language, role, ip_address)
  114. VALUES (%s, %s, %s, %s, %s, %s, %s)
  115. """
  116. params = (email, hashed_password, first_name, last_name, data.preferred_language, 'user', ip_address)
  117. user_id = db.execute_commit(query, params)
  118. access_token = auth_utils.create_access_token(data={"sub": email, "id": user_id, "role": 'user'})
  119. return {"access_token": access_token, "token_type": "bearer"}
  120. @router.post("/logout")
  121. async def logout(user: dict = Depends(get_current_user)):
  122. sid = user.get("sid")
  123. if sid: session_utils.delete_session(sid)
  124. return {"message": "Successfully logged out"}
  125. @router.post("/forgot-password")
  126. async def forgot_password(request: schemas.ForgotPassword):
  127. user = db.execute_query("SELECT id FROM users WHERE email = %s", (request.email,))
  128. if not user: raise HTTPException(status_code=404, detail="Email not found")
  129. token = str(uuid.uuid4())
  130. expires_at = datetime.utcnow() + timedelta(minutes=15)
  131. db.execute_commit("INSERT INTO password_reset_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)", (user[0]['id'], token, expires_at))
  132. return {"message": "Reset instructions sent to your email", "demo_token": token}
  133. @router.post("/reset-password")
  134. async def reset_password(request: schemas.ResetPassword):
  135. reset_data = db.execute_query("SELECT user_id, expires_at FROM password_reset_tokens WHERE token = %s", (request.token,))
  136. if not reset_data: raise HTTPException(status_code=400, detail="Invalid token")
  137. if reset_data[0]['expires_at'] < datetime.utcnow(): raise HTTPException(status_code=400, detail="Token expired")
  138. hashed_password = auth_utils.get_password_hash(request.new_password)
  139. db.execute_commit("UPDATE users SET password_hash = %s WHERE id = %s", (hashed_password, reset_data[0]['user_id']))
  140. db.execute_commit("DELETE FROM password_reset_tokens WHERE token = %s", (request.token,))
  141. return {"message": "Password reset successfully"}
  142. @router.get("/me", response_model=schemas.UserResponse)
  143. async def get_me(user: dict = Depends(get_current_user)):
  144. user_id = user.get("id")
  145. user_data = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  146. if not user_data: raise HTTPException(status_code=404, detail="User not found")
  147. return user_data[0]
  148. @router.put("/me", response_model=schemas.UserResponse)
  149. async def update_me(data: schemas.UserUpdate, user: dict = Depends(get_current_user)):
  150. user_id = user.get("id")
  151. update_fields = []
  152. params = []
  153. for field, value in data.dict(exclude_unset=True).items():
  154. update_fields.append(f"{field} = %s")
  155. params.append(value)
  156. if update_fields:
  157. query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
  158. params.append(user_id)
  159. db.execute_commit(query, tuple(params))
  160. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  161. return user[0]
  162. @router.get("/admin/users")
  163. async def admin_get_users(page: int = 1, size: int = 50, search: Optional[str] = None, admin: dict = Depends(require_admin)):
  164. offset = (page - 1) * size
  165. base_query = "SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users"
  166. count_query = "SELECT COUNT(*) as total FROM users"
  167. params = []
  168. if search and search.strip():
  169. where_clause = " WHERE email LIKE %s OR first_name LIKE %s OR last_name LIKE %s OR phone LIKE %s"
  170. base_query += where_clause
  171. count_query += where_clause
  172. pattern = f"%{search.strip()}%"
  173. params = [pattern] * 4
  174. base_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
  175. users = db.execute_query(base_query, tuple(params + [size, offset]))
  176. total = db.execute_query(count_query, tuple(params))[0]['total']
  177. return {"users": users, "total": total, "page": page, "size": size}
  178. @router.post("/admin/users", response_model=schemas.UserResponse)
  179. async def admin_create_user(data: schemas.UserCreate, admin: dict = Depends(require_admin)):
  180. existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (data.email,))
  181. if existing_user:
  182. raise HTTPException(status_code=400, detail="Email already registered")
  183. hashed_password = auth_utils.get_password_hash(data.password)
  184. user_id = db.execute_commit(
  185. "INSERT INTO users (email, password_hash, first_name, last_name, phone, role, can_chat) VALUES (%s, %s, %s, %s, %s, %s, %s)",
  186. (data.email, hashed_password, data.first_name, data.last_name, data.phone, 'user', True)
  187. )
  188. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (user_id,))
  189. return user[0]
  190. @router.patch("/users/{target_id}/admin", response_model=schemas.UserResponse)
  191. async def admin_update_user(target_id: int, data: schemas.AdminUserUpdate, admin: dict = Depends(require_admin)):
  192. update_fields = []
  193. params = []
  194. update_dict = data.dict(exclude_unset=True)
  195. # Handle password hashing
  196. if "password" in update_dict:
  197. password = update_dict.pop("password")
  198. update_dict["password_hash"] = auth_utils.get_password_hash(password)
  199. for field, value in update_dict.items():
  200. update_fields.append(f"`{field}` = %s")
  201. params.append(value)
  202. if update_fields:
  203. query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
  204. params.append(target_id)
  205. db.execute_commit(query, tuple(params))
  206. # If user was deactivated, kick from active sessions
  207. if update_dict.get("is_active") is False:
  208. await global_manager.kick_user(target_id)
  209. user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (target_id,))
  210. if not user: raise HTTPException(status_code=404, detail="User not found")
  211. return user[0]
  212. # WebSocket implementation moved to main.py to handle path prefixing issues
  213. # @router.websocket("/ws/global")
  214. # async def ws_global(websocket: WebSocket, token: str = Query(...)):
  215. # ...