X-Git-Url: https://git.mar77i.info/?a=blobdiff_plain;f=hub%2Fwebsocket.py;h=0ce426ad30269d223eb10cb17a732f6423e1b06d;hb=f591748910835b1a11c3765723d9d30193a5bd26;hp=3d511168968e8386898623b9d70c6f84eba077fe;hpb=6128e895bc2a5da5fe645cc9a7ad74ac75af4f6b;p=hublib diff --git a/hub/websocket.py b/hub/websocket.py index 3d51116..0ce426a 100644 --- a/hub/websocket.py +++ b/hub/websocket.py @@ -7,46 +7,30 @@ from functools import partial from traceback import print_exception from falcon import WebSocketDisconnected -from redis.asyncio import StrictRedis -from .utils import get_redis_pass +from .static import TreeFileApp -class BaseWebSocketHub: - first = True - client_ids_sem = asyncio.Semaphore(0) - - @classmethod - def __class_init(cls, redis): - if not cls.first: - return - cls.first = False - asyncio.create_task(cls.initialize_client_ids(redis)) - - @classmethod - async def initialize_client_ids(cls, redis): - await redis.set("client_id", 0) - cls.client_ids_sem.release() - - def __init__(self): - self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf")) - self.__class_init(self.redis) - - def task_done(self): - self.task = None +class WebSocketApp: + def __init__(self, hubapp: TreeFileApp): + self.name = hubapp.name + self.conn = hubapp.root.conn @staticmethod - async def process_websocket(redis, web_socket, extra_data={}, recipients=[]): + async def process_websocket(conn, web_socket, extra_data=None, recipients=None): try: while True: data = json.loads(await web_socket.receive_text()) - data.update(extra_data) + if extra_data: + data.update(extra_data) if callable(recipients): current_recipients = recipients(data) - else: + elif recipients: current_recipients = recipients + else: + raise ValueError("no recipients specified") for recipient in current_recipients: - await redis.publish(recipient, pickle.dumps(data)) + await conn.publish(recipient, pickle.dumps(data)) except (CancelledError, WebSocketDisconnected): pass @@ -72,7 +56,7 @@ class BaseWebSocketHub: leave_cb=None, ): await web_socket.accept() - pubsub = self.redis.pubsub() + pubsub = self.conn.pubsub() if pubsub_name: await pubsub.subscribe(pubsub_name) if callable(join_cb): @@ -80,7 +64,7 @@ class BaseWebSocketHub: try: await asyncio.gather( self.process_websocket( - self.redis, web_socket, **(process_websockets_kwargs or {}) + self.conn, web_socket, **(process_websockets_kwargs or {}) ), self.process_pubsub(pubsub, web_socket), return_exceptions=True, @@ -94,46 +78,36 @@ class BaseWebSocketHub: if callable(leave_cb): await leave_cb() - -class WebSocketHub(BaseWebSocketHub): - def __init__(self, hubapp): - super().__init__() - self.hubapp = hubapp - - async def join_leave_client_notify(self, redis, action, client_id): + async def client_notify(self, redis, action, client_id): await redis.publish( - f"{self.hubapp.name}-master", + f"{self.name}-master", pickle.dumps({"action": action, "client_id": client_id}), ) async def on_websocket_client(self, req, web_socket): - await self.client_ids_sem.acquire() - try: - client_id = await self.redis.incr("client_id") - finally: - self.client_ids_sem.release() + client_id = await self.conn.incr("client_id") return await self.on_websocket( req, web_socket, - f"{self.hubapp.name}-client-{client_id}", + f"{self.name}-client-{client_id}", { "extra_data": {"client_id": client_id}, - "recipients": [f"{self.hubapp.name}-master"], + "recipients": [f"{self.name}-master"], }, - partial(self.join_leave_client_notify, self.redis, "join", client_id), - partial(self.join_leave_client_notify, self.redis, "leave", client_id), + partial(self.client_notify, self.conn, "join", client_id), + partial(self.client_notify, self.conn, "leave", client_id), ) async def on_websocket_master(self, req, web_socket): return await self.on_websocket( req, web_socket, - f"{self.hubapp.name}-master", + f"{self.name}-master", {"recipients": self.get_master_recipients}, ) def get_master_recipients(self, data): return [ - f"{self.hubapp.name}-client-{int(client_id)}" + f"{self.name}-client-{int(client_id)}" for client_id in data.pop("client_ids", ()) ]