diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index efb40a7f4..4c8373930 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -230,6 +230,10 @@ class Server: logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. + connections: A set of open :class:`ServerConnection` instances + maintained by `handler`. When omitted, e.g., if the handler does + not maintain such a set, this defaults to an empty set and the + server will not attempt to close connections on shutdown. """ @@ -238,12 +242,20 @@ def __init__( socket: socket.socket, handler: Callable[[socket.socket, Any], None], logger: LoggerLike | None = None, + *, + connections: set[ServerConnection] | None = None, ) -> None: self.socket = socket self.handler = handler if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger + + # _connections tracks active connections + if connections is None: + connections = set() + self._connections = connections + if sys.platform != "win32": self.shutdown_watcher, self.shutdown_notifier = os.pipe() @@ -289,11 +301,22 @@ def shutdown(self) -> None: """ See :meth:`socketserver.BaseServer.shutdown`. + Shuts down the server and closes existing connections. + """ self.socket.close() if sys.platform != "win32": os.write(self.shutdown_notifier, b"x") + # Close all connections + conns = list(self._connections) + for conn in conns: + try: + conn.close() + except Exception as exc: + debug_msg = f"Could not close {conn.id}: {exc}" + self.logger.debug(debug_msg, exc_info=exc) + def fileno(self) -> int: """ See :meth:`socketserver.BaseServer.fileno`. @@ -516,6 +539,24 @@ def handler(websocket): do_handshake_on_connect=False, ) + # Stores active ServerConnection instances, used by the server to handle graceful + # shutdown in Server.shutdown() + connections: set[ServerConnection] = set() + + def on_connection_created(connection: ServerConnection) -> None: + # Invoked from conn_handler() to add a new ServerConnection instance to + # Server._connections + connections.add(connection) + + def on_connection_closed(connection: ServerConnection) -> None: + # Invoked from conn_handler() to remove a closed ServerConnection instance from + # Server._connections. Keeping only active references in the set is important + # for avoiding memory leaks. + try: + connections.remove(connection) + except KeyError: # pragma: no cover + pass + # Define request handler def conn_handler(sock: socket.socket, addr: Any) -> None: @@ -581,6 +622,7 @@ def protocol_select_subprotocol( close_timeout=close_timeout, max_queue=max_queue, ) + on_connection_created(connection) except Exception: sock.close() return @@ -595,11 +637,13 @@ def protocol_select_subprotocol( ) except TimeoutError: connection.close_socket() + on_connection_closed(connection) connection.recv_events_thread.join() return except Exception: connection.logger.error("opening handshake failed", exc_info=True) connection.close_socket() + on_connection_closed(connection) connection.recv_events_thread.join() return @@ -610,8 +654,10 @@ def protocol_select_subprotocol( except Exception: connection.logger.error("connection handler failed", exc_info=True) connection.close(CloseCode.INTERNAL_ERROR) + on_connection_closed(connection) else: connection.close() + on_connection_closed(connection) except Exception: # pragma: no cover # Don't leak sockets on unexpected errors. @@ -619,7 +665,7 @@ def protocol_select_subprotocol( # Initialize server - return Server(sock, conn_handler, logger) + return Server(sock, conn_handler, logger, connections=connections) def unix_serve( diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index d04d1859a..53b3eed14 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -3,9 +3,11 @@ import http import logging import socket +import threading import time import unittest +from websockets import CloseCode from websockets.exceptions import ( ConnectionClosedError, ConnectionClosedOK, @@ -338,6 +340,71 @@ def test_junk_handshake(self): ["invalid HTTP request line: HELO relay.invalid"], ) + def test_initialize_server_without_tracking_connections(self): + """Call Server() constructor without 'connections' arg.""" + with socket.create_server(("localhost", 0)) as sock: + server = Server(socket=sock, handler=handler) + self.assertIsInstance( + server._connections, set, "Server._connections property not initialized" + ) + + def test_connections_is_empty_after_disconnects(self): + """Clients are added to Server._connections, and removed when disconnected.""" + with run_server() as server: + connections: set[ServerConnection] = server._connections + with connect(get_uri(server)): + self.assertEqual(len(connections), 1) + time.sleep(0.5) + self.assertEqual(len(connections), 0) + + def test_shutdown_calls_close_for_all_connections(self): + """Graceful shutdown with broken ServerConnection.close() implementations.""" + CLIENTS_TO_LAUNCH = 3 + + connections_attempted = 0 + + class ServerConnectionWithBrokenClose(ServerConnection): + close_method_called = False + + def close(self, code=CloseCode.NORMAL_CLOSURE, reason=""): + """Custom close method that intentionally fails.""" + + # Do not increment the counter when calling .close() multiple times + if self.close_method_called: + return + self.close_method_called = True + + nonlocal connections_attempted + connections_attempted += 1 + raise Exception("broken close method") + + clients: set[threading.Thread] = set() + with run_server(create_connection=ServerConnectionWithBrokenClose) as server: + + def client(): + with connect(get_uri(server)): + time.sleep(1) + + for i in range(CLIENTS_TO_LAUNCH): + client_thread = threading.Thread(target=client) + client_thread.start() + clients.add(client_thread) + time.sleep(0.2) + self.assertEqual( + len(server._connections), + CLIENTS_TO_LAUNCH, + "not all clients connected to the server yet, increase sleep duration", + ) + server.shutdown() + while len(clients) > 0: + client = clients.pop() + client.join() + self.assertEqual( + connections_attempted, + CLIENTS_TO_LAUNCH, + "server did not call ServerConnection.close() on all connections", + ) + class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self):