Skip to content

Commit 3b72d58

Browse files
committed
Allow the request URL to be used for subsequent responses
1 parent da28771 commit 3b72d58

File tree

1 file changed

+77
-4
lines changed

1 file changed

+77
-4
lines changed

src/lib.rs

+77-4
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,15 @@ pub enum CheckSupportMethod {
131131
Head,
132132
}
133133

134+
/// Which URL should be used for subsequent range requests?
135+
pub enum RangeRequestUrlSource {
136+
/// Use the initial request URL
137+
Request,
138+
139+
/// Use the initial response URL
140+
Response,
141+
}
142+
134143
fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
135144
response
136145
.error_for_status()
@@ -143,6 +152,7 @@ impl AsyncHttpRangeReader {
143152
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
144153
url: reqwest::Url,
145154
check_method: CheckSupportMethod,
155+
range_request_url_source: RangeRequestUrlSource,
146156
extra_headers: HeaderMap,
147157
) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> {
148158
let client = client.into();
@@ -156,15 +166,23 @@ impl AsyncHttpRangeReader {
156166
)
157167
.await?;
158168
let response_headers = response.headers().clone();
159-
let self_ = Self::from_tail_response(client, response, extra_headers).await?;
169+
let url = match range_request_url_source {
170+
RangeRequestUrlSource::Request => url,
171+
RangeRequestUrlSource::Response => response.url().clone(),
172+
};
173+
let self_ = Self::from_tail_response(client, response, url, extra_headers).await?;
160174
Ok((self_, response_headers))
161175
}
162176
CheckSupportMethod::Head => {
163177
let response =
164178
Self::initial_head_request(client.clone(), url.clone(), HeaderMap::default())
165179
.await?;
166180
let response_headers = response.headers().clone();
167-
let self_ = Self::from_head_response(client, response, extra_headers).await?;
181+
let url = match range_request_url_source {
182+
RangeRequestUrlSource::Request => url,
183+
RangeRequestUrlSource::Response => response.url().clone(),
184+
};
185+
let self_ = Self::from_head_response(client, response, url, extra_headers).await?;
168186
Ok((self_, response_headers))
169187
}
170188
}
@@ -200,6 +218,7 @@ impl AsyncHttpRangeReader {
200218
pub async fn from_tail_response(
201219
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
202220
tail_request_response: Response,
221+
url: Url,
203222
extra_headers: HeaderMap,
204223
) -> Result<Self, AsyncHttpRangeReaderError> {
205224
let client = client.into();
@@ -245,7 +264,7 @@ impl AsyncHttpRangeReader {
245264
let (state_tx, state_rx) = watch::channel(StreamerState::default());
246265
tokio::spawn(run_streamer(
247266
client,
248-
tail_request_response.url().clone(),
267+
url,
249268
extra_headers,
250269
Some((tail_request_response, start)),
251270
memory_map,
@@ -300,6 +319,7 @@ impl AsyncHttpRangeReader {
300319
pub async fn from_head_response(
301320
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
302321
head_response: Response,
322+
url: Url,
303323
extra_headers: HeaderMap,
304324
) -> Result<Self, AsyncHttpRangeReaderError> {
305325
let client = client.into();
@@ -345,7 +365,7 @@ impl AsyncHttpRangeReader {
345365
let (state_tx, state_rx) = watch::channel(StreamerState::default());
346366
tokio::spawn(run_streamer(
347367
client,
348-
head_response.url().clone(),
368+
url,
349369
extra_headers,
350370
None,
351371
memory_map,
@@ -688,6 +708,7 @@ mod test {
688708
Client::new(),
689709
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
690710
check_method,
711+
RangeRequestUrlSource::Response,
691712
HeaderMap::default(),
692713
)
693714
.await
@@ -783,6 +804,57 @@ mod test {
783804
Client::new(),
784805
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
785806
check_method,
807+
RangeRequestUrlSource::Response,
808+
HeaderMap::default(),
809+
)
810+
.await
811+
.expect("bla");
812+
813+
// Also open a simple file reader
814+
let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda"))
815+
.await
816+
.unwrap();
817+
818+
// Read until the end and make sure that the contents matches
819+
let mut range_read = vec![0; 64 * 1024];
820+
let mut file_read = vec![0; 64 * 1024];
821+
loop {
822+
// Read with the async reader
823+
let range_read_bytes = range.read(&mut range_read).await.unwrap();
824+
825+
// Read directly from the file
826+
let file_read_bytes = file
827+
.read_exact(&mut file_read[0..range_read_bytes])
828+
.await
829+
.unwrap();
830+
831+
assert_eq!(range_read_bytes, file_read_bytes);
832+
assert_eq!(
833+
range_read[0..range_read_bytes],
834+
file_read[0..file_read_bytes]
835+
);
836+
837+
if file_read_bytes == 0 && range_read_bytes == 0 {
838+
break;
839+
}
840+
}
841+
}
842+
843+
#[rstest]
844+
#[case(RangeRequestUrlSource::Request)]
845+
#[case(RangeRequestUrlSource::Response)]
846+
#[tokio::test]
847+
async fn async_range_reader_url_source(#[case] url_source: RangeRequestUrlSource) {
848+
// Spawn a static file server
849+
let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data");
850+
let server = StaticDirectoryServer::new(&path);
851+
852+
// Construct an AsyncRangeReader
853+
let (mut range, _) = AsyncHttpRangeReader::new(
854+
Client::new(),
855+
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
856+
CheckSupportMethod::Head,
857+
url_source,
786858
HeaderMap::default(),
787859
)
788860
.await
@@ -825,6 +897,7 @@ mod test {
825897
Client::new(),
826898
server.url().join("not-found").unwrap(),
827899
CheckSupportMethod::Head,
900+
RangeRequestUrlSource::Response,
828901
HeaderMap::default(),
829902
)
830903
.await

0 commit comments

Comments
 (0)