global_manager.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import json
  2. from typing import Dict, List
  3. from fastapi import WebSocket
  4. class GlobalConnectionManager:
  5. def __init__(self):
  6. self.active_connections: Dict[int, List[WebSocket]] = {}
  7. self.user_roles: Dict[int, str] = {}
  8. async def connect(self, websocket: WebSocket, user_id: int, role: str):
  9. await websocket.accept()
  10. if user_id not in self.active_connections:
  11. self.active_connections[user_id] = []
  12. self.active_connections[user_id].append(websocket)
  13. self.user_roles[user_id] = role
  14. def disconnect(self, websocket: WebSocket, user_id: int):
  15. if user_id in self.active_connections:
  16. self.active_connections[user_id] = [ws for ws in self.active_connections[user_id] if ws != websocket]
  17. if not self.active_connections[user_id]:
  18. del self.active_connections[user_id]
  19. if user_id in self.user_roles:
  20. del self.user_roles[user_id]
  21. async def send_unread_count(self, user_id: int, count: int):
  22. if user_id in self.active_connections:
  23. payload = json.dumps({"type": "unread_count", "count": count})
  24. for ws in self.active_connections[user_id]:
  25. try:
  26. await ws.send_text(payload)
  27. except:
  28. pass
  29. async def notify_user(self, user_id: int):
  30. import db
  31. row = db.execute_query(
  32. "SELECT count(*) as cnt FROM order_messages om JOIN orders o ON om.order_id = o.id WHERE o.user_id = %s AND om.is_from_admin = TRUE AND om.is_read = FALSE",
  33. (user_id,)
  34. )
  35. count = row[0]['cnt'] if row else 0
  36. await self.send_unread_count(user_id, count)
  37. async def broadcast_to_role(self, role_name: str, payload: str):
  38. for uid, role in self.user_roles.items():
  39. if role == role_name and uid in self.active_connections:
  40. for ws in self.active_connections[uid]:
  41. try:
  42. await ws.send_text(payload)
  43. except:
  44. pass
  45. async def notify_admins(self):
  46. import db
  47. row = db.execute_query("SELECT count(*) as cnt FROM order_messages WHERE is_from_admin = FALSE AND is_read = FALSE")
  48. count = row[0]['cnt'] if row else 0
  49. payload = json.dumps({"type": "unread_count", "count": count})
  50. await self.broadcast_to_role('admin', payload)
  51. async def notify_admins_new_message(self, order_id: int, message_text: str):
  52. hint = (message_text[:50] + '...') if len(message_text) > 50 else message_text
  53. payload = json.dumps({
  54. "type": "new_chat_message",
  55. "order_id": order_id,
  56. "text": hint
  57. })
  58. await self.broadcast_to_role('admin', payload)
  59. async def notify_order_read(self, order_id: int):
  60. payload = json.dumps({
  61. "type": "order_read",
  62. "order_id": order_id
  63. })
  64. await self.broadcast_to_role('admin', payload)
  65. async def kick_user(self, user_id: int):
  66. if user_id in self.active_connections:
  67. payload = json.dumps({"type": "account_suspended"})
  68. for ws in self.active_connections[user_id]:
  69. try:
  70. await ws.send_text(payload)
  71. await ws.close(code=4003) # Custom code for kick
  72. except:
  73. pass
  74. # Connections will be removed via disconnect() called on close
  75. pass
  76. async def notify_order_update(self, user_id: int, order_id: int):
  77. if user_id in self.active_connections:
  78. payload = json.dumps({
  79. "type": "order_updated",
  80. "order_id": order_id
  81. })
  82. for ws in self.active_connections[user_id]:
  83. try:
  84. await ws.send_text(payload)
  85. except:
  86. pass
  87. global_manager = GlobalConnectionManager()