]> git.mar77i.info Git - hublib/blobdiff - hub/websocket.py
big cleanup and refactoring #1
[hublib] / hub / websocket.py
index 3d511168968e8386898623b9d70c6f84eba077fe..0ecf9875e29a5ba5f66665a1b54ef30bd2a41f9a 100644 (file)
@@ -7,36 +7,18 @@ from functools import partial
 from traceback import print_exception
 
 from falcon import WebSocketDisconnected
-from redis.asyncio import StrictRedis
 
-from .utils import get_redis_pass
 
-
-class BaseWebSocketHub:
-    first = True
-    client_ids_sem = asyncio.Semaphore(0)
-
-    @classmethod
-    def __class_init(cls, redis):
-        if not cls.first:
-            return
-        cls.first = False
-        asyncio.create_task(cls.initialize_client_ids(redis))
-
-    @classmethod
-    async def initialize_client_ids(cls, redis):
-        await redis.set("client_id", 0)
-        cls.client_ids_sem.release()
-
-    def __init__(self):
-        self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf"))
-        self.__class_init(self.redis)
+class BaseWebSocketApp:
+    def __init__(self, hubapp):
+        self.hubapp = hubapp
+        self.conn = self.hubapp.app.hubapps["root"].conn
 
     def task_done(self):
         self.task = None
 
     @staticmethod
-    async def process_websocket(redis, web_socket, extra_data={}, recipients=[]):
+    async def process_websocket(conn, web_socket, extra_data={}, recipients=[]):
         try:
             while True:
                 data = json.loads(await web_socket.receive_text())
@@ -46,7 +28,7 @@ class BaseWebSocketHub:
                 else:
                     current_recipients = recipients
                 for recipient in current_recipients:
-                    await redis.publish(recipient, pickle.dumps(data))
+                    await conn.publish(recipient, pickle.dumps(data))
         except (CancelledError, WebSocketDisconnected):
             pass
 
@@ -72,7 +54,7 @@ class BaseWebSocketHub:
         leave_cb=None,
     ):
         await web_socket.accept()
-        pubsub = self.redis.pubsub()
+        pubsub = self.conn.pubsub()
         if pubsub_name:
             await pubsub.subscribe(pubsub_name)
         if callable(join_cb):
@@ -80,7 +62,7 @@ class BaseWebSocketHub:
         try:
             await asyncio.gather(
                 self.process_websocket(
-                    self.redis, web_socket, **(process_websockets_kwargs or {})
+                    self.conn, web_socket, **(process_websockets_kwargs or {})
                 ),
                 self.process_pubsub(pubsub, web_socket),
                 return_exceptions=True,
@@ -95,11 +77,7 @@ class BaseWebSocketHub:
                 await leave_cb()
 
 
-class WebSocketHub(BaseWebSocketHub):
-    def __init__(self, hubapp):
-        super().__init__()
-        self.hubapp = hubapp
-
+class WebSocketApp(BaseWebSocketApp):
     async def join_leave_client_notify(self, redis, action, client_id):
         await redis.publish(
             f"{self.hubapp.name}-master",
@@ -107,11 +85,7 @@ class WebSocketHub(BaseWebSocketHub):
         )
 
     async def on_websocket_client(self, req, web_socket):
-        await self.client_ids_sem.acquire()
-        try:
-            client_id = await self.redis.incr("client_id")
-        finally:
-            self.client_ids_sem.release()
+        client_id = await self.conn.incr("client_id")
         return await self.on_websocket(
             req,
             web_socket,
@@ -120,8 +94,8 @@ class WebSocketHub(BaseWebSocketHub):
                 "extra_data": {"client_id": client_id},
                 "recipients": [f"{self.hubapp.name}-master"],
             },
-            partial(self.join_leave_client_notify, self.redis, "join", client_id),
-            partial(self.join_leave_client_notify, self.redis, "leave", client_id),
+            partial(self.join_leave_client_notify, self.conn, "join", client_id),
+            partial(self.join_leave_client_notify, self.conn, "leave", client_id),
         )
 
     async def on_websocket_master(self, req, web_socket):