]> git.mar77i.info Git - hublib/blobdiff - hub/websocket.py
add a bunch of typehints
[hublib] / hub / websocket.py
index 89f11928d97938ee9bd5be669ecafed79fd823ce..0ce426ad30269d223eb10cb17a732f6423e1b06d 100644 (file)
@@ -7,46 +7,30 @@ from functools import partial
 from traceback import print_exception
 
 from falcon import WebSocketDisconnected
-from redis.asyncio import StrictRedis
 
-from .utils import get_redis_pass
+from .static import TreeFileApp
 
 
-class BaseWebSocketHub:
-    client_ids_sem = asyncio.Semaphore(0)
-
-    @classmethod
-    def _class_init(cls, redis):
-        if not hasattr(cls, "_class_init"):
-            return
-        delattr(cls, "_class_init")
-        asyncio.create_task(cls.initialize_client_ids(redis))
-
-    @classmethod
-    async def initialize_client_ids(cls, redis):
-        await redis.set("client_id", 0)
-        cls.client_ids_sem.release()
-
-    def __init__(self):
-        self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf"))
-        if hasattr(BaseWebSocketHub, "_class_init"):
-            BaseWebSocketHub._class_init(self.redis)
-
-    def task_done(self):
-        self.task = None
+class WebSocketApp:
+    def __init__(self, hubapp: TreeFileApp):
+        self.name = hubapp.name
+        self.conn = hubapp.root.conn
 
     @staticmethod
-    async def process_websocket(redis, web_socket, extra_data={}, recipients=[]):
+    async def process_websocket(conn, web_socket, extra_data=None, recipients=None):
         try:
             while True:
                 data = json.loads(await web_socket.receive_text())
-                data.update(extra_data)
+                if extra_data:
+                    data.update(extra_data)
                 if callable(recipients):
                     current_recipients = recipients(data)
-                else:
+                elif recipients:
                     current_recipients = recipients
+                else:
+                    raise ValueError("no recipients specified")
                 for recipient in current_recipients:
-                    await redis.publish(recipient, pickle.dumps(data))
+                    await conn.publish(recipient, pickle.dumps(data))
         except (CancelledError, WebSocketDisconnected):
             pass
 
@@ -72,7 +56,7 @@ class BaseWebSocketHub:
         leave_cb=None,
     ):
         await web_socket.accept()
-        pubsub = self.redis.pubsub()
+        pubsub = self.conn.pubsub()
         if pubsub_name:
             await pubsub.subscribe(pubsub_name)
         if callable(join_cb):
@@ -80,7 +64,7 @@ class BaseWebSocketHub:
         try:
             await asyncio.gather(
                 self.process_websocket(
-                    self.redis, web_socket, **(process_websockets_kwargs or {})
+                    self.conn, web_socket, **(process_websockets_kwargs or {})
                 ),
                 self.process_pubsub(pubsub, web_socket),
                 return_exceptions=True,
@@ -94,46 +78,36 @@ class BaseWebSocketHub:
             if callable(leave_cb):
                 await leave_cb()
 
-
-class WebSocketHub(BaseWebSocketHub):
-    def __init__(self, hubapp):
-        super().__init__()
-        self.hubapp = hubapp
-
-    async def join_leave_client_notify(self, redis, action, client_id):
+    async def client_notify(self, redis, action, client_id):
         await redis.publish(
-            f"{self.hubapp.name}-master",
+            f"{self.name}-master",
             pickle.dumps({"action": action, "client_id": client_id}),
         )
 
     async def on_websocket_client(self, req, web_socket):
-        await self.client_ids_sem.acquire()
-        try:
-            client_id = await self.redis.incr("client_id")
-        finally:
-            self.client_ids_sem.release()
+        client_id = await self.conn.incr("client_id")
         return await self.on_websocket(
             req,
             web_socket,
-            f"{self.hubapp.name}-client-{client_id}",
+            f"{self.name}-client-{client_id}",
             {
                 "extra_data": {"client_id": client_id},
-                "recipients": [f"{self.hubapp.name}-master"],
+                "recipients": [f"{self.name}-master"],
             },
-            partial(self.join_leave_client_notify, self.redis, "join", client_id),
-            partial(self.join_leave_client_notify, self.redis, "leave", client_id),
+            partial(self.client_notify, self.conn, "join", client_id),
+            partial(self.client_notify, self.conn, "leave", client_id),
         )
 
     async def on_websocket_master(self, req, web_socket):
             return await self.on_websocket(
                 req,
                 web_socket,
-                f"{self.hubapp.name}-master",
+                f"{self.name}-master",
                 {"recipients": self.get_master_recipients},
             )
 
     def get_master_recipients(self, data):
         return [
-            f"{self.hubapp.name}-client-{int(client_id)}"
+            f"{self.name}-client-{int(client_id)}"
             for client_id in data.pop("client_ids", ())
         ]