-import asyncio
-from asyncio.exceptions import CancelledError
-import json
-import pickle
-import sys
-from traceback import print_exception
-
-from falcon import WebSocketDisconnected
-from redis.asyncio import StrictRedis
-
-from .utils import get_redis_pass, scramble
-
-
-class Hub:
- def __init__(self, secret):
- self.master_ws_uri = f"/{scramble(secret, 'ws')}"
- self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf"))
- asyncio.ensure_future(self.redis.set("client_id", 0))
-
- async def process_websocket(self, client_id, web_socket):
- try:
- while True:
- data = await web_socket.receive_text()
- try:
- parsed_data = json.loads(data)
- except json.JSONDecodeError:
- parsed_data = None
- if not isinstance(parsed_data, dict):
- parsed_data = {"data": data}
- parsed_data["client_id"] = client_id
- await self.redis.publish("master", pickle.dumps(parsed_data))
- except (CancelledError, WebSocketDisconnected):
- pass
-
- async def process_pubsub(self, pubsub, web_socket):
- try:
- while True:
- data = await pubsub.get_message(True, .3)
- if not web_socket.ready or web_socket.closed:
- break
- if data is not None:
- await web_socket.send_text(json.dumps(pickle.loads(data["data"])))
- except (CancelledError, WebSocketDisconnected):
- pass
-
- async def on_websocket(self, req, web_socket):
- client_id = await self.redis.incr("client_id")
- await web_socket.accept()
- pubsub = self.redis.pubsub()
- await pubsub.subscribe(f"client-{client_id}")
- await self.redis.publish(
- "master", pickle.dumps({"action": "join", "client_id": client_id}),
- )
- try:
- await asyncio.gather(
- self.process_websocket(client_id, web_socket),
- self.process_pubsub(pubsub, web_socket),
- return_exceptions=True,
- )
- except (CancelledError, WebSocketDisconnected):
- pass
- except Exception:
- print_exception(*sys.exc_info())
- finally:
- await web_socket.close()
- await self.redis.publish(
- "master",
- pickle.dumps({"action": "leave", "client_id": client_id}),
- )
-
- async def process_websocket_master(self, web_socket):
- try:
- while True:
- data = json.loads(await web_socket.receive_text())
- for client_id in data.pop("client_ids", ()):
- await self.redis.publish(
- f"client-{client_id}",
- pickle.dumps(data),
- )
- except (CancelledError, WebSocketDisconnected) as e:
- pass
-
- async def on_websocket_master(self, req, web_socket):
- await web_socket.accept()
- pubsub = self.redis.pubsub()
- await pubsub.subscribe("master")
- try:
- await asyncio.gather(
- self.process_websocket_master(web_socket),
- self.process_pubsub(pubsub, web_socket),
- return_exceptions=True,
- )
- except (CancelledError, WebSocketDisconnected):
- pass
- except Exception:
- print_exception(*sys.exc_info())
- finally:
- await web_socket.close()
-
- def add_routes(self, app):
- app.add_route("/ws", self)
- app.add_route(self.master_ws_uri, self, suffix="master")
-
- def update_context_vars(self, context_vars):
- context_vars["master_ws_uri"] = self.master_ws_uri