| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import json
- from typing import Dict, List
- from fastapi import WebSocket
- class GlobalConnectionManager:
- def __init__(self):
- self.active_connections: Dict[int, List[WebSocket]] = {}
- self.user_roles: Dict[int, str] = {}
-
- async def connect(self, websocket: WebSocket, user_id: int, role: str):
- await websocket.accept()
- if user_id not in self.active_connections:
- self.active_connections[user_id] = []
- self.active_connections[user_id].append(websocket)
- self.user_roles[user_id] = role
- def disconnect(self, websocket: WebSocket, user_id: int):
- if user_id in self.active_connections:
- self.active_connections[user_id] = [ws for ws in self.active_connections[user_id] if ws != websocket]
- if not self.active_connections[user_id]:
- del self.active_connections[user_id]
- if user_id in self.user_roles:
- del self.user_roles[user_id]
- async def send_unread_count(self, user_id: int, count: int):
- if user_id in self.active_connections:
- payload = json.dumps({"type": "unread_count", "count": count})
- for ws in self.active_connections[user_id]:
- try:
- await ws.send_text(payload)
- except:
- pass
- async def notify_user(self, user_id: int):
- import db
- row = db.execute_query(
- "SELECT count(*) as cnt FROM order_messages om JOIN orders o ON om.order_id = o.id WHERE o.user_id = %s AND om.is_from_admin = TRUE AND om.is_read = FALSE",
- (user_id,)
- )
- count = row[0]['cnt'] if row else 0
- await self.send_unread_count(user_id, count)
- async def notify_admins(self):
- import db
- row = db.execute_query("SELECT count(*) as cnt FROM order_messages WHERE is_from_admin = FALSE AND is_read = FALSE")
- count = row[0]['cnt'] if row else 0
- payload = json.dumps({"type": "unread_count", "count": count})
- for uid, role in self.user_roles.items():
- if role == 'admin' and uid in self.active_connections:
- for ws in self.active_connections[uid]:
- try:
- await ws.send_text(payload)
- except:
- pass
- global_manager = GlobalConnectionManager()
|