Skip to content

Commit 74355b0

Browse files
committed
make InsistentReaderBytesIO properly handle a stream that closes while being read from
1 parent f47f942 commit 74355b0

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

src/aws_encryption_sdk/internal/utils/streams.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,12 @@ def read(self, b=-1):
8181
remaining_bytes = b
8282
data = io.BytesIO()
8383
while True:
84-
chunk = to_bytes(self.__wrapped__.read(remaining_bytes))
84+
try:
85+
chunk = to_bytes(self.__wrapped__.read(remaining_bytes))
86+
except ValueError:
87+
if self.__wrapped__.closed:
88+
break
89+
raise
8590

8691
if not chunk:
8792
break

test/unit/test_util_streams.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from aws_encryption_sdk.internal.str_ops import to_bytes, to_str
2020
from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO, ROStream, TeeStream
2121

22-
from .unit_test_utils import NothingButRead, SometimesIncompleteReaderIO
22+
from .unit_test_utils import ExactlyTwoReads, NothingButRead, SometimesIncompleteReaderIO
2323

2424
pytestmark = [pytest.mark.unit, pytest.mark.local]
2525

@@ -74,3 +74,16 @@ def test_insistent_stream(source_length, bytes_to_read, stream_type, converter):
7474
assert (source_length >= bytes_to_read and len(test) == bytes_to_read) or (
7575
source_length < bytes_to_read and len(test) == source_length
7676
)
77+
78+
79+
def test_insistent_stream_close_partway_through():
80+
raw = data(length=100)
81+
source = ExactlyTwoReads(raw.getvalue())
82+
83+
wrapped = InsistentReaderBytesIO(source)
84+
85+
test = b""
86+
test += wrapped.read(10) # actually reads 10 bytes
87+
test += wrapped.read(10) # reads 5 bytes, stream is closed before third read can complete, truncating the result
88+
89+
assert test == raw.getvalue()[:15]

test/unit/unit_test_utils.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ def build_valid_kwargs_list(base, optional_kwargs):
5555

5656
class SometimesIncompleteReaderIO(io.BytesIO):
5757
def __init__(self, *args, **kwargs):
58-
self.__read_counter = 0
58+
self._read_counter = 0
5959
super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs)
6060

6161
def read(self, size=-1):
6262
"""Every other read request, return fewer than the requested number of bytes if more than one byte requested."""
63-
self.__read_counter += 1
64-
if size > 1 and self.__read_counter % 2 == 0:
63+
self._read_counter += 1
64+
if size > 1 and self._read_counter % 2 == 0:
6565
size //= 2
6666
return super(SometimesIncompleteReaderIO, self).read(size)
6767

@@ -72,3 +72,10 @@ def __init__(self, data):
7272

7373
def read(self, size=-1):
7474
return self._data.read(size)
75+
76+
77+
class ExactlyTwoReads(SometimesIncompleteReaderIO):
78+
def read(self, size=-1):
79+
if self._read_counter >= 2:
80+
self.close()
81+
return super(ExactlyTwoReads, self).read(size)

0 commit comments

Comments
 (0)