from traceback import print_exception
from falcon import WebSocketDisconnected
-from redis.asyncio import StrictRedis
-from .utils import get_redis_pass
-
-class BaseWebSocketHub:
- first = True
- client_ids_sem = asyncio.Semaphore(0)
-
- @classmethod
- def __class_init(cls, redis):
- if not cls.first:
- return
- cls.first = False
- asyncio.create_task(cls.initialize_client_ids(redis))
-
- @classmethod
- async def initialize_client_ids(cls, redis):
- await redis.set("client_id", 0)
- cls.client_ids_sem.release()
-
- def __init__(self):
- self.redis = StrictRedis(password=get_redis_pass("/etc/redis/redis.conf"))
- self.__class_init(self.redis)
+class BaseWebSocketApp:
+ def __init__(self, hubapp):
+ self.hubapp = hubapp
+ self.conn = self.hubapp.app.hubapps["root"].conn
def task_done(self):
self.task = None
@staticmethod
- async def process_websocket(redis, web_socket, extra_data={}, recipients=[]):
+ async def process_websocket(conn, web_socket, extra_data={}, recipients=[]):
try:
while True:
data = json.loads(await web_socket.receive_text())
else:
current_recipients = recipients
for recipient in current_recipients:
- await redis.publish(recipient, pickle.dumps(data))
+ await conn.publish(recipient, pickle.dumps(data))
except (CancelledError, WebSocketDisconnected):
pass
leave_cb=None,
):
await web_socket.accept()
- pubsub = self.redis.pubsub()
+ pubsub = self.conn.pubsub()
if pubsub_name:
await pubsub.subscribe(pubsub_name)
if callable(join_cb):
try:
await asyncio.gather(
self.process_websocket(
- self.redis, web_socket, **(process_websockets_kwargs or {})
+ self.conn, web_socket, **(process_websockets_kwargs or {})
),
self.process_pubsub(pubsub, web_socket),
return_exceptions=True,
await leave_cb()
-class WebSocketHub(BaseWebSocketHub):
- def __init__(self, hubapp):
- super().__init__()
- self.hubapp = hubapp
-
+class WebSocketApp(BaseWebSocketApp):
async def join_leave_client_notify(self, redis, action, client_id):
await redis.publish(
f"{self.hubapp.name}-master",
)
async def on_websocket_client(self, req, web_socket):
- await self.client_ids_sem.acquire()
- try:
- client_id = await self.redis.incr("client_id")
- finally:
- self.client_ids_sem.release()
+ client_id = await self.conn.incr("client_id")
return await self.on_websocket(
req,
web_socket,
"extra_data": {"client_id": client_id},
"recipients": [f"{self.hubapp.name}-master"],
},
- partial(self.join_leave_client_notify, self.redis, "join", client_id),
- partial(self.join_leave_client_notify, self.redis, "leave", client_id),
+ partial(self.join_leave_client_notify, self.conn, "join", client_id),
+ partial(self.join_leave_client_notify, self.conn, "leave", client_id),
)
async def on_websocket_master(self, req, web_socket):