From a67b84bd1958084befcf9a003795b5a16be5c67c Mon Sep 17 00:00:00 2001 From: Marcus Griep Date: Fri, 20 May 2022 10:21:26 -0400 Subject: [PATCH 1/3] feat: add `Limited` body --- src/lib.rs | 2 + src/limited.rs | 280 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+) create mode 100644 src/limited.rs diff --git a/src/lib.rs b/src/lib.rs index b63924f..9ffc8ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ mod empty; mod full; +mod limited; mod next; mod size_hint; @@ -22,6 +23,7 @@ pub mod combinators; pub use self::empty::Empty; pub use self::full::Full; +pub use self::limited::{LengthLimitError, Limited}; pub use self::next::{Data, Trailers}; pub use self::size_hint::SizeHint; diff --git a/src/limited.rs b/src/limited.rs new file mode 100644 index 0000000..5a1dd3a --- /dev/null +++ b/src/limited.rs @@ -0,0 +1,280 @@ +use crate::{Body, SizeHint}; +use bytes::Buf; +use http::HeaderMap; +use std::error::Error; +use std::fmt; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A length limited body. +/// +/// This body will return an error if more than `N` bytes are returned +/// on polling the wrapped body. +#[derive(Clone, Copy, Debug)] +pub struct Limited { + remaining: usize, + inner: B, +} + +impl Limited { + /// Create a new `Limited`. + pub fn new(inner: B) -> Limited { + Limited { + remaining: N, + inner, + } + } +} + +impl Default for Limited +where + B: Default, +{ + fn default() -> Self { + Limited::new(B::default()) + } +} + +impl Body for Limited +where + B: Body + Unpin, +{ + type Data = B::Data; + type Error = LengthLimitError; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let mut this = self; + let res = match Pin::new(&mut this.inner).poll_data(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => None, + Poll::Ready(Some(Ok(data))) => { + if data.remaining() > this.remaining { + this.remaining = 0; + Some(Err(LengthLimitError::LengthLimitExceeded)) + } else { + this.remaining -= data.remaining(); + Some(Ok(data)) + } + } + Poll::Ready(Some(Err(err))) => Some(Err(LengthLimitError::Other(err))), + }; + + Poll::Ready(res) + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + let mut this = self; + let res = match Pin::new(&mut this.inner).poll_trailers(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(data)) => Ok(data), + Poll::Ready(Err(err)) => Err(LengthLimitError::Other(err)), + }; + + Poll::Ready(res) + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + use std::convert::TryFrom; + match u64::try_from(N) { + Ok(n) => { + let mut hint = self.inner.size_hint(); + if hint.lower() >= n { + hint.set_exact(n) + } else if let Some(max) = hint.upper() { + hint.set_upper(n.min(max)) + } else { + hint.set_upper(n) + } + hint + } + Err(_) => self.inner.size_hint(), + } + } +} + +/// An error returned when reading from a [`Limited`] body. +#[derive(Debug)] +pub enum LengthLimitError { + /// The body exceeded the length limit. + LengthLimitExceeded, + /// Some other error was encountered while reading from the underlying body. + Other(E), +} + +impl fmt::Display for LengthLimitError +where + E: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::LengthLimitExceeded => f.write_str("length limit exceeded"), + Self::Other(err) => err.fmt(f), + } + } +} + +impl Error for LengthLimitError +where + E: Error, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::LengthLimitExceeded => None, + Self::Other(err) => err.source(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Full; + use bytes::Bytes; + use std::convert::Infallible; + + #[tokio::test] + async fn read_for_body_under_limit_returns_data() { + const DATA: &[u8] = b"testing"; + let inner = Full::new(Bytes::from(DATA)); + let body = &mut Limited::new::<8>(inner); + let data = body.data().await.unwrap().unwrap(); + assert_eq!(data, DATA); + assert!(matches!(body.data().await, None)); + } + + #[tokio::test] + async fn read_for_body_over_limit_returns_error() { + const DATA: &[u8] = b"testing a string that is too long"; + let inner = Full::new(Bytes::from(DATA)); + let body = &mut Limited::new::<8>(inner); + let error = body.data().await.unwrap().unwrap_err(); + assert!(matches!(error, LengthLimitError::LengthLimitExceeded)); + } + + struct Chunky(&'static [&'static [u8]]); + + impl Body for Chunky { + type Data = &'static [u8]; + type Error = Infallible; + + fn poll_data( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + let mut this = self; + match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) { + Some((data, new_tail)) => { + this.0 = new_tail; + + Poll::Ready(Some(data)) + } + None => Poll::Ready(None), + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(Some(HeaderMap::new()))) + } + } + + #[tokio::test] + async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk( + ) { + const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"]; + let inner = Chunky(DATA); + let body = &mut Limited::new::<8>(inner); + let data = body.data().await.unwrap().unwrap(); + assert_eq!(data, DATA[0]); + let error = body.data().await.unwrap().unwrap_err(); + assert!(matches!(error, LengthLimitError::LengthLimitExceeded)); + } + + #[tokio::test] + async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() { + const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"]; + let inner = Chunky(DATA); + let body = &mut Limited::new::<8>(inner); + let error = body.data().await.unwrap().unwrap_err(); + assert!(matches!(error, LengthLimitError::LengthLimitExceeded)); + } + + #[tokio::test] + async fn read_for_chunked_body_under_limit_is_okay() { + const DATA: &[&[u8]] = &[b"test", b"ing!"]; + let inner = Chunky(DATA); + let body = &mut Limited::new::<8>(inner); + let data = body.data().await.unwrap().unwrap(); + assert_eq!(data, DATA[0]); + let data = body.data().await.unwrap().unwrap(); + assert_eq!(data, DATA[1]); + assert!(matches!(body.data().await, None)); + } + + #[tokio::test] + async fn read_for_trailers_propagates_inner_trailers() { + const DATA: &[&[u8]] = &[b"test", b"ing!"]; + let inner = Chunky(DATA); + let body = &mut Limited::new::<8>(inner); + let trailers = body.trailers().await.unwrap(); + assert_eq!(trailers, Some(HeaderMap::new())) + } + + enum ErrorBodyError { + Data, + Trailers, + } + + struct ErrorBody; + + impl Body for ErrorBody { + type Data = &'static [u8]; + type Error = ErrorBodyError; + + fn poll_data( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + Poll::Ready(Some(Err(ErrorBodyError::Data))) + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Err(ErrorBodyError::Trailers)) + } + } + + #[tokio::test] + async fn read_for_body_returning_error_propagates_error() { + let body = &mut Limited::new::<8>(ErrorBody); + let error = body.data().await.unwrap().unwrap_err(); + assert!(matches!( + error, + LengthLimitError::Other(ErrorBodyError::Data) + )); + } + + #[tokio::test] + async fn trailers_for_body_returning_error_propagates_error() { + let body = &mut Limited::new::<8>(ErrorBody); + let error = body.trailers().await.unwrap_err(); + assert!(matches!( + error, + LengthLimitError::Other(ErrorBodyError::Trailers) + )); + } +} From 909645ad7d5026a0ece678ae602c828ed05058ed Mon Sep 17 00:00:00 2001 From: Marcus Griep Date: Fri, 20 May 2022 11:07:04 -0400 Subject: [PATCH 2/3] fix: correct size_hint, remove const generic --- src/limited.rs | 80 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 26 deletions(-) diff --git a/src/limited.rs b/src/limited.rs index 5a1dd3a..75904ee 100644 --- a/src/limited.rs +++ b/src/limited.rs @@ -8,34 +8,25 @@ use std::task::{Context, Poll}; /// A length limited body. /// -/// This body will return an error if more than `N` bytes are returned -/// on polling the wrapped body. +/// This body will return an error if more than the configured number +/// of bytes are returned on polling the wrapped body. #[derive(Clone, Copy, Debug)] -pub struct Limited { +pub struct Limited { remaining: usize, inner: B, } -impl Limited { +impl Limited { /// Create a new `Limited`. - pub fn new(inner: B) -> Limited { - Limited { - remaining: N, + pub fn new(inner: B, limit: usize) -> Self { + Self { + remaining: limit, inner, } } } -impl Default for Limited -where - B: Default, -{ - fn default() -> Self { - Limited::new(B::default()) - } -} - -impl Body for Limited +impl Body for Limited where B: Body + Unpin, { @@ -85,7 +76,7 @@ where fn size_hint(&self) -> SizeHint { use std::convert::TryFrom; - match u64::try_from(N) { + match u64::try_from(self.remaining) { Ok(n) => { let mut hint = self.inner.size_hint(); if hint.lower() >= n { @@ -146,9 +137,17 @@ mod tests { async fn read_for_body_under_limit_returns_data() { const DATA: &[u8] = b"testing"; let inner = Full::new(Bytes::from(DATA)); - let body = &mut Limited::new::<8>(inner); + let body = &mut Limited::new(inner, 8); + + let mut hint = SizeHint::new(); + hint.set_upper(7); + assert_eq!(body.size_hint().upper(), hint.upper()); + let data = body.data().await.unwrap().unwrap(); assert_eq!(data, DATA); + hint.set_upper(0); + assert_eq!(body.size_hint().upper(), hint.upper()); + assert!(matches!(body.data().await, None)); } @@ -156,7 +155,12 @@ mod tests { async fn read_for_body_over_limit_returns_error() { const DATA: &[u8] = b"testing a string that is too long"; let inner = Full::new(Bytes::from(DATA)); - let body = &mut Limited::new::<8>(inner); + let body = &mut Limited::new(inner, 8); + + let mut hint = SizeHint::new(); + hint.set_upper(8); + assert_eq!(body.size_hint().upper(), hint.upper()); + let error = body.data().await.unwrap().unwrap_err(); assert!(matches!(error, LengthLimitError::LengthLimitExceeded)); } @@ -195,9 +199,17 @@ mod tests { ) { const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"]; let inner = Chunky(DATA); - let body = &mut Limited::new::<8>(inner); + let body = &mut Limited::new(inner, 8); + + let mut hint = SizeHint::new(); + hint.set_upper(8); + assert_eq!(body.size_hint().upper(), hint.upper()); + let data = body.data().await.unwrap().unwrap(); assert_eq!(data, DATA[0]); + hint.set_upper(0); + assert_eq!(body.size_hint().upper(), hint.upper()); + let error = body.data().await.unwrap().unwrap_err(); assert!(matches!(error, LengthLimitError::LengthLimitExceeded)); } @@ -206,7 +218,12 @@ mod tests { async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() { const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"]; let inner = Chunky(DATA); - let body = &mut Limited::new::<8>(inner); + let body = &mut Limited::new(inner, 8); + + let mut hint = SizeHint::new(); + hint.set_upper(8); + assert_eq!(body.size_hint().upper(), hint.upper()); + let error = body.data().await.unwrap().unwrap_err(); assert!(matches!(error, LengthLimitError::LengthLimitExceeded)); } @@ -215,11 +232,22 @@ mod tests { async fn read_for_chunked_body_under_limit_is_okay() { const DATA: &[&[u8]] = &[b"test", b"ing!"]; let inner = Chunky(DATA); - let body = &mut Limited::new::<8>(inner); + let body = &mut Limited::new(inner, 8); + + let mut hint = SizeHint::new(); + hint.set_upper(8); + assert_eq!(body.size_hint().upper(), hint.upper()); + let data = body.data().await.unwrap().unwrap(); assert_eq!(data, DATA[0]); + hint.set_upper(4); + assert_eq!(body.size_hint().upper(), hint.upper()); + let data = body.data().await.unwrap().unwrap(); assert_eq!(data, DATA[1]); + hint.set_upper(0); + assert_eq!(body.size_hint().upper(), hint.upper()); + assert!(matches!(body.data().await, None)); } @@ -227,7 +255,7 @@ mod tests { async fn read_for_trailers_propagates_inner_trailers() { const DATA: &[&[u8]] = &[b"test", b"ing!"]; let inner = Chunky(DATA); - let body = &mut Limited::new::<8>(inner); + let body = &mut Limited::new(inner, 8); let trailers = body.trailers().await.unwrap(); assert_eq!(trailers, Some(HeaderMap::new())) } @@ -260,7 +288,7 @@ mod tests { #[tokio::test] async fn read_for_body_returning_error_propagates_error() { - let body = &mut Limited::new::<8>(ErrorBody); + let body = &mut Limited::new(ErrorBody, 8); let error = body.data().await.unwrap().unwrap_err(); assert!(matches!( error, @@ -270,7 +298,7 @@ mod tests { #[tokio::test] async fn trailers_for_body_returning_error_propagates_error() { - let body = &mut Limited::new::<8>(ErrorBody); + let body = &mut Limited::new(ErrorBody, 8); let error = body.trailers().await.unwrap_err(); assert!(matches!( error, From 6573327147a13cd0b341673452722c241d737f07 Mon Sep 17 00:00:00 2001 From: Marcus Griep Date: Fri, 20 May 2022 15:31:15 -0400 Subject: [PATCH 3/3] chore: use boxed error, pin project Co-authored-by: Programatik --- src/limited.rs | 101 ++++++++++++++++++++++--------------------------- 1 file changed, 46 insertions(+), 55 deletions(-) diff --git a/src/limited.rs b/src/limited.rs index 75904ee..a40add9 100644 --- a/src/limited.rs +++ b/src/limited.rs @@ -1,19 +1,23 @@ use crate::{Body, SizeHint}; use bytes::Buf; use http::HeaderMap; +use pin_project_lite::pin_project; use std::error::Error; use std::fmt; use std::pin::Pin; use std::task::{Context, Poll}; -/// A length limited body. -/// -/// This body will return an error if more than the configured number -/// of bytes are returned on polling the wrapped body. -#[derive(Clone, Copy, Debug)] -pub struct Limited { - remaining: usize, - inner: B, +pin_project! { + /// A length limited body. + /// + /// This body will return an error if more than the configured number + /// of bytes are returned on polling the wrapped body. + #[derive(Clone, Copy, Debug)] + pub struct Limited { + remaining: usize, + #[pin] + inner: B, + } } impl Limited { @@ -28,29 +32,30 @@ impl Limited { impl Body for Limited where - B: Body + Unpin, + B: Body, + B::Error: Into>, { type Data = B::Data; - type Error = LengthLimitError; + type Error = Box; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - let mut this = self; - let res = match Pin::new(&mut this.inner).poll_data(cx) { + let this = self.project(); + let res = match this.inner.poll_data(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(None) => None, Poll::Ready(Some(Ok(data))) => { - if data.remaining() > this.remaining { - this.remaining = 0; - Some(Err(LengthLimitError::LengthLimitExceeded)) + if data.remaining() > *this.remaining { + *this.remaining = 0; + Some(Err(LengthLimitError.into())) } else { - this.remaining -= data.remaining(); + *this.remaining -= data.remaining(); Some(Ok(data)) } } - Poll::Ready(Some(Err(err))) => Some(Err(LengthLimitError::Other(err))), + Poll::Ready(Some(Err(err))) => Some(Err(err.into())), }; Poll::Ready(res) @@ -60,11 +65,11 @@ where self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { - let mut this = self; - let res = match Pin::new(&mut this.inner).poll_trailers(cx) { + let this = self.project(); + let res = match this.inner.poll_trailers(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(data)) => Ok(data), - Poll::Ready(Err(err)) => Err(LengthLimitError::Other(err)), + Poll::Ready(Err(err)) => Err(err.into()), }; Poll::Ready(res) @@ -93,38 +98,18 @@ where } } -/// An error returned when reading from a [`Limited`] body. +/// An error returned when body length exceeds the configured limit. #[derive(Debug)] -pub enum LengthLimitError { - /// The body exceeded the length limit. - LengthLimitExceeded, - /// Some other error was encountered while reading from the underlying body. - Other(E), -} +#[non_exhaustive] +pub struct LengthLimitError; -impl fmt::Display for LengthLimitError -where - E: fmt::Display, -{ +impl fmt::Display for LengthLimitError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::LengthLimitExceeded => f.write_str("length limit exceeded"), - Self::Other(err) => err.fmt(f), - } + f.write_str("length limit exceeded") } } -impl Error for LengthLimitError -where - E: Error, -{ - fn source(&self) -> Option<&(dyn Error + 'static)> { - match self { - Self::LengthLimitExceeded => None, - Self::Other(err) => err.source(), - } - } -} +impl Error for LengthLimitError {} #[cfg(test)] mod tests { @@ -162,7 +147,7 @@ mod tests { assert_eq!(body.size_hint().upper(), hint.upper()); let error = body.data().await.unwrap().unwrap_err(); - assert!(matches!(error, LengthLimitError::LengthLimitExceeded)); + assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); } struct Chunky(&'static [&'static [u8]]); @@ -211,7 +196,7 @@ mod tests { assert_eq!(body.size_hint().upper(), hint.upper()); let error = body.data().await.unwrap().unwrap_err(); - assert!(matches!(error, LengthLimitError::LengthLimitExceeded)); + assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); } #[tokio::test] @@ -225,7 +210,7 @@ mod tests { assert_eq!(body.size_hint().upper(), hint.upper()); let error = body.data().await.unwrap().unwrap_err(); - assert!(matches!(error, LengthLimitError::LengthLimitExceeded)); + assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); } #[tokio::test] @@ -260,11 +245,20 @@ mod tests { assert_eq!(trailers, Some(HeaderMap::new())) } + #[derive(Debug)] enum ErrorBodyError { Data, Trailers, } + impl fmt::Display for ErrorBodyError { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { + Ok(()) + } + } + + impl Error for ErrorBodyError {} + struct ErrorBody; impl Body for ErrorBody { @@ -290,10 +284,7 @@ mod tests { async fn read_for_body_returning_error_propagates_error() { let body = &mut Limited::new(ErrorBody, 8); let error = body.data().await.unwrap().unwrap_err(); - assert!(matches!( - error, - LengthLimitError::Other(ErrorBodyError::Data) - )); + assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data))); } #[tokio::test] @@ -301,8 +292,8 @@ mod tests { let body = &mut Limited::new(ErrorBody, 8); let error = body.trailers().await.unwrap_err(); assert!(matches!( - error, - LengthLimitError::Other(ErrorBodyError::Trailers) + error.downcast_ref(), + Some(ErrorBodyError::Trailers) )); } }