import asyncio import json import pickle import sys from asyncio.exceptions import CancelledError from functools import partial from traceback import print_exception from falcon import WebSocketDisconnected class BaseWebSocketApp: def __init__(self, hubapp): self.hubapp = hubapp self.conn = self.hubapp.app.hubapps["root"].conn def task_done(self): self.task = None @staticmethod async def process_websocket(conn, web_socket, extra_data={}, recipients=[]): try: while True: data = json.loads(await web_socket.receive_text()) data.update(extra_data) if callable(recipients): current_recipients = recipients(data) else: current_recipients = recipients for recipient in current_recipients: await conn.publish(recipient, pickle.dumps(data)) except (CancelledError, WebSocketDisconnected): pass @staticmethod async def process_pubsub(pubsub, web_socket): try: while True: data = await pubsub.get_message(True, 0.3) if not web_socket.ready or web_socket.closed: break if data is not None: await web_socket.send_text(json.dumps(pickle.loads(data["data"]))) except (CancelledError, WebSocketDisconnected): pass async def on_websocket( self, req, web_socket, pubsub_name=None, process_websockets_kwargs=None, join_cb=None, leave_cb=None, ): await web_socket.accept() pubsub = self.conn.pubsub() if pubsub_name: await pubsub.subscribe(pubsub_name) if callable(join_cb): await join_cb() try: await asyncio.gather( self.process_websocket( self.conn, web_socket, **(process_websockets_kwargs or {}) ), self.process_pubsub(pubsub, web_socket), return_exceptions=True, ) except (CancelledError, WebSocketDisconnected): pass except Exception: print_exception(*sys.exc_info()) finally: await web_socket.close() if callable(leave_cb): await leave_cb() class WebSocketApp(BaseWebSocketApp): async def join_leave_client_notify(self, redis, action, client_id): await redis.publish( f"{self.hubapp.name}-master", pickle.dumps({"action": action, "client_id": client_id}), ) async def on_websocket_client(self, req, web_socket): client_id = await self.conn.incr("client_id") return await self.on_websocket( req, web_socket, f"{self.hubapp.name}-client-{client_id}", { "extra_data": {"client_id": client_id}, "recipients": [f"{self.hubapp.name}-master"], }, partial(self.join_leave_client_notify, self.conn, "join", client_id), partial(self.join_leave_client_notify, self.conn, "leave", client_id), ) async def on_websocket_master(self, req, web_socket): return await self.on_websocket( req, web_socket, f"{self.hubapp.name}-master", {"recipients": self.get_master_recipients}, ) def get_master_recipients(self, data): return [ f"{self.hubapp.name}-client-{int(client_id)}" for client_id in data.pop("client_ids", ()) ]