diff --git a/src/error.rs b/src/error.rs index 0609817..14b60af 100644 --- a/src/error.rs +++ b/src/error.rs @@ -30,6 +30,10 @@ pub enum AsyncHttpRangeReaderError { /// Memory mapping the file failed #[error("memory mapping the file failed")] MemoryMapError(#[source] Arc), + + /// Error building the reader + #[error("error building the reader: {0}")] + BuilderError(#[source] Arc), } impl From for AsyncHttpRangeReaderError { @@ -49,3 +53,25 @@ impl From for AsyncHttpRangeReaderError { AsyncHttpRangeReaderError::TransportError(Arc::new(err.into())) } } + +impl From for AsyncHttpRangeReaderError { + fn from(err: AsyncHttpRangeReaderBuilderError) -> Self { + AsyncHttpRangeReaderError::BuilderError(Arc::new(err)) + } +} + +/// Error type used for [`crate::AsyncHttpRangeReaderBuilder`] +#[derive(Clone, Debug, thiserror::Error)] +pub enum AsyncHttpRangeReaderBuilderError { + /// Required field 'content_length' is zero + #[error("required field 'content_length' is zero")] + InvalidContentLength, + + /// Required field 'url' is missing + #[error("required field 'url' is missing")] + MissingUrl, + + /// Memory mapping the file failed + #[error("memory mapping the file failed")] + MemoryMapError(#[source] Arc), +} diff --git a/src/lib.rs b/src/lib.rs index 786f454..bfc94bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ mod error; mod sparse_range; +use error::AsyncHttpRangeReaderBuilderError; use futures::{FutureExt, Stream, StreamExt}; use http_content_range::{ContentRange, ContentRangeBytes}; use memmap2::MmapMut; @@ -72,7 +73,7 @@ pub use error::AsyncHttpRangeReaderError; /// if response.status() == reqwest::StatusCode::NOT_MODIFIED { /// Ok(None) /// } else { -/// let reader = AsyncHttpRangeReader::from_head_response(client, response, HeaderMap::default()).await?; +/// let reader = AsyncHttpRangeReader::from_head_response(client, response, HeaderMap::default())?; /// Ok(Some(reader)) /// } /// } @@ -156,7 +157,7 @@ impl AsyncHttpRangeReader { ) .await?; let response_headers = response.headers().clone(); - let self_ = Self::from_tail_response(client, response, extra_headers).await?; + let self_ = Self::from_tail_response(client, response, extra_headers)?; Ok((self_, response_headers)) } CheckSupportMethod::Head => { @@ -164,12 +165,19 @@ impl AsyncHttpRangeReader { Self::initial_head_request(client.clone(), url.clone(), HeaderMap::default()) .await?; let response_headers = response.headers().clone(); - let self_ = Self::from_head_response(client, response, extra_headers).await?; + let self_ = Self::from_head_response(client, response, extra_headers)?; Ok((self_, response_headers)) } } } + // Make a builder for AsyncHttpRangeReader + pub fn builder( + client: reqwest_middleware::ClientWithMiddleware, + ) -> AsyncHttpRangeReaderBuilder { + AsyncHttpRangeReaderBuilder::new(client) + } + /// Send an initial range request to determine if the remote accepts range /// requests. This will return a number of bytes from the end of the stream. Use the /// `initial_chunk_size` parameter to define how many bytes should be requested from the end. @@ -197,81 +205,16 @@ impl AsyncHttpRangeReader { /// Initialize the reader from [`AsyncHttpRangeReader::initial_tail_request`] (or a user /// provided response that also has a range of bytes from the end as body) - pub async fn from_tail_response( + pub fn from_tail_response( client: impl Into, tail_request_response: Response, extra_headers: HeaderMap, ) -> Result { - let client = client.into(); - - // Get the size of the file from this initial request - let content_range = ContentRange::parse( - tail_request_response - .headers() - .get(reqwest::header::CONTENT_RANGE) - .ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)? - .to_str() - .map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?, - ); - let (start, finish, complete_length) = match content_range { - ContentRange::Bytes(ContentRangeBytes { - first_byte, - last_byte, - complete_length, - }) => (first_byte, last_byte, complete_length), - _ => return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported), - }; - - // Allocate a memory map to hold the data - let memory_map = memmap2::MmapOptions::new() - .len(complete_length as usize) - .map_anon() - .map_err(Arc::new) - .map_err(AsyncHttpRangeReaderError::MemoryMapError)?; - - // SAFETY: Get a read-only slice to the memory. This is safe because the memory map is never - // reallocated and we keep track of the initialized part. - let memory_map_slice = - unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) }; - - let requested_range = - SparseRange::from_range(complete_length - (finish - start)..complete_length); - - // adding more than 2 entries to the channel would block the sender. I assumed two would - // suffice because I would want to 1) prefetch a certain range and 2) read stuff via the - // AsyncRead implementation. Any extra would simply have to wait for one of these to - // succeed. I eventually used 10 because who cares. - let (request_tx, request_rx) = tokio::sync::mpsc::channel(10); - let (state_tx, state_rx) = watch::channel(StreamerState::default()); - tokio::spawn(run_streamer( - client, - tail_request_response.url().clone(), - extra_headers, - Some((tail_request_response, start)), - memory_map, - state_tx, - request_rx, - )); - - // Configure the initial state of the streamer. - let mut streamer_state = StreamerState::default(); - streamer_state - .requested_ranges - .push(complete_length - (finish - start)..complete_length); - - let reader = Self { - len: memory_map_slice.len() as u64, - inner: Mutex::new(Inner { - data: memory_map_slice, - pos: 0, - requested_range, - streamer_state, - streamer_state_rx: WatchStream::new(state_rx), - request_tx, - poll_request_tx: None, - }), - }; - Ok(reader) + AsyncHttpRangeReaderBuilder::new(client.into()) + .from_tail_response(tail_request_response)? + .extra_headers(extra_headers) + .build() + .map_err(AsyncHttpRangeReaderError::from) } /// Send an initial range request to determine if the remote accepts range @@ -297,13 +240,110 @@ impl AsyncHttpRangeReader { /// Initialize the reader from [`AsyncHttpRangeReader::initial_head_request`] (or a user /// provided response the) - pub async fn from_head_response( + pub fn from_head_response( client: impl Into, head_response: Response, extra_headers: HeaderMap, ) -> Result { - let client = client.into(); + AsyncHttpRangeReaderBuilder::new(client.into()) + .from_head_response(head_response)? + .extra_headers(extra_headers) + .build() + .map_err(AsyncHttpRangeReaderError::from) + } + + /// Returns the ranges that this instance actually performed HTTP requests for. + pub async fn requested_ranges(&self) -> Vec> { + let mut inner = self.inner.lock().await; + if let Some(Some(new_state)) = inner.streamer_state_rx.next().now_or_never() { + inner.streamer_state = new_state; + } + inner.streamer_state.requested_ranges.clone() + } + + /// Prefetches a range of bytes from the remote. When specifying a large range this can + /// drastically reduce the number of requests required to the server. + pub async fn prefetch(&mut self, bytes: Range) { + let inner = self.inner.get_mut(); + + // Ensure the range is withing the file size and non-zero of length. + let range = bytes.start..(bytes.end.min(inner.data.len() as u64)); + if range.start >= range.end { + return; + } + + // Check if the range has been requested or not. + let inner = self.inner.get_mut(); + if let Some((new_range, _)) = inner.requested_range.cover(range.clone()) { + let _ = inner.request_tx.send(range).await; + inner.requested_range = new_range; + } + } + + /// Returns the length of the stream in bytes + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u64 { + self.len + } +} + +pub struct AsyncHttpRangeReaderBuilder { + client: reqwest_middleware::ClientWithMiddleware, + url: Option, + extra_headers: HeaderMap, + requested_range: SparseRange, + streamer_state: StreamerState, + initial_tail_response: Option<(Response, u64)>, + content_length: usize, +} + +impl AsyncHttpRangeReaderBuilder { + pub fn new(client: reqwest_middleware::ClientWithMiddleware) -> Self { + Self { + client, + url: None, + extra_headers: HeaderMap::default(), + requested_range: SparseRange::default(), + streamer_state: StreamerState::default(), + initial_tail_response: None, + content_length: 0, + } + } + + pub fn url(mut self, url: Url) -> Self { + self.url = Some(url); + self + } + + pub fn extra_headers(mut self, extra_headers: HeaderMap) -> Self { + self.extra_headers = extra_headers; + self + } + + fn requested_range(mut self, requested_range: SparseRange) -> Self { + self.requested_range = requested_range; + self + } + + fn streamer_state(mut self, streamer_state: StreamerState) -> Self { + self.streamer_state = streamer_state; + self + } + + fn initial_tail_response(mut self, initial_tail_response: Option<(Response, u64)>) -> Self { + self.initial_tail_response = initial_tail_response; + self + } + + pub fn content_length(mut self, content_length: usize) -> Self { + self.content_length = content_length; + self + } + pub fn from_head_response( + self, + head_response: Response, + ) -> Result { // Are range requests supported? if head_response .headers() @@ -323,46 +363,95 @@ impl AsyncHttpRangeReader { .parse() .map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?; - // Allocate a memory map to hold the data + let builder = self + .url(head_response.url().clone()) + .content_length(content_length as usize); + + Ok(builder) + } + + pub fn from_tail_response( + self, + tail_response: Response, + ) -> Result { + let content_range = ContentRange::parse( + tail_response + .headers() + .get(reqwest::header::CONTENT_RANGE) + .ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)? + .to_str() + .map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?, + ); + let (start, finish, complete_length) = match content_range { + ContentRange::Bytes(ContentRangeBytes { + first_byte, + last_byte, + complete_length, + }) => (first_byte, last_byte, complete_length), + _ => return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported), + }; + + let requested_range = + SparseRange::from_range(complete_length - (finish - start)..complete_length); + + // Configure the initial state of the streamer. + let mut streamer_state = StreamerState::default(); + streamer_state + .requested_ranges + .push(complete_length - (finish - start)..complete_length); + let builder = self + .url(tail_response.url().clone()) + .initial_tail_response(Some((tail_response, start))) + .requested_range(requested_range) + .content_length(complete_length as usize) + .streamer_state(streamer_state); + Ok(builder) + } + + pub fn build(self) -> Result { + let Some(url) = self.url else { + return Err(AsyncHttpRangeReaderBuilderError::MissingUrl); + }; + + if self.content_length == 0 { + return Err(AsyncHttpRangeReaderBuilderError::InvalidContentLength); + } + let memory_map = memmap2::MmapOptions::new() - .len(content_length as _) + .len(self.content_length) .map_anon() .map_err(Arc::new) - .map_err(AsyncHttpRangeReaderError::MemoryMapError)?; + .map_err(AsyncHttpRangeReaderBuilderError::MemoryMapError)?; // SAFETY: Get a read-only slice to the memory. This is safe because the memory map is never // reallocated and we keep track of the initialized part. let memory_map_slice = unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) }; - let requested_range = SparseRange::default(); - // adding more than 2 entries to the channel would block the sender. I assumed two would // suffice because I would want to 1) prefetch a certain range and 2) read stuff via the // AsyncRead implementation. Any extra would simply have to wait for one of these to // succeed. I eventually used 10 because who cares. let (request_tx, request_rx) = tokio::sync::mpsc::channel(10); let (state_tx, state_rx) = watch::channel(StreamerState::default()); + tokio::spawn(run_streamer( - client, - head_response.url().clone(), - extra_headers, - None, + self.client, + url.clone(), + self.extra_headers, + self.initial_tail_response, memory_map, state_tx, request_rx, )); - // Configure the initial state of the streamer. - let streamer_state = StreamerState::default(); - - let reader = Self { + let reader = AsyncHttpRangeReader { len: memory_map_slice.len() as u64, inner: Mutex::new(Inner { data: memory_map_slice, pos: 0, - requested_range, - streamer_state, + requested_range: self.requested_range, + streamer_state: self.streamer_state, streamer_state_rx: WatchStream::new(state_rx), request_tx, poll_request_tx: None, @@ -370,40 +459,6 @@ impl AsyncHttpRangeReader { }; Ok(reader) } - - /// Returns the ranges that this instance actually performed HTTP requests for. - pub async fn requested_ranges(&self) -> Vec> { - let mut inner = self.inner.lock().await; - if let Some(Some(new_state)) = inner.streamer_state_rx.next().now_or_never() { - inner.streamer_state = new_state; - } - inner.streamer_state.requested_ranges.clone() - } - - /// Prefetches a range of bytes from the remote. When specifying a large range this can - /// drastically reduce the number of requests required to the server. - pub async fn prefetch(&mut self, bytes: Range) { - let inner = self.inner.get_mut(); - - // Ensure the range is withing the file size and non-zero of length. - let range = bytes.start..(bytes.end.min(inner.data.len() as u64)); - if range.start >= range.end { - return; - } - - // Check if the range has been requested or not. - let inner = self.inner.get_mut(); - if let Some((new_range, _)) = inner.requested_range.cover(range.clone()) { - let _ = inner.request_tx.send(range).await; - inner.requested_range = new_range; - } - } - - /// Returns the length of the stream in bytes - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> u64 { - self.len - } } /// A task that will download parts from the remote archive and "send" them to the frontend as they @@ -659,7 +714,7 @@ mod test { use reqwest::{Client, StatusCode}; use rstest::*; use std::path::Path; - use tokio::io::AsyncReadExt as _; + use tokio::{fs::File, io::AsyncReadExt as _}; use tokio_util::compat::TokioAsyncReadCompatExt; #[rstest] @@ -783,7 +838,7 @@ mod test { .expect("could not initialize server"); // Construct an AsyncRangeReader - let (mut range, _) = AsyncHttpRangeReader::new( + let (range, _) = AsyncHttpRangeReader::new( Client::new(), server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(), check_method, @@ -793,10 +848,14 @@ mod test { .expect("bla"); // Also open a simple file reader - let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda")) + let file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda")) .await .unwrap(); + assert_range_and_file_contents_match(range, file).await; + } + + async fn assert_range_and_file_contents_match(mut range: AsyncHttpRangeReader, mut file: File) { // Read until the end and make sure that the contents matches let mut range_read = vec![0; 64 * 1024]; let mut file_read = vec![0; 64 * 1024]; @@ -840,4 +899,69 @@ mod test { err, AsyncHttpRangeReaderError::HttpError(err) if err.status() == Some(StatusCode::NOT_FOUND) ); } + + #[tokio::test] + async fn test_builder_happy_path() { + let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data"); + let server = StaticDirectoryServer::new(&path) + .await + .expect("could not initialize server"); + + let url = server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(); + + // First, let's try to build the reader from a head response + let head_response = AsyncHttpRangeReader::initial_head_request( + Client::new(), + url.clone(), + HeaderMap::default(), + ) + .await + .expect("could not perform head request"); + + let builder = AsyncHttpRangeReader::builder(Client::new().into()) + .from_head_response(head_response) + .expect("could not build reader from head response") + .build() + .expect("could not build reader"); + let file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda")) + .await + .unwrap(); + + // And let's make sure that the contents match the file + assert_range_and_file_contents_match(builder, file).await; + + // Now let's try to build the reader from a tail response + let tail_response = AsyncHttpRangeReader::initial_tail_request( + Client::new(), + url.clone(), + 8192, + HeaderMap::default(), + ) + .await + .expect("could not perform tail request"); + + let builder = AsyncHttpRangeReader::builder(Client::new().into()) + .from_tail_response(tail_response) + .expect("could not build reader from tail response") + .build() + .expect("could not build reader"); + let file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda")) + .await + .unwrap(); + + // And again let's make sure that the contents match the file + assert_range_and_file_contents_match(builder, file).await; + } + + #[test] + fn test_builder_fails_on_missing_content_length() { + let url = Url::parse("http://localhost").unwrap(); + let result = AsyncHttpRangeReader::builder(Client::new().into()) + .url(url) + .build(); + assert_matches!( + result, + Err(AsyncHttpRangeReaderBuilderError::InvalidContentLength) + ); + } }