]> git.mar77i.info Git - hublib/blobdiff - hub/websocket.py
serve other hubapps too, consolidate and a lot more...
[hublib] / hub / websocket.py
diff --git a/hub/websocket.py b/hub/websocket.py
new file mode 100644 (file)
index 0000000..3d51116
--- /dev/null
@@ -0,0 +1,139 @@
+import asyncio
+import json
+import pickle
+import sys
+from asyncio.exceptions import CancelledError
+from functools import partial
+from traceback import print_exception
+
+from falcon import WebSocketDisconnected
+from redis.asyncio import StrictRedis
+
+from .utils import get_redis_pass
+
+
+class BaseWebSocketHub:
+    first = True
+    client_ids_sem = asyncio.Semaphore(0)
+
+    @classmethod
+    def __class_init(cls, redis):
+        if not cls.first:
+            return
+        cls.first = False
+        asyncio.create_task(cls.initialize_client_ids(redis))
+
+    @classmethod
+    async def initialize_client_ids(cls, redis):
+        await redis.set("client_id", 0)
+        cls.client_ids_sem.release()
+
+    def __init__(self):
+        self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf"))
+        self.__class_init(self.redis)
+
+    def task_done(self):
+        self.task = None
+
+    @staticmethod
+    async def process_websocket(redis, web_socket, extra_data={}, recipients=[]):
+        try:
+            while True:
+                data = json.loads(await web_socket.receive_text())
+                data.update(extra_data)
+                if callable(recipients):
+                    current_recipients = recipients(data)
+                else:
+                    current_recipients = recipients
+                for recipient in current_recipients:
+                    await redis.publish(recipient, pickle.dumps(data))
+        except (CancelledError, WebSocketDisconnected):
+            pass
+
+    @staticmethod
+    async def process_pubsub(pubsub, web_socket):
+        try:
+            while True:
+                data = await pubsub.get_message(True, 0.3)
+                if not web_socket.ready or web_socket.closed:
+                    break
+                if data is not None:
+                    await web_socket.send_text(json.dumps(pickle.loads(data["data"])))
+        except (CancelledError, WebSocketDisconnected):
+            pass
+
+    async def on_websocket(
+        self,
+        req,
+        web_socket,
+        pubsub_name=None,
+        process_websockets_kwargs=None,
+        join_cb=None,
+        leave_cb=None,
+    ):
+        await web_socket.accept()
+        pubsub = self.redis.pubsub()
+        if pubsub_name:
+            await pubsub.subscribe(pubsub_name)
+        if callable(join_cb):
+            await join_cb()
+        try:
+            await asyncio.gather(
+                self.process_websocket(
+                    self.redis, web_socket, **(process_websockets_kwargs or {})
+                ),
+                self.process_pubsub(pubsub, web_socket),
+                return_exceptions=True,
+            )
+        except (CancelledError, WebSocketDisconnected):
+            pass
+        except Exception:
+            print_exception(*sys.exc_info())
+        finally:
+            await web_socket.close()
+            if callable(leave_cb):
+                await leave_cb()
+
+
+class WebSocketHub(BaseWebSocketHub):
+    def __init__(self, hubapp):
+        super().__init__()
+        self.hubapp = hubapp
+
+    async def join_leave_client_notify(self, redis, action, client_id):
+        await redis.publish(
+            f"{self.hubapp.name}-master",
+            pickle.dumps({"action": action, "client_id": client_id}),
+        )
+
+    async def on_websocket_client(self, req, web_socket):
+        await self.client_ids_sem.acquire()
+        try:
+            client_id = await self.redis.incr("client_id")
+        finally:
+            self.client_ids_sem.release()
+        return await self.on_websocket(
+            req,
+            web_socket,
+            f"{self.hubapp.name}-client-{client_id}",
+            {
+                "extra_data": {"client_id": client_id},
+                "recipients": [f"{self.hubapp.name}-master"],
+            },
+            partial(self.join_leave_client_notify, self.redis, "join", client_id),
+            partial(self.join_leave_client_notify, self.redis, "leave", client_id),
+        )
+
+    async def on_websocket_master(self, req, web_socket):
+            return await self.on_websocket(
+                req,
+                web_socket,
+                f"{self.hubapp.name}-master",
+                {"recipients": self.get_master_recipients},
+            )
+
+    def get_master_recipients(self, data):
+        return [
+            f"{self.hubapp.name}-client-{int(client_id)}"
+            for client_id in data.pop("client_ids", ())
+        ]