global_manager.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import json
  2. from typing import Dict, List
  3. from fastapi import WebSocket
  4. class GlobalConnectionManager:
  5. def __init__(self):
  6. self.active_connections: Dict[int, List[WebSocket]] = {}
  7. self.user_roles: Dict[int, str] = {}
  8. async def connect(self, websocket: WebSocket, user_id: int, role: str):
  9. await websocket.accept()
  10. if user_id not in self.active_connections:
  11. self.active_connections[user_id] = []
  12. self.active_connections[user_id].append(websocket)
  13. self.user_roles[user_id] = role
  14. def disconnect(self, websocket: WebSocket, user_id: int):
  15. if user_id in self.active_connections:
  16. self.active_connections[user_id] = [ws for ws in self.active_connections[user_id] if ws != websocket]
  17. if not self.active_connections[user_id]:
  18. del self.active_connections[user_id]
  19. if user_id in self.user_roles:
  20. del self.user_roles[user_id]
  21. async def send_unread_count(self, user_id: int, count: int):
  22. if user_id in self.active_connections:
  23. payload = json.dumps({"type": "unread_count", "count": count})
  24. for ws in self.active_connections[user_id]:
  25. try:
  26. await ws.send_text(payload)
  27. except:
  28. pass
  29. async def notify_user(self, user_id: int):
  30. import db
  31. row = db.execute_query(
  32. "SELECT count(*) as cnt FROM order_messages om JOIN orders o ON om.order_id = o.id WHERE o.user_id = %s AND om.is_from_admin = TRUE AND om.is_read = FALSE",
  33. (user_id,)
  34. )
  35. count = row[0]['cnt'] if row else 0
  36. await self.send_unread_count(user_id, count)
  37. async def notify_admins(self):
  38. import db
  39. row = db.execute_query("SELECT count(*) as cnt FROM order_messages WHERE is_from_admin = FALSE AND is_read = FALSE")
  40. count = row[0]['cnt'] if row else 0
  41. payload = json.dumps({"type": "unread_count", "count": count})
  42. for uid, role in self.user_roles.items():
  43. if role == 'admin' and uid in self.active_connections:
  44. for ws in self.active_connections[uid]:
  45. try:
  46. await ws.send_text(payload)
  47. except:
  48. pass
  49. async def kick_user(self, user_id: int):
  50. if user_id in self.active_connections:
  51. payload = json.dumps({"type": "account_suspended"})
  52. for ws in self.active_connections[user_id]:
  53. try:
  54. await ws.send_text(payload)
  55. await ws.close(code=4003) # Custom code for kick
  56. except:
  57. pass
  58. # Connections will be removed via disconnect() called on close
  59. global_manager = GlobalConnectionManager()