import json from typing import Dict, List from fastapi import WebSocket class GlobalConnectionManager: def __init__(self): self.active_connections: Dict[int, List[WebSocket]] = {} self.user_roles: Dict[int, str] = {} async def connect(self, websocket: WebSocket, user_id: int, role: str): await websocket.accept() if user_id not in self.active_connections: self.active_connections[user_id] = [] self.active_connections[user_id].append(websocket) self.user_roles[user_id] = role def disconnect(self, websocket: WebSocket, user_id: int): if user_id in self.active_connections: self.active_connections[user_id] = [ws for ws in self.active_connections[user_id] if ws != websocket] if not self.active_connections[user_id]: del self.active_connections[user_id] if user_id in self.user_roles: del self.user_roles[user_id] async def send_unread_count(self, user_id: int, count: int): if user_id in self.active_connections: payload = json.dumps({"type": "unread_count", "count": count}) for ws in self.active_connections[user_id]: try: await ws.send_text(payload) except: pass async def notify_user(self, user_id: int): import db row = db.execute_query( "SELECT count(*) as cnt FROM order_messages om JOIN orders o ON om.order_id = o.id WHERE o.user_id = %s AND om.is_from_admin = TRUE AND om.is_read = FALSE", (user_id,) ) count = row[0]['cnt'] if row else 0 await self.send_unread_count(user_id, count) async def broadcast_to_role(self, role_name: str, payload: str): for uid, role in self.user_roles.items(): if role == role_name and uid in self.active_connections: for ws in self.active_connections[uid]: try: await ws.send_text(payload) except: pass async def notify_admins(self): import db row = db.execute_query("SELECT count(*) as cnt FROM order_messages WHERE is_from_admin = FALSE AND is_read = FALSE") count = row[0]['cnt'] if row else 0 payload = json.dumps({"type": "unread_count", "count": count}) await self.broadcast_to_role('admin', payload) async def notify_admins_new_message(self, order_id: int, message_text: str): hint = (message_text[:50] + '...') if len(message_text) > 50 else message_text payload = json.dumps({ "type": "new_chat_message", "order_id": order_id, "text": hint }) await self.broadcast_to_role('admin', payload) async def notify_order_read(self, order_id: int): payload = json.dumps({ "type": "order_read", "order_id": order_id }) await self.broadcast_to_role('admin', payload) async def kick_user(self, user_id: int): if user_id in self.active_connections: payload = json.dumps({"type": "account_suspended"}) for ws in self.active_connections[user_id]: try: await ws.send_text(payload) await ws.close(code=4003) # Custom code for kick except: pass # Connections will be removed via disconnect() called on close global_manager = GlobalConnectionManager()