from fastapi import APIRouter, Request, Depends, HTTPException, WebSocket, WebSocketDisconnect, Query from services.global_manager import global_manager import auth_utils import db import schemas import session_utils import uuid from datetime import datetime, timedelta import locales router = APIRouter(prefix="/auth", tags=["auth"]) @router.post("/register", response_model=schemas.UserResponse) async def register(request: Request, user: schemas.UserCreate, lang: str = "en"): existing_user = db.execute_query("SELECT id FROM users WHERE email = %s", (user.email,)) if existing_user: raise HTTPException(status_code=400, detail=locales.translate_error("email_already_registered", lang)) ip_address = request.client.host if request.client else None hashed_password = auth_utils.get_password_hash(user.password) query = """ INSERT INTO users (email, password_hash, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) """ params = (user.email, hashed_password, user.first_name, user.last_name, user.phone, user.shipping_address, user.preferred_language, 'user', ip_address) user_id = db.execute_commit(query, params) new_user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, created_at FROM users WHERE id = %s", (user_id,)) return new_user[0] @router.post("/login", response_model=schemas.Token) async def login(user_data: schemas.UserLogin, lang: str = "en"): user = db.execute_query("SELECT * FROM users WHERE email = %s", (user_data.email,)) if not user or not auth_utils.verify_password(user_data.password, user[0]['password_hash']): raise HTTPException(status_code=401, detail=locales.translate_error("incorrect_credentials", lang)) access_token = auth_utils.create_access_token( data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']} ) return {"access_token": access_token, "token_type": "bearer"} @router.post("/social-login", response_model=schemas.Token) async def social_login(request: Request, data: schemas.SocialLogin): user = db.execute_query("SELECT id, email, role FROM users WHERE email = %s", (data.email,)) if user: access_token = auth_utils.create_access_token( data={"sub": user[0]['email'], "id": user[0]['id'], "role": user[0]['role']} ) return {"access_token": access_token, "token_type": "bearer"} else: ip_address = request.client.host if request.client else None hashed_password = auth_utils.get_password_hash(str(uuid.uuid4())) query = "INSERT INTO users (email, password_hash, first_name, last_name, preferred_language, role, ip_address) VALUES (%s, %s, %s, %s, %s, %s, %s)" params = (data.email, hashed_password, data.first_name, data.last_name, data.preferred_language, 'user', ip_address) user_id = db.execute_commit(query, params) access_token = auth_utils.create_access_token(data={"sub": data.email, "id": user_id, "role": 'user'}) return {"access_token": access_token, "token_type": "bearer"} @router.post("/logout") async def logout(token: str = Depends(auth_utils.oauth2_scheme)): payload = auth_utils.decode_token(token) if payload: sid = payload.get("sid") if sid: session_utils.delete_session(sid) return {"message": "Successfully logged out"} @router.post("/forgot-password") async def forgot_password(request: schemas.ForgotPassword): user = db.execute_query("SELECT id FROM users WHERE email = %s", (request.email,)) if not user: raise HTTPException(status_code=404, detail="Email not found") token = str(uuid.uuid4()) expires_at = datetime.utcnow() + timedelta(minutes=15) db.execute_commit("INSERT INTO password_reset_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)", (user[0]['id'], token, expires_at)) return {"message": "Reset instructions sent to your email", "demo_token": token} @router.post("/reset-password") async def reset_password(request: schemas.ResetPassword): reset_data = db.execute_query("SELECT user_id, expires_at FROM password_reset_tokens WHERE token = %s", (request.token,)) if not reset_data: raise HTTPException(status_code=400, detail="Invalid token") if reset_data[0]['expires_at'] < datetime.utcnow(): raise HTTPException(status_code=400, detail="Token expired") hashed_password = auth_utils.get_password_hash(request.new_password) db.execute_commit("UPDATE users SET password_hash = %s WHERE id = %s", (hashed_password, reset_data[0]['user_id'])) db.execute_commit("DELETE FROM password_reset_tokens WHERE token = %s", (request.token,)) return {"message": "Password reset successfully"} @router.get("/me", response_model=schemas.UserResponse) async def get_me(token: str = Depends(auth_utils.oauth2_scheme)): payload = auth_utils.decode_token(token) if not payload: raise HTTPException(status_code=401, detail="Invalid token") user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, created_at FROM users WHERE id = %s", (payload.get("id"),)) if not user: raise HTTPException(status_code=404, detail="User not found") return user[0] @router.put("/me", response_model=schemas.UserResponse) async def update_me(data: schemas.UserUpdate, token: str = Depends(auth_utils.oauth2_scheme)): payload = auth_utils.decode_token(token) if not payload: raise HTTPException(status_code=401, detail="Invalid token") user_id = payload.get("id") update_fields = [] params = [] for field, value in data.dict(exclude_unset=True).items(): update_fields.append(f"{field} = %s") params.append(value) if update_fields: query = f"UPDATE users SET {', '.join(update_fields)} WHERE id = %s" params.append(user_id) db.execute_commit(query, tuple(params)) user = db.execute_query("SELECT id, email, first_name, last_name, phone, shipping_address, preferred_language, role, ip_address, created_at FROM users WHERE id = %s", (user_id,)) return user[0] @router.websocket("/ws/global") async def ws_global(websocket: WebSocket, token: str = Query(...)): payload = auth_utils.decode_token(token) if not payload: await websocket.close(code=4001) return user_id = payload.get("id") role = payload.get("role") if not user_id: await websocket.close(code=4001) return await global_manager.connect(websocket, user_id, role) session_utils.track_user_ping(user_id) # Send initial unread count await global_manager.notify_user(user_id) if role != 'admin' else await global_manager.notify_admins() try: while True: data = await websocket.receive_text() if data == "ping": session_utils.track_user_ping(user_id) except WebSocketDisconnect: global_manager.disconnect(websocket, user_id)