|
16 | 16 | package software.amazon.awssdk.core.internal.http.pipeline.stages;
|
17 | 17 |
|
18 | 18 | import static org.assertj.core.api.Assertions.assertThat;
|
| 19 | +import static org.assertj.core.api.Assertions.assertThatThrownBy; |
19 | 20 | import static org.mockito.ArgumentMatchers.any;
|
20 | 21 | import static org.mockito.ArgumentMatchers.eq;
|
21 | 22 | import static org.mockito.Mockito.mock;
|
22 | 23 | import static org.mockito.Mockito.verify;
|
23 | 24 | import static org.mockito.Mockito.when;
|
24 | 25 | import static software.amazon.awssdk.core.client.config.SdkClientOption.SYNC_HTTP_CLIENT;
|
| 26 | + |
| 27 | +import java.io.ByteArrayInputStream; |
25 | 28 | import java.io.IOException;
|
26 |
| -import org.junit.Before; |
27 |
| -import org.junit.Test; |
28 |
| -import org.junit.runner.RunWith; |
| 29 | +import java.io.InputStream; |
| 30 | +import java.util.stream.Stream; |
| 31 | +import org.junit.jupiter.api.BeforeEach; |
| 32 | +import org.junit.jupiter.api.Test; |
| 33 | +import org.junit.jupiter.params.ParameterizedTest; |
| 34 | +import org.junit.jupiter.params.provider.Arguments; |
| 35 | +import org.junit.jupiter.params.provider.MethodSource; |
29 | 36 | import org.mockito.ArgumentCaptor;
|
30 |
| -import org.mockito.Mock; |
31 |
| -import org.mockito.junit.MockitoJUnitRunner; |
32 | 37 | import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
|
33 | 38 | import software.amazon.awssdk.core.http.ExecutionContext;
|
34 | 39 | import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
|
35 | 40 | import software.amazon.awssdk.core.internal.http.HttpClientDependencies;
|
36 | 41 | import software.amazon.awssdk.core.internal.http.RequestExecutionContext;
|
37 | 42 | import software.amazon.awssdk.core.internal.http.timers.TimeoutTracker;
|
| 43 | +import software.amazon.awssdk.core.internal.io.SdkLengthAwareInputStream; |
| 44 | +import software.amazon.awssdk.http.ContentStreamProvider; |
38 | 45 | import software.amazon.awssdk.http.HttpExecuteRequest;
|
39 | 46 | import software.amazon.awssdk.http.SdkHttpClient;
|
40 | 47 | import software.amazon.awssdk.http.SdkHttpFullRequest;
|
41 | 48 | import software.amazon.awssdk.http.SdkHttpMethod;
|
42 | 49 | import software.amazon.awssdk.metrics.MetricCollector;
|
43 | 50 | import utils.ValidSdkObjects;
|
44 | 51 |
|
45 |
| -@RunWith(MockitoJUnitRunner.class) |
46 | 52 | public class MakeHttpRequestStageTest {
|
47 | 53 |
|
48 |
| - @Mock |
49 | 54 | private SdkHttpClient mockClient;
|
50 | 55 |
|
51 | 56 | private MakeHttpRequestStage stage;
|
52 | 57 |
|
53 |
| - @Before |
| 58 | + @BeforeEach |
54 | 59 | public void setup() throws IOException {
|
| 60 | + mockClient = mock(SdkHttpClient.class); |
55 | 61 | SdkClientConfiguration config = SdkClientConfiguration.builder().option(SYNC_HTTP_CLIENT, mockClient).build();
|
56 | 62 | stage = new MakeHttpRequestStage(HttpClientDependencies.builder().clientConfiguration(config).build());
|
57 | 63 | }
|
@@ -94,4 +100,92 @@ public void testExecute_contextContainsMetricCollector_addsChildToExecuteRequest
|
94 | 100 | assertThat(httpRequestCaptor.getValue().metricCollector()).contains(childCollector);
|
95 | 101 | }
|
96 | 102 | }
|
| 103 | + |
| 104 | + @ParameterizedTest |
| 105 | + @MethodSource("contentLengthVerificationInputs") |
| 106 | + public void execute_testLengthChecking(String description, |
| 107 | + ContentStreamProvider provider, |
| 108 | + Long contentLength, |
| 109 | + boolean expectLengthAware) { |
| 110 | + SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder() |
| 111 | + .method(SdkHttpMethod.PUT) |
| 112 | + .host("mybucket.s3.us-west-2.amazonaws.com") |
| 113 | + .protocol("https"); |
| 114 | + |
| 115 | + if (provider != null) { |
| 116 | + requestBuilder.contentStreamProvider(provider); |
| 117 | + } |
| 118 | + |
| 119 | + if (contentLength != null) { |
| 120 | + requestBuilder.putHeader("Content-Length", String.valueOf(contentLength)); |
| 121 | + } |
| 122 | + |
| 123 | + when(mockClient.prepareRequest(any())) |
| 124 | + .thenThrow(new RuntimeException("BOOM")); |
| 125 | + |
| 126 | + assertThatThrownBy(() -> stage.execute(requestBuilder.build(), createContext())).hasMessage("BOOM"); |
| 127 | + |
| 128 | + ArgumentCaptor<HttpExecuteRequest> requestCaptor = ArgumentCaptor.forClass(HttpExecuteRequest.class); |
| 129 | + |
| 130 | + verify(mockClient).prepareRequest(requestCaptor.capture()); |
| 131 | + |
| 132 | + HttpExecuteRequest capturedRequest = requestCaptor.getValue(); |
| 133 | + |
| 134 | + if (provider != null) { |
| 135 | + InputStream requestContentStream = capturedRequest.contentStreamProvider().get().newStream(); |
| 136 | + |
| 137 | + if (expectLengthAware) { |
| 138 | + assertThat(requestContentStream).isInstanceOf(SdkLengthAwareInputStream.class); |
| 139 | + } else { |
| 140 | + assertThat(requestContentStream).isNotInstanceOf(SdkLengthAwareInputStream.class); |
| 141 | + } |
| 142 | + } else { |
| 143 | + assertThat(capturedRequest.contentStreamProvider()).isEmpty(); |
| 144 | + } |
| 145 | + } |
| 146 | + |
| 147 | + private static Stream<Arguments> contentLengthVerificationInputs() { |
| 148 | + return Stream.of( |
| 149 | + Arguments.of( |
| 150 | + "Provider present, ContentLength present", |
| 151 | + (ContentStreamProvider) () -> new ByteArrayInputStream(new byte[16]), |
| 152 | + 16L, |
| 153 | + true |
| 154 | + ), |
| 155 | + Arguments.of( |
| 156 | + "Provider present, ContentLength not present", |
| 157 | + (ContentStreamProvider) () -> new ByteArrayInputStream(new byte[16]), |
| 158 | + null, |
| 159 | + false |
| 160 | + ), |
| 161 | + Arguments.of( |
| 162 | + "Provider not present, ContentLength present", |
| 163 | + null, |
| 164 | + 16L, |
| 165 | + false |
| 166 | + ), |
| 167 | + Arguments.of( |
| 168 | + "Provider not present, ContentLength not present", |
| 169 | + null, |
| 170 | + null, |
| 171 | + false |
| 172 | + ) |
| 173 | + ); |
| 174 | + } |
| 175 | + |
| 176 | + private static RequestExecutionContext createContext() { |
| 177 | + ExecutionContext executionContext = ExecutionContext.builder() |
| 178 | + .executionAttributes(new ExecutionAttributes()) |
| 179 | + .build(); |
| 180 | + |
| 181 | + RequestExecutionContext context = RequestExecutionContext.builder() |
| 182 | + .originalRequest(ValidSdkObjects.sdkRequest()) |
| 183 | + .executionContext(executionContext) |
| 184 | + .build(); |
| 185 | + |
| 186 | + context.apiCallAttemptTimeoutTracker(mock(TimeoutTracker.class)); |
| 187 | + context.apiCallTimeoutTracker(mock(TimeoutTracker.class)); |
| 188 | + |
| 189 | + return context; |
| 190 | + } |
97 | 191 | }
|
0 commit comments