--- /dev/null
+import asyncio
+import json
+import pickle
+import sys
+from asyncio.exceptions import CancelledError
+from functools import partial
+from traceback import print_exception
+
+from falcon import WebSocketDisconnected
+from redis.asyncio import StrictRedis
+
+from .utils import get_redis_pass
+
+
+class BaseWebSocketHub:
+ first = True
+ client_ids_sem = asyncio.Semaphore(0)
+
+ @classmethod
+ def __class_init(cls, redis):
+ if not cls.first:
+ return
+ cls.first = False
+ asyncio.create_task(cls.initialize_client_ids(redis))
+
+ @classmethod
+ async def initialize_client_ids(cls, redis):
+ await redis.set("client_id", 0)
+ cls.client_ids_sem.release()
+
+ def __init__(self):
+ self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf"))
+ self.__class_init(self.redis)
+
+ def task_done(self):
+ self.task = None
+
+ @staticmethod
+ async def process_websocket(redis, web_socket, extra_data={}, recipients=[]):
+ try:
+ while True:
+ data = json.loads(await web_socket.receive_text())
+ data.update(extra_data)
+ if callable(recipients):
+ current_recipients = recipients(data)
+ else:
+ current_recipients = recipients
+ for recipient in current_recipients:
+ await redis.publish(recipient, pickle.dumps(data))
+ except (CancelledError, WebSocketDisconnected):
+ pass
+
+ @staticmethod
+ async def process_pubsub(pubsub, web_socket):
+ try:
+ while True:
+ data = await pubsub.get_message(True, 0.3)
+ if not web_socket.ready or web_socket.closed:
+ break
+ if data is not None:
+ await web_socket.send_text(json.dumps(pickle.loads(data["data"])))
+ except (CancelledError, WebSocketDisconnected):
+ pass
+
+ async def on_websocket(
+ self,
+ req,
+ web_socket,
+ pubsub_name=None,
+ process_websockets_kwargs=None,
+ join_cb=None,
+ leave_cb=None,
+ ):
+ await web_socket.accept()
+ pubsub = self.redis.pubsub()
+ if pubsub_name:
+ await pubsub.subscribe(pubsub_name)
+ if callable(join_cb):
+ await join_cb()
+ try:
+ await asyncio.gather(
+ self.process_websocket(
+ self.redis, web_socket, **(process_websockets_kwargs or {})
+ ),
+ self.process_pubsub(pubsub, web_socket),
+ return_exceptions=True,
+ )
+ except (CancelledError, WebSocketDisconnected):
+ pass
+ except Exception:
+ print_exception(*sys.exc_info())
+ finally:
+ await web_socket.close()
+ if callable(leave_cb):
+ await leave_cb()
+
+
+class WebSocketHub(BaseWebSocketHub):
+ def __init__(self, hubapp):
+ super().__init__()
+ self.hubapp = hubapp
+
+ async def join_leave_client_notify(self, redis, action, client_id):
+ await redis.publish(
+ f"{self.hubapp.name}-master",
+ pickle.dumps({"action": action, "client_id": client_id}),
+ )
+
+ async def on_websocket_client(self, req, web_socket):
+ await self.client_ids_sem.acquire()
+ try:
+ client_id = await self.redis.incr("client_id")
+ finally:
+ self.client_ids_sem.release()
+ return await self.on_websocket(
+ req,
+ web_socket,
+ f"{self.hubapp.name}-client-{client_id}",
+ {
+ "extra_data": {"client_id": client_id},
+ "recipients": [f"{self.hubapp.name}-master"],
+ },
+ partial(self.join_leave_client_notify, self.redis, "join", client_id),
+ partial(self.join_leave_client_notify, self.redis, "leave", client_id),
+ )
+
+ async def on_websocket_master(self, req, web_socket):
+ return await self.on_websocket(
+ req,
+ web_socket,
+ f"{self.hubapp.name}-master",
+ {"recipients": self.get_master_recipients},
+ )
+
+ def get_master_recipients(self, data):
+ return [
+ f"{self.hubapp.name}-client-{int(client_id)}"
+ for client_id in data.pop("client_ids", ())
+ ]