Skip to content

Commit 8941d17

Browse files
authored
fix: reconnect when send message fails (#295)
1 parent 8b21e31 commit 8941d17

File tree

4 files changed

+85
-41
lines changed

4 files changed

+85
-41
lines changed

poetry.lock

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-4
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,15 @@ python-dateutil = "^2.8.1"
1717
typing-extensions = "^4.12.2"
1818
aiohttp = "^3.11.14"
1919

20-
[tool.poetry.dev-dependencies]
20+
[tool.poetry.group.dev.dependencies]
2121
pytest = "^8.3.5"
2222
pytest-cov = "^5.0.0"
2323
python-dotenv = "^1.1.0"
2424
pytest-asyncio = "^0.26.0"
2525
coveralls = "^3.0.0"
26-
27-
[tool.poetry.group.dev.dependencies]
2826
black = ">=23.11,<26.0"
2927
isort = "^6.0.1"
3028
pre-commit = "^4.2.0"
31-
pytest-cov = "^5.0.0"
3229

3330
[build-system]
3431
requires = ["poetry-core>=1.0.0"]

realtime/_async/client.py

+37-32
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
self.access_token = token
7575
self.send_buffer: List[Callable] = []
7676
self.hb_interval = hb_interval
77-
self.ws_connection: Optional[ClientProtocol] = None
77+
self._ws_connection: Optional[ClientProtocol] = None
7878
self.ref = 0
7979
self.auto_reconnect = auto_reconnect
8080
self.channels: Dict[str, AsyncRealtimeChannel] = {}
@@ -86,39 +86,31 @@ def __init__(
8686

8787
@property
8888
def is_connected(self) -> bool:
89-
return self.ws_connection is not None
89+
return self._ws_connection is not None
9090

9191
async def _listen(self) -> None:
9292
"""
9393
An infinite loop that keeps listening.
9494
:return: None
9595
"""
9696

97-
if not self.ws_connection:
97+
if not self._ws_connection:
9898
raise Exception("WebSocket connection not established")
9999

100100
try:
101-
async for msg in self.ws_connection:
101+
async for msg in self._ws_connection:
102102
logger.info(f"receive: {msg}")
103103

104104
msg = Message(**json.loads(msg))
105105
channel = self.channels.get(msg.topic)
106106

107107
if channel:
108108
channel._trigger(msg.event, msg.payload, msg.ref)
109-
except websockets.exceptions.ConnectionClosedError as e:
110-
logger.error(
111-
f"WebSocket connection closed with code: {e.code}, reason: {e.reason}"
112-
)
113-
if self.auto_reconnect:
114-
logger.info("Initiating auto-reconnect sequence...")
115-
116-
await self._reconnect()
117-
else:
118-
logger.error("Auto-reconnect disabled, terminating connection")
109+
except Exception as e:
110+
await self._on_connect_error(e)
119111

120112
async def _reconnect(self) -> None:
121-
self.ws_connection = None
113+
self._ws_connection = None
122114
await self.connect()
123115

124116
if self.is_connected:
@@ -156,7 +148,7 @@ async def connect(self) -> None:
156148
while retries < self.max_retries:
157149
try:
158150
ws = await connect(self.url)
159-
self.ws_connection = ws
151+
self._ws_connection = ws
160152
logger.info("WebSocket connection established successfully")
161153
return await self._on_connect()
162154
except Exception as e:
@@ -197,6 +189,20 @@ async def _on_connect(self) -> None:
197189
self._heartbeat_task = asyncio.create_task(self._heartbeat())
198190
await self._flush_send_buffer()
199191

192+
async def _on_connect_error(self, e: Exception) -> None:
193+
if isinstance(e, websockets.exceptions.ConnectionClosedError):
194+
logger.error(
195+
f"WebSocket connection closed with code: {e.code}, reason: {e.reason}"
196+
)
197+
198+
if self.auto_reconnect:
199+
logger.info("Initiating auto-reconnect sequence...")
200+
await self._reconnect()
201+
else:
202+
logger.error("Auto-reconnect disabled, terminating connection")
203+
else:
204+
logger.error(f"Error on connect: {e}")
205+
200206
async def _flush_send_buffer(self):
201207
if self.is_connected and len(self.send_buffer) > 0:
202208
for callback in self.send_buffer:
@@ -214,10 +220,10 @@ async def close(self) -> None:
214220
NotConnectedError: If the connection is not established when this method is called.
215221
"""
216222

217-
if self.ws_connection:
218-
await self.ws_connection.close()
223+
if self._ws_connection:
224+
await self._ws_connection.close()
219225

220-
self.ws_connection = None
226+
self._ws_connection = None
221227

222228
if self._listen_task:
223229
self._listen_task.cancel()
@@ -228,7 +234,7 @@ async def close(self) -> None:
228234
self._heartbeat_task = None
229235

230236
async def _heartbeat(self) -> None:
231-
if not self.ws_connection:
237+
if not self._ws_connection:
232238
raise Exception("WebSocket connection not established")
233239

234240
while self.is_connected:
@@ -242,17 +248,8 @@ async def _heartbeat(self) -> None:
242248
await self.send(data)
243249
await asyncio.sleep(max(self.hb_interval, 15))
244250

245-
except websockets.exceptions.ConnectionClosed as e:
246-
logger.error(
247-
f"Connection closed during heartbeat. Code: {e.code}, reason: {e.reason}"
248-
)
249-
250-
if self.auto_reconnect:
251-
logger.info("Heartbeat failed - initiating reconnection sequence")
252-
await self._reconnect()
253-
else:
254-
logger.error("Heartbeat failed - auto-reconnect disabled")
255-
break
251+
except Exception as e:
252+
await self._on_connect_error(e)
256253

257254
def channel(
258255
self, topic: str, params: Optional[RealtimeChannelOptions] = None
@@ -373,7 +370,15 @@ async def send(self, message: Dict[str, Any]) -> None:
373370
logger.info(f"send: {message}")
374371

375372
async def send_message():
376-
await self.ws_connection.send(message)
373+
if not self._ws_connection:
374+
raise Exception(
375+
"WebSocket connection not established, a connection is expected to be established before sending a message"
376+
)
377+
378+
try:
379+
await self._ws_connection.send(message)
380+
except Exception as e:
381+
await self._on_connect_error(e)
377382

378383
if self.is_connected:
379384
await send_message()

tests/test_connection.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,12 @@ async def test_multiple_connect_attempts(socket: AsyncRealtimeClient):
259259
# First connection should succeed
260260
await socket.connect()
261261
assert socket.is_connected
262-
initial_ws = socket.ws_connection
262+
initial_ws = socket._ws_connection
263263

264264
# Second connection attempt should be a no-op since we're already connected
265265
await socket.connect()
266266
assert socket.is_connected
267-
assert socket.ws_connection == initial_ws # Should be the same connection object
267+
assert socket._ws_connection == initial_ws # Should be the same connection object
268268

269269
await socket.close()
270270
assert not socket.is_connected
@@ -311,3 +311,45 @@ async def test_multiple_connect_attempts(socket: AsyncRealtimeClient):
311311
assert socket.is_connected
312312

313313
await socket.close()
314+
315+
316+
@pytest.mark.asyncio
317+
async def test_send_message_reconnection(socket: AsyncRealtimeClient):
318+
# First establish a connection
319+
await socket.connect()
320+
assert socket.is_connected
321+
322+
# Create a channel and subscribe to it
323+
channel = socket.channel("test-channel")
324+
subscribe_event = asyncio.Event()
325+
await channel.subscribe(
326+
lambda state, _: (
327+
subscribe_event.set()
328+
if state == RealtimeSubscribeStates.SUBSCRIBED
329+
else None
330+
)
331+
)
332+
await asyncio.wait_for(subscribe_event.wait(), 5)
333+
334+
# Simulate a connection failure by closing the WebSocket
335+
if socket._ws_connection:
336+
await socket._ws_connection.close()
337+
338+
# Try to send a message - this should trigger reconnection
339+
message = {
340+
"topic": "test-channel",
341+
"event": "test-event",
342+
"payload": {"test": "data"},
343+
}
344+
await socket.send(message)
345+
346+
# Wait for reconnection to complete
347+
await asyncio.sleep(1) # Give some time for reconnection
348+
349+
# Verify we're connected again
350+
assert socket.is_connected
351+
352+
# Try sending another message to verify the connection is working
353+
await socket.send(message)
354+
355+
await socket.close()

0 commit comments

Comments
 (0)