+import asyncio
+from asyncio.exceptions import CancelledError
+import json
+import pickle
+import sys
+from traceback import print_exception
+
+from falcon import WebSocketDisconnected
+from redis.asyncio import StrictRedis
+
+from .utils import get_redis_pass, scramble
+
+
+class Hub:
+ def __init__(self, secret):
+ self.master_ws_uri = f"/{scramble(secret, 'ws')}"
+ self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf"))
+ asyncio.ensure_future(self.redis.set("client_id", 0))
+
+ async def process_websocket(self, client_id, web_socket):
+ try:
+ while True:
+ data = await web_socket.receive_text()
+ try:
+ parsed_data = json.loads(data)
+ except json.JSONDecodeError:
+ parsed_data = None
+ if not isinstance(parsed_data, dict):
+ parsed_data = {"data": data}
+ parsed_data["client_id"] = client_id
+ await self.redis.publish("master", pickle.dumps(parsed_data))
+ except (CancelledError, WebSocketDisconnected):
+ pass
+
+ async def process_pubsub(self, pubsub, web_socket):
+ try:
+ while True:
+ data = await pubsub.get_message(True, .3)
+ if not web_socket.ready or web_socket.closed:
+ break
+ if data is not None:
+ await web_socket.send_text(json.dumps(pickle.loads(data["data"])))
+ except (CancelledError, WebSocketDisconnected):
+ pass
+
+ async def on_websocket(self, req, web_socket):
+ client_id = await self.redis.incr("client_id")
+ await web_socket.accept()
+ pubsub = self.redis.pubsub()
+ await pubsub.subscribe(f"client-{client_id}")
+ await self.redis.publish(
+ "master", pickle.dumps({"action": "join", "client_id": client_id}),
+ )
+ try:
+ await asyncio.gather(
+ self.process_websocket(client_id, web_socket),
+ self.process_pubsub(pubsub, web_socket),
+ return_exceptions=True,
+ )
+ except (CancelledError, WebSocketDisconnected):
+ pass
+ except Exception:
+ print_exception(*sys.exc_info())
+ finally:
+ await web_socket.close()
+ await self.redis.publish(
+ "master",
+ pickle.dumps({"action": "leave", "client_id": client_id}),
+ )
+
+ async def process_websocket_master(self, web_socket):
+ try:
+ while True:
+ data = json.loads(await web_socket.receive_text())
+ for client_id in data.pop("client_ids", ()):
+ await self.redis.publish(
+ f"client-{client_id}",
+ pickle.dumps(data),
+ )
+ except (CancelledError, WebSocketDisconnected) as e:
+ pass
+
+ async def on_websocket_master(self, req, web_socket):
+ await web_socket.accept()
+ pubsub = self.redis.pubsub()
+ await pubsub.subscribe("master")
+ try:
+ await asyncio.gather(
+ self.process_websocket_master(web_socket),
+ self.process_pubsub(pubsub, web_socket),
+ return_exceptions=True,
+ )
+ except (CancelledError, WebSocketDisconnected):
+ pass
+ except Exception:
+ print_exception(*sys.exc_info())
+ finally:
+ await web_socket.close()
+
+ def add_routes(self, app):
+ app.add_route("/ws", self)
+ app.add_route(self.master_ws_uri, self, suffix="master")