Skip to content

Commit 8fe17d9

Browse files
authored
Improve Home Assistant Core WebSocket proxy implementation (#5790)
* Improve Home Assistant Core WebSocket proxy implementation This change removes unnecessary task creation for every WebSocket message and instead creates just two tasks, one for each direction. This improves performance by about factor of 3 when measuring 1000 WebSocket requests to Core (from ~530ms to ~160ms). While at it, also handle all WebSocket message related to closing the WebSocket and report all other errors as warnings instead of just info. * Improve logging and error handling * Add WS client error test case * Use asyncio.gather directly * Use asyncio.wait to handle exceptions gracefully * Drop cancellation handling and correctly wait for the other proxy task
1 parent 0a684bd commit 8fe17d9

File tree

3 files changed

+93
-58
lines changed

3 files changed

+93
-58
lines changed

supervisor/api/proxy.py

+55-54
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from aiohttp.client_exceptions import ClientConnectorError
1111
from aiohttp.client_ws import ClientWebSocketResponse
1212
from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE
13-
from aiohttp.http import WSMessage
1413
from aiohttp.http_websocket import WSMsgType
1514
from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized
1615

16+
from supervisor.utils.logging import AddonLoggerAdapter
17+
1718
from ..coresys import CoreSysAttributes
1819
from ..exceptions import APIError, HomeAssistantAPIError, HomeAssistantAuthError
1920
from ..utils.json import json_dumps
@@ -179,23 +180,39 @@ async def _websocket_client(self) -> ClientWebSocketResponse:
179180

180181
async def _proxy_message(
181182
self,
182-
read_task: asyncio.Task,
183+
source: web.WebSocketResponse | ClientWebSocketResponse,
183184
target: web.WebSocketResponse | ClientWebSocketResponse,
185+
logger: AddonLoggerAdapter,
184186
) -> None:
185187
"""Proxy a message from client to server or vice versa."""
186-
msg: WSMessage = read_task.result()
187-
match msg.type:
188-
case WSMsgType.TEXT:
189-
await target.send_str(msg.data)
190-
case WSMsgType.BINARY:
191-
await target.send_bytes(msg.data)
192-
case WSMsgType.CLOSE:
193-
_LOGGER.debug("Received close message from WebSocket.")
194-
await target.close()
195-
case _:
196-
raise TypeError(
197-
f"Cannot proxy websocket message of unsupported type: {msg.type}"
198-
)
188+
while not source.closed and not target.closed:
189+
msg = await source.receive()
190+
match msg.type:
191+
case WSMsgType.TEXT:
192+
await target.send_str(msg.data)
193+
case WSMsgType.BINARY:
194+
await target.send_bytes(msg.data)
195+
case WSMsgType.CLOSE | WSMsgType.CLOSED:
196+
logger.debug(
197+
"Received WebSocket message type %r from %s.",
198+
msg.type,
199+
"add-on" if type(source) is web.WebSocketResponse else "Core",
200+
)
201+
await target.close()
202+
case WSMsgType.CLOSING:
203+
pass
204+
case WSMsgType.ERROR:
205+
logger.warning(
206+
"Error WebSocket message received while proxying: %r", msg.data
207+
)
208+
await target.close(code=source.close_code)
209+
case _:
210+
logger.warning(
211+
"Cannot proxy WebSocket message of unsupported type: %r",
212+
msg.type,
213+
)
214+
await source.close()
215+
await target.close()
199216

200217
async def websocket(self, request: web.Request):
201218
"""Initialize a WebSocket API connection."""
@@ -255,48 +272,32 @@ async def websocket(self, request: web.Request):
255272
except APIError:
256273
return server
257274

258-
_LOGGER.info("Home Assistant WebSocket API request running")
259-
try:
260-
client_read: asyncio.Task | None = None
261-
server_read: asyncio.Task | None = None
262-
while not server.closed and not client.closed:
263-
if not client_read:
264-
client_read = self.sys_create_task(client.receive())
265-
if not server_read:
266-
server_read = self.sys_create_task(server.receive())
267-
268-
# wait until data need to be processed
269-
await asyncio.wait(
270-
[client_read, server_read], return_when=asyncio.FIRST_COMPLETED
271-
)
275+
logger = AddonLoggerAdapter(_LOGGER, {"addon_name": addon_name})
276+
logger.info("Home Assistant WebSocket API proxy running")
272277

273-
# server
274-
if server_read.done() and not client.closed:
275-
await self._proxy_message(server_read, client)
276-
server_read = None
278+
client_task = self.sys_create_task(self._proxy_message(client, server, logger))
279+
server_task = self.sys_create_task(self._proxy_message(server, client, logger))
277280

278-
# client
279-
if client_read.done() and not server.closed:
280-
await self._proxy_message(client_read, server)
281-
client_read = None
281+
# Typically, this will return with an empty pending set. However, if one of
282+
# the directions has an exception, make sure to close both connections and
283+
# wait for the other proxy task to exit gracefully. Using this over try-except
284+
# handling makes it easier to wait for the other direction to complete.
285+
_, pending = await asyncio.wait(
286+
(client_task, server_task), return_when=asyncio.FIRST_EXCEPTION
287+
)
282288

283-
except asyncio.CancelledError:
284-
pass
289+
if not client.closed:
290+
await client.close()
291+
if not server.closed:
292+
await server.close()
285293

286-
except (RuntimeError, ConnectionError, TypeError) as err:
287-
_LOGGER.info("Home Assistant WebSocket API error: %s", err)
288-
289-
finally:
290-
if client_read and not client_read.done():
291-
client_read.cancel()
292-
if server_read and not server_read.done():
293-
server_read.cancel()
294-
295-
# close connections
296-
if not client.closed:
297-
await client.close()
298-
if not server.closed:
299-
await server.close()
294+
if pending:
295+
_, pending = await asyncio.wait(
296+
pending, timeout=10, return_when=asyncio.ALL_COMPLETED
297+
)
298+
for task in pending:
299+
task.cancel()
300+
logger.critical("WebSocket proxy task: %s did not end gracefully", task)
300301

301-
_LOGGER.info("Home Assistant WebSocket API for %s closed", addon_name)
302+
logger.info("Home Assistant WebSocket API closed")
302303
return server

supervisor/utils/logging.py

+8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88
from typing import Any
99

1010

11+
class AddonLoggerAdapter(logging.LoggerAdapter):
12+
"""Logging Adapter which prepends log entries with add-on name."""
13+
14+
def process(self, msg, kwargs):
15+
"""Process the logging message by prepending the add-on name."""
16+
return f"[{self.extra['addon_name']}] {msg}", kwargs
17+
18+
1119
class SupervisorQueueHandler(logging.handlers.QueueHandler):
1220
"""Process the log in another thread."""
1321

tests/api/test_proxy.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Any, cast
1010
from unittest.mock import patch
1111

12-
from aiohttp import ClientWebSocketResponse
12+
from aiohttp import ClientWebSocketResponse, WSCloseCode
1313
from aiohttp.http_websocket import WSMessage, WSMsgType
1414
from aiohttp.test_utils import TestClient
1515
import pytest
@@ -37,16 +37,20 @@ class MockHAServerWebSocket:
3737
"""Mock of HA Websocket server."""
3838

3939
closed: bool = False
40+
close_code: int | None = None
4041

4142
def __init__(self) -> None:
4243
"""Initialize object."""
4344
self.outgoing: asyncio.Queue[WSMessage] = asyncio.Queue()
4445
self.incoming: asyncio.Queue[WSMessage] = asyncio.Queue()
4546
self._id_generator = id_generator()
4647

47-
def receive(self) -> Awaitable[WSMessage]:
48+
async def receive(self) -> WSMessage:
4849
"""Receive next message."""
49-
return self.outgoing.get()
50+
try:
51+
return await self.outgoing.get()
52+
except asyncio.QueueShutDown:
53+
return WSMessage(WSMsgType.CLOSED, None, None)
5054

5155
def send_str(self, data: str) -> Awaitable[None]:
5256
"""Incoming string message."""
@@ -68,9 +72,11 @@ def respond_bytes(self, data: bytes) -> Awaitable[None]:
6872
"""Respond with binary."""
6973
return self.outgoing.put(WSMessage(WSMsgType.BINARY, data, None))
7074

71-
async def close(self) -> None:
75+
async def close(self, code: int = WSCloseCode.OK) -> None:
7276
"""Close connection."""
7377
self.closed = True
78+
self.outgoing.shutdown(immediate=True)
79+
self.close_code = code
7480

7581

7682
WebSocketGenerator = Callable[..., Coroutine[Any, Any, MockHAClientWebSocket]]
@@ -162,6 +168,26 @@ async def test_proxy_binary_message(
162168
assert await client.close()
163169

164170

171+
async def test_proxy_large_message(
172+
proxy_ws_client: WebSocketGenerator,
173+
ha_ws_server: MockHAServerWebSocket,
174+
install_addon_ssh: Addon,
175+
):
176+
"""Test too large message handled gracefully."""
177+
install_addon_ssh.persist[ATTR_ACCESS_TOKEN] = "abc123"
178+
client: MockHAClientWebSocket = await proxy_ws_client(
179+
install_addon_ssh.supervisor_token
180+
)
181+
182+
# Test message over size limit of 4MB
183+
await client.send_bytes(bytearray(1024 * 1024 * 4))
184+
msg = await client.receive()
185+
assert msg.type == WSMsgType.CLOSE
186+
assert msg.data == WSCloseCode.MESSAGE_TOO_BIG
187+
188+
assert ha_ws_server.closed
189+
190+
165191
@pytest.mark.parametrize("auth_token", ["abc123", "bad"])
166192
async def test_proxy_invalid_auth(
167193
api_client: TestClient, install_addon_example: Addon, auth_token: str

0 commit comments

Comments
 (0)