|
10 | 10 | from aiohttp.client_exceptions import ClientConnectorError
|
11 | 11 | from aiohttp.client_ws import ClientWebSocketResponse
|
12 | 12 | from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE
|
13 |
| -from aiohttp.http import WSMessage |
14 | 13 | from aiohttp.http_websocket import WSMsgType
|
15 | 14 | from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized
|
16 | 15 |
|
| 16 | +from supervisor.utils.logging import AddonLoggerAdapter |
| 17 | + |
17 | 18 | from ..coresys import CoreSysAttributes
|
18 | 19 | from ..exceptions import APIError, HomeAssistantAPIError, HomeAssistantAuthError
|
19 | 20 | from ..utils.json import json_dumps
|
@@ -179,23 +180,39 @@ async def _websocket_client(self) -> ClientWebSocketResponse:
|
179 | 180 |
|
180 | 181 | async def _proxy_message(
|
181 | 182 | self,
|
182 |
| - read_task: asyncio.Task, |
| 183 | + source: web.WebSocketResponse | ClientWebSocketResponse, |
183 | 184 | target: web.WebSocketResponse | ClientWebSocketResponse,
|
| 185 | + logger: AddonLoggerAdapter, |
184 | 186 | ) -> None:
|
185 | 187 | """Proxy a message from client to server or vice versa."""
|
186 |
| - msg: WSMessage = read_task.result() |
187 |
| - match msg.type: |
188 |
| - case WSMsgType.TEXT: |
189 |
| - await target.send_str(msg.data) |
190 |
| - case WSMsgType.BINARY: |
191 |
| - await target.send_bytes(msg.data) |
192 |
| - case WSMsgType.CLOSE: |
193 |
| - _LOGGER.debug("Received close message from WebSocket.") |
194 |
| - await target.close() |
195 |
| - case _: |
196 |
| - raise TypeError( |
197 |
| - f"Cannot proxy websocket message of unsupported type: {msg.type}" |
198 |
| - ) |
| 188 | + while not source.closed and not target.closed: |
| 189 | + msg = await source.receive() |
| 190 | + match msg.type: |
| 191 | + case WSMsgType.TEXT: |
| 192 | + await target.send_str(msg.data) |
| 193 | + case WSMsgType.BINARY: |
| 194 | + await target.send_bytes(msg.data) |
| 195 | + case WSMsgType.CLOSE | WSMsgType.CLOSED: |
| 196 | + logger.debug( |
| 197 | + "Received WebSocket message type %r from %s.", |
| 198 | + msg.type, |
| 199 | + "add-on" if type(source) is web.WebSocketResponse else "Core", |
| 200 | + ) |
| 201 | + await target.close() |
| 202 | + case WSMsgType.CLOSING: |
| 203 | + pass |
| 204 | + case WSMsgType.ERROR: |
| 205 | + logger.warning( |
| 206 | + "Error WebSocket message received while proxying: %r", msg.data |
| 207 | + ) |
| 208 | + await target.close(code=source.close_code) |
| 209 | + case _: |
| 210 | + logger.warning( |
| 211 | + "Cannot proxy WebSocket message of unsupported type: %r", |
| 212 | + msg.type, |
| 213 | + ) |
| 214 | + await source.close() |
| 215 | + await target.close() |
199 | 216 |
|
200 | 217 | async def websocket(self, request: web.Request):
|
201 | 218 | """Initialize a WebSocket API connection."""
|
@@ -255,48 +272,32 @@ async def websocket(self, request: web.Request):
|
255 | 272 | except APIError:
|
256 | 273 | return server
|
257 | 274 |
|
258 |
| - _LOGGER.info("Home Assistant WebSocket API request running") |
259 |
| - try: |
260 |
| - client_read: asyncio.Task | None = None |
261 |
| - server_read: asyncio.Task | None = None |
262 |
| - while not server.closed and not client.closed: |
263 |
| - if not client_read: |
264 |
| - client_read = self.sys_create_task(client.receive()) |
265 |
| - if not server_read: |
266 |
| - server_read = self.sys_create_task(server.receive()) |
267 |
| - |
268 |
| - # wait until data need to be processed |
269 |
| - await asyncio.wait( |
270 |
| - [client_read, server_read], return_when=asyncio.FIRST_COMPLETED |
271 |
| - ) |
| 275 | + logger = AddonLoggerAdapter(_LOGGER, {"addon_name": addon_name}) |
| 276 | + logger.info("Home Assistant WebSocket API proxy running") |
272 | 277 |
|
273 |
| - # server |
274 |
| - if server_read.done() and not client.closed: |
275 |
| - await self._proxy_message(server_read, client) |
276 |
| - server_read = None |
| 278 | + client_task = self.sys_create_task(self._proxy_message(client, server, logger)) |
| 279 | + server_task = self.sys_create_task(self._proxy_message(server, client, logger)) |
277 | 280 |
|
278 |
| - # client |
279 |
| - if client_read.done() and not server.closed: |
280 |
| - await self._proxy_message(client_read, server) |
281 |
| - client_read = None |
| 281 | + # Typically, this will return with an empty pending set. However, if one of |
| 282 | + # the directions has an exception, make sure to close both connections and |
| 283 | + # wait for the other proxy task to exit gracefully. Using this over try-except |
| 284 | + # handling makes it easier to wait for the other direction to complete. |
| 285 | + _, pending = await asyncio.wait( |
| 286 | + (client_task, server_task), return_when=asyncio.FIRST_EXCEPTION |
| 287 | + ) |
282 | 288 |
|
283 |
| - except asyncio.CancelledError: |
284 |
| - pass |
| 289 | + if not client.closed: |
| 290 | + await client.close() |
| 291 | + if not server.closed: |
| 292 | + await server.close() |
285 | 293 |
|
286 |
| - except (RuntimeError, ConnectionError, TypeError) as err: |
287 |
| - _LOGGER.info("Home Assistant WebSocket API error: %s", err) |
288 |
| - |
289 |
| - finally: |
290 |
| - if client_read and not client_read.done(): |
291 |
| - client_read.cancel() |
292 |
| - if server_read and not server_read.done(): |
293 |
| - server_read.cancel() |
294 |
| - |
295 |
| - # close connections |
296 |
| - if not client.closed: |
297 |
| - await client.close() |
298 |
| - if not server.closed: |
299 |
| - await server.close() |
| 294 | + if pending: |
| 295 | + _, pending = await asyncio.wait( |
| 296 | + pending, timeout=10, return_when=asyncio.ALL_COMPLETED |
| 297 | + ) |
| 298 | + for task in pending: |
| 299 | + task.cancel() |
| 300 | + logger.critical("WebSocket proxy task: %s did not end gracefully", task) |
300 | 301 |
|
301 |
| - _LOGGER.info("Home Assistant WebSocket API for %s closed", addon_name) |
| 302 | + logger.info("Home Assistant WebSocket API closed") |
302 | 303 | return server
|
0 commit comments