Skip to content

Commit 4081a05

Browse files
authored
Merge pull request #1565 from njsmith/stop-leaking-sockets
2 parents e7d0571 + 0376be0 commit 4081a05

7 files changed

+168
-149
lines changed

trio/_core/tests/test_multierror.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def run_script(name, use_ipython=False):
600600
print("subprocess PYTHONPATH:", env.get("PYTHONPATH"))
601601

602602
if use_ipython:
603-
lines = [script_path.open().read(), "exit()"]
603+
lines = [script_path.read_text(), "exit()"]
604604

605605
cmd = [
606606
sys.executable,

trio/_highlevel_open_tcp_stream.py

+1
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ async def open_tcp_stream(
222222
open_ssl_over_tcp_stream
223223
224224
"""
225+
225226
# To keep our public API surface smaller, rule out some cases that
226227
# getaddrinfo will accept in some circumstances, but that act weird or
227228
# have non-portable behavior or are just plain not useful.

trio/tests/test_highlevel_open_tcp_listeners.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,11 @@ async def check_backlog(nominal, required_min, required_max):
115115
async def test_open_tcp_listeners_ipv6_v6only():
116116
# Check IPV6_V6ONLY is working properly
117117
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
118-
_, port, *_ = ipv6_listener.socket.getsockname()
118+
async with ipv6_listener:
119+
_, port, *_ = ipv6_listener.socket.getsockname()
119120

120-
with pytest.raises(OSError):
121-
await open_tcp_stream("127.0.0.1", port)
121+
with pytest.raises(OSError):
122+
await open_tcp_stream("127.0.0.1", port)
122123

123124

124125
async def test_open_tcp_listeners_rebind():
@@ -127,10 +128,10 @@ async def test_open_tcp_listeners_rebind():
127128

128129
# Plain old rebinding while it's still there should fail, even if we have
129130
# SO_REUSEADDR set
130-
probe = stdlib_socket.socket()
131-
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
132-
with pytest.raises(OSError):
133-
probe.bind(sockaddr1)
131+
with stdlib_socket.socket() as probe:
132+
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
133+
with pytest.raises(OSError):
134+
probe.bind(sockaddr1)
134135

135136
# Now use the first listener to set up some connections in various states,
136137
# and make sure that they don't create any obstacle to rebinding a second

trio/tests/test_highlevel_socket.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,9 @@ async def accept(self):
259259

260260
async def test_socket_stream_works_when_peer_has_already_closed():
261261
sock_a, sock_b = tsocket.socketpair()
262-
await sock_b.send(b"x")
263-
sock_b.close()
264-
stream = SocketStream(sock_a)
265-
assert await stream.receive_some(1) == b"x"
266-
assert await stream.receive_some(1) == b""
262+
with sock_a, sock_b:
263+
await sock_b.send(b"x")
264+
sock_b.close()
265+
stream = SocketStream(sock_a)
266+
assert await stream.receive_some(1) == b"x"
267+
assert await stream.receive_some(1) == b""

trio/tests/test_highlevel_ssl_helpers.py

+47-43
Original file line numberDiff line numberDiff line change
@@ -43,53 +43,57 @@ async def getnameinfo(self, *args): # pragma: no cover
4343

4444
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
4545
# noqa is needed because flake8 doesn't understand how pytest fixtures work.
46-
async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx,): # noqa: F811
46+
async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811
4747
async with trio.open_nursery() as nursery:
4848
(listener,) = await nursery.start(
4949
partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1",)
5050
)
51-
sockaddr = listener.transport_listener.socket.getsockname()
52-
hostname_resolver = FakeHostnameResolver(sockaddr)
53-
trio.socket.set_custom_hostname_resolver(hostname_resolver)
54-
55-
# We don't have the right trust set up
56-
# (checks that ssl_context=None is doing some validation)
57-
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
58-
with pytest.raises(trio.BrokenResourceError):
59-
await stream.do_handshake()
60-
61-
# We have the trust but not the hostname
62-
# (checks custom ssl_context + hostname checking)
63-
stream = await open_ssl_over_tcp_stream(
64-
"xyzzy.example.org", 80, ssl_context=client_ctx,
65-
)
66-
with pytest.raises(trio.BrokenResourceError):
67-
await stream.do_handshake()
68-
69-
# This one should work!
70-
stream = await open_ssl_over_tcp_stream(
71-
"trio-test-1.example.org", 80, ssl_context=client_ctx,
72-
)
73-
assert isinstance(stream, trio.SSLStream)
74-
assert stream.server_hostname == "trio-test-1.example.org"
75-
await stream.send_all(b"x")
76-
assert await stream.receive_some(1) == b"x"
77-
await stream.aclose()
78-
79-
# Check https_compatible settings are being passed through
80-
assert not stream._https_compatible
81-
stream = await open_ssl_over_tcp_stream(
82-
"trio-test-1.example.org",
83-
80,
84-
ssl_context=client_ctx,
85-
https_compatible=True,
86-
# also, smoke test happy_eyeballs_delay
87-
happy_eyeballs_delay=1,
88-
)
89-
assert stream._https_compatible
90-
91-
# Stop the echo server
92-
nursery.cancel_scope.cancel()
51+
async with listener:
52+
sockaddr = listener.transport_listener.socket.getsockname()
53+
hostname_resolver = FakeHostnameResolver(sockaddr)
54+
trio.socket.set_custom_hostname_resolver(hostname_resolver)
55+
56+
# We don't have the right trust set up
57+
# (checks that ssl_context=None is doing some validation)
58+
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
59+
async with stream:
60+
with pytest.raises(trio.BrokenResourceError):
61+
await stream.do_handshake()
62+
63+
# We have the trust but not the hostname
64+
# (checks custom ssl_context + hostname checking)
65+
stream = await open_ssl_over_tcp_stream(
66+
"xyzzy.example.org", 80, ssl_context=client_ctx,
67+
)
68+
async with stream:
69+
with pytest.raises(trio.BrokenResourceError):
70+
await stream.do_handshake()
71+
72+
# This one should work!
73+
stream = await open_ssl_over_tcp_stream(
74+
"trio-test-1.example.org", 80, ssl_context=client_ctx,
75+
)
76+
async with stream:
77+
assert isinstance(stream, trio.SSLStream)
78+
assert stream.server_hostname == "trio-test-1.example.org"
79+
await stream.send_all(b"x")
80+
assert await stream.receive_some(1) == b"x"
81+
82+
# Check https_compatible settings are being passed through
83+
assert not stream._https_compatible
84+
stream = await open_ssl_over_tcp_stream(
85+
"trio-test-1.example.org",
86+
80,
87+
ssl_context=client_ctx,
88+
https_compatible=True,
89+
# also, smoke test happy_eyeballs_delay
90+
happy_eyeballs_delay=1,
91+
)
92+
async with stream:
93+
assert stream._https_compatible
94+
95+
# Stop the echo server
96+
nursery.cancel_scope.cancel()
9397

9498

9599
async def test_open_ssl_over_tcp_listeners():

trio/tests/test_socket.py

+78-69
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,9 @@ async def test_from_stdlib_socket():
224224
class MySocket(stdlib_socket.socket):
225225
pass
226226

227-
mysock = MySocket()
228-
with pytest.raises(TypeError):
229-
tsocket.from_stdlib_socket(mysock)
227+
with MySocket() as mysock:
228+
with pytest.raises(TypeError):
229+
tsocket.from_stdlib_socket(mysock)
230230

231231

232232
async def test_from_fd():
@@ -292,12 +292,15 @@ async def test_sniff_sockopts():
292292
# check family / type for correctness:
293293
assert tsocket_socket.family == socket.family
294294
assert tsocket_socket.type == socket.type
295+
tsocket_socket.detach()
295296

296297
# fromfd constructor
297298
tsocket_from_fd = tsocket.fromfd(socket.fileno(), AF_INET, SOCK_STREAM)
298299
# check family / type for correctness:
299300
assert tsocket_from_fd.family == socket.family
300301
assert tsocket_from_fd.type == socket.type
302+
tsocket_from_fd.close()
303+
301304
socket.close()
302305

303306

@@ -482,73 +485,78 @@ class Addresses:
482485
async def test_SocketType_resolve(socket_type, addrs):
483486
v6 = socket_type == tsocket.AF_INET6
484487

485-
# For some reason the stdlib special-cases "" to pass NULL to getaddrinfo
486-
# They also error out on None, but whatever, None is much more consistent,
487-
# so we accept it too.
488-
for null in [None, ""]:
489-
sock = tsocket.socket(family=socket_type)
490-
got = await sock._resolve_local_address((null, 80))
491-
assert got == (addrs.bind_all, 80, *addrs.extra)
492-
got = await sock._resolve_remote_address((null, 80))
493-
assert got == (addrs.localhost, 80, *addrs.extra)
494-
495-
# AI_PASSIVE only affects the wildcard address, so for everything else
496-
# _resolve_local_address and _resolve_remote_address should work the same:
497-
for resolver in ["_resolve_local_address", "_resolve_remote_address"]:
498-
499-
async def res(*args):
500-
return await getattr(sock, resolver)(*args)
501-
502-
assert await res((addrs.arbitrary, "http")) == (
503-
addrs.arbitrary,
504-
80,
505-
*addrs.extra,
506-
)
507-
if v6:
508-
assert await res(("1::2", 80, 1)) == ("1::2", 80, 1, 0)
509-
assert await res(("1::2", 80, 1, 2)) == ("1::2", 80, 1, 2)
510-
511-
# V4 mapped addresses resolved if V6ONLY is False
512-
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False)
513-
assert await res(("1.2.3.4", "http")) == ("::ffff:1.2.3.4", 80, 0, 0,)
514-
515-
# Check the <broadcast> special case, because why not
516-
assert await res(("<broadcast>", 123)) == (addrs.broadcast, 123, *addrs.extra,)
517-
518-
# But not if it's true (at least on systems where getaddrinfo works
519-
# correctly)
520-
if v6 and not gai_without_v4mapped_is_buggy():
521-
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True)
522-
with pytest.raises(tsocket.gaierror) as excinfo:
523-
await res(("1.2.3.4", 80))
524-
# Windows, macOS
525-
expected_errnos = {tsocket.EAI_NONAME}
526-
# Linux
527-
if hasattr(tsocket, "EAI_ADDRFAMILY"):
528-
expected_errnos.add(tsocket.EAI_ADDRFAMILY)
529-
assert excinfo.value.errno in expected_errnos
530-
531-
# A family where we know nothing about the addresses, so should just
532-
# pass them through. This should work on Linux, which is enough to
533-
# smoke test the basic functionality...
534-
try:
535-
netlink_sock = tsocket.socket(
536-
family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM
488+
with tsocket.socket(family=socket_type) as sock:
489+
# For some reason the stdlib special-cases "" to pass NULL to
490+
# getaddrinfo They also error out on None, but whatever, None is much
491+
# more consistent, so we accept it too.
492+
for null in [None, ""]:
493+
got = await sock._resolve_local_address((null, 80))
494+
assert got == (addrs.bind_all, 80, *addrs.extra)
495+
got = await sock._resolve_remote_address((null, 80))
496+
assert got == (addrs.localhost, 80, *addrs.extra)
497+
498+
# AI_PASSIVE only affects the wildcard address, so for everything else
499+
# _resolve_local_address and _resolve_remote_address should work the same:
500+
for resolver in ["_resolve_local_address", "_resolve_remote_address"]:
501+
502+
async def res(*args):
503+
return await getattr(sock, resolver)(*args)
504+
505+
assert await res((addrs.arbitrary, "http")) == (
506+
addrs.arbitrary,
507+
80,
508+
*addrs.extra,
537509
)
538-
except (AttributeError, OSError):
539-
pass
540-
else:
541-
assert await getattr(netlink_sock, resolver)("asdf") == "asdf"
542-
543-
with pytest.raises(ValueError):
544-
await res("1.2.3.4")
545-
with pytest.raises(ValueError):
546-
await res(("1.2.3.4",))
547-
with pytest.raises(ValueError):
548510
if v6:
549-
await res(("1.2.3.4", 80, 0, 0, 0))
511+
assert await res(("1::2", 80, 1)) == ("1::2", 80, 1, 0)
512+
assert await res(("1::2", 80, 1, 2)) == ("1::2", 80, 1, 2)
513+
514+
# V4 mapped addresses resolved if V6ONLY is False
515+
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False)
516+
assert await res(("1.2.3.4", "http")) == ("::ffff:1.2.3.4", 80, 0, 0,)
517+
518+
# Check the <broadcast> special case, because why not
519+
assert await res(("<broadcast>", 123)) == (
520+
addrs.broadcast,
521+
123,
522+
*addrs.extra,
523+
)
524+
525+
# But not if it's true (at least on systems where getaddrinfo works
526+
# correctly)
527+
if v6 and not gai_without_v4mapped_is_buggy():
528+
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True)
529+
with pytest.raises(tsocket.gaierror) as excinfo:
530+
await res(("1.2.3.4", 80))
531+
# Windows, macOS
532+
expected_errnos = {tsocket.EAI_NONAME}
533+
# Linux
534+
if hasattr(tsocket, "EAI_ADDRFAMILY"):
535+
expected_errnos.add(tsocket.EAI_ADDRFAMILY)
536+
assert excinfo.value.errno in expected_errnos
537+
538+
# A family where we know nothing about the addresses, so should just
539+
# pass them through. This should work on Linux, which is enough to
540+
# smoke test the basic functionality...
541+
try:
542+
netlink_sock = tsocket.socket(
543+
family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM
544+
)
545+
except (AttributeError, OSError):
546+
pass
550547
else:
551-
await res(("1.2.3.4", 80, 0, 0))
548+
assert await getattr(netlink_sock, resolver)("asdf") == "asdf"
549+
netlink_sock.close()
550+
551+
with pytest.raises(ValueError):
552+
await res("1.2.3.4")
553+
with pytest.raises(ValueError):
554+
await res(("1.2.3.4",))
555+
with pytest.raises(ValueError):
556+
if v6:
557+
await res(("1.2.3.4", 80, 0, 0, 0))
558+
else:
559+
await res(("1.2.3.4", 80, 0, 0))
552560

553561

554562
async def test_SocketType_unresolved_names():
@@ -923,8 +931,9 @@ async def check_AF_UNIX(path):
923931
with tsocket.socket(family=tsocket.AF_UNIX) as csock:
924932
await csock.connect(path)
925933
ssock, _ = await lsock.accept()
926-
await csock.send(b"x")
927-
assert await ssock.recv(1) == b"x"
934+
with ssock:
935+
await csock.send(b"x")
936+
assert await ssock.recv(1) == b"x"
928937

929938
# Can't use tmpdir fixture, because we can exceed the maximum AF_UNIX path
930939
# length on macOS.

0 commit comments

Comments
 (0)