Skip to content

Commit 43cf88b

Browse files
authored
Merge pull request #3647 from aws/zoewang/lengthaware-enforce-length-merge
Validate sent body len, remove buffering for known content length cases
2 parents c121fca + 4958a32 commit 43cf88b

File tree

12 files changed

+374
-121
lines changed

12 files changed

+374
-121
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "feature",
3+
"category": "AWS SDK for Java v2",
4+
"contributor": "",
5+
"description": "The SDK now throws exception for input streaming operation if the stream has fewer bytes (i.e. reaches EOF) before the expected length is reached."
6+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "feature",
3+
"category": "AWS SDK for Java v2",
4+
"contributor": "",
5+
"description": "The SDK now does not buffer input data from `RequestBody#fromInputStream` in cases where the InputStream does not support mark and reset."
6+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "feature",
3+
"category": "AWS SDK for Java v2",
4+
"contributor": "",
5+
"description": "The SDK now does not buffer input data from ContentStreamProvider in cases where content length is known."
6+
}

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeHttpRequestStage.java

+44
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,29 @@
1515

1616
package software.amazon.awssdk.core.internal.http.pipeline.stages;
1717

18+
import static software.amazon.awssdk.http.Header.CONTENT_LENGTH;
19+
1820
import java.time.Duration;
21+
import java.util.Optional;
1922
import software.amazon.awssdk.annotations.SdkInternalApi;
2023
import software.amazon.awssdk.core.client.config.SdkClientOption;
2124
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
2225
import software.amazon.awssdk.core.internal.http.HttpClientDependencies;
2326
import software.amazon.awssdk.core.internal.http.InterruptMonitor;
2427
import software.amazon.awssdk.core.internal.http.RequestExecutionContext;
2528
import software.amazon.awssdk.core.internal.http.pipeline.RequestPipeline;
29+
import software.amazon.awssdk.core.internal.io.SdkLengthAwareInputStream;
2630
import software.amazon.awssdk.core.internal.util.MetricUtils;
2731
import software.amazon.awssdk.core.metrics.CoreMetric;
32+
import software.amazon.awssdk.http.ContentStreamProvider;
2833
import software.amazon.awssdk.http.ExecutableHttpRequest;
2934
import software.amazon.awssdk.http.HttpExecuteRequest;
3035
import software.amazon.awssdk.http.HttpExecuteResponse;
3136
import software.amazon.awssdk.http.SdkHttpClient;
3237
import software.amazon.awssdk.http.SdkHttpFullRequest;
3338
import software.amazon.awssdk.http.SdkHttpFullResponse;
3439
import software.amazon.awssdk.metrics.MetricCollector;
40+
import software.amazon.awssdk.utils.Logger;
3541
import software.amazon.awssdk.utils.Pair;
3642

3743
/**
@@ -40,6 +46,7 @@
4046
@SdkInternalApi
4147
public class MakeHttpRequestStage
4248
implements RequestPipeline<SdkHttpFullRequest, Pair<SdkHttpFullRequest, SdkHttpFullResponse>> {
49+
private static final Logger LOG = Logger.loggerFor(MakeHttpRequestStage.class);
4350

4451
private final SdkHttpClient sdkHttpClient;
4552

@@ -65,6 +72,8 @@ private HttpExecuteResponse executeHttpRequest(SdkHttpFullRequest request, Reque
6572

6673
MetricCollector httpMetricCollector = MetricUtils.createHttpMetricsCollector(context);
6774

75+
request = enforceContentLengthIfPresent(request);
76+
6877
ExecutableHttpRequest requestCallable = sdkHttpClient
6978
.prepareRequest(HttpExecuteRequest.builder()
7079
.request(request)
@@ -94,4 +103,39 @@ private static long updateMetricCollectionAttributes(RequestExecutionContext con
94103
now);
95104
return now;
96105
}
106+
107+
private static SdkHttpFullRequest enforceContentLengthIfPresent(SdkHttpFullRequest request) {
108+
Optional<ContentStreamProvider> requestContentStreamProviderOptional = request.contentStreamProvider();
109+
110+
if (!requestContentStreamProviderOptional.isPresent()) {
111+
return request;
112+
}
113+
114+
Optional<Long> contentLength = contentLength(request);
115+
if (!contentLength.isPresent()) {
116+
LOG.debug(() -> String.format("Request contains a body but does not have a Content-Length header. Not validating "
117+
+ "the amount of data sent to the service: %s", request));
118+
return request;
119+
}
120+
121+
ContentStreamProvider requestContentProvider = requestContentStreamProviderOptional.get();
122+
ContentStreamProvider lengthVerifyingProvider = () -> new SdkLengthAwareInputStream(requestContentProvider.newStream(),
123+
contentLength.get());
124+
return request.toBuilder()
125+
.contentStreamProvider(lengthVerifyingProvider)
126+
.build();
127+
}
128+
129+
private static Optional<Long> contentLength(SdkHttpFullRequest request) {
130+
Optional<String> contentLengthHeader = request.firstMatchingHeader(CONTENT_LENGTH);
131+
132+
if (contentLengthHeader.isPresent()) {
133+
try {
134+
return Optional.of(Long.parseLong(contentLengthHeader.get()));
135+
} catch (NumberFormatException e) {
136+
LOG.warn(() -> "Unable to parse 'Content-Length' header. Treating it as non existent.");
137+
}
138+
}
139+
return Optional.empty();
140+
}
97141
}

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/SdkLengthAwareInputStream.java

+19-5
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
import software.amazon.awssdk.utils.Validate;
2626

2727
/**
28-
* An {@code InputStream} that is aware of its length. The main purpose of this class is to support truncating streams to a
29-
* length that is shorter than the total length of the stream.
28+
* An {@code InputStream} that is aware of its length. This class enforces that we sent exactly the number of bytes equal to
29+
* the input length. If the wrapped stream has more bytes than the expected length, it will be truncated to length. If the stream
30+
* has less bytes (i.e. reaches EOF) before the expected length is reached, it will throw {@code IOException}.
3031
*/
3132
@SdkInternalApi
3233
public class SdkLengthAwareInputStream extends FilterInputStream {
@@ -48,8 +49,13 @@ public int read() throws IOException {
4849
}
4950

5051
int read = super.read();
52+
5153
if (read != -1) {
5254
remaining--;
55+
} else if (remaining != 0) { // EOF, ensure we've read the number of expected bytes
56+
throw new IllegalStateException("The request content has fewer bytes than the "
57+
+ "specified "
58+
+ "content-length: " + length + " bytes.");
5359
}
5460
return read;
5561
}
@@ -61,12 +67,20 @@ public int read(byte[] b, int off, int len) throws IOException {
6167
return -1;
6268
}
6369

64-
len = Math.min(len, saturatedCast(remaining));
65-
int read = super.read(b, off, len);
66-
if (read > 0) {
70+
int readLen = Math.min(len, saturatedCast(remaining));
71+
72+
int read = super.read(b, off, readLen);
73+
if (read != -1) {
6774
remaining -= read;
6875
}
6976

77+
// EOF, ensure we've read the number of expected bytes
78+
if (read == -1 && remaining != 0) {
79+
throw new IllegalStateException("The request content has fewer bytes than the "
80+
+ "specified "
81+
+ "content-length: " + length + " bytes.");
82+
}
83+
7084
return read;
7185
}
7286

core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java

+6-16
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,14 @@ public static RequestBody fromFile(File file) {
136136
* @return RequestBody instance.
137137
*/
138138
public static RequestBody fromInputStream(InputStream inputStream, long contentLength) {
139-
// NOTE: does not have an effect if mark not supported
140139
IoUtils.markStreamWithMaxReadLimit(inputStream);
141140
InputStream nonCloseable = nonCloseableInputStream(inputStream);
142-
ContentStreamProvider provider;
143-
if (nonCloseable.markSupported()) {
144-
// stream supports mark + reset
145-
provider = () -> {
141+
return fromContentProvider(() -> {
142+
if (nonCloseable.markSupported()) {
146143
invokeSafely(nonCloseable::reset);
147-
return nonCloseable;
148-
};
149-
} else {
150-
// stream doesn't support mark + reset, make sure to buffer it
151-
provider = new BufferingContentStreamProvider(() -> nonCloseable, contentLength);
152-
}
153-
return new RequestBody(provider, contentLength, Mimetype.MIMETYPE_OCTET_STREAM);
144+
}
145+
return nonCloseable;
146+
}, contentLength, Mimetype.MIMETYPE_OCTET_STREAM);
154147
}
155148

156149
/**
@@ -224,9 +217,6 @@ public static RequestBody empty() {
224217
/**
225218
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider}.
226219
* <p>
227-
* Important: Be aware that this implementation requires buffering the contents for {@code ContentStreamProvider}, which can
228-
* cause increased memory usage.
229-
* <p>
230220
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
231221
* S3's documentation for
232222
* <a href="https://docs.aws.amazon.com/AmazonS3/latest/API/s3_example_s3_Scenario_UploadStream_section.html">alternative
@@ -239,7 +229,7 @@ public static RequestBody empty() {
239229
* @return The created {@code RequestBody}.
240230
*/
241231
public static RequestBody fromContentProvider(ContentStreamProvider provider, long contentLength, String mimeType) {
242-
return new RequestBody(new BufferingContentStreamProvider(provider, contentLength), contentLength, mimeType);
232+
return new RequestBody(provider, contentLength, mimeType);
243233
}
244234

245235
/**

core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeHttpRequestStageTest.java

+102-8
Original file line numberDiff line numberDiff line change
@@ -16,42 +16,48 @@
1616
package software.amazon.awssdk.core.internal.http.pipeline.stages;
1717

1818
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
1920
import static org.mockito.ArgumentMatchers.any;
2021
import static org.mockito.ArgumentMatchers.eq;
2122
import static org.mockito.Mockito.mock;
2223
import static org.mockito.Mockito.verify;
2324
import static org.mockito.Mockito.when;
2425
import static software.amazon.awssdk.core.client.config.SdkClientOption.SYNC_HTTP_CLIENT;
26+
27+
import java.io.ByteArrayInputStream;
2528
import java.io.IOException;
26-
import org.junit.Before;
27-
import org.junit.Test;
28-
import org.junit.runner.RunWith;
29+
import java.io.InputStream;
30+
import java.util.stream.Stream;
31+
import org.junit.jupiter.api.BeforeEach;
32+
import org.junit.jupiter.api.Test;
33+
import org.junit.jupiter.params.ParameterizedTest;
34+
import org.junit.jupiter.params.provider.Arguments;
35+
import org.junit.jupiter.params.provider.MethodSource;
2936
import org.mockito.ArgumentCaptor;
30-
import org.mockito.Mock;
31-
import org.mockito.junit.MockitoJUnitRunner;
3237
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
3338
import software.amazon.awssdk.core.http.ExecutionContext;
3439
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
3540
import software.amazon.awssdk.core.internal.http.HttpClientDependencies;
3641
import software.amazon.awssdk.core.internal.http.RequestExecutionContext;
3742
import software.amazon.awssdk.core.internal.http.timers.TimeoutTracker;
43+
import software.amazon.awssdk.core.internal.io.SdkLengthAwareInputStream;
44+
import software.amazon.awssdk.http.ContentStreamProvider;
3845
import software.amazon.awssdk.http.HttpExecuteRequest;
3946
import software.amazon.awssdk.http.SdkHttpClient;
4047
import software.amazon.awssdk.http.SdkHttpFullRequest;
4148
import software.amazon.awssdk.http.SdkHttpMethod;
4249
import software.amazon.awssdk.metrics.MetricCollector;
4350
import utils.ValidSdkObjects;
4451

45-
@RunWith(MockitoJUnitRunner.class)
4652
public class MakeHttpRequestStageTest {
4753

48-
@Mock
4954
private SdkHttpClient mockClient;
5055

5156
private MakeHttpRequestStage stage;
5257

53-
@Before
58+
@BeforeEach
5459
public void setup() throws IOException {
60+
mockClient = mock(SdkHttpClient.class);
5561
SdkClientConfiguration config = SdkClientConfiguration.builder().option(SYNC_HTTP_CLIENT, mockClient).build();
5662
stage = new MakeHttpRequestStage(HttpClientDependencies.builder().clientConfiguration(config).build());
5763
}
@@ -94,4 +100,92 @@ public void testExecute_contextContainsMetricCollector_addsChildToExecuteRequest
94100
assertThat(httpRequestCaptor.getValue().metricCollector()).contains(childCollector);
95101
}
96102
}
103+
104+
@ParameterizedTest
105+
@MethodSource("contentLengthVerificationInputs")
106+
public void execute_testLengthChecking(String description,
107+
ContentStreamProvider provider,
108+
Long contentLength,
109+
boolean expectLengthAware) {
110+
SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder()
111+
.method(SdkHttpMethod.PUT)
112+
.host("mybucket.s3.us-west-2.amazonaws.com")
113+
.protocol("https");
114+
115+
if (provider != null) {
116+
requestBuilder.contentStreamProvider(provider);
117+
}
118+
119+
if (contentLength != null) {
120+
requestBuilder.putHeader("Content-Length", String.valueOf(contentLength));
121+
}
122+
123+
when(mockClient.prepareRequest(any()))
124+
.thenThrow(new RuntimeException("BOOM"));
125+
126+
assertThatThrownBy(() -> stage.execute(requestBuilder.build(), createContext())).hasMessage("BOOM");
127+
128+
ArgumentCaptor<HttpExecuteRequest> requestCaptor = ArgumentCaptor.forClass(HttpExecuteRequest.class);
129+
130+
verify(mockClient).prepareRequest(requestCaptor.capture());
131+
132+
HttpExecuteRequest capturedRequest = requestCaptor.getValue();
133+
134+
if (provider != null) {
135+
InputStream requestContentStream = capturedRequest.contentStreamProvider().get().newStream();
136+
137+
if (expectLengthAware) {
138+
assertThat(requestContentStream).isInstanceOf(SdkLengthAwareInputStream.class);
139+
} else {
140+
assertThat(requestContentStream).isNotInstanceOf(SdkLengthAwareInputStream.class);
141+
}
142+
} else {
143+
assertThat(capturedRequest.contentStreamProvider()).isEmpty();
144+
}
145+
}
146+
147+
private static Stream<Arguments> contentLengthVerificationInputs() {
148+
return Stream.of(
149+
Arguments.of(
150+
"Provider present, ContentLength present",
151+
(ContentStreamProvider) () -> new ByteArrayInputStream(new byte[16]),
152+
16L,
153+
true
154+
),
155+
Arguments.of(
156+
"Provider present, ContentLength not present",
157+
(ContentStreamProvider) () -> new ByteArrayInputStream(new byte[16]),
158+
null,
159+
false
160+
),
161+
Arguments.of(
162+
"Provider not present, ContentLength present",
163+
null,
164+
16L,
165+
false
166+
),
167+
Arguments.of(
168+
"Provider not present, ContentLength not present",
169+
null,
170+
null,
171+
false
172+
)
173+
);
174+
}
175+
176+
private static RequestExecutionContext createContext() {
177+
ExecutionContext executionContext = ExecutionContext.builder()
178+
.executionAttributes(new ExecutionAttributes())
179+
.build();
180+
181+
RequestExecutionContext context = RequestExecutionContext.builder()
182+
.originalRequest(ValidSdkObjects.sdkRequest())
183+
.executionContext(executionContext)
184+
.build();
185+
186+
context.apiCallAttemptTimeoutTracker(mock(TimeoutTracker.class));
187+
context.apiCallTimeoutTracker(mock(TimeoutTracker.class));
188+
189+
return context;
190+
}
97191
}

0 commit comments

Comments
 (0)