global_manager.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import json
  2. from typing import Dict, List
  3. from fastapi import WebSocket
  4. class GlobalConnectionManager:
  5. def __init__(self):
  6. self.active_connections: Dict[int, List[WebSocket]] = {}
  7. self.user_roles: Dict[int, str] = {}
  8. async def connect(self, websocket: WebSocket, user_id: int, role: str):
  9. await websocket.accept()
  10. if user_id not in self.active_connections:
  11. self.active_connections[user_id] = []
  12. self.active_connections[user_id].append(websocket)
  13. self.user_roles[user_id] = role
  14. def disconnect(self, websocket: WebSocket, user_id: int):
  15. if user_id in self.active_connections:
  16. self.active_connections[user_id] = [ws for ws in self.active_connections[user_id] if ws != websocket]
  17. if not self.active_connections[user_id]:
  18. del self.active_connections[user_id]
  19. if user_id in self.user_roles:
  20. del self.user_roles[user_id]
  21. async def send_unread_count(self, user_id: int, count: int):
  22. if user_id in self.active_connections:
  23. payload = json.dumps({"type": "unread_count", "count": count})
  24. for ws in self.active_connections[user_id]:
  25. try:
  26. await ws.send_text(payload)
  27. except:
  28. pass
  29. async def notify_user(self, user_id: int):
  30. import db
  31. row = db.execute_query(
  32. "SELECT count(*) as cnt FROM order_messages om JOIN orders o ON om.order_id = o.id WHERE o.user_id = %s AND om.is_from_admin = TRUE AND om.is_read = FALSE",
  33. (user_id,)
  34. )
  35. count = row[0]['cnt'] if row else 0
  36. await self.send_unread_count(user_id, count)
  37. async def notify_admins(self):
  38. import db
  39. row = db.execute_query("SELECT count(*) as cnt FROM order_messages WHERE is_from_admin = FALSE AND is_read = FALSE")
  40. count = row[0]['cnt'] if row else 0
  41. payload = json.dumps({"type": "unread_count", "count": count})
  42. for uid, role in self.user_roles.items():
  43. if role == 'admin' and uid in self.active_connections:
  44. for ws in self.active_connections[uid]:
  45. try:
  46. await ws.send_text(payload)
  47. except:
  48. pass
  49. global_manager = GlobalConnectionManager()