X-Git-Url: https://git.mar77i.info/?a=blobdiff_plain;f=hub%2Fwebsocket.py;h=0ecf9875e29a5ba5f66665a1b54ef30bd2a41f9a;hb=3c5ec422ace644d848d2f845b0f3ef8de73462ef;hp=3d511168968e8386898623b9d70c6f84eba077fe;hpb=6128e895bc2a5da5fe645cc9a7ad74ac75af4f6b;p=hublib diff --git a/hub/websocket.py b/hub/websocket.py index 3d51116..0ecf987 100644 --- a/hub/websocket.py +++ b/hub/websocket.py @@ -7,36 +7,18 @@ from functools import partial from traceback import print_exception from falcon import WebSocketDisconnected -from redis.asyncio import StrictRedis -from .utils import get_redis_pass - -class BaseWebSocketHub: - first = True - client_ids_sem = asyncio.Semaphore(0) - - @classmethod - def __class_init(cls, redis): - if not cls.first: - return - cls.first = False - asyncio.create_task(cls.initialize_client_ids(redis)) - - @classmethod - async def initialize_client_ids(cls, redis): - await redis.set("client_id", 0) - cls.client_ids_sem.release() - - def __init__(self): - self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf")) - self.__class_init(self.redis) +class BaseWebSocketApp: + def __init__(self, hubapp): + self.hubapp = hubapp + self.conn = self.hubapp.app.hubapps["root"].conn def task_done(self): self.task = None @staticmethod - async def process_websocket(redis, web_socket, extra_data={}, recipients=[]): + async def process_websocket(conn, web_socket, extra_data={}, recipients=[]): try: while True: data = json.loads(await web_socket.receive_text()) @@ -46,7 +28,7 @@ class BaseWebSocketHub: else: current_recipients = recipients for recipient in current_recipients: - await redis.publish(recipient, pickle.dumps(data)) + await conn.publish(recipient, pickle.dumps(data)) except (CancelledError, WebSocketDisconnected): pass @@ -72,7 +54,7 @@ class BaseWebSocketHub: leave_cb=None, ): await web_socket.accept() - pubsub = self.redis.pubsub() + pubsub = self.conn.pubsub() if pubsub_name: await pubsub.subscribe(pubsub_name) if callable(join_cb): @@ -80,7 +62,7 @@ class BaseWebSocketHub: try: await asyncio.gather( self.process_websocket( - self.redis, web_socket, **(process_websockets_kwargs or {}) + self.conn, web_socket, **(process_websockets_kwargs or {}) ), self.process_pubsub(pubsub, web_socket), return_exceptions=True, @@ -95,11 +77,7 @@ class BaseWebSocketHub: await leave_cb() -class WebSocketHub(BaseWebSocketHub): - def __init__(self, hubapp): - super().__init__() - self.hubapp = hubapp - +class WebSocketApp(BaseWebSocketApp): async def join_leave_client_notify(self, redis, action, client_id): await redis.publish( f"{self.hubapp.name}-master", @@ -107,11 +85,7 @@ class WebSocketHub(BaseWebSocketHub): ) async def on_websocket_client(self, req, web_socket): - await self.client_ids_sem.acquire() - try: - client_id = await self.redis.incr("client_id") - finally: - self.client_ids_sem.release() + client_id = await self.conn.incr("client_id") return await self.on_websocket( req, web_socket, @@ -120,8 +94,8 @@ class WebSocketHub(BaseWebSocketHub): "extra_data": {"client_id": client_id}, "recipients": [f"{self.hubapp.name}-master"], }, - partial(self.join_leave_client_notify, self.redis, "join", client_id), - partial(self.join_leave_client_notify, self.redis, "leave", client_id), + partial(self.join_leave_client_notify, self.conn, "join", client_id), + partial(self.join_leave_client_notify, self.conn, "leave", client_id), ) async def on_websocket_master(self, req, web_socket):