]> git.mar77i.info Git - hublib/blob - hub/websocket.py
89f11928d97938ee9bd5be669ecafed79fd823ce
[hublib] / hub / websocket.py
1 import asyncio
2 import json
3 import pickle
4 import sys
5 from asyncio.exceptions import CancelledError
6 from functools import partial
7 from traceback import print_exception
8
9 from falcon import WebSocketDisconnected
10 from redis.asyncio import StrictRedis
11
12 from .utils import get_redis_pass
13
14
15 class BaseWebSocketHub:
16 client_ids_sem = asyncio.Semaphore(0)
17
18 @classmethod
19 def _class_init(cls, redis):
20 if not hasattr(cls, "_class_init"):
21 return
22 delattr(cls, "_class_init")
23 asyncio.create_task(cls.initialize_client_ids(redis))
24
25 @classmethod
26 async def initialize_client_ids(cls, redis):
27 await redis.set("client_id", 0)
28 cls.client_ids_sem.release()
29
30 def __init__(self):
31 self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf"))
32 if hasattr(BaseWebSocketHub, "_class_init"):
33 BaseWebSocketHub._class_init(self.redis)
34
35 def task_done(self):
36 self.task = None
37
38 @staticmethod
39 async def process_websocket(redis, web_socket, extra_data={}, recipients=[]):
40 try:
41 while True:
42 data = json.loads(await web_socket.receive_text())
43 data.update(extra_data)
44 if callable(recipients):
45 current_recipients = recipients(data)
46 else:
47 current_recipients = recipients
48 for recipient in current_recipients:
49 await redis.publish(recipient, pickle.dumps(data))
50 except (CancelledError, WebSocketDisconnected):
51 pass
52
53 @staticmethod
54 async def process_pubsub(pubsub, web_socket):
55 try:
56 while True:
57 data = await pubsub.get_message(True, 0.3)
58 if not web_socket.ready or web_socket.closed:
59 break
60 if data is not None:
61 await web_socket.send_text(json.dumps(pickle.loads(data["data"])))
62 except (CancelledError, WebSocketDisconnected):
63 pass
64
65 async def on_websocket(
66 self,
67 req,
68 web_socket,
69 pubsub_name=None,
70 process_websockets_kwargs=None,
71 join_cb=None,
72 leave_cb=None,
73 ):
74 await web_socket.accept()
75 pubsub = self.redis.pubsub()
76 if pubsub_name:
77 await pubsub.subscribe(pubsub_name)
78 if callable(join_cb):
79 await join_cb()
80 try:
81 await asyncio.gather(
82 self.process_websocket(
83 self.redis, web_socket, **(process_websockets_kwargs or {})
84 ),
85 self.process_pubsub(pubsub, web_socket),
86 return_exceptions=True,
87 )
88 except (CancelledError, WebSocketDisconnected):
89 pass
90 except Exception:
91 print_exception(*sys.exc_info())
92 finally:
93 await web_socket.close()
94 if callable(leave_cb):
95 await leave_cb()
96
97
98 class WebSocketHub(BaseWebSocketHub):
99 def __init__(self, hubapp):
100 super().__init__()
101 self.hubapp = hubapp
102
103 async def join_leave_client_notify(self, redis, action, client_id):
104 await redis.publish(
105 f"{self.hubapp.name}-master",
106 pickle.dumps({"action": action, "client_id": client_id}),
107 )
108
109 async def on_websocket_client(self, req, web_socket):
110 await self.client_ids_sem.acquire()
111 try:
112 client_id = await self.redis.incr("client_id")
113 finally:
114 self.client_ids_sem.release()
115 return await self.on_websocket(
116 req,
117 web_socket,
118 f"{self.hubapp.name}-client-{client_id}",
119 {
120 "extra_data": {"client_id": client_id},
121 "recipients": [f"{self.hubapp.name}-master"],
122 },
123 partial(self.join_leave_client_notify, self.redis, "join", client_id),
124 partial(self.join_leave_client_notify, self.redis, "leave", client_id),
125 )
126
127 async def on_websocket_master(self, req, web_socket):
128 return await self.on_websocket(
129 req,
130 web_socket,
131 f"{self.hubapp.name}-master",
132 {"recipients": self.get_master_recipients},
133 )
134
135 def get_master_recipients(self, data):
136 return [
137 f"{self.hubapp.name}-client-{int(client_id)}"
138 for client_id in data.pop("client_ids", ())
139 ]