+ @asynccontextmanager
+ async def listen_notify_handler(self):
+ await sync_to_async(connection.connect)()
+ db_conn = connection.connection
+ db_conn.add_notify_handler(self.process_triggers)
+ self.loop.add_reader(db_conn.fileno(), self.db_idle, db_conn)
+ for name in TriggerChannel.registry:
+ await sync_to_async(db_conn.execute)(f"LISTEN {name}")
+ try:
+ yield
+ finally:
+ for name in TriggerChannel.registry:
+ try:
+ await sync_to_async(db_conn.execute)(f"UNLISTEN {name}")
+ except Exception:
+ pass
+ self.loop.remove_reader(db_conn.fileno())
+ db_conn.remove_notify_handler(self.process_triggers)
+ await sync_to_async(connection.close)()
+
+ async def process_ws(self):
+ try:
+ while True:
+ event = await self.receive()
+ if event["type"] == "websocket.connect":
+ await self.send({"type": "websocket.accept"})
+ self.active_event.set()
+ elif event["type"] == "websocket.disconnect":
+ break
+ elif event["type"] == "websocket.receive":
+ await self.send(
+ {
+ "type": "websocket.send",
+ "text": json.dumps(
+ await self.get_api_response(json.loads(event["text"]))
+ ),
+ }
+ )
+ finally:
+ self.active_event.clear()
+ # wake up send_loop() to end it
+ await self.notification_queue.put(ShutDownSentinel)
+
+ async def get_api_request(self, request_data):
+ body = request_data.get("body", "")
+ if not isinstance(body, str):
+ body = json.dumps(body)
+ ws_scope = {
+ **self.scope,
+ "type": "http",
+ "method": request_data["method"],
+ "path": request_data["path"],
+ "serial": request_data["serial"],
+ }
+ if "query_string" in request_data:
+ ws_scope["query_string"] = request_data["query_string"]
+ else:
+ ws_scope.pop("query_string", None)
+ request = ASGIRequest(ws_scope, BytesIO(body.encode()))
+ request.user = self.scope["user"]
+ request.resolver_match = self.resolver.resolve(request.path_info)
+ if not request.resolver_match.url_name.startswith("api-"):
+ raise Resolver404
+ return request
+
+ async def get_api_response(self, request_data):
+ try:
+ request = await self.get_api_request(request_data)
+ except Resolver404:
+ return {
+ "status_code": 404,
+ "content": "resource not found",
+ "serial": request_data["serial"],
+ }
+ else:
+ logger.info(f"ws api call: {request.method} {request.path}")
+ resolver_match = request.resolver_match
+ response = await sync_to_async(resolver_match.func)(
+ request,
+ *resolver_match.args,
+ **resolver_match.kwargs,
+ )
+ return {
+ "status_code": response.status_code,
+ "content": response.content.decode(),
+ "serial": request.scope["serial"],
+ }
+
+ async def send_loop(self):
+ await self.active_event.wait()
+ while self.active_event.is_set():
+ item = await self.notification_queue.get()
+ if item is ShutDownSentinel:
+ break
+ try:
+ await self.send(item)
+ finally:
+ self.notification_queue.task_done()