import asyncio import json import pickle import sys from asyncio.exceptions import CancelledError from functools import partial from traceback import print_exception from falcon import WebSocketDisconnected from redis.asyncio import StrictRedis from .utils import get_redis_pass class BaseWebSocketHub: client_ids_sem = asyncio.Semaphore(0) @classmethod def _class_init(cls, redis): if not hasattr(cls, "_class_init"): return delattr(cls, "_class_init") asyncio.create_task(cls.initialize_client_ids(redis)) @classmethod async def initialize_client_ids(cls, redis): await redis.set("client_id", 0) cls.client_ids_sem.release() def __init__(self): self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf")) if hasattr(BaseWebSocketHub, "_class_init"): BaseWebSocketHub._class_init(self.redis) def task_done(self): self.task = None @staticmethod async def process_websocket(redis, web_socket, extra_data={}, recipients=[]): try: while True: data = json.loads(await web_socket.receive_text()) data.update(extra_data) if callable(recipients): current_recipients = recipients(data) else: current_recipients = recipients for recipient in current_recipients: await redis.publish(recipient, pickle.dumps(data)) except (CancelledError, WebSocketDisconnected): pass @staticmethod async def process_pubsub(pubsub, web_socket): try: while True: data = await pubsub.get_message(True, 0.3) if not web_socket.ready or web_socket.closed: break if data is not None: await web_socket.send_text(json.dumps(pickle.loads(data["data"]))) except (CancelledError, WebSocketDisconnected): pass async def on_websocket( self, req, web_socket, pubsub_name=None, process_websockets_kwargs=None, join_cb=None, leave_cb=None, ): await web_socket.accept() pubsub = self.redis.pubsub() if pubsub_name: await pubsub.subscribe(pubsub_name) if callable(join_cb): await join_cb() try: await asyncio.gather( self.process_websocket( self.redis, web_socket, **(process_websockets_kwargs or {}) ), self.process_pubsub(pubsub, web_socket), return_exceptions=True, ) except (CancelledError, WebSocketDisconnected): pass except Exception: print_exception(*sys.exc_info()) finally: await web_socket.close() if callable(leave_cb): await leave_cb() class WebSocketHub(BaseWebSocketHub): def __init__(self, hubapp): super().__init__() self.hubapp = hubapp async def join_leave_client_notify(self, redis, action, client_id): await redis.publish( f"{self.hubapp.name}-master", pickle.dumps({"action": action, "client_id": client_id}), ) async def on_websocket_client(self, req, web_socket): await self.client_ids_sem.acquire() try: client_id = await self.redis.incr("client_id") finally: self.client_ids_sem.release() return await self.on_websocket( req, web_socket, f"{self.hubapp.name}-client-{client_id}", { "extra_data": {"client_id": client_id}, "recipients": [f"{self.hubapp.name}-master"], }, partial(self.join_leave_client_notify, self.redis, "join", client_id), partial(self.join_leave_client_notify, self.redis, "leave", client_id), ) async def on_websocket_master(self, req, web_socket): return await self.on_websocket( req, web_socket, f"{self.hubapp.name}-master", {"recipients": self.get_master_recipients}, ) def get_master_recipients(self, data): return [ f"{self.hubapp.name}-client-{int(client_id)}" for client_id in data.pop("client_ids", ()) ]