| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- from fastapi import APIRouter, Request, Depends, HTTPException, WebSocket, WebSocketDisconnect, Query
- from typing import Optional, List
- from services.global_manager import global_manager
- from services.rate_limit_service import rate_limit_service
- import auth_utils
- import db
- import schemas
- import session_utils
- import uuid
- from datetime import datetime, timedelta
- import locales
- from dependencies import get_current_user, require_admin
- router = APIRouter(prefix="/auth", tags=["auth"])
- @router.post("/register", response_model=schemas.UserResponse)
- async def register(request: Request, user: schemas.UserCreate, lang: str = "en"):
- existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (user.email,))
- if existing_user:
- raise HTTPException(status_code=400, detail=locales.translate_error("email_already_registered", lang))
-
- ip_address = request.client.host if request.client else None
- hashed_password = auth_utils.get_password_hash(user.password)
-
- query = """
- INSERT INTO users (email, password_hash, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, is_company, company_name, company_pib, company_address)
- VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
- """
- params = (user.email, hashed_password, user.first_name, user.last_name, user.phone, user.shipping_address, user.preferred_language, 'user', ip_address, user.is_company, user.company_name, user.company_pib, user.company_address)
-
- user_id = db.execute_commit(query, params)
- new_user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (user_id,))
- return new_user[0]
- @router.post("/login", response_model=schemas.Token)
- async def login(request: Request, user_data: schemas.UserLogin, lang: str = "en"):
- ip = request.client.host if request.client else "unknown"
- email = user_data.email.lower()
- # 1. Check Global Rate Limit
- if rate_limit_service.is_rate_limited(email, ip):
- raise HTTPException(
- status_code=429,
- detail=locales.translate_error("too_many_attempts", lang)
- )
- # 2. Check if Captcha is Required
- if rate_limit_service.is_captcha_required(email, ip):
- if not user_data.captcha_token:
- raise HTTPException(
- status_code=403,
- detail=locales.translate_error("captcha_required", lang)
- )
-
- # 3. Verify Captcha
- if not await rate_limit_service.verify_captcha(user_data.captcha_token):
- raise HTTPException(
- status_code=403,
- detail=locales.translate_error("invalid_token", lang)
- )
- # 4. Attempt Authentication
- user = db.execute_query("SELECT * FROM users WHERE email = %s", (email,))
- if not user or not auth_utils.verify_password(user_data.password, user[0]['password_hash']):
- # Log failure
- rate_limit_service.record_failed_attempt(email, ip)
- raise HTTPException(status_code=401, detail=locales.translate_error("incorrect_credentials", lang))
-
- if not user[0].get('is_active', True):
- raise HTTPException(status_code=403, detail="Your account has been suspended.")
-
- # 5. Success - Reset Rate Limits
- rate_limit_service.reset_attempts(email, ip)
- access_token = auth_utils.create_access_token(
- data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
- )
- return {"access_token": access_token, "token_type": "bearer"}
- @router.post("/social-login", response_model=schemas.Token)
- async def social_login(request: Request, data: schemas.SocialLogin):
- user = db.execute_query("SELECT id, email, role, is_active FROM users WHERE email = %s", (data.email,))
- if user:
- if not user[0].get('is_active', True):
- raise HTTPException(status_code=403, detail="Your account has been suspended.")
- access_token = auth_utils.create_access_token(
- data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']}
- )
- return {"access_token": access_token, "token_type": "bearer"}
- else:
- ip_address = request.client.host if request.client else None
- hashed_password = auth_utils.get_password_hash(str(uuid.uuid4()))
- query = "INSERT INTO users (email, password_hash, first_name, last_name, preferred_language, role, ip_address) VALUES (%s, %s, %s, %s, %s, %s, %s)"
- params = (data.email, hashed_password, data.first_name, data.last_name, data.preferred_language, 'user', ip_address)
- user_id = db.execute_commit(query, params)
- access_token = auth_utils.create_access_token(data={"sub": data.email, "id": user_id, "role": 'user'})
- return {"access_token": access_token, "token_type": "bearer"}
- @router.post("/logout")
- async def logout(user: dict = Depends(get_current_user)):
- sid = user.get("sid")
- if sid: session_utils.delete_session(sid)
- return {"message": "Successfully logged out"}
- @router.post("/forgot-password")
- async def forgot_password(request: schemas.ForgotPassword):
- user = db.execute_query("SELECT id FROM users WHERE email = %s", (request.email,))
- if not user: raise HTTPException(status_code=404, detail="Email not found")
- token = str(uuid.uuid4())
- expires_at = datetime.utcnow() + timedelta(minutes=15)
- db.execute_commit("INSERT INTO password_reset_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)", (user[0]['id'], token, expires_at))
- return {"message": "Reset instructions sent to your email", "demo_token": token}
- @router.post("/reset-password")
- async def reset_password(request: schemas.ResetPassword):
- reset_data = db.execute_query("SELECT user_id, expires_at FROM password_reset_tokens WHERE token = %s", (request.token,))
- if not reset_data: raise HTTPException(status_code=400, detail="Invalid token")
- if reset_data[0]['expires_at'] < datetime.utcnow(): raise HTTPException(status_code=400, detail="Token expired")
- hashed_password = auth_utils.get_password_hash(request.new_password)
- db.execute_commit("UPDATE users SET password_hash = %s WHERE id = %s", (hashed_password, reset_data[0]['user_id']))
- db.execute_commit("DELETE FROM password_reset_tokens WHERE token = %s", (request.token,))
- return {"message": "Password reset successfully"}
- @router.get("/me", response_model=schemas.UserResponse)
- async def get_me(user: dict = Depends(get_current_user)):
- user_id = user.get("id")
- user_data = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (user_id,))
- if not user_data: raise HTTPException(status_code=404, detail="User not found")
- return user_data[0]
- @router.put("/me", response_model=schemas.UserResponse)
- async def update_me(data: schemas.UserUpdate, user: dict = Depends(get_current_user)):
- user_id = user.get("id")
- update_fields = []
- params = []
- for field, value in data.dict(exclude_unset=True).items():
- update_fields.append(f"{field} = %s")
- params.append(value)
- if update_fields:
- query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
- params.append(user_id)
- db.execute_commit(query, tuple(params))
- user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (user_id,))
- return user[0]
- @router.get("/admin/users")
- async def admin_get_users(page: int = 1, size: int = 50, search: Optional[str] = None, admin: dict = Depends(require_admin)):
-
- offset = (page - 1) * size
- base_query = "SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users"
- count_query = "SELECT COUNT(*) as total FROM users"
- params = []
- if search and search.strip():
- where_clause = " WHERE email LIKE %s OR first_name LIKE %s OR last_name LIKE %s OR phone LIKE %s"
- base_query += where_clause
- count_query += where_clause
- pattern = f"%{search.strip()}%"
- params = [pattern] * 4
-
- base_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
-
- users = db.execute_query(base_query, tuple(params + [size, offset]))
- total = db.execute_query(count_query, tuple(params))[0]['total']
-
- return {"users": users, "total": total, "page": page, "size": size}
- @router.post("/admin/users", response_model=schemas.UserResponse)
- async def admin_create_user(data: schemas.UserCreate, admin: dict = Depends(require_admin)):
-
- existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (data.email,))
- if existing_user:
- raise HTTPException(status_code=400, detail="Email already registered")
-
- hashed_password = auth_utils.get_password_hash(data.password)
- user_id = db.execute_commit(
- "INSERT INTO users (email, password_hash, first_name, last_name, phone, role, can_chat) VALUES (%s, %s, %s, %s, %s, %s, %s)",
- (data.email, hashed_password, data.first_name, data.last_name, data.phone, 'user', True)
- )
-
- user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (user_id,))
- return user[0]
- @router.patch("/users/{target_id}/admin", response_model=schemas.UserResponse)
- async def admin_update_user(target_id: int, data: schemas.AdminUserUpdate, admin: dict = Depends(require_admin)):
-
- update_fields = []
- params = []
- update_dict = data.dict(exclude_unset=True)
-
- # Handle password hashing
- if "password" in update_dict:
- password = update_dict.pop("password")
- update_dict["password_hash"] = auth_utils.hash_password(password)
-
- for field, value in update_dict.items():
- update_fields.append(f"`{field}` = %s")
- params.append(value)
-
- if update_fields:
- query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s"
- params.append(target_id)
- db.execute_commit(query, tuple(params))
-
- # If user was deactivated, kick from active sessions
- if update_dict.get("is_active") is False:
- await global_manager.kick_user(target_id)
-
- user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, can_chat, is_active, is_company, company_name, company_pib, company_address, ip_address, created_at FROM users WHERE id = %s", (target_id,))
- if not user: raise HTTPException(status_code=404, detail="User not found")
- return user[0]
- @router.websocket("/ws/global")
- async def ws_global(websocket: WebSocket, token: str = Query(...)):
- payload = auth_utils.decode_token(token)
- if not payload:
- await websocket.close(code=4001)
- return
- user_id = payload.get("id")
- role = payload.get("role")
- if not user_id:
- await websocket.close(code=4001)
- return
-
- await global_manager.connect(websocket, user_id, role)
- session_utils.track_user_ping(user_id)
-
- # Send initial unread count
- await global_manager.notify_user(user_id) if role != 'admin' else await global_manager.notify_admins()
-
- try:
- while True:
- data = await websocket.receive_text()
- if data == "ping":
- session_utils.track_user_ping(user_id)
- except WebSocketDisconnect:
- global_manager.disconnect(websocket, user_id)
|