Skip to content

Commit 079c963

Browse files
dholthreaperhulk
authored andcommitted
use _ffi.from_buffer() to support bytearray (pyca#852)
* use _ffi.from_buffer(buf) in send, to support bytearray * add bytearray test * update CHANGELOG.rst * move from_buffer before 'buffer too long' check * context-managed from_buffer + black * don't shadow buf in send() * test return count for sendall * test sending an array * fix test * also use from_buffer in bio_write * de-format _util.py * formatting * add simple bio_write tests * wrap line
1 parent 8543286 commit 079c963

File tree

5 files changed

+92
-40
lines changed

5 files changed

+92
-40
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ doc/_build/
1111
examples/simple/*.cert
1212
examples/simple/*.pkey
1313
.cache
14+
.mypy_cache

CHANGELOG.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ Deprecations:
2828
Changes:
2929
^^^^^^^^
3030

31-
*none*
31+
- Support ``bytearray`` in ``SSL.Connection.send()`` by using cffi's from_buffer.
32+
`#852 <https://github.com/pyca/pyopenssl/pull/852>`_
3233

3334

3435
----

src/OpenSSL/SSL.py

+35-37
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
UNSPECIFIED as _UNSPECIFIED,
1616
exception_from_error_queue as _exception_from_error_queue,
1717
ffi as _ffi,
18+
from_buffer as _from_buffer,
1819
lib as _lib,
1920
make_assert as _make_assert,
2021
native as _native,
@@ -1730,18 +1731,18 @@ def send(self, buf, flags=0):
17301731
# Backward compatibility
17311732
buf = _text_to_bytes_and_warn("buf", buf)
17321733

1733-
if isinstance(buf, memoryview):
1734-
buf = buf.tobytes()
1735-
if isinstance(buf, _buffer):
1736-
buf = str(buf)
1737-
if not isinstance(buf, bytes):
1738-
raise TypeError("data must be a memoryview, buffer or byte string")
1739-
if len(buf) > 2147483647:
1740-
raise ValueError("Cannot send more than 2**31-1 bytes at once.")
1734+
with _from_buffer(buf) as data:
1735+
# check len(buf) instead of len(data) for testability
1736+
if len(buf) > 2147483647:
1737+
raise ValueError(
1738+
"Cannot send more than 2**31-1 bytes at once."
1739+
)
1740+
1741+
result = _lib.SSL_write(self._ssl, data, len(data))
1742+
self._raise_ssl_error(self._ssl, result)
1743+
1744+
return result
17411745

1742-
result = _lib.SSL_write(self._ssl, buf, len(buf))
1743-
self._raise_ssl_error(self._ssl, result)
1744-
return result
17451746
write = send
17461747

17471748
def sendall(self, buf, flags=0):
@@ -1757,28 +1758,24 @@ def sendall(self, buf, flags=0):
17571758
"""
17581759
buf = _text_to_bytes_and_warn("buf", buf)
17591760

1760-
if isinstance(buf, memoryview):
1761-
buf = buf.tobytes()
1762-
if isinstance(buf, _buffer):
1763-
buf = str(buf)
1764-
if not isinstance(buf, bytes):
1765-
raise TypeError("buf must be a memoryview, buffer or byte string")
1766-
1767-
left_to_send = len(buf)
1768-
total_sent = 0
1769-
data = _ffi.new("char[]", buf)
1770-
1771-
while left_to_send:
1772-
# SSL_write's num arg is an int,
1773-
# so we cannot send more than 2**31-1 bytes at once.
1774-
result = _lib.SSL_write(
1775-
self._ssl,
1776-
data + total_sent,
1777-
min(left_to_send, 2147483647)
1778-
)
1779-
self._raise_ssl_error(self._ssl, result)
1780-
total_sent += result
1781-
left_to_send -= result
1761+
with _from_buffer(buf) as data:
1762+
1763+
left_to_send = len(buf)
1764+
total_sent = 0
1765+
1766+
while left_to_send:
1767+
# SSL_write's num arg is an int,
1768+
# so we cannot send more than 2**31-1 bytes at once.
1769+
result = _lib.SSL_write(
1770+
self._ssl,
1771+
data + total_sent,
1772+
min(left_to_send, 2147483647)
1773+
)
1774+
self._raise_ssl_error(self._ssl, result)
1775+
total_sent += result
1776+
left_to_send -= result
1777+
1778+
return total_sent
17821779

17831780
def recv(self, bufsiz, flags=None):
17841781
"""
@@ -1892,10 +1889,11 @@ def bio_write(self, buf):
18921889
if self._into_ssl is None:
18931890
raise TypeError("Connection sock was not None")
18941891

1895-
result = _lib.BIO_write(self._into_ssl, buf, len(buf))
1896-
if result <= 0:
1897-
self._handle_bio_errors(self._into_ssl, result)
1898-
return result
1892+
with _from_buffer(buf) as data:
1893+
result = _lib.BIO_write(self._into_ssl, data, len(data))
1894+
if result <= 0:
1895+
self._handle_bio_errors(self._into_ssl, result)
1896+
return result
18991897

19001898
def renegotiate(self):
19011899
"""

src/OpenSSL/_util.py

+14
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,17 @@ def text_to_bytes_and_warn(label, obj):
145145
)
146146
return obj.encode('utf-8')
147147
return obj
148+
149+
150+
try:
151+
# newer versions of cffi free the buffer deterministically
152+
with ffi.from_buffer(b""):
153+
pass
154+
from_buffer = ffi.from_buffer
155+
except AttributeError:
156+
# cffi < 0.12 frees the buffer with refcounting gc
157+
from contextlib import contextmanager
158+
159+
@contextmanager
160+
def from_buffer(*args):
161+
yield ffi.from_buffer(*args)

tests/test_ssl.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -2087,6 +2087,29 @@ def test_wrong_args(self, bad_context):
20872087
with pytest.raises(TypeError):
20882088
Connection(bad_context)
20892089

2090+
@pytest.mark.parametrize('bad_bio', [object(), None, 1, [1, 2, 3]])
2091+
def test_bio_write_wrong_args(self, bad_bio):
2092+
"""
2093+
`Connection.bio_write` raises `TypeError` if called with a non-bytes
2094+
(or text) argument.
2095+
"""
2096+
context = Context(TLSv1_METHOD)
2097+
connection = Connection(context, None)
2098+
with pytest.raises(TypeError):
2099+
connection.bio_write(bad_bio)
2100+
2101+
def test_bio_write(self):
2102+
"""
2103+
`Connection.bio_write` does not raise if called with bytes or
2104+
bytearray, warns if called with text.
2105+
"""
2106+
context = Context(TLSv1_METHOD)
2107+
connection = Connection(context, None)
2108+
connection.bio_write(b'xy')
2109+
connection.bio_write(bytearray(b'za'))
2110+
with pytest.warns(DeprecationWarning):
2111+
connection.bio_write(u'deprecated')
2112+
20902113
def test_get_context(self):
20912114
"""
20922115
`Connection.get_context` returns the `Context` instance used to
@@ -2807,6 +2830,8 @@ def test_wrong_args(self):
28072830
connection = Connection(Context(TLSv1_METHOD), None)
28082831
with pytest.raises(TypeError):
28092832
connection.send(object())
2833+
with pytest.raises(TypeError):
2834+
connection.send([1, 2, 3])
28102835

28112836
def test_short_bytes(self):
28122837
"""
@@ -2845,6 +2870,16 @@ def test_short_memoryview(self):
28452870
assert count == 2
28462871
assert client.recv(2) == b'xy'
28472872

2873+
def test_short_bytearray(self):
2874+
"""
2875+
When passed a short bytearray, `Connection.send` transmits all of
2876+
it and returns the number of bytes sent.
2877+
"""
2878+
server, client = loopback()
2879+
count = server.send(bytearray(b'xy'))
2880+
assert count == 2
2881+
assert client.recv(2) == b'xy'
2882+
28482883
@skip_if_py3
28492884
def test_short_buffer(self):
28502885
"""
@@ -3015,6 +3050,8 @@ def test_wrong_args(self):
30153050
connection = Connection(Context(TLSv1_METHOD), None)
30163051
with pytest.raises(TypeError):
30173052
connection.sendall(object())
3053+
with pytest.raises(TypeError):
3054+
connection.sendall([1, 2, 3])
30183055

30193056
def test_short(self):
30203057
"""
@@ -3056,8 +3093,9 @@ def test_short_buffers(self):
30563093
`Connection.sendall` transmits all of them.
30573094
"""
30583095
server, client = loopback()
3059-
server.sendall(buffer(b'x'))
3060-
assert client.recv(1) == b'x'
3096+
count = server.sendall(buffer(b'xy'))
3097+
assert count == 2
3098+
assert client.recv(2) == b'xy'
30613099

30623100
def test_long(self):
30633101
"""

0 commit comments

Comments
 (0)