X-Git-Url: https://git.mar77i.info/?a=blobdiff_plain;f=hub%2Fwebsocket.py;fp=hub%2Fwebsocket.py;h=3d511168968e8386898623b9d70c6f84eba077fe;hb=6128e895bc2a5da5fe645cc9a7ad74ac75af4f6b;hp=0000000000000000000000000000000000000000;hpb=66e1cc7886b1ce7092281a43b9ee7969366e6835;p=hublib diff --git a/hub/websocket.py b/hub/websocket.py new file mode 100644 index 0000000..3d51116 --- /dev/null +++ b/hub/websocket.py @@ -0,0 +1,139 @@ +import asyncio +import json +import pickle +import sys +from asyncio.exceptions import CancelledError +from functools import partial +from traceback import print_exception + +from falcon import WebSocketDisconnected +from redis.asyncio import StrictRedis + +from .utils import get_redis_pass + + +class BaseWebSocketHub: + first = True + client_ids_sem = asyncio.Semaphore(0) + + @classmethod + def __class_init(cls, redis): + if not cls.first: + return + cls.first = False + asyncio.create_task(cls.initialize_client_ids(redis)) + + @classmethod + async def initialize_client_ids(cls, redis): + await redis.set("client_id", 0) + cls.client_ids_sem.release() + + def __init__(self): + self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf")) + self.__class_init(self.redis) + + def task_done(self): + self.task = None + + @staticmethod + async def process_websocket(redis, web_socket, extra_data={}, recipients=[]): + try: + while True: + data = json.loads(await web_socket.receive_text()) + data.update(extra_data) + if callable(recipients): + current_recipients = recipients(data) + else: + current_recipients = recipients + for recipient in current_recipients: + await redis.publish(recipient, pickle.dumps(data)) + except (CancelledError, WebSocketDisconnected): + pass + + @staticmethod + async def process_pubsub(pubsub, web_socket): + try: + while True: + data = await pubsub.get_message(True, 0.3) + if not web_socket.ready or web_socket.closed: + break + if data is not None: + await web_socket.send_text(json.dumps(pickle.loads(data["data"]))) + except (CancelledError, WebSocketDisconnected): + pass + + async def on_websocket( + self, + req, + web_socket, + pubsub_name=None, + process_websockets_kwargs=None, + join_cb=None, + leave_cb=None, + ): + await web_socket.accept() + pubsub = self.redis.pubsub() + if pubsub_name: + await pubsub.subscribe(pubsub_name) + if callable(join_cb): + await join_cb() + try: + await asyncio.gather( + self.process_websocket( + self.redis, web_socket, **(process_websockets_kwargs or {}) + ), + self.process_pubsub(pubsub, web_socket), + return_exceptions=True, + ) + except (CancelledError, WebSocketDisconnected): + pass + except Exception: + print_exception(*sys.exc_info()) + finally: + await web_socket.close() + if callable(leave_cb): + await leave_cb() + + +class WebSocketHub(BaseWebSocketHub): + def __init__(self, hubapp): + super().__init__() + self.hubapp = hubapp + + async def join_leave_client_notify(self, redis, action, client_id): + await redis.publish( + f"{self.hubapp.name}-master", + pickle.dumps({"action": action, "client_id": client_id}), + ) + + async def on_websocket_client(self, req, web_socket): + await self.client_ids_sem.acquire() + try: + client_id = await self.redis.incr("client_id") + finally: + self.client_ids_sem.release() + return await self.on_websocket( + req, + web_socket, + f"{self.hubapp.name}-client-{client_id}", + { + "extra_data": {"client_id": client_id}, + "recipients": [f"{self.hubapp.name}-master"], + }, + partial(self.join_leave_client_notify, self.redis, "join", client_id), + partial(self.join_leave_client_notify, self.redis, "leave", client_id), + ) + + async def on_websocket_master(self, req, web_socket): + return await self.on_websocket( + req, + web_socket, + f"{self.hubapp.name}-master", + {"recipients": self.get_master_recipients}, + ) + + def get_master_recipients(self, data): + return [ + f"{self.hubapp.name}-client-{int(client_id)}" + for client_id in data.pop("client_ids", ()) + ]